2017-09-08 22:06:19 +08:00
|
|
|
///////////////////////////////////////////////////////////////////////
|
|
|
|
// File: intsimdmatrixavx2.cpp
|
|
|
|
// Description: matrix-vector product for 8-bit data on avx2.
|
|
|
|
// Author: Ray Smith
|
|
|
|
// Created: Fri Aug 04 13:26:20 PST 2017
|
|
|
|
//
|
|
|
|
// (C) Copyright 2017, Google Inc.
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
///////////////////////////////////////////////////////////////////////
|
|
|
|
|
2019-01-13 04:30:45 +08:00
|
|
|
#if !defined(__AVX2__)
|
|
|
|
#error Implementation only for AVX2 capable architectures
|
|
|
|
#endif
|
|
|
|
|
2019-01-12 16:44:08 +08:00
|
|
|
#include "intsimdmatrix.h"
|
2017-09-08 22:06:19 +08:00
|
|
|
|
|
|
|
#include <immintrin.h>
|
2018-05-20 05:52:04 +08:00
|
|
|
#include <cstdint>
|
2018-01-24 23:45:15 +08:00
|
|
|
#include <algorithm>
|
2017-09-08 22:06:19 +08:00
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
namespace tesseract {
|
|
|
|
|
|
|
|
// Number of outputs held in each register. 8 x 32 bit ints.
|
|
|
|
constexpr int kNumOutputsPerRegister = 8;
|
|
|
|
// Maximum number of registers that we will use.
|
|
|
|
constexpr int kMaxOutputRegisters = 8;
|
|
|
|
// Number of inputs in the inputs register.
|
|
|
|
constexpr int kNumInputsPerRegister = 32;
|
|
|
|
// Number of inputs in each weight group.
|
|
|
|
constexpr int kNumInputsPerGroup = 4;
|
|
|
|
// Number of groups of inputs to be broadcast.
|
|
|
|
constexpr int kNumInputGroups = kNumInputsPerRegister / kNumInputsPerGroup;
|
|
|
|
|
2019-01-13 06:00:31 +08:00
|
|
|
// Functions to compute part of a matrix.vector multiplication. The weights
|
|
|
|
// are in a very specific order (see above) in w, which is multiplied by
|
|
|
|
// u of length num_in, to produce output v after scaling the integer results
|
|
|
|
// by the corresponding member of scales.
|
|
|
|
// The amount of w and scales consumed is fixed and not available to the
|
|
|
|
// caller. The number of outputs written to v will be at most num_out.
|
|
|
|
|
2017-09-08 22:06:19 +08:00
|
|
|
// Computes one set of 4x8 products of inputs and weights, adding to result.
|
|
|
|
// Horizontally adds 4 adjacent results, making 8x32-bit results.
|
|
|
|
// rep_input is assumed to be an 8x replicated set of 4x8-bit signed integers.
|
|
|
|
// Note that wi must previously have been re-organized with blocks of 4x8
|
|
|
|
// weights in contiguous memory.
|
|
|
|
// ones is a register of 16x16-bit values all equal to 1.
|
|
|
|
// Note: wi is incremented by the amount of data read.
|
|
|
|
// weights and reps are scratch registers.
|
|
|
|
// This function must be inlined with references in order for the compiler to
|
|
|
|
// correctly use the registers declared in the caller.
|
2018-12-01 02:59:02 +08:00
|
|
|
static inline void MultiplyGroup(const __m256i& rep_input, const __m256i& ones,
|
|
|
|
const int8_t*& wi, __m256i& weights,
|
|
|
|
__m256i& reps, __m256i& result) {
|
2017-09-08 22:06:19 +08:00
|
|
|
// Load a 4x8 block of weights.
|
|
|
|
weights = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(wi));
|
|
|
|
wi += kNumInputsPerRegister;
|
|
|
|
// Normalize the signs on rep_input, weights, so weights is always +ve.
|
|
|
|
reps = _mm256_sign_epi8(rep_input, weights);
|
|
|
|
weights = _mm256_sign_epi8(weights, weights);
|
|
|
|
// Multiply 32x8-bit reps by 32x8-bit weights to make 16x16-bit results,
|
|
|
|
// with adjacent pairs added.
|
|
|
|
weights = _mm256_maddubs_epi16(weights, reps);
|
|
|
|
// Multiply 16x16-bit result by 16x16-bit ones to make 8x32-bit results,
|
|
|
|
// with adjacent pairs added. What we really want is a horizontal add of
|
|
|
|
// 16+16=32 bit result, but there is no such instruction, so multiply by
|
|
|
|
// 16-bit ones instead. It is probably faster than all the sign-extending,
|
|
|
|
// permuting and adding that would otherwise be required.
|
|
|
|
weights = _mm256_madd_epi16(weights, ones);
|
|
|
|
result = _mm256_add_epi32(result, weights);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Extracts and converts 8x32-bit results from result, adding the bias from wi
|
|
|
|
// and scaling by scales, before storing in *v. Note that wi, scales and v are
|
|
|
|
// expected to contain 8 consecutive elements or num_out if less.
|
2018-12-01 02:59:02 +08:00
|
|
|
static inline void ExtractResults(__m256i& result, __m256i& shift_id,
|
|
|
|
const int8_t*& wi, const double*& scales,
|
|
|
|
int num_out, double*& v) {
|
2017-09-08 22:06:19 +08:00
|
|
|
for (int out = 0; out < num_out; ++out) {
|
2018-01-24 23:45:15 +08:00
|
|
|
#ifndef _MSC_VER
|
2019-03-26 19:26:14 +08:00
|
|
|
auto res = _mm256_extract_epi32(result, 0);
|
2018-01-24 23:45:15 +08:00
|
|
|
#else
|
2019-03-26 19:26:14 +08:00
|
|
|
// Workaround MSVC's ICE
|
|
|
|
// _mm256_extract_epi32(X, Y) == ((int32_t*)&X)[Y]
|
|
|
|
auto res = ((int32_t*)&result)[0];
|
2018-01-24 23:45:15 +08:00
|
|
|
#endif
|
Use POSIX data types and macros (#878)
* api: Replace Tesseract data types by POSIX data types
Signed-off-by: Stefan Weil <sw@weilnetz.de>
* ccmain: Replace Tesseract data types by POSIX data types
Signed-off-by: Stefan Weil <sw@weilnetz.de>
* ccstruct: Replace Tesseract data types by POSIX data types
Signed-off-by: Stefan Weil <sw@weilnetz.de>
* classify: Replace Tesseract data types by POSIX data types
Signed-off-by: Stefan Weil <sw@weilnetz.de>
* cutil: Replace Tesseract data types by POSIX data types
Signed-off-by: Stefan Weil <sw@weilnetz.de>
* dict: Replace Tesseract data types by POSIX data types
Signed-off-by: Stefan Weil <sw@weilnetz.de>
* textord: Replace Tesseract data types by POSIX data types
Signed-off-by: Stefan Weil <sw@weilnetz.de>
* training: Replace Tesseract data types by POSIX data types
Signed-off-by: Stefan Weil <sw@weilnetz.de>
* wordrec: Replace Tesseract data types by POSIX data types
Signed-off-by: Stefan Weil <sw@weilnetz.de>
* ccutil: Replace Tesseract data types by POSIX data types
Now all Tesseract data types which are no longer needed can be removed
from ccutil/host.h.
Signed-off-by: Stefan Weil <sw@weilnetz.de>
* ccmain: Replace Tesseract's MIN_*INT, MAX_*INT* by POSIX *INT*_MIN, *INT*_MAX
Signed-off-by: Stefan Weil <sw@weilnetz.de>
* ccstruct: Replace Tesseract's MIN_*INT, MAX_*INT* by POSIX *INT*_MIN, *INT*_MAX
Signed-off-by: Stefan Weil <sw@weilnetz.de>
* classify: Replace Tesseract's MIN_*INT, MAX_*INT* by POSIX *INT*_MIN, *INT*_MAX
Signed-off-by: Stefan Weil <sw@weilnetz.de>
* dict: Replace Tesseract's MIN_*INT, MAX_*INT* by POSIX *INT*_MIN, *INT*_MAX
Signed-off-by: Stefan Weil <sw@weilnetz.de>
* lstm: Replace Tesseract's MIN_*INT, MAX_*INT* by POSIX *INT*_MIN, *INT*_MAX
Signed-off-by: Stefan Weil <sw@weilnetz.de>
* textord: Replace Tesseract's MIN_*INT, MAX_*INT* by POSIX *INT*_MIN, *INT*_MAX
Signed-off-by: Stefan Weil <sw@weilnetz.de>
* wordrec: Replace Tesseract's MIN_*INT, MAX_*INT* by POSIX *INT*_MIN, *INT*_MAX
Signed-off-by: Stefan Weil <sw@weilnetz.de>
* ccutil: Replace Tesseract's MIN_*INT, MAX_*INT* by POSIX *INT*_MIN, *INT*_MAX
Remove the macros which are now unused from ccutil/host.h.
Remove also the obsolete history comments.
Signed-off-by: Stefan Weil <sw@weilnetz.de>
* Fix build error caused by ambiguous ClipToRange
Error message vom Appveyor CI:
C:\projects\tesseract\ccstruct\coutln.cpp(818): error C2672: 'ClipToRange': no matching overloaded function found [C:\projects\tesseract\build\libtesseract.vcxproj]
C:\projects\tesseract\ccstruct\coutln.cpp(818): error C2782: 'T ClipToRange(const T &,const T &,const T &)': template parameter 'T' is ambiguous [C:\projects\tesseract\build\libtesseract.vcxproj]
c:\projects\tesseract\ccutil\helpers.h(122): note: see declaration of 'ClipToRange'
C:\projects\tesseract\ccstruct\coutln.cpp(818): note: could be 'char'
C:\projects\tesseract\ccstruct\coutln.cpp(818): note: or 'int'
Signed-off-by: Stefan Weil <sw@weilnetz.de>
* unittest: Replace Tesseract's MAX_INT8 by POSIX INT8_MAX
Signed-off-by: Stefan Weil <sw@weilnetz.de>
* arch: Replace Tesseract's MAX_INT8 by POSIX INT8_MAX
Signed-off-by: Stefan Weil <sw@weilnetz.de>
2018-03-14 04:36:30 +08:00
|
|
|
*v++ = (static_cast<double>(res) / INT8_MAX + *wi++) * *scales++;
|
2017-09-08 22:06:19 +08:00
|
|
|
// Rotate the results in int32_t units, so the next result is ready.
|
|
|
|
result = _mm256_permutevar8x32_epi32(result, shift_id);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Computes part of matrix.vector v = Wu. Computes N=64 results.
|
|
|
|
// The weights *must* be arranged so that consecutive reads from wi
|
|
|
|
// provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
|
|
|
|
// (kNumInputsPerGroup inputs))). After that there must be N consecutive
|
|
|
|
// bias weights, before continuing with any more weights.
|
|
|
|
// u must be padded out with zeros to
|
|
|
|
// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
|
|
|
|
static void PartialMatrixDotVector64(const int8_t* wi, const double* scales,
|
|
|
|
const int8_t* u, int num_in, int num_out,
|
|
|
|
double* v) {
|
|
|
|
// Register containing 16-bit ones for horizontal add with 16->32 bit
|
|
|
|
// conversion.
|
|
|
|
__m256i ones =
|
|
|
|
_mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
|
|
|
|
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
|
|
|
|
// Initialize all the results to 0.
|
|
|
|
__m256i result0 = _mm256_setzero_si256();
|
|
|
|
__m256i result1 = _mm256_setzero_si256();
|
|
|
|
__m256i result2 = _mm256_setzero_si256();
|
|
|
|
__m256i result3 = _mm256_setzero_si256();
|
|
|
|
__m256i result4 = _mm256_setzero_si256();
|
|
|
|
__m256i result5 = _mm256_setzero_si256();
|
|
|
|
__m256i result6 = _mm256_setzero_si256();
|
|
|
|
__m256i result7 = _mm256_setzero_si256();
|
|
|
|
// Iterate over the input (u), one registerful at a time.
|
|
|
|
for (int j = 0; j < num_in;) {
|
|
|
|
__m256i inputs =
|
|
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
|
|
|
|
// Inputs are processed in groups of kNumInputsPerGroup, replicated
|
|
|
|
// kNumInputGroups times.
|
|
|
|
for (int ig = 0; ig < kNumInputGroups && j < num_in;
|
|
|
|
++ig, j += kNumInputsPerGroup) {
|
|
|
|
// Replicate the low 32 bits (4 inputs) 8 times.
|
|
|
|
__m256i rep_input =
|
|
|
|
_mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
|
|
|
|
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
|
|
|
|
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
|
|
|
|
__m256i weights, reps;
|
|
|
|
// Mul-add, with horizontal add of the 4 inputs to each of the results.
|
|
|
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
|
|
|
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
|
|
|
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
|
|
|
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
|
|
|
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
|
|
|
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
|
|
|
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
|
|
|
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
|
|
|
ExtractResults(result1, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
|
|
|
ExtractResults(result2, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
|
|
|
ExtractResults(result3, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
|
|
|
ExtractResults(result4, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
|
|
|
ExtractResults(result5, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
|
|
|
ExtractResults(result6, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
|
|
|
num_out -= kNumOutputsPerRegister * 7;
|
|
|
|
ExtractResults(result7, shift_id, wi, scales,
|
|
|
|
std::min(kNumOutputsPerRegister, num_out), v);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Computes part of matrix.vector v = Wu. Computes N=32 results.
|
|
|
|
// For details see PartialMatrixDotVector64 with N=32.
|
|
|
|
static void PartialMatrixDotVector32(const int8_t* wi, const double* scales,
|
|
|
|
const int8_t* u, int num_in, int num_out,
|
|
|
|
double* v) {
|
|
|
|
// Register containing 16-bit ones for horizontal add with 16->32 bit
|
|
|
|
// conversion.
|
|
|
|
__m256i ones =
|
|
|
|
_mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
|
|
|
|
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
|
|
|
|
// Initialize all the results to 0.
|
|
|
|
__m256i result0 = _mm256_setzero_si256();
|
|
|
|
__m256i result1 = _mm256_setzero_si256();
|
|
|
|
__m256i result2 = _mm256_setzero_si256();
|
|
|
|
__m256i result3 = _mm256_setzero_si256();
|
|
|
|
// Iterate over the input (u), one registerful at a time.
|
|
|
|
for (int j = 0; j < num_in;) {
|
|
|
|
__m256i inputs =
|
|
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
|
|
|
|
// Inputs are processed in groups of kNumInputsPerGroup, replicated
|
|
|
|
// kNumInputGroups times.
|
|
|
|
for (int ig = 0; ig < kNumInputGroups && j < num_in;
|
|
|
|
++ig, j += kNumInputsPerGroup) {
|
|
|
|
// Replicate the low 32 bits (4 inputs) 8 times.
|
|
|
|
__m256i rep_input =
|
|
|
|
_mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
|
|
|
|
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
|
|
|
|
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
|
|
|
|
__m256i weights, reps;
|
|
|
|
// Mul-add, with horizontal add of the 4 inputs to each of the results.
|
|
|
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
|
|
|
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
|
|
|
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
|
|
|
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
|
|
|
ExtractResults(result1, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
|
|
|
ExtractResults(result2, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
|
|
|
num_out -= kNumOutputsPerRegister * 3;
|
|
|
|
ExtractResults(result3, shift_id, wi, scales,
|
|
|
|
std::min(kNumOutputsPerRegister, num_out), v);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Computes part of matrix.vector v = Wu. Computes N=16 results.
|
|
|
|
// For details see PartialMatrixDotVector64 with N=16.
|
|
|
|
static void PartialMatrixDotVector16(const int8_t* wi, const double* scales,
|
|
|
|
const int8_t* u, int num_in, int num_out,
|
|
|
|
double* v) {
|
|
|
|
// Register containing 16-bit ones for horizontal add with 16->32 bit
|
|
|
|
// conversion.
|
|
|
|
__m256i ones =
|
|
|
|
_mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
|
|
|
|
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
|
|
|
|
// Initialize all the results to 0.
|
|
|
|
__m256i result0 = _mm256_setzero_si256();
|
|
|
|
__m256i result1 = _mm256_setzero_si256();
|
|
|
|
// Iterate over the input (u), one registerful at a time.
|
|
|
|
for (int j = 0; j < num_in;) {
|
|
|
|
__m256i inputs =
|
|
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
|
|
|
|
// Inputs are processed in groups of kNumInputsPerGroup, replicated
|
|
|
|
// kNumInputGroups times.
|
|
|
|
for (int ig = 0; ig < kNumInputGroups && j < num_in;
|
|
|
|
++ig, j += kNumInputsPerGroup) {
|
|
|
|
// Replicate the low 32 bits (4 inputs) 8 times.
|
|
|
|
__m256i rep_input =
|
|
|
|
_mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
|
|
|
|
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
|
|
|
|
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
|
|
|
|
__m256i weights, reps;
|
|
|
|
// Mul-add, with horizontal add of the 4 inputs to each of the results.
|
|
|
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
|
|
|
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
|
|
|
|
num_out -= kNumOutputsPerRegister;
|
|
|
|
ExtractResults(result1, shift_id, wi, scales,
|
|
|
|
std::min(kNumOutputsPerRegister, num_out), v);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Computes part of matrix.vector v = Wu. Computes N=8 results.
|
|
|
|
// For details see PartialMatrixDotVector64 with N=8.
|
|
|
|
static void PartialMatrixDotVector8(const int8_t* wi, const double* scales,
|
|
|
|
const int8_t* u, int num_in, int num_out,
|
|
|
|
double* v) {
|
|
|
|
// Register containing 16-bit ones for horizontal add with 16->32 bit
|
|
|
|
// conversion.
|
|
|
|
__m256i ones =
|
|
|
|
_mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
|
|
|
|
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
|
|
|
|
// Initialize all the results to 0.
|
|
|
|
__m256i result0 = _mm256_setzero_si256();
|
|
|
|
// Iterate over the input (u), one registerful at a time.
|
|
|
|
for (int j = 0; j < num_in;) {
|
|
|
|
__m256i inputs =
|
|
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
|
|
|
|
// Inputs are processed in groups of kNumInputsPerGroup, replicated
|
|
|
|
// kNumInputGroups times.
|
|
|
|
for (int ig = 0; ig < kNumInputGroups && j < num_in;
|
|
|
|
++ig, j += kNumInputsPerGroup) {
|
|
|
|
// Replicate the low 32 bits (4 inputs) 8 times.
|
|
|
|
__m256i rep_input =
|
|
|
|
_mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
|
|
|
|
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
|
|
|
|
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
|
|
|
|
__m256i weights, reps;
|
|
|
|
// Mul-add, with horizontal add of the 4 inputs to each of the results.
|
|
|
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ExtractResults(result0, shift_id, wi, scales, num_out, v);
|
|
|
|
}
|
|
|
|
|
2019-01-13 06:00:31 +08:00
|
|
|
static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
|
|
|
|
const double* scales, const int8_t* u, double* v) {
|
|
|
|
const int num_out = dim1;
|
|
|
|
const int num_in = dim2 - 1;
|
|
|
|
// Each call to a partial_func_ produces group_size outputs, except the
|
|
|
|
// last one, which can produce less.
|
|
|
|
const int rounded_num_in =
|
|
|
|
IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup);
|
|
|
|
const int rounded_num_out =
|
|
|
|
IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister);
|
|
|
|
int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
|
|
|
|
int output = 0;
|
|
|
|
|
|
|
|
int w_step = (rounded_num_in + 1) * group_size;
|
|
|
|
|
|
|
|
// Run with this group size, until it would produce too much output, then
|
|
|
|
// switch to a smaller size.
|
|
|
|
for (; output + group_size <= rounded_num_out; output += group_size) {
|
|
|
|
PartialMatrixDotVector64(wi, scales, u, rounded_num_in, num_out - output, v);
|
|
|
|
wi += w_step;
|
|
|
|
scales += group_size;
|
|
|
|
v += group_size;
|
|
|
|
}
|
|
|
|
group_size /= 2;
|
|
|
|
w_step /= 2;
|
|
|
|
|
|
|
|
for (; output + group_size <= rounded_num_out; output += group_size) {
|
|
|
|
PartialMatrixDotVector32(wi, scales, u, rounded_num_in, num_out - output, v);
|
|
|
|
wi += w_step;
|
|
|
|
scales += group_size;
|
|
|
|
v += group_size;
|
|
|
|
}
|
|
|
|
group_size /= 2;
|
|
|
|
w_step /= 2;
|
|
|
|
|
|
|
|
for (; output + group_size <= rounded_num_out; output += group_size) {
|
|
|
|
PartialMatrixDotVector16(wi, scales, u, rounded_num_in, num_out - output, v);
|
|
|
|
wi += w_step;
|
|
|
|
scales += group_size;
|
|
|
|
v += group_size;
|
|
|
|
}
|
|
|
|
group_size /= 2;
|
|
|
|
w_step /= 2;
|
|
|
|
|
|
|
|
for (; output + group_size <= rounded_num_out; output += group_size) {
|
|
|
|
PartialMatrixDotVector8(wi, scales, u, rounded_num_in, num_out - output, v);
|
|
|
|
wi += w_step;
|
|
|
|
scales += group_size;
|
|
|
|
v += group_size;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = {
|
2019-01-14 03:25:58 +08:00
|
|
|
// Function.
|
|
|
|
matrixDotVector,
|
2019-01-13 06:00:31 +08:00
|
|
|
// Number of 32 bit outputs held in each register.
|
|
|
|
kNumOutputsPerRegister,
|
|
|
|
// Maximum number of registers that we will use to hold outputs.
|
|
|
|
kMaxOutputRegisters,
|
|
|
|
// Number of 8 bit inputs in the inputs register.
|
|
|
|
kNumInputsPerRegister,
|
|
|
|
// Number of inputs in each weight group.
|
2019-01-14 03:25:58 +08:00
|
|
|
kNumInputsPerGroup
|
2019-01-13 06:00:31 +08:00
|
|
|
};
|
2017-09-08 22:06:19 +08:00
|
|
|
|
|
|
|
} // namespace tesseract.
|