dnn: refactor reduce (#23613)

* initial impl

* remove reduce in8; fix reduce importer

* fix bugs and add log sum exp

* remove unnecessary header and fix indentation
pull/23637/head
Yuantao Feng 2 years ago committed by GitHub
parent 5229312ad2
commit eefee8574a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 9
      modules/dnn/include/opencv2/dnn/all_layers.hpp
  2. 1
      modules/dnn/src/init.cpp
  3. 234
      modules/dnn/src/int8layers/reduce_layer.cpp
  4. 707
      modules/dnn/src/layers/reduce_layer.cpp
  5. 192
      modules/dnn/src/onnx/onnx_importer.cpp

@ -346,18 +346,9 @@ CV__DNN_INLINE_NS_BEGIN
class CV_EXPORTS ReduceLayer : public Layer class CV_EXPORTS ReduceLayer : public Layer
{ {
public: public:
int reduceType;
// reduceDims contains the dimensions that need to be reduced, targetDims is the target output dimension.
std::vector<size_t> reduceDims, targetDims;
static Ptr<ReduceLayer> create(const LayerParams& params); static Ptr<ReduceLayer> create(const LayerParams& params);
}; };
class CV_EXPORTS ReduceLayerInt8 : public ReduceLayer
{
public:
static Ptr<ReduceLayerInt8> create(const LayerParams& params);
};
class CV_EXPORTS SoftmaxLayer : public Layer class CV_EXPORTS SoftmaxLayer : public Layer
{ {
public: public:

@ -194,7 +194,6 @@ void initializeLayerFactory()
CV_DNN_REGISTER_LAYER_CLASS(ConvolutionInt8, ConvolutionLayerInt8); CV_DNN_REGISTER_LAYER_CLASS(ConvolutionInt8, ConvolutionLayerInt8);
CV_DNN_REGISTER_LAYER_CLASS(InnerProductInt8, InnerProductLayerInt8); CV_DNN_REGISTER_LAYER_CLASS(InnerProductInt8, InnerProductLayerInt8);
CV_DNN_REGISTER_LAYER_CLASS(PoolingInt8, PoolingLayerInt8); CV_DNN_REGISTER_LAYER_CLASS(PoolingInt8, PoolingLayerInt8);
CV_DNN_REGISTER_LAYER_CLASS(ReduceInt8, ReduceLayerInt8);
CV_DNN_REGISTER_LAYER_CLASS(EltwiseInt8, EltwiseLayerInt8); CV_DNN_REGISTER_LAYER_CLASS(EltwiseInt8, EltwiseLayerInt8);
CV_DNN_REGISTER_LAYER_CLASS(BatchNormInt8, BatchNormLayerInt8); CV_DNN_REGISTER_LAYER_CLASS(BatchNormInt8, BatchNormLayerInt8);
CV_DNN_REGISTER_LAYER_CLASS(ScaleInt8, ScaleLayerInt8); CV_DNN_REGISTER_LAYER_CLASS(ScaleInt8, ScaleLayerInt8);

@ -1,234 +0,0 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include "../precomp.hpp"
#include "layers_common.hpp"
#include <algorithm>
#include <stdlib.h>
#include <numeric>
namespace cv
{
namespace dnn
{
class ReduceLayerInt8Impl CV_FINAL : public ReduceLayerInt8
{
public:
ReduceLayerInt8Impl(const LayerParams& params)
{
// Set reduce type
CV_Assert(params.has("reduce"));
String typeString = toLowerCase(params.get<String>("reduce"));
if (typeString == "max")
reduceType = MAX;
else if (typeString == "min")
reduceType = MIN;
else
CV_Error(Error::StsBadArg, "Unknown reduce type \"" + typeString + "\"");
// Set deleted dims
CV_Assert(params.has("deleted_dims"));
DictValue tempDims = params.get("deleted_dims");
int i, n = tempDims.size();
reduceDims.resize(n);
for (i = 0; i < n; i++)
{
reduceDims[i] = tempDims.get<int>(i);
}
CV_Assert(params.has("target_dims"));
tempDims = params.get("target_dims");
n = tempDims.size();
targetDims.resize(n);
for (i = 0; i < n; i++)
{
targetDims[i] = tempDims.get<int>(i);
}
}
virtual bool supportBackend(int backendId) CV_OVERRIDE
{
if (backendId == DNN_BACKEND_OPENCV)
{
return true;
}
return false;
}
// reduceType == MIN
struct ReduceOpMIN
{
int8_t apply(const int8_t* first, const int8_t* last)
{
return std::accumulate(first, last, *first,
[](int8_t a, int8_t b)
{
return std::min(a, b);
});
}
};
// reduceType == MAX
struct ReduceOpMAX
{
int8_t apply(const int8_t* first, const int8_t* last)
{
return std::accumulate(first, last, *first,
[](int8_t a, int8_t b)
{
return std::max(a, b);
});
}
};
template<typename Func>
class ReduceInvoker : public ParallelLoopBody
{
public:
const Mat* src;
Mat *dst;
std::vector<size_t> reduceDims;
int nstripes;
int reduceType;
Ptr<Func> func;
ReduceInvoker() : src(0), dst(0), nstripes(0), reduceType(MAX), func(makePtr<Func>()) {}
static void run(const Mat& src, Mat& dst, std::vector<size_t> reduceDims, int reduceType, int nstripes)
{
CV_Assert_N(src.isContinuous(), dst.isContinuous(), src.type() == CV_8S, src.type() == dst.type());
ReduceInvoker<Func> p;
p.src = &src;
p.dst = &dst;
p.reduceDims = reduceDims;
p.nstripes = nstripes;
p.reduceType = reduceType;
parallel_for_(Range(0, nstripes), p, nstripes);
}
void operator()(const Range& r) const CV_OVERRIDE
{
size_t total = dst->total();
size_t stripeSize = (total + nstripes - 1)/nstripes;
size_t stripeStart = r.start*stripeSize;
size_t stripeEnd = std::min(r.end*stripeSize, total);
size_t totalDeleted = std::accumulate(reduceDims.begin(), reduceDims.end(), 1, std::multiplies<size_t>());
int8_t *dstData = (int8_t *)dst->data;
int8_t *srcData = (int8_t *)src->data;
for (size_t ofs = stripeStart; ofs < stripeEnd;)
{
const int8_t* first = srcData + ofs * totalDeleted;
const int8_t* last = srcData + (ofs + 1) * totalDeleted;
dstData[ofs] = func->apply(first, last);
ofs += 1;
}
}
};
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
{
CV_TRACE_FUNCTION();
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
std::vector<Mat> inputs, outputs;
inputs_arr.getMatVector(inputs);
outputs_arr.getMatVector(outputs);
CV_Assert(inputs.size() == 1);
const int nstripes = getNumThreads();
switch (reduceType)
{
case MIN:
{
ReduceInvoker<ReduceOpMIN>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
break;
}
case MAX:
{
ReduceInvoker<ReduceOpMAX>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
break;
}
default:
CV_Error(Error::StsNotImplemented, "Not implemented");
break;
}
}
bool getMemoryShapes(const std::vector<MatShape> &inputs,
const int requiredOutputs,
std::vector<MatShape> &outputs,
std::vector<MatShape> &internals) const CV_OVERRIDE
{
CV_Assert(inputs.size() > 0);
CV_Assert( reduceDims.size() !=0 && targetDims.size() != 0 && inputs[0].size() >= reduceDims.size());
// outShapeTmp can save the right number of `total(outShapeTmp)`. And the outShape is used as the final output shape.
std::vector<int> outShapeTmp, outShape;
outShape.assign(targetDims.begin(), targetDims.end());
if (inputs[0].size() == reduceDims.size())
outShapeTmp.push_back(1);
else
{
for (int i = 0; i < inputs[0].size() - reduceDims.size(); i++)
{
outShapeTmp.push_back(inputs[0][i]);
}
}
// Support dynamic shape of Batch size.
// Note that: when there are multiple dynamic inputs, we will give an error.
if (total(outShape) != total(outShapeTmp))
{
if (outShape[0] != outShapeTmp[0])
outShape[0] = outShapeTmp[0];
}
CV_Assert(total(outShape) == total(outShapeTmp));
outputs.assign(1, outShape);
return false;
}
virtual bool tryQuantize(const std::vector<std::vector<float> > &scales,
const std::vector<std::vector<int> > &zeropoints, LayerParams& params) CV_OVERRIDE
{
return false;
}
virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
const std::vector<MatShape> &outputs) const CV_OVERRIDE
{
CV_UNUSED(inputs); // suppress unused variable warning
long flops = 0;
size_t totalDeleted = std::accumulate(reduceDims.begin(), reduceDims.end(), 1, std::multiplies<size_t>());
for (int i = 0; i < outputs.size(); i++)
{
flops += total(outputs[i])*(totalDeleted);
}
return flops;
}
private:
enum Type
{
MAX,
MIN
};
};
Ptr<ReduceLayerInt8> ReduceLayerInt8::create(const LayerParams& params)
{
return Ptr<ReduceLayerInt8>(new ReduceLayerInt8Impl(params));
}
}
}

@ -3,393 +3,505 @@
// of this distribution and at http://opencv.org/license.html. // of this distribution and at http://opencv.org/license.html.
#include "../precomp.hpp" #include "../precomp.hpp"
#include "opencv2/core/hal/intrin.hpp" #include <opencv2/dnn/shape_utils.hpp>
#include "../op_cuda.hpp"
#include "../op_webnn.hpp"
#include <float.h>
#include <algorithm>
#include <numeric>
using std::max;
using std::min;
#include <opencv2/core/utils/logger.hpp> namespace cv { namespace dnn {
namespace cv
{
namespace dnn
{
class ReduceLayerImpl CV_FINAL : public ReduceLayer class ReduceLayerImpl CV_FINAL : public ReduceLayer
{ {
public: public:
ReduceLayerImpl(const LayerParams& params) ReduceLayerImpl(const LayerParams& params) {
{
setParamsFrom(params); setParamsFrom(params);
// set reduce type // set reduce type
CV_Assert(params.has("reduce")); CV_Assert(params.has("reduce"));
String typeString = toLowerCase(params.get<String>("reduce")); String op_type = toLowerCase(params.get<String>("reduce"));
if (typeString == "max") if (op_type == "max")
reduceType= MAX; reduce_type = ReduceType::MAX;
else if (typeString == "min") else if (op_type == "min")
reduceType= MIN; reduce_type = ReduceType::MIN;
else if (typeString == "ave") else if (op_type == "mean")
reduceType= AVE; reduce_type = ReduceType::MEAN;
else if (typeString == "sum") else if (op_type == "sum")
reduceType= SUM; reduce_type = ReduceType::SUM;
else if (typeString == "sum_square") else if (op_type == "sum_square")
reduceType= SUM_SQUARE; reduce_type = ReduceType::SUM_SQUARE;
else if (typeString == "l1") else if (op_type == "l1")
reduceType= L1; reduce_type = ReduceType::L1;
else if (typeString == "l2") else if (op_type == "l2")
reduceType= L2; reduce_type = ReduceType::L2;
else if (typeString == "log_sum") else if (op_type == "log_sum")
reduceType= LOG_SUM; reduce_type = ReduceType::LOG_SUM;
else if (typeString == "log_sum_exp") else if (op_type == "log_sum_exp")
reduceType= LOG_SUM_EXP; reduce_type = ReduceType::LOG_SUM_EXP;
else if (typeString == "prod") else if (op_type == "prod")
reduceType= PROD; reduce_type = ReduceType::PROD;
else else
CV_Error(Error::StsBadArg, "Unknown reduce type\"" + typeString + "\""); CV_Error(Error::StsBadArg, "Unknown reduce type\"" + op_type + "\"");
// set deleted dims keepdims = params.get<bool>("keepdims", true);
CV_Assert(params.has("deleted_dims")); noop_with_empty_axes = params.get<bool>("noop_with_empty_axes", false);
DictValue tempDims = params.get("deleted_dims");
int i, n = tempDims.size(); // get axes if it is existed, otherwise reduce all
reduceDims.resize(n); if (params.has("axes")) {
for (i = 0; i < n; i++) auto param_axes = params.get("axes");
{ int num_axes = param_axes.size();
reduceDims[i] = tempDims.get<int>(i); axes.resize(num_axes);
for (int i = 0; i < num_axes; ++i)
axes[i] = param_axes.get<int>(i);
} }
}
CV_Assert(params.has("target_dims")); virtual bool supportBackend(int backendId) CV_OVERRIDE {
tempDims = params.get("target_dims"); return backendId == DNN_BACKEND_OPENCV;
n = tempDims.size(); }
targetDims.resize(n);
for (i = 0; i < n; i++) virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE {
{ if (axes.empty()) {
targetDims[i] = tempDims.get<int>(i); return;
}
std::vector<Mat> inputs, outputs;
inputs_arr.getMatVector(inputs);
outputs_arr.getMatVector(outputs);
auto shape_input = shape(inputs[0]);
for (auto i = 0; i < axes.size(); ++i) {
auto norm_axis = normalize_axis(axes[i], shape_input);
axes[i] = norm_axis;
}
bool do_nothing = true;
for (auto axis : axes) {
if (shape_input[axis] != 1) {
do_nothing = false;
}
}
if (do_nothing) {
axes.clear();
noop_with_empty_axes = true;
} }
} }
virtual bool supportBackend(int backendId) CV_OVERRIDE bool getMemoryShapes(const std::vector<MatShape> &inputs,
const int requiredOutputs,
std::vector<MatShape> &outputs,
std::vector<MatShape> &internals) const CV_OVERRIDE
{ {
if (backendId == DNN_BACKEND_OPENCV) // empty axes
{ if (axes.empty()) {
return true; if (noop_with_empty_axes) {
// do nothing
outputs.assign(1, inputs[0]);
} else {
// reduce all axes
MatShape shape_output;
if (keepdims) {
shape_output = inputs[0];
for (auto i = 0; i < shape_output.size(); ++i)
shape_output[i] = 1;
} else {
shape_output.push_back(1);
}
outputs.assign(1, shape_output);
}
} else {
auto shape_output_ = inputs[0];
for (size_t i = 0; i < axes.size(); ++i) {
auto norm_axis = normalize_axis(axes[i], inputs[0]);
shape_output_[norm_axis] = -1;
}
MatShape shape_output;
for (size_t i = 0; i < shape_output_.size(); ++i) {
if (shape_output_[i] == -1) {
if (keepdims)
shape_output.push_back(1);
else
continue;
} else
shape_output.push_back(shape_output_[i]);
}
if (shape_output.empty())
shape_output.push_back(1);
outputs.assign(1, shape_output);
} }
return false; return false;
} }
// reduceType == MIN template <typename T>
struct ReduceOpMIN class ReduceBase {
{ public:
float apply(const float* first, const float* last, const float ikarea = 1.0f) using dtype_input = T;
{
return std::accumulate(first, last, FLT_MAX, ReduceBase(size_t n, const T& init) : n_(n), accumulator_(init) {}
[](float a, float b) virtual void update(const T& a) = 0;
{ virtual T get_value() { return accumulator_; }
return std::min(a, b); virtual ~ReduceBase() = default;
}); protected:
size_t n_;
T accumulator_;
};
template <typename T>
class ReduceMin : public ReduceBase<T> {
public:
ReduceMin(size_t n, const T& init) : ReduceBase<T>(n, init) {}
void update(const T& a) override {
this->accumulator_ = a > this->accumulator_ ? this->accumulator_ : a;
} }
}; };
// reduceType == MAX template <typename T>
struct ReduceOpMAX class ReduceMax : public ReduceBase<T> {
{ public:
float apply(const float* first, const float* last, const float ikarea = 1.0f) ReduceMax(size_t n, const T& init) : ReduceBase<T>(n, init) {}
{ void update(const T& a) override {
return std::accumulate(first, last, -FLT_MAX, this->accumulator_ = a > this->accumulator_ ? a : this->accumulator_;
[](float a, float b)
{
return std::max(a, b);
});
} }
}; };
// reduceType == SUM template <typename T>
struct ReduceOpSUM class ReduceSum : public ReduceBase<T> {
{ public:
float apply(const float* first, const float* last, const float ikarea = 1.0f) ReduceSum(size_t n, const T& init) : ReduceBase<T>(n, 0) {}
{ void update(const T& a) override {
return std::accumulate(first, last, 0.f); this->accumulator_ += a;
} }
}; };
// reduceType == AVE template <typename T>
struct ReduceOpAVE class ReduceMean : public ReduceSum<T> {
{ public:
float apply(const float* first, const float* last, const float ikarea = 1.0f) ReduceMean(size_t n, const T& init) : ReduceSum<T>(n, init) {}
{ T get_value() override {
float output = std::accumulate(first, last, 0.f); return this->accumulator_ / static_cast<T>(this->n_);
return output * ikarea;
} }
}; };
// reduceType == SUM_SQUARE template <typename T>
struct ReduceOpSUM_SQUARE class ReduceSumSquare : public ReduceBase<T> {
{ public:
float apply(const float* first, const float* last, const float ikarea = 1.0f) ReduceSumSquare(size_t n, const T& init) : ReduceBase<T>(n, 0) {}
{ void update(const T& a) override {
return std::accumulate(first, last, 0.f, this->accumulator_ += a * a;
[](float a, float b)
{
return a + b * b;
});
} }
}; };
// reduceType == L1 template <typename T>
struct ReduceOpL1 class ReduceL1 : public ReduceBase<T> {
{ public:
float apply(const float* first, const float* last, const float ikarea = 1.0f) ReduceL1(size_t n, const T& init) : ReduceBase<T>(n, 0) {}
{ void update(const T& a) override {
return std::accumulate(first, last, 0.f, this->accumulator_ += a > 0 ? a : -a;
[](float a, float b)
{
return a + std::abs(b);
});
} }
}; };
// reduceType == L2 template <typename T>
struct ReduceOpL2 class ReduceL2 : public ReduceBase<T> {
{ public:
float apply(const float* first, const float* last, const float ikarea = 1.0f) ReduceL2(size_t n, const T& init) : ReduceBase<T>(n, 0) {}
{ void update(const T& a) override {
float output = std::accumulate(first, last, 0.f, this->accumulator_ += a * a;
[](float a, float b) }
{ T get_value() override {
return a + b * b; return std::sqrt(this->accumulator_);
});
return std::sqrt(output);
} }
}; };
// reduceType == PROD template <typename T>
struct ReduceOpPROD class ReduceProd : public ReduceBase<T> {
{ public:
float apply(const float* first, const float* last, const float ikarea = 1.0f) ReduceProd(size_t n, const T& init) : ReduceBase<T>(n, 1) {}
{ void update(const T& a) override {
return std::accumulate(first, last, 1.0f, std::multiplies<float>()); this->accumulator_ *= a;
} }
}; };
// reduceType == LOG_SUM template <typename T>
struct ReduceOpLOG_SUM class ReduceLogSum : public ReduceBase<T> {
{ public:
float apply(const float* first, const float* last, const float ikarea = 1.0f) ReduceLogSum(size_t n, const T& init) : ReduceBase<T>(n, 0) {}
{ void update(const T& a) override {
float output = std::accumulate(first, last, 0.0f); this->accumulator_ += a;
return std::log(output); }
T get_value() override {
return static_cast<T>(std::log(this->accumulator_));
} }
}; };
// reduceType == LOG_SUM_EXP // FIXME: overflow caution
struct ReduceOpLOG_SUM_EXP template <typename T>
{ class ReduceLogSumExp : public ReduceBase<T> {
float apply(const float* first, const float* last, const float ikarea = 1.0f) public:
{ ReduceLogSumExp(size_t n, const T& init) : ReduceBase<T>(n, 0) {}
float output = std::accumulate(first, last, 0.0f, void update(const T& a) override {
[](float a, float b) this->accumulator_ += static_cast<T>(std::exp(a));
{ }
return a + std::exp(b); T get_value() override {
}); return static_cast<T>(std::log(this->accumulator_));
return std::log(output);
} }
}; };
template<typename Func>
class ReduceInvoker : public ParallelLoopBody template <typename Op>
{ class ReduceAllInvoker : public ParallelLoopBody {
public: public:
const Mat* src; using dtype = typename Op::dtype_input;
Mat *dst;
std::vector<size_t> reduceDims;
int nstripes;
int reduceType;
Ptr<Func> func;
ReduceInvoker() : src(0), dst(0), nstripes(0), reduceType(MAX), func(makePtr<Func>()) {} const Mat& src;
Mat& dst;
static void run(const Mat& src, Mat& dst, std::vector<size_t> reduceDims, int reduceType, int nstripes) int n_reduce;
{ int loop_size;
CV_Assert_N( src.isContinuous(), dst.isContinuous(), src.type() == CV_32F, src.type() == dst.type());
ReduceInvoker<Func> p; int total;
int cost_per_thread;
p.src = &src; ReduceAllInvoker(const Mat& src_, Mat& dst_) : src(src_), dst(dst_) {
p.dst = &dst; auto shape_src = shape(src);
p.reduceDims = reduceDims; n_reduce = std::accumulate(shape_src.begin(), shape_src.end(), 1, std::multiplies<int>());
p.nstripes = nstripes; loop_size = n_reduce;
p.reduceType = reduceType;
parallel_for_(Range(0, nstripes), p, nstripes); total = 1;
cost_per_thread = 1;
} }
void operator()(const Range& r) const CV_OVERRIDE void operator()(const Range& r) const CV_OVERRIDE {
{ int start = r.start;
size_t total = dst->total(); int end = r.end;
size_t stripeSize = (total + nstripes - 1)/nstripes;
size_t stripeStart = r.start*stripeSize; const dtype* p_src = src.ptr<const dtype>();
size_t stripeEnd = std::min(r.end*stripeSize, total); dtype* p_dst = dst.ptr<dtype>();
size_t stride_w = std::accumulate(reduceDims.begin(), reduceDims.end(), 1, std::multiplies<size_t>());
for (int i = start; i < end; ++i) {
float *dstData = (float *)dst->data; Op accumulator(n_reduce, *p_src);
float *srcData = (float *)src->data; for (int l = 0; l < loop_size; ++l) {
accumulator.update(p_src[l]);
for (size_t ofs = stripeStart; ofs < stripeEnd;)
{
const float* first = srcData + ofs * stride_w;
const float* last = srcData + (ofs + 1) * stride_w;
if (ofs < stripeEnd)
{
dstData[ofs] = func->apply(first, last, 1.0 / stride_w);
ofs += 1;
} }
p_dst[i] = accumulator.get_value();
} }
} }
}; };
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE template <typename Op>
{ class ReduceInvoker : public ParallelLoopBody {
CV_TRACE_FUNCTION(); public:
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); using dtype = typename Op::dtype_input;
if (inputs_arr.depth() == CV_16S) const Mat& src;
{ Mat& dst;
forward_fallback(inputs_arr, outputs_arr, internals_arr);
return;
}
std::vector<Mat> inputs, outputs; std::vector<int> reduced_axes; // assume in ascending order
inputs_arr.getMatVector(inputs);
outputs_arr.getMatVector(outputs);
CV_Assert(inputs.size() == 1 || (inputs.size() == 2 && reduceType== SUM));
const int nstripes = getNumThreads();
switch (reduceType) int n_reduce;
{ int loop_size;
case MIN:
{ int last_reduced_dim;
ReduceInvoker<ReduceOpMIN>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); int last_reduced_step;
break; std::vector<int> projected_steps;
}
case MAX: int last_unreduced_dim;
{ int last_unreduced_step;
ReduceInvoker<ReduceOpMAX>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); std::vector<int> unprojected_steps;
break;
} int total;
case AVE: int cost_per_thread;
{
ReduceInvoker<ReduceOpAVE>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); ReduceInvoker(const Mat& src_, Mat& dst_, std::vector<int> axes_) : src(src_), dst(dst_), reduced_axes(axes_) {
break; auto shape_src = shape(src);
}
case SUM: auto steps_src = shape_src;
{ steps_src[steps_src.size() - 1] = 1;
ReduceInvoker<ReduceOpSUM>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); for (int i = static_cast<int>(steps_src.size()) - 2; i >= 0; --i)
break; steps_src[i] = steps_src[i + 1] * shape_src[i + 1];
}
case L1: size_t projection_size = 1;
{ for (auto axis : reduced_axes) {
ReduceInvoker<ReduceOpL1>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); projection_size *= shape_src[axis];
break;
} }
case L2: n_reduce = projection_size;
{
ReduceInvoker<ReduceOpL2>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); last_reduced_dim = shape_src[reduced_axes.back()];
break; last_reduced_step = steps_src[reduced_axes.back()];
loop_size = last_reduced_dim * last_reduced_step;
projection_size /= last_reduced_dim;
// calculate projected_steps
int last_reduced_axis = static_cast<int>(reduced_axes.size()) - 1;
if (last_reduced_axis == 0) {
projected_steps.resize(1, 0);
} else {
projected_steps.resize(projection_size);
std::vector<int> projected_indices(last_reduced_axis, 0);
for (size_t i = 0, current_step = 0; i < projection_size; ++i) {
projected_steps[i] = current_step;
++projected_indices[last_reduced_axis - 1];
current_step += steps_src[reduced_axes[last_reduced_axis - 1]];
for (int j = last_reduced_axis - 1; j > 0; --j) {
if (projected_indices[j] < shape_src[reduced_axes[j]]) {
break;
}
projected_indices[j] = 0;
++projected_indices[j - 1];
current_step = steps_src[reduced_axes[j - 1]];
}
}
} }
case SUM_SQUARE:
{ // calculate unprojected_steps
ReduceInvoker<ReduceOpSUM_SQUARE>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); std::vector<int> unreduced_axes;
break; for (int i = 0; i < static_cast<int>(shape_src.size()); ++i) {
if (std::find(reduced_axes.begin(), reduced_axes.end(), i) == reduced_axes.end()) {
unreduced_axes.push_back(i);
}
} }
case PROD: size_t unprojection_size = 1;
{ for (auto axis : unreduced_axes) {
ReduceInvoker<ReduceOpPROD>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); unprojection_size *= shape_src[axis];
break;
} }
case LOG_SUM: last_unreduced_dim = shape_src[unreduced_axes.back()];
{ last_unreduced_step = steps_src[unreduced_axes.back()];
ReduceInvoker<ReduceOpLOG_SUM>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); unprojection_size /= last_unreduced_dim;
break;
std::vector<int> unprojected_indices(unreduced_axes.size(), 0);
unprojected_steps.reserve(unprojection_size);
if (unprojected_indices.size() <= 1) {
unprojected_steps.push_back(0);
} else {
for (size_t i = 0, current_step = 0; i < unprojection_size; ++i) {
unprojected_steps.push_back(current_step);
++unprojected_indices[unprojected_indices.size() - 2];
current_step += steps_src[unreduced_axes[unreduced_axes.size() - 2]];
for (int j = static_cast<int>(unreduced_axes.size()) - 2; j > 0; --j) {
if (unprojected_indices[j] < shape_src[unreduced_axes[j]]) {
break;
}
unprojected_indices[j] = 0;
++unprojected_indices[j - 1];
current_step = steps_src[unreduced_axes[j - 1]];
}
}
} }
case LOG_SUM_EXP:
{ auto shape_dst = shape(dst);
ReduceInvoker<ReduceOpLOG_SUM_EXP>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes); total = std::accumulate(shape_dst.begin(), shape_dst.end(), 1, std::multiplies<int>());
break; cost_per_thread = static_cast<int>(projected_steps.size() * last_reduced_step);
}
static void run(const Mat& src, Mat& dst, std::vector<int> axes, bool noop_with_empty_axes) {
CV_Assert(src.isContinuous());
CV_Assert(dst.isContinuous());
if (axes.empty()) {
if (noop_with_empty_axes) {
// copyTo is not used here for the reason that we want a
// copy for the case when dims at all axes are 1
const auto p_src = src.ptr<const dtype>();
auto p_dst = dst.ptr<dtype>();
std::memcpy(p_dst, p_src, sizeof(dtype) * dst.total());
return;
}
ReduceAllInvoker<Op> p(src, dst);
double nstripes = (size_t)p.total * (size_t)p.cost_per_thread * (1 / 1024.0);
parallel_for_(Range(0, p.total), p, nstripes);
return;
} }
default:
CV_Error(Error::StsNotImplemented, "Not implemented"); ReduceInvoker<Op> p(src, dst, axes);
break; double nstripes = (size_t)p.total * (size_t)p.cost_per_thread * (1 / 1024.0);
parallel_for_(Range(0, p.total), p, nstripes);
} }
}
bool getMemoryShapes(const std::vector<MatShape> &inputs, void operator()(const Range& r) const CV_OVERRIDE {
const int requiredOutputs, int start = r.start;
std::vector<MatShape> &outputs, int end = r.end;
std::vector<MatShape> &internals) const CV_OVERRIDE
{ const dtype* p_src = src.ptr<const dtype>();
CV_Assert(inputs.size() > 0); dtype* p_dst = dst.ptr<dtype>();
CV_Assert( reduceDims.size() !=0 && targetDims.size() != 0 && inputs[0].size() >= reduceDims.size());
size_t main_index = start / last_unreduced_dim;
// outShapeTmp can save the right number of `total(outShapeTmp)`. And the outShape is used as the final output shape. size_t loop = start / last_unreduced_dim;
std::vector<int> outShapeTmp, outShape; size_t origin = unprojected_steps[main_index] + loop * last_unreduced_step;
outShape.assign(targetDims.begin(), targetDims.end()); for (int i = start; i < end; ++i) {
if (inputs[0].size() == reduceDims.size()) Op accumulator(n_reduce, p_src[origin + projected_steps[0]]);
outShapeTmp.push_back(1); for (auto projected_step : projected_steps) {
else const dtype* loop_p_src = p_src + origin + projected_step;
{ for (auto l = 0; l < loop_size; l += last_reduced_step) {
for (int i = 0; i < inputs[0].size() - reduceDims.size(); i++) accumulator.update(loop_p_src[l]);
{ }
outShapeTmp.push_back(inputs[0][i]); }
p_dst[i] = accumulator.get_value();
++loop;
if (loop >= last_unreduced_dim) {
loop = 0;
++main_index;
if (main_index < unprojected_steps.size()) {
origin = unprojected_steps[main_index];
}
} else {
origin += last_unreduced_step;
}
} }
} }
};
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
{
CV_TRACE_FUNCTION();
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
// Support dynamic shape of Batch size. if (inputs_arr.depth() == CV_16S)
// Note that: when there are multiple dynamic inputs, we will give an error.
if (total(outShape) != total(outShapeTmp) && outShape[0] != outShapeTmp[0])
{ {
outShape[0] = outShapeTmp[0]; forward_fallback(inputs_arr, outputs_arr, internals_arr);
return;
} }
CV_Assert(total(outShape) == total(outShapeTmp)); std::vector<Mat> inputs, outputs;
outputs.assign(1, outShape); inputs_arr.getMatVector(inputs);
outputs_arr.getMatVector(outputs);
return false; typeDispatch(outputs[0].type(), inputs[0], outputs[0], axes, noop_with_empty_axes);
} }
virtual bool tryQuantize(const std::vector<std::vector<float> > &scales, template <typename T, typename... Args>
const std::vector<std::vector<int> > &zeropoints, LayerParams& params) CV_OVERRIDE inline void opDispatch(Args&&... args) {
{ switch (reduce_type) {
if (reduceType== MAX || reduceType== MIN) case ReduceType::MAX: ReduceInvoker<ReduceMax<T>>::run(std::forward<Args>(args)...); break;
{ case ReduceType::MIN: ReduceInvoker<ReduceMin<T>>::run(std::forward<Args>(args)...); break;
return true; case ReduceType::MEAN: ReduceInvoker<ReduceMean<T>>::run(std::forward<Args>(args)...); break;
case ReduceType::SUM: ReduceInvoker<ReduceSum<T>>::run(std::forward<Args>(args)...); break;
case ReduceType::L1: ReduceInvoker<ReduceL1<T>>::run(std::forward<Args>(args)...); break;
case ReduceType::L2: ReduceInvoker<ReduceL2<T>>::run(std::forward<Args>(args)...); break;
case ReduceType::PROD: ReduceInvoker<ReduceProd<T>>::run(std::forward<Args>(args)...); break;
case ReduceType::SUM_SQUARE: ReduceInvoker<ReduceSumSquare<T>>::run(std::forward<Args>(args)...); break;
case ReduceType::LOG_SUM: ReduceInvoker<ReduceLogSum<T>>::run(std::forward<Args>(args)...); break;
case ReduceType::LOG_SUM_EXP: ReduceInvoker<ReduceLogSumExp<T>>::run(std::forward<Args>(args)...); break;
default: CV_Error(Error::StsBadArg, "DNN/Reduce: Unsupported operation.");
} }
return false;
} }
virtual int64 getFLOPS(const std::vector<MatShape> &inputs, template <typename... Args>
const std::vector<MatShape> &outputs) const CV_OVERRIDE inline void typeDispatch(const int type, Args&&... args) {
{ switch (type) {
CV_UNUSED(inputs); // suppress unused variable warning case CV_8U: opDispatch<uint8_t>(std::forward<Args>(args)...); break;
long flops = 0; case CV_32S: opDispatch<int32_t>(std::forward<Args>(args)...); break;
size_t stride_w = std::accumulate(reduceDims.begin(), reduceDims.end(), 1, std::multiplies<size_t>()); case CV_32F: opDispatch<float>(std::forward<Args>(args)...); break;
for (int i = 0; i < outputs.size(); i++) default: CV_Error(cv::Error::BadDepth, "DNN/Reduce: Unsupported type.");
{
flops += total(outputs[i])*(stride_w);
} }
return flops;
} }
private: private:
enum ReduceType enum ReduceType
{ {
MAX, MAX,
MIN, MIN,
AVE, MEAN,
SUM, SUM,
L1, L1,
L2, L2,
@ -397,7 +509,11 @@ private:
SUM_SQUARE, SUM_SQUARE,
LOG_SUM, LOG_SUM,
LOG_SUM_EXP LOG_SUM_EXP
}; } reduce_type;
bool keepdims;
bool noop_with_empty_axes;
std::vector<int> axes;
}; };
Ptr<ReduceLayer> ReduceLayer::create(const LayerParams& params) Ptr<ReduceLayer> ReduceLayer::create(const LayerParams& params)
@ -405,5 +521,4 @@ Ptr<ReduceLayer> ReduceLayer::create(const LayerParams& params)
return Ptr<ReduceLayer>(new ReduceLayerImpl(params)); return Ptr<ReduceLayer>(new ReduceLayerImpl(params));
} }
} }} // cv::dnn
}

