Merge pull request #17363 from YashasSamaga:cuda4dnn-eltwise-fusion2

cuda4dnn(conv): fuse eltwise with convolutions

* fuse eltwise with convolutions

* manually rebase to avoid bad git merge
pull/17791/head
Yashas Samaga B L 4 years ago committed by GitHub
parent 44d473fba0
commit d0e6d2438c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 121
      modules/dnn/src/cuda/activation_eltwise.cu
  2. 84
      modules/dnn/src/cuda/activations.cu
  3. 62
      modules/dnn/src/cuda/bias_activation.cu
  4. 125
      modules/dnn/src/cuda/bias_activation_eltwise.cu
  5. 132
      modules/dnn/src/cuda/bias_eltwise_activation.cu
  6. 125
      modules/dnn/src/cuda/eltwise_activation.cu
  7. 36
      modules/dnn/src/cuda/eltwise_ops.cu
  8. 245
      modules/dnn/src/cuda/functors.hpp
  9. 6
      modules/dnn/src/cuda/scale_shift.cu
  10. 95
      modules/dnn/src/cuda4dnn/csl/cudnn/convolution.hpp
  11. 11
      modules/dnn/src/cuda4dnn/csl/span.hpp
  12. 29
      modules/dnn/src/cuda4dnn/csl/tensor_ops.hpp
  13. 40
      modules/dnn/src/cuda4dnn/kernels/activation_eltwise.hpp
  14. 18
      modules/dnn/src/cuda4dnn/kernels/activations.hpp
  15. 10
      modules/dnn/src/cuda4dnn/kernels/bias_activation.hpp
  16. 42
      modules/dnn/src/cuda4dnn/kernels/bias_activation_eltwise.hpp
  17. 45
      modules/dnn/src/cuda4dnn/kernels/bias_eltwise_activation.hpp
  18. 40
      modules/dnn/src/cuda4dnn/kernels/eltwise_activation.hpp
  19. 289
      modules/dnn/src/cuda4dnn/primitives/convolution.hpp
  20. 29
      modules/dnn/src/cuda4dnn/primitives/eltwise.hpp
  21. 242
      modules/dnn/src/dnn.cpp
  22. 56
      modules/dnn/src/layers/convolution_layer.cpp

@ -0,0 +1,121 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "functors.hpp"
#include "vector_traits.hpp"
#include "grid_stride_range.hpp"
#include "execution.hpp"
#include "../cuda4dnn/csl/stream.hpp"
#include "../cuda4dnn/csl/span.hpp"
using namespace cv::dnn::cuda4dnn::csl;
using namespace cv::dnn::cuda4dnn::csl::device;
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T, class ActivationOp, class EltwiseOp, std::size_t N>
__global__ void generic_op_eltwise_op_inplace_vec(Span<T> inplace_output, View<T> eltwise, const typename ActivationOp::Params act_params, const typename EltwiseOp::Params eltwise_params) {
using vector_type = get_vector_type_t<T, N>;
auto inplace_output_vPtr = vector_type::get_pointer(inplace_output.data());
auto eltwise_vPtr = vector_type::get_pointer(eltwise.data());
ActivationOp activation_op(act_params);
EltwiseOp eltwise_op(eltwise_params);
for (auto i : grid_stride_range(inplace_output.size() / vector_type::size())) {
vector_type output_vec, eltwise_vec;
v_load(output_vec, inplace_output_vPtr[i]);
v_load(eltwise_vec, eltwise_vPtr[i]);
for(int j = 0; j < output_vec.size(); j++)
output_vec.data[j] = eltwise_op(activation_op(output_vec.data[j]), eltwise_vec.data[j]);
v_store(inplace_output_vPtr[i], output_vec);
}
}
}
template <class T, class ActivationOp, class EltwiseOp, std::size_t N> static
void launch_vectorized_generic_op_eltwise_op_inplace(const Stream& stream, Span<T> inplace_output, View<T> eltwise, const typename ActivationOp::Params& act_params, const typename EltwiseOp::Params& eltwise_params) {
CV_Assert(is_fully_aligned<T>(inplace_output, N));
CV_Assert(is_fully_aligned<T>(eltwise, N));
auto kernel = raw::generic_op_eltwise_op_inplace_vec<T, ActivationOp, EltwiseOp, N>;
auto policy = make_policy(kernel, inplace_output.size() / N, 0, stream);
launch_kernel(kernel, policy, inplace_output, eltwise, act_params, eltwise_params);
}
template <class T, class ActivationOp, class EltwiseOp> static
void generic_op_eltwise_op_inplace(const Stream& stream, Span<T> inplace_output, View<T> eltwise, const typename ActivationOp::Params& act_params = {}, const typename EltwiseOp::Params& eltwise_params = {}) {
CV_Assert(inplace_output.size() == eltwise.size());
if (is_fully_aligned<T>(inplace_output, 4) && is_fully_aligned<T>(eltwise, 4)) {
launch_vectorized_generic_op_eltwise_op_inplace<T, ActivationOp, EltwiseOp, 4>(stream, inplace_output, eltwise, act_params, eltwise_params);
} else if (is_fully_aligned<T>(inplace_output, 2) && is_fully_aligned<T>(eltwise, 2)) {
launch_vectorized_generic_op_eltwise_op_inplace<T, ActivationOp, EltwiseOp, 2>(stream, inplace_output, eltwise, act_params, eltwise_params);
} else {
launch_vectorized_generic_op_eltwise_op_inplace<T, ActivationOp, EltwiseOp, 1>(stream, inplace_output, eltwise, act_params, eltwise_params);
}
}
template <class T>
void relu_eltwise_sum_2_inplace(const Stream& stream, Span<T> inplace_output, View<T> eltwise, T slope) {
generic_op_eltwise_op_inplace<T, ReLUFunctor<T>, SumFunctor<T>>(stream, inplace_output, eltwise, {slope});
}
template <class T>
void clipped_relu_eltwise_sum_2_inplace(const Stream& stream, Span<T> inplace_output, View<T> eltwise, T floor, T ceiling) {
CV_Assert(static_cast<double>(floor) <= static_cast<double>(ceiling));
generic_op_eltwise_op_inplace<T, ClippedReLUFunctor<T>, SumFunctor<T>>(stream, inplace_output, eltwise, {floor, ceiling});
}
template <class T>
void tanh_eltwise_sum_2_inplace(const Stream& stream, Span<T> inplace_output, View<T> eltwise) {
generic_op_eltwise_op_inplace<T, TanHFunctor<T>, SumFunctor<T>>(stream, inplace_output, eltwise);
}
template <class T>
void swish_eltwise_sum_2_inplace(const Stream& stream, Span<T> inplace_output, View<T> eltwise) {
generic_op_eltwise_op_inplace<T, SwishFunctor<T>, SumFunctor<T>>(stream, inplace_output, eltwise);
}
template <class T>
void mish_eltwise_sum_2_inplace(const Stream& stream, Span<T> inplace_output, View<T> eltwise) {
generic_op_eltwise_op_inplace<T, MishFunctor<T>, SumFunctor<T>>(stream, inplace_output, eltwise);
}
template <class T>
void sigmoid_eltwise_sum_2_inplace(const Stream& stream, Span<T> inplace_output, View<T> eltwise) {
generic_op_eltwise_op_inplace<T, SigmoidFunctor<T>, SumFunctor<T>>(stream, inplace_output, eltwise);
}
template <class T>
void power_eltwise_sum_2_inplace(const Stream& stream, Span<T> inplace_output, View<T> eltwise, T exp, T scale, T shift) {
generic_op_eltwise_op_inplace<T, PowerFunctor<T>, SumFunctor<T>>(stream, inplace_output, eltwise, {exp, scale, shift});
}
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template void relu_eltwise_sum_2_inplace<__half>(const Stream&, Span<__half>, View<__half>, __half);
template void clipped_relu_eltwise_sum_2_inplace<__half>(const Stream&, Span<__half>, View<__half>, __half, __half);
template void tanh_eltwise_sum_2_inplace<__half>(const Stream&, Span<__half>, View<__half>);
template void swish_eltwise_sum_2_inplace<__half>(const Stream&, Span<__half>, View<__half>);
template void mish_eltwise_sum_2_inplace<__half>(const Stream&, Span<__half>, View<__half>);
template void sigmoid_eltwise_sum_2_inplace<__half>(const Stream&, Span<__half>, View<__half>);
template void power_eltwise_sum_2_inplace<__half>(const Stream&, Span<__half>, View<__half>, __half, __half, __half);
#endif
template void relu_eltwise_sum_2_inplace<float>(const Stream&, Span<float>, View<float>, float);
template void clipped_relu_eltwise_sum_2_inplace<float>(const Stream&, Span<float>, View<float>, float, float);
template void tanh_eltwise_sum_2_inplace<float>(const Stream&, Span<float>, View<float>);
template void swish_eltwise_sum_2_inplace<float>(const Stream&, Span<float>, View<float>);
template void mish_eltwise_sum_2_inplace<float>(const Stream&, Span<float>, View<float>);
template void sigmoid_eltwise_sum_2_inplace<float>(const Stream&, Span<float>, View<float>);
template void power_eltwise_sum_2_inplace<float>(const Stream&, Span<float>, View<float>, float, float, float);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */

