|
|
|
@ -90,6 +90,12 @@ |
|
|
|
|
#pragma OPENCL EXTENSION cl_intel_subgroups : enable |
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
#ifdef ZERO_BETA |
|
|
|
|
#define BETA_ZERO_CHECK(b0, v) (b0) |
|
|
|
|
#else |
|
|
|
|
#define BETA_ZERO_CHECK(b0, v) (v) |
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
#define VEC_SIZE 4 |
|
|
|
|
#define LWG_HEIGHT 4 |
|
|
|
|
#define TILE_M 8 |
|
|
|
@ -143,14 +149,14 @@ __kernel void TEMPLATE(gemm_buffer_NN, Dtype)( |
|
|
|
|
int row6 = mad24(global_y, TILE_M, 6) < M ? 6 : border; |
|
|
|
|
int row7 = mad24(global_y, TILE_M, 7) < M ? 7 : border; |
|
|
|
|
|
|
|
|
|
Dtype4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0); |
|
|
|
|
Dtype4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + 1 * N) : beta * vload4(0, dst_write0 + 1 * N); |
|
|
|
|
Dtype4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N); |
|
|
|
|
Dtype4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N); |
|
|
|
|
Dtype4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N); |
|
|
|
|
Dtype4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N); |
|
|
|
|
Dtype4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N); |
|
|
|
|
Dtype4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N); |
|
|
|
|
Dtype4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0)); |
|
|
|
|
Dtype4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + 1 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 1 * N)); |
|
|
|
|
Dtype4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 2 * N)); |
|
|
|
|
Dtype4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 3 * N)); |
|
|
|
|
Dtype4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 4 * N)); |
|
|
|
|
Dtype4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 5 * N)); |
|
|
|
|
Dtype4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 6 * N)); |
|
|
|
|
Dtype4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : BETA_ZERO_CHECK((Dtype4)0, beta * vload4(0, dst_write0 + 7 * N)); |
|
|
|
|
|
|
|
|
|
int end_index = min(start_index + 256, K); |
|
|
|
|
int w = start_index; |
|
|
|
@ -579,7 +585,7 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( |
|
|
|
|
output = (local_x == 5) ? _dot.s5 : output; \ |
|
|
|
|
output = (local_x == 6) ? _dot.s6 : output; \ |
|
|
|
|
output = (local_x == 7) ? _dot.s7 : output; \ |
|
|
|
|
dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); \ |
|
|
|
|
dst_write0[0] = BETA_ZERO_CHECK(alpha * output, mad(output, alpha, beta * dst_write0[0])); \ |
|
|
|
|
dst_write0 += N; |
|
|
|
|
|
|
|
|
|
if(global_x < N && global_y * 8 < M) { |
|
|
|
@ -765,7 +771,7 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)( |
|
|
|
|
output = (local_x == 5) ? _dot.s5 : output; \ |
|
|
|
|
output = (local_x == 6) ? _dot.s6 : output; \ |
|
|
|
|
output = (local_x == 7) ? _dot.s7 : output; \ |
|
|
|
|
dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); \ |
|
|
|
|
dst_write0[0] = BETA_ZERO_CHECK(alpha * output, mad(output, alpha, beta * dst_write0[0])); \ |
|
|
|
|
dst_write0 += N; |
|
|
|
|
|
|
|
|
|
if(global_x < N && global_y * 8 < M) { |
|
|
|
@ -819,8 +825,9 @@ void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)( |
|
|
|
|
const Dtype4 b1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]}; |
|
|
|
|
#pragma unroll |
|
|
|
|
for(int j = 0; j < rows; ++j) { |
|
|
|
|
dot0[j] += b0 * vload4(i, srcb_read + j * K); |
|
|
|
|
dot1[j] += b1 * vload4(i, srcb_read + j * K); |
|
|
|
|
Dtype4 a = vload4(i, srcb_read + j * K); |
|
|
|
|
dot0[j] += b0 * a; |
|
|
|
|
dot1[j] += b1 * a; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
i += get_local_size(0); |
|
|
|
@ -859,11 +866,19 @@ void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)( |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE); |
|
|
|
|
if(lid == 0) { |
|
|
|
|
#pragma unroll |
|
|
|
|
for(int j = 0; j < rows; ++j) { |
|
|
|
|
dstc0[(x_gid * 4 + j)] = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)]; |
|
|
|
|
dstc1[(x_gid * 4 + j)] = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)]; |
|
|
|
|
#ifdef ZERO_BETA |
|
|
|
|
Dtype a0 = alpha * work_each0[j]; |
|
|
|
|
Dtype a1 = alpha * work_each1[j]; |
|
|
|
|
#else |
|
|
|
|
Dtype a0 = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)]; |
|
|
|
|
Dtype a1 = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)]; |
|
|
|
|
#endif |
|
|
|
|
dstc0[(x_gid * 4 + j)] = a0; |
|
|
|
|
dstc1[(x_gid * 4 + j)] = a1; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
@ -952,9 +967,15 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_2,Dtype)( |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if(lid == 0) { |
|
|
|
|
if(lid == 0) |
|
|
|
|
{ |
|
|
|
|
#ifdef ZERO_BETA |
|
|
|
|
dstc0[x_gid] = alpha * work0[0]; |
|
|
|
|
dstc1[x_gid] = alpha * work1[0]; |
|
|
|
|
#else |
|
|
|
|
dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid]; |
|
|
|
|
dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid]; |
|
|
|
|
#endif |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
@ -1058,10 +1079,17 @@ void TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype)( |
|
|
|
|
if(lid == 0) { |
|
|
|
|
#pragma unroll |
|
|
|
|
for(int j = 0; j < rows; ++j) { |
|
|
|
|
#ifdef ZERO_BETA |
|
|
|
|
dstc0[(x_gid * 4 + j)] = alpha * work_each0[j]; |
|
|
|
|
dstc1[(x_gid * 4 + j)] = alpha * work_each1[j]; |
|
|
|
|
dstc2[(x_gid * 4 + j)] = alpha * work_each2[j]; |
|
|
|
|
dstc3[(x_gid * 4 + j)] = alpha * work_each3[j]; |
|
|
|
|
#else |
|
|
|
|
dstc0[(x_gid * 4 + j)] = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)]; |
|
|
|
|
dstc1[(x_gid * 4 + j)] = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)]; |
|
|
|
|
dstc2[(x_gid * 4 + j)] = alpha * work_each2[j] + beta * dstc2[(x_gid * 4 + j)]; |
|
|
|
|
dstc3[(x_gid * 4 + j)] = alpha * work_each3[j] + beta * dstc3[(x_gid * 4 + j)]; |
|
|
|
|
#endif |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
@ -1179,10 +1207,17 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_4,Dtype)( |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if(lid == 0) { |
|
|
|
|
#ifdef ZERO_BETA |
|
|
|
|
dstc0[x_gid] = alpha * work0[0]; |
|
|
|
|
dstc1[x_gid] = alpha * work1[0]; |
|
|
|
|
dstc2[x_gid] = alpha * work2[0]; |
|
|
|
|
dstc3[x_gid] = alpha * work3[0]; |
|
|
|
|
#else |
|
|
|
|
dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid]; |
|
|
|
|
dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid]; |
|
|
|
|
dstc2[x_gid] = alpha * work2[0] + beta * dstc2[x_gid]; |
|
|
|
|
dstc3[x_gid] = alpha * work3[0] + beta * dstc3[x_gid]; |
|
|
|
|
#endif |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
@ -1320,6 +1355,16 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)( |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if(lid == 0) { |
|
|
|
|
#ifdef ZERO_BETA |
|
|
|
|
dstc0[x_gid] = alpha * work0[0]; |
|
|
|
|
dstc1[x_gid] = alpha * work1[0]; |
|
|
|
|
dstc2[x_gid] = alpha * work2[0]; |
|
|
|
|
dstc3[x_gid] = alpha * work3[0]; |
|
|
|
|
dstc4[x_gid] = alpha * work4[0]; |
|
|
|
|
dstc5[x_gid] = alpha * work5[0]; |
|
|
|
|
dstc6[x_gid] = alpha * work6[0]; |
|
|
|
|
dstc7[x_gid] = alpha * work7[0]; |
|
|
|
|
#else |
|
|
|
|
dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid]; |
|
|
|
|
dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid]; |
|
|
|
|
dstc2[x_gid] = alpha * work2[0] + beta * dstc2[x_gid]; |
|
|
|
@ -1328,6 +1373,7 @@ __kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)( |
|
|
|
|
dstc5[x_gid] = alpha * work5[0] + beta * dstc5[x_gid]; |
|
|
|
|
dstc6[x_gid] = alpha * work6[0] + beta * dstc6[x_gid]; |
|
|
|
|
dstc7[x_gid] = alpha * work7[0] + beta * dstc7[x_gid]; |
|
|
|
|
#endif |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
#undef SLM_SIZE |
|
|
|
|