@ -453,14 +453,14 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)(
int w ;
int w ;
for ( int b_tile = 0 ; b_tile < K; b_tile += SLM_BLOCK) {
for ( int b_tile = 0 ; b_tile < K; b_tile += SLM_BLOCK) {
barrier ( CLK_LOCAL_MEM_FENCE ) ;
barrier ( CLK_LOCAL_MEM_FENCE ) ;
vstore4 ( vload4 ( 0 , ( __global float * ) ( src1_read0 + mad24 ( 0 , K, local_index ) ) ) , 0 , ( __local float * ) ( slm_brow + mad24 ( 0 , SLM_BLOCK, local_index ) ) ) ;
vstore8 ( vload8 ( 0 , src1_read0 + mad24 ( 0 , K, local_index ) ) , 0 , slm_brow + mad24 ( 0 , SLM_BLOCK, local_index ) ) ;
vstore4 ( vload4 ( 0 , ( __global float * ) ( src1_read0 + mad24 ( 1 , K, local_index ) ) ) , 0 , ( __local float * ) ( slm_brow + mad24 ( 1 , SLM_BLOCK, local_index ) ) ) ;
vstore8 ( vload8 ( 0 , src1_read0 + mad24 ( 1 , K, local_index ) ) , 0 , slm_brow + mad24 ( 1 , SLM_BLOCK, local_index ) ) ;
vstore4 ( vload4 ( 0 , ( __global float * ) ( src1_read0 + mad24 ( 2 , K, local_index ) ) ) , 0 , ( __local float * ) ( slm_brow + mad24 ( 2 , SLM_BLOCK, local_index ) ) ) ;
vstore8 ( vload8 ( 0 , src1_read0 + mad24 ( 2 , K, local_index ) ) , 0 , slm_brow + mad24 ( 2 , SLM_BLOCK, local_index ) ) ;
vstore4 ( vload4 ( 0 , ( __global float * ) ( src1_read0 + mad24 ( 3 , K, local_index ) ) ) , 0 , ( __local float * ) ( slm_brow + mad24 ( 3 , SLM_BLOCK, local_index ) ) ) ;
vstore8 ( vload8 ( 0 , src1_read0 + mad24 ( 3 , K, local_index ) ) , 0 , slm_brow + mad24 ( 3 , SLM_BLOCK, local_index ) ) ;
vstore4 ( vload4 ( 0 , ( __global float * ) ( src1_read0 + mad24 ( 4 , K, local_index ) ) ) , 0 , ( __local float * ) ( slm_brow + mad24 ( 4 , SLM_BLOCK, local_index ) ) ) ;
vstore8 ( vload8 ( 0 , src1_read0 + mad24 ( 4 , K, local_index ) ) , 0 , slm_brow + mad24 ( 4 , SLM_BLOCK, local_index ) ) ;
vstore4 ( vload4 ( 0 , ( __global float * ) ( src1_read0 + mad24 ( 5 , K, local_index ) ) ) , 0 , ( __local float * ) ( slm_brow + mad24 ( 5 , SLM_BLOCK, local_index ) ) ) ;
vstore8 ( vload8 ( 0 , src1_read0 + mad24 ( 5 , K, local_index ) ) , 0 , slm_brow + mad24 ( 5 , SLM_BLOCK, local_index ) ) ;
vstore4 ( vload4 ( 0 , ( __global float * ) ( src1_read0 + mad24 ( 6 , K, local_index ) ) ) , 0 , ( __local float * ) ( slm_brow + mad24 ( 6 , SLM_BLOCK, local_index ) ) ) ;
vstore8 ( vload8 ( 0 , src1_read0 + mad24 ( 6 , K, local_index ) ) , 0 , slm_brow + mad24 ( 6 , SLM_BLOCK, local_index ) ) ;
vstore4 ( vload4 ( 0 , ( __global float * ) ( src1_read0 + mad24 ( 7 , K, local_index ) ) ) , 0 , ( __local float * ) ( slm_brow + mad24 ( 7 , SLM_BLOCK, local_index ) ) ) ;
vstore8 ( vload8 ( 0 , src1_read0 + mad24 ( 7 , K, local_index ) ) , 0 , slm_brow + mad24 ( 7 , SLM_BLOCK, local_index ) ) ;
barrier ( CLK_LOCAL_MEM_FENCE ) ;
barrier ( CLK_LOCAL_MEM_FENCE ) ;
slm_brow0 = slm_brow + local_x * ( TILE_K / 8 ) ;
slm_brow0 = slm_brow + local_x * ( TILE_K / 8 ) ;
@ -469,17 +469,17 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)(
while ( w + TILE_K <= end_w ) {
while ( w + TILE_K <= end_w ) {
Dtype8 arow ;
Dtype8 arow ;
brow0 = as_half8 ( vload4 ( 0 , ( __local float * ) ( slm_brow0 + 0 * SLM_BLOCK ) ) ) ;
brow0 = vload8 ( 0 , slm_brow0 + 0 * SLM_BLOCK ) ;
brow1 = as_half8 ( vload4 ( 0 , ( __local float * ) ( slm_brow0 + 1 * SLM_BLOCK ) ) ) ;
brow1 = vload8 ( 0 , slm_brow0 + 1 * SLM_BLOCK ) ;
brow2 = as_half8 ( vload4 ( 0 , ( __local float * ) ( slm_brow0 + 2 * SLM_BLOCK ) ) ) ;
brow2 = vload8 ( 0 , slm_brow0 + 2 * SLM_BLOCK ) ;
brow3 = as_half8 ( vload4 ( 0 , ( __local float * ) ( slm_brow0 + 3 * SLM_BLOCK ) ) ) ;
brow3 = vload8 ( 0 , slm_brow0 + 3 * SLM_BLOCK ) ;
brow4 = as_half8 ( vload4 ( 0 , ( __local float * ) ( slm_brow0 + 4 * SLM_BLOCK ) ) ) ;
brow4 = vload8 ( 0 , slm_brow0 + 4 * SLM_BLOCK ) ;
brow5 = as_half8 ( vload4 ( 0 , ( __local float * ) ( slm_brow0 + 5 * SLM_BLOCK ) ) ) ;
brow5 = vload8 ( 0 , slm_brow0 + 5 * SLM_BLOCK ) ;
brow6 = as_half8 ( vload4 ( 0 , ( __local float * ) ( slm_brow0 + 6 * SLM_BLOCK ) ) ) ;
brow6 = vload8 ( 0 , slm_brow0 + 6 * SLM_BLOCK ) ;
brow7 = as_half8 ( vload4 ( 0 , ( __local float * ) ( slm_brow0 + 7 * SLM_BLOCK ) ) ) ;
brow7 = vload8 ( 0 , slm_brow0 + 7 * SLM_BLOCK ) ;
# define MM_DOT_PRODUCT ( _row, _dot ) \
# define MM_DOT_PRODUCT ( _row, _dot ) \
arow = as_half8 ( vload4 ( 0 , ( __global float * ) ( src0_read + _row * K ) ) ) ; \
arow = vload8 ( 0 , src0_read + _row * K ) ; \
_dot = mad ( ( Dtype8 ) ( arow.s0 ) , ( Dtype8 ) ( brow0.s0, brow1.s0, brow2.s0, brow3.s0, brow4.s0, brow5.s0, brow6.s0, brow7.s0 ) , _dot ) ; \
_dot = mad ( ( Dtype8 ) ( arow.s0 ) , ( Dtype8 ) ( brow0.s0, brow1.s0, brow2.s0, brow3.s0, brow4.s0, brow5.s0, brow6.s0, brow7.s0 ) , _dot ) ; \
_dot = mad ( ( Dtype8 ) ( arow.s1 ) , ( Dtype8 ) ( brow0.s1, brow1.s1, brow2.s1, brow3.s1, brow4.s1, brow5.s1, brow6.s1, brow7.s1 ) , _dot ) ; \
_dot = mad ( ( Dtype8 ) ( arow.s1 ) , ( Dtype8 ) ( brow0.s1, brow1.s1, brow2.s1, brow3.s1, brow4.s1, brow5.s1, brow6.s1, brow7.s1 ) , _dot ) ; \
_dot = mad ( ( Dtype8 ) ( arow.s2 ) , ( Dtype8 ) ( brow0.s2, brow1.s2, brow2.s2, brow3.s2, brow4.s2, brow5.s2, brow6.s2, brow7.s2 ) , _dot ) ; \
_dot = mad ( ( Dtype8 ) ( arow.s2 ) , ( Dtype8 ) ( brow0.s2, brow1.s2, brow2.s2, brow3.s2, brow4.s2, brow5.s2, brow6.s2, brow7.s2 ) , _dot ) ; \
@ -510,7 +510,7 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)(
Dtype8 arow ;
Dtype8 arow ;
# define READ_BROW ( _brow, _row ) \
# define READ_BROW ( _brow, _row ) \
_brow = as_half8 ( vload4 ( 0 , ( __local float * ) ( slm_brow0 + _row * SLM_BLOCK ) ) ) ; \
_brow = vload8 ( 0 , slm_brow0 + _row * SLM_BLOCK ) ; \
_brow.s0 = ( mad24 ( local_x, 8 , w ) < K ) ? _brow.s0 : 0.0f ; \
_brow.s0 = ( mad24 ( local_x, 8 , w ) < K ) ? _brow.s0 : 0.0f ; \
_brow.s1 = ( mad24 ( local_x, 8 , w + 1 ) < K ) ? _brow.s1 : 0.0f ; \
_brow.s1 = ( mad24 ( local_x, 8 , w + 1 ) < K ) ? _brow.s1 : 0.0f ; \
_brow.s2 = ( mad24 ( local_x, 8 , w + 2 ) < K ) ? _brow.s2 : 0.0f ; \
_brow.s2 = ( mad24 ( local_x, 8 , w + 2 ) < K ) ? _brow.s2 : 0.0f ; \
@ -532,7 +532,7 @@ __kernel void TEMPLATE(gemm_buffer_NT, Dtype)(
# undef READ_BROW
# undef READ_BROW
# define MM_DOT_PRODUCT ( _row, _dot ) \
# define MM_DOT_PRODUCT ( _row, _dot ) \
arow = as_half8 ( vload4 ( 0 , ( __global float * ) ( src0_read + _row * K ) ) ) ; \
arow = vload8 ( 0 , src0_read + _row * K ) ; \
arow.s0 = ( mad24 ( local_x, 8 , w ) < K ) ? arow.s0 : 0.0f ; \
arow.s0 = ( mad24 ( local_x, 8 , w ) < K ) ? arow.s0 : 0.0f ; \
arow.s1 = ( mad24 ( local_x, 8 , w + 1 ) < K ) ? arow.s1 : 0.0f ; \
arow.s1 = ( mad24 ( local_x, 8 , w + 1 ) < K ) ? arow.s1 : 0.0f ; \
arow.s2 = ( mad24 ( local_x, 8 , w + 2 ) < K ) ? arow.s2 : 0.0f ; \
arow.s2 = ( mad24 ( local_x, 8 , w + 2 ) < K ) ? arow.s2 : 0.0f ; \