diff --git a/modules/dnn/src/cuda/shortcut.cu b/modules/dnn/src/cuda/shortcut.cu new file mode 100644 index 0000000000..e2958627ab --- /dev/null +++ b/modules/dnn/src/cuda/shortcut.cu @@ -0,0 +1,109 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include <cuda_runtime.h> +#include <cuda_fp16.h> + +#include "grid_stride_range.hpp" +#include "execution.hpp" +#include "vector_traits.hpp" + +#include "../cuda4dnn/csl/stream.hpp" +#include "../cuda4dnn/csl/span.hpp" +#include "../cuda4dnn/csl/tensor.hpp" + +#include <opencv2/core.hpp> + +using namespace cv::dnn::cuda4dnn::csl; +using namespace cv::dnn::cuda4dnn::csl::device; + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + +namespace raw { + template <class T, std::size_t N> + __global__ void input_shortcut_vec( + Span<T> output, + View<T> input, index_type c_input, /* `c_input` = number of channels in `input` */ + View<T> from, index_type c_from, /* `c_from` = number of channels in `from` */ + size_type channel_stride /* common for both `input` and `from` */) + { + using vector_type = get_vector_type_t<T, N>; + + auto output_vPtr = vector_type::get_pointer(output.data()); + auto input_vPtr = vector_type::get_pointer(input.data()); + auto from_vPtr = vector_type::get_pointer(from.data()); + + auto batch_stride_input = c_input * channel_stride; + auto batch_stride_from = c_from * channel_stride; + + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + const auto actual_idx = i * vector_type::size(); + const auto b = actual_idx / batch_stride_input; /* `input` and `output` have the same shape */ + const auto c = (actual_idx % batch_stride_input) / channel_stride; + const auto c_offset = actual_idx % channel_stride; + + vector_type vec_input; + v_load(vec_input, input_vPtr[i]); + + /* We can break down the shortcut operation into two steps: + * - copy `input` to `output` + * - add `from` to corresponding channels in `output` + * + * In this scheme, only some channels in the `output` differ from `input`. They differ in the channels + * which have a corresponding channel in `from`. + */ + if (c < c_from) { + const auto from_actual_idx = b * batch_stride_from + c * channel_stride + c_offset; + const auto from_vec_idx = from_actual_idx / vector_type::size(); + + vector_type vec_from; + v_load(vec_from, from_vPtr[from_vec_idx]); + for (int j = 0; j < vector_type::size(); j++) + vec_input.data[j] += vec_from.data[j]; + } + + v_store(output_vPtr[i], vec_input); + } + } +} + +template <class T, std::size_t N> +void launch_vectorized_input_shortcut(const Stream& stream, Span<T> output, View<T> input, std::size_t c_input, View<T> from, std::size_t c_from, std::size_t channel_stride) { + CV_Assert(is_fully_aligned<T>(output, N)); + CV_Assert(is_fully_aligned<T>(input, N)); + CV_Assert(is_fully_aligned<T>(from, N)); + CV_Assert(channel_stride % N == 0); + + auto kernel = raw::input_shortcut_vec<T, N>; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, input, c_input, from, c_from, channel_stride); +} + +template <class T> +void input_shortcut(const csl::Stream& stream, csl::TensorSpan<T> output, csl::TensorView<T> input, csl::TensorView<T> from) { + CV_Assert(is_shape_same(output, input)); + CV_Assert(output.rank() == from.rank()); + for (int i = 0; i < output.rank(); i++) { + if (i != 1) { + CV_Assert(from.get_axis_size(i) == output.get_axis_size(i)); + } + } + + auto channel_stride = output.size_range(2, output.rank()); /* same for `output`, `input` and `from` */ + auto c_input = input.get_axis_size(1); + auto c_from = from.get_axis_size(1); + + if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4) && is_fully_aligned<T>(from, 4) && channel_stride % 4 == 0) { + launch_vectorized_input_shortcut<T, 4>(stream, output, input, c_input, from, c_from, channel_stride); + } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2) && is_fully_aligned<T>(from, 2) && channel_stride % 2 == 0) { + launch_vectorized_input_shortcut<T, 2>(stream, output, input, c_input, from, c_from, channel_stride); + } else { + launch_vectorized_input_shortcut<T, 1>(stream, output, input, c_input, from, c_from, channel_stride); + } +} + +template void input_shortcut(const Stream&, TensorSpan<__half>, TensorView<__half>, TensorView<__half>); +template void input_shortcut(const Stream&, TensorSpan<float>, TensorView<float>, TensorView<float>); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda4dnn/kernels/shortcut.hpp b/modules/dnn/src/cuda4dnn/kernels/shortcut.hpp new file mode 100644 index 0000000000..169d7558a2 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/kernels/shortcut.hpp @@ -0,0 +1,18 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_SHORTCUT_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_SHORTCUT_HPP + +#include "../csl/stream.hpp" +#include "../csl/tensor.hpp" + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + template <class T> + void input_shortcut(const csl::Stream& stream, csl::TensorSpan<T> output, csl::TensorView<T> input, csl::TensorView<T> from); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_SHORTCUT_HPP */ diff --git a/modules/dnn/src/cuda4dnn/primitives/shortcut.hpp b/modules/dnn/src/cuda4dnn/primitives/shortcut.hpp new file mode 100644 index 0000000000..bfdabfc6bc --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/shortcut.hpp @@ -0,0 +1,76 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SHORTCUT_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SHORTCUT_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" +#include "../csl/tensor.hpp" +#include "../csl/tensor_ops.hpp" + +#include "../kernels/shortcut.hpp" + +#include <opencv2/core.hpp> + +#include <utility> + +namespace cv { namespace dnn { namespace cuda4dnn { + + template <class T> + class ShortcutOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType<T>; + + ShortcutOp(csl::Stream stream_) : stream(std::move(stream_)) { } + + void forward( + const std::vector<cv::Ptr<BackendWrapper>>& inputs, + const std::vector<cv::Ptr<BackendWrapper>>& outputs, + csl::Workspace& workspace) override + { + CV_Assert(outputs.size() == 1); + + auto output_wrapper = outputs[0].dynamicCast<wrapper_type>(); + auto output = output_wrapper->getSpan(); + + auto input_wrapper = inputs[0].dynamicCast<wrapper_type>(); + auto input = input_wrapper->getView(); + + /* output shape is determined by the input shape */ + CV_Assert(is_shape_same(output, input)); + + for (int i = 1; i < inputs.size(); i++) + { + auto from_wrapper = inputs[i].dynamicCast<wrapper_type>(); + auto from = from_wrapper->getView(); + + CV_Assert(output.rank() == from.rank()); + for (int i = 0; i < output.rank(); i++) { + if (i != 1) { + CV_Assert(from.get_axis_size(i) == output.get_axis_size(i)); + } + } + + if (i == 1) + { + /* optimized path for first two inputs */ + kernels::input_shortcut<T>(stream, output, input, from); + } + else + { + kernels::input_shortcut<T>(stream, output, output, from); + } + } + + } + + private: + csl::Stream stream; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SHORTCUT_HPP */ diff --git a/modules/dnn/src/layers/eltwise_layer.cpp b/modules/dnn/src/layers/eltwise_layer.cpp index 81a947cd2f..6d2827c3a4 100644 --- a/modules/dnn/src/layers/eltwise_layer.cpp +++ b/modules/dnn/src/layers/eltwise_layer.cpp @@ -53,6 +53,7 @@ #ifdef HAVE_CUDA #include "../cuda4dnn/primitives/eltwise.hpp" +#include "../cuda4dnn/primitives/shortcut.hpp" using namespace cv::dnn::cuda4dnn; #endif @@ -155,8 +156,14 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { + if (backendId == DNN_BACKEND_CUDA) + { + if(channelsModeInput == ELTWISE_CHANNNELS_INPUT_0 || channelsModeInput == ELTWISE_CHANNNELS_INPUT_0_TRUNCATE) + return op == SUM && coeffs.empty(); + return channelsModeInput == ELTWISE_CHANNNELS_SAME; + } + return backendId == DNN_BACKEND_OPENCV || - backendId == DNN_BACKEND_CUDA || (backendId == DNN_BACKEND_HALIDE && op != DIV) || // TODO: not implemented, see PR #15811 ((((backendId == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && (preferableTarget != DNN_TARGET_OPENCL || coeffs.empty())) || backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH) && channelsMode == ELTWISE_CHANNNELS_SAME)); @@ -623,6 +630,25 @@ public: { auto context = reinterpret_cast<csl::CSLContext*>(context_); + CV_Assert(channelsModeInput == ELTWISE_CHANNNELS_INPUT_0 || + channelsModeInput == ELTWISE_CHANNNELS_INPUT_0_TRUNCATE || + channelsModeInput == ELTWISE_CHANNNELS_SAME); + + if(channelsModeInput == ELTWISE_CHANNNELS_INPUT_0 || channelsModeInput == ELTWISE_CHANNNELS_INPUT_0_TRUNCATE) + { + auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>(); + for (int i = 1; i < inputs.size(); i++) + { + auto from_wrapper = inputs[i].dynamicCast<CUDABackendWrapper>(); + if (input_wrapper->getShape()[1] != from_wrapper->getShape()[1]) + { + CV_Assert(op == SUM); + CV_Assert(coeffs.empty()); + return make_cuda_node<cuda4dnn::ShortcutOp>(preferableTarget, std::move(context->stream)); + } + } + } + auto op_ = [this] { switch (op) { case MAX: return cuda4dnn::EltwiseOpType::MAX; diff --git a/modules/dnn/test/test_darknet_importer.cpp b/modules/dnn/test/test_darknet_importer.cpp index a61e6420f1..7545b35b8e 100644 --- a/modules/dnn/test/test_darknet_importer.cpp +++ b/modules/dnn/test/test_darknet_importer.cpp @@ -528,8 +528,6 @@ INSTANTIATE_TEST_CASE_P(/**/, Test_Darknet_nets, dnnBackendsAndTargets()); TEST_P(Test_Darknet_layers, shortcut) { - if (backend == DNN_BACKEND_CUDA) - applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA); testDarknetLayer("shortcut"); testDarknetLayer("shortcut_leaky"); testDarknetLayer("shortcut_unequal"); diff --git a/modules/dnn/test/test_layers.cpp b/modules/dnn/test/test_layers.cpp index 742357be9b..b64a9ca07a 100644 --- a/modules/dnn/test/test_layers.cpp +++ b/modules/dnn/test/test_layers.cpp @@ -1624,7 +1624,7 @@ TEST_P(Layer_Test_Eltwise_unequal, accuracy_input_0_truncate) int backendId = get<0>(get<1>(GetParam())); int targetId = get<1>(get<1>(GetParam())); - if (backendId == DNN_BACKEND_CUDA) + if (backendId == DNN_BACKEND_CUDA && weighted) applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA); Net net; @@ -1690,15 +1690,15 @@ TEST_P(Layer_Test_Eltwise_unequal, accuracy_input_0) int backendId = get<0>(get<1>(GetParam())); int targetId = get<1>(get<1>(GetParam())); - if (backendId == DNN_BACKEND_CUDA) - applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA); - Net net; LayerParams lp; lp.type = "Eltwise"; lp.name = "testLayer"; lp.set<std::string>("output_channels_mode", "input_0"); + if (backendId == DNN_BACKEND_CUDA && weighted) + applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA); + const int inpShapes[][4] = {{1, 4, 2, 2}, {1, 2, 2, 2}, {1, 3, 2, 2}}; const int out_channels = inpShapes[0][1]; std::vector<String> inpNames(3);