mirror of
https://github.com/tesseract-ocr/tesseract.git
synced 2025-01-18 06:30:14 +08:00
Use constructor with parameters for IntSimdMatrix
Signed-off-by: Stefan Weil <sw@weilnetz.de>
This commit is contained in:
parent
e237a38405
commit
26be7c5d2e
@ -36,7 +36,7 @@ const IntSimdMatrix* IntSimdMatrix::GetFastestMultiplier() {
|
||||
multiplier = new IntSimdMatrixSSE();
|
||||
} else {
|
||||
// Default c++ implementation.
|
||||
multiplier = new IntSimdMatrix();
|
||||
multiplier = new IntSimdMatrix(1, 1, 1, 1, 1, {});
|
||||
}
|
||||
return multiplier;
|
||||
}
|
||||
|
@ -60,14 +60,30 @@ namespace tesseract {
|
||||
// is required to allow the base class implementation to do all the work.
|
||||
class IntSimdMatrix {
|
||||
public:
|
||||
// Constructor should set the data members to indicate the sizes.
|
||||
// NOTE: Base constructor public only for test purposes.
|
||||
IntSimdMatrix()
|
||||
: num_outputs_per_register_(1),
|
||||
max_output_registers_(1),
|
||||
num_inputs_per_register_(1),
|
||||
num_inputs_per_group_(1),
|
||||
num_input_groups_(1) {}
|
||||
// Function to compute part of a matrix.vector multiplication. The weights
|
||||
// are in a very specific order (see above) in w, which is multiplied by
|
||||
// u of length num_in, to produce output v after scaling the integer results
|
||||
// by the corresponding member of scales.
|
||||
// The amount of w and scales consumed is fixed and not available to the
|
||||
// caller. The number of outputs written to v will be at most num_out.
|
||||
typedef void (*PartialFunc)(const int8_t* w, const double* scales,
|
||||
const int8_t* u, int num_in, int num_out,
|
||||
double* v);
|
||||
|
||||
IntSimdMatrix(int num_outputs_per_register, int max_output_registers, int num_inputs_per_register, int num_inputs_per_group, int num_input_groups, std::vector<PartialFunc> partial_funcs) :
|
||||
// Number of 32 bit outputs held in each register.
|
||||
num_outputs_per_register_(num_outputs_per_register),
|
||||
// Maximum number of registers that we will use to hold outputs.
|
||||
max_output_registers_(max_output_registers),
|
||||
// Number of 8 bit inputs in the inputs register.
|
||||
num_inputs_per_register_(num_inputs_per_register),
|
||||
// Number of inputs in each weight group.
|
||||
num_inputs_per_group_(num_inputs_per_group),
|
||||
// Number of groups of inputs to be broadcast.
|
||||
num_input_groups_(num_input_groups),
|
||||
// A series of functions to compute a partial result.
|
||||
partial_funcs_(partial_funcs)
|
||||
{}
|
||||
|
||||
// Factory makes and returns an IntSimdMatrix (sub)class of the best
|
||||
// available type for the current architecture.
|
||||
@ -100,16 +116,6 @@ class IntSimdMatrix {
|
||||
double* v) const;
|
||||
|
||||
protected:
|
||||
// Function to compute part of a matrix.vector multiplication. The weights
|
||||
// are in a very specific order (see above) in w, which is multiplied by
|
||||
// u of length num_in, to produce output v after scaling the integer results
|
||||
// by the corresponding member of scales.
|
||||
// The amount of w and scales consumed is fixed and not available to the
|
||||
// caller. The number of outputs written to v will be at most num_out.
|
||||
typedef void (*PartialFunc)(const int8_t* w, const double* scales,
|
||||
const int8_t* u, int num_in, int num_out,
|
||||
double* v);
|
||||
|
||||
// Rounds the input up to a multiple of the given factor.
|
||||
static int Roundup(int input, int factor) {
|
||||
return (input + factor - 1) / factor * factor;
|
||||
|
@ -269,16 +269,14 @@ static void PartialMatrixDotVector8(const int8_t* wi, const double* scales,
|
||||
namespace tesseract {
|
||||
#endif // __AVX2__
|
||||
|
||||
IntSimdMatrixAVX2::IntSimdMatrixAVX2() {
|
||||
IntSimdMatrixAVX2::IntSimdMatrixAVX2()
|
||||
#ifdef __AVX2__
|
||||
num_outputs_per_register_ = kNumOutputsPerRegister;
|
||||
max_output_registers_ = kMaxOutputRegisters;
|
||||
num_inputs_per_register_ = kNumInputsPerRegister;
|
||||
num_inputs_per_group_ = kNumInputsPerGroup;
|
||||
num_input_groups_ = kNumInputGroups;
|
||||
partial_funcs_ = {PartialMatrixDotVector64, PartialMatrixDotVector32,
|
||||
PartialMatrixDotVector16, PartialMatrixDotVector8};
|
||||
: IntSimdMatrix(kNumOutputsPerRegister, kMaxOutputRegisters, kNumInputsPerRegister, kNumInputsPerGroup, kNumInputGroups, {PartialMatrixDotVector64, PartialMatrixDotVector32,
|
||||
PartialMatrixDotVector16, PartialMatrixDotVector8})
|
||||
#else
|
||||
: IntSimdMatrix(1, 1, 1, 1, 1, {})
|
||||
#endif // __AVX2__
|
||||
{
|
||||
}
|
||||
|
||||
} // namespace tesseract.
|
||||
|
@ -33,10 +33,13 @@ static void PartialMatrixDotVector1(const int8_t* wi, const double* scales,
|
||||
}
|
||||
#endif // __SSE4_1__
|
||||
|
||||
IntSimdMatrixSSE::IntSimdMatrixSSE() {
|
||||
IntSimdMatrixSSE::IntSimdMatrixSSE()
|
||||
#ifdef __SSE4_1__
|
||||
partial_funcs_ = {PartialMatrixDotVector1};
|
||||
: IntSimdMatrix(1, 1, 1, 1, 1, {PartialMatrixDotVector1})
|
||||
#else
|
||||
: IntSimdMatrix(1, 1, 1, 1, 1, {})
|
||||
#endif // __SSE4_1__
|
||||
{
|
||||
}
|
||||
|
||||
} // namespace tesseract.
|
||||
|
@ -82,7 +82,7 @@ class IntSimdMatrixTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
TRand random_;
|
||||
IntSimdMatrix base_;
|
||||
IntSimdMatrix base_ = IntSimdMatrix(1, 1, 1, 1, 1, {});
|
||||
};
|
||||
|
||||
// Test the C++ implementation without SIMD.
|
||||
|
Loading…
Reference in New Issue
Block a user