From 476a02739e1254607767cca744a7f58ea3ba8b91 Mon Sep 17 00:00:00 2001 From: Yashas Samaga B L Date: Mon, 9 Dec 2019 20:01:27 +0530 Subject: [PATCH] Merge pull request #16097 from YashasSamaga:cuda4dnn-optimize-resize-bilinear cuda4dnn(resize): process multiple channels each iteration * resize bilinear: process multiple chans. per iter. * remove unused headers * correct dispatch logic * resize_nn: process multiple chans. per iter. --- modules/dnn/src/cuda/fill.cu | 2 +- modules/dnn/src/cuda/resize.cu | 188 +++++++++++++++++++++++++-------- 2 files changed, 145 insertions(+), 45 deletions(-) diff --git a/modules/dnn/src/cuda/fill.cu b/modules/dnn/src/cuda/fill.cu index e4fea27ca3..d08884facb 100644 --- a/modules/dnn/src/cuda/fill.cu +++ b/modules/dnn/src/cuda/fill.cu @@ -32,7 +32,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { } } - template + template static void launch_vectorized_fill(const Stream& stream, Span output, T value) { CV_Assert(is_fully_aligned(output, N)); diff --git a/modules/dnn/src/cuda/resize.cu b/modules/dnn/src/cuda/resize.cu index 6eed48aba3..306325ec3c 100644 --- a/modules/dnn/src/cuda/resize.cu +++ b/modules/dnn/src/cuda/resize.cu @@ -22,7 +22,7 @@ using namespace cv::dnn::cuda4dnn::csl::device; namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { namespace raw { - template + template __global__ void resize_nn( Span output, size_type out_height, size_type out_width, View input, size_type in_height, size_type in_width) @@ -30,29 +30,55 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { auto in_image_size = in_height * in_width; auto out_image_size = out_height * out_width; - /* o2i = output to input */ - auto o2i_fx = static_cast(in_width) / out_width; - auto o2i_fy = static_cast(in_height) / out_height; - /* think of the output and input as a collection of 2d images with the last axis * representing the width and the last but one axis representing the height * - * the remaining axis together form a collection of these images + * the remaining axis together form a collection of these images/channels */ - for (auto idx : grid_stride_range(output.size())) { - const index_type n = idx / out_image_size; - const index_type x = (idx % out_image_size) % out_width; - const index_type y = (idx % out_image_size) / out_width; + auto num_effective_channels = output.size() / out_image_size; + + /* we process multiple channels every iteration to reuse the identical computation + * involved with the spatial dimensions + * + * if we are processing `CHANNELS_PER_ITER` channels per iteration, we will need + * (num_effective_channels / CHANNELS_PER_ITER) iterations per (x, y) location + */ + auto num_channel_iters_per_xy = (num_effective_channels / CHANNELS_PER_ITER); + + /* we need `num_channel_iters_per_xy` iterations per (x, y) and there are `out_image_size` + * combinations of (x, y); hence, we'll need `num_channel_iters_per_xy * out_image_size` + * iterations in total to finish the resize operation + */ + auto iters_required = num_channel_iters_per_xy * out_image_size; + for (auto iter : grid_stride_range(iters_required)) { + const index_type c_start = (iter / out_image_size) * CHANNELS_PER_ITER; + + /* note here that consecutive `iter` values will often have consecutive `x` values + * => stores into output will be coalesced across threads + */ + const index_type y = (iter % out_image_size) / out_width; + const index_type x = iter % out_width; + + /* o2i = output to input */ + auto o2i_fy = static_cast(in_height) / out_height; + auto o2i_fx = static_cast(in_width) / out_width; - auto in_x = static_cast(x * o2i_fx); auto in_y = static_cast(y * o2i_fy); + auto in_x = static_cast(x * o2i_fx); + + index_type in_idx = c_start * in_image_size + in_y * in_width + in_x; + index_type out_idx = c_start * out_image_size + y * out_width + x; - index_type in_idx = n * in_image_size + in_y * in_width + in_x; - output[idx] = input[in_idx]; + for (int i = 0; i < CHANNELS_PER_ITER; i++) { + output[out_idx] = input[in_idx]; + + in_idx += in_image_size; + out_idx += out_image_size; + } } } - template + template __global__ void resize_bilinear( Span output, size_type out_height, size_type out_width, View input, size_type in_height, size_type in_width, @@ -64,12 +90,33 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { /* think of the output and input as a collection of 2d images with the last axis * representing the width and the last but one axis representing the height * - * the remaining axis together form a collection of these images + * the remaining axis together form a collection of these images/channels + */ + auto num_effective_channels = output.size() / out_image_size; + + /* we process multiple channels every iteration to reuse the identical computation + * involved with the spatial dimensions + * + * if we are processing `CHANNELS_PER_ITER` channels per iteration, we will need + * (num_effective_channels / CHANNELS_PER_ITER) iterations per (x, y) location + */ + auto num_channel_iters_per_xy = (num_effective_channels / CHANNELS_PER_ITER); + + /* we need `num_channel_iters_per_xy` iterations per (x, y) and there are `out_image_size` + * combinations of (x, y); hence, we'll need `num_channel_iters_per_xy * out_image_size` + * iterations in total to finish the resize operation */ - for (auto idx : grid_stride_range(output.size())) { - const index_type n = idx / out_image_size; - const index_type x = (idx % out_image_size) % out_width; - const index_type y = (idx % out_image_size) / out_width; + auto iters_required = num_channel_iters_per_xy * out_image_size; + + for (auto iter : grid_stride_range(iters_required)) { + const index_type c_start = (iter / out_image_size) * CHANNELS_PER_ITER; + const index_type c_end = c_start + CHANNELS_PER_ITER; + + /* note here that consecutive `iter` values will often have consecutive `x` values + * => stores into output will be coalesced across threads + */ + const index_type y = (iter % out_image_size) / out_width; + const index_type x = iter % out_width; auto in_x = x * o2i_fx; auto in_y = y * o2i_fy; @@ -81,50 +128,103 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { auto in_x1 = min(in_x0 + 1, in_width - 1); auto in_y1 = min(in_y0 + 1, in_height - 1); - const index_type in_offset_r0 = n * in_image_size + in_y0 * in_width; - const index_type in_offset_r1 = n * in_image_size + in_y1 * in_width; - - auto v_00 = input[in_offset_r0 + in_x0], - v_01 = input[in_offset_r0 + in_x1], - v_10 = input[in_offset_r1 + in_x0], - v_11 = input[in_offset_r1 + in_x1]; - - output[idx] = - v_00 + - T(in_y - in_y0) * T(v_10 - v_00) + - T(in_x - in_x0) * T(v_01 - v_00) + - T(in_y - in_y0) * T(in_x - in_x0) * T(v_11 - v_01 - v_10 + v_00); + index_type in_offset_r0 = c_start * in_image_size + in_y0 * in_width; + index_type in_offset_r1 = c_start * in_image_size + in_y1 * in_width; + index_type out_idx = c_start * out_image_size + y * out_width + x; + + #pragma unroll 1 /* disable unrolling to reduce register pressure; not sure how but it works */ + for (auto c = c_start; c < c_end; c++) { + auto v_00 = input[in_offset_r0 + in_x0], + v_01 = input[in_offset_r0 + in_x1], + v_10 = input[in_offset_r1 + in_x0], + v_11 = input[in_offset_r1 + in_x1]; + + output[out_idx] = + v_00 + + T(in_y - in_y0) * T(v_10 - v_00) + + T(in_x - in_x0) * T(v_01 - v_00) + + T(in_y - in_y0) * T(in_x - in_x0) * T(v_11 - v_01 - v_10 + v_00); + + in_offset_r0 += in_image_size; + in_offset_r1 += in_image_size; + out_idx += out_image_size; + } } } } + template static + void launch_multichannel_resize_nn(const Stream& stream, + Span output, size_type out_height, size_type out_width, + View input, size_type in_height, size_type in_width) + { + auto kernel = raw::resize_nn; + auto policy = make_policy(kernel, output.size() / CHANNELS_PER_ITER, 0, stream); + launch_kernel(kernel, policy, output, out_height, out_width, input, in_height, in_width); + } + template void resize_nn(const Stream& stream, TensorSpan output, TensorView input) { - auto in_height = input.get_axis_size(-2); - auto in_width = input.get_axis_size(-1); - auto out_height = output.get_axis_size(-2); auto out_width = output.get_axis_size(-1); - auto kernel = raw::resize_nn; - auto policy = make_policy(kernel, output.size(), 0, stream); - launch_kernel(kernel, policy, output, out_height, out_width, input, in_height, in_width); + auto in_height = input.get_axis_size(-2); + auto in_width = input.get_axis_size(-1); + + auto num_effective_channels = input.size_range(0, 2); + auto num_iters = num_effective_channels * out_height * out_width; + + if (num_effective_channels % 32 == 0 && num_iters > 655360) { + launch_multichannel_resize_nn(stream, output, out_height, out_width, input, in_height, in_width); + } else if (num_effective_channels % 16 == 0 && num_iters > 327680) { + launch_multichannel_resize_nn(stream, output, out_height, out_width, input, in_height, in_width); + } else if (num_effective_channels % 8 == 0 && num_iters > 163840) { + launch_multichannel_resize_nn(stream, output, out_height, out_width, input, in_height, in_width); + } else if (num_effective_channels % 4 == 0 && num_iters > 81920) { + launch_multichannel_resize_nn(stream, output, out_height, out_width, input, in_height, in_width); + } else if (num_effective_channels % 2 == 0) { + launch_multichannel_resize_nn(stream, output, out_height, out_width, input, in_height, in_width); + } else { + launch_multichannel_resize_nn(stream, output, out_height, out_width, input, in_height, in_width); + } } template void resize_nn<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>); template void resize_nn(const Stream&, TensorSpan, TensorView); + template static + void launch_multichannel_resize_bilinear(const Stream& stream, + Span output, size_type out_height, size_type out_width, + View input, size_type in_height, size_type in_width, + float scale_y, float scale_x) + { + auto kernel = raw::resize_bilinear; + auto policy = make_policy(kernel, output.size() / CHANNELS_PER_ITER, 0, stream); + launch_kernel(kernel, policy, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x); + } + template void resize_bilinear(const Stream& stream, TensorSpan output, TensorView input, float scale_y, float scale_x) { - auto in_height = input.get_axis_size(-2); - auto in_width = input.get_axis_size(-1); - auto out_height = output.get_axis_size(-2); auto out_width = output.get_axis_size(-1); - auto kernel = raw::resize_bilinear; - auto policy = make_policy(kernel, output.size(), 0, stream); - launch_kernel(kernel, policy, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x); + auto in_height = input.get_axis_size(-2); + auto in_width = input.get_axis_size(-1); + + auto num_effective_channels = input.size_range(0, 2); + auto num_iters = num_effective_channels * out_height * out_width; + + if (num_effective_channels % 16 == 0 && num_iters > 163840) { + launch_multichannel_resize_bilinear(stream, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x); + } else if (num_effective_channels % 8 == 0 && num_iters > 81920) { + launch_multichannel_resize_bilinear(stream, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x); + } else if (num_effective_channels % 4 == 0 && num_iters > 40960) { + launch_multichannel_resize_bilinear(stream, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x); + } else if (num_effective_channels % 2 == 0) { + launch_multichannel_resize_bilinear(stream, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x); + } else { + launch_multichannel_resize_bilinear(stream, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x); + } } template void resize_bilinear<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>, float, float);