diff --git a/modules/dnn/src/cuda/crop_and_resize.cu b/modules/dnn/src/cuda/crop_and_resize.cu index 4e597b6417..0104ad1c6c 100644 --- a/modules/dnn/src/cuda/crop_and_resize.cu +++ b/modules/dnn/src/cuda/crop_and_resize.cu @@ -9,6 +9,7 @@ #include "types.hpp" #include "grid_stride_range.hpp" #include "execution.hpp" +#include "memory.hpp" #include "../cuda4dnn/csl/stream.hpp" #include "../cuda4dnn/csl/tensor.hpp" @@ -102,10 +103,10 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { #pragma unroll 1 /* disable unrolling */ for (int i = 0; i < CHANNELS_PER_ITER; i++) { - auto v_00 = input[in_offset_r0 + in_x0], - v_01 = input[in_offset_r0 + in_x1], - v_10 = input[in_offset_r1 + in_x0], - v_11 = input[in_offset_r1 + in_x1]; + auto v_00 = load_ldg(input[in_offset_r0 + in_x0]), + v_01 = load_ldg(input[in_offset_r0 + in_x1]), + v_10 = load_ldg(input[in_offset_r1 + in_x0]), + v_11 = load_ldg(input[in_offset_r1 + in_x1]); output[out_idx] = v_00 + diff --git a/modules/dnn/src/cuda/functors.hpp b/modules/dnn/src/cuda/functors.hpp index c35a85437c..237c429558 100644 --- a/modules/dnn/src/cuda/functors.hpp +++ b/modules/dnn/src/cuda/functors.hpp @@ -30,8 +30,10 @@ struct tanh_functor { template struct swish_functor { __device__ T operator()(T value) { - using csl::device::sigmoid; - return value * sigmoid(value); + // f(x) = x * sigmoid(x) + using csl::device::fast_divide; + using csl::device::fast_exp; + return fast_divide(value, static_cast(1) + fast_exp(-value)); } }; @@ -44,11 +46,30 @@ struct mish_functor { } }; +template <> +struct mish_functor { + __device__ float operator()(float value) { + // f(x) = x * tanh(log1pexp(x)); + using csl::device::fast_divide; + using csl::device::fast_exp; + + auto e = fast_exp(value); + if (value <= -18.0f) + return value * e; + + auto n = e * e + 2 * e; + if (value <= -5.0f) + return value * fast_divide(n, n + 2); + + return value - 2 * fast_divide(value, n + 2); + } +}; + template struct sigmoid_functor { __device__ T operator()(T value) { - using csl::device::sigmoid; - return sigmoid(value); + using csl::device::fast_sigmoid; + return fast_sigmoid(value); } }; diff --git a/modules/dnn/src/cuda/math.hpp b/modules/dnn/src/cuda/math.hpp index 5fb9f43445..8d4aea8b7d 100644 --- a/modules/dnn/src/cuda/math.hpp +++ b/modules/dnn/src/cuda/math.hpp @@ -160,6 +160,15 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace de template <> inline __device__ __half2 ceil(__half2 value) { return h2ceil(value); } #endif + template __device__ T fast_divide(T x, T y) { return x / y; } + template <> inline __device__ float fast_divide(float x, float y) { return __fdividef(x, y); } + + template __device__ T fast_exp(T value) { return exp(value); } + template <> inline __device__ float fast_exp(float value) { return __expf(value); } + + template __device__ T fast_sigmoid(T value) { return sigmoid(value); } + template <> inline __device__ float fast_sigmoid(float value) { return __fdividef(1, 1 + __expf(-value)); } + }}}}} /* namespace cv::dnn::cuda4dnn::csl::device */ #endif /* OPENCV_DNN_SRC_CUDA_MATH_HPP */ diff --git a/modules/dnn/src/cuda/memory.hpp b/modules/dnn/src/cuda/memory.hpp new file mode 100644 index 0000000000..4ee984626a --- /dev/null +++ b/modules/dnn/src/cuda/memory.hpp @@ -0,0 +1,32 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA_MEMORY_HPP +#define OPENCV_DNN_SRC_CUDA_MEMORY_HPP + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device { + +template +__device__ T load_ldg(const T& src) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350) + return __ldg(&src); +#else + return src; +#endif +} + +template +__device__ T load_ldg(const T* src) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350) + return __ldg(src); +#else + return *src; +#endif +} + +}}}}} /* namespace cv::dnn::cuda4dnn::csl::device */ + +#endif /* OPENCV_DNN_SRC_CUDA_MEMORY_HPP */ diff --git a/modules/dnn/src/cuda/permute.cu b/modules/dnn/src/cuda/permute.cu index e79087eb67..082c1bf75e 100644 --- a/modules/dnn/src/cuda/permute.cu +++ b/modules/dnn/src/cuda/permute.cu @@ -7,7 +7,6 @@ #include "array.hpp" #include "types.hpp" -#include "vector_traits.hpp" #include "grid_stride_range.hpp" #include "execution.hpp" #include "kernel_dispatcher.hpp" @@ -50,84 +49,62 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { } } - template + template __global__ void transpose(Span output, View input, size_type in_width, size_type out_width) { - using vector_type = get_vector_type_t; - __shared__ T tile[TILE_SIZE][TILE_SIZE + 1]; - /* blockDim.y = TILE_SIZE, blockDim.x = TILE_SIZE/N */ - const index_type in_x = blockIdx.x * TILE_SIZE + threadIdx.x * N; - const index_type in_y = blockIdx.y * TILE_SIZE + threadIdx.y; + /* blockDim.y = TILE_SIZE / ROWS_PER_THREAD, blockDim.x = TILE_SIZE */ + const index_type in_x = blockIdx.x * TILE_SIZE + threadIdx.x; + const index_type in_y_begin = blockIdx.y * TILE_SIZE + threadIdx.y; /* Every valid input location has a corresponding output location and vice versa. * Hence, if we do not load values into the shared memory for a given location, we * also won't read them for storing in the output. */ - if (in_x < in_width && in_y < out_width) + for (int j = 0; j < TILE_SIZE; j += TILE_SIZE / ROWS_PER_THREAD) { - vector_type vec; - auto input_vPtr = vector_type::get_pointer(input.data()); - v_load(vec, input_vPtr[(in_y * in_width + in_x) / N]); - - for (int i = 0; i < vector_type::size(); i++) - tile[threadIdx.y][threadIdx.x * N + i] = vec.data[i]; + const auto in_y_current = in_y_begin + j; + if (in_x < in_width && in_y_current < out_width) + tile[threadIdx.y + j][threadIdx.x] = input[in_y_current * in_width + in_x]; } __syncthreads(); - /* Note that `blockDim.x * N` is equal to `blockDim.y`. Since there are an equal - * number of them, we can interchange `threadIdx.x` and `threadIdx.y` without changing - * result. The advantage of interchanging is that consecutive output indices map to + /* We interchange `threadIdx.x` and `threadIdx.y` so that consecutive output indices map to * consecutive threads. This would allow writes across threds in a warp to be coalesced. */ - const index_type out_x = blockIdx.y * TILE_SIZE + threadIdx.x * N; - const index_type out_y = blockIdx.x * TILE_SIZE + threadIdx.y; + const index_type out_x = blockIdx.y * TILE_SIZE + threadIdx.x; + const index_type out_y_begin = blockIdx.x * TILE_SIZE + threadIdx.y; - if (out_x < out_width && out_y < in_width) + for (int j = 0; j < TILE_SIZE; j += TILE_SIZE / ROWS_PER_THREAD) { - vector_type vec; - for (int i = 0; i < vector_type::size(); i++) - vec.data[i] = tile[threadIdx.x * N + i][threadIdx.y]; - - auto output_vPtr = vector_type::get_pointer(output.data()); - v_store(output_vPtr[(out_y * out_width + out_x) / N], vec); + const auto out_y_current = out_y_begin + j; + if (out_x < out_width && out_y_current < in_width) + output[out_y_current * out_width + out_x] = tile[threadIdx.x][threadIdx.y + j]; } } } - template static - void launch_transpose_kernel(const Stream& stream, Span output, View input, size_type in_width, size_type out_width) + template + void transpose(const Stream& stream, Span output, View input, std::size_t in_width, std::size_t out_width) { - CV_Assert(is_fully_aligned(output, N)); - CV_Assert(is_fully_aligned(input, N)); - CV_Assert(in_width % N == 0); - CV_Assert(out_width % N == 0); - + /* Each block processes a TILE_SIZE x TILE_SIZE piece */ constexpr int TILE_SIZE = 32; - constexpr int TILE_SIZE_X = TILE_SIZE/N, TILE_SIZE_Y = TILE_SIZE; - auto kernel = raw::transpose; - dim3 grid_size((in_width/N + TILE_SIZE_X - 1)/TILE_SIZE_X, (out_width + TILE_SIZE_Y - 1)/TILE_SIZE_Y); - dim3 block_size(TILE_SIZE_X, TILE_SIZE_Y); + /* Each thread processes ROWS_PER_THREAD rows. We do this to decrease the number of threads required + * in a block so that the cost of the block-wide synchronization is minimized. + */ + constexpr int ROWS_PER_THREAD = 4; + + dim3 grid_size((in_width + TILE_SIZE - 1) / TILE_SIZE, (out_width + TILE_SIZE - 1) / TILE_SIZE); + dim3 block_size(TILE_SIZE, TILE_SIZE / ROWS_PER_THREAD); auto policy = execution_policy(grid_size, block_size, stream); + auto kernel = raw::transpose; launch_kernel(kernel, policy, output, input, in_width, out_width); } - template - void transpose(const Stream& stream, Span output, View input, std::size_t in_width, std::size_t out_width) - { - if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4) && in_width % 4 == 0 && out_width % 4 == 0) { - launch_transpose_kernel(stream, output, input, in_width, out_width); - } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2) && in_width % 2 == 0 && out_width % 2 == 0) { - launch_transpose_kernel(stream, output, input, in_width, out_width); - } else { - launch_transpose_kernel(stream, output, input, in_width, out_width); - } - } - template void transpose(const Stream&, Span<__half>, View<__half>, std::size_t, std::size_t); template void transpose(const Stream&, Span, View, std::size_t, std::size_t); diff --git a/modules/dnn/src/cuda/region.cu b/modules/dnn/src/cuda/region.cu index b90a13fff6..d9e548f3c9 100644 --- a/modules/dnn/src/cuda/region.cu +++ b/modules/dnn/src/cuda/region.cu @@ -47,20 +47,20 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { const auto y = (box_index % batch_inner_size) / row_inner_size; const auto x = (box_index % row_inner_size) / col_inner_size; - using device::sigmoid; - output[box_offset + 0] = (T(x) + sigmoid(input[box_offset + 0])) / T(cols); - output[box_offset + 1] = (T(y) + sigmoid(input[box_offset + 1])) / T(rows); + using device::fast_sigmoid; + output[box_offset + 0] = (T(x) + fast_sigmoid(input[box_offset + 0])) / T(cols); + output[box_offset + 1] = (T(y) + fast_sigmoid(input[box_offset + 1])) / T(rows); vector2_type bias_xy; v_load(bias_xy, bias_vPtr[box_of_the_cell]); - using device::exp; - output[box_offset + 2] = exp(input[box_offset + 2]) * bias_xy.data[0] / T(width_norm); - output[box_offset + 3] = exp(input[box_offset + 3]) * bias_xy.data[1] / T(height_norm); + using device::fast_exp; + output[box_offset + 2] = fast_exp(input[box_offset + 2]) * bias_xy.data[0] / T(width_norm); + output[box_offset + 3] = fast_exp(input[box_offset + 3]) * bias_xy.data[1] / T(height_norm); /* squash objectness score into a probability */ - using device::sigmoid; - T objectness_prob = sigmoid(input[box_offset + 4]); + using device::fast_sigmoid; + T objectness_prob = fast_sigmoid(input[box_offset + 4]); /* ignore prediction if the objectness probability is less than the cutoff */ if (objectness_prob < object_prob_cutoff) @@ -91,7 +91,8 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { * to obtain the actual class probability, we multiply the conditional probability * with the object probability */ - auto actual_class_prob = objectness_prob * sigmoid(input[idx]); + using device::fast_sigmoid; + auto actual_class_prob = objectness_prob * fast_sigmoid(input[idx]); if (actual_class_prob <= class_prob_cutoff) actual_class_prob = T(0); output[idx] = actual_class_prob; diff --git a/modules/dnn/src/cuda/resize.cu b/modules/dnn/src/cuda/resize.cu index c34790f74c..045b4f0a87 100644 --- a/modules/dnn/src/cuda/resize.cu +++ b/modules/dnn/src/cuda/resize.cu @@ -9,6 +9,7 @@ #include "types.hpp" #include "grid_stride_range.hpp" #include "execution.hpp" +#include "memory.hpp" #include "../cuda4dnn/csl/stream.hpp" #include "../cuda4dnn/csl/tensor.hpp" @@ -70,7 +71,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { index_type out_idx = c_start * out_image_size + y * out_width + x; for (int i = 0; i < CHANNELS_PER_ITER; i++) { - output[out_idx] = input[in_idx]; + output[out_idx] = load_ldg(input[in_idx]); in_idx += in_image_size; out_idx += out_image_size; @@ -134,10 +135,10 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { #pragma unroll 1 /* disable unrolling to reduce register pressure; not sure how but it works */ for (auto c = c_start; c < c_end; c++) { - auto v_00 = input[in_offset_r0 + in_x0], - v_01 = input[in_offset_r0 + in_x1], - v_10 = input[in_offset_r1 + in_x0], - v_11 = input[in_offset_r1 + in_x1]; + auto v_00 = load_ldg(input[in_offset_r0 + in_x0]), + v_01 = load_ldg(input[in_offset_r0 + in_x1]), + v_10 = load_ldg(input[in_offset_r1 + in_x0]), + v_11 = load_ldg(input[in_offset_r1 + in_x1]); output[out_idx] = v_00 + diff --git a/modules/dnn/src/cuda/roi_pooling.cu b/modules/dnn/src/cuda/roi_pooling.cu index ecffbf86cb..c43e332b87 100644 --- a/modules/dnn/src/cuda/roi_pooling.cu +++ b/modules/dnn/src/cuda/roi_pooling.cu @@ -10,6 +10,7 @@ #include "types.hpp" #include "grid_stride_range.hpp" #include "execution.hpp" +#include "memory.hpp" #include "../cuda4dnn/csl/stream.hpp" #include "../cuda4dnn/csl/tensor.hpp" @@ -118,7 +119,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { const auto in_idx = in_offset + iy * in_width; for (auto ix = x_start; ix < x_end; ix++) { - max_val = max(max_val, input[in_idx + ix]); + max_val = max(max_val, load_ldg(input[in_idx + ix])); } } diff --git a/modules/dnn/src/cuda/vector_traits.hpp b/modules/dnn/src/cuda/vector_traits.hpp index b10bcd301b..1b9b76980c 100644 --- a/modules/dnn/src/cuda/vector_traits.hpp +++ b/modules/dnn/src/cuda/vector_traits.hpp @@ -8,6 +8,7 @@ #include #include "types.hpp" +#include "memory.hpp" #include "../cuda4dnn/csl/pointer.hpp" @@ -86,6 +87,16 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace de dest.raw = src->raw; } + template + __device__ void v_load_ldg(V& dest, const V& src) { + dest.raw = load_ldg(src.raw); + } + + template + __device__ void v_load_ldg(V& dest, const V* src) { + dest.raw = load_ldg(src->raw); + } + template __device__ void v_store(V* dest, const V& src) { dest->raw = src.raw; diff --git a/modules/dnn/src/layers/convolution_layer.cpp b/modules/dnn/src/layers/convolution_layer.cpp index 87e0468087..a2889a77d1 100644 --- a/modules/dnn/src/layers/convolution_layer.cpp +++ b/modules/dnn/src/layers/convolution_layer.cpp @@ -167,6 +167,10 @@ public: virtual bool tryFuse(Ptr& top) CV_OVERRIDE { + Ptr blank_layer = top.dynamicCast(); + if (blank_layer) + return true; + Mat w, b; top->getScaleShift(w, b); if (!w.empty() || !b.empty())