From 3f13ce797b89b4bc83a3caddc682ea2158c3f32c Mon Sep 17 00:00:00 2001 From: Yuantao Feng Date: Sat, 22 Jun 2024 00:28:22 +0800 Subject: [PATCH] Merge pull request #25779 from fengyuentau:dnn/fix_onnx_depthtospace dnn: add DepthToSpace and SpaceToDepth #25779 We are working on updating WeChat QRCode module. One of the new models is a fully convolutional model and hence it should be able to run with different input shapes. However, it has an operator `DepthToSpace`, which is parsed as a subgraph of `Reshape -> Permute -> Reshape` with a fixed shape getting during parsing. The subgraph itself is not a problem, but the true problem is the subgraph with a fixed input and output shape regardless input changes. This does not allow the model to run with different input shapes. Solution is to add a dedicated layer for DepthtoSpace and SpaceToDepth. Backend support: - [x] CPU - [x] CUDA - [x] OpenCL - [x] OpenVINO - [x] CANN - [x] TIMVX - ~Vulkan~ (missing fundamental tools, like permutation and reshape) ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake --- .../dnn/include/opencv2/dnn/all_layers.hpp | 10 + modules/dnn/src/cuda4dnn/csl/tensor.hpp | 29 +- .../cuda4dnn/primitives/depth_space_ops.hpp | 76 +++ modules/dnn/src/init.cpp | 4 + .../dnn/src/layers/depth_space_ops_layer.cpp | 492 ++++++++++++++++++ modules/dnn/src/onnx/onnx_importer.cpp | 82 +-- modules/dnn/test/test_int8_layers.cpp | 55 ++ ...conformance_layer_filter__openvino.inl.hpp | 10 + 8 files changed, 663 insertions(+), 95 deletions(-) create mode 100644 modules/dnn/src/cuda4dnn/primitives/depth_space_ops.hpp create mode 100644 modules/dnn/src/layers/depth_space_ops_layer.cpp diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp index 3301f20fde..2abce0c87b 100644 --- a/modules/dnn/include/opencv2/dnn/all_layers.hpp +++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp @@ -1188,6 +1188,16 @@ CV__DNN_INLINE_NS_BEGIN static Ptr create(const LayerParams ¶ms); }; + class CV_EXPORTS DepthToSpaceLayer : public Layer { + public: + static Ptr create(const LayerParams ¶ms); + }; + + class CV_EXPORTS SpaceToDepthLayer : public Layer { + public: + static Ptr create(const LayerParams ¶ms); + }; + //! @} //! @} CV__DNN_INLINE_NS_END diff --git a/modules/dnn/src/cuda4dnn/csl/tensor.hpp b/modules/dnn/src/cuda4dnn/csl/tensor.hpp index 8f495ac807..15cda2fff5 100644 --- a/modules/dnn/src/cuda4dnn/csl/tensor.hpp +++ b/modules/dnn/src/cuda4dnn/csl/tensor.hpp @@ -265,7 +265,6 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { typename std::enable_if::value, void> ::type reshape(ForwardItr start, ForwardItr end) { CV_Assert(start != end); - CV_Assert(std::distance(start, end) <= rank()); using ItrValueType = typename std::iterator_traits::value_type; @@ -284,6 +283,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { auto total = std::accumulate(start, end, 1, std::multiplies()); if (total < 0) { /* there is an unknown size */ + CV_CheckEQ(size() % std::abs(total), static_cast(0), "cannot be reshaped"); // must be divisible if (std::abs(total) <= size()) { unknown_size = size() / std::abs(total); total = size(); @@ -298,11 +298,9 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { CV_Error(Error::StsBadArg, "new axes do not preserve the tensor element count"); } - /* we assume the size of the unspecified axes to be one */ - std::fill(std::begin(shape), std::end(shape), 1); - std::copy_backward(start, end, std::end(shape)); - - /* replace the unknown axis with the correct value */ + /* copy shape from given iterator and reshape -1 with deduced value */ + shape.resize(std::distance(start, end)); + std::copy(start, end, shape.begin()); std::replace(std::begin(shape), std::end(shape), size_type(-1), unknown_size); } @@ -600,6 +598,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { auto total = std::accumulate(start, end, 1, std::multiplies()); if (total < 0) { /* there is an unknown size */ + CV_CheckEQ(size() % std::abs(total), static_cast(0), "cannot be reshaped"); // must be divisible if (std::abs(total) <= size()) { unknown_size = size() / std::abs(total); total = size(); @@ -614,11 +613,9 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { CV_Error(Error::StsBadArg, "new axes do not preserve the tensor element count"); } - /* we assume the size of the unspecified axes to be one */ - std::fill(std::begin(shape), std::end(shape), 1); - std::copy_backward(start, end, std::end(shape)); - - /* replace the unknown axis with the correct value */ + /* copy shape from given iterator and reshape -1 with deduced value */ + shape.resize(std::distance(start, end)); + std::copy(start, end, shape.begin()); std::replace(std::begin(shape), std::end(shape), size_type(-1), unknown_size); } @@ -946,7 +943,6 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { typename std::enable_if::value, void> ::type reshape(ForwardItr start, ForwardItr end) { CV_Assert(start != end); - CV_Assert(std::distance(start, end) <= rank()); using ItrValueType = typename std::iterator_traits::value_type; @@ -965,6 +961,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { auto total = std::accumulate(start, end, 1, std::multiplies()); if (total < 0) { /* there is an unknown size */ + CV_CheckEQ(size() % std::abs(total), static_cast(0), "cannot be reshaped"); // must be divisible if (std::abs(total) <= size()) { unknown_size = size() / std::abs(total); total = size(); @@ -979,11 +976,9 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { CV_Error(Error::StsBadArg, "new axes do not preserve the tensor element count"); } - /* we assume the size of the unspecified axes to be one */ - std::fill(std::begin(shape), std::end(shape), 1); - std::copy_backward(start, end, std::end(shape)); - - /* replace the unknown axis with the correct value */ + /* copy shape from given iterator and reshape -1 with deduced value */ + shape.resize(std::distance(start, end)); + std::copy(start, end, shape.begin()); std::replace(std::begin(shape), std::end(shape), size_type(-1), unknown_size); } diff --git a/modules/dnn/src/cuda4dnn/primitives/depth_space_ops.hpp b/modules/dnn/src/cuda4dnn/primitives/depth_space_ops.hpp new file mode 100644 index 0000000000..7846881e70 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/depth_space_ops.hpp @@ -0,0 +1,76 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_DEPTH_SPACE_OPS_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_DEPTH_SPACE_OPS_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" +#include "../csl/tensor.hpp" +#include "../csl/tensor_ops.hpp" +#include "../csl/memory.hpp" +#include "../kernels/permute.hpp" + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + template + class DepthSpaceOps final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + DepthSpaceOps(csl::Stream stream_, const std::vector &internal_shape_, + const std::vector &permutation_) + : stream(std::move(stream_)), internal_shape(internal_shape_), + permutation(permutation_) + { + transposed_internal_shape = std::vector(internal_shape.size()); + for (size_t i = 0; i < permutation.size(); i++) { + transposed_internal_shape[i] = internal_shape[permutation[i]]; + } + + size_t num_elements = std::accumulate(internal_shape.begin(), internal_shape.end(), 1, std::multiplies()); + csl::WorkspaceBuilder builder; + builder.require(num_elements); + scratch_mem_in_bytes = builder.required_workspace_size(); + } + + void forward(const std::vector> &inputs, + const std::vector> &outputs, + csl::Workspace &workspace) override { + CV_CheckEQ(inputs.size(), size_t(1), "DepthSpaceOps: only one input is accepted"); + CV_CheckEQ(outputs.size(), size_t(1), "DepthSpaceOps: only one output is accepted"); + + auto input_wrapper = inputs.front().dynamicCast(); + auto input = input_wrapper->getView(); + CV_CheckEQ(input.rank(), size_t(4), "DepthSpaceOps: input needs to be 4-dimensional [N, C, H, W]"); + auto output_wrapper = outputs.front().dynamicCast(); + auto output = output_wrapper->getSpan(); + auto ws_allocator = csl::WorkspaceAllocator(workspace); + auto transposed_internal = ws_allocator.get_tensor_span(transposed_internal_shape.begin(), transposed_internal_shape.end()); + + // Call reshape on input so that it has the correct shape for permutation + input.reshape(internal_shape.begin(), internal_shape.end()); + kernels::permute(stream, transposed_internal, input, permutation); + // Only copying is needed as output already has the expected shape + auto t = csl::TensorView(transposed_internal); + csl::memcpy(output.get(), t.get(), output.size(), stream); + } + + std::size_t get_workspace_memory_in_bytes() const noexcept override { return scratch_mem_in_bytes; } + + private: + csl::Stream stream; + std::vector internal_shape; + std::vector permutation; + std::vector transposed_internal_shape; + + std::size_t scratch_mem_in_bytes; + }; + +}}} // namespace cv::dnn::cuda4dnn + +#endif // OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_DEPTH_SPACE_OPS_HPP diff --git a/modules/dnn/src/init.cpp b/modules/dnn/src/init.cpp index e8450c18f9..ce1eb77649 100644 --- a/modules/dnn/src/init.cpp +++ b/modules/dnn/src/init.cpp @@ -164,6 +164,10 @@ void initializeLayerFactory() CV_DNN_REGISTER_LAYER_CLASS(InstanceNormalization, InstanceNormLayer); CV_DNN_REGISTER_LAYER_CLASS(Attention, AttentionLayer); CV_DNN_REGISTER_LAYER_CLASS(GroupNormalization, GroupNormLayer); + CV_DNN_REGISTER_LAYER_CLASS(DepthToSpace, DepthToSpaceLayer) + CV_DNN_REGISTER_LAYER_CLASS(SpaceToDepth, SpaceToDepthLayer) + CV_DNN_REGISTER_LAYER_CLASS(DepthToSpaceInt8, DepthToSpaceLayer) + CV_DNN_REGISTER_LAYER_CLASS(SpaceToDepthInt8, SpaceToDepthLayer) CV_DNN_REGISTER_LAYER_CLASS(Crop, CropLayer); CV_DNN_REGISTER_LAYER_CLASS(Eltwise, EltwiseLayer); diff --git a/modules/dnn/src/layers/depth_space_ops_layer.cpp b/modules/dnn/src/layers/depth_space_ops_layer.cpp new file mode 100644 index 0000000000..3877758b20 --- /dev/null +++ b/modules/dnn/src/layers/depth_space_ops_layer.cpp @@ -0,0 +1,492 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include "../precomp.hpp" +#include + +// OpenCL backend +#ifdef HAVE_OPENCL +#include "opencl_kernels_dnn.hpp" +#endif + +// OpenVINO backend +#ifdef HAVE_DNN_NGRAPH +#include "../op_inf_engine.hpp" +#include "../ie_ngraph.hpp" +#endif + +// CUDA backend +#ifdef HAVE_CUDA +#include "../op_cuda.hpp" +#include "../cuda4dnn/primitives/depth_space_ops.hpp" +#endif + +// CANN backend +#ifdef HAVE_CANN +#include "../op_cann.hpp" +#endif + +// TIM-VX backend +#ifdef HAVE_TIMVX +#include "../op_timvx.hpp" +#endif + +namespace cv { namespace dnn { + +struct DepthSpaceOps { + MatShape internal_shape; + MatShape transposed_internal_shape; + std::vector permutation; + +#ifdef HAVE_OPENCL + UMat umat_permutation; + UMat umat_internal_strides; + UMat umat_transposed_internal_strides; +#endif + + void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) { + transposed_internal_shape = MatShape(internal_shape.size()); + for (size_t i = 0; i < permutation.size(); i++) { + transposed_internal_shape[i] = internal_shape[permutation[i]]; + } + +#ifdef HAVE_OPENCL + umat_permutation.release(); + umat_internal_strides.release(); + umat_transposed_internal_strides.release(); +#endif + } + + void cpuCompute(const Mat &input, Mat &output) { + const auto output_shape = shape(output); + Mat tmp; + cv::transposeND(input.reshape(1, internal_shape), permutation, tmp); + tmp.reshape(1, output_shape).copyTo(output); + } + +#ifdef HAVE_OPENCL + bool oclCompute(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) { + std::vector inputs, outputs; + + inputs_arr.getUMatVector(inputs); + outputs_arr.getUMatVector(outputs); + + if (umat_permutation.empty() || umat_internal_strides.empty() || umat_transposed_internal_strides.empty()) { + Mat mat_permutation(1, permutation.size(), CV_32S, permutation.data()); + mat_permutation.copyTo(umat_permutation); + + std::vector internal_strides(permutation.size(), 1), transposed_internal_stides(permutation.size(), 1); + for (int i = static_cast(permutation.size()) - 2; i >= 0; i--) { + internal_strides[i] = internal_strides[i + 1] * internal_shape[i + 1]; + transposed_internal_stides[i] = transposed_internal_stides[i + 1] * transposed_internal_shape[i + 1]; + } + Mat mat_internal_strides(1, internal_strides.size(), CV_32S, internal_strides.data()); + mat_internal_strides.copyTo(umat_internal_strides); + Mat mat_transposed_internal_strides(1, transposed_internal_stides.size(), CV_32S, transposed_internal_stides.data()); + mat_transposed_internal_strides.copyTo(umat_transposed_internal_strides); + } + + const auto output_shape = shape(outputs.front()); + UMat tmp = inputs.front().reshape(1, static_cast(internal_shape.size()), internal_shape.data()); + + bool use_half = (inputs_arr.depth() == CV_16F); + std::string permute_options = cv::format("-DDtype=%s", use_half ? "half" : "float"); + ocl::Kernel permute_kernel("permute", ocl::dnn::permute_oclsrc, permute_options); + if (permute_kernel.empty()) { + return false; + } + UMat transposed_tmp(static_cast(transposed_internal_shape.size()), transposed_internal_shape.data(), inputs_arr.depth()); + size_t num_element = static_cast(std::accumulate(internal_shape.begin(), internal_shape.end(), 1, std::multiplies())); + permute_kernel.set(0, static_cast(num_element)); + permute_kernel.set(1, ocl::KernelArg::PtrReadOnly(tmp)); + permute_kernel.set(2, ocl::KernelArg::PtrReadOnly(umat_permutation)); + permute_kernel.set(3, ocl::KernelArg::PtrReadOnly(umat_internal_strides)); + permute_kernel.set(4, ocl::KernelArg::PtrReadOnly(umat_transposed_internal_strides)); + permute_kernel.set(5, static_cast(permutation.size())); + permute_kernel.set(6, ocl::KernelArg::PtrWriteOnly(transposed_tmp)); + if (!permute_kernel.run(1, &num_element, NULL, false)) { + return false; + } + + transposed_tmp.reshape(1, static_cast(output_shape.size()), output_shape.data()).copyTo(outputs.front()); + return true; + } +#endif // HAVE_OPENCL +}; + +class DepthToSpaceLayerImpl CV_FINAL : public DepthToSpaceLayer, public DepthSpaceOps { +public: + DepthToSpaceLayerImpl(const LayerParams ¶ms) { + setParamsFrom(params); + + CV_CheckTrue(params.has("blocksize"), "DepthSpaceLayer: blocksize is required"); + blocksize = params.get("blocksize"); + + auto mode = params.get("mode", "DCR"); + if (mode == "CRD") { + is_crd = true; + permutation = {0, 1, 4, 2, 5, 3}; + } else if (mode == "DCR") { + is_crd = false; + permutation = {0, 3, 4, 1, 5, 2}; + } else { + CV_Error(Error::StsBadArg, cv::format("DepthToSpace: unsupported mode %s\n", mode.c_str())); + } + } + + virtual bool supportBackend(int backendId) CV_OVERRIDE { + return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH || + backendId == DNN_BACKEND_CUDA || + backendId == DNN_BACKEND_CANN || + (backendId == DNN_BACKEND_TIMVX && is_crd); + } + + virtual bool getMemoryShapes(const std::vector &inputs, + const int requiredOutputs, + std::vector &outputs, + std::vector &internals) const CV_OVERRIDE { + CV_CheckEQ(inputs.size(), static_cast(1), "DepthSpaceLayer: accepts only one input"); + const auto &input = inputs.front(); + CV_CheckEQ(input.size(), static_cast(4), "DepthSpaceLayer: input needs to be 4-dimensional [N, C, H, W]"); + int batch = input[0], input_depth = input[1], input_height = input[2], input_width = input[3]; + int output_depth = -1, output_height = -1, output_width = -1; + + CV_CheckEQ(input_depth % (blocksize * blocksize), 0, + "DepthSpaceLayer: requires input depth to be a multiple of (blocksize * blocksize)"); + output_depth = input_depth / blocksize / blocksize; + output_height = input_height * blocksize; + output_width = input_width * blocksize; + + outputs.assign(1, MatShape{batch, output_depth, output_height, output_width}); + return false; + } + + virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE { + std::vector inputs; + inputs_arr.getMatVector(inputs); + + auto input_shape = shape(inputs.front()); + int batch = input_shape[0], input_depth = input_shape[1], input_height = input_shape[2], input_width = input_shape[3]; + if (is_crd) { + internal_shape = MatShape{batch, input_depth / (blocksize * blocksize), blocksize, blocksize, input_height, input_width}; + } else { + internal_shape = MatShape{batch, blocksize, blocksize, input_depth / (blocksize * blocksize), input_height, input_width}; + } + + DepthSpaceOps::finalize(inputs_arr, outputs_arr); + } + + void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE { + CV_TRACE_FUNCTION(); + CV_TRACE_ARG_VALUE(name, "name", name.c_str()); + + // TODO: support 8-bit int in permute kernel + CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget) && inputs_arr.depth() != CV_8S, + DepthSpaceOps::oclCompute(inputs_arr, outputs_arr, internals_arr)) + + if (inputs_arr.depth() == CV_16F) { + forward_fallback(inputs_arr, outputs_arr, internals_arr); + return; + } + + std::vector inputs, outputs; + inputs_arr.getMatVector(inputs); + outputs_arr.getMatVector(outputs); + + DepthSpaceOps::cpuCompute(inputs.front(), outputs.front()); + } + +#ifdef HAVE_DNN_NGRAPH + virtual Ptr initNgraph(const std::vector> &inputs, + const std::vector> &nodes) CV_OVERRIDE { + using namespace ov::op; + auto input_node = nodes[0].dynamicCast()->node; + std::shared_ptr output_node; + if (is_crd) { + output_node = std::make_shared(input_node, v0::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST, static_cast(blocksize)); + } else { + output_node = std::make_shared(input_node, v0::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST, static_cast(blocksize)); + } + return Ptr(new InfEngineNgraphNode(output_node)); + } +#endif // HAVE_DNN_NGRAPH + +#ifdef HAVE_CUDA + Ptr initCUDA(void *context_, + const std::vector>& inputs, + const std::vector>& outputs) override { + using namespace cv::dnn::cuda4dnn; + auto context = reinterpret_cast(context_); + std::vector perm(permutation.begin(), permutation.end()); + return make_cuda_node(preferableTarget, std::move(context->stream), internal_shape, perm); + } +#endif // HAVE_CUDA + +#ifdef HAVE_CANN + virtual Ptr initCann(const std::vector> &inputs, + const std::vector> &outputs, + const std::vector> &nodes) CV_OVERRIDE { + CV_CheckEQ(inputs.size(), static_cast(1), "DepthToSpace/CANN: only accepts one input wrapper"); + CV_CheckEQ(nodes.size(), static_cast(1), "DepthToSpace/CANN: only accepts one input node"); + + auto input_tensor_wrapper = inputs.front().dynamicCast(); + auto input_tensor_desc = input_tensor_wrapper->getTensorDesc(); + auto input_node = nodes.front().dynamicCast()->getOp(); + + auto node = std::make_shared(name); + + node->set_attr_block_size(blocksize); + if (is_crd) { + node->set_attr_mode("CRD"); + } else { + node->set_attr_mode("DCR"); + } + node->set_attr_data_format("NCHW"); + + node->set_input_x_by_name(*input_node, input_tensor_wrapper->name.c_str()); + node->update_input_desc_x(*input_tensor_desc); + + auto output_tensor_desc = std::make_shared(ge::Shape(), ge::FORMAT_NCHW, ge::DT_FLOAT); + node->update_output_desc_y(*output_tensor_desc); + + return Ptr(new CannBackendNode(node)); + } +#endif + +#ifdef HAVE_TIMVX + virtual Ptr initTimVX(void* timvx_info_, + const std::vector> &inputs, + const std::vector> &outputs, + bool isLast) CV_OVERRIDE { + auto info = reinterpret_cast(timvx_info_); + CV_Assert(info); + auto timvx_graph = info->getGraph(); + CV_Assert(timvx_graph); + auto graph = timvx_graph->graph; + + auto input_wrapper = inputs.front().dynamicCast(); + int input_wrapper_index = -1; + if (input_wrapper->isTensor()) { + input_wrapper_index = timvx_graph->getTensorIndex(input_wrapper->getTensor()); + if (input_wrapper_index == -1) { + auto tmp = input_wrapper->getMat(); + input_wrapper = std::make_shared(tmp); + } + } + if (!input_wrapper->isTensor() || input_wrapper_index == 1) { + auto input_node_quant = Ptr(new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, 1.0f, 0)); + input_wrapper->createTensor(graph, tim::vx::TensorAttribute::INPUT, input_node_quant); + input_wrapper_index = timvx_graph->addWrapper(input_wrapper); + } + + auto output_wrapper = outputs.front().dynamicCast(); + auto output_node_quant = input_wrapper->getTensorQuantization(); + if (isLast) { + auto shape_type = getShapeTypeFromMat(output_wrapper->getMat()); + output_wrapper->setTensorShape(shape_type); + output_wrapper->createTensor(graph, tim::vx::TensorAttribute::OUTPUT, output_node_quant); + } else { + output_wrapper->createTensor(graph, tim::vx::TensorAttribute::TRANSIENT, output_node_quant); + } + int output_wrapper_index = timvx_graph->addWrapper(output_wrapper); + + std::shared_ptr timvx_node; + timvx_node = graph->CreateOperation(blocksize); + std::vector input_wrapper_indices{input_wrapper_index}, output_wrapper_indices{output_wrapper_index}; + return Ptr(new TimVXBackendNode(timvx_graph, timvx_node, input_wrapper_indices, output_wrapper_indices)); + } +#endif + + virtual bool tryQuantize(const std::vector> &scales, + const std::vector> &zeropoints, LayerParams ¶ms) CV_OVERRIDE { + return true; + } + +private: + int blocksize; + + bool is_crd; +}; + +Ptr DepthToSpaceLayer::create(const LayerParams ¶ms) { + return makePtr(params); +} + +class SpaceToDepthLayerImpl CV_FINAL : public SpaceToDepthLayer, public DepthSpaceOps { +public: + SpaceToDepthLayerImpl(const LayerParams ¶ms) { + setParamsFrom(params); + + CV_CheckTrue(params.has("blocksize"), "SpaceToDepthLayer: blocksize is required"); + blocksize = params.get("blocksize"); + + permutation = {0, 3, 5, 1, 2, 4}; + } + + virtual bool supportBackend(int backendId) CV_OVERRIDE { + return backendId == DNN_BACKEND_OPENCV || + backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH || + backendId == DNN_BACKEND_CUDA || + backendId == DNN_BACKEND_CANN || + (backendId == DNN_BACKEND_TIMVX); + } + + virtual bool getMemoryShapes(const std::vector &inputs, + const int requiredOutputs, + std::vector &outputs, + std::vector &internals) const CV_OVERRIDE { + CV_CheckEQ(inputs.size(), static_cast(1), "SpaceToDepthLayer: accepts only one input"); + const auto &input = inputs.front(); + CV_CheckEQ(input.size(), static_cast(4), "SpaceToDepthLayer: input needs to be 4-dimensional [N, C, H, W]"); + int batch = input[0], input_depth = input[1], input_height = input[2], input_width = input[3]; + int output_depth = -1, output_height = -1, output_width = -1; + + CV_CheckEQ(input_height % blocksize, 0, "SpaceToDepthLayer: requires input height to be a multiple of blocksize"); + CV_CheckEQ(input_width % blocksize, 0, "SpaceToDepthLayer: requires input width to be a multiple of blocksize"); + output_depth = input_depth * blocksize * blocksize; + output_height = input_height / blocksize; + output_width = input_width / blocksize; + + outputs.assign(1, MatShape{batch, output_depth, output_height, output_width}); + return false; + } + + virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE { + std::vector inputs; + inputs_arr.getMatVector(inputs); + + auto input_shape = shape(inputs.front()); + int batch = input_shape[0], input_depth = input_shape[1], input_height = input_shape[2], input_width = input_shape[3]; + internal_shape = MatShape{batch, input_depth, input_height / blocksize, blocksize, input_width / blocksize, blocksize}; + + DepthSpaceOps::finalize(inputs_arr, outputs_arr); + } + + void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE { + CV_TRACE_FUNCTION(); + CV_TRACE_ARG_VALUE(name, "name", name.c_str()); + + // TODO: support 8-bit int in permute kernel + CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget) && inputs_arr.depth() != CV_8S, + DepthSpaceOps::oclCompute(inputs_arr, outputs_arr, internals_arr)) + + if (inputs_arr.depth() == CV_16F) { + forward_fallback(inputs_arr, outputs_arr, internals_arr); + return; + } + + std::vector inputs, outputs; + inputs_arr.getMatVector(inputs); + outputs_arr.getMatVector(outputs); + + DepthSpaceOps::cpuCompute(inputs.front(), outputs.front()); + } + +#ifdef HAVE_DNN_NGRAPH + virtual Ptr initNgraph(const std::vector> &inputs, + const std::vector> &nodes) CV_OVERRIDE { + using namespace ov::op; + auto input_node = nodes[0].dynamicCast()->node; + std::shared_ptr output_node; + output_node = std::make_shared(input_node, v0::SpaceToDepth::SpaceToDepthMode::BLOCKS_FIRST, static_cast(blocksize)); + return Ptr(new InfEngineNgraphNode(output_node)); + } +#endif // HAVE_DNN_NGRAPH + +#ifdef HAVE_CUDA + Ptr initCUDA(void *context_, + const std::vector> &inputs, + const std::vector> &outputs) override { + using namespace cv::dnn::cuda4dnn; + auto context = reinterpret_cast(context_); + std::vector perm(permutation.begin(), permutation.end()); + return make_cuda_node(preferableTarget, std::move(context->stream), internal_shape, perm); + } +#endif // HAVE_CUDA + +#ifdef HAVE_CANN + virtual Ptr initCann(const std::vector> &inputs, + const std::vector> &outputs, + const std::vector> &nodes) CV_OVERRIDE { + CV_CheckEQ(inputs.size(), static_cast(1), "DepthToSpace/CANN: only accepts one input wrapper"); + CV_CheckEQ(nodes.size(), static_cast(1), "DepthToSpace/CANN: only accepts one input node"); + + auto input_tensor_wrapper = inputs.front().dynamicCast(); + auto input_tensor_desc = input_tensor_wrapper->getTensorDesc(); + auto input_node = nodes.front().dynamicCast()->getOp(); + + auto node = std::make_shared(name); + + node->set_attr_block_size(blocksize); + node->set_attr_data_format("NCHW"); + + node->set_input_x_by_name(*input_node, input_tensor_wrapper->name.c_str()); + node->update_input_desc_x(*input_tensor_desc); + + auto output_tensor_desc = std::make_shared(ge::Shape(), ge::FORMAT_NCHW, ge::DT_FLOAT); + node->update_output_desc_y(*output_tensor_desc); + + return Ptr(new CannBackendNode(node)); + } +#endif + +#ifdef HAVE_TIMVX + virtual Ptr initTimVX(void* timvx_info_, + const std::vector> &inputs, + const std::vector> &outputs, + bool isLast) CV_OVERRIDE { + auto info = reinterpret_cast(timvx_info_); + CV_Assert(info); + auto timvx_graph = info->getGraph(); + CV_Assert(timvx_graph); + auto graph = timvx_graph->graph; + + auto input_wrapper = inputs.front().dynamicCast(); + int input_wrapper_index = -1; + if (input_wrapper->isTensor()) { + input_wrapper_index = timvx_graph->getTensorIndex(input_wrapper->getTensor()); + if (input_wrapper_index == -1) { + auto tmp = input_wrapper->getMat(); + input_wrapper = std::make_shared(tmp); + } + } + if (!input_wrapper->isTensor() || input_wrapper_index == 1) { + auto input_node_quant = Ptr(new tim::vx::Quantization(tim::vx::QuantType::ASYMMETRIC, 1.0f, 0)); + input_wrapper->createTensor(graph, tim::vx::TensorAttribute::INPUT, input_node_quant); + input_wrapper_index = timvx_graph->addWrapper(input_wrapper); + } + + auto output_wrapper = outputs.front().dynamicCast(); + auto output_node_quant = input_wrapper->getTensorQuantization(); + if (isLast) { + auto shape_type = getShapeTypeFromMat(output_wrapper->getMat()); + output_wrapper->setTensorShape(shape_type); + output_wrapper->createTensor(graph, tim::vx::TensorAttribute::OUTPUT, output_node_quant); + } else { + output_wrapper->createTensor(graph, tim::vx::TensorAttribute::TRANSIENT, output_node_quant); + } + int output_wrapper_index = timvx_graph->addWrapper(output_wrapper); + + std::shared_ptr timvx_node; + timvx_node = graph->CreateOperation(std::vector{blocksize, blocksize}); + std::vector input_wrapper_indices{input_wrapper_index}, output_wrapper_indices{output_wrapper_index}; + return Ptr(new TimVXBackendNode(timvx_graph, timvx_node, input_wrapper_indices, output_wrapper_indices)); + } +#endif + + virtual bool tryQuantize(const std::vector> &scales, + const std::vector> &zeropoints, LayerParams ¶ms) CV_OVERRIDE { + return true; + } + +private: + int blocksize; +}; + +Ptr SpaceToDepthLayer::create(const LayerParams ¶ms) { + return makePtr(params); +} + +}} // namespace cv::dnn diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 7b63e39a3a..565a88b760 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -189,7 +189,7 @@ private: void parseDetectionOutput (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseCumSum (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseElementWise (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); - void parseDepthToSpace (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); + void parseDepthSpaceOps (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseRange (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseScatter (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseTile (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); @@ -2916,82 +2916,8 @@ void ONNXImporter::parseElementWise(LayerParams& layerParams, const opencv_onnx: addLayer(layerParams, node_proto); } -void ONNXImporter::parseDepthToSpace(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_) -{ - // We parse "DepthToSpace" and "SpaceToDepth" in this function. - opencv_onnx::NodeProto node_proto = node_proto_; - const std::string& layer_type = node_proto.op_type(); - CV_Assert(layer_type == "DepthToSpace" || layer_type == "SpaceToDepth"); - - // Get blocksize - CV_Assert(layerParams.has("blocksize")); - int blocksize = layerParams.get("blocksize"); - CV_Assert(blocksize > 0); - - // Get mode, only for "DepthToSpace" - std::string modeType = layerParams.get("mode", "DCR"); - - MatShape inpShape = outShapes[node_proto.input(0)]; - CV_Assert(inpShape.size() == 4); - int N = inpShape[0], C = inpShape[1], H = inpShape[2], W = inpShape[3]; - - // Implement DepthToSpace and SpaceToDepth by the Reshape and Permute layer. - std::array shape0, perm; - std::array shape1; - - if (layer_type == "DepthToSpace") - { - if (modeType == "DCR") - { - shape0 = {N, blocksize, blocksize, C/(blocksize * blocksize), H, W}; - perm = {0, 3, 4, 1, 5, 2}; - shape1 = {N, C/(blocksize * blocksize), H * blocksize, W * blocksize}; - } - else if (modeType == "CRD") - { - shape0 = {N, C/(blocksize * blocksize), blocksize, blocksize, H, W}; - perm = {0, 1, 4, 2, 5, 3}; - shape1 = {N, C/(blocksize * blocksize), H * blocksize, W * blocksize}; - } - else - CV_Error(Error::StsNotImplemented, "The mode of " + modeType + " in " + layer_type + " Layer is not supported"); - } - else // SpaceToDepth - { - shape0 = {N, C, H/blocksize, blocksize, W/blocksize, blocksize}; - perm = {0, 3, 5, 1, 2, 4}; - shape1 = {N, C * blocksize * blocksize, H/blocksize, W/blocksize}; - } - - // Step1: Reshape - LayerParams reshapeLp; - reshapeLp.name = layerParams.name + "/reshape"; - reshapeLp.type = "Reshape"; - CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end()); - reshapeLp.set("dim", DictValue::arrayInt(shape0.data(), shape0.size())); - - opencv_onnx::NodeProto protoReshape; - protoReshape.add_input(node_proto.input(0)); - protoReshape.add_output(reshapeLp.name); - addLayer(reshapeLp, protoReshape); - - // Step2: Transpose - LayerParams permuteLp; - permuteLp.name = layerParams.name + "/permute"; - permuteLp.type = "Permute"; - CV_Assert(layer_id.find(permuteLp.name) == layer_id.end()); - permuteLp.set("order", DictValue::arrayInt(perm.data(), perm.size())); - - opencv_onnx::NodeProto protoPermute; - protoPermute.add_input(reshapeLp.name); - protoPermute.add_output(permuteLp.name); - addLayer(permuteLp, protoPermute); - - // Step3: Reshape - layerParams.type = "Reshape"; - layerParams.set("dim", DictValue::arrayInt(shape1.data(), shape1.size())); - - node_proto.set_input(0, permuteLp.name); +void ONNXImporter::parseDepthSpaceOps(LayerParams &layerParams, const opencv_onnx::NodeProto& node_proto) { + CV_CheckTrue(layerParams.has("blocksize"), "blocksize is required but not found"); addLayer(layerParams, node_proto); } @@ -4002,7 +3928,7 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version) dispatch["SoftMax"] = dispatch["Softmax"] = dispatch["LogSoftmax"] = &ONNXImporter::parseSoftMax; dispatch["DetectionOutput"] = &ONNXImporter::parseDetectionOutput; dispatch["CumSum"] = &ONNXImporter::parseCumSum; - dispatch["SpaceToDepth"] = dispatch["DepthToSpace"] = &ONNXImporter::parseDepthToSpace; + dispatch["SpaceToDepth"] = dispatch["DepthToSpace"] = &ONNXImporter::parseDepthSpaceOps; dispatch["ScatterElements"] = dispatch["Scatter"] = dispatch["ScatterND"] = &ONNXImporter::parseScatter; dispatch["Tile"] = &ONNXImporter::parseTile; dispatch["LayerNormalization"] = &ONNXImporter::parseLayerNorm; diff --git a/modules/dnn/test/test_int8_layers.cpp b/modules/dnn/test/test_int8_layers.cpp index bc5d9388a9..9d102937bd 100644 --- a/modules/dnn/test/test_int8_layers.cpp +++ b/modules/dnn/test/test_int8_layers.cpp @@ -509,6 +509,61 @@ TEST_P(Test_Int8_layers, Eltwise) testLayer("split_max", "ONNX", 0.004, 0.012); } +TEST_P(Test_Int8_layers, DepthSpaceOps) { + auto test_layer_with_onnx_conformance_models = [&](const std::string &model_name, double l1, double lInf) { + std::string model_path = _tf("onnx/conformance/node/test_" + model_name + "/model.onnx"); + auto net = readNet(model_path); + + // load reference inputs and outputs + std::string data_base_path = _tf("onnx/conformance/node/test_" + model_name + "/test_data_set_0"); + Mat input = readTensorFromONNX(data_base_path + "/input_0.pb"); + Mat ref_output = readTensorFromONNX(data_base_path + "/output_0.pb"); + + std::vector input_scales, output_scales; + std::vector input_zeropoints, output_zeropoints; + auto qnet = net.quantize(std::vector{input}, CV_8S, CV_8S, false); + qnet.getInputDetails(input_scales, input_zeropoints); + qnet.getOutputDetails(output_scales, output_zeropoints); + qnet.setPreferableBackend(backend); + qnet.setPreferableTarget(target); + + Mat quantized_input, quantized_output; + input.convertTo(quantized_input, CV_8S, 1.f / input_scales.front(), input_zeropoints.front()); + qnet.setInput(quantized_input); + quantized_output = qnet.forward(); + + Mat output; + quantized_output.convertTo(output, CV_32F, output_scales.front(), -(output_scales.front() * output_zeropoints.front())); + normAssert(ref_output, output, model_name.c_str(), l1, lInf); + }; + + double l1 = default_l1, lInf = default_lInf; + { + l1 = 0.001; lInf = 0.002; + if (backend == DNN_BACKEND_TIMVX) { l1 = 0.001; lInf = 0.002; } + test_layer_with_onnx_conformance_models("spacetodepth", l1, lInf); + } + { + l1 = 0.022; lInf = 0.044; + if (backend == DNN_BACKEND_TIMVX) { l1 = 0.022; lInf = 0.044; } + test_layer_with_onnx_conformance_models("spacetodepth_example", l1, lInf); + } + { + l1 = 0.001; lInf = 0.002; + if (backend == DNN_BACKEND_TIMVX) { l1 = 0.24; lInf = 0.99; } + test_layer_with_onnx_conformance_models("depthtospace_crd_mode", l1, lInf); + } + test_layer_with_onnx_conformance_models("depthtospace_dcr_mode", 0.001, 0.002); + test_layer_with_onnx_conformance_models("depthtospace_example", 0.07, 0.14); + + { + l1 = 0.07; lInf = 0.14; + if (backend == DNN_BACKEND_TIMVX) // diff too huge, l1 = 13.6; lInf = 27.2 + applyTestTag(CV_TEST_TAG_DNN_SKIP_TIMVX); + test_layer_with_onnx_conformance_models("depthtospace_crd_mode_example", l1, lInf); + } +} + INSTANTIATE_TEST_CASE_P(/**/, Test_Int8_layers, dnnBackendsAndTargetsInt8()); class Test_Int8_nets : public DNNTestLayer diff --git a/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp index 509cf6007d..000e867217 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp @@ -544,10 +544,20 @@ CASE(test_cumsum_2d_negative_axis) // no filter CASE(test_depthtospace_crd_mode) // no filter + if (target == DNN_TARGET_OPENCL) + { + default_l1 = 1e-4; // Expected: (normL1) <= (l1), actual: 9.33057e-05 vs 1e-05 + default_lInf = 2.5e-4; // Expected: (normInf) <= (lInf), actual: 0.000243843 vs 0.0001 + } CASE(test_depthtospace_crd_mode_example) // no filter CASE(test_depthtospace_dcr_mode) // no filter + if (target == DNN_TARGET_OPENCL) + { + default_l1 = 1e-4; // Expected: (normL1) <= (l1), actual: 9.33057e-05 vs 1e-05 + default_lInf = 2.5e-4; // Expected: (normInf) <= (lInf), actual: 0.000243843 vs 0.0001 + } CASE(test_depthtospace_example) // no filter CASE(test_dequantizelinear)