mirror of
https://github.com/opencv/opencv.git
synced 2024-11-25 03:30:34 +08:00
core(ocl): buffer bounds in intelblas_gemm_buffer_NT
This commit is contained in:
parent
e3f4f874c5
commit
9b4ecc96f6
@ -77,11 +77,7 @@ static bool intel_gpu_gemm(
|
||||
}
|
||||
else if(!atrans && btrans)
|
||||
{
|
||||
if (M % 128 != 0)
|
||||
return false;
|
||||
if (N % 8 != 0)
|
||||
return false;
|
||||
if (K % 512 != 0)
|
||||
if (K % 4 != 0)
|
||||
return false;
|
||||
kernelName = "intelblas_gemm_buffer_NT";
|
||||
ly = 16;
|
||||
|
@ -392,6 +392,15 @@ __kernel void intelblas_gemm_buffer_NN(
|
||||
#define TILE_N 8
|
||||
#define SLM_BLOCK 512
|
||||
|
||||
/*
|
||||
A K B.t() K D N
|
||||
----------- ----------- -----------
|
||||
| | | | | |
|
||||
M | | x N | | => M | |
|
||||
| | | | | |
|
||||
----------- ----------- -----------
|
||||
*/
|
||||
|
||||
__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))
|
||||
__kernel void intelblas_gemm_buffer_NT(
|
||||
const __global float *src0, int off0,
|
||||
@ -422,59 +431,79 @@ __kernel void intelblas_gemm_buffer_NT(
|
||||
float8 dot06 = 0.f;
|
||||
float8 dot07 = 0.f;
|
||||
|
||||
float4 brow0;
|
||||
float4 brow1;
|
||||
float4 brow2;
|
||||
float4 brow3;
|
||||
float4 brow4;
|
||||
float4 brow5;
|
||||
float4 brow6;
|
||||
float4 brow7;
|
||||
const int dst_row = (global_y * TILE_M);
|
||||
__global float *dst_write0 = dst + global_x + dst_row * ldC + offd;
|
||||
|
||||
__global float *dst_write0 = dst + local_x * VEC_SIZE + ( group_x * TILE_N ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * ldC + offd;
|
||||
const __global float *src0_read00 = src0 + off0;
|
||||
const int a_row_base = global_y * TILE_M;
|
||||
const int a_col_base = local_x * (TILE_K / 8); // <= TILE_K - 4
|
||||
|
||||
const __global float *src0_read = src0 + local_x * ( TILE_K / 8 ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M ) * ldA + off0;
|
||||
|
||||
const __global float *src1_read0 = src1 + ( group_x * TILE_N ) * ldB + off1;
|
||||
const __global float *src1_read00 = src1 + off1;
|
||||
const int b_row_base = (group_x * TILE_N);
|
||||
//const int b_col_base = 0;
|
||||
|
||||
__local float slm_brow[8 * SLM_BLOCK];
|
||||
__local float* slm_brow0;
|
||||
|
||||
int local_index = mad24(local_y, 8, local_x) * 4;
|
||||
int w;
|
||||
for(int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK) {
|
||||
int w = 0;
|
||||
for (int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK)
|
||||
{
|
||||
#define UPDATE_BROW(_row) \
|
||||
{ \
|
||||
float4 brow; \
|
||||
int b_row = b_row_base + _row; \
|
||||
int b_col = b_tile + local_index; \
|
||||
if (b_row < N && b_col <= K - 4 /*vload4*/) \
|
||||
brow = vload4(0, src1_read00 + mad24(b_row, ldB, b_col)); \
|
||||
else \
|
||||
brow = (float4)0; \
|
||||
vstore4(brow, 0, slm_brow + mad24(_row, SLM_BLOCK, local_index)); \
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
vstore4(vload4(0, src1_read0 + mad24(0, ldB, local_index)), 0, slm_brow + mad24(0, SLM_BLOCK, local_index));
|
||||
vstore4(vload4(0, src1_read0 + mad24(1, ldB, local_index)), 0, slm_brow + mad24(1, SLM_BLOCK, local_index));
|
||||
vstore4(vload4(0, src1_read0 + mad24(2, ldB, local_index)), 0, slm_brow + mad24(2, SLM_BLOCK, local_index));
|
||||
vstore4(vload4(0, src1_read0 + mad24(3, ldB, local_index)), 0, slm_brow + mad24(3, SLM_BLOCK, local_index));
|
||||
vstore4(vload4(0, src1_read0 + mad24(4, ldB, local_index)), 0, slm_brow + mad24(4, SLM_BLOCK, local_index));
|
||||
vstore4(vload4(0, src1_read0 + mad24(5, ldB, local_index)), 0, slm_brow + mad24(5, SLM_BLOCK, local_index));
|
||||
vstore4(vload4(0, src1_read0 + mad24(6, ldB, local_index)), 0, slm_brow + mad24(6, SLM_BLOCK, local_index));
|
||||
vstore4(vload4(0, src1_read0 + mad24(7, ldB, local_index)), 0, slm_brow + mad24(7, SLM_BLOCK, local_index));
|
||||
UPDATE_BROW(0);
|
||||
UPDATE_BROW(1);
|
||||
UPDATE_BROW(2);
|
||||
UPDATE_BROW(3);
|
||||
UPDATE_BROW(4);
|
||||
UPDATE_BROW(5);
|
||||
UPDATE_BROW(6);
|
||||
UPDATE_BROW(7);
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#undef UPDATE_BROW
|
||||
|
||||
slm_brow0 = slm_brow + local_x * (TILE_K / 8);
|
||||
w = b_tile;
|
||||
int end_w = min(b_tile + SLM_BLOCK, K);
|
||||
while( w + TILE_K <= end_w ) {
|
||||
float4 arow;
|
||||
for (int k_tile_offset = 0; k_tile_offset < SLM_BLOCK; k_tile_offset += TILE_K)
|
||||
{
|
||||
int a_col = a_col_base + b_tile + k_tile_offset;
|
||||
|
||||
brow0 = vload4(0, slm_brow0 + 0 * SLM_BLOCK);
|
||||
brow1 = vload4(0, slm_brow0 + 1 * SLM_BLOCK);
|
||||
brow2 = vload4(0, slm_brow0 + 2 * SLM_BLOCK);
|
||||
brow3 = vload4(0, slm_brow0 + 3 * SLM_BLOCK);
|
||||
brow4 = vload4(0, slm_brow0 + 4 * SLM_BLOCK);
|
||||
brow5 = vload4(0, slm_brow0 + 5 * SLM_BLOCK);
|
||||
brow6 = vload4(0, slm_brow0 + 6 * SLM_BLOCK);
|
||||
brow7 = vload4(0, slm_brow0 + 7 * SLM_BLOCK);
|
||||
if (a_col > K - 4 /*vload4*/)
|
||||
break;
|
||||
|
||||
#define MM_DOT_PRODUCT(_row,_dot) \
|
||||
arow = vload4(0, src0_read + _row * ldA); \
|
||||
_dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \
|
||||
_dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \
|
||||
_dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \
|
||||
_dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );
|
||||
int slm_brow_col = a_col_base + k_tile_offset; // <= SLM_BLOCK - 4
|
||||
#define READ_SLM_BROW(_row) \
|
||||
float4 brow##_row = vload4(0, slm_brow + mad24(_row, SLM_BLOCK, slm_brow_col));
|
||||
|
||||
READ_SLM_BROW(0);
|
||||
READ_SLM_BROW(1);
|
||||
READ_SLM_BROW(2);
|
||||
READ_SLM_BROW(3);
|
||||
READ_SLM_BROW(4);
|
||||
READ_SLM_BROW(5);
|
||||
READ_SLM_BROW(6);
|
||||
READ_SLM_BROW(7);
|
||||
#undef READ_SLM_BROW
|
||||
|
||||
#define MM_DOT_PRODUCT(_row,_dot) \
|
||||
{ \
|
||||
int a_row = a_row_base + _row; \
|
||||
if (a_row < M) { \
|
||||
float4 arow = vload4(0, src0_read00 + mad24(a_row, ldA, a_col)); \
|
||||
_dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \
|
||||
_dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \
|
||||
_dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \
|
||||
_dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot ); \
|
||||
} \
|
||||
}
|
||||
|
||||
MM_DOT_PRODUCT(0,dot00);
|
||||
MM_DOT_PRODUCT(1,dot01);
|
||||
@ -485,53 +514,7 @@ __kernel void intelblas_gemm_buffer_NT(
|
||||
MM_DOT_PRODUCT(6,dot06);
|
||||
MM_DOT_PRODUCT(7,dot07);
|
||||
#undef MM_DOT_PRODUCT
|
||||
|
||||
src0_read += TILE_K;
|
||||
slm_brow0 += TILE_K;
|
||||
w += TILE_K;
|
||||
}
|
||||
src1_read0 += SLM_BLOCK;
|
||||
}
|
||||
|
||||
if(w < K) {
|
||||
float4 arow;
|
||||
|
||||
#define READ_BROW(_brow,_row) \
|
||||
_brow = vload4(0, slm_brow0 + _row * SLM_BLOCK); \
|
||||
_brow.x = (mad24(local_x, 4, w) < K) ? _brow.x : 0.0f; \
|
||||
_brow.y = (mad24(local_x, 4, w + 1) < K) ? _brow.y : 0.0f; \
|
||||
_brow.z = (mad24(local_x, 4, w + 2) < K) ? _brow.z : 0.0f; \
|
||||
_brow.w = (mad24(local_x, 4, w + 3) < K) ? _brow.w : 0.0f;
|
||||
|
||||
READ_BROW(brow0,0);
|
||||
READ_BROW(brow1,1);
|
||||
READ_BROW(brow2,2);
|
||||
READ_BROW(brow3,3);
|
||||
READ_BROW(brow4,4);
|
||||
READ_BROW(brow5,5);
|
||||
READ_BROW(brow6,6);
|
||||
READ_BROW(brow7,7);
|
||||
|
||||
#define MM_DOT_PRODUCT(_row,_dot) \
|
||||
arow = vload4(0, src0_read + _row * ldA); \
|
||||
arow.x = (mad24(local_x, 4, w) < K) ? arow.x : 0.0f; \
|
||||
arow.y = (mad24(local_x, 4, w + 1) < K) ? arow.y : 0.0f; \
|
||||
arow.z = (mad24(local_x, 4, w + 2) < K) ? arow.z : 0.0f; \
|
||||
arow.w = (mad24(local_x, 4, w + 3) < K) ? arow.w : 0.0f; \
|
||||
_dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \
|
||||
_dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \
|
||||
_dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \
|
||||
_dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );
|
||||
|
||||
MM_DOT_PRODUCT(0,dot00);
|
||||
MM_DOT_PRODUCT(1,dot01);
|
||||
MM_DOT_PRODUCT(2,dot02);
|
||||
MM_DOT_PRODUCT(3,dot03);
|
||||
MM_DOT_PRODUCT(4,dot04);
|
||||
MM_DOT_PRODUCT(5,dot05);
|
||||
MM_DOT_PRODUCT(6,dot06);
|
||||
MM_DOT_PRODUCT(7,dot07);
|
||||
#undef MM_DOT_PRODUCT
|
||||
}
|
||||
|
||||
#define REDUCE(_dot) \
|
||||
@ -572,21 +555,22 @@ __kernel void intelblas_gemm_buffer_NT(
|
||||
output = (local_x == 5) ? _dot.s5 : output; \
|
||||
output = (local_x == 6) ? _dot.s6 : output; \
|
||||
output = (local_x == 7) ? _dot.s7 : output; \
|
||||
if (beta != 0.0) \
|
||||
if (beta != 0.0f) \
|
||||
dst_write0[0] = mad(output, (float)alpha, ((float)beta * dst_write0[0])); \
|
||||
else \
|
||||
dst_write0[0] = output * (float)alpha; \
|
||||
dst_write0 += ldC;
|
||||
|
||||
if(global_x < N && global_y * 8 < M) {
|
||||
OUTPUT(dot00);
|
||||
if(mad24(global_y, 8, 1) < M) { OUTPUT(dot01); }
|
||||
if(mad24(global_y, 8, 2) < M) { OUTPUT(dot02); }
|
||||
if(mad24(global_y, 8, 3) < M) { OUTPUT(dot03); }
|
||||
if(mad24(global_y, 8, 4) < M) { OUTPUT(dot04); }
|
||||
if(mad24(global_y, 8, 5) < M) { OUTPUT(dot05); }
|
||||
if(mad24(global_y, 8, 6) < M) { OUTPUT(dot06); }
|
||||
if(mad24(global_y, 8, 7) < M) { OUTPUT(dot07); }
|
||||
if (global_x < N && dst_row < M)
|
||||
{
|
||||
/*if (dst_row + 0 < M)*/ { OUTPUT(dot00); }
|
||||
if (dst_row + 1 < M) { OUTPUT(dot01); }
|
||||
if (dst_row + 2 < M) { OUTPUT(dot02); }
|
||||
if (dst_row + 3 < M) { OUTPUT(dot03); }
|
||||
if (dst_row + 4 < M) { OUTPUT(dot04); }
|
||||
if (dst_row + 5 < M) { OUTPUT(dot05); }
|
||||
if (dst_row + 6 < M) { OUTPUT(dot06); }
|
||||
if (dst_row + 7 < M) { OUTPUT(dot07); }
|
||||
}
|
||||
#undef OUTPUT
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user