From 9b4ecc96f63d64a07ac043ad06fa44a1fd02b18b Mon Sep 17 00:00:00 2001 From: Alexander Alekhin Date: Tue, 7 Sep 2021 04:39:28 +0000 Subject: [PATCH] core(ocl): buffer bounds in intelblas_gemm_buffer_NT --- modules/core/src/intel_gpu_gemm.inl.hpp | 6 +- modules/core/src/opencl/intel_gemm.cl | 186 +++++++++++------------- 2 files changed, 86 insertions(+), 106 deletions(-) diff --git a/modules/core/src/intel_gpu_gemm.inl.hpp b/modules/core/src/intel_gpu_gemm.inl.hpp index fa66856f5e..28cc4ab9b9 100644 --- a/modules/core/src/intel_gpu_gemm.inl.hpp +++ b/modules/core/src/intel_gpu_gemm.inl.hpp @@ -77,11 +77,7 @@ static bool intel_gpu_gemm( } else if(!atrans && btrans) { - if (M % 128 != 0) - return false; - if (N % 8 != 0) - return false; - if (K % 512 != 0) + if (K % 4 != 0) return false; kernelName = "intelblas_gemm_buffer_NT"; ly = 16; diff --git a/modules/core/src/opencl/intel_gemm.cl b/modules/core/src/opencl/intel_gemm.cl index 6cea8d7efd..53ae790779 100644 --- a/modules/core/src/opencl/intel_gemm.cl +++ b/modules/core/src/opencl/intel_gemm.cl @@ -392,6 +392,15 @@ __kernel void intelblas_gemm_buffer_NN( #define TILE_N 8 #define SLM_BLOCK 512 +/* + A K B.t() K D N + ----------- ----------- ----------- + | | | | | | + M | | x N | | => M | | + | | | | | | + ----------- ----------- ----------- +*/ + __attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1))) __kernel void intelblas_gemm_buffer_NT( const __global float *src0, int off0, @@ -422,59 +431,79 @@ __kernel void intelblas_gemm_buffer_NT( float8 dot06 = 0.f; float8 dot07 = 0.f; - float4 brow0; - float4 brow1; - float4 brow2; - float4 brow3; - float4 brow4; - float4 brow5; - float4 brow6; - float4 brow7; - - __global float *dst_write0 = dst + local_x * VEC_SIZE + ( group_x * TILE_N ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * ldC + offd; + const int dst_row = (global_y * TILE_M); + __global float *dst_write0 = dst + global_x + dst_row * ldC + offd; - const __global float *src0_read = src0 + local_x * ( TILE_K / 8 ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M ) * ldA + off0; + const __global float *src0_read00 = src0 + off0; + const int a_row_base = global_y * TILE_M; + const int a_col_base = local_x * (TILE_K / 8); // <= TILE_K - 4 - const __global float *src1_read0 = src1 + ( group_x * TILE_N ) * ldB + off1; + const __global float *src1_read00 = src1 + off1; + const int b_row_base = (group_x * TILE_N); + //const int b_col_base = 0; __local float slm_brow[8 * SLM_BLOCK]; - __local float* slm_brow0; int local_index = mad24(local_y, 8, local_x) * 4; - int w; - for(int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK) { + int w = 0; + for (int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK) + { +#define UPDATE_BROW(_row) \ + { \ + float4 brow; \ + int b_row = b_row_base + _row; \ + int b_col = b_tile + local_index; \ + if (b_row < N && b_col <= K - 4 /*vload4*/) \ + brow = vload4(0, src1_read00 + mad24(b_row, ldB, b_col)); \ + else \ + brow = (float4)0; \ + vstore4(brow, 0, slm_brow + mad24(_row, SLM_BLOCK, local_index)); \ + } + barrier(CLK_LOCAL_MEM_FENCE); - vstore4(vload4(0, src1_read0 + mad24(0, ldB, local_index)), 0, slm_brow + mad24(0, SLM_BLOCK, local_index)); - vstore4(vload4(0, src1_read0 + mad24(1, ldB, local_index)), 0, slm_brow + mad24(1, SLM_BLOCK, local_index)); - vstore4(vload4(0, src1_read0 + mad24(2, ldB, local_index)), 0, slm_brow + mad24(2, SLM_BLOCK, local_index)); - vstore4(vload4(0, src1_read0 + mad24(3, ldB, local_index)), 0, slm_brow + mad24(3, SLM_BLOCK, local_index)); - vstore4(vload4(0, src1_read0 + mad24(4, ldB, local_index)), 0, slm_brow + mad24(4, SLM_BLOCK, local_index)); - vstore4(vload4(0, src1_read0 + mad24(5, ldB, local_index)), 0, slm_brow + mad24(5, SLM_BLOCK, local_index)); - vstore4(vload4(0, src1_read0 + mad24(6, ldB, local_index)), 0, slm_brow + mad24(6, SLM_BLOCK, local_index)); - vstore4(vload4(0, src1_read0 + mad24(7, ldB, local_index)), 0, slm_brow + mad24(7, SLM_BLOCK, local_index)); + UPDATE_BROW(0); + UPDATE_BROW(1); + UPDATE_BROW(2); + UPDATE_BROW(3); + UPDATE_BROW(4); + UPDATE_BROW(5); + UPDATE_BROW(6); + UPDATE_BROW(7); barrier(CLK_LOCAL_MEM_FENCE); - - slm_brow0 = slm_brow + local_x * (TILE_K / 8); - w = b_tile; - int end_w = min(b_tile + SLM_BLOCK, K); - while( w + TILE_K <= end_w ) { - float4 arow; - - brow0 = vload4(0, slm_brow0 + 0 * SLM_BLOCK); - brow1 = vload4(0, slm_brow0 + 1 * SLM_BLOCK); - brow2 = vload4(0, slm_brow0 + 2 * SLM_BLOCK); - brow3 = vload4(0, slm_brow0 + 3 * SLM_BLOCK); - brow4 = vload4(0, slm_brow0 + 4 * SLM_BLOCK); - brow5 = vload4(0, slm_brow0 + 5 * SLM_BLOCK); - brow6 = vload4(0, slm_brow0 + 6 * SLM_BLOCK); - brow7 = vload4(0, slm_brow0 + 7 * SLM_BLOCK); - -#define MM_DOT_PRODUCT(_row,_dot) \ - arow = vload4(0, src0_read + _row * ldA); \ - _dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \ - _dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \ - _dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \ - _dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot ); +#undef UPDATE_BROW + + for (int k_tile_offset = 0; k_tile_offset < SLM_BLOCK; k_tile_offset += TILE_K) + { + int a_col = a_col_base + b_tile + k_tile_offset; + + if (a_col > K - 4 /*vload4*/) + break; + + int slm_brow_col = a_col_base + k_tile_offset; // <= SLM_BLOCK - 4 +#define READ_SLM_BROW(_row) \ + float4 brow##_row = vload4(0, slm_brow + mad24(_row, SLM_BLOCK, slm_brow_col)); + + READ_SLM_BROW(0); + READ_SLM_BROW(1); + READ_SLM_BROW(2); + READ_SLM_BROW(3); + READ_SLM_BROW(4); + READ_SLM_BROW(5); + READ_SLM_BROW(6); + READ_SLM_BROW(7); +#undef READ_SLM_BROW + +#define MM_DOT_PRODUCT(_row,_dot) \ + { \ + int a_row = a_row_base + _row; \ + if (a_row < M) { \ + float4 arow = vload4(0, src0_read00 + mad24(a_row, ldA, a_col)); \ + _dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \ + _dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \ + _dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \ + _dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot ); \ + } \ + } MM_DOT_PRODUCT(0,dot00); MM_DOT_PRODUCT(1,dot01); @@ -485,53 +514,7 @@ __kernel void intelblas_gemm_buffer_NT( MM_DOT_PRODUCT(6,dot06); MM_DOT_PRODUCT(7,dot07); #undef MM_DOT_PRODUCT - - src0_read += TILE_K; - slm_brow0 += TILE_K; - w += TILE_K; } - src1_read0 += SLM_BLOCK; - } - - if(w < K) { - float4 arow; - -#define READ_BROW(_brow,_row) \ - _brow = vload4(0, slm_brow0 + _row * SLM_BLOCK); \ - _brow.x = (mad24(local_x, 4, w) < K) ? _brow.x : 0.0f; \ - _brow.y = (mad24(local_x, 4, w + 1) < K) ? _brow.y : 0.0f; \ - _brow.z = (mad24(local_x, 4, w + 2) < K) ? _brow.z : 0.0f; \ - _brow.w = (mad24(local_x, 4, w + 3) < K) ? _brow.w : 0.0f; - - READ_BROW(brow0,0); - READ_BROW(brow1,1); - READ_BROW(brow2,2); - READ_BROW(brow3,3); - READ_BROW(brow4,4); - READ_BROW(brow5,5); - READ_BROW(brow6,6); - READ_BROW(brow7,7); - -#define MM_DOT_PRODUCT(_row,_dot) \ - arow = vload4(0, src0_read + _row * ldA); \ - arow.x = (mad24(local_x, 4, w) < K) ? arow.x : 0.0f; \ - arow.y = (mad24(local_x, 4, w + 1) < K) ? arow.y : 0.0f; \ - arow.z = (mad24(local_x, 4, w + 2) < K) ? arow.z : 0.0f; \ - arow.w = (mad24(local_x, 4, w + 3) < K) ? arow.w : 0.0f; \ - _dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \ - _dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \ - _dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \ - _dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot ); - - MM_DOT_PRODUCT(0,dot00); - MM_DOT_PRODUCT(1,dot01); - MM_DOT_PRODUCT(2,dot02); - MM_DOT_PRODUCT(3,dot03); - MM_DOT_PRODUCT(4,dot04); - MM_DOT_PRODUCT(5,dot05); - MM_DOT_PRODUCT(6,dot06); - MM_DOT_PRODUCT(7,dot07); -#undef MM_DOT_PRODUCT } #define REDUCE(_dot) \ @@ -572,21 +555,22 @@ __kernel void intelblas_gemm_buffer_NT( output = (local_x == 5) ? _dot.s5 : output; \ output = (local_x == 6) ? _dot.s6 : output; \ output = (local_x == 7) ? _dot.s7 : output; \ - if (beta != 0.0) \ + if (beta != 0.0f) \ dst_write0[0] = mad(output, (float)alpha, ((float)beta * dst_write0[0])); \ else \ dst_write0[0] = output * (float)alpha; \ dst_write0 += ldC; - if(global_x < N && global_y * 8 < M) { - OUTPUT(dot00); - if(mad24(global_y, 8, 1) < M) { OUTPUT(dot01); } - if(mad24(global_y, 8, 2) < M) { OUTPUT(dot02); } - if(mad24(global_y, 8, 3) < M) { OUTPUT(dot03); } - if(mad24(global_y, 8, 4) < M) { OUTPUT(dot04); } - if(mad24(global_y, 8, 5) < M) { OUTPUT(dot05); } - if(mad24(global_y, 8, 6) < M) { OUTPUT(dot06); } - if(mad24(global_y, 8, 7) < M) { OUTPUT(dot07); } + if (global_x < N && dst_row < M) + { + /*if (dst_row + 0 < M)*/ { OUTPUT(dot00); } + if (dst_row + 1 < M) { OUTPUT(dot01); } + if (dst_row + 2 < M) { OUTPUT(dot02); } + if (dst_row + 3 < M) { OUTPUT(dot03); } + if (dst_row + 4 < M) { OUTPUT(dot04); } + if (dst_row + 5 < M) { OUTPUT(dot05); } + if (dst_row + 6 < M) { OUTPUT(dot06); } + if (dst_row + 7 < M) { OUTPUT(dot07); } } #undef OUTPUT }