From 252ce0b58194f76d1a6b700dd18a39cd022757b3 Mon Sep 17 00:00:00 2001 From: cqn2219076254 <2219076254@qq.com> Date: Mon, 13 Dec 2021 21:43:13 +0800 Subject: [PATCH] add square layer --- modules/dnn/src/tensorflow/tf_importer.cpp | 22 +++++++++++++++++++++- modules/dnn/test/test_tf_importer.cpp | 7 +++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index 5fafa2b9d5..e37e888e35 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -598,7 +598,7 @@ private: void parseLeakyRelu (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams); void parseActivation (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams); void parseExpandDims (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams); - + void parseSquare (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams); void parseCustomLayer (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams); }; @@ -676,6 +676,7 @@ const TFImporter::DispatchMap TFImporter::buildDispatchMap() dispatch["Abs"] = dispatch["Tanh"] = dispatch["Sigmoid"] = dispatch["Relu"] = dispatch["Elu"] = dispatch["Exp"] = dispatch["Identity"] = dispatch["Relu6"] = &TFImporter::parseActivation; dispatch["ExpandDims"] = &TFImporter::parseExpandDims; + dispatch["Square"] = &TFImporter::parseSquare; return dispatch; } @@ -1252,6 +1253,25 @@ void TFImporter::parseExpandDims(tensorflow::GraphDef& net, const tensorflow::No } } +// "Square" +void TFImporter::parseSquare(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams) +{ + const std::string& name = layer.name(); + const int num_inputs = layer.input_size(); + + CV_CheckEQ(num_inputs, 1, ""); + + int id; + layerParams.set("operation", "prod"); + id = dstNet.addLayer(name, "Eltwise", layerParams); + + layer_id[name] = id; + + Pin inp = parsePin(layer.input(0)); + connect(layer_id, dstNet, inp, id, 0); + connect(layer_id, dstNet, inp, id, 1); +} + // "Flatten" "Squeeze" void TFImporter::parseFlatten(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams) { diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index 117177a860..1c4beb7468 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -670,6 +670,13 @@ TEST_P(Test_TensorFlow_layers, batch_matmul) runTensorFlowNet("batch_matmul"); } +TEST_P(Test_TensorFlow_layers, square) +{ + if (backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16) + applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16); + runTensorFlowNet("square"); +} + TEST_P(Test_TensorFlow_layers, reshape) { #if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_LT(2021040000)