Use constructor with parameters for IntSimdMatrix

Signed-off-by: Stefan Weil <sw@weilnetz.de>
This commit is contained in:
Stefan Weil 2019-01-12 09:00:42 +01:00
parent e237a38405
commit 26be7c5d2e
5 changed files with 37 additions and 30 deletions

View File

@ -36,7 +36,7 @@ const IntSimdMatrix* IntSimdMatrix::GetFastestMultiplier() {
multiplier = new IntSimdMatrixSSE();
} else {
// Default c++ implementation.
multiplier = new IntSimdMatrix();
multiplier = new IntSimdMatrix(1, 1, 1, 1, 1, {});
}
return multiplier;
}

View File

@ -60,14 +60,30 @@ namespace tesseract {
// is required to allow the base class implementation to do all the work.
class IntSimdMatrix {
public:
// Constructor should set the data members to indicate the sizes.
// NOTE: Base constructor public only for test purposes.
IntSimdMatrix()
: num_outputs_per_register_(1),
max_output_registers_(1),
num_inputs_per_register_(1),
num_inputs_per_group_(1),
num_input_groups_(1) {}
// Function to compute part of a matrix.vector multiplication. The weights
// are in a very specific order (see above) in w, which is multiplied by
// u of length num_in, to produce output v after scaling the integer results
// by the corresponding member of scales.
// The amount of w and scales consumed is fixed and not available to the
// caller. The number of outputs written to v will be at most num_out.
typedef void (*PartialFunc)(const int8_t* w, const double* scales,
const int8_t* u, int num_in, int num_out,
double* v);
IntSimdMatrix(int num_outputs_per_register, int max_output_registers, int num_inputs_per_register, int num_inputs_per_group, int num_input_groups, std::vector<PartialFunc> partial_funcs) :
// Number of 32 bit outputs held in each register.
num_outputs_per_register_(num_outputs_per_register),
// Maximum number of registers that we will use to hold outputs.
max_output_registers_(max_output_registers),
// Number of 8 bit inputs in the inputs register.
num_inputs_per_register_(num_inputs_per_register),
// Number of inputs in each weight group.
num_inputs_per_group_(num_inputs_per_group),
// Number of groups of inputs to be broadcast.
num_input_groups_(num_input_groups),
// A series of functions to compute a partial result.
partial_funcs_(partial_funcs)
{}
// Factory makes and returns an IntSimdMatrix (sub)class of the best
// available type for the current architecture.
@ -100,16 +116,6 @@ class IntSimdMatrix {
double* v) const;
protected:
// Function to compute part of a matrix.vector multiplication. The weights
// are in a very specific order (see above) in w, which is multiplied by
// u of length num_in, to produce output v after scaling the integer results
// by the corresponding member of scales.
// The amount of w and scales consumed is fixed and not available to the
// caller. The number of outputs written to v will be at most num_out.
typedef void (*PartialFunc)(const int8_t* w, const double* scales,
const int8_t* u, int num_in, int num_out,
double* v);
// Rounds the input up to a multiple of the given factor.
static int Roundup(int input, int factor) {
return (input + factor - 1) / factor * factor;

View File

@ -269,16 +269,14 @@ static void PartialMatrixDotVector8(const int8_t* wi, const double* scales,
namespace tesseract {
#endif // __AVX2__
IntSimdMatrixAVX2::IntSimdMatrixAVX2() {
IntSimdMatrixAVX2::IntSimdMatrixAVX2()
#ifdef __AVX2__
num_outputs_per_register_ = kNumOutputsPerRegister;
max_output_registers_ = kMaxOutputRegisters;
num_inputs_per_register_ = kNumInputsPerRegister;
num_inputs_per_group_ = kNumInputsPerGroup;
num_input_groups_ = kNumInputGroups;
partial_funcs_ = {PartialMatrixDotVector64, PartialMatrixDotVector32,
PartialMatrixDotVector16, PartialMatrixDotVector8};
: IntSimdMatrix(kNumOutputsPerRegister, kMaxOutputRegisters, kNumInputsPerRegister, kNumInputsPerGroup, kNumInputGroups, {PartialMatrixDotVector64, PartialMatrixDotVector32,
PartialMatrixDotVector16, PartialMatrixDotVector8})
#else
: IntSimdMatrix(1, 1, 1, 1, 1, {})
#endif // __AVX2__
{
}
} // namespace tesseract.

View File

@ -33,10 +33,13 @@ static void PartialMatrixDotVector1(const int8_t* wi, const double* scales,
}
#endif // __SSE4_1__
IntSimdMatrixSSE::IntSimdMatrixSSE() {
IntSimdMatrixSSE::IntSimdMatrixSSE()
#ifdef __SSE4_1__
partial_funcs_ = {PartialMatrixDotVector1};
: IntSimdMatrix(1, 1, 1, 1, 1, {PartialMatrixDotVector1})
#else
: IntSimdMatrix(1, 1, 1, 1, 1, {})
#endif // __SSE4_1__
{
}
} // namespace tesseract.

View File

@ -82,7 +82,7 @@ class IntSimdMatrixTest : public ::testing::Test {
}
TRand random_;
IntSimdMatrix base_;
IntSimdMatrix base_ = IntSimdMatrix(1, 1, 1, 1, 1, {});
};
// Test the C++ implementation without SIMD.