diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index e37e888e35..efaedfaab1 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -599,6 +599,8 @@ private: void parseActivation (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams); void parseExpandDims (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams); void parseSquare (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams); + void parseArg (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams); + void parseCustomLayer (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams); }; @@ -677,6 +679,7 @@ const TFImporter::DispatchMap TFImporter::buildDispatchMap() dispatch["Elu"] = dispatch["Exp"] = dispatch["Identity"] = dispatch["Relu6"] = &TFImporter::parseActivation; dispatch["ExpandDims"] = &TFImporter::parseExpandDims; dispatch["Square"] = &TFImporter::parseSquare; + dispatch["ArgMax"] = dispatch["ArgMin"] = &TFImporter::parseArg; return dispatch; } @@ -2624,6 +2627,22 @@ void TFImporter::parseActivation(tensorflow::GraphDef& net, const tensorflow::No connectToAllBlobs(layer_id, dstNet, parsePin(layer.input(0)), id, num_inputs); } +void TFImporter::parseArg(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams) +{ + const std::string& name = layer.name(); + const std::string& type = layer.op(); + + Mat dimension = getTensorContent(getConstBlob(layer, value_id, 1)); + CV_Assert(dimension.total() == 1 && dimension.type() == CV_32SC1); + layerParams.set("axis", *dimension.ptr()); + layerParams.set("op", type == "ArgMax" ? "max" : "min"); + layerParams.set("keepdims", false); //tensorflow doesn't have this atrr, the output's dims minus one(default); + + int id = dstNet.addLayer(name, "Arg", layerParams); + layer_id[name] = id; + connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0); +} + void TFImporter::parseCustomLayer(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams) { // Importer does not know how to map this TensorFlow's operation onto OpenCV's layer. diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index 1c4beb7468..b40a604a6e 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -185,6 +185,14 @@ TEST_P(Test_TensorFlow_layers, reduce_sum_channel_keep_dims) runTensorFlowNet("reduce_sum_channel", false, 0.0, 0.0, false, "_keep_dims"); } +TEST_P(Test_TensorFlow_layers, ArgLayer) +{ + if (backend != DNN_BACKEND_OPENCV || target != DNN_TARGET_CPU) + throw SkipTestException("Only CPU is supported"); // FIXIT use tags + runTensorFlowNet("argmax"); + runTensorFlowNet("argmin"); +} + TEST_P(Test_TensorFlow_layers, conv_single_conv) { runTensorFlowNet("single_conv");