mirror of
https://github.com/tesseract-ocr/tesseract.git
synced 2024-11-27 12:49:35 +08:00
Rejig intsimdmatrix to reduce FP ops.
Avoid 1) floating point division by 127, 2) conversion of bias to double, 3) FP addition, in favour of 1) integer multiplication by 127, and 2) integer addition. (Also costs extra work in the serialisation/deserialisation of the scale values, and conversion of weights to int formats, but these are all one offs).
This commit is contained in:
parent
aba1800f69
commit
872816897a
@ -316,6 +316,11 @@ class GenericVector {
|
||||
return true;
|
||||
}
|
||||
|
||||
void scale(T factor) const {
|
||||
for (int i = 0; i < size_used_; ++i) {
|
||||
data_[i] *= factor;
|
||||
}
|
||||
}
|
||||
protected:
|
||||
// Internal recursive version of choose_nth_item.
|
||||
int choose_nth_item(int target_index, int start, int end, unsigned int* seed);
|
||||
|
@ -89,7 +89,7 @@ void IntSimdMatrix::MatrixDotVector(const GENERIC_2D_ARRAY<int8_t>& w,
|
||||
int total = 0;
|
||||
for (int j = 0; j < num_in; ++j) total += wi[j] * u[j];
|
||||
// Add in the bias and correct for integer values.
|
||||
v[i] = (static_cast<double>(total) / INT8_MAX + wi[num_in]) * scales[i];
|
||||
v[i] = (total + wi[num_in] * INT8_MAX) * scales[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -94,7 +94,7 @@ static inline void ExtractResults(__m256i& result, __m256i& shift_id,
|
||||
// _mm256_extract_epi32(X, Y) == ((int32_t*)&X)[Y]
|
||||
auto res = ((int32_t*)&result)[0];
|
||||
#endif
|
||||
*v++ = (static_cast<double>(res) / INT8_MAX + *wi++) * *scales++;
|
||||
*v++ = (res + *wi++ * INT8_MAX) * *scales++;
|
||||
// Rotate the results in int32_t units, so the next result is ready.
|
||||
result = _mm256_permutevar8x32_epi32(result, shift_id);
|
||||
}
|
||||
|
@ -112,21 +112,21 @@ PartialMatrixDotVector8(const int8_t* __restrict wi,
|
||||
u += 8;
|
||||
wi += 64;
|
||||
}
|
||||
*v++ = (static_cast<double>(vget_lane_s32(vget_low_s32 (result0123), 0)) / INT8_MAX + *wi++) * *scales++;
|
||||
*v++ = (vget_lane_s32(vget_low_s32 (result0123), 0) + *wi++ * INT8_MAX) * *scales++;
|
||||
if (num_out > 1)
|
||||
*v++ = (static_cast<double>(vget_lane_s32(vget_low_s32 (result0123), 1)) / INT8_MAX + *wi++) * *scales++;
|
||||
*v++ = (vget_lane_s32(vget_low_s32 (result0123), 1) + *wi++ * INT8_MAX) * *scales++;
|
||||
if (num_out > 2)
|
||||
*v++ = (static_cast<double>(vget_lane_s32(vget_high_s32(result0123), 0)) / INT8_MAX + *wi++) * *scales++;
|
||||
*v++ = (vget_lane_s32(vget_high_s32(result0123), 0) + *wi++ * INT8_MAX) * *scales++;
|
||||
if (num_out > 3)
|
||||
*v++ = (static_cast<double>(vget_lane_s32(vget_high_s32(result0123), 1)) / INT8_MAX + *wi++) * *scales++;
|
||||
*v++ = (vget_lane_s32(vget_high_s32(result0123), 1) + *wi++ * INT8_MAX) * *scales++;
|
||||
if (num_out > 4)
|
||||
*v++ = (static_cast<double>(vget_lane_s32(vget_low_s32 (result4567), 0)) / INT8_MAX + *wi++) * *scales++;
|
||||
*v++ = (vget_lane_s32(vget_low_s32 (result4567), 0) + *wi++ * INT8_MAX) * *scales++;
|
||||
if (num_out > 5)
|
||||
*v++ = (static_cast<double>(vget_lane_s32(vget_low_s32 (result4567), 1)) / INT8_MAX + *wi++) * *scales++;
|
||||
*v++ = (vget_lane_s32(vget_low_s32 (result4567), 1) + *wi++ * INT8_MAX) * *scales++;
|
||||
if (num_out > 6)
|
||||
*v++ = (static_cast<double>(vget_lane_s32(vget_high_s32(result4567), 0)) / INT8_MAX + *wi++) * *scales++;
|
||||
*v++ = (vget_lane_s32(vget_high_s32(result4567), 0) + *wi++ * INT8_MAX) * *scales++;
|
||||
if (num_out > 7)
|
||||
*v = (static_cast<double>(vget_lane_s32(vget_high_s32(result4567), 1)) / INT8_MAX + *wi ) * *scales;
|
||||
*v = (vget_lane_s32(vget_high_s32(result4567), 1) + *wi * INT8_MAX) * *scales;
|
||||
}
|
||||
|
||||
static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
|
||||
|
@ -74,7 +74,7 @@ static void PartialMatrixDotVector1(const int8_t* wi, const double* scales,
|
||||
double* v) {
|
||||
double total = IntDotProductSSE(u, wi, num_in);
|
||||
// Add in the bias and correct for integer values.
|
||||
*v = (total / INT8_MAX + wi[num_in]) * *scales;
|
||||
*v = (total + wi[num_in] * INT8_MAX) * *scales;
|
||||
}
|
||||
|
||||
static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
|
||||
|
@ -135,7 +135,7 @@ void WeightMatrix::ConvertToInt() {
|
||||
if (abs_val > max_abs) max_abs = abs_val;
|
||||
}
|
||||
double scale = max_abs / INT8_MAX;
|
||||
scales_[t] = scale;
|
||||
scales_[t] = scale / INT8_MAX;
|
||||
if (scale == 0.0) scale = 1.0;
|
||||
for (int f = 0; f < dim2; ++f) {
|
||||
i_line[f] = IntCastRounded(f_line[f] / scale);
|
||||
@ -177,7 +177,12 @@ bool WeightMatrix::Serialize(bool training, TFile* fp) const {
|
||||
if (!fp->Serialize(&mode)) return false;
|
||||
if (int_mode_) {
|
||||
if (!wi_.Serialize(fp)) return false;
|
||||
/* The scales stored in memory have an extra factor applied to them
|
||||
* to allow faster operation. We have to remove that factor here
|
||||
* before writing to disc, and put it back afterwards. */
|
||||
scales_.scale(INT8_MAX);
|
||||
if (!scales_.Serialize(fp)) return false;
|
||||
scales_.scale(1.0/INT8_MAX);
|
||||
} else {
|
||||
if (!wf_.Serialize(fp)) return false;
|
||||
if (training && !updates_.Serialize(fp)) return false;
|
||||
@ -197,6 +202,7 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) {
|
||||
if (int_mode_) {
|
||||
if (!wi_.DeSerialize(fp)) return false;
|
||||
if (!scales_.DeSerialize(fp)) return false;
|
||||
scales_.scale(1.0/INT8_MAX);
|
||||
if (IntSimdMatrix::intSimdMatrix) {
|
||||
IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_, scales_);
|
||||
}
|
||||
|
@ -68,9 +68,15 @@ class IntSimdMatrixTest : public ::testing::Test {
|
||||
GENERIC_2D_ARRAY<int8_t> w = InitRandom(num_out, num_in + 1);
|
||||
std::vector<int8_t> u = RandomVector(num_in, matrix);
|
||||
GenericVector<double> scales = RandomScales(num_out);
|
||||
std::vector<double> base_result(num_out);
|
||||
scales.scale(1.0/INT8_MAX);
|
||||
int ro = num_out;
|
||||
if (IntSimdMatrix::intSimdMatrix)
|
||||
ro = IntSimdMatrix::intSimdMatrix->RoundOutputs(ro);
|
||||
std::vector<double> base_result(ro);
|
||||
base_result.resize(num_out);
|
||||
IntSimdMatrix::MatrixDotVector(w, scales, u.data(), base_result.data());
|
||||
std::vector<double> test_result(num_out);
|
||||
std::vector<double> test_result(ro);
|
||||
test_result.resize(num_out);
|
||||
std::vector<int8_t> shaped_wi;
|
||||
matrix.Init(w, shaped_wi, scales);
|
||||
if (matrix.matrixDotVectorFunction) {
|
||||
|
Loading…
Reference in New Issue
Block a user