From 01f97f150c5f63ff71ffb18a4b99d725dcf9a0a3 Mon Sep 17 00:00:00 2001 From: YashasSamaga Date: Mon, 30 Dec 2019 00:05:39 +0530 Subject: [PATCH] perfor fp conversions on GPU --- modules/dnn/src/cuda/fp_conversion.cu | 102 ++++++++++++++++++ .../src/cuda4dnn/kernels/fp_conversion.hpp | 18 ++++ modules/dnn/src/op_cuda.hpp | 55 +++++++++- 3 files changed, 172 insertions(+), 3 deletions(-) create mode 100644 modules/dnn/src/cuda/fp_conversion.cu create mode 100644 modules/dnn/src/cuda4dnn/kernels/fp_conversion.hpp diff --git a/modules/dnn/src/cuda/fp_conversion.cu b/modules/dnn/src/cuda/fp_conversion.cu new file mode 100644 index 0000000000..7614174800 --- /dev/null +++ b/modules/dnn/src/cuda/fp_conversion.cu @@ -0,0 +1,102 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include +#include + +#include "grid_stride_range.hpp" +#include "execution.hpp" +#include "vector_traits.hpp" + +#include "../cuda4dnn/csl/stream.hpp" +#include "../cuda4dnn/csl/span.hpp" + +using namespace cv::dnn::cuda4dnn::csl; +using namespace cv::dnn::cuda4dnn::csl::device; + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + namespace raw { + template + __global__ void fp32_to_fp16(Span<__half> output, View input) { + using output_vector_type = get_vector_type_t<__half, N>; + using input_vector_type = get_vector_type_t; + + auto output_vPtr = output_vector_type::get_pointer(output.data()); + auto input_vPtr = input_vector_type::get_pointer(input.data()); + + for (auto i : grid_stride_range(output.size() / output_vector_type::size())) { + input_vector_type in_vec; + v_load(in_vec, input_vPtr[i]); + + output_vector_type out_vec; + for (int j = 0; j < output_vector_type::size(); j++) + out_vec.data[j] = __float2half(in_vec.data[j]); + + v_store(output_vPtr[i], out_vec); + } + } + + template + __global__ void fp16_to_fp32(Span output, View<__half> input) { + using output_vector_type = get_vector_type_t; + using input_vector_type = get_vector_type_t<__half, N>; + + auto output_vPtr = output_vector_type::get_pointer(output.data()); + auto input_vPtr = input_vector_type::get_pointer(input.data()); + + for (auto i : grid_stride_range(output.size() / output_vector_type::size())) { + input_vector_type in_vec; + v_load(in_vec, input_vPtr[i]); + + output_vector_type out_vec; + for (int j = 0; j < output_vector_type::size(); j++) + out_vec.data[j] = __half2float(in_vec.data[j]); + + v_store(output_vPtr[i], out_vec); + } + } + } + + template static + void launch_vectorized_fp32_to_fp16(const Stream& stream, Span<__half> output, View input) { + CV_Assert(is_fully_aligned<__half>(output, N)); + CV_Assert(is_fully_aligned(input, N)); + + auto kernel = raw::fp32_to_fp16; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, input); + } + + void fp32_to_fp16(const Stream& stream, Span<__half> output, View input) { + if (is_fully_aligned<__half>(output, 4) && is_fully_aligned(input, 4)) { + launch_vectorized_fp32_to_fp16<4>(stream, output, input); + } else if (is_fully_aligned<__half>(output, 2) && is_fully_aligned(input, 2)) { + launch_vectorized_fp32_to_fp16<2>(stream, output, input); + } else { + launch_vectorized_fp32_to_fp16<1>(stream, output, input); + } + } + + template static + void launch_vectorized_fp16_to_fp32(const Stream& stream, Span output, View<__half> input) { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned<__half>(input, N)); + + auto kernel = raw::fp16_to_fp32; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, input); + } + + void fp16_to_fp32(const Stream& stream, Span output, View<__half> input) { + if (is_fully_aligned(output, 4) && is_fully_aligned<__half>(input, 4)) { + launch_vectorized_fp16_to_fp32<4>(stream, output, input); + } else if (is_fully_aligned(output, 2) && is_fully_aligned<__half>(input, 2)) { + launch_vectorized_fp16_to_fp32<2>(stream, output, input); + } else { + launch_vectorized_fp16_to_fp32<1>(stream, output, input); + } + } + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda4dnn/kernels/fp_conversion.hpp b/modules/dnn/src/cuda4dnn/kernels/fp_conversion.hpp new file mode 100644 index 0000000000..31913c12c4 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/kernels/fp_conversion.hpp @@ -0,0 +1,18 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_FP_CONVERSION_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_FP_CONVERSION_HPP + +#include "../csl/stream.hpp" +#include "../csl/span.hpp" + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + + void fp32_to_fp16(const csl::Stream& stream, csl::Span output, csl::View input); + void fp16_to_fp32(const csl::Stream& stream, csl::Span output, csl::View input); + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_FP_CONVERSION_HPP */ diff --git a/modules/dnn/src/op_cuda.hpp b/modules/dnn/src/op_cuda.hpp index ccb7877e88..d702989c8c 100644 --- a/modules/dnn/src/op_cuda.hpp +++ b/modules/dnn/src/op_cuda.hpp @@ -13,6 +13,7 @@ #include "cuda4dnn/csl/memory.hpp" #include "cuda4dnn/csl/fp16.hpp" #include "cuda4dnn/csl/workspace.hpp" +#include "cuda4dnn/kernels/fp_conversion.hpp" #endif #include @@ -149,7 +150,6 @@ namespace cv { namespace dnn { if (temp.data != destMat.data) temp.copyTo(destMat); } - }} /* namespace cuda4dnn::csl */ /** base class for CUDA operation nodes (for all supported targets) */ @@ -219,6 +219,45 @@ namespace cv { namespace dnn { virtual void setStream(cuda4dnn::csl::Stream stream) noexcept = 0; }; + namespace cuda4dnn { namespace detail { + + template + void convert_D2H(const cv::Mat& mat, cuda4dnn::csl::View view, cuda4dnn::csl::ManagedPtr& device_temp, const cuda4dnn::csl::Stream& stream); + + template <> inline + void convert_D2H(const cv::Mat& mat, cuda4dnn::csl::View view, cuda4dnn::csl::ManagedPtr& device_temp, const cuda4dnn::csl::Stream& stream) { + if (device_temp.size() < view.size()) + device_temp.reset(view.size()); + auto temp_span = cuda4dnn::csl::Span(device_temp.get(), view.size()); + + cuda4dnn::kernels::fp16_to_fp32(stream, temp_span, view); + cuda4dnn::csl::memcpy(reinterpret_cast(mat.data), temp_span.data(), view.size(), stream); + } + + template <> inline + void convert_D2H(const cv::Mat& mat, cuda4dnn::csl::View view, cuda4dnn::csl::ManagedPtr& device_temp, const cuda4dnn::csl::Stream& stream) { + cuda4dnn::csl::memcpy(reinterpret_cast(mat.data), view.data(), view.size(), stream); + } + + template + void convert_H2D(cuda4dnn::csl::Span span, const cv::Mat& mat, cuda4dnn::csl::ManagedPtr& device_temp, const cuda4dnn::csl::Stream& stream); + + template <> inline + void convert_H2D(cuda4dnn::csl::Span span, const cv::Mat& mat, cuda4dnn::csl::ManagedPtr& device_temp, const cuda4dnn::csl::Stream& stream) { + if (device_temp.size() < span.size()) + device_temp.reset(span.size()); + auto temp_span = cuda4dnn::csl::Span(device_temp.get(), span.size()); + + cuda4dnn::csl::memcpy(temp_span.data(), reinterpret_cast(mat.data), span.size(), stream); + cuda4dnn::kernels::fp32_to_fp16(stream, span, temp_span); + } + + template <> inline + void convert_H2D(cuda4dnn::csl::Span span, const cv::Mat& mat, cuda4dnn::csl::ManagedPtr& device_temp, const cuda4dnn::csl::Stream& stream) { + cuda4dnn::csl::memcpy(span.data(), reinterpret_cast(mat.data), span.size(), stream); + } + }} /* namespace cuda4dnn::detail */ + template class GenericCUDABackendWrapper final : public CUDABackendWrapper { public: @@ -283,8 +322,12 @@ namespace cv { namespace dnn { * We use a view to ensure that only the required region of memory is copied. */ auto view = tensor_view_type(shared_block->device.get(), std::begin(shape), std::end(shape)); - cuda4dnn::csl::copyTensorToMat(view, shared_block->host, shared_block->stream); + auto& mat = shared_block->host; + CV_Assert(mat.isContinuous()); + CV_Assert(mat.type() == CV_32F); + + cuda4dnn::detail::convert_D2H(mat, view, shared_block->device_temp, shared_block->stream); shared_block->stream.synchronize(); } } @@ -300,7 +343,12 @@ namespace cv { namespace dnn { shared_block->device_dirty = false; auto span = tensor_span_type(shared_block->device.get(), std::begin(shape), std::end(shape)); - cuda4dnn::csl::copyMatToTensor(shared_block->host, span, shared_block->stream); + + auto& mat = shared_block->host; + CV_Assert(mat.isContinuous()); + CV_Assert(mat.type() == CV_32F); + + cuda4dnn::detail::convert_H2D(span, mat, shared_block->device_temp, shared_block->stream); } } @@ -368,6 +416,7 @@ namespace cv { namespace dnn { cuda4dnn::csl::MemoryLockGuard memGuard; /* keeps host memory page-locked if possible */ cuda4dnn::csl::ManagedPtr device; + cuda4dnn::csl::ManagedPtr device_temp; /* use for conversions */ cuda4dnn::csl::Stream stream; };