mirror of
https://github.com/tesseract-ocr/tesseract.git
synced 2025-01-22 18:13:42 +08:00
Merge pull request #2994 from robinwatts/pushback11
Improve speed of tesseract by optimising for intSimdMatrix case
This commit is contained in:
commit
514a7893f4
@ -316,6 +316,11 @@ class GenericVector {
|
||||
return true;
|
||||
}
|
||||
|
||||
void scale(T factor) const {
|
||||
for (int i = 0; i < size_used_; ++i) {
|
||||
data_[i] *= factor;
|
||||
}
|
||||
}
|
||||
protected:
|
||||
// Internal recursive version of choose_nth_item.
|
||||
int choose_nth_item(int target_index, int start, int end, unsigned int* seed);
|
||||
|
@ -27,7 +27,8 @@ const IntSimdMatrix* IntSimdMatrix::intSimdMatrix = nullptr;
|
||||
|
||||
// Computes a reshaped copy of the weight matrix w.
|
||||
void IntSimdMatrix::Init(const GENERIC_2D_ARRAY<int8_t>& w,
|
||||
std::vector<int8_t>& shaped_w) const {
|
||||
std::vector<int8_t>& shaped_w,
|
||||
GenericVector<double>& scales) const {
|
||||
const int num_out = w.dim1();
|
||||
const int num_in = w.dim2() - 1;
|
||||
// The rounded-up sizes of the reshaped weight matrix, excluding biases.
|
||||
@ -35,6 +36,7 @@ void IntSimdMatrix::Init(const GENERIC_2D_ARRAY<int8_t>& w,
|
||||
int rounded_num_out = RoundOutputs(num_out);
|
||||
// Add the bias and compute the required size.
|
||||
shaped_w.resize((rounded_num_in + 1) * rounded_num_out, 0);
|
||||
scales.resize_no_init(rounded_num_out);
|
||||
int shaped_index = 0;
|
||||
int output = 0;
|
||||
// Each number of registers needs a different format! Iterates over the
|
||||
@ -87,7 +89,7 @@ void IntSimdMatrix::MatrixDotVector(const GENERIC_2D_ARRAY<int8_t>& w,
|
||||
int total = 0;
|
||||
for (int j = 0; j < num_in; ++j) total += wi[j] * u[j];
|
||||
// Add in the bias and correct for integer values.
|
||||
v[i] = (static_cast<double>(total) / INT8_MAX + wi[num_in]) * scales[i];
|
||||
v[i] = (total + wi[num_in] * INT8_MAX) * scales[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -62,7 +62,8 @@ namespace tesseract {
|
||||
struct IntSimdMatrix {
|
||||
// Computes a reshaped copy of the weight matrix w.
|
||||
void Init(const GENERIC_2D_ARRAY<int8_t>& w,
|
||||
std::vector<int8_t>& shaped_w) const;
|
||||
std::vector<int8_t>& shaped_w,
|
||||
GenericVector<double>& scales) const;
|
||||
|
||||
// Rounds the size up to a multiple of the input register size (in int8_t).
|
||||
int RoundInputs(int size) const {
|
||||
|
@ -80,24 +80,73 @@ static inline void MultiplyGroup(const __m256i& rep_input, const __m256i& ones,
|
||||
result = _mm256_add_epi32(result, weights);
|
||||
}
|
||||
|
||||
// Extracts and converts 8x32-bit results from result, adding the bias from wi
|
||||
// and scaling by scales, before storing in *v. Note that wi, scales and v are
|
||||
// expected to contain 8 consecutive elements or num_out if less.
|
||||
static inline void ExtractResults(__m256i& result, __m256i& shift_id,
|
||||
const int8_t*& wi, const double*& scales,
|
||||
int num_out, double*& v) {
|
||||
for (int out = 0; out < num_out; ++out) {
|
||||
#ifndef _MSC_VER
|
||||
auto res = _mm256_extract_epi32(result, 0);
|
||||
#else
|
||||
// Workaround MSVC's ICE
|
||||
// _mm256_extract_epi32(X, Y) == ((int32_t*)&X)[Y]
|
||||
auto res = ((int32_t*)&result)[0];
|
||||
#endif
|
||||
*v++ = (static_cast<double>(res) / INT8_MAX + *wi++) * *scales++;
|
||||
// Rotate the results in int32_t units, so the next result is ready.
|
||||
result = _mm256_permutevar8x32_epi32(result, shift_id);
|
||||
}
|
||||
// Load 64 bits into the bottom of a 128bit register.
|
||||
// We don't actually care what the top 64bits are, but this ends
|
||||
// up with them being zero.
|
||||
static inline __m128i load64_to_128(const int8_t *wi_)
|
||||
{
|
||||
const int64_t *wi = reinterpret_cast<const int64_t *>(wi_);
|
||||
return _mm_set_epi64x(0, wi[0]);
|
||||
}
|
||||
|
||||
static inline void ExtractResults8(__m256i result,
|
||||
const int8_t* wi,
|
||||
const double* scales,
|
||||
double* v) {
|
||||
__m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
|
||||
__m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
|
||||
__m256i bias_scale = _mm256_set_epi32(127,127,127,127,127,127,127,127);
|
||||
__m256d scale0123 = _mm256_loadu_pd(scales);
|
||||
__m256d scale4567 = _mm256_loadu_pd(scales+4);
|
||||
__m256d bias0123, bias4567, res0123, res4567;
|
||||
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
|
||||
result = _mm256_add_epi32(result, w256); // result += bias * 127
|
||||
res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
|
||||
result = _mm256_permute4x64_epi64(result, 2+(3<<2));
|
||||
res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
|
||||
res0123 = _mm256_mul_pd(res0123, scale0123);
|
||||
res4567 = _mm256_mul_pd(res4567, scale4567);
|
||||
_mm256_storeu_pd(v, res0123);
|
||||
_mm256_storeu_pd(v+4, res4567);
|
||||
}
|
||||
|
||||
static inline void ExtractResults16(__m256i result0,
|
||||
__m256i result1,
|
||||
const int8_t*& wi,
|
||||
const double*& scales,
|
||||
double*& v) {
|
||||
__m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(wi));
|
||||
// 8x8bit vals in bottom of 128bit reg
|
||||
const __m256i bias_scale = _mm256_set_epi32(127,127,127,127,127,127,127,127);
|
||||
__m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
|
||||
__m256d scale0123 = _mm256_loadu_pd(scales);
|
||||
__m256d scale4567 = _mm256_loadu_pd(scales+4);
|
||||
__m256d bias0123, bias4567, res0123, res4567;
|
||||
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
|
||||
result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
|
||||
res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
|
||||
result0 = _mm256_permute4x64_epi64(result0, 2+(3<<2));
|
||||
res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
|
||||
res0123 = _mm256_mul_pd(res0123, scale0123);
|
||||
res4567 = _mm256_mul_pd(res4567, scale4567);
|
||||
_mm256_storeu_pd(v, res0123);
|
||||
_mm256_storeu_pd(v+4, res4567);
|
||||
w8 = _mm_shuffle_epi32(w8,2+(3<<2));
|
||||
w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
|
||||
scale0123 = _mm256_loadu_pd(scales+8);
|
||||
scale4567 = _mm256_loadu_pd(scales+12);
|
||||
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
|
||||
result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
|
||||
res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
|
||||
result1 = _mm256_permute4x64_epi64(result1, 2+(3<<2));
|
||||
res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
|
||||
res0123 = _mm256_mul_pd(res0123, scale0123);
|
||||
res4567 = _mm256_mul_pd(res4567, scale4567);
|
||||
_mm256_storeu_pd(v+8, res0123);
|
||||
_mm256_storeu_pd(v+12, res4567);
|
||||
wi += 16;
|
||||
scales += 16;
|
||||
v += 16;
|
||||
}
|
||||
|
||||
// Computes part of matrix.vector v = Wu. Computes N=64 results.
|
||||
@ -108,7 +157,7 @@ static inline void ExtractResults(__m256i& result, __m256i& shift_id,
|
||||
// u must be padded out with zeros to
|
||||
// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
|
||||
static void PartialMatrixDotVector64(const int8_t* wi, const double* scales,
|
||||
const int8_t* u, int num_in, int num_out,
|
||||
const int8_t* u, int num_in,
|
||||
double* v) {
|
||||
// Register containing 16-bit ones for horizontal add with 16->32 bit
|
||||
// conversion.
|
||||
@ -149,22 +198,16 @@ static void PartialMatrixDotVector64(const int8_t* wi, const double* scales,
|
||||
MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
|
||||
}
|
||||
}
|
||||
ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
ExtractResults(result1, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
ExtractResults(result2, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
ExtractResults(result3, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
ExtractResults(result4, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
ExtractResults(result5, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
ExtractResults(result6, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
num_out -= kNumOutputsPerRegister * 7;
|
||||
ExtractResults(result7, shift_id, wi, scales,
|
||||
std::min(kNumOutputsPerRegister, num_out), v);
|
||||
ExtractResults16(result0, result1, wi, scales, v);
|
||||
ExtractResults16(result2, result3, wi, scales, v);
|
||||
ExtractResults16(result4, result5, wi, scales, v);
|
||||
ExtractResults16(result6, result7, wi, scales, v);
|
||||
}
|
||||
|
||||
// Computes part of matrix.vector v = Wu. Computes N=32 results.
|
||||
// For details see PartialMatrixDotVector64 with N=32.
|
||||
static void PartialMatrixDotVector32(const int8_t* wi, const double* scales,
|
||||
const int8_t* u, int num_in, int num_out,
|
||||
const int8_t* u, int num_in,
|
||||
double* v) {
|
||||
// Register containing 16-bit ones for horizontal add with 16->32 bit
|
||||
// conversion.
|
||||
@ -197,18 +240,14 @@ static void PartialMatrixDotVector32(const int8_t* wi, const double* scales,
|
||||
MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
|
||||
}
|
||||
}
|
||||
ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
ExtractResults(result1, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
ExtractResults(result2, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
num_out -= kNumOutputsPerRegister * 3;
|
||||
ExtractResults(result3, shift_id, wi, scales,
|
||||
std::min(kNumOutputsPerRegister, num_out), v);
|
||||
ExtractResults16(result0, result1, wi, scales, v);
|
||||
ExtractResults16(result2, result3, wi, scales, v);
|
||||
}
|
||||
|
||||
// Computes part of matrix.vector v = Wu. Computes N=16 results.
|
||||
// For details see PartialMatrixDotVector64 with N=16.
|
||||
static void PartialMatrixDotVector16(const int8_t* wi, const double* scales,
|
||||
const int8_t* u, int num_in, int num_out,
|
||||
const int8_t* u, int num_in,
|
||||
double* v) {
|
||||
// Register containing 16-bit ones for horizontal add with 16->32 bit
|
||||
// conversion.
|
||||
@ -237,17 +276,18 @@ static void PartialMatrixDotVector16(const int8_t* wi, const double* scales,
|
||||
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
|
||||
}
|
||||
}
|
||||
ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
||||
num_out -= kNumOutputsPerRegister;
|
||||
ExtractResults(result1, shift_id, wi, scales,
|
||||
std::min(kNumOutputsPerRegister, num_out), v);
|
||||
ExtractResults16(result0, result1, wi, scales, v);
|
||||
}
|
||||
|
||||
// Computes part of matrix.vector v = Wu. Computes N=8 results.
|
||||
// For details see PartialMatrixDotVector64 with N=8.
|
||||
static void PartialMatrixDotVector8(const int8_t* wi, const double* scales,
|
||||
const int8_t* u, int num_in, int num_out,
|
||||
double* v) {
|
||||
static inline void PartialMatrixDotVector8(const int8_t *wi,
|
||||
const double *scales,
|
||||
const int8_t *u,
|
||||
int num_in,
|
||||
double *v) {
|
||||
double *ov = v;
|
||||
double temp[8];
|
||||
// Register containing 16-bit ones for horizontal add with 16->32 bit
|
||||
// conversion.
|
||||
__m256i ones =
|
||||
@ -273,7 +313,7 @@ static void PartialMatrixDotVector8(const int8_t* wi, const double* scales,
|
||||
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
|
||||
}
|
||||
}
|
||||
ExtractResults(result0, shift_id, wi, scales, num_out, v);
|
||||
ExtractResults8(result0, wi, scales, v);
|
||||
}
|
||||
|
||||
static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
|
||||
@ -294,7 +334,7 @@ static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
|
||||
// Run with this group size, until it would produce too much output, then
|
||||
// switch to a smaller size.
|
||||
for (; output + group_size <= rounded_num_out; output += group_size) {
|
||||
PartialMatrixDotVector64(wi, scales, u, rounded_num_in, num_out - output, v);
|
||||
PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
|
||||
wi += w_step;
|
||||
scales += group_size;
|
||||
v += group_size;
|
||||
@ -302,30 +342,28 @@ static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
|
||||
group_size /= 2;
|
||||
w_step /= 2;
|
||||
|
||||
for (; output + group_size <= rounded_num_out; output += group_size) {
|
||||
PartialMatrixDotVector32(wi, scales, u, rounded_num_in, num_out - output, v);
|
||||
if (output + group_size <= rounded_num_out) {
|
||||
PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
|
||||
wi += w_step;
|
||||
scales += group_size;
|
||||
v += group_size;
|
||||
output += group_size;
|
||||
}
|
||||
group_size /= 2;
|
||||
w_step /= 2;
|
||||
|
||||
for (; output + group_size <= rounded_num_out; output += group_size) {
|
||||
PartialMatrixDotVector16(wi, scales, u, rounded_num_in, num_out - output, v);
|
||||
if (output + group_size <= rounded_num_out) {
|
||||
PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
|
||||
wi += w_step;
|
||||
scales += group_size;
|
||||
v += group_size;
|
||||
output += group_size;
|
||||
}
|
||||
group_size /= 2;
|
||||
w_step /= 2;
|
||||
|
||||
for (; output + group_size <= rounded_num_out; output += group_size) {
|
||||
PartialMatrixDotVector8(wi, scales, u, rounded_num_in, num_out - output, v);
|
||||
wi += w_step;
|
||||
scales += group_size;
|
||||
v += group_size;
|
||||
}
|
||||
if (output + group_size <= rounded_num_out)
|
||||
PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
|
||||
}
|
||||
|
||||
const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = {
|
||||
|
@ -61,6 +61,7 @@ PartialMatrixDotVector8(const int8_t* __restrict wi,
|
||||
// Initialize all the results to 0.
|
||||
int32x4_t result0123 = { 0, 0, 0, 0 };
|
||||
int32x4_t result4567 = { 0, 0, 0, 0 };
|
||||
int8x8_t bias_scale = { 127, 127, 127, 127, 127, 127, 127, 127 };
|
||||
// Iterate over the input (u), one registerful at a time.
|
||||
for (int j = 0; j < num_in; j += 8) {
|
||||
int8x8_t vu = vld1_s8(u); // vu = u0 u1 u2 u3 u4 u5 u6 u7
|
||||
@ -112,21 +113,27 @@ PartialMatrixDotVector8(const int8_t* __restrict wi,
|
||||
u += 8;
|
||||
wi += 64;
|
||||
}
|
||||
*v++ = (static_cast<double>(vget_lane_s32(vget_low_s32 (result0123), 0)) / INT8_MAX + *wi++) * *scales++;
|
||||
if (num_out > 1)
|
||||
*v++ = (static_cast<double>(vget_lane_s32(vget_low_s32 (result0123), 1)) / INT8_MAX + *wi++) * *scales++;
|
||||
if (num_out > 2)
|
||||
*v++ = (static_cast<double>(vget_lane_s32(vget_high_s32(result0123), 0)) / INT8_MAX + *wi++) * *scales++;
|
||||
if (num_out > 3)
|
||||
*v++ = (static_cast<double>(vget_lane_s32(vget_high_s32(result0123), 1)) / INT8_MAX + *wi++) * *scales++;
|
||||
if (num_out > 4)
|
||||
*v++ = (static_cast<double>(vget_lane_s32(vget_low_s32 (result4567), 0)) / INT8_MAX + *wi++) * *scales++;
|
||||
if (num_out > 5)
|
||||
*v++ = (static_cast<double>(vget_lane_s32(vget_low_s32 (result4567), 1)) / INT8_MAX + *wi++) * *scales++;
|
||||
if (num_out > 6)
|
||||
*v++ = (static_cast<double>(vget_lane_s32(vget_high_s32(result4567), 0)) / INT8_MAX + *wi++) * *scales++;
|
||||
if (num_out > 7)
|
||||
*v = (static_cast<double>(vget_lane_s32(vget_high_s32(result4567), 1)) / INT8_MAX + *wi ) * *scales;
|
||||
{
|
||||
int8x8_t bias = vld1_s8(wi); // vw0 = b0 b1 b2 b3 b4 b5 b6 b7
|
||||
int16x8_t scaled_bias = vmull_s8(bias, bias_scale);
|
||||
result0123 = vaddw_s16(result0123, vget_low_s16(scaled_bias));
|
||||
result4567 = vaddw_s16(result4567, vget_high_s16(scaled_bias));
|
||||
*v++ = vget_lane_s32(vget_low_s32 (result0123), 0) * *scales++;
|
||||
if (num_out > 1)
|
||||
*v++ = vget_lane_s32(vget_low_s32 (result0123), 1) * *scales++;
|
||||
if (num_out > 2)
|
||||
*v++ = vget_lane_s32(vget_high_s32(result0123), 0) * *scales++;
|
||||
if (num_out > 3)
|
||||
*v++ = vget_lane_s32(vget_high_s32(result0123), 1) * *scales++;
|
||||
if (num_out > 4)
|
||||
*v++ = vget_lane_s32(vget_low_s32 (result4567), 0) * *scales++;
|
||||
if (num_out > 5)
|
||||
*v++ = vget_lane_s32(vget_low_s32 (result4567), 1) * *scales++;
|
||||
if (num_out > 6)
|
||||
*v++ = vget_lane_s32(vget_high_s32(result4567), 0) * *scales++;
|
||||
if (num_out > 7)
|
||||
*v = vget_lane_s32(vget_high_s32(result4567), 1) * *scales;
|
||||
}
|
||||
}
|
||||
|
||||
static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
|
||||
|
@ -74,7 +74,7 @@ static void PartialMatrixDotVector1(const int8_t* wi, const double* scales,
|
||||
double* v) {
|
||||
double total = IntDotProductSSE(u, wi, num_in);
|
||||
// Add in the bias and correct for integer values.
|
||||
*v = (total / INT8_MAX + wi[num_in]) * *scales;
|
||||
*v = (total + wi[num_in] * INT8_MAX) * *scales;
|
||||
}
|
||||
|
||||
static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
|
||||
|
@ -132,8 +132,11 @@ void FullyConnected::Forward(bool debug, const NetworkIO& input,
|
||||
temp_lines.init_to_size(kNumThreads, NetworkScratch::FloatVec());
|
||||
GenericVector<NetworkScratch::FloatVec> curr_input;
|
||||
curr_input.init_to_size(kNumThreads, NetworkScratch::FloatVec());
|
||||
int ro = no_;
|
||||
if (IntSimdMatrix::intSimdMatrix)
|
||||
ro = IntSimdMatrix::intSimdMatrix->RoundOutputs(ro);
|
||||
for (int i = 0; i < kNumThreads; ++i) {
|
||||
temp_lines[i].Init(no_, scratch);
|
||||
temp_lines[i].Init(no_, ro, scratch);
|
||||
curr_input[i].Init(ni_, scratch);
|
||||
}
|
||||
#ifdef _OPENMP
|
||||
|
@ -264,7 +264,10 @@ void LSTM::Forward(bool debug, const NetworkIO& input,
|
||||
ResizeForward(input);
|
||||
// Temporary storage of forward computation for each gate.
|
||||
NetworkScratch::FloatVec temp_lines[WT_COUNT];
|
||||
for (auto & temp_line : temp_lines) temp_line.Init(ns_, scratch);
|
||||
int ro = ns_;
|
||||
if (source_.int_mode() && IntSimdMatrix::intSimdMatrix)
|
||||
ro = IntSimdMatrix::intSimdMatrix->RoundOutputs(ro);
|
||||
for (auto & temp_line : temp_lines) temp_line.Init(ns_, ro, scratch);
|
||||
// Single timestep buffers for the current/recurrent output and state.
|
||||
NetworkScratch::FloatVec curr_state, curr_output;
|
||||
curr_state.Init(ns_, scratch);
|
||||
|
@ -144,15 +144,24 @@ class NetworkScratch {
|
||||
if (scratch_space_ != nullptr) scratch_space_->vec_stack_.Return(vec_);
|
||||
}
|
||||
|
||||
void Init(int size, NetworkScratch* scratch) {
|
||||
void Init(int size, int reserve, NetworkScratch* scratch) {
|
||||
if (scratch_space_ != nullptr && vec_ != nullptr)
|
||||
scratch_space_->vec_stack_.Return(vec_);
|
||||
scratch_space_ = scratch;
|
||||
vec_ = scratch_space_->vec_stack_.Borrow();
|
||||
// Abuse vec_ here; first resize to 'reserve', which is larger
|
||||
// than 'size' (i.e. it's size rounded up) then resize down again
|
||||
// to the desired size. This assumes that the implementation does
|
||||
// not shrink the storage on a resize.
|
||||
vec_->resize_no_init(reserve);
|
||||
vec_->resize_no_init(size);
|
||||
data_ = &(*vec_)[0];
|
||||
}
|
||||
|
||||
void Init(int size, NetworkScratch *scratch) {
|
||||
Init(size, size, scratch);
|
||||
}
|
||||
|
||||
// Use the cast operator instead of operator[] so the FloatVec can be used
|
||||
// as a double* argument to a function call.
|
||||
operator double*() const { return data_; }
|
||||
|
@ -135,7 +135,7 @@ void WeightMatrix::ConvertToInt() {
|
||||
if (abs_val > max_abs) max_abs = abs_val;
|
||||
}
|
||||
double scale = max_abs / INT8_MAX;
|
||||
scales_[t] = scale;
|
||||
scales_[t] = scale / INT8_MAX;
|
||||
if (scale == 0.0) scale = 1.0;
|
||||
for (int f = 0; f < dim2; ++f) {
|
||||
i_line[f] = IntCastRounded(f_line[f] / scale);
|
||||
@ -144,7 +144,7 @@ void WeightMatrix::ConvertToInt() {
|
||||
wf_.Resize(1, 1, 0.0);
|
||||
int_mode_ = true;
|
||||
if (IntSimdMatrix::intSimdMatrix) {
|
||||
IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_);
|
||||
IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_, scales_);
|
||||
}
|
||||
}
|
||||
|
||||
@ -177,7 +177,12 @@ bool WeightMatrix::Serialize(bool training, TFile* fp) const {
|
||||
if (!fp->Serialize(&mode)) return false;
|
||||
if (int_mode_) {
|
||||
if (!wi_.Serialize(fp)) return false;
|
||||
/* The scales stored in memory have an extra factor applied to them
|
||||
* to allow faster operation. We have to remove that factor here
|
||||
* before writing to disc, and put it back afterwards. */
|
||||
scales_.scale(INT8_MAX);
|
||||
if (!scales_.Serialize(fp)) return false;
|
||||
scales_.scale(1.0/INT8_MAX);
|
||||
} else {
|
||||
if (!wf_.Serialize(fp)) return false;
|
||||
if (training && !updates_.Serialize(fp)) return false;
|
||||
@ -197,8 +202,9 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) {
|
||||
if (int_mode_) {
|
||||
if (!wi_.DeSerialize(fp)) return false;
|
||||
if (!scales_.DeSerialize(fp)) return false;
|
||||
scales_.scale(1.0/INT8_MAX);
|
||||
if (IntSimdMatrix::intSimdMatrix) {
|
||||
IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_);
|
||||
IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_, scales_);
|
||||
}
|
||||
} else {
|
||||
if (!wf_.DeSerialize(fp)) return false;
|
||||
|
@ -68,11 +68,17 @@ class IntSimdMatrixTest : public ::testing::Test {
|
||||
GENERIC_2D_ARRAY<int8_t> w = InitRandom(num_out, num_in + 1);
|
||||
std::vector<int8_t> u = RandomVector(num_in, matrix);
|
||||
GenericVector<double> scales = RandomScales(num_out);
|
||||
std::vector<double> base_result(num_out);
|
||||
scales.scale(1.0/INT8_MAX);
|
||||
int ro = num_out;
|
||||
if (IntSimdMatrix::intSimdMatrix)
|
||||
ro = IntSimdMatrix::intSimdMatrix->RoundOutputs(ro);
|
||||
std::vector<double> base_result(ro);
|
||||
base_result.resize(num_out);
|
||||
IntSimdMatrix::MatrixDotVector(w, scales, u.data(), base_result.data());
|
||||
std::vector<double> test_result(num_out);
|
||||
std::vector<double> test_result(ro);
|
||||
test_result.resize(num_out);
|
||||
std::vector<int8_t> shaped_wi;
|
||||
matrix.Init(w, shaped_wi);
|
||||
matrix.Init(w, shaped_wi, scales);
|
||||
if (matrix.matrixDotVectorFunction) {
|
||||
matrix.matrixDotVectorFunction(w.dim1(), w.dim2(), &shaped_wi[0],
|
||||
&scales[0], &u[0], &test_result[0]);
|
||||
|
Loading…
Reference in New Issue
Block a user