From 3cbe60cca20901f6818d52fa35a001789147da75 Mon Sep 17 00:00:00 2001 From: Tomoaki Teshima Date: Mon, 20 Nov 2017 21:56:53 +0900 Subject: [PATCH] Merge pull request #9753 from tomoaki0705:universalMatmul * add accuracy test and performance check for matmul * add performance tests for transform and dotProduct * add test Core_TransformLargeTest for 8u version of transform * remove raw SSE2/NEON implementation from matmul.cpp * use universal intrinsic instead of raw intrinsic * remove unused templated function * add v_matmuladd which multiply 3x3 matrix and add 3x1 vector * add v_rotate_left/right in universal intrinsic * suppress intrinsic on some function and platform * add pure SW implementation of new universal intrinsics * add test for new universal intrinsics * core: prevent memory access after the end of buffer * fix perf tests --- .../include/opencv2/core/hal/intrin_cpp.hpp | 75 +- .../include/opencv2/core/hal/intrin_neon.hpp | 22 +- .../include/opencv2/core/hal/intrin_sse.hpp | 33 + modules/core/perf/opencl/perf_arithm.cpp | 28 + modules/core/perf/perf_mat.cpp | 28 + modules/core/src/matmul.cpp | 654 ++++++++---------- modules/core/test/test_intrin.cpp | 8 + modules/core/test/test_intrin_utils.hpp | 34 + modules/core/test/test_math.cpp | 55 ++ 9 files changed, 563 insertions(+), 374 deletions(-) diff --git a/modules/core/include/opencv2/core/hal/intrin_cpp.hpp b/modules/core/include/opencv2/core/hal/intrin_cpp.hpp index 0fc78ca519..945fddd801 100644 --- a/modules/core/include/opencv2/core/hal/intrin_cpp.hpp +++ b/modules/core/include/opencv2/core/hal/intrin_cpp.hpp @@ -885,12 +885,59 @@ template inline v_reg<_Tp, n> operator shift_op(const v_reg /** @brief Bitwise shift left For 16-, 32- and 64-bit integer values. */ -OPENCV_HAL_IMPL_SHIFT_OP(<<) +OPENCV_HAL_IMPL_SHIFT_OP(<< ) /** @brief Bitwise shift right For 16-, 32- and 64-bit integer values. */ -OPENCV_HAL_IMPL_SHIFT_OP(>>) +OPENCV_HAL_IMPL_SHIFT_OP(>> ) + +/** @brief Element shift left among vector + +For all type */ +#define OPENCV_HAL_IMPL_ROTATE_SHIFT_OP(suffix,opA,opB) \ +template inline v_reg<_Tp, n> v_rotate_##suffix(const v_reg<_Tp, n>& a) \ +{ \ + v_reg<_Tp, n> b; \ + for (int i = 0; i < n; i++) \ + { \ + int sIndex = i opA imm; \ + if (0 <= sIndex && sIndex < n) \ + { \ + b.s[i] = a.s[sIndex]; \ + } \ + else \ + { \ + b.s[i] = 0; \ + } \ + } \ + return b; \ +} \ +template inline v_reg<_Tp, n> v_rotate_##suffix(const v_reg<_Tp, n>& a, const v_reg<_Tp, n>& b) \ +{ \ + v_reg<_Tp, n> c; \ + for (int i = 0; i < n; i++) \ + { \ + int aIndex = i opA imm; \ + int bIndex = i opA imm opB n; \ + if (0 <= bIndex && bIndex < n) \ + { \ + c.s[i] = b.s[bIndex]; \ + } \ + else if (0 <= aIndex && aIndex < n) \ + { \ + c.s[i] = a.s[aIndex]; \ + } \ + else \ + { \ + c.s[i] = 0; \ + } \ + } \ + return c; \ +} + +OPENCV_HAL_IMPL_ROTATE_SHIFT_OP(left, -, +) +OPENCV_HAL_IMPL_ROTATE_SHIFT_OP(right, +, -) /** @brief Sum packed values @@ -1860,6 +1907,30 @@ inline v_float32x4 v_matmul(const v_float32x4& v, const v_float32x4& m0, v.s[0]*m0.s[3] + v.s[1]*m1.s[3] + v.s[2]*m2.s[3] + v.s[3]*m3.s[3]); } +/** @brief Matrix multiplication and add + +Scheme: +@code +{A0 A1 A2 } |V0| |D0| +{B0 B1 B2 } |V1| |D1| +{C0 C1 C2 } x |V2| + |D2| +==================== +{R0 R1 R2 R3}, where: +R0 = A0V0 + A1V1 + A2V2 + D0, +R1 = B0V0 + B1V1 + B2V2 + D1 +... +@endcode +*/ +inline v_float32x4 v_matmuladd(const v_float32x4& v, const v_float32x4& m0, + const v_float32x4& m1, const v_float32x4& m2, + const v_float32x4& m3) +{ + return v_float32x4(v.s[0]*m0.s[0] + v.s[1]*m1.s[0] + v.s[2]*m2.s[0] + m3.s[0], + v.s[0]*m0.s[1] + v.s[1]*m1.s[1] + v.s[2]*m2.s[1] + m3.s[1], + v.s[0]*m0.s[2] + v.s[1]*m1.s[2] + v.s[2]*m2.s[2] + m3.s[2], + v.s[0]*m0.s[3] + v.s[1]*m1.s[3] + v.s[2]*m2.s[3] + m3.s[3]); +} + //! @} //! @name Check SIMD support diff --git a/modules/core/include/opencv2/core/hal/intrin_neon.hpp b/modules/core/include/opencv2/core/hal/intrin_neon.hpp index 175750e06a..b824ca0c1c 100644 --- a/modules/core/include/opencv2/core/hal/intrin_neon.hpp +++ b/modules/core/include/opencv2/core/hal/intrin_neon.hpp @@ -407,6 +407,18 @@ inline v_float32x4 v_matmul(const v_float32x4& v, const v_float32x4& m0, return v_float32x4(res); } +inline v_float32x4 v_matmuladd(const v_float32x4& v, const v_float32x4& m0, + const v_float32x4& m1, const v_float32x4& m2, + const v_float32x4& a) +{ + float32x2_t vl = vget_low_f32(v.val), vh = vget_high_f32(v.val); + float32x4_t res = vmulq_lane_f32(m0.val, vl, 0); + res = vmlaq_lane_f32(res, m1.val, vl, 1); + res = vmlaq_lane_f32(res, m2.val, vh, 0); + res = vaddq_f32(res, a.val); + return v_float32x4(res); +} + #define OPENCV_HAL_IMPL_NEON_BIN_OP(bin_op, _Tpvec, intrin) \ inline _Tpvec operator bin_op (const _Tpvec& a, const _Tpvec& b) \ { \ @@ -747,7 +759,15 @@ template inline _Tpvec v_shl(const _Tpvec& a) \ template inline _Tpvec v_shr(const _Tpvec& a) \ { return _Tpvec(vshrq_n_##suffix(a.val, n)); } \ template inline _Tpvec v_rshr(const _Tpvec& a) \ -{ return _Tpvec(vrshrq_n_##suffix(a.val, n)); } +{ return _Tpvec(vrshrq_n_##suffix(a.val, n)); } \ +template inline _Tpvec v_rotate_right(const _Tpvec& a) \ +{ return _Tpvec(vextq_##suffix(a.val, vdupq_n_##suffix(0), n)); } \ +template inline _Tpvec v_rotate_left(const _Tpvec& a) \ +{ return _Tpvec(vextq_##suffix(vdupq_n_##suffix(0), a.val, _Tpvec::nlanes - n)); } \ +template inline _Tpvec v_rotate_right(const _Tpvec& a, const _Tpvec& b) \ +{ return _Tpvec(vextq_##suffix(a.val, b.val, n)); } \ +template inline _Tpvec v_rotate_left(const _Tpvec& a, const _Tpvec& b) \ +{ return _Tpvec(vextq_##suffix(b.val, a.val, _Tpvec::nlanes - n)); } OPENCV_HAL_IMPL_NEON_SHIFT_OP(v_uint8x16, u8, schar, s8) OPENCV_HAL_IMPL_NEON_SHIFT_OP(v_int8x16, s8, schar, s8) diff --git a/modules/core/include/opencv2/core/hal/intrin_sse.hpp b/modules/core/include/opencv2/core/hal/intrin_sse.hpp index 47ea2a2f54..637d49282e 100644 --- a/modules/core/include/opencv2/core/hal/intrin_sse.hpp +++ b/modules/core/include/opencv2/core/hal/intrin_sse.hpp @@ -602,6 +602,16 @@ inline v_float32x4 v_matmul(const v_float32x4& v, const v_float32x4& m0, return v_float32x4(_mm_add_ps(_mm_add_ps(v0, v1), _mm_add_ps(v2, v3))); } +inline v_float32x4 v_matmuladd(const v_float32x4& v, const v_float32x4& m0, + const v_float32x4& m1, const v_float32x4& m2, + const v_float32x4& a) +{ + __m128 v0 = _mm_mul_ps(_mm_shuffle_ps(v.val, v.val, _MM_SHUFFLE(0, 0, 0, 0)), m0.val); + __m128 v1 = _mm_mul_ps(_mm_shuffle_ps(v.val, v.val, _MM_SHUFFLE(1, 1, 1, 1)), m1.val); + __m128 v2 = _mm_mul_ps(_mm_shuffle_ps(v.val, v.val, _MM_SHUFFLE(2, 2, 2, 2)), m2.val); + + return v_float32x4(_mm_add_ps(_mm_add_ps(v0, v1), _mm_add_ps(v2, a.val))); +} #define OPENCV_HAL_IMPL_SSE_BIN_OP(bin_op, _Tpvec, intrin) \ inline _Tpvec operator bin_op (const _Tpvec& a, const _Tpvec& b) \ @@ -1011,6 +1021,29 @@ OPENCV_HAL_IMPL_SSE_SHIFT_OP(v_uint16x8, v_int16x8, epi16, _mm_srai_epi16) OPENCV_HAL_IMPL_SSE_SHIFT_OP(v_uint32x4, v_int32x4, epi32, _mm_srai_epi32) OPENCV_HAL_IMPL_SSE_SHIFT_OP(v_uint64x2, v_int64x2, epi64, v_srai_epi64) +template +inline _Tpvec v_rotate_right(const _Tpvec &a) +{ + return _Tpvec(_mm_srli_si128(a.val, imm*(sizeof(typename _Tpvec::lane_type)))); +} +template +inline _Tpvec v_rotate_left(const _Tpvec &a) +{ + return _Tpvec(_mm_slli_si128(a.val, imm*(sizeof(typename _Tpvec::lane_type)))); +} +template +inline _Tpvec v_rotate_right(const _Tpvec &a, const _Tpvec &b) +{ + const int cWidth = sizeof(typename _Tpvec::lane_type); + return _Tpvec(_mm_or_si128(_mm_srli_si128(a.val, imm*cWidth), _mm_slli_si128(b.val, (16 - imm*cWidth)))); +} +template +inline _Tpvec v_rotate_left(const _Tpvec &a, const _Tpvec &b) +{ + const int cWidth = sizeof(typename _Tpvec::lane_type); + return _Tpvec(_mm_or_si128(_mm_slli_si128(a.val, imm*cWidth), _mm_srli_si128(b.val, (16 - imm*cWidth)))); +} + #define OPENCV_HAL_IMPL_SSE_LOADSTORE_INT_OP(_Tpvec, _Tp) \ inline _Tpvec v_load(const _Tp* ptr) \ { return _Tpvec(_mm_loadu_si128((const __m128i*)ptr)); } \ diff --git a/modules/core/perf/opencl/perf_arithm.cpp b/modules/core/perf/opencl/perf_arithm.cpp index 3efccb07a8..40bfa1c291 100644 --- a/modules/core/perf/opencl/perf_arithm.cpp +++ b/modules/core/perf/opencl/perf_arithm.cpp @@ -1062,6 +1062,34 @@ OCL_PERF_TEST_P(ScaleAddFixture, ScaleAdd, SANITY_CHECK(dst, 1e-6); } +///////////// Transform //////////////////////// + +typedef Size_MatType TransformFixture; + +OCL_PERF_TEST_P(TransformFixture, Transform, + ::testing::Combine(OCL_TEST_SIZES, + ::testing::Values(CV_8UC3, CV_8SC3, CV_16UC3, CV_16SC3, CV_32SC3, CV_32FC3, CV_64FC3))) +{ + const Size_MatType_t params = GetParam(); + const Size srcSize = get<0>(params); + const int type = get<1>(params); + + checkDeviceMaxMemoryAllocSize(srcSize, type); + + const float transform[] = { 0.5f, 0.f, 0.86602540378f, 128, + 0.f, 1.f, 0.f, -64, + 0.86602540378f, 0.f, 0.5f, 32,}; + Mat mtx(Size(4, 3), CV_32FC1, (void*)transform); + + UMat src(srcSize, type), dst(srcSize, type); + randu(src, 0, 30); + declare.in(src).out(dst); + + OCL_TEST_CYCLE() cv::transform(src, dst, mtx); + + SANITY_CHECK(dst, 1e-6, ERROR_RELATIVE); +} + ///////////// PSNR //////////////////////// typedef Size_MatType PSNRFixture; diff --git a/modules/core/perf/perf_mat.cpp b/modules/core/perf/perf_mat.cpp index 79a3ecd1ff..7066c5badf 100644 --- a/modules/core/perf/perf_mat.cpp +++ b/modules/core/perf/perf_mat.cpp @@ -96,3 +96,31 @@ PERF_TEST_P(Size_MatType, Mat_Clone_Roi, SANITY_CHECK(destination, 1); } + +///////////// Transform //////////////////////// + +PERF_TEST_P(Size_MatType, Mat_Transform, + testing::Combine(testing::Values(TYPICAL_MAT_SIZES), + testing::Values(CV_8UC3, CV_8SC3, CV_16UC3, CV_16SC3, CV_32SC3, CV_32FC3, CV_64FC3)) + ) +{ + const Size_MatType_t params = GetParam(); + const Size srcSize0 = get<0>(params); + const Size srcSize = Size(1, srcSize0.width*srcSize0.height); + const int type = get<1>(params); + const float transform[] = { 0.5f, 0.f, 0.86602540378f, 128, + 0.f, 1.f, 0.f, -64, + 0.86602540378f, 0.f, 0.5f, 32,}; + Mat mtx(Size(4, 3), CV_32FC1, (void*)transform); + + Mat src(srcSize, type), dst(srcSize, type); + randu(src, 0, 30); + declare.in(src).out(dst); + + TEST_CYCLE() + { + cv::transform(src, dst, mtx); + } + + SANITY_CHECK(dst, 1e-6, ERROR_RELATIVE); +} diff --git a/modules/core/src/matmul.cpp b/modules/core/src/matmul.cpp index cfd7fa1eaa..f67a301086 100644 --- a/modules/core/src/matmul.cpp +++ b/modules/core/src/matmul.cpp @@ -1699,41 +1699,53 @@ transform_( const T* src, T* dst, const WT* m, int len, int scn, int dcn ) } } -#if CV_SSE2 - +#if CV_SIMD128 static inline void -load3x3Matrix( const float* m, __m128& m0, __m128& m1, __m128& m2, __m128& m3 ) +load3x3Matrix(const float* m, v_float32x4& m0, v_float32x4& m1, v_float32x4& m2, v_float32x4& m3) { - m0 = _mm_setr_ps(m[0], m[4], m[8], 0); - m1 = _mm_setr_ps(m[1], m[5], m[9], 0); - m2 = _mm_setr_ps(m[2], m[6], m[10], 0); - m3 = _mm_setr_ps(m[3], m[7], m[11], 0); + m0 = v_float32x4(m[0], m[4], m[8], 0); + m1 = v_float32x4(m[1], m[5], m[9], 0); + m2 = v_float32x4(m[2], m[6], m[10], 0); + m3 = v_float32x4(m[3], m[7], m[11], 0); } -static inline void -load4x4Matrix( const float* m, __m128& m0, __m128& m1, __m128& m2, __m128& m3, __m128& m4 ) +static inline v_int16x8 +v_matmulvec(const v_int16x8 &v0, const v_int16x8 &m0, const v_int16x8 &m1, const v_int16x8 &m2, const v_int32x4 &m3, const int BITS) { - m0 = _mm_setr_ps(m[0], m[5], m[10], m[15]); - m1 = _mm_setr_ps(m[1], m[6], m[11], m[16]); - m2 = _mm_setr_ps(m[2], m[7], m[12], m[17]); - m3 = _mm_setr_ps(m[3], m[8], m[13], m[18]); - m4 = _mm_setr_ps(m[4], m[9], m[14], m[19]); -} + // v0 : 0 b0 g0 r0 b1 g1 r1 ? + v_int32x4 t0 = v_dotprod(v0, m0); // a0 b0 a1 b1 + v_int32x4 t1 = v_dotprod(v0, m1); // c0 d0 c1 d1 + v_int32x4 t2 = v_dotprod(v0, m2); // e0 f0 e1 f1 + v_int32x4 t3 = v_setzero_s32(); + v_int32x4 s0, s1, s2, s3; + v_transpose4x4(t0, t1, t2, t3, s0, s1, s2, s3); + s0 = s0 + s1 + m3; // B0 G0 R0 ? + s2 = s2 + s3 + m3; // B1 G1 R1 ? + s0 = s0 >> BITS; + s2 = s2 >> BITS; + + v_int16x8 result = v_pack(s0, v_setzero_s32()); // B0 G0 R0 0 0 0 0 0 + result = v_reinterpret_as_s16(v_reinterpret_as_s64(result) << 16); // 0 B0 G0 R0 0 0 0 0 + result = result | v_pack(v_setzero_s32(), s2); // 0 B0 G0 R0 B1 G1 R1 0 + return result; +} #endif static void transform_8u( const uchar* src, uchar* dst, const float* m, int len, int scn, int dcn ) { -#if CV_SSE2 +#if CV_SIMD128 const int BITS = 10, SCALE = 1 << BITS; const float MAX_M = (float)(1 << (15 - BITS)); - if( USE_SSE2 && scn == 3 && dcn == 3 && + if( hasSIMD128() && scn == 3 && dcn == 3 && std::abs(m[0]) < MAX_M && std::abs(m[1]) < MAX_M && std::abs(m[2]) < MAX_M && std::abs(m[3]) < MAX_M*256 && std::abs(m[4]) < MAX_M && std::abs(m[5]) < MAX_M && std::abs(m[6]) < MAX_M && std::abs(m[7]) < MAX_M*256 && std::abs(m[8]) < MAX_M && std::abs(m[9]) < MAX_M && std::abs(m[10]) < MAX_M && std::abs(m[11]) < MAX_M*256 ) { + const int nChannels = 3; + const int cWidth = v_int16x8::nlanes; // faster fixed-point transformation short m00 = saturate_cast(m[0]*SCALE), m01 = saturate_cast(m[1]*SCALE), m02 = saturate_cast(m[2]*SCALE), m10 = saturate_cast(m[4]*SCALE), @@ -1743,92 +1755,50 @@ transform_8u( const uchar* src, uchar* dst, const float* m, int len, int scn, in int m03 = saturate_cast((m[3]+0.5f)*SCALE), m13 = saturate_cast((m[7]+0.5f)*SCALE ), m23 = saturate_cast((m[11]+0.5f)*SCALE); - __m128i m0 = _mm_setr_epi16(0, m00, m01, m02, m00, m01, m02, 0); - __m128i m1 = _mm_setr_epi16(0, m10, m11, m12, m10, m11, m12, 0); - __m128i m2 = _mm_setr_epi16(0, m20, m21, m22, m20, m21, m22, 0); - __m128i m3 = _mm_setr_epi32(m03, m13, m23, 0); + v_int16x8 m0 = v_int16x8(0, m00, m01, m02, m00, m01, m02, 0); + v_int16x8 m1 = v_int16x8(0, m10, m11, m12, m10, m11, m12, 0); + v_int16x8 m2 = v_int16x8(0, m20, m21, m22, m20, m21, m22, 0); + v_int32x4 m3 = v_int32x4(m03, m13, m23, 0); int x = 0; - for( ; x <= (len - 8)*3; x += 8*3 ) + for (; x <= (len - cWidth) * nChannels; x += cWidth * nChannels) { - __m128i z = _mm_setzero_si128(), t0, t1, t2, r0, r1; - __m128i v0 = _mm_loadl_epi64((const __m128i*)(src + x)); - __m128i v1 = _mm_loadl_epi64((const __m128i*)(src + x + 8)); - __m128i v2 = _mm_loadl_epi64((const __m128i*)(src + x + 16)), v3; - v0 = _mm_unpacklo_epi8(v0, z); // b0 g0 r0 b1 g1 r1 b2 g2 - v1 = _mm_unpacklo_epi8(v1, z); // r2 b3 g3 r3 b4 g4 r4 b5 - v2 = _mm_unpacklo_epi8(v2, z); // g5 r5 b6 g6 r6 b7 g7 r7 + // load 8 pixels + v_int16x8 v0 = v_reinterpret_as_s16(v_load_expand(src + x)); + v_int16x8 v1 = v_reinterpret_as_s16(v_load_expand(src + x + cWidth)); + v_int16x8 v2 = v_reinterpret_as_s16(v_load_expand(src + x + cWidth * 2)); + v_int16x8 v3; - v3 = _mm_srli_si128(v2, 2); // ? b6 g6 r6 b7 g7 r7 0 - v2 = _mm_or_si128(_mm_slli_si128(v2, 10), _mm_srli_si128(v1, 6)); // ? b4 g4 r4 b5 g5 r5 ? - v1 = _mm_or_si128(_mm_slli_si128(v1, 6), _mm_srli_si128(v0, 10)); // ? b2 g2 r2 b3 g3 r3 ? - v0 = _mm_slli_si128(v0, 2); // 0 b0 g0 r0 b1 g1 r1 ? + // rotate and pack + v3 = v_rotate_right<1>(v2); // 0 b6 g6 r6 b7 g7 r7 0 + v2 = v_rotate_left <5>(v2, v1); // 0 b4 g4 r4 b5 g5 r5 0 + v1 = v_rotate_left <3>(v1, v0); // 0 b2 g2 r2 b3 g3 r3 0 + v0 = v_rotate_left <1>(v0); // 0 b0 g0 r0 b1 g1 r1 0 - // process pixels 0 & 1 - t0 = _mm_madd_epi16(v0, m0); // a0 b0 a1 b1 - t1 = _mm_madd_epi16(v0, m1); // c0 d0 c1 d1 - t2 = _mm_madd_epi16(v0, m2); // e0 f0 e1 f1 - v0 = _mm_unpacklo_epi32(t0, t1); // a0 c0 b0 d0 - t0 = _mm_unpackhi_epi32(t0, t1); // a1 b1 c1 d1 - t1 = _mm_unpacklo_epi32(t2, z); // e0 0 f0 0 - t2 = _mm_unpackhi_epi32(t2, z); // e1 0 f1 0 - r0 = _mm_add_epi32(_mm_add_epi32(_mm_unpacklo_epi64(v0, t1), _mm_unpackhi_epi64(v0,t1)), m3); // B0 G0 R0 0 - r1 = _mm_add_epi32(_mm_add_epi32(_mm_unpacklo_epi64(t0, t2), _mm_unpackhi_epi64(t0,t2)), m3); // B1 G1 R1 0 - r0 = _mm_srai_epi32(r0, BITS); - r1 = _mm_srai_epi32(r1, BITS); - v0 = _mm_packus_epi16(_mm_packs_epi32(_mm_slli_si128(r0, 4), r1), z); // 0 B0 G0 R0 B1 G1 R1 0 + // multiply with matrix and normalize + v0 = v_matmulvec(v0, m0, m1, m2, m3, BITS); // 0 B0 G0 R0 B1 G1 R1 0 + v1 = v_matmulvec(v1, m0, m1, m2, m3, BITS); // 0 B2 G2 R2 B3 G3 R3 0 + v2 = v_matmulvec(v2, m0, m1, m2, m3, BITS); // 0 B4 G4 R4 B5 G5 R5 0 + v3 = v_matmulvec(v3, m0, m1, m2, m3, BITS); // 0 B6 G6 R6 B7 G7 R7 0 - // process pixels 2 & 3 - t0 = _mm_madd_epi16(v1, m0); // a0 b0 a1 b1 - t1 = _mm_madd_epi16(v1, m1); // c0 d0 c1 d1 - t2 = _mm_madd_epi16(v1, m2); // e0 f0 e1 f1 - v1 = _mm_unpacklo_epi32(t0, t1); // a0 c0 b0 d0 - t0 = _mm_unpackhi_epi32(t0, t1); // a1 b1 c1 d1 - t1 = _mm_unpacklo_epi32(t2, z); // e0 0 f0 0 - t2 = _mm_unpackhi_epi32(t2, z); // e1 0 f1 0 - r0 = _mm_add_epi32(_mm_add_epi32(_mm_unpacklo_epi64(v1, t1), _mm_unpackhi_epi64(v1,t1)), m3); // B2 G2 R2 0 - r1 = _mm_add_epi32(_mm_add_epi32(_mm_unpacklo_epi64(t0, t2), _mm_unpackhi_epi64(t0,t2)), m3); // B3 G3 R3 0 - r0 = _mm_srai_epi32(r0, BITS); - r1 = _mm_srai_epi32(r1, BITS); - v1 = _mm_packus_epi16(_mm_packs_epi32(_mm_slli_si128(r0, 4), r1), z); // 0 B2 G2 R2 B3 G3 R3 0 + // narrow down as uint8x16 + v_uint8x16 z0 = v_pack_u(v0, v_setzero_s16()); // 0 B0 G0 R0 B1 G1 R1 0 0 0 0 0 0 0 0 0 + v_uint8x16 z1 = v_pack_u(v1, v_setzero_s16()); // 0 B2 G2 R2 B3 G3 R3 0 0 0 0 0 0 0 0 0 + v_uint8x16 z2 = v_pack_u(v2, v_setzero_s16()); // 0 B4 G4 R4 B5 G5 R5 0 0 0 0 0 0 0 0 0 + v_uint8x16 z3 = v_pack_u(v3, v_setzero_s16()); // 0 B6 G6 R6 B7 G7 R7 0 0 0 0 0 0 0 0 0 - // process pixels 4 & 5 - t0 = _mm_madd_epi16(v2, m0); // a0 b0 a1 b1 - t1 = _mm_madd_epi16(v2, m1); // c0 d0 c1 d1 - t2 = _mm_madd_epi16(v2, m2); // e0 f0 e1 f1 - v2 = _mm_unpacklo_epi32(t0, t1); // a0 c0 b0 d0 - t0 = _mm_unpackhi_epi32(t0, t1); // a1 b1 c1 d1 - t1 = _mm_unpacklo_epi32(t2, z); // e0 0 f0 0 - t2 = _mm_unpackhi_epi32(t2, z); // e1 0 f1 0 - r0 = _mm_add_epi32(_mm_add_epi32(_mm_unpacklo_epi64(v2, t1), _mm_unpackhi_epi64(v2,t1)), m3); // B4 G4 R4 0 - r1 = _mm_add_epi32(_mm_add_epi32(_mm_unpacklo_epi64(t0, t2), _mm_unpackhi_epi64(t0,t2)), m3); // B5 G5 R5 0 - r0 = _mm_srai_epi32(r0, BITS); - r1 = _mm_srai_epi32(r1, BITS); - v2 = _mm_packus_epi16(_mm_packs_epi32(_mm_slli_si128(r0, 4), r1), z); // 0 B4 G4 R4 B5 G5 R5 0 + // rotate and pack + z0 = v_reinterpret_as_u8(v_reinterpret_as_u64(z0) >> 8) | v_reinterpret_as_u8(v_reinterpret_as_u64(z1) << 40); // B0 G0 R0 B1 G1 R1 B2 G2 0 0 0 0 0 0 0 0 + z1 = v_reinterpret_as_u8(v_reinterpret_as_u64(z1) >> 24) | v_reinterpret_as_u8(v_reinterpret_as_u64(z2) << 24); // R2 B3 G3 R3 B4 G4 R4 B5 0 0 0 0 0 0 0 0 + z2 = v_reinterpret_as_u8(v_reinterpret_as_u64(z2) >> 40) | v_reinterpret_as_u8(v_reinterpret_as_u64(z3) << 8); // G5 R6 B6 G6 R6 B7 G7 R7 0 0 0 0 0 0 0 0 - // process pixels 6 & 7 - t0 = _mm_madd_epi16(v3, m0); // a0 b0 a1 b1 - t1 = _mm_madd_epi16(v3, m1); // c0 d0 c1 d1 - t2 = _mm_madd_epi16(v3, m2); // e0 f0 e1 f1 - v3 = _mm_unpacklo_epi32(t0, t1); // a0 c0 b0 d0 - t0 = _mm_unpackhi_epi32(t0, t1); // a1 b1 c1 d1 - t1 = _mm_unpacklo_epi32(t2, z); // e0 0 f0 0 - t2 = _mm_unpackhi_epi32(t2, z); // e1 0 f1 0 - r0 = _mm_add_epi32(_mm_add_epi32(_mm_unpacklo_epi64(v3, t1), _mm_unpackhi_epi64(v3,t1)), m3); // B6 G6 R6 0 - r1 = _mm_add_epi32(_mm_add_epi32(_mm_unpacklo_epi64(t0, t2), _mm_unpackhi_epi64(t0,t2)), m3); // B7 G7 R7 0 - r0 = _mm_srai_epi32(r0, BITS); - r1 = _mm_srai_epi32(r1, BITS); - v3 = _mm_packus_epi16(_mm_packs_epi32(_mm_slli_si128(r0, 4), r1), z); // 0 B6 G6 R6 B7 G7 R7 0 - - v0 = _mm_or_si128(_mm_srli_si128(v0, 1), _mm_slli_si128(v1, 5)); - v1 = _mm_or_si128(_mm_srli_si128(v1, 3), _mm_slli_si128(v2, 3)); - v2 = _mm_or_si128(_mm_srli_si128(v2, 5), _mm_slli_si128(v3, 1)); - _mm_storel_epi64((__m128i*)(dst + x), v0); - _mm_storel_epi64((__m128i*)(dst + x + 8), v1); - _mm_storel_epi64((__m128i*)(dst + x + 16), v2); + // store on memory + v_store_low(dst + x, z0); + v_store_low(dst + x + cWidth, z1); + v_store_low(dst + x + cWidth * 2, z2); } - for( ; x < len*3; x += 3 ) + for( ; x < len * nChannels; x += nChannels ) { int v0 = src[x], v1 = src[x+1], v2 = src[x+2]; uchar t0 = saturate_cast((m00*v0 + m01*v1 + m02*v2 + m03)>>BITS); @@ -1846,61 +1816,63 @@ transform_8u( const uchar* src, uchar* dst, const float* m, int len, int scn, in static void transform_16u( const ushort* src, ushort* dst, const float* m, int len, int scn, int dcn ) { -#if CV_SSE2 - if( USE_SSE2 && scn == 3 && dcn == 3 ) +#if CV_SIMD128 && !defined(__aarch64__) + if( hasSIMD128() && scn == 3 && dcn == 3 ) { - __m128 m0, m1, m2, m3; - __m128i delta = _mm_setr_epi16(0,-32768,-32768,-32768,-32768,-32768,-32768,0); + const int nChannels = 3; + const int cWidth = v_float32x4::nlanes; + v_int16x8 delta = v_int16x8(0, -32768, -32768, -32768, -32768, -32768, -32768, 0); + v_float32x4 m0, m1, m2, m3; load3x3Matrix(m, m0, m1, m2, m3); - m3 = _mm_sub_ps(m3, _mm_setr_ps(32768.f, 32768.f, 32768.f, 0.f)); + m3 -= v_float32x4(32768.f, 32768.f, 32768.f, 0.f); int x = 0; - for( ; x <= (len - 4)*3; x += 4*3 ) + for( ; x <= (len - cWidth) * nChannels; x += cWidth * nChannels ) { - __m128i z = _mm_setzero_si128(); - __m128i v0 = _mm_loadu_si128((const __m128i*)(src + x)), v1; - __m128i v2 = _mm_loadl_epi64((const __m128i*)(src + x + 8)), v3; - v1 = _mm_unpacklo_epi16(_mm_srli_si128(v0, 6), z); // b1 g1 r1 - v3 = _mm_unpacklo_epi16(_mm_srli_si128(v2, 2), z); // b3 g3 r3 - v2 = _mm_or_si128(_mm_srli_si128(v0, 12), _mm_slli_si128(v2, 4)); - v0 = _mm_unpacklo_epi16(v0, z); // b0 g0 r0 - v2 = _mm_unpacklo_epi16(v2, z); // b2 g2 r2 - __m128 x0 = _mm_cvtepi32_ps(v0), x1 = _mm_cvtepi32_ps(v1); - __m128 x2 = _mm_cvtepi32_ps(v2), x3 = _mm_cvtepi32_ps(v3); - __m128 y0 = _mm_add_ps(_mm_add_ps(_mm_add_ps( - _mm_mul_ps(m0, _mm_shuffle_ps(x0,x0,_MM_SHUFFLE(0,0,0,0))), - _mm_mul_ps(m1, _mm_shuffle_ps(x0,x0,_MM_SHUFFLE(1,1,1,1)))), - _mm_mul_ps(m2, _mm_shuffle_ps(x0,x0,_MM_SHUFFLE(2,2,2,2)))), m3); - __m128 y1 = _mm_add_ps(_mm_add_ps(_mm_add_ps( - _mm_mul_ps(m0, _mm_shuffle_ps(x1,x1,_MM_SHUFFLE(0,0,0,0))), - _mm_mul_ps(m1, _mm_shuffle_ps(x1,x1,_MM_SHUFFLE(1,1,1,1)))), - _mm_mul_ps(m2, _mm_shuffle_ps(x1,x1,_MM_SHUFFLE(2,2,2,2)))), m3); - __m128 y2 = _mm_add_ps(_mm_add_ps(_mm_add_ps( - _mm_mul_ps(m0, _mm_shuffle_ps(x2,x2,_MM_SHUFFLE(0,0,0,0))), - _mm_mul_ps(m1, _mm_shuffle_ps(x2,x2,_MM_SHUFFLE(1,1,1,1)))), - _mm_mul_ps(m2, _mm_shuffle_ps(x2,x2,_MM_SHUFFLE(2,2,2,2)))), m3); - __m128 y3 = _mm_add_ps(_mm_add_ps(_mm_add_ps( - _mm_mul_ps(m0, _mm_shuffle_ps(x3,x3,_MM_SHUFFLE(0,0,0,0))), - _mm_mul_ps(m1, _mm_shuffle_ps(x3,x3,_MM_SHUFFLE(1,1,1,1)))), - _mm_mul_ps(m2, _mm_shuffle_ps(x3,x3,_MM_SHUFFLE(2,2,2,2)))), m3); - v0 = _mm_cvtps_epi32(y0); v1 = _mm_cvtps_epi32(y1); - v2 = _mm_cvtps_epi32(y2); v3 = _mm_cvtps_epi32(y3); + // load 4 pixels + v_uint16x8 v0_16 = v_load(src + x); // b0 g0 r0 b1 g1 r1 b2 g2 + v_uint16x8 v2_16 = v_load_low(src + x + cWidth * 2); // r2 b3 g3 r3 ? ? ? ? - v0 = _mm_add_epi16(_mm_packs_epi32(_mm_slli_si128(v0,4), v1), delta); // 0 b0 g0 r0 b1 g1 r1 0 - v2 = _mm_add_epi16(_mm_packs_epi32(_mm_slli_si128(v2,4), v3), delta); // 0 b2 g2 r2 b3 g3 r3 0 - v1 = _mm_or_si128(_mm_srli_si128(v0,2), _mm_slli_si128(v2,10)); // b0 g0 r0 b1 g1 r1 b2 g2 - v2 = _mm_srli_si128(v2, 6); // r2 b3 g3 r3 0 0 0 0 - _mm_storeu_si128((__m128i*)(dst + x), v1); - _mm_storel_epi64((__m128i*)(dst + x + 8), v2); + // expand to 4 vectors + v_uint32x4 v0_32, v1_32, v2_32, v3_32, dummy_32; + v_expand(v_rotate_right<3>(v0_16), v1_32, dummy_32); // b1 g1 r1 + v_expand(v_rotate_right<1>(v2_16), v3_32, dummy_32); // b3 g3 r3 + v_expand(v_rotate_right<6>(v0_16, v2_16), v2_32, dummy_32); // b2 g2 r2 + v_expand(v0_16, v0_32, dummy_32); // b0 g0 r0 + + // convert to float32x4 + v_float32x4 x0 = v_cvt_f32(v_reinterpret_as_s32(v0_32)); // b0 g0 r0 + v_float32x4 x1 = v_cvt_f32(v_reinterpret_as_s32(v1_32)); // b1 g1 r1 + v_float32x4 x2 = v_cvt_f32(v_reinterpret_as_s32(v2_32)); // b2 g2 r2 + v_float32x4 x3 = v_cvt_f32(v_reinterpret_as_s32(v3_32)); // b3 g3 r3 + + // multiply and convert back to int32x4 + v_int32x4 y0, y1, y2, y3; + y0 = v_round(v_matmuladd(x0, m0, m1, m2, m3)); // B0 G0 R0 + y1 = v_round(v_matmuladd(x1, m0, m1, m2, m3)); // B1 G1 R1 + y2 = v_round(v_matmuladd(x2, m0, m1, m2, m3)); // B2 G2 R2 + y3 = v_round(v_matmuladd(x3, m0, m1, m2, m3)); // B3 G3 R3 + + // narrow down to int16x8 + v_int16x8 v0 = v_add_wrap(v_pack(v_rotate_left<1>(y0), y1), delta); // 0 B0 G0 R0 B1 G1 R1 0 + v_int16x8 v2 = v_add_wrap(v_pack(v_rotate_left<1>(y2), y3), delta); // 0 B2 G2 R2 B3 G3 R3 0 + + // rotate and pack + v0 = v_rotate_right<1>(v0) | v_rotate_left<5>(v2); // B0 G0 R0 B1 G1 R1 B2 G2 + v2 = v_rotate_right<3>(v2); // R2 B3 G3 R3 0 0 0 0 + + // store 4 pixels + v_store(dst + x, v_reinterpret_as_u16(v0)); + v_store_low(dst + x + cWidth * 2, v_reinterpret_as_u16(v2)); } - for( ; x < len*3; x += 3 ) + for( ; x < len * nChannels; x += nChannels ) { - float v0 = src[x], v1 = src[x+1], v2 = src[x+2]; - ushort t0 = saturate_cast(m[0]*v0 + m[1]*v1 + m[2]*v2 + m[3]); - ushort t1 = saturate_cast(m[4]*v0 + m[5]*v1 + m[6]*v2 + m[7]); - ushort t2 = saturate_cast(m[8]*v0 + m[9]*v1 + m[10]*v2 + m[11]); - dst[x] = t0; dst[x+1] = t1; dst[x+2] = t2; + float v0 = src[x], v1 = src[x + 1], v2 = src[x + 2]; + ushort t0 = saturate_cast(m[0] * v0 + m[1] * v1 + m[2] * v2 + m[3]); + ushort t1 = saturate_cast(m[4] * v0 + m[5] * v1 + m[6] * v2 + m[7]); + ushort t2 = saturate_cast(m[8] * v0 + m[9] * v1 + m[10] * v2 + m[11]); + dst[x] = t0; dst[x + 1] = t1; dst[x + 2] = t2; } return; } @@ -1909,31 +1881,28 @@ transform_16u( const ushort* src, ushort* dst, const float* m, int len, int scn, transform_(src, dst, m, len, scn, dcn); } - static void transform_32f( const float* src, float* dst, const float* m, int len, int scn, int dcn ) { -#if CV_SSE2 - if( USE_SSE2 ) +#if CV_SIMD128 && !defined(__aarch64__) + if( hasSIMD128() ) { int x = 0; if( scn == 3 && dcn == 3 ) { - __m128 m0, m1, m2, m3; + const int cWidth = 3; + v_float32x4 m0, m1, m2, m3; load3x3Matrix(m, m0, m1, m2, m3); - for( ; x < (len - 1)*3; x += 3 ) + for( ; x < (len - 1)*cWidth; x += cWidth ) { - __m128 x0 = _mm_loadu_ps(src + x); - __m128 y0 = _mm_add_ps(_mm_add_ps(_mm_add_ps( - _mm_mul_ps(m0, _mm_shuffle_ps(x0,x0,_MM_SHUFFLE(0,0,0,0))), - _mm_mul_ps(m1, _mm_shuffle_ps(x0,x0,_MM_SHUFFLE(1,1,1,1)))), - _mm_mul_ps(m2, _mm_shuffle_ps(x0,x0,_MM_SHUFFLE(2,2,2,2)))), m3); - _mm_storel_pi((__m64*)(dst + x), y0); - _mm_store_ss(dst + x + 2, _mm_movehl_ps(y0,y0)); + v_float32x4 x0 = v_load(src + x); + v_float32x4 y0 = v_matmuladd(x0, m0, m1, m2, m3); + v_store_low(dst + x, y0); + dst[x + 2] = v_combine_high(y0, y0).get0(); } - for( ; x < len*3; x += 3 ) + for( ; x < len*cWidth; x += cWidth ) { float v0 = src[x], v1 = src[x+1], v2 = src[x+2]; float t0 = saturate_cast(m[0]*v0 + m[1]*v1 + m[2]*v2 + m[3]); @@ -1946,18 +1915,18 @@ transform_32f( const float* src, float* dst, const float* m, int len, int scn, i if( scn == 4 && dcn == 4 ) { - __m128 m0, m1, m2, m3, m4; - load4x4Matrix(m, m0, m1, m2, m3, m4); + const int cWidth = 4; + v_float32x4 m0 = v_float32x4(m[0], m[5], m[10], m[15]); + v_float32x4 m1 = v_float32x4(m[1], m[6], m[11], m[16]); + v_float32x4 m2 = v_float32x4(m[2], m[7], m[12], m[17]); + v_float32x4 m3 = v_float32x4(m[3], m[8], m[13], m[18]); + v_float32x4 m4 = v_float32x4(m[4], m[9], m[14], m[19]); - for( ; x < len*4; x += 4 ) + for( ; x < len*cWidth; x += cWidth ) { - __m128 x0 = _mm_loadu_ps(src + x); - __m128 y0 = _mm_add_ps(_mm_add_ps(_mm_add_ps(_mm_add_ps( - _mm_mul_ps(m0, _mm_shuffle_ps(x0,x0,_MM_SHUFFLE(0,0,0,0))), - _mm_mul_ps(m1, _mm_shuffle_ps(x0,x0,_MM_SHUFFLE(1,1,1,1)))), - _mm_mul_ps(m2, _mm_shuffle_ps(x0,x0,_MM_SHUFFLE(2,2,2,2)))), - _mm_mul_ps(m3, _mm_shuffle_ps(x0,x0,_MM_SHUFFLE(3,3,3,3)))), m4); - _mm_storeu_ps(dst + x, y0); + v_float32x4 x0 = v_load(src + x); + v_float32x4 y0 = v_matmul(x0, m0, m1, m2, m3) + m4; + v_store(dst + x, y0); } return; } @@ -2342,58 +2311,21 @@ static void scaleAdd_32f(const float* src1, const float* src2, float* dst, { float alpha = *_alpha; int i = 0; -#if CV_SSE2 - if( USE_SSE2 ) +#if CV_SIMD128 + if (hasSIMD128()) { - __m128 a4 = _mm_set1_ps(alpha); - if( (((size_t)src1|(size_t)src2|(size_t)dst) & 15) == 0 ) - for( ; i <= len - 8; i += 8 ) - { - __m128 x0, x1, y0, y1, t0, t1; - x0 = _mm_load_ps(src1 + i); x1 = _mm_load_ps(src1 + i + 4); - y0 = _mm_load_ps(src2 + i); y1 = _mm_load_ps(src2 + i + 4); - t0 = _mm_add_ps(_mm_mul_ps(x0, a4), y0); - t1 = _mm_add_ps(_mm_mul_ps(x1, a4), y1); - _mm_store_ps(dst + i, t0); - _mm_store_ps(dst + i + 4, t1); - } - else - for( ; i <= len - 8; i += 8 ) - { - __m128 x0, x1, y0, y1, t0, t1; - x0 = _mm_loadu_ps(src1 + i); x1 = _mm_loadu_ps(src1 + i + 4); - y0 = _mm_loadu_ps(src2 + i); y1 = _mm_loadu_ps(src2 + i + 4); - t0 = _mm_add_ps(_mm_mul_ps(x0, a4), y0); - t1 = _mm_add_ps(_mm_mul_ps(x1, a4), y1); - _mm_storeu_ps(dst + i, t0); - _mm_storeu_ps(dst + i + 4, t1); - } - } - else -#elif CV_NEON - if (true) - { - for ( ; i <= len - 4; i += 4) + v_float32x4 v_alpha = v_setall_f32(alpha); + const int cWidth = v_float32x4::nlanes; + for (; i <= len - cWidth; i += cWidth) { - float32x4_t v_src1 = vld1q_f32(src1 + i), v_src2 = vld1q_f32(src2 + i); - vst1q_f32(dst + i, vaddq_f32(vmulq_n_f32(v_src1, alpha), v_src2)); + v_float32x4 v_src1 = v_load(src1 + i); + v_float32x4 v_src2 = v_load(src2 + i); + v_store(dst + i, (v_src1 * v_alpha) + v_src2); } } - else #endif - //vz why do we need unroll here? - for( ; i <= len - 4; i += 4 ) - { - float t0, t1; - t0 = src1[i]*alpha + src2[i]; - t1 = src1[i+1]*alpha + src2[i+1]; - dst[i] = t0; dst[i+1] = t1; - t0 = src1[i+2]*alpha + src2[i+2]; - t1 = src1[i+3]*alpha + src2[i+3]; - dst[i+2] = t0; dst[i+3] = t1; - } - for(; i < len; i++ ) - dst[i] = src1[i]*alpha + src2[i]; + for (; i < len; i++) + dst[i] = src1[i] * alpha + src2[i]; } @@ -2402,36 +2334,25 @@ static void scaleAdd_64f(const double* src1, const double* src2, double* dst, { double alpha = *_alpha; int i = 0; -#if CV_SSE2 - if( USE_SSE2 && (((size_t)src1|(size_t)src2|(size_t)dst) & 15) == 0 ) +#if CV_SIMD128_64F + if (hasSIMD128()) { - __m128d a2 = _mm_set1_pd(alpha); - for( ; i <= len - 4; i += 4 ) + v_float64x2 a2 = v_setall_f64(alpha); + const int cWidth = v_float64x2::nlanes; + for (; i <= len - cWidth * 2; i += cWidth * 2) { - __m128d x0, x1, y0, y1, t0, t1; - x0 = _mm_load_pd(src1 + i); x1 = _mm_load_pd(src1 + i + 2); - y0 = _mm_load_pd(src2 + i); y1 = _mm_load_pd(src2 + i + 2); - t0 = _mm_add_pd(_mm_mul_pd(x0, a2), y0); - t1 = _mm_add_pd(_mm_mul_pd(x1, a2), y1); - _mm_store_pd(dst + i, t0); - _mm_store_pd(dst + i + 2, t1); + v_float64x2 x0, x1, y0, y1, t0, t1; + x0 = v_load(src1 + i); x1 = v_load(src1 + i + cWidth); + y0 = v_load(src2 + i); y1 = v_load(src2 + i + cWidth); + t0 = x0 * a2 + y0; + t1 = x1 * a2 + y1; + v_store(dst + i, t0); + v_store(dst + i + cWidth, t1); } } - else #endif - //vz why do we need unroll here? - for( ; i <= len - 4; i += 4 ) - { - double t0, t1; - t0 = src1[i]*alpha + src2[i]; - t1 = src1[i+1]*alpha + src2[i+1]; - dst[i] = t0; dst[i+1] = t1; - t0 = src1[i+2]*alpha + src2[i+2]; - t1 = src1[i+3]*alpha + src2[i+3]; - dst[i+2] = t0; dst[i+3] = t1; - } - for(; i < len; i++ ) - dst[i] = src1[i]*alpha + src2[i]; + for (; i < len; i++) + dst[i] = src1[i] * alpha + src2[i]; } typedef void (*ScaleAddFunc)(const uchar* src1, const uchar* src2, uchar* dst, int len, const void* alpha); @@ -3105,43 +3026,36 @@ static double dotProd_8u(const uchar* src1, const uchar* src2, int len) #endif int i = 0; -#if CV_SSE2 - if( USE_SSE2 ) +#if CV_SIMD128 + if (hasSIMD128()) { - int j, len0 = len & -4, blockSize0 = (1 << 13), blockSize; - __m128i z = _mm_setzero_si128(); - CV_DECL_ALIGNED(16) int buf[4]; + int len0 = len & -8, blockSize0 = (1 << 15), blockSize; - while( i < len0 ) + while (i < len0) { blockSize = std::min(len0 - i, blockSize0); - __m128i s = z; - j = 0; - for( ; j <= blockSize - 16; j += 16 ) + v_int32x4 v_sum = v_setzero_s32(); + const int cWidth = v_uint16x8::nlanes; + + int j = 0; + for (; j <= blockSize - cWidth * 2; j += cWidth * 2) { - __m128i b0 = _mm_loadu_si128((const __m128i*)(src1 + j)); - __m128i b1 = _mm_loadu_si128((const __m128i*)(src2 + j)); - __m128i s0, s1, s2, s3; - s0 = _mm_unpacklo_epi8(b0, z); - s2 = _mm_unpackhi_epi8(b0, z); - s1 = _mm_unpacklo_epi8(b1, z); - s3 = _mm_unpackhi_epi8(b1, z); - s0 = _mm_madd_epi16(s0, s1); - s2 = _mm_madd_epi16(s2, s3); - s = _mm_add_epi32(s, s0); - s = _mm_add_epi32(s, s2); + v_uint16x8 v_src10, v_src20, v_src11, v_src21; + v_expand(v_load(src1 + j), v_src10, v_src11); + v_expand(v_load(src2 + j), v_src20, v_src21); + + v_sum += v_dotprod(v_reinterpret_as_s16(v_src10), v_reinterpret_as_s16(v_src20)); + v_sum += v_dotprod(v_reinterpret_as_s16(v_src11), v_reinterpret_as_s16(v_src21)); } - for( ; j < blockSize; j += 4 ) + for (; j <= blockSize - cWidth; j += cWidth) { - __m128i s0 = _mm_unpacklo_epi8(_mm_cvtsi32_si128(*(const int*)(src1 + j)), z); - __m128i s1 = _mm_unpacklo_epi8(_mm_cvtsi32_si128(*(const int*)(src2 + j)), z); - s0 = _mm_madd_epi16(s0, s1); - s = _mm_add_epi32(s, s0); - } + v_int16x8 v_src10 = v_reinterpret_as_s16(v_load_expand(src1 + j)); + v_int16x8 v_src20 = v_reinterpret_as_s16(v_load_expand(src2 + j)); - _mm_store_si128((__m128i*)buf, s); - r += buf[0] + buf[1] + buf[2] + buf[3]; + v_sum += v_dotprod(v_src10, v_src20); + } + r += (double)v_reduce_sum(v_sum); src1 += blockSize; src2 += blockSize; @@ -3149,43 +3063,46 @@ static double dotProd_8u(const uchar* src1, const uchar* src2, int len) } } #elif CV_NEON - int len0 = len & -8, blockSize0 = (1 << 15), blockSize; - uint32x4_t v_zero = vdupq_n_u32(0u); - CV_DECL_ALIGNED(16) uint buf[4]; - - while( i < len0 ) + if( cv::checkHardwareSupport(CV_CPU_NEON) ) { - blockSize = std::min(len0 - i, blockSize0); - uint32x4_t v_sum = v_zero; + int len0 = len & -8, blockSize0 = (1 << 15), blockSize; + uint32x4_t v_zero = vdupq_n_u32(0u); + CV_DECL_ALIGNED(16) uint buf[4]; - int j = 0; - for( ; j <= blockSize - 16; j += 16 ) + while( i < len0 ) { - uint8x16_t v_src1 = vld1q_u8(src1 + j), v_src2 = vld1q_u8(src2 + j); + blockSize = std::min(len0 - i, blockSize0); + uint32x4_t v_sum = v_zero; - uint16x8_t v_src10 = vmovl_u8(vget_low_u8(v_src1)), v_src20 = vmovl_u8(vget_low_u8(v_src2)); - v_sum = vmlal_u16(v_sum, vget_low_u16(v_src10), vget_low_u16(v_src20)); - v_sum = vmlal_u16(v_sum, vget_high_u16(v_src10), vget_high_u16(v_src20)); + int j = 0; + for( ; j <= blockSize - 16; j += 16 ) + { + uint8x16_t v_src1 = vld1q_u8(src1 + j), v_src2 = vld1q_u8(src2 + j); - v_src10 = vmovl_u8(vget_high_u8(v_src1)); - v_src20 = vmovl_u8(vget_high_u8(v_src2)); - v_sum = vmlal_u16(v_sum, vget_low_u16(v_src10), vget_low_u16(v_src20)); - v_sum = vmlal_u16(v_sum, vget_high_u16(v_src10), vget_high_u16(v_src20)); + uint16x8_t v_src10 = vmovl_u8(vget_low_u8(v_src1)), v_src20 = vmovl_u8(vget_low_u8(v_src2)); + v_sum = vmlal_u16(v_sum, vget_low_u16(v_src10), vget_low_u16(v_src20)); + v_sum = vmlal_u16(v_sum, vget_high_u16(v_src10), vget_high_u16(v_src20)); + + v_src10 = vmovl_u8(vget_high_u8(v_src1)); + v_src20 = vmovl_u8(vget_high_u8(v_src2)); + v_sum = vmlal_u16(v_sum, vget_low_u16(v_src10), vget_low_u16(v_src20)); + v_sum = vmlal_u16(v_sum, vget_high_u16(v_src10), vget_high_u16(v_src20)); + } + + for( ; j <= blockSize - 8; j += 8 ) + { + uint16x8_t v_src1 = vmovl_u8(vld1_u8(src1 + j)), v_src2 = vmovl_u8(vld1_u8(src2 + j)); + v_sum = vmlal_u16(v_sum, vget_low_u16(v_src1), vget_low_u16(v_src2)); + v_sum = vmlal_u16(v_sum, vget_high_u16(v_src1), vget_high_u16(v_src2)); + } + + vst1q_u32(buf, v_sum); + r += buf[0] + buf[1] + buf[2] + buf[3]; + + src1 += blockSize; + src2 += blockSize; + i += blockSize; } - - for( ; j <= blockSize - 8; j += 8 ) - { - uint16x8_t v_src1 = vmovl_u8(vld1_u8(src1 + j)), v_src2 = vmovl_u8(vld1_u8(src2 + j)); - v_sum = vmlal_u16(v_sum, vget_low_u16(v_src1), vget_low_u16(v_src2)); - v_sum = vmlal_u16(v_sum, vget_high_u16(v_src1), vget_high_u16(v_src2)); - } - - vst1q_u32(buf, v_sum); - r += buf[0] + buf[1] + buf[2] + buf[3]; - - src1 += blockSize; - src2 += blockSize; - i += blockSize; } #endif return r + dotProd_(src1, src2, len - i); @@ -3194,48 +3111,39 @@ static double dotProd_8u(const uchar* src1, const uchar* src2, int len) static double dotProd_8s(const schar* src1, const schar* src2, int len) { - int i = 0; double r = 0.0; + int i = 0; -#if CV_SSE2 - if( USE_SSE2 ) +#if CV_SIMD128 + if (hasSIMD128()) { - int j, len0 = len & -4, blockSize0 = (1 << 13), blockSize; - __m128i z = _mm_setzero_si128(); - CV_DECL_ALIGNED(16) int buf[4]; + int len0 = len & -8, blockSize0 = (1 << 14), blockSize; - while( i < len0 ) + while (i < len0) { blockSize = std::min(len0 - i, blockSize0); - __m128i s = z; - j = 0; - for( ; j <= blockSize - 16; j += 16 ) + v_int32x4 v_sum = v_setzero_s32(); + const int cWidth = v_int16x8::nlanes; + + int j = 0; + for (; j <= blockSize - cWidth * 2; j += cWidth * 2) { - __m128i b0 = _mm_loadu_si128((const __m128i*)(src1 + j)); - __m128i b1 = _mm_loadu_si128((const __m128i*)(src2 + j)); - __m128i s0, s1, s2, s3; - s0 = _mm_srai_epi16(_mm_unpacklo_epi8(b0, b0), 8); - s2 = _mm_srai_epi16(_mm_unpackhi_epi8(b0, b0), 8); - s1 = _mm_srai_epi16(_mm_unpacklo_epi8(b1, b1), 8); - s3 = _mm_srai_epi16(_mm_unpackhi_epi8(b1, b1), 8); - s0 = _mm_madd_epi16(s0, s1); - s2 = _mm_madd_epi16(s2, s3); - s = _mm_add_epi32(s, s0); - s = _mm_add_epi32(s, s2); + v_int16x8 v_src10, v_src20, v_src11, v_src21; + v_expand(v_load(src1 + j), v_src10, v_src11); + v_expand(v_load(src2 + j), v_src20, v_src21); + + v_sum += v_dotprod(v_src10, v_src20); + v_sum += v_dotprod(v_src11, v_src21); } - for( ; j < blockSize; j += 4 ) + for (; j <= blockSize - cWidth; j += cWidth) { - __m128i s0 = _mm_cvtsi32_si128(*(const int*)(src1 + j)); - __m128i s1 = _mm_cvtsi32_si128(*(const int*)(src2 + j)); - s0 = _mm_srai_epi16(_mm_unpacklo_epi8(s0, s0), 8); - s1 = _mm_srai_epi16(_mm_unpacklo_epi8(s1, s1), 8); - s0 = _mm_madd_epi16(s0, s1); - s = _mm_add_epi32(s, s0); - } + v_int16x8 v_src10 = v_load_expand(src1 + j); + v_int16x8 v_src20 = v_load_expand(src2 + j); - _mm_store_si128((__m128i*)buf, s); - r += buf[0] + buf[1] + buf[2] + buf[3]; + v_sum += v_dotprod(v_src10, v_src20); + } + r += (double)v_reduce_sum(v_sum); src1 += blockSize; src2 += blockSize; @@ -3243,43 +3151,46 @@ static double dotProd_8s(const schar* src1, const schar* src2, int len) } } #elif CV_NEON - int len0 = len & -8, blockSize0 = (1 << 14), blockSize; - int32x4_t v_zero = vdupq_n_s32(0); - CV_DECL_ALIGNED(16) int buf[4]; - - while( i < len0 ) + if( cv::checkHardwareSupport(CV_CPU_NEON) ) { - blockSize = std::min(len0 - i, blockSize0); - int32x4_t v_sum = v_zero; + int len0 = len & -8, blockSize0 = (1 << 14), blockSize; + int32x4_t v_zero = vdupq_n_s32(0); + CV_DECL_ALIGNED(16) int buf[4]; - int j = 0; - for( ; j <= blockSize - 16; j += 16 ) + while( i < len0 ) { - int8x16_t v_src1 = vld1q_s8(src1 + j), v_src2 = vld1q_s8(src2 + j); + blockSize = std::min(len0 - i, blockSize0); + int32x4_t v_sum = v_zero; - int16x8_t v_src10 = vmovl_s8(vget_low_s8(v_src1)), v_src20 = vmovl_s8(vget_low_s8(v_src2)); - v_sum = vmlal_s16(v_sum, vget_low_s16(v_src10), vget_low_s16(v_src20)); - v_sum = vmlal_s16(v_sum, vget_high_s16(v_src10), vget_high_s16(v_src20)); + int j = 0; + for( ; j <= blockSize - 16; j += 16 ) + { + int8x16_t v_src1 = vld1q_s8(src1 + j), v_src2 = vld1q_s8(src2 + j); - v_src10 = vmovl_s8(vget_high_s8(v_src1)); - v_src20 = vmovl_s8(vget_high_s8(v_src2)); - v_sum = vmlal_s16(v_sum, vget_low_s16(v_src10), vget_low_s16(v_src20)); - v_sum = vmlal_s16(v_sum, vget_high_s16(v_src10), vget_high_s16(v_src20)); + int16x8_t v_src10 = vmovl_s8(vget_low_s8(v_src1)), v_src20 = vmovl_s8(vget_low_s8(v_src2)); + v_sum = vmlal_s16(v_sum, vget_low_s16(v_src10), vget_low_s16(v_src20)); + v_sum = vmlal_s16(v_sum, vget_high_s16(v_src10), vget_high_s16(v_src20)); + + v_src10 = vmovl_s8(vget_high_s8(v_src1)); + v_src20 = vmovl_s8(vget_high_s8(v_src2)); + v_sum = vmlal_s16(v_sum, vget_low_s16(v_src10), vget_low_s16(v_src20)); + v_sum = vmlal_s16(v_sum, vget_high_s16(v_src10), vget_high_s16(v_src20)); + } + + for( ; j <= blockSize - 8; j += 8 ) + { + int16x8_t v_src1 = vmovl_s8(vld1_s8(src1 + j)), v_src2 = vmovl_s8(vld1_s8(src2 + j)); + v_sum = vmlal_s16(v_sum, vget_low_s16(v_src1), vget_low_s16(v_src2)); + v_sum = vmlal_s16(v_sum, vget_high_s16(v_src1), vget_high_s16(v_src2)); + } + + vst1q_s32(buf, v_sum); + r += buf[0] + buf[1] + buf[2] + buf[3]; + + src1 += blockSize; + src2 += blockSize; + i += blockSize; } - - for( ; j <= blockSize - 8; j += 8 ) - { - int16x8_t v_src1 = vmovl_s8(vld1_s8(src1 + j)), v_src2 = vmovl_s8(vld1_s8(src2 + j)); - v_sum = vmlal_s16(v_sum, vget_low_s16(v_src1), vget_low_s16(v_src2)); - v_sum = vmlal_s16(v_sum, vget_high_s16(v_src1), vget_high_s16(v_src2)); - } - - vst1q_s32(buf, v_sum); - r += buf[0] + buf[1] + buf[2] + buf[3]; - - src1 += blockSize; - src2 += blockSize; - i += blockSize; } #endif @@ -3322,26 +3233,27 @@ static double dotProd_32f(const float* src1, const float* src2, int len) #endif int i = 0; -#if CV_NEON - int len0 = len & -4, blockSize0 = (1 << 13), blockSize; - float32x4_t v_zero = vdupq_n_f32(0.0f); - CV_DECL_ALIGNED(16) float buf[4]; - - while( i < len0 ) +#if CV_SIMD128 + if (hasSIMD128()) { - blockSize = std::min(len0 - i, blockSize0); - float32x4_t v_sum = v_zero; + int len0 = len & -4, blockSize0 = (1 << 13), blockSize; - int j = 0; - for( ; j <= blockSize - 4; j += 4 ) - v_sum = vmlaq_f32(v_sum, vld1q_f32(src1 + j), vld1q_f32(src2 + j)); + while (i < len0) + { + blockSize = std::min(len0 - i, blockSize0); + v_float32x4 v_sum = v_setzero_f32(); - vst1q_f32(buf, v_sum); - r += buf[0] + buf[1] + buf[2] + buf[3]; + int j = 0; + int cWidth = v_float32x4::nlanes; + for (; j <= blockSize - cWidth; j += cWidth) + v_sum = v_muladd(v_load(src1 + j), v_load(src2 + j), v_sum); - src1 += blockSize; - src2 += blockSize; - i += blockSize; + r += v_reduce_sum(v_sum); + + src1 += blockSize; + src2 += blockSize; + i += blockSize; + } } #endif return r + dotProd_(src1, src2, len - i); diff --git a/modules/core/test/test_intrin.cpp b/modules/core/test/test_intrin.cpp index 09cf196c30..d73252984a 100644 --- a/modules/core/test/test_intrin.cpp +++ b/modules/core/test/test_intrin.cpp @@ -33,6 +33,7 @@ TEST(hal_intrin, uint8x16) { .test_pack_u<1>().test_pack_u<2>().test_pack_u<3>().test_pack_u<8>() .test_unpack() .test_extract<0>().test_extract<1>().test_extract<8>().test_extract<15>() + .test_rotate<0>().test_rotate<1>().test_rotate<8>().test_rotate<15>() ; } @@ -54,6 +55,7 @@ TEST(hal_intrin, int8x16) { .test_pack<1>().test_pack<2>().test_pack<3>().test_pack<8>() .test_unpack() .test_extract<0>().test_extract<1>().test_extract<8>().test_extract<15>() + .test_rotate<0>().test_rotate<1>().test_rotate<8>().test_rotate<15>() ; } @@ -81,6 +83,7 @@ TEST(hal_intrin, uint16x8) { .test_pack_u<1>().test_pack_u<2>().test_pack_u<7>().test_pack_u<16>() .test_unpack() .test_extract<0>().test_extract<1>().test_extract<4>().test_extract<7>() + .test_rotate<0>().test_rotate<1>().test_rotate<4>().test_rotate<7>() ; } @@ -107,6 +110,7 @@ TEST(hal_intrin, int16x8) { .test_pack<1>().test_pack<2>().test_pack<7>().test_pack<16>() .test_unpack() .test_extract<0>().test_extract<1>().test_extract<4>().test_extract<7>() + .test_rotate<0>().test_rotate<1>().test_rotate<4>().test_rotate<7>() ; } @@ -132,6 +136,7 @@ TEST(hal_intrin, uint32x4) { .test_pack<1>().test_pack<2>().test_pack<15>().test_pack<32>() .test_unpack() .test_extract<0>().test_extract<1>().test_extract<2>().test_extract<3>() + .test_rotate<0>().test_rotate<1>().test_rotate<2>().test_rotate<3>() .test_transpose() ; } @@ -155,6 +160,7 @@ TEST(hal_intrin, int32x4) { .test_pack<1>().test_pack<2>().test_pack<15>().test_pack<32>() .test_unpack() .test_extract<0>().test_extract<1>().test_extract<2>().test_extract<3>() + .test_rotate<0>().test_rotate<1>().test_rotate<2>().test_rotate<3>() .test_float_cvt32() .test_float_cvt64() .test_transpose() @@ -170,6 +176,7 @@ TEST(hal_intrin, uint64x2) { .test_shift<1>().test_shift<8>() .test_logic() .test_extract<0>().test_extract<1>() + .test_rotate<0>().test_rotate<1>() ; } @@ -180,6 +187,7 @@ TEST(hal_intrin, int64x2) { .test_shift<1>().test_shift<8>() .test_logic() .test_extract<0>().test_extract<1>() + .test_rotate<0>().test_rotate<1>() ; } diff --git a/modules/core/test/test_intrin_utils.hpp b/modules/core/test/test_intrin_utils.hpp index 249ef38947..678bbd4628 100644 --- a/modules/core/test/test_intrin_utils.hpp +++ b/modules/core/test/test_intrin_utils.hpp @@ -793,6 +793,30 @@ template struct TheTest return *this; } + template + TheTest & test_rotate() + { + Data dataA, dataB; + dataB *= 10; + R a = dataA, b = dataB; + + Data resC = v_rotate_right(a); + Data resD = v_rotate_right(a, b); + + for (int i = 0; i < R::nlanes; ++i) + { + if (i + s >= R::nlanes) + { + EXPECT_EQ((LaneType)0, resC[i]); + EXPECT_EQ(dataB[i - R::nlanes + s], resD[i]); + } + else + EXPECT_EQ(dataA[i + s], resC[i]); + } + + return *this; + } + TheTest & test_float_math() { typedef typename V_RegTrait128::int_reg Ri; @@ -882,6 +906,16 @@ template struct TheTest + dataV[3] * dataD[i]; EXPECT_DOUBLE_EQ(val, res[i]); } + + Data resAdd = v_matmuladd(v, a, b, c, d); + for (int i = 0; i < R::nlanes; ++i) + { + LaneType val = dataV[0] * dataA[i] + + dataV[1] * dataB[i] + + dataV[2] * dataC[i] + + dataD[i]; + EXPECT_DOUBLE_EQ(val, resAdd[i]); + } return *this; } diff --git a/modules/core/test/test_math.cpp b/modules/core/test/test_math.cpp index ed960dcf98..411f8b6fff 100644 --- a/modules/core/test/test_math.cpp +++ b/modules/core/test/test_math.cpp @@ -904,6 +904,60 @@ void Core_TransformTest::prepare_to_validation( int ) cvtest::transform( test_mat[INPUT][0], test_mat[REF_OUTPUT][0], transmat, shift ); } +class Core_TransformLargeTest : public Core_TransformTest +{ +public: + typedef Core_MatrixTest Base; +protected: + void get_test_array_types_and_sizes(int test_case_idx, vector >& sizes, vector >& types); +}; + +void Core_TransformLargeTest::get_test_array_types_and_sizes(int test_case_idx, vector >& sizes, vector >& types) +{ + RNG& rng = ts->get_rng(); + int bits = cvtest::randInt(rng); + int depth, dst_cn, mat_cols, mattype; + Base::get_test_array_types_and_sizes(test_case_idx, sizes, types); + for (unsigned int j = 0; j < sizes.size(); j++) + { + for (unsigned int i = 0; i < sizes[j].size(); i++) + { + sizes[j][i].width *= 4; + } + } + + mat_cols = CV_MAT_CN(types[INPUT][0]); + depth = CV_MAT_DEPTH(types[INPUT][0]); + dst_cn = cvtest::randInt(rng) % 4 + 1; + types[OUTPUT][0] = types[REF_OUTPUT][0] = CV_MAKETYPE(depth, dst_cn); + + mattype = depth < CV_32S ? CV_32F : depth == CV_64F ? CV_64F : bits & 1 ? CV_32F : CV_64F; + types[INPUT][1] = mattype; + types[INPUT][2] = CV_MAKETYPE(mattype, dst_cn); + + scale = 1. / ((cvtest::randInt(rng) % 4) * 50 + 1); + + if (bits & 2) + { + sizes[INPUT][2] = Size(0, 0); + mat_cols += (bits & 4) != 0; + } + else if (bits & 4) + sizes[INPUT][2] = Size(1, 1); + else + { + if (bits & 8) + sizes[INPUT][2] = Size(dst_cn, 1); + else + sizes[INPUT][2] = Size(1, dst_cn); + types[INPUT][2] &= ~CV_MAT_CN_MASK; + } + diagMtx = (bits & 16) != 0; + + sizes[INPUT][1] = Size(mat_cols, dst_cn); +} + + ///////////////// PerspectiveTransform ///////////////////// @@ -2691,6 +2745,7 @@ TEST(Core_Invert, accuracy) { Core_InvertTest test; test.safe_run(); } TEST(Core_Mahalanobis, accuracy) { Core_MahalanobisTest test; test.safe_run(); } TEST(Core_MulTransposed, accuracy) { Core_MulTransposedTest test; test.safe_run(); } TEST(Core_Transform, accuracy) { Core_TransformTest test; test.safe_run(); } +TEST(Core_TransformLarge, accuracy) { Core_TransformLargeTest test; test.safe_run(); } TEST(Core_PerspectiveTransform, accuracy) { Core_PerspectiveTransformTest test; test.safe_run(); } TEST(Core_Pow, accuracy) { Core_PowTest test; test.safe_run(); } TEST(Core_SolveLinearSystem, accuracy) { Core_SolveTest test; test.safe_run(); }