Rejig intsimdmatrix to reduce FP ops.

Avoid 1) floating point division by 127, 2) conversion of
bias to double, 3) FP addition, in favour of 1) integer
multiplication by 127, and 2) integer addition.

(Also costs extra work in the serialisation/deserialisation of
the scale values, and conversion of weights to int formats, but
these are all one offs).
This commit is contained in:
Robin Watts 2020-05-18 15:36:42 +01:00
parent aba1800f69
commit 872816897a
7 changed files with 31 additions and 14 deletions

View File

@ -316,6 +316,11 @@ class GenericVector {
return true;
}
void scale(T factor) const {
for (int i = 0; i < size_used_; ++i) {
data_[i] *= factor;
}
}
protected:
// Internal recursive version of choose_nth_item.
int choose_nth_item(int target_index, int start, int end, unsigned int* seed);

View File

@ -89,7 +89,7 @@ void IntSimdMatrix::MatrixDotVector(const GENERIC_2D_ARRAY<int8_t>& w,
int total = 0;
for (int j = 0; j < num_in; ++j) total += wi[j] * u[j];
// Add in the bias and correct for integer values.
v[i] = (static_cast<double>(total) / INT8_MAX + wi[num_in]) * scales[i];
v[i] = (total + wi[num_in] * INT8_MAX) * scales[i];
}
}

View File

@ -94,7 +94,7 @@ static inline void ExtractResults(__m256i& result, __m256i& shift_id,
// _mm256_extract_epi32(X, Y) == ((int32_t*)&X)[Y]
auto res = ((int32_t*)&result)[0];
#endif
*v++ = (static_cast<double>(res) / INT8_MAX + *wi++) * *scales++;
*v++ = (res + *wi++ * INT8_MAX) * *scales++;
// Rotate the results in int32_t units, so the next result is ready.
result = _mm256_permutevar8x32_epi32(result, shift_id);
}

View File

@ -112,21 +112,21 @@ PartialMatrixDotVector8(const int8_t* __restrict wi,
u += 8;
wi += 64;
}
*v++ = (static_cast<double>(vget_lane_s32(vget_low_s32 (result0123), 0)) / INT8_MAX + *wi++) * *scales++;
*v++ = (vget_lane_s32(vget_low_s32 (result0123), 0) + *wi++ * INT8_MAX) * *scales++;
if (num_out > 1)
*v++ = (static_cast<double>(vget_lane_s32(vget_low_s32 (result0123), 1)) / INT8_MAX + *wi++) * *scales++;
*v++ = (vget_lane_s32(vget_low_s32 (result0123), 1) + *wi++ * INT8_MAX) * *scales++;
if (num_out > 2)
*v++ = (static_cast<double>(vget_lane_s32(vget_high_s32(result0123), 0)) / INT8_MAX + *wi++) * *scales++;
*v++ = (vget_lane_s32(vget_high_s32(result0123), 0) + *wi++ * INT8_MAX) * *scales++;
if (num_out > 3)
*v++ = (static_cast<double>(vget_lane_s32(vget_high_s32(result0123), 1)) / INT8_MAX + *wi++) * *scales++;
*v++ = (vget_lane_s32(vget_high_s32(result0123), 1) + *wi++ * INT8_MAX) * *scales++;
if (num_out > 4)
*v++ = (static_cast<double>(vget_lane_s32(vget_low_s32 (result4567), 0)) / INT8_MAX + *wi++) * *scales++;
*v++ = (vget_lane_s32(vget_low_s32 (result4567), 0) + *wi++ * INT8_MAX) * *scales++;
if (num_out > 5)
*v++ = (static_cast<double>(vget_lane_s32(vget_low_s32 (result4567), 1)) / INT8_MAX + *wi++) * *scales++;
*v++ = (vget_lane_s32(vget_low_s32 (result4567), 1) + *wi++ * INT8_MAX) * *scales++;
if (num_out > 6)
*v++ = (static_cast<double>(vget_lane_s32(vget_high_s32(result4567), 0)) / INT8_MAX + *wi++) * *scales++;
*v++ = (vget_lane_s32(vget_high_s32(result4567), 0) + *wi++ * INT8_MAX) * *scales++;
if (num_out > 7)
*v = (static_cast<double>(vget_lane_s32(vget_high_s32(result4567), 1)) / INT8_MAX + *wi ) * *scales;
*v = (vget_lane_s32(vget_high_s32(result4567), 1) + *wi * INT8_MAX) * *scales;
}
static void matrixDotVector(int dim1, int dim2, const int8_t* wi,

View File

@ -74,7 +74,7 @@ static void PartialMatrixDotVector1(const int8_t* wi, const double* scales,
double* v) {
double total = IntDotProductSSE(u, wi, num_in);
// Add in the bias and correct for integer values.
*v = (total / INT8_MAX + wi[num_in]) * *scales;
*v = (total + wi[num_in] * INT8_MAX) * *scales;
}
static void matrixDotVector(int dim1, int dim2, const int8_t* wi,

View File

@ -135,7 +135,7 @@ void WeightMatrix::ConvertToInt() {
if (abs_val > max_abs) max_abs = abs_val;
}
double scale = max_abs / INT8_MAX;
scales_[t] = scale;
scales_[t] = scale / INT8_MAX;
if (scale == 0.0) scale = 1.0;
for (int f = 0; f < dim2; ++f) {
i_line[f] = IntCastRounded(f_line[f] / scale);
@ -177,7 +177,12 @@ bool WeightMatrix::Serialize(bool training, TFile* fp) const {
if (!fp->Serialize(&mode)) return false;
if (int_mode_) {
if (!wi_.Serialize(fp)) return false;
/* The scales stored in memory have an extra factor applied to them
* to allow faster operation. We have to remove that factor here
* before writing to disc, and put it back afterwards. */
scales_.scale(INT8_MAX);
if (!scales_.Serialize(fp)) return false;
scales_.scale(1.0/INT8_MAX);
} else {
if (!wf_.Serialize(fp)) return false;
if (training && !updates_.Serialize(fp)) return false;
@ -197,6 +202,7 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) {
if (int_mode_) {
if (!wi_.DeSerialize(fp)) return false;
if (!scales_.DeSerialize(fp)) return false;
scales_.scale(1.0/INT8_MAX);
if (IntSimdMatrix::intSimdMatrix) {
IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_, scales_);
}

View File

@ -68,9 +68,15 @@ class IntSimdMatrixTest : public ::testing::Test {
GENERIC_2D_ARRAY<int8_t> w = InitRandom(num_out, num_in + 1);
std::vector<int8_t> u = RandomVector(num_in, matrix);
GenericVector<double> scales = RandomScales(num_out);
std::vector<double> base_result(num_out);
scales.scale(1.0/INT8_MAX);
int ro = num_out;
if (IntSimdMatrix::intSimdMatrix)
ro = IntSimdMatrix::intSimdMatrix->RoundOutputs(ro);
std::vector<double> base_result(ro);
base_result.resize(num_out);
IntSimdMatrix::MatrixDotVector(w, scales, u.data(), base_result.data());
std::vector<double> test_result(num_out);
std::vector<double> test_result(ro);
test_result.resize(num_out);
std::vector<int8_t> shaped_wi;
matrix.Init(w, shaped_wi, scales);
if (matrix.matrixDotVectorFunction) {