bugfixing the AVX2 Extract8+16 codes

There's lines like `__m256d scale01234567 = _mm256_loadu_ps(scales)`,
i.e. loading float vectors into double vector types.

[sw] Formatted commit message
This commit is contained in:
Ger Hobbelt 2021-07-13 09:56:11 +02:00 committed by Stefan Weil
parent 24a29b79e5
commit 79e8b4f344

View File

@ -86,54 +86,45 @@ static inline __m128i load64_to_128(const int8_t *wi_) {
}
#if defined(FAST_FLOAT)
static inline void ExtractResults8(__m256i result, const int8_t *wi, const float *scales,
float *v) {
__m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
static inline void ExtractResults8(__m256i result, const int8_t *wi,
const float *scales, float *v) {
__m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
__m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
__m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
__m256d scale01234567 = _mm256_loadu_ps(scales);
//~ __m256d scale4567 = _mm256_loadu_ps(scales + 8);
__m256 scale01234567 = _mm256_loadu_ps(scales);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result = _mm256_add_epi32(result, w256); // result += bias * 127
__m256 res01234567 = _mm256_cvtepi32_ps(_mm256_castsi256_si128(result));
__m256 res01234567 = _mm256_cvtepi32_ps(result);
result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
__m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
res01234567 = _mm256_mul_pd(res01234567, scale01234567);
//~ res4567 = _mm256_mul_pd(res4567, scale4567);
res01234567 = _mm256_mul_ps(res01234567, scale01234567);
_mm256_storeu_ps(v, res01234567);
//~ _mm256_storeu_pd(v + 4, res4567);
}
static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi,
const float *&scales, float *&v) {
static inline void ExtractResults16(__m256i result0, __m256i result1,
const int8_t *&wi, const float *&scales,
float *&v) {
__m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
// 8x8bit vals in bottom of 128bit reg
const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
const __m256i bias_scale =
_mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
__m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
__m256d scale0123 = _mm256_loadu_ps(scales);
__m256d scale4567 = _mm256_loadu_ps(scales + 8);
__m256 scale01234567 = _mm256_loadu_ps(scales);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
__m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
__m256 res01234567 = _mm256_cvtepi32_ps(result0);
result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
__m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
res0123 = _mm256_mul_pd(res0123, scale0123);
res4567 = _mm256_mul_pd(res4567, scale4567);
_mm256_storeu_ps(v, res0123);
_mm256_storeu_ps(v + 8, res4567);
res01234567 = _mm256_mul_ps(res01234567, scale01234567);
_mm256_storeu_ps(v, res01234567);
w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
scale0123 = _mm256_loadu_ps(scales + 16);
scale4567 = _mm256_loadu_ps(scales + 24);
scale01234567 = _mm256_loadu_ps(scales + 8);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
res01234567 = _mm256_cvtepi32_ps(result1);
result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
res0123 = _mm256_mul_pd(res0123, scale0123);
res4567 = _mm256_mul_pd(res4567, scale4567);
_mm256_storeu_ps(v + 16, res0123);
_mm256_storeu_ps(v + 24, res4567);
res01234567 = _mm256_mul_ps(res01234567, scale01234567);
_mm256_storeu_ps(v + 8, res01234567);
wi += 16;
scales += 16;
v += 16;