diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp index 0bec9e35e2..45935279f5 100644 --- a/modules/dnn/include/opencv2/dnn/all_layers.hpp +++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp @@ -284,6 +284,16 @@ CV__DNN_INLINE_NS_BEGIN static Ptr create(const LayerParams& params); }; + + /** @brief ArgMax/ArgMin layer + * @note returns indices as floats, which means the supported range is [-2^24; 2^24] + */ + class CV_EXPORTS ArgLayer : public Layer + { + public: + static Ptr create(const LayerParams& params); + }; + class CV_EXPORTS PoolingLayer : public Layer { public: diff --git a/modules/dnn/src/init.cpp b/modules/dnn/src/init.cpp index affaa1a7e1..443d1eaef4 100644 --- a/modules/dnn/src/init.cpp +++ b/modules/dnn/src/init.cpp @@ -123,6 +123,7 @@ void initializeLayerFactory() CV_DNN_REGISTER_LAYER_CLASS(Identity, BlankLayer); CV_DNN_REGISTER_LAYER_CLASS(Silence, BlankLayer); CV_DNN_REGISTER_LAYER_CLASS(Const, ConstLayer); + CV_DNN_REGISTER_LAYER_CLASS(Arg, ArgLayer); CV_DNN_REGISTER_LAYER_CLASS(Crop, CropLayer); CV_DNN_REGISTER_LAYER_CLASS(Eltwise, EltwiseLayer); diff --git a/modules/dnn/src/layers/arg_layer.cpp b/modules/dnn/src/layers/arg_layer.cpp new file mode 100644 index 0000000000..94af45882a --- /dev/null +++ b/modules/dnn/src/layers/arg_layer.cpp @@ -0,0 +1,120 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include "../precomp.hpp" +#include "layers_common.hpp" + + +namespace cv { namespace dnn { + +class ArgLayerImpl CV_FINAL : public ArgLayer +{ +public: + enum class ArgOp + { + MIN = 0, + MAX = 1, + }; + + ArgLayerImpl(const LayerParams& params) + { + setParamsFrom(params); + + axis = params.get("axis", 0); + keepdims = (params.get("keepdims", 1) == 1); + select_last_index = (params.get("select_last_index", 0) == 1); + + const std::string& argOp = params.get("op"); + + if (argOp == "max") + { + op = ArgOp::MAX; + } + else if (argOp == "min") + { + op = ArgOp::MIN; + } + else + { + CV_Error(Error::StsBadArg, "Unsupported operation"); + } + } + + virtual bool supportBackend(int backendId) CV_OVERRIDE + { + return backendId == DNN_BACKEND_OPENCV && preferableTarget == DNN_TARGET_CPU; + } + + void handleKeepDims(MatShape& shape, const int axis_) const + { + if (keepdims) + { + shape[axis_] = 1; + } + else + { + shape.erase(shape.begin() + axis_); + } + } + + virtual bool getMemoryShapes(const std::vector &inputs, + const int requiredOutputs, + std::vector &outputs, + std::vector &internals) const CV_OVERRIDE + { + MatShape inpShape = inputs[0]; + + const int axis_ = normalize_axis(axis, inpShape); + handleKeepDims(inpShape, axis_); + outputs.assign(1, inpShape); + + return false; + } + + void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE + { + CV_TRACE_FUNCTION(); + CV_TRACE_ARG_VALUE(name, "name", name.c_str()); + + std::vector inputs, outputs; + inputs_arr.getMatVector(inputs); + outputs_arr.getMatVector(outputs); + + CV_Assert_N(inputs.size() == 1, outputs.size() == 1); + std::vector outShape = shape(outputs[0]); + Mat output(outShape, CV_32SC1); + + switch (op) + { + case ArgOp::MIN: + cv::reduceArgMin(inputs[0], output, axis, select_last_index); + break; + case ArgOp::MAX: + cv::reduceArgMax(inputs[0], output, axis, select_last_index); + break; + default: + CV_Error(Error::StsBadArg, "Unsupported operation."); + } + + output = output.reshape(1, outShape); + output.convertTo(outputs[0], CV_32FC1); + } + +private: + // The axis in which to compute the arg indices. Accepted range is [-r, r-1] where r = rank(data). + int axis; + // Keep the reduced dimension or not + bool keepdims; + // Whether to select the first or the last index or Max/Min. + bool select_last_index; + // Operation to be performed + ArgOp op; +}; + +Ptr ArgLayer::create(const LayerParams& params) +{ + return Ptr(new ArgLayerImpl(params)); +} + +}} // namespace cv::dnn diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index b0d7d4b913..85c4479c6f 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -100,6 +100,7 @@ private: const DispatchMap dispatch; static const DispatchMap buildDispatchMap(); + void parseArg (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseMaxPool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseAveragePool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseReduce (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); @@ -768,6 +769,14 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto) } } +void ONNXImporter::parseArg(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) +{ + const std::string& layer_type = node_proto.op_type(); + layerParams.type = "Arg"; + layerParams.set("op", layer_type == "ArgMax" ? "max" : "min"); + addLayer(layerParams, node_proto); +} + void setCeilMode(LayerParams& layerParams) { // auto_pad attribute is deprecated and uses ceil @@ -2986,6 +2995,7 @@ const ONNXImporter::DispatchMap ONNXImporter::buildDispatchMap() { DispatchMap dispatch; + dispatch["ArgMax"] = dispatch["ArgMin"] = &ONNXImporter::parseArg; dispatch["MaxPool"] = &ONNXImporter::parseMaxPool; dispatch["AveragePool"] = &ONNXImporter::parseAveragePool; dispatch["GlobalAveragePool"] = dispatch["GlobalMaxPool"] = dispatch["ReduceMean"] = dispatch["ReduceSum"] = diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index f2deaf1a4e..bfea8550c0 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -355,6 +355,15 @@ TEST_P(Test_ONNX_layers, Min) testONNXModels("min", npy, 0, 0, false, true, 2); } +TEST_P(Test_ONNX_layers, ArgLayer) +{ + if (backend != DNN_BACKEND_OPENCV || target != DNN_TARGET_CPU) + throw SkipTestException("Only CPU is supported"); // FIXIT use tags + + testONNXModels("argmax"); + testONNXModels("argmin"); +} + TEST_P(Test_ONNX_layers, Scale) { if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)