mirror of https://github.com/opencv/opencv.git
dnn: add gemm_layer in place of fully_connected_layer for onnx models (#23897)
* first commit * turned C from input to constant; force C constant in impl; better handling 0d/1d cases * integrate with gemm from ficus nn * fix const inputs * adjust threshold for int8 tryQuantize * adjust threshold for int8 quantized 2 * support batched gemm and matmul; tune threshold for rcnn_ilsvrc13; update googlenet * add gemm perf against innerproduct * add perf tests for innerproduct with bias * fix perf * add memset * renamings for next step * add dedicated perf gemm * add innerproduct in perf_gemm * remove gemm and innerproduct perf tests from perf_layer * add perf cases for vit sizes; prepack constants * remove batched gemm; fix wrong trans; optimize KC * remove prepacking for const A; several fixes for const B prepacking * add todos and gemm expression * add optimized branch for avx/avx2 * trigger build * update macros and signature * update signature * fix macro * fix bugs for neon aarch64 & x64 * add backends: cuda, cann, inf_ngraph and vkcom * fix cuda backend * test commit for cuda * test cuda backend * remove debug message from cuda backend * use cpu dispatcher * fix neon macro undef in dispatcher * fix dispatcher * fix inner kernel for neon aarch64 * fix compiling issue on armv7; try fixing accuracy issue on other platforms * broadcast C with beta multiplied; improve func namings * fix bug for avx and avx2 * put all platform-specific kernels in dispatcher * fix typos * attempt to fix compile issues on x64 * run old gemm when neon, avx, avx2 are all not available; add kernel for armv7 neon * fix typo * quick fix: add macros for pack4 * quick fix: use vmlaq_f32 for armv7 * quick fix for missing macro of fast gemm pack f32 4 * disable conformance tests when optimized branches are not supported * disable perf tests when optimized branches are not supported * decouple cv_try_neon and cv_neon_aarch64 * drop googlenet_2023; add fastGemmBatched * fix step in fastGemmBatched * cpu: fix initialization ofb; gpu: support batch * quick followup fix for cuda * add default kernels * quick followup fix to avoid macro redef * optmized kernels for lasx * resolve mis-alignment; remove comments * tune performance for x64 platform * tune performance for neon aarch64 * tune for armv7 * comment time consuming tests * quick follow-up fixpull/24302/head
parent
70d7e83dca
commit
8a96e34e33
12 changed files with 2470 additions and 61 deletions
@ -0,0 +1,251 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "perf_precomp.hpp" |
||||
#include <opencv2/dnn/shape_utils.hpp> |
||||
|
||||
namespace opencv_test { |
||||
|
||||
struct GemmParam_t { |
||||
std::vector<int> a_shape; |
||||
std::vector<int> b_shape; |
||||
std::vector<int> c_shape; |
||||
bool trans_a; |
||||
bool trans_b; |
||||
|
||||
GemmParam_t(std::vector<int> a_shape_, std::vector<int> b_shape_, std::vector<int> c_shape_ = {}, bool trans_a_ = false, bool trans_b_ = false) |
||||
: a_shape(a_shape_), b_shape(b_shape_), c_shape(c_shape_), trans_a(trans_a_), trans_b(trans_b_) {} |
||||
}; |
||||
|
||||
// TODO: Dsiable most of the test cases except vision transformers to save time
|
||||
static const GemmParam_t test_gemm_configs[] = { |
||||
// vision transformers cases
|
||||
{ { 768, 768 }, { 768, 768 }, { 768 } }, |
||||
{ { 1024, 1024 }, { 1024, 1024 }, { 1024 } }, |
||||
{ { 50, 768 }, { 768, 2304 } }, |
||||
{ { 197, 768 }, { 768, 2304 } }, |
||||
{ { 50, 1024 }, { 1024, 3072 } }, |
||||
{ { 197, 1024 }, { 1024, 3072 } }, |
||||
|
||||
// these cases are commented to save testing time
|
||||
/*
|
||||
// square mat
|
||||
{ { 64, 64 }, { 64, 64 } }, |
||||
{ { 128, 128 }, { 128, 128 } }, |
||||
{ { 256, 256 }, { 256, 256 } }, |
||||
{ { 512, 512 }, { 512, 512 } }, |
||||
{ { 1024, 1024 }, { 1024, 1024 } }, |
||||
{ { 4096, 4096 }, { 4096, 4096 } }, |
||||
|
||||
// retangular mat
|
||||
{ { 256, 256 }, { 256, 1024 } }, |
||||
{ { 256, 1024 }, { 1024, 256 } }, |
||||
{ { 256, 1024 }, { 1024, 1024 } }, |
||||
{ { 1024, 1024 }, { 1024, 256 } }, |
||||
{ { 1024, 256 }, { 256, 1024 } }, |
||||
{ { 1024, 256 }, { 256, 256 } }, |
||||
|
||||
// with C
|
||||
{ { 256, 256 }, { 256, 256 }, { 256 } }, |
||||
{ { 256, 256 }, { 256, 1024 }, { 1024 } }, |
||||
{ { 256, 1024 }, { 1024, 256 }, { 256 } }, |
||||
{ { 256, 1024 }, { 1024, 1024 }, { 1024 } }, |
||||
{ { 1024, 1024 }, { 1024, 256 }, { 256 } }, |
||||
{ { 1024, 256 }, { 256, 1024 }, { 1024 } }, |
||||
{ { 1024, 256 }, { 256, 256 }, { 256 } }, |
||||
|
||||
// with C and trans_b
|
||||
{ { 256, 256 }, { 256, 256 }, { 256 } , false, true}, |
||||
{ { 256, 1024 }, { 256, 1024 }, { 256 } , false, true}, |
||||
{ { 256, 1024 }, { 1024, 1024 }, { 1024 } , false, true}, |
||||
{ { 1024, 1024 }, { 1024, 1024 }, { 1024 } , false, true}, |
||||
{ { 1024, 256 }, { 1024, 256 }, { 1024 } , false, true}, |
||||
{ { 1024, 256 }, { 256, 256 }, { 256 } , false, true}, |
||||
|
||||
// with C and trans_b and trans_a
|
||||
{ { 256, 256 }, { 256, 256 }, { 256 } , true, true}, |
||||
{ { 1024, 256 }, { 256, 1024 }, { 256 } , true, true}, |
||||
{ { 256, 1024 }, { 1024, 256 }, { 1024 } , true, true}, |
||||
{ { 1024, 1024 }, { 1024, 1024 }, { 1024 } , true, true}, |
||||
*/ |
||||
}; |
||||
|
||||
struct GemmParamId |
||||
{ |
||||
enum { |
||||
GEMM_0 = 0, |
||||
GEMM_LAST = sizeof(test_gemm_configs) / sizeof(test_gemm_configs[0]) |
||||
}; |
||||
int val_; |
||||
GemmParamId(int val = 0) : val_(val) {} |
||||
operator int() const { return val_; } |
||||
static ::testing::internal::ParamGenerator<GemmParamId> all() |
||||
{ |
||||
enum { NUM = (int)GEMM_LAST }; |
||||
GemmParamId v_[NUM]; for (int i = 0; i < NUM; ++i) { v_[i] = GemmParamId(i); } // reduce generated code size
|
||||
return ::testing::ValuesIn(v_, v_ + NUM); |
||||
} |
||||
}; |
||||
|
||||
static inline void PrintTo(const GemmParamId& v, std::ostream* os) |
||||
{ |
||||
CV_Assert((int)v >= 0); CV_Assert((int)v < GemmParamId::GEMM_LAST); |
||||
const GemmParam_t& p = test_gemm_configs[(int)v]; |
||||
|
||||
auto print_shape = [os](const std::vector<int>& shape, const std::string tag) { |
||||
if (shape.empty()) { |
||||
return ; |
||||
} |
||||
|
||||
*os << tag << "=["; |
||||
for (size_t i = 0; i < shape.size(); ++i) { |
||||
if (i == shape.size() - 1) { |
||||
*os << shape[i] << "]"; |
||||
break; |
||||
} |
||||
*os << shape[i] << ", "; |
||||
} |
||||
}; |
||||
|
||||
print_shape(p.a_shape, "A"); |
||||
print_shape(p.b_shape, ", B"); |
||||
print_shape(p.c_shape, ", C"); |
||||
*os << ", trans_a=" << p.trans_a << ", trans_b=" << p.trans_b; |
||||
} |
||||
|
||||
typedef tuple<GemmParamId, tuple<Backend, Target> > GemmTestParam_t; |
||||
typedef TestBaseWithParam<GemmTestParam_t> Gemm; |
||||
|
||||
PERF_TEST_P_(Gemm, gemm) |
||||
{ |
||||
int test_id = (int)get<0>(GetParam()); |
||||
ASSERT_GE(test_id, 0); ASSERT_LT(test_id, GemmParamId::GEMM_LAST); |
||||
const GemmParam_t& params = test_gemm_configs[test_id]; |
||||
auto a_shape = params.a_shape; |
||||
auto b_shape = params.b_shape; |
||||
auto c_shape = params.c_shape; |
||||
auto trans_a = params.trans_a; |
||||
auto trans_b = params.trans_b; |
||||
float alpha = 1.f; |
||||
float beta = 1.f; |
||||
|
||||
Backend backend_id = get<0>(get<1>(GetParam())); |
||||
Target target_id = get<1>(get<1>(GetParam())); |
||||
|
||||
bool have_bias = c_shape.empty() ? false : true; |
||||
|
||||
Mat A(static_cast<int>(a_shape.size()), a_shape.data(), CV_32F); |
||||
randu(A, -1.0f, 1.0f); |
||||
Mat B(static_cast<int>(b_shape.size()), b_shape.data(), CV_32F); |
||||
randu(A, -1.0f, 1.0f); |
||||
|
||||
LayerParams lp; |
||||
lp.type = "Gemm"; |
||||
lp.name = "testLayer"; |
||||
lp.set("transA", trans_a); |
||||
lp.set("transB", trans_b); |
||||
lp.set("alpha", alpha); |
||||
lp.set("beta", beta); |
||||
lp.set("real_ndims_C", static_cast<int>(c_shape.size())); |
||||
|
||||
lp.set("constB", true); |
||||
lp.blobs.push_back(B); |
||||
if (have_bias) { |
||||
Mat C(static_cast<int>(c_shape.size()), c_shape.data(), CV_32F); |
||||
randu(C, -1.0f, 1.0f); |
||||
lp.set("have_bias", true); |
||||
lp.set("constC", true); |
||||
lp.blobs.push_back(C); |
||||
} |
||||
|
||||
Net net; |
||||
int id = net.addLayerToPrev(lp.name, lp.type, lp); |
||||
net.connect(0, 0, id, 0); |
||||
net.setPreferableBackend(backend_id); |
||||
net.setPreferableTarget(target_id); |
||||
|
||||
// warmup
|
||||
{ |
||||
net.setInput(A); |
||||
Mat out = net.forward(); |
||||
} |
||||
|
||||
TEST_CYCLE() |
||||
{ |
||||
Mat res = net.forward(); |
||||
} |
||||
|
||||
SANITY_CHECK_NOTHING(); |
||||
} |
||||
|
||||
PERF_TEST_P_(Gemm, innerproduct) |
||||
{ |
||||
int test_id = (int)get<0>(GetParam()); |
||||
ASSERT_GE(test_id, 0); ASSERT_LT(test_id, GemmParamId::GEMM_LAST); |
||||
const GemmParam_t& params = test_gemm_configs[test_id]; |
||||
auto a_shape = params.a_shape; |
||||
auto b_shape = params.b_shape; |
||||
auto c_shape = params.c_shape; |
||||
auto trans_a = params.trans_a; |
||||
auto trans_b = params.trans_b; |
||||
|
||||
Backend backend_id = get<0>(get<1>(GetParam())); |
||||
Target target_id = get<1>(get<1>(GetParam())); |
||||
|
||||
bool have_bias = c_shape.empty() ? false : true; |
||||
|
||||
Mat A(static_cast<int>(a_shape.size()), a_shape.data(), CV_32F); |
||||
randu(A, -1.0f, 1.0f); |
||||
Mat B(static_cast<int>(b_shape.size()), b_shape.data(), CV_32F); |
||||
randu(A, -1.0f, 1.0f); |
||||
|
||||
LayerParams lp; |
||||
lp.type = "InnerProduct"; |
||||
lp.name = "testLayer"; |
||||
if (trans_a) { |
||||
cv::transpose(A, A); |
||||
} |
||||
if (!trans_b) { |
||||
cv::transpose(B, B); |
||||
} |
||||
lp.blobs.push_back(B); |
||||
lp.set("num_output", B.size[0]); |
||||
if (have_bias) { |
||||
Mat C(static_cast<int>(c_shape.size()), c_shape.data(), CV_32F); |
||||
randu(C, -1.0f, 1.0f); |
||||
lp.blobs.push_back(C); |
||||
lp.set("bias_term", true); |
||||
} else { |
||||
lp.set("bias_term", false); |
||||
} |
||||
|
||||
Net net; |
||||
int id = net.addLayerToPrev(lp.name, lp.type, lp); |
||||
net.connect(0, 0, id, 0); |
||||
net.setPreferableBackend(backend_id); |
||||
net.setPreferableTarget(target_id); |
||||
|
||||
// warmup
|
||||
{ |
||||
std::vector<std::string> input_names(2); |
||||
input_names[0] = "A"; |
||||
net.setInputsNames(input_names); |
||||
net.setInput(A, input_names[0]); |
||||
Mat out = net.forward(); |
||||
} |
||||
|
||||
TEST_CYCLE() |
||||
{ |
||||
Mat res = net.forward(); |
||||
} |
||||
|
||||
SANITY_CHECK_NOTHING(); |
||||
} |
||||
|
||||
INSTANTIATE_TEST_CASE_P(/**/, Gemm, Combine( |
||||
GemmParamId::all(), |
||||
dnnBackendsAndTargets(false, false) // defined in ../test/test_common.hpp
|
||||
)); |
||||
|
||||
} // namespace
|
@ -0,0 +1,262 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
// This file is modified from the ficus (https://github.com/vpisarev/ficus/blob/master/runtime/ficus/impl/gemm.impl.h).
|
||||
// Here is the original license:
|
||||
/*
|
||||
This file is a part of ficus language project. |
||||
See ficus/LICENSE for the licensing terms |
||||
*/ |
||||
|
||||
#include "../../precomp.hpp" |
||||
#include "fast_gemm.hpp" |
||||
|
||||
#define CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY |
||||
#include "fast_gemm_kernels.simd.hpp" |
||||
#include "layers/cpu_kernels/fast_gemm_kernels.simd_declarations.hpp" |
||||
#undef CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY |
||||
#include "fast_gemm_kernels.default.hpp" |
||||
|
||||
namespace cv { namespace dnn { |
||||
|
||||
void fastGemmPackB(const Mat &B, std::vector<float> &packed_B, bool trans, FastGemmOpt &opt) { |
||||
CV_CheckEQ(B.dims, 2, "fastGemmPackB: input mat should be two-dimensional"); |
||||
CV_CheckTypeEQ(B.type(), CV_32F, "fastGemmPackB: only float32 is supported for now"); |
||||
|
||||
auto B_shape = shape(B); |
||||
int K = B_shape[0], N = B_shape[1], ldb0 = N, ldb1 = 1; |
||||
if (trans) { |
||||
std::swap(K, N); |
||||
std::swap(ldb0, ldb1); |
||||
} |
||||
|
||||
#if CV_TRY_NEON |
||||
if (opt.use_neon) { |
||||
int size_packed_B = opt_NEON::fastGemmPackBSize(N, K); |
||||
packed_B.resize(size_packed_B); |
||||
opt_NEON::fastGemmPackBKernel(B.ptr<const char>(), (char *)packed_B.data(), N, K, ldb0, ldb1, B.elemSize()); |
||||
} else |
||||
#endif |
||||
#if CV_TRY_AVX2 |
||||
if (opt.use_avx2) { |
||||
int size_packed_B = opt_AVX2::fastGemmPackBSize(N, K); |
||||
packed_B.resize(size_packed_B); |
||||
opt_AVX2::fastGemmPackBKernel(B.ptr<const char>(), (char *)packed_B.data(), N, K, ldb0, ldb1, B.elemSize()); |
||||
} else |
||||
#endif |
||||
#if CV_TRY_AVX |
||||
if (opt.use_avx) { |
||||
int size_packed_B = opt_AVX::fastGemmPackBSize(N, K); |
||||
packed_B.resize(size_packed_B); |
||||
opt_AVX::fastGemmPackBKernel(B.ptr<const char>(), (char *)packed_B.data(), N, K, ldb0, ldb1, B.elemSize()); |
||||
} else |
||||
#endif |
||||
#if CV_TRY_LASX |
||||
if (opt.use_lasx) { |
||||
int size_packed_B = opt_LASX::fastGemmPackBSize(N, K); |
||||
packed_B.resize(size_packed_B); |
||||
opt_LASX::fastGemmPackBKernel(B.ptr<const char>(), (char *)packed_B.data(), N, K, ldb0, ldb1, B.elemSize()); |
||||
} else |
||||
#endif |
||||
{ |
||||
int size_packed_B = cpu_baseline::fastGemmPackBSize(N, K); |
||||
packed_B.resize(size_packed_B); |
||||
cpu_baseline::fastGemmPackBKernel(B.ptr<const char>(), (char *)packed_B.data(), N, K, ldb0, ldb1, B.elemSize()); |
||||
} |
||||
} |
||||
|
||||
static void fast_gemm_thin(float alpha, float beta, int M, int N, int K, |
||||
const char *a_, int lda0, int lda1, |
||||
const char *b_, int ldb, |
||||
char *c_, int ldc) { |
||||
const float* a = (const float*)a_; |
||||
|
||||
auto fn = [&](const Range &r) { |
||||
for(int start = r.start ; start < r.end; start++ ) { |
||||
float* c_i = (float*)c_ + start * ldc; |
||||
if (beta == 0.f) |
||||
for(int j = 0; j < N; j++ ) c_i[j] = 0.f; |
||||
else if (beta != 1.f) |
||||
for(int j = 0; j < N; j++ ) c_i[j] *= beta; |
||||
for(int k = 0; k < K; k++ ) { |
||||
const float* b_k = (const float*)b_ + k * ldb; |
||||
float aval = alpha * a[start * lda0 + k * lda1]; |
||||
for(int j = 0; j < N; j++ ) |
||||
c_i[j] += aval * b_k[j]; |
||||
} |
||||
} |
||||
}; |
||||
|
||||
int total = M; // outer loops
|
||||
int cost_per_thread = static_cast<int>(K * N); // inner loops
|
||||
double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); |
||||
parallel_for_(Range(0, total), fn, nstripes); |
||||
} |
||||
|
||||
void fastGemm(bool trans_a, int M, int N, int K, |
||||
float alpha, const float *A, int lda, |
||||
const float *packed_B, float beta, |
||||
float *C, int ldc, FastGemmOpt &opt) { |
||||
int lda0 = lda, lda1 = 1; |
||||
if (trans_a) { |
||||
std::swap(lda0, lda1); |
||||
} |
||||
|
||||
#if CV_TRY_NEON |
||||
if (opt.use_neon) { |
||||
opt_NEON::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, (const char *)packed_B, beta, (char *)C, ldc, sizeof(float)); |
||||
} else |
||||
#endif |
||||
#if CV_TRY_AVX2 |
||||
if (opt.use_avx2) { |
||||
opt_AVX2::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, (const char *)packed_B, beta, (char *)C, ldc, sizeof(float)); |
||||
} else |
||||
#endif |
||||
#if CV_TRY_AVX |
||||
if (opt.use_avx) { |
||||
opt_AVX::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, (const char *)packed_B, beta, (char *)C, ldc, sizeof(float)); |
||||
} else |
||||
#endif |
||||
#if CV_TRY_LASX |
||||
if (opt.use_lasx) { |
||||
opt_LASX::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, (const char *)packed_B, beta, (char *)C, ldc, sizeof(float)); |
||||
} else |
||||
#endif |
||||
{ |
||||
cpu_baseline::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, (const char *)packed_B, beta, (char *)C, ldc, sizeof(float)); |
||||
} |
||||
} |
||||
|
||||
void fastGemm(bool trans_a, bool trans_b, int ma, int na, int mb, int nb, |
||||
float alpha, const float *A, int lda0, int lda1, const float *B, int ldb0, int ldb1, |
||||
float beta, float *C, int ldc, FastGemmOpt &opt) { |
||||
|
||||
const char *a = (const char *)A; |
||||
const char *b = (const char *)B; |
||||
char *c = (char *)C; |
||||
|
||||
int M = trans_a ? na : ma; |
||||
int N = trans_b ? mb : nb; |
||||
int K = trans_a ? ma : na; |
||||
|
||||
if (trans_a) { |
||||
std::swap(lda0, lda1); |
||||
} |
||||
if (trans_b) { |
||||
std::swap(ldb0, ldb1); |
||||
} |
||||
|
||||
if (!trans_b && ldb1 == 1 && (M <= 4 || (uint64_t)M * N * K <= 10000)) { |
||||
return fast_gemm_thin(alpha, beta, M, N, K, a, lda0, lda1, b, ldb0, c, ldc); |
||||
} |
||||
|
||||
#if CV_TRY_NEON |
||||
if (opt.use_neon) { |
||||
opt_NEON::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, |
||||
(const char *)B, ldb0, ldb1, beta, (char *)C, ldc, sizeof(float)); |
||||
} else |
||||
#endif |
||||
#if CV_TRY_AVX2 |
||||
if (opt.use_avx2) { |
||||
opt_AVX2::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, |
||||
(const char *)B, ldb0, ldb1, beta, (char *)C, ldc, sizeof(float)); |
||||
} else |
||||
#endif |
||||
#if CV_TRY_AVX |
||||
if (opt.use_avx) { |
||||
opt_AVX::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, |
||||
(const char *)B, ldb0, ldb1, beta, (char *)C, ldc, sizeof(float)); |
||||
} else |
||||
#endif |
||||
#if CV_TRY_LASX |
||||
if (opt.use_lasx) { |
||||
opt_LASX::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, |
||||
(const char *)B, ldb0, ldb1, beta, (char *)C, ldc, sizeof(float)); |
||||
} else |
||||
#endif |
||||
{ |
||||
cpu_baseline::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, |
||||
(const char *)B, ldb0, ldb1, beta, (char *)C, ldc, sizeof(float)); |
||||
} |
||||
} |
||||
|
||||
void fastGemm(bool trans_a, bool trans_b, |
||||
float alpha, const Mat &A, const Mat &B, |
||||
float beta, Mat &C, FastGemmOpt &opt) { |
||||
CV_CheckTypeEQ(A.type(), CV_32F, "DNN/fastGemm: only support float32 for now"); |
||||
CV_CheckTypeEQ(A.type(), B.type(), "DNN/fastGemm: A and B should have the same type"); |
||||
CV_CheckTypeEQ(B.type(), C.type(), "DNN/fastGemm: B and C should have the same type"); |
||||
|
||||
const auto shape_a = shape(A); |
||||
CV_CheckEQ(shape_a.size(), static_cast<size_t>(2), "DNN/fastGemm: A must be 2-dimensional"); |
||||
const auto shape_b = shape(B); |
||||
CV_CheckEQ(shape_b.size(), static_cast<size_t>(2), "DNN/fastGemm: B must be 2-dimensional"); |
||||
const auto shape_c = shape(C); |
||||
CV_CheckEQ(shape_c.size(), static_cast<size_t>(2), "DNN/fastGemm: C must be 2-dimensional"); |
||||
|
||||
int ma = shape_a[0], na = shape_a[1]; |
||||
int mb = shape_b[0], nb = shape_b[1]; |
||||
|
||||
int lda0 = na, lda1 = 1, ldb0 = nb, ldb1 = 1, ldc = shape_c[1]; |
||||
|
||||
const float *a = A.ptr<const float>(); |
||||
const float *b = B.ptr<const float>(); |
||||
float *c = C.ptr<float>(); |
||||
|
||||
fastGemm(trans_a, trans_b, ma, na, mb, nb, |
||||
alpha, a, lda0, lda1, b, ldb0, ldb1, |
||||
beta, c, ldc, opt); |
||||
} |
||||
|
||||
void fastGemmBatched(bool trans_a, bool trans_b, |
||||
float alpha, const Mat &A, const Mat &B, |
||||
float beta, Mat &C, FastGemmOpt &opt) { |
||||
CV_CheckTypeEQ(A.type(), B.type(), "DNN/fastGemmBatched: A and B should have the same type"); |
||||
CV_CheckTypeEQ(B.type(), C.type(), "DNN/fastGemmBatched: B and C should have the same type"); |
||||
CV_CheckTypeEQ(A.type(), CV_32F, "DNN/fastGemmBatched: only support float32 for now"); |
||||
|
||||
const auto shape_a = shape(A); |
||||
size_t dims_A = shape_a.size(); |
||||
CV_CheckGE(dims_A, static_cast<size_t>(2), "DNN/fastGemmBatched: A must be n-dimensional (n >= 2)"); |
||||
const auto shape_b = shape(B); |
||||
CV_CheckEQ(shape_b.size(), static_cast<size_t>(2), "DNN/fastGemmBatched: B must be 2-dimensional"); |
||||
const auto shape_c = shape(C); |
||||
size_t dims_C = shape_c.size(); |
||||
CV_CheckGE(dims_C, static_cast<size_t>(2), "DNN/fastGemmBatched: C must be n-dimensional (n >= 2)"); |
||||
|
||||
if (trans_a) { |
||||
int ma = shape_a[dims_A - 2], na = shape_a[dims_A - 1]; |
||||
int mb = shape_b[0], nb = shape_b[1]; |
||||
|
||||
int lda0 = na, lda1 = 1, ldb0 = nb, ldb1 = 1, ldc = shape_c[1]; |
||||
|
||||
const float *a = A.ptr<const float>(); |
||||
const float *b = B.ptr<const float>(); |
||||
float *c = C.ptr<float>(); |
||||
|
||||
int batches = std::accumulate(shape_a.begin(), shape_a.end() - 2, 1, std::multiplies<int>()); |
||||
int step_a = ma * na, step_c = na * nb; |
||||
for (int i = 0; i < batches; i++) { |
||||
fastGemm(true, trans_b, ma, na, mb, nb, |
||||
alpha, a + i * step_a, lda0, lda1, b, ldb0, ldb1, |
||||
beta, c + i * step_c, ldc, opt); |
||||
} |
||||
} else { |
||||
int ma = std::accumulate(shape_a.begin(), shape_a.end() - 1, 1, std::multiplies<int>()), |
||||
na = shape_a[dims_A - 1]; |
||||
int mb = shape_b[0], nb = shape_b[1]; |
||||
|
||||
int lda0 = na, lda1 = 1, ldb0 = nb, ldb1 = 1, ldc = shape_c[1]; |
||||
|
||||
const float *a = A.ptr<const float>(); |
||||
const float *b = B.ptr<const float>(); |
||||
float *c = C.ptr<float>(); |
||||
|
||||
fastGemm(false, trans_b, ma, na, mb, nb, |
||||
alpha, a, lda0, lda1, b, ldb0, ldb1, |
||||
beta, c, ldc, opt); |
||||
} |
||||
} |
||||
|
||||
}} // cv::dnn
|
@ -0,0 +1,65 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
// This file is modified from the ficus (https://github.com/vpisarev/ficus/blob/master/runtime/ficus/impl/gemm.impl.h).
|
||||
// Here is the original license:
|
||||
/*
|
||||
This file is a part of ficus language project. |
||||
See ficus/LICENSE for the licensing terms |
||||
*/ |
||||
|
||||
#ifndef OPENCV_DNN_FAST_GEMM_HPP |
||||
#define OPENCV_DNN_FAST_GEMM_HPP |
||||
|
||||
#include "opencv2/core/hal/intrin.hpp" |
||||
#include <opencv2/dnn/shape_utils.hpp> |
||||
|
||||
namespace cv { namespace dnn { |
||||
|
||||
struct FastGemmOpt { |
||||
bool use_avx; |
||||
bool use_avx2; |
||||
bool use_neon; |
||||
bool use_lasx; |
||||
|
||||
FastGemmOpt() { |
||||
use_avx = false; |
||||
use_avx2 = false; |
||||
use_neon = false; |
||||
use_lasx = false; |
||||
} |
||||
|
||||
void init() { |
||||
use_avx = checkHardwareSupport(CPU_AVX); |
||||
use_avx2 = checkHardwareSupport(CPU_AVX2); |
||||
use_neon = checkHardwareSupport(CPU_NEON); |
||||
use_lasx = checkHardwareSupport(CPU_LASX); |
||||
} |
||||
|
||||
bool all() { |
||||
return use_avx || use_avx2 || use_neon || use_lasx; |
||||
} |
||||
}; |
||||
|
||||
void fastGemmPackB(const Mat &m, std::vector<float> &packed_B, bool trans, FastGemmOpt &opt); |
||||
|
||||
void fastGemm(bool trans_a, int M, int N, int K, |
||||
float alpha, const float *A, int lda, |
||||
const float *packed_B, float beta, |
||||
float *C, int ldc, FastGemmOpt &opt); |
||||
void fastGemm(bool trans_a, bool trans_b, int ma, int na, int mb, int nb, |
||||
float alpha, const float *A, int lda0, int lda1, const float *B, int ldb0, int ldb1, |
||||
float beta, float *C, int ldc, FastGemmOpt &opt); |
||||
void fastGemm(bool trans_a, bool trans_b, |
||||
float alpha, const Mat &A, const Mat &B, |
||||
float beta, Mat &C, FastGemmOpt &opt); |
||||
|
||||
// FIXME: B needs to 2d for now. Support nd (n>=2) B in the future.
|
||||
void fastGemmBatched(bool trans_a, bool trans_b, |
||||
float alpha, const Mat &A, const Mat &B, |
||||
float beta, Mat &C, FastGemmOpt &opt); |
||||
|
||||
}} // cv::dnn
|
||||
|
||||
#endif // OPENCV_DNN_FAST_GEMM_HPP
|
@ -0,0 +1,393 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
// This file is modified from the ficus (https://github.com/vpisarev/ficus/blob/master/runtime/ficus/impl/gemm.impl.h).
|
||||
// Here is the original license:
|
||||
/*
|
||||
This file is a part of ficus language project. |
||||
See ficus/LICENSE for the licensing terms |
||||
*/ |
||||
|
||||
#include <opencv2/core/hal/intrin.hpp> |
||||
#include <opencv2/core/utility.hpp> // parallel_for_ |
||||
|
||||
#define FAST_GEMM_DEFAULT_STORAGE (1<<20) // 2^20
|
||||
#define FAST_GEMM_DEFAULT_MAX_STACKBUF (1 << 14) |
||||
|
||||
#define FAST_GEMM_DEFAULT_F32_MC 64 |
||||
#define FAST_GEMM_DEFAULT_F32_NC 240 |
||||
#define FAST_GEMM_DEFAULT_F32_MR 8 |
||||
#define FAST_GEMM_DEFAULT_F32_NR 12 |
||||
#define FAST_GEMM_DEFAULT_F32_PACKED_STRIDE_K 256 |
||||
|
||||
#define FAST_GEMM_DEFAULT_IMPLEMENT_PACK(N, suffix, styp, dtyp) \ |
||||
static void fast_gemm_pack##N##suffix( int m, int k, const void* A_, \
|
||||
int lda0, int lda1, void* packA_ ) \
|
||||
{ \
|
||||
const styp* A = (const styp*)A_; \
|
||||
dtyp* packA = (dtyp*)packA_; \
|
||||
for( int i = 0; i < m; i += N ) { \
|
||||
if (i + N-1 < m) { \
|
||||
const styp* a_ptr = A + lda0*i; \
|
||||
for( int j = 0; j < k*lda1; packA += N, j += lda1 ) \
|
||||
{ \
|
||||
FAST_GEMM_DEFAULT_LOAD_TO_BUF_##N(styp); \
|
||||
FAST_GEMM_DEFAULT_PACK##suffix##_##N(buf, packA); \
|
||||
} \
|
||||
} else { \
|
||||
const styp* a_ptr[N]; \
|
||||
for (int k = 0; k < N; k++) a_ptr[k] = A + lda0*(i+k < m ? i+k : i); \
|
||||
for( int j = 0; j < k*lda1; packA += N, j += lda1 ) \
|
||||
{ \
|
||||
FAST_GEMM_DEFAULT_LOAD_TO_BUF_BORDERS_##N(styp); \
|
||||
FAST_GEMM_DEFAULT_PACK##suffix##_##N(buf, packA); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} |
||||
|
||||
#define FAST_GEMM_DEFAULT_LOAD_TO_BUF_8(styp) \ |
||||
styp buf[] = { \
|
||||
a_ptr[j], a_ptr[j+lda0], a_ptr[j+lda0*2], a_ptr[j+lda0*3], \
|
||||
a_ptr[j+lda0*4], a_ptr[j+lda0*5], a_ptr[j+lda0*6], a_ptr[j+lda0*7] } |
||||
|
||||
#define FAST_GEMM_DEFAULT_LOAD_TO_BUF_BORDERS_8(styp) \ |
||||
styp buf[] = { \
|
||||
a_ptr[0][j], a_ptr[1][j], a_ptr[2][j], a_ptr[3][j], \
|
||||
a_ptr[4][j], a_ptr[5][j], a_ptr[6][j], a_ptr[7][j] } |
||||
|
||||
#define FAST_GEMM_DEFAULT_LOAD_TO_BUF_12(styp) \ |
||||
styp buf[] = { \
|
||||
a_ptr[j], a_ptr[j+lda0], a_ptr[j+lda0*2], a_ptr[j+lda0*3], \
|
||||
a_ptr[j+lda0*4], a_ptr[j+lda0*5], a_ptr[j+lda0*6], a_ptr[j+lda0*7], \
|
||||
a_ptr[j+lda0*8], a_ptr[j+lda0*9], a_ptr[j+lda0*10], a_ptr[j+lda0*11] } |
||||
|
||||
#define FAST_GEMM_DEFAULT_LOAD_TO_BUF_BORDERS_12(styp) \ |
||||
styp buf[] = { \
|
||||
a_ptr[0][j], a_ptr[1][j], a_ptr[2][j], a_ptr[3][j], \
|
||||
a_ptr[4][j], a_ptr[5][j], a_ptr[6][j], a_ptr[7][j], \
|
||||
a_ptr[8][j], a_ptr[9][j], a_ptr[10][j], a_ptr[11][j] } |
||||
|
||||
#define FAST_GEMM_DEFAULT_PACK_COPY(src, dst, N) \ |
||||
memcpy((dst), (src), N*sizeof(src[0])) |
||||
#define FAST_GEMM_DEFAULT_PACK_f32_8(src, dst) FAST_GEMM_DEFAULT_PACK_COPY((src), (dst), 8) |
||||
#define FAST_GEMM_DEFAULT_PACK_f32_12(src, dst) FAST_GEMM_DEFAULT_PACK_COPY((src), (dst), 12) |
||||
|
||||
namespace cv { namespace dnn { namespace cpu_baseline { |
||||
|
||||
int fastGemmPackBSize(int N, int K); |
||||
|
||||
void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0, int ldb1, int esz); |
||||
|
||||
void fastGemmKernel(int M, int N, int K, |
||||
float alpha, const char *A, int lda0, int lda1, |
||||
const char *B, int ldb0, int ldb1, |
||||
float beta, char *C, int ldc, int esz); |
||||
void fastGemmKernel(int M, int N, int K, |
||||
float alpha, const char *A, int lda0, int lda1, |
||||
const char *packed_B, float beta, char *C, int ldc, int esz); |
||||
|
||||
FAST_GEMM_DEFAULT_IMPLEMENT_PACK(8, _f32, float, float) |
||||
FAST_GEMM_DEFAULT_IMPLEMENT_PACK(12, _f32, float, float) |
||||
|
||||
int fastGemmPackBSize(int N, int K) { |
||||
int GEMM_NC = FAST_GEMM_DEFAULT_F32_NC, GEMM_NR = FAST_GEMM_DEFAULT_F32_NR; |
||||
int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; |
||||
|
||||
return static_cast<int>((N + NC - 1) / NC) * NC * K; |
||||
} |
||||
|
||||
void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0, int ldb1, int esz) { |
||||
int GEMM_NC = FAST_GEMM_DEFAULT_F32_NC, GEMM_NR = FAST_GEMM_DEFAULT_F32_NR; |
||||
int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; |
||||
int KC = std::min(FAST_GEMM_DEFAULT_F32_PACKED_STRIDE_K, K); |
||||
|
||||
int n_tiles = (N + NC - 1) / NC; |
||||
for (int r = 0; r < n_tiles; ++r) { |
||||
int j0 = r * NC; |
||||
int nc = N - j0 < NC ? N - j0 : NC; |
||||
int _nc = static_cast<int>((nc + GEMM_NR - 1) / GEMM_NR) * GEMM_NR * esz; |
||||
for (int k = 0; k < K; k += KC) { |
||||
int kc = K - k < KC ? K - k : KC; |
||||
fast_gemm_pack12_f32(nc, kc, B + (k * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_B); |
||||
packed_B += _nc * kc; |
||||
} |
||||
} |
||||
} |
||||
|
||||
#if CV_SIMD128 |
||||
static void fast_gemm8x12_f32(int k, const char *a_, const char *b_, |
||||
char *c_, int ldc, float alpha) { |
||||
const float* a = (const float*)a_; |
||||
const float* b = (const float*)b_; |
||||
float* c = (float*)c_; |
||||
|
||||
v_float32x4 s00 = v_setzero_f32(), s01 = s00, s02 = s00; |
||||
v_float32x4 s10 = s00, s11 = s00, s12 = s00; |
||||
v_float32x4 s20 = s00, s21 = s00, s22 = s00; |
||||
v_float32x4 s30 = s00, s31 = s00, s32 = s00; |
||||
v_float32x4 s40 = s00, s41 = s00, s42 = s00; |
||||
v_float32x4 s50 = s00, s51 = s00, s52 = s00; |
||||
v_float32x4 s60 = s00, s61 = s00, s62 = s00; |
||||
v_float32x4 s70 = s00, s71 = s00, s72 = s00; |
||||
|
||||
for(int p = 0; p < k; p++, a += FAST_GEMM_DEFAULT_F32_MR, b += FAST_GEMM_DEFAULT_F32_NR) { |
||||
v_float32x4 b0 = v_load(b), b1 = v_load(b + 4), b2 = v_load(b + 8); |
||||
|
||||
v_float32x4 a0 = v_setall_f32(*a); |
||||
s00 = v_fma(b0, a0, s00); |
||||
s01 = v_fma(b1, a0, s01); |
||||
s02 = v_fma(b2, a0, s02); |
||||
v_float32x4 a1 = v_setall_f32(*(a + 1)); |
||||
s10 = v_fma(b0, a1, s10); |
||||
s11 = v_fma(b1, a1, s11); |
||||
s12 = v_fma(b2, a1, s12); |
||||
|
||||
v_float32x4 a2 = v_setall_f32(*(a + 2)); |
||||
s20 = v_fma(b0, a2, s20); |
||||
s21 = v_fma(b1, a2, s21); |
||||
s22 = v_fma(b2, a2, s22); |
||||
v_float32x4 a3 = v_setall_f32(*(a + 3)); |
||||
s30 = v_fma(b0, a3, s30); |
||||
s31 = v_fma(b1, a3, s31); |
||||
s32 = v_fma(b2, a3, s32); |
||||
|
||||
a0 = v_setall_f32(*(a + 4)); |
||||
s40 = v_fma(b0, a0, s40); |
||||
s41 = v_fma(b1, a0, s41); |
||||
s42 = v_fma(b2, a0, s42); |
||||
a1 = v_setall_f32(*(a + 5)); |
||||
s50 = v_fma(b0, a1, s50); |
||||
s51 = v_fma(b1, a1, s51); |
||||
s52 = v_fma(b2, a1, s52); |
||||
|
||||
a2 = v_setall_f32(*(a + 6)); |
||||
s60 = v_fma(b0, a2, s60); |
||||
s61 = v_fma(b1, a2, s61); |
||||
s62 = v_fma(b2, a2, s62); |
||||
a3 = v_setall_f32(*(a + 7)); |
||||
s70 = v_fma(b0, a3, s70); |
||||
s71 = v_fma(b1, a3, s71); |
||||
s72 = v_fma(b2, a3, s72); |
||||
} |
||||
|
||||
v_float32x4 c0, c1, c2, c3, c4, c5, v_alpha = v_setall_f32(alpha); |
||||
#define FAST_GEMM_FINALE(row0, row1) \ |
||||
c0 = v_load(c + row0 * ldc); \
|
||||
c1 = v_load(c + row0 * ldc + 4); \
|
||||
c2 = v_load(c + row0 * ldc + 8); \
|
||||
c3 = v_load(c + row1 * ldc); \
|
||||
c4 = v_load(c + row1 * ldc + 4); \
|
||||
c5 = v_load(c + row1 * ldc + 8); \
|
||||
c0 = v_fma(s##row0##0, v_alpha, c0); \
|
||||
c1 = v_fma(s##row0##1, v_alpha, c1); \
|
||||
c2 = v_fma(s##row0##2, v_alpha, c2); \
|
||||
c3 = v_fma(s##row1##0, v_alpha, c3); \
|
||||
c4 = v_fma(s##row1##1, v_alpha, c4); \
|
||||
c5 = v_fma(s##row1##2, v_alpha, c5); \
|
||||
v_store(c + row0 * ldc, c0); \
|
||||
v_store(c + row0 * ldc + 4, c1); \
|
||||
v_store(c + row0 * ldc + 8, c2); \
|
||||
v_store(c + row1 * ldc, c3); \
|
||||
v_store(c + row1 * ldc + 4, c4); \
|
||||
v_store(c + row1 * ldc + 8, c5); |
||||
|
||||
FAST_GEMM_FINALE(0, 1); |
||||
FAST_GEMM_FINALE(2, 3); |
||||
FAST_GEMM_FINALE(4, 5); |
||||
FAST_GEMM_FINALE(6, 7); |
||||
#undef FAST_GEMM_FINALE |
||||
} |
||||
|
||||
#else |
||||
static void fast_gemm_f32(int k, const char *a_, const char *b_, |
||||
char *c_, int ldc, float alpha) { |
||||
const float* a = (const float*)a_; |
||||
const float* b = (const float*)b_; |
||||
float* c = (float*)c_; |
||||
|
||||
float sbuf[FAST_GEMM_DEFAULT_F32_MR * FAST_GEMM_DEFAULT_F32_NR]; |
||||
memset(sbuf, 0, sizeof(sbuf)); |
||||
for(int p = 0; p < k; p++) { |
||||
for( int i = 0; i < FAST_GEMM_DEFAULT_F32_MR; i++ ) { |
||||
float ai = a[FAST_GEMM_DEFAULT_F32_MR * p + i]; |
||||
for( int j = 0; j < FAST_GEMM_DEFAULT_F32_NR; j++ ) |
||||
sbuf[i * FAST_GEMM_DEFAULT_F32_NR + j] += b[FAST_GEMM_DEFAULT_F32_NR * p + j] * ai; |
||||
} |
||||
} |
||||
for (int i = 0; i < FAST_GEMM_DEFAULT_F32_MR; i++) { |
||||
for (int j = 0; j < FAST_GEMM_DEFAULT_F32_NR; j++) |
||||
c[i * ldc + j] += alpha * sbuf[i * FAST_GEMM_DEFAULT_F32_NR + j]; |
||||
} |
||||
} |
||||
#endif // CV_SIMD128
|
||||
|
||||
static void fast_gemm_macro_kernel(int m, int n, int k, |
||||
const char *packed_A, const char *packed_B, |
||||
float alpha, char *c, int ldc0, int esz) { |
||||
int ldc0_esz = ldc0 * esz; |
||||
|
||||
double tempC[FAST_GEMM_DEFAULT_F32_MR * FAST_GEMM_DEFAULT_F32_NR]; // make sure the buffer is big enough
|
||||
for(int i = 0; i < m; i += FAST_GEMM_DEFAULT_F32_MR) { |
||||
for(int j = 0; j < n; j += FAST_GEMM_DEFAULT_F32_NR) { |
||||
char* cptr0 = &c[i * ldc0_esz + j * esz]; |
||||
char* cptr = cptr0; |
||||
int ldc = ldc0; |
||||
int mr = m - i < FAST_GEMM_DEFAULT_F32_MR ? m - i : FAST_GEMM_DEFAULT_F32_MR; |
||||
int nr = n - j < FAST_GEMM_DEFAULT_F32_NR ? n - j : FAST_GEMM_DEFAULT_F32_NR; |
||||
int nr_esz = nr * esz; |
||||
bool partial = (bool)((mr < FAST_GEMM_DEFAULT_F32_MR) | (nr < FAST_GEMM_DEFAULT_F32_NR)); |
||||
if (partial) { |
||||
memset(tempC, 0, sizeof(tempC)); |
||||
cptr = (char *)tempC; |
||||
ldc = FAST_GEMM_DEFAULT_F32_NR; |
||||
for(int p = 0; p < mr; p++) |
||||
memcpy(cptr + p * (ldc * esz), cptr0 + p * ldc0_esz, nr_esz); |
||||
} |
||||
#if CV_SIMD128 |
||||
fast_gemm8x12_f32(k, packed_A + i * k * esz, packed_B + j * k * esz, cptr, ldc, alpha); |
||||
#else |
||||
fast_gemm_f32(k, packed_A + i * k * esz, packed_B + j * k * esz, cptr, ldc, alpha); |
||||
#endif |
||||
|
||||
if (partial) { |
||||
for(int p = 0; p < mr; p++) |
||||
memcpy(cptr0 + p * ldc0_esz, cptr + p * (ldc * esz), nr_esz); |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
void fastGemmKernel(int M, int N, int K, |
||||
float alpha, const char *A, int lda0, int lda1, |
||||
const char *B, int ldb0, int ldb1, |
||||
float beta, char *C, int ldc, int esz) { |
||||
int GEMM_MC = FAST_GEMM_DEFAULT_F32_MC, |
||||
GEMM_NC = FAST_GEMM_DEFAULT_F32_NC, |
||||
GEMM_MR = FAST_GEMM_DEFAULT_F32_MR, |
||||
GEMM_NR = FAST_GEMM_DEFAULT_F32_NR; |
||||
|
||||
int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR; |
||||
int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; |
||||
int KC = FAST_GEMM_DEFAULT_STORAGE / ((MC + NC) * esz); |
||||
KC = KC > 8 ? KC : 8; |
||||
KC = KC < K ? KC : K; |
||||
|
||||
size_t buff_size = KC * (MC + NC) * esz; |
||||
bool use_stackbuff = buff_size <= FAST_GEMM_DEFAULT_MAX_STACKBUF; |
||||
int m_tiles = (M + MC - 1) / MC; |
||||
int n_tiles = (N + NC - 1) / NC; |
||||
int total_tiles = m_tiles * n_tiles; |
||||
|
||||
auto fn = [&](const Range &r) { |
||||
char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size)); |
||||
char* packed_b = packed_a + KC * MC * esz; |
||||
int start = r.start; |
||||
int end = r.end; |
||||
|
||||
for (int tile_idx = start; tile_idx < end; tile_idx++) { |
||||
int i0 = (tile_idx / n_tiles) * MC; |
||||
int j0 = (tile_idx % n_tiles) * NC; |
||||
int mc = M - i0 < MC ? M - i0 : MC; |
||||
int nc = N - j0 < NC ? N - j0 : NC; |
||||
int ldc_block = ldc; |
||||
char* c_block = C + (i0 * ldc + j0) * esz; |
||||
|
||||
if (beta == 0.f) { |
||||
for(int i = 0; i < mc; i++) |
||||
memset(c_block + i * ldc_block * esz, 0, nc * esz); |
||||
} else if (beta != 1.f) { |
||||
for(int i = 0; i < mc; i++) { |
||||
float* c_i = (float*)c_block + i * ldc_block; |
||||
for(int j = 0; j < nc; j++) |
||||
c_i[j] *= beta; |
||||
} |
||||
} |
||||
|
||||
for(int k0 = 0; k0 < K; k0 += KC) |
||||
{ |
||||
int kc = K - k0 < KC ? K - k0 : KC; |
||||
fast_gemm_pack8_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); |
||||
fast_gemm_pack12_f32(nc, kc, B + (k0 * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_b); |
||||
fast_gemm_macro_kernel(mc, nc, kc, packed_a, packed_b, alpha, c_block, ldc_block, esz); |
||||
} |
||||
} |
||||
|
||||
if (!use_stackbuff) { |
||||
free(packed_a); |
||||
} |
||||
}; |
||||
|
||||
int total = total_tiles; |
||||
int cost_per_thread = static_cast<int>((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); |
||||
double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); |
||||
parallel_for_(Range(0, total), fn, nstripes); |
||||
} |
||||
|
||||
void fastGemmKernel(int M, int N, int K, |
||||
float alpha, const char *A, int lda0, int lda1, |
||||
const char *packed_B, float beta, char *C, int ldc, int esz) { |
||||
int GEMM_MC = FAST_GEMM_DEFAULT_F32_MC, |
||||
GEMM_NC = FAST_GEMM_DEFAULT_F32_NC, |
||||
GEMM_MR = FAST_GEMM_DEFAULT_F32_MR, |
||||
GEMM_NR = FAST_GEMM_DEFAULT_F32_NR; |
||||
|
||||
int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR; |
||||
int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; |
||||
int KC = std::min(FAST_GEMM_DEFAULT_F32_PACKED_STRIDE_K, K); |
||||
|
||||
size_t buff_size = KC * MC * esz; |
||||
bool use_stackbuff = buff_size <= FAST_GEMM_DEFAULT_MAX_STACKBUF; |
||||
int m_tiles = (M + MC - 1) / MC; |
||||
int n_tiles = (N + NC - 1) / NC; |
||||
int total_tiles = m_tiles * n_tiles; |
||||
|
||||
auto fn = [&](const Range &r) { |
||||
char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size)); // TODO: use AutoBuffer
|
||||
const char *packed_b_ = packed_B; |
||||
int start = r.start; |
||||
int end = r.end; |
||||
|
||||
for (int tile_idx = start; tile_idx < end; tile_idx++) { |
||||
int i0 = (tile_idx / n_tiles) * MC; |
||||
int j0 = (tile_idx % n_tiles) * NC; |
||||
int mc = M - i0 < MC ? M - i0 : MC; |
||||
int nc = N - j0 < NC ? N - j0 : NC; |
||||
int ldc_block = ldc; |
||||
char* c_block = C + (i0 * ldc + j0) * esz; |
||||
packed_b_ = packed_B + j0 * K * esz; |
||||
|
||||
if (beta == 0.f) { |
||||
for(int i = 0; i < mc; i++) |
||||
memset(c_block + i * ldc_block * esz, 0, nc * esz); |
||||
} else if (beta != 1.f) { |
||||
for(int i = 0; i < mc; i++) { |
||||
float* c_i = (float*)c_block + i * ldc_block; |
||||
for(int j = 0; j < nc; j++) |
||||
c_i[j] *= beta; |
||||
} |
||||
} |
||||
|
||||
int _nc = static_cast<int>((nc + GEMM_NR - 1) / GEMM_NR) * GEMM_NR * esz; |
||||
for(int k0 = 0; k0 < K; k0 += KC) |
||||
{ |
||||
int kc = K - k0 < KC ? K - k0 : KC; |
||||
fast_gemm_pack8_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); |
||||
fast_gemm_macro_kernel(mc, nc, kc, packed_a, packed_b_, alpha, c_block, ldc_block, esz); |
||||
packed_b_ += _nc * kc; |
||||
} |
||||
} |
||||
|
||||
if (!use_stackbuff) { |
||||
free(packed_a); |
||||
} |
||||
}; |
||||
|
||||
int total = total_tiles; |
||||
int cost_per_thread = static_cast<int>((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); |
||||
double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); |
||||
parallel_for_(Range(0, total), fn, nstripes); |
||||
} |
||||
|
||||
}}} // cv::dnn::cpu_baseline
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,361 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp" |
||||
#include "layers_common.hpp" |
||||
// backends
|
||||
#include "../op_cuda.hpp" |
||||
#ifdef HAVE_CUDA |
||||
// #include "../cuda4dnn/primitives/matmul.hpp"
|
||||
#include "../cuda4dnn/primitives/inner_product.hpp" |
||||
using namespace cv::dnn::cuda4dnn; |
||||
#endif |
||||
#include "../op_cann.hpp" |
||||
#include "../ie_ngraph.hpp" |
||||
#include "../op_vkcom.hpp" |
||||
|
||||
#include <opencv2/dnn/shape_utils.hpp> |
||||
#include "cpu_kernels/fast_gemm.hpp" |
||||
|
||||
namespace cv { namespace dnn { |
||||
|
||||
class GemmLayerImpl CV_FINAL : public GemmLayer { |
||||
public: |
||||
GemmLayerImpl(const LayerParams& params) { |
||||
setParamsFrom(params); |
||||
|
||||
trans_a = params.get<bool>("transA", false); |
||||
trans_b = params.get<bool>("transB", false); |
||||
alpha = params.get<float>("alpha", 1.0f); |
||||
beta = params.get<float>("beta", 1.0f); |
||||
|
||||
const_B = params.get<bool>("constB", false); // true means blobs[0] is B
|
||||
const_C = params.get<bool>("constC", false); // true means blobs.back() is C
|
||||
have_bias = params.get<bool>("have_bias", false); // NOTE: have_bias being true does not mean bias is constant
|
||||
|
||||
real_ndims_C = params.get<int>("real_ndims_C", -1); |
||||
} |
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE { |
||||
return backendId == DNN_BACKEND_OPENCV || |
||||
(backendId == DNN_BACKEND_CUDA && const_B && !trans_a) || |
||||
backendId == DNN_BACKEND_CANN || |
||||
backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH || |
||||
(backendId == DNN_BACKEND_VKCOM && haveVulkan() && !have_bias && !trans_a); |
||||
} |
||||
|
||||
virtual bool getMemoryShapes(const std::vector<MatShape> &inputs, |
||||
const int requiredOutputs, |
||||
std::vector<MatShape> &outputs, |
||||
std::vector<MatShape> &internals) const CV_OVERRIDE { |
||||
int num_inputs = static_cast<int>(inputs.size() + blobs.size()); |
||||
CV_CheckGE(num_inputs, 2, "DNN/Gemm: Gemm takes at least two inputs"); |
||||
CV_CheckLE(num_inputs, 3, "DNN/Gemm: Gemm takes at most three inputs"); |
||||
|
||||
// Check whether A and B are two dimensional
|
||||
const auto shape_A = inputs[0]; |
||||
const auto shape_B = const_B ? shape(blobs[0]) : inputs[1]; |
||||
CV_CheckGE(shape_A.size(), static_cast<size_t>(2), "DNN/Gemm: Tensor A must be n-dimensional (n >= 2)"); |
||||
CV_CheckEQ(shape_B.size(), static_cast<size_t>(2), "DNN/Gemm: Tensor B must be two dimensional"); |
||||
|
||||
// Check legal matrix multiplication
|
||||
size_t dims_A = shape_A.size(); |
||||
int ma = shape_A[dims_A - 2], na = shape_A[dims_A - 1]; |
||||
int mb = shape_B[0], nb = shape_B[1]; |
||||
int M = trans_a ? na : ma; |
||||
int N = trans_b ? mb : nb; |
||||
int K_a = trans_a ? ma : na; |
||||
int K_b = trans_b ? nb : mb; |
||||
CV_CheckEQ(K_a, K_b, "DNN/Gemm: Invalid dimension of dim K"); |
||||
|
||||
// Check whether C can be unidirectional broadcast to (M, N). Handle carefully with 1D Mat.
|
||||
if (have_bias) { |
||||
const auto shape_C = const_C ? shape(blobs.back()) : inputs.back(); |
||||
|
||||
auto ndims_C = shape_C.size(); |
||||
CV_CheckLE(ndims_C, static_cast<size_t>(2), "DNN/Gemm: C can only be 0d (scalar) / 1d / 2d tensor"); |
||||
|
||||
if (real_ndims_C == 1) { // (1,) or (N,)
|
||||
CV_Check(shape_C[0], shape_C[0] == 1 || shape_C[0] == N, "DNN/Gemm: invalid dimension of C"); |
||||
} else if (real_ndims_C == 2) { // (1, 1) or (1, N) or (M, 1) or (M, N)
|
||||
// printf("shape_C=[%d, %d]\n", shape_C[0], shape_C[1]);
|
||||
CV_Check(shape_C[0], (shape_C[0] == 1 && shape_C[1] == 1) || |
||||
(shape_C[0] == 1 && shape_C[1] == N) || |
||||
(shape_C[0] == M && shape_C[1] == 1) || |
||||
(shape_C[0] == M && shape_C[1] == N), |
||||
"DNN/Gemm: C must be of shape (1, 1) or (1, N) or (M, 1) or (M, N)"); |
||||
if (shape_C[0] == 1) { |
||||
CV_Check(shape_C[1], shape_C[1] == 1 || shape_C[1] == N, "DNN/Gemm: invalid dimension of C"); |
||||
} else if (shape_C[0] == M) { |
||||
CV_Check(shape_C[1], shape_C[1] == 1 || shape_C[1] == N, "DNN/Gemm: invalid dimension of C"); |
||||
} else { |
||||
CV_Error(Error::StsBadSize, "DNN/Gemm: invalid dimension of C"); |
||||
} |
||||
} |
||||
} |
||||
|
||||
int batches = std::accumulate(shape_A.begin(), shape_A.end() - 2, 1, std::multiplies<int>()); |
||||
MatShape shape_y{M * batches, N}; |
||||
outputs.assign(1, shape_y); |
||||
return false; |
||||
} |
||||
|
||||
// TODO: replace with cv::broadcast() once 1d mat is supported
|
||||
// FIXME: fix if conditions if 1d mat is supported properly
|
||||
void broadcastCWtihBeta(int M, int N, const Mat &C) { |
||||
if (beta != 0 && !C.empty()) { |
||||
broadcast_C.clear(); |
||||
broadcast_C.resize(M * N, 0.f); |
||||
|
||||
const float *ptr_c = C.ptr<const float>(); |
||||
const auto shape_C = shape(C); |
||||
if ((real_ndims_C == 0) || (real_ndims_C == 1 && shape_C[0] == 1) || |
||||
(real_ndims_C == 2 && shape_C[0] == 1 && shape_C[1] == 1)) { |
||||
// (), (1,), (1, 1)
|
||||
float c = *ptr_c; |
||||
int total = M * N; |
||||
for (int i = 0; i < total; ++i) { |
||||
broadcast_C[i] = beta * c; |
||||
} |
||||
} else if ((real_ndims_C == 1 && shape_C[0] == N) || |
||||
(real_ndims_C == 2 && shape_C[0] == 1 && shape_C[1] == N)) { |
||||
// (N,), (1, N)
|
||||
for (int i = 0; i < M; ++i) { |
||||
int step = i * N; |
||||
for (int j = 0; j < N; ++j) { |
||||
broadcast_C[step + j] = beta * ptr_c[j]; |
||||
} |
||||
} |
||||
} else if (real_ndims_C == 2 && shape_C[0] == M && shape_C[1] == 1) { |
||||
// (M, 1)
|
||||
for (int i = 0; i < M; ++i) { |
||||
int step = i * N; |
||||
for (int j = 0; j < N; ++j) { |
||||
broadcast_C[step + j] = beta * ptr_c[i]; |
||||
} |
||||
} |
||||
} else { |
||||
// (M, N)
|
||||
std::transform(ptr_c, ptr_c + M * N, broadcast_C.begin(), [this] (const float &c) { |
||||
return this->beta * c; }); |
||||
} |
||||
} |
||||
} |
||||
|
||||
virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE { |
||||
opt.init(); |
||||
|
||||
// pack B if it is const
|
||||
if (const_B) { |
||||
fastGemmPackB(blobs[0], packed_B, trans_b, opt); |
||||
} |
||||
|
||||
// also pre-broadcast bias
|
||||
if (const_C) { |
||||
const auto &C = blobs.back(); |
||||
|
||||
std::vector<Mat> outputs; |
||||
outputs_arr.getMatVector(outputs); |
||||
const auto &Y = outputs[0]; |
||||
const auto shape_Y = shape(Y); |
||||
size_t dims_Y = shape_Y.size(); |
||||
int M = shape_Y[dims_Y - 2], N = shape_Y[dims_Y - 1]; |
||||
|
||||
// broadcast
|
||||
broadcastCWtihBeta(M, N, C); |
||||
} |
||||
} |
||||
|
||||
// Y = A * B + C, note that C is unidirectionaly broadcastable to (A * B).
|
||||
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE { |
||||
CV_TRACE_FUNCTION(); |
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); |
||||
|
||||
if (inputs_arr.depth() == CV_16S) |
||||
{ |
||||
forward_fallback(inputs_arr, outputs_arr, internals_arr); |
||||
return; |
||||
} |
||||
|
||||
std::vector<Mat> inputs, outputs; |
||||
inputs_arr.getMatVector(inputs); |
||||
outputs_arr.getMatVector(outputs); |
||||
|
||||
const auto &A = inputs[0]; |
||||
auto &Y = outputs[0]; |
||||
|
||||
const auto shape_A = shape(A), shape_Y = shape(Y); |
||||
size_t dims_A = shape_A.size(); |
||||
int ma = shape_A[dims_A - 2], na = shape_A[dims_A - 1]; |
||||
size_t dims_Y = shape_Y.size(); |
||||
int M = shape_Y[dims_Y - 2], N = shape_Y[dims_Y - 1]; |
||||
int K = trans_a ? ma : na; |
||||
int batches = std::accumulate(shape_A.begin(), shape_A.end() - 2, 1, std::multiplies<int>()); |
||||
|
||||
// broadcast C and copy C to output
|
||||
if (have_bias) { |
||||
if (!const_C) { |
||||
broadcastCWtihBeta(M, N, inputs.back()); |
||||
} |
||||
int step = M * N; |
||||
CV_CheckEQ(broadcast_C.size(), static_cast<size_t>(step), "DNN/Gemm: C is not broadcast properly"); |
||||
float *ptr_y = Y.ptr<float>(); |
||||
for (int i = 0; i < batches; i++) { |
||||
std::memcpy(ptr_y + i * step, broadcast_C.data(), step * sizeof(float)); |
||||
} |
||||
} else { // initialization
|
||||
float *ptr_y = Y.ptr<float>(); |
||||
size_t total = Y.total(); |
||||
std::memset(ptr_y, 0, total * sizeof(float)); |
||||
} |
||||
|
||||
if (const_B) { |
||||
CV_CheckGT(packed_B.size(), static_cast<size_t>(0), "DNN/Gemm: constant B is not pre-packed"); |
||||
M *= batches; |
||||
fastGemm(trans_a, M, N, K, alpha, A.ptr<const float>(), na, packed_B.data(), 1.f, Y.ptr<float>(), N, opt); |
||||
} else { |
||||
fastGemmBatched(trans_a, trans_b, alpha, A, inputs[1], 1.f, Y, opt); |
||||
} |
||||
} |
||||
|
||||
#ifdef HAVE_CUDA |
||||
// Y = A * B + C. B should be guaranteed as two dimensional.
|
||||
Ptr<BackendNode> initCUDA(void *context_, |
||||
const std::vector<Ptr<BackendWrapper>>& inputs, |
||||
const std::vector<Ptr<BackendWrapper>>& outputs) CV_OVERRIDE { |
||||
CV_CheckFalse(trans_a, "DNN/Gemm/Cuda: does not support transA"); |
||||
CV_CheckTrue(const_B, "DNN/Gemm/Cuda: input B (weight) is required to be constant"); |
||||
auto context = reinterpret_cast<csl::CSLContext*>(context_); |
||||
auto wrapper_A = inputs[0].dynamicCast<CUDABackendWrapper>(); |
||||
auto B = blobs[0]; |
||||
auto C = have_bias && const_C ? blobs[1] : Mat(); // in most cases C is constant
|
||||
|
||||
if (!trans_b) |
||||
cv::transpose(B, B); |
||||
auto flatten_start_axis = normalize_axis(1, wrapper_A->getRank()); |
||||
return make_cuda_node<cuda4dnn::InnerProductOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), flatten_start_axis, B, C); |
||||
} |
||||
#endif // HAVE_CUDA
|
||||
|
||||
#ifdef HAVE_CANN |
||||
// Y = A * B + C.
|
||||
virtual Ptr<BackendNode> initCann(const std::vector<Ptr<BackendWrapper> > &inputs, |
||||
const std::vector<Ptr<BackendWrapper> > &outputs, |
||||
const std::vector<Ptr<BackendNode> >& nodes) CV_OVERRIDE { |
||||
auto x1 = inputs[0].dynamicCast<CannBackendWrapper>(); |
||||
auto desc_x1 = x1->getTensorDesc(); |
||||
auto op_x1 = nodes[0].dynamicCast<CannBackendNode>()->getOp(); |
||||
|
||||
auto op = std::make_shared<ge::op::MatMulV2>(name); |
||||
|
||||
// set attributes
|
||||
op->set_attr_transpose_x1(trans_a); |
||||
op->set_attr_transpose_x2(trans_b); |
||||
|
||||
// set inputs
|
||||
// set inputs : x1
|
||||
op->set_input_x1_by_name(*op_x1, x1->name.c_str()); |
||||
op->update_input_desc_x1(*desc_x1); |
||||
// set inputs : x2
|
||||
if (const_B) { |
||||
auto B = blobs[0]; |
||||
auto op_const_B = std::make_shared<CannConstOp>(B.data, B.type(), shape(B), cv::format("%s_w", name.c_str())); |
||||
op->set_input_x2_by_name(*(op_const_B->getOp()), "y"); |
||||
op->update_input_desc_x2(*(op_const_B->getTensorDesc())); |
||||
} else { |
||||
CV_CheckGE(inputs.size(), static_cast<size_t>(2), "DNN/Gemm/CANN: input B is required since it is not constant"); |
||||
CV_CheckGE(nodes.size(), static_cast<size_t>(2), "DNN/Gemm/CANN: input B is required since it is not constant"); |
||||
auto op_x2 = nodes[1].dynamicCast<CannBackendNode>()->getOp(); |
||||
auto desc_x2 = inputs[1].dynamicCast<CannBackendWrapper>()->getTensorDesc(); |
||||
op->set_input_x2_by_name(*op_x2, "y"); |
||||
op->update_input_desc_x2(*desc_x2); |
||||
} |
||||
// set inputs : bias
|
||||
auto mat_C = have_bias && const_C ? blobs.back() : Mat::zeros(1, 1, CV_32F); |
||||
auto op_const_C = std::make_shared<CannConstOp>(mat_C.data, mat_C.type(), shape(mat_C), cv::format("%s_b", name.c_str())); |
||||
op->set_input_bias(*(op_const_C->getOp())); |
||||
op->update_input_desc_bias(*(op_const_C->getTensorDesc())); |
||||
|
||||
// set outputs
|
||||
op->update_output_desc_y(*output_desc); |
||||
return Ptr<BackendNode>(new CannBackendNode(op)); |
||||
} |
||||
#endif // HAVE_CANN
|
||||
|
||||
#ifdef HAVE_DNN_NGRAPH |
||||
virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inputs, |
||||
const std::vector<Ptr<BackendNode> >& nodes) CV_OVERRIDE |
||||
{ |
||||
auto& ieInpNode = nodes[0].dynamicCast<InfEngineNgraphNode>()->node; |
||||
std::shared_ptr<ngraph::Node> matmul; |
||||
int axis = -2; |
||||
|
||||
if (nodes.size() == 2) |
||||
{ |
||||
auto& inp2 = nodes[1].dynamicCast<InfEngineNgraphNode>()->node; |
||||
matmul = std::make_shared<ngraph::op::MatMul>(ieInpNode, inp2, transA, transB); |
||||
} |
||||
else |
||||
{ |
||||
std::vector<int> shape(1 + normalize_axis(axis, ieInpNode->get_shape().size()), 0); |
||||
shape[shape.size() - 1] = -1; |
||||
auto inp = std::make_shared<ngraph::op::v1::Reshape>( |
||||
ieInpNode, |
||||
std::make_shared<ngraph::op::Constant>(ngraph::element::i32, ngraph::Shape{shape.size()}, shape.data()), |
||||
true |
||||
); |
||||
|
||||
std::vector<size_t> weight_shape{(size_t)blobs[0].size[0], (size_t)blobs[0].size[1]}; |
||||
auto ieWeights = std::make_shared<ngraph::op::Constant>(ngraph::element::f32, weight_shape, blobs[0].data); |
||||
matmul = std::make_shared<ngraph::op::MatMul>(inp, ieWeights, transA, transB); |
||||
} |
||||
|
||||
if (have_bias && const_C) { |
||||
auto bias_node = std::make_shared<ngraph::op::Constant>(ngraph::element::f32, |
||||
ngraph::Shape{(size_t)blobs.back().size[1]}, blobs.back().data); |
||||
matmul = std::make_shared<ngraph::op::v1::Add>(matmul, bias_node, ngraph::op::AutoBroadcastType::NUMPY); |
||||
} |
||||
return Ptr<BackendNode>(new InfEngineNgraphNode(matmul)); |
||||
} |
||||
#endif // HAVE_DNN_NGRAPH
|
||||
|
||||
#ifdef HAVE_VULKAN |
||||
// Y = A * B + C. Currently support 2d matrix multiplication without bias.
|
||||
virtual Ptr<BackendNode> initVkCom(const std::vector<Ptr<BackendWrapper> > &inputs, |
||||
std::vector<Ptr<BackendWrapper> > &outputs) CV_OVERRIDE |
||||
{ |
||||
// does not support with bias; only 2d matmul
|
||||
auto wrapper_Y = outputs[0].dynamicCast<VkComBackendWrapper>(); |
||||
auto shape_Y = shape(*(wrapper_Y->getMat())); |
||||
if (have_bias || shape_Y.size() > static_cast<size_t>(2)) { |
||||
return Ptr<BackendNode>(); |
||||
} |
||||
|
||||
std::vector<Mat> vkBlobs; |
||||
if (const_B) { |
||||
vkBlobs.push_back(blobs[0]); |
||||
} |
||||
|
||||
auto wrapper_A = inputs[0].dynamicCast<VkComBackendWrapper>(); |
||||
auto shape_A = shape(*wrapper_A->getMat()); |
||||
Ptr<vkcom::OpBase> op = (new vkcom::OpMatMul(vkBlobs, shape_A[0], shape_A[1], shape_Y[1])); |
||||
return Ptr<BackendNode>(new VkComBackendNode(inputs, op, outputs)); |
||||
} |
||||
#endif |
||||
|
||||
private: |
||||
bool const_B; |
||||
bool const_C; |
||||
bool have_bias; |
||||
std::vector<float> packed_B; |
||||
std::vector<float> broadcast_C; |
||||
int real_ndims_C; |
||||
FastGemmOpt opt; |
||||
}; |
||||
|
||||
Ptr<GemmLayer> GemmLayer::create(const LayerParams& params) { |
||||
return makePtr<GemmLayerImpl>(params); |
||||
} |
||||
|
||||
}} // namespace cv::dnn
|
Loading…
Reference in new issue