|
|
|
@ -392,6 +392,15 @@ __kernel void intelblas_gemm_buffer_NN( |
|
|
|
|
#define TILE_N 8 |
|
|
|
|
#define SLM_BLOCK 512 |
|
|
|
|
|
|
|
|
|
/* |
|
|
|
|
A K B.t() K D N |
|
|
|
|
----------- ----------- ----------- |
|
|
|
|
| | | | | | |
|
|
|
|
M | | x N | | => M | | |
|
|
|
|
| | | | | | |
|
|
|
|
----------- ----------- ----------- |
|
|
|
|
*/ |
|
|
|
|
|
|
|
|
|
__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1))) |
|
|
|
|
__kernel void intelblas_gemm_buffer_NT( |
|
|
|
|
const __global float *src0, int off0, |
|
|
|
@ -422,59 +431,79 @@ __kernel void intelblas_gemm_buffer_NT( |
|
|
|
|
float8 dot06 = 0.f; |
|
|
|
|
float8 dot07 = 0.f; |
|
|
|
|
|
|
|
|
|
float4 brow0; |
|
|
|
|
float4 brow1; |
|
|
|
|
float4 brow2; |
|
|
|
|
float4 brow3; |
|
|
|
|
float4 brow4; |
|
|
|
|
float4 brow5; |
|
|
|
|
float4 brow6; |
|
|
|
|
float4 brow7; |
|
|
|
|
|
|
|
|
|
__global float *dst_write0 = dst + local_x * VEC_SIZE + ( group_x * TILE_N ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * ldC + offd; |
|
|
|
|
const int dst_row = (global_y * TILE_M); |
|
|
|
|
__global float *dst_write0 = dst + global_x + dst_row * ldC + offd; |
|
|
|
|
|
|
|
|
|
const __global float *src0_read = src0 + local_x * ( TILE_K / 8 ) + ( group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M ) * ldA + off0; |
|
|
|
|
const __global float *src0_read00 = src0 + off0; |
|
|
|
|
const int a_row_base = global_y * TILE_M; |
|
|
|
|
const int a_col_base = local_x * (TILE_K / 8); // <= TILE_K - 4 |
|
|
|
|
|
|
|
|
|
const __global float *src1_read0 = src1 + ( group_x * TILE_N ) * ldB + off1; |
|
|
|
|
const __global float *src1_read00 = src1 + off1; |
|
|
|
|
const int b_row_base = (group_x * TILE_N); |
|
|
|
|
//const int b_col_base = 0; |
|
|
|
|
|
|
|
|
|
__local float slm_brow[8 * SLM_BLOCK]; |
|
|
|
|
__local float* slm_brow0; |
|
|
|
|
|
|
|
|
|
int local_index = mad24(local_y, 8, local_x) * 4; |
|
|
|
|
int w; |
|
|
|
|
for(int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK) { |
|
|
|
|
int w = 0; |
|
|
|
|
for (int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK) |
|
|
|
|
{ |
|
|
|
|
#define UPDATE_BROW(_row) \ |
|
|
|
|
{ \ |
|
|
|
|
float4 brow; \ |
|
|
|
|
int b_row = b_row_base + _row; \ |
|
|
|
|
int b_col = b_tile + local_index; \ |
|
|
|
|
if (b_row < N && b_col <= K - 4 /*vload4*/) \ |
|
|
|
|
brow = vload4(0, src1_read00 + mad24(b_row, ldB, b_col)); \ |
|
|
|
|
else \ |
|
|
|
|
brow = (float4)0; \ |
|
|
|
|
vstore4(brow, 0, slm_brow + mad24(_row, SLM_BLOCK, local_index)); \ |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE); |
|
|
|
|
vstore4(vload4(0, src1_read0 + mad24(0, ldB, local_index)), 0, slm_brow + mad24(0, SLM_BLOCK, local_index)); |
|
|
|
|
vstore4(vload4(0, src1_read0 + mad24(1, ldB, local_index)), 0, slm_brow + mad24(1, SLM_BLOCK, local_index)); |
|
|
|
|
vstore4(vload4(0, src1_read0 + mad24(2, ldB, local_index)), 0, slm_brow + mad24(2, SLM_BLOCK, local_index)); |
|
|
|
|
vstore4(vload4(0, src1_read0 + mad24(3, ldB, local_index)), 0, slm_brow + mad24(3, SLM_BLOCK, local_index)); |
|
|
|
|
vstore4(vload4(0, src1_read0 + mad24(4, ldB, local_index)), 0, slm_brow + mad24(4, SLM_BLOCK, local_index)); |
|
|
|
|
vstore4(vload4(0, src1_read0 + mad24(5, ldB, local_index)), 0, slm_brow + mad24(5, SLM_BLOCK, local_index)); |
|
|
|
|
vstore4(vload4(0, src1_read0 + mad24(6, ldB, local_index)), 0, slm_brow + mad24(6, SLM_BLOCK, local_index)); |
|
|
|
|
vstore4(vload4(0, src1_read0 + mad24(7, ldB, local_index)), 0, slm_brow + mad24(7, SLM_BLOCK, local_index)); |
|
|
|
|
UPDATE_BROW(0); |
|
|
|
|
UPDATE_BROW(1); |
|
|
|
|
UPDATE_BROW(2); |
|
|
|
|
UPDATE_BROW(3); |
|
|
|
|
UPDATE_BROW(4); |
|
|
|
|
UPDATE_BROW(5); |
|
|
|
|
UPDATE_BROW(6); |
|
|
|
|
UPDATE_BROW(7); |
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE); |
|
|
|
|
|
|
|
|
|
slm_brow0 = slm_brow + local_x * (TILE_K / 8); |
|
|
|
|
w = b_tile; |
|
|
|
|
int end_w = min(b_tile + SLM_BLOCK, K); |
|
|
|
|
while( w + TILE_K <= end_w ) { |
|
|
|
|
float4 arow; |
|
|
|
|
|
|
|
|
|
brow0 = vload4(0, slm_brow0 + 0 * SLM_BLOCK); |
|
|
|
|
brow1 = vload4(0, slm_brow0 + 1 * SLM_BLOCK); |
|
|
|
|
brow2 = vload4(0, slm_brow0 + 2 * SLM_BLOCK); |
|
|
|
|
brow3 = vload4(0, slm_brow0 + 3 * SLM_BLOCK); |
|
|
|
|
brow4 = vload4(0, slm_brow0 + 4 * SLM_BLOCK); |
|
|
|
|
brow5 = vload4(0, slm_brow0 + 5 * SLM_BLOCK); |
|
|
|
|
brow6 = vload4(0, slm_brow0 + 6 * SLM_BLOCK); |
|
|
|
|
brow7 = vload4(0, slm_brow0 + 7 * SLM_BLOCK); |
|
|
|
|
|
|
|
|
|
#define MM_DOT_PRODUCT(_row,_dot) \ |
|
|
|
|
arow = vload4(0, src0_read + _row * ldA); \ |
|
|
|
|
_dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \ |
|
|
|
|
_dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \ |
|
|
|
|
_dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \ |
|
|
|
|
_dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot ); |
|
|
|
|
#undef UPDATE_BROW |
|
|
|
|
|
|
|
|
|
for (int k_tile_offset = 0; k_tile_offset < SLM_BLOCK; k_tile_offset += TILE_K) |
|
|
|
|
{ |
|
|
|
|
int a_col = a_col_base + b_tile + k_tile_offset; |
|
|
|
|
|
|
|
|
|
if (a_col > K - 4 /*vload4*/) |
|
|
|
|
break; |
|
|
|
|
|
|
|
|
|
int slm_brow_col = a_col_base + k_tile_offset; // <= SLM_BLOCK - 4 |
|
|
|
|
#define READ_SLM_BROW(_row) \ |
|
|
|
|
float4 brow##_row = vload4(0, slm_brow + mad24(_row, SLM_BLOCK, slm_brow_col)); |
|
|
|
|
|
|
|
|
|
READ_SLM_BROW(0); |
|
|
|
|
READ_SLM_BROW(1); |
|
|
|
|
READ_SLM_BROW(2); |
|
|
|
|
READ_SLM_BROW(3); |
|
|
|
|
READ_SLM_BROW(4); |
|
|
|
|
READ_SLM_BROW(5); |
|
|
|
|
READ_SLM_BROW(6); |
|
|
|
|
READ_SLM_BROW(7); |
|
|
|
|
#undef READ_SLM_BROW |
|
|
|
|
|
|
|
|
|
#define MM_DOT_PRODUCT(_row,_dot) \ |
|
|
|
|
{ \ |
|
|
|
|
int a_row = a_row_base + _row; \ |
|
|
|
|
if (a_row < M) { \ |
|
|
|
|
float4 arow = vload4(0, src0_read00 + mad24(a_row, ldA, a_col)); \ |
|
|
|
|
_dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \ |
|
|
|
|
_dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \ |
|
|
|
|
_dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \ |
|
|
|
|
_dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot ); \ |
|
|
|
|
} \ |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
MM_DOT_PRODUCT(0,dot00); |
|
|
|
|
MM_DOT_PRODUCT(1,dot01); |
|
|
|
@ -485,53 +514,7 @@ __kernel void intelblas_gemm_buffer_NT( |
|
|
|
|
MM_DOT_PRODUCT(6,dot06); |
|
|
|
|
MM_DOT_PRODUCT(7,dot07); |
|
|
|
|
#undef MM_DOT_PRODUCT |
|
|
|
|
|
|
|
|
|
src0_read += TILE_K; |
|
|
|
|
slm_brow0 += TILE_K; |
|
|
|
|
w += TILE_K; |
|
|
|
|
} |
|
|
|
|
src1_read0 += SLM_BLOCK; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if(w < K) { |
|
|
|
|
float4 arow; |
|
|
|
|
|
|
|
|
|
#define READ_BROW(_brow,_row) \ |
|
|
|
|
_brow = vload4(0, slm_brow0 + _row * SLM_BLOCK); \ |
|
|
|
|
_brow.x = (mad24(local_x, 4, w) < K) ? _brow.x : 0.0f; \ |
|
|
|
|
_brow.y = (mad24(local_x, 4, w + 1) < K) ? _brow.y : 0.0f; \ |
|
|
|
|
_brow.z = (mad24(local_x, 4, w + 2) < K) ? _brow.z : 0.0f; \ |
|
|
|
|
_brow.w = (mad24(local_x, 4, w + 3) < K) ? _brow.w : 0.0f; |
|
|
|
|
|
|
|
|
|
READ_BROW(brow0,0); |
|
|
|
|
READ_BROW(brow1,1); |
|
|
|
|
READ_BROW(brow2,2); |
|
|
|
|
READ_BROW(brow3,3); |
|
|
|
|
READ_BROW(brow4,4); |
|
|
|
|
READ_BROW(brow5,5); |
|
|
|
|
READ_BROW(brow6,6); |
|
|
|
|
READ_BROW(brow7,7); |
|
|
|
|
|
|
|
|
|
#define MM_DOT_PRODUCT(_row,_dot) \ |
|
|
|
|
arow = vload4(0, src0_read + _row * ldA); \ |
|
|
|
|
arow.x = (mad24(local_x, 4, w) < K) ? arow.x : 0.0f; \ |
|
|
|
|
arow.y = (mad24(local_x, 4, w + 1) < K) ? arow.y : 0.0f; \ |
|
|
|
|
arow.z = (mad24(local_x, 4, w + 2) < K) ? arow.z : 0.0f; \ |
|
|
|
|
arow.w = (mad24(local_x, 4, w + 3) < K) ? arow.w : 0.0f; \ |
|
|
|
|
_dot = mad( (float8)(arow.x), (float8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \ |
|
|
|
|
_dot = mad( (float8)(arow.y), (float8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \ |
|
|
|
|
_dot = mad( (float8)(arow.z), (float8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \ |
|
|
|
|
_dot = mad( (float8)(arow.w), (float8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot ); |
|
|
|
|
|
|
|
|
|
MM_DOT_PRODUCT(0,dot00); |
|
|
|
|
MM_DOT_PRODUCT(1,dot01); |
|
|
|
|
MM_DOT_PRODUCT(2,dot02); |
|
|
|
|
MM_DOT_PRODUCT(3,dot03); |
|
|
|
|
MM_DOT_PRODUCT(4,dot04); |
|
|
|
|
MM_DOT_PRODUCT(5,dot05); |
|
|
|
|
MM_DOT_PRODUCT(6,dot06); |
|
|
|
|
MM_DOT_PRODUCT(7,dot07); |
|
|
|
|
#undef MM_DOT_PRODUCT |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
#define REDUCE(_dot) \ |
|
|
|
@ -572,21 +555,22 @@ __kernel void intelblas_gemm_buffer_NT( |
|
|
|
|
output = (local_x == 5) ? _dot.s5 : output; \ |
|
|
|
|
output = (local_x == 6) ? _dot.s6 : output; \ |
|
|
|
|
output = (local_x == 7) ? _dot.s7 : output; \ |
|
|
|
|
if (beta != 0.0) \ |
|
|
|
|
if (beta != 0.0f) \ |
|
|
|
|
dst_write0[0] = mad(output, (float)alpha, ((float)beta * dst_write0[0])); \ |
|
|
|
|
else \ |
|
|
|
|
dst_write0[0] = output * (float)alpha; \ |
|
|
|
|
dst_write0 += ldC; |
|
|
|
|
|
|
|
|
|
if(global_x < N && global_y * 8 < M) { |
|
|
|
|
OUTPUT(dot00); |
|
|
|
|
if(mad24(global_y, 8, 1) < M) { OUTPUT(dot01); } |
|
|
|
|
if(mad24(global_y, 8, 2) < M) { OUTPUT(dot02); } |
|
|
|
|
if(mad24(global_y, 8, 3) < M) { OUTPUT(dot03); } |
|
|
|
|
if(mad24(global_y, 8, 4) < M) { OUTPUT(dot04); } |
|
|
|
|
if(mad24(global_y, 8, 5) < M) { OUTPUT(dot05); } |
|
|
|
|
if(mad24(global_y, 8, 6) < M) { OUTPUT(dot06); } |
|
|
|
|
if(mad24(global_y, 8, 7) < M) { OUTPUT(dot07); } |
|
|
|
|
if (global_x < N && dst_row < M) |
|
|
|
|
{ |
|
|
|
|
/*if (dst_row + 0 < M)*/ { OUTPUT(dot00); } |
|
|
|
|
if (dst_row + 1 < M) { OUTPUT(dot01); } |
|
|
|
|
if (dst_row + 2 < M) { OUTPUT(dot02); } |
|
|
|
|
if (dst_row + 3 < M) { OUTPUT(dot03); } |
|
|
|
|
if (dst_row + 4 < M) { OUTPUT(dot04); } |
|
|
|
|
if (dst_row + 5 < M) { OUTPUT(dot05); } |
|
|
|
|
if (dst_row + 6 < M) { OUTPUT(dot06); } |
|
|
|
|
if (dst_row + 7 < M) { OUTPUT(dot07); } |
|
|
|
|
} |
|
|
|
|
#undef OUTPUT |
|
|
|
|
} |
|
|
|
|