diff --git a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp index e56073788b..8d93bca869 100644 --- a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp +++ b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp @@ -601,7 +601,7 @@ public: class UpsamplingKerasSubgraph : public Subgraph { public: - UpsamplingKerasSubgraph() + UpsamplingKerasSubgraph(const std::string& type) { int input = addNodeToMatch(""); int shape = addNodeToMatch("Shape", input); @@ -611,8 +611,8 @@ public: int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2); int factors = addNodeToMatch("Const"); int mul = addNodeToMatch("Mul", strided_slice, factors); - addNodeToMatch("ResizeNearestNeighbor", input, mul); - setFusedNode("ResizeNearestNeighbor", input, factors); + addNodeToMatch(type, input, mul); + setFusedNode(type, input, factors); } virtual void finalize(tensorflow::GraphDef& net, tensorflow::NodeDef* fusedNode, @@ -707,7 +707,8 @@ void simplifySubgraphs(tensorflow::GraphDef& net) subgraphs.push_back(Ptr(new DeconvolutionValidKerasSubgraph())); subgraphs.push_back(Ptr(new DeconvolutionSameKerasSubgraph())); subgraphs.push_back(Ptr(new ResizeBilinearSubgraph())); - subgraphs.push_back(Ptr(new UpsamplingKerasSubgraph())); + subgraphs.push_back(Ptr(new UpsamplingKerasSubgraph("ResizeNearestNeighbor"))); + subgraphs.push_back(Ptr(new UpsamplingKerasSubgraph("ResizeBilinear"))); subgraphs.push_back(Ptr(new SoftMaxSlimSubgraph())); subgraphs.push_back(Ptr(new SoftMaxSlimV2Subgraph())); subgraphs.push_back(Ptr(new ReshapeAsShapeSubgraph())); @@ -752,6 +753,8 @@ void RemoveIdentityOps(tensorflow::GraphDef& net) tensorflow::NodeDef* layer = net.mutable_node(li); for (int input_id = 0; input_id < layer->input_size(); input_id++) { String input_op_name = layer->input(input_id); + input_op_name = input_op_name.substr(input_op_name.find('^') + 1, + input_op_name.rfind(':')); IdentityOpsMap::iterator it = identity_ops.find(input_op_name); if (it != identity_ops.end()) { diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index abbcf9fde9..c495af8cfe 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -186,6 +186,7 @@ TEST_P(Test_TensorFlow_layers, batch_norm) runTensorFlowNet("unfused_batch_norm_no_gamma"); runTensorFlowNet("mvn_batch_norm"); runTensorFlowNet("mvn_batch_norm_1x1"); + runTensorFlowNet("switch_identity"); } TEST_P(Test_TensorFlow_layers, batch_norm3D)