opencv/modules/dnn/src/opencl/gemm_buffer.cl

1343 lines
54 KiB
Common Lisp
Raw Normal View History

/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
//
//
// License Agreement
// For Open Source Computer Vision Library
//
// Copyright (C) 2017, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// * Redistribution's of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistribution's in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * The name of the copyright holders may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/
#if defined(cl_khr_fp16)
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#endif
#define CONCAT(A,B) A##_##B
#define TEMPLATE(name,type) CONCAT(name,type)
#define KERNEL_ARG_DTYPE float
#define TYPE_FLOAT 1
#define TYPE_HALF 2
#if TYPE == TYPE_HALF
#define Dtype half
#define Dtype2 half2
#define Dtype4 half4
#define Dtype8 half8
#define Dtype16 half16
#define as_Dtype as_half
#define as_Dtype2 as_half2
#define as_Dtype4 as_half4
#define as_Dtype8 as_half8
#define as_Dtype16 as_half16
#else
#define Dtype float
#define Dtype2 float2
#define Dtype4 float4
#define Dtype8 float8
#define Dtype16 float16
#define as_Dtype as_float
#define as_Dtype2 as_float2
#define as_Dtype4 as_float4
#define as_Dtype8 as_float8
#define as_Dtype16 as_float16
#endif
#if TYPE == TYPE_HALF
#define SHUFFLE_TYPE2(val) as_ushort2(val)
#define SHUFFLE_TYPE8(val) as_ushort8(val)
#define SIMD_SIZE_GEMM 16
#else
#define SHUFFLE_TYPE2(val) val
#define SHUFFLE_TYPE8(val) val
#define SIMD_SIZE_GEMM 8
#endif
#if defined(cl_intel_subgroups)
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#endif
#define VEC_SIZE 4
#define LWG_HEIGHT 4
#define TILE_M 8
#if TYPE == TYPE_HALF
#define TILE_K 32
#define TILE_N 64
#else
#define TILE_K 16
#define TILE_N 32
#endif
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, LWG_HEIGHT, 1)))
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM)))
__kernel void TEMPLATE(gemm_buffer_NN, Dtype)(
const __global Dtype *src0, int off0,
const __global Dtype *src1, int off1,
__global Dtype *dst, int offd,
int M,
int N,
int K,
KERNEL_ARG_DTYPE alpha_in,
KERNEL_ARG_DTYPE beta_in,
int start_index)
{
const Dtype alpha = (Dtype)alpha_in;
const Dtype beta = (Dtype)beta_in;
const int group_x = get_group_id(0);
const int group_y = get_group_id(1);
const int local_x = get_local_id(0);
const int local_y = get_local_id(1);
const int global_x = get_global_id(0);
const int global_y = get_global_id(1);
Dtype4 brow;
Dtype2 arow0, arow1, arow2, arow3, arow4, arow5, arow6, arow7;
__global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;
const __global Dtype *src0_read = src0 + local_x * (TILE_K / SIMD_SIZE_GEMM) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + start_index + off0;
const __global Dtype *src1_read0 = src1 + local_x * VEC_SIZE + (group_x * TILE_N) + start_index * N + off1;
int border = -(group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M);
int row0 = mad24(global_y, TILE_M, 0) < M ? 0 : border;
int row1 = mad24(global_y, TILE_M, 1) < M ? 1 : border;
int row2 = mad24(global_y, TILE_M, 2) < M ? 2 : border;
int row3 = mad24(global_y, TILE_M, 3) < M ? 3 : border;
int row4 = mad24(global_y, TILE_M, 4) < M ? 4 : border;
int row5 = mad24(global_y, TILE_M, 5) < M ? 5 : border;
int row6 = mad24(global_y, TILE_M, 6) < M ? 6 : border;
int row7 = mad24(global_y, TILE_M, 7) < M ? 7 : border;
Dtype4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0);
Dtype4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + 1 * N) : beta * vload4(0, dst_write0 + 1 * N);
Dtype4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N);
Dtype4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N);
Dtype4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N);
Dtype4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N);
Dtype4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N);
Dtype4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N);
int end_index = min(start_index + 256, K);
int w = start_index;
while( w + TILE_K <= end_index ) {
arow0 = alpha * vload2(0, src0_read + row0 * K);
arow1 = alpha * vload2(0, src0_read + row1 * K);
arow2 = alpha * vload2(0, src0_read + row2 * K);
arow3 = alpha * vload2(0, src0_read + row3 * K);
arow4 = alpha * vload2(0, src0_read + row4 * K);
arow5 = alpha * vload2(0, src0_read + row5 * K);
arow6 = alpha * vload2(0, src0_read + row6 * K);
arow7 = alpha * vload2(0, src0_read + row7 * K);
#define MM_DOT_PRODUCT( index, suffix ) \
brow = vload4(0, src1_read0); src1_read0 += N; \
dot00 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow0), index )).s##suffix), brow, dot00 ); \
dot01 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow1), index )).s##suffix), brow, dot01 ); \
dot02 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow2), index )).s##suffix), brow, dot02 ); \
dot03 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow3), index )).s##suffix), brow, dot03 ); \
dot04 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow4), index )).s##suffix), brow, dot04 ); \
dot05 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow5), index )).s##suffix), brow, dot05 ); \
dot06 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow6), index )).s##suffix), brow, dot06 ); \
dot07 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow7), index )).s##suffix), brow, dot07 );
MM_DOT_PRODUCT(0, 0);
MM_DOT_PRODUCT(0, 1);
MM_DOT_PRODUCT(1, 0);
MM_DOT_PRODUCT(1, 1);
MM_DOT_PRODUCT(2, 0);
MM_DOT_PRODUCT(2, 1);
MM_DOT_PRODUCT(3, 0);
MM_DOT_PRODUCT(3, 1);
MM_DOT_PRODUCT(4, 0);
MM_DOT_PRODUCT(4, 1);
MM_DOT_PRODUCT(5, 0);
MM_DOT_PRODUCT(5, 1);
MM_DOT_PRODUCT(6, 0);
MM_DOT_PRODUCT(6, 1);
MM_DOT_PRODUCT(7, 0);
MM_DOT_PRODUCT(7, 1);
#if TYPE == TYPE_HALF
MM_DOT_PRODUCT(8, 0);
MM_DOT_PRODUCT(8, 1);
MM_DOT_PRODUCT(9, 0);
MM_DOT_PRODUCT(9, 1);
MM_DOT_PRODUCT(10, 0);
MM_DOT_PRODUCT(10, 1);
MM_DOT_PRODUCT(11, 0);
MM_DOT_PRODUCT(11, 1);
MM_DOT_PRODUCT(12, 0);
MM_DOT_PRODUCT(12, 1);
MM_DOT_PRODUCT(13, 0);
MM_DOT_PRODUCT(13, 1);
MM_DOT_PRODUCT(14, 0);
MM_DOT_PRODUCT(14, 1);
MM_DOT_PRODUCT(15, 0);
MM_DOT_PRODUCT(15, 1);
#endif
#undef MM_DOT_PRODUCT
src0_read += TILE_K;
w += TILE_K;
}
if(w < end_index) {
arow0.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row0 * K)[0] : 0.0f;
arow0.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row0 * K)[1] : 0.0f;
arow1.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row1 * K)[0] : 0.0f;
arow1.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row1 * K)[1] : 0.0f;
arow2.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row2 * K)[0] : 0.0f;
arow2.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row2 * K)[1] : 0.0f;
arow3.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row3 * K)[0] : 0.0f;
arow3.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row3 * K)[1] : 0.0f;
arow4.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row4 * K)[0] : 0.0f;
arow4.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row4 * K)[1] : 0.0f;
arow5.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row5 * K)[0] : 0.0f;
arow5.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row5 * K)[1] : 0.0f;
arow6.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row6 * K)[0] : 0.0f;
arow6.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row6 * K)[1] : 0.0f;
arow7.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row7 * K)[0] : 0.0f;
arow7.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row7 * K)[1] : 0.0f;
#define MM_DOT_PRODUCT( index, suffix ) \
brow = (w < K) ? vload4(0, src1_read0) : (Dtype4)0.0f; src1_read0 += N; w++; \
dot00 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow0), index )).s##suffix), brow, dot00 ); \
dot01 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow1), index )).s##suffix), brow, dot01 ); \
dot02 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow2), index )).s##suffix), brow, dot02 ); \
dot03 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow3), index )).s##suffix), brow, dot03 ); \
dot04 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow4), index )).s##suffix), brow, dot04 ); \
dot05 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow5), index )).s##suffix), brow, dot05 ); \
dot06 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow6), index )).s##suffix), brow, dot06 ); \
dot07 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow7), index )).s##suffix), brow, dot07 );
MM_DOT_PRODUCT(0, 0);
MM_DOT_PRODUCT(0, 1);
MM_DOT_PRODUCT(1, 0);
MM_DOT_PRODUCT(1, 1);
MM_DOT_PRODUCT(2, 0);
MM_DOT_PRODUCT(2, 1);
MM_DOT_PRODUCT(3, 0);
MM_DOT_PRODUCT(3, 1);
MM_DOT_PRODUCT(4, 0);
MM_DOT_PRODUCT(4, 1);
MM_DOT_PRODUCT(5, 0);
MM_DOT_PRODUCT(5, 1);
MM_DOT_PRODUCT(6, 0);
MM_DOT_PRODUCT(6, 1);
MM_DOT_PRODUCT(7, 0);
MM_DOT_PRODUCT(7, 1);
#if TYPE == TYPE_HALF
MM_DOT_PRODUCT(8, 0);
MM_DOT_PRODUCT(8, 1);
MM_DOT_PRODUCT(9, 0);
MM_DOT_PRODUCT(9, 1);
MM_DOT_PRODUCT(10, 0);
MM_DOT_PRODUCT(10, 1);
MM_DOT_PRODUCT(11, 0);
MM_DOT_PRODUCT(11, 1);
MM_DOT_PRODUCT(12, 0);
MM_DOT_PRODUCT(12, 1);
MM_DOT_PRODUCT(13, 0);
MM_DOT_PRODUCT(13, 1);
MM_DOT_PRODUCT(14, 0);
MM_DOT_PRODUCT(14, 1);
MM_DOT_PRODUCT(15, 0);
MM_DOT_PRODUCT(15, 1);
#endif
#undef MM_DOT_PRODUCT
}
if(global_x * 4 < N && global_y * 8 < M) {
if(mad24(global_x, 4, 3) < N) {
vstore4(dot00, 0, dst_write0); dst_write0 += N;
if(mad24(global_y, 8, 1) < M) { vstore4(dot01, 0, dst_write0); dst_write0 += N; }
else return;
if(mad24(global_y, 8, 2) < M) { vstore4(dot02, 0, dst_write0); dst_write0 += N; }
else return;
if(mad24(global_y, 8, 3) < M) { vstore4(dot03, 0, dst_write0); dst_write0 += N; }
else return;
if(mad24(global_y, 8, 4) < M) { vstore4(dot04, 0, dst_write0); dst_write0 += N; }
else return;
if(mad24(global_y, 8, 5) < M) { vstore4(dot05, 0, dst_write0); dst_write0 += N; }
else return;
if(mad24(global_y, 8, 6) < M) { vstore4(dot06, 0, dst_write0); dst_write0 += N; }
else return;
if(mad24(global_y, 8, 7) < M) { vstore4(dot07, 0, dst_write0); }
} else if(mad24(global_x, 4, 2) < N) {
vstore2(dot00.xy, 0, dst_write0);
dst_write0[2] = dot00.z;
dst_write0 += N;
if(mad24(global_y, 8, 1) < M) {
vstore2(dot01.xy, 0, dst_write0);
dst_write0[2] = dot01.z;
dst_write0 += N;
} else
return;
if(mad24(global_y, 8, 2) < M) {
vstore2(dot02.xy, 0, dst_write0);
dst_write0[2] = dot02.z;
dst_write0 += N;
} else
return;
if(mad24(global_y, 8, 3) < M) {
vstore2(dot03.xy, 0, dst_write0);
dst_write0[2] = dot03.z;
dst_write0 += N;
} else
return;
if(mad24(global_y, 8, 4) < M) {
vstore2(dot04.xy, 0, dst_write0);
dst_write0[2] = dot04.z;
dst_write0 += N;
} else
return;
if(mad24(global_y, 8, 5) < M) {
vstore2(dot05.xy, 0, dst_write0);
dst_write0[2] = dot05.z;
dst_write0 += N;
} else
return;
if(mad24(global_y, 8, 6) < M) {
vstore2(dot06.xy, 0, dst_write0);
dst_write0[2] = dot06.z;
dst_write0 += N;
} else
return;
if(mad24(global_y, 8, 7) < M) {
vstore2(dot07.xy, 0, dst_write0);
dst_write0[2] = dot07.z;
}
} else if(mad24(global_x, 4, 1) < N) {
vstore2(dot00.xy, 0, dst_write0); dst_write0 += N;
if(mad24(global_y, 8, 1) < M) { vstore2(dot01.xy, 0, dst_write0); dst_write0 += N; }
else return;
if(mad24(global_y, 8, 2) < M) { vstore2(dot02.xy, 0, dst_write0); dst_write0 += N; }
else return;
if(mad24(global_y, 8, 3) < M) { vstore2(dot03.xy, 0, dst_write0); dst_write0 += N; }
else return;
if(mad24(global_y, 8, 4) < M) { vstore2(dot04.xy, 0, dst_write0); dst_write0 += N; }
else return;
if(mad24(global_y, 8, 5) < M) { vstore2(dot05.xy, 0, dst_write0); dst_write0 += N; }
else return;
if(mad24(global_y, 8, 6) < M) { vstore2(dot06.xy, 0, dst_write0); dst_write0 += N; }
else return;
if(mad24(global_y, 8, 7) < M) { vstore2(dot07.xy, 0, dst_write0); }
} else {
dst_write0[0] = dot00.x; dst_write0 += N;
if(mad24(global_y, 8, 1) < M) { dst_write0[0] = dot01.x; dst_write0 += N; }
else return;
if(mad24(global_y, 8, 2) < M) { dst_write0[0] = dot02.x; dst_write0 += N; }
else return;
if(mad24(global_y, 8, 3) < M) { dst_write0[0] = dot03.x; dst_write0 += N; }
else return;
if(mad24(global_y, 8, 4) < M) { dst_write0[0] = dot04.x; dst_write0 += N; }
else return;
if(mad24(global_y, 8, 5) < M) { dst_write0[0] = dot05.x; dst_write0 += N; }
else return;
if(mad24(global_y, 8, 6) < M) { dst_write0[0] = dot06.x; dst_write0 += N; }
else return;
if(mad24(global_y, 8, 7) < M) { dst_write0[0] = dot07.x; }
}
}
}
#undef VEC_SIZE
#undef LWG_HEIGHT
#undef TILE_M
#undef TILE_K
#undef TILE_N
#define VEC_SIZE 1
#define TILE_M 8
#define TILE_N 8
#define SLM_BLOCK 128
#if TYPE == TYPE_HALF
#define LWG_HEIGHT 2
#define TILE_K 64
#else
#define LWG_HEIGHT 4
#define TILE_K 32
#endif
#if TYPE == TYPE_HALF
__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))
__attribute__((intel_reqd_sub_group_size(8)))
__kernel void TEMPLATE(gemm_buffer_NT, Dtype)(
const __global Dtype *src0, int off0,
const __global Dtype *src1, int off1,
__global Dtype *dst, int offd,
int M,
int N,
int K,
KERNEL_ARG_DTYPE alpha_in,
KERNEL_ARG_DTYPE beta_in)
{
const Dtype alpha = (Dtype)alpha_in;
const Dtype beta = (Dtype)beta_in;
const int group_x = get_group_id(0);
const int group_y = get_group_id(1);
const int local_x = get_local_id(0);
const int local_y = get_local_id(1);
const int global_x = get_global_id(0);
const int global_y = get_global_id(1);
Dtype8 dot00 = 0.f;
Dtype8 dot01 = 0.f;
Dtype8 dot02 = 0.f;
Dtype8 dot03 = 0.f;
Dtype8 dot04 = 0.f;
Dtype8 dot05 = 0.f;
Dtype8 dot06 = 0.f;
Dtype8 dot07 = 0.f;
Dtype8 brow0;
Dtype8 brow1;
Dtype8 brow2;
Dtype8 brow3;
Dtype8 brow4;
Dtype8 brow5;
Dtype8 brow6;
Dtype8 brow7;
__global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;
const __global Dtype *src0_read = src0 + local_x * (TILE_K / 8) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + off0;
const __global Dtype *src1_read0 = src1 + (group_x * TILE_N) * K + off1;
__local Dtype slm_brow[8 * SLM_BLOCK];
__local Dtype* slm_brow0;
int local_index = mad24(local_y, 8, local_x) * 8;
int w;
for(int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK) {
barrier(CLK_LOCAL_MEM_FENCE);
vstore4(vload4(0, (__global float *)(src1_read0 + mad24(0, K, local_index))), 0, (__local float *)(slm_brow + mad24(0, SLM_BLOCK, local_index)));
vstore4(vload4(0, (__global float *)(src1_read0 + mad24(1, K, local_index))), 0, (__local float *)(slm_brow + mad24(1, SLM_BLOCK, local_index)));
vstore4(vload4(0, (__global float *)(src1_read0 + mad24(2, K, local_index))), 0, (__local float *)(slm_brow + mad24(2, SLM_BLOCK, local_index)));
vstore4(vload4(0, (__global float *)(src1_read0 + mad24(3, K, local_index))), 0, (__local float *)(slm_brow + mad24(3, SLM_BLOCK, local_index)));
vstore4(vload4(0, (__global float *)(src1_read0 + mad24(4, K, local_index))), 0, (__local float *)(slm_brow + mad24(4, SLM_BLOCK, local_index)));
vstore4(vload4(0, (__global float *)(src1_read0 + mad24(5, K, local_index))), 0, (__local float *)(slm_brow + mad24(5, SLM_BLOCK, local_index)));
vstore4(vload4(0, (__global float *)(src1_read0 + mad24(6, K, local_index))), 0, (__local float *)(slm_brow + mad24(6, SLM_BLOCK, local_index)));
vstore4(vload4(0, (__global float *)(src1_read0 + mad24(7, K, local_index))), 0, (__local float *)(slm_brow + mad24(7, SLM_BLOCK, local_index)));
barrier(CLK_LOCAL_MEM_FENCE);
slm_brow0 = slm_brow + local_x * (TILE_K / 8);
w = b_tile;
int end_w = min(b_tile + SLM_BLOCK, K);
while( w + TILE_K <= end_w ) {
Dtype8 arow;
brow0 = as_half8(vload4(0, (__local float *)(slm_brow0 + 0 * SLM_BLOCK)));
brow1 = as_half8(vload4(0, (__local float *)(slm_brow0 + 1 * SLM_BLOCK)));
brow2 = as_half8(vload4(0, (__local float *)(slm_brow0 + 2 * SLM_BLOCK)));
brow3 = as_half8(vload4(0, (__local float *)(slm_brow0 + 3 * SLM_BLOCK)));
brow4 = as_half8(vload4(0, (__local float *)(slm_brow0 + 4 * SLM_BLOCK)));
brow5 = as_half8(vload4(0, (__local float *)(slm_brow0 + 5 * SLM_BLOCK)));
brow6 = as_half8(vload4(0, (__local float *)(slm_brow0 + 6 * SLM_BLOCK)));
brow7 = as_half8(vload4(0, (__local float *)(slm_brow0 + 7 * SLM_BLOCK)));
#define MM_DOT_PRODUCT( _row, _dot ) \
arow = as_half8(vload4(0, (__global float *)(src0_read + _row * K))); \
_dot = mad( (Dtype8)(arow.s0), (Dtype8)(brow0.s0, brow1.s0, brow2.s0, brow3.s0, brow4.s0, brow5.s0, brow6.s0, brow7.s0), _dot ); \
_dot = mad( (Dtype8)(arow.s1), (Dtype8)(brow0.s1, brow1.s1, brow2.s1, brow3.s1, brow4.s1, brow5.s1, brow6.s1, brow7.s1), _dot ); \
_dot = mad( (Dtype8)(arow.s2), (Dtype8)(brow0.s2, brow1.s2, brow2.s2, brow3.s2, brow4.s2, brow5.s2, brow6.s2, brow7.s2), _dot ); \
_dot = mad( (Dtype8)(arow.s3), (Dtype8)(brow0.s3, brow1.s3, brow2.s3, brow3.s3, brow4.s3, brow5.s3, brow6.s3, brow7.s3), _dot ); \
_dot = mad( (Dtype8)(arow.s4), (Dtype8)(brow0.s4, brow1.s4, brow2.s4, brow3.s4, brow4.s4, brow5.s4, brow6.s4, brow7.s4), _dot ); \
_dot = mad( (Dtype8)(arow.s5), (Dtype8)(brow0.s5, brow1.s5, brow2.s5, brow3.s5, brow4.s5, brow5.s5, brow6.s5, brow7.s5), _dot ); \
_dot = mad( (Dtype8)(arow.s6), (Dtype8)(brow0.s6, brow1.s6, brow2.s6, brow3.s6, brow4.s6, brow5.s6, brow6.s6, brow7.s6), _dot ); \
_dot = mad( (Dtype8)(arow.s7), (Dtype8)(brow0.s7, brow1.s7, brow2.s7, brow3.s7, brow4.s7, brow5.s7, brow6.s7, brow7.s7), _dot );
MM_DOT_PRODUCT( 0, dot00 );
MM_DOT_PRODUCT( 1, dot01 );
MM_DOT_PRODUCT( 2, dot02 );
MM_DOT_PRODUCT( 3, dot03 );
MM_DOT_PRODUCT( 4, dot04 );
MM_DOT_PRODUCT( 5, dot05 );
MM_DOT_PRODUCT( 6, dot06 );
MM_DOT_PRODUCT( 7, dot07 );
#undef MM_DOT_PRODUCT
src0_read += TILE_K;
slm_brow0 += TILE_K;
w += TILE_K;
}
src1_read0 += SLM_BLOCK;
}
if(w < K) {
Dtype8 arow;
#define READ_BROW(_brow, _row) \
_brow = as_half8(vload4(0, (__local float *)(slm_brow0 + _row * SLM_BLOCK))); \
_brow.s0 = (mad24(local_x, 8, w) < K) ? _brow.s0 : 0.0f; \
_brow.s1 = (mad24(local_x, 8, w + 1) < K) ? _brow.s1 : 0.0f; \
_brow.s2 = (mad24(local_x, 8, w + 2) < K) ? _brow.s2 : 0.0f; \
_brow.s3 = (mad24(local_x, 8, w + 3) < K) ? _brow.s3 : 0.0f; \
_brow.s4 = (mad24(local_x, 8, w + 4) < K) ? _brow.s4 : 0.0f; \
_brow.s5 = (mad24(local_x, 8, w + 5) < K) ? _brow.s5 : 0.0f; \
_brow.s6 = (mad24(local_x, 8, w + 6) < K) ? _brow.s6 : 0.0f; \
_brow.s7 = (mad24(local_x, 8, w + 7) < K) ? _brow.s7 : 0.0f;
READ_BROW(brow0, 0);
READ_BROW(brow1, 1);
READ_BROW(brow2, 2);
READ_BROW(brow3, 3);
READ_BROW(brow4, 4);
READ_BROW(brow5, 5);
READ_BROW(brow6, 6);
READ_BROW(brow7, 7);
#undef READ_BROW
#define MM_DOT_PRODUCT( _row, _dot ) \
arow = as_half8(vload4(0, (__global float *)(src0_read + _row * K))); \
arow.s0 = (mad24(local_x, 8, w) < K) ? arow.s0 : 0.0f; \
arow.s1 = (mad24(local_x, 8, w + 1) < K) ? arow.s1 : 0.0f; \
arow.s2 = (mad24(local_x, 8, w + 2) < K) ? arow.s2 : 0.0f; \
arow.s3 = (mad24(local_x, 8, w + 3) < K) ? arow.s3 : 0.0f; \
arow.s4 = (mad24(local_x, 8, w + 4) < K) ? arow.s4 : 0.0f; \
arow.s5 = (mad24(local_x, 8, w + 5) < K) ? arow.s5 : 0.0f; \
arow.s6 = (mad24(local_x, 8, w + 6) < K) ? arow.s6 : 0.0f; \
arow.s7 = (mad24(local_x, 8, w + 7) < K) ? arow.s7 : 0.0f; \
_dot = mad( (Dtype8)(arow.s0), (Dtype8)(brow0.s0, brow1.s0, brow2.s0, brow3.s0, brow4.s0, brow5.s0, brow6.s0, brow7.s0), _dot ); \
_dot = mad( (Dtype8)(arow.s1), (Dtype8)(brow0.s1, brow1.s1, brow2.s1, brow3.s1, brow4.s1, brow5.s1, brow6.s1, brow7.s1), _dot ); \
_dot = mad( (Dtype8)(arow.s2), (Dtype8)(brow0.s2, brow1.s2, brow2.s2, brow3.s2, brow4.s2, brow5.s2, brow6.s2, brow7.s2), _dot ); \
_dot = mad( (Dtype8)(arow.s3), (Dtype8)(brow0.s3, brow1.s3, brow2.s3, brow3.s3, brow4.s3, brow5.s3, brow6.s3, brow7.s3), _dot ); \
_dot = mad( (Dtype8)(arow.s4), (Dtype8)(brow0.s4, brow1.s4, brow2.s4, brow3.s4, brow4.s4, brow5.s4, brow6.s4, brow7.s4), _dot ); \
_dot = mad( (Dtype8)(arow.s5), (Dtype8)(brow0.s5, brow1.s5, brow2.s5, brow3.s5, brow4.s5, brow5.s5, brow6.s5, brow7.s5), _dot ); \
_dot = mad( (Dtype8)(arow.s6), (Dtype8)(brow0.s6, brow1.s6, brow2.s6, brow3.s6, brow4.s6, brow5.s6, brow6.s6, brow7.s6), _dot ); \
_dot = mad( (Dtype8)(arow.s7), (Dtype8)(brow0.s7, brow1.s7, brow2.s7, brow3.s7, brow4.s7, brow5.s7, brow6.s7, brow7.s7), _dot );
MM_DOT_PRODUCT( 0, dot00 );
MM_DOT_PRODUCT( 1, dot01 );
MM_DOT_PRODUCT( 2, dot02 );
MM_DOT_PRODUCT( 3, dot03 );
MM_DOT_PRODUCT( 4, dot04 );
MM_DOT_PRODUCT( 5, dot05 );
MM_DOT_PRODUCT( 6, dot06 );
MM_DOT_PRODUCT( 7, dot07 );
#undef MM_DOT_PRODUCT
}
#define REDUCE(_dot) \
_dot = as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 0)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 1)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 2)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 3)) + \
as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 4)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 5)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 6)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 7));
REDUCE(dot00);
REDUCE(dot01);
REDUCE(dot02);
REDUCE(dot03);
REDUCE(dot04);
REDUCE(dot05);
REDUCE(dot06);
REDUCE(dot07);
#undef REDUCE
Dtype output = 0.0f;
#define OUTPUT( _dot) \
output = (local_x == 0) ? _dot.s0 : output; \
output = (local_x == 1) ? _dot.s1 : output; \
output = (local_x == 2) ? _dot.s2 : output; \
output = (local_x == 3) ? _dot.s3 : output; \
output = (local_x == 4) ? _dot.s4 : output; \
output = (local_x == 5) ? _dot.s5 : output; \
output = (local_x == 6) ? _dot.s6 : output; \
output = (local_x == 7) ? _dot.s7 : output; \
dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); \
dst_write0 += N;
if(global_x < N && global_y * 8 < M) {
OUTPUT(dot00);
if(mad24(global_y, 8, 1) < M) { OUTPUT(dot01); }
if(mad24(global_y, 8, 2) < M) { OUTPUT(dot02); }
if(mad24(global_y, 8, 3) < M) { OUTPUT(dot03); }
if(mad24(global_y, 8, 4) < M) { OUTPUT(dot04); }
if(mad24(global_y, 8, 5) < M) { OUTPUT(dot05); }
if(mad24(global_y, 8, 6) < M) { OUTPUT(dot06); }
if(mad24(global_y, 8, 7) < M) { OUTPUT(dot07); }
}
#undef OUTPUT
}
#else
__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))
__attribute__((intel_reqd_sub_group_size(8)))
__kernel void TEMPLATE(gemm_buffer_NT, Dtype)(
const __global Dtype *src0, int off0,
const __global Dtype *src1, int off1,
__global Dtype *dst, int offd,
int M,
int N,
int K,
KERNEL_ARG_DTYPE alpha_in,
KERNEL_ARG_DTYPE beta_in)
{
const Dtype alpha = (Dtype)alpha_in;
const Dtype beta = (Dtype)beta_in;
const int group_x = get_group_id(0);
const int group_y = get_group_id(1);
const int local_x = get_local_id(0);
const int local_y = get_local_id(1);
const int global_x = get_global_id(0);
const int global_y = get_global_id(1);
Dtype8 dot00 = 0.f;
Dtype8 dot01 = 0.f;
Dtype8 dot02 = 0.f;
Dtype8 dot03 = 0.f;
Dtype8 dot04 = 0.f;
Dtype8 dot05 = 0.f;
Dtype8 dot06 = 0.f;
Dtype8 dot07 = 0.f;
Dtype4 brow0;
Dtype4 brow1;
Dtype4 brow2;
Dtype4 brow3;
Dtype4 brow4;
Dtype4 brow5;
Dtype4 brow6;
Dtype4 brow7;
__global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;
const __global Dtype *src0_read = src0 + local_x * (TILE_K / 8) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + off0;
const __global Dtype *src1_read0 = src1 + (group_x * TILE_N) * K + off1;
__local Dtype slm_brow[8 * SLM_BLOCK];
__local Dtype* slm_brow0;
int local_index = mad24(local_y, 8, local_x) * 4;
int w;
for(int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK) {
barrier(CLK_LOCAL_MEM_FENCE);
vstore4(vload4(0, src1_read0 + mad24(0, K, local_index)), 0, slm_brow + mad24(0, SLM_BLOCK, local_index));
vstore4(vload4(0, src1_read0 + mad24(1, K, local_index)), 0, slm_brow + mad24(1, SLM_BLOCK, local_index));
vstore4(vload4(0, src1_read0 + mad24(2, K, local_index)), 0, slm_brow + mad24(2, SLM_BLOCK, local_index));
vstore4(vload4(0, src1_read0 + mad24(3, K, local_index)), 0, slm_brow + mad24(3, SLM_BLOCK, local_index));
vstore4(vload4(0, src1_read0 + mad24(4, K, local_index)), 0, slm_brow + mad24(4, SLM_BLOCK, local_index));
vstore4(vload4(0, src1_read0 + mad24(5, K, local_index)), 0, slm_brow + mad24(5, SLM_BLOCK, local_index));
vstore4(vload4(0, src1_read0 + mad24(6, K, local_index)), 0, slm_brow + mad24(6, SLM_BLOCK, local_index));
vstore4(vload4(0, src1_read0 + mad24(7, K, local_index)), 0, slm_brow + mad24(7, SLM_BLOCK, local_index));
barrier(CLK_LOCAL_MEM_FENCE);
slm_brow0 = slm_brow + local_x * (TILE_K / 8);
w = b_tile;
int end_w = min(b_tile + SLM_BLOCK, K);
while( w + TILE_K <= end_w ) {
Dtype4 arow;
brow0 = vload4(0, slm_brow0 + 0 * SLM_BLOCK);
brow1 = vload4(0, slm_brow0 + 1 * SLM_BLOCK);
brow2 = vload4(0, slm_brow0 + 2 * SLM_BLOCK);
brow3 = vload4(0, slm_brow0 + 3 * SLM_BLOCK);
brow4 = vload4(0, slm_brow0 + 4 * SLM_BLOCK);
brow5 = vload4(0, slm_brow0 + 5 * SLM_BLOCK);
brow6 = vload4(0, slm_brow0 + 6 * SLM_BLOCK);
brow7 = vload4(0, slm_brow0 + 7 * SLM_BLOCK);
#define MM_DOT_PRODUCT( _row, _dot ) \
arow = vload4(0, src0_read + _row * K); \
_dot = mad( (Dtype8)(arow.x), (Dtype8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \
_dot = mad( (Dtype8)(arow.y), (Dtype8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \
_dot = mad( (Dtype8)(arow.z), (Dtype8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \
_dot = mad( (Dtype8)(arow.w), (Dtype8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );
MM_DOT_PRODUCT( 0, dot00 );
MM_DOT_PRODUCT( 1, dot01 );
MM_DOT_PRODUCT( 2, dot02 );
MM_DOT_PRODUCT( 3, dot03 );
MM_DOT_PRODUCT( 4, dot04 );
MM_DOT_PRODUCT( 5, dot05 );
MM_DOT_PRODUCT( 6, dot06 );
MM_DOT_PRODUCT( 7, dot07 );
#undef MM_DOT_PRODUCT
src0_read += TILE_K;
slm_brow0 += TILE_K;
w += TILE_K;
}
src1_read0 += SLM_BLOCK;
}
if(w < K) {
Dtype4 arow;
#define READ_BROW(_brow, _row) \
_brow = vload4(0, slm_brow0 + _row * SLM_BLOCK); \
_brow.x = (mad24(local_x, 4, w) < K) ? _brow.x : 0.0f; \
_brow.y = (mad24(local_x, 4, w + 1) < K) ? _brow.y : 0.0f; \
_brow.z = (mad24(local_x, 4, w + 2) < K) ? _brow.z : 0.0f; \
_brow.w = (mad24(local_x, 4, w + 3) < K) ? _brow.w : 0.0f;
READ_BROW(brow0, 0);
READ_BROW(brow1, 1);
READ_BROW(brow2, 2);
READ_BROW(brow3, 3);
READ_BROW(brow4, 4);
READ_BROW(brow5, 5);
READ_BROW(brow6, 6);
READ_BROW(brow7, 7);
#undef READ_BROW
#define MM_DOT_PRODUCT( _row, _dot ) \
arow = vload4(0, src0_read + _row * K); \
arow.x = (mad24(local_x, 4, w) < K) ? arow.x : 0.0f; \
arow.y = (mad24(local_x, 4, w + 1) < K) ? arow.y : 0.0f; \
arow.z = (mad24(local_x, 4, w + 2) < K) ? arow.z : 0.0f; \
arow.w = (mad24(local_x, 4, w + 3) < K) ? arow.w : 0.0f; \
_dot = mad( (Dtype8)(arow.x), (Dtype8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \
_dot = mad( (Dtype8)(arow.y), (Dtype8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \
_dot = mad( (Dtype8)(arow.z), (Dtype8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \
_dot = mad( (Dtype8)(arow.w), (Dtype8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );
MM_DOT_PRODUCT( 0, dot00 );
MM_DOT_PRODUCT( 1, dot01 );
MM_DOT_PRODUCT( 2, dot02 );
MM_DOT_PRODUCT( 3, dot03 );
MM_DOT_PRODUCT( 4, dot04 );
MM_DOT_PRODUCT( 5, dot05 );
MM_DOT_PRODUCT( 6, dot06 );
MM_DOT_PRODUCT( 7, dot07 );
#undef MM_DOT_PRODUCT
}
#define REDUCE(_dot) \
_dot = as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 0)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 1)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 2)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 3)) + \
as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 4)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 5)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 6)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 7));
REDUCE(dot00);
REDUCE(dot01);
REDUCE(dot02);
REDUCE(dot03);
REDUCE(dot04);
REDUCE(dot05);
REDUCE(dot06);
REDUCE(dot07);
#undef REDUCE
Dtype output = 0.0f;
#define OUTPUT( _dot) \
output = (local_x == 0) ? _dot.s0 : output; \
output = (local_x == 1) ? _dot.s1 : output; \
output = (local_x == 2) ? _dot.s2 : output; \
output = (local_x == 3) ? _dot.s3 : output; \
output = (local_x == 4) ? _dot.s4 : output; \
output = (local_x == 5) ? _dot.s5 : output; \
output = (local_x == 6) ? _dot.s6 : output; \
output = (local_x == 7) ? _dot.s7 : output; \
dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); \
dst_write0 += N;
if(global_x < N && global_y * 8 < M) {
OUTPUT(dot00);
if(mad24(global_y, 8, 1) < M) { OUTPUT(dot01); }
if(mad24(global_y, 8, 2) < M) { OUTPUT(dot02); }
if(mad24(global_y, 8, 3) < M) { OUTPUT(dot03); }
if(mad24(global_y, 8, 4) < M) { OUTPUT(dot04); }
if(mad24(global_y, 8, 5) < M) { OUTPUT(dot05); }
if(mad24(global_y, 8, 6) < M) { OUTPUT(dot06); }
if(mad24(global_y, 8, 7) < M) { OUTPUT(dot07); }
}
#undef OUTPUT
}
#endif
#undef VEC_SIZE
#undef LWG_HEIGHT
#undef TILE_M
#undef TILE_K
#undef TILE_N
#undef SLM_BLOCK
#define SLM_SIZE 64
void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)(
const __global Dtype* srca_read0,
const __global Dtype* srca_read1,
const __global Dtype* srcb_read,
__local Dtype4* work0,
__local Dtype4* work1,
int N,
int K,
int x_gid,
int lid,
Dtype alpha,
Dtype beta,
__global Dtype* dstc0,
__global Dtype* dstc1)
{
__local Dtype* work_each0 = (__local Dtype*)work0;
__local Dtype* work_each1 = (__local Dtype*)work1;
int rows = N - x_gid * 4;
Dtype4 dot0[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
Dtype4 dot1[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
int i = lid;
while( i < K / 4) {
const Dtype4 b0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]};
const Dtype4 b1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]};
#pragma unroll
for(int j = 0; j < rows; ++j) {
dot0[j] += b0 * vload4(i, srcb_read + j * K);
dot1[j] += b1 * vload4(i, srcb_read + j * K);
}
i += get_local_size(0);
}
#pragma unroll
for(int j = 0; j < rows; ++j) {
work_each0[lid * 4 + j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w;
work_each1[lid * 4 + j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w;
}
if(i == K / 4) {
short tail_items = K % 4;
if(tail_items != 0) {
const __global Dtype *srcb_tail = srcb_read + i * 4;
const __global Dtype *srca_tail0 = srca_read0 + i * 4;
const __global Dtype *srca_tail1 = srca_read1 + i * 4;
#pragma unroll
for(short i = 0; i < tail_items; ++i) {
const Dtype at0 = srca_tail0[i];
const Dtype at1 = srca_tail1[i];
#pragma unroll
for(int j = 0; j < rows; ++j) {
work_each0[lid * 4 + j] += at0 * srcb_tail[i + j * K];
work_each1[lid * 4 + j] += at1 * srcb_tail[i + j * K];
}
}
}
}
for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {
barrier(CLK_LOCAL_MEM_FENCE);
if(lid < stride) {
work0[lid] += work0[lid+stride];
work1[lid] += work1[lid+stride];
}
}
if(lid == 0) {
#pragma unroll
for(int j = 0; j < rows; ++j) {
dstc0[(x_gid * 4 + j)] = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)];
dstc1[(x_gid * 4 + j)] = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)];
}
}
}
__kernel void TEMPLATE(gemm_buffer_NT_M_2,Dtype)(
__global const Dtype * A,
int offA,
__global const Dtype * B,
int offB,
__global Dtype * C,
int offC,
int M,
int N,
int K,
KERNEL_ARG_DTYPE alpha_f,
KERNEL_ARG_DTYPE beta_f)
{
Dtype alpha = (Dtype)alpha_f;
Dtype beta = (Dtype)beta_f;
int x_gid = get_group_id(0);
int lid = get_local_id(0);
const __global Dtype *srca_read0 = A + offA;
const __global Dtype *srca_read1 = srca_read0 + K;
const __global Dtype *srcb_read = B + x_gid * 4 * K + offB;
__global Dtype4 *dstc0 = (__global Dtype4*)(C + offC);
__global Dtype4 *dstc1 = (__global Dtype4*)((__global Dtype*)(dstc0) + N);
__local Dtype4 work0[SLM_SIZE];
__local Dtype4 work1[SLM_SIZE];
__local Dtype* work_each0 = (__local Dtype*)work0;
__local Dtype* work_each1 = (__local Dtype*)work1;
if(x_gid == N / 4) {
TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype) \
(srca_read0, srca_read1, srcb_read, work0, work1, N, K, x_gid, lid, alpha, beta, (__global Dtype*)dstc0, (__global Dtype*)dstc1);
} else {
Dtype4 dot0[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
Dtype4 dot1[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
int i = lid;
while( i < K / 4) {
const Dtype4 b0 = vload4(i, srca_read0);
const Dtype4 b1 = vload4(i, srca_read1);
#pragma unroll
for(int j = 0; j < 4; ++j) {
Dtype4 a = vload4(i, srcb_read + j * K);
dot0[j] += b0 * a;
dot1[j] += b1 * a;
}
i += get_local_size(0);
}
#pragma unroll
for(int j = 0; j < 4; ++j) {
work_each0[lid * 4 + j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w;
work_each1[lid * 4 + j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w;
}
if(i == K / 4) {
short tail_items = K % 4;
if(tail_items != 0) {
const __global Dtype *srcb_tail = srcb_read + i * 4;
const __global Dtype *srca_tail0 = srca_read0 + i * 4;
const __global Dtype *srca_tail1 = srca_read1 + i * 4;
#pragma unroll
for(short i = 0; i < tail_items; ++i) {
const Dtype at0 = srca_tail0[i];
const Dtype at1 = srca_tail1[i];
#pragma unroll
for(int j = 0; j < 4; ++j) {
work_each0[lid * 4 + j] += at0 * srcb_tail[i + j * K];
work_each1[lid * 4 + j] += at1 * srcb_tail[i + j * K];
}
}
}
}
for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {
barrier(CLK_LOCAL_MEM_FENCE);
if(lid < stride) {
work0[lid] += work0[lid+stride];
work1[lid] += work1[lid+stride];
}
}
if(lid == 0) {
dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid];
dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid];
}
}
}
#undef SLM_SIZE
#define SLM_SIZE 32
void TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype)(
const __global Dtype* srca_read0,
const __global Dtype* srca_read1,
const __global Dtype* srca_read2,
const __global Dtype* srca_read3,
const __global Dtype* srcb_read,
__local Dtype4* work0,
__local Dtype4* work1,
__local Dtype4* work2,
__local Dtype4* work3,
int N,
int K,
int x_gid,
int lid,
Dtype alpha,
Dtype beta,
__global Dtype* dstc0,
__global Dtype* dstc1,
__global Dtype* dstc2,
__global Dtype* dstc3)
{
__local Dtype* work_each0 = (__local Dtype*)(work0 + lid);
__local Dtype* work_each1 = (__local Dtype*)(work1 + lid);
__local Dtype* work_each2 = (__local Dtype*)(work2 + lid);
__local Dtype* work_each3 = (__local Dtype*)(work3 + lid);
int rows = N - x_gid * 4;
Dtype4 dot0[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
Dtype4 dot1[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
Dtype4 dot2[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
Dtype4 dot3[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
int i = lid;
while( i < K / 4) {
const Dtype4 a0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]};
const Dtype4 a1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]};
const Dtype4 a2 = {srca_read2[i*4], srca_read2[(i*4+1)], srca_read2[(i*4+2)], srca_read2[(i*4+3)]};
const Dtype4 a3 = {srca_read3[i*4], srca_read3[(i*4+1)], srca_read3[(i*4+2)], srca_read3[(i*4+3)]};
#pragma unrol
for(int j = 0; j < rows; ++j) {
dot0[j] += a0 * vload4(i, srcb_read + j * K);
dot1[j] += a1 * vload4(i, srcb_read + j * K);
dot2[j] += a2 * vload4(i, srcb_read + j * K);
dot3[j] += a3 * vload4(i, srcb_read + j * K);
}
i += get_local_size(0);
}
#pragma unroll
for(int j = 0; j < rows; ++j) {
work_each0[j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w;
work_each1[j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w;
work_each2[j] = dot2[j].x + dot2[j].y + dot2[j].z + dot2[j].w;
work_each3[j] = dot3[j].x + dot3[j].y + dot3[j].z + dot3[j].w;
}
if(i == K / 4) {
short tail_items = K % 4;
if(tail_items != 0) {
const __global Dtype *srcb_tail = srcb_read + i * 4;
const __global Dtype *srca_tail0 = srca_read0 + i * 4;
const __global Dtype *srca_tail1 = srca_read1 + i * 4;
const __global Dtype *srca_tail2 = srca_read2 + i * 4;
const __global Dtype *srca_tail3 = srca_read3 + i * 4;
#pragma unroll
for(short i = 0; i < tail_items; ++i) {
const Dtype at0 = srca_tail0[i];
const Dtype at1 = srca_tail1[i];
const Dtype at2 = srca_tail2[i];
const Dtype at3 = srca_tail3[i];
#pragma unroll
for(int j = 0; j < rows; ++j) {
work_each0[j] += at0 * srcb_tail[i + j * K];
work_each1[j] += at1 * srcb_tail[i + j * K];
work_each2[j] += at2 * srcb_tail[i + j * K];
work_each3[j] += at3 * srcb_tail[i + j * K];
}
}
}
}
for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {
barrier(CLK_LOCAL_MEM_FENCE);
if(lid < stride) {
work0[lid] += work0[lid+stride];
work1[lid] += work1[lid+stride];
work2[lid] += work2[lid+stride];
work3[lid] += work3[lid+stride];
}
}
if(lid == 0) {
#pragma unroll
for(int j = 0; j < rows; ++j) {
dstc0[(x_gid * 4 + j)] = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)];
dstc1[(x_gid * 4 + j)] = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)];
dstc2[(x_gid * 4 + j)] = alpha * work_each2[j] + beta * dstc2[(x_gid * 4 + j)];
dstc3[(x_gid * 4 + j)] = alpha * work_each3[j] + beta * dstc3[(x_gid * 4 + j)];
}
}
}
__kernel void TEMPLATE(gemm_buffer_NT_M_4,Dtype)(
__global const Dtype * A,
int offA,
__global const Dtype * B,
int offB,
__global Dtype * C,
int offC,
int M,
int N,
int K,
KERNEL_ARG_DTYPE alpha_f,
KERNEL_ARG_DTYPE beta_f)
{
Dtype alpha = (Dtype)alpha_f;
Dtype beta = (Dtype)beta_f;
int x_gid = get_group_id(0);
int lid = get_local_id(0);
int lsize = get_local_size(0);
const __global Dtype *srca_read0 = A + offA;
const __global Dtype *srca_read1 = srca_read0 + K;
const __global Dtype *srca_read2 = srca_read1 + K;
const __global Dtype *srca_read3 = srca_read2 + K;
const __global Dtype *srcb_read = B + x_gid * 4 * K + offB;
__global Dtype4 *dstc0 = (__global Dtype4*)(C + offC);
__global Dtype4 *dstc1 = (__global Dtype4*)((__global Dtype*)(dstc0) + N);
__global Dtype4 *dstc2 = (__global Dtype4*)((__global Dtype*)(dstc1) + N);
__global Dtype4 *dstc3 = (__global Dtype4*)((__global Dtype*)(dstc2) + N);
__local Dtype4 work0[SLM_SIZE];
__local Dtype4 work1[SLM_SIZE];
__local Dtype4 work2[SLM_SIZE];
__local Dtype4 work3[SLM_SIZE];
__local Dtype* work_each0 = (__local Dtype*)(work0 + lid);
__local Dtype* work_each1 = (__local Dtype*)(work1 + lid);
__local Dtype* work_each2 = (__local Dtype*)(work2 + lid);
__local Dtype* work_each3 = (__local Dtype*)(work3 + lid);
if(x_gid == N / 4) {
TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype) \
(srca_read0, srca_read1, srca_read2, srca_read3, srcb_read, \
work0, work1, work2, work3, N, K, x_gid, lid, alpha, beta, \
(__global Dtype*)dstc0, (__global Dtype*)dstc1, (__global Dtype*)dstc2, (__global Dtype*)dstc3);
} else {
Dtype4 dot0[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
Dtype4 dot1[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
Dtype4 dot2[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
Dtype4 dot3[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
int kid = lid;
while( kid < K / 4) {
const Dtype4 b0 = vload4(kid, srca_read0);
const Dtype4 b1 = vload4(kid, srca_read1);
const Dtype4 b2 = vload4(kid, srca_read2);
const Dtype4 b3 = vload4(kid, srca_read3);
#pragma unroll
for(int j = 0; j < 4; ++j) {
Dtype4 a = vload4(kid, srcb_read + j * K);
dot0[j] += b0 * a;
dot1[j] += b1 * a;
dot2[j] += b2 * a;
dot3[j] += b3 * a;
}
kid += lsize;
}
#pragma unroll
for(int j = 0; j < 4; ++j) {
work_each0[j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w;
work_each1[j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w;
work_each2[j] = dot2[j].x + dot2[j].y + dot2[j].z + dot2[j].w;
work_each3[j] = dot3[j].x + dot3[j].y + dot3[j].z + dot3[j].w;
}
if(kid == (K >> 2)) {
short tail_items = K % 4;
if(tail_items != 0) {
int offset = kid << 2;
const __global Dtype *srcb_tail = srcb_read + offset;
const __global Dtype *srca_tail0 = srca_read0 + offset;
const __global Dtype *srca_tail1 = srca_read1 + offset;
const __global Dtype *srca_tail2 = srca_read2 + offset;
const __global Dtype *srca_tail3 = srca_read3 + offset;
#pragma unroll
for(short i = 0; i < tail_items; ++i) {
const Dtype at0 = srca_tail0[i];
const Dtype at1 = srca_tail1[i];
const Dtype at2 = srca_tail2[i];
const Dtype at3 = srca_tail3[i];
#pragma unroll
for(int j = 0; j < 4; ++j) {
work_each0[j] += at0 * srcb_tail[i + j * K];
work_each1[j] += at1 * srcb_tail[i + j * K];
work_each2[j] += at2 * srcb_tail[i + j * K];
work_each3[j] += at3 * srcb_tail[i + j * K];
}
}
}
}
for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {
barrier(CLK_LOCAL_MEM_FENCE);
if(lid < stride) {
work0[lid] += work0[lid+stride];
work1[lid] += work1[lid+stride];
work2[lid] += work2[lid+stride];
work3[lid] += work3[lid+stride];
}
}
if(lid == 0) {
dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid];
dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid];
dstc2[x_gid] = alpha * work2[0] + beta * dstc2[x_gid];
dstc3[x_gid] = alpha * work3[0] + beta * dstc3[x_gid];
}
}
}
#undef SLM_SIZE
#define SLM_SIZE 16
__kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)(
__global const Dtype * A,
int offA,
__global const Dtype * B,
int offB,
__global Dtype * C,
int offC,
int M,
int N,
int K,
KERNEL_ARG_DTYPE alpha_f,
KERNEL_ARG_DTYPE beta_f)
{
Dtype alpha = (Dtype)alpha_f;
Dtype beta = (Dtype)beta_f;
int x_gid = get_group_id(0);
int lid = get_local_id(0);
int lsize = get_local_size(0);
const __global Dtype *srca_read0 = A + offA;
const __global Dtype *srca_read1 = srca_read0 + K;
const __global Dtype *srca_read2 = srca_read1 + K;
const __global Dtype *srca_read3 = srca_read2 + K;
const __global Dtype *srca_read4 = srca_read3 + K;
const __global Dtype *srca_read5 = srca_read4 + K;
const __global Dtype *srca_read6 = srca_read5 + K;
const __global Dtype *srca_read7 = srca_read6 + K;
const __global Dtype *srcb_read = B + x_gid * K + offB;
__global Dtype *dstc0 = C + offC;
__global Dtype *dstc1 = dstc0 + N;
__global Dtype *dstc2 = dstc1 + N;
__global Dtype *dstc3 = dstc2 + N;
__global Dtype *dstc4 = dstc3 + N;
__global Dtype *dstc5 = dstc4 + N;
__global Dtype *dstc6 = dstc5 + N;
__global Dtype *dstc7 = dstc6 + N;
__local Dtype work0[SLM_SIZE];
__local Dtype work1[SLM_SIZE];
__local Dtype work2[SLM_SIZE];
__local Dtype work3[SLM_SIZE];
__local Dtype work4[SLM_SIZE];
__local Dtype work5[SLM_SIZE];
__local Dtype work6[SLM_SIZE];
__local Dtype work7[SLM_SIZE];
Dtype4 dot0 = (Dtype4)(0.);
Dtype4 dot1 = (Dtype4)(0.);
Dtype4 dot2 = (Dtype4)(0.);
Dtype4 dot3 = (Dtype4)(0.);
Dtype4 dot4 = (Dtype4)(0.);
Dtype4 dot5 = (Dtype4)(0.);
Dtype4 dot6 = (Dtype4)(0.);
Dtype4 dot7 = (Dtype4)(0.);
int kid = lid;
while( kid < K / 4) {
const Dtype4 a0 = vload4(kid, srca_read0);
const Dtype4 a1 = vload4(kid, srca_read1);
const Dtype4 a2 = vload4(kid, srca_read2);
const Dtype4 a3 = vload4(kid, srca_read3);
const Dtype4 a4 = vload4(kid, srca_read4);
const Dtype4 a5 = vload4(kid, srca_read5);
const Dtype4 a6 = vload4(kid, srca_read6);
const Dtype4 a7 = vload4(kid, srca_read7);
Dtype4 b = vload4(kid, srcb_read);
dot0 += a0 * b;
dot1 += a1 * b;
dot2 += a2 * b;
dot3 += a3 * b;
dot4 += a4 * b;
dot5 += a5 * b;
dot6 += a6 * b;
dot7 += a7 * b;
kid += lsize;
}
work0[lid] = dot0.x + dot0.y + dot0.z + dot0.w;
work1[lid] = dot1.x + dot1.y + dot1.z + dot1.w;
work2[lid] = dot2.x + dot2.y + dot2.z + dot2.w;
work3[lid] = dot3.x + dot3.y + dot3.z + dot3.w;
work4[lid] = dot4.x + dot4.y + dot4.z + dot4.w;
work5[lid] = dot5.x + dot5.y + dot5.z + dot5.w;
work6[lid] = dot6.x + dot6.y + dot6.z + dot6.w;
work7[lid] = dot7.x + dot7.y + dot7.z + dot7.w;
if(kid == (K >> 2)) {
short tail_items = K % 4;
if(tail_items != 0) {
int offset = kid << 2;
const __global Dtype *srcb_tail = srcb_read + offset;
const __global Dtype *srca_tail0 = srca_read0 + offset;
const __global Dtype *srca_tail1 = srca_read1 + offset;
const __global Dtype *srca_tail2 = srca_read2 + offset;
const __global Dtype *srca_tail3 = srca_read3 + offset;
const __global Dtype *srca_tail4 = srca_read4 + offset;
const __global Dtype *srca_tail5 = srca_read5 + offset;
const __global Dtype *srca_tail6 = srca_read6 + offset;
const __global Dtype *srca_tail7 = srca_read7 + offset;
#pragma unroll
for(short item = 0; item < tail_items; ++item) {
work0[lid] += srca_tail0[item] * srcb_tail[item];
work1[lid] += srca_tail1[item] * srcb_tail[item];
work2[lid] += srca_tail2[item] * srcb_tail[item];
work3[lid] += srca_tail3[item] * srcb_tail[item];
work4[lid] += srca_tail4[item] * srcb_tail[item];
work5[lid] += srca_tail5[item] * srcb_tail[item];
work6[lid] += srca_tail6[item] * srcb_tail[item];
work7[lid] += srca_tail7[item] * srcb_tail[item];
}
}
}
for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {
barrier(CLK_LOCAL_MEM_FENCE);
if(lid < stride) {
work0[lid] += work0[lid+stride];
work1[lid] += work1[lid+stride];
work2[lid] += work2[lid+stride];
work3[lid] += work3[lid+stride];
work4[lid] += work4[lid+stride];
work5[lid] += work5[lid+stride];
work6[lid] += work6[lid+stride];
work7[lid] += work7[lid+stride];
}
}
if(lid == 0) {
dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid];
dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid];
dstc2[x_gid] = alpha * work2[0] + beta * dstc2[x_gid];
dstc3[x_gid] = alpha * work3[0] + beta * dstc3[x_gid];
dstc4[x_gid] = alpha * work4[0] + beta * dstc4[x_gid];
dstc5[x_gid] = alpha * work5[0] + beta * dstc5[x_gid];
dstc6[x_gid] = alpha * work6[0] + beta * dstc6[x_gid];
dstc7[x_gid] = alpha * work7[0] + beta * dstc7[x_gid];
}
}
#undef SLM_SIZE
#undef VEC_SIZE
#undef LWG_HEIGHT
#undef TILE_M
#undef TILE_K
#undef TILE_N
#undef SIMD_SIZE_GEMM
#undef SHUFFLE_TYPE2
#undef SHUFFLE_TYPE8