From eefee8574ab26be0bc102938f9f74e5c09829e77 Mon Sep 17 00:00:00 2001 From: Yuantao Feng Date: Wed, 17 May 2023 15:03:45 +0800 Subject: [PATCH] dnn: refactor reduce (#23613) * initial impl * remove reduce in8; fix reduce importer * fix bugs and add log sum exp * remove unnecessary header and fix indentation --- .../dnn/include/opencv2/dnn/all_layers.hpp | 9 - modules/dnn/src/init.cpp | 1 - modules/dnn/src/int8layers/reduce_layer.cpp | 234 ------ modules/dnn/src/layers/reduce_layer.cpp | 707 ++++++++++-------- modules/dnn/src/onnx/onnx_importer.cpp | 192 +---- 5 files changed, 449 insertions(+), 694 deletions(-) delete mode 100644 modules/dnn/src/int8layers/reduce_layer.cpp diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp index 49be0674f4..fe08d58ec6 100644 --- a/modules/dnn/include/opencv2/dnn/all_layers.hpp +++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp @@ -346,18 +346,9 @@ CV__DNN_INLINE_NS_BEGIN class CV_EXPORTS ReduceLayer : public Layer { public: - int reduceType; - // reduceDims contains the dimensions that need to be reduced, targetDims is the target output dimension. - std::vector reduceDims, targetDims; static Ptr create(const LayerParams& params); }; - class CV_EXPORTS ReduceLayerInt8 : public ReduceLayer - { - public: - static Ptr create(const LayerParams& params); - }; - class CV_EXPORTS SoftmaxLayer : public Layer { public: diff --git a/modules/dnn/src/init.cpp b/modules/dnn/src/init.cpp index 360113196b..2ce54ac0bb 100644 --- a/modules/dnn/src/init.cpp +++ b/modules/dnn/src/init.cpp @@ -194,7 +194,6 @@ void initializeLayerFactory() CV_DNN_REGISTER_LAYER_CLASS(ConvolutionInt8, ConvolutionLayerInt8); CV_DNN_REGISTER_LAYER_CLASS(InnerProductInt8, InnerProductLayerInt8); CV_DNN_REGISTER_LAYER_CLASS(PoolingInt8, PoolingLayerInt8); - CV_DNN_REGISTER_LAYER_CLASS(ReduceInt8, ReduceLayerInt8); CV_DNN_REGISTER_LAYER_CLASS(EltwiseInt8, EltwiseLayerInt8); CV_DNN_REGISTER_LAYER_CLASS(BatchNormInt8, BatchNormLayerInt8); CV_DNN_REGISTER_LAYER_CLASS(ScaleInt8, ScaleLayerInt8); diff --git a/modules/dnn/src/int8layers/reduce_layer.cpp b/modules/dnn/src/int8layers/reduce_layer.cpp deleted file mode 100644 index 9ffb4897a0..0000000000 --- a/modules/dnn/src/int8layers/reduce_layer.cpp +++ /dev/null @@ -1,234 +0,0 @@ -// This file is part of OpenCV project. -// It is subject to the license terms in the LICENSE file found in the top-level directory -// of this distribution and at http://opencv.org/license.html. - -#include "../precomp.hpp" -#include "layers_common.hpp" - -#include -#include -#include - -namespace cv -{ -namespace dnn -{ - -class ReduceLayerInt8Impl CV_FINAL : public ReduceLayerInt8 -{ -public: - ReduceLayerInt8Impl(const LayerParams& params) - { - // Set reduce type - CV_Assert(params.has("reduce")); - String typeString = toLowerCase(params.get("reduce")); - if (typeString == "max") - reduceType = MAX; - else if (typeString == "min") - reduceType = MIN; - else - CV_Error(Error::StsBadArg, "Unknown reduce type \"" + typeString + "\""); - - // Set deleted dims - CV_Assert(params.has("deleted_dims")); - DictValue tempDims = params.get("deleted_dims"); - int i, n = tempDims.size(); - reduceDims.resize(n); - for (i = 0; i < n; i++) - { - reduceDims[i] = tempDims.get(i); - } - - CV_Assert(params.has("target_dims")); - tempDims = params.get("target_dims"); - n = tempDims.size(); - targetDims.resize(n); - for (i = 0; i < n; i++) - { - targetDims[i] = tempDims.get(i); - } - } - - virtual bool supportBackend(int backendId) CV_OVERRIDE - { - if (backendId == DNN_BACKEND_OPENCV) - { - return true; - } - return false; - } - - // reduceType == MIN - struct ReduceOpMIN - { - int8_t apply(const int8_t* first, const int8_t* last) - { - return std::accumulate(first, last, *first, - [](int8_t a, int8_t b) - { - return std::min(a, b); - }); - } - }; - - // reduceType == MAX - struct ReduceOpMAX - { - int8_t apply(const int8_t* first, const int8_t* last) - { - return std::accumulate(first, last, *first, - [](int8_t a, int8_t b) - { - return std::max(a, b); - }); - } - }; - - template - class ReduceInvoker : public ParallelLoopBody - { - public: - const Mat* src; - Mat *dst; - std::vector reduceDims; - int nstripes; - int reduceType; - Ptr func; - - ReduceInvoker() : src(0), dst(0), nstripes(0), reduceType(MAX), func(makePtr()) {} - - static void run(const Mat& src, Mat& dst, std::vector reduceDims, int reduceType, int nstripes) - { - CV_Assert_N(src.isContinuous(), dst.isContinuous(), src.type() == CV_8S, src.type() == dst.type()); - - ReduceInvoker p; - - p.src = &src; - p.dst = &dst; - - p.reduceDims = reduceDims; - p.nstripes = nstripes; - p.reduceType = reduceType; - - parallel_for_(Range(0, nstripes), p, nstripes); - } - - void operator()(const Range& r) const CV_OVERRIDE - { - size_t total = dst->total(); - size_t stripeSize = (total + nstripes - 1)/nstripes; - size_t stripeStart = r.start*stripeSize; - size_t stripeEnd = std::min(r.end*stripeSize, total); - size_t totalDeleted = std::accumulate(reduceDims.begin(), reduceDims.end(), 1, std::multiplies()); - - int8_t *dstData = (int8_t *)dst->data; - int8_t *srcData = (int8_t *)src->data; - - for (size_t ofs = stripeStart; ofs < stripeEnd;) - { - const int8_t* first = srcData + ofs * totalDeleted; - const int8_t* last = srcData + (ofs + 1) * totalDeleted; - - dstData[ofs] = func->apply(first, last); - ofs += 1; - } - } - }; - - void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE - { - CV_TRACE_FUNCTION(); - CV_TRACE_ARG_VALUE(name, "name", name.c_str()); - - std::vector inputs, outputs; - inputs_arr.getMatVector(inputs); - outputs_arr.getMatVector(outputs); - CV_Assert(inputs.size() == 1); - const int nstripes = getNumThreads(); - - switch (reduceType) - { - case MIN: - { - ReduceInvoker::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); - break; - } - case MAX: - { - ReduceInvoker::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); - break; - } - default: - CV_Error(Error::StsNotImplemented, "Not implemented"); - break; - } - } - - bool getMemoryShapes(const std::vector &inputs, - const int requiredOutputs, - std::vector &outputs, - std::vector &internals) const CV_OVERRIDE - { - CV_Assert(inputs.size() > 0); - CV_Assert( reduceDims.size() !=0 && targetDims.size() != 0 && inputs[0].size() >= reduceDims.size()); - - // outShapeTmp can save the right number of `total(outShapeTmp)`. And the outShape is used as the final output shape. - std::vector outShapeTmp, outShape; - outShape.assign(targetDims.begin(), targetDims.end()); - if (inputs[0].size() == reduceDims.size()) - outShapeTmp.push_back(1); - else - { - for (int i = 0; i < inputs[0].size() - reduceDims.size(); i++) - { - outShapeTmp.push_back(inputs[0][i]); - } - } - - // Support dynamic shape of Batch size. - // Note that: when there are multiple dynamic inputs, we will give an error. - if (total(outShape) != total(outShapeTmp)) - { - if (outShape[0] != outShapeTmp[0]) - outShape[0] = outShapeTmp[0]; - } - - CV_Assert(total(outShape) == total(outShapeTmp)); - outputs.assign(1, outShape); - - return false; - } - - virtual bool tryQuantize(const std::vector > &scales, - const std::vector > &zeropoints, LayerParams& params) CV_OVERRIDE - { - return false; - } - - virtual int64 getFLOPS(const std::vector &inputs, - const std::vector &outputs) const CV_OVERRIDE - { - CV_UNUSED(inputs); // suppress unused variable warning - long flops = 0; - size_t totalDeleted = std::accumulate(reduceDims.begin(), reduceDims.end(), 1, std::multiplies()); - for (int i = 0; i < outputs.size(); i++) - { - flops += total(outputs[i])*(totalDeleted); - } - return flops; - } -private: - enum Type - { - MAX, - MIN - }; -}; - -Ptr ReduceLayerInt8::create(const LayerParams& params) -{ - return Ptr(new ReduceLayerInt8Impl(params)); -} - -} -} diff --git a/modules/dnn/src/layers/reduce_layer.cpp b/modules/dnn/src/layers/reduce_layer.cpp index c1f74f1cc1..b983a791c5 100644 --- a/modules/dnn/src/layers/reduce_layer.cpp +++ b/modules/dnn/src/layers/reduce_layer.cpp @@ -3,393 +3,505 @@ // of this distribution and at http://opencv.org/license.html. #include "../precomp.hpp" -#include "opencv2/core/hal/intrin.hpp" -#include "../op_cuda.hpp" -#include "../op_webnn.hpp" +#include -#include -#include -#include -using std::max; -using std::min; -#include - -namespace cv -{ -namespace dnn -{ +namespace cv { namespace dnn { class ReduceLayerImpl CV_FINAL : public ReduceLayer { public: - ReduceLayerImpl(const LayerParams& params) - { + ReduceLayerImpl(const LayerParams& params) { setParamsFrom(params); + // set reduce type CV_Assert(params.has("reduce")); - String typeString = toLowerCase(params.get("reduce")); - if (typeString == "max") - reduceType= MAX; - else if (typeString == "min") - reduceType= MIN; - else if (typeString == "ave") - reduceType= AVE; - else if (typeString == "sum") - reduceType= SUM; - else if (typeString == "sum_square") - reduceType= SUM_SQUARE; - else if (typeString == "l1") - reduceType= L1; - else if (typeString == "l2") - reduceType= L2; - else if (typeString == "log_sum") - reduceType= LOG_SUM; - else if (typeString == "log_sum_exp") - reduceType= LOG_SUM_EXP; - else if (typeString == "prod") - reduceType= PROD; + String op_type = toLowerCase(params.get("reduce")); + if (op_type == "max") + reduce_type = ReduceType::MAX; + else if (op_type == "min") + reduce_type = ReduceType::MIN; + else if (op_type == "mean") + reduce_type = ReduceType::MEAN; + else if (op_type == "sum") + reduce_type = ReduceType::SUM; + else if (op_type == "sum_square") + reduce_type = ReduceType::SUM_SQUARE; + else if (op_type == "l1") + reduce_type = ReduceType::L1; + else if (op_type == "l2") + reduce_type = ReduceType::L2; + else if (op_type == "log_sum") + reduce_type = ReduceType::LOG_SUM; + else if (op_type == "log_sum_exp") + reduce_type = ReduceType::LOG_SUM_EXP; + else if (op_type == "prod") + reduce_type = ReduceType::PROD; else - CV_Error(Error::StsBadArg, "Unknown reduce type\"" + typeString + "\""); - - // set deleted dims - CV_Assert(params.has("deleted_dims")); - DictValue tempDims = params.get("deleted_dims"); - int i, n = tempDims.size(); - reduceDims.resize(n); - for (i = 0; i < n; i++) - { - reduceDims[i] = tempDims.get(i); + CV_Error(Error::StsBadArg, "Unknown reduce type\"" + op_type + "\""); + + keepdims = params.get("keepdims", true); + noop_with_empty_axes = params.get("noop_with_empty_axes", false); + + // get axes if it is existed, otherwise reduce all + if (params.has("axes")) { + auto param_axes = params.get("axes"); + int num_axes = param_axes.size(); + axes.resize(num_axes); + for (int i = 0; i < num_axes; ++i) + axes[i] = param_axes.get(i); } + } - CV_Assert(params.has("target_dims")); - tempDims = params.get("target_dims"); - n = tempDims.size(); - targetDims.resize(n); - for (i = 0; i < n; i++) - { - targetDims[i] = tempDims.get(i); + virtual bool supportBackend(int backendId) CV_OVERRIDE { + return backendId == DNN_BACKEND_OPENCV; + } + + virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE { + if (axes.empty()) { + return; + } + + std::vector inputs, outputs; + inputs_arr.getMatVector(inputs); + outputs_arr.getMatVector(outputs); + + auto shape_input = shape(inputs[0]); + for (auto i = 0; i < axes.size(); ++i) { + auto norm_axis = normalize_axis(axes[i], shape_input); + axes[i] = norm_axis; + } + + bool do_nothing = true; + for (auto axis : axes) { + if (shape_input[axis] != 1) { + do_nothing = false; + } + } + if (do_nothing) { + axes.clear(); + noop_with_empty_axes = true; } } - virtual bool supportBackend(int backendId) CV_OVERRIDE + bool getMemoryShapes(const std::vector &inputs, + const int requiredOutputs, + std::vector &outputs, + std::vector &internals) const CV_OVERRIDE { - if (backendId == DNN_BACKEND_OPENCV) - { - return true; + // empty axes + if (axes.empty()) { + if (noop_with_empty_axes) { + // do nothing + outputs.assign(1, inputs[0]); + } else { + // reduce all axes + MatShape shape_output; + if (keepdims) { + shape_output = inputs[0]; + for (auto i = 0; i < shape_output.size(); ++i) + shape_output[i] = 1; + } else { + shape_output.push_back(1); + } + outputs.assign(1, shape_output); + } + } else { + auto shape_output_ = inputs[0]; + for (size_t i = 0; i < axes.size(); ++i) { + auto norm_axis = normalize_axis(axes[i], inputs[0]); + shape_output_[norm_axis] = -1; + } + MatShape shape_output; + for (size_t i = 0; i < shape_output_.size(); ++i) { + if (shape_output_[i] == -1) { + if (keepdims) + shape_output.push_back(1); + else + continue; + } else + shape_output.push_back(shape_output_[i]); + } + if (shape_output.empty()) + shape_output.push_back(1); + + outputs.assign(1, shape_output); } + return false; } - // reduceType == MIN - struct ReduceOpMIN - { - float apply(const float* first, const float* last, const float ikarea = 1.0f) - { - return std::accumulate(first, last, FLT_MAX, - [](float a, float b) - { - return std::min(a, b); - }); + template + class ReduceBase { + public: + using dtype_input = T; + + ReduceBase(size_t n, const T& init) : n_(n), accumulator_(init) {} + virtual void update(const T& a) = 0; + virtual T get_value() { return accumulator_; } + virtual ~ReduceBase() = default; + protected: + size_t n_; + T accumulator_; + }; + + template + class ReduceMin : public ReduceBase { + public: + ReduceMin(size_t n, const T& init) : ReduceBase(n, init) {} + void update(const T& a) override { + this->accumulator_ = a > this->accumulator_ ? this->accumulator_ : a; } }; - // reduceType == MAX - struct ReduceOpMAX - { - float apply(const float* first, const float* last, const float ikarea = 1.0f) - { - return std::accumulate(first, last, -FLT_MAX, - [](float a, float b) - { - return std::max(a, b); - }); + template + class ReduceMax : public ReduceBase { + public: + ReduceMax(size_t n, const T& init) : ReduceBase(n, init) {} + void update(const T& a) override { + this->accumulator_ = a > this->accumulator_ ? a : this->accumulator_; } }; - // reduceType == SUM - struct ReduceOpSUM - { - float apply(const float* first, const float* last, const float ikarea = 1.0f) - { - return std::accumulate(first, last, 0.f); + template + class ReduceSum : public ReduceBase { + public: + ReduceSum(size_t n, const T& init) : ReduceBase(n, 0) {} + void update(const T& a) override { + this->accumulator_ += a; } }; - // reduceType == AVE - struct ReduceOpAVE - { - float apply(const float* first, const float* last, const float ikarea = 1.0f) - { - float output = std::accumulate(first, last, 0.f); - return output * ikarea; + template + class ReduceMean : public ReduceSum { + public: + ReduceMean(size_t n, const T& init) : ReduceSum(n, init) {} + T get_value() override { + return this->accumulator_ / static_cast(this->n_); } }; - // reduceType == SUM_SQUARE - struct ReduceOpSUM_SQUARE - { - float apply(const float* first, const float* last, const float ikarea = 1.0f) - { - return std::accumulate(first, last, 0.f, - [](float a, float b) - { - return a + b * b; - }); + template + class ReduceSumSquare : public ReduceBase { + public: + ReduceSumSquare(size_t n, const T& init) : ReduceBase(n, 0) {} + void update(const T& a) override { + this->accumulator_ += a * a; } }; - // reduceType == L1 - struct ReduceOpL1 - { - float apply(const float* first, const float* last, const float ikarea = 1.0f) - { - return std::accumulate(first, last, 0.f, - [](float a, float b) - { - return a + std::abs(b); - }); + template + class ReduceL1 : public ReduceBase { + public: + ReduceL1(size_t n, const T& init) : ReduceBase(n, 0) {} + void update(const T& a) override { + this->accumulator_ += a > 0 ? a : -a; } }; - // reduceType == L2 - struct ReduceOpL2 - { - float apply(const float* first, const float* last, const float ikarea = 1.0f) - { - float output = std::accumulate(first, last, 0.f, - [](float a, float b) - { - return a + b * b; - }); - return std::sqrt(output); + template + class ReduceL2 : public ReduceBase { + public: + ReduceL2(size_t n, const T& init) : ReduceBase(n, 0) {} + void update(const T& a) override { + this->accumulator_ += a * a; + } + T get_value() override { + return std::sqrt(this->accumulator_); } }; - // reduceType == PROD - struct ReduceOpPROD - { - float apply(const float* first, const float* last, const float ikarea = 1.0f) - { - return std::accumulate(first, last, 1.0f, std::multiplies()); + template + class ReduceProd : public ReduceBase { + public: + ReduceProd(size_t n, const T& init) : ReduceBase(n, 1) {} + void update(const T& a) override { + this->accumulator_ *= a; } }; - // reduceType == LOG_SUM - struct ReduceOpLOG_SUM - { - float apply(const float* first, const float* last, const float ikarea = 1.0f) - { - float output = std::accumulate(first, last, 0.0f); - return std::log(output); + template + class ReduceLogSum : public ReduceBase { + public: + ReduceLogSum(size_t n, const T& init) : ReduceBase(n, 0) {} + void update(const T& a) override { + this->accumulator_ += a; + } + T get_value() override { + return static_cast(std::log(this->accumulator_)); } }; - // reduceType == LOG_SUM_EXP - struct ReduceOpLOG_SUM_EXP - { - float apply(const float* first, const float* last, const float ikarea = 1.0f) - { - float output = std::accumulate(first, last, 0.0f, - [](float a, float b) - { - return a + std::exp(b); - }); - return std::log(output); + // FIXME: overflow caution + template + class ReduceLogSumExp : public ReduceBase { + public: + ReduceLogSumExp(size_t n, const T& init) : ReduceBase(n, 0) {} + void update(const T& a) override { + this->accumulator_ += static_cast(std::exp(a)); + } + T get_value() override { + return static_cast(std::log(this->accumulator_)); } }; - template - class ReduceInvoker : public ParallelLoopBody - { + + template + class ReduceAllInvoker : public ParallelLoopBody { public: - const Mat* src; - Mat *dst; - std::vector reduceDims; - int nstripes; - int reduceType; - Ptr func; + using dtype = typename Op::dtype_input; - ReduceInvoker() : src(0), dst(0), nstripes(0), reduceType(MAX), func(makePtr()) {} + const Mat& src; + Mat& dst; - static void run(const Mat& src, Mat& dst, std::vector reduceDims, int reduceType, int nstripes) - { - CV_Assert_N( src.isContinuous(), dst.isContinuous(), src.type() == CV_32F, src.type() == dst.type()); + int n_reduce; + int loop_size; - ReduceInvoker p; + int total; + int cost_per_thread; - p.src = &src; - p.dst = &dst; + ReduceAllInvoker(const Mat& src_, Mat& dst_) : src(src_), dst(dst_) { + auto shape_src = shape(src); - p.reduceDims = reduceDims; - p.nstripes = nstripes; - p.reduceType = reduceType; + n_reduce = std::accumulate(shape_src.begin(), shape_src.end(), 1, std::multiplies()); + loop_size = n_reduce; - parallel_for_(Range(0, nstripes), p, nstripes); + total = 1; + cost_per_thread = 1; } - void operator()(const Range& r) const CV_OVERRIDE - { - size_t total = dst->total(); - size_t stripeSize = (total + nstripes - 1)/nstripes; - size_t stripeStart = r.start*stripeSize; - size_t stripeEnd = std::min(r.end*stripeSize, total); - size_t stride_w = std::accumulate(reduceDims.begin(), reduceDims.end(), 1, std::multiplies()); - - float *dstData = (float *)dst->data; - float *srcData = (float *)src->data; - - for (size_t ofs = stripeStart; ofs < stripeEnd;) - { - const float* first = srcData + ofs * stride_w; - const float* last = srcData + (ofs + 1) * stride_w; - - if (ofs < stripeEnd) - { - dstData[ofs] = func->apply(first, last, 1.0 / stride_w); - ofs += 1; + void operator()(const Range& r) const CV_OVERRIDE { + int start = r.start; + int end = r.end; + + const dtype* p_src = src.ptr(); + dtype* p_dst = dst.ptr(); + + for (int i = start; i < end; ++i) { + Op accumulator(n_reduce, *p_src); + for (int l = 0; l < loop_size; ++l) { + accumulator.update(p_src[l]); } + p_dst[i] = accumulator.get_value(); } } }; - void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE - { - CV_TRACE_FUNCTION(); - CV_TRACE_ARG_VALUE(name, "name", name.c_str()); + template + class ReduceInvoker : public ParallelLoopBody { + public: + using dtype = typename Op::dtype_input; - if (inputs_arr.depth() == CV_16S) - { - forward_fallback(inputs_arr, outputs_arr, internals_arr); - return; - } + const Mat& src; + Mat& dst; - std::vector inputs, outputs; - inputs_arr.getMatVector(inputs); - outputs_arr.getMatVector(outputs); - CV_Assert(inputs.size() == 1 || (inputs.size() == 2 && reduceType== SUM)); - const int nstripes = getNumThreads(); + std::vector reduced_axes; // assume in ascending order - switch (reduceType) - { - case MIN: - { - ReduceInvoker::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); - break; - } - case MAX: - { - ReduceInvoker::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); - break; - } - case AVE: - { - ReduceInvoker::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); - break; - } - case SUM: - { - ReduceInvoker::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); - break; - } - case L1: - { - ReduceInvoker::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); - break; + int n_reduce; + int loop_size; + + int last_reduced_dim; + int last_reduced_step; + std::vector projected_steps; + + int last_unreduced_dim; + int last_unreduced_step; + std::vector unprojected_steps; + + int total; + int cost_per_thread; + + ReduceInvoker(const Mat& src_, Mat& dst_, std::vector axes_) : src(src_), dst(dst_), reduced_axes(axes_) { + auto shape_src = shape(src); + + auto steps_src = shape_src; + steps_src[steps_src.size() - 1] = 1; + for (int i = static_cast(steps_src.size()) - 2; i >= 0; --i) + steps_src[i] = steps_src[i + 1] * shape_src[i + 1]; + + size_t projection_size = 1; + for (auto axis : reduced_axes) { + projection_size *= shape_src[axis]; } - case L2: - { - ReduceInvoker::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); - break; + n_reduce = projection_size; + + last_reduced_dim = shape_src[reduced_axes.back()]; + last_reduced_step = steps_src[reduced_axes.back()]; + loop_size = last_reduced_dim * last_reduced_step; + projection_size /= last_reduced_dim; + + // calculate projected_steps + int last_reduced_axis = static_cast(reduced_axes.size()) - 1; + if (last_reduced_axis == 0) { + projected_steps.resize(1, 0); + } else { + projected_steps.resize(projection_size); + std::vector projected_indices(last_reduced_axis, 0); + for (size_t i = 0, current_step = 0; i < projection_size; ++i) { + projected_steps[i] = current_step; + ++projected_indices[last_reduced_axis - 1]; + current_step += steps_src[reduced_axes[last_reduced_axis - 1]]; + for (int j = last_reduced_axis - 1; j > 0; --j) { + if (projected_indices[j] < shape_src[reduced_axes[j]]) { + break; + } + projected_indices[j] = 0; + ++projected_indices[j - 1]; + current_step = steps_src[reduced_axes[j - 1]]; + } + } } - case SUM_SQUARE: - { - ReduceInvoker::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); - break; + + // calculate unprojected_steps + std::vector unreduced_axes; + for (int i = 0; i < static_cast(shape_src.size()); ++i) { + if (std::find(reduced_axes.begin(), reduced_axes.end(), i) == reduced_axes.end()) { + unreduced_axes.push_back(i); + } } - case PROD: - { - ReduceInvoker::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); - break; + size_t unprojection_size = 1; + for (auto axis : unreduced_axes) { + unprojection_size *= shape_src[axis]; } - case LOG_SUM: - { - ReduceInvoker::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); - break; + last_unreduced_dim = shape_src[unreduced_axes.back()]; + last_unreduced_step = steps_src[unreduced_axes.back()]; + unprojection_size /= last_unreduced_dim; + + std::vector unprojected_indices(unreduced_axes.size(), 0); + unprojected_steps.reserve(unprojection_size); + if (unprojected_indices.size() <= 1) { + unprojected_steps.push_back(0); + } else { + for (size_t i = 0, current_step = 0; i < unprojection_size; ++i) { + unprojected_steps.push_back(current_step); + ++unprojected_indices[unprojected_indices.size() - 2]; + current_step += steps_src[unreduced_axes[unreduced_axes.size() - 2]]; + for (int j = static_cast(unreduced_axes.size()) - 2; j > 0; --j) { + if (unprojected_indices[j] < shape_src[unreduced_axes[j]]) { + break; + } + unprojected_indices[j] = 0; + ++unprojected_indices[j - 1]; + current_step = steps_src[unreduced_axes[j - 1]]; + } + } } - case LOG_SUM_EXP: - { - ReduceInvoker::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); - break; + + auto shape_dst = shape(dst); + total = std::accumulate(shape_dst.begin(), shape_dst.end(), 1, std::multiplies()); + cost_per_thread = static_cast(projected_steps.size() * last_reduced_step); + } + + static void run(const Mat& src, Mat& dst, std::vector axes, bool noop_with_empty_axes) { + CV_Assert(src.isContinuous()); + CV_Assert(dst.isContinuous()); + + if (axes.empty()) { + if (noop_with_empty_axes) { + // copyTo is not used here for the reason that we want a + // copy for the case when dims at all axes are 1 + const auto p_src = src.ptr(); + auto p_dst = dst.ptr(); + std::memcpy(p_dst, p_src, sizeof(dtype) * dst.total()); + return; + } + + ReduceAllInvoker p(src, dst); + double nstripes = (size_t)p.total * (size_t)p.cost_per_thread * (1 / 1024.0); + parallel_for_(Range(0, p.total), p, nstripes); + return; } - default: - CV_Error(Error::StsNotImplemented, "Not implemented"); - break; + + ReduceInvoker p(src, dst, axes); + double nstripes = (size_t)p.total * (size_t)p.cost_per_thread * (1 / 1024.0); + parallel_for_(Range(0, p.total), p, nstripes); } - } - bool getMemoryShapes(const std::vector &inputs, - const int requiredOutputs, - std::vector &outputs, - std::vector &internals) const CV_OVERRIDE - { - CV_Assert(inputs.size() > 0); - CV_Assert( reduceDims.size() !=0 && targetDims.size() != 0 && inputs[0].size() >= reduceDims.size()); - - // outShapeTmp can save the right number of `total(outShapeTmp)`. And the outShape is used as the final output shape. - std::vector outShapeTmp, outShape; - outShape.assign(targetDims.begin(), targetDims.end()); - if (inputs[0].size() == reduceDims.size()) - outShapeTmp.push_back(1); - else - { - for (int i = 0; i < inputs[0].size() - reduceDims.size(); i++) - { - outShapeTmp.push_back(inputs[0][i]); + void operator()(const Range& r) const CV_OVERRIDE { + int start = r.start; + int end = r.end; + + const dtype* p_src = src.ptr(); + dtype* p_dst = dst.ptr(); + + size_t main_index = start / last_unreduced_dim; + size_t loop = start / last_unreduced_dim; + size_t origin = unprojected_steps[main_index] + loop * last_unreduced_step; + for (int i = start; i < end; ++i) { + Op accumulator(n_reduce, p_src[origin + projected_steps[0]]); + for (auto projected_step : projected_steps) { + const dtype* loop_p_src = p_src + origin + projected_step; + for (auto l = 0; l < loop_size; l += last_reduced_step) { + accumulator.update(loop_p_src[l]); + } + } + p_dst[i] = accumulator.get_value(); + + ++loop; + if (loop >= last_unreduced_dim) { + loop = 0; + ++main_index; + if (main_index < unprojected_steps.size()) { + origin = unprojected_steps[main_index]; + } + } else { + origin += last_unreduced_step; + } } } + }; + + void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE + { + CV_TRACE_FUNCTION(); + CV_TRACE_ARG_VALUE(name, "name", name.c_str()); - // Support dynamic shape of Batch size. - // Note that: when there are multiple dynamic inputs, we will give an error. - if (total(outShape) != total(outShapeTmp) && outShape[0] != outShapeTmp[0]) + if (inputs_arr.depth() == CV_16S) { - outShape[0] = outShapeTmp[0]; + forward_fallback(inputs_arr, outputs_arr, internals_arr); + return; } - CV_Assert(total(outShape) == total(outShapeTmp)); - outputs.assign(1, outShape); + std::vector inputs, outputs; + inputs_arr.getMatVector(inputs); + outputs_arr.getMatVector(outputs); - return false; + typeDispatch(outputs[0].type(), inputs[0], outputs[0], axes, noop_with_empty_axes); } - virtual bool tryQuantize(const std::vector > &scales, - const std::vector > &zeropoints, LayerParams& params) CV_OVERRIDE - { - if (reduceType== MAX || reduceType== MIN) - { - return true; + template + inline void opDispatch(Args&&... args) { + switch (reduce_type) { + case ReduceType::MAX: ReduceInvoker>::run(std::forward(args)...); break; + case ReduceType::MIN: ReduceInvoker>::run(std::forward(args)...); break; + case ReduceType::MEAN: ReduceInvoker>::run(std::forward(args)...); break; + case ReduceType::SUM: ReduceInvoker>::run(std::forward(args)...); break; + case ReduceType::L1: ReduceInvoker>::run(std::forward(args)...); break; + case ReduceType::L2: ReduceInvoker>::run(std::forward(args)...); break; + case ReduceType::PROD: ReduceInvoker>::run(std::forward(args)...); break; + case ReduceType::SUM_SQUARE: ReduceInvoker>::run(std::forward(args)...); break; + case ReduceType::LOG_SUM: ReduceInvoker>::run(std::forward(args)...); break; + case ReduceType::LOG_SUM_EXP: ReduceInvoker>::run(std::forward(args)...); break; + default: CV_Error(Error::StsBadArg, "DNN/Reduce: Unsupported operation."); } - return false; } - virtual int64 getFLOPS(const std::vector &inputs, - const std::vector &outputs) const CV_OVERRIDE - { - CV_UNUSED(inputs); // suppress unused variable warning - long flops = 0; - size_t stride_w = std::accumulate(reduceDims.begin(), reduceDims.end(), 1, std::multiplies()); - for (int i = 0; i < outputs.size(); i++) - { - flops += total(outputs[i])*(stride_w); + template + inline void typeDispatch(const int type, Args&&... args) { + switch (type) { + case CV_8U: opDispatch(std::forward(args)...); break; + case CV_32S: opDispatch(std::forward(args)...); break; + case CV_32F: opDispatch(std::forward(args)...); break; + default: CV_Error(cv::Error::BadDepth, "DNN/Reduce: Unsupported type."); } - return flops; } + private: enum ReduceType { MAX, MIN, - AVE, + MEAN, SUM, L1, L2, @@ -397,7 +509,11 @@ private: SUM_SQUARE, LOG_SUM, LOG_SUM_EXP - }; + } reduce_type; + + bool keepdims; + bool noop_with_empty_axes; + std::vector axes; }; Ptr ReduceLayer::create(const LayerParams& params) @@ -405,5 +521,4 @@ Ptr ReduceLayer::create(const LayerParams& params) return Ptr(new ReduceLayerImpl(params)); } -} -} +}} // cv::dnn diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 81a61dd861..3c3e8787be 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -1178,165 +1178,49 @@ void ONNXImporter::parseGlobalPool(LayerParams &layerParams, const opencv_onnx:: addLayer(layerParams, node_proto); } -void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_) -{ - opencv_onnx::NodeProto node_proto = node_proto_; - const std::string& layer_type = node_proto.op_type(); - const std::string output_name = node_proto.output(0); - int depth = layerParams.get("depth", CV_32F); - - CV_Assert(node_proto.input_size() <= 2); - String reduceType; - - if (layer_type == "ReduceMax") - reduceType = "MAX"; - else if (layer_type == "ReduceMin") - reduceType = "MIN"; - else if (layer_type == "ReduceSum") - reduceType = "SUM"; - else if (layer_type == "ReduceSumSquare") - reduceType = "SUM_SQUARE"; - else if (layer_type == "ReduceProd") - reduceType = "PROD"; - else if (layer_type == "ReduceL1") - reduceType = "L1"; - else if (layer_type == "ReduceL2") - reduceType = "L2"; - else if (layer_type == "ReduceLogSum") - reduceType = "LOG_SUM"; - else if (layer_type == "ReduceLogSumExp") - reduceType = "LOG_SUM_EXP"; - else if (layer_type == "ReduceMean") - reduceType = "AVE"; +void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) +{ + const auto& op_type = node_proto.op_type(); + String reduce_type; + if (op_type == "ReduceMax") + reduce_type = "MAX"; + else if (op_type == "ReduceMean") + reduce_type = "MEAN"; + else if (op_type == "ReduceMin") + reduce_type = "MIN"; + else if (op_type == "ReduceProd") + reduce_type = "PROD"; + else if (op_type == "ReduceSum") + reduce_type = "SUM"; + else if (op_type == "ReduceL1") + reduce_type = "L1"; + else if (op_type == "ReduceL2") + reduce_type = "L2"; + else if (op_type == "ReduceLogSum") + reduce_type = "LOG_SUM"; + else if (op_type == "ReduceLogSumExp") + reduce_type = "LOG_SUM_EXP"; + else if (op_type == "ReduceSumSquare") + reduce_type = "SUM_SQUARE"; else - CV_Error(Error::StsNotImplemented, "Unsupported Pooling type of " + layer_type + " operation."); - - // The ReduceInt8 can only support "MAX" and "MIN". - if (depth == CV_8S) - { - CV_Assert(reduceType == "MAX" || reduceType == "MIN"); - } - - layerParams.type = (depth == CV_8S) ? "ReduceInt8" : "Reduce"; - layerParams.set("reduce", reduceType); - bool keepdims = layerParams.get("keepdims", 1) == 1; - - MatShape inpShape = outShapes[node_proto.input(0)]; - std::vector shouldDelete(inpShape.size(), false); - - if (layer_type == "ReduceSum" && node_proto.input_size() == 2) - { - if (constBlobs.find(node_proto.input(1)) != constBlobs.end()) - { - Mat axesMat = getBlob(node_proto, 1); - int axesNum = axesMat.total(); - for (int i = 0; i < axesNum; i++) - { - int axis = normalize_axis(axesMat.at(i), inpShape.size()); - shouldDelete[axis] = true; - } - } - else - // in opset 13, the ReduceSum has two input, it takes axes as input instead of attribute - // details:https://github.com/onnx/onnx/issues/3420#issuecomment-844295687 - CV_Error(Error::StsNotImplemented, "Non-constant axis values in ReduceSum are not supported."); - } - else - { - if (layerParams.has("axes")) - { - DictValue axes = layerParams.get("axes"); - for (int i = 0; i < axes.size(); i++) - { - int axis = normalize_axis(axes.get(i), inpShape.size()); - shouldDelete[axis] = true; - } - } - else - { - for (int i = 0; i < inpShape.size(); i++) - { - shouldDelete[i] = true; - } - } - } - - std::vector targetShape; - for (int i = 0; i < inpShape.size(); ++i) - { - if (!shouldDelete[i]) - { - targetShape.push_back(inpShape[i]); - } - else if (keepdims) - { - targetShape.push_back(1); - } - } - - if (targetShape.empty()) - targetShape.push_back(1); + CV_Error(Error::StsNotImplemented, "DNN/ONNX: " + op_type + " is not supported."); + layerParams.set("reduce", reduce_type); - // Using PermuteLayer to move the deleted axis to the last. - std::vector perm(inpShape.size(), 0); - for (int i = 0; i < inpShape.size(); i++) - perm[i] = i; - - bool needPermuet = false; - for (int i = 0; i < inpShape.size(); i++) - { - if (shouldDelete[i]) - { - // find the first not deleted element. - std::vector::iterator iter = std::find(shouldDelete.begin() + i, shouldDelete.end(), false); - - if (iter != shouldDelete.end()) - { - int index = iter - shouldDelete.begin(); - - bool temp = shouldDelete[index]; - shouldDelete[index] = shouldDelete[i]; - shouldDelete[i] = temp; - - std::swap(perm[index], perm[i]); - std::swap(inpShape[index], inpShape[i]); - needPermuet = true; - } - else - break; - } - } - - auto inputString= node_proto.input(0); - if (needPermuet) - { - LayerParams permuteLp; - permuteLp.name = layerParams.name + "/permute"; - permuteLp.type = (depth == CV_8S) ? "PermuteInt8" : "Permute"; - permuteLp.set("order", DictValue::arrayInt(perm.data(), perm.size())); - - opencv_onnx::NodeProto protoPermute; - protoPermute.add_input(inputString); - protoPermute.add_output(permuteLp.name); - addLayer(permuteLp, protoPermute); - inputString = permuteLp.name; - } + int num_inputs = node_proto.input_size(); + CV_Check(num_inputs, num_inputs >= 1 && num_inputs <= 2, "DNN/ONNX: Reduce layers should have at least one input and at most two inputs"); - std::vector deletedDims; - for (int axis_i = 0; axis_i < inpShape.size(); ++axis_i) - { - if (shouldDelete[axis_i]) - { - deletedDims.push_back(inpShape[axis_i]); - } + // "axes" is turned to one of the inputs since opset 18, + // except for ReduceSum, which has "axes" input since opset 13. + if (!layerParams.has("axes") && num_inputs == 2 && constBlobs.find(node_proto.input(1)) != constBlobs.end()) { + Mat mat_axes = getBlob(node_proto, 1); + int num_axes = mat_axes.total(); + std::vector axes(num_axes); + for (int i = 0; i < num_axes; ++i) + axes[i] = mat_axes.at(i); + layerParams.set("axes", DictValue::arrayInt(&axes[0], num_axes)); } - layerParams.set("deleted_dims", DictValue::arrayInt(&deletedDims[0], deletedDims.size())); - layerParams.set("target_dims", DictValue::arrayInt(&targetShape[0], targetShape.size())); - - node_proto.set_input(0, inputString); - node_proto.set_output(0, output_name); - + layerParams.type = "Reduce"; addLayer(layerParams, node_proto); }