Merge pull request #2994 from robinwatts/pushback11

Improve speed of tesseract by optimising for intSimdMatrix case
This commit is contained in:
zdenop 2020-10-17 17:19:49 +02:00 committed by GitHub
commit 514a7893f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 164 additions and 84 deletions

View File

@ -316,6 +316,11 @@ class GenericVector {
return true;
}
void scale(T factor) const {
for (int i = 0; i < size_used_; ++i) {
data_[i] *= factor;
}
}
protected:
// Internal recursive version of choose_nth_item.
int choose_nth_item(int target_index, int start, int end, unsigned int* seed);

View File

@ -27,7 +27,8 @@ const IntSimdMatrix* IntSimdMatrix::intSimdMatrix = nullptr;
// Computes a reshaped copy of the weight matrix w.
void IntSimdMatrix::Init(const GENERIC_2D_ARRAY<int8_t>& w,
std::vector<int8_t>& shaped_w) const {
std::vector<int8_t>& shaped_w,
GenericVector<double>& scales) const {
const int num_out = w.dim1();
const int num_in = w.dim2() - 1;
// The rounded-up sizes of the reshaped weight matrix, excluding biases.
@ -35,6 +36,7 @@ void IntSimdMatrix::Init(const GENERIC_2D_ARRAY<int8_t>& w,
int rounded_num_out = RoundOutputs(num_out);
// Add the bias and compute the required size.
shaped_w.resize((rounded_num_in + 1) * rounded_num_out, 0);
scales.resize_no_init(rounded_num_out);
int shaped_index = 0;
int output = 0;
// Each number of registers needs a different format! Iterates over the
@ -87,7 +89,7 @@ void IntSimdMatrix::MatrixDotVector(const GENERIC_2D_ARRAY<int8_t>& w,
int total = 0;
for (int j = 0; j < num_in; ++j) total += wi[j] * u[j];
// Add in the bias and correct for integer values.
v[i] = (static_cast<double>(total) / INT8_MAX + wi[num_in]) * scales[i];
v[i] = (total + wi[num_in] * INT8_MAX) * scales[i];
}
}

View File

@ -62,7 +62,8 @@ namespace tesseract {
struct IntSimdMatrix {
// Computes a reshaped copy of the weight matrix w.
void Init(const GENERIC_2D_ARRAY<int8_t>& w,
std::vector<int8_t>& shaped_w) const;
std::vector<int8_t>& shaped_w,
GenericVector<double>& scales) const;
// Rounds the size up to a multiple of the input register size (in int8_t).
int RoundInputs(int size) const {

View File

@ -80,24 +80,73 @@ static inline void MultiplyGroup(const __m256i& rep_input, const __m256i& ones,
result = _mm256_add_epi32(result, weights);
}
// Extracts and converts 8x32-bit results from result, adding the bias from wi
// and scaling by scales, before storing in *v. Note that wi, scales and v are
// expected to contain 8 consecutive elements or num_out if less.
static inline void ExtractResults(__m256i& result, __m256i& shift_id,
const int8_t*& wi, const double*& scales,
int num_out, double*& v) {
for (int out = 0; out < num_out; ++out) {
#ifndef _MSC_VER
auto res = _mm256_extract_epi32(result, 0);
#else
// Workaround MSVC's ICE
// _mm256_extract_epi32(X, Y) == ((int32_t*)&X)[Y]
auto res = ((int32_t*)&result)[0];
#endif
*v++ = (static_cast<double>(res) / INT8_MAX + *wi++) * *scales++;
// Rotate the results in int32_t units, so the next result is ready.
result = _mm256_permutevar8x32_epi32(result, shift_id);
}
// Load 64 bits into the bottom of a 128bit register.
// We don't actually care what the top 64bits are, but this ends
// up with them being zero.
static inline __m128i load64_to_128(const int8_t *wi_)
{
const int64_t *wi = reinterpret_cast<const int64_t *>(wi_);
return _mm_set_epi64x(0, wi[0]);
}
static inline void ExtractResults8(__m256i result,
const int8_t* wi,
const double* scales,
double* v) {
__m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
__m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
__m256i bias_scale = _mm256_set_epi32(127,127,127,127,127,127,127,127);
__m256d scale0123 = _mm256_loadu_pd(scales);
__m256d scale4567 = _mm256_loadu_pd(scales+4);
__m256d bias0123, bias4567, res0123, res4567;
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result = _mm256_add_epi32(result, w256); // result += bias * 127
res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
result = _mm256_permute4x64_epi64(result, 2+(3<<2));
res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
res0123 = _mm256_mul_pd(res0123, scale0123);
res4567 = _mm256_mul_pd(res4567, scale4567);
_mm256_storeu_pd(v, res0123);
_mm256_storeu_pd(v+4, res4567);
}
static inline void ExtractResults16(__m256i result0,
__m256i result1,
const int8_t*& wi,
const double*& scales,
double*& v) {
__m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(wi));
// 8x8bit vals in bottom of 128bit reg
const __m256i bias_scale = _mm256_set_epi32(127,127,127,127,127,127,127,127);
__m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
__m256d scale0123 = _mm256_loadu_pd(scales);
__m256d scale4567 = _mm256_loadu_pd(scales+4);
__m256d bias0123, bias4567, res0123, res4567;
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
result0 = _mm256_permute4x64_epi64(result0, 2+(3<<2));
res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
res0123 = _mm256_mul_pd(res0123, scale0123);
res4567 = _mm256_mul_pd(res4567, scale4567);
_mm256_storeu_pd(v, res0123);
_mm256_storeu_pd(v+4, res4567);
w8 = _mm_shuffle_epi32(w8,2+(3<<2));
w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
scale0123 = _mm256_loadu_pd(scales+8);
scale4567 = _mm256_loadu_pd(scales+12);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
result1 = _mm256_permute4x64_epi64(result1, 2+(3<<2));
res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
res0123 = _mm256_mul_pd(res0123, scale0123);
res4567 = _mm256_mul_pd(res4567, scale4567);
_mm256_storeu_pd(v+8, res0123);
_mm256_storeu_pd(v+12, res4567);
wi += 16;
scales += 16;
v += 16;
}
// Computes part of matrix.vector v = Wu. Computes N=64 results.
@ -108,7 +157,7 @@ static inline void ExtractResults(__m256i& result, __m256i& shift_id,
// u must be padded out with zeros to
// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
static void PartialMatrixDotVector64(const int8_t* wi, const double* scales,
const int8_t* u, int num_in, int num_out,
const int8_t* u, int num_in,
double* v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
@ -149,22 +198,16 @@ static void PartialMatrixDotVector64(const int8_t* wi, const double* scales,
MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
}
}
ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
ExtractResults(result1, shift_id, wi, scales, kNumOutputsPerRegister, v);
ExtractResults(result2, shift_id, wi, scales, kNumOutputsPerRegister, v);
ExtractResults(result3, shift_id, wi, scales, kNumOutputsPerRegister, v);
ExtractResults(result4, shift_id, wi, scales, kNumOutputsPerRegister, v);
ExtractResults(result5, shift_id, wi, scales, kNumOutputsPerRegister, v);
ExtractResults(result6, shift_id, wi, scales, kNumOutputsPerRegister, v);
num_out -= kNumOutputsPerRegister * 7;
ExtractResults(result7, shift_id, wi, scales,
std::min(kNumOutputsPerRegister, num_out), v);
ExtractResults16(result0, result1, wi, scales, v);
ExtractResults16(result2, result3, wi, scales, v);
ExtractResults16(result4, result5, wi, scales, v);
ExtractResults16(result6, result7, wi, scales, v);
}
// Computes part of matrix.vector v = Wu. Computes N=32 results.
// For details see PartialMatrixDotVector64 with N=32.
static void PartialMatrixDotVector32(const int8_t* wi, const double* scales,
const int8_t* u, int num_in, int num_out,
const int8_t* u, int num_in,
double* v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
@ -197,18 +240,14 @@ static void PartialMatrixDotVector32(const int8_t* wi, const double* scales,
MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
}
}
ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
ExtractResults(result1, shift_id, wi, scales, kNumOutputsPerRegister, v);
ExtractResults(result2, shift_id, wi, scales, kNumOutputsPerRegister, v);
num_out -= kNumOutputsPerRegister * 3;
ExtractResults(result3, shift_id, wi, scales,
std::min(kNumOutputsPerRegister, num_out), v);
ExtractResults16(result0, result1, wi, scales, v);
ExtractResults16(result2, result3, wi, scales, v);
}
// Computes part of matrix.vector v = Wu. Computes N=16 results.
// For details see PartialMatrixDotVector64 with N=16.
static void PartialMatrixDotVector16(const int8_t* wi, const double* scales,
const int8_t* u, int num_in, int num_out,
const int8_t* u, int num_in,
double* v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
@ -237,17 +276,18 @@ static void PartialMatrixDotVector16(const int8_t* wi, const double* scales,
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
}
}
ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
num_out -= kNumOutputsPerRegister;
ExtractResults(result1, shift_id, wi, scales,
std::min(kNumOutputsPerRegister, num_out), v);
ExtractResults16(result0, result1, wi, scales, v);
}
// Computes part of matrix.vector v = Wu. Computes N=8 results.
// For details see PartialMatrixDotVector64 with N=8.
static void PartialMatrixDotVector8(const int8_t* wi, const double* scales,
const int8_t* u, int num_in, int num_out,
double* v) {
static inline void PartialMatrixDotVector8(const int8_t *wi,
const double *scales,
const int8_t *u,
int num_in,
double *v) {
double *ov = v;
double temp[8];
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones =
@ -273,7 +313,7 @@ static void PartialMatrixDotVector8(const int8_t* wi, const double* scales,
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
}
}
ExtractResults(result0, shift_id, wi, scales, num_out, v);
ExtractResults8(result0, wi, scales, v);
}
static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
@ -294,7 +334,7 @@ static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
// Run with this group size, until it would produce too much output, then
// switch to a smaller size.
for (; output + group_size <= rounded_num_out; output += group_size) {
PartialMatrixDotVector64(wi, scales, u, rounded_num_in, num_out - output, v);
PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
wi += w_step;
scales += group_size;
v += group_size;
@ -302,30 +342,28 @@ static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
group_size /= 2;
w_step /= 2;
for (; output + group_size <= rounded_num_out; output += group_size) {
PartialMatrixDotVector32(wi, scales, u, rounded_num_in, num_out - output, v);
if (output + group_size <= rounded_num_out) {
PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
wi += w_step;
scales += group_size;
v += group_size;
output += group_size;
}
group_size /= 2;
w_step /= 2;
for (; output + group_size <= rounded_num_out; output += group_size) {
PartialMatrixDotVector16(wi, scales, u, rounded_num_in, num_out - output, v);
if (output + group_size <= rounded_num_out) {
PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
wi += w_step;
scales += group_size;
v += group_size;
output += group_size;
}
group_size /= 2;
w_step /= 2;
for (; output + group_size <= rounded_num_out; output += group_size) {
PartialMatrixDotVector8(wi, scales, u, rounded_num_in, num_out - output, v);
wi += w_step;
scales += group_size;
v += group_size;
}
if (output + group_size <= rounded_num_out)
PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
}
const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = {

View File

@ -61,6 +61,7 @@ PartialMatrixDotVector8(const int8_t* __restrict wi,
// Initialize all the results to 0.
int32x4_t result0123 = { 0, 0, 0, 0 };
int32x4_t result4567 = { 0, 0, 0, 0 };
int8x8_t bias_scale = { 127, 127, 127, 127, 127, 127, 127, 127 };
// Iterate over the input (u), one registerful at a time.
for (int j = 0; j < num_in; j += 8) {
int8x8_t vu = vld1_s8(u); // vu = u0 u1 u2 u3 u4 u5 u6 u7
@ -112,21 +113,27 @@ PartialMatrixDotVector8(const int8_t* __restrict wi,
u += 8;
wi += 64;
}
*v++ = (static_cast<double>(vget_lane_s32(vget_low_s32 (result0123), 0)) / INT8_MAX + *wi++) * *scales++;
if (num_out > 1)
*v++ = (static_cast<double>(vget_lane_s32(vget_low_s32 (result0123), 1)) / INT8_MAX + *wi++) * *scales++;
if (num_out > 2)
*v++ = (static_cast<double>(vget_lane_s32(vget_high_s32(result0123), 0)) / INT8_MAX + *wi++) * *scales++;
if (num_out > 3)
*v++ = (static_cast<double>(vget_lane_s32(vget_high_s32(result0123), 1)) / INT8_MAX + *wi++) * *scales++;
if (num_out > 4)
*v++ = (static_cast<double>(vget_lane_s32(vget_low_s32 (result4567), 0)) / INT8_MAX + *wi++) * *scales++;
if (num_out > 5)
*v++ = (static_cast<double>(vget_lane_s32(vget_low_s32 (result4567), 1)) / INT8_MAX + *wi++) * *scales++;
if (num_out > 6)
*v++ = (static_cast<double>(vget_lane_s32(vget_high_s32(result4567), 0)) / INT8_MAX + *wi++) * *scales++;
if (num_out > 7)
*v = (static_cast<double>(vget_lane_s32(vget_high_s32(result4567), 1)) / INT8_MAX + *wi ) * *scales;
{
int8x8_t bias = vld1_s8(wi); // vw0 = b0 b1 b2 b3 b4 b5 b6 b7
int16x8_t scaled_bias = vmull_s8(bias, bias_scale);
result0123 = vaddw_s16(result0123, vget_low_s16(scaled_bias));
result4567 = vaddw_s16(result4567, vget_high_s16(scaled_bias));
*v++ = vget_lane_s32(vget_low_s32 (result0123), 0) * *scales++;
if (num_out > 1)
*v++ = vget_lane_s32(vget_low_s32 (result0123), 1) * *scales++;
if (num_out > 2)
*v++ = vget_lane_s32(vget_high_s32(result0123), 0) * *scales++;
if (num_out > 3)
*v++ = vget_lane_s32(vget_high_s32(result0123), 1) * *scales++;
if (num_out > 4)
*v++ = vget_lane_s32(vget_low_s32 (result4567), 0) * *scales++;
if (num_out > 5)
*v++ = vget_lane_s32(vget_low_s32 (result4567), 1) * *scales++;
if (num_out > 6)
*v++ = vget_lane_s32(vget_high_s32(result4567), 0) * *scales++;
if (num_out > 7)
*v = vget_lane_s32(vget_high_s32(result4567), 1) * *scales;
}
}
static void matrixDotVector(int dim1, int dim2, const int8_t* wi,

View File

@ -74,7 +74,7 @@ static void PartialMatrixDotVector1(const int8_t* wi, const double* scales,
double* v) {
double total = IntDotProductSSE(u, wi, num_in);
// Add in the bias and correct for integer values.
*v = (total / INT8_MAX + wi[num_in]) * *scales;
*v = (total + wi[num_in] * INT8_MAX) * *scales;
}
static void matrixDotVector(int dim1, int dim2, const int8_t* wi,

View File

@ -132,8 +132,11 @@ void FullyConnected::Forward(bool debug, const NetworkIO& input,
temp_lines.init_to_size(kNumThreads, NetworkScratch::FloatVec());
GenericVector<NetworkScratch::FloatVec> curr_input;
curr_input.init_to_size(kNumThreads, NetworkScratch::FloatVec());
int ro = no_;
if (IntSimdMatrix::intSimdMatrix)
ro = IntSimdMatrix::intSimdMatrix->RoundOutputs(ro);
for (int i = 0; i < kNumThreads; ++i) {
temp_lines[i].Init(no_, scratch);
temp_lines[i].Init(no_, ro, scratch);
curr_input[i].Init(ni_, scratch);
}
#ifdef _OPENMP

View File

@ -264,7 +264,10 @@ void LSTM::Forward(bool debug, const NetworkIO& input,
ResizeForward(input);
// Temporary storage of forward computation for each gate.
NetworkScratch::FloatVec temp_lines[WT_COUNT];
for (auto & temp_line : temp_lines) temp_line.Init(ns_, scratch);
int ro = ns_;
if (source_.int_mode() && IntSimdMatrix::intSimdMatrix)
ro = IntSimdMatrix::intSimdMatrix->RoundOutputs(ro);
for (auto & temp_line : temp_lines) temp_line.Init(ns_, ro, scratch);
// Single timestep buffers for the current/recurrent output and state.
NetworkScratch::FloatVec curr_state, curr_output;
curr_state.Init(ns_, scratch);

View File

@ -144,15 +144,24 @@ class NetworkScratch {
if (scratch_space_ != nullptr) scratch_space_->vec_stack_.Return(vec_);
}
void Init(int size, NetworkScratch* scratch) {
void Init(int size, int reserve, NetworkScratch* scratch) {
if (scratch_space_ != nullptr && vec_ != nullptr)
scratch_space_->vec_stack_.Return(vec_);
scratch_space_ = scratch;
vec_ = scratch_space_->vec_stack_.Borrow();
// Abuse vec_ here; first resize to 'reserve', which is larger
// than 'size' (i.e. it's size rounded up) then resize down again
// to the desired size. This assumes that the implementation does
// not shrink the storage on a resize.
vec_->resize_no_init(reserve);
vec_->resize_no_init(size);
data_ = &(*vec_)[0];
}
void Init(int size, NetworkScratch *scratch) {
Init(size, size, scratch);
}
// Use the cast operator instead of operator[] so the FloatVec can be used
// as a double* argument to a function call.
operator double*() const { return data_; }

View File

@ -135,7 +135,7 @@ void WeightMatrix::ConvertToInt() {
if (abs_val > max_abs) max_abs = abs_val;
}
double scale = max_abs / INT8_MAX;
scales_[t] = scale;
scales_[t] = scale / INT8_MAX;
if (scale == 0.0) scale = 1.0;
for (int f = 0; f < dim2; ++f) {
i_line[f] = IntCastRounded(f_line[f] / scale);
@ -144,7 +144,7 @@ void WeightMatrix::ConvertToInt() {
wf_.Resize(1, 1, 0.0);
int_mode_ = true;
if (IntSimdMatrix::intSimdMatrix) {
IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_);
IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_, scales_);
}
}
@ -177,7 +177,12 @@ bool WeightMatrix::Serialize(bool training, TFile* fp) const {
if (!fp->Serialize(&mode)) return false;
if (int_mode_) {
if (!wi_.Serialize(fp)) return false;
/* The scales stored in memory have an extra factor applied to them
* to allow faster operation. We have to remove that factor here
* before writing to disc, and put it back afterwards. */
scales_.scale(INT8_MAX);
if (!scales_.Serialize(fp)) return false;
scales_.scale(1.0/INT8_MAX);
} else {
if (!wf_.Serialize(fp)) return false;
if (training && !updates_.Serialize(fp)) return false;
@ -197,8 +202,9 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) {
if (int_mode_) {
if (!wi_.DeSerialize(fp)) return false;
if (!scales_.DeSerialize(fp)) return false;
scales_.scale(1.0/INT8_MAX);
if (IntSimdMatrix::intSimdMatrix) {
IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_);
IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_, scales_);
}
} else {
if (!wf_.DeSerialize(fp)) return false;

View File

@ -68,11 +68,17 @@ class IntSimdMatrixTest : public ::testing::Test {
GENERIC_2D_ARRAY<int8_t> w = InitRandom(num_out, num_in + 1);
std::vector<int8_t> u = RandomVector(num_in, matrix);
GenericVector<double> scales = RandomScales(num_out);
std::vector<double> base_result(num_out);
scales.scale(1.0/INT8_MAX);
int ro = num_out;
if (IntSimdMatrix::intSimdMatrix)
ro = IntSimdMatrix::intSimdMatrix->RoundOutputs(ro);
std::vector<double> base_result(ro);
base_result.resize(num_out);
IntSimdMatrix::MatrixDotVector(w, scales, u.data(), base_result.data());
std::vector<double> test_result(num_out);
std::vector<double> test_result(ro);
test_result.resize(num_out);
std::vector<int8_t> shaped_wi;
matrix.Init(w, shaped_wi);
matrix.Init(w, shaped_wi, scales);
if (matrix.matrixDotVectorFunction) {
matrix.matrixDotVectorFunction(w.dim1(), w.dim2(), &shaped_wi[0],
&scales[0], &u[0], &test_result[0]);