diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp index 3109955685..c408921c04 100644 --- a/modules/dnn/include/opencv2/dnn/all_layers.hpp +++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp @@ -1178,6 +1178,11 @@ CV__DNN_INLINE_NS_BEGIN static Ptr create(const LayerParams ¶ms); }; + class CV_EXPORTS AttentionLayer : public Layer { + public: + static Ptr create(const LayerParams ¶ms); + }; + //! @} //! @} CV__DNN_INLINE_NS_END diff --git a/modules/dnn/perf/perf_layer.cpp b/modules/dnn/perf/perf_layer.cpp index c26b7a1588..5849b8ff47 100644 --- a/modules/dnn/perf/perf_layer.cpp +++ b/modules/dnn/perf/perf_layer.cpp @@ -739,6 +739,62 @@ PERF_TEST_P_(Layer_InstanceNorm, InstanceNorm) test_layer({N, C, H, W}); } +struct Layer_Attention : public TestBaseWithParam> { + void test_layer(const std::vector x_shape, const std::vector qkv_hidden_sizes, const int num_heads) { + int backendId = get<0>(GetParam()); + int targetId = get<1>(GetParam()); + + auto qk_hidden_size = qkv_hidden_sizes[0]; + auto v_hidden_size = qkv_hidden_sizes[2]; + + auto input_hidden_size = x_shape[2]; + auto hidden_size = qk_hidden_size + qk_hidden_size + v_hidden_size; + + Mat x(x_shape, CV_32F); + Mat weight(std::vector{input_hidden_size, hidden_size}, CV_32F); + Mat bias(std::vector{hidden_size}, CV_32F); + + randu(x, 0.f, 1.f); + randu(weight, 0.f, 1.f); + randu(bias, 0.f, 1.f); + + LayerParams lp; + lp.type = "Attention"; + lp.name = "testLayer"; + lp.set("num_heads", num_heads); + lp.set("qkv_hidden_sizes", DictValue::arrayInt(qkv_hidden_sizes.data(), qkv_hidden_sizes.size())); + + Net net; + int id = net.addLayerToPrev(lp.name, lp.type, lp); + net.connect(0, 0, id, 0); + net.connect(0, 1, id, 1); + net.connect(0, 2, id, 2); + + { + std::vector input_names{"x", "weight", "bias"}; + net.setInputsNames(input_names); + net.setInput(x, input_names[0]); + net.setInput(weight, input_names[1]); + net.setInput(bias, input_names[2]); + + net.setPreferableBackend(backendId); + net.setPreferableTarget(targetId); + Mat out = net.forward(); + } + + TEST_CYCLE() + { + Mat out = net.forward(); + } + + SANITY_CHECK_NOTHING(); + } +}; + +PERF_TEST_P_(Layer_Attention, VisionTransformer) { + test_layer({1, 197, 768}, {768, 768, 768}, 12); +} + INSTANTIATE_TEST_CASE_P(/**/, Layer_Slice, dnnBackendsAndTargets(false, false)); INSTANTIATE_TEST_CASE_P(/**/, Layer_NaryEltwise, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); #ifdef HAVE_CUDA @@ -750,6 +806,7 @@ INSTANTIATE_TEST_CASE_P(/**/, Layer_LayerNorm, testing::Values(std::make_tuple(D INSTANTIATE_TEST_CASE_P(/**/, Layer_LayerNormExpanded, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); INSTANTIATE_TEST_CASE_P(/**/, Layer_GatherElements, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); INSTANTIATE_TEST_CASE_P(/**/, Layer_InstanceNorm, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); +INSTANTIATE_TEST_CASE_P(/**/, Layer_Attention, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); typedef TestBaseWithParam > > Layer_FullyConnected; diff --git a/modules/dnn/perf/perf_net.cpp b/modules/dnn/perf/perf_net.cpp index 63d605b45c..33e49526e9 100644 --- a/modules/dnn/perf/perf_net.cpp +++ b/modules/dnn/perf/perf_net.cpp @@ -93,7 +93,6 @@ public: } }; - PERF_TEST_P_(DNNTestNetwork, AlexNet) { processNet("dnn/bvlc_alexnet.caffemodel", "dnn/bvlc_alexnet.prototxt", @@ -391,17 +390,16 @@ PERF_TEST_P_(DNNTestNetwork, CRNN) { processNet("", "dnn/text_recognition_CRNN_EN_2021sep.onnx", "", inp); } -PERF_TEST_P_(DNNTestNetwork, ViTTrack) { +PERF_TEST_P_(DNNTestNetwork, VitTrack) { Mat inp1(cv::Size(128, 128), CV_32FC3); Mat inp2(cv::Size(256, 256), CV_32FC3); randu(inp1, 0.0f, 1.0f); randu(inp2, 0.0f, 1.0f); inp1 = blobFromImage(inp1, 1.0, Size(), Scalar(), false); inp2 = blobFromImage(inp2, 1.0, Size(), Scalar(), false); - processNet("", "dnn/onnx/models/vitTracker.onnx", "", {std::make_tuple(inp1, "template"), std::make_tuple(inp2, "search")}); + processNet("", "dnn/onnx/models/object_tracking_vittrack_2023sep.onnx", "", {std::make_tuple(inp1, "template"), std::make_tuple(inp2, "search")}); } - PERF_TEST_P_(DNNTestNetwork, EfficientDet_int8) { if (target != DNN_TARGET_CPU || (backend != DNN_BACKEND_OPENCV && @@ -413,6 +411,10 @@ PERF_TEST_P_(DNNTestNetwork, EfficientDet_int8) processNet("", "dnn/tflite/coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite", "", inp); } +PERF_TEST_P_(DNNTestNetwork, VIT_B_32) { + processNet("", "dnn/onnx/models/vit_b_32.onnx", "", cv::Size(224, 224)); +} + INSTANTIATE_TEST_CASE_P(/*nothing*/, DNNTestNetwork, dnnBackendsAndTargets()); } // namespace diff --git a/modules/dnn/src/init.cpp b/modules/dnn/src/init.cpp index cc316efbfc..9b433dac50 100644 --- a/modules/dnn/src/init.cpp +++ b/modules/dnn/src/init.cpp @@ -162,6 +162,7 @@ void initializeLayerFactory() CV_DNN_REGISTER_LAYER_CLASS(LayerNormalization, LayerNormLayer); CV_DNN_REGISTER_LAYER_CLASS(Expand, ExpandLayer); CV_DNN_REGISTER_LAYER_CLASS(InstanceNormalization, InstanceNormLayer); + CV_DNN_REGISTER_LAYER_CLASS(Attention, AttentionLayer); CV_DNN_REGISTER_LAYER_CLASS(Crop, CropLayer); CV_DNN_REGISTER_LAYER_CLASS(Eltwise, EltwiseLayer); diff --git a/modules/dnn/src/layers/attention_layer.cpp b/modules/dnn/src/layers/attention_layer.cpp new file mode 100644 index 0000000000..64b39297f5 --- /dev/null +++ b/modules/dnn/src/layers/attention_layer.cpp @@ -0,0 +1,272 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include "../precomp.hpp" +#include "cpu_kernels/fast_gemm.hpp" +#include "cpu_kernels/softmax.hpp" + +#include + +namespace cv { namespace dnn { + +static void packWeight(size_t num_heads, size_t head_size, size_t input_hidden_size, + const float *weight_data, size_t hidden_size, std::vector &packed_weight, const FastGemmOpt &opt) { + // num_heads * pack(head_size, input_hidden_size) + size_t pack_size = fastGemmPackBSize(head_size, input_hidden_size, opt); + size_t packed_weight_size = num_heads * pack_size; + packed_weight.resize(packed_weight_size, 0.f); + auto *packed_weight_data = packed_weight.data(); + for (size_t i = 0; i < num_heads; i++) { + fastGemmPackB(false, head_size, input_hidden_size, weight_data, hidden_size, packed_weight_data, opt); + packed_weight_data += pack_size; + weight_data += head_size; + } +} + +// Operator spec: https://github.com/microsoft/onnxruntime/blob/v1.16.1/docs/ContribOperators.md#com.microsoft.Attention +class AttentionLayerImpl CV_FINAL : public AttentionLayer { + public: + AttentionLayerImpl(const LayerParams ¶ms) { + setParamsFrom(params); + + CV_CheckTrue(params.has("num_heads"), "DNN/Attention: num_heads is required but missing"); + num_heads = params.get("num_heads"); // required, no default value + + CV_CheckTrue(params.has("qkv_hidden_sizes"), "DNN/Attention: qkv_hidden_sizes is required but missing"); + auto param_qkv_hidden_sizes = params.get("qkv_hidden_sizes"); + CV_CheckEQ(param_qkv_hidden_sizes.size(), 3, "DNN/Attention: qkv_hidden_sizes must and only have three elements"); + + qkv_hidden_sizes.clear(); + qkv_hidden_sizes.resize(3); + qkv_hidden_sizes[0] = static_cast(param_qkv_hidden_sizes.get(0)); + qkv_hidden_sizes[1] = static_cast(param_qkv_hidden_sizes.get(1)); + /* v_hidden_size needs to be initialized in finalize in case v_slice_end=INT_MAX */ + + qkv_head_sizes.clear(); + qkv_head_sizes.resize(3); + qkv_head_sizes[0] = static_cast(qkv_hidden_sizes[0] / num_heads); + qkv_head_sizes[1] = static_cast(qkv_hidden_sizes[1] / num_heads); + + scale = 1.f / params.get("scale", sqrt(qkv_head_sizes[0])); + + output_ndims = params.get("output_ndims", 3); + + is_prepacked = false; + } + + virtual bool supportBackend(int backendId) CV_OVERRIDE { + return backendId == DNN_BACKEND_OPENCV; + } + + virtual bool getMemoryShapes(const std::vector &inputs, + const int requiredOutputs, + std::vector &outputs, + std::vector &internals) const CV_OVERRIDE { + CV_CheckEQ(inputs.size(), static_cast(3), "DNN/Attention: three inputs are required"); + const auto &input_shape = inputs[0]; + const auto &weight_shape = inputs[1]; + const auto &bias_shape = inputs[2]; + + CV_CheckEQ(input_shape.size(), static_cast(3), "DNN/Attention: invalid input dimension"); + CV_CheckEQ(weight_shape.size(), static_cast(2), "DNN/Attention: invalid weight dimension"); + + CV_CheckEQ(input_shape[2], weight_shape[0], "DNN/Attention: invalid input shape"); + CV_CheckEQ(weight_shape[1], bias_shape[0], "DNN/Attention: invalid weight or bias shape"); + + if (output_ndims == 3) { + outputs.assign(1, inputs[0]); + } else if (output_ndims == 2) { + int batch = input_shape[0], seq_len = input_shape[1], input_hidden_size = input_shape[2]; + MatShape output_shape{batch * seq_len, input_hidden_size}; + outputs.assign(1, output_shape); + } else { + CV_Error(Error::StsBadArg, format("DNN/Attention: invalid output dimension %zu, valid value is 2 or 3", output_ndims)); + } + return false; + } + + virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE { + opt.init(); + + std::vector inputs; + inputs_arr.getMatVector(inputs); + const auto input_shape = shape(inputs[0]); + batch_size = static_cast(input_shape[0]); + seq_len = static_cast(input_shape[1]); + input_hidden_size = static_cast(input_shape[2]); + + const auto weight_shape = shape(inputs[1]); + hidden_size = weight_shape[1]; + qkv_hidden_sizes[2] = hidden_size - qkv_hidden_sizes[0] - qkv_hidden_sizes[1]; + qkv_head_sizes[2] = static_cast(qkv_hidden_sizes[2] / num_heads); + } + + void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE { + CV_TRACE_FUNCTION(); + CV_TRACE_ARG_VALUE(name, "name", name.c_str()); + + if (inputs_arr.depth() == CV_16S) + { + forward_fallback(inputs_arr, outputs_arr, internals_arr); + return; + } + + std::vector inputs, outputs; + inputs_arr.getMatVector(inputs); + outputs_arr.getMatVector(outputs); + + // prepack weights + if (!is_prepacked) { + // prepack + const auto &weight = inputs[1]; + const auto *weight_data = weight.ptr(); + packWeight(num_heads, qkv_head_sizes[0], input_hidden_size, weight_data, hidden_size, packed_weight_q, opt); + packWeight(num_heads, qkv_head_sizes[1], input_hidden_size, weight_data + qkv_hidden_sizes[0], hidden_size, packed_weight_k, opt); + packWeight(num_heads, qkv_head_sizes[2], input_hidden_size, weight_data + qkv_hidden_sizes[0] + qkv_hidden_sizes[1], hidden_size, packed_weight_v, opt); + + is_prepacked = true; + } + + float *packed_weights[3] = {packed_weight_q.data(), packed_weight_k.data(), packed_weight_v.data()}; + size_t packed_weights_size[3] = {packed_weight_q.size() / num_heads, packed_weight_k.size() / num_heads, packed_weight_v.size() / num_heads}; + + Mat gemm_buffer = Mat::zeros(1, int(batch_size * seq_len * hidden_size), CV_32F); + auto *Q = gemm_buffer.ptr(); + auto *K = Q + batch_size * seq_len * qkv_hidden_sizes[0]; + auto *V = K + batch_size * seq_len * qkv_hidden_sizes[1]; + float *QKV[3] = {Q, K, V}; // Q, K, V: [B, N, S, H] + { + const auto &input = inputs[0]; + const auto &bias = inputs[2]; + const auto *input_data = input.ptr(); + const auto *bias_data = bias.ptr(); + + opt.multi_thread = false; + auto fn = [&](const Range &r) { + for (int i = r.start; i < r.end; i++) { + const int batch_index = static_cast((i / 3) / num_heads); + const int head_index = static_cast((i / 3) % num_heads); + const int qkv_index = static_cast(i % 3); + + auto *dst = QKV[qkv_index]; + size_t head_size = qkv_head_sizes[qkv_index]; + + int input_offset = batch_index * seq_len * input_hidden_size; + int bias_offset = qkv_index * qkv_hidden_sizes[0] + head_index * head_size; + int dst_offset = (batch_index * num_heads + head_index) * (seq_len * head_size); + + // broadcast bias ([NH] -> [BN, SH]) and make copy to dst + const auto *bias_data_src = bias_data + bias_offset; + auto *dst_data = dst + dst_offset; + for (size_t seq_len_idx = 0; seq_len_idx < seq_len; seq_len_idx++) { + std::memcpy(dst_data, bias_data_src, head_size * sizeof(float)); + dst_data += head_size; + } + + auto *packed_weight = packed_weights[qkv_index] + packed_weights_size[qkv_index] * head_index; + // single-thread gemm kernel + fastGemm(false, seq_len, head_size, input_hidden_size, + 1.f, input_data + input_offset, input_hidden_size, + packed_weight, 1.f, dst + dst_offset, head_size, opt); + } + }; + + size_t loops = 3 * batch_size * num_heads; + double nstripes = loops * seq_len * qkv_head_sizes[0] * input_hidden_size * (1 / 1024.0); + parallel_for_(Range(0, loops), fn, nstripes); + } + + // Compute softmax(scale * matmul(Q, K)) + std::vector attention_prob_shape{int(batch_size * num_heads), int(seq_len), int(seq_len)}; + Mat attention_prob = Mat::zeros(attention_prob_shape.size(), attention_prob_shape.data(), CV_32F); + { + auto *output = attention_prob.ptr(); + + auto loops = batch_size * num_heads; + auto seq_len_square = seq_len * seq_len; + auto qk_head_size = qkv_head_sizes[0]; + auto qk_inner_size = seq_len * qk_head_size; + + // Compute scale * matmul(Q, K) + opt.multi_thread = false; + parallel_for_(Range(0, loops), [&] (const Range r) { + for (int i = r.start; i < r.end; i++) { + const int output_offset = i * seq_len_square; + + const auto *q = Q + qk_inner_size * i, *k = K + qk_inner_size * i; + fastGemm(false, true, seq_len, qk_head_size, seq_len, qk_head_size, + scale, q, qk_head_size, 1, + k, qk_head_size, 1, 0.f, + output + output_offset, seq_len, opt); + } + }, loops * seq_len * qk_head_size * seq_len * (1 / 1024.0)); + + // Compute softmax + softmax(attention_prob, attention_prob, attention_prob_shape.size() - 1); + } + + // Compute np.matmul(attention_prob, V) + Mat output_buffer = Mat::zeros(1, int(batch_size * num_heads * seq_len * qkv_head_sizes[2]), CV_32F); + { + auto *output = outputs[0].ptr(); + auto *output_buff = output_buffer.ptr(); + const auto *prob = attention_prob.ptr(); + + auto loops = batch_size * num_heads; + auto prob_inner_size = seq_len * seq_len; + auto v_head_size = qkv_head_sizes[2]; + auto v_inner_size = seq_len * v_head_size; + + opt.multi_thread = false; + parallel_for_(Range(0, loops), [&] (const Range &r) { + for (int i = r.start; i < r.end; i++) { + const int output_offset = i * v_inner_size; + + const auto *p = prob + i * prob_inner_size, *v = V + i * v_inner_size; + fastGemm(false, false, seq_len, seq_len, seq_len, v_head_size, + 1.f, p, seq_len, 1, + v, v_head_size, 1, 0.f, + output_buff + output_offset, v_head_size, opt); + + // tranpose on the fly + const int batch_index = static_cast(i / num_heads); + const int head_index = static_cast(i % num_heads); + auto *src = output_buff + output_offset; + auto *dst = output + (batch_index * seq_len * num_heads + head_index) * v_head_size; + for (int j = 0; j < seq_len; j++) { + std::memcpy(dst, src, v_head_size * sizeof(float)); + src += v_head_size; + dst += qkv_hidden_sizes[2]; + } + } + }, loops * seq_len * seq_len * v_head_size * (1 / 1024.0)); + } + } + + private: + size_t num_heads; + std::vector qkv_hidden_sizes; // order: {qk_hidden_size, qk_hidden_size, v_hidden_size} + float scale; + size_t output_ndims; + + std::vector qkv_head_sizes; // order: {qk_head_size, qk_head_size, v_head_size} + + size_t batch_size; + size_t seq_len; + size_t input_hidden_size; + size_t hidden_size; + + bool is_prepacked; + std::vector packed_weight_q; + std::vector packed_weight_k; + std::vector packed_weight_v; + + FastGemmOpt opt; +}; + +Ptr AttentionLayer::create(const LayerParams ¶ms) { + return makePtr(params); +} + +}} // cv::dnn diff --git a/modules/dnn/src/layers/cpu_kernels/fast_gemm.cpp b/modules/dnn/src/layers/cpu_kernels/fast_gemm.cpp index ef71d8b10c..a8972aba4e 100644 --- a/modules/dnn/src/layers/cpu_kernels/fast_gemm.cpp +++ b/modules/dnn/src/layers/cpu_kernels/fast_gemm.cpp @@ -20,6 +20,32 @@ namespace cv { namespace dnn { +size_t fastGemmPackBSize(size_t N, size_t K, const FastGemmOpt &opt) { +#if CV_TRY_NEON + if (opt.use_neon) { + return static_cast(opt_NEON::fastGemmPackBSize(N, K)); + } else +#endif +#if CV_TRY_AVX2 + if (opt.use_avx2) { + return static_cast(opt_AVX2::fastGemmPackBSize(N, K)); + } else +#endif +#if CV_TRY_AVX + if (opt.use_avx) { + return static_cast(opt_AVX::fastGemmPackBSize(N, K)); + } else +#endif +#if CV_TRY_LASX + if (opt.use_lasx) { + return static_cast(opt_LASX::fastGemmPackBSize(N, K)); + } else +#endif + { + return static_cast(cpu_baseline::fastGemmPackBSize(N, K)); + } +} + void fastGemmPackB(const Mat &B, std::vector &packed_B, bool trans, FastGemmOpt &opt) { CV_CheckTypeEQ(B.type(), CV_32F, "fastGemmPackB: only float32 is supported for now"); @@ -94,10 +120,45 @@ void fastGemmPackB(const Mat &B, std::vector &packed_B, bool trans, FastG } } +void fastGemmPackB(bool trans, size_t N, size_t K, const float *B, size_t ldb, float *packed_B, const FastGemmOpt &opt) { + size_t ldb0 = ldb, ldb1 = 1; + if (trans) { + std::swap(K, N); + std::swap(ldb0, ldb1); + } + + const auto &b = (const char *)B; + auto *packed_b = (char *)packed_B; + +#if CV_TRY_NEON + if (opt.use_neon) { + opt_NEON::fastGemmPackBKernel(b, packed_b, N, K, ldb0, ldb1, sizeof(float)); + } else +#endif +#if CV_TRY_AVX2 + if (opt.use_avx2) { + opt_AVX2::fastGemmPackBKernel(b, packed_b, N, K, ldb0, ldb1, sizeof(float)); + } else +#endif +#if CV_TRY_AVX + if (opt.use_avx) { + opt_AVX::fastGemmPackBKernel(b, packed_b, N, K, ldb0, ldb1, sizeof(float)); + } else +#endif +#if CV_TRY_LASX + if (opt.use_lasx) { + opt_LASX::fastGemmPackBKernel(b, packed_b, N, K, ldb0, ldb1, sizeof(float)); + } else +#endif + { + cpu_baseline::fastGemmPackBKernel(b, packed_b, N, K, ldb0, ldb1, sizeof(float)); + } +} + static void fast_gemm_thin(float alpha, float beta, int M, int N, int K, const char *a_, int lda0, int lda1, const char *b_, int ldb, - char *c_, int ldc) { + char *c_, int ldc, bool multi_thread) { const float* a = (const float*)a_; auto fn = [&](const Range &r) { @@ -116,16 +177,24 @@ static void fast_gemm_thin(float alpha, float beta, int M, int N, int K, } }; - int total = M; // outer loops - int cost_per_thread = static_cast(K * N); // inner loops - double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); - parallel_for_(Range(0, total), fn, nstripes); + if (multi_thread) { + int total = M; // outer loops + int cost_per_thread = static_cast(K * N); // inner loops + double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); + parallel_for_(Range(0, total), fn, nstripes); + } else { + fn(Range(0, M)); + } } void fastGemm(bool trans_a, int M, int N, int K, float alpha, const float *A, int lda, const float *packed_B, float beta, float *C, int ldc, FastGemmOpt &opt) { + const char *a = (const char *)A; + const char *packed_b = (const char *)packed_B; + char *c = (char *)C; + int lda0 = lda, lda1 = 1; if (trans_a) { std::swap(lda0, lda1); @@ -133,26 +202,26 @@ void fastGemm(bool trans_a, int M, int N, int K, #if CV_TRY_NEON if (opt.use_neon) { - opt_NEON::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, (const char *)packed_B, beta, (char *)C, ldc, sizeof(float)); + opt_NEON::fastGemmKernel(M, N, K, alpha, a, lda0, lda1, packed_b, beta, c, ldc, sizeof(float), opt.multi_thread); } else #endif #if CV_TRY_AVX2 if (opt.use_avx2) { - opt_AVX2::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, (const char *)packed_B, beta, (char *)C, ldc, sizeof(float)); + opt_AVX2::fastGemmKernel(M, N, K, alpha, a, lda0, lda1, packed_b, beta, c, ldc, sizeof(float), opt.multi_thread); } else #endif #if CV_TRY_AVX if (opt.use_avx) { - opt_AVX::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, (const char *)packed_B, beta, (char *)C, ldc, sizeof(float)); + opt_AVX::fastGemmKernel(M, N, K, alpha, a, lda0, lda1, packed_b, beta, c, ldc, sizeof(float), opt.multi_thread); } else #endif #if CV_TRY_LASX if (opt.use_lasx) { - opt_LASX::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, (const char *)packed_B, beta, (char *)C, ldc, sizeof(float)); + opt_LASX::fastGemmKernel(M, N, K, alpha, a, lda0, lda1, packed_b, beta, c, ldc, sizeof(float), opt.multi_thread); } else #endif { - cpu_baseline::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, (const char *)packed_B, beta, (char *)C, ldc, sizeof(float)); + cpu_baseline::fastGemmKernel(M, N, K, alpha, a, lda0, lda1, packed_b, beta, c, ldc, sizeof(float), opt.multi_thread); } } @@ -175,36 +244,41 @@ void fastGemm(bool trans_a, bool trans_b, int ma, int na, int mb, int nb, } if (!trans_b && ldb1 == 1 && (M <= 4 || (uint64_t)M * N * K <= 10000)) { - return fast_gemm_thin(alpha, beta, M, N, K, a, lda0, lda1, b, ldb0, c, ldc); + return fast_gemm_thin(alpha, beta, M, N, K, a, lda0, lda1, b, ldb0, c, ldc, opt.multi_thread); } #if CV_TRY_NEON if (opt.use_neon) { - opt_NEON::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, - (const char *)B, ldb0, ldb1, beta, (char *)C, ldc, sizeof(float)); + opt_NEON::fastGemmKernel(M, N, K, alpha, a, lda0, lda1, + b, ldb0, ldb1, beta, + c, ldc, sizeof(float), opt.multi_thread); } else #endif #if CV_TRY_AVX2 if (opt.use_avx2) { - opt_AVX2::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, - (const char *)B, ldb0, ldb1, beta, (char *)C, ldc, sizeof(float)); + opt_AVX2::fastGemmKernel(M, N, K, alpha, a, lda0, lda1, + b, ldb0, ldb1, beta, + c, ldc, sizeof(float), opt.multi_thread); } else #endif #if CV_TRY_AVX if (opt.use_avx) { - opt_AVX::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, - (const char *)B, ldb0, ldb1, beta, (char *)C, ldc, sizeof(float)); + opt_AVX::fastGemmKernel(M, N, K, alpha, a, lda0, lda1, + b, ldb0, ldb1, beta, + c, ldc, sizeof(float), opt.multi_thread); } else #endif #if CV_TRY_LASX if (opt.use_lasx) { - opt_LASX::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, - (const char *)B, ldb0, ldb1, beta, (char *)C, ldc, sizeof(float)); + opt_LASX::fastGemmKernel(M, N, K, alpha, a, lda0, lda1, + b, ldb0, ldb1, beta, + c, ldc, sizeof(float), opt.multi_thread); } else #endif { - cpu_baseline::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, - (const char *)B, ldb0, ldb1, beta, (char *)C, ldc, sizeof(float)); + cpu_baseline::fastGemmKernel(M, N, K, alpha, a, lda0, lda1, + b, ldb0, ldb1, beta, + c, ldc, sizeof(float), opt.multi_thread); } } diff --git a/modules/dnn/src/layers/cpu_kernels/fast_gemm.hpp b/modules/dnn/src/layers/cpu_kernels/fast_gemm.hpp index 9060068080..a207c63c3c 100644 --- a/modules/dnn/src/layers/cpu_kernels/fast_gemm.hpp +++ b/modules/dnn/src/layers/cpu_kernels/fast_gemm.hpp @@ -22,12 +22,14 @@ struct FastGemmOpt { bool use_avx2; bool use_neon; bool use_lasx; + bool multi_thread; FastGemmOpt() { use_avx = false; use_avx2 = false; use_neon = false; use_lasx = false; + multi_thread = false; } void init() { @@ -35,6 +37,7 @@ struct FastGemmOpt { use_avx2 = checkHardwareSupport(CPU_AVX2); use_neon = checkHardwareSupport(CPU_NEON); use_lasx = checkHardwareSupport(CPU_LASX); + multi_thread = true; } bool all() { @@ -148,7 +151,10 @@ struct MatMulHelper { } }; +size_t fastGemmPackBSize(size_t N, size_t K, const FastGemmOpt &opt); + void fastGemmPackB(const Mat &m, std::vector &packed_B, bool trans, FastGemmOpt &opt); +void fastGemmPackB(bool trans, size_t N, size_t K, const float *B, size_t ldb, float *packed_B, const FastGemmOpt &opt); void fastGemm(bool trans_a, int M, int N, int K, float alpha, const float *A, int lda, diff --git a/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.default.hpp b/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.default.hpp index e985fc46ee..f6bd7317a2 100644 --- a/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.default.hpp +++ b/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.default.hpp @@ -83,10 +83,10 @@ void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0, void fastGemmKernel(int M, int N, int K, float alpha, const char *A, int lda0, int lda1, const char *B, int ldb0, int ldb1, - float beta, char *C, int ldc, int esz); + float beta, char *C, int ldc, int esz, bool multi_thread); void fastGemmKernel(int M, int N, int K, float alpha, const char *A, int lda0, int lda1, - const char *packed_B, float beta, char *C, int ldc, int esz); + const char *packed_B, float beta, char *C, int ldc, int esz, bool multi_thread); void fastGemmBatchKernel(size_t batch, const size_t *A_offsets, const size_t *B_offsets, const size_t *C_offsets, int M, int N, int K, float alpha, const char *A, int lda0, int lda1, @@ -179,7 +179,7 @@ static void fast_gemm_macro_kernel(int m, int n, int k, void fastGemmKernel(int M, int N, int K, float alpha, const char *A, int lda0, int lda1, const char *B, int ldb0, int ldb1, - float beta, char *C, int ldc, int esz) { + float beta, char *C, int ldc, int esz, bool multi_thread) { int GEMM_MC = FAST_GEMM_F32_MC, GEMM_NC = FAST_GEMM_F32_NC, GEMM_MR = FAST_GEMM_F32_MR, @@ -236,15 +236,18 @@ void fastGemmKernel(int M, int N, int K, } }; - int total = total_tiles; - int cost_per_thread = static_cast((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); - double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); - parallel_for_(Range(0, total), fn, nstripes); + if (multi_thread) { + int cost_per_thread = static_cast((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); + double nstripes = (size_t)total_tiles * cost_per_thread * (1 / 1024.0); + parallel_for_(Range(0, total_tiles), fn, nstripes); + } else { + fn(Range(0, total_tiles)); + } } void fastGemmKernel(int M, int N, int K, float alpha, const char *A, int lda0, int lda1, - const char *packed_B, float beta, char *C, int ldc, int esz) { + const char *packed_B, float beta, char *C, int ldc, int esz, bool multi_thread) { int GEMM_MC = FAST_GEMM_F32_MC, GEMM_NC = FAST_GEMM_F32_NC, GEMM_MR = FAST_GEMM_F32_MR, @@ -301,10 +304,13 @@ void fastGemmKernel(int M, int N, int K, } }; - int total = total_tiles; - int cost_per_thread = static_cast((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); - double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); - parallel_for_(Range(0, total), fn, nstripes); + if (multi_thread) { + int cost_per_thread = static_cast((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); + double nstripes = (size_t)total_tiles * cost_per_thread * (1 / 1024.0); + parallel_for_(Range(0, total_tiles), fn, nstripes); + } else { + fn(Range(0, total_tiles)); + } } void fastGemmBatchKernel(size_t batch, const size_t *A_offsets, const size_t *B_offsets, const size_t *C_offsets, diff --git a/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.simd.hpp b/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.simd.hpp index 74677f73ed..8e63d15137 100644 --- a/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.simd.hpp +++ b/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.simd.hpp @@ -122,10 +122,10 @@ void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0, void fastGemmKernel(int M, int N, int K, float alpha, const char *A, int lda0, int lda1, const char *B, int ldb0, int ldb1, - float beta, char *C, int ldc, int esz); + float beta, char *C, int ldc, int esz, bool multi_thread); void fastGemmKernel(int M, int N, int K, float alpha, const char *A, int lda0, int lda1, - const char *packed_B, float beta, char *C, int ldc, int esz); + const char *packed_B, float beta, char *C, int ldc, int esz, bool multi_thread); void fastGemmBatchKernel(size_t batch, const size_t *A_offsets, const size_t *B_offsets, const size_t *C_offsets, int M, int N, int K, float alpha, const char *A, int lda0, int lda1, @@ -568,7 +568,7 @@ void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0, void fastGemmKernel(int M, int N, int K, float alpha, const char *A, int lda0, int lda1, const char *B, int ldb0, int ldb1, - float beta, char *C, int ldc, int esz) { + float beta, char *C, int ldc, int esz, bool multi_thread) { int GEMM_MC = FAST_GEMM_F32_MC, GEMM_NC = FAST_GEMM_F32_NC, GEMM_MR = FAST_GEMM_F32_MR, @@ -646,15 +646,19 @@ void fastGemmKernel(int M, int N, int K, } }; - int total = total_tiles; - int cost_per_thread = static_cast((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); - double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); - parallel_for_(Range(0, total), fn, nstripes); + if (multi_thread) { + int cost_per_thread = static_cast((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); + double nstripes = (size_t)total_tiles * cost_per_thread * (1 / 1024.0); + parallel_for_(Range(0, total_tiles), fn, nstripes); + } else { + fn(Range(0, total_tiles)); + } + } void fastGemmKernel(int M, int N, int K, float alpha, const char *A, int lda0, int lda1, - const char *packed_B, float beta, char *C, int ldc, int esz) { + const char *packed_B, float beta, char *C, int ldc, int esz, bool multi_thread) { int GEMM_MC = FAST_GEMM_F32_MC, GEMM_NC = FAST_GEMM_F32_NC, GEMM_MR = FAST_GEMM_F32_MR, @@ -722,10 +726,13 @@ void fastGemmKernel(int M, int N, int K, } }; - int total = total_tiles; - int cost_per_thread = static_cast((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); - double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); - parallel_for_(Range(0, total), fn, nstripes); + if (multi_thread) { + int cost_per_thread = static_cast((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); + double nstripes = (size_t)total_tiles * cost_per_thread * (1 / 1024.0); + parallel_for_(Range(0, total_tiles), fn, nstripes); + } else { + fn(Range(0, total_tiles)); + } } void fastGemmBatchKernel(size_t batch, const size_t *A_offsets, const size_t *B_offsets, const size_t *C_offsets, diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index e1fa80c165..77dc1c52df 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -13,6 +13,7 @@ #include #include +#include namespace cv { namespace dnn { CV__DNN_INLINE_NS_BEGIN @@ -181,6 +182,17 @@ static Mat extractConstant(const Ptr& net, int node_id, int } } +static std::string getInputName(const Ptr& net, int node_id, int input_id) { + auto onnx_net = net.dynamicCast(); + int initializer_id = onnx_net->getInputInitializerId(node_id, input_id); + if (initializer_id != -1) { + return onnx_net->getNameOfInitializer(initializer_id); + } else { + const auto node = net->getNode(node_id); + return node->getInputName(input_id); + } +} + /* Slice operator has two optional inputs "axes" and "steps". Some models may be set to have Slice with optional inputs of default values, some of them don't. This Subgraph adjusts all optional inputs of Slice up to 5. @@ -212,12 +224,308 @@ class AdjustSliceAllOptionalInputsSubgraph : public Subgraph { node->add_input(""); } } - -private: + private: int slice_id; size_t num_inputs_; }; +/* The fusion for the multi-head attention from vision transformer. + + Abbreviations: + B - batch_size, symbolic; + S - sequence_length, symbolic; + W - hidden_size, W = N * H; + N - num_heads; + H - head_size; + + Graph before fusion: + [Input](BxSxW) + | + LayerNorm + | + Transpose(perm=[1, 0, 2]) + | + | (SxBxW) + | + Matmul[Weight(Wx3W)] + | + Add[Bias(3W)] + / | \ + q_Slice k_Slice v_Slice (output(SxBxW)) + | | | + q_Reshape k_Reshape v_Reshape (output(Sx(BxN)xH), could be optional if N=1) + | | | + q_Transpose k_Transpose v_Transpose + (1,0,2) (1,2,0) (perm=1,0,2) + |((BxN)xSxH) |((BxN)xHxS) | + q_Div / / + \ / / + qk_MatMul / + | / + qk_Softmax / + | ((BxN)xSxS) / ((BxN)xSxH) + \ / + qkv_MatMul (output((BxN)xSxH)) + | + Transpose(perm=1,2,0) + | + Reshape (output(SxH)) + | + MatMul + | + Add + | + [Output](BxSxW) + + + Attributes: + num_heads - number of attention heads + qkv_hidden_sizes - hidden size of qkv respectively, [qk_hidden_size, qk_hidden_size, v_hidden_size], + assume qk_hidden_size = v_hidden_size for now. TODO: support qk_hidden_size != v_hidden_size + scale - scale factor of q, defaults to sqrt(1/num_heads) + Inputs: + weight - merged Q, K, V weights of shape [input_hidden_size, qk_hidden_size + qk_hidden_size + v_hidden_size] + bias - bias of shape [qk_hidden_size + qk_hidden_size + v_hidden_size] + + Graph after fusion: + [Input](BxSxW) + | + LayerNorm + | + Transpose + | + Attention[weight, bias] + | + MatMul + | + Add + | + [Output](BxSxW) + + More details see See https://github.com/microsoft/onnxruntime/blob/v1.16.1/docs/ContribOperators.md#com.microsoft.Attention. +*/ +class AttentionSubGraph : public Subgraph { + public: + AttentionSubGraph() { + int input = addNodeToMatch(""); + int transpose = addNodeToMatch("Transpose", input); // tranpose does not make any differences to the accuracy here in this subgraph + att_matmul = addNodeToMatch("MatMul", transpose, addNodeToMatch("")); + att_add = addNodeToMatch("Add", addNodeToMatch(""), att_matmul); + + // v_path + slice_v = addNodeToMatch("Slice", std::vector{att_add, addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch("")}); + int reshape_v = addNodeToMatch("Reshape", slice_v, addNodeToMatch("")); + int transpose_v = addNodeToMatch("Transpose", reshape_v); + + // q_path + slice_q = addNodeToMatch("Slice", std::vector{att_add, addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch("")}); + reshape_q = addNodeToMatch("Reshape", slice_q, addNodeToMatch("")); + int transpose_q = addNodeToMatch("Transpose", reshape_q); + div_q = addNodeToMatch("Div", transpose_q, addNodeToMatch("")); + + // k_path + slice_k = addNodeToMatch("Slice", std::vector{att_add, addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch("")}); + int reshape_k = addNodeToMatch("Reshape", slice_k, addNodeToMatch("")); + int transpose_k = addNodeToMatch("Transpose", reshape_k); + + // qk + int matmul_qk = addNodeToMatch("MatMul", div_q, transpose_k); + int softmax_qk = addNodeToMatch("Softmax", matmul_qk); + + // qkv + int matmul_qkv = addNodeToMatch("MatMul", softmax_qk, transpose_v); + int transpose_qkv = addNodeToMatch("Transpose", matmul_qkv); + last_reshape = addNodeToMatch("Reshape", transpose_qkv, addNodeToMatch("")); + + setFusedNode("Attention", input); + } + + virtual bool match(const Ptr& net, int nodeId, + std::vector& matchedNodesIds) CV_OVERRIDE { + if (Subgraph::match(net, nodeId, matchedNodesIds)) { + // get attrs - qkv_hidden_sizes + qkv_hidden_sizes.clear(); + auto fill_qkv_hidden_sizes = [&] (const int slice_node_id) { + int slice_start = extractConstant(net, matchedNodesIds[slice_node_id], 1).at(0); + int slice_end = extractConstant(net, matchedNodesIds[slice_node_id], 2).at(0); + if (slice_end == std::numeric_limits::max()) { + qkv_hidden_sizes.push_back(0); // workaround for Slice with end=INT_MAX + } else { + int64_t hidden_size = static_cast(slice_end - slice_start); + qkv_hidden_sizes.push_back(hidden_size); + } + }; + fill_qkv_hidden_sizes(slice_q); + fill_qkv_hidden_sizes(slice_k); + fill_qkv_hidden_sizes(slice_v); // TODO: take care of INT64_MAX + CV_CheckEQ(qkv_hidden_sizes.size(), static_cast(3), "ONNXSimplifier/Attention: invalid qkv hidden sizes"); + CV_CheckEQ(int(qkv_hidden_sizes[0]), int(qkv_hidden_sizes[1]), "ONNXSimplifier/Attention: invalid qkv hidden sizes, q_hidden_size == v_hidden_size is required"); + // get attrs - num_heads, scale + num_heads = extractConstant(net, matchedNodesIds[reshape_q], 1).at(1); + scale = extractConstant(net, matchedNodesIds[div_q], 1).at(0); + output_ndims = extractConstant(net, matchedNodesIds[last_reshape], 1).size[0]; + + // get names + weight_name = getInputName(net, matchedNodesIds[att_matmul], 1); + bias_name = getInputName(net, matchedNodesIds[att_add], 0); + return true; + } + return false; + } + + virtual void finalize(const Ptr& net, + const Ptr& fusedNode, + std::vector >&) CV_OVERRIDE { + // add attrs + opencv_onnx::NodeProto* node = fusedNode.dynamicCast()->node; + opencv_onnx::AttributeProto* attr_num_heads = node->add_attribute(); + attr_num_heads->set_name("num_heads"); + attr_num_heads->set_i(num_heads); + opencv_onnx::AttributeProto* attr_qkv_hidden_sizes = node->add_attribute(); + attr_qkv_hidden_sizes->set_name("qkv_hidden_sizes"); + attr_qkv_hidden_sizes->add_ints(qkv_hidden_sizes[0]); + attr_qkv_hidden_sizes->add_ints(qkv_hidden_sizes[1]); + attr_qkv_hidden_sizes->add_ints(qkv_hidden_sizes[2]); + opencv_onnx::AttributeProto* attr_scale = node->add_attribute(); + attr_scale->set_name("scale"); + attr_scale->set_f(scale); + + // add customized attrs + opencv_onnx::AttributeProto* attr_output_ndims = node->add_attribute(); + attr_output_ndims->set_name("output_ndims"); + attr_output_ndims->set_i(output_ndims); + + // add inputs + node->add_input(weight_name); + node->add_input(bias_name); + } + + private: + int att_matmul, att_add; + int slice_q, slice_k, slice_v; + int reshape_q, div_q, last_reshape; + + std::vector qkv_hidden_sizes; // order: [qk_hidden_size, qk_hidden_size, v_hidden_size] + int64_t num_heads; + float scale; + + int64_t output_ndims; + + std::string weight_name; + std::string bias_name; +}; + +/* Attention subgraph with single head. + No Reshape operator is appended after each Slice operator. +*/ +class AttentionSingleHeadSubGraph : public Subgraph { + public: + AttentionSingleHeadSubGraph() { + int input = addNodeToMatch(""); + int transpose = addNodeToMatch("Transpose", input); // tranpose does not make any differences to the accuracy here in this subgraph + att_matmul = addNodeToMatch("MatMul", transpose, addNodeToMatch("")); + att_add = addNodeToMatch("Add", addNodeToMatch(""), att_matmul); + + // v_path + slice_v = addNodeToMatch("Slice", std::vector{att_add, addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch("")}); + int transpose_v = addNodeToMatch("Transpose", slice_v); + + // q_path + slice_q = addNodeToMatch("Slice", std::vector{att_add, addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch("")}); + int transpose_q = addNodeToMatch("Transpose", slice_q); + div_q = addNodeToMatch("Div", transpose_q, addNodeToMatch("")); + + // k_path + slice_k = addNodeToMatch("Slice", std::vector{att_add, addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch("")}); + int transpose_k = addNodeToMatch("Transpose", slice_k); + + // qk + int matmul_qk = addNodeToMatch("MatMul", div_q, transpose_k); + int softmax_qk = addNodeToMatch("Softmax", matmul_qk); + + // qkv + int matmul_qkv = addNodeToMatch("MatMul", softmax_qk, transpose_v); + int transpose_qkv = addNodeToMatch("Transpose", matmul_qkv); + last_reshape = addNodeToMatch("Reshape", transpose_qkv, addNodeToMatch("")); + + setFusedNode("Attention", input); + } + + virtual bool match(const Ptr& net, int nodeId, + std::vector& matchedNodesIds) CV_OVERRIDE { + if (Subgraph::match(net, nodeId, matchedNodesIds)) { + // get attrs - qkv_hidden_sizes + qkv_hidden_sizes.clear(); + auto fill_qkv_hidden_sizes = [&] (const int slice_node_id) { + int slice_start = extractConstant(net, matchedNodesIds[slice_node_id], 1).at(0); + int slice_end = extractConstant(net, matchedNodesIds[slice_node_id], 2).at(0); + if (slice_end == std::numeric_limits::max()) { + qkv_hidden_sizes.push_back(0); // workaround for Slice with end=INT_MAX + } else { + int64_t hidden_size = static_cast(slice_end - slice_start); + qkv_hidden_sizes.push_back(hidden_size); + } + }; + fill_qkv_hidden_sizes(slice_q); + fill_qkv_hidden_sizes(slice_k); + fill_qkv_hidden_sizes(slice_v); + CV_CheckEQ(qkv_hidden_sizes.size(), static_cast(3), "ONNXSimplifier/Attention: invalid qkv hidden sizes"); + CV_CheckEQ(int(qkv_hidden_sizes[0]), int(qkv_hidden_sizes[1]), "ONNXSimplifier/Attention: invalid qkv hidden sizes, q_hidden_size == v_hidden_size is required"); + // get attrs - num_heads, scale + num_heads = 1; + scale = extractConstant(net, matchedNodesIds[div_q], 1).at(0); + output_ndims = extractConstant(net, matchedNodesIds[last_reshape], 1).size[0]; + + // get names + weight_name = getInputName(net, matchedNodesIds[att_matmul], 1); + bias_name = getInputName(net, matchedNodesIds[att_add], 0); + return true; + } + return false; + } + + virtual void finalize(const Ptr& net, + const Ptr& fusedNode, + std::vector >&) CV_OVERRIDE { + // add attrs + opencv_onnx::NodeProto* node = fusedNode.dynamicCast()->node; + opencv_onnx::AttributeProto* attr_num_heads = node->add_attribute(); + attr_num_heads->set_name("num_heads"); + attr_num_heads->set_i(num_heads); + opencv_onnx::AttributeProto* attr_qkv_hidden_sizes = node->add_attribute(); + attr_qkv_hidden_sizes->set_name("qkv_hidden_sizes"); + attr_qkv_hidden_sizes->add_ints(qkv_hidden_sizes[0]); + attr_qkv_hidden_sizes->add_ints(qkv_hidden_sizes[1]); + attr_qkv_hidden_sizes->add_ints(qkv_hidden_sizes[2]); + opencv_onnx::AttributeProto* attr_scale = node->add_attribute(); + attr_scale->set_name("scale"); + attr_scale->set_f(scale); + + // add customized attrs + opencv_onnx::AttributeProto* attr_output_ndims = node->add_attribute(); + attr_output_ndims->set_name("output_ndims"); + attr_output_ndims->set_i(output_ndims); + + // add inputs + node->add_input(weight_name); + node->add_input(bias_name); + } + + protected: + int att_matmul, att_add; + int slice_q, slice_k, slice_v; + int div_q, last_reshape; + + std::vector qkv_hidden_sizes; // order: [qk_hidden_size, qk_hidden_size, v_hidden_size] + int64_t num_heads; + float scale; + + int64_t output_ndims; + + std::string weight_name; + std::string bias_name; +}; + /* Fusion for Gelu. Graph before fusion: @@ -390,21 +698,6 @@ public: return axis_; } - static std::string getInputName(const Ptr& net, int node_id, int input_id) - { - auto onnx_net = net.dynamicCast(); - int initializer_id = onnx_net->getInputInitializerId(node_id, input_id); - if (initializer_id != -1) - { - return onnx_net->getNameOfInitializer(initializer_id); - } - else - { - const auto node = net->getNode(node_id); - return node->getInputName(input_id); - } - } - virtual bool match(const Ptr& net, int nodeId, std::vector& matchedNodesIds) CV_OVERRIDE { @@ -1252,6 +1545,10 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net) subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); + if (getParam_DNN_BACKEND_DEFAULT() == DNN_BACKEND_OPENCV) { + subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); + } simplifySubgraphs(Ptr(new ONNXGraphWrapper(net)), subgraphs); } diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index eee8b5828e..d65f155a55 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -207,6 +207,7 @@ private: void parseQConcat (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseQGemm (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseQSoftmax (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); + void parseAttention (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); // '???' domain or '???' layer type void parseCustomLayer (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); @@ -3894,6 +3895,31 @@ void ONNXImporter::parseQSoftmax(LayerParams& layerParams, const opencv_onnx::No addLayer(layerParams, node_proto); } +void ONNXImporter::parseAttention(LayerParams& params, const opencv_onnx::NodeProto& node_proto) { + CV_CheckTrue(params.has("num_heads"), "ONNXImporter/parseAttention: num_heads is required but missing"); + CV_CheckTrue(params.has("qkv_hidden_sizes"), "ONNXImporter/parseAttention: qkv_hidden_sizes is required but missing"); + + auto param_qkv_hidden_sizes = params.get("qkv_hidden_sizes"); + CV_CheckEQ(param_qkv_hidden_sizes.size(), 3, "ONNXImporter/parseAttention: qkv_hidden_sizes is must and only have three elements"); + + for (size_t i = 1; i < node_proto.input_size(); i++) { + if (layer_id.find(node_proto.input(i)) == layer_id.end()) { + Mat tensor = getBlob(node_proto, i); + + LayerParams const_params; + const_params.name = node_proto.input(i); + const_params.type = "Const"; + const_params.blobs.push_back(tensor); + + opencv_onnx::NodeProto proto; + proto.add_output(const_params.name); + addLayer(const_params, proto); + } + } + + addLayer(params, node_proto); +} + // Domain: ai.onnx (default) // URL: https://github.com/onnx/onnx/blob/master/docs/Operators.md void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version) @@ -3977,6 +4003,11 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version) dispatch["QLinearConv"] = &ONNXImporter::parseQConv; dispatch["QLinearMatMul"] = &ONNXImporter::parseQMatMul; + // com.microsft: This operator is added for compatibility via onnx graph simplifier. + // Opset domain cannot be modified from onnx_graph_simplifier.cpp so this + // operator cannot be parsed if only added in buildDispatchMap_COM_MICROSOFT + dispatch["Attention"] = &ONNXImporter::parseAttention; + domain_dispatch_map[str_domain_ai_onnx] = dispatch; } @@ -3994,6 +4025,7 @@ void ONNXImporter::buildDispatchMap_COM_MICROSOFT(int opset_version) dispatch["QLinearConcat"] = &ONNXImporter::parseQConcat; dispatch["QGemm"] = &ONNXImporter::parseQGemm; dispatch["QLinearSoftmax"] = &ONNXImporter::parseQSoftmax; + dispatch["Attention"] = &ONNXImporter::parseAttention; domain_dispatch_map["com.microsoft"] = dispatch; } diff --git a/modules/dnn/test/test_graph_simplifier.cpp b/modules/dnn/test/test_graph_simplifier.cpp index f6b85de230..e09a68c158 100644 --- a/modules/dnn/test/test_graph_simplifier.cpp +++ b/modules/dnn/test/test_graph_simplifier.cpp @@ -130,4 +130,13 @@ TEST_F(Test_Graph_Simplifier, MishSubgraph) { test("mish", "Mish"); } +TEST_F(Test_Graph_Simplifier, AttentionSubgraph) { + /* Test for 2 subgraphs + - AttentionSubgraph + - AttentionSingleHeadSubgraph + */ + test("attention", "Attention"); + test("attention_single_head", "Attention"); +} + }} diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index a265b31db9..46064462fb 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -2949,6 +2949,63 @@ TEST_P(Test_ONNX_layers, Expand_shape_model4) { testONNXModels("test_expand_shape_model4", pb, 0, 0, false, true, 1); } +TEST_P(Test_ONNX_layers, Attention) { + testONNXModels("attention"); +} +TEST_P(Test_ONNX_layers, AttentionSingleHead) { + testONNXModels("attention_single_head"); +} + +TEST_P(Test_ONNX_nets, ViT_B_32) { + applyTestTag(CV_TEST_TAG_LONG, CV_TEST_TAG_DEBUG_LONG); + + if (backend == DNN_BACKEND_CUDA && target == DNN_TARGET_CUDA_FP16) + { + // does not pass test for now + applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA_FP16); + } + + const std::string model_path = _tf("models/vit_b_32.onnx", false); + + auto net = readNet(model_path); + ASSERT_FALSE(net.empty()); + + net.setPreferableBackend(backend); + net.setPreferableTarget(target); + + auto image = imread(_tf("../googlenet_0.png")); + auto blob = blobFromImage(image, 1.f, Size(224, 224)); + auto ref = blobFromNPY(_tf("data/output_vit_b_32.npy")); + checkBackend(&blob, &ref); + + net.setInput(blob); + auto out = net.forward(); + + normAssert(ref, out, "ViTB_32", default_l1, default_lInf); +} + +TEST_P(Test_ONNX_nets, VitTrack) { + auto image = imread(_tf("../dog_orig_size.png")); + auto input0 = blobFromImage(image, 1.f, Size(128, 128)); + auto input1 = blobFromImage(image, 1.f, Size(256, 256)); + + auto net = readNet(_tf("models/object_tracking_vittrack_2023sep.onnx", false)); + net.setInput(input0, "template"); + net.setInput(input1, "search"); + + std::vector output_names{"output1", "output2", "output3"}; + std::vector outputs; + net.forward(outputs, output_names); + + auto ref_output1 = blobFromNPY(_tf("data/output_object_tracking_vittrack_2023sep_0.npy")); + auto ref_output2 = blobFromNPY(_tf("data/output_object_tracking_vittrack_2023sep_1.npy")); + auto ref_output3 = blobFromNPY(_tf("data/output_object_tracking_vittrack_2023sep_2.npy")); + + normAssert(ref_output1, outputs[0], "VitTrack output1"); + normAssert(ref_output2, outputs[1], "VitTrack output2"); + normAssert(ref_output3, outputs[2], "VitTrack output3"); +} + INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets()); }} // namespace