Compile files for AVX, AVX2 or SSE only when needed

Signed-off-by: Stefan Weil <sw@weilnetz.de>
This commit is contained in:
Stefan Weil 2019-01-12 21:30:45 +01:00
parent a9a1035e55
commit 7fc7d28dd0
9 changed files with 61 additions and 34 deletions

View File

@ -53,9 +53,6 @@ libtesseract_la_LIBADD = \
../dict/libtesseract_dict.la \ ../dict/libtesseract_dict.la \
../arch/libtesseract_arch.la \ ../arch/libtesseract_arch.la \
../arch/libtesseract_native.la \ ../arch/libtesseract_native.la \
../arch/libtesseract_avx.la \
../arch/libtesseract_avx2.la \
../arch/libtesseract_sse.la \
../lstm/libtesseract_lstm.la \ ../lstm/libtesseract_lstm.la \
../ccstruct/libtesseract_ccstruct.la \ ../ccstruct/libtesseract_ccstruct.la \
../cutil/libtesseract_cutil.la \ ../cutil/libtesseract_cutil.la \
@ -63,6 +60,16 @@ libtesseract_la_LIBADD = \
../ccutil/libtesseract_ccutil.la \ ../ccutil/libtesseract_ccutil.la \
../opencl/libtesseract_opencl.la ../opencl/libtesseract_opencl.la
if AVX_OPT
libtesseract_la_LIBADD += ../arch/libtesseract_avx.la
endif
if AVX2_OPT
libtesseract_la_LIBADD += ../arch/libtesseract_avx2.la
endif
if SSE41_OPT
libtesseract_la_LIBADD += ../arch/libtesseract_sse.la
endif
libtesseract_la_LDFLAGS += -version-info $(GENERIC_LIBRARY_VERSION) $(NOUNDEFINED) libtesseract_la_LDFLAGS += -version-info $(GENERIC_LIBRARY_VERSION) $(NOUNDEFINED)
bin_PROGRAMS = tesseract bin_PROGRAMS = tesseract

View File

@ -15,8 +15,15 @@ noinst_HEADERS += intsimdmatrix.h
noinst_HEADERS += simddetect.h noinst_HEADERS += simddetect.h
noinst_LTLIBRARIES = libtesseract_native.la noinst_LTLIBRARIES = libtesseract_native.la
noinst_LTLIBRARIES += libtesseract_avx.la libtesseract_avx2.la if AVX_OPT
noinst_LTLIBRARIES += libtesseract_avx.la
endif
if AVX2_OPT
noinst_LTLIBRARIES += libtesseract_avx2.la
endif
if SSE41_OPT
noinst_LTLIBRARIES += libtesseract_sse.la noinst_LTLIBRARIES += libtesseract_sse.la
endif
noinst_LTLIBRARIES += libtesseract_arch.la noinst_LTLIBRARIES += libtesseract_arch.la
libtesseract_arch_la_CPPFLAGS = $(AM_CPPFLAGS) libtesseract_arch_la_CPPFLAGS = $(AM_CPPFLAGS)
@ -41,8 +48,14 @@ libtesseract_native_la_SOURCES = dotproduct.cpp
libtesseract_arch_la_SOURCES = intsimdmatrix.cpp simddetect.cpp libtesseract_arch_la_SOURCES = intsimdmatrix.cpp simddetect.cpp
if AVX_OPT
libtesseract_avx_la_SOURCES = dotproductavx.cpp libtesseract_avx_la_SOURCES = dotproductavx.cpp
endif
if AVX2_OPT
libtesseract_avx2_la_SOURCES = intsimdmatrixavx2.cpp libtesseract_avx2_la_SOURCES = intsimdmatrixavx2.cpp
endif
if SSE41_OPT
libtesseract_sse_la_SOURCES = dotproductsse.cpp intsimdmatrixsse.cpp libtesseract_sse_la_SOURCES = dotproductsse.cpp intsimdmatrixsse.cpp
endif

View File

@ -16,8 +16,10 @@
// limitations under the License. // limitations under the License.
/////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////
#if defined(__AVX__) #if !defined(__AVX__)
// Implementation for avx capable archs. #error Implementation only for AVX capable architectures
#endif
#include <immintrin.h> #include <immintrin.h>
#include <cstdint> #include <cstdint>
#include "dotproductavx.h" #include "dotproductavx.h"
@ -96,5 +98,3 @@ double DotProductAVX(const double* u, const double* v, int n) {
} }
} // namespace tesseract. } // namespace tesseract.
#endif // __AVX__

View File

@ -16,7 +16,9 @@
// limitations under the License. // limitations under the License.
/////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////
#if defined(__SSE4_1__) #if !defined(__SSE4_1__)
#error Implementation only for SSE 4.1 capable architectures
#endif
#include <emmintrin.h> #include <emmintrin.h>
#include <smmintrin.h> #include <smmintrin.h>
@ -117,5 +119,3 @@ int32_t IntDotProductSSE(const int8_t* u, const int8_t* v, int n) {
} }
} // namespace tesseract. } // namespace tesseract.
#endif // __SSE4_1__

View File

@ -25,12 +25,11 @@ namespace tesseract {
const IntSimdMatrix* IntSimdMatrix::intSimdMatrix = nullptr; const IntSimdMatrix* IntSimdMatrix::intSimdMatrix = nullptr;
// Computes a reshaped copy of the weight matrix w. If there are no // Computes a reshaped copy of the weight matrix w.
// partial_funcs_, it does nothing. void IntSimdMatrix::Init(const GENERIC_2D_ARRAY<int8_t>& w,
void IntSimdMatrix::Init(const GENERIC_2D_ARRAY<int8_t>& w, std::vector<int8_t>& shaped_w) const { std::vector<int8_t>& shaped_w) const {
if (partial_funcs_.empty()) return; const int num_out = w.dim1();
int num_out = w.dim1(); const int num_in = w.dim2() - 1;
int num_in = w.dim2() - 1;
// The rounded-up sizes of the reshaped weight matrix, excluding biases. // The rounded-up sizes of the reshaped weight matrix, excluding biases.
int rounded_num_in = Roundup(num_in, num_inputs_per_group_); int rounded_num_in = Roundup(num_in, num_inputs_per_group_);
int rounded_num_out = RoundOutputs(num_out); int rounded_num_out = RoundOutputs(num_out);

View File

@ -16,9 +16,12 @@
// limitations under the License. // limitations under the License.
/////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////
#if !defined(__AVX2__)
#error Implementation only for AVX2 capable architectures
#endif
#include "intsimdmatrix.h" #include "intsimdmatrix.h"
#ifdef __AVX2__
#include <immintrin.h> #include <immintrin.h>
#include <cstdint> #include <cstdint>
#include <algorithm> #include <algorithm>
@ -265,16 +268,9 @@ static void PartialMatrixDotVector8(const int8_t* wi, const double* scales,
} }
ExtractResults(result0, shift_id, wi, scales, num_out, v); ExtractResults(result0, shift_id, wi, scales, num_out, v);
} }
#else
namespace tesseract {
#endif // __AVX2__
#ifdef __AVX2__
const IntSimdMatrix IntSimdMatrix::IntSimdMatrixAVX2 = const IntSimdMatrix IntSimdMatrix::IntSimdMatrixAVX2 =
IntSimdMatrix(kNumOutputsPerRegister, kMaxOutputRegisters, kNumInputsPerRegister, kNumInputsPerGroup, kNumInputGroups, {PartialMatrixDotVector64, PartialMatrixDotVector32, IntSimdMatrix(kNumOutputsPerRegister, kMaxOutputRegisters, kNumInputsPerRegister, kNumInputsPerGroup, kNumInputGroups, {PartialMatrixDotVector64, PartialMatrixDotVector32,
PartialMatrixDotVector16, PartialMatrixDotVector8}); PartialMatrixDotVector16, PartialMatrixDotVector8});
#else
const IntSimdMatrix IntSimdMatrix::IntSimdMatrixAVX2 = IntSimdMatrix(1, 1, 1, 1, 1, {});
#endif // __AVX2__
} // namespace tesseract. } // namespace tesseract.

View File

@ -15,6 +15,10 @@
// limitations under the License. // limitations under the License.
/////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////
#if !defined(__SSE4_1__)
#error Implementation only for SSE 4.1 capable architectures
#endif
#include "intsimdmatrix.h" #include "intsimdmatrix.h"
#include <cstdint> #include <cstdint>
@ -22,7 +26,6 @@
namespace tesseract { namespace tesseract {
#ifdef __SSE4_1__
// Computes part of matrix.vector v = Wu. Computes 1 result. // Computes part of matrix.vector v = Wu. Computes 1 result.
static void PartialMatrixDotVector1(const int8_t* wi, const double* scales, static void PartialMatrixDotVector1(const int8_t* wi, const double* scales,
const int8_t* u, int num_in, int num_out, const int8_t* u, int num_in, int num_out,
@ -31,14 +34,8 @@ static void PartialMatrixDotVector1(const int8_t* wi, const double* scales,
// Add in the bias and correct for integer values. // Add in the bias and correct for integer values.
*v = (total / INT8_MAX + wi[num_in]) * *scales; *v = (total / INT8_MAX + wi[num_in]) * *scales;
} }
#endif // __SSE4_1__
#ifdef __SSE4_1__
const IntSimdMatrix IntSimdMatrix::IntSimdMatrixSSE = const IntSimdMatrix IntSimdMatrix::IntSimdMatrixSSE =
IntSimdMatrix(1, 1, 1, 1, 1, {PartialMatrixDotVector1}); IntSimdMatrix(1, 1, 1, 1, 1, {PartialMatrixDotVector1});
#else
const IntSimdMatrix IntSimdMatrix::IntSimdMatrixSSE =
IntSimdMatrix(1, 1, 1, 1, 1, {});
#endif // __SSE4_1__
} // namespace tesseract. } // namespace tesseract.

View File

@ -175,6 +175,13 @@ intfeaturemap_test_LDADD = $(GTEST_LIBS) $(TESS_LIBS)
intsimdmatrix_test_SOURCES = intsimdmatrix_test.cc intsimdmatrix_test_SOURCES = intsimdmatrix_test.cc
intsimdmatrix_test_LDADD = $(GTEST_LIBS) $(TESS_LIBS) intsimdmatrix_test_LDADD = $(GTEST_LIBS) $(TESS_LIBS)
intsimdmatrix_test_CPPFLAGS = $(AM_CPPFLAGS)
if AVX2_OPT
intsimdmatrix_test_CPPFLAGS += -DAVX2
endif
if SSE41_OPT
intsimdmatrix_test_CPPFLAGS += -DSSE4_1
endif
lang_model_test_SOURCES = lang_model_test.cc lang_model_test_SOURCES = lang_model_test.cc
lang_model_test_LDADD = $(ABSEIL_LIBS) $(GTEST_LIBS) $(TRAINING_LIBS) $(TESS_LIBS) $(ICU_I18N_LIBS) $(ICU_UC_LIBS) lang_model_test_LDADD = $(ABSEIL_LIBS) $(GTEST_LIBS) $(TRAINING_LIBS) $(TESS_LIBS) $(ICU_I18N_LIBS) $(ICU_UC_LIBS)

View File

@ -92,24 +92,32 @@ TEST_F(IntSimdMatrixTest, C) {
// Tests that the SSE implementation gets the same result as the vanilla. // Tests that the SSE implementation gets the same result as the vanilla.
TEST_F(IntSimdMatrixTest, SSE) { TEST_F(IntSimdMatrixTest, SSE) {
#if defined(SSE4_1)
if (SIMDDetect::IsSSEAvailable()) { if (SIMDDetect::IsSSEAvailable()) {
tprintf("SSE found! Continuing..."); tprintf("SSE found! Continuing...");
} else { } else {
tprintf("No SSE found! Not Tested!"); tprintf("No SSE found! Not tested!");
return; return;
} }
ExpectEqualResults(IntSimdMatrix::IntSimdMatrixSSE); ExpectEqualResults(IntSimdMatrix::IntSimdMatrixSSE);
#else
tprintf("SSE unsupported! Not tested!");
#endif
} }
// Tests that the AVX2 implementation gets the same result as the vanilla. // Tests that the AVX2 implementation gets the same result as the vanilla.
TEST_F(IntSimdMatrixTest, AVX2) { TEST_F(IntSimdMatrixTest, AVX2) {
#if defined(AVX2)
if (SIMDDetect::IsAVX2Available()) { if (SIMDDetect::IsAVX2Available()) {
tprintf("AVX2 found! Continuing..."); tprintf("AVX2 found! Continuing...");
} else { } else {
tprintf("No AVX2 found! Not Tested!"); tprintf("No AVX2 found! Not tested!");
return; return;
} }
ExpectEqualResults(IntSimdMatrix::IntSimdMatrixAVX2); ExpectEqualResults(IntSimdMatrix::IntSimdMatrixAVX2);
#else
tprintf("AVX2 unsupported! Not tested!");
#endif
} }
} // namespace } // namespace