diff --git a/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.default.hpp b/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.default.hpp index 6a8ef6b590..b9362bb4d5 100644 --- a/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.default.hpp +++ b/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.default.hpp @@ -12,16 +12,16 @@ #include #include // parallel_for_ -#define FAST_GEMM_DEFAULT_STORAGE (1<<20) // 2^20 -#define FAST_GEMM_DEFAULT_MAX_STACKBUF (1 << 14) +#define FAST_GEMM_STORAGE (1<<20) // 2^20 +#define FAST_GEMM_MAX_STACKBUF (1 << 14) -#define FAST_GEMM_DEFAULT_F32_MC 64 -#define FAST_GEMM_DEFAULT_F32_NC 240 -#define FAST_GEMM_DEFAULT_F32_MR 8 -#define FAST_GEMM_DEFAULT_F32_NR 12 -#define FAST_GEMM_DEFAULT_F32_PACKED_STRIDE_K 256 +#define FAST_GEMM_F32_MC 64 +#define FAST_GEMM_F32_NC 240 +#define FAST_GEMM_F32_MR 8 +#define FAST_GEMM_F32_NR 12 +#define FAST_GEMM_F32_PACKED_STRIDE_K 64 -#define FAST_GEMM_DEFAULT_IMPLEMENT_PACK(N, suffix, styp, dtyp) \ +#define FAST_GEMM_IMPLEMENT_PACK(N, suffix, styp, dtyp) \ static void fast_gemm_pack##N##suffix( int m, int k, const void* A_, \ int lda0, int lda1, void* packA_ ) \ { \ @@ -32,47 +32,47 @@ static void fast_gemm_pack##N##suffix( int m, int k, const void* A_, \ const styp* a_ptr = A + lda0*i; \ for( int j = 0; j < k*lda1; packA += N, j += lda1 ) \ { \ - FAST_GEMM_DEFAULT_LOAD_TO_BUF_##N(styp); \ - FAST_GEMM_DEFAULT_PACK##suffix##_##N(buf, packA); \ + FAST_GEMM_LOAD_TO_BUF_##N(styp); \ + FAST_GEMM_PACK##suffix##_##N(buf, packA); \ } \ } else { \ const styp* a_ptr[N]; \ for (int k = 0; k < N; k++) a_ptr[k] = A + lda0*(i+k < m ? i+k : i); \ for( int j = 0; j < k*lda1; packA += N, j += lda1 ) \ { \ - FAST_GEMM_DEFAULT_LOAD_TO_BUF_BORDERS_##N(styp); \ - FAST_GEMM_DEFAULT_PACK##suffix##_##N(buf, packA); \ + FAST_GEMM_LOAD_TO_BUF_BORDERS_##N(styp); \ + FAST_GEMM_PACK##suffix##_##N(buf, packA); \ } \ } \ } \ } -#define FAST_GEMM_DEFAULT_LOAD_TO_BUF_8(styp) \ +#define FAST_GEMM_LOAD_TO_BUF_8(styp) \ styp buf[] = { \ a_ptr[j], a_ptr[j+lda0], a_ptr[j+lda0*2], a_ptr[j+lda0*3], \ a_ptr[j+lda0*4], a_ptr[j+lda0*5], a_ptr[j+lda0*6], a_ptr[j+lda0*7] } -#define FAST_GEMM_DEFAULT_LOAD_TO_BUF_BORDERS_8(styp) \ +#define FAST_GEMM_LOAD_TO_BUF_BORDERS_8(styp) \ styp buf[] = { \ a_ptr[0][j], a_ptr[1][j], a_ptr[2][j], a_ptr[3][j], \ a_ptr[4][j], a_ptr[5][j], a_ptr[6][j], a_ptr[7][j] } -#define FAST_GEMM_DEFAULT_LOAD_TO_BUF_12(styp) \ +#define FAST_GEMM_LOAD_TO_BUF_12(styp) \ styp buf[] = { \ a_ptr[j], a_ptr[j+lda0], a_ptr[j+lda0*2], a_ptr[j+lda0*3], \ a_ptr[j+lda0*4], a_ptr[j+lda0*5], a_ptr[j+lda0*6], a_ptr[j+lda0*7], \ a_ptr[j+lda0*8], a_ptr[j+lda0*9], a_ptr[j+lda0*10], a_ptr[j+lda0*11] } -#define FAST_GEMM_DEFAULT_LOAD_TO_BUF_BORDERS_12(styp) \ +#define FAST_GEMM_LOAD_TO_BUF_BORDERS_12(styp) \ styp buf[] = { \ a_ptr[0][j], a_ptr[1][j], a_ptr[2][j], a_ptr[3][j], \ a_ptr[4][j], a_ptr[5][j], a_ptr[6][j], a_ptr[7][j], \ a_ptr[8][j], a_ptr[9][j], a_ptr[10][j], a_ptr[11][j] } -#define FAST_GEMM_DEFAULT_PACK_COPY(src, dst, N) \ +#define FAST_GEMM_PACK_COPY(src, dst, N) \ memcpy((dst), (src), N*sizeof(src[0])) -#define FAST_GEMM_DEFAULT_PACK_f32_8(src, dst) FAST_GEMM_DEFAULT_PACK_COPY((src), (dst), 8) -#define FAST_GEMM_DEFAULT_PACK_f32_12(src, dst) FAST_GEMM_DEFAULT_PACK_COPY((src), (dst), 12) +#define FAST_GEMM_PACK_f32_8(src, dst) FAST_GEMM_PACK_COPY((src), (dst), 8) +#define FAST_GEMM_PACK_f32_12(src, dst) FAST_GEMM_PACK_COPY((src), (dst), 12) namespace cv { namespace dnn { namespace cpu_baseline { @@ -88,20 +88,20 @@ void fastGemmKernel(int M, int N, int K, float alpha, const char *A, int lda0, int lda1, const char *packed_B, float beta, char *C, int ldc, int esz); -FAST_GEMM_DEFAULT_IMPLEMENT_PACK(8, _f32, float, float) -FAST_GEMM_DEFAULT_IMPLEMENT_PACK(12, _f32, float, float) +FAST_GEMM_IMPLEMENT_PACK(8, _f32, float, float) +FAST_GEMM_IMPLEMENT_PACK(12, _f32, float, float) int fastGemmPackBSize(int N, int K) { - int GEMM_NC = FAST_GEMM_DEFAULT_F32_NC, GEMM_NR = FAST_GEMM_DEFAULT_F32_NR; + int GEMM_NC = FAST_GEMM_F32_NC, GEMM_NR = FAST_GEMM_F32_NR; int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; return static_cast((N + NC - 1) / NC) * NC * K; } void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0, int ldb1, int esz) { - int GEMM_NC = FAST_GEMM_DEFAULT_F32_NC, GEMM_NR = FAST_GEMM_DEFAULT_F32_NR; + int GEMM_NC = FAST_GEMM_F32_NC, GEMM_NR = FAST_GEMM_F32_NR; int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; - int KC = std::min(FAST_GEMM_DEFAULT_F32_PACKED_STRIDE_K, K); + int KC = std::min(FAST_GEMM_F32_PACKED_STRIDE_K, K); int n_tiles = (N + NC - 1) / NC; for (int r = 0; r < n_tiles; ++r) { @@ -116,140 +116,50 @@ void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0, } } -#if CV_SIMD128 -static void fast_gemm8x12_f32(int k, const char *a_, const char *b_, - char *c_, int ldc, float alpha) { +static inline void fast_gemm_f32(int k, const char *a_, const char *b_, + char *c_, int ldc, float alpha) { const float* a = (const float*)a_; const float* b = (const float*)b_; float* c = (float*)c_; - v_float32x4 s00 = v_setzero_f32(), s01 = s00, s02 = s00; - v_float32x4 s10 = s00, s11 = s00, s12 = s00; - v_float32x4 s20 = s00, s21 = s00, s22 = s00; - v_float32x4 s30 = s00, s31 = s00, s32 = s00; - v_float32x4 s40 = s00, s41 = s00, s42 = s00; - v_float32x4 s50 = s00, s51 = s00, s52 = s00; - v_float32x4 s60 = s00, s61 = s00, s62 = s00; - v_float32x4 s70 = s00, s71 = s00, s72 = s00; - - for(int p = 0; p < k; p++, a += FAST_GEMM_DEFAULT_F32_MR, b += FAST_GEMM_DEFAULT_F32_NR) { - v_float32x4 b0 = v_load(b), b1 = v_load(b + 4), b2 = v_load(b + 8); - - v_float32x4 a0 = v_setall_f32(*a); - s00 = v_fma(b0, a0, s00); - s01 = v_fma(b1, a0, s01); - s02 = v_fma(b2, a0, s02); - v_float32x4 a1 = v_setall_f32(*(a + 1)); - s10 = v_fma(b0, a1, s10); - s11 = v_fma(b1, a1, s11); - s12 = v_fma(b2, a1, s12); - - v_float32x4 a2 = v_setall_f32(*(a + 2)); - s20 = v_fma(b0, a2, s20); - s21 = v_fma(b1, a2, s21); - s22 = v_fma(b2, a2, s22); - v_float32x4 a3 = v_setall_f32(*(a + 3)); - s30 = v_fma(b0, a3, s30); - s31 = v_fma(b1, a3, s31); - s32 = v_fma(b2, a3, s32); - - a0 = v_setall_f32(*(a + 4)); - s40 = v_fma(b0, a0, s40); - s41 = v_fma(b1, a0, s41); - s42 = v_fma(b2, a0, s42); - a1 = v_setall_f32(*(a + 5)); - s50 = v_fma(b0, a1, s50); - s51 = v_fma(b1, a1, s51); - s52 = v_fma(b2, a1, s52); - - a2 = v_setall_f32(*(a + 6)); - s60 = v_fma(b0, a2, s60); - s61 = v_fma(b1, a2, s61); - s62 = v_fma(b2, a2, s62); - a3 = v_setall_f32(*(a + 7)); - s70 = v_fma(b0, a3, s70); - s71 = v_fma(b1, a3, s71); - s72 = v_fma(b2, a3, s72); - } - - v_float32x4 c0, c1, c2, c3, c4, c5, v_alpha = v_setall_f32(alpha); -#define FAST_GEMM_FINALE(row0, row1) \ - c0 = v_load(c + row0 * ldc); \ - c1 = v_load(c + row0 * ldc + 4); \ - c2 = v_load(c + row0 * ldc + 8); \ - c3 = v_load(c + row1 * ldc); \ - c4 = v_load(c + row1 * ldc + 4); \ - c5 = v_load(c + row1 * ldc + 8); \ - c0 = v_fma(s##row0##0, v_alpha, c0); \ - c1 = v_fma(s##row0##1, v_alpha, c1); \ - c2 = v_fma(s##row0##2, v_alpha, c2); \ - c3 = v_fma(s##row1##0, v_alpha, c3); \ - c4 = v_fma(s##row1##1, v_alpha, c4); \ - c5 = v_fma(s##row1##2, v_alpha, c5); \ - v_store(c + row0 * ldc, c0); \ - v_store(c + row0 * ldc + 4, c1); \ - v_store(c + row0 * ldc + 8, c2); \ - v_store(c + row1 * ldc, c3); \ - v_store(c + row1 * ldc + 4, c4); \ - v_store(c + row1 * ldc + 8, c5); - - FAST_GEMM_FINALE(0, 1); - FAST_GEMM_FINALE(2, 3); - FAST_GEMM_FINALE(4, 5); - FAST_GEMM_FINALE(6, 7); -#undef FAST_GEMM_FINALE -} - -#else -static void fast_gemm_f32(int k, const char *a_, const char *b_, - char *c_, int ldc, float alpha) { - const float* a = (const float*)a_; - const float* b = (const float*)b_; - float* c = (float*)c_; - - float sbuf[FAST_GEMM_DEFAULT_F32_MR * FAST_GEMM_DEFAULT_F32_NR]; + float sbuf[FAST_GEMM_F32_MR * FAST_GEMM_F32_NR]; memset(sbuf, 0, sizeof(sbuf)); for(int p = 0; p < k; p++) { - for( int i = 0; i < FAST_GEMM_DEFAULT_F32_MR; i++ ) { - float ai = a[FAST_GEMM_DEFAULT_F32_MR * p + i]; - for( int j = 0; j < FAST_GEMM_DEFAULT_F32_NR; j++ ) - sbuf[i * FAST_GEMM_DEFAULT_F32_NR + j] += b[FAST_GEMM_DEFAULT_F32_NR * p + j] * ai; + for( int i = 0; i < FAST_GEMM_F32_MR; i++ ) { + float ai = a[FAST_GEMM_F32_MR * p + i]; + for( int j = 0; j < FAST_GEMM_F32_NR; j++ ) + sbuf[i * FAST_GEMM_F32_NR + j] += b[FAST_GEMM_F32_NR * p + j] * ai; } } - for (int i = 0; i < FAST_GEMM_DEFAULT_F32_MR; i++) { - for (int j = 0; j < FAST_GEMM_DEFAULT_F32_NR; j++) - c[i * ldc + j] += alpha * sbuf[i * FAST_GEMM_DEFAULT_F32_NR + j]; + for (int i = 0; i < FAST_GEMM_F32_MR; i++) { + for (int j = 0; j < FAST_GEMM_F32_NR; j++) + c[i * ldc + j] += alpha * sbuf[i * FAST_GEMM_F32_NR + j]; } } -#endif // CV_SIMD128 static void fast_gemm_macro_kernel(int m, int n, int k, const char *packed_A, const char *packed_B, float alpha, char *c, int ldc0, int esz) { int ldc0_esz = ldc0 * esz; - double tempC[FAST_GEMM_DEFAULT_F32_MR * FAST_GEMM_DEFAULT_F32_NR]; // make sure the buffer is big enough - for(int i = 0; i < m; i += FAST_GEMM_DEFAULT_F32_MR) { - for(int j = 0; j < n; j += FAST_GEMM_DEFAULT_F32_NR) { + double tempC[FAST_GEMM_F32_MR * FAST_GEMM_F32_NR]; // make sure the buffer is big enough + for(int i = 0; i < m; i += FAST_GEMM_F32_MR) { + for(int j = 0; j < n; j += FAST_GEMM_F32_NR) { char* cptr0 = &c[i * ldc0_esz + j * esz]; char* cptr = cptr0; int ldc = ldc0; - int mr = m - i < FAST_GEMM_DEFAULT_F32_MR ? m - i : FAST_GEMM_DEFAULT_F32_MR; - int nr = n - j < FAST_GEMM_DEFAULT_F32_NR ? n - j : FAST_GEMM_DEFAULT_F32_NR; + int mr = m - i < FAST_GEMM_F32_MR ? m - i : FAST_GEMM_F32_MR; + int nr = n - j < FAST_GEMM_F32_NR ? n - j : FAST_GEMM_F32_NR; int nr_esz = nr * esz; - bool partial = (bool)((mr < FAST_GEMM_DEFAULT_F32_MR) | (nr < FAST_GEMM_DEFAULT_F32_NR)); + bool partial = (bool)((mr < FAST_GEMM_F32_MR) | (nr < FAST_GEMM_F32_NR)); if (partial) { memset(tempC, 0, sizeof(tempC)); cptr = (char *)tempC; - ldc = FAST_GEMM_DEFAULT_F32_NR; + ldc = FAST_GEMM_F32_NR; for(int p = 0; p < mr; p++) memcpy(cptr + p * (ldc * esz), cptr0 + p * ldc0_esz, nr_esz); } -#if CV_SIMD128 - fast_gemm8x12_f32(k, packed_A + i * k * esz, packed_B + j * k * esz, cptr, ldc, alpha); -#else fast_gemm_f32(k, packed_A + i * k * esz, packed_B + j * k * esz, cptr, ldc, alpha); -#endif if (partial) { for(int p = 0; p < mr; p++) @@ -263,19 +173,19 @@ void fastGemmKernel(int M, int N, int K, float alpha, const char *A, int lda0, int lda1, const char *B, int ldb0, int ldb1, float beta, char *C, int ldc, int esz) { - int GEMM_MC = FAST_GEMM_DEFAULT_F32_MC, - GEMM_NC = FAST_GEMM_DEFAULT_F32_NC, - GEMM_MR = FAST_GEMM_DEFAULT_F32_MR, - GEMM_NR = FAST_GEMM_DEFAULT_F32_NR; + int GEMM_MC = FAST_GEMM_F32_MC, + GEMM_NC = FAST_GEMM_F32_NC, + GEMM_MR = FAST_GEMM_F32_MR, + GEMM_NR = FAST_GEMM_F32_NR; int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR; int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; - int KC = FAST_GEMM_DEFAULT_STORAGE / ((MC + NC) * esz); + int KC = FAST_GEMM_STORAGE / ((MC + NC) * esz); KC = KC > 8 ? KC : 8; KC = KC < K ? KC : K; size_t buff_size = KC * (MC + NC) * esz; - bool use_stackbuff = buff_size <= FAST_GEMM_DEFAULT_MAX_STACKBUF; + bool use_stackbuff = buff_size <= FAST_GEMM_MAX_STACKBUF; int m_tiles = (M + MC - 1) / MC; int n_tiles = (N + NC - 1) / NC; int total_tiles = m_tiles * n_tiles; @@ -328,17 +238,17 @@ void fastGemmKernel(int M, int N, int K, void fastGemmKernel(int M, int N, int K, float alpha, const char *A, int lda0, int lda1, const char *packed_B, float beta, char *C, int ldc, int esz) { - int GEMM_MC = FAST_GEMM_DEFAULT_F32_MC, - GEMM_NC = FAST_GEMM_DEFAULT_F32_NC, - GEMM_MR = FAST_GEMM_DEFAULT_F32_MR, - GEMM_NR = FAST_GEMM_DEFAULT_F32_NR; + int GEMM_MC = FAST_GEMM_F32_MC, + GEMM_NC = FAST_GEMM_F32_NC, + GEMM_MR = FAST_GEMM_F32_MR, + GEMM_NR = FAST_GEMM_F32_NR; int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR; int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; - int KC = std::min(FAST_GEMM_DEFAULT_F32_PACKED_STRIDE_K, K); + int KC = std::min(FAST_GEMM_F32_PACKED_STRIDE_K, K); size_t buff_size = KC * MC * esz; - bool use_stackbuff = buff_size <= FAST_GEMM_DEFAULT_MAX_STACKBUF; + bool use_stackbuff = buff_size <= FAST_GEMM_MAX_STACKBUF; int m_tiles = (M + MC - 1) / MC; int n_tiles = (N + NC - 1) / NC; int total_tiles = m_tiles * n_tiles; @@ -391,3 +301,29 @@ void fastGemmKernel(int M, int N, int K, } }}} // cv::dnn::cpu_baseline + +#undef FAST_GEMM_STORAGE +#undef FAST_GEMM_MAX_STACKBUF +#ifdef FAST_GEMM_F32_MC +#undef FAST_GEMM_F32_MC +#endif +#ifdef FAST_GEMM_F32_NC +#undef FAST_GEMM_F32_NC +#endif +#ifdef FAST_GEMM_F32_MR +#undef FAST_GEMM_F32_MR +#endif +#ifdef FAST_GEMM_F32_NR +#undef FAST_GEMM_F32_NR +#endif +#ifdef FAST_GEMM_F32_PACKED_STRIDE_K +#undef FAST_GEMM_F32_PACKED_STRIDE_K +#endif +#undef FAST_GEMM_IMPLEMENT_PACK +#undef FAST_GEMM_LOAD_TO_BUF_8 +#undef FAST_GEMM_LOAD_TO_BUF_BORDERS_8 +#undef FAST_GEMM_LOAD_TO_BUF_12 +#undef FAST_GEMM_LOAD_TO_BUF_BORDERS_12 +#undef FAST_GEMM_PACK_COPY +#undef FAST_GEMM_PACK_f32_8 +#undef FAST_GEMM_PACK_f32_12 diff --git a/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.simd.hpp b/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.simd.hpp index 99a7d3b2d7..7d123ed9b5 100644 --- a/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.simd.hpp +++ b/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.simd.hpp @@ -15,37 +15,31 @@ #define FAST_GEMM_STORAGE (1<<20) // 2^20 #define FAST_GEMM_MAX_STACKBUF (1 << 14) -#if CV_NEON -#define FAST_GEMM_F32_MC 64 -#define FAST_GEMM_F32_NC 240 -#elif CV_AVX +#if CV_AVX #define FAST_GEMM_F32_MC 60 #define FAST_GEMM_F32_NC 320 #elif CV_LASX #define FAST_GEMM_F32_MC 48 #define FAST_GEMM_F32_NC 128 +#else // CV_NEON_AARCH64, SIMD128 +#define FAST_GEMM_F32_MC 64 +#define FAST_GEMM_F32_NC 240 #endif -// micro kernel size -#if CV_NEON && CV_NEON_AARCH64 -#define FAST_GEMM_F32_MR 8 -#define FAST_GEMM_F32_NR 12 -#elif CV_NEON -#define FAST_GEMM_F32_MR 4 -#define FAST_GEMM_F32_NR 12 -#elif CV_AVX +#if CV_AVX #define FAST_GEMM_F32_MR 12 #define FAST_GEMM_F32_NR 8 #elif CV_LASX #define FAST_GEMM_F32_MR 12 #define FAST_GEMM_F32_NR 16 +#else // CV_NEON_AARCH64, CV_SIMD128 +#define FAST_GEMM_F32_MR 8 +#define FAST_GEMM_F32_NR 12 #endif -#if CV_NEON -#define FAST_GEMM_F32_PACKED_STRIDE_K 64 -#elif CV_AVX +#if CV_AVX #define FAST_GEMM_F32_PACKED_STRIDE_K 128 -#elif CV_LASX +#else // CV_LASX, CV_NEON_AARCH64, CV_SIMD128 #define FAST_GEMM_F32_PACKED_STRIDE_K 64 #endif @@ -75,14 +69,6 @@ static void fast_gemm_pack##N##suffix( int m, int k, const void* A_, \ } \ } -#define FAST_GEMM_LOAD_TO_BUF_4(styp) \ - styp buf[] = { \ - a_ptr[j], a_ptr[j+lda0], a_ptr[j+lda0*2], a_ptr[j+lda0*3] } - -#define FAST_GEMM_LOAD_TO_BUF_BORDERS_4(styp) \ - styp buf[] = { \ - a_ptr[0][j], a_ptr[1][j], a_ptr[2][j], a_ptr[3][j] } - #define FAST_GEMM_LOAD_TO_BUF_8(styp) \ styp buf[] = { \ a_ptr[j], a_ptr[j+lda0], a_ptr[j+lda0*2], a_ptr[j+lda0*3], \ @@ -121,7 +107,6 @@ static void fast_gemm_pack##N##suffix( int m, int k, const void* A_, \ #define FAST_GEMM_PACK_COPY(src, dst, N) \ memcpy((dst), (src), N*sizeof(src[0])) -#define FAST_GEMM_PACK_f32_4(src, dst) FAST_GEMM_PACK_COPY((src), (dst), 4) #define FAST_GEMM_PACK_f32_8(src, dst) FAST_GEMM_PACK_COPY((src), (dst), 8) #define FAST_GEMM_PACK_f32_12(src, dst) FAST_GEMM_PACK_COPY((src), (dst), 12) #define FAST_GEMM_PACK_f32_16(src, dst) FAST_GEMM_PACK_COPY((src), (dst), 16) @@ -130,7 +115,6 @@ namespace cv { namespace dnn { CV_CPU_OPTIMIZATION_NAMESPACE_BEGIN -// TODO: type to size_t int fastGemmPackBSize(int N, int K); void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0, int ldb1, int esz); @@ -143,44 +127,18 @@ void fastGemmKernel(int M, int N, int K, float alpha, const char *A, int lda0, int lda1, const char *packed_B, float beta, char *C, int ldc, int esz); -// NEON (AARCH64: 32 x 128-bit registers, armv7: 16 x 128-bit registers) -#if !defined(CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY) && CV_NEON +#ifndef CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY -#if CV_NEON_AARCH64 -FAST_GEMM_IMPLEMENT_PACK(8, _f32, float, float) -#else -FAST_GEMM_IMPLEMENT_PACK(4, _f32, float, float) -#endif -FAST_GEMM_IMPLEMENT_PACK(12, _f32, float, float) - -int fastGemmPackBSize(int N, int K) { - int GEMM_NC = FAST_GEMM_F32_NC, GEMM_NR = FAST_GEMM_F32_NR; - int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; - - return static_cast((N + NC - 1) / NC) * NC * K; -} - -void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0, int ldb1, int esz) { - int GEMM_NC = FAST_GEMM_F32_NC, GEMM_NR = FAST_GEMM_F32_NR; - int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; - int KC = std::min(FAST_GEMM_F32_PACKED_STRIDE_K, K); +/* + Compute kernels that optimized for different platforms +*/ +#if CV_NEON && CV_NEON_AARCH64 // AARCH64: 32 x 128-bit registers - int n_tiles = (N + NC - 1) / NC; - for (int r = 0; r < n_tiles; ++r) { - int j0 = r * NC; - int nc = N - j0 < NC ? N - j0 : NC; - int _nc = static_cast((nc + GEMM_NR - 1) / GEMM_NR) * GEMM_NR * esz; - for (int k = 0; k < K; k += KC) { - int kc = K - k < KC ? K - k : KC; - fast_gemm_pack12_f32(nc, kc, B + (k * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_B); - packed_B += _nc * kc; - } - } -} +FAST_GEMM_IMPLEMENT_PACK(8, _f32, float, float) // a packer +FAST_GEMM_IMPLEMENT_PACK(12, _f32, float, float) // b packer -#if CV_NEON_AARCH64 -static void fast_gemm8x12_f32(int k, const char *a_, const char *b_, - char *c_, int ldc, float alpha) { +static inline void fast_gemm8x12_f32(int k, const char *a_, const char *b_, + char *c_, int ldc, float alpha) { const float* a = (const float*)a_; const float* b = (const float*)b_; float* c = (float*)c_; @@ -219,583 +177,131 @@ static void fast_gemm8x12_f32(int k, const char *a_, const char *b_, s41 = vfmaq_laneq_f32(s41, b1, a0, 0); s42 = vfmaq_laneq_f32(s42, b2, a0, 0); s50 = vfmaq_laneq_f32(s50, b0, a0, 1); - s51 = vfmaq_laneq_f32(s51, b1, a0, 1); - s52 = vfmaq_laneq_f32(s52, b2, a0, 1); - - s60 = vfmaq_laneq_f32(s60, b0, a0, 2); - s61 = vfmaq_laneq_f32(s61, b1, a0, 2); - s62 = vfmaq_laneq_f32(s62, b2, a0, 2); - s70 = vfmaq_laneq_f32(s70, b0, a0, 3); - s71 = vfmaq_laneq_f32(s71, b1, a0, 3); - s72 = vfmaq_laneq_f32(s72, b2, a0, 3); - } - - float32x4_t c0, c1, c2, c3, c4, c5, v_alpha = vdupq_n_f32(alpha); -#define FAST_GEMM_FINALE(row0, row1) \ - c0 = vld1q_f32(c + row0 * ldc); \ - c1 = vld1q_f32(c + row0 * ldc + 4); \ - c2 = vld1q_f32(c + row0 * ldc + 8); \ - c3 = vld1q_f32(c + row1 * ldc); \ - c4 = vld1q_f32(c + row1 * ldc + 4); \ - c5 = vld1q_f32(c + row1 * ldc + 8); \ - c0 = vfmaq_f32(c0, s##row0##0, v_alpha); \ - c1 = vfmaq_f32(c1, s##row0##1, v_alpha); \ - c2 = vfmaq_f32(c2, s##row0##2, v_alpha); \ - c3 = vfmaq_f32(c3, s##row1##0, v_alpha); \ - c4 = vfmaq_f32(c4, s##row1##1, v_alpha); \ - c5 = vfmaq_f32(c5, s##row1##2, v_alpha); \ - vst1q_f32(c + row0 * ldc, c0); \ - vst1q_f32(c + row0 * ldc + 4, c1); \ - vst1q_f32(c + row0 * ldc + 8, c2); \ - vst1q_f32(c + row1 * ldc, c3); \ - vst1q_f32(c + row1 * ldc + 4, c4); \ - vst1q_f32(c + row1 * ldc + 8, c5); - - FAST_GEMM_FINALE(0, 1); - FAST_GEMM_FINALE(2, 3); - FAST_GEMM_FINALE(4, 5); - FAST_GEMM_FINALE(6, 7); -#undef FAST_GEMM_FINALE -} - -#else // CV_NEON_AARCH64 -static void fast_gemm4x12_f32(int k, const char *a_, const char *b_, - char *c_, int ldc, float alpha) { - const float* a = (const float*)a_; - const float* b = (const float*)b_; - float* c = (float*)c_; - - float32x4_t s00 = vdupq_n_f32(0.f), s01 = s00, s02 = s00, - s10 = s00, s11 = s00, s12 = s00, - s20 = s00, s21 = s00, s22 = s00, - s30 = s00, s31 = s00, s32 = s00; - - for(int p = 0; p < k; p++, a += FAST_GEMM_F32_MR, b += FAST_GEMM_F32_NR) - { - float32x4_t b0 = vld1q_f32(b), b1 = vld1q_f32(b + 4), b2 = vld1q_f32(b + 8); - - float32x4_t a0 = vld1q_dup_f32(a); - s00 = vmlaq_f32(a0, b0, s00); - s01 = vmlaq_f32(a0, b1, s01); - s02 = vmlaq_f32(a0, b2, s02); - - a0 = vld1q_dup_f32(a + 1); - s10 = vmlaq_f32(a0, b0, s10); - s11 = vmlaq_f32(a0, b1, s11); - s12 = vmlaq_f32(a0, b2, s12); - - a0 = vld1q_dup_f32(a + 2); - s20 = vmlaq_f32(a0, b0, s20); - s21 = vmlaq_f32(a0, b1, s21); - s22 = vmlaq_f32(a0, b2, s22); - - a0 = vld1q_dup_f32(a + 3); - s30 = vmlaq_f32(a0, b0, s30); - s31 = vmlaq_f32(a0, b1, s31); - s32 = vmlaq_f32(a0, b2, s32); - } - - float32x4_t c0, c1, c2, v_alpha = vdupq_n_f32(alpha); -#define FAST_GEMM_FINALE(row0) \ - c0 = vld1q_f32(c + row0 * ldc); \ - c1 = vld1q_f32(c + row0 * ldc + 4); \ - c2 = vld1q_f32(c + row0 * ldc + 8); \ - c0 = vmlaq_f32(c0, s##row0##0, v_alpha); \ - c1 = vmlaq_f32(c1, s##row0##1, v_alpha); \ - c2 = vmlaq_f32(c2, s##row0##2, v_alpha); \ - vst1q_f32(c + row0 * ldc, c0); \ - vst1q_f32(c + row0 * ldc + 4, c1); \ - vst1q_f32(c + row0 * ldc + 8, c2); - - FAST_GEMM_FINALE(0); - FAST_GEMM_FINALE(1); - FAST_GEMM_FINALE(2); - FAST_GEMM_FINALE(3); -#undef FAST_GEMM_FINALE -} - -#endif // micro kernel CV_NEON_AARCH64 - -static void fast_gemm_macro_kernel(int m, int n, int k, - const char *packed_A, const char *packed_B, - float alpha, char *c, int ldc0, int esz) { - int ldc0_esz = ldc0 * esz; - - double tempC[FAST_GEMM_F32_MR * FAST_GEMM_F32_NR]; // make sure the buffer is big enough - for(int i = 0; i < m; i += FAST_GEMM_F32_MR) { - for(int j = 0; j < n; j += FAST_GEMM_F32_NR) { - char* cptr0 = &c[i * ldc0_esz + j * esz]; - char* cptr = cptr0; - int ldc = ldc0; - int mr = m - i < FAST_GEMM_F32_MR ? m - i : FAST_GEMM_F32_MR; - int nr = n - j < FAST_GEMM_F32_NR ? n - j : FAST_GEMM_F32_NR; - int nr_esz = nr * esz; - bool partial = (bool)((mr < FAST_GEMM_F32_MR) | (nr < FAST_GEMM_F32_NR)); - if (partial) { - memset(tempC, 0, sizeof(tempC)); - cptr = (char *)tempC; - ldc = FAST_GEMM_F32_NR; - for(int p = 0; p < mr; p++) - memcpy(cptr + p * (ldc * esz), cptr0 + p * ldc0_esz, nr_esz); - } -#if CV_NEON_AARCH64 - fast_gemm8x12_f32(k, packed_A + i * k * esz, packed_B + j * k * esz, cptr, ldc, alpha); -#else - fast_gemm4x12_f32(k, packed_A + i * k * esz, packed_B + j * k * esz, cptr, ldc, alpha); -#endif - - if (partial) { - for(int p = 0; p < mr; p++) - memcpy(cptr0 + p * ldc0_esz, cptr + p * (ldc * esz), nr_esz); - } - } - } -} - -void fastGemmKernel(int M, int N, int K, - float alpha, const char *A, int lda0, int lda1, - const char *B, int ldb0, int ldb1, - float beta, char *C, int ldc, int esz) { - int GEMM_MC = FAST_GEMM_F32_MC, - GEMM_NC = FAST_GEMM_F32_NC, - GEMM_MR = FAST_GEMM_F32_MR, - GEMM_NR = FAST_GEMM_F32_NR; - - int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR; - int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; - int KC = FAST_GEMM_STORAGE / ((MC + NC) * esz); - KC = KC > 8 ? KC : 8; - KC = KC < K ? KC : K; - - size_t buff_size = KC * (MC + NC) * esz; - bool use_stackbuff = buff_size <= FAST_GEMM_MAX_STACKBUF; - int m_tiles = (M + MC - 1) / MC; - int n_tiles = (N + NC - 1) / NC; - int total_tiles = m_tiles * n_tiles; - - auto fn = [&](const Range &r) { - char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size)); - char* packed_b = packed_a + KC * MC * esz; - int start = r.start; - int end = r.end; - - for (int tile_idx = start; tile_idx < end; tile_idx++) { - int i0 = (tile_idx / n_tiles) * MC; - int j0 = (tile_idx % n_tiles) * NC; - int mc = M - i0 < MC ? M - i0 : MC; - int nc = N - j0 < NC ? N - j0 : NC; - int ldc_block = ldc; - char* c_block = C + (i0 * ldc + j0) * esz; - - if (beta == 0.f) { - for(int i = 0; i < mc; i++) - memset(c_block + i * ldc_block * esz, 0, nc * esz); - } else if (beta != 1.f) { - for(int i = 0; i < mc; i++) { - float* c_i = (float*)c_block + i * ldc_block; - for(int j = 0; j < nc; j++) - c_i[j] *= beta; - } - } - - for(int k0 = 0; k0 < K; k0 += KC) - { - int kc = K - k0 < KC ? K - k0 : KC; -#if CV_NEON_AARCH64 - fast_gemm_pack8_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); -#else - fast_gemm_pack4_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); -#endif - fast_gemm_pack12_f32(nc, kc, B + (k0 * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_b); - fast_gemm_macro_kernel(mc, nc, kc, packed_a, packed_b, alpha, c_block, ldc_block, esz); - } - } - - if (!use_stackbuff) { - free(packed_a); - } - }; - - int total = total_tiles; - int cost_per_thread = static_cast((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); - double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); - parallel_for_(Range(0, total), fn, nstripes); -} - -void fastGemmKernel(int M, int N, int K, - float alpha, const char *A, int lda0, int lda1, - const char *packed_B, float beta, char *C, int ldc, int esz) { - int GEMM_MC = FAST_GEMM_F32_MC, - GEMM_NC = FAST_GEMM_F32_NC, - GEMM_MR = FAST_GEMM_F32_MR, - GEMM_NR = FAST_GEMM_F32_NR; - - int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR; - int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; - int KC = std::min(FAST_GEMM_F32_PACKED_STRIDE_K, K); - - size_t buff_size = KC * MC * esz; - bool use_stackbuff = buff_size <= FAST_GEMM_MAX_STACKBUF; - int m_tiles = (M + MC - 1) / MC; - int n_tiles = (N + NC - 1) / NC; - int total_tiles = m_tiles * n_tiles; - - auto fn = [&](const Range &r) { - char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size)); // TODO: use AutoBuffer - const char *packed_b_ = packed_B; - int start = r.start; - int end = r.end; - - for (int tile_idx = start; tile_idx < end; tile_idx++) { - int i0 = (tile_idx / n_tiles) * MC; - int j0 = (tile_idx % n_tiles) * NC; - int mc = M - i0 < MC ? M - i0 : MC; - int nc = N - j0 < NC ? N - j0 : NC; - int ldc_block = ldc; - char* c_block = C + (i0 * ldc + j0) * esz; - packed_b_ = packed_B + j0 * K * esz; - - if (beta == 0.f) { - for(int i = 0; i < mc; i++) - memset(c_block + i * ldc_block * esz, 0, nc * esz); - } else if (beta != 1.f) { - for(int i = 0; i < mc; i++) { - float* c_i = (float*)c_block + i * ldc_block; - for(int j = 0; j < nc; j++) - c_i[j] *= beta; - } - } - - int _nc = static_cast((nc + GEMM_NR - 1) / GEMM_NR) * GEMM_NR * esz; - for(int k0 = 0; k0 < K; k0 += KC) - { - int kc = K - k0 < KC ? K - k0 : KC; -#if CV_NEON_AARCH64 - fast_gemm_pack8_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); -#else - fast_gemm_pack4_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); -#endif - fast_gemm_macro_kernel(mc, nc, kc, packed_a, packed_b_, alpha, c_block, ldc_block, esz); - packed_b_ += _nc * kc; - } - } - - if (!use_stackbuff) { - free(packed_a); - } - }; - - int total = total_tiles; - int cost_per_thread = static_cast((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); - double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); - parallel_for_(Range(0, total), fn, nstripes); -} - -#endif // CV_NEON, CV_NEON_AARCH64 - -// AVX and AVX2 (16 x 256-bit registers) -#if !defined(CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY) && CV_AVX - -FAST_GEMM_IMPLEMENT_PACK(8, _f32, float, float) -FAST_GEMM_IMPLEMENT_PACK(12, _f32, float, float) - -int fastGemmPackBSize(int N, int K) { - int GEMM_NC = FAST_GEMM_F32_NC, GEMM_NR = FAST_GEMM_F32_NR; - int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; - - return static_cast((N + NC - 1) / NC) * NC * K; -} - -void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0, int ldb1, int esz) { - int GEMM_NC = FAST_GEMM_F32_NC, GEMM_NR = FAST_GEMM_F32_NR; - int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; - int KC = std::min(FAST_GEMM_F32_PACKED_STRIDE_K, K); - - int n_tiles = (N + NC - 1) / NC; - for (int r = 0; r < n_tiles; ++r) { - int j0 = r * NC; - int nc = N - j0 < NC ? N - j0 : NC; - int _nc = static_cast((nc + GEMM_NR - 1) / GEMM_NR) * GEMM_NR * esz; - for (int k = 0; k < K; k += KC) { - int kc = K - k < KC ? K - k : KC; - fast_gemm_pack8_f32(nc, kc, B + (k * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_B); - packed_B += _nc * kc; - } - } -} - -#if !CV_FMA3 // AVX workaround for FMA -#undef _mm256_fmadd_ps -#define _mm256_fmadd_ps(a, b, c) _mm256_add_ps(c, _mm256_mul_ps(a, b)) -#endif - -static void fast_gemm12x8_f32(int k, const char *a_, const char *b_, char *c_, int ldc, float alpha) { - const float* a = (const float*)a_; - const float* b = (const float*)b_; - float* c = (float*)c_; - - __m256 s00 = _mm256_setzero_ps(), - s10 = _mm256_setzero_ps(), - s20 = _mm256_setzero_ps(), - s30 = _mm256_setzero_ps(), - s40 = _mm256_setzero_ps(), - s50 = _mm256_setzero_ps(), - s60 = _mm256_setzero_ps(), - s70 = _mm256_setzero_ps(), - s80 = _mm256_setzero_ps(), - s90 = _mm256_setzero_ps(), - s100 = _mm256_setzero_ps(), - s110 = _mm256_setzero_ps(); - for (int p = 0; p < k; p++, a += FAST_GEMM_F32_MR, b += FAST_GEMM_F32_NR) { - __m256 b0 = _mm256_loadu_ps(b); - - __m256 a0 = _mm256_set1_ps(*a); - s00 = _mm256_fmadd_ps(b0, a0, s00); - __m256 a1 = _mm256_set1_ps(*(a + 1)); - s10 = _mm256_fmadd_ps(b0, a1, s10); - __m256 a2 = _mm256_set1_ps(*(a + 2)); - s20 = _mm256_fmadd_ps(b0, a2, s20); - - a0 = _mm256_set1_ps(*(a + 3)); - s30 = _mm256_fmadd_ps(b0, a0, s30); - a1 = _mm256_set1_ps(*(a + 4)); - s40 = _mm256_fmadd_ps(b0, a1, s40); - a2 = _mm256_set1_ps(*(a + 5)); - s50 = _mm256_fmadd_ps(b0, a2, s50); - - a0 = _mm256_set1_ps(*(a + 6)); - s60 = _mm256_fmadd_ps(b0, a0, s60); - a1 = _mm256_set1_ps(*(a + 7)); - s70 = _mm256_fmadd_ps(b0, a1, s70); - a2 = _mm256_set1_ps(*(a + 8)); - s80 = _mm256_fmadd_ps(b0, a2, s80); - - a0 = _mm256_set1_ps(*(a + 9)); - s90 = _mm256_fmadd_ps(b0, a0, s90); - a1 = _mm256_set1_ps(*(a + 10)); - s100 = _mm256_fmadd_ps(b0, a1, s100); - a2 = _mm256_set1_ps(*(a + 11)); - s110 = _mm256_fmadd_ps(b0, a2, s110); - } - - __m256 c0, c1, c2, c3, v_alpha = _mm256_set1_ps(alpha); -#define FAST_GEMM_FINALE(row0, row1, row2, row3) \ - c0 = _mm256_loadu_ps(c + row0 * ldc); \ - c1 = _mm256_loadu_ps(c + row1 * ldc); \ - c2 = _mm256_loadu_ps(c + row2 * ldc); \ - c3 = _mm256_loadu_ps(c + row3 * ldc); \ - c0 = _mm256_fmadd_ps(s##row0##0, v_alpha, c0); \ - c1 = _mm256_fmadd_ps(s##row1##0, v_alpha, c1); \ - c2 = _mm256_fmadd_ps(s##row2##0, v_alpha, c2); \ - c3 = _mm256_fmadd_ps(s##row3##0, v_alpha, c3); \ - _mm256_storeu_ps(c + row0 * ldc, c0); \ - _mm256_storeu_ps(c + row1 * ldc, c1); \ - _mm256_storeu_ps(c + row2 * ldc, c2); \ - _mm256_storeu_ps(c + row3 * ldc, c3); \ - - FAST_GEMM_FINALE(0, 1, 2, 3); - FAST_GEMM_FINALE(4, 5, 6, 7); - FAST_GEMM_FINALE(8, 9, 10, 11); -#undef FAST_GEMM_FINALE -} - -static void fast_gemm_macro_kernel(int m, int n, int k, - const char *packed_A, const char *packed_B, - float alpha, char *c, int ldc0, int esz) { - int ldc0_esz = ldc0 * esz; - - double tempC[FAST_GEMM_F32_MR * FAST_GEMM_F32_NR]; // make sure the buffer is big enough - for(int i = 0; i < m; i += FAST_GEMM_F32_MR) { - for(int j = 0; j < n; j += FAST_GEMM_F32_NR) { - char* cptr0 = &c[i * ldc0_esz + j * esz]; - char* cptr = cptr0; - int ldc = ldc0; - int mr = m - i < FAST_GEMM_F32_MR ? m - i : FAST_GEMM_F32_MR; - int nr = n - j < FAST_GEMM_F32_NR ? n - j : FAST_GEMM_F32_NR; - int nr_esz = nr * esz; - bool partial = (bool)((mr < FAST_GEMM_F32_MR) | (nr < FAST_GEMM_F32_NR)); - if (partial) { - memset(tempC, 0, sizeof(tempC)); - cptr = (char *)tempC; - ldc = FAST_GEMM_F32_NR; - for(int p = 0; p < mr; p++) - memcpy(cptr + p * (ldc * esz), cptr0 + p * ldc0_esz, nr_esz); - } - fast_gemm12x8_f32(k, packed_A + i * k * esz, packed_B + j * k * esz, cptr, ldc, alpha); - - if (partial) { - for(int p = 0; p < mr; p++) - memcpy(cptr0 + p * ldc0_esz, cptr + p * (ldc * esz), nr_esz); - } - } - } -} - -void fastGemmKernel(int M, int N, int K, - float alpha, const char *A, int lda0, int lda1, - const char *B, int ldb0, int ldb1, - float beta, char *C, int ldc, int esz) { - int GEMM_MC = FAST_GEMM_F32_MC, - GEMM_NC = FAST_GEMM_F32_NC, - GEMM_MR = FAST_GEMM_F32_MR, - GEMM_NR = FAST_GEMM_F32_NR; - - int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR; - int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; - int KC = FAST_GEMM_STORAGE / ((MC + NC) * esz); - KC = KC > 8 ? KC : 8; - KC = KC < K ? KC : K; - - size_t buff_size = KC * (MC + NC) * esz; - bool use_stackbuff = buff_size <= FAST_GEMM_MAX_STACKBUF; - int m_tiles = (M + MC - 1) / MC; - int n_tiles = (N + NC - 1) / NC; - int total_tiles = m_tiles * n_tiles; - - auto fn = [&](const Range &r) { - char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size)); - char* packed_b = packed_a + KC * MC * esz; - int start = r.start; - int end = r.end; - - for (int tile_idx = start; tile_idx < end; tile_idx++) { - int i0 = (tile_idx / n_tiles) * MC; - int j0 = (tile_idx % n_tiles) * NC; - int mc = M - i0 < MC ? M - i0 : MC; - int nc = N - j0 < NC ? N - j0 : NC; - int ldc_block = ldc; - char* c_block = C + (i0 * ldc + j0) * esz; - - if (beta == 0.f) { - for(int i = 0; i < mc; i++) - memset(c_block + i * ldc_block * esz, 0, nc * esz); - } else if (beta != 1.f) { - for(int i = 0; i < mc; i++) { - float* c_i = (float*)c_block + i * ldc_block; - for(int j = 0; j < nc; j++) - c_i[j] *= beta; - } - } - - for(int k0 = 0; k0 < K; k0 += KC) - { - int kc = K - k0 < KC ? K - k0 : KC; - fast_gemm_pack12_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); - fast_gemm_pack8_f32(nc, kc, B + (k0 * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_b); - fast_gemm_macro_kernel(mc, nc, kc, packed_a, packed_b, alpha, c_block, ldc_block, esz); - } - } - - if (!use_stackbuff) { - free(packed_a); - } - }; - - int total = total_tiles; - int cost_per_thread = static_cast((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); - double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); - parallel_for_(Range(0, total), fn, nstripes); -} - -void fastGemmKernel(int M, int N, int K, - float alpha, const char *A, int lda0, int lda1, - const char *packed_B, float beta, char *C, int ldc, int esz) { - int GEMM_MC = FAST_GEMM_F32_MC, - GEMM_NC = FAST_GEMM_F32_NC, - GEMM_MR = FAST_GEMM_F32_MR, - GEMM_NR = FAST_GEMM_F32_NR; + s51 = vfmaq_laneq_f32(s51, b1, a0, 1); + s52 = vfmaq_laneq_f32(s52, b2, a0, 1); - int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR; - int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; - int KC = std::min(FAST_GEMM_F32_PACKED_STRIDE_K, K); + s60 = vfmaq_laneq_f32(s60, b0, a0, 2); + s61 = vfmaq_laneq_f32(s61, b1, a0, 2); + s62 = vfmaq_laneq_f32(s62, b2, a0, 2); + s70 = vfmaq_laneq_f32(s70, b0, a0, 3); + s71 = vfmaq_laneq_f32(s71, b1, a0, 3); + s72 = vfmaq_laneq_f32(s72, b2, a0, 3); + } - size_t buff_size = KC * MC * esz; - bool use_stackbuff = buff_size <= FAST_GEMM_MAX_STACKBUF; - int m_tiles = (M + MC - 1) / MC; - int n_tiles = (N + NC - 1) / NC; - int total_tiles = m_tiles * n_tiles; + float32x4_t c0, c1, c2, c3, c4, c5, v_alpha = vdupq_n_f32(alpha); +#define FAST_GEMM_FINALE(row0, row1) \ + c0 = vld1q_f32(c + row0 * ldc); \ + c1 = vld1q_f32(c + row0 * ldc + 4); \ + c2 = vld1q_f32(c + row0 * ldc + 8); \ + c3 = vld1q_f32(c + row1 * ldc); \ + c4 = vld1q_f32(c + row1 * ldc + 4); \ + c5 = vld1q_f32(c + row1 * ldc + 8); \ + c0 = vfmaq_f32(c0, s##row0##0, v_alpha); \ + c1 = vfmaq_f32(c1, s##row0##1, v_alpha); \ + c2 = vfmaq_f32(c2, s##row0##2, v_alpha); \ + c3 = vfmaq_f32(c3, s##row1##0, v_alpha); \ + c4 = vfmaq_f32(c4, s##row1##1, v_alpha); \ + c5 = vfmaq_f32(c5, s##row1##2, v_alpha); \ + vst1q_f32(c + row0 * ldc, c0); \ + vst1q_f32(c + row0 * ldc + 4, c1); \ + vst1q_f32(c + row0 * ldc + 8, c2); \ + vst1q_f32(c + row1 * ldc, c3); \ + vst1q_f32(c + row1 * ldc + 4, c4); \ + vst1q_f32(c + row1 * ldc + 8, c5); - auto fn = [&](const Range &r) { - char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size)); // TODO: use AutoBuffer - const char *packed_b_ = packed_B; - int start = r.start; - int end = r.end; + FAST_GEMM_FINALE(0, 1); + FAST_GEMM_FINALE(2, 3); + FAST_GEMM_FINALE(4, 5); + FAST_GEMM_FINALE(6, 7); +#undef FAST_GEMM_FINALE +} - for (int tile_idx = start; tile_idx < end; tile_idx++) { - int i0 = (tile_idx / n_tiles) * MC; - int j0 = (tile_idx % n_tiles) * NC; - int mc = M - i0 < MC ? M - i0 : MC; - int nc = N - j0 < NC ? N - j0 : NC; - int ldc_block = ldc; - char* c_block = C + (i0 * ldc + j0) * esz; - packed_b_ = packed_B + j0 * K * esz; +#elif CV_AVX // AVX and AVX2 (16 x 256-bit registers) - if (beta == 0.f) { - for(int i = 0; i < mc; i++) - memset(c_block + i * ldc_block * esz, 0, nc * esz); - } else if (beta != 1.f) { - for(int i = 0; i < mc; i++) { - float* c_i = (float*)c_block + i * ldc_block; - for(int j = 0; j < nc; j++) - c_i[j] *= beta; - } - } +FAST_GEMM_IMPLEMENT_PACK(8, _f32, float, float) // a packer +FAST_GEMM_IMPLEMENT_PACK(12, _f32, float, float) // b packer - int _nc = static_cast((nc + GEMM_NR - 1) / GEMM_NR) * GEMM_NR * esz; - for(int k0 = 0; k0 < K; k0 += KC) - { - int kc = K - k0 < KC ? K - k0 : KC; - fast_gemm_pack12_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); - fast_gemm_macro_kernel(mc, nc, kc, packed_a, packed_b_, alpha, c_block, ldc_block, esz); - packed_b_ += _nc * kc; - } - } +#if !CV_FMA3 // AVX workaround for FMA +#undef _mm256_fmadd_ps +#define _mm256_fmadd_ps(a, b, c) _mm256_add_ps(c, _mm256_mul_ps(a, b)) +#endif - if (!use_stackbuff) { - free(packed_a); - } - }; +static inline void fast_gemm12x8_f32(int k, const char *a_, const char *b_, char *c_, int ldc, float alpha) { + const float* a = (const float*)a_; + const float* b = (const float*)b_; + float* c = (float*)c_; - int total = total_tiles; - int cost_per_thread = static_cast((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); - double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); - parallel_for_(Range(0, total), fn, nstripes); -} + __m256 s00 = _mm256_setzero_ps(), + s10 = _mm256_setzero_ps(), + s20 = _mm256_setzero_ps(), + s30 = _mm256_setzero_ps(), + s40 = _mm256_setzero_ps(), + s50 = _mm256_setzero_ps(), + s60 = _mm256_setzero_ps(), + s70 = _mm256_setzero_ps(), + s80 = _mm256_setzero_ps(), + s90 = _mm256_setzero_ps(), + s100 = _mm256_setzero_ps(), + s110 = _mm256_setzero_ps(); + for (int p = 0; p < k; p++, a += FAST_GEMM_F32_MR, b += FAST_GEMM_F32_NR) { + __m256 b0 = _mm256_loadu_ps(b); -#endif // CV_AVX, CV_AVX2 + __m256 a0 = _mm256_set1_ps(*a); + s00 = _mm256_fmadd_ps(b0, a0, s00); + __m256 a1 = _mm256_set1_ps(*(a + 1)); + s10 = _mm256_fmadd_ps(b0, a1, s10); + __m256 a2 = _mm256_set1_ps(*(a + 2)); + s20 = _mm256_fmadd_ps(b0, a2, s20); -// LASX (32 x 256-bit registers) -#if !defined(CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY) && CV_LASX + a0 = _mm256_set1_ps(*(a + 3)); + s30 = _mm256_fmadd_ps(b0, a0, s30); + a1 = _mm256_set1_ps(*(a + 4)); + s40 = _mm256_fmadd_ps(b0, a1, s40); + a2 = _mm256_set1_ps(*(a + 5)); + s50 = _mm256_fmadd_ps(b0, a2, s50); -FAST_GEMM_IMPLEMENT_PACK(12, _f32, float, float) -FAST_GEMM_IMPLEMENT_PACK(16, _f32, float, float) + a0 = _mm256_set1_ps(*(a + 6)); + s60 = _mm256_fmadd_ps(b0, a0, s60); + a1 = _mm256_set1_ps(*(a + 7)); + s70 = _mm256_fmadd_ps(b0, a1, s70); + a2 = _mm256_set1_ps(*(a + 8)); + s80 = _mm256_fmadd_ps(b0, a2, s80); -int fastGemmPackBSize(int N, int K) { - int GEMM_NC = FAST_GEMM_F32_NC, GEMM_NR = FAST_GEMM_F32_NR; - int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; + a0 = _mm256_set1_ps(*(a + 9)); + s90 = _mm256_fmadd_ps(b0, a0, s90); + a1 = _mm256_set1_ps(*(a + 10)); + s100 = _mm256_fmadd_ps(b0, a1, s100); + a2 = _mm256_set1_ps(*(a + 11)); + s110 = _mm256_fmadd_ps(b0, a2, s110); + } - return static_cast((N + NC - 1) / NC) * NC * K; + __m256 c0, c1, c2, c3, v_alpha = _mm256_set1_ps(alpha); +#define FAST_GEMM_FINALE(row0, row1, row2, row3) \ + c0 = _mm256_loadu_ps(c + row0 * ldc); \ + c1 = _mm256_loadu_ps(c + row1 * ldc); \ + c2 = _mm256_loadu_ps(c + row2 * ldc); \ + c3 = _mm256_loadu_ps(c + row3 * ldc); \ + c0 = _mm256_fmadd_ps(s##row0##0, v_alpha, c0); \ + c1 = _mm256_fmadd_ps(s##row1##0, v_alpha, c1); \ + c2 = _mm256_fmadd_ps(s##row2##0, v_alpha, c2); \ + c3 = _mm256_fmadd_ps(s##row3##0, v_alpha, c3); \ + _mm256_storeu_ps(c + row0 * ldc, c0); \ + _mm256_storeu_ps(c + row1 * ldc, c1); \ + _mm256_storeu_ps(c + row2 * ldc, c2); \ + _mm256_storeu_ps(c + row3 * ldc, c3); \ + + FAST_GEMM_FINALE(0, 1, 2, 3); + FAST_GEMM_FINALE(4, 5, 6, 7); + FAST_GEMM_FINALE(8, 9, 10, 11); +#undef FAST_GEMM_FINALE } -void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0, int ldb1, int esz) { - int GEMM_NC = FAST_GEMM_F32_NC, GEMM_NR = FAST_GEMM_F32_NR; - int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; - int KC = std::min(FAST_GEMM_F32_PACKED_STRIDE_K, K); +#elif CV_LASX // LASX (32 x 256-bit registers) - int n_tiles = (N + NC - 1) / NC; - for (int r = 0; r < n_tiles; ++r) { - int j0 = r * NC; - int nc = N - j0 < NC ? N - j0 : NC; - int _nc = static_cast((nc + GEMM_NR - 1) / GEMM_NR) * GEMM_NR * esz; - for (int k = 0; k < K; k += KC) { - int kc = K - k < KC ? K - k : KC; - fast_gemm_pack16_f32(nc, kc, B + (k * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_B); - packed_B += _nc * kc; - } - } -} +FAST_GEMM_IMPLEMENT_PACK(12, _f32, float, float) // a packer +FAST_GEMM_IMPLEMENT_PACK(16, _f32, float, float) // b packer -static void fast_gemm12x16_f32(int k, const char *a_, const char *b_, char *c_, int ldc, float alpha) { +static inline void fast_gemm12x16_f32(int k, const char *a_, const char *b_, char *c_, int ldc, float alpha) { const float* a = (const float*)a_; const float* b = (const float*)b_; float* c = (float*)c_; @@ -889,9 +395,99 @@ static void fast_gemm12x16_f32(int k, const char *a_, const char *b_, char *c_, #undef FAST_GEMM_FINALE } -static void fast_gemm_macro_kernel(int m, int n, int k, - const char *packed_A, const char *packed_B, - float alpha, char *c, int ldc0, int esz) { +#elif CV_SIMD128 // armv7: 16 x 128-bit registers + +FAST_GEMM_IMPLEMENT_PACK(8, _f32, float, float) // a packer +FAST_GEMM_IMPLEMENT_PACK(12, _f32, float, float) // b packer + +static inline void fast_gemm8x12_f32(int k, const char *a_, const char *b_, + char *c_, int ldc, float alpha) { + const float* a = (const float*)a_; + const float* b = (const float*)b_; + float* c = (float*)c_; + + v_float32x4 s00 = v_setzero_f32(), s01 = s00, s02 = s00; + v_float32x4 s10 = s00, s11 = s00, s12 = s00; + v_float32x4 s20 = s00, s21 = s00, s22 = s00; + v_float32x4 s30 = s00, s31 = s00, s32 = s00; + v_float32x4 s40 = s00, s41 = s00, s42 = s00; + v_float32x4 s50 = s00, s51 = s00, s52 = s00; + v_float32x4 s60 = s00, s61 = s00, s62 = s00; + v_float32x4 s70 = s00, s71 = s00, s72 = s00; + + for(int p = 0; p < k; p++, a += FAST_GEMM_F32_MR, b += FAST_GEMM_F32_NR) { + v_float32x4 b0 = v_load(b), b1 = v_load(b + 4), b2 = v_load(b + 8); + + v_float32x4 a0 = v_setall_f32(*a); + s00 = v_fma(b0, a0, s00); + s01 = v_fma(b1, a0, s01); + s02 = v_fma(b2, a0, s02); + v_float32x4 a1 = v_setall_f32(*(a + 1)); + s10 = v_fma(b0, a1, s10); + s11 = v_fma(b1, a1, s11); + s12 = v_fma(b2, a1, s12); + + v_float32x4 a2 = v_setall_f32(*(a + 2)); + s20 = v_fma(b0, a2, s20); + s21 = v_fma(b1, a2, s21); + s22 = v_fma(b2, a2, s22); + v_float32x4 a3 = v_setall_f32(*(a + 3)); + s30 = v_fma(b0, a3, s30); + s31 = v_fma(b1, a3, s31); + s32 = v_fma(b2, a3, s32); + + a0 = v_setall_f32(*(a + 4)); + s40 = v_fma(b0, a0, s40); + s41 = v_fma(b1, a0, s41); + s42 = v_fma(b2, a0, s42); + a1 = v_setall_f32(*(a + 5)); + s50 = v_fma(b0, a1, s50); + s51 = v_fma(b1, a1, s51); + s52 = v_fma(b2, a1, s52); + + a2 = v_setall_f32(*(a + 6)); + s60 = v_fma(b0, a2, s60); + s61 = v_fma(b1, a2, s61); + s62 = v_fma(b2, a2, s62); + a3 = v_setall_f32(*(a + 7)); + s70 = v_fma(b0, a3, s70); + s71 = v_fma(b1, a3, s71); + s72 = v_fma(b2, a3, s72); + } + + v_float32x4 c0, c1, c2, c3, c4, c5, v_alpha = v_setall_f32(alpha); +#define FAST_GEMM_FINALE(row0, row1) \ + c0 = v_load(c + row0 * ldc); \ + c1 = v_load(c + row0 * ldc + 4); \ + c2 = v_load(c + row0 * ldc + 8); \ + c3 = v_load(c + row1 * ldc); \ + c4 = v_load(c + row1 * ldc + 4); \ + c5 = v_load(c + row1 * ldc + 8); \ + c0 = v_fma(s##row0##0, v_alpha, c0); \ + c1 = v_fma(s##row0##1, v_alpha, c1); \ + c2 = v_fma(s##row0##2, v_alpha, c2); \ + c3 = v_fma(s##row1##0, v_alpha, c3); \ + c4 = v_fma(s##row1##1, v_alpha, c4); \ + c5 = v_fma(s##row1##2, v_alpha, c5); \ + v_store(c + row0 * ldc, c0); \ + v_store(c + row0 * ldc + 4, c1); \ + v_store(c + row0 * ldc + 8, c2); \ + v_store(c + row1 * ldc, c3); \ + v_store(c + row1 * ldc + 4, c4); \ + v_store(c + row1 * ldc + 8, c5); + + FAST_GEMM_FINALE(0, 1); + FAST_GEMM_FINALE(2, 3); + FAST_GEMM_FINALE(4, 5); + FAST_GEMM_FINALE(6, 7); +#undef FAST_GEMM_FINALE +} + +#endif + +static inline void fast_gemm_macro_kernel(int m, int n, int k, + const char *packed_A, const char *packed_B, + float alpha, char *c, int ldc0, int esz) { int ldc0_esz = ldc0 * esz; double tempC[FAST_GEMM_F32_MR * FAST_GEMM_F32_NR]; // make sure the buffer is big enough @@ -911,7 +507,15 @@ static void fast_gemm_macro_kernel(int m, int n, int k, for(int p = 0; p < mr; p++) memcpy(cptr + p * (ldc * esz), cptr0 + p * ldc0_esz, nr_esz); } +#if CV_NEON && CV_NEON_AARCH64 + fast_gemm8x12_f32(k, packed_A + i * k * esz, packed_B + j * k * esz, cptr, ldc, alpha); +#elif CV_AVX + fast_gemm12x8_f32(k, packed_A + i * k * esz, packed_B + j * k * esz, cptr, ldc, alpha); +#elif CV_LASX fast_gemm12x16_f32(k, packed_A + i * k * esz, packed_B + j * k * esz, cptr, ldc, alpha); +#elif CV_SIMD128 + fast_gemm8x12_f32(k, packed_A + i * k * esz, packed_B + j * k * esz, cptr, ldc, alpha); +#endif if (partial) { for(int p = 0; p < mr; p++) @@ -921,6 +525,39 @@ static void fast_gemm_macro_kernel(int m, int n, int k, } } +int fastGemmPackBSize(int N, int K) { + int GEMM_NC = FAST_GEMM_F32_NC, GEMM_NR = FAST_GEMM_F32_NR; + int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; + + return static_cast((N + NC - 1) / NC) * NC * K; +} + +void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0, int ldb1, int esz) { + int GEMM_NC = FAST_GEMM_F32_NC, GEMM_NR = FAST_GEMM_F32_NR; + int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; + int KC = std::min(FAST_GEMM_F32_PACKED_STRIDE_K, K); + + int n_tiles = (N + NC - 1) / NC; + for (int r = 0; r < n_tiles; ++r) { + int j0 = r * NC; + int nc = N - j0 < NC ? N - j0 : NC; + int _nc = static_cast((nc + GEMM_NR - 1) / GEMM_NR) * GEMM_NR * esz; + for (int k = 0; k < K; k += KC) { + int kc = K - k < KC ? K - k : KC; +#if CV_NEON && CV_NEON_AARCH64 + fast_gemm_pack12_f32(nc, kc, B + (k * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_B); +#elif CV_AVX + fast_gemm_pack8_f32(nc, kc, B + (k * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_B); +#elif CV_LASX + fast_gemm_pack16_f32(nc, kc, B + (k * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_B); +#elif CV_SIMD128 + fast_gemm_pack12_f32(nc, kc, B + (k * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_B); +#endif + packed_B += _nc * kc; + } + } +} + void fastGemmKernel(int M, int N, int K, float alpha, const char *A, int lda0, int lda1, const char *B, int ldb0, int ldb1, @@ -970,8 +607,29 @@ void fastGemmKernel(int M, int N, int K, for(int k0 = 0; k0 < K; k0 += KC) { int kc = K - k0 < KC ? K - k0 : KC; + // pack a +#if CV_NEON && CV_NEON_AARCH64 + fast_gemm_pack8_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); +#elif CV_AVX + fast_gemm_pack12_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); +#elif CV_LASX fast_gemm_pack12_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); +#elif CV_SIMD128 + fast_gemm_pack8_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); +#endif + + // pack b +#if CV_NEON && CV_NEON_AARCH64 + fast_gemm_pack12_f32(nc, kc, B + (k0 * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_b); +#elif CV_AVX + fast_gemm_pack8_f32(nc, kc, B + (k0 * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_b); +#elif CV_LASX fast_gemm_pack16_f32(nc, kc, B + (k0 * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_b); +#elif CV_SIMD128 + fast_gemm_pack12_f32(nc, kc, B + (k0 * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_b); +#endif + + // run kernel fast_gemm_macro_kernel(mc, nc, kc, packed_a, packed_b, alpha, c_block, ldc_block, esz); } } @@ -1035,7 +693,18 @@ void fastGemmKernel(int M, int N, int K, for(int k0 = 0; k0 < K; k0 += KC) { int kc = K - k0 < KC ? K - k0 : KC; + // pack a +#if CV_NEON && CV_NEON_AARCH64 + fast_gemm_pack8_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); +#elif CV_AVX + fast_gemm_pack12_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); +#elif CV_LASX fast_gemm_pack12_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); +#elif CV_SIMD128 + fast_gemm_pack8_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); +#endif + + // run kernel fast_gemm_macro_kernel(mc, nc, kc, packed_a, packed_b_, alpha, c_block, ldc_block, esz); packed_b_ += _nc * kc; } @@ -1052,8 +721,37 @@ void fastGemmKernel(int M, int N, int K, parallel_for_(Range(0, total), fn, nstripes); } -#endif // CV_LASX +#endif // CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY CV_CPU_OPTIMIZATION_NAMESPACE_END }} // cv::dnn + +#undef FAST_GEMM_STORAGE +#undef FAST_GEMM_MAX_STACKBUF +#ifdef FAST_GEMM_F32_MC +#undef FAST_GEMM_F32_MC +#endif +#ifdef FAST_GEMM_F32_NC +#undef FAST_GEMM_F32_NC +#endif +#ifdef FAST_GEMM_F32_MR +#undef FAST_GEMM_F32_MR +#endif +#ifdef FAST_GEMM_F32_NR +#undef FAST_GEMM_F32_NR +#endif +#ifdef FAST_GEMM_F32_PACKED_STRIDE_K +#undef FAST_GEMM_F32_PACKED_STRIDE_K +#endif +#undef FAST_GEMM_IMPLEMENT_PACK +#undef FAST_GEMM_LOAD_TO_BUF_8 +#undef FAST_GEMM_LOAD_TO_BUF_BORDERS_8 +#undef FAST_GEMM_LOAD_TO_BUF_12 +#undef FAST_GEMM_LOAD_TO_BUF_BORDERS_12 +#undef FAST_GEMM_LOAD_TO_BUF_16 +#undef FAST_GEMM_LOAD_TO_BUF_BORDERS_16 +#undef FAST_GEMM_PACK_COPY +#undef FAST_GEMM_PACK_f32_8 +#undef FAST_GEMM_PACK_f32_12 +#undef FAST_GEMM_PACK_f32_16 diff --git a/modules/dnn/src/layers/gemm_layer.cpp b/modules/dnn/src/layers/gemm_layer.cpp index 0a58abce5d..a553f97568 100644 --- a/modules/dnn/src/layers/gemm_layer.cpp +++ b/modules/dnn/src/layers/gemm_layer.cpp @@ -191,7 +191,6 @@ public: size_t dims_Y = shape_Y.size(); int M = shape_Y[dims_Y - 2], N = shape_Y[dims_Y - 1]; int K = trans_a ? ma : na; - int batches = std::accumulate(shape_A.begin(), shape_A.end() - 2, 1, std::multiplies()); // broadcast C and copy C to output if (have_bias) { @@ -201,9 +200,7 @@ public: int step = M * N; CV_CheckEQ(broadcast_C.size(), static_cast(step), "DNN/Gemm: C is not broadcast properly"); float *ptr_y = Y.ptr(); - for (int i = 0; i < batches; i++) { - std::memcpy(ptr_y + i * step, broadcast_C.data(), step * sizeof(float)); - } + std::memcpy(ptr_y, broadcast_C.data(), step * sizeof(float)); } else { // initialization float *ptr_y = Y.ptr(); size_t total = Y.total(); @@ -212,7 +209,6 @@ public: if (const_B) { CV_CheckGT(packed_B.size(), static_cast(0), "DNN/Gemm: constant B is not pre-packed"); - M *= batches; fastGemm(trans_a, M, N, K, alpha, A.ptr(), na, packed_B.data(), 1.f, Y.ptr(), N, opt); } else { fastGemmBatched(trans_a, trans_b, alpha, A, inputs[1], 1.f, Y, opt); diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 37a4afcf8e..6aa6dc672e 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -2675,37 +2675,37 @@ TEST_P(Test_ONNX_layers, where_node) testONNXModels("where_layer"); } -TEST_P(Test_ONNX_layers, Conformance_Gemm_all_attributes) { +TEST_P(Test_ONNX_layers, Gemm_all_attributes) { testONNXModels("test_gemm_all_attributes", pb, 0, 0, false, true, 2); } -TEST_P(Test_ONNX_layers, Conformance_Gemm_alpha) { +TEST_P(Test_ONNX_layers, Gemm_alpha) { testONNXModels("test_gemm_alpha", pb, 0, 0, false, true, 2); } -TEST_P(Test_ONNX_layers, Conformance_Gemm_beta) { +TEST_P(Test_ONNX_layers, Gemm_beta) { testONNXModels("test_gemm_beta", pb, 0, 0, false, true, 2); } -TEST_P(Test_ONNX_layers, Conformance_Gemm_default_matrix_bias) { +TEST_P(Test_ONNX_layers, Gemm_default_matrix_bias) { testONNXModels("test_gemm_default_matrix_bias", pb, 0, 0, false, true, 2); } -TEST_P(Test_ONNX_layers, Conformance_Gemm_default_no_bias) { +TEST_P(Test_ONNX_layers, Gemm_default_no_bias) { testONNXModels("test_gemm_default_no_bias", pb, 0, 0, false, true, 2); } -TEST_P(Test_ONNX_layers, Conformance_Gemm_default_scalar_bias) { +TEST_P(Test_ONNX_layers, Gemm_default_scalar_bias) { testONNXModels("test_gemm_default_scalar_bias", pb, 0, 0, false, true, 2); } -TEST_P(Test_ONNX_layers, Conformance_Gemm_default_single_elem_vector_bias) { +TEST_P(Test_ONNX_layers, Gemm_default_single_elem_vector_bias) { testONNXModels("test_gemm_default_single_elem_vector_bias", pb, 0, 0, false, true, 2); } -TEST_P(Test_ONNX_layers, Conformance_Gemm_default_vector_bias) { +TEST_P(Test_ONNX_layers, Gemm_default_vector_bias) { testONNXModels("test_gemm_default_vector_bias", pb, 0, 0, false, true, 2); } -TEST_P(Test_ONNX_layers, Conformance_Gemm_default_zero_bias) { +TEST_P(Test_ONNX_layers, Gemm_default_zero_bias) { testONNXModels("test_gemm_default_zero_bias", pb, 0, 0, false, true, 2); } -TEST_P(Test_ONNX_layers, Conformance_Gemm_transposeA) { +TEST_P(Test_ONNX_layers, Gemm_transposeA) { testONNXModels("test_gemm_transposeA", pb, 0, 0, false, true, 2); } -TEST_P(Test_ONNX_layers, Conformance_Gemm_transposeB) { +TEST_P(Test_ONNX_layers, Gemm_transposeB) { testONNXModels("test_gemm_transposeB", pb, 0, 0, false, true, 2); }