@ -1178,165 +1178,49 @@ void ONNXImporter::parseGlobalPool(LayerParams &layerParams, const opencv_onnx::
addLayer(layerParams, node_proto); addLayer(layerParams, node_proto);
} }
void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_) void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{ {
opencv_onnx::NodeProto node_proto = node_proto_; const auto& op_type = node_proto.op_type();
const std::string& layer_type = node_proto.op_type(); String reduce_type;
const std::string output_name = node_proto.output(0); if (op_type == "ReduceMax")
int depth = layerParams.get<int>("depth", CV_32F); reduce_type = "MAX";
else if (op_type == "ReduceMean")
CV_Assert(node_proto.input_size() <= 2); reduce_type = "MEAN";
String reduceType; else if (op_type == "ReduceMin")
reduce_type = "MIN";
if (layer_type == "ReduceMax") else if (op_type == "ReduceProd")
reduceType = "MAX"; reduce_type = "PROD";
else if (layer_type == "ReduceMin") else if (op_type == "ReduceSum")
reduceType = "MIN"; reduce_type = "SUM";
else if (layer_type == "ReduceSum") else if (op_type == "ReduceL1")
reduceType = "SUM"; reduce_type = "L1";
else if (layer_type == "ReduceSumSquare") else if (op_type == "ReduceL2")
reduceType = "SUM_SQUARE"; reduce_type = "L2";
else if (layer_type == "ReduceProd") else if (op_type == "ReduceLogSum")
reduceType = "PROD"; reduce_type = "LOG_SUM";
else if (layer_type == "ReduceL1") else if (op_type == "ReduceLogSumExp")
reduceType = "L1"; reduce_type = "LOG_SUM_EXP";
else if (layer_type == "ReduceL2") else if (op_type == "ReduceSumSquare")
reduceType = "L2"; reduce_type = "SUM_SQUARE";
else if (layer_type == "ReduceLogSum")
reduceType = "LOG_SUM";
else if (layer_type == "ReduceLogSumExp")
reduceType = "LOG_SUM_EXP";
else if (layer_type == "ReduceMean")
reduceType = "AVE";
else else
CV_Error(Error::StsNotImplemented, "Unsupported Pooling type of " + layer_type + " operation."); CV_Error(Error::StsNotImplemented, "DNN/ONNX: " + op_type + " is not supported.");
layerParams.set("reduce", reduce_type);
// The ReduceInt8 can only support "MAX" and "MIN".
if (depth == CV_8S)
{
CV_Assert(reduceType == "MAX" || reduceType == "MIN");
}
layerParams.type = (depth == CV_8S) ? "ReduceInt8" : "Reduce";
layerParams.set("reduce", reduceType);
bool keepdims = layerParams.get<int>("keepdims", 1) == 1;
MatShape inpShape = outShapes[node_proto.input(0)];
std::vector<bool> shouldDelete(inpShape.size(), false);
if (layer_type == "ReduceSum" && node_proto.input_size() == 2)
{
if (constBlobs.find(node_proto.input(1)) != constBlobs.end())
{
Mat axesMat = getBlob(node_proto, 1);
int axesNum = axesMat.total();
for (int i = 0; i < axesNum; i++)
{
int axis = normalize_axis(axesMat.at<int>(i), inpShape.size());
shouldDelete[axis] = true;
}
}
else
// in opset 13, the ReduceSum has two input, it takes axes as input instead of attribute
// details:https://github.com/onnx/onnx/issues/3420#issuecomment-844295687
CV_Error(Error::StsNotImplemented, "Non-constant axis values in ReduceSum are not supported.");
}
else
{
if (layerParams.has("axes"))
{
DictValue axes = layerParams.get("axes");
for (int i = 0; i < axes.size(); i++)
{
int axis = normalize_axis(axes.get<int>(i), inpShape.size());
shouldDelete[axis] = true;
}
}
else
{
for (int i = 0; i < inpShape.size(); i++)
{
shouldDelete[i] = true;
}
}
}
std::vector<int> targetShape;
for (int i = 0; i < inpShape.size(); ++i)
{
if (!shouldDelete[i])
{
targetShape.push_back(inpShape[i]);
}
else if (keepdims)
{
targetShape.push_back(1);
}
}
if (targetShape.empty())
targetShape.push_back(1);
// Using PermuteLayer to move the deleted axis to the last. int num_inputs = node_proto.input_size();
std::vector<int> perm(inpShape.size(), 0); CV_Check(num_inputs, num_inputs >= 1 && num_inputs <= 2, "DNN/ONNX: Reduce layers should have at least one input and at most two inputs");
for (int i = 0; i < inpShape.size(); i++)
perm[i] = i;
bool needPermuet = false;
for (int i = 0; i < inpShape.size(); i++)
{
if (shouldDelete[i])
{
// find the first not deleted element.
std::vector<bool>::iterator iter = std::find(shouldDelete.begin() + i, shouldDelete.end(), false);
if (iter != shouldDelete.end())
{
int index = iter - shouldDelete.begin();
bool temp = shouldDelete[index];
shouldDelete[index] = shouldDelete[i];
shouldDelete[i] = temp;
std::swap(perm[index], perm[i]);
std::swap(inpShape[index], inpShape[i]);
needPermuet = true;
}
else
break;
}
}
auto inputString= node_proto.input(0);
if (needPermuet)
{
LayerParams permuteLp;
permuteLp.name = layerParams.name + "/permute";
permuteLp.type = (depth == CV_8S) ? "PermuteInt8" : "Permute";
permuteLp.set("order", DictValue::arrayInt(perm.data(), perm.size()));
opencv_onnx::NodeProto protoPermute;
protoPermute.add_input(inputString);
protoPermute.add_output(permuteLp.name);
addLayer(permuteLp, protoPermute);
inputString = permuteLp.name;
}
std::vector<int> deletedDims; // "axes" is turned to one of the inputs since opset 18,
for (int axis_i = 0; axis_i < inpShape.size(); ++axis_i) // except for ReduceSum, which has "axes" input since opset 13.
{ if (!layerParams.has("axes") && num_inputs == 2 && constBlobs.find(node_proto.input(1)) != constBlobs.end()) {
if (shouldDelete[axis_i]) Mat mat_axes = getBlob(node_proto, 1);
{ int num_axes = mat_axes.total();
deletedDims.push_back(inpShape[axis_i]); std::vector<int> axes(num_axes);
} for (int i = 0; i < num_axes; ++i)
axes[i] = mat_axes.at<int>(i);
layerParams.set("axes", DictValue::arrayInt(&axes[0], num_axes));
} }
layerParams.set("deleted_dims", DictValue::arrayInt(&deletedDims[0], deletedDims.size())); layerParams.type = "Reduce";
layerParams.set("target_dims", DictValue::arrayInt(&targetShape[0], targetShape.size()));
node_proto.set_input(0, inputString);
node_proto.set_output(0, output_name);
addLayer(layerParams, node_proto); addLayer(layerParams, node_proto);
} }

Loading…
Cancel
Save