mirror of
https://github.com/tesseract-ocr/tesseract.git
synced 2025-01-22 01:30:49 +08:00
intsimdmatrixneon.cpp: Do biasing in SIMD.
This commit is contained in:
parent
d1e49d6dd2
commit
db10c7b577
@ -85,15 +85,8 @@ static inline void MultiplyGroup(const __m256i& rep_input, const __m256i& ones,
|
||||
// up with them being zero.
|
||||
static inline __m128i load64_to_128(const int8_t *wi_)
|
||||
{
|
||||
#if 1
|
||||
const int64_t *wi = reinterpret_cast<const int64_t *>(wi_);
|
||||
return _mm_set_epi64x(0, wi[0]);
|
||||
#else
|
||||
// This version doesn't work on MSVC 32bits mode.
|
||||
__m64 wi64 = _m_from_int64(*reinterpret_cast<const __int64*>(wi));
|
||||
// wi64 = 8 x 8 bit values
|
||||
return _mm_movpi64_epi64(wi64); // 8x8 bit vals in 128bit reg
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline void ExtractResults8(__m256i result,
|
||||
|
@ -61,6 +61,7 @@ PartialMatrixDotVector8(const int8_t* __restrict wi,
|
||||
// Initialize all the results to 0.
|
||||
int32x4_t result0123 = { 0, 0, 0, 0 };
|
||||
int32x4_t result4567 = { 0, 0, 0, 0 };
|
||||
int8x8_t bias_scale = { 127, 127, 127, 127, 127, 127, 127, 127 };
|
||||
// Iterate over the input (u), one registerful at a time.
|
||||
for (int j = 0; j < num_in; j += 8) {
|
||||
int8x8_t vu = vld1_s8(u); // vu = u0 u1 u2 u3 u4 u5 u6 u7
|
||||
@ -112,21 +113,27 @@ PartialMatrixDotVector8(const int8_t* __restrict wi,
|
||||
u += 8;
|
||||
wi += 64;
|
||||
}
|
||||
*v++ = (vget_lane_s32(vget_low_s32 (result0123), 0) + *wi++ * INT8_MAX) * *scales++;
|
||||
if (num_out > 1)
|
||||
*v++ = (vget_lane_s32(vget_low_s32 (result0123), 1) + *wi++ * INT8_MAX) * *scales++;
|
||||
if (num_out > 2)
|
||||
*v++ = (vget_lane_s32(vget_high_s32(result0123), 0) + *wi++ * INT8_MAX) * *scales++;
|
||||
if (num_out > 3)
|
||||
*v++ = (vget_lane_s32(vget_high_s32(result0123), 1) + *wi++ * INT8_MAX) * *scales++;
|
||||
if (num_out > 4)
|
||||
*v++ = (vget_lane_s32(vget_low_s32 (result4567), 0) + *wi++ * INT8_MAX) * *scales++;
|
||||
if (num_out > 5)
|
||||
*v++ = (vget_lane_s32(vget_low_s32 (result4567), 1) + *wi++ * INT8_MAX) * *scales++;
|
||||
if (num_out > 6)
|
||||
*v++ = (vget_lane_s32(vget_high_s32(result4567), 0) + *wi++ * INT8_MAX) * *scales++;
|
||||
if (num_out > 7)
|
||||
*v = (vget_lane_s32(vget_high_s32(result4567), 1) + *wi * INT8_MAX) * *scales;
|
||||
{
|
||||
int8x8_t bias = vld1_s8(wi); // vw0 = b0 b1 b2 b3 b4 b5 b6 b7
|
||||
int16x8_t scaled_bias = vmull_s8(bias, bias_scale);
|
||||
result0123 = vaddw_s16(result0123, vget_low_s16(scaled_bias));
|
||||
result4567 = vaddw_s16(result4567, vget_high_s16(scaled_bias));
|
||||
*v++ = vget_lane_s32(vget_low_s32 (result0123), 0) * *scales++;
|
||||
if (num_out > 1)
|
||||
*v++ = vget_lane_s32(vget_low_s32 (result0123), 1) * *scales++;
|
||||
if (num_out > 2)
|
||||
*v++ = vget_lane_s32(vget_high_s32(result0123), 0) * *scales++;
|
||||
if (num_out > 3)
|
||||
*v++ = vget_lane_s32(vget_high_s32(result0123), 1) * *scales++;
|
||||
if (num_out > 4)
|
||||
*v++ = vget_lane_s32(vget_low_s32 (result4567), 0) * *scales++;
|
||||
if (num_out > 5)
|
||||
*v++ = vget_lane_s32(vget_low_s32 (result4567), 1) * *scales++;
|
||||
if (num_out > 6)
|
||||
*v++ = vget_lane_s32(vget_high_s32(result4567), 0) * *scales++;
|
||||
if (num_out > 7)
|
||||
*v = vget_lane_s32(vget_high_s32(result4567), 1) * *scales;
|
||||
}
|
||||
}
|
||||
|
||||
static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
|
||||
|
Loading…
Reference in New Issue
Block a user