@ -26,20 +26,20 @@ using namespace cv::dnn::cuda4dnn::csl::device;
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T, class Functor, std::size_t N, class ...FunctorArgs>
__global__ void generic_op_vec(Span<T> output, View<T> input, FunctorArgs ...functorArgs) {
template <class T, class ActivationOp, std::size_t N>
__global__ void generic_op_vec(Span<T> output, View<T> input, const typename ActivationOp::Params params) {
using vector_type = get_vector_type_t<T, N>;
auto output_vPtr = vector_type::get_pointer(output.data());
auto input_vPtr = vector_type::get_pointer(input.data());
Functor functor(functorArgs...);
ActivationOp activation_op(params);
for (auto i : grid_stride_range(output.size() / vector_type::size())) {
vector_type vec;
v_load(vec, input_vPtr[i]);
for (int j = 0; j < vector_type::size(); j++)
vec.data[j] = functor(vec.data[j]);
vec.data[j] = activation_op(vec.data[j]);
v_store(output_vPtr[i], vec);
}
}
@ -51,9 +51,8 @@ namespace raw {
auto output_vPtr = vector_type::get_pointer(output.data());
auto input_vPtr = vector_type::get_pointer(input.data());
inner_size /= vector_type::size();
for (auto i : grid_stride_range(output.size() / vector_type::size())) {
const index_type c = (i / inner_size) % static_cast<size_type>(slope.size());
const index_type c = (i / inner_size) % slope.size();
vector_type vec;
v_load(vec, input_vPtr[i]);
@ -65,73 +64,73 @@ namespace raw {
} /* namespace raw */
template <class T, template <class> class Activation, std::size_t N, class ...ActivationArgs> static
void launch_vectorized_generic_op(const Stream& stream, Span<T> output, View<T> input, ActivationArgs ...activationArgs) {
template <class T, class ActivationOp, std::size_t N> static
void launch_vectorized_generic_op(const Stream& stream, Span<T> output, View<T> input, const typename ActivationOp::Params& params) {
CV_Assert(is_fully_aligned<T>(output, N));
CV_Assert(is_fully_aligned<T>(input, N));
auto kernel = raw::generic_op_vec<T, Activation<T>, N, ActivationArgs...>;
auto kernel = raw::generic_op_vec<T, ActivationOp, N>;
auto policy = make_policy(kernel, output.size() / N, 0, stream);
launch_kernel(kernel, policy, output, input, activationArgs...);
launch_kernel(kernel, policy, output, input, params);
}
template <class T, template <class> class Activation, class ...ActivationArgs> static
void generic_op(const Stream& stream, Span<T> output, View<T> input, ActivationArgs ...activationArgs) {
template <class T, class ActivationOp> static
void generic_op(const Stream& stream, Span<T> output, View<T> input, const typename ActivationOp::Params& params = {}) {
CV_Assert(input.size() == output.size());
if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
launch_vectorized_generic_op<T, Activation, 4>(stream, output, input, activationArgs...);
launch_vectorized_generic_op<T, ActivationOp, 4>(stream, output, input, params);
} else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
launch_vectorized_generic_op<T, Activation, 2>(stream, output, input, activationArgs...);
launch_vectorized_generic_op<T, ActivationOp, 2>(stream, output, input, params);
} else {
launch_vectorized_generic_op<T, Activation, 1>(stream, output, input, activationArgs...);
launch_vectorized_generic_op<T, ActivationOp, 1>(stream, output, input, params);
}
}
template <class T>
void abs(const Stream& stream, Span<T> output, View<T> input) {
generic_op<T, abs_functor>(stream, output, input);
void relu(const Stream& stream, Span<T> output, View<T> input, T slope) {
generic_op<T, ReLUFunctor<T>>(stream, output, input, {slope});
}
template <class T>
void clipped_relu(const Stream& stream, Span<T> output, View<T> input, T floor, T ceiling) {
CV_Assert(static_cast<double>(floor) <= static_cast<double>(ceiling));
generic_op<T, ClippedReLUFunctor<T>>(stream, output, input, {floor, ceiling});
}
template <class T>
void tanh(const Stream& stream, Span<T> output, View<T> input) {
generic_op<T, tanh_functor>(stream, output, input);
generic_op<T, TanHFunctor<T>>(stream, output, input);
}
template <class T>
void swish(const Stream& stream, Span<T> output, View<T> input) {
generic_op<T, swish_functor>(stream, output, input);
generic_op<T, SwishFunctor<T>>(stream, output, input);
}
template <class T>
void mish(const Stream& stream, Span<T> output, View<T> input) {
generic_op<T, mish_functor>(stream, output, input);
generic_op<T, MishFunctor<T>>(stream, output, input);
}
template <class T>
void sigmoid(const Stream& stream, Span<T> output, View<T> input) {
generic_op<T, sigmoid_functor>(stream, output, input);
}
template <class T>
void bnll(const Stream& stream, Span<T> output, View<T> input) {
generic_op<T, bnll_functor>(stream, output, input);
generic_op<T, SigmoidFunctor<T>>(stream, output, input);
}
template <class T>
void elu(const Stream& stream, Span<T> output, View<T> input) {
generic_op<T, elu_functor>(stream, output, input);
generic_op<T, ELUFunctor<T>>(stream, output, input);
}
template <class T>
void relu(const Stream& stream, Span<T> output, View<T> input, T slope) {
generic_op<T, relu_functor>(stream, output, input, slope);
void bnll(const Stream& stream, Span<T> output, View<T> input) {
generic_op<T, BNLLFunctor<T>>(stream, output, input);
}
template <class T>
void clipped_relu(const Stream& stream, Span<T> output, View<T> input, T floor, T ceiling) {
CV_Assert(static_cast<double>(floor) <= static_cast<double>(ceiling));
generic_op<T, clipped_relu_functor>(stream, output, input, floor, ceiling);
void abs(const Stream& stream, Span<T> output, View<T> input) {
generic_op<T, AbsFunctor<T>>(stream, output, input);
}
template <class T>
@ -143,31 +142,32 @@ void power(const Stream& stream, Span<T> output, View<T> input, T exp, T scale,
return;
}
generic_op<T, power_functor>(stream, output, input, exp, scale, shift);
generic_op<T, PowerFunctor<T>>(stream, output, input, {exp, scale, shift});
}
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template void abs<__half>(const Stream& stream, Span<__half> output, View<__half> input);
template void relu<__half>(const Stream&, Span<__half>, View<__half>, __half);
template void clipped_relu<__half>(const Stream&, Span<__half>, View<__half>, __half, __half);
template void tanh<__half>(const Stream&, Span<__half>, View<__half>);
template void swish<__half>(const Stream&, Span<__half>, View<__half>);
template void mish<__half>(const Stream&, Span<__half>, View<__half>);
template void sigmoid<__half>(const Stream&, Span<__half>, View<__half>);
template void bnll<__half>(const Stream&, Span<__half>, View<__half>);
template void elu<__half>(const Stream&, Span<__half>, View<__half>);
template void relu<__half>(const Stream&, Span<__half>, View<__half>, __half);
template void clipped_relu<__half>(const Stream&, Span<__half>, View<__half>, __half, __half);
template void abs<__half>(const Stream& stream, Span<__half> output, View<__half> input);
template void bnll<__half>(const Stream&, Span<__half>, View<__half>);
template void power<__half>(const Stream&, Span<__half>, View<__half>, __half, __half, __half);
#endif
template void abs<float>(const Stream& stream, Span<float> output, View<float> input);
template void relu<float>(const Stream&, Span<float>, View<float>, float);
template void clipped_relu<float>(const Stream&, Span<float>, View<float>, float, float);
template void tanh<float>(const Stream&, Span<float>, View<float>);
template void swish<float>(const Stream&, Span<float>, View<float>);
template void mish<float>(const Stream&, Span<float>, View<float>);
template void sigmoid<float>(const Stream&, Span<float>, View<float>);
template void bnll<float>(const Stream&, Span<float>, View<float>);
template void elu<float>(const Stream&, Span<float>, View<float>);
template void relu<float>(const Stream&, Span<float>, View<float>, float);
template void clipped_relu<float>(const Stream&, Span<float>, View<float>, float, float);
template void abs<float>(const Stream& stream, Span<float> output, View<float> input);
template void bnll<float>(const Stream&, Span<float>, View<float>);
template void power<float>(const Stream&, Span<float>, View<float>, float, float, float);
template <class T, std::size_t N> static
@ -178,7 +178,7 @@ void launch_vectorized_axiswise_relu(const Stream& stream, Span<T> output, View<
auto kernel = raw::axiswise_relu_vec<T, N>;
auto policy = make_policy(kernel, output.size() / N, 0, stream);
launch_kernel(kernel, policy, output, input, inner_size, slope);
launch_kernel(kernel, policy, output, input, inner_size / N, slope);
}
template <class T>

@ -20,103 +20,101 @@ using namespace cv::dnn::cuda4dnn::csl::device;
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T, class Functor, std::size_t N, class ...FunctorArgs>
__global__ void biasN_generic_op_inplace_vec(Span<T> inplace_output, size_type inner_size, View<T> bias, FunctorArgs ...functorArgs) {
template <class T, class ActivationOp, std::size_t N>
__global__ void biasN_generic_op_inplace_vec(Span<T> inplace_output, size_type inner_size, View<T> bias, const typename ActivationOp::Params params) {
using vector_type = get_vector_type_t<T, N>;
auto inplace_output_vPtr = vector_type::get_pointer(inplace_output.data());
Functor functor(functorArgs...);
ActivationOp activation_op(params);
inner_size /= vector_type::size();
for (auto i : grid_stride_range(inplace_output.size() / vector_type::size())) {
const index_type bias_idx = (i / inner_size) % static_cast<size_type>(bias.size());
const index_type bias_idx = (i / inner_size) % bias.size();
vector_type vec;
v_load(vec, inplace_output_vPtr[i]);
for(int j = 0; j < vec.size(); j++)
vec.data[j] = functor(vec.data[j] + bias[bias_idx]);
vec.data[j] = activation_op(vec.data[j] + bias[bias_idx]);
v_store(inplace_output_vPtr[i], vec);
}
}
} /* namespace raw */
template <class T, template <class> class Activation, std::size_t N, class ...ActivationArgs> static
void launch_vectorized_biasN_generic_op_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, ActivationArgs ...activationArgs) {
template <class T, class ActivationOp, std::size_t N> static
void launch_vectorized_biasN_generic_op_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, const typename ActivationOp::Params& params) {
CV_Assert(inplace_output.size() % inner_size == 0);
CV_Assert(inplace_output.size() % bias.size() == 0);
CV_Assert(is_fully_aligned<T>(inplace_output, N));
CV_Assert(inner_size % N == 0);
auto kernel = raw::biasN_generic_op_inplace_vec<T, Activation<T>, N, ActivationArgs...>;
auto kernel = raw::biasN_generic_op_inplace_vec<T, ActivationOp, N>;
auto policy = make_policy(kernel, inplace_output.size() / N, 0, stream);
launch_kernel(kernel, policy, inplace_output, inner_size, bias, activationArgs...);
launch_kernel(kernel, policy, inplace_output, inner_size / N, bias, params);
}
template <class T, template <class> class Activation, class ...ActivationArgs> static
void biasN_generic_op_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, ActivationArgs ...activationArgs) {
template <class T, class ActivationOp> static
void biasN_generic_op_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, const typename ActivationOp::Params& params = {}) {
if (is_fully_aligned<T>(inplace_output, 4) && inner_size % 4 == 0) {
launch_vectorized_biasN_generic_op_inplace<T, Activation, 4>(stream, inplace_output, inner_size, bias, activationArgs...);
launch_vectorized_biasN_generic_op_inplace<T, ActivationOp, 4>(stream, inplace_output, inner_size, bias, params);
} else if (is_fully_aligned<T>(inplace_output, 2) && inner_size % 2 == 0) {
launch_vectorized_biasN_generic_op_inplace<T, Activation, 2>(stream, inplace_output, inner_size, bias, activationArgs...);
launch_vectorized_biasN_generic_op_inplace<T, ActivationOp, 2>(stream, inplace_output, inner_size, bias, params);
} else {
launch_vectorized_biasN_generic_op_inplace<T, Activation, 1>(stream, inplace_output, inner_size, bias, activationArgs...);
launch_vectorized_biasN_generic_op_inplace<T, ActivationOp, 1>(stream, inplace_output, inner_size, bias, params);
}
}
template <class T>
void biasN_relu_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, T slope) {
biasN_generic_op_inplace<T, relu_functor>(stream, inplace_output, inner_size, bias, slope);
biasN_generic_op_inplace<T, ReLUFunctor<T>>(stream, inplace_output, inner_size, bias, {slope});
}
template <class T>
void biasN_clipped_relu_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, T floor, T ceil) {
CV_Assert(static_cast<double>(floor) <= static_cast<double>(ceil));
biasN_generic_op_inplace<T, clipped_relu_functor>(stream, inplace_output, inner_size, bias, floor, ceil);
biasN_generic_op_inplace<T, ClippedReLUFunctor<T>>(stream, inplace_output, inner_size, bias, {floor, ceil});
}
template <class T>
void biasN_power_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, T power, T scale, T shift) {
biasN_generic_op_inplace<T, power_functor>(stream, inplace_output, inner_size, bias, power, scale, shift);
void biasN_tanh_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias) {
biasN_generic_op_inplace<T, TanHFunctor<T>>(stream, inplace_output, inner_size, bias);
}
template <class T>
void biasN_tanh_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias) {
biasN_generic_op_inplace<T, tanh_functor>(stream, inplace_output, inner_size, bias);
void biasN_swish_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias) {
biasN_generic_op_inplace<T, SwishFunctor<T>>(stream, inplace_output, inner_size, bias);
}
template <class T>
void biasN_sigmoid_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias) {
biasN_generic_op_inplace<T, sigmoid_functor>(stream, inplace_output, inner_size, bias);
void biasN_mish_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias) {
biasN_generic_op_inplace<T, MishFunctor<T>>(stream, inplace_output, inner_size, bias);
}
template <class T>
void biasN_swish_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias) {
biasN_generic_op_inplace<T, swish_functor>(stream, inplace_output, inner_size, bias);
void biasN_sigmoid_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias) {
biasN_generic_op_inplace<T, SigmoidFunctor<T>>(stream, inplace_output, inner_size, bias);
}
template <class T>
void biasN_mish_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias) {
biasN_generic_op_inplace<T, mish_functor>(stream, inplace_output, inner_size, bias);
void biasN_power_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, T power, T scale, T shift) {
biasN_generic_op_inplace<T, PowerFunctor<T>>(stream, inplace_output, inner_size, bias, {power, scale, shift});
}
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template void biasN_relu_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, __half);
template void biasN_clipped_relu_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, __half, __half);
template void biasN_power_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, __half, __half, __half);
template void biasN_tanh_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>);
template void biasN_sigmoid_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>);
template void biasN_swish_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>);
template void biasN_mish_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>);
template void biasN_sigmoid_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>);
template void biasN_power_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, __half, __half, __half);
#endif
template void biasN_relu_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, float);
template void biasN_clipped_relu_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, float, float);
template void biasN_power_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, float, float, float);
template void biasN_tanh_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>);
template void biasN_sigmoid_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>);
template void biasN_swish_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>);
template void biasN_mish_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>);
template void biasN_sigmoid_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>);
template void biasN_power_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, float, float, float);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */

@ -0,0 +1,125 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "functors.hpp"
#include "types.hpp"
#include "vector_traits.hpp"
#include "grid_stride_range.hpp"
#include "execution.hpp"
#include "../cuda4dnn/csl/stream.hpp"
#include "../cuda4dnn/csl/span.hpp"
using namespace cv::dnn::cuda4dnn::csl;
using namespace cv::dnn::cuda4dnn::csl::device;
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T, class ActivationOp, class EltwiseOp, std::size_t N>
__global__ void biasN_generic_op_eltwise_op_inplace_vec(Span<T> inplace_output, size_type inner_size, View<T> bias, View<T> eltwise, const typename ActivationOp::Params act_params, const typename EltwiseOp::Params eltwise_params) {
using vector_type = get_vector_type_t<T, N>;
auto inplace_output_vPtr = vector_type::get_pointer(inplace_output.data());
auto eltwise_vPtr = vector_type::get_pointer(eltwise.data());
ActivationOp activation_op(act_params);
EltwiseOp eltwise_op(eltwise_params);
for (auto i : grid_stride_range(inplace_output.size() / vector_type::size())) {
const index_type bias_idx = (i / inner_size) % bias.size();
vector_type output_vec, eltwise_vec;
v_load(output_vec, inplace_output_vPtr[i]);
v_load(eltwise_vec, eltwise_vPtr[i]);
for(int j = 0; j < output_vec.size(); j++)
output_vec.data[j] = eltwise_op(activation_op(output_vec.data[j] + bias[bias_idx]), eltwise_vec.data[j]);
v_store(inplace_output_vPtr[i], output_vec);
}
}
}
template <class T, class ActivationOp, class EltwiseOp, std::size_t N> static
void launch_vectorized_biasN_generic_op_eltwise_op_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, View<T> eltwise, const typename ActivationOp::Params& act_params, const typename EltwiseOp::Params& eltwise_params) {
CV_Assert(is_fully_aligned<T>(inplace_output, N));
CV_Assert(is_fully_aligned<T>(eltwise, N));
CV_Assert(inner_size % N == 0);
auto kernel = raw::biasN_generic_op_eltwise_op_inplace_vec<T, ActivationOp, EltwiseOp, N>;
auto policy = make_policy(kernel, inplace_output.size() / N, 0, stream);
launch_kernel(kernel, policy, inplace_output, inner_size / N, bias, eltwise, act_params, eltwise_params);
}
template <class T, class ActivationOp, class EltwiseOp> static
void biasN_generic_op_eltwise_op_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, View<T> eltwise, const typename ActivationOp::Params& act_params = {}, const typename EltwiseOp::Params& eltwise_params = {}) {
CV_Assert(inplace_output.size() == eltwise.size());
if (is_fully_aligned<T>(inplace_output, 4) && is_fully_aligned<T>(eltwise, 4) && inner_size % 4 == 0) {
launch_vectorized_biasN_generic_op_eltwise_op_inplace<T, ActivationOp, EltwiseOp, 4>(stream, inplace_output, inner_size, bias, eltwise, act_params, eltwise_params);
} else if (is_fully_aligned<T>(inplace_output, 2) && is_fully_aligned<T>(eltwise, 2) && inner_size % 2 == 0) {
launch_vectorized_biasN_generic_op_eltwise_op_inplace<T, ActivationOp, EltwiseOp, 2>(stream, inplace_output, inner_size, bias, eltwise, act_params, eltwise_params);
} else {
launch_vectorized_biasN_generic_op_eltwise_op_inplace<T, ActivationOp, EltwiseOp, 1>(stream, inplace_output, inner_size, bias, eltwise, act_params, eltwise_params);
}
}
template <class T>
void biasN_relu_eltwise_sum_2_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, View<T> eltwise, T slope) {
biasN_generic_op_eltwise_op_inplace<T, ReLUFunctor<T>, SumFunctor<T>>(stream, inplace_output, inner_size, bias, eltwise, {slope});
}
template <class T>
void biasN_clipped_relu_eltwise_sum_2_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, View<T> eltwise, T floor, T ceiling) {
CV_Assert(static_cast<double>(floor) <= static_cast<double>(ceiling));
biasN_generic_op_eltwise_op_inplace<T, ClippedReLUFunctor<T>, SumFunctor<T>>(stream, inplace_output, inner_size, bias, eltwise, {floor, ceiling});
}
template <class T>
void biasN_tanh_eltwise_sum_2_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, View<T> eltwise) {
biasN_generic_op_eltwise_op_inplace<T, TanHFunctor<T>, SumFunctor<T>>(stream, inplace_output, inner_size, bias, eltwise);
}
template <class T>
void biasN_swish_eltwise_sum_2_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, View<T> eltwise) {
biasN_generic_op_eltwise_op_inplace<T, SwishFunctor<T>, SumFunctor<T>>(stream, inplace_output, inner_size, bias, eltwise);
}
template <class T>
void biasN_mish_eltwise_sum_2_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, View<T> eltwise) {
biasN_generic_op_eltwise_op_inplace<T, MishFunctor<T>, SumFunctor<T>>(stream, inplace_output, inner_size, bias, eltwise);
}
template <class T>
void biasN_sigmoid_eltwise_sum_2_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, View<T> eltwise) {
biasN_generic_op_eltwise_op_inplace<T, SigmoidFunctor<T>, SumFunctor<T>>(stream, inplace_output, inner_size, bias, eltwise);
}
template <class T>
void biasN_power_eltwise_sum_2_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, View<T> eltwise, T exp, T scale, T shift) {
biasN_generic_op_eltwise_op_inplace<T, PowerFunctor<T>, SumFunctor<T>>(stream, inplace_output, inner_size, bias, eltwise, {exp, scale, shift});
}
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template void biasN_relu_eltwise_sum_2_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, View<__half>, __half);
template void biasN_clipped_relu_eltwise_sum_2_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, View<__half>, __half, __half);
template void biasN_tanh_eltwise_sum_2_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, View<__half>);
template void biasN_swish_eltwise_sum_2_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, View<__half>);
template void biasN_mish_eltwise_sum_2_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, View<__half>);
template void biasN_sigmoid_eltwise_sum_2_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, View<__half>);
template void biasN_power_eltwise_sum_2_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, View<__half>, __half, __half, __half);
#endif
template void biasN_relu_eltwise_sum_2_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, View<float>, float);
template void biasN_clipped_relu_eltwise_sum_2_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, View<float>, float, float);
template void biasN_tanh_eltwise_sum_2_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, View<float>);
template void biasN_swish_eltwise_sum_2_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, View<float>);
template void biasN_mish_eltwise_sum_2_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, View<float>);
template void biasN_sigmoid_eltwise_sum_2_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, View<float>);
template void biasN_power_eltwise_sum_2_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, View<float>, float, float, float);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */

@ -0,0 +1,132 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "functors.hpp"
#include "types.hpp"
#include "vector_traits.hpp"
#include "grid_stride_range.hpp"
#include "execution.hpp"
#include "../cuda4dnn/csl/stream.hpp"
#include "../cuda4dnn/csl/span.hpp"
using namespace cv::dnn::cuda4dnn::csl;
using namespace cv::dnn::cuda4dnn::csl::device;
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T, class EltwiseOp, class ActivationOp, std::size_t N>
__global__ void biasN_eltwise_op_generic_op_inplace_vec(Span<T> inplace_output, size_type inner_size, View<T> bias, View<T> eltwise, const typename EltwiseOp::Params eltwise_params, const typename ActivationOp::Params act_params) {
using vector_type = get_vector_type_t<T, N>;
auto inplace_output_vPtr = vector_type::get_pointer(inplace_output.data());
auto eltwise_vPtr = vector_type::get_pointer(eltwise.data());
EltwiseOp eltwise_op(eltwise_params);
ActivationOp activation_op(act_params);
for (auto i : grid_stride_range(inplace_output.size() / vector_type::size())) {
const index_type bias_idx = (i / inner_size) % bias.size();
vector_type output_vec, eltwise_vec;
v_load(output_vec, inplace_output_vPtr[i]);
v_load(eltwise_vec, eltwise_vPtr[i]);
for(int j = 0; j < output_vec.size(); j++)
output_vec.data[j] = activation_op(eltwise_op(output_vec.data[j] + bias[bias_idx], eltwise_vec.data[j]));
v_store(inplace_output_vPtr[i], output_vec);
}
}
}
template <class T, class EltwiseOp, class ActivationOp, std::size_t N> static
void launch_vectorized_biasN_eltwise_op_generic_op_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, View<T> eltwise, const typename EltwiseOp::Params& eltwise_params, const typename ActivationOp::Params& act_params) {
CV_Assert(is_fully_aligned<T>(inplace_output, N));
CV_Assert(inplace_output.size() % bias.size() == 0);
CV_Assert(is_fully_aligned<T>(eltwise, N));
CV_Assert(inner_size % N == 0);
auto kernel = raw::biasN_eltwise_op_generic_op_inplace_vec<T, EltwiseOp, ActivationOp, N>;
auto policy = make_policy(kernel, inplace_output.size() / N, 0, stream);
launch_kernel(kernel, policy, inplace_output, inner_size / N, bias, eltwise, eltwise_params, act_params);
}
template <class T, class EltwiseOp, class ActivationOp> static
void biasN_eltwise_op_generic_op_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, View<T> eltwise, const typename EltwiseOp::Params& eltwise_params = {}, const typename ActivationOp::Params& act_params = {}) {
CV_Assert(inplace_output.size() == eltwise.size());
if (is_fully_aligned<T>(inplace_output, 4) && is_fully_aligned<T>(eltwise, 4) && inner_size % 4 == 0) {
launch_vectorized_biasN_eltwise_op_generic_op_inplace<T, EltwiseOp, ActivationOp, 4>(stream, inplace_output, inner_size, bias, eltwise, eltwise_params, act_params);
} else if (is_fully_aligned<T>(inplace_output, 2) && is_fully_aligned<T>(eltwise, 2) && inner_size % 2 == 0) {
launch_vectorized_biasN_eltwise_op_generic_op_inplace<T, EltwiseOp, ActivationOp, 2>(stream, inplace_output, inner_size, bias, eltwise, eltwise_params, act_params);
} else {
launch_vectorized_biasN_eltwise_op_generic_op_inplace<T, EltwiseOp, ActivationOp, 1>(stream, inplace_output, inner_size, bias, eltwise, eltwise_params, act_params);
}
}
template <class T>
void biasN_eltwise_sum_2_identity_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, View<T> eltwise) {
biasN_eltwise_op_generic_op_inplace<T, SumFunctor<T>, IdentityFunctor<T>>(stream, inplace_output, inner_size, bias, eltwise);
}
template <class T>
void biasN_eltwise_sum_2_relu_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, View<T> eltwise, T slope) {
biasN_eltwise_op_generic_op_inplace<T, SumFunctor<T>, ReLUFunctor<T>>(stream, inplace_output, inner_size, bias, eltwise, {}, {slope});
}
template <class T>
void biasN_eltwise_sum_2_clipped_relu_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, View<T> eltwise, T floor, T ceiling) {
CV_Assert(static_cast<double>(floor) <= static_cast<double>(ceiling));
biasN_eltwise_op_generic_op_inplace<T, SumFunctor<T>, ClippedReLUFunctor<T>>(stream, inplace_output, inner_size, bias, eltwise, {}, {floor, ceiling});
}
template <class T>
void biasN_eltwise_sum_2_tanh_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, View<T> eltwise) {
biasN_eltwise_op_generic_op_inplace<T, SumFunctor<T>, TanHFunctor<T>>(stream, inplace_output, inner_size, bias, eltwise);
}
template <class T>
void biasN_eltwise_sum_2_swish_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, View<T> eltwise) {
biasN_eltwise_op_generic_op_inplace<T, SumFunctor<T>, SwishFunctor<T>>(stream, inplace_output, inner_size, bias, eltwise);
}
template <class T>
void biasN_eltwise_sum_2_mish_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, View<T> eltwise) {
biasN_eltwise_op_generic_op_inplace<T, SumFunctor<T>, MishFunctor<T>>(stream, inplace_output, inner_size, bias, eltwise);
}
template <class T>
void biasN_eltwise_sum_2_sigmoid_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, View<T> eltwise) {
biasN_eltwise_op_generic_op_inplace<T, SumFunctor<T>, SigmoidFunctor<T>>(stream, inplace_output, inner_size, bias, eltwise);
}
template <class T>
void biasN_eltwise_sum_2_power_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, View<T> eltwise, T exp, T scale, T shift) {
biasN_eltwise_op_generic_op_inplace<T, SumFunctor<T>, PowerFunctor<T>>(stream, inplace_output, inner_size, bias, eltwise, {}, {exp, scale, shift});
}
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template void biasN_eltwise_sum_2_identity_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, View<__half>);
template void biasN_eltwise_sum_2_relu_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, View<__half>, __half);
template void biasN_eltwise_sum_2_clipped_relu_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, View<__half>, __half, __half);
template void biasN_eltwise_sum_2_tanh_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, View<__half>);
template void biasN_eltwise_sum_2_swish_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, View<__half>);
template void biasN_eltwise_sum_2_mish_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, View<__half>);
template void biasN_eltwise_sum_2_sigmoid_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, View<__half>);
template void biasN_eltwise_sum_2_power_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, View<__half>, __half, __half, __half);
#endif
template void biasN_eltwise_sum_2_identity_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, View<float>);
template void biasN_eltwise_sum_2_relu_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, View<float>, float);
template void biasN_eltwise_sum_2_clipped_relu_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, View<float>, float, float);
template void biasN_eltwise_sum_2_tanh_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, View<float>);
template void biasN_eltwise_sum_2_swish_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, View<float>);
template void biasN_eltwise_sum_2_mish_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, View<float>);
template void biasN_eltwise_sum_2_sigmoid_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, View<float>);
template void biasN_eltwise_sum_2_power_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, View<float>, float, float, float);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */

@ -0,0 +1,125 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "functors.hpp"
#include "types.hpp"
#include "vector_traits.hpp"
#include "grid_stride_range.hpp"
#include "execution.hpp"
#include "../cuda4dnn/csl/stream.hpp"
#include "../cuda4dnn/csl/span.hpp"
using namespace cv::dnn::cuda4dnn::csl;
using namespace cv::dnn::cuda4dnn::csl::device;
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T, class EltwiseOp, class ActivationOp, std::size_t N>
__global__ void eltwise_op_generic_op_vec(Span<T> output, View<T> x, View<T> y, const typename EltwiseOp::Params eltwise_params, const typename ActivationOp::Params act_params) {
using vector_type = get_vector_type_t<T, N>;
auto output_vPtr = vector_type::get_pointer(output.data());
auto x_vPtr = vector_type::get_pointer(x.data());
auto y_vPtr = vector_type::get_pointer(y.data());
EltwiseOp eltwise_op(eltwise_params);
ActivationOp activation_op(act_params);
for (auto i : grid_stride_range(output.size() / vector_type::size())) {
vector_type vec_x, vec_y;
v_load(vec_x, x_vPtr[i]);
v_load(vec_y, y_vPtr[i]);
for(int j = 0; j < vec_x.size(); j++)
vec_x.data[j] = activation_op(eltwise_op(vec_x.data[j], vec_y.data[j]));
v_store(output_vPtr[i], vec_x);
}
}
}
template <class T, class EltwiseOp, class ActivationOp, std::size_t N> static
void launch_vectorized_eltwise_op_generic_op(const Stream& stream, Span<T> output, View<T> x, View<T> y, const typename EltwiseOp::Params& eltwise_params, const typename ActivationOp::Params& act_params) {
CV_Assert(is_fully_aligned<T>(output, N));
CV_Assert(is_fully_aligned<T>(x, N));
CV_Assert(is_fully_aligned<T>(y, N));
auto kernel = raw::eltwise_op_generic_op_vec<T, EltwiseOp, ActivationOp, N>;
auto policy = make_policy(kernel, output.size() / N, 0, stream);
launch_kernel(kernel, policy, output, x, y, eltwise_params, act_params);
}
template <class T, class EltwiseOp, class ActivationOp> static
void eltwise_op_generic_op(const Stream& stream, Span<T> output, View<T> x, View<T> y, const typename EltwiseOp::Params& eltwise_params = {}, const typename ActivationOp::Params& act_params = {}) {
CV_Assert(output.size() == x.size());
CV_Assert(output.size() == y.size());
if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(x, 4) && is_fully_aligned<T>(y, 4)) {
launch_vectorized_eltwise_op_generic_op<T, EltwiseOp, ActivationOp, 4>(stream, output, x, y, eltwise_params, act_params);
} else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(x, 2) && is_fully_aligned<T>(y, 4)) {
launch_vectorized_eltwise_op_generic_op<T, EltwiseOp, ActivationOp, 2>(stream, output, x, y, eltwise_params, act_params);
} else {
launch_vectorized_eltwise_op_generic_op<T, EltwiseOp, ActivationOp, 1>(stream, output, x, y, eltwise_params, act_params);
}
}
template <class T>
void eltwise_sum_2_relu(const Stream& stream, Span<T> output, View<T> x, View<T> y, T slope) {
eltwise_op_generic_op<T, SumFunctor<T>, ReLUFunctor<T>>(stream, output, x, y, {}, {slope});
}
template <class T>
void eltwise_sum_2_clipped_relu(const Stream& stream, Span<T> output, View<T> x, View<T> y, T floor, T ceiling) {
CV_Assert(static_cast<double>(floor) <= static_cast<double>(ceiling));
eltwise_op_generic_op<T, SumFunctor<T>, ClippedReLUFunctor<T>>(stream, output, x, y, {}, {floor, ceiling});
}
template <class T>
void eltwise_sum_2_tanh(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
eltwise_op_generic_op<T, SumFunctor<T>, TanHFunctor<T>>(stream, output, x, y);
}
template <class T>
void eltwise_sum_2_swish(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
eltwise_op_generic_op<T, SumFunctor<T>, SwishFunctor<T>>(stream, output, x, y);
}
template <class T>
void eltwise_sum_2_mish(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
eltwise_op_generic_op<T, SumFunctor<T>, MishFunctor<T>>(stream, output, x, y);
}
template <class T>
void eltwise_sum_2_sigmoid(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
eltwise_op_generic_op<T, SumFunctor<T>, SigmoidFunctor<T>>(stream, output, x, y);
}
template <class T>
void eltwise_sum_2_power(const Stream& stream, Span<T> output, View<T> x, View<T> y, T exp, T scale, T shift) {
eltwise_op_generic_op<T, SumFunctor<T>, PowerFunctor<T>>(stream, output, x, y, {}, {exp, scale, shift});
}
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template void eltwise_sum_2_relu<__half>(const Stream&, Span<__half>, View<__half>, View<__half>, __half);
template void eltwise_sum_2_clipped_relu<__half>(const Stream&, Span<__half>, View<__half>, View<__half>, __half, __half);
template void eltwise_sum_2_tanh<__half>(const Stream&, Span<__half>, View<__half>, View<__half>);
template void eltwise_sum_2_swish<__half>(const Stream&, Span<__half>, View<__half>, View<__half>);
template void eltwise_sum_2_mish<__half>(const Stream&, Span<__half>, View<__half>, View<__half>);
template void eltwise_sum_2_sigmoid<__half>(const Stream&, Span<__half>, View<__half>, View<__half>);
template void eltwise_sum_2_power<__half>(const Stream&, Span<__half>, View<__half>, View<__half>, __half, __half, __half);
#endif
template void eltwise_sum_2_relu<float>(const Stream&, Span<float>, View<float>, View<float>, float);
template void eltwise_sum_2_clipped_relu<float>(const Stream&, Span<float>, View<float>, View<float>, float, float);
template void eltwise_sum_2_tanh<float>(const Stream&, Span<float>, View<float>, View<float>);
template void eltwise_sum_2_swish<float>(const Stream&, Span<float>, View<float>, View<float>);
template void eltwise_sum_2_mish<float>(const Stream&, Span<float>, View<float>, View<float>);
template void eltwise_sum_2_sigmoid<float>(const Stream&, Span<float>, View<float>, View<float>);
template void eltwise_sum_2_power<float>(const Stream&, Span<float>, View<float>, View<float>, float, float, float);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */

@ -21,77 +21,77 @@ using namespace cv::dnn::cuda4dnn::csl::device;
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T, class Functor, std::size_t N, class ...FunctorArgs>
__global__ void eltwise_op_vec(Span<T> output, View<T> x, View<T> y, FunctorArgs ...functorArgs) {
template <class T, class EltwiseOp, std::size_t N>
__global__ void eltwise_op_vec(Span<T> output, View<T> x, View<T> y, const typename EltwiseOp::Params params) {
using vector_type = get_vector_type_t<T, N>;
auto output_vPtr = vector_type::get_pointer(output.data());
auto x_vPtr = vector_type::get_pointer(x.data());
auto y_vPtr = vector_type::get_pointer(y.data());
Functor functor(functorArgs...);
EltwiseOp eltwise_op(params);
for (auto i : grid_stride_range(output.size() / vector_type::size())) {
vector_type vec_x, vec_y;
v_load(vec_x, x_vPtr[i]);
v_load(vec_y, y_vPtr[i]);
for (int j = 0; j < vector_type::size(); j++)
vec_x.data[j] = functor(vec_x.data[j], vec_y.data[j]);
vec_x.data[j] = eltwise_op(vec_x.data[j], vec_y.data[j]);
v_store(output_vPtr[i], vec_x);
}
}
}
template <class T, template <class> class EltwiseOp, std::size_t N, class ...EltwiseOpArgs> static
void launch_vectorized_eltwise_op(const Stream& stream, Span<T> output, View<T> x, View<T> y, EltwiseOpArgs ...eltwiseOpArgs) {
template <class T, class EltwiseOp, std::size_t N> static
void launch_vectorized_eltwise_op(const Stream& stream, Span<T> output, View<T> x, View<T> y, const typename EltwiseOp::Params& params) {
CV_Assert(x.size() == y.size());
CV_Assert(x.size() == output.size());
CV_Assert(is_fully_aligned<T>(output, N));
CV_Assert(is_fully_aligned<T>(x, N));
CV_Assert(is_fully_aligned<T>(y, N));
auto kernel = raw::eltwise_op_vec<T, EltwiseOp<T>, N, EltwiseOpArgs...>;
auto kernel = raw::eltwise_op_vec<T, EltwiseOp, N>;
auto policy = make_policy(kernel, output.size() / N, 0, stream);
launch_kernel(kernel, policy, output, x, y, eltwiseOpArgs...);
launch_kernel(kernel, policy, output, x, y, params);
}
template <class T, template <class> class EltwiseOp, class ...EltwiseOpArgs> static
void eltwise_op(const Stream& stream, Span<T> output, View<T> x, View<T> y, EltwiseOpArgs ...eltwiseOpArgs) {
template <class T, class EltwiseOp> static
void eltwise_op(const Stream& stream, Span<T> output, View<T> x, View<T> y, const typename EltwiseOp::Params& params = {}) {
CV_Assert(x.size() == y.size());
CV_Assert(x.size() == output.size());
if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(x, 4) && is_fully_aligned<T>(y, 4)) {
launch_vectorized_eltwise_op<T, EltwiseOp, 4>(stream, output, x, y, eltwiseOpArgs...);
launch_vectorized_eltwise_op<T, EltwiseOp, 4>(stream, output, x, y, params);
} else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(x, 2) && is_fully_aligned<T>(y, 2)) {
launch_vectorized_eltwise_op<T, EltwiseOp, 2>(stream, output, x, y, eltwiseOpArgs...);
launch_vectorized_eltwise_op<T, EltwiseOp, 2>(stream, output, x, y, params);
} else {
launch_vectorized_eltwise_op<T, EltwiseOp, 1>(stream, output, x, y, eltwiseOpArgs...);
launch_vectorized_eltwise_op<T, EltwiseOp, 1>(stream, output, x, y, params);
}
}
template <class T>
void eltwise_max_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
eltwise_op<T, max_functor>(stream, output, x, y);
eltwise_op<T, MaxFunctor<T>>(stream, output, x, y);
}
template <class T>
void eltwise_sum_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
eltwise_op<T, sum_functor>(stream, output, x, y);
eltwise_op<T, SumFunctor<T>>(stream, output, x, y);
}
template <class T>
void eltwise_sum_coeff_2(const Stream& stream, Span<T> output, T coeff_x, View<T> x, T coeff_y, View<T> y) {
eltwise_op<T, scaled_sum_functor>(stream, output, x, y, coeff_x, coeff_y);
eltwise_op<T, ScaledSumFunctor<T>>(stream, output, x, y, {coeff_x, coeff_y});
}
template <class T>
void eltwise_prod_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
eltwise_op<T, product_functor>(stream, output, x, y);
eltwise_op<T, ProductFunctor<T>>(stream, output, x, y);
}
template <class T>
void eltwise_div_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
eltwise_op<T, div_functor>(stream, output, x, y);
eltwise_op<T, DivFunctor<T>>(stream, output, x, y);
}
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)

@ -9,27 +9,87 @@
#include "math.hpp"
#include "../cuda4dnn/csl/nvcc_defs.hpp"
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
template <class T>
struct abs_functor {
__device__ T operator()(T value) {
using csl::device::abs;
return abs(value);
struct IdentityFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() { }
};
CUDA4DNN_DEVICE IdentityFunctor() { }
CUDA4DNN_DEVICE IdentityFunctor(const Params& params) { }
CUDA4DNN_DEVICE T operator()(T value) {
return value;
};
};
template <class T>
struct ReLUFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() : slope(0) { }
CUDA4DNN_HOST_DEVICE Params(T slope_) : slope(slope_) { }
T slope;
};
CUDA4DNN_DEVICE ReLUFunctor() : ReLUFunctor(Params{}) { }
CUDA4DNN_DEVICE ReLUFunctor(const Params& params) : slope(params.slope) { }
CUDA4DNN_DEVICE T operator()(T value) {
using csl::device::log1pexp;
return value >= T(0) ? value : slope * value;
}
T slope;
};
template <class T>
struct tanh_functor {
__device__ T operator()(T value) {
struct ClippedReLUFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() : floor(0), ceiling(6) { }
CUDA4DNN_HOST_DEVICE Params(T floor_, T ceiling_) : floor(floor_), ceiling(ceiling_) { }
T floor, ceiling;
};
CUDA4DNN_DEVICE ClippedReLUFunctor() : ClippedReLUFunctor(Params{}) { }
CUDA4DNN_DEVICE ClippedReLUFunctor(const Params& params) : floor{params.floor}, ceiling{params.ceiling} { }
CUDA4DNN_DEVICE T operator()(T value) {
using csl::device::clamp;
return clamp(value, floor, ceiling);
}
T floor, ceiling;
};
template <class T>
struct TanHFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() { }
};
CUDA4DNN_DEVICE TanHFunctor() { }
CUDA4DNN_DEVICE TanHFunctor(const Params& params) { }
CUDA4DNN_DEVICE T operator()(T value) {
using csl::device::tanh;
return tanh(value);
}
};
template <class T>
struct swish_functor {
__device__ T operator()(T value) {
struct SwishFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() { }
};
CUDA4DNN_DEVICE SwishFunctor() { }
CUDA4DNN_DEVICE SwishFunctor(const Params& params) { }
CUDA4DNN_DEVICE T operator()(T value) {
// f(x) = x * sigmoid(x)
using csl::device::fast_divide;
using csl::device::fast_exp;
@ -38,8 +98,15 @@ struct swish_functor {
};
template <class T>
struct mish_functor {
__device__ T operator()(T value) {
struct MishFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() { }
};
CUDA4DNN_DEVICE MishFunctor() { }
CUDA4DNN_DEVICE MishFunctor(const Params& params) { }
CUDA4DNN_DEVICE T operator()(T value) {
using csl::device::tanh;
using csl::device::log1pexp;
return value * tanh(log1pexp(value));
@ -47,8 +114,15 @@ struct mish_functor {
};
template <>
struct mish_functor<float> {
__device__ float operator()(float value) {
struct MishFunctor<float> {
struct Params {
CUDA4DNN_HOST_DEVICE Params() { }
};
CUDA4DNN_DEVICE MishFunctor() { }
CUDA4DNN_DEVICE MishFunctor(const Params& params) { }
CUDA4DNN_DEVICE float operator()(float value) {
// f(x) = x * tanh(log1pexp(x));
using csl::device::fast_divide;
using csl::device::fast_exp;
@ -63,63 +137,90 @@ struct mish_functor<float> {
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template <>
struct mish_functor<__half> {
__device__ __half operator()(__half value) {
return mish_functor<float>()(value);
struct MishFunctor<__half> {
struct Params {
CUDA4DNN_HOST_DEVICE Params() { }
};
CUDA4DNN_DEVICE MishFunctor() { }
CUDA4DNN_DEVICE MishFunctor(const Params& params) { }
CUDA4DNN_DEVICE __half operator()(__half value) {
return MishFunctor<float>()(value);
}
};
#endif
template <class T>
struct sigmoid_functor {
__device__ T operator()(T value) {
struct SigmoidFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() { }
};
CUDA4DNN_DEVICE SigmoidFunctor() { }
CUDA4DNN_DEVICE SigmoidFunctor(const Params& params) { }
CUDA4DNN_DEVICE T operator()(T value) {
using csl::device::fast_sigmoid;
return fast_sigmoid(value);
}
};
template <class T>
struct bnll_functor {
__device__ T operator()(T value) {
using csl::device::log1pexp;
return value > T(0) ? value + log1pexp(-value) : log1pexp(value);
}
};
struct ELUFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() { }
};
template <class T>
struct elu_functor {
__device__ T operator()(T value) {
CUDA4DNN_DEVICE ELUFunctor() { }
CUDA4DNN_DEVICE ELUFunctor(const Params& params) { }
CUDA4DNN_DEVICE T operator()(T value) {
using csl::device::expm1;
return value >= T(0) ? value : expm1(value);
}
};
template <class T>
struct relu_functor {
__device__ relu_functor(T slope_) : slope{slope_} { }
__device__ T operator()(T value) {
using csl::device::log1pexp;
return value >= T(0) ? value : slope * value;
}
struct AbsFunctor {
struct Params { };
T slope;
CUDA4DNN_DEVICE AbsFunctor() { }
CUDA4DNN_DEVICE AbsFunctor(const Params& params) { }
CUDA4DNN_DEVICE T operator()(T value) {
using csl::device::abs;
return abs(value);
}
};
template <class T>
struct clipped_relu_functor {
__device__ clipped_relu_functor(T floor_, T ceiling_) : floor{floor_}, ceiling{ceiling_} { }
__device__ T operator()(T value) {
using csl::device::clamp;
return clamp(value, floor, ceiling);
}
struct BNLLFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() { }
};
T floor, ceiling;
CUDA4DNN_DEVICE BNLLFunctor() { }
CUDA4DNN_DEVICE BNLLFunctor(const Params& params) { }
CUDA4DNN_DEVICE T operator()(T value) {
using csl::device::log1pexp;
return value > T(0) ? value + log1pexp(-value) : log1pexp(value);
}
};
template <class T>
struct power_functor {
__device__ power_functor(T exp_, T scale_, T shift_) : exp{exp_}, scale{scale_}, shift{shift_} { }
__device__ T operator()(T value) {
struct PowerFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() : exp(1), scale(1), shift(0) { }
CUDA4DNN_HOST_DEVICE Params(T exp_, T scale_, T shift_) : exp(exp_), scale(scale_), shift(shift_) { }
T exp, scale, shift;
};
CUDA4DNN_DEVICE PowerFunctor() : PowerFunctor(Params{}) { }
CUDA4DNN_DEVICE PowerFunctor(const Params& params) : exp{params.exp}, scale{params.scale}, shift{params.shift} { }
CUDA4DNN_DEVICE T operator()(T value) {
using csl::device::pow;
return pow(shift + scale * value, exp);
}
@ -128,36 +229,70 @@ struct power_functor {
};
template <class T>
struct max_functor {
__device__ T operator()(T x, T y) {
struct MaxFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() { }
};
CUDA4DNN_DEVICE MaxFunctor() { }
CUDA4DNN_DEVICE MaxFunctor(const Params& params) { }
CUDA4DNN_DEVICE T operator()(T x, T y) {
using csl::device::max;
return max(x, y);
}
};
template <class T>
struct sum_functor {
__device__ T operator()(T x, T y) { return x + y; }
struct SumFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() { }
};
CUDA4DNN_DEVICE SumFunctor() { }
CUDA4DNN_DEVICE SumFunctor(const Params& params) { }
CUDA4DNN_DEVICE T operator()(T x, T y) { return x + y; }
};
template <class T>
struct scaled_sum_functor {
__device__ scaled_sum_functor(T scale_x_, T scale_y_)
: scale_x{scale_x_}, scale_y{scale_y_} { }
struct ScaledSumFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() : scale_x(1), scale_y(1) { }
CUDA4DNN_HOST_DEVICE Params(T scale_x_, T scale_y_) : scale_x(scale_x_), scale_y(scale_y_) { }
T scale_x, scale_y;
};
CUDA4DNN_DEVICE ScaledSumFunctor() : scale_x(1), scale_y(1) { }
CUDA4DNN_DEVICE ScaledSumFunctor(const Params& params) : scale_x{params.scale_x}, scale_y{params.scale_y} { }
__device__ T operator()(T x, T y) { return scale_x * x + scale_y * y; }
CUDA4DNN_DEVICE T operator()(T x, T y) { return scale_x * x + scale_y * y; }
T scale_x, scale_y;
};
template <class T>
struct product_functor {
__device__ T operator()(T x, T y) { return x * y; }
struct ProductFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() { }
};
CUDA4DNN_DEVICE ProductFunctor() { }
CUDA4DNN_DEVICE ProductFunctor(const Params& params) { }
CUDA4DNN_DEVICE T operator()(T x, T y) { return x * y; }
};
template <class T>
struct div_functor {
__device__ T operator()(T x, T y) { return x / y; }
struct DivFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() { }
};
CUDA4DNN_DEVICE DivFunctor() { }
CUDA4DNN_DEVICE DivFunctor(const Params& params) { }
CUDA4DNN_DEVICE T operator()(T x, T y) { return x / y; }
};
}}}} /* namespace cv::dnn::cuda4dnn::kernels */

@ -33,7 +33,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
inner_size /= vector_type::size();
for (auto i : grid_stride_range(output.size() / vector_type::size())) {
const index_type bias_idx = (i / inner_size) % static_cast<size_type>(bias.size());
const index_type bias_idx = (i / inner_size) % bias.size();
vector_type vec;
v_load(vec, input_vPtr[i]);
@ -53,7 +53,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
inner_size /= vector_type::size();
for (auto i : grid_stride_range(output.size() / vector_type::size())) {
const index_type scale_idx = (i / inner_size) % static_cast<size_type>(weights.size());
const index_type scale_idx = (i / inner_size) % weights.size();
vector_type vec;
v_load(vec, input_vPtr[i]);
@ -90,7 +90,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
inner_size /= vector_type::size();
for (auto i : grid_stride_range(output.size() / vector_type::size())) {
const index_type scale_idx = (i / inner_size) % static_cast<size_type>(weights.size());
const index_type scale_idx = (i / inner_size) % weights.size();
vector_type vec;
v_load(vec, input_vPtr[i]);

@ -537,6 +537,101 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
outputDesc.get(), outputPtr.get()));
}
/** @brief performs convolution, bias addition, eltwise addition and activation simultaneously
*
* dstValue = act(alpha1 * conv(input) + bias + alpha2 * eltwise)
*
* @tparam T convolution element type (must be `half` or `float`)
*
* @param handle valid cuDNN Handle
* @param convDesc convolution description
* @param convAlgo algorithm to use for convolution
* @param workspace workspace memory which meets the requirements of \p convAlgo
* @param filterDesc filter descriptor
* @param[in] filterPtr pointer to device memory containing the filters
* @param alpha1 convolution scale factor
* @param inputDesc tensor descriptor describing the input
* @param[in] inputPtr pointer to input tensor in device memory
* @param biasDesc tensor descriptor describing the bias
* @param[in] biasPtr pointer to bias tensor in device memory
* @param alpha2 eltwise scale factor
* @param eltwiseDesc tensor descriptor describing the eltwise tensor
* @param[in] eltwisePtr pointer to the eltwise tensor in device memory
* @param actDesc activation descriptor
* @param outputDesc tensor descriptor describing the output
* @param[out] outputPtr pointer to output tensor in device memory
*
* Exception Guarantee: Basic
*/
template <class T>
void convolve_with_bias_eltwise_activation(
const Handle& handle,
T alpha1,
const ConvolutionDescriptor<T>& convDesc,
const ConvolutionAlgorithm<T>& convAlgo,
WorkspaceInstance workspace,
const FilterDescriptor<T>& filterDesc,
DevicePtr<const T> filterPtr,
const TensorDescriptor<T>& inputDesc,
DevicePtr<const T> inputPtr,
const TensorDescriptor<T>& biasDesc,
DevicePtr<const T> biasPtr,
T alpha2,
const TensorDescriptor<T>& eltwiseDesc,
DevicePtr<const T> eltwisePtr,
const ActivationDescriptor& actDesc,
const TensorDescriptor<T>& outputDesc,
DevicePtr<T> outputPtr)
{
CV_Assert(handle);
CUDA4DNN_CHECK_CUDNN(cudnnConvolutionBiasActivationForward(
handle.get(),
&alpha1, inputDesc.get(), inputPtr.get(),
filterDesc.get(), filterPtr.get(),
convDesc.get(), convAlgo.get(),
static_cast<void*>(workspace.get()), workspace.size_in_bytes(),
&alpha2, eltwiseDesc.get(), eltwisePtr.get(),
biasDesc.get(), biasPtr.get(),
actDesc.get(),
outputDesc.get(), outputPtr.get()));
}
template <> inline
void convolve_with_bias_eltwise_activation(
const Handle& handle,
half alpha1,
const ConvolutionDescriptor<half>& convDesc,
const ConvolutionAlgorithm<half>& convAlgo,
WorkspaceInstance workspace,
const FilterDescriptor<half>& filterDesc,
DevicePtr<const half> filterPtr,
const TensorDescriptor<half>& inputDesc,
DevicePtr<const half> inputPtr,
const TensorDescriptor<half>& biasDesc,
DevicePtr<const half> biasPtr,
half alpha2,
const TensorDescriptor<half>& eltwiseDesc,
DevicePtr<const half> eltwisePtr,
const ActivationDescriptor& actDesc,
const TensorDescriptor<half>& outputDesc,
DevicePtr<half> outputPtr)
{
CV_Assert(handle);
float alpha1_ = alpha1, alpha2_ = alpha2;
CUDA4DNN_CHECK_CUDNN(cudnnConvolutionBiasActivationForward(
handle.get(),
&alpha1_, inputDesc.get(), inputPtr.get(),
filterDesc.get(), filterPtr.get(),
convDesc.get(), convAlgo.get(),
static_cast<void*>(workspace.get()), workspace.size_in_bytes(),
&alpha2_, eltwiseDesc.get(), eltwisePtr.get(),
biasDesc.get(), biasPtr.get(),
actDesc.get(),
outputDesc.get(), outputPtr.get()));
}
}}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */
#endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_CONVOLUTION_HPP */

@ -8,6 +8,8 @@
#include "pointer.hpp"
#include "nvcc_defs.hpp"
#include "../../cuda/types.hpp"
#include <cstddef>
#include <type_traits>
@ -24,17 +26,14 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
public:
using value_type = T;
using size_type = std::size_t;
using difference_type = std::ptrdiff_t;
using size_type = device::size_type;
using index_type = device::index_type;
using pointer = DevicePtr<value_type>;
using const_pointer = DevicePtr<typename std::add_const<value_type>::type>;
using reference = typename std::add_lvalue_reference<value_type>::type;
using const_reference = typename std::add_lvalue_reference<typename std::add_const<value_type>::type>;
using iterator = pointer;
using const_iterator = const_pointer;
Span() noexcept : ptr{ nullptr }, sz{ 0 } { }
CUDA4DNN_HOST_DEVICE Span(pointer first, pointer last) noexcept : ptr{ first }, sz{ last - first } { }
CUDA4DNN_HOST_DEVICE Span(pointer first, size_type count) noexcept : ptr{ first }, sz{ count } { }
@ -42,7 +41,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
CUDA4DNN_HOST_DEVICE size_type size() const noexcept { return sz; }
CUDA4DNN_HOST_DEVICE bool empty() const noexcept { return size() == 0; }
CUDA4DNN_DEVICE reference operator[](difference_type index) const { return ptr[index]; }
CUDA4DNN_DEVICE reference operator[](index_type index) const { return ptr[index]; }
CUDA4DNN_HOST_DEVICE pointer data() const noexcept { return ptr; }
template<class U = T, class V = typename std::add_const<U>::type,

@ -152,6 +152,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
/* bias and activation (only RELU supported) */
std::vector<std::size_t> bias_shape;
ActivationType activation_type; /* MUST BE identity if there is no bias and ReLU if there is bias */
bool eltwise;
};
Convolution() = default;
@ -164,19 +165,21 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
filterDesc = FilterDescriptor(params.filter_shape);
convDesc = ConvolutionDescriptor(params.padding, params.stride, params.dilation, params.groups);
std::vector<int> output_dims;
getConvolutionForwardOutputDim(convDesc, filterDesc, inputTensorDesc, output_dims);
outputTensorDesc = TensorDescriptor(output_dims);
algo = ConvolutionAlgorithm(cudnnHandle, convDesc, filterDesc, inputTensorDesc, outputTensorDesc);
if (!params.bias_shape.empty()) {
CV_Assert(params.activation_type == ActivationType::RELU);
biasTensorDesc = TensorDescriptor(params.bias_shape);
if (params.eltwise)
eltwiseTensorDesc = TensorDescriptor(output_dims);
activationDesc = ActivationDescriptor(params.activation_type, 0.0);
} else {
CV_Assert(params.activation_type == ActivationType::IDENTITY);
}
std::vector<int> output_dims;
getConvolutionForwardOutputDim(convDesc, filterDesc, inputTensorDesc, output_dims);
outputTensorDesc = TensorDescriptor(output_dims);
algo = ConvolutionAlgorithm(cudnnHandle, convDesc, filterDesc, inputTensorDesc, outputTensorDesc);
}
Convolution& operator=(const Convolution&) = delete;
@ -208,6 +211,19 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
);
}
void convolve_with_bias_eltwise_activation(TensorSpan<T> output, TensorView<T> input, TensorView<T> filters, TensorView<T> bias, TensorView<T> eltwise, WorkspaceInstance scratchpad) {
cudnn::convolve_with_bias_eltwise_activation<T>(
cudnnHandle,
1.0, convDesc, algo, scratchpad,
filterDesc, filters.get(),
inputTensorDesc, input.get(),
biasTensorDesc, bias.get(),
1.0, eltwiseTensorDesc, eltwise.get(),
activationDesc,
outputTensorDesc, output.get()
);
}
private:
cudnn::Handle cudnnHandle;
TensorDescriptor inputTensorDesc, outputTensorDesc;
@ -215,6 +231,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
ConvolutionDescriptor convDesc;
ConvolutionAlgorithm algo;
TensorDescriptor biasTensorDesc;
TensorDescriptor eltwiseTensorDesc;
ActivationDescriptor activationDesc;
};

@ -0,0 +1,40 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_ACTIVATION_ELTWISE_HPP
#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_ACTIVATION_ELTWISE_HPP
#include "../csl/stream.hpp"
#include "../csl/span.hpp"
#include <cstddef>
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
/* inplace_output = activation(inplace_output) + eltwise */
template <class T>
void relu_eltwise_sum_2_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, csl::View<T> eltwise, T slope);
template <class T>
void clipped_relu_eltwise_sum_2_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, csl::View<T> eltwise, T floor, T ceiling);
template <class T>
void tanh_eltwise_sum_2_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, csl::View<T> eltwise);
template <class T>
void swish_eltwise_sum_2_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, csl::View<T> eltwise);
template <class T>
void mish_eltwise_sum_2_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, csl::View<T> eltwise);
template <class T>
void sigmoid_eltwise_sum_2_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, csl::View<T> eltwise);
template <class T>
void power_eltwise_sum_2_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, csl::View<T> eltwise, T exp, T scale, T shift);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_ACTIVATION_ELTWISE_HPP */

@ -13,7 +13,13 @@
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
template <class T>
void abs(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input);
void relu(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input, T slope);
template <class T>
void clipped_relu(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input, T floor, T ceiling);
template <class T>
void axiswise_relu(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input, std::size_t inner_size, csl::View<T> slope);
template <class T>
void tanh(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input);
@ -27,20 +33,14 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
template <class T>
void sigmoid(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input);
template <class T>
void bnll(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input);
template <class T>
void elu(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input);
template <class T>
void relu(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input, T slope);
template <class T>
void clipped_relu(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input, T floor, T ceiling);
void abs(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input);
template <class T>
void axiswise_relu(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input, std::size_t inner_size, csl::View<T> slope);
void bnll(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input);
template <class T>
void power(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input, T exp, T scale, T shift);

@ -19,19 +19,19 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
void biasN_clipped_relu_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias, T floor, T ceiling);
template <class T>
void biasN_power_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias, T exp, T scale, T shift);
void biasN_tanh_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias);
template <class T>
void biasN_tanh_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias);
void biasN_swish_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias);
template <class T>
void biasN_sigmoid_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias);
void biasN_mish_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias);
template <class T>
void biasN_swish_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias);
void biasN_sigmoid_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias);
template <class T>
void biasN_mish_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias);
void biasN_power_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias, T exp, T scale, T shift);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */

@ -0,0 +1,42 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_BIAS_ACTIVATION_ELTWISE_HPP
#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_BIAS_ACTIVATION_ELTWISE_HPP
#include "../csl/stream.hpp"
#include "../csl/span.hpp"
#include <cstddef>
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
/* inplace_output = activation(inplace_output + bias) + eltwise
* broadcasting on `bias` is allowed but not on `eltwise`
*/
template <class T>
void biasN_relu_eltwise_sum_2_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias, csl::View<T> eltwise, T slope);
template <class T>
void biasN_clipped_relu_eltwise_sum_2_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias, csl::View<T> eltwise, T floor, T ceiling);
template <class T>
void biasN_tanh_eltwise_sum_2_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias, csl::View<T> eltwise);
template <class T>
void biasN_sigmoid_eltwise_sum_2_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias, csl::View<T> eltwise);
template <class T>
void biasN_swish_eltwise_sum_2_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias, csl::View<T> eltwise);
template <class T>
void biasN_mish_eltwise_sum_2_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias, csl::View<T> eltwise);
template <class T>
void biasN_power_eltwise_sum_2_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias, csl::View<T> eltwise, T exp, T scale, T shift);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_BIAS_ACTIVATION_ELTWISE_HPP */

@ -0,0 +1,45 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_BIAS_ELTWISE_ACTIVATION_HPP
#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_BIAS_ELTWISE_ACTIVATION_HPP
#include "../csl/stream.hpp"
#include "../csl/span.hpp"
#include <cstddef>
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
/* inplace_output = activation(inplace_output + bias + eltwise)
* broadcasting on `bias` is allowed but not on `eltwise`
*/
template <class T>
void biasN_eltwise_sum_2_identity_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias, csl::View<T> eltwise);
template <class T>
void biasN_eltwise_sum_2_relu_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias, csl::View<T> eltwise, T slope);
template <class T>
void biasN_eltwise_sum_2_clipped_relu_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias, csl::View<T> eltwise, T floor, T ceiling);
template <class T>
void biasN_eltwise_sum_2_tanh_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias, csl::View<T> eltwise);
template <class T>
void biasN_eltwise_sum_2_swish_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias, csl::View<T> eltwise);
template <class T>
void biasN_eltwise_sum_2_mish_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias, csl::View<T> eltwise);
template <class T>
void biasN_eltwise_sum_2_sigmoid_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias, csl::View<T> eltwise);
template <class T>
void biasN_eltwise_sum_2_power_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias, csl::View<T> eltwise, T exp, T scale, T shift);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_BIAS_ELTWISE_ACTIVATION_HPP */

@ -0,0 +1,40 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_ELTWISE_ACTIVATION_HPP
#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_ELTWISE_ACTIVATION_HPP
#include "../csl/stream.hpp"
#include "../csl/span.hpp"
#include <cstddef>
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
/* output = activation(x + y) */
template <class T>
void eltwise_sum_2_relu(const csl::Stream& stream, csl::Span<T> output, csl::View<T> x, csl::View<T> y, T slope);
template <class T>
void eltwise_sum_2_clipped_relu(const csl::Stream& stream, csl::Span<T> output, csl::View<T> x, csl::View<T> y, T floor, T ceiling);
template <class T>
void eltwise_sum_2_tanh(const csl::Stream& stream, csl::Span<T> output, csl::View<T> x, csl::View<T> y);
template <class T>
void eltwise_sum_2_swish(const csl::Stream& stream, csl::Span<T> output, csl::View<T> x, csl::View<T> y);
template <class T>
void eltwise_sum_2_mish(const csl::Stream& stream, csl::Span<T> output, csl::View<T> x, csl::View<T> y);
template <class T>
void eltwise_sum_2_sigmoid(const csl::Stream& stream, csl::Span<T> output, csl::View<T> x, csl::View<T> y);
template <class T>
void eltwise_sum_2_power(const csl::Stream& stream, csl::Span<T> output, csl::View<T> x, csl::View<T> y, T exp, T scale, T shift);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_ELTWISE_ACTIVATION_HPP */

@ -11,9 +11,16 @@
#include "../csl/stream.hpp"
#include "../csl/tensor.hpp"
#include "../csl/tensor_ops.hpp"
#include "../kernels/scale_shift.hpp"
#include "../kernels/activations.hpp"
#include "../kernels/activation_eltwise.hpp"
#include "../kernels/bias_activation.hpp"
#include "../kernels/bias_eltwise_activation.hpp"
#include "../kernels/bias_activation_eltwise.hpp"
#include "../kernels/activation_eltwise.hpp"
#include "../kernels/eltwise_activation.hpp"
#include "../kernels/eltwise_ops.hpp"
#include <opencv2/core.hpp>
@ -47,11 +54,21 @@ namespace cv { namespace dnn { namespace cuda4dnn {
/* group count for grouped convolution */
std::size_t groups;
enum class FusionMode {
NONE,
ACTIVATION, /* act(conv) */
ELTWISE_SUM, /* eltwise + conv */ /* eltwise tensor is passed as second input to forward */
ELTWISE_SUM_THEN_ACTIVATION, /* act(conv + eltwise) */
ACTIVATION_THEN_ELTWISE_SUM, /* act(conv) + eltwise */
};
FusionMode fusion_mode;
enum class ActivationType {
IDENTITY,
RELU, /* uses value provided in `relu_negative_slope` */
CLIPPED_RELU, /* uses values provided in `crelu_floor` and `crelu_ceil` */
POWER, /* scale and shift fused beforehand (fuseWeights); only `power_exp` is handled by CUDA */
POWER, /* scale and shift fused with weights and bias; only `power_exp` is handled here */
TANH,
SIGMOID,
SWISH,
@ -67,16 +84,14 @@ namespace cv { namespace dnn { namespace cuda4dnn {
public:
using wrapper_type = GetCUDABackendWrapperType<T>;
ConvolutionOp(csl::Stream stream_, csl::cudnn::Handle handle, const ConvolutionConfiguration& config, const Mat& filters, const Mat& bias)
: stream(std::move(stream_)), cudnnHandle(std::move(handle))
ConvolutionOp(csl::Stream stream_, csl::cudnn::Handle handle_, const ConvolutionConfiguration& config, const Mat& filters, const Mat& bias)
: stream(std::move(stream_)), cudnnHandle(std::move(handle_))
{
const auto& kernel_size = config.kernel_size;
const auto& dilations = config.dilations;
const auto& strides = config.strides;
const auto convolution_order = kernel_size.size();
CV_Assert(convolution_order > 1);
CV_Assert(convolution_order == dilations.size());
CV_Assert(convolution_order == strides.size());
@ -87,8 +102,7 @@ namespace cv { namespace dnn { namespace cuda4dnn {
const auto groups = config.groups;
if (convolution_order > 3)
CV_Error(Error::StsNotImplemented, "Only 2D/3D convolution is supported.");
CV_Assert (1 < convolution_order && convolution_order <= 3);
const auto rank = input_shape.size();
const auto output_feature_maps = output_shape[1];
@ -204,32 +218,63 @@ namespace cv { namespace dnn { namespace cuda4dnn {
params.dilation = dilations;
params.groups = config.groups;
/* check if we can perform fused convolution using cudnn */
params.activation_type = csl::Convolution<T>::ActivationType::IDENTITY;
fusion_location = InternalFusionLocation::NATIVE;
if (!biasTensor.empty() &&
biasTensor.size() == output_feature_maps && /* cuDNN requirement */
config.activation_type == ConvolutionConfiguration::ActivationType::RELU &&
config.relu_negative_slope == 0.0)
{
fusion_location = InternalFusionLocation::CUDNN;
auto bias_shape = std::vector<std::size_t>(rank, 1);
bias_shape[1] = output_feature_maps;
params.bias_shape = bias_shape;
params.activation_type = csl::Convolution<T>::ActivationType::RELU;
}
convoluter = csl::Convolution<T>(cudnnHandle, params);
fusion_mode = config.fusion_mode;
activation = config.activation_type;
relu_negative_slope = config.relu_negative_slope;
crelu_floor = config.crelu_floor;
crelu_ceil = config.crelu_ceil;
power_exp = config.power_exp;
/* the scale and shift parameters of POWER have already been fused with weights and bias */
if (activation == ConvolutionConfiguration::ActivationType::POWER && power_exp == 1.0f)
activation = ConvolutionConfiguration::ActivationType::IDENTITY;
/* we normally use cuDNN for convolution and perform bias, activation and eltwise ops ourselves
* hence, the activation for cuDNN is IDENTITY by default
*/
fusion_location = InternalFusionLocation::NATIVE; /* i.e. we perform bias, act and eltwise */
params.eltwise = false;
params.activation_type = csl::Convolution<T>::ActivationType::IDENTITY;
/* cuDNN can fuse the operations with convolution in some cases; try if it's possible */
if (!biasTensor.empty() && 0 &&
biasTensor.size() == output_feature_maps && /* cuDNN requirement */
activation == ConvolutionConfiguration::ActivationType::RELU && /* cuDNN requirement */
relu_negative_slope == 0.0 && /* cuDNN requirement */
(fusion_mode == ConvolutionConfiguration::FusionMode::ACTIVATION || /* act(conv + bias) */
fusion_mode == ConvolutionConfiguration::FusionMode::ELTWISE_SUM_THEN_ACTIVATION) /* act(conv + bias + eltwise) */
)
{
bool do_not_fuse = false;
if(std::is_same<T, half>::value)
{
/* performance degrades if fused with tensor core based convolutions in most cases */
int device;
CUDA4DNN_CHECK_CUDA(cudaGetDevice(&device));
int cc_major;
CUDA4DNN_CHECK_CUDA(cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, device));
if (cc_major >= 7)
do_not_fuse = true;
}
if (!do_not_fuse)
{
fusion_location = InternalFusionLocation::CUDNN;
auto bias_shape = std::vector<std::size_t>(rank, 1);
bias_shape[1] = output_feature_maps;
params.bias_shape = bias_shape;
if (config.fusion_mode == ConvolutionConfiguration::FusionMode::ELTWISE_SUM_THEN_ACTIVATION)
params.eltwise = true;
params.activation_type = csl::Convolution<T>::ActivationType::RELU;
}
}
convoluter = csl::Convolution<T>(cudnnHandle, params);
csl::WorkspaceBuilder builder;
if (!transformed_shape.empty())
{
@ -246,7 +291,9 @@ namespace cv { namespace dnn { namespace cuda4dnn {
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
csl::Workspace& workspace) override
{
CV_Assert(inputs.size() == 1 && outputs.size() == 1);
/* input[0] = conv input, input[1] = bias (from fused eltwise layer) */
CV_Assert(inputs.size() == 1 || inputs.size() == 2);
CV_Assert(outputs.size() == 1);
csl::WorkspaceAllocator allocator(workspace);
@ -270,7 +317,16 @@ namespace cv { namespace dnn { namespace cuda4dnn {
{
try
{
convoluter.convolve_with_bias_activation(output, input, filtersTensor, biasTensor, conv_scratchpad);
if (fusion_mode == ConvolutionConfiguration::FusionMode::ACTIVATION)
convoluter.convolve_with_bias_activation(output, input, filtersTensor, biasTensor, conv_scratchpad);
else if (fusion_mode == ConvolutionConfiguration::FusionMode::ELTWISE_SUM_THEN_ACTIVATION)
{
auto eltwise_wrapper = inputs[1].dynamicCast<wrapper_type>();
auto eltwise = eltwise_wrapper->getView();
CV_Assert(is_shape_same(eltwise, output));
convoluter.convolve_with_bias_eltwise_activation(output, input, filtersTensor, biasTensor, eltwise, conv_scratchpad);
}
}
catch(const csl::cudnn::cuDNNException& ex)
{
@ -287,8 +343,100 @@ namespace cv { namespace dnn { namespace cuda4dnn {
if (fusion_location == InternalFusionLocation::NATIVE)
{
convoluter.convolve(output, input, filtersTensor, conv_scratchpad);
if (!biasTensor.empty())
if (fusion_mode == ConvolutionConfiguration::FusionMode::ELTWISE_SUM ||
fusion_mode == ConvolutionConfiguration::FusionMode::ELTWISE_SUM_THEN_ACTIVATION ||
fusion_mode == ConvolutionConfiguration::FusionMode::ACTIVATION_THEN_ELTWISE_SUM)
{
CV_Assert(inputs.size() == 2);
}
if (!biasTensor.empty() && inputs.size() == 2)
{
/* bias and eltwise */
CV_Assert(fusion_mode == ConvolutionConfiguration::FusionMode::ELTWISE_SUM ||
fusion_mode == ConvolutionConfiguration::FusionMode::ELTWISE_SUM_THEN_ACTIVATION ||
fusion_mode == ConvolutionConfiguration::FusionMode::ACTIVATION_THEN_ELTWISE_SUM);
auto eltwise_wrapper = inputs[1].dynamicCast<wrapper_type>();
auto eltwise = eltwise_wrapper->getView();
CV_Assert(is_shape_same(eltwise, output));
std::size_t inner_size = output.size_range(2, output.rank());
if (fusion_mode == ConvolutionConfiguration::FusionMode::ELTWISE_SUM)
{
kernels::biasN_eltwise_sum_2_identity_inplace<T>(stream, output, inner_size, biasTensor, eltwise);
}
else if (fusion_mode == ConvolutionConfiguration::FusionMode::ELTWISE_SUM_THEN_ACTIVATION)
{
/* activation(conv + bias + eltwise) */
switch (activation)
{
case ConvolutionConfiguration::ActivationType::IDENTITY:
kernels::biasN_eltwise_sum_2_identity_inplace<T>(stream, output, inner_size, biasTensor, eltwise);
break;
case ConvolutionConfiguration::ActivationType::RELU:
kernels::biasN_eltwise_sum_2_relu_inplace<T>(stream, output, inner_size, biasTensor, eltwise, relu_negative_slope);
break;
case ConvolutionConfiguration::ActivationType::CLIPPED_RELU:
kernels::biasN_eltwise_sum_2_clipped_relu_inplace<T>(stream, output, inner_size, biasTensor, eltwise, crelu_floor, crelu_ceil);
break;
case ConvolutionConfiguration::ActivationType::POWER:
kernels::biasN_eltwise_sum_2_power_inplace<T>(stream, output, inner_size, biasTensor, eltwise, power_exp, 1.0, 0.0);
break;
case ConvolutionConfiguration::ActivationType::TANH:
kernels::biasN_eltwise_sum_2_tanh_inplace<T>(stream, output, inner_size, biasTensor, eltwise);
break;
case ConvolutionConfiguration::ActivationType::SIGMOID:
kernels::biasN_eltwise_sum_2_sigmoid_inplace<T>(stream, output, inner_size, biasTensor, eltwise);
break;
case ConvolutionConfiguration::ActivationType::SWISH:
kernels::biasN_eltwise_sum_2_swish_inplace<T>(stream, output, inner_size, biasTensor, eltwise);
break;
case ConvolutionConfiguration::ActivationType::MISH:
kernels::biasN_eltwise_sum_2_mish_inplace<T>(stream, output, inner_size, biasTensor, eltwise);
break;
}
}
else if (fusion_mode == ConvolutionConfiguration::FusionMode::ACTIVATION_THEN_ELTWISE_SUM)
{
/* activation(conv + bias) + eltwise */
switch (activation)
{
case ConvolutionConfiguration::ActivationType::IDENTITY:
kernels::biasN_eltwise_sum_2_identity_inplace<T>(stream, output, inner_size, biasTensor, eltwise);
break;
case ConvolutionConfiguration::ActivationType::RELU:
kernels::biasN_relu_eltwise_sum_2_inplace<T>(stream, output, inner_size, biasTensor, eltwise, relu_negative_slope);
break;
case ConvolutionConfiguration::ActivationType::CLIPPED_RELU:
kernels::biasN_clipped_relu_eltwise_sum_2_inplace<T>(stream, output, inner_size, biasTensor, eltwise, crelu_floor, crelu_ceil);
break;
case ConvolutionConfiguration::ActivationType::POWER:
kernels::biasN_power_eltwise_sum_2_inplace<T>(stream, output, inner_size, biasTensor, eltwise, power_exp, 1.0, 0.0);
break;
case ConvolutionConfiguration::ActivationType::TANH:
kernels::biasN_tanh_eltwise_sum_2_inplace<T>(stream, output, inner_size, biasTensor, eltwise);
break;
case ConvolutionConfiguration::ActivationType::SIGMOID:
kernels::biasN_sigmoid_eltwise_sum_2_inplace<T>(stream, output, inner_size, biasTensor, eltwise);
break;
case ConvolutionConfiguration::ActivationType::SWISH:
kernels::biasN_swish_eltwise_sum_2_inplace<T>(stream, output, inner_size, biasTensor, eltwise);
break;
case ConvolutionConfiguration::ActivationType::MISH:
kernels::biasN_mish_eltwise_sum_2_inplace<T>(stream, output, inner_size, biasTensor, eltwise);
break;
}
}
}
else if (!biasTensor.empty() && inputs.size() == 1)
{
/* bias but no eltwise */
CV_Assert(fusion_mode == ConvolutionConfiguration::FusionMode::NONE ||
fusion_mode == ConvolutionConfiguration::FusionMode::ACTIVATION);
std::size_t inner_size = output.size_range(2, output.rank());
switch(activation)
{
@ -302,7 +450,7 @@ namespace cv { namespace dnn { namespace cuda4dnn {
kernels::biasN_clipped_relu_inplace<T>(stream, output, inner_size, biasTensor, crelu_floor, crelu_ceil);
break;
case ConvolutionConfiguration::ActivationType::POWER:
kernels::biasN_power_inplace<T>(stream, output, inner_size, biasTensor, power_exp, T(1.0), T(0.0));
kernels::biasN_power_inplace<T>(stream, output, inner_size, biasTensor, power_exp, 1.0, 0.0);
break;
case ConvolutionConfiguration::ActivationType::TANH:
kernels::biasN_tanh_inplace<T>(stream, output, inner_size, biasTensor);
@ -318,8 +466,90 @@ namespace cv { namespace dnn { namespace cuda4dnn {
break;
}
}
else
else if (biasTensor.empty() && inputs.size() == 2)
{
/* no bias but eltwise */
CV_Assert(fusion_mode == ConvolutionConfiguration::FusionMode::ELTWISE_SUM ||
fusion_mode == ConvolutionConfiguration::FusionMode::ELTWISE_SUM_THEN_ACTIVATION ||
fusion_mode == ConvolutionConfiguration::FusionMode::ACTIVATION_THEN_ELTWISE_SUM);
auto eltwise_wrapper = inputs[1].dynamicCast<wrapper_type>();
auto eltwise = eltwise_wrapper->getView();
CV_Assert(is_shape_same(eltwise, output));
/* we pass `eltwise` as `bias` (with `inner_size` as one) to bias-activation kernels */
if (fusion_mode == ConvolutionConfiguration::FusionMode::ELTWISE_SUM)
{
kernels::eltwise_sum_2<T>(stream, output, output, eltwise);
}
else if (fusion_mode == ConvolutionConfiguration::FusionMode::ELTWISE_SUM_THEN_ACTIVATION)
{
switch (activation)
{
case ConvolutionConfiguration::ActivationType::IDENTITY:
kernels::eltwise_sum_2<T>(stream, output, output, eltwise);
break;
case ConvolutionConfiguration::ActivationType::RELU:
kernels::eltwise_sum_2_relu<T>(stream, output, output, eltwise, relu_negative_slope);
break;
case ConvolutionConfiguration::ActivationType::CLIPPED_RELU:
kernels::eltwise_sum_2_clipped_relu<T>(stream, output, output, eltwise, crelu_floor, crelu_ceil);
break;
case ConvolutionConfiguration::ActivationType::POWER:
kernels::eltwise_sum_2_power<T>(stream, output, output, eltwise, power_exp, 1.0, 0.0);
break;
case ConvolutionConfiguration::ActivationType::TANH:
kernels::eltwise_sum_2_tanh<T>(stream, output, output, eltwise);
break;
case ConvolutionConfiguration::ActivationType::SIGMOID:
kernels::eltwise_sum_2_sigmoid<T>(stream, output, output, eltwise);
break;
case ConvolutionConfiguration::ActivationType::SWISH:
kernels::eltwise_sum_2_swish<T>(stream, output, output, eltwise);
break;
case ConvolutionConfiguration::ActivationType::MISH:
kernels::eltwise_sum_2_mish<T>(stream, output, output, eltwise);
break;
}
}
else if (fusion_mode == ConvolutionConfiguration::FusionMode::ACTIVATION_THEN_ELTWISE_SUM)
{
switch (activation)
{
case ConvolutionConfiguration::ActivationType::IDENTITY:
kernels::eltwise_sum_2<T>(stream, output, output, eltwise);
break;
case ConvolutionConfiguration::ActivationType::RELU:
kernels::relu_eltwise_sum_2_inplace<T>(stream, output, eltwise, relu_negative_slope);
break;
case ConvolutionConfiguration::ActivationType::CLIPPED_RELU:
kernels::clipped_relu_eltwise_sum_2_inplace<T>(stream, output, eltwise, crelu_floor, crelu_ceil);
break;
case ConvolutionConfiguration::ActivationType::POWER:
kernels::power_eltwise_sum_2_inplace<T>(stream, output, eltwise, power_exp, 1.0, 0.0);
break;
case ConvolutionConfiguration::ActivationType::TANH:
kernels::tanh_eltwise_sum_2_inplace<T>(stream, output, eltwise);
break;
case ConvolutionConfiguration::ActivationType::SIGMOID:
kernels::sigmoid_eltwise_sum_2_inplace<T>(stream, output, eltwise);
break;
case ConvolutionConfiguration::ActivationType::SWISH:
kernels::swish_eltwise_sum_2_inplace<T>(stream, output, eltwise);
break;
case ConvolutionConfiguration::ActivationType::MISH:
kernels::mish_eltwise_sum_2_inplace<T>(stream, output, eltwise);
break;
}
}
}
else if(biasTensor.empty() && inputs.size() == 1)
{
/* no bias and no eltwise */
CV_Assert(fusion_mode == ConvolutionConfiguration::FusionMode::NONE ||
fusion_mode == ConvolutionConfiguration::FusionMode::ACTIVATION);
switch(activation)
{
case ConvolutionConfiguration::ActivationType::IDENTITY:
@ -363,6 +593,7 @@ namespace cv { namespace dnn { namespace cuda4dnn {
std::size_t scratch_mem_in_bytes;
ConvolutionConfiguration::FusionMode fusion_mode;
ConvolutionConfiguration::ActivationType activation;
float relu_negative_slope, crelu_floor, crelu_ceil, power_exp;

@ -28,14 +28,28 @@ namespace cv { namespace dnn { namespace cuda4dnn {
DIV
};
class EltwiseOpBase : public CUDABackendNode {
public:
EltwiseOpBase(csl::Stream stream_, EltwiseOpType op_, std::vector<float> coeffs_)
: stream(std::move(stream_)), op(op_), coeffs(std::move(coeffs_))
{
}
protected:
csl::Stream stream;
public:
EltwiseOpType op;
std::vector<float> coeffs;
};
template <class T>
class EltwiseOp final : public CUDABackendNode {
class EltwiseOp final : public EltwiseOpBase {
public:
using wrapper_type = GetCUDABackendWrapperType<T>;
template <class V>
EltwiseOp(csl::Stream stream_, EltwiseOpType op_, std::vector<V> coeffs_)
: stream(std::move(stream_)), op{ op_ }, coeffs(std::begin(coeffs_), std::end(coeffs_))
EltwiseOp(csl::Stream stream_, EltwiseOpType op_, std::vector<float> coeffs_)
: EltwiseOpBase(std::move(stream_), op_, std::move(coeffs_))
{
}
@ -98,7 +112,7 @@ namespace cv { namespace dnn { namespace cuda4dnn {
else
{
/* if this is the first op, we must scale output too */
auto coeff_x = (i == 1) ? coeffs[0] : static_cast<T>(1.0);
T coeff_x = (i == 1) ? coeffs[0] : 1.0;
kernels::eltwise_sum_coeff_2<T>(stream, output, coeff_x, output, coeffs[i], input);
}
break;
@ -106,11 +120,6 @@ namespace cv { namespace dnn { namespace cuda4dnn {
}
}
}
private:
csl::Stream stream;
EltwiseOpType op;
std::vector<T> coeffs;
};
}}} /* namespace cv::dnn::cuda4dnn */

@ -46,6 +46,10 @@
#include "op_vkcom.hpp"
#include "op_cuda.hpp"
#ifdef HAVE_CUDA
#include "cuda4dnn/primitives/eltwise.hpp"
#endif
#include "halide_scheduler.hpp"
#include <set>
@ -2554,6 +2558,11 @@ struct Net::Impl : public detail::NetImplBase
LayerPin lpNext(ld.consumers[0].lid, 0);
while (nextData)
{
/* we use `tryFuse` member of convolution layer to fuse eltwise later
* it's not intended to be fused here; hence, we stop when we encounter eltwise
*/
if (preferableBackend == DNN_BACKEND_CUDA && ld.type == "Convolution" && nextData->type == "Eltwise")
break;
Ptr<Layer> nextLayer = nextData->layerInstance;
if (currLayer->tryFuse(nextLayer))
{
@ -2629,15 +2638,41 @@ struct Net::Impl : public detail::NetImplBase
break;
}
// fuse convolution layer followed by eltwise + relu
if ( IS_DNN_OPENCL_TARGET(preferableTarget) && ld.layerInstance->type == "Convolution" )
// OpenCL: fuse convolution layer followed by eltwise + relu
// CUDA: fuse convolution layer followed by eltwise (and optional activation)
if ((IS_DNN_OPENCL_TARGET(preferableTarget) || IS_DNN_CUDA_TARGET(preferableTarget)) &&
ld.layerInstance->type == "Convolution" )
{
Ptr<EltwiseLayer> nextEltwiseLayer;
if( nextData )
nextEltwiseLayer = nextData->layerInstance.dynamicCast<EltwiseLayer>();
if( !nextEltwiseLayer.empty() && pinsToKeep.count(lpNext) == 0 &&
nextData && nextData->inputBlobsId.size() == 2 )
#ifdef HAVE_CUDA
// CUDA backend supports fusion with eltwise sum (without variable channels)
// `nextEltwiseLayer` is reset if eltwise layer doesn't have a compatible configuration for fusion
if (IS_DNN_CUDA_TARGET(preferableTarget) && !nextEltwiseLayer.empty())
{
// we create a temporary backend node for eltwise layer to obtain the eltwise configuration
auto context = cudaInfo->context; /* make a copy so that initCUDA doesn't modify cudaInfo */
const auto node = nextData->layerInstance->initCUDA(&context, nextData->inputBlobsWrappers, nextData->outputBlobsWrappers);
const auto eltwiseNode = node.dynamicCast<cuda4dnn::EltwiseOpBase>();
if (eltwiseNode->op != cuda4dnn::EltwiseOpType::SUM || !eltwiseNode->coeffs.empty())
nextEltwiseLayer = Ptr<EltwiseLayer>();
// check for variable channels
auto& inputs = nextData->inputBlobs;
for (int i = 1; i < inputs.size(); ++i)
{
if (inputs[i]->size[1] != inputs[0]->size[1])
{
nextEltwiseLayer = Ptr<EltwiseLayer>();
break;
}
}
}
#endif
if (!nextEltwiseLayer.empty() && nextData && nextData->inputBlobsId.size() == 2)
{
LayerData *eltwiseData = nextData;
@ -2666,65 +2701,160 @@ struct Net::Impl : public detail::NetImplBase
}
CV_Assert(biasLayerData);
{
if( eltwiseData->consumers.size() == 1 )
// fuse eltwise + activation layer
// bias must already be computed to fuse => bias layer must appear before convolution
if (biasLayerData->id < ld.id)
{
// fuse eltwise + activation layer
if (biasLayerData->id < ld.id)
/* we can fuse activation if:
* => activation layer that follows is the only consumer of eltwise output
* => activation layer does not process multiple inputs
* => we do not require to keep the output of eltwise
*/
Ptr<ActivationLayer> nextFusabeleActivLayer;
if (eltwiseData->consumers.size() == 1 && pinsToKeep.count(lpNext) == 0)
{
nextData = &layers[eltwiseData->consumers[0].lid];
lpNext = LayerPin(eltwiseData->consumers[0].lid, 0);
Ptr<ActivationLayer> nextActivLayer;
if( nextData )
nextActivLayer = nextData->layerInstance.dynamicCast<ActivationLayer>();
if( !nextActivLayer.empty() && pinsToKeep.count(lpNext) == 0 &&
(!nextData->type.compare("ReLU") ||
!nextData->type.compare("ChannelsPReLU") ||
!nextData->type.compare("Power")) &&
currLayer->setActivation(nextActivLayer) )
if (pinsToKeep.count(lpNext) == 0 && nextData->outputBlobs.size() == 1)
nextFusabeleActivLayer = nextData->layerInstance.dynamicCast<ActivationLayer>();
}
else
{
// OCL backend cannot fuse in this case but the CUDA backend can continue with just eltwise
nextData = 0;
}
// the requirements of OCV OpenCL backend and CUDA backend are different
// we need to check them separately; hence, the fuse variables
bool fuse_eltwise = false, fuse_activation = false;
if (IS_DNN_OPENCL_TARGET(preferableTarget) && !nextFusabeleActivLayer.empty() &&
(!nextData->type.compare("ReLU") ||
!nextData->type.compare("ChannelsPReLU") ||
!nextData->type.compare("Power")) &&
currLayer->setActivation(nextFusabeleActivLayer))
{
fuse_eltwise = true;
fuse_activation = true;
}
if (IS_DNN_CUDA_TARGET(preferableTarget))
{
/* supported fusion options:
* => convolution + eltwise
* => activation(convolution) + eltwise
* > convolution + activation would have been fused already; we have to fuse eltwise
* => activation(convolution + eltwise)
* > fuse eltwise and then activation
*/
auto layer = nextEltwiseLayer.staticCast<Layer>();
if (currLayer->tryFuse(layer))
{
fuse_eltwise = true; /* eltwise was successfully fused */
if (!nextFusabeleActivLayer.empty())
{
if ((!nextData->type.compare("ReLU") ||
!nextData->type.compare("ReLU6") ||
!nextData->type.compare("Power") ||
!nextData->type.compare("TanH") ||
!nextData->type.compare("Sigmoid") ||
!nextData->type.compare("Swish") ||
!nextData->type.compare("Mish")) &&
currLayer->setActivation(nextFusabeleActivLayer))
{
// activation was fused
fuse_activation = true;
}
}
}
}
CV_Assert(!fuse_activation || fuse_eltwise); /* cannot fuse activation without eltwise */
if(fuse_eltwise && fuse_activation)
{
CV_Assert_N(biasLayerData->outputBlobsWrappers.size() == 1, ld.inputBlobsWrappers.size() == 1);
ld.inputBlobsWrappers.push_back(biasLayerData->outputBlobsWrappers[0]);
printf_(("\tfused with %s\n", nextEltwiseLayer->name.c_str()));
printf_(("\tfused with %s\n", nextFusabeleActivLayer->name.c_str()));
eltwiseData->skip = true;
nextData->skip = true;
// This optimization for cases like
// some_layer conv
// | |
// +-- eltwise --+
// |
// activ
// This way all the element-wise computations
// (i.e. some_layer+conv or some_layer*conv)
// would be done at [conv] layer. So we need to
// replace [conv]'s output blob to [eltwise]'s one
// considering that [activ] is an in-place layer.
// Also we need to move all the consumers' references.
// To prevent memory collisions (i.e. when input of
// [conv] and output of [eltwise] is the same blob)
// we allocate a new blob.
CV_Assert_N(ld.outputBlobs.size() == 1, ld.outputBlobsWrappers.size() == 1);
ld.outputBlobs[0] = ld.outputBlobs[0].clone();
ld.outputBlobsWrappers[0] = wrap(ld.outputBlobs[0]);
eltwiseData->outputBlobs = ld.outputBlobs;
nextData->outputBlobs = ld.outputBlobs;
eltwiseData->outputBlobsWrappers = ld.outputBlobsWrappers;
nextData->outputBlobsWrappers = ld.outputBlobsWrappers;
// Move references of [activ] layer consumers to the newly allocated blob.
for (int i = 0; i < nextData->consumers.size(); ++i)
{
LayerData& consumer = layers[nextData->consumers[i].lid];
for (int j = 0; j < consumer.inputBlobsId.size(); ++j)
{
if (consumer.inputBlobsId[j].lid == lpNext.lid)
{
consumer.inputBlobs[j] = &ld.outputBlobs[0];
consumer.inputBlobsWrappers[j] = ld.outputBlobsWrappers[0];
break;
}
}
}
}
else if (fuse_eltwise) // conv + eltwise (note: conv could have fused activations before eltwise)
{
CV_Assert(IS_DNN_CUDA_TARGET(preferableTarget));
CV_Assert_N(biasLayerData->outputBlobsWrappers.size() == 1, ld.inputBlobsWrappers.size() == 1);
ld.inputBlobsWrappers.push_back(biasLayerData->outputBlobsWrappers[0]);
printf_(("\tfused with %s\n", nextEltwiseLayer->name.c_str()));
eltwiseData->skip = true;
// This optimization is for cases like
// some_layer conv (maybe fused with activ)
// | |
// +-- eltwise --+
//
// This way all the element-wise computations
// (i.e. some_layer+conv or some_layer*conv)
// would be done at [conv] layer. So we need to
// replace [conv]'s output blob to [eltwise]'s one.
// Also we need to move all the consumers' references.
// To prevent memory collisions (i.e. when input of
// [conv] and output of [eltwise] is the same blob)
// we allocate a new blob.
CV_Assert_N(ld.outputBlobs.size() == 1, ld.outputBlobsWrappers.size() == 1);
ld.outputBlobs[0] = ld.outputBlobs[0].clone();
ld.outputBlobsWrappers[0] = wrap(ld.outputBlobs[0]);
eltwiseData->outputBlobs = ld.outputBlobs;
eltwiseData->outputBlobsWrappers = ld.outputBlobsWrappers;
// Move references of [eltwise] layer consumers to the newly allocated blob.
for (int i = 0; i < eltwiseData->consumers.size(); ++i)
{
CV_Assert_N(biasLayerData->outputBlobsWrappers.size() == 1, ld.inputBlobsWrappers.size() == 1);
ld.inputBlobsWrappers.push_back(biasLayerData->outputBlobsWrappers[0]);
printf_(("\tfused with %s\n", nextEltwiseLayer->name.c_str()));
printf_(("\tfused with %s\n", nextActivLayer->name.c_str()));
eltwiseData->skip = true;
nextData->skip = true;
// This optimization for cases like
// some_layer conv
// | |
// +-- eltwise --+
// |
// activ
// This way all the element-wise computations
// (i.e. some_layer+conv or some_layer*conv)
// would be done at [conv] layer. So we need to
// replace [conv]'s output blob to [eltwise]'s one
// considering that [activ] is an in-place layer.
// Also we need to move all the consumers' references.
// To prevent memory collisions (i.e. when input of
// [conv] and output of [eltwise] is the same blob)
// we allocate a new blob.
CV_Assert_N(ld.outputBlobs.size() == 1, ld.outputBlobsWrappers.size() == 1);
ld.outputBlobs[0] = ld.outputBlobs[0].clone();
ld.outputBlobsWrappers[0] = wrap(ld.outputBlobs[0]);
eltwiseData->outputBlobs = ld.outputBlobs;
nextData->outputBlobs = ld.outputBlobs;
eltwiseData->outputBlobsWrappers = ld.outputBlobsWrappers;
nextData->outputBlobsWrappers = ld.outputBlobsWrappers;
// Move references of [activ] layer consumers to the newly allocated blob.
for (int i = 0; i < nextData->consumers.size(); ++i)
LayerData& consumer = layers[eltwiseData->consumers[i].lid];
for (int j = 0; j < consumer.inputBlobsId.size(); ++j)
{
LayerData& consumer = layers[nextData->consumers[i].lid];
for (int j = 0; j < consumer.inputBlobsId.size(); ++j)
if (consumer.inputBlobsId[j].lid == eltwiseData->id)
{
if (consumer.inputBlobsId[j].lid == lpNext.lid)
{
consumer.inputBlobs[j] = &ld.outputBlobs[0];
consumer.inputBlobsWrappers[j] = ld.outputBlobsWrappers[0];
break;
}
consumer.inputBlobs[j] = &ld.outputBlobs[0];
consumer.inputBlobsWrappers[j] = ld.outputBlobsWrappers[0];
break;
}
}
}

@ -248,6 +248,7 @@ public:
#endif
#ifdef HAVE_CUDA
cuda4dnn::ConvolutionConfiguration::FusionMode cudaFusionMode;
cuda4dnn::ConvolutionConfiguration::ActivationType cudaActType;
float cuda_relu_slope, cuda_crelu_floor, cuda_crelu_ceil, cuda_power_exp;
#endif
@ -261,6 +262,7 @@ public:
#endif
#ifdef HAVE_CUDA
cudaFusionMode = cuda4dnn::ConvolutionConfiguration::FusionMode::NONE;
cudaActType = cuda4dnn::ConvolutionConfiguration::ActivationType::IDENTITY;
#endif
}
@ -425,10 +427,18 @@ public:
#endif
#ifdef HAVE_CUDA
cudaActType = cuda4dnn::ConvolutionConfiguration::ActivationType::IDENTITY;
if (activ.empty())
{
/* setActivation was called with empty argument => reset all fusions */
cudaFusionMode = cuda4dnn::ConvolutionConfiguration::FusionMode::NONE;
cudaActType = cuda4dnn::ConvolutionConfiguration::ActivationType::IDENTITY;
}
if(IS_DNN_CUDA_TARGET(preferableTarget))
{
CV_Assert(cudaFusionMode == ConvolutionConfiguration::FusionMode::NONE ||
cudaFusionMode == ConvolutionConfiguration::FusionMode::ELTWISE_SUM);
Ptr<ReLULayer> activ_relu = activ.dynamicCast<ReLULayer>();
if(!activ_relu.empty())
{
@ -475,12 +485,53 @@ public:
cudaActType = cuda4dnn::ConvolutionConfiguration::ActivationType::MISH;
if (cudaActType == cuda4dnn::ConvolutionConfiguration::ActivationType::IDENTITY)
{
/* no activation fused */
activ.reset();
}
else
{
/* activation was fused */
if (cudaFusionMode == ConvolutionConfiguration::FusionMode::NONE) /* no previous fusion */
cudaFusionMode = ConvolutionConfiguration::FusionMode::ACTIVATION; /* now activation */
else if (cudaFusionMode == ConvolutionConfiguration::FusionMode::ELTWISE_SUM) /* previously eltwise was fused */
cudaFusionMode = ConvolutionConfiguration::FusionMode::ELTWISE_SUM_THEN_ACTIVATION; /* now activation on eltwise output */
}
}
#endif
return !activ.empty();
}
virtual bool tryFuse(Ptr<Layer>& top) CV_OVERRIDE
{
#ifdef HAVE_CUDA
if(IS_DNN_CUDA_TARGET(preferableTarget))
{
Ptr<EltwiseLayer> eltwise = top.dynamicCast<EltwiseLayer>();
if (!eltwise.empty()) // && eltwise->op == EltwiseLayer::SUM && eltwise->coeffs.empty())
{
/* we also need to check that the eltwise input does not require shortcut mechanism
* it's difficult to verify it here but we hope that `fuseLayers` has done the check already
*/
if (cudaFusionMode == ConvolutionConfiguration::FusionMode::NONE)
{
/* no previous fusion */
cudaFusionMode = ConvolutionConfiguration::FusionMode::ELTWISE_SUM; /* now eltwise */
return true;
}
else if(cudaFusionMode == ConvolutionConfiguration::FusionMode::ACTIVATION)
{
/* previously an activation was fused */
cudaFusionMode = ConvolutionConfiguration::FusionMode::ACTIVATION_THEN_ELTWISE_SUM;
return true;
}
return false;
}
}
#endif
return BaseConvolutionLayerImpl::tryFuse(top);
}
void fuseWeights(const Mat& w_, const Mat& b_) CV_OVERRIDE
{
// Convolution weights have OIHW data layout. Parameters fusion in case of
@ -1493,7 +1544,7 @@ public:
{
auto context = reinterpret_cast<csl::CSLContext*>(context_);
CV_Assert(inputs.size() == 1);
CV_Assert(inputs.size() == 1 || inputs.size() == 2);
auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
auto input_shape = input_wrapper->getShape();
@ -1534,6 +1585,7 @@ public:
config.output_shape.assign(std::begin(output_shape), std::end(output_shape));
config.groups = groups;
config.fusion_mode = cudaFusionMode;
config.activation_type = cudaActType;
config.relu_negative_slope = cuda_relu_slope;
config.crelu_floor = cuda_crelu_floor;

Loading…
Cancel
Save