From a7fd9446cf5deddb9394cad35a60164c245f6b98 Mon Sep 17 00:00:00 2001 From: Yuantao Feng Date: Wed, 3 Jul 2024 15:09:05 +0800 Subject: [PATCH] Merge pull request #25630 from fengyuentau:nary-multi-thread dnn: parallelize nary elementwise forward implementation & enable related conformance tests #25630 This PR introduces the following changes: - [x] Parallelize binary forward impl - [x] Parallelize ternary forward impl (Where) - [x] Parallelize nary (Operator that can take >=1 operands) - [x] Enable conformance tests if workable ## Performance ### i7-12700K, RAM 64GB, Ubuntu 22.04 ``` Geometric mean (ms) Name of Test opencv opencv opencv perf perf perf core.x64.0606 core.x64.0606 core.x64.0606 vs opencv perf core.x64.0606 (x-factor) NCHW_C_sum::Layer_NaryEltwise::OCV/CPU 16.116 11.161 1.44 NCHW_NCHW_add::Layer_NaryEltwise::OCV/CPU 17.469 11.446 1.53 NCHW_NCHW_div::Layer_NaryEltwise::OCV/CPU 17.531 11.469 1.53 NCHW_NCHW_equal::Layer_NaryEltwise::OCV/CPU 28.653 13.682 2.09 NCHW_NCHW_greater::Layer_NaryEltwise::OCV/CPU 21.899 13.422 1.63 NCHW_NCHW_less::Layer_NaryEltwise::OCV/CPU 21.738 13.185 1.65 NCHW_NCHW_max::Layer_NaryEltwise::OCV/CPU 16.172 11.473 1.41 NCHW_NCHW_mean::Layer_NaryEltwise::OCV/CPU 16.309 11.565 1.41 NCHW_NCHW_min::Layer_NaryEltwise::OCV/CPU 16.166 11.454 1.41 NCHW_NCHW_mul::Layer_NaryEltwise::OCV/CPU 16.157 11.443 1.41 NCHW_NCHW_pow::Layer_NaryEltwise::OCV/CPU 163.459 15.234 10.73 NCHW_NCHW_ref_div::Layer_NaryEltwise::OCV/CPU 10.880 10.868 1.00 NCHW_NCHW_ref_max::Layer_NaryEltwise::OCV/CPU 10.947 11.058 0.99 NCHW_NCHW_ref_min::Layer_NaryEltwise::OCV/CPU 10.948 10.910 1.00 NCHW_NCHW_ref_mul::Layer_NaryEltwise::OCV/CPU 10.874 10.871 1.00 NCHW_NCHW_ref_sum::Layer_NaryEltwise::OCV/CPU 10.971 10.920 1.00 NCHW_NCHW_sub::Layer_NaryEltwise::OCV/CPU 17.546 11.462 1.53 NCHW_NCHW_sum::Layer_NaryEltwise::OCV/CPU 16.175 11.475 1.41 NHWC_C::Layer_NaryEltwise::OCV/CPU 11.339 11.333 1.00 NHWC_H::Layer_NaryEltwise::OCV/CPU 16.154 11.102 1.46 ``` ### Apple M1, RAM 16GB, macOS 14.4.1 ``` Geometric mean (ms) Name of Test opencv opencv opencv perf perf perf core.m1.0606 core.m1.0606.patch core.m1.0606.patch vs opencv perf core.m1.0606 (x-factor) NCHW_C_sum::Layer_NaryEltwise::OCV/CPU 28.418 3.768 7.54 NCHW_NCHW_add::Layer_NaryEltwise::OCV/CPU 6.942 5.679 1.22 NCHW_NCHW_div::Layer_NaryEltwise::OCV/CPU 5.822 5.653 1.03 NCHW_NCHW_equal::Layer_NaryEltwise::OCV/CPU 5.751 5.628 1.02 NCHW_NCHW_greater::Layer_NaryEltwise::OCV/CPU 5.797 5.599 1.04 NCHW_NCHW_less::Layer_NaryEltwise::OCV/CPU 7.272 5.578 1.30 NCHW_NCHW_max::Layer_NaryEltwise::OCV/CPU 5.777 5.562 1.04 NCHW_NCHW_mean::Layer_NaryEltwise::OCV/CPU 5.819 5.559 1.05 NCHW_NCHW_min::Layer_NaryEltwise::OCV/CPU 5.830 5.574 1.05 NCHW_NCHW_mul::Layer_NaryEltwise::OCV/CPU 5.759 5.567 1.03 NCHW_NCHW_pow::Layer_NaryEltwise::OCV/CPU 342.260 74.655 4.58 NCHW_NCHW_ref_div::Layer_NaryEltwise::OCV/CPU 8.338 8.280 1.01 NCHW_NCHW_ref_max::Layer_NaryEltwise::OCV/CPU 8.359 8.309 1.01 NCHW_NCHW_ref_min::Layer_NaryEltwise::OCV/CPU 8.412 8.295 1.01 NCHW_NCHW_ref_mul::Layer_NaryEltwise::OCV/CPU 8.380 8.297 1.01 NCHW_NCHW_ref_sum::Layer_NaryEltwise::OCV/CPU 8.356 8.323 1.00 NCHW_NCHW_sub::Layer_NaryEltwise::OCV/CPU 6.818 5.561 1.23 NCHW_NCHW_sum::Layer_NaryEltwise::OCV/CPU 5.805 5.570 1.04 NHWC_C::Layer_NaryEltwise::OCV/CPU 3.834 4.817 0.80 NHWC_H::Layer_NaryEltwise::OCV/CPU 28.402 3.771 7.53 ``` ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [ ] There is a reference to the original bug report and related work - [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [ ] The feature is well documented and sample code can be built with the project CMake --- modules/dnn/src/cuda/eltwise_ops.cu | 7 + modules/dnn/src/cuda/functors.hpp | 15 + .../dnn/src/cuda4dnn/kernels/eltwise_ops.hpp | 3 + .../dnn/src/cuda4dnn/primitives/eltwise.hpp | 11 +- .../dnn/src/layers/nary_eltwise_layers.cpp | 757 ++++++++++-------- modules/dnn/src/onnx/onnx_importer.cpp | 10 +- modules/dnn/test/test_onnx_conformance.cpp | 15 +- ...rmance_layer_filter__cuda_denylist.inl.hpp | 12 - ...e_layer_filter__cuda_fp16_denylist.inl.hpp | 19 + ...conformance_layer_filter__openvino.inl.hpp | 14 +- ...ance_layer_filter__vulkan_denylist.inl.hpp | 3 + ...e_layer_filter_opencv_all_denylist.inl.hpp | 2 +- ...er_filter_opencv_ocl_fp16_denylist.inl.hpp | 1 + ..._conformance_layer_parser_denylist.inl.hpp | 98 +-- 14 files changed, 563 insertions(+), 404 deletions(-) create mode 100644 modules/dnn/test/test_onnx_conformance_layer_filter__cuda_fp16_denylist.inl.hpp diff --git a/modules/dnn/src/cuda/eltwise_ops.cu b/modules/dnn/src/cuda/eltwise_ops.cu index e2a7cc9a67..2949782138 100644 --- a/modules/dnn/src/cuda/eltwise_ops.cu +++ b/modules/dnn/src/cuda/eltwise_ops.cu @@ -350,6 +350,11 @@ void eltwise_fmod_2(const Stream& stream, TensorSpan output, TensorView x, eltwise_op>(stream, output, x, y); } +template +void eltwise_pow_2(const Stream& stream, TensorSpan output, TensorView x, TensorView y) { + eltwise_op>(stream, output, x, y); +} + #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) template void eltwise_mod_2(const Stream& stream, TensorSpan<__half> output, TensorView<__half> x, TensorView<__half> y); template void eltwise_fmod_2(const Stream& stream, TensorSpan<__half> output, TensorView<__half> x, TensorView<__half> y); @@ -360,6 +365,7 @@ void eltwise_fmod_2(const Stream& stream, TensorSpan output, TensorView x, template void eltwise_sum_2(const Stream& stream, TensorSpan<__half> output, TensorView<__half> x, TensorView<__half> y); template void eltwise_max_2(const Stream& stream, TensorSpan<__half> output, TensorView<__half> x, TensorView<__half> y); template void eltwise_min_2(const Stream& stream, TensorSpan<__half> output, TensorView<__half> x, TensorView<__half> y); + template void eltwise_pow_2(const Stream& stream, TensorSpan<__half> output, TensorView<__half> x, TensorView<__half> y); #endif template void eltwise_mod_2(const Stream& stream, TensorSpan output, TensorView x, TensorView y); template void eltwise_fmod_2(const Stream& stream, TensorSpan output, TensorView x, TensorView y); @@ -370,5 +376,6 @@ void eltwise_fmod_2(const Stream& stream, TensorSpan output, TensorView x, template void eltwise_sum_2(const Stream& stream, TensorSpan output, TensorView x, TensorView y); template void eltwise_max_2(const Stream& stream, TensorSpan output, TensorView x, TensorView y); template void eltwise_min_2(const Stream& stream, TensorSpan output, TensorView x, TensorView y); + template void eltwise_pow_2(const Stream& stream, TensorSpan output, TensorView x, TensorView y); }}}} /* namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda/functors.hpp b/modules/dnn/src/cuda/functors.hpp index cada43387e..5aa271bdf4 100644 --- a/modules/dnn/src/cuda/functors.hpp +++ b/modules/dnn/src/cuda/functors.hpp @@ -833,6 +833,21 @@ struct FModFunctor { } }; +template +struct PowFunctor { + struct Params { + CUDA4DNN_HOST_DEVICE Params() {} + }; + + CUDA4DNN_DEVICE PowFunctor() { } + CUDA4DNN_DEVICE PowFunctor(const Params& params) { } + + CUDA4DNN_DEVICE T operator()(T x, T y) { + using csl::device::pow; + return pow(x, y); + } +}; + }}}} /* namespace cv::dnn::cuda4dnn::kernels */ #endif /* OPENCV_DNN_SRC_CUDA_FUNCTORS_HPP */ diff --git a/modules/dnn/src/cuda4dnn/kernels/eltwise_ops.hpp b/modules/dnn/src/cuda4dnn/kernels/eltwise_ops.hpp index e80db943ae..452d23da64 100644 --- a/modules/dnn/src/cuda4dnn/kernels/eltwise_ops.hpp +++ b/modules/dnn/src/cuda4dnn/kernels/eltwise_ops.hpp @@ -39,6 +39,9 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { template void eltwise_fmod_2(const csl::Stream& stream, csl::TensorSpan output, csl::TensorView x, csl::TensorView y); + template + void eltwise_pow_2(const csl::Stream& stream, csl::TensorSpan output, csl::TensorView x, csl::TensorView y); + }}}} /* namespace cv::dnn::cuda4dnn::kernels */ #endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_ELTWISE_OPS_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/eltwise.hpp b/modules/dnn/src/cuda4dnn/primitives/eltwise.hpp index 5822f48061..1dfab63136 100644 --- a/modules/dnn/src/cuda4dnn/primitives/eltwise.hpp +++ b/modules/dnn/src/cuda4dnn/primitives/eltwise.hpp @@ -30,6 +30,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { SUB, MOD, FMOD, + POW, }; class EltwiseOpBase : public CUDABackendNode { @@ -62,7 +63,6 @@ namespace cv { namespace dnn { namespace cuda4dnn { const std::vector>& outputs, csl::Workspace& workspace) override { - CV_Assert(inputs.size() >= 2); CV_Assert(outputs.size() == 1); CV_Assert(coeffs.size() == 0 || op == EltwiseOpType::SUM); @@ -94,9 +94,13 @@ namespace cv { namespace dnn { namespace cuda4dnn { case EltwiseOpType::SUB: kernels::eltwise_sub_2(stream, output, input_x, input_y); break; case EltwiseOpType::MOD: kernels::eltwise_mod_2(stream, output, input_x, input_y); break; case EltwiseOpType::FMOD: kernels::eltwise_fmod_2(stream, output, input_x, input_y); break; + case EltwiseOpType::POW: kernels::eltwise_pow_2(stream, output, input_x, input_y); break; } - } - else + } else if (inputs.size() == 1) { + auto input_wrapper_0 = inputs[0].dynamicCast(); + auto input_0 = input_wrapper_0->getView(); + csl::tensor_ops::copy(stream, output, input_0); + } else { auto input_wrapper_0 = inputs[0].dynamicCast(); auto input_0 = input_wrapper_0->getView(); @@ -128,6 +132,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { case EltwiseOpType::SUB: kernels::eltwise_sub_2(stream, output, output, input); break; case EltwiseOpType::MOD: kernels::eltwise_mod_2(stream, output, output, input); break; case EltwiseOpType::FMOD: kernels::eltwise_fmod_2(stream, output, output, input); break; + case EltwiseOpType::POW: kernels::eltwise_pow_2(stream, output, output, input); break; } } } diff --git a/modules/dnn/src/layers/nary_eltwise_layers.cpp b/modules/dnn/src/layers/nary_eltwise_layers.cpp index e3a8b2a583..659e7e29a8 100644 --- a/modules/dnn/src/layers/nary_eltwise_layers.cpp +++ b/modules/dnn/src/layers/nary_eltwise_layers.cpp @@ -44,13 +44,11 @@ public: std::vector all_ndims; std::vector> orig_shapes; std::vector> orig_steps; - std::vector ptrs; std::vector> shapes; std::vector> steps; std::vector elemsize; - NaryEltwiseHelper() { - } + NaryEltwiseHelper() {} void init(const std::vector& inputs, const std::vector& outputs) { @@ -59,7 +57,6 @@ public: all_ndims.clear(); orig_shapes.clear(); orig_steps.clear(); - ptrs.clear(); shapes.clear(); steps.clear(); elemsize.clear(); @@ -81,7 +78,6 @@ public: shapes = std::vector>(narrays, std::vector(max_ndims, 0)); steps = std::vector>(narrays, std::vector(max_ndims, 0)); - ptrs = std::vector(narrays, nullptr); for(i = 0; i <= ninputs; i++) { all_ndims.push_back(i == 0 ? out_ndims : inp_ndims[i-1]); @@ -183,6 +179,7 @@ public: this->shapes[k][i] = 1; } } + return true; } }; @@ -288,7 +285,7 @@ public: #ifdef HAVE_VULKAN if (backendId == DNN_BACKEND_VKCOM) return op == OPERATION::ADD || op == OPERATION::PROD || op == OPERATION::SUB || - op == OPERATION::DIV ; + op == OPERATION::DIV; #endif if (backendId == DNN_BACKEND_CUDA) { @@ -333,8 +330,16 @@ public: inputs_arr.getMatVector(inputs); outputs_arr.getMatVector(outputs); + if (op != OPERATION::POW) { + for (size_t i = 0; i < inputs.size(); i++) { + if (inputs[i].depth() != outputs[0].depth()) { + CV_Error(Error::BadDepth, cv::format("NaryEltwiseLayer: Data type mismatch, input %zu of type %d, output of type %d", i, inputs[i].depth(), outputs[0].depth())); + } + } + } + helper.init(inputs, outputs); - CV_Assert(helper.prepare_for_broadcast_op()); + CV_CheckTrue(helper.prepare_for_broadcast_op(), "NaryEltwiseLayer: Preparation for broadcasting failed"); } bool getMemoryShapes(const std::vector &inputs, @@ -342,168 +347,234 @@ public: std::vector &outputs, std::vector &internals) const CV_OVERRIDE { - MatShape outShape = findCommonShape(inputs); - outputs.assign(1, outShape); + if (inputs.size() == 1) { + outputs.assign(1, inputs.front()); + } else { + MatShape outShape = findCommonShape(inputs); + outputs.assign(1, outShape); + } return false; } template - void binary_forward_impl( - int ndims, const std::vector& shape, - const char* data1, const std::vector& step1, - const char* data2, const std::vector& step2, - char* data, const std::vector& step, - const Functor& op) - { + void binary_forward_impl(const Functor& op, int ndims, const std::vector& shape, + const char* data1, const std::vector& step1, + const char* data2, const std::vector& step2, + char* data, const std::vector& step, size_t block_size) { assert(ndims >= 2); - size_t dp1 = step1[ndims-1]/sizeof(T); - size_t dp2 = step2[ndims-1]/sizeof(T); - size_t dp = step[ndims-1]/sizeof(T); - int k, n1 = shape[ndims-1], n2 = shape[ndims-2]; - size_t plane_idx, nplanes = 1; - for (k = 0; k < ndims-2; k++) nplanes *= shape[k]; - - for (plane_idx = 0; plane_idx < nplanes; plane_idx++) { - const char* ptr1_ = data1; - const char* ptr2_ = data2; - char* ptr_ = data; - size_t idx = plane_idx; - for (k = ndims-3; k >= 0; k--) { - size_t next_idx = idx/shape[k]; - int i_k = (int)(idx - next_idx*shape[k]); - ptr1_ += i_k*step1[k]; - ptr2_ += i_k*step2[k]; - ptr_ += i_k*step[k]; - idx = next_idx; - } - for (int i2 = 0; i2 < n2; i2++, ptr1_ += step1[ndims-2], - ptr2_ += step2[ndims-2], - ptr_ += step[ndims-2]) - { - const T* ptr1 = (const T*)ptr1_; - const T* ptr2 = (const T*)ptr2_; - T* ptr = (T*)ptr_; + size_t dp1 = step1.back() / sizeof(T); + size_t dp2 = step2.back() / sizeof(T); + size_t dp = step.back() / sizeof(T); + int plane_size = shape.back(); + int nplanes = std::accumulate(shape.begin(), shape.end() - 1, 1, std::multiplies()); + + if (nplanes == 1) { // parallelize within the plane + const T* ptr1 = (const T*)data1; + const T* ptr2 = (const T*)data2; + T* ptr = (T*)data; + auto worker = [&](const Range &r) { if (dp1 == 1 && dp2 == 1 && dp == 1) { - for(int i1 = 0; i1 < n1; i1++) - ptr[i1] = op(ptr1[i1], ptr2[i1]); + for(int i = r.start; i < r.end; i++) { + ptr[i] = op(ptr1[i], ptr2[i]); + } } else if (dp1 == 1 && dp2 == 0 && dp == 1){ T x2 = *ptr2; - for(int i1 = 0; i1 < n1; i1++) - ptr[i1] = op(ptr1[i1], x2); + for(int i = r.start; i < r.end; i++) { + ptr[i] = op(ptr1[i], x2); + } } else if (dp1 == 0 && dp2 == 1 && dp == 1){ T x1 = *ptr1; - for(int i1 = 0; i1 < n1; i1++) - ptr[i1] = op(x1, ptr2[i1]); + for(int i = r.start; i < r.end; i++) { + ptr[i] = op(x1, ptr2[i]); + } } else { - for(int i1 = 0; i1 < n1; i1++, ptr1 += dp1, ptr2 += dp2, ptr += dp) + for(int i = r.start; i < r.end; i++, ptr1 += dp1, ptr2 += dp2, ptr += dp) { *ptr = op(*ptr1, *ptr2); + } } - } + }; + + double nstripes = plane_size * (1.0 / double(block_size)); + parallel_for_(Range(0, plane_size), worker, nstripes); + } else { // parallelize across planes + auto worker = [&](const Range &r) { + for (int plane_idx = r.start; plane_idx < r.end; plane_idx++) { + const char* ptr1_ = data1; + const char* ptr2_ = data2; + char* ptr_ = data; + size_t idx = plane_idx; + for (int k = ndims - 2; k >= 0; k--) { + size_t next_idx = idx / shape[k]; + size_t i_k = (int)(idx - next_idx * shape[k]); + ptr1_ += i_k * step1[k]; + ptr2_ += i_k * step2[k]; + ptr_ += i_k * step[k]; + idx = next_idx; + } + + const T* ptr1 = (const T*)ptr1_; + const T* ptr2 = (const T*)ptr2_; + T* ptr = (T*)ptr_; + if (dp1 == 1 && dp2 == 1 && dp == 1) { + for(int i = 0; i < plane_size; i++) { + ptr[i] = op(ptr1[i], ptr2[i]); + } + } else if (dp1 == 1 && dp2 == 0 && dp == 1){ + T x2 = *ptr2; + for(int i = 0; i < plane_size; i++) { + ptr[i] = op(ptr1[i], x2); + } + } else if (dp1 == 0 && dp2 == 1 && dp == 1){ + T x1 = *ptr1; + for(int i = 0; i < plane_size; i++) { + ptr[i] = op(x1, ptr2[i]); + } + } else { + for(int i = 0; i < plane_size; i++, ptr1 += dp1, ptr2 += dp2, ptr += dp) { + *ptr = op(*ptr1, *ptr2); + } + } + } + }; + double nstripes = nplanes * (1.0 / double(block_size)); + parallel_for_(Range(0, nplanes), worker, nstripes); } } + /* + Elementwise binary operator (like +, -, x, /, etc.) which takes two operands + */ template - void binary_forward(const Functor& f, const std::vector& inputs, std::vector& outputs) - { + void binary_forward(const Functor& f, const std::vector& inputs, std::vector& outputs, size_t block_size = 6e6) { const Mat& a = inputs[0]; const Mat& b = inputs[1]; Mat& out = outputs[0]; CV_Assert(helper.shapes.size() == 3 && helper.steps.size() == 3); - binary_forward_impl( - helper.max_ndims, helper.shapes[0], a.ptr(), helper.steps[1], - b.ptr(), helper.steps[2], out.ptr(), helper.steps[0], - f); + binary_forward_impl(f, helper.max_ndims, helper.shapes[0], a.ptr(), helper.steps[1], + b.ptr(), helper.steps[2], out.ptr(), helper.steps[0], block_size); } template - void nary_forward_impl( - const Functor& f, const T scale, int ninputs, int ndims, const std::vector& shape, - const char** inp, char* out, - const std::vector>& steps, std::vector& ptrs) - { + void nary_forward_impl(const Functor& op, const T scale, int ninputs, int ndims, const std::vector& shape, + const char** inp, char* out, const std::vector>& steps, size_t block_size) { CV_Assert(ndims >= 2); - size_t dp = steps[0][ndims-1]/sizeof(T); - size_t dp1 = steps[1][ndims-1]/sizeof(T); - size_t dp2 = steps[2][ndims-1]/sizeof(T); - - enum { BLOCK_SIZE = 1024 }; - T blck[BLOCK_SIZE]; + size_t dp = steps[0].back() / sizeof(T); + size_t dp1 = steps[1].back() / sizeof(T); + size_t dp2 = steps[2].back() / sizeof(T); - int k, i, di1=0, n1 = shape[ndims-1], n2 = shape[ndims-2]; - int second = ninputs == 1 ? 1 : 2; - size_t plane_idx, nplanes = 1; - for (k = 0; k < ndims-2; k++) nplanes *= shape[k]; + int plane_size = shape.back(); + int nplanes = std::accumulate(shape.begin(), shape.end() - 1, 1, std::multiplies()); - for (plane_idx = 0; plane_idx < nplanes; plane_idx++) { + if (nplanes == 1) { // parallelize within the plane + AutoBuffer buf_ptrs(steps.size()); + auto ptrs = (char**)buf_ptrs.data(); ptrs[0] = out; - for (i = 0; i < ninputs; i++) ptrs[i+1] = (char*)inp[i]; - size_t idx = plane_idx; - for (k = ndims-3; k >= 0; k--) { - size_t next_idx = idx/shape[k]; - int i_k = (int)(idx - next_idx*shape[k]); - for (i = 0; i < ninputs; i++) - ptrs[i] += i_k*steps[i][k]; - idx = next_idx; + for (int i = 0; i < ninputs; i++) { + ptrs[i+1] = (char*)inp[i]; } - for (int i2 = 0; i2 < n2; i2++) - { - const T* ptr1 = (const T*)(ptrs[1] + steps[1][ndims-2]*i2); - const T* ptr2 = (const T*)(ptrs[second] + steps[second][ndims-2]*i2); - T* ptr = (T*)(ptrs[0] + steps[0][ndims-2]*i2); - if (ninputs <= 2) { - if (dp1 == 1 && dp2 == 1) { - for (int i1 = 0; i1 < n1; i1++) - ptr[i1] = saturate_cast(f(ptr1[i1], ptr2[i1])*scale); - } else { - for(int i1 = 0; i1 < n1; i1++, ptr1 += dp1, ptr2 += dp2, ptr += dp) - *ptr = saturate_cast(f(*ptr1, *ptr2)*scale); + const T* ptr1 = (const T*)(ptrs[1]); + const T* ptr2 = (const T*)(ptrs[2]); + T* ptr = (T*)(ptrs[0]); + auto worker = [&](const Range &r) { + if (dp == 1 && dp1 == 1 && dp2 == 1) { + for (int i = r.start; i < r.end; i++) { + ptr[i] = op(ptr1[i], ptr2[i]); } - } else { - for (int i1 = 0; i1 < n1; i1 += di1, ptr += di1) { - di1 = BLOCK_SIZE < n1-i1 ? BLOCK_SIZE : n1-i1; - if (dp1 == 1 && dp2 == 1) { - for (int j = 0; j < di1; j++) - blck[j] = f(ptr1[j], ptr2[j]); - ptr1 += di1; - ptr2 += di1; + for (int j = 2; j < ninputs; j++) { + int dpj = steps[j + 1].back(); + const T* ptrj = (const T*)(ptrs[j + 1]); + if (dpj == 1) { + for (int i = r.start; i < r.end; i++) { + ptr[i] = saturate_cast(op(ptr[i], ptrj[i]) * scale); + } } else { - for(int j = 0; j < di1; j++, ptr1 += dp1, ptr2 += dp2) - blck[j] = f(*ptr1, *ptr2); + for (int i = r.start; i < r.end; i++, ptrj += dpj) { + ptr[i] = saturate_cast(op(ptr[i], *ptrj) * scale); + } } - for(i = 2; i < ninputs; i++) { - int dp_i = steps[i+1][ndims-1]/sizeof(T); - const T* ptr_i = (const T*)(ptrs[i+1] + - steps[i+1][ndims-2]*i2) + i1*dp_i; - if (dp_i == 1) { - if (i < ninputs-1) { - for (int j = 0; j < di1; j++) - blck[j] = f(blck[j], ptr_i[j]); - } else { - for (int j = 0; j < di1; j++) - ptr[j] = saturate_cast(f(blck[j], ptr_i[j]) * scale); + } + } else { + auto *tmp = ptr; + for (int i = r.start; i < r.end; i++, ptr += dp, ptr1 += dp1, ptr2 += dp2) { + *ptr = op(*ptr1, *ptr2); + } + ptr = tmp; + for (int j = 2; j < ninputs; j++) { + int dpj = steps[j + 1].back(); + const T* ptr_j = (const T*)(ptrs[j + 1]); + for (int i = r.start; i < r.end; i++, ptr += dp, ptr_j += dpj) { + *ptr = saturate_cast(op(*ptr, *ptr_j) * scale); + } + } + } + }; + double nstripes = plane_size * (1.0 / double(block_size)); + parallel_for_(Range(0, plane_size), worker, nstripes); + } else { // parallelize across the plane + auto worker = [&](const Range &r) { + AutoBuffer buf_ptrs(steps.size()); + auto ptrs = (char**)buf_ptrs.data(); + for (int plane_idx = r.start; plane_idx < r.end; plane_idx++) { + ptrs[0] = out; + for (int i = 0; i < ninputs; i++) ptrs[i+1] = (char*)inp[i]; + size_t idx = plane_idx; + for (int k = ndims - 2; k >= 0; k--) { + size_t next_idx = idx / shape[k]; + int i_k = (int)(idx - next_idx * shape[k]); + for (int i = 0; i <= ninputs; i++) { + ptrs[i] += i_k * steps[i][k]; + } + idx = next_idx; + } + + const T* ptr1 = (const T*)(ptrs[1]); + const T* ptr2 = (const T*)(ptrs[2]); + T* ptr = (T*)(ptrs[0]); + if (dp == 1 && dp1 == 1 && dp2 == 1) { + for (int i = 0; i < plane_size; i++) { + ptr[i] = saturate_cast(op(ptr1[i], ptr2[i]) * scale); + } + for (int j = 2; j < ninputs; j++) { + int dpj = steps[j + 1].back(); + const T* ptrj = (const T*)(ptrs[j + 1]); + if (dpj == 1) { + for (int i = 0; i < plane_size; i++) { + ptr[i] = op(ptr[i], saturate_cast(ptrj[i] * scale)); } } else { - if (i < ninputs-1) { - for (int j = 0; j < di1; j++, ptr_i += dp_i) - blck[j] = f(blck[j], *ptr_i); - } else { - for (int j = 0; j < di1; j++, ptr_i += dp_i) - ptr[j] = saturate_cast(f(blck[j], *ptr_i) * scale); + for (int i = 0; i < plane_size; i++, ptrj += dpj) { + ptr[i] = op(ptr[i], saturate_cast(*ptrj * scale)); } } } + } else { + auto *tmp = ptr; + for (int i = 0; i < plane_size; i++, ptr += dp, ptr1 += dp1, ptr2 += dp2) { + *ptr = saturate_cast(op(*ptr1, *ptr2) * scale); + } + ptr = tmp; + for (int j = 2; j < ninputs; j++) { + int dpj = steps[j + 1].back(); + const T* ptrj = (const T*)(ptrs[j + 1]); + for (int i = 0; i < plane_size; i++, ptr += dp, ptrj += dpj) { + *ptr = op(*ptr, saturate_cast(*ptrj * scale)); + } + } } } - } + }; + double nstripes = nplanes * (1.0 / double(block_size)); + parallel_for_(Range(0, nplanes), worker, nstripes); } } + /* + Elementwise nary operator (like sum, mean, etc.) which takes at least one operand + */ template - void nary_forward( - const Functor& f, T scale, - const std::vector& inputs, std::vector& outputs - ) - { + void nary_forward(const Functor& f, T scale, + const std::vector& inputs, std::vector& outputs, + size_t block_size = 6e6) { // collect all input info std::vector v_inp; std::transform(inputs.begin(), inputs.end(), std::back_inserter(v_inp), [] (const Mat& m) { return m.template ptr(); }); @@ -512,13 +583,14 @@ public: // collect output info char* out = outputs[0].ptr(); - nary_forward_impl( - f, scale, helper.ninputs, helper.max_ndims, helper.shapes[0], inp, out, helper.steps, helper.ptrs); + nary_forward_impl(f, scale, helper.ninputs, helper.max_ndims, helper.shapes[0], inp, out, helper.steps, block_size); } + /* + Elementwise ternary operator (like where) which takes three operands + */ template - void trinary_forward(const Functor& f, const std::vector& inputs, std::vector& outputs) - { + void ternary_forward(const Functor& f, const std::vector& inputs, std::vector& outputs, size_t block_size = 6e6) { const Mat& a = inputs[0]; const Mat& b = inputs[1]; const Mat& c = inputs[2]; @@ -526,69 +598,112 @@ public: CV_Assert(helper.shapes.size() == 4 && helper.steps.size() == 4); - trinary_forward_impl( - helper.max_ndims, helper.shapes[0], a.ptr(), helper.steps[1], b.ptr(), helper.steps[2], - c.ptr(), helper.steps[3], out.ptr(), helper.steps[0], - f); + ternary_forward_impl(f, helper.max_ndims, helper.shapes[0], + a.ptr(), helper.steps[1], + b.ptr(), helper.steps[2], + c.ptr(), helper.steps[3], + out.ptr(), helper.steps[0], block_size); } template - void trinary_forward_impl( - int ndims, const std::vector& shape, + void ternary_forward_impl( + const Functor& op, int ndims, const std::vector& shape, const char* data1, const std::vector& step1, const char* data2, const std::vector& step2, const char* data3, const std::vector& step3, - char* data, const std::vector& step, - const Functor& op) - { - assert(ndims >= 2); - size_t dp1 = step1[ndims-1]/sizeof(T); - size_t dp2 = step2[ndims-1]/sizeof(T); - size_t dp3 = step3[ndims-1]/sizeof(T); - size_t dp = step[ndims-1]/sizeof(T); - int k, n1 = shape[ndims-1], n2 = shape[ndims-2]; - size_t plane_idx, nplanes = 1; - for (k = 0; k < ndims-2; k++) nplanes *= shape[k]; - - for (plane_idx = 0; plane_idx < nplanes; plane_idx++) - { - const char* ptr1_ = data1; - const char* ptr2_ = data2; - const char* ptr3_ = data3; - char* ptr_ = data; - size_t idx = plane_idx; - for (k = ndims-3; k >= 0; k--) - { - size_t next_idx = idx/shape[k]; - int i_k = (int)(idx - next_idx*shape[k]); - ptr1_ += i_k*step1[k]; - ptr2_ += i_k*step2[k]; - ptr3_ += i_k*step3[k]; - ptr_ += i_k*step[k]; - idx = next_idx; - } - - for (int i2 = 0; i2 < n2; i2++, ptr1_ += step1[ndims-2], - ptr2_ += step2[ndims-2], - ptr3_ += step3[ndims-2], - ptr_ += step[ndims-2]) - { - const T* ptr1 = (const T*)ptr1_; - const T* ptr2 = (const T*)ptr2_; - const T* ptr3 = (const T*)ptr3_; - T* ptr = (T*)ptr_; - - if (dp1 == 1 && dp2 == 1 && dp3 == 1 && dp == 1) - { - for(int i1 = 0; i1 < n1; i1++) - ptr[i1] = op(ptr1[i1], ptr2[i1], ptr3[i1]); - } - else - { - for(int i1 = 0; i1 < n1; i1++, ptr1 += dp1, ptr2 += dp2, ptr3 += dp3, ptr += dp) + char* data, const std::vector& step, size_t block_size) { + CV_Assert(ndims >= 2); + size_t dp1 = step1.back() / sizeof(T); + size_t dp2 = step2.back() / sizeof(T); + size_t dp3 = step3.back() / sizeof(T); + size_t dp = step.back() / sizeof(T); + int plane_size = shape.back(); + int nplanes = std::accumulate(shape.begin(), shape.end() - 1, 1, std::multiplies()); + + if (nplanes == 1) { // parallelize within the plane + const T *ptr1 = (const T*)data1; + const T *ptr2 = (const T*)data2; + const T *ptr3 = (const T*)data3; + T* ptr = (T*)data; + auto worker = [&](const Range &r) { + if (dp1 == 1 && dp2 == 1 && dp3 == 1 && dp == 1) { + for (int i = r.start; i < r.end; i++) { + ptr[i] = op(ptr1[i], ptr2[i], ptr3[i]); + } + } else if (dp1 == 0 && dp2 == 1 && dp3 == 1 && dp == 1){ + T x1 = *ptr1; + for (int i = r.start; i < r.end; i++) { + ptr[i] = op(x1, ptr2[i], ptr3[i]); + } + } else if (dp1 == 1 && dp2 == 0 && dp3 == 1 && dp == 1){ + T x2 = *ptr2; + for (int i = r.start; i < r.end; i++) { + ptr[i] = op(ptr1[i], x2, ptr3[i]); + } + } else if (dp1 == 1 && dp2 == 1 && dp3 == 1 && dp == 1) { + T x3 = *ptr3; + for (int i = r.start; i < r.end; i++) { + ptr[i] = op(ptr1[i], ptr2[i], x3); + } + } else { + for(int i = r.start; i < r.end; i++, ptr1 += dp1, ptr2 += dp2, ptr3 += dp3, ptr += dp) { *ptr = op(*ptr1, *ptr2, *ptr3); + } } - } + }; + double nstripes = plane_size * (1.0 / double(block_size)); + parallel_for_(Range(0, plane_size), worker, nstripes); + } else { // parallelize across planes + auto worker = [&](const Range &r) { + for (int plane_idx = r.start; plane_idx < r.end; plane_idx++) { + const char* ptr1_ = data1; + const char* ptr2_ = data2; + const char* ptr3_ = data3; + char* ptr_ = data; + size_t idx = plane_idx; + for (int k = ndims - 2; k >= 0; k--) + { + size_t next_idx = idx / shape[k]; + int i_k = (int)(idx - next_idx * shape[k]); + ptr1_ += i_k * step1[k]; + ptr2_ += i_k * step2[k]; + ptr3_ += i_k * step3[k]; + ptr_ += i_k * step[k]; + idx = next_idx; + } + + const T *ptr1 = (const T*)ptr1_; + const T *ptr2 = (const T*)ptr2_; + const T *ptr3 = (const T*)ptr3_; + T* ptr = (T*)ptr_; + if (dp1 == 1 && dp2 == 1 && dp3 == 1 && dp == 1) { + for (int i = 0; i < plane_size; i++) { + ptr[i] = op(ptr1[i], ptr2[i], ptr3[i]); + } + } else if (dp1 == 0 && dp2 == 1 && dp3 == 1 && dp == 1){ + T x1 = *ptr1; + for (int i = 0; i < plane_size; i++) { + ptr[i] = op(x1, ptr2[i], ptr3[i]); + } + } else if (dp1 == 1 && dp2 == 0 && dp3 == 1 && dp == 1){ + T x2 = *ptr2; + for (int i = 0; i < plane_size; i++) { + ptr[i] = op(ptr1[i], x2, ptr3[i]); + } + } else if (dp1 == 1 && dp2 == 1 && dp3 == 1 && dp == 1) { + T x3 = *ptr3; + for (int i = 0; i < plane_size; i++) { + ptr[i] = op(ptr1[i], ptr2[i], x3); + } + } else { + for(int i = 0; i < plane_size; i++, ptr1 += dp1, ptr2 += dp2, ptr3 += dp3, ptr += dp) { + *ptr = op(*ptr1, *ptr2, *ptr3); + } + } + } + }; + double nstripes = nplanes * (1.0 / double(block_size)); + parallel_for_(Range(0, nplanes), worker, nstripes); } } @@ -608,143 +723,147 @@ public: inputs_arr.getMatVector(inputs); outputs_arr.getMatVector(outputs); - // TODO: assert types + if (inputs.size() == 1) { + inputs[0].copyTo(outputs[0]); + return; + } + typeDispatch(outputs[0].type(), inputs.size(), inputs, outputs); } template inline void opDispatch(size_t ninputs, Args&&... args) { - switch (op) - { - case OPERATION::EQUAL: - { - auto equal = [](const T &a, const T &b) { return a == b; }; - binary_forward(equal, std::forward(args)...); - break; - } - case OPERATION::GREATER: - { - auto greater = [](const T &a, const T &b) { return a > b; }; - binary_forward(greater, std::forward(args)...); - break; - } - case OPERATION::GREATER_EQUAL: - { - auto greater_equal = [](const T &a, const T &b) { return a >= b; }; - binary_forward(greater_equal, std::forward(args)...); - break; - } - case OPERATION::LESS: - { - auto less = [](const T &a, const T &b) { return a < b; }; - binary_forward(less, std::forward(args)...); - break; - } - case OPERATION::LESS_EQUAL: - { - auto less_equal = [](const T &a, const T &b) { return a <= b; }; - binary_forward(less_equal, std::forward(args)...); - break; - } - case OPERATION::POW: - { - auto pow = [] (const T& a, const T& b) { return std::pow(a, b); }; - binary_forward(pow, std::forward(args)...); - break; - } - case OPERATION::BITSHIFT: - { - auto bitshift = [] (const uint8_t &a, const uint8_t &b) { return a << b; }; - binary_forward(bitshift, std::forward(args)...); - break; - } - case OPERATION::MAX: - { - auto max = [](const T &a, const T &b) { return std::max(a, b); }; - nary_forward(max, T{1}, std::forward(args)...); - break; - } - case OPERATION::MEAN: - { - auto mean = [](const T &a, const T &b) { return (a + b) / T{2}; }; - nary_forward(mean, T{1} / ninputs, std::forward(args)...); - break; - } - case OPERATION::MIN: - { - auto min = [](const T &a, const T &b) { return std::min(a, b); }; - nary_forward(min, T{1}, std::forward(args)...); - break; - } - case OPERATION::MOD: - { - auto mod = [] (const T &a, const T &b) { return static_cast(_mod(int(a), int(b))); }; - binary_forward(mod, std::forward(args)...); - break; - } - case OPERATION::FMOD: - { - auto fmod = [](const T &a, const T &b) { return std::fmod(a, b); }; - binary_forward(fmod, std::forward(args)...); - break; - } - case OPERATION::PROD: - { - auto prod = [](const T &a, const T &b) { return a * b; }; - binary_forward(prod, std::forward(args)...); - break; - } - case OPERATION::SUB: - { - auto sub = [](const T &a, const T &b) { return a - b; }; - binary_forward(sub, std::forward(args)...); - break; - } - case OPERATION::SUM: - { - auto sum = [](const T &a, const T &b) { return a + b; }; - nary_forward(sum, T{1}, std::forward(args)...); - break; - } - case OPERATION::ADD: - { - auto add = [](const T &a, const T &b) { return a + b; }; - binary_forward(add, std::forward(args)...); - break; - } - case OPERATION::DIV: - { - auto div = [](const T &a, const T &b) { return a / b; }; - binary_forward(div, std::forward(args)...); - break; - } - case OPERATION::AND: - { - auto op_and = [](const uint8_t &a, const uint8_t &b) { return a & b; }; - binary_forward(op_and, std::forward(args)...); - break; - } - case OPERATION::OR: - { - auto op_or = [](const uint8_t &a, const uint8_t &b) { return a | b; }; - binary_forward(op_or, std::forward(args)...); - break; - } - case OPERATION::XOR: - { - auto op_xor = [](const uint8_t &a, const uint8_t &b) { return a ^ b; }; - binary_forward(op_xor, std::forward(args)...); - break; + if (ninputs == 2) { // Operators that take two operands + switch (op) { + case OPERATION::AND: { + auto op_and = [](const uint8_t &a, const uint8_t &b) { return a & b; }; + binary_forward(op_and, std::forward(args)...); + break; + } + case OPERATION::EQUAL: { + auto equal = [](const T &a, const T &b) { return a == b; }; + binary_forward(equal, std::forward(args)...); + break; + } + case OPERATION::GREATER: { + auto greater = [](const T &a, const T &b) { return a > b; }; + binary_forward(greater, std::forward(args)...); + break; + } + case OPERATION::GREATER_EQUAL: { + auto greater_equal = [](const T &a, const T &b) { return a >= b; }; + binary_forward(greater_equal, std::forward(args)...); + break; + } + case OPERATION::LESS: { + auto less = [](const T &a, const T &b) { return a < b; }; + binary_forward(less, std::forward(args)...); + break; + } + case OPERATION::LESS_EQUAL: { + auto less_equal = [](const T &a, const T &b) { return a <= b; }; + binary_forward(less_equal, std::forward(args)...); + break; + } + case OPERATION::OR: { + auto op_or = [](const uint8_t &a, const uint8_t &b) { return a | b; }; + binary_forward(op_or, std::forward(args)...); + break; + } + case OPERATION::POW: { + auto pow = [] (const T& a, const T& b) { return std::pow(a, b); }; + binary_forward(pow, std::forward(args)..., 1e5); + break; + } + case OPERATION::XOR: { + auto op_xor = [](const uint8_t &a, const uint8_t &b) { return a ^ b; }; + binary_forward(op_xor, std::forward(args)...); + break; + } + case OPERATION::BITSHIFT: { + auto bitshift = [] (const uint8_t &a, const uint8_t &b) { return a << b; }; + binary_forward(bitshift, std::forward(args)...); + break; + } + case OPERATION::MAX: { + auto max = [](const T &a, const T &b) { return std::max(a, b); }; + binary_forward(max, std::forward(args)...); + break; + } + case OPERATION::MEAN: { + auto mean = [](const T &a, const T &b) { return (a + b) / T{2}; }; + binary_forward(mean, std::forward(args)...); + break; + } + case OPERATION::MIN: { + auto min = [](const T &a, const T &b) { return std::min(a, b); }; + binary_forward(min, std::forward(args)...); + break; + } + case OPERATION::MOD: { + auto mod = [] (const T &a, const T &b) { return static_cast(_mod(int(a), int(b))); }; + binary_forward(mod, std::forward(args)...); + break; + } + case OPERATION::FMOD: { + auto fmod = [](const T &a, const T &b) { return std::fmod(a, b); }; + binary_forward(fmod, std::forward(args)...); + break; + } + case OPERATION::PROD: { + auto prod = [](const T &a, const T &b) { return a * b; }; + binary_forward(prod, std::forward(args)...); + break; + } + case OPERATION::SUB: { + auto sub = [](const T &a, const T &b) { return a - b; }; + binary_forward(sub, std::forward(args)...); + break; + } + case OPERATION::ADD: + case OPERATION::SUM: { + auto sum = [](const T &a, const T &b) { return a + b; }; + binary_forward(sum, std::forward(args)...); + break; + } + case OPERATION::DIV: { + auto div = [](const T &a, const T &b) { return a / b; }; + binary_forward(div, std::forward(args)...); + break; + } + default: CV_Error(Error::StsBadArg, "Unsupported operation"); } - case OPERATION::WHERE: + } else if (ninputs == 3 && op == OPERATION::WHERE) { // Operators that take three operands + auto where = [](const T &a, const T &b, const T &c) { return a ? b : c; }; + ternary_forward(where, std::forward(args)...); + } else { // Operators that can take multiple (>= 3) operands + switch (op) { - auto op_where = [](const T &a, const T &b, const T &c) { return a ? b : c; }; - trinary_forward(op_where, std::forward(args)...); - break; + case OPERATION::MAX: { + auto max = [](const T &a, const T &b) { return std::max(a, b); }; + nary_forward(max, T{1}, std::forward(args)...); + break; + } + case OPERATION::MEAN: { + // Sum up inputs and then calculate mean by scale = 1 / ninputs + auto sum = [](const T &a, const T &b) { return a + b; }; + nary_forward(sum, T{1} / ninputs, std::forward(args)...); + break; + } + case OPERATION::MIN: { + auto min = [](const T &a, const T &b) { return std::min(a, b); }; + nary_forward(min, T{1}, std::forward(args)...); + break; + } + case OPERATION::SUM: { + auto sum = [](const T &a, const T &b) { return a + b; }; + nary_forward(sum, T{1}, std::forward(args)...); + break; + } + default: + CV_Error(Error::StsBadArg, "Unsupported operation."); } - default: - CV_Error(Error::StsBadArg, "Unsupported operation."); }; } @@ -811,6 +930,9 @@ public: case OPERATION::FMOD: op_ = cuda4dnn::EltwiseOpType::FMOD; break; + case OPERATION::POW: + op_ = cuda4dnn::EltwiseOpType::POW; + break; default: return Ptr(); // return empty cuda_node if the EltwiseOpType is unsupported type. }; @@ -881,6 +1003,15 @@ public: #ifdef HAVE_DNN_NGRAPH virtual Ptr initNgraph(const std::vector >& inputs, const std::vector >& nodes) CV_OVERRIDE { + // In case only one input + if (inputs.size() == 1) { + auto &ieInpNode = nodes[0].dynamicCast()->node; + ngraph::OutputVector inp{ieInpNode}; + auto blank = std::make_shared(inp, 0); + return Ptr(new InfEngineNgraphNode(blank)); + } + + // TODO: Support multiple (>=3) inputs CV_Assert(inputs.size() == 2); auto& inp0 = nodes[0].dynamicCast()->node; auto& inp1 = nodes[1].dynamicCast()->node; diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 565a88b760..3745d7ed86 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -2851,14 +2851,6 @@ void ONNXImporter::parseElementWise(LayerParams& layerParams, const opencv_onnx: }; } - // element-wise layers that can have >=1 inputs but actually have one input - if (node_proto.input_size() == 1 && (op_type == "max" || op_type == "min" || op_type == "mean" || op_type == "sum")) - { - layerParams.type = "Identity"; - addLayer(layerParams, node_proto); - return; - } - auto pre_broadcast_transform = [](Mat& t, int t_real_ndims) { if (t.dims == 2 && t_real_ndims == 1 && t.size[1] == 1) transpose(t, t); @@ -3938,7 +3930,7 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version) dispatch["Sub"] = dispatch["Mul"] = dispatch["Div"] = dispatch["GreaterOrEqual"] = dispatch["LessOrEqual"] = dispatch["Mod"] = &ONNXImporter::parseElementWise; - dispatch["Sum"] = dispatch["Min"] = dispatch["Max"] = &ONNXImporter::parseElementWise; + dispatch["Sum"] = dispatch["Min"] = dispatch["Max"] = dispatch["Mean"] = &ONNXImporter::parseElementWise; dispatch["Where"] = &ONNXImporter::parseElementWise; dispatch["Range"] = &ONNXImporter::parseRange; dispatch["Einsum"] = &ONNXImporter::parseEinsum; diff --git a/modules/dnn/test/test_onnx_conformance.cpp b/modules/dnn/test/test_onnx_conformance.cpp index 1ca3f2f75b..bd892adb2f 100644 --- a/modules/dnn/test/test_onnx_conformance.cpp +++ b/modules/dnn/test/test_onnx_conformance.cpp @@ -970,6 +970,7 @@ public: #endif #ifdef HAVE_CUDA static std::set cuda_deny_list; + static std::set cuda_fp16_deny_list; #endif Test_ONNX_conformance() @@ -1055,6 +1056,9 @@ public: cuda_deny_list = { #include "test_onnx_conformance_layer_filter__cuda_denylist.inl.hpp" }; + cuda_fp16_deny_list = { + #include "test_onnx_conformance_layer_filter__cuda_fp16_denylist.inl.hpp" + }; #endif } @@ -1074,6 +1078,7 @@ std::set Test_ONNX_conformance::vulkan_deny_list; #endif #ifdef HAVE_CUDA std::set Test_ONNX_conformance::cuda_deny_list; +std::set Test_ONNX_conformance::cuda_fp16_deny_list; #endif TEST_P(Test_ONNX_conformance, Layer_Test) @@ -1114,6 +1119,10 @@ TEST_P(Test_ONNX_conformance, Layer_Test) { applyTestTag(CV_TEST_TAG_DNN_SKIP_CPU, CV_TEST_TAG_DNN_SKIP_OPENCV_BACKEND, CV_TEST_TAG_DNN_SKIP_ONNX_CONFORMANCE); } + + if (name == "test_pow") { + default_lInf = 0.00013; // Expected: (normInf) <= (lInf), actual: 0.00012207 vs 0.0001 + } } #ifdef HAVE_HALIDE else if (backend == DNN_BACKEND_HALIDE) @@ -1142,10 +1151,14 @@ TEST_P(Test_ONNX_conformance, Layer_Test) #ifdef HAVE_CUDA else if (backend == DNN_BACKEND_CUDA) { - if (cuda_deny_list.find(name) != cuda_deny_list.end()) + if (target == DNN_TARGET_CUDA && cuda_deny_list.find(name) != cuda_deny_list.end()) { applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA, CV_TEST_TAG_DNN_SKIP_ONNX_CONFORMANCE); } + if (target == DNN_TARGET_CUDA_FP16 && cuda_fp16_deny_list.find(name) != cuda_fp16_deny_list.end()) + { + applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA_FP16, CV_TEST_TAG_DNN_SKIP_ONNX_CONFORMANCE); + } } #endif else diff --git a/modules/dnn/test/test_onnx_conformance_layer_filter__cuda_denylist.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_filter__cuda_denylist.inl.hpp index 96778ef5d4..42968ef721 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_filter__cuda_denylist.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_filter__cuda_denylist.inl.hpp @@ -73,21 +73,9 @@ "test_maxunpool_export_with_output_shape", "test_mul_bcast", "test_mul_uint8", -"test_reduce_prod_default_axes_keepdims_example", // FP16 only -"test_reduce_prod_default_axes_keepdims_random", // FP16 only -"test_reduce_prod_do_not_keepdims_random", // FP16 only -"test_reduce_prod_keepdims_random", // FP16 only -"test_reduce_prod_negative_axes_keepdims_random", // FP16 only -"test_reduce_sum_square_default_axes_keepdims_random", // FP16 only -"test_reduce_sum_square_do_not_keepdims_random", // FP16 only -"test_reduce_sum_square_keepdims_random", // FP16 only -"test_reduce_sum_square_negative_axes_keepdims_random", // FP16 only "test_softmax_default_axis", -"test_softmax_large_number", // FP16 only -"test_softmax_large_number_expanded", // FP16 only "test_sub_bcast", "test_sub_uint8", -"test_tan", // FP16 only "test_upsample_nearest", "test_scatter_elements_with_axis", "test_scatter_elements_with_duplicate_indices", diff --git a/modules/dnn/test/test_onnx_conformance_layer_filter__cuda_fp16_denylist.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_filter__cuda_fp16_denylist.inl.hpp new file mode 100644 index 0000000000..4fe0825632 --- /dev/null +++ b/modules/dnn/test/test_onnx_conformance_layer_filter__cuda_fp16_denylist.inl.hpp @@ -0,0 +1,19 @@ +"test_basic_conv_with_padding", // (assert failed) !blobs.empty() in initCUDA +"test_basic_conv_without_padding", // (assert failed) !blobs.empty() in initCUDA +"test_conv_with_autopad_same", // (assert failed) !blobs.empty() in initCUDA +"test_conv_with_strides_and_asymmetric_padding", // (assert failed) !blobs.empty() in initCUDA +"test_conv_with_strides_no_padding", // (assert failed) !blobs.empty() in initCUDA +"test_conv_with_strides_padding", // (assert failed) !blobs.empty() in initCUDA +"test_dropout_default_ratio", +"test_logsoftmax_large_number", // fp16 accuracy issue +"test_logsoftmax_large_number_expanded", // fp16 accuracy issue +"test_reduce_prod_default_axes_keepdims_example", // fallback to cpu, accuracy +"test_reduce_prod_default_axes_keepdims_random", // fallback to cpu, accuracy +"test_reduce_sum_square_default_axes_keepdims_random", // fallback to cpu, accuracy +"test_reduce_sum_square_do_not_keepdims_random", // fallback to cpu, accuracy +"test_reduce_sum_square_keepdims_random", // fallback to cpu, accuracy +"test_reduce_sum_square_negative_axes_keepdims_random", // fallback to cpu, accuracy +"test_pow", // fp16 accuracy issue +"test_softmax_large_number", // fp16 accuracy issue +"test_softmax_large_number_expanded", // fp16 accuracy issue +"test_tan", // fp16 accuracy issue diff --git a/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp index 000e867217..229bb9ca82 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp @@ -86,7 +86,11 @@ CASE(test_adam) CASE(test_adam_multiple) // no filter CASE(test_add) - // no filter + if (target == DNN_TARGET_OPENCL) + { + default_l1 = 0.00024; // Expected: (normL1) <= (l1), actual: 0.000234754 vs 1e-05 + default_lInf = 0.0011; // Expected: (normInf) <= (lInf), actual: 0.00106502 vs 0.0001 + } CASE(test_add_bcast) #if SKIP_SET_1 SKIP; @@ -1110,7 +1114,11 @@ CASE(test_momentum) CASE(test_momentum_multiple) // no filter CASE(test_mul) - // no filter + if (target == DNN_TARGET_OPENCL) + { + default_l1 = 0.00024; // Expected: (normL1) <= (l1), actual: 0.00023824 vs 1e-05 + default_lInf = 0.0015; // Expected: (normInf) <= (lInf), actual: 0.00145674 vs 0.0001 + } CASE(test_mul_bcast) #if SKIP_SET_1 SKIP; @@ -1262,7 +1270,7 @@ CASE(test_or_bcast4v3d) CASE(test_or_bcast4v4d) // no filter CASE(test_pow) - // no filter + SKIP_OPENCL_FP16; CASE(test_pow_bcast_array) // no filter CASE(test_pow_bcast_scalar) diff --git a/modules/dnn/test/test_onnx_conformance_layer_filter__vulkan_denylist.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_filter__vulkan_denylist.inl.hpp index f87e16a42f..968dd1e025 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_filter__vulkan_denylist.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_filter__vulkan_denylist.inl.hpp @@ -68,6 +68,9 @@ "test_maxunpool_export_with_output_shape", "test_maxunpool_export_without_output_shape", "test_mul_uint8", +"test_pow_types_float32_int32", // vulkan backend does not take tensor other than float32 data type +"test_pow_types_float32_int64", // vulkan backend does not take tensor other than float32 data type +"test_pow_types_int", // vulkan backend does not take tensor other than float32 data type "test_softmax_default_axis", "test_sub_bcast", "test_sub_uint8", diff --git a/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_all_denylist.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_all_denylist.inl.hpp index 0da0111990..0370b22764 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_all_denylist.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_all_denylist.inl.hpp @@ -50,7 +50,7 @@ "test_maxpool_with_argmax_2d_precomputed_strides", "test_maxunpool_export_with_output_shape", // exception during net.forward() call "test_mul_uint8", // output type mismatch -"test_sub_bcast", +"test_sub_bcast", // 1d support is required "test_sub_uint8", // output type mismatch "test_upsample_nearest", "test_div_bcast", // remove when 1D Mat is supported diff --git a/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_ocl_fp16_denylist.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_ocl_fp16_denylist.inl.hpp index 9b6b2414db..7303348d10 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_ocl_fp16_denylist.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_ocl_fp16_denylist.inl.hpp @@ -14,6 +14,7 @@ "test_maxpool_2d_same_upper", "test_maxpool_2d_strides", "test_maxpool_3d_default", +"test_pow", // fp16 accuracy issue "test_softmax_large_number", "test_softmax_large_number_expanded", "test_split_equal_parts_1d", diff --git a/modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp index cb008e9670..243c7e704d 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp @@ -93,7 +93,6 @@ "test_dequantizelinear_axis", "test_det_2d", "test_det_nd", -"test_div_example", "test_dropout_default_mask", "test_dropout_default_mask_ratio", "test_dynamicquantizelinear", @@ -175,50 +174,34 @@ "test_lstm_with_initial_bias", "test_lstm_with_peepholes", "test_matmulinteger", -"test_max_example", -"test_max_float16", -"test_max_float32", -"test_max_float64", -"test_max_int16", -"test_max_int32", -"test_max_int64", -"test_max_int8", -"test_max_one_input", -"test_max_two_inputs", -"test_max_uint16", -"test_max_uint32", -"test_max_uint64", -"test_max_uint8", -"test_mean_example", -"test_mean_one_input", -"test_mean_two_inputs", -"test_min_example", -"test_min_float16", -"test_min_float32", -"test_min_float64", -"test_min_int16", -"test_min_int32", -"test_min_int64", -"test_min_int8", -"test_min_one_input", -"test_min_two_inputs", -"test_min_uint16", -"test_min_uint32", -"test_min_uint64", -"test_min_uint8", -"test_mod_broadcast", -"test_mod_int64_fmod", -"test_mod_mixed_sign_int16", -"test_mod_mixed_sign_int32", -"test_mod_mixed_sign_int64", -"test_mod_mixed_sign_int8", -"test_mod_uint16", -"test_mod_uint32", -"test_mod_uint64", -"test_mod_uint8", +"test_max_int16", // output type (int16) mismatched +"test_max_int32", // output type (int32) mismatched +"test_max_int64", // output type (int64) mismatched +"test_max_int8", // output type (int8) mismatched +"test_max_uint16", // output type (uint16) mismatched +"test_max_uint32", // output type (uint32) mismatched +"test_max_uint64", // output type (uint64) mismatched +"test_max_uint8", // output type (uint8) mismatched +"test_min_int16", // output type (int16) mismatched +"test_min_int32", // output type (int32) mismatched +"test_min_int64", // output type (int64) mismatched +"test_min_int8", // output type (int8) mismatched +"test_min_uint16", // output type (uint16) mismatched +"test_min_uint32", // output type (uint32) mismatched +"test_min_uint64", // output type (uint64) mismatched +"test_min_uint8", // output type (uint8) mismatched +"test_mod_broadcast", // output type (int32) mismatched +"test_mod_int64_fmod", // output type (int64) mismatched +"test_mod_mixed_sign_int16", // unsupported data type (int16) +"test_mod_mixed_sign_int32", // output type (int32) mismatched +"test_mod_mixed_sign_int64", // output type (int64) mismatched +"test_mod_mixed_sign_int8", // output type (int8) mismatched +"test_mod_uint16", // unsupported data type (uint16) +"test_mod_uint32", // unsupported data type (uint32) +"test_mod_uint64", // unsupported data type (uint32) +"test_mod_uint8", // output type (int8) mismatched "test_momentum", "test_momentum_multiple", -"test_mul_example", "test_mvn", "test_mvn_expanded", "test_nesterov_momentum", @@ -287,20 +270,14 @@ "test_or_bcast4v2d", "test_or_bcast4v3d", "test_or_bcast4v4d", -"test_pow", -"test_pow_bcast_array", -"test_pow_bcast_scalar", -"test_pow_example", -"test_pow_types_float", -"test_pow_types_float32_int32", -"test_pow_types_float32_int64", -"test_pow_types_float32_uint32", -"test_pow_types_float32_uint64", -"test_pow_types_int", -"test_pow_types_int32_float32", -"test_pow_types_int32_int32", -"test_pow_types_int64_float32", -"test_pow_types_int64_int64", +"test_pow_bcast_array", // 1d support is required +"test_pow_types_float", // output type (int64) mismatched +"test_pow_types_float32_uint32", // exponent of unsupported data type (uint32) +"test_pow_types_float32_uint64", // exponent of unsupported data type (uint64) +"test_pow_types_int32_float32", // output type (int32) mismatched +"test_pow_types_int32_int32", // output type (int32) mismatched +"test_pow_types_int64_float32", // output type (int64) mismatched +"test_pow_types_int64_int64", // output type (int64) mismatched "test_prelu_broadcast", "test_prelu_example", "test_qlinearconv", @@ -468,9 +445,6 @@ "test_strnormalizer_export_monday_empty_output", "test_strnormalizer_export_monday_insensintive_upper_twodim", "test_strnormalizer_nostopwords_nochangecase", -"test_sub_example", -"test_sum_example", -"test_sum_two_inputs", "test_tfidfvectorizer_tf_batch_onlybigrams_skip0", "test_tfidfvectorizer_tf_batch_onlybigrams_skip5", "test_tfidfvectorizer_tf_batch_uniandbigrams_skip5", @@ -519,8 +493,8 @@ "test_unsqueeze_three_axes", "test_unsqueeze_two_axes", "test_unsqueeze_unsorted_axes", -"test_where_example", -"test_where_long_example", +"test_where_example", // input of unsupported data type (bool) +"test_where_long_example", // input of unsupported data type (bool) "test_xor2d", "test_xor3d", "test_xor4d",