dnn: hotfixes for fast gemm (#24315)

* remove Conformance from test names

* integrate neon optimization into default

* quick fix: define CV_NEON_AARCH64 0 for non NEON platforms

* remove var batch that leads to memory leak

* put neon code back to fast_gemm_kernels.simd

* reorganize code to reduce duplicate code
pull/21407/head^2
Yuantao Feng 1 year ago committed by GitHub
parent 5fb3869775
commit 590f150d5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 220
      modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.default.hpp
  2. 974
      modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.simd.hpp
  3. 6
      modules/dnn/src/layers/gemm_layer.cpp
  4. 22
      modules/dnn/test/test_onnx_importer.cpp

@ -12,16 +12,16 @@
#include <opencv2/core/hal/intrin.hpp> #include <opencv2/core/hal/intrin.hpp>
#include <opencv2/core/utility.hpp> // parallel_for_ #include <opencv2/core/utility.hpp> // parallel_for_
#define FAST_GEMM_DEFAULT_STORAGE (1<<20) // 2^20 #define FAST_GEMM_STORAGE (1<<20) // 2^20
#define FAST_GEMM_DEFAULT_MAX_STACKBUF (1 << 14) #define FAST_GEMM_MAX_STACKBUF (1 << 14)
#define FAST_GEMM_DEFAULT_F32_MC 64 #define FAST_GEMM_F32_MC 64
#define FAST_GEMM_DEFAULT_F32_NC 240 #define FAST_GEMM_F32_NC 240
#define FAST_GEMM_DEFAULT_F32_MR 8 #define FAST_GEMM_F32_MR 8
#define FAST_GEMM_DEFAULT_F32_NR 12 #define FAST_GEMM_F32_NR 12
#define FAST_GEMM_DEFAULT_F32_PACKED_STRIDE_K 256 #define FAST_GEMM_F32_PACKED_STRIDE_K 64
#define FAST_GEMM_DEFAULT_IMPLEMENT_PACK(N, suffix, styp, dtyp) \ #define FAST_GEMM_IMPLEMENT_PACK(N, suffix, styp, dtyp) \
static void fast_gemm_pack##N##suffix( int m, int k, const void* A_, \ static void fast_gemm_pack##N##suffix( int m, int k, const void* A_, \
int lda0, int lda1, void* packA_ ) \ int lda0, int lda1, void* packA_ ) \
{ \ { \
@ -32,47 +32,47 @@ static void fast_gemm_pack##N##suffix( int m, int k, const void* A_, \
const styp* a_ptr = A + lda0*i; \ const styp* a_ptr = A + lda0*i; \
for( int j = 0; j < k*lda1; packA += N, j += lda1 ) \ for( int j = 0; j < k*lda1; packA += N, j += lda1 ) \
{ \ { \
FAST_GEMM_DEFAULT_LOAD_TO_BUF_##N(styp); \ FAST_GEMM_LOAD_TO_BUF_##N(styp); \
FAST_GEMM_DEFAULT_PACK##suffix##_##N(buf, packA); \ FAST_GEMM_PACK##suffix##_##N(buf, packA); \
} \ } \
} else { \ } else { \
const styp* a_ptr[N]; \ const styp* a_ptr[N]; \
for (int k = 0; k < N; k++) a_ptr[k] = A + lda0*(i+k < m ? i+k : i); \ for (int k = 0; k < N; k++) a_ptr[k] = A + lda0*(i+k < m ? i+k : i); \
for( int j = 0; j < k*lda1; packA += N, j += lda1 ) \ for( int j = 0; j < k*lda1; packA += N, j += lda1 ) \
{ \ { \
FAST_GEMM_DEFAULT_LOAD_TO_BUF_BORDERS_##N(styp); \ FAST_GEMM_LOAD_TO_BUF_BORDERS_##N(styp); \
FAST_GEMM_DEFAULT_PACK##suffix##_##N(buf, packA); \ FAST_GEMM_PACK##suffix##_##N(buf, packA); \
} \ } \
} \ } \
} \ } \
} }
#define FAST_GEMM_DEFAULT_LOAD_TO_BUF_8(styp) \ #define FAST_GEMM_LOAD_TO_BUF_8(styp) \
styp buf[] = { \ styp buf[] = { \
a_ptr[j], a_ptr[j+lda0], a_ptr[j+lda0*2], a_ptr[j+lda0*3], \ a_ptr[j], a_ptr[j+lda0], a_ptr[j+lda0*2], a_ptr[j+lda0*3], \
a_ptr[j+lda0*4], a_ptr[j+lda0*5], a_ptr[j+lda0*6], a_ptr[j+lda0*7] } a_ptr[j+lda0*4], a_ptr[j+lda0*5], a_ptr[j+lda0*6], a_ptr[j+lda0*7] }
#define FAST_GEMM_DEFAULT_LOAD_TO_BUF_BORDERS_8(styp) \ #define FAST_GEMM_LOAD_TO_BUF_BORDERS_8(styp) \
styp buf[] = { \ styp buf[] = { \
a_ptr[0][j], a_ptr[1][j], a_ptr[2][j], a_ptr[3][j], \ a_ptr[0][j], a_ptr[1][j], a_ptr[2][j], a_ptr[3][j], \
a_ptr[4][j], a_ptr[5][j], a_ptr[6][j], a_ptr[7][j] } a_ptr[4][j], a_ptr[5][j], a_ptr[6][j], a_ptr[7][j] }
#define FAST_GEMM_DEFAULT_LOAD_TO_BUF_12(styp) \ #define FAST_GEMM_LOAD_TO_BUF_12(styp) \
styp buf[] = { \ styp buf[] = { \
a_ptr[j], a_ptr[j+lda0], a_ptr[j+lda0*2], a_ptr[j+lda0*3], \ a_ptr[j], a_ptr[j+lda0], a_ptr[j+lda0*2], a_ptr[j+lda0*3], \
a_ptr[j+lda0*4], a_ptr[j+lda0*5], a_ptr[j+lda0*6], a_ptr[j+lda0*7], \ a_ptr[j+lda0*4], a_ptr[j+lda0*5], a_ptr[j+lda0*6], a_ptr[j+lda0*7], \
a_ptr[j+lda0*8], a_ptr[j+lda0*9], a_ptr[j+lda0*10], a_ptr[j+lda0*11] } a_ptr[j+lda0*8], a_ptr[j+lda0*9], a_ptr[j+lda0*10], a_ptr[j+lda0*11] }
#define FAST_GEMM_DEFAULT_LOAD_TO_BUF_BORDERS_12(styp) \ #define FAST_GEMM_LOAD_TO_BUF_BORDERS_12(styp) \
styp buf[] = { \ styp buf[] = { \
a_ptr[0][j], a_ptr[1][j], a_ptr[2][j], a_ptr[3][j], \ a_ptr[0][j], a_ptr[1][j], a_ptr[2][j], a_ptr[3][j], \
a_ptr[4][j], a_ptr[5][j], a_ptr[6][j], a_ptr[7][j], \ a_ptr[4][j], a_ptr[5][j], a_ptr[6][j], a_ptr[7][j], \
a_ptr[8][j], a_ptr[9][j], a_ptr[10][j], a_ptr[11][j] } a_ptr[8][j], a_ptr[9][j], a_ptr[10][j], a_ptr[11][j] }
#define FAST_GEMM_DEFAULT_PACK_COPY(src, dst, N) \ #define FAST_GEMM_PACK_COPY(src, dst, N) \
memcpy((dst), (src), N*sizeof(src[0])) memcpy((dst), (src), N*sizeof(src[0]))
#define FAST_GEMM_DEFAULT_PACK_f32_8(src, dst) FAST_GEMM_DEFAULT_PACK_COPY((src), (dst), 8) #define FAST_GEMM_PACK_f32_8(src, dst) FAST_GEMM_PACK_COPY((src), (dst), 8)
#define FAST_GEMM_DEFAULT_PACK_f32_12(src, dst) FAST_GEMM_DEFAULT_PACK_COPY((src), (dst), 12) #define FAST_GEMM_PACK_f32_12(src, dst) FAST_GEMM_PACK_COPY((src), (dst), 12)
namespace cv { namespace dnn { namespace cpu_baseline { namespace cv { namespace dnn { namespace cpu_baseline {
@ -88,20 +88,20 @@ void fastGemmKernel(int M, int N, int K,
float alpha, const char *A, int lda0, int lda1, float alpha, const char *A, int lda0, int lda1,
const char *packed_B, float beta, char *C, int ldc, int esz); const char *packed_B, float beta, char *C, int ldc, int esz);
FAST_GEMM_DEFAULT_IMPLEMENT_PACK(8, _f32, float, float) FAST_GEMM_IMPLEMENT_PACK(8, _f32, float, float)
FAST_GEMM_DEFAULT_IMPLEMENT_PACK(12, _f32, float, float) FAST_GEMM_IMPLEMENT_PACK(12, _f32, float, float)
int fastGemmPackBSize(int N, int K) { int fastGemmPackBSize(int N, int K) {
int GEMM_NC = FAST_GEMM_DEFAULT_F32_NC, GEMM_NR = FAST_GEMM_DEFAULT_F32_NR; int GEMM_NC = FAST_GEMM_F32_NC, GEMM_NR = FAST_GEMM_F32_NR;
int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR;
return static_cast<int>((N + NC - 1) / NC) * NC * K; return static_cast<int>((N + NC - 1) / NC) * NC * K;
} }
void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0, int ldb1, int esz) { void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0, int ldb1, int esz) {
int GEMM_NC = FAST_GEMM_DEFAULT_F32_NC, GEMM_NR = FAST_GEMM_DEFAULT_F32_NR; int GEMM_NC = FAST_GEMM_F32_NC, GEMM_NR = FAST_GEMM_F32_NR;
int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR;
int KC = std::min(FAST_GEMM_DEFAULT_F32_PACKED_STRIDE_K, K); int KC = std::min(FAST_GEMM_F32_PACKED_STRIDE_K, K);
int n_tiles = (N + NC - 1) / NC; int n_tiles = (N + NC - 1) / NC;
for (int r = 0; r < n_tiles; ++r) { for (int r = 0; r < n_tiles; ++r) {
@ -116,140 +116,50 @@ void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0,
} }
} }
#if CV_SIMD128 static inline void fast_gemm_f32(int k, const char *a_, const char *b_,
static void fast_gemm8x12_f32(int k, const char *a_, const char *b_,
char *c_, int ldc, float alpha) { char *c_, int ldc, float alpha) {
const float* a = (const float*)a_; const float* a = (const float*)a_;
const float* b = (const float*)b_; const float* b = (const float*)b_;
float* c = (float*)c_; float* c = (float*)c_;
v_float32x4 s00 = v_setzero_f32(), s01 = s00, s02 = s00; float sbuf[FAST_GEMM_F32_MR * FAST_GEMM_F32_NR];
v_float32x4 s10 = s00, s11 = s00, s12 = s00;
v_float32x4 s20 = s00, s21 = s00, s22 = s00;
v_float32x4 s30 = s00, s31 = s00, s32 = s00;
v_float32x4 s40 = s00, s41 = s00, s42 = s00;
v_float32x4 s50 = s00, s51 = s00, s52 = s00;
v_float32x4 s60 = s00, s61 = s00, s62 = s00;
v_float32x4 s70 = s00, s71 = s00, s72 = s00;
for(int p = 0; p < k; p++, a += FAST_GEMM_DEFAULT_F32_MR, b += FAST_GEMM_DEFAULT_F32_NR) {
v_float32x4 b0 = v_load(b), b1 = v_load(b + 4), b2 = v_load(b + 8);
v_float32x4 a0 = v_setall_f32(*a);
s00 = v_fma(b0, a0, s00);
s01 = v_fma(b1, a0, s01);
s02 = v_fma(b2, a0, s02);
v_float32x4 a1 = v_setall_f32(*(a + 1));
s10 = v_fma(b0, a1, s10);
s11 = v_fma(b1, a1, s11);
s12 = v_fma(b2, a1, s12);
v_float32x4 a2 = v_setall_f32(*(a + 2));
s20 = v_fma(b0, a2, s20);
s21 = v_fma(b1, a2, s21);
s22 = v_fma(b2, a2, s22);
v_float32x4 a3 = v_setall_f32(*(a + 3));
s30 = v_fma(b0, a3, s30);
s31 = v_fma(b1, a3, s31);
s32 = v_fma(b2, a3, s32);
a0 = v_setall_f32(*(a + 4));
s40 = v_fma(b0, a0, s40);
s41 = v_fma(b1, a0, s41);
s42 = v_fma(b2, a0, s42);
a1 = v_setall_f32(*(a + 5));
s50 = v_fma(b0, a1, s50);
s51 = v_fma(b1, a1, s51);
s52 = v_fma(b2, a1, s52);
a2 = v_setall_f32(*(a + 6));
s60 = v_fma(b0, a2, s60);
s61 = v_fma(b1, a2, s61);
s62 = v_fma(b2, a2, s62);
a3 = v_setall_f32(*(a + 7));
s70 = v_fma(b0, a3, s70);
s71 = v_fma(b1, a3, s71);
s72 = v_fma(b2, a3, s72);
}
v_float32x4 c0, c1, c2, c3, c4, c5, v_alpha = v_setall_f32(alpha);
#define FAST_GEMM_FINALE(row0, row1) \
c0 = v_load(c + row0 * ldc); \
c1 = v_load(c + row0 * ldc + 4); \
c2 = v_load(c + row0 * ldc + 8); \
c3 = v_load(c + row1 * ldc); \
c4 = v_load(c + row1 * ldc + 4); \
c5 = v_load(c + row1 * ldc + 8); \
c0 = v_fma(s##row0##0, v_alpha, c0); \
c1 = v_fma(s##row0##1, v_alpha, c1); \
c2 = v_fma(s##row0##2, v_alpha, c2); \
c3 = v_fma(s##row1##0, v_alpha, c3); \
c4 = v_fma(s##row1##1, v_alpha, c4); \
c5 = v_fma(s##row1##2, v_alpha, c5); \
v_store(c + row0 * ldc, c0); \
v_store(c + row0 * ldc + 4, c1); \
v_store(c + row0 * ldc + 8, c2); \
v_store(c + row1 * ldc, c3); \
v_store(c + row1 * ldc + 4, c4); \
v_store(c + row1 * ldc + 8, c5);
FAST_GEMM_FINALE(0, 1);
FAST_GEMM_FINALE(2, 3);
FAST_GEMM_FINALE(4, 5);
FAST_GEMM_FINALE(6, 7);
#undef FAST_GEMM_FINALE
}
#else
static void fast_gemm_f32(int k, const char *a_, const char *b_,
char *c_, int ldc, float alpha) {
const float* a = (const float*)a_;
const float* b = (const float*)b_;
float* c = (float*)c_;
float sbuf[FAST_GEMM_DEFAULT_F32_MR * FAST_GEMM_DEFAULT_F32_NR];
memset(sbuf, 0, sizeof(sbuf)); memset(sbuf, 0, sizeof(sbuf));
for(int p = 0; p < k; p++) { for(int p = 0; p < k; p++) {
for( int i = 0; i < FAST_GEMM_DEFAULT_F32_MR; i++ ) { for( int i = 0; i < FAST_GEMM_F32_MR; i++ ) {
float ai = a[FAST_GEMM_DEFAULT_F32_MR * p + i]; float ai = a[FAST_GEMM_F32_MR * p + i];
for( int j = 0; j < FAST_GEMM_DEFAULT_F32_NR; j++ ) for( int j = 0; j < FAST_GEMM_F32_NR; j++ )
sbuf[i * FAST_GEMM_DEFAULT_F32_NR + j] += b[FAST_GEMM_DEFAULT_F32_NR * p + j] * ai; sbuf[i * FAST_GEMM_F32_NR + j] += b[FAST_GEMM_F32_NR * p + j] * ai;
} }
} }
for (int i = 0; i < FAST_GEMM_DEFAULT_F32_MR; i++) { for (int i = 0; i < FAST_GEMM_F32_MR; i++) {
for (int j = 0; j < FAST_GEMM_DEFAULT_F32_NR; j++) for (int j = 0; j < FAST_GEMM_F32_NR; j++)
c[i * ldc + j] += alpha * sbuf[i * FAST_GEMM_DEFAULT_F32_NR + j]; c[i * ldc + j] += alpha * sbuf[i * FAST_GEMM_F32_NR + j];
} }
} }
#endif // CV_SIMD128
static void fast_gemm_macro_kernel(int m, int n, int k, static void fast_gemm_macro_kernel(int m, int n, int k,
const char *packed_A, const char *packed_B, const char *packed_A, const char *packed_B,
float alpha, char *c, int ldc0, int esz) { float alpha, char *c, int ldc0, int esz) {
int ldc0_esz = ldc0 * esz; int ldc0_esz = ldc0 * esz;
double tempC[FAST_GEMM_DEFAULT_F32_MR * FAST_GEMM_DEFAULT_F32_NR]; // make sure the buffer is big enough double tempC[FAST_GEMM_F32_MR * FAST_GEMM_F32_NR]; // make sure the buffer is big enough
for(int i = 0; i < m; i += FAST_GEMM_DEFAULT_F32_MR) { for(int i = 0; i < m; i += FAST_GEMM_F32_MR) {
for(int j = 0; j < n; j += FAST_GEMM_DEFAULT_F32_NR) { for(int j = 0; j < n; j += FAST_GEMM_F32_NR) {
char* cptr0 = &c[i * ldc0_esz + j * esz]; char* cptr0 = &c[i * ldc0_esz + j * esz];
char* cptr = cptr0; char* cptr = cptr0;
int ldc = ldc0; int ldc = ldc0;
int mr = m - i < FAST_GEMM_DEFAULT_F32_MR ? m - i : FAST_GEMM_DEFAULT_F32_MR; int mr = m - i < FAST_GEMM_F32_MR ? m - i : FAST_GEMM_F32_MR;
int nr = n - j < FAST_GEMM_DEFAULT_F32_NR ? n - j : FAST_GEMM_DEFAULT_F32_NR; int nr = n - j < FAST_GEMM_F32_NR ? n - j : FAST_GEMM_F32_NR;
int nr_esz = nr * esz; int nr_esz = nr * esz;
bool partial = (bool)((mr < FAST_GEMM_DEFAULT_F32_MR) | (nr < FAST_GEMM_DEFAULT_F32_NR)); bool partial = (bool)((mr < FAST_GEMM_F32_MR) | (nr < FAST_GEMM_F32_NR));
if (partial) { if (partial) {
memset(tempC, 0, sizeof(tempC)); memset(tempC, 0, sizeof(tempC));
cptr = (char *)tempC; cptr = (char *)tempC;
ldc = FAST_GEMM_DEFAULT_F32_NR; ldc = FAST_GEMM_F32_NR;
for(int p = 0; p < mr; p++) for(int p = 0; p < mr; p++)
memcpy(cptr + p * (ldc * esz), cptr0 + p * ldc0_esz, nr_esz); memcpy(cptr + p * (ldc * esz), cptr0 + p * ldc0_esz, nr_esz);
} }
#if CV_SIMD128
fast_gemm8x12_f32(k, packed_A + i * k * esz, packed_B + j * k * esz, cptr, ldc, alpha);
#else
fast_gemm_f32(k, packed_A + i * k * esz, packed_B + j * k * esz, cptr, ldc, alpha); fast_gemm_f32(k, packed_A + i * k * esz, packed_B + j * k * esz, cptr, ldc, alpha);
#endif
if (partial) { if (partial) {
for(int p = 0; p < mr; p++) for(int p = 0; p < mr; p++)
@ -263,19 +173,19 @@ void fastGemmKernel(int M, int N, int K,
float alpha, const char *A, int lda0, int lda1, float alpha, const char *A, int lda0, int lda1,
const char *B, int ldb0, int ldb1, const char *B, int ldb0, int ldb1,
float beta, char *C, int ldc, int esz) { float beta, char *C, int ldc, int esz) {
int GEMM_MC = FAST_GEMM_DEFAULT_F32_MC, int GEMM_MC = FAST_GEMM_F32_MC,
GEMM_NC = FAST_GEMM_DEFAULT_F32_NC, GEMM_NC = FAST_GEMM_F32_NC,
GEMM_MR = FAST_GEMM_DEFAULT_F32_MR, GEMM_MR = FAST_GEMM_F32_MR,
GEMM_NR = FAST_GEMM_DEFAULT_F32_NR; GEMM_NR = FAST_GEMM_F32_NR;
int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR; int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR;
int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR;
int KC = FAST_GEMM_DEFAULT_STORAGE / ((MC + NC) * esz); int KC = FAST_GEMM_STORAGE / ((MC + NC) * esz);
KC = KC > 8 ? KC : 8; KC = KC > 8 ? KC : 8;
KC = KC < K ? KC : K; KC = KC < K ? KC : K;
size_t buff_size = KC * (MC + NC) * esz; size_t buff_size = KC * (MC + NC) * esz;
bool use_stackbuff = buff_size <= FAST_GEMM_DEFAULT_MAX_STACKBUF; bool use_stackbuff = buff_size <= FAST_GEMM_MAX_STACKBUF;
int m_tiles = (M + MC - 1) / MC; int m_tiles = (M + MC - 1) / MC;
int n_tiles = (N + NC - 1) / NC; int n_tiles = (N + NC - 1) / NC;
int total_tiles = m_tiles * n_tiles; int total_tiles = m_tiles * n_tiles;
@ -328,17 +238,17 @@ void fastGemmKernel(int M, int N, int K,
void fastGemmKernel(int M, int N, int K, void fastGemmKernel(int M, int N, int K,
float alpha, const char *A, int lda0, int lda1, float alpha, const char *A, int lda0, int lda1,
const char *packed_B, float beta, char *C, int ldc, int esz) { const char *packed_B, float beta, char *C, int ldc, int esz) {
int GEMM_MC = FAST_GEMM_DEFAULT_F32_MC, int GEMM_MC = FAST_GEMM_F32_MC,
GEMM_NC = FAST_GEMM_DEFAULT_F32_NC, GEMM_NC = FAST_GEMM_F32_NC,
GEMM_MR = FAST_GEMM_DEFAULT_F32_MR, GEMM_MR = FAST_GEMM_F32_MR,
GEMM_NR = FAST_GEMM_DEFAULT_F32_NR; GEMM_NR = FAST_GEMM_F32_NR;
int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR; int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR;
int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR;
int KC = std::min(FAST_GEMM_DEFAULT_F32_PACKED_STRIDE_K, K); int KC = std::min(FAST_GEMM_F32_PACKED_STRIDE_K, K);
size_t buff_size = KC * MC * esz; size_t buff_size = KC * MC * esz;
bool use_stackbuff = buff_size <= FAST_GEMM_DEFAULT_MAX_STACKBUF; bool use_stackbuff = buff_size <= FAST_GEMM_MAX_STACKBUF;
int m_tiles = (M + MC - 1) / MC; int m_tiles = (M + MC - 1) / MC;
int n_tiles = (N + NC - 1) / NC; int n_tiles = (N + NC - 1) / NC;
int total_tiles = m_tiles * n_tiles; int total_tiles = m_tiles * n_tiles;
@ -391,3 +301,29 @@ void fastGemmKernel(int M, int N, int K,
} }
}}} // cv::dnn::cpu_baseline }}} // cv::dnn::cpu_baseline
#undef FAST_GEMM_STORAGE
#undef FAST_GEMM_MAX_STACKBUF
#ifdef FAST_GEMM_F32_MC
#undef FAST_GEMM_F32_MC
#endif
#ifdef FAST_GEMM_F32_NC
#undef FAST_GEMM_F32_NC
#endif
#ifdef FAST_GEMM_F32_MR
#undef FAST_GEMM_F32_MR
#endif
#ifdef FAST_GEMM_F32_NR
#undef FAST_GEMM_F32_NR
#endif
#ifdef FAST_GEMM_F32_PACKED_STRIDE_K
#undef FAST_GEMM_F32_PACKED_STRIDE_K
#endif
#undef FAST_GEMM_IMPLEMENT_PACK
#undef FAST_GEMM_LOAD_TO_BUF_8
#undef FAST_GEMM_LOAD_TO_BUF_BORDERS_8
#undef FAST_GEMM_LOAD_TO_BUF_12
#undef FAST_GEMM_LOAD_TO_BUF_BORDERS_12
#undef FAST_GEMM_PACK_COPY
#undef FAST_GEMM_PACK_f32_8
#undef FAST_GEMM_PACK_f32_12

@ -191,7 +191,6 @@ public:
size_t dims_Y = shape_Y.size(); size_t dims_Y = shape_Y.size();
int M = shape_Y[dims_Y - 2], N = shape_Y[dims_Y - 1]; int M = shape_Y[dims_Y - 2], N = shape_Y[dims_Y - 1];
int K = trans_a ? ma : na; int K = trans_a ? ma : na;
int batches = std::accumulate(shape_A.begin(), shape_A.end() - 2, 1, std::multiplies<int>());
// broadcast C and copy C to output // broadcast C and copy C to output
if (have_bias) { if (have_bias) {
@ -201,9 +200,7 @@ public:
int step = M * N; int step = M * N;
CV_CheckEQ(broadcast_C.size(), static_cast<size_t>(step), "DNN/Gemm: C is not broadcast properly"); CV_CheckEQ(broadcast_C.size(), static_cast<size_t>(step), "DNN/Gemm: C is not broadcast properly");
float *ptr_y = Y.ptr<float>(); float *ptr_y = Y.ptr<float>();
for (int i = 0; i < batches; i++) { std::memcpy(ptr_y, broadcast_C.data(), step * sizeof(float));
std::memcpy(ptr_y + i * step, broadcast_C.data(), step * sizeof(float));
}
} else { // initialization } else { // initialization
float *ptr_y = Y.ptr<float>(); float *ptr_y = Y.ptr<float>();
size_t total = Y.total(); size_t total = Y.total();
@ -212,7 +209,6 @@ public:
if (const_B) { if (const_B) {
CV_CheckGT(packed_B.size(), static_cast<size_t>(0), "DNN/Gemm: constant B is not pre-packed"); CV_CheckGT(packed_B.size(), static_cast<size_t>(0), "DNN/Gemm: constant B is not pre-packed");
M *= batches;
fastGemm(trans_a, M, N, K, alpha, A.ptr<const float>(), na, packed_B.data(), 1.f, Y.ptr<float>(), N, opt); fastGemm(trans_a, M, N, K, alpha, A.ptr<const float>(), na, packed_B.data(), 1.f, Y.ptr<float>(), N, opt);
} else { } else {
fastGemmBatched(trans_a, trans_b, alpha, A, inputs[1], 1.f, Y, opt); fastGemmBatched(trans_a, trans_b, alpha, A, inputs[1], 1.f, Y, opt);

@ -2675,37 +2675,37 @@ TEST_P(Test_ONNX_layers, where_node)
testONNXModels("where_layer"); testONNXModels("where_layer");
} }
TEST_P(Test_ONNX_layers, Conformance_Gemm_all_attributes) { TEST_P(Test_ONNX_layers, Gemm_all_attributes) {
testONNXModels("test_gemm_all_attributes", pb, 0, 0, false, true, 2); testONNXModels("test_gemm_all_attributes", pb, 0, 0, false, true, 2);
} }
TEST_P(Test_ONNX_layers, Conformance_Gemm_alpha) { TEST_P(Test_ONNX_layers, Gemm_alpha) {
testONNXModels("test_gemm_alpha", pb, 0, 0, false, true, 2); testONNXModels("test_gemm_alpha", pb, 0, 0, false, true, 2);
} }
TEST_P(Test_ONNX_layers, Conformance_Gemm_beta) { TEST_P(Test_ONNX_layers, Gemm_beta) {
testONNXModels("test_gemm_beta", pb, 0, 0, false, true, 2); testONNXModels("test_gemm_beta", pb, 0, 0, false, true, 2);
} }
TEST_P(Test_ONNX_layers, Conformance_Gemm_default_matrix_bias) { TEST_P(Test_ONNX_layers, Gemm_default_matrix_bias) {
testONNXModels("test_gemm_default_matrix_bias", pb, 0, 0, false, true, 2); testONNXModels("test_gemm_default_matrix_bias", pb, 0, 0, false, true, 2);
} }
TEST_P(Test_ONNX_layers, Conformance_Gemm_default_no_bias) { TEST_P(Test_ONNX_layers, Gemm_default_no_bias) {
testONNXModels("test_gemm_default_no_bias", pb, 0, 0, false, true, 2); testONNXModels("test_gemm_default_no_bias", pb, 0, 0, false, true, 2);
} }
TEST_P(Test_ONNX_layers, Conformance_Gemm_default_scalar_bias) { TEST_P(Test_ONNX_layers, Gemm_default_scalar_bias) {
testONNXModels("test_gemm_default_scalar_bias", pb, 0, 0, false, true, 2); testONNXModels("test_gemm_default_scalar_bias", pb, 0, 0, false, true, 2);
} }
TEST_P(Test_ONNX_layers, Conformance_Gemm_default_single_elem_vector_bias) { TEST_P(Test_ONNX_layers, Gemm_default_single_elem_vector_bias) {
testONNXModels("test_gemm_default_single_elem_vector_bias", pb, 0, 0, false, true, 2); testONNXModels("test_gemm_default_single_elem_vector_bias", pb, 0, 0, false, true, 2);
} }
TEST_P(Test_ONNX_layers, Conformance_Gemm_default_vector_bias) { TEST_P(Test_ONNX_layers, Gemm_default_vector_bias) {
testONNXModels("test_gemm_default_vector_bias", pb, 0, 0, false, true, 2); testONNXModels("test_gemm_default_vector_bias", pb, 0, 0, false, true, 2);
} }
TEST_P(Test_ONNX_layers, Conformance_Gemm_default_zero_bias) { TEST_P(Test_ONNX_layers, Gemm_default_zero_bias) {
testONNXModels("test_gemm_default_zero_bias", pb, 0, 0, false, true, 2); testONNXModels("test_gemm_default_zero_bias", pb, 0, 0, false, true, 2);
} }
TEST_P(Test_ONNX_layers, Conformance_Gemm_transposeA) { TEST_P(Test_ONNX_layers, Gemm_transposeA) {
testONNXModels("test_gemm_transposeA", pb, 0, 0, false, true, 2); testONNXModels("test_gemm_transposeA", pb, 0, 0, false, true, 2);
} }
TEST_P(Test_ONNX_layers, Conformance_Gemm_transposeB) { TEST_P(Test_ONNX_layers, Gemm_transposeB) {
testONNXModels("test_gemm_transposeB", pb, 0, 0, false, true, 2); testONNXModels("test_gemm_transposeB", pb, 0, 0, false, true, 2);
} }

Loading…
Cancel
Save