diff --git a/src/arch/intsimdmatrix.cpp b/src/arch/intsimdmatrix.cpp index fa4afa7c..9929e8cc 100644 --- a/src/arch/intsimdmatrix.cpp +++ b/src/arch/intsimdmatrix.cpp @@ -80,7 +80,32 @@ void IntSimdMatrix::MatrixDotVector(const GENERIC_2D_ARRAY &w, int num_out = w.dim1(); int num_in = w.dim2() - 1; // Base implementation. - for (int i = 0; i < num_out; ++i) { + int i; + // Break up into chunks of four to facilitate vectorization + for (i = 0; i < (num_out / 4) * 4; i += 4) { + const int8_t *wi0 = w[i + 0]; + const int8_t *wi1 = w[i + 1]; + const int8_t *wi2 = w[i + 2]; + const int8_t *wi3 = w[i + 3]; + int total0 = 0; + int total1 = 0; + int total2 = 0; + int total3 = 0; + for (int j = 0; j < num_in; ++j) { + total0 += wi0[j] * u[j]; + total1 += wi1[j] * u[j]; + total2 += wi2[j] * u[j]; + total3 += wi3[j] * u[j]; + } + // Add in the bias and correct for integer values. + v[i + 0] = (total0 + wi0[num_in] * INT8_MAX) * scales[i + 0]; + v[i + 1] = (total1 + wi1[num_in] * INT8_MAX) * scales[i + 1]; + v[i + 2] = (total2 + wi2[num_in] * INT8_MAX) * scales[i + 2]; + v[i + 3] = (total3 + wi3[num_in] * INT8_MAX) * scales[i + 3]; + } + + // Capture the remainder mod four + for (; i < num_out; ++i) { const int8_t *wi = w[i]; int total = 0; for (int j = 0; j < num_in; ++j) {