mirror of
https://github.com/tesseract-ocr/tesseract.git
synced 2025-06-07 18:02:40 +08:00
Merge 232093f046
into de095fc074
This commit is contained in:
commit
d92d3d9f98
@ -170,6 +170,14 @@ libtesseract_la_LIBADD += libtesseract_avx512.la
|
|||||||
noinst_LTLIBRARIES += libtesseract_avx512.la
|
noinst_LTLIBRARIES += libtesseract_avx512.la
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
if HAVE_AVX512VNNI
|
||||||
|
libtesseract_avx512vnni_la_CXXFLAGS = -mavx512vnni -mavx512vl
|
||||||
|
libtesseract_avx512vnni_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil
|
||||||
|
libtesseract_avx512vnni_la_SOURCES = src/arch/intsimdmatrixavx512vnni.cpp
|
||||||
|
libtesseract_la_LIBADD += libtesseract_avx512vnni.la
|
||||||
|
noinst_LTLIBRARIES += libtesseract_avx512vnni.la
|
||||||
|
endif
|
||||||
|
|
||||||
if HAVE_FMA
|
if HAVE_FMA
|
||||||
libtesseract_fma_la_CXXFLAGS = -mfma
|
libtesseract_fma_la_CXXFLAGS = -mfma
|
||||||
libtesseract_fma_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil
|
libtesseract_fma_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil
|
||||||
|
@ -128,6 +128,7 @@ AX_CHECK_COMPILE_FLAG([-Werror=unused-command-line-argument], [WERROR=-Werror=un
|
|||||||
AM_CONDITIONAL([HAVE_AVX], false)
|
AM_CONDITIONAL([HAVE_AVX], false)
|
||||||
AM_CONDITIONAL([HAVE_AVX2], false)
|
AM_CONDITIONAL([HAVE_AVX2], false)
|
||||||
AM_CONDITIONAL([HAVE_AVX512F], false)
|
AM_CONDITIONAL([HAVE_AVX512F], false)
|
||||||
|
AM_CONDITIONAL([HAVE_AVX512VNNI], false)
|
||||||
AM_CONDITIONAL([HAVE_FMA], false)
|
AM_CONDITIONAL([HAVE_FMA], false)
|
||||||
AM_CONDITIONAL([HAVE_SSE4_1], false)
|
AM_CONDITIONAL([HAVE_SSE4_1], false)
|
||||||
AM_CONDITIONAL([HAVE_NEON], false)
|
AM_CONDITIONAL([HAVE_NEON], false)
|
||||||
@ -155,6 +156,12 @@ case "${host_cpu}" in
|
|||||||
AC_DEFINE([HAVE_AVX512F], [1], [Enable AVX512F instructions])
|
AC_DEFINE([HAVE_AVX512F], [1], [Enable AVX512F instructions])
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
AX_CHECK_COMPILE_FLAG([-mavx512vnni], [avx512vnni=true], [avx512vnni=false], [$WERROR])
|
||||||
|
AM_CONDITIONAL([HAVE_AVX512VNNI], $avx512vnni)
|
||||||
|
if $avx512vnni; then
|
||||||
|
AC_DEFINE([HAVE_AVX512VNNI], [1], [Enable AVX512VNNI instructions])
|
||||||
|
fi
|
||||||
|
|
||||||
AX_CHECK_COMPILE_FLAG([-mfma], [fma=true], [fma=false], [$WERROR])
|
AX_CHECK_COMPILE_FLAG([-mfma], [fma=true], [fma=false], [$WERROR])
|
||||||
AM_CONDITIONAL([HAVE_FMA], $fma)
|
AM_CONDITIONAL([HAVE_FMA], $fma)
|
||||||
if $fma; then
|
if $fma; then
|
||||||
|
@ -119,6 +119,7 @@ struct TESS_API IntSimdMatrix {
|
|||||||
static const IntSimdMatrix intSimdMatrixRVV;
|
static const IntSimdMatrix intSimdMatrixRVV;
|
||||||
// Only available with AVX2 / AVX / FMA / SSE.
|
// Only available with AVX2 / AVX / FMA / SSE.
|
||||||
static const IntSimdMatrix intSimdMatrixAVX2;
|
static const IntSimdMatrix intSimdMatrixAVX2;
|
||||||
|
static const IntSimdMatrix intSimdMatrixAVX512VNNI;
|
||||||
static const IntSimdMatrix intSimdMatrixSSE;
|
static const IntSimdMatrix intSimdMatrixSSE;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
590
src/arch/intsimdmatrixavx512vnni.cpp
Normal file
590
src/arch/intsimdmatrixavx512vnni.cpp
Normal file
@ -0,0 +1,590 @@
|
|||||||
|
///////////////////////////////////////////////////////////////////////
|
||||||
|
// File: intsimdmatrixavx512vnni.cpp
|
||||||
|
// Description: matrix-vector product for 8-bit data on avx512vnni.
|
||||||
|
// Author: Ray Smith
|
||||||
|
//
|
||||||
|
// (C) Copyright 2017, Google Inc.
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
///////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#include "intsimdmatrix.h"
|
||||||
|
|
||||||
|
#if !defined(__AVX512VNNI__) || !defined(__AVX512VL__)
|
||||||
|
# if defined(__i686__) || defined(__x86_64__)
|
||||||
|
# error Implementation only for AVX512VNNI capable architectures
|
||||||
|
# endif
|
||||||
|
#else
|
||||||
|
# include <immintrin.h>
|
||||||
|
# include <algorithm>
|
||||||
|
# include <cstdint>
|
||||||
|
# include <vector>
|
||||||
|
|
||||||
|
# if defined(_MSC_VER) && _MSC_VER >= 1925 && _MSC_VER <= 1929 && \
|
||||||
|
defined(_WIN32) && !defined(_WIN64)
|
||||||
|
// Optimize for size (/Os) instead of using the default optimization for some
|
||||||
|
// versions of the 32 bit Visual Studio compiler which generate buggy code.
|
||||||
|
# pragma optimize("", off)
|
||||||
|
# pragma optimize("s", on)
|
||||||
|
# endif
|
||||||
|
|
||||||
|
namespace tesseract {
|
||||||
|
|
||||||
|
// Number of outputs held in each register. 8 x 32 bit ints.
|
||||||
|
constexpr int kNumOutputsPerRegister = 8;
|
||||||
|
// Maximum number of registers that we will use.
|
||||||
|
constexpr int kMaxOutputRegisters = 8;
|
||||||
|
// Number of inputs in the inputs register.
|
||||||
|
constexpr int kNumInputsPerRegister = 32;
|
||||||
|
// Number of inputs in each weight group.
|
||||||
|
constexpr int kNumInputsPerGroup = 4;
|
||||||
|
// Number of groups of inputs to be broadcast.
|
||||||
|
constexpr int kNumInputGroups = kNumInputsPerRegister / kNumInputsPerGroup;
|
||||||
|
|
||||||
|
// Functions to compute part of a matrix.vector multiplication. The weights
|
||||||
|
// are in a very specific order (see above) in w, which is multiplied by
|
||||||
|
// u of length num_in, to produce output v after scaling the integer results
|
||||||
|
// by the corresponding member of scales.
|
||||||
|
// The amount of w and scales consumed is fixed and not available to the
|
||||||
|
// caller. The number of outputs written to v will be at most num_out.
|
||||||
|
|
||||||
|
// Computes one set of 4x8 products of inputs and weights, adding to result.
|
||||||
|
// Horizontally adds 4 adjacent results, making 8x32-bit results.
|
||||||
|
// rep_input is assumed to be an 8x replicated set of 4x8-bit signed integers.
|
||||||
|
// Note that wi must previously have been re-organized with blocks of 4x8
|
||||||
|
// weights in contiguous memory.
|
||||||
|
// ones is a register of 16x16-bit values all equal to 1.
|
||||||
|
// Note: wi is incremented by the amount of data read.
|
||||||
|
// weights and reps are scratch registers.
|
||||||
|
// This function must be inlined with references in order for the compiler to
|
||||||
|
// correctly use the registers declared in the caller.
|
||||||
|
static inline void MultiplyGroup(const __m256i &rep_input, const __m256i &ones, const int8_t *&wi,
|
||||||
|
__m256i &weights, __m256i &reps, __m256i &result) {
|
||||||
|
// Load a 4x8 block of weights.
|
||||||
|
weights = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(wi));
|
||||||
|
wi += kNumInputsPerRegister;
|
||||||
|
// Normalize the signs on rep_input, weights, so weights is always +ve.
|
||||||
|
reps = _mm256_sign_epi8(rep_input, weights);
|
||||||
|
weights = _mm256_sign_epi8(weights, weights);
|
||||||
|
|
||||||
|
// VNNI instruction. It replaces 3 AVX2 instructions:
|
||||||
|
//weights = _mm256_maddubs_epi16(weights, reps);
|
||||||
|
//weights = _mm256_madd_epi16(weights, ones);
|
||||||
|
//result = _mm256_add_epi32(result, weights);
|
||||||
|
result = _mm256_dpbusd_epi32(result, weights, reps);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load 64 bits into the bottom of a 128bit register.
|
||||||
|
// We don't actually care what the top 64bits are, but this ends
|
||||||
|
// up with them being zero.
|
||||||
|
static inline __m128i load64_to_128(const int8_t *wi_) {
|
||||||
|
const auto *wi = reinterpret_cast<const int64_t *>(wi_);
|
||||||
|
return _mm_set_epi64x(0, wi[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(FAST_FLOAT)
|
||||||
|
|
||||||
|
static inline void ExtractResults8(__m256i result, const int8_t *wi,
|
||||||
|
const float *scales, float *v) {
|
||||||
|
__m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
|
||||||
|
__m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
|
||||||
|
__m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
|
||||||
|
__m256 scale01234567 = _mm256_loadu_ps(scales);
|
||||||
|
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
|
||||||
|
result = _mm256_add_epi32(result, w256); // result += bias * 127
|
||||||
|
__m256 res01234567 = _mm256_cvtepi32_ps(result);
|
||||||
|
result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
|
||||||
|
res01234567 = _mm256_mul_ps(res01234567, scale01234567);
|
||||||
|
_mm256_storeu_ps(v, res01234567);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void ExtractResults16(__m256i result0, __m256i result1,
|
||||||
|
const int8_t *&wi, const float *&scales,
|
||||||
|
float *&v) {
|
||||||
|
__m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
|
||||||
|
// 8x8bit vals in bottom of 128bit reg
|
||||||
|
const __m256i bias_scale =
|
||||||
|
_mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
|
||||||
|
__m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
|
||||||
|
__m256 scale01234567 = _mm256_loadu_ps(scales);
|
||||||
|
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
|
||||||
|
result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
|
||||||
|
__m256 res01234567 = _mm256_cvtepi32_ps(result0);
|
||||||
|
result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
|
||||||
|
res01234567 = _mm256_mul_ps(res01234567, scale01234567);
|
||||||
|
_mm256_storeu_ps(v, res01234567);
|
||||||
|
w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
|
||||||
|
w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
|
||||||
|
scale01234567 = _mm256_loadu_ps(scales + 8);
|
||||||
|
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
|
||||||
|
result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
|
||||||
|
res01234567 = _mm256_cvtepi32_ps(result1);
|
||||||
|
result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
|
||||||
|
res01234567 = _mm256_mul_ps(res01234567, scale01234567);
|
||||||
|
_mm256_storeu_ps(v + 8, res01234567);
|
||||||
|
wi += 16;
|
||||||
|
scales += 16;
|
||||||
|
v += 16;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Computes part of matrix.vector v = Wu. Computes N=64 results.
|
||||||
|
// The weights *must* be arranged so that consecutive reads from wi
|
||||||
|
// provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
|
||||||
|
// (kNumInputsPerGroup inputs))). After that there must be N consecutive
|
||||||
|
// bias weights, before continuing with any more weights.
|
||||||
|
// u must be padded out with zeros to
|
||||||
|
// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
|
||||||
|
static void PartialMatrixDotVector64(const int8_t *wi, const float *scales, const int8_t *u,
|
||||||
|
int num_in, float *v) {
|
||||||
|
// Register containing 16-bit ones for horizontal add with 16->32 bit
|
||||||
|
// conversion.
|
||||||
|
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
|
||||||
|
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
|
||||||
|
// Initialize all the results to 0.
|
||||||
|
__m256i result0 = _mm256_setzero_si256();
|
||||||
|
__m256i result1 = _mm256_setzero_si256();
|
||||||
|
__m256i result2 = _mm256_setzero_si256();
|
||||||
|
__m256i result3 = _mm256_setzero_si256();
|
||||||
|
__m256i result4 = _mm256_setzero_si256();
|
||||||
|
__m256i result5 = _mm256_setzero_si256();
|
||||||
|
__m256i result6 = _mm256_setzero_si256();
|
||||||
|
__m256i result7 = _mm256_setzero_si256();
|
||||||
|
// Iterate over the input (u), one registerful at a time.
|
||||||
|
for (int j = 0; j < num_in;) {
|
||||||
|
__m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
|
||||||
|
// Inputs are processed in groups of kNumInputsPerGroup, replicated
|
||||||
|
// kNumInputGroups times.
|
||||||
|
for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
|
||||||
|
// Replicate the low 32 bits (4 inputs) 8 times.
|
||||||
|
__m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
|
||||||
|
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
|
||||||
|
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
|
||||||
|
__m256i weights, reps;
|
||||||
|
// Mul-add, with horizontal add of the 4 inputs to each of the results.
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExtractResults16(result0, result1, wi, scales, v);
|
||||||
|
ExtractResults16(result2, result3, wi, scales, v);
|
||||||
|
ExtractResults16(result4, result5, wi, scales, v);
|
||||||
|
ExtractResults16(result6, result7, wi, scales, v);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Computes part of matrix.vector v = Wu. Computes N=32 results.
|
||||||
|
// For details see PartialMatrixDotVector64 with N=32.
|
||||||
|
static void PartialMatrixDotVector32(const int8_t *wi, const float *scales, const int8_t *u,
|
||||||
|
int num_in, float *v) {
|
||||||
|
// Register containing 16-bit ones for horizontal add with 16->32 bit
|
||||||
|
// conversion.
|
||||||
|
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
|
||||||
|
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
|
||||||
|
// Initialize all the results to 0.
|
||||||
|
__m256i result0 = _mm256_setzero_si256();
|
||||||
|
__m256i result1 = _mm256_setzero_si256();
|
||||||
|
__m256i result2 = _mm256_setzero_si256();
|
||||||
|
__m256i result3 = _mm256_setzero_si256();
|
||||||
|
// Iterate over the input (u), one registerful at a time.
|
||||||
|
for (int j = 0; j < num_in;) {
|
||||||
|
__m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
|
||||||
|
// Inputs are processed in groups of kNumInputsPerGroup, replicated
|
||||||
|
// kNumInputGroups times.
|
||||||
|
for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
|
||||||
|
// Replicate the low 32 bits (4 inputs) 8 times.
|
||||||
|
__m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
|
||||||
|
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
|
||||||
|
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
|
||||||
|
__m256i weights, reps;
|
||||||
|
// Mul-add, with horizontal add of the 4 inputs to each of the results.
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExtractResults16(result0, result1, wi, scales, v);
|
||||||
|
ExtractResults16(result2, result3, wi, scales, v);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Computes part of matrix.vector v = Wu. Computes N=16 results.
|
||||||
|
// For details see PartialMatrixDotVector64 with N=16.
|
||||||
|
static void PartialMatrixDotVector16(const int8_t *wi, const float *scales, const int8_t *u,
|
||||||
|
int num_in, float *v) {
|
||||||
|
// Register containing 16-bit ones for horizontal add with 16->32 bit
|
||||||
|
// conversion.
|
||||||
|
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
|
||||||
|
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
|
||||||
|
// Initialize all the results to 0.
|
||||||
|
__m256i result0 = _mm256_setzero_si256();
|
||||||
|
__m256i result1 = _mm256_setzero_si256();
|
||||||
|
// Iterate over the input (u), one registerful at a time.
|
||||||
|
for (int j = 0; j < num_in;) {
|
||||||
|
__m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
|
||||||
|
// Inputs are processed in groups of kNumInputsPerGroup, replicated
|
||||||
|
// kNumInputGroups times.
|
||||||
|
for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
|
||||||
|
// Replicate the low 32 bits (4 inputs) 8 times.
|
||||||
|
__m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
|
||||||
|
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
|
||||||
|
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
|
||||||
|
__m256i weights, reps;
|
||||||
|
// Mul-add, with horizontal add of the 4 inputs to each of the results.
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExtractResults16(result0, result1, wi, scales, v);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Computes part of matrix.vector v = Wu. Computes N=8 results.
|
||||||
|
// For details see PartialMatrixDotVector64 with N=8.
|
||||||
|
static inline void PartialMatrixDotVector8(const int8_t *wi, const float *scales, const int8_t *u,
|
||||||
|
int num_in, float *v) {
|
||||||
|
// Register containing 16-bit ones for horizontal add with 16->32 bit
|
||||||
|
// conversion.
|
||||||
|
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
|
||||||
|
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
|
||||||
|
// Initialize all the results to 0.
|
||||||
|
__m256i result0 = _mm256_setzero_si256();
|
||||||
|
// Iterate over the input (u), one registerful at a time.
|
||||||
|
for (int j = 0; j < num_in;) {
|
||||||
|
__m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
|
||||||
|
// Inputs are processed in groups of kNumInputsPerGroup, replicated
|
||||||
|
// kNumInputGroups times.
|
||||||
|
for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
|
||||||
|
// Replicate the low 32 bits (4 inputs) 8 times.
|
||||||
|
__m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
|
||||||
|
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
|
||||||
|
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
|
||||||
|
__m256i weights, reps;
|
||||||
|
// Mul-add, with horizontal add of the 4 inputs to each of the results.
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExtractResults8(result0, wi, scales, v);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const float *scales,
|
||||||
|
const int8_t *u, float *v) {
|
||||||
|
const int num_out = dim1;
|
||||||
|
const int num_in = dim2 - 1;
|
||||||
|
// Each call to a partial_func_ produces group_size outputs, except the
|
||||||
|
// last one, which can produce less.
|
||||||
|
const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup);
|
||||||
|
const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister);
|
||||||
|
int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
|
||||||
|
int output = 0;
|
||||||
|
|
||||||
|
int w_step = (rounded_num_in + 1) * group_size;
|
||||||
|
|
||||||
|
// Run with this group size, until it would produce too much output, then
|
||||||
|
// switch to a smaller size.
|
||||||
|
for (; output + group_size <= rounded_num_out; output += group_size) {
|
||||||
|
PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
|
||||||
|
wi += w_step;
|
||||||
|
scales += group_size;
|
||||||
|
v += group_size;
|
||||||
|
}
|
||||||
|
group_size /= 2;
|
||||||
|
w_step /= 2;
|
||||||
|
|
||||||
|
if (output + group_size <= rounded_num_out) {
|
||||||
|
PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
|
||||||
|
wi += w_step;
|
||||||
|
scales += group_size;
|
||||||
|
v += group_size;
|
||||||
|
output += group_size;
|
||||||
|
}
|
||||||
|
group_size /= 2;
|
||||||
|
w_step /= 2;
|
||||||
|
|
||||||
|
if (output + group_size <= rounded_num_out) {
|
||||||
|
PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
|
||||||
|
wi += w_step;
|
||||||
|
scales += group_size;
|
||||||
|
v += group_size;
|
||||||
|
output += group_size;
|
||||||
|
}
|
||||||
|
group_size /= 2;
|
||||||
|
w_step /= 2;
|
||||||
|
|
||||||
|
if (output + group_size <= rounded_num_out) {
|
||||||
|
PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
static inline void ExtractResults8(__m256i result, const int8_t *wi, const double *scales,
|
||||||
|
double *v) {
|
||||||
|
__m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
|
||||||
|
__m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
|
||||||
|
__m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
|
||||||
|
__m256d scale0123 = _mm256_loadu_pd(scales);
|
||||||
|
__m256d scale4567 = _mm256_loadu_pd(scales + 4);
|
||||||
|
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
|
||||||
|
result = _mm256_add_epi32(result, w256); // result += bias * 127
|
||||||
|
__m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
|
||||||
|
result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
|
||||||
|
__m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
|
||||||
|
res0123 = _mm256_mul_pd(res0123, scale0123);
|
||||||
|
res4567 = _mm256_mul_pd(res4567, scale4567);
|
||||||
|
_mm256_storeu_pd(v, res0123);
|
||||||
|
_mm256_storeu_pd(v + 4, res4567);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi,
|
||||||
|
const double *&scales, double *&v) {
|
||||||
|
__m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
|
||||||
|
// 8x8bit vals in bottom of 128bit reg
|
||||||
|
const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
|
||||||
|
__m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
|
||||||
|
__m256d scale0123 = _mm256_loadu_pd(scales);
|
||||||
|
__m256d scale4567 = _mm256_loadu_pd(scales + 4);
|
||||||
|
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
|
||||||
|
result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
|
||||||
|
__m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
|
||||||
|
result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
|
||||||
|
__m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
|
||||||
|
res0123 = _mm256_mul_pd(res0123, scale0123);
|
||||||
|
res4567 = _mm256_mul_pd(res4567, scale4567);
|
||||||
|
_mm256_storeu_pd(v, res0123);
|
||||||
|
_mm256_storeu_pd(v + 4, res4567);
|
||||||
|
w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
|
||||||
|
w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
|
||||||
|
scale0123 = _mm256_loadu_pd(scales + 8);
|
||||||
|
scale4567 = _mm256_loadu_pd(scales + 12);
|
||||||
|
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
|
||||||
|
result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
|
||||||
|
res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
|
||||||
|
result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
|
||||||
|
res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
|
||||||
|
res0123 = _mm256_mul_pd(res0123, scale0123);
|
||||||
|
res4567 = _mm256_mul_pd(res4567, scale4567);
|
||||||
|
_mm256_storeu_pd(v + 8, res0123);
|
||||||
|
_mm256_storeu_pd(v + 12, res4567);
|
||||||
|
wi += 16;
|
||||||
|
scales += 16;
|
||||||
|
v += 16;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Computes part of matrix.vector v = Wu. Computes N=64 results.
|
||||||
|
// The weights *must* be arranged so that consecutive reads from wi
|
||||||
|
// provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
|
||||||
|
// (kNumInputsPerGroup inputs))). After that there must be N consecutive
|
||||||
|
// bias weights, before continuing with any more weights.
|
||||||
|
// u must be padded out with zeros to
|
||||||
|
// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
|
||||||
|
static void PartialMatrixDotVector64(const int8_t *wi, const double *scales, const int8_t *u,
|
||||||
|
int num_in, double *v) {
|
||||||
|
// Register containing 16-bit ones for horizontal add with 16->32 bit
|
||||||
|
// conversion.
|
||||||
|
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
|
||||||
|
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
|
||||||
|
// Initialize all the results to 0.
|
||||||
|
__m256i result0 = _mm256_setzero_si256();
|
||||||
|
__m256i result1 = _mm256_setzero_si256();
|
||||||
|
__m256i result2 = _mm256_setzero_si256();
|
||||||
|
__m256i result3 = _mm256_setzero_si256();
|
||||||
|
__m256i result4 = _mm256_setzero_si256();
|
||||||
|
__m256i result5 = _mm256_setzero_si256();
|
||||||
|
__m256i result6 = _mm256_setzero_si256();
|
||||||
|
__m256i result7 = _mm256_setzero_si256();
|
||||||
|
// Iterate over the input (u), one registerful at a time.
|
||||||
|
for (int j = 0; j < num_in;) {
|
||||||
|
__m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
|
||||||
|
// Inputs are processed in groups of kNumInputsPerGroup, replicated
|
||||||
|
// kNumInputGroups times.
|
||||||
|
for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
|
||||||
|
// Replicate the low 32 bits (4 inputs) 8 times.
|
||||||
|
__m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
|
||||||
|
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
|
||||||
|
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
|
||||||
|
__m256i weights, reps;
|
||||||
|
// Mul-add, with horizontal add of the 4 inputs to each of the results.
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExtractResults16(result0, result1, wi, scales, v);
|
||||||
|
ExtractResults16(result2, result3, wi, scales, v);
|
||||||
|
ExtractResults16(result4, result5, wi, scales, v);
|
||||||
|
ExtractResults16(result6, result7, wi, scales, v);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Computes part of matrix.vector v = Wu. Computes N=32 results.
|
||||||
|
// For details see PartialMatrixDotVector64 with N=32.
|
||||||
|
static void PartialMatrixDotVector32(const int8_t *wi, const double *scales, const int8_t *u,
|
||||||
|
int num_in, double *v) {
|
||||||
|
// Register containing 16-bit ones for horizontal add with 16->32 bit
|
||||||
|
// conversion.
|
||||||
|
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
|
||||||
|
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
|
||||||
|
// Initialize all the results to 0.
|
||||||
|
__m256i result0 = _mm256_setzero_si256();
|
||||||
|
__m256i result1 = _mm256_setzero_si256();
|
||||||
|
__m256i result2 = _mm256_setzero_si256();
|
||||||
|
__m256i result3 = _mm256_setzero_si256();
|
||||||
|
// Iterate over the input (u), one registerful at a time.
|
||||||
|
for (int j = 0; j < num_in;) {
|
||||||
|
__m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
|
||||||
|
// Inputs are processed in groups of kNumInputsPerGroup, replicated
|
||||||
|
// kNumInputGroups times.
|
||||||
|
for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
|
||||||
|
// Replicate the low 32 bits (4 inputs) 8 times.
|
||||||
|
__m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
|
||||||
|
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
|
||||||
|
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
|
||||||
|
__m256i weights, reps;
|
||||||
|
// Mul-add, with horizontal add of the 4 inputs to each of the results.
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExtractResults16(result0, result1, wi, scales, v);
|
||||||
|
ExtractResults16(result2, result3, wi, scales, v);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Computes part of matrix.vector v = Wu. Computes N=16 results.
|
||||||
|
// For details see PartialMatrixDotVector64 with N=16.
|
||||||
|
static void PartialMatrixDotVector16(const int8_t *wi, const double *scales, const int8_t *u,
|
||||||
|
int num_in, double *v) {
|
||||||
|
// Register containing 16-bit ones for horizontal add with 16->32 bit
|
||||||
|
// conversion.
|
||||||
|
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
|
||||||
|
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
|
||||||
|
// Initialize all the results to 0.
|
||||||
|
__m256i result0 = _mm256_setzero_si256();
|
||||||
|
__m256i result1 = _mm256_setzero_si256();
|
||||||
|
// Iterate over the input (u), one registerful at a time.
|
||||||
|
for (int j = 0; j < num_in;) {
|
||||||
|
__m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
|
||||||
|
// Inputs are processed in groups of kNumInputsPerGroup, replicated
|
||||||
|
// kNumInputGroups times.
|
||||||
|
for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
|
||||||
|
// Replicate the low 32 bits (4 inputs) 8 times.
|
||||||
|
__m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
|
||||||
|
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
|
||||||
|
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
|
||||||
|
__m256i weights, reps;
|
||||||
|
// Mul-add, with horizontal add of the 4 inputs to each of the results.
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExtractResults16(result0, result1, wi, scales, v);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Computes part of matrix.vector v = Wu. Computes N=8 results.
|
||||||
|
// For details see PartialMatrixDotVector64 with N=8.
|
||||||
|
static inline void PartialMatrixDotVector8(const int8_t *wi, const double *scales, const int8_t *u,
|
||||||
|
int num_in, double *v) {
|
||||||
|
// Register containing 16-bit ones for horizontal add with 16->32 bit
|
||||||
|
// conversion.
|
||||||
|
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
|
||||||
|
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
|
||||||
|
// Initialize all the results to 0.
|
||||||
|
__m256i result0 = _mm256_setzero_si256();
|
||||||
|
// Iterate over the input (u), one registerful at a time.
|
||||||
|
for (int j = 0; j < num_in;) {
|
||||||
|
__m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
|
||||||
|
// Inputs are processed in groups of kNumInputsPerGroup, replicated
|
||||||
|
// kNumInputGroups times.
|
||||||
|
for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
|
||||||
|
// Replicate the low 32 bits (4 inputs) 8 times.
|
||||||
|
__m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
|
||||||
|
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
|
||||||
|
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
|
||||||
|
__m256i weights, reps;
|
||||||
|
// Mul-add, with horizontal add of the 4 inputs to each of the results.
|
||||||
|
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExtractResults8(result0, wi, scales, v);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double *scales,
|
||||||
|
const int8_t *u, double *v) {
|
||||||
|
const int num_out = dim1;
|
||||||
|
const int num_in = dim2 - 1;
|
||||||
|
// Each call to a partial_func_ produces group_size outputs, except the
|
||||||
|
// last one, which can produce less.
|
||||||
|
const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup);
|
||||||
|
const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister);
|
||||||
|
int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
|
||||||
|
int output = 0;
|
||||||
|
|
||||||
|
int w_step = (rounded_num_in + 1) * group_size;
|
||||||
|
|
||||||
|
// Run with this group size, until it would produce too much output, then
|
||||||
|
// switch to a smaller size.
|
||||||
|
for (; output + group_size <= rounded_num_out; output += group_size) {
|
||||||
|
PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
|
||||||
|
wi += w_step;
|
||||||
|
scales += group_size;
|
||||||
|
v += group_size;
|
||||||
|
}
|
||||||
|
group_size /= 2;
|
||||||
|
w_step /= 2;
|
||||||
|
|
||||||
|
if (output + group_size <= rounded_num_out) {
|
||||||
|
PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
|
||||||
|
wi += w_step;
|
||||||
|
scales += group_size;
|
||||||
|
v += group_size;
|
||||||
|
output += group_size;
|
||||||
|
}
|
||||||
|
group_size /= 2;
|
||||||
|
w_step /= 2;
|
||||||
|
|
||||||
|
if (output + group_size <= rounded_num_out) {
|
||||||
|
PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
|
||||||
|
wi += w_step;
|
||||||
|
scales += group_size;
|
||||||
|
v += group_size;
|
||||||
|
output += group_size;
|
||||||
|
}
|
||||||
|
group_size /= 2;
|
||||||
|
w_step /= 2;
|
||||||
|
|
||||||
|
if (output + group_size <= rounded_num_out) {
|
||||||
|
PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX512VNNI = {
|
||||||
|
// Function.
|
||||||
|
matrixDotVector,
|
||||||
|
// Number of 32 bit outputs held in each register.
|
||||||
|
kNumOutputsPerRegister,
|
||||||
|
// Maximum number of registers that we will use to hold outputs.
|
||||||
|
kMaxOutputRegisters,
|
||||||
|
// Number of 8 bit inputs in the inputs register.
|
||||||
|
kNumInputsPerRegister,
|
||||||
|
// Number of inputs in each weight group.
|
||||||
|
kNumInputsPerGroup
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tesseract.
|
||||||
|
|
||||||
|
#endif
|
@ -179,9 +179,12 @@ SIMDDetect::SIMDDetect() {
|
|||||||
// be used inside an if.
|
// be used inside an if.
|
||||||
__cpuid_count(7, 0, eax, ebx, ecx, edx);
|
__cpuid_count(7, 0, eax, ebx, ecx, edx);
|
||||||
avx2_available_ = (ebx & 0x00000020) != 0;
|
avx2_available_ = (ebx & 0x00000020) != 0;
|
||||||
avx512F_available_ = (ebx & 0x00010000) != 0;
|
if ((xgetbv() & 0xe0) == 0xe0) {
|
||||||
avx512BW_available_ = (ebx & 0x40000000) != 0;
|
// OS supports AVX512.
|
||||||
avx512VNNI_available_ = (ecx & 0x00000800) != 0;
|
avx512F_available_ = (ebx & 0x00010000) != 0;
|
||||||
|
avx512BW_available_ = (ebx & 0x40000000) != 0;
|
||||||
|
avx512VNNI_available_ = (ecx & 0x00000800) != 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
# endif
|
# endif
|
||||||
}
|
}
|
||||||
@ -210,9 +213,12 @@ SIMDDetect::SIMDDetect() {
|
|||||||
if (max_function_id >= 7) {
|
if (max_function_id >= 7) {
|
||||||
__cpuid(cpuInfo, 7);
|
__cpuid(cpuInfo, 7);
|
||||||
avx2_available_ = (cpuInfo[1] & 0x00000020) != 0;
|
avx2_available_ = (cpuInfo[1] & 0x00000020) != 0;
|
||||||
avx512F_available_ = (cpuInfo[1] & 0x00010000) != 0;
|
if ((_xgetbv(0) & 0xe0) == 0xe0) {
|
||||||
avx512BW_available_ = (cpuInfo[1] & 0x40000000) != 0;
|
// OS supports AVX512.
|
||||||
avx512VNNI_available_ = (cpuInfo[2] & 0x00000800) != 0;
|
avx512F_available_ = (cpuInfo[1] & 0x00010000) != 0;
|
||||||
|
avx512BW_available_ = (cpuInfo[1] & 0x40000000) != 0;
|
||||||
|
avx512VNNI_available_ = (cpuInfo[2] & 0x00000800) != 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
# endif
|
# endif
|
||||||
}
|
}
|
||||||
@ -256,7 +262,18 @@ SIMDDetect::SIMDDetect() {
|
|||||||
#if defined(HAVE_AVX512F)
|
#if defined(HAVE_AVX512F)
|
||||||
} else if (avx512F_available_) {
|
} else if (avx512F_available_) {
|
||||||
// AVX512F detected.
|
// AVX512F detected.
|
||||||
|
# if defined(HAVE_AVX512VNNI)
|
||||||
|
if (avx512VNNI_available_) {
|
||||||
|
printf("mit VNNI\n");
|
||||||
|
SetDotProduct(DotProductAVX512F, &IntSimdMatrix::intSimdMatrixAVX512VNNI);
|
||||||
|
} else {
|
||||||
|
printf("ohne VNNI\n");
|
||||||
|
SetDotProduct(DotProductAVX512F, &IntSimdMatrix::intSimdMatrixAVX2);
|
||||||
|
}
|
||||||
|
# else
|
||||||
|
printf("ohne VNNI???\n");
|
||||||
SetDotProduct(DotProductAVX512F, &IntSimdMatrix::intSimdMatrixAVX2);
|
SetDotProduct(DotProductAVX512F, &IntSimdMatrix::intSimdMatrixAVX2);
|
||||||
|
# endif
|
||||||
#endif
|
#endif
|
||||||
#if defined(HAVE_AVX2)
|
#if defined(HAVE_AVX2)
|
||||||
} else if (avx2_available_) {
|
} else if (avx2_available_) {
|
||||||
|
@ -134,4 +134,18 @@ TEST_F(IntSimdMatrixTest, AVX2) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tests that the AVX512VNNI implementation gets the same result as the vanilla.
|
||||||
|
TEST_F(IntSimdMatrixTest, AVX512VNNI) {
|
||||||
|
#if defined(HAVE_AVX512VNNI)
|
||||||
|
if (!SIMDDetect::IsAVX512VNNIAvailable()) {
|
||||||
|
GTEST_LOG_(INFO) << "No AVX512VNNI found! Not tested!";
|
||||||
|
GTEST_SKIP();
|
||||||
|
}
|
||||||
|
ExpectEqualResults(IntSimdMatrix::intSimdMatrixAVX512VNNI);
|
||||||
|
#else
|
||||||
|
GTEST_LOG_(INFO) << "AVX512VNNI unsupported! Not tested!";
|
||||||
|
GTEST_SKIP();
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tesseract
|
} // namespace tesseract
|
||||||
|
Loading…
Reference in New Issue
Block a user