intsimdmatrixneon.cpp: Do biasing in SIMD.

This commit is contained in:
Robin Watts 2020-05-18 16:06:35 +01:00
parent d1e49d6dd2
commit db10c7b577
2 changed files with 22 additions and 22 deletions

View File

@ -85,15 +85,8 @@ static inline void MultiplyGroup(const __m256i& rep_input, const __m256i& ones,
// up with them being zero.
static inline __m128i load64_to_128(const int8_t *wi_)
{
#if 1
const int64_t *wi = reinterpret_cast<const int64_t *>(wi_);
return _mm_set_epi64x(0, wi[0]);
#else
// This version doesn't work on MSVC 32bits mode.
__m64 wi64 = _m_from_int64(*reinterpret_cast<const __int64*>(wi));
// wi64 = 8 x 8 bit values
return _mm_movpi64_epi64(wi64); // 8x8 bit vals in 128bit reg
#endif
}
static inline void ExtractResults8(__m256i result,

View File

@ -61,6 +61,7 @@ PartialMatrixDotVector8(const int8_t* __restrict wi,
// Initialize all the results to 0.
int32x4_t result0123 = { 0, 0, 0, 0 };
int32x4_t result4567 = { 0, 0, 0, 0 };
int8x8_t bias_scale = { 127, 127, 127, 127, 127, 127, 127, 127 };
// Iterate over the input (u), one registerful at a time.
for (int j = 0; j < num_in; j += 8) {
int8x8_t vu = vld1_s8(u); // vu = u0 u1 u2 u3 u4 u5 u6 u7
@ -112,21 +113,27 @@ PartialMatrixDotVector8(const int8_t* __restrict wi,
u += 8;
wi += 64;
}
*v++ = (vget_lane_s32(vget_low_s32 (result0123), 0) + *wi++ * INT8_MAX) * *scales++;
if (num_out > 1)
*v++ = (vget_lane_s32(vget_low_s32 (result0123), 1) + *wi++ * INT8_MAX) * *scales++;
if (num_out > 2)
*v++ = (vget_lane_s32(vget_high_s32(result0123), 0) + *wi++ * INT8_MAX) * *scales++;
if (num_out > 3)
*v++ = (vget_lane_s32(vget_high_s32(result0123), 1) + *wi++ * INT8_MAX) * *scales++;
if (num_out > 4)
*v++ = (vget_lane_s32(vget_low_s32 (result4567), 0) + *wi++ * INT8_MAX) * *scales++;
if (num_out > 5)
*v++ = (vget_lane_s32(vget_low_s32 (result4567), 1) + *wi++ * INT8_MAX) * *scales++;
if (num_out > 6)
*v++ = (vget_lane_s32(vget_high_s32(result4567), 0) + *wi++ * INT8_MAX) * *scales++;
if (num_out > 7)
*v = (vget_lane_s32(vget_high_s32(result4567), 1) + *wi * INT8_MAX) * *scales;
{
int8x8_t bias = vld1_s8(wi); // vw0 = b0 b1 b2 b3 b4 b5 b6 b7
int16x8_t scaled_bias = vmull_s8(bias, bias_scale);
result0123 = vaddw_s16(result0123, vget_low_s16(scaled_bias));
result4567 = vaddw_s16(result4567, vget_high_s16(scaled_bias));
*v++ = vget_lane_s32(vget_low_s32 (result0123), 0) * *scales++;
if (num_out > 1)
*v++ = vget_lane_s32(vget_low_s32 (result0123), 1) * *scales++;
if (num_out > 2)
*v++ = vget_lane_s32(vget_high_s32(result0123), 0) * *scales++;
if (num_out > 3)
*v++ = vget_lane_s32(vget_high_s32(result0123), 1) * *scales++;
if (num_out > 4)
*v++ = vget_lane_s32(vget_low_s32 (result4567), 0) * *scales++;
if (num_out > 5)
*v++ = vget_lane_s32(vget_low_s32 (result4567), 1) * *scales++;
if (num_out > 6)
*v++ = vget_lane_s32(vget_high_s32(result4567), 0) * *scales++;
if (num_out > 7)
*v = vget_lane_s32(vget_high_s32(result4567), 1) * *scales;
}
}
static void matrixDotVector(int dim1, int dim2, const int8_t* wi,