intsimdmatrixavx2: Do biasing in SIMD.

We also move to relying on both scales and output having been
padded to accomodate us writing more results than are actually
needed here. This was allowed for a few commits back.
This commit is contained in:
Robin Watts 2020-05-11 14:56:53 +01:00
parent 872816897a
commit d1e49d6dd2

View File

@ -80,24 +80,80 @@ static inline void MultiplyGroup(const __m256i& rep_input, const __m256i& ones,
result = _mm256_add_epi32(result, weights);
}
// Extracts and converts 8x32-bit results from result, adding the bias from wi
// and scaling by scales, before storing in *v. Note that wi, scales and v are
// expected to contain 8 consecutive elements or num_out if less.
static inline void ExtractResults(__m256i& result, __m256i& shift_id,
const int8_t*& wi, const double*& scales,
int num_out, double*& v) {
for (int out = 0; out < num_out; ++out) {
#ifndef _MSC_VER
auto res = _mm256_extract_epi32(result, 0);
// Load 64 bits into the bottom of a 128bit register.
// We don't actually care what the top 64bits are, but this ends
// up with them being zero.
static inline __m128i load64_to_128(const int8_t *wi_)
{
#if 1
const int64_t *wi = reinterpret_cast<const int64_t *>(wi_);
return _mm_set_epi64x(0, wi[0]);
#else
// Workaround MSVC's ICE
// _mm256_extract_epi32(X, Y) == ((int32_t*)&X)[Y]
auto res = ((int32_t*)&result)[0];
// This version doesn't work on MSVC 32bits mode.
__m64 wi64 = _m_from_int64(*reinterpret_cast<const __int64*>(wi));
// wi64 = 8 x 8 bit values
return _mm_movpi64_epi64(wi64); // 8x8 bit vals in 128bit reg
#endif
*v++ = (res + *wi++ * INT8_MAX) * *scales++;
// Rotate the results in int32_t units, so the next result is ready.
result = _mm256_permutevar8x32_epi32(result, shift_id);
}
}
static inline void ExtractResults8(__m256i result,
const int8_t* wi,
const double* scales,
double* v) {
__m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
__m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
__m256i bias_scale = _mm256_set_epi32(127,127,127,127,127,127,127,127);
__m256d scale0123 = _mm256_loadu_pd(scales);
__m256d scale4567 = _mm256_loadu_pd(scales+4);
__m256d bias0123, bias4567, res0123, res4567;
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result = _mm256_add_epi32(result, w256); // result += bias * 127
res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
result = _mm256_permute4x64_epi64(result, 2+(3<<2));
res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
res0123 = _mm256_mul_pd(res0123, scale0123);
res4567 = _mm256_mul_pd(res4567, scale4567);
_mm256_storeu_pd(v, res0123);
_mm256_storeu_pd(v+4, res4567);
}
static inline void ExtractResults16(__m256i result0,
__m256i result1,
const int8_t*& wi,
const double*& scales,
double*& v) {
__m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(wi));
// 8x8bit vals in bottom of 128bit reg
const __m256i bias_scale = _mm256_set_epi32(127,127,127,127,127,127,127,127);
__m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
__m256d scale0123 = _mm256_loadu_pd(scales);
__m256d scale4567 = _mm256_loadu_pd(scales+4);
__m256d bias0123, bias4567, res0123, res4567;
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
result0 = _mm256_permute4x64_epi64(result0, 2+(3<<2));
res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
res0123 = _mm256_mul_pd(res0123, scale0123);
res4567 = _mm256_mul_pd(res4567, scale4567);
_mm256_storeu_pd(v, res0123);
_mm256_storeu_pd(v+4, res4567);
w8 = _mm_shuffle_epi32(w8,2+(3<<2));
w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
scale0123 = _mm256_loadu_pd(scales+8);
scale4567 = _mm256_loadu_pd(scales+12);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
result1 = _mm256_permute4x64_epi64(result1, 2+(3<<2));
res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
res0123 = _mm256_mul_pd(res0123, scale0123);
res4567 = _mm256_mul_pd(res4567, scale4567);
_mm256_storeu_pd(v+8, res0123);
_mm256_storeu_pd(v+12, res4567);
wi += 16;
scales += 16;
v += 16;
}
// Computes part of matrix.vector v = Wu. Computes N=64 results.
@ -108,7 +164,7 @@ static inline void ExtractResults(__m256i& result, __m256i& shift_id,
// u must be padded out with zeros to
// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
static void PartialMatrixDotVector64(const int8_t* wi, const double* scales,
const int8_t* u, int num_in, int num_out,
const int8_t* u, int num_in,
double* v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
@ -149,22 +205,16 @@ static void PartialMatrixDotVector64(const int8_t* wi, const double* scales,
MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
}
}
ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
ExtractResults(result1, shift_id, wi, scales, kNumOutputsPerRegister, v);
ExtractResults(result2, shift_id, wi, scales, kNumOutputsPerRegister, v);
ExtractResults(result3, shift_id, wi, scales, kNumOutputsPerRegister, v);
ExtractResults(result4, shift_id, wi, scales, kNumOutputsPerRegister, v);
ExtractResults(result5, shift_id, wi, scales, kNumOutputsPerRegister, v);
ExtractResults(result6, shift_id, wi, scales, kNumOutputsPerRegister, v);
num_out -= kNumOutputsPerRegister * 7;
ExtractResults(result7, shift_id, wi, scales,
std::min(kNumOutputsPerRegister, num_out), v);
ExtractResults16(result0, result1, wi, scales, v);
ExtractResults16(result2, result3, wi, scales, v);
ExtractResults16(result4, result5, wi, scales, v);
ExtractResults16(result6, result7, wi, scales, v);
}
// Computes part of matrix.vector v = Wu. Computes N=32 results.
// For details see PartialMatrixDotVector64 with N=32.
static void PartialMatrixDotVector32(const int8_t* wi, const double* scales,
const int8_t* u, int num_in, int num_out,
const int8_t* u, int num_in,
double* v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
@ -197,18 +247,14 @@ static void PartialMatrixDotVector32(const int8_t* wi, const double* scales,
MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
}
}
ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
ExtractResults(result1, shift_id, wi, scales, kNumOutputsPerRegister, v);
ExtractResults(result2, shift_id, wi, scales, kNumOutputsPerRegister, v);
num_out -= kNumOutputsPerRegister * 3;
ExtractResults(result3, shift_id, wi, scales,
std::min(kNumOutputsPerRegister, num_out), v);
ExtractResults16(result0, result1, wi, scales, v);
ExtractResults16(result2, result3, wi, scales, v);
}
// Computes part of matrix.vector v = Wu. Computes N=16 results.
// For details see PartialMatrixDotVector64 with N=16.
static void PartialMatrixDotVector16(const int8_t* wi, const double* scales,
const int8_t* u, int num_in, int num_out,
const int8_t* u, int num_in,
double* v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
@ -237,17 +283,18 @@ static void PartialMatrixDotVector16(const int8_t* wi, const double* scales,
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
}
}
ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
num_out -= kNumOutputsPerRegister;
ExtractResults(result1, shift_id, wi, scales,
std::min(kNumOutputsPerRegister, num_out), v);
ExtractResults16(result0, result1, wi, scales, v);
}
// Computes part of matrix.vector v = Wu. Computes N=8 results.
// For details see PartialMatrixDotVector64 with N=8.
static void PartialMatrixDotVector8(const int8_t* wi, const double* scales,
const int8_t* u, int num_in, int num_out,
double* v) {
static inline void PartialMatrixDotVector8(const int8_t *wi,
const double *scales,
const int8_t *u,
int num_in,
double *v) {
double *ov = v;
double temp[8];
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones =
@ -273,7 +320,7 @@ static void PartialMatrixDotVector8(const int8_t* wi, const double* scales,
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
}
}
ExtractResults(result0, shift_id, wi, scales, num_out, v);
ExtractResults8(result0, wi, scales, v);
}
static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
@ -294,7 +341,7 @@ static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
// Run with this group size, until it would produce too much output, then
// switch to a smaller size.
for (; output + group_size <= rounded_num_out; output += group_size) {
PartialMatrixDotVector64(wi, scales, u, rounded_num_in, num_out - output, v);
PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
wi += w_step;
scales += group_size;
v += group_size;
@ -302,30 +349,28 @@ static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
group_size /= 2;
w_step /= 2;
for (; output + group_size <= rounded_num_out; output += group_size) {
PartialMatrixDotVector32(wi, scales, u, rounded_num_in, num_out - output, v);
if (output + group_size <= rounded_num_out) {
PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
wi += w_step;
scales += group_size;
v += group_size;
output += group_size;
}
group_size /= 2;
w_step /= 2;
for (; output + group_size <= rounded_num_out; output += group_size) {
PartialMatrixDotVector16(wi, scales, u, rounded_num_in, num_out - output, v);
if (output + group_size <= rounded_num_out) {
PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
wi += w_step;
scales += group_size;
v += group_size;
output += group_size;
}
group_size /= 2;
w_step /= 2;
for (; output + group_size <= rounded_num_out; output += group_size) {
PartialMatrixDotVector8(wi, scales, u, rounded_num_in, num_out - output, v);
wi += w_step;
scales += group_size;
v += group_size;
}
if (output + group_size <= rounded_num_out)
PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
}
const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = {