parent
77b01deb80
commit
a3106d424b
4 changed files with 342 additions and 1 deletions
@ -0,0 +1,145 @@ |
||||
// This file is part of OpenCV project. |
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory |
||||
// of this distribution and at http://opencv.org/license.html. |
||||
|
||||
#include <cuda_runtime.h> |
||||
#include <cuda_fp16.h> |
||||
|
||||
#include "math.hpp" |
||||
#include "types.hpp" |
||||
#include "atomics.hpp" |
||||
#include "grid_stride_range.hpp" |
||||
#include "execution.hpp" |
||||
|
||||
#include "../cuda4dnn/csl/stream.hpp" |
||||
#include "../cuda4dnn/csl/span.hpp" |
||||
|
||||
#include <opencv2/core.hpp> |
||||
|
||||
#include <cstddef> |
||||
|
||||
using namespace cv::dnn::cuda4dnn::csl; |
||||
using namespace cv::dnn::cuda4dnn::csl::device; |
||||
|
||||
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { |
||||
|
||||
namespace raw { |
||||
template <class T> |
||||
__global__ void reduce_mean(Span<float> means, View<T> input, size_type inner_size) { |
||||
for (auto idx : grid_stride_range(input.size())) { |
||||
const index_type outer_idx = idx / inner_size; |
||||
atomicAdd(&means[outer_idx], static_cast<float>(input[idx]) / inner_size); |
||||
} |
||||
} |
||||
|
||||
template <class T> |
||||
__global__ void reduce_mean_sqr_sum(Span<float> means, Span<float> sum_sqrs, View<T> input, size_type inner_size) { |
||||
for (auto idx : grid_stride_range(input.size())) { |
||||
const index_type outer_idx = idx / inner_size; |
||||
auto x = static_cast<float>(input[idx]); |
||||
atomicAdd(&means[outer_idx], x / inner_size); |
||||
atomicAdd(&sum_sqrs[outer_idx], x * x); |
||||
} |
||||
} |
||||
|
||||
__global__ void compute_normalization_scale(Span<float> scale, View<float> means, View<float> sums_sqr, size_type inner_size, float eps) { |
||||
for (auto idx : grid_stride_range(scale.size())) { |
||||
auto mean = means[idx]; |
||||
auto var = sums_sqr[idx] / inner_size - mean * mean; |
||||
using device::rsqrt; |
||||
scale[idx] = rsqrt(eps + var); |
||||
} |
||||
} |
||||
|
||||
template <class T> |
||||
__global__ void normalize_mean(Span<T> output, View<T> input, View<float> means, size_type inner_size) { |
||||
for (auto idx : grid_stride_range(output.size())) { |
||||
const index_type outer_idx = idx / inner_size; |
||||
output[idx] = static_cast<float>(input[idx]) - means[outer_idx]; |
||||
} |
||||
} |
||||
|
||||
template <class T> |
||||
__global__ void normalize_mean_variance(Span<T> output, View<T> input, View<float> means, View<float> scale, size_type inner_size) { |
||||
for (auto idx : grid_stride_range(output.size())) { |
||||
const index_type outer_idx = idx / inner_size; |
||||
output[idx] = (static_cast<float>(input[idx]) - means[outer_idx]) * scale[outer_idx]; |
||||
} |
||||
} |
||||
} |
||||
|
||||
template <class T> |
||||
void reduce_mean(const Stream& stream, Span<float> means, View<T> input, std::size_t inner_size) |
||||
{ |
||||
CV_Assert(input.size() / inner_size == means.size()); |
||||
|
||||
auto kernel = raw::reduce_mean<T>; |
||||
auto policy = make_policy(kernel, input.size(), 0, stream); |
||||
launch_kernel(kernel, policy, means, input, inner_size); |
||||
} |
||||
|
||||
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) |
||||
template void reduce_mean(const Stream&, Span<float>, View<__half>, std::size_t); |
||||
#endif |
||||
template void reduce_mean(const Stream&, Span<float>, View<float>, std::size_t); |
||||
|
||||
template <class T> |
||||
void reduce_mean_sqr_sum(const Stream& stream, Span<float> means, Span<float> sum_sqrs, View<T> input, std::size_t inner_size) |
||||
{ |
||||
CV_Assert(input.size() / inner_size == means.size()); |
||||
CV_Assert(input.size() / inner_size == sum_sqrs.size()); |
||||
|
||||
auto kernel = raw::reduce_mean_sqr_sum<T>; |
||||
auto policy = make_policy(kernel, input.size(), 0, stream); |
||||
launch_kernel(kernel, policy, means, sum_sqrs, input, inner_size); |
||||
} |
||||
|
||||
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) |
||||
template void reduce_mean_sqr_sum(const Stream&, Span<float>, Span<float>, View<__half>, std::size_t); |
||||
#endif |
||||
template void reduce_mean_sqr_sum(const Stream&, Span<float>, Span<float>, View<float>, std::size_t); |
||||
|
||||
void compute_normalization_scale(const Stream& stream, Span<float> scale, View<float> means, View<float> sum_sqrs, std::size_t inner_size, float eps) |
||||
{ |
||||
CV_Assert(scale.size() == means.size()); |
||||
CV_Assert(scale.size() == sum_sqrs.size()); |
||||
|
||||
auto kernel = raw::compute_normalization_scale; |
||||
auto policy = make_policy(kernel, scale.size(), 0, stream); |
||||
launch_kernel(kernel, policy, scale, means, sum_sqrs, inner_size, eps); |
||||
} |
||||
|
||||
template <class T> |
||||
void normalize_mean(const Stream& stream, Span<T> output, View<T> input, View<float> means, std::size_t inner_size) |
||||
{ |
||||
CV_Assert(output.size() == input.size()); |
||||
CV_Assert(input.size() / inner_size == means.size()); |
||||
|
||||
auto kernel = raw::normalize_mean<T>; |
||||
auto policy = make_policy(kernel, output.size(), 0, stream); |
||||
launch_kernel(kernel, policy, output, input, means, inner_size); |
||||
} |
||||
|
||||
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) |
||||
template void normalize_mean(const Stream&, Span<__half>, View<__half>, View<float>, std::size_t); |
||||
#endif |
||||
template void normalize_mean(const Stream&, Span<float>, View<float>, View<float>, std::size_t); |
||||
|
||||
template <class T> |
||||
void normalize_mean_variance(const Stream& stream, Span<T> output, View<T> input, View<float> means, View<float> scale, std::size_t inner_size) |
||||
{ |
||||
CV_Assert(input.size() == output.size()); |
||||
CV_Assert(input.size() / inner_size == means.size()); |
||||
CV_Assert(input.size() / inner_size == scale.size()); |
||||
|
||||
auto kernel = raw::normalize_mean_variance<T>; |
||||
auto policy = make_policy(kernel, output.size(), 0, stream); |
||||
launch_kernel(kernel, policy, output, input, means, scale, inner_size); |
||||
} |
||||
|
||||
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) |
||||
template void normalize_mean_variance(const Stream&, Span<__half>, View<__half>, View<float>, View<float>, std::size_t); |
||||
#endif |
||||
template void normalize_mean_variance(const Stream&, Span<float>, View<float>, View<float>, View<float>, std::size_t); |
||||
|
||||
}}}} /* namespace cv::dnn::cuda4dnn::kernels */ |
@ -0,0 +1,31 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_MVN_HPP |
||||
#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_MVN_HPP |
||||
|
||||
#include "../csl/stream.hpp" |
||||
#include "../csl/span.hpp" |
||||
|
||||
#include <cstddef> |
||||
|
||||
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { |
||||
|
||||
template <class T> |
||||
void reduce_mean(const csl::Stream& stream, csl::Span<float> means, csl::View<T> input, std::size_t inner_size); |
||||
|
||||
template <class T> |
||||
void reduce_mean_sqr_sum(const csl::Stream& stream, csl::Span<float> means, csl::Span<float> sum_sqrs, csl::View<T> input, std::size_t inner_size); |
||||
|
||||
void compute_normalization_scale(const csl::Stream& stream, csl::Span<float> scale, csl::View<float> means, csl::View<float> sum_sqrs, std::size_t inner_size, float eps); |
||||
|
||||
template <class T> |
||||
void normalize_mean(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input, csl::View<float> means, std::size_t inner_size); |
||||
|
||||
template <class T> |
||||
void normalize_mean_variance(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input, csl::View<float> means, csl::View<float> scale, std::size_t inner_size); |
||||
|
||||
}}}} /* namespace cv::dnn::cuda4dnn::kernels */ |
||||
|
||||
#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_MVN_HPP */ |
@ -0,0 +1,134 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MVN_HPP |
||||
#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MVN_HPP |
||||
|
||||
#include "../../op_cuda.hpp" |
||||
|
||||
#include "../csl/stream.hpp" |
||||
#include "../csl/span.hpp" |
||||
#include "../csl/tensor.hpp" |
||||
#include "../csl/workspace.hpp" |
||||
|
||||
#include "../kernels/fill_copy.hpp" |
||||
#include "../kernels/mvn.hpp" |
||||
|
||||
#include <opencv2/core.hpp> |
||||
|
||||
#include <cstddef> |
||||
#include <vector> |
||||
#include <utility> |
||||
|
||||
namespace cv { namespace dnn { namespace cuda4dnn { |
||||
|
||||
struct MVNConfiguration { |
||||
std::vector<std::vector<std::size_t>> input_shapes; |
||||
|
||||
/*
|
||||
* [0, split_axis) = outer range |
||||
* [split_axis, -1] = inner range |
||||
* |
||||
* for each location in the outer range, all the values in the inner range are normalized as a group |
||||
*/ |
||||
std::size_t split_axis; |
||||
|
||||
/* The group (described above) is centered always. The following parameter controls whether the variance
|
||||
* is also normalized. |
||||
*/ |
||||
bool normalize_variance; |
||||
float epsilon; |
||||
}; |
||||
|
||||
template <class T> |
||||
class MVNOp final : public CUDABackendNode { |
||||
public: |
||||
using wrapper_type = GetCUDABackendWrapperType<T>; |
||||
|
||||
MVNOp(csl::Stream stream_, const MVNConfiguration& config) |
||||
: stream(std::move(stream_)) |
||||
{ |
||||
split_axis = config.split_axis; |
||||
normalize_variance = config.normalize_variance; |
||||
epsilon = config.epsilon; |
||||
|
||||
std::size_t max_outer_size = 0; |
||||
const auto& input_shapes = config.input_shapes; |
||||
for (int i = 0; i < input_shapes.size(); i++) |
||||
{ |
||||
std::size_t outer_size = 1; |
||||
for (int j = 0; j < split_axis; j++) |
||||
outer_size *= input_shapes[i][j]; |
||||
max_outer_size = std::max(max_outer_size, outer_size); |
||||
} |
||||
|
||||
csl::WorkspaceBuilder builder; |
||||
builder.require<float>(max_outer_size); |
||||
if (normalize_variance) |
||||
builder.require<float>(max_outer_size); |
||||
scratch_mem_in_bytes = builder.required_workspace_size(); |
||||
} |
||||
|
||||
void forward( |
||||
const std::vector<cv::Ptr<BackendWrapper>>& inputs, |
||||
const std::vector<cv::Ptr<BackendWrapper>>& outputs, |
||||
csl::Workspace& workspace) override |
||||
{ |
||||
CV_Assert(inputs.size() == outputs.size()); |
||||
|
||||
for (int i = 0; i < inputs.size(); i++) |
||||
{ |
||||
auto input_wrapper = inputs[i].dynamicCast<wrapper_type>(); |
||||
auto input = input_wrapper->getView(); |
||||
|
||||
auto output_wrapper = outputs[i].dynamicCast<wrapper_type>(); |
||||
auto output = output_wrapper->getSpan(); |
||||
|
||||
auto outer_size = input.size_range(0, split_axis); |
||||
auto inner_size = input.size_range(split_axis, input.rank()); |
||||
if (inner_size == 1) |
||||
{ |
||||
kernels::fill<T>(stream, output, 0.0f); |
||||
return; |
||||
} |
||||
else |
||||
{ |
||||
auto ws_allocator = csl::WorkspaceAllocator(workspace); |
||||
|
||||
auto means = ws_allocator.get_span<float>(outer_size); |
||||
kernels::fill<float>(stream, means, 0); |
||||
|
||||
if (normalize_variance) |
||||
{ |
||||
auto scales = ws_allocator.get_span<float>(outer_size); |
||||
kernels::fill<float>(stream, scales, 0); |
||||
|
||||
kernels::reduce_mean_sqr_sum<T>(stream, means, scales, input, inner_size); |
||||
kernels::compute_normalization_scale(stream, scales, means, scales, inner_size, epsilon); |
||||
kernels::normalize_mean_variance<T>(stream, output, input, means, scales, inner_size); |
||||
} |
||||
else |
||||
{ |
||||
kernels::reduce_mean<T>(stream, means, input, inner_size); |
||||
kernels::normalize_mean<T>(stream, output, input, means, inner_size); |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
std::size_t get_workspace_memory_in_bytes() const noexcept override { return scratch_mem_in_bytes; } |
||||
|
||||
private: |
||||
csl::Stream stream; |
||||
|
||||
bool normalize_variance; |
||||
float epsilon; |
||||
std::size_t split_axis; |
||||
|
||||
std::size_t scratch_mem_in_bytes; |
||||
}; |
||||
|
||||
}}} /* namespace cv::dnn::cuda4dnn */ |
||||
|
||||
#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MVN_HPP */ |
Loading…
Reference in new issue