Merge pull request #20877 from rogday:simple_layers

pull/20921/head
Alexander Alekhin 3 years ago
commit ec10f2e72b
  1. 43
      modules/dnn/include/opencv2/dnn/all_layers.hpp
  2. 42
      modules/dnn/src/cuda/activations.cu
  3. 90
      modules/dnn/src/cuda/functors.hpp
  4. 21
      modules/dnn/src/cuda/math.hpp
  5. 18
      modules/dnn/src/cuda4dnn/kernels/activations.hpp
  6. 320
      modules/dnn/src/cuda4dnn/primitives/activation.hpp
  7. 7
      modules/dnn/src/init.cpp
  8. 897
      modules/dnn/src/layers/elementwise_layers.cpp
  9. 53
      modules/dnn/src/layers/scale_layer.cpp
  10. 36
      modules/dnn/src/onnx/onnx_importer.cpp
  11. 36
      modules/dnn/src/opencl/activations.cl
  12. 76
      modules/dnn/test/test_onnx_importer.cpp

@ -600,6 +600,42 @@ CV__DNN_INLINE_NS_BEGIN
static Ptr<ExpLayer> create(const LayerParams &params);
};
class CV_EXPORTS CeilLayer : public ActivationLayer
{
public:
static Ptr<CeilLayer> create(const LayerParams &params);
};
class CV_EXPORTS FloorLayer : public ActivationLayer
{
public:
static Ptr<FloorLayer> create(const LayerParams &params);
};
class CV_EXPORTS LogLayer : public ActivationLayer
{
public:
static Ptr<LogLayer> create(const LayerParams &params);
};
class CV_EXPORTS RoundLayer : public ActivationLayer
{
public:
static Ptr<RoundLayer> create(const LayerParams &params);
};
class CV_EXPORTS SqrtLayer : public ActivationLayer
{
public:
static Ptr<SqrtLayer> create(const LayerParams &params);
};
class CV_EXPORTS NotLayer : public ActivationLayer
{
public:
static Ptr<NotLayer> create(const LayerParams &params);
};
class CV_EXPORTS ActivationLayerInt8 : public ActivationLayer
{
public:
@ -665,6 +701,7 @@ CV__DNN_INLINE_NS_BEGIN
public:
bool hasBias;
int axis;
String mode;
static Ptr<ScaleLayer> create(const LayerParams& params);
};
@ -689,6 +726,12 @@ CV__DNN_INLINE_NS_BEGIN
static Ptr<Layer> create(const LayerParams& params);
};
class CV_EXPORTS CompareLayer : public Layer
{
public:
static Ptr<Layer> create(const LayerParams& params);
};
class CV_EXPORTS DataAugmentationLayer : public Layer
{
public:

@ -128,6 +128,36 @@ void bnll(const Stream& stream, Span<T> output, View<T> input) {
generic_op<T, BNLLFunctor<T>>(stream, output, input);
}
template <class T>
void ceil(const Stream& stream, Span<T> output, View<T> input) {
generic_op<T, CeilFunctor<T>>(stream, output, input);
}
template <class T>
void floor(const Stream& stream, Span<T> output, View<T> input) {
generic_op<T, FloorFunctor<T>>(stream, output, input);
}
template <class T>
void log(const Stream& stream, Span<T> output, View<T> input) {
generic_op<T, LogFunctor<T>>(stream, output, input);
}
template <class T>
void rint(const Stream& stream, Span<T> output, View<T> input) {
generic_op<T, RintFunctor<T>>(stream, output, input);
}
template <class T>
void sqrt(const Stream& stream, Span<T> output, View<T> input) {
generic_op<T, SqrtFunctor<T>>(stream, output, input);
}
template <class T>
void not_k(const Stream& stream, Span<T> output, View<T> input) {
generic_op<T, NotFunctor<T>>(stream, output, input);
}
template <class T>
void abs(const Stream& stream, Span<T> output, View<T> input) {
generic_op<T, AbsFunctor<T>>(stream, output, input);
@ -160,6 +190,12 @@ template void sigmoid<__half>(const Stream&, Span<__half>, View<__half>);
template void elu<__half>(const Stream&, Span<__half>, View<__half>);
template void abs<__half>(const Stream& stream, Span<__half> output, View<__half> input);
template void bnll<__half>(const Stream&, Span<__half>, View<__half>);
template void ceil<__half>(const Stream&, Span<__half>, View<__half>);
template void floor<__half>(const Stream&, Span<__half>, View<__half>);
template void log<__half>(const Stream&, Span<__half>, View<__half>);
template void rint<__half>(const Stream&, Span<__half>, View<__half>);
template void sqrt<__half>(const Stream&, Span<__half>, View<__half>);
template void not_k<__half>(const Stream&, Span<__half>, View<__half>);
template void power<__half>(const Stream&, Span<__half>, View<__half>, __half, __half, __half);
template void exp<__half>(const Stream&, Span<__half>, View<__half>, __half, __half);
#endif
@ -174,6 +210,12 @@ template void sigmoid<float>(const Stream&, Span<float>, View<float>);
template void elu<float>(const Stream&, Span<float>, View<float>);
template void abs<float>(const Stream& stream, Span<float> output, View<float> input);
template void bnll<float>(const Stream&, Span<float>, View<float>);
template void ceil<float>(const Stream&, Span<float>, View<float>);
template void floor<float>(const Stream&, Span<float>, View<float>);
template void log<float>(const Stream&, Span<float>, View<float>);
template void rint<float>(const Stream&, Span<float>, View<float>);
template void sqrt<float>(const Stream&, Span<float>, View<float>);
template void not_k<float>(const Stream&, Span<float>, View<float>);
template void power<float>(const Stream&, Span<float>, View<float>, float, float, float);
template void exp<float>(const Stream&, Span<float>, View<float>, float, float);

@ -209,6 +209,96 @@ struct BNLLFunctor {
}
};
template <class T>
struct CeilFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() { }
};
CUDA4DNN_DEVICE CeilFunctor() { }
CUDA4DNN_DEVICE CeilFunctor(const Params& params) { }
CUDA4DNN_DEVICE T operator()(T value) {
using csl::device::ceil;
return ceil(value);
}
};
template <class T>
struct FloorFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() { }
};
CUDA4DNN_DEVICE FloorFunctor() { }
CUDA4DNN_DEVICE FloorFunctor(const Params& params) { }
CUDA4DNN_DEVICE T operator()(T value) {
using csl::device::floor;
return floor(value);
}
};
template <class T>
struct LogFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() { }
};
CUDA4DNN_DEVICE LogFunctor() { }
CUDA4DNN_DEVICE LogFunctor(const Params& params) { }
CUDA4DNN_DEVICE T operator()(T value) {
using csl::device::log;
return log(value);
}
};
template <class T>
struct RintFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() { }
};
CUDA4DNN_DEVICE RintFunctor() { }
CUDA4DNN_DEVICE RintFunctor(const Params& params) { }
CUDA4DNN_DEVICE T operator()(T value) {
using csl::device::rint;
return rint(value);
}
};
template <class T>
struct SqrtFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() { }
};
CUDA4DNN_DEVICE SqrtFunctor() { }
CUDA4DNN_DEVICE SqrtFunctor(const Params& params) { }
CUDA4DNN_DEVICE T operator()(T value) {
using csl::device::sqrt;
return sqrt(value);
}
};
template <class T>
struct NotFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() { }
};
CUDA4DNN_DEVICE NotFunctor() { }
CUDA4DNN_DEVICE NotFunctor(const Params& params) { }
CUDA4DNN_DEVICE T operator()(T value) {
using csl::device::floor;
return floor(static_cast<T>(1.) - value);
}
};
template <class T>
struct PowerFunctor {
struct Params {

@ -119,6 +119,27 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace de
template <> inline __device__ __half round(__half value) { return hrint(value); }
#endif
template <class T> __device__ T floor(T value);
template <> inline __device__ double floor(double value) { return ::floor(value); }
template <> inline __device__ float floor(float value) { return floorf(value); }
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template <> inline __device__ __half floor(__half value) { return hfloor(value); }
#endif
template <class T> __device__ T log(T value);
template <> inline __device__ double log(double value) { return ::log(value); }
template <> inline __device__ float log(float value) { return logf(value); }
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template <> inline __device__ __half log(__half value) { return hlog(value); }
#endif
template <class T> __device__ T rint(T value);
template <> inline __device__ double rint(double value) { return ::rint(value); }
template <> inline __device__ float rint(float value) { return rintf(value); }
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template <> inline __device__ __half rint(__half value) { return hrint(value); }
#endif
template <class T> __device__ T ceil(T value);
template <> inline __device__ double ceil(double value) { return ::ceil(value); }
template <> inline __device__ float ceil(float value) { return ceilf(value); }

@ -42,6 +42,24 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
template <class T>
void bnll(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input);
template <class T>
void ceil(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input);
template <class T>
void floor(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input);
template <class T>
void log(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input);
template <class T>
void rint(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input);
template <class T>
void sqrt(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input);
template <class T>
void not_k(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input);
template <class T>
void power(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input, T exp, T scale, T shift);

@ -18,14 +18,12 @@
namespace cv { namespace dnn { namespace cuda4dnn {
template <class T>
class ReLUOp final : public CUDABackendNode {
public:
template <template<class> class Op, class T>
struct BaseOp : public CUDABackendNode
{
protected:
using wrapper_type = GetCUDABackendWrapperType<T>;
ReLUOp(csl::Stream stream_, T slope_)
: stream(std::move(stream_)), slope{ slope_ } { }
void forward(
const std::vector<cv::Ptr<BackendWrapper>>& inputs,
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
@ -39,9 +37,21 @@ namespace cv { namespace dnn { namespace cuda4dnn {
auto output_wrapper = outputs[i].dynamicCast<wrapper_type>();
auto output = output_wrapper->getSpan();
kernels::relu<T>(stream, output, input, slope);
static_cast<const Op<T>*>(this)->calculate(output, input);
}
}
};
template <class T>
class ReLUOp final : public BaseOp<ReLUOp, T> {
public:
ReLUOp(csl::Stream stream_, T slope_)
: stream(std::move(stream_)), slope{ slope_ } { }
void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
{
kernels::relu<T>(stream, output, input, slope);
}
private:
csl::Stream stream;
@ -49,28 +59,14 @@ namespace cv { namespace dnn { namespace cuda4dnn {
};
template <class T>
class ClippedReLUOp final : public CUDABackendNode {
class ClippedReLUOp final : public BaseOp<ClippedReLUOp, T> {
public:
using wrapper_type = GetCUDABackendWrapperType<T>;
ClippedReLUOp(csl::Stream stream_, T min_, T max_)
: stream(std::move(stream_)), min{ min_ }, max{ max_ } { }
void forward(
const std::vector<cv::Ptr<BackendWrapper>>& inputs,
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
csl::Workspace& workspace) override
void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
{
for (int i = 0; i < inputs.size(); i++)
{
auto input_wrapper = inputs[i].dynamicCast<wrapper_type>();
auto input = input_wrapper->getView();
auto output_wrapper = outputs[i].dynamicCast<wrapper_type>();
auto output = output_wrapper->getSpan();
kernels::clipped_relu<T>(stream, output, input, min, max);
}
kernels::clipped_relu<T>(stream, output, input, min, max);
}
private:
@ -79,35 +75,21 @@ namespace cv { namespace dnn { namespace cuda4dnn {
};
template <class T>
class ChannelwiseReLUOp final : public CUDABackendNode {
class ChannelwiseReLUOp final : public BaseOp<ChannelwiseReLUOp, T> {
public:
using wrapper_type = GetCUDABackendWrapperType<T>;
ChannelwiseReLUOp(csl::Stream stream_, const Mat& slope)
: stream(std::move(stream_))
: stream(std::move(stream_))
{
CV_Assert(!slope.empty());
slopeTensor = csl::makeTensorHeader<T>(slope);
csl::copyMatToTensor<T>(slope, slopeTensor, stream);
}
void forward(
const std::vector<cv::Ptr<BackendWrapper>>& inputs,
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
csl::Workspace& workspace) override
void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
{
for (int i = 0; i < inputs.size(); i++)
{
auto input_wrapper = inputs[i].dynamicCast<wrapper_type>();
auto input = input_wrapper->getView();
auto output_wrapper = outputs[i].dynamicCast<wrapper_type>();
auto output = output_wrapper->getSpan();
CV_Assert(input.get_axis_size(1) == slopeTensor.size());
std::size_t inner_size = input.size_range(2, input.rank());
kernels::axiswise_relu<T>(stream, output, input, inner_size, slopeTensor);
}
CV_Assert(input.get_axis_size(1) == slopeTensor.size());
std::size_t inner_size = input.size_range(2, input.rank());
kernels::axiswise_relu<T>(stream, output, input, inner_size, slopeTensor);
}
private:
@ -116,27 +98,13 @@ namespace cv { namespace dnn { namespace cuda4dnn {
};
template <class T>
class TanHOp final : public CUDABackendNode {
class TanHOp final : public BaseOp<TanHOp, T> {
public:
using wrapper_type = GetCUDABackendWrapperType<T>;
TanHOp(csl::Stream stream_) : stream(std::move(stream_)) { }
void forward(
const std::vector<cv::Ptr<BackendWrapper>>& inputs,
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
csl::Workspace& workspace) override
void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
{
for (int i = 0; i < inputs.size(); i++)
{
auto input_wrapper = inputs[i].dynamicCast<wrapper_type>();
auto input = input_wrapper->getView();
auto output_wrapper = outputs[i].dynamicCast<wrapper_type>();
auto output = output_wrapper->getSpan();
kernels::tanh<T>(stream, output, input);
}
kernels::tanh<T>(stream, output, input);
}
private:
@ -144,27 +112,13 @@ namespace cv { namespace dnn { namespace cuda4dnn {
};
template <class T>
class SwishOp final : public CUDABackendNode {
class SwishOp final : public BaseOp<SwishOp, T> {
public:
using wrapper_type = GetCUDABackendWrapperType<T>;
SwishOp(csl::Stream stream_) : stream(std::move(stream_)) { }
void forward(
const std::vector<cv::Ptr<BackendWrapper>>& inputs,
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
csl::Workspace& workspace) override
void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
{
for (int i = 0; i < inputs.size(); i++)
{
auto input_wrapper = inputs[i].dynamicCast<wrapper_type>();
auto input = input_wrapper->getView();
auto output_wrapper = outputs[i].dynamicCast<wrapper_type>();
auto output = output_wrapper->getSpan();
kernels::swish<T>(stream, output, input);
}
kernels::swish<T>(stream, output, input);
}
private:
@ -172,27 +126,13 @@ namespace cv { namespace dnn { namespace cuda4dnn {
};
template <class T>
class MishOp final : public CUDABackendNode {
class MishOp final : public BaseOp<MishOp, T> {
public:
using wrapper_type = GetCUDABackendWrapperType<T>;
MishOp(csl::Stream stream_) : stream(std::move(stream_)) { }
void forward(
const std::vector<cv::Ptr<BackendWrapper>>& inputs,
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
csl::Workspace& workspace) override
void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
{
for (int i = 0; i < inputs.size(); i++)
{
auto input_wrapper = inputs[i].dynamicCast<wrapper_type>();
auto input = input_wrapper->getView();
auto output_wrapper = outputs[i].dynamicCast<wrapper_type>();
auto output = output_wrapper->getSpan();
kernels::mish<T>(stream, output, input);
}
kernels::mish<T>(stream, output, input);
}
private:
@ -200,27 +140,13 @@ namespace cv { namespace dnn { namespace cuda4dnn {
};
template <class T>
class SigmoidOp final : public CUDABackendNode {
class SigmoidOp final : public BaseOp<SigmoidOp, T> {
public:
using wrapper_type = GetCUDABackendWrapperType<T>;
SigmoidOp(csl::Stream stream_) : stream(std::move(stream_)) { }
void forward(
const std::vector<cv::Ptr<BackendWrapper>>& inputs,
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
csl::Workspace& workspace) override
void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
{
for (int i = 0; i < inputs.size(); i++)
{
auto input_wrapper = inputs[i].dynamicCast<wrapper_type>();
auto input = input_wrapper->getView();
auto output_wrapper = outputs[i].dynamicCast<wrapper_type>();
auto output = output_wrapper->getSpan();
kernels::sigmoid<T>(stream, output, input);
}
kernels::sigmoid<T>(stream, output, input);
}
private:
@ -228,27 +154,27 @@ namespace cv { namespace dnn { namespace cuda4dnn {
};
template <class T>
class ELUOp final : public CUDABackendNode {
class ELUOp final : public BaseOp<ELUOp, T> {
public:
using wrapper_type = GetCUDABackendWrapperType<T>;
ELUOp(csl::Stream stream_) : stream(std::move(stream_)) { }
void forward(
const std::vector<cv::Ptr<BackendWrapper>>& inputs,
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
csl::Workspace& workspace) override
void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
{
for (int i = 0; i < inputs.size(); i++)
{
auto input_wrapper = inputs[i].dynamicCast<wrapper_type>();
auto input = input_wrapper->getView();
kernels::elu<T>(stream, output, input);
}
auto output_wrapper = outputs[i].dynamicCast<wrapper_type>();
auto output = output_wrapper->getSpan();
private:
csl::Stream stream;
};
kernels::elu<T>(stream, output, input);
}
template <class T>
class AbsValOp final : public BaseOp<AbsValOp, T> {
public:
AbsValOp(csl::Stream stream_) : stream(std::move(stream_)) { }
void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
{
kernels::abs<T>(stream, output, input);
}
private:
@ -256,27 +182,41 @@ namespace cv { namespace dnn { namespace cuda4dnn {
};
template <class T>
class AbsValOp final : public CUDABackendNode {
class BNLLOp final : public BaseOp<BNLLOp, T> {
public:
using wrapper_type = GetCUDABackendWrapperType<T>;
BNLLOp(csl::Stream stream_) : stream(std::move(stream_)) { }
AbsValOp(csl::Stream stream_) : stream(std::move(stream_)) { }
void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
{
kernels::bnll<T>(stream, output, input);
}
void forward(
const std::vector<cv::Ptr<BackendWrapper>>& inputs,
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
csl::Workspace& workspace) override
private:
csl::Stream stream;
};
template <class T>
class CeilOp final : public BaseOp<CeilOp, T> {
public:
CeilOp(csl::Stream stream_) : stream(std::move(stream_)) { }
void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
{
for (int i = 0; i < inputs.size(); i++)
{
auto input_wrapper = inputs[i].dynamicCast<wrapper_type>();
auto input = input_wrapper->getView();
kernels::ceil<T>(stream, output, input);
}
auto output_wrapper = outputs[i].dynamicCast<wrapper_type>();
auto output = output_wrapper->getSpan();
private:
csl::Stream stream;
};
kernels::abs<T>(stream, output, input);
}
template <class T>
class FloorOp final : public BaseOp<FloorOp, T> {
public:
FloorOp(csl::Stream stream_) : stream(std::move(stream_)) { }
void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
{
kernels::floor<T>(stream, output, input);
}
private:
@ -284,27 +224,41 @@ namespace cv { namespace dnn { namespace cuda4dnn {
};
template <class T>
class BNLLOp final : public CUDABackendNode {
class LogOp final : public BaseOp<LogOp, T> {
public:
using wrapper_type = GetCUDABackendWrapperType<T>;
LogOp(csl::Stream stream_) : stream(std::move(stream_)) { }
BNLLOp(csl::Stream stream_) : stream(std::move(stream_)) { }
void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
{
kernels::log<T>(stream, output, input);
}
void forward(
const std::vector<cv::Ptr<BackendWrapper>>& inputs,
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
csl::Workspace& workspace) override
private:
csl::Stream stream;
};
template <class T>
class RoundOp final : public BaseOp<RoundOp, T> {
public:
RoundOp(csl::Stream stream_) : stream(std::move(stream_)) { }
void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
{
for (int i = 0; i < inputs.size(); i++)
{
auto input_wrapper = inputs[i].dynamicCast<wrapper_type>();
auto input = input_wrapper->getView();
kernels::rint<T>(stream, output, input);
}
auto output_wrapper = outputs[i].dynamicCast<wrapper_type>();
auto output = output_wrapper->getSpan();
private:
csl::Stream stream;
};
kernels::bnll<T>(stream, output, input);
}
template <class T>
class SqrtOp final : public BaseOp<SqrtOp, T> {
public:
SqrtOp(csl::Stream stream_) : stream(std::move(stream_)) { }
void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
{
kernels::sqrt<T>(stream, output, input);
}
private:
@ -312,28 +266,28 @@ namespace cv { namespace dnn { namespace cuda4dnn {
};
template <class T>
class PowerOp final : public CUDABackendNode {
class NotOp final : public BaseOp<NotOp, T> {
public:
using wrapper_type = GetCUDABackendWrapperType<T>;
NotOp(csl::Stream stream_) : stream(std::move(stream_)) { }
void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
{
kernels::not_k<T>(stream, output, input);
}
private:
csl::Stream stream;
};
template <class T>
class PowerOp final : public BaseOp<PowerOp, T> {
public:
PowerOp(csl::Stream stream_, T exp_, T scale_, T shift_)
: stream(std::move(stream_)), exp{ exp_ }, scale{ scale_ }, shift{ shift_ } { }
void forward(
const std::vector<cv::Ptr<BackendWrapper>>& inputs,
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
csl::Workspace& workspace) override
void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
{
for (int i = 0; i < inputs.size(); i++)
{
auto input_wrapper = inputs[i].dynamicCast<wrapper_type>();
auto input = input_wrapper->getView();
auto output_wrapper = outputs[i].dynamicCast<wrapper_type>();
auto output = output_wrapper->getSpan();
kernels::power<T>(stream, output, input, exp, scale, shift);
}
kernels::power<T>(stream, output, input, exp, scale, shift);
}
private:
@ -342,28 +296,14 @@ namespace cv { namespace dnn { namespace cuda4dnn {
};
template <class T>
class ExpOp final : public CUDABackendNode {
class ExpOp final : public BaseOp<ExpOp, T> {
public:
using wrapper_type = GetCUDABackendWrapperType<T>;
ExpOp(csl::Stream stream_, T nScale_, T nShift_)
: stream(std::move(stream_)), normScale{ nScale_ }, normShift{ nShift_ } { }
void forward(
const std::vector<cv::Ptr<BackendWrapper>>& inputs,
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
csl::Workspace& workspace) override
void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
{
for (int i = 0; i < inputs.size(); i++)
{
auto input_wrapper = inputs[i].dynamicCast<wrapper_type>();
auto input = input_wrapper->getView();
auto output_wrapper = outputs[i].dynamicCast<wrapper_type>();
auto output = output_wrapper->getSpan();
kernels::exp<T>(stream, output, input, normScale, normShift);
}
kernels::exp<T>(stream, output, input, normScale, normShift);
}
private:

@ -111,6 +111,12 @@ void initializeLayerFactory()
CV_DNN_REGISTER_LAYER_CLASS(AbsVal, AbsLayer);
CV_DNN_REGISTER_LAYER_CLASS(Power, PowerLayer);
CV_DNN_REGISTER_LAYER_CLASS(Exp, ExpLayer);
CV_DNN_REGISTER_LAYER_CLASS(Ceil, CeilLayer);
CV_DNN_REGISTER_LAYER_CLASS(Floor, FloorLayer);
CV_DNN_REGISTER_LAYER_CLASS(Log, LogLayer);
CV_DNN_REGISTER_LAYER_CLASS(Round, RoundLayer);
CV_DNN_REGISTER_LAYER_CLASS(Sqrt, SqrtLayer);
CV_DNN_REGISTER_LAYER_CLASS(Not, NotLayer);
CV_DNN_REGISTER_LAYER_CLASS(BatchNorm, BatchNormLayer);
CV_DNN_REGISTER_LAYER_CLASS(MaxUnpool, MaxUnpoolLayer);
CV_DNN_REGISTER_LAYER_CLASS(Dropout, BlankLayer);
@ -133,6 +139,7 @@ void initializeLayerFactory()
CV_DNN_REGISTER_LAYER_CLASS(Padding, PaddingLayer);
CV_DNN_REGISTER_LAYER_CLASS(Proposal, ProposalLayer);
CV_DNN_REGISTER_LAYER_CLASS(Scale, ScaleLayer);
CV_DNN_REGISTER_LAYER_CLASS(Compare, CompareLayer);
CV_DNN_REGISTER_LAYER_CLASS(DataAugmentation, DataAugmentationLayer);
CV_DNN_REGISTER_LAYER_CLASS(Correlation, CorrelationLayer);
CV_DNN_REGISTER_LAYER_CLASS(Accum, AccumLayer);

File diff suppressed because it is too large Load Diff

@ -38,6 +38,7 @@ public:
hasBias = params.get<bool>("bias_term", false);
axis = params.get<int>("axis", 1);
hasWeights = false;
mode = params.get<String>("mode", "scale");
}
bool getMemoryShapes(const std::vector<MatShape> &inputs,
@ -59,6 +60,10 @@ public:
virtual bool supportBackend(int backendId) CV_OVERRIDE
{
if (mode != "scale")
{
return backendId == DNN_BACKEND_OPENCV;
}
return backendId == DNN_BACKEND_OPENCV ||
backendId == DNN_BACKEND_CUDA ||
backendId == DNN_BACKEND_HALIDE ||
@ -66,6 +71,20 @@ public:
(backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH && axis > 0);
}
template<typename T>
void handleCompare(const Mat& a, const T& b, Mat& dst, const int spatialSize)
{
Mat out(1, spatialSize, CV_8U);
if (mode == "equal")
compare(a, b, out, CMP_EQ);
else if (mode == "greater")
compare(a, b, out, CMP_GT);
else
compare(a, b, out, CMP_LT);
out.convertTo(dst, CV_32F, 1. / 255.);
}
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
{
CV_TRACE_FUNCTION();
@ -123,7 +142,16 @@ public:
float b = biasesData ? biasesData[j] : 0;
Mat inpSlice(1, spatialSize, CV_32F, inpData);
Mat outSlice(1, spatialSize, CV_32F, outData);
inpSlice.convertTo(outSlice, CV_32F, w, b);
if (mode == "scale")
{
inpSlice.convertTo(outSlice, CV_32F, w, b);
}
else
{
handleCompare(inpSlice, b, outSlice, spatialSize);
}
inpData += spatialSize;
outData += spatialSize;
}
@ -142,7 +170,16 @@ public:
add(outSlice, bias, outSlice);
}
else if (hasBias)
add(inpSlice, bias, outSlice);
{
if (mode == "scale")
{
add(inpSlice, bias, outSlice);
}
else
{
handleCompare(inpSlice, bias, outSlice, numWeights);
}
}
inpData += numWeights;
outData += numWeights;
}
@ -385,6 +422,18 @@ Ptr<Layer> ShiftLayer::create(const LayerParams& params)
return Ptr<ScaleLayer>(new ScaleLayerImpl(scaleParams));
}
Ptr<Layer> CompareLayer::create(const LayerParams& params)
{
LayerParams compareParams;
compareParams.name = params.name;
compareParams.type = "Scale";
compareParams.blobs = params.blobs;
compareParams.set("bias_term", true);
compareParams.set("axis", 0);
compareParams.set("mode", params.get<String>("mode"));
return Ptr<ScaleLayer>(new ScaleLayerImpl(compareParams));
}
class DataAugmentationLayerImpl CV_FINAL : public DataAugmentationLayer
{
public:

@ -118,6 +118,8 @@ private:
void parseRelu (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseElu (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseTanh (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseAbs (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseCompare (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parsePRelu (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseLRN (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseInstanceNormalization(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
@ -1410,6 +1412,38 @@ void ONNXImporter::parseTanh(LayerParams& layerParams, const opencv_onnx::NodePr
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseAbs(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "AbsVal";
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseCompare(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
CV_Assert(node_proto.input_size() == 2);
const std::string& layer_type = node_proto.op_type();
bool is_const_0 = layer_id.find(node_proto.input(0)) == layer_id.end();
bool is_const_1 = layer_id.find(node_proto.input(1)) == layer_id.end();
if (is_const_0 || is_const_1)
{
Mat blob = getBlob(node_proto, static_cast<int>(is_const_1));
blob = blob.reshape(1, 1);
layerParams.blobs.push_back(blob);
}
layerParams.type = "Compare";
if (layer_type == "Equal")
layerParams.set("mode", "equal");
else if (layer_type == "Greater")
layerParams.set("mode", "greater");
else
layerParams.set("mode", "less");
addLayer(layerParams, node_proto);
}
void ONNXImporter::parsePRelu(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "PReLU";
@ -2939,6 +2973,8 @@ const ONNXImporter::DispatchMap ONNXImporter::buildDispatchMap()
dispatch["Relu"] = &ONNXImporter::parseRelu;
dispatch["Elu"] = &ONNXImporter::parseElu;
dispatch["Tanh"] = &ONNXImporter::parseTanh;
dispatch["Abs"] = &ONNXImporter::parseAbs;
dispatch["Equal"] = dispatch["Greater"] = dispatch["Less"] = &ONNXImporter::parseCompare;
dispatch["PRelu"] = &ONNXImporter::parsePRelu;
dispatch["LRN"] = &ONNXImporter::parseLRN;
dispatch["InstanceNormalization"] = &ONNXImporter::parseInstanceNormalization;

@ -151,3 +151,39 @@ __kernel void ExpForward(const int n, __global const T* in, __global T* out,
out[index] = exp(normShift + normScale * in[index]);
}
}
__kernel void CeilForward(const int n, __global T* in, __global T* out) {
int index = get_global_id(0);
if(index < n)
out[index] = ceil(in[index]);
}
__kernel void FloorForward(const int n, __global T* in, __global T* out) {
int index = get_global_id(0);
if(index < n)
out[index] = floor(in[index]);
}
__kernel void LogForward(const int n, __global T* in, __global T* out) {
int index = get_global_id(0);
if(index < n)
out[index] = log(in[index]);
}
__kernel void RoundForward(const int n, __global T* in, __global T* out) {
int index = get_global_id(0);
if(index < n)
out[index] = rint(in[index]);
}
__kernel void SqrtForward(const int n, __global T* in, __global T* out) {
int index = get_global_id(0);
if(index < n)
out[index] = sqrt(in[index]);
}
__kernel void NotForward(const int n, __global T* in, __global T* out) {
int index = get_global_id(0);
if(index < n)
out[index] = floor(1.0f - in[index]);
}

@ -353,6 +353,82 @@ TEST_P(Test_ONNX_layers, Exp)
testONNXModels("exp");
}
TEST_P(Test_ONNX_layers, Elementwise_Ceil)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
testONNXModels("ceil");
}
TEST_P(Test_ONNX_layers, Elementwise_Floor)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
testONNXModels("floor");
}
TEST_P(Test_ONNX_layers, Elementwise_Log)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
testONNXModels("log");
}
TEST_P(Test_ONNX_layers, Elementwise_Round)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
testONNXModels("round");
}
TEST_P(Test_ONNX_layers, Elementwise_Sqrt)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
testONNXModels("sqrt");
}
TEST_P(Test_ONNX_layers, Elementwise_not)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
testONNXModels("not");
}
TEST_P(Test_ONNX_layers, Compare)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
testONNXModels("equal");
testONNXModels("greater");
testONNXModels("less");
}
TEST_P(Test_ONNX_layers, CompareSameDims)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
testONNXModels("equal_same_dims", npy, 0, 0, false, true, 2);
testONNXModels("greater_same_dims", npy, 0, 0, false, true, 2);
testONNXModels("less_same_dims", npy, 0, 0, false, true, 2);
}
TEST_P(Test_ONNX_layers, Concatenation)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)

Loading…
Cancel
Save