Add the needed changes to support AVX512VNNI

This commit is contained in:
Amit Dovev 2022-08-09 15:22:38 +03:00 committed by Stefan Weil
parent 8d0bd68983
commit 232093f046
3 changed files with 10 additions and 14 deletions

View File

@ -171,7 +171,7 @@ noinst_LTLIBRARIES += libtesseract_avx512.la
endif
if HAVE_AVX512VNNI
libtesseract_avx512vnni_la_CXXFLAGS = -march=icelake-client
libtesseract_avx512vnni_la_CXXFLAGS = -mavx512vnni -mavx512vl
libtesseract_avx512vnni_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil
libtesseract_avx512vnni_la_SOURCES = src/arch/intsimdmatrixavx512vnni.cpp
libtesseract_la_LIBADD += libtesseract_avx512vnni.la

View File

@ -157,7 +157,7 @@ case "${host_cpu}" in
AC_DEFINE([HAVE_AVX512F], [1], [Enable AVX512F instructions])
fi
AX_CHECK_COMPILE_FLAG([-march=icelake-client], [avx512vnni=true], [avx512vnni=false], [$WERROR])
AX_CHECK_COMPILE_FLAG([-mavx512vnni], [avx512vnni=true], [avx512vnni=false], [$WERROR])
AM_CONDITIONAL([HAVE_AVX512VNNI], $avx512vnni)
if $avx512vnni; then
AC_DEFINE([HAVE_AVX512VNNI], [1], [Enable AVX512VNNI instructions])

View File

@ -17,9 +17,9 @@
#include "intsimdmatrix.h"
#if !defined(__AVX2__)
#if !defined(__AVX512VNNI__) || !defined(__AVX512VL__)
# if defined(__i686__) || defined(__x86_64__)
# error Implementation only for AVX2 capable architectures
# error Implementation only for AVX512VNNI capable architectures
# endif
#else
# include <immintrin.h>
@ -73,16 +73,12 @@ static inline void MultiplyGroup(const __m256i &rep_input, const __m256i &ones,
// Normalize the signs on rep_input, weights, so weights is always +ve.
reps = _mm256_sign_epi8(rep_input, weights);
weights = _mm256_sign_epi8(weights, weights);
// Multiply 32x8-bit reps by 32x8-bit weights to make 16x16-bit results,
// with adjacent pairs added.
weights = _mm256_maddubs_epi16(weights, reps);
// Multiply 16x16-bit result by 16x16-bit ones to make 8x32-bit results,
// with adjacent pairs added. What we really want is a horizontal add of
// 16+16=32 bit result, but there is no such instruction, so multiply by
// 16-bit ones instead. It is probably faster than all the sign-extending,
// permuting and adding that would otherwise be required.
weights = _mm256_madd_epi16(weights, ones);
result = _mm256_add_epi32(result, weights);
// VNNI instruction. It replaces 3 AVX2 instructions:
//weights = _mm256_maddubs_epi16(weights, reps);
//weights = _mm256_madd_epi16(weights, ones);
//result = _mm256_add_epi32(result, weights);
result = _mm256_dpbusd_epi32(result, weights, reps);
}
// Load 64 bits into the bottom of a 128bit register.