core(ocl): buffer bounds in intelblas_gemm_buffer_NT

This commit is contained in:
Alexander Alekhin 2021-09-07 04:39:28 +00:00
parent e3f4f874c5
commit 9b4ecc96f6
2 changed files with 83 additions and 103 deletions

View File

@ -77,11 +77,7 @@ static bool intel_gpu_gemm(
}
else if(!atrans && btrans)
{
if (M % 128 != 0)
return false;
if (N % 8 != 0)
return false;
if (K % 512 != 0)
if (K % 4 != 0)
return false;
kernelName = "intelblas_gemm_buffer_NT";
ly = 16;

View File

@ -392,6 +392,15 @@ __kernel void intelblas_gemm_buffer_NN(
#define TILE_N 8
#define SLM_BLOCK 512
/*
A K B.t() K D N
----------- ----------- -----------
| | | | | |
M | | x N | | => M | |
| | | | | |
----------- ----------- -----------
*/
__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))
__kernel void intelblas_gemm_buffer_NT(
const __global float *src0, int off0,
@ -422,59 +431,79 @@ __kernel void intelblas_gemm_buffer_NT(
float8 dot06 = 0.f;
float8 dot07 = 0.f;
float4 brow0;
float4 brow1;
float4 brow2;
float4 brow3;
float4 brow4;
float4 brow5;
float4 brow6;
float4 brow7;
const int dst_row = (global_y * TILE_M);
__global float *dst_write0 = dst + global_x + dst_row * ldC + offd;
__global float *dst_write0 = dst + local_x * VEC_SIZE + ( group_x * TILE_N ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * ldC + offd;
const __global float *src0_read00 = src0 + off0;
const int a_row_base = global_y * TILE_M;
const int a_col_base = local_x * (TILE_K / 8); // <= TILE_K - 4
const __global float *src0_read = src0 + local_x * ( TILE_K / 8 ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M ) * ldA + off0;
const __global float *src1_read0 = src1 + ( group_x * TILE_N ) * ldB + off1;
const __global float *src1_read00 = src1 + off1;
const int b_row_base = (group_x * TILE_N);
//const int b_col_base = 0;
__local float slm_brow[8 * SLM_BLOCK];
__local float* slm_brow0;
int local_index = mad24(local_y, 8, local_x) * 4;
int w;
for(int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK) {
int w = 0;
for (int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK)
{
#define UPDATE_BROW(_row) \
{ \
float4 brow; \
int b_row = b_row_base + _row; \
int b_col = b_tile + local_index; \
if (b_row < N && b_col <= K - 4 /*vload4*/) \
brow = vload4(0, src1_read00 + mad24(b_row, ldB, b_col)); \
else \
brow = (float4)0; \
vstore4(brow, 0, slm_brow + mad24(_row, SLM_BLOCK, local_index)); \
}
barrier(CLK_LOCAL_MEM_FENCE);
vstore4(vload4(0, src1_read0 + mad24(0, ldB, local_index)), 0, slm_brow + mad24(0, SLM_BLOCK, local_index));
vstore4(vload4(0, src1_read0 + mad24(1, ldB, local_index)), 0, slm_brow + mad24(1, SLM_BLOCK, local_index));
vstore4(vload4(0, src1_read0 + mad24(2, ldB, local_index)), 0, slm_brow + mad24(2, SLM_BLOCK, local_index));
vstore4(vload4(0, src1_read0 + mad24(3, ldB, local_index)), 0, slm_brow + mad24(3, SLM_BLOCK, local_index));
vstore4(vload4(0, src1_read0 + mad24(4, ldB, local_index)), 0, slm_brow + mad24(4, SLM_BLOCK, local_index));
vstore4(vload4(0, src1_read0 + mad24(5, ldB, local_index)), 0, slm_brow + mad24(5, SLM_BLOCK, local_index));
vstore4(vload4(0, src1_read0 + mad24(6, ldB, local_index)), 0, slm_brow + mad24(6, SLM_BLOCK, local_index));
vstore4(vload4(0, src1_read0 + mad24(7, ldB, local_index)), 0, slm_brow + mad24(7, SLM_BLOCK, local_index));
UPDATE_BROW(0);
UPDATE_BROW(1);
UPDATE_BROW(2);
UPDATE_BROW(3);
UPDATE_BROW(4);
UPDATE_BROW(5);
UPDATE_BROW(6);
UPDATE_BROW(7);
barrier(CLK_LOCAL_MEM_FENCE);
#undef UPDATE_BROW
slm_brow0 = slm_brow + local_x * (TILE_K / 8);
w = b_tile;
int end_w = min(b_tile + SLM_BLOCK, K);
while( w + TILE_K <= end_w ) {
float4 arow;
for (int k_tile_offset = 0; k_tile_offset < SLM_BLOCK; k_tile_offset += TILE_K)
{
int a_col = a_col_base + b_tile + k_tile_offset;
brow0 = vload4(0, slm_brow0 + 0 * SLM_BLOCK);
brow1 = vload4(0, slm_brow0 + 1 * SLM_BLOCK);
brow2 = vload4(0, slm_brow0 + 2 * SLM_BLOCK);
brow3 = vload4(0, slm_brow0 + 3 * SLM_BLOCK);
brow4 = vload4(0, slm_brow0 + 4 * SLM_BLOCK);
brow5 = vload4(0, slm_brow0 + 5 * SLM_BLOCK);
brow6 = vload4(0, slm_brow0 + 6 * SLM_BLOCK);
brow7 = vload4(0, slm_brow0 + 7 * SLM_BLOCK);
if (a_col > K - 4 /*vload4*/)
break;
#define MM_DOT_PRODUCT(_row,_dot) \
arow = vload4(0, src0_read + _row * ldA); \
_dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \
_dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \
_dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \
_dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );
int slm_brow_col = a_col_base + k_tile_offset; // <= SLM_BLOCK - 4
#define READ_SLM_BROW(_row) \
float4 brow##_row = vload4(0, slm_brow + mad24(_row, SLM_BLOCK, slm_brow_col));
READ_SLM_BROW(0);
READ_SLM_BROW(1);
READ_SLM_BROW(2);
READ_SLM_BROW(3);
READ_SLM_BROW(4);
READ_SLM_BROW(5);
READ_SLM_BROW(6);
READ_SLM_BROW(7);
#undef READ_SLM_BROW
#define MM_DOT_PRODUCT(_row,_dot) \
{ \
int a_row = a_row_base + _row; \
if (a_row < M) { \
float4 arow = vload4(0, src0_read00 + mad24(a_row, ldA, a_col)); \
_dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \
_dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \
_dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \
_dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot ); \
} \
}
MM_DOT_PRODUCT(0,dot00);
MM_DOT_PRODUCT(1,dot01);
@ -485,53 +514,7 @@ __kernel void intelblas_gemm_buffer_NT(
MM_DOT_PRODUCT(6,dot06);
MM_DOT_PRODUCT(7,dot07);
#undef MM_DOT_PRODUCT
src0_read += TILE_K;
slm_brow0 += TILE_K;
w += TILE_K;
}
src1_read0 += SLM_BLOCK;
}
if(w < K) {
float4 arow;
#define READ_BROW(_brow,_row) \
_brow = vload4(0, slm_brow0 + _row * SLM_BLOCK); \
_brow.x = (mad24(local_x, 4, w) < K) ? _brow.x : 0.0f; \
_brow.y = (mad24(local_x, 4, w + 1) < K) ? _brow.y : 0.0f; \
_brow.z = (mad24(local_x, 4, w + 2) < K) ? _brow.z : 0.0f; \
_brow.w = (mad24(local_x, 4, w + 3) < K) ? _brow.w : 0.0f;
READ_BROW(brow0,0);
READ_BROW(brow1,1);
READ_BROW(brow2,2);
READ_BROW(brow3,3);
READ_BROW(brow4,4);
READ_BROW(brow5,5);
READ_BROW(brow6,6);
READ_BROW(brow7,7);
#define MM_DOT_PRODUCT(_row,_dot) \
arow = vload4(0, src0_read + _row * ldA); \
arow.x = (mad24(local_x, 4, w) < K) ? arow.x : 0.0f; \
arow.y = (mad24(local_x, 4, w + 1) < K) ? arow.y : 0.0f; \
arow.z = (mad24(local_x, 4, w + 2) < K) ? arow.z : 0.0f; \
arow.w = (mad24(local_x, 4, w + 3) < K) ? arow.w : 0.0f; \
_dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \
_dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \
_dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \
_dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );
MM_DOT_PRODUCT(0,dot00);
MM_DOT_PRODUCT(1,dot01);
MM_DOT_PRODUCT(2,dot02);
MM_DOT_PRODUCT(3,dot03);
MM_DOT_PRODUCT(4,dot04);
MM_DOT_PRODUCT(5,dot05);
MM_DOT_PRODUCT(6,dot06);
MM_DOT_PRODUCT(7,dot07);
#undef MM_DOT_PRODUCT
}
#define REDUCE(_dot) \
@ -572,21 +555,22 @@ __kernel void intelblas_gemm_buffer_NT(
output = (local_x == 5) ? _dot.s5 : output; \
output = (local_x == 6) ? _dot.s6 : output; \
output = (local_x == 7) ? _dot.s7 : output; \
if (beta != 0.0) \
if (beta != 0.0f) \
dst_write0[0] = mad(output, (float)alpha, ((float)beta * dst_write0[0])); \
else \
dst_write0[0] = output * (float)alpha; \
dst_write0 += ldC;
if(global_x < N && global_y * 8 < M) {
OUTPUT(dot00);
if(mad24(global_y, 8, 1) < M) { OUTPUT(dot01); }
if(mad24(global_y, 8, 2) < M) { OUTPUT(dot02); }
if(mad24(global_y, 8, 3) < M) { OUTPUT(dot03); }
if(mad24(global_y, 8, 4) < M) { OUTPUT(dot04); }
if(mad24(global_y, 8, 5) < M) { OUTPUT(dot05); }
if(mad24(global_y, 8, 6) < M) { OUTPUT(dot06); }
if(mad24(global_y, 8, 7) < M) { OUTPUT(dot07); }
if (global_x < N && dst_row < M)
{
/*if (dst_row + 0 < M)*/ { OUTPUT(dot00); }
if (dst_row + 1 < M) { OUTPUT(dot01); }
if (dst_row + 2 < M) { OUTPUT(dot02); }
if (dst_row + 3 < M) { OUTPUT(dot03); }
if (dst_row + 4 < M) { OUTPUT(dot04); }
if (dst_row + 5 < M) { OUTPUT(dot05); }
if (dst_row + 6 < M) { OUTPUT(dot06); }
if (dst_row + 7 < M) { OUTPUT(dot07); }
}
#undef OUTPUT
}