tesseract/arch/intsimdmatrix.cpp

134 lines
5.4 KiB
C++
Raw Normal View History

///////////////////////////////////////////////////////////////////////
// File: intsimdmatrix.cpp
// Description: Base class for 8-bit int SIMD matrix multipliers.
// Author: Ray Smith
// Created: Tue Aug 15 08:01:32 PST 2017
//
// (C) Copyright 2017, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#include "intsimdmatrix.h"
#include "intsimdmatrixavx2.h"
#include "intsimdmatrixsse.h"
#include "simddetect.h"
namespace tesseract {
// Factory makes and returns an IntSimdMatrix (sub)class of the best
// available type for the current architecture.
/* static */
IntSimdMatrix* IntSimdMatrix::GetFastestMultiplier() {
IntSimdMatrix* multiplier = nullptr;
if (SIMDDetect::IsAVX2Available()) {
multiplier = new IntSimdMatrixAVX2();
} else if (SIMDDetect::IsSSEAvailable()) {
multiplier = new IntSimdMatrixSSE();
} else {
// Default c++ implementation.
multiplier = new IntSimdMatrix();
}
return multiplier;
}
// Computes a reshaped copy of the weight matrix w. If there are no
// partial_funcs_, it does nothing.
void IntSimdMatrix::Init(const GENERIC_2D_ARRAY<int8_t>& w) {
if (partial_funcs_.empty()) return;
int num_out = w.dim1();
int num_in = w.dim2() - 1;
// The rounded-up sizes of the reshaped weight matrix, excluding biases.
int rounded_num_in = Roundup(num_in, num_inputs_per_group_);
int rounded_num_out = RoundOutputs(num_out);
// Add the bias and compute the required size.
shaped_w_.resize((rounded_num_in + 1) * rounded_num_out, 0);
int shaped_index = 0;
int output = 0;
// Each number of registers needs a different format! Iterates over the
// different numbers of registers (each a power of 2).
for (int num_registers = max_output_registers_; num_registers >= 1;
num_registers /= 2) {
// The number of outputs that we will generate with this many registers.
int num_outputs_per_register_set =
num_registers * num_outputs_per_register_;
// Use the max number of registers until we have to go fewer.
while (output + num_outputs_per_register_set <= rounded_num_out) {
// Accumulating outputs in registers saves iterating over the inputs, so
// we only have to do it once per output register set.
for (int input = 0; input < num_in; input += num_inputs_per_group_) {
// Iterate over the number of outputs in a register set.
for (int j = 0; j < num_outputs_per_register_set; ++j) {
// Inner-most loop corresponds to the number of inputs in an input
// group.
for (int i = 0; i < num_inputs_per_group_; ++i) {
int8_t weight = 0;
if (output + j < num_out && input + i < num_in)
weight = w(output + j, input + i);
shaped_w_[shaped_index++] = weight;
}
}
}
// Append the bias weights for the register set.
for (int j = 0; j < num_outputs_per_register_set; ++j) {
int8_t weight = 0;
if (output + j < num_out) weight = w(output + j, num_in);
shaped_w_[shaped_index++] = weight;
}
output += num_outputs_per_register_set;
}
}
}
// Computes matrix.vector v = Wu.
// u is of size W.dim2() - 1 and the output v is of size W.dim1().
// u is imagined to have an extra element at the end with value 1, to
// implement the bias, but it doesn't actually have it.
void IntSimdMatrix::MatrixDotVector(const GENERIC_2D_ARRAY<int8_t>& w,
const GenericVector<double>& scales,
const int8_t* u, double* v) const {
int num_out = w.dim1();
int num_in = w.dim2() - 1;
if (partial_funcs_.empty()) {
// Base implementation.
for (int i = 0; i < num_out; ++i) {
const int8_t* wi = w[i];
int total = 0;
for (int j = 0; j < num_in; ++j) total += wi[j] * u[j];
// Add in the bias and correct for integer values.
Use POSIX data types and macros (#878) * api: Replace Tesseract data types by POSIX data types Signed-off-by: Stefan Weil <sw@weilnetz.de> * ccmain: Replace Tesseract data types by POSIX data types Signed-off-by: Stefan Weil <sw@weilnetz.de> * ccstruct: Replace Tesseract data types by POSIX data types Signed-off-by: Stefan Weil <sw@weilnetz.de> * classify: Replace Tesseract data types by POSIX data types Signed-off-by: Stefan Weil <sw@weilnetz.de> * cutil: Replace Tesseract data types by POSIX data types Signed-off-by: Stefan Weil <sw@weilnetz.de> * dict: Replace Tesseract data types by POSIX data types Signed-off-by: Stefan Weil <sw@weilnetz.de> * textord: Replace Tesseract data types by POSIX data types Signed-off-by: Stefan Weil <sw@weilnetz.de> * training: Replace Tesseract data types by POSIX data types Signed-off-by: Stefan Weil <sw@weilnetz.de> * wordrec: Replace Tesseract data types by POSIX data types Signed-off-by: Stefan Weil <sw@weilnetz.de> * ccutil: Replace Tesseract data types by POSIX data types Now all Tesseract data types which are no longer needed can be removed from ccutil/host.h. Signed-off-by: Stefan Weil <sw@weilnetz.de> * ccmain: Replace Tesseract's MIN_*INT, MAX_*INT* by POSIX *INT*_MIN, *INT*_MAX Signed-off-by: Stefan Weil <sw@weilnetz.de> * ccstruct: Replace Tesseract's MIN_*INT, MAX_*INT* by POSIX *INT*_MIN, *INT*_MAX Signed-off-by: Stefan Weil <sw@weilnetz.de> * classify: Replace Tesseract's MIN_*INT, MAX_*INT* by POSIX *INT*_MIN, *INT*_MAX Signed-off-by: Stefan Weil <sw@weilnetz.de> * dict: Replace Tesseract's MIN_*INT, MAX_*INT* by POSIX *INT*_MIN, *INT*_MAX Signed-off-by: Stefan Weil <sw@weilnetz.de> * lstm: Replace Tesseract's MIN_*INT, MAX_*INT* by POSIX *INT*_MIN, *INT*_MAX Signed-off-by: Stefan Weil <sw@weilnetz.de> * textord: Replace Tesseract's MIN_*INT, MAX_*INT* by POSIX *INT*_MIN, *INT*_MAX Signed-off-by: Stefan Weil <sw@weilnetz.de> * wordrec: Replace Tesseract's MIN_*INT, MAX_*INT* by POSIX *INT*_MIN, *INT*_MAX Signed-off-by: Stefan Weil <sw@weilnetz.de> * ccutil: Replace Tesseract's MIN_*INT, MAX_*INT* by POSIX *INT*_MIN, *INT*_MAX Remove the macros which are now unused from ccutil/host.h. Remove also the obsolete history comments. Signed-off-by: Stefan Weil <sw@weilnetz.de> * Fix build error caused by ambiguous ClipToRange Error message vom Appveyor CI: C:\projects\tesseract\ccstruct\coutln.cpp(818): error C2672: 'ClipToRange': no matching overloaded function found [C:\projects\tesseract\build\libtesseract.vcxproj] C:\projects\tesseract\ccstruct\coutln.cpp(818): error C2782: 'T ClipToRange(const T &,const T &,const T &)': template parameter 'T' is ambiguous [C:\projects\tesseract\build\libtesseract.vcxproj] c:\projects\tesseract\ccutil\helpers.h(122): note: see declaration of 'ClipToRange' C:\projects\tesseract\ccstruct\coutln.cpp(818): note: could be 'char' C:\projects\tesseract\ccstruct\coutln.cpp(818): note: or 'int' Signed-off-by: Stefan Weil <sw@weilnetz.de> * unittest: Replace Tesseract's MAX_INT8 by POSIX INT8_MAX Signed-off-by: Stefan Weil <sw@weilnetz.de> * arch: Replace Tesseract's MAX_INT8 by POSIX INT8_MAX Signed-off-by: Stefan Weil <sw@weilnetz.de>
2018-03-14 04:36:30 +08:00
v[i] = (static_cast<double>(total) / INT8_MAX + wi[num_in]) * scales[i];
}
} else {
const int8_t* w_data = shaped_w_.data();
const double* scales_data = &scales[0];
// Each call to a partial_func_ produces group_size outputs, except the
// last one, which can produce less.
int group_size = num_outputs_per_register_ * max_output_registers_;
int rounded_num_in = Roundup(num_in, num_inputs_per_group_);
int rounded_num_out = RoundOutputs(num_out);
int output = 0;
for (auto fn : partial_funcs_) {
// The amount of w_data consumed by each call to fn.
int w_step = (rounded_num_in + 1) * group_size;
// Run with this group size, until it would produce too much output, then
// switch to a smaller size.
for (; output + group_size <= rounded_num_out; output += group_size) {
(*fn)(w_data, scales_data, u, rounded_num_in, num_out - output, v);
w_data += w_step;
scales_data += group_size;
v += group_size;
}
group_size /= 2;
}
}
}
} // namespace tesseract