Open Source Computer Vision Library https://opencv.org/
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

1004 lines
45 KiB

/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
//
//
// License Agreement
// For Open Source Computer Vision Library
//
// Copyright (C) 2017, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// * Redistribution's of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistribution's in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * The name of the copyright holders may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/
#if defined(cl_khr_fp16)
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#endif
#define CONCAT(A,B) A##_##B
#define TEMPLATE(name,type) CONCAT(name,type)
#define KERNEL_ARG_DTYPE float
#define TYPE_FLOAT 1
#define TYPE_HALF 2
#if TYPE == TYPE_HALF
#define Dtype half
#define Dtype2 half2
#define Dtype4 half4
#define Dtype8 half8
#define Dtype16 half16
#define as_Dtype as_half
#define as_Dtype2 as_half2
#define as_Dtype4 as_half4
#define as_Dtype8 as_half8
#define as_Dtype16 as_half16
#else
#define Dtype float
#define Dtype2 float2
#define Dtype4 float4
#define Dtype8 float8
#define Dtype16 float16
#define as_Dtype as_float
#define as_Dtype2 as_float2
#define as_Dtype4 as_float4
#define as_Dtype8 as_float8
#define as_Dtype16 as_float16
#endif
#if defined(cl_intel_subgroups)
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#endif
#define TILE_M 32
#define TILE_K 8
// common block to calculate (alpha * AxB + beta * C) and output to destination image.
#if TYPE == TYPE_HALF
#define SUBGROUP_BLOCK_READ8( __image, __coord ) intel_sub_group_block_read_us8( __image, __coord )
#define SHUFFLE_TYPE2(val) as_ushort2(val)
#define SHUFFLE_TYPE8(val) as_ushort8(val)
#define READ_IMAGE(__image, __coord) read_imageh(__image, sampler, __coord)
#define SIZE_OF_ELEMENT sizeof(ushort)
#define SIMD_SIZE_GEMM 16
#define TILE_N 16
#else
#define SUBGROUP_BLOCK_READ8( __image, __coord ) intel_sub_group_block_read8( __image, __coord )
#define SHUFFLE_TYPE2(val) val
#define SHUFFLE_TYPE8(val) val
#define READ_IMAGE(__image, __coord) read_imagef(__image, sampler, __coord)
#define SIZE_OF_ELEMENT sizeof(uint)
#define SIMD_SIZE_GEMM 8
#define TILE_N 8
#endif
//#define USE_IMAGE_C
#ifdef USE_IMAGE_C
#if TYPE == TYPE_HALF
#define BLOCKC_READ8( _C, _coordC ) as_Dtype8( intel_sub_group_block_read_us8( _C, _coordC ) )
#define BLOCKC_WRITE8( _C, _coordC, _val ) intel_sub_group_block_write_us8( _C, _coordC, as_ushort8( _val ) )
#else
#define BLOCKC_READ8( _C, _coordC ) as_Dtype8( intel_sub_group_block_read8( _C, _coordC ) )
#define BLOCKC_WRITE8( _C, _coordC, _val ) intel_sub_group_block_write8( _C, _coordC, as_uint8( _val ) )
#endif
#define MATC_PARAMETER __read_only image2d_t C, __write_only image2d_t dst
#define GEMM_OUTPUT(ALPHA1, BETA_NOT0) GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, C, dst, sizeof(uint))
#else
#define BLOCKC_READ8( _C, _coordC ) \
(Dtype8) ( (_coordC.x + get_local_id(0) < N && _coordC.y < M) ? _C[ _coordC.y * ldc + _coordC.x + get_local_id(0) ] : 0, \
(_coordC.x + get_local_id(0) < N && _coordC.y + 1 < M) ? _C[ ( _coordC.y + 1 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \
(_coordC.x + get_local_id(0) < N && _coordC.y + 2 < M) ? _C[ ( _coordC.y + 2 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \
(_coordC.x + get_local_id(0) < N && _coordC.y + 3 < M) ? _C[ ( _coordC.y + 3 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \
(_coordC.x + get_local_id(0) < N && _coordC.y + 4 < M) ? _C[ ( _coordC.y + 4 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \
(_coordC.x + get_local_id(0) < N && _coordC.y + 5 < M) ? _C[ ( _coordC.y + 5 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \
(_coordC.x + get_local_id(0) < N && _coordC.y + 6 < M) ? _C[ ( _coordC.y + 6 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \
(_coordC.x + get_local_id(0) < N && _coordC.y + 7 < M) ? _C[ ( _coordC.y + 7 ) * ldc + _coordC.x + get_local_id(0) ] : 0)
#define BLOCKC_WRITE8( _C, _coordC, _val) do {\
if (_coordC.x + get_local_id(0) < N) { \
if (_coordC.y < M) \
_C[ _coordC.y * ldc + _coordC.x + get_local_id(0) ] = _val.s0; \
if (_coordC.y + 1 < M) \
_C[ ( _coordC.y + 1 )* ldc + _coordC.x + get_local_id(0) ] = _val.s1; \
if (_coordC.y + 2 < M) \
_C[ ( _coordC.y + 2 )* ldc + _coordC.x + get_local_id(0) ] = _val.s2; \
if (_coordC.y + 3 < M) \
_C[ ( _coordC.y + 3 )* ldc + _coordC.x + get_local_id(0) ] = _val.s3; \
if (_coordC.y + 4 < M) \
_C[ ( _coordC.y + 4 )* ldc + _coordC.x + get_local_id(0) ] = _val.s4; \
if (_coordC.y + 5 < M) \
_C[ ( _coordC.y + 5 )* ldc + _coordC.x + get_local_id(0) ] = _val.s5; \
if (_coordC.y + 6 < M) \
_C[ ( _coordC.y + 6 )* ldc + _coordC.x + get_local_id(0) ] = _val.s6; \
if (_coordC.y + 7 < M) \
_C[ ( _coordC.y + 7 )* ldc + _coordC.x + get_local_id(0) ] = _val.s7; \
}} while(0)
#define MATC_PARAMETER __global Dtype * C, const int offC, const int M, const int N, const int ldc
#define GEMM_OUTPUT(ALPHA1, BETA_NOT0) GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, (C + offC), (C + offC), 1)
#endif
#define GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, _C, _dst, _C_step) \
int2 coordDst = (int2)( ( group_x * TILE_N ) * _C_step, ( group_y * TILE_M ) ); \
int2 coordC = coordDst; \
Dtype8 blockC00; \
Dtype8 blockC01; \
Dtype8 blockC02; \
Dtype8 blockC03; \
if (BETA_NOT0) { \
blockC00 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \
blockC01 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \
blockC02 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \
blockC03 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); \
if (!ALPHA1) { \
blockC00 = mad(blockAxB00, (Dtype8)alpha, blockC00); \
blockC01 = mad(blockAxB01, (Dtype8)alpha, blockC01); \
blockC02 = mad(blockAxB02, (Dtype8)alpha, blockC02); \
blockC03 = mad(blockAxB03, (Dtype8)alpha, blockC03); \
} else { \
blockC00 += blockAxB00; \
blockC01 += blockAxB01; \
blockC02 += blockAxB02; \
blockC03 += blockAxB03; \
} \
} else { \
blockC00 = isFirstColBlock ? (Dtype)0. : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \
blockC01 = isFirstColBlock ? (Dtype)0. : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \
blockC02 = isFirstColBlock ? (Dtype)0. : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \
blockC03 = isFirstColBlock ? (Dtype)0. : BLOCKC_READ8( _C, coordC ); \
if (!ALPHA1) { \
blockC00 = mad(blockAxB00, (Dtype8)alpha, blockC00); \
blockC01 = mad(blockAxB01, (Dtype8)alpha, blockC01); \
blockC02 = mad(blockAxB02, (Dtype8)alpha, blockC02); \
blockC03 = mad(blockAxB03, (Dtype8)alpha, blockC03); \
} else { \
blockC00 += blockAxB00; \
blockC01 += blockAxB01; \
blockC02 += blockAxB02; \
blockC03 += blockAxB03; \
} \
} \
BLOCKC_WRITE8( _dst, coordDst, blockC00 ); coordDst.y += 8; \
BLOCKC_WRITE8( _dst, coordDst, blockC01 ); coordDst.y += 8; \
BLOCKC_WRITE8( _dst, coordDst, blockC02 ); coordDst.y += 8; \
BLOCKC_WRITE8( _dst, coordDst, blockC03 );
// Get the specified column of the block of the block
#define TRANSPOSE_BLOCK_8( _block, _col ) \
(Dtype8)( intel_sub_group_shuffle( _block.s0, _col ), \
intel_sub_group_shuffle( _block.s1, _col ), \
intel_sub_group_shuffle( _block.s2, _col ), \
intel_sub_group_shuffle( _block.s3, _col ), \
intel_sub_group_shuffle( _block.s4, _col ), \
intel_sub_group_shuffle( _block.s5, _col ), \
intel_sub_group_shuffle( _block.s6, _col ), \
intel_sub_group_shuffle( _block.s7, _col ) );
// A's column block multiply B 's row block.
#if TYPE == TYPE_HALF
#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB00, _blockB01 ) \
{ \
const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \
const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \
const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \
const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \
const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \
const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \
const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \
const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \
const Dtype8 acol8 = TRANSPOSE_BLOCK_8( _blockA, 8 ); \
const Dtype8 acol9 = TRANSPOSE_BLOCK_8( _blockA, 9 ); \
const Dtype8 acola = TRANSPOSE_BLOCK_8( _blockA, 10 ); \
const Dtype8 acolb = TRANSPOSE_BLOCK_8( _blockA, 11 ); \
const Dtype8 acolc = TRANSPOSE_BLOCK_8( _blockA, 12 ); \
const Dtype8 acold = TRANSPOSE_BLOCK_8( _blockA, 13 ); \
const Dtype8 acole = TRANSPOSE_BLOCK_8( _blockA, 14 ); \
const Dtype8 acolf = TRANSPOSE_BLOCK_8( _blockA, 15 ); \
_result = mad( (Dtype8)(_blockB00.s0), acol0, _result ); \
_result = mad( (Dtype8)(_blockB00.s1), acol1, _result ); \
_result = mad( (Dtype8)(_blockB00.s2), acol2, _result ); \
_result = mad( (Dtype8)(_blockB00.s3), acol3, _result ); \
_result = mad( (Dtype8)(_blockB00.s4), acol4, _result ); \
_result = mad( (Dtype8)(_blockB00.s5), acol5, _result ); \
_result = mad( (Dtype8)(_blockB00.s6), acol6, _result ); \
_result = mad( (Dtype8)(_blockB00.s7), acol7, _result ); \
_result = mad( (Dtype8)(_blockB01.s0), acol8, _result ); \
_result = mad( (Dtype8)(_blockB01.s1), acol9, _result ); \
_result = mad( (Dtype8)(_blockB01.s2), acola, _result ); \
_result = mad( (Dtype8)(_blockB01.s3), acolb, _result ); \
_result = mad( (Dtype8)(_blockB01.s4), acolc, _result ); \
_result = mad( (Dtype8)(_blockB01.s5), acold, _result ); \
_result = mad( (Dtype8)(_blockB01.s6), acole, _result ); \
_result = mad( (Dtype8)(_blockB01.s7), acolf, _result ); \
}
#else
#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \
{ \
const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \
const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \
const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \
const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \
const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \
const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \
const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \
const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \
_result = mad( (Dtype8)(_blockB.s0), acol0, _result ); \
_result = mad( (Dtype8)(_blockB.s1), acol1, _result ); \
_result = mad( (Dtype8)(_blockB.s2), acol2, _result ); \
_result = mad( (Dtype8)(_blockB.s3), acol3, _result ); \
_result = mad( (Dtype8)(_blockB.s4), acol4, _result ); \
_result = mad( (Dtype8)(_blockB.s5), acol5, _result ); \
_result = mad( (Dtype8)(_blockB.s6), acol6, _result ); \
_result = mad( (Dtype8)(_blockB.s7), acol7, _result ); \
}
#endif
#if TYPE == TYPE_HALF
#define GEMM_NN(ALPHA1, BETA_NOT0) \
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
__kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \
__read_only image2d_t A, \
__read_only image2d_t B, \
MATC_PARAMETER, \
KERNEL_ARG_DTYPE alpha_in, \
KERNEL_ARG_DTYPE beta_in, \
int width0, \
int isFirstColBlock) \
{ \
const Dtype alpha = (Dtype)alpha_in; \
const Dtype beta = (Dtype)beta_in; \
const int group_x = get_group_id(0); \
const int group_y = get_group_id(1); \
Dtype8 blockAxB00 = 0; \
Dtype8 blockAxB01 = 0; \
Dtype8 blockAxB02 = 0; \
Dtype8 blockAxB03 = 0; \
int2 coordA = (int2)( 0, group_y * TILE_M ); \
int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); \
do \
{ \
int2 coordBTemp = coordB; \
Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; \
Dtype8 blockB01 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; \
int2 coordATemp = coordA; \
Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT * 2; \
MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, blockB01 ); \
MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00, blockB01 ); \
MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00, blockB01 ); \
MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00, blockB01 ); \
} \
while( coordB.y < width0 ); \
GEMM_OUTPUT(ALPHA1, BETA_NOT0); \
}
#else
#define GEMM_NN(ALPHA1, BETA_NOT0) \
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
__kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \
__read_only image2d_t A, \
__read_only image2d_t B, \
MATC_PARAMETER, \
KERNEL_ARG_DTYPE alpha_in, \
KERNEL_ARG_DTYPE beta_in, \
int width0, \
int isFirstColBlock) \
{ \
const Dtype alpha = (Dtype)alpha_in; \
const Dtype beta = (Dtype)beta_in; \
const int group_x = get_group_id(0); \
const int group_y = get_group_id(1); \
Dtype8 blockAxB00 = 0.0f; \
Dtype8 blockAxB01 = 0.0f; \
Dtype8 blockAxB02 = 0.0f; \
Dtype8 blockAxB03 = 0.0f; \
int2 coordA = (int2)( 0, group_y * TILE_M ); \
int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); \
do \
{ \
int2 coordBTemp = coordB; \
Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; \
int2 coordATemp = coordA; \
Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT; \
MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \
MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \
MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \
MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); \
} \
while( coordB.y < width0 ); \
GEMM_OUTPUT(ALPHA1, BETA_NOT0); \
}
#endif
GEMM_NN(1, 0) // ALPHA == 1, BETA == 0
GEMM_NN(1, 1) // ALPHA == 1, BETA != 0
GEMM_NN(0, 0) // ALPHA != 1, BETA == 0
GEMM_NN(0, 1) // ALPHA != 1, BETA != 0
#undef TRANSPOSE_BLOCK_8
#undef MULTIPLY_BLOCKS_8x8
#undef GEMM_NN
// replicate the first row to column block.
#define TRANSPOSE_BLOCK_8(_vec, _col) \
(Dtype8)( intel_sub_group_shuffle(_vec, _col + 0), \
intel_sub_group_shuffle(_vec, _col + 1), \
intel_sub_group_shuffle(_vec, _col + 2), \
intel_sub_group_shuffle(_vec, _col + 3), \
intel_sub_group_shuffle(_vec, _col + 4), \
intel_sub_group_shuffle(_vec, _col + 5), \
intel_sub_group_shuffle(_vec, _col + 6), \
intel_sub_group_shuffle(_vec, _col + 7) )
#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB, _col ) \
{ \
_result = mad( (Dtype8)(_blockB.s0), TRANSPOSE_BLOCK_8(_blockA.s0, _col), _result ); \
_result = mad( (Dtype8)(_blockB.s1), TRANSPOSE_BLOCK_8(_blockA.s1, _col), _result ); \
_result = mad( (Dtype8)(_blockB.s2), TRANSPOSE_BLOCK_8(_blockA.s2, _col), _result ); \
_result = mad( (Dtype8)(_blockB.s3), TRANSPOSE_BLOCK_8(_blockA.s3, _col), _result ); \
_result = mad( (Dtype8)(_blockB.s4), TRANSPOSE_BLOCK_8(_blockA.s4, _col), _result ); \
_result = mad( (Dtype8)(_blockB.s5), TRANSPOSE_BLOCK_8(_blockA.s5, _col), _result ); \
_result = mad( (Dtype8)(_blockB.s6), TRANSPOSE_BLOCK_8(_blockA.s6, _col), _result ); \
_result = mad( (Dtype8)(_blockB.s7), TRANSPOSE_BLOCK_8(_blockA.s7, _col), _result ); \
}
#if TYPE == TYPE_HALF
#define GEMM_TN(ALPHA1, BETA_NOT0) \
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
__kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \
__read_only image2d_t A, \
__read_only image2d_t B, \
MATC_PARAMETER, \
KERNEL_ARG_DTYPE alpha_in, \
KERNEL_ARG_DTYPE beta_in, \
int width0, \
int isFirstColBlock) \
{ \
const Dtype alpha = (Dtype)alpha_in; \
const Dtype beta = (Dtype)beta_in; \
const int group_x = get_group_id(0);\
const int group_y = get_group_id(1);\
Dtype8 blockAxB00 = 0;\
Dtype8 blockAxB01 = 0;\
Dtype8 blockAxB02 = 0;\
Dtype8 blockAxB03 = 0;\
int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 );\
int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 );\
do\
{\
int2 coordBTemp = coordB;\
Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K;\
int2 coordATemp = coordA;\
Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 16 * SIZE_OF_ELEMENT;\
Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K;\
MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0); \
MULTIPLY_BLOCKS_8x8( blockAxB01, blockA00, blockB00, 8); \
MULTIPLY_BLOCKS_8x8( blockAxB02, blockA01, blockB00, 0); \
MULTIPLY_BLOCKS_8x8( blockAxB03, blockA01, blockB00, 8); \
} \
while( coordB.y < width0 ); \
GEMM_OUTPUT(ALPHA1, BETA_NOT0); \
}
#else
#define GEMM_TN(ALPHA1, BETA_NOT0) \
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
__kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \
__read_only image2d_t A, \
__read_only image2d_t B, \
MATC_PARAMETER, \
KERNEL_ARG_DTYPE alpha_in, \
KERNEL_ARG_DTYPE beta_in, \
int width0, \
int isFirstColBlock) \
{ \
const Dtype alpha = (Dtype)alpha_in; \
const Dtype beta = (Dtype)beta_in; \
const int group_x = get_group_id(0);\
const int group_y = get_group_id(1);\
Dtype8 blockAxB00 = 0.0f;\
Dtype8 blockAxB01 = 0.0f;\
Dtype8 blockAxB02 = 0.0f;\
Dtype8 blockAxB03 = 0.0f;\
int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 );\
int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 );\
do\
{\
int2 coordBTemp = coordB;\
Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K;\
int2 coordATemp = coordA;\
Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT;\
Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT;\
Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT;\
Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K;\
MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0 ); \
MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00, 0 ); \
MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00, 0 ); \
MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00, 0 ); \
} \
while( coordB.y < width0 ); \
GEMM_OUTPUT(ALPHA1, BETA_NOT0); \
}
#endif
GEMM_TN(1, 0) // ALPHA == 1, BETA == 0
GEMM_TN(1, 1) // ALPHA == 1, BETA != 0
GEMM_TN(0, 0) // ALPHA != 1, BETA == 0
GEMM_TN(0, 1) // ALPHA != 1, BETA != 0
#undef MULTIPLY_BLOCKS_8x8
#undef TRANSPOSE_BLOCK_8
#undef GEMM_TN
// The same as GEMM_NN
#define TRANSPOSE_BLOCK_8( _block, _col ) \
(Dtype8)( intel_sub_group_shuffle( _block.s0, _col), \
intel_sub_group_shuffle( _block.s1, _col), \
intel_sub_group_shuffle( _block.s2, _col), \
intel_sub_group_shuffle( _block.s3, _col), \
intel_sub_group_shuffle( _block.s4, _col), \
intel_sub_group_shuffle( _block.s5, _col), \
intel_sub_group_shuffle( _block.s6, _col), \
intel_sub_group_shuffle( _block.s7, _col) )
#if TYPE == TYPE_HALF
#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \
{ \
const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \
const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \
const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \
const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \
const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \
const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \
const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \
const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \
const Dtype8 acol8 = TRANSPOSE_BLOCK_8( _blockA, 8 ); \
const Dtype8 acol9 = TRANSPOSE_BLOCK_8( _blockA, 9 ); \
const Dtype8 acola = TRANSPOSE_BLOCK_8( _blockA, 10 ); \
const Dtype8 acolb = TRANSPOSE_BLOCK_8( _blockA, 11 ); \
const Dtype8 acolc = TRANSPOSE_BLOCK_8( _blockA, 12 ); \
const Dtype8 acold = TRANSPOSE_BLOCK_8( _blockA, 13 ); \
const Dtype8 acole = TRANSPOSE_BLOCK_8( _blockA, 14 ); \
const Dtype8 acolf = TRANSPOSE_BLOCK_8( _blockA, 15 ); \
_result = mad( (Dtype8)_blockB.s0, acol0, _result ); \
_result = mad( (Dtype8)_blockB.s1, acol1, _result ); \
_result = mad( (Dtype8)_blockB.s2, acol2, _result ); \
_result = mad( (Dtype8)_blockB.s3, acol3, _result ); \
_result = mad( (Dtype8)_blockB.s4, acol4, _result ); \
_result = mad( (Dtype8)_blockB.s5, acol5, _result ); \
_result = mad( (Dtype8)_blockB.s6, acol6, _result ); \
_result = mad( (Dtype8)_blockB.s7, acol7, _result ); \
_result = mad( (Dtype8)_blockB.s8, acol8, _result ); \
_result = mad( (Dtype8)_blockB.s9, acol9, _result ); \
_result = mad( (Dtype8)_blockB.sa, acola, _result ); \
_result = mad( (Dtype8)_blockB.sb, acolb, _result ); \
_result = mad( (Dtype8)_blockB.sc, acolc, _result ); \
_result = mad( (Dtype8)_blockB.sd, acold, _result ); \
_result = mad( (Dtype8)_blockB.se, acole, _result ); \
_result = mad( (Dtype8)_blockB.sf, acolf, _result ); \
}
#else
#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \
{ \
const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \
const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \
const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \
const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \
const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \
const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \
const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \
const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \
_result = mad( (Dtype8)_blockB.s0, acol0, _result ); \
_result = mad( (Dtype8)_blockB.s1, acol1, _result ); \
_result = mad( (Dtype8)_blockB.s2, acol2, _result ); \
_result = mad( (Dtype8)_blockB.s3, acol3, _result ); \
_result = mad( (Dtype8)_blockB.s4, acol4, _result ); \
_result = mad( (Dtype8)_blockB.s5, acol5, _result ); \
_result = mad( (Dtype8)_blockB.s6, acol6, _result ); \
_result = mad( (Dtype8)_blockB.s7, acol7, _result ); \
}
#endif
#if TYPE == TYPE_HALF
#define GEMM_NT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
__kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \
__read_only image2d_t A, \
MATB_PARAMETER, \
MATC_PARAMETER, \
KERNEL_ARG_DTYPE alpha_in, \
KERNEL_ARG_DTYPE beta_in, \
int padded_k, \
int k, \
int isFirstColBlock) \
{ \
const Dtype alpha = (Dtype)alpha_in; \
const Dtype beta = (Dtype)beta_in; \
const int group_x = get_group_id(0); \
const int group_y = get_group_id(1); \
Dtype8 blockAxB00 = 0; \
Dtype8 blockAxB01 = 0; \
Dtype8 blockAxB02 = 0; \
Dtype8 blockAxB03 = 0; \
int2 coordA = (int2)( 0, group_y * TILE_M ); \
int2 coordB = (int2)( 0, ( group_x * TILE_N )); \
const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \
do \
{ \
Dtype16 blockB00; \
BLOCKB_READ8(blockB00, B, coordB); \
int2 coordATemp = coordA; \
Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT * 2; \
MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \
MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \
MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \
MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); \
} \
while( coordB.x < padded_k / VECSIZE ); \
GEMM_OUTPUT(ALPHA1, BETA_NOT0); \
}
#else
#define GEMM_NT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
__kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \
__read_only image2d_t A, \
MATB_PARAMETER, \
MATC_PARAMETER, \
KERNEL_ARG_DTYPE alpha_in, \
KERNEL_ARG_DTYPE beta_in, \
int padded_k, \
int k, \
int isFirstColBlock) \
{ \
const Dtype alpha = (Dtype)alpha_in; \
const Dtype beta = (Dtype)beta_in; \
const int group_x = get_group_id(0); \
const int group_y = get_group_id(1); \
Dtype8 blockAxB00 = 0.0f; \
Dtype8 blockAxB01 = 0.0f; \
Dtype8 blockAxB02 = 0.0f; \
Dtype8 blockAxB03 = 0.0f; \
int2 coordA = (int2)( 0, group_y * TILE_M ); \
int2 coordB = (int2)( 0, ( group_x * TILE_N )); \
const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \
do \
{ \
Dtype8 blockB00; \
BLOCKB_READ8(blockB00, B, coordB); \
int2 coordATemp = coordA; \
Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT; \
MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \
MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \
MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \
MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); \
} \
while( coordB.x < padded_k / VECSIZE ); \
GEMM_OUTPUT(ALPHA1, BETA_NOT0); \
}
#endif
#if TYPE == TYPE_HALF
#define BLOCKB_READ8(_blockb, _B, _coordB) \
int2 _coordBTemp = _coordB; \
_coordBTemp.y += get_local_id(0); \
_blockb.s0123 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s4567 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s89ab = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.scdef = READ_IMAGE(_B, _coordBTemp); _coordB.x += 4;
#else
#define BLOCKB_READ8(_blockb, _B, _coordB) \
int2 _coordBTemp = _coordB; \
_coordBTemp.y += get_local_id(0); \
_blockb.s0123 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s4567 = READ_IMAGE(_B, _coordBTemp); _coordB.x += 2;
#endif
#define MATB_PARAMETER __read_only image2d_t B
GEMM_NT(1, 0, VEC4, 4) // ALPHA == 1, BETA == 0
GEMM_NT(1, 1, VEC4, 4) // ALPHA == 1, BETA != 0
GEMM_NT(0, 0, VEC4, 4) // ALPHA != 1, BETA == 0
GEMM_NT(0, 1, VEC4, 4) // ALPHA != 1, BETA != 0
#undef BLOCKB_READ8
#undef MATB_PARAMETER
#if TYPE == TYPE_HALF
#define BLOCKB_READ8(_blockb, _B, _coordB) \
int2 _coordBTemp = _coordB; \
_coordBTemp.y += get_local_id(0); \
const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * ldb) + _coordBTemp.x + offB); \
_blockb = as_Dtype16(as_ushort16(vload8(0, B_read))); \
_coordB.x += TILE_K * 2;
#else
#define BLOCKB_READ8(_blockb, _B, _coordB) \
int2 _coordBTemp = _coordB; \
_coordBTemp.y += get_local_id(0); \
const __global Dtype *B_read = (__global Dtype *)(_B + (_coordBTemp.y * ldb) + _coordBTemp.x + offB); \
_blockb = vload8(0, B_read); \
_coordB.x += TILE_K;
#endif
#define MATB_PARAMETER __global Dtype *B, int offB, int ldb
GEMM_NT(1, 0, BUFFER, 1) // ALPHA == 1, BETA == 0
GEMM_NT(1, 1, BUFFER, 1) // ALPHA == 1, BETA != 0
GEMM_NT(0, 0, BUFFER, 1) // ALPHA != 1, BETA == 0
GEMM_NT(0, 1, BUFFER, 1) // ALPHA != 1, BETA != 0
#undef BLOCKB_READ8
#undef MATB_PARAMETER
#if TYPE == TYPE_HALF
#define BLOCKB_READ8(_blockb, _B, _coordB) \
int2 _coordBTemp = _coordB; \
_coordBTemp.y += get_local_id(0); \
Dtype4 temp; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s0 = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s1 = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s2 = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s3 = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s4 = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s5 = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s6 = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s7 = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s8 = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s9 = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.sa = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.sb = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.sc = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.sd = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.se = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.sf = temp.s0; \
_coordB.x += 16;
#else
#define BLOCKB_READ8(_blockb, _B, _coordB) \
int2 _coordBTemp = _coordB; \
_coordBTemp.y += get_local_id(0); \
Dtype4 temp; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s0 = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s1 = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s2 = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s3 = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s4 = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s5 = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s6 = temp.s0; \
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s7 = temp.s0; \
_coordB.x += 8;
#endif
#define MATB_PARAMETER __read_only image2d_t B
GEMM_NT(1, 0, SCALAR, 1) // ALPHA == 1, BETA == 0
GEMM_NT(1, 1, SCALAR, 1) // ALPHA == 1, BETA != 0
GEMM_NT(0, 0, SCALAR, 1) // ALPHA != 1, BETA == 0
GEMM_NT(0, 1, SCALAR, 1) // ALPHA != 1, BETA != 0
#undef BLOCKB_READ8
#undef MATB_PARAMETER
#undef MULTIPLY_BLOCKS_8x8
#undef TRANSPOSE_BLOCK_8
#undef GEMM_NT
//The same as GEMM_TN.
#define TRANSPOSE_BLOCK_8(_vec, _col) \
(Dtype8)( intel_sub_group_shuffle(_vec, _col + 0), \
intel_sub_group_shuffle(_vec, _col + 1), \
intel_sub_group_shuffle(_vec, _col + 2), \
intel_sub_group_shuffle(_vec, _col + 3), \
intel_sub_group_shuffle(_vec, _col + 4), \
intel_sub_group_shuffle(_vec, _col + 5), \
intel_sub_group_shuffle(_vec, _col + 6), \
intel_sub_group_shuffle(_vec, _col + 7) );
#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB, _col ) \
{ \
const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA.s0, _col ); \
const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA.s1, _col ); \
const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA.s2, _col ); \
const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA.s3, _col ); \
const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA.s4, _col ); \
const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA.s5, _col ); \
const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA.s6, _col ); \
const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA.s7, _col ); \
_result = mad( (Dtype8)_blockB.s0, acol0, _result ); \
_result = mad( (Dtype8)_blockB.s1, acol1, _result ); \
_result = mad( (Dtype8)_blockB.s2, acol2, _result ); \
_result = mad( (Dtype8)_blockB.s3, acol3, _result ); \
_result = mad( (Dtype8)_blockB.s4, acol4, _result ); \
_result = mad( (Dtype8)_blockB.s5, acol5, _result ); \
_result = mad( (Dtype8)_blockB.s6, acol6, _result ); \
_result = mad( (Dtype8)_blockB.s7, acol7, _result ); \
}
#if TYPE == TYPE_HALF
#define GEMM_TT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
__kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \
__read_only image2d_t A, \
MATB_PARAMETER, \
MATC_PARAMETER, \
KERNEL_ARG_DTYPE alpha_in, \
KERNEL_ARG_DTYPE beta_in, \
int padded_k, \
int k, \
int isFirstColBlock) \
{ \
const Dtype alpha = (Dtype)alpha_in; \
const Dtype beta = (Dtype)beta_in; \
const int group_x = get_group_id(0); \
const int group_y = get_group_id(1); \
Dtype8 blockAxB00 = 0; \
Dtype8 blockAxB01 = 0; \
Dtype8 blockAxB02 = 0; \
Dtype8 blockAxB03 = 0; \
int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); \
int2 coordB = (int2)( 0, ( group_x * TILE_N )); \
const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \
do \
{ \
Dtype8 blockB00; \
BLOCKB_READ8(blockB00, B, coordB); \
int2 coordATemp = coordA; \
Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 16 * SIZE_OF_ELEMENT;\
Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K;\
MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0); \
MULTIPLY_BLOCKS_8x8( blockAxB01, blockA00, blockB00, 8); \
MULTIPLY_BLOCKS_8x8( blockAxB02, blockA01, blockB00, 0); \
MULTIPLY_BLOCKS_8x8( blockAxB03, blockA01, blockB00, 8); \
} \
while( coordB.x < padded_k / VECSIZE ); \
GEMM_OUTPUT(ALPHA1, BETA_NOT0);\
}
#else
#define GEMM_TT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
__kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \
__read_only image2d_t A, \
MATB_PARAMETER, \
MATC_PARAMETER, \
KERNEL_ARG_DTYPE alpha_in, \
KERNEL_ARG_DTYPE beta_in, \
int padded_k, \
int k, \
int isFirstColBlock) \
{ \
const Dtype alpha = (Dtype)alpha_in; \
const Dtype beta = (Dtype)beta_in; \
const int group_x = get_group_id(0); \
const int group_y = get_group_id(1); \
Dtype8 blockAxB00 = 0.0f; \
Dtype8 blockAxB01 = 0.0f; \
Dtype8 blockAxB02 = 0.0f; \
Dtype8 blockAxB03 = 0.0f; \
int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); \
int2 coordB = (int2)( 0, ( group_x * TILE_N )); \
const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \
do \
{ \
Dtype8 blockB00; \
BLOCKB_READ8(blockB00, B, coordB); \
int2 coordATemp = coordA; \
Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; \
Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; \
Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; \
Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K; \
MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00 , blockB00, 0 ); \
MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01 , blockB00, 0 ); \
MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02 , blockB00, 0 ); \
MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03 , blockB00, 0 ); \
} \
while( coordB.x < padded_k / VECSIZE ); \
GEMM_OUTPUT(ALPHA1, BETA_NOT0);\
}
#endif
#define BLOCKB_READ8(_blockb, _B, _coordB) \
int2 _coordBTemp = _coordB; \
_coordBTemp.y += get_local_id(0); \
_blockb.s0123 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s4567 = READ_IMAGE(_B, _coordBTemp); _coordB.x += 2;
#define MATB_PARAMETER __read_only image2d_t B
GEMM_TT(1, 0, VEC4, 4) // ALPHA == 1, BETA == 0
GEMM_TT(1, 1, VEC4, 4) // ALPHA == 1, BETA != 0
GEMM_TT(0, 0, VEC4, 4) // ALPHA != 1, BETA == 0
GEMM_TT(0, 1, VEC4, 4) // ALPHA != 1, BETA != 0
#undef BLOCKB_READ8
#undef MATB_PARAMETER
#if TYPE == TYPE_HALF
#define BLOCKB_READ8(_blockb, _B, _coordB) \
int2 _coordBTemp = _coordB; \
_coordBTemp.y += get_local_id(0); \
const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * k) + _coordBTemp.x + offB); \
_blockb = as_Dtype8(as_ushort8(vload4(0, B_read))); \
_coordB.x += TILE_K;
#else
#define BLOCKB_READ8(_blockb, _B, _coordB) \
int2 _coordBTemp = _coordB; \
_coordBTemp.y += get_local_id(0); \
const __global Dtype *B_read = (__global Dtype *)(_B + (_coordBTemp.y * k) + _coordBTemp.x + offB); \
_blockb = vload8(0, B_read); \
_coordB.x += TILE_K;
#endif
#define MATB_PARAMETER __global Dtype *B, int offB, int ldb
GEMM_TT(1, 0, BUFFER, 1) // ALPHA == 1, BETA == 0
GEMM_TT(1, 1, BUFFER, 1) // ALPHA == 1, BETA != 0
GEMM_TT(0, 0, BUFFER, 1) // ALPHA != 1, BETA == 0
GEMM_TT(0, 1, BUFFER, 1) // ALPHA != 1, BETA != 0
#undef BLOCKB_READ8
#undef MATB_PARAMETER
#define BLOCKB_READ8(_blockb, _B, _coordB) \
int2 _coordBTemp = _coordB; \
_coordBTemp.y += get_local_id(0); \
Dtype4 temp; \
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s0 = temp.s0; \
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s1 = temp.s0; \
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s2 = temp.s0; \
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s3 = temp.s0; \
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s4 = temp.s0; \
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s5 = temp.s0; \
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s6 = temp.s0; \
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \
_blockb.s7 = temp.s0; \
_coordB.x += 8;
#define MATB_PARAMETER __read_only image2d_t B
GEMM_TT(1, 0, SCALAR, 1) // ALPHA == 1, BETA == 0
GEMM_TT(1, 1, SCALAR, 1) // ALPHA == 1, BETA != 0
GEMM_TT(0, 0, SCALAR, 1) // ALPHA != 1, BETA == 0
GEMM_TT(0, 1, SCALAR, 1) // ALPHA != 1, BETA != 0
#undef BLOCKB_READ8
#undef MATB_PARAMETER
#undef MULTIPLY_BLOCKS_8x8
#undef TRANSPOSE_BLOCK_8
#undef GEMM_TT
#undef TILE_M
#undef TILE_K
#undef TILE_N
#undef SUBGROUP_BLOCK_READ8
#undef READ_IMAGE
#undef SIZE_OF_ELEMENT
__kernel void TEMPLATE(gemm_buffer_copy_image_transpose, Dtype)(
__global Dtype* A,
__write_only image2d_t ImA,
int offA,
int width,
int height,
int ldA)
{
const int gidx = get_global_id(0);
const int gidy = get_global_id(1);
if (gidx >= width || gidy >= height)
return;
int2 coord_dst = (int2)(gidx, gidy);
__global Dtype* A_off = A + offA;
Dtype srcA = A_off[gidy * ldA + gidx];
#if TYPE == TYPE_HALF
write_imageh(ImA, coord_dst, (Dtype4)srcA);
#else
write_imagef(ImA, coord_dst, (Dtype4)srcA);
#endif
}
__kernel void TEMPLATE(gemm_buffer_copy_image_no_transpose, Dtype)(
__global Dtype* A,
__write_only image2d_t ImA,
int offA,
int padded_width,
int padded_height,
int width,
int height,
int ldA)
{
const int gidx = get_global_id(0);
const int gidy = get_global_id(1);
if (gidx >= padded_width || gidy >= padded_height)
return;
int2 coord_dst = (int2)(gidx, gidy);
#if TYPE == TYPE_HALF
if (gidx >= width || gidy >= height) {
write_imageh(ImA, coord_dst, 0);
return;
}
__global Dtype* A_off = A + offA;
write_imageh(ImA, coord_dst, A_off[gidy * ldA + gidx]);
#else
if (gidx >= width || gidy >= height) {
write_imageui(ImA, coord_dst, (uint4)0);
return;
}
__global Dtype* A_off = A + offA;
uint4 srcA = convert_uint4(as_uchar4(A_off[gidy * ldA + gidx]));
write_imageui(ImA, coord_dst, srcA);
#endif
}