/////////////////////////////////////////////////////////////////////// // File: intsimdmatrixavx2.cpp // Description: matrix-vector product for 8-bit data on avx2. // Author: Ray Smith // Created: Fri Aug 04 13:26:20 PST 2017 // // (C) Copyright 2017, Google Inc. // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // http://www.apache.org/licenses/LICENSE-2.0 // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /////////////////////////////////////////////////////////////////////// #include "intsimdmatrixavx2.h" #ifdef __AVX2__ #include #include #include namespace tesseract { // Number of outputs held in each register. 8 x 32 bit ints. constexpr int kNumOutputsPerRegister = 8; // Maximum number of registers that we will use. constexpr int kMaxOutputRegisters = 8; // Number of inputs in the inputs register. constexpr int kNumInputsPerRegister = 32; // Number of inputs in each weight group. constexpr int kNumInputsPerGroup = 4; // Number of groups of inputs to be broadcast. constexpr int kNumInputGroups = kNumInputsPerRegister / kNumInputsPerGroup; // Computes one set of 4x8 products of inputs and weights, adding to result. // Horizontally adds 4 adjacent results, making 8x32-bit results. // rep_input is assumed to be an 8x replicated set of 4x8-bit signed integers. // Note that wi must previously have been re-organized with blocks of 4x8 // weights in contiguous memory. // ones is a register of 16x16-bit values all equal to 1. // Note: wi is incremented by the amount of data read. // weights and reps are scratch registers. // This function must be inlined with references in order for the compiler to // correctly use the registers declared in the caller. inline void MultiplyGroup(const __m256i& rep_input, const __m256i& ones, const int8_t*& wi, __m256i& weights, __m256i& reps, __m256i& result) { // Load a 4x8 block of weights. weights = _mm256_loadu_si256(reinterpret_cast(wi)); wi += kNumInputsPerRegister; // Normalize the signs on rep_input, weights, so weights is always +ve. reps = _mm256_sign_epi8(rep_input, weights); weights = _mm256_sign_epi8(weights, weights); // Multiply 32x8-bit reps by 32x8-bit weights to make 16x16-bit results, // with adjacent pairs added. weights = _mm256_maddubs_epi16(weights, reps); // Multiply 16x16-bit result by 16x16-bit ones to make 8x32-bit results, // with adjacent pairs added. What we really want is a horizontal add of // 16+16=32 bit result, but there is no such instruction, so multiply by // 16-bit ones instead. It is probably faster than all the sign-extending, // permuting and adding that would otherwise be required. weights = _mm256_madd_epi16(weights, ones); result = _mm256_add_epi32(result, weights); } // Extracts and converts 8x32-bit results from result, adding the bias from wi // and scaling by scales, before storing in *v. Note that wi, scales and v are // expected to contain 8 consecutive elements or num_out if less. inline void ExtractResults(__m256i& result, __m256i& shift_id, const int8_t*& wi, const double*& scales, int num_out, double*& v) { for (int out = 0; out < num_out; ++out) { int32_t res = _mm256_extract_epi32(result, 0); *v++ = (static_cast(res) / MAX_INT8 + *wi++) * *scales++; // Rotate the results in int32_t units, so the next result is ready. result = _mm256_permutevar8x32_epi32(result, shift_id); } } // Computes part of matrix.vector v = Wu. Computes N=64 results. // The weights *must* be arranged so that consecutive reads from wi // provides (num_in/kNumInputsPerGroup groups of (N output dim groups of // (kNumInputsPerGroup inputs))). After that there must be N consecutive // bias weights, before continuing with any more weights. // u must be padded out with zeros to // kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements. static void PartialMatrixDotVector64(const int8_t* wi, const double* scales, const int8_t* u, int num_in, int num_out, double* v) { // Register containing 16-bit ones for horizontal add with 16->32 bit // conversion. __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); // Initialize all the results to 0. __m256i result0 = _mm256_setzero_si256(); __m256i result1 = _mm256_setzero_si256(); __m256i result2 = _mm256_setzero_si256(); __m256i result3 = _mm256_setzero_si256(); __m256i result4 = _mm256_setzero_si256(); __m256i result5 = _mm256_setzero_si256(); __m256i result6 = _mm256_setzero_si256(); __m256i result7 = _mm256_setzero_si256(); // Iterate over the input (u), one registerful at a time. for (int j = 0; j < num_in;) { __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); // Inputs are processed in groups of kNumInputsPerGroup, replicated // kNumInputGroups times. for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { // Replicate the low 32 bits (4 inputs) 8 times. __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); // Rotate the inputs in groups of 4, so the next 4 inputs are ready. inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); __m256i weights, reps; // Mul-add, with horizontal add of the 4 inputs to each of the results. MultiplyGroup(rep_input, ones, wi, weights, reps, result0); MultiplyGroup(rep_input, ones, wi, weights, reps, result1); MultiplyGroup(rep_input, ones, wi, weights, reps, result2); MultiplyGroup(rep_input, ones, wi, weights, reps, result3); MultiplyGroup(rep_input, ones, wi, weights, reps, result4); MultiplyGroup(rep_input, ones, wi, weights, reps, result5); MultiplyGroup(rep_input, ones, wi, weights, reps, result6); MultiplyGroup(rep_input, ones, wi, weights, reps, result7); } } ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v); ExtractResults(result1, shift_id, wi, scales, kNumOutputsPerRegister, v); ExtractResults(result2, shift_id, wi, scales, kNumOutputsPerRegister, v); ExtractResults(result3, shift_id, wi, scales, kNumOutputsPerRegister, v); ExtractResults(result4, shift_id, wi, scales, kNumOutputsPerRegister, v); ExtractResults(result5, shift_id, wi, scales, kNumOutputsPerRegister, v); ExtractResults(result6, shift_id, wi, scales, kNumOutputsPerRegister, v); num_out -= kNumOutputsPerRegister * 7; ExtractResults(result7, shift_id, wi, scales, std::min(kNumOutputsPerRegister, num_out), v); } // Computes part of matrix.vector v = Wu. Computes N=32 results. // For details see PartialMatrixDotVector64 with N=32. static void PartialMatrixDotVector32(const int8_t* wi, const double* scales, const int8_t* u, int num_in, int num_out, double* v) { // Register containing 16-bit ones for horizontal add with 16->32 bit // conversion. __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); // Initialize all the results to 0. __m256i result0 = _mm256_setzero_si256(); __m256i result1 = _mm256_setzero_si256(); __m256i result2 = _mm256_setzero_si256(); __m256i result3 = _mm256_setzero_si256(); // Iterate over the input (u), one registerful at a time. for (int j = 0; j < num_in;) { __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); // Inputs are processed in groups of kNumInputsPerGroup, replicated // kNumInputGroups times. for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { // Replicate the low 32 bits (4 inputs) 8 times. __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); // Rotate the inputs in groups of 4, so the next 4 inputs are ready. inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); __m256i weights, reps; // Mul-add, with horizontal add of the 4 inputs to each of the results. MultiplyGroup(rep_input, ones, wi, weights, reps, result0); MultiplyGroup(rep_input, ones, wi, weights, reps, result1); MultiplyGroup(rep_input, ones, wi, weights, reps, result2); MultiplyGroup(rep_input, ones, wi, weights, reps, result3); } } ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v); ExtractResults(result1, shift_id, wi, scales, kNumOutputsPerRegister, v); ExtractResults(result2, shift_id, wi, scales, kNumOutputsPerRegister, v); num_out -= kNumOutputsPerRegister * 3; ExtractResults(result3, shift_id, wi, scales, std::min(kNumOutputsPerRegister, num_out), v); } // Computes part of matrix.vector v = Wu. Computes N=16 results. // For details see PartialMatrixDotVector64 with N=16. static void PartialMatrixDotVector16(const int8_t* wi, const double* scales, const int8_t* u, int num_in, int num_out, double* v) { // Register containing 16-bit ones for horizontal add with 16->32 bit // conversion. __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); // Initialize all the results to 0. __m256i result0 = _mm256_setzero_si256(); __m256i result1 = _mm256_setzero_si256(); // Iterate over the input (u), one registerful at a time. for (int j = 0; j < num_in;) { __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); // Inputs are processed in groups of kNumInputsPerGroup, replicated // kNumInputGroups times. for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { // Replicate the low 32 bits (4 inputs) 8 times. __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); // Rotate the inputs in groups of 4, so the next 4 inputs are ready. inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); __m256i weights, reps; // Mul-add, with horizontal add of the 4 inputs to each of the results. MultiplyGroup(rep_input, ones, wi, weights, reps, result0); MultiplyGroup(rep_input, ones, wi, weights, reps, result1); } } ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v); num_out -= kNumOutputsPerRegister; ExtractResults(result1, shift_id, wi, scales, std::min(kNumOutputsPerRegister, num_out), v); } // Computes part of matrix.vector v = Wu. Computes N=8 results. // For details see PartialMatrixDotVector64 with N=8. static void PartialMatrixDotVector8(const int8_t* wi, const double* scales, const int8_t* u, int num_in, int num_out, double* v) { // Register containing 16-bit ones for horizontal add with 16->32 bit // conversion. __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); // Initialize all the results to 0. __m256i result0 = _mm256_setzero_si256(); // Iterate over the input (u), one registerful at a time. for (int j = 0; j < num_in;) { __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); // Inputs are processed in groups of kNumInputsPerGroup, replicated // kNumInputGroups times. for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { // Replicate the low 32 bits (4 inputs) 8 times. __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); // Rotate the inputs in groups of 4, so the next 4 inputs are ready. inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); __m256i weights, reps; // Mul-add, with horizontal add of the 4 inputs to each of the results. MultiplyGroup(rep_input, ones, wi, weights, reps, result0); } } ExtractResults(result0, shift_id, wi, scales, num_out, v); } #else namespace tesseract { #endif // __AVX2__ IntSimdMatrixAVX2::IntSimdMatrixAVX2() { #ifdef __AVX2__ num_outputs_per_register_ = kNumOutputsPerRegister; max_output_registers_ = kMaxOutputRegisters; num_inputs_per_register_ = kNumInputsPerRegister; num_inputs_per_group_ = kNumInputsPerGroup; num_input_groups_ = kNumInputGroups; partial_funcs_ = {PartialMatrixDotVector64, PartialMatrixDotVector32, PartialMatrixDotVector16, PartialMatrixDotVector8}; #endif // __AVX2__ } } // namespace tesseract.