mirror of
https://github.com/tesseract-ocr/tesseract.git
synced 2025-01-22 01:30:49 +08:00
intsimdmatrixavx2: Do biasing in SIMD.
We also move to relying on both scales and output having been padded to accomodate us writing more results than are actually needed here. This was allowed for a few commits back.
This commit is contained in:
parent
872816897a
commit
d1e49d6dd2
@ -80,24 +80,80 @@ static inline void MultiplyGroup(const __m256i& rep_input, const __m256i& ones,
|
||||
result = _mm256_add_epi32(result, weights);
|
||||
}
|
||||
|
||||
// Extracts and converts 8x32-bit results from result, adding the bias from wi
|
||||
// and scaling by scales, before storing in *v. Note that wi, scales and v are
|
||||
// expected to contain 8 consecutive elements or num_out if less.
|
||||
static inline void ExtractResults(__m256i& result, __m256i& shift_id,
|
||||
const int8_t*& wi, const double*& scales,
|
||||
int num_out, double*& v) {
|
||||
for (int out = 0; out < num_out; ++out) {
|
||||
#ifndef _MSC_VER
|
||||
auto res = _mm256_extract_epi32(result, 0);
|
||||
// Load 64 bits into the bottom of a 128bit register.
|
||||
// We don't actually care what the top 64bits are, but this ends
|
||||
// up with them being zero.
|
||||
static inline __m128i load64_to_128(const int8_t *wi_)
|
||||
{
|
||||
#if 1
|
||||
const int64_t *wi = reinterpret_cast<const int64_t *>(wi_);
|
||||
return _mm_set_epi64x(0, wi[0]);
|
||||
#else
|
||||
// Workaround MSVC's ICE
|
||||
// _mm256_extract_epi32(X, Y) == ((int32_t*)&X)[Y]
|
||||
auto res = ((int32_t*)&result)[0];
|
||||
// This version doesn't work on MSVC 32bits mode.
|
||||
__m64 wi64 = _m_from_int64(*reinterpret_cast<const __int64*>(wi));
|
||||
// wi64 = 8 x 8 bit values
|
||||
return _mm_movpi64_epi64(wi64); // 8x8 bit vals in 128bit reg
|
||||
#endif
|
||||
*v++ = (res + *wi++ * INT8_MAX) * *scales++;
|
||||
// Rotate the results in int32_t units, so the next result is ready.
|
||||
result = _mm256_permutevar8x32_epi32(result, shift_id);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void ExtractResults8(__m256i result,
|
||||
const int8_t* wi,
|
||||
const double* scales,
|
||||
double* v) {
|
||||
__m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
|
||||
__m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
|
||||
__m256i bias_scale = _mm256_set_epi32(127,127,127,127,127,127,127,127);
|
||||
__m256d scale0123 = _mm256_loadu_pd(scales);
|
||||
__m256d scale4567 = _mm256_loadu_pd(scales+4);
|
||||
__m256d bias0123, bias4567, res0123, res4567;
|
||||
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
|
||||
result = _mm256_add_epi32(result, w256); // result += bias * 127
|
||||
res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
|
||||
result = _mm256_permute4x64_epi64(result, 2+(3<<2));
|
||||
res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
|
||||
res0123 = _mm256_mul_pd(res0123, scale0123);
|
||||
res4567 = _mm256_mul_pd(res4567, scale4567);
|
||||
_mm256_storeu_pd(v, res0123);
|
||||
_mm256_storeu_pd(v+4, res4567);
|
||||
}
|
||||
|
||||
static inline void ExtractResults16(__m256i result0,
|
||||
__m256i result1,
|
||||
const int8_t*& wi,
|
||||
const double*& scales,
|
||||
double*& v) {
|
||||
__m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(wi));
|
||||
// 8x8bit vals in bottom of 128bit reg
|
||||
const __m256i bias_scale = _mm256_set_epi32(127,127,127,127,127,127,127,127);
|
||||
__m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
|
||||
__m256d scale0123 = _mm256_loadu_pd(scales);
|
||||
__m256d scale4567 = _mm256_loadu_pd(scales+4);
|
||||
__m256d bias0123, bias4567, res0123, res4567;
|
||||
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
|
||||
result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
|
||||
res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
|
||||
result0 = _mm256_permute4x64_epi64(result0, 2+(3<<2));
|
||||
res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
|
||||
res0123 = _mm256_mul_pd(res0123, scale0123);
|
||||
res4567 = _mm256_mul_pd(res4567, scale4567);
|
||||
_mm256_storeu_pd(v, res0123);
|
||||
_mm256_storeu_pd(v+4, res4567);
|
||||
w8 = _mm_shuffle_epi32(w8,2+(3<<2));
|
||||
w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
|
||||
scale0123 = _mm256_loadu_pd(scales+8);
|
||||
scale4567 = _mm256_loadu_pd(scales+12);
|
||||
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
|
||||
result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
|
||||
res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
|
||||
result1 = _mm256_permute4x64_epi64(result1, 2+(3<<2));
|
||||
res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
|
||||
res0123 = _mm256_mul_pd(res0123, scale0123);
|
||||
res4567 = _mm256_mul_pd(res4567, scale4567);
|
||||
_mm256_storeu_pd(v+8, res0123);
|
||||
_mm256_storeu_pd(v+12, res4567);
|
||||
wi += 16;
|
||||
scales += 16;
|
||||
v += 16;
|
||||
}
|
||||
|
||||
// Computes part of matrix.vector v = Wu. Computes N=64 results.
|
||||
@ -108,7 +164,7 @@ static inline void ExtractResults(__m256i& result, __m256i& shift_id,
|
||||
// u must be padded out with zeros to
|
||||
// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
|
||||
static void PartialMatrixDotVector64(const int8_t* wi, const double* scales,
|
||||
const int8_t* u, int num_in, int num_out,
|
||||
const int8_t* u, int num_in,
|
||||
double* v) {
|
||||
// Register containing 16-bit ones for horizontal add with 16->32 bit
|
||||
// conversion.
|
||||
@ -149,22 +205,16 @@ static void PartialMatrixDotVector64(const int8_t* wi, const double* scales,
|
||||
MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
|
||||
}
|
||||
}
|
||||
ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
ExtractResults(result1, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
ExtractResults(result2, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
ExtractResults(result3, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
ExtractResults(result4, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
ExtractResults(result5, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
ExtractResults(result6, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
num_out -= kNumOutputsPerRegister * 7;
|
||||
ExtractResults(result7, shift_id, wi, scales,
|
||||
std::min(kNumOutputsPerRegister, num_out), v);
|
||||
ExtractResults16(result0, result1, wi, scales, v);
|
||||
ExtractResults16(result2, result3, wi, scales, v);
|
||||
ExtractResults16(result4, result5, wi, scales, v);
|
||||
ExtractResults16(result6, result7, wi, scales, v);
|
||||
}
|
||||
|
||||
// Computes part of matrix.vector v = Wu. Computes N=32 results.
|
||||
// For details see PartialMatrixDotVector64 with N=32.
|
||||
static void PartialMatrixDotVector32(const int8_t* wi, const double* scales,
|
||||
const int8_t* u, int num_in, int num_out,
|
||||
const int8_t* u, int num_in,
|
||||
double* v) {
|
||||
// Register containing 16-bit ones for horizontal add with 16->32 bit
|
||||
// conversion.
|
||||
@ -197,18 +247,14 @@ static void PartialMatrixDotVector32(const int8_t* wi, const double* scales,
|
||||
MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
|
||||
}
|
||||
}
|
||||
ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
ExtractResults(result1, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
ExtractResults(result2, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
num_out -= kNumOutputsPerRegister * 3;
|
||||
ExtractResults(result3, shift_id, wi, scales,
|
||||
std::min(kNumOutputsPerRegister, num_out), v);
|
||||
ExtractResults16(result0, result1, wi, scales, v);
|
||||
ExtractResults16(result2, result3, wi, scales, v);
|
||||
}
|
||||
|
||||
// Computes part of matrix.vector v = Wu. Computes N=16 results.
|
||||
// For details see PartialMatrixDotVector64 with N=16.
|
||||
static void PartialMatrixDotVector16(const int8_t* wi, const double* scales,
|
||||
const int8_t* u, int num_in, int num_out,
|
||||
const int8_t* u, int num_in,
|
||||
double* v) {
|
||||
// Register containing 16-bit ones for horizontal add with 16->32 bit
|
||||
// conversion.
|
||||
@ -237,17 +283,18 @@ static void PartialMatrixDotVector16(const int8_t* wi, const double* scales,
|
||||
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
|
||||
}
|
||||
}
|
||||
ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
num_out -= kNumOutputsPerRegister;
|
||||
ExtractResults(result1, shift_id, wi, scales,
|
||||
std::min(kNumOutputsPerRegister, num_out), v);
|
||||
ExtractResults16(result0, result1, wi, scales, v);
|
||||
}
|
||||
|
||||
// Computes part of matrix.vector v = Wu. Computes N=8 results.
|
||||
// For details see PartialMatrixDotVector64 with N=8.
|
||||
static void PartialMatrixDotVector8(const int8_t* wi, const double* scales,
|
||||
const int8_t* u, int num_in, int num_out,
|
||||
double* v) {
|
||||
static inline void PartialMatrixDotVector8(const int8_t *wi,
|
||||
const double *scales,
|
||||
const int8_t *u,
|
||||
int num_in,
|
||||
double *v) {
|
||||
double *ov = v;
|
||||
double temp[8];
|
||||
// Register containing 16-bit ones for horizontal add with 16->32 bit
|
||||
// conversion.
|
||||
__m256i ones =
|
||||
@ -273,7 +320,7 @@ static void PartialMatrixDotVector8(const int8_t* wi, const double* scales,
|
||||
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
|
||||
}
|
||||
}
|
||||
ExtractResults(result0, shift_id, wi, scales, num_out, v);
|
||||
ExtractResults8(result0, wi, scales, v);
|
||||
}
|
||||
|
||||
static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
|
||||
@ -294,7 +341,7 @@ static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
|
||||
// Run with this group size, until it would produce too much output, then
|
||||
// switch to a smaller size.
|
||||
for (; output + group_size <= rounded_num_out; output += group_size) {
|
||||
PartialMatrixDotVector64(wi, scales, u, rounded_num_in, num_out - output, v);
|
||||
PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
|
||||
wi += w_step;
|
||||
scales += group_size;
|
||||
v += group_size;
|
||||
@ -302,30 +349,28 @@ static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
|
||||
group_size /= 2;
|
||||
w_step /= 2;
|
||||
|
||||
for (; output + group_size <= rounded_num_out; output += group_size) {
|
||||
PartialMatrixDotVector32(wi, scales, u, rounded_num_in, num_out - output, v);
|
||||
if (output + group_size <= rounded_num_out) {
|
||||
PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
|
||||
wi += w_step;
|
||||
scales += group_size;
|
||||
v += group_size;
|
||||
output += group_size;
|
||||
}
|
||||
group_size /= 2;
|
||||
w_step /= 2;
|
||||
|
||||
for (; output + group_size <= rounded_num_out; output += group_size) {
|
||||
PartialMatrixDotVector16(wi, scales, u, rounded_num_in, num_out - output, v);
|
||||
if (output + group_size <= rounded_num_out) {
|
||||
PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
|
||||
wi += w_step;
|
||||
scales += group_size;
|
||||
v += group_size;
|
||||
output += group_size;
|
||||
}
|
||||
group_size /= 2;
|
||||
w_step /= 2;
|
||||
|
||||
for (; output + group_size <= rounded_num_out; output += group_size) {
|
||||
PartialMatrixDotVector8(wi, scales, u, rounded_num_in, num_out - output, v);
|
||||
wi += w_step;
|
||||
scales += group_size;
|
||||
v += group_size;
|
||||
}
|
||||
if (output + group_size <= rounded_num_out)
|
||||
PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
|
||||
}
|
||||
|
||||
const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = {
|
||||
|
Loading…
Reference in New Issue
Block a user