diff --git a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp index 5cf65c8304..59d0d57cc8 100644 --- a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp +++ b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp @@ -630,6 +630,21 @@ public: } }; +class SoftMaxSlimSubgraph : public Subgraph +{ +public: + SoftMaxSlimSubgraph() + { + int input = addNodeToMatch(""); + int shape = addNodeToMatch("Const"); + int shapeOp = addNodeToMatch("Shape", input); + int reshape = addNodeToMatch("Reshape", input, shape); + int softmax = addNodeToMatch("Softmax", reshape); + addNodeToMatch("Reshape", softmax, shapeOp); + setFusedNode("Softmax", input); + } +}; + void simplifySubgraphs(tensorflow::GraphDef& net) { std::vector > subgraphs; @@ -646,6 +661,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net) subgraphs.push_back(Ptr(new ResizeBilinearSubgraph())); subgraphs.push_back(Ptr(new UpsamplingKerasSubgraph())); subgraphs.push_back(Ptr(new ReshapeAsShapeSubgraph())); + subgraphs.push_back(Ptr(new SoftMaxSlimSubgraph())); int numNodes = net.node_size(); std::vector matchedNodesIds; diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index 6ce99d6610..480e8c7d29 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -661,7 +661,10 @@ void TFImporter::populateNet(Net dstNet) RemoveIdentityOps(netTxt); if (!netTxt.ByteSize()) + { simplifySubgraphs(netBin); + sortByExecutionOrder(netBin); + } std::set layers_to_ignore; diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index a5d5512370..9a7b09c546 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -549,6 +549,7 @@ TEST_P(Test_TensorFlow_layers, slice) TEST_P(Test_TensorFlow_layers, softmax) { runTensorFlowNet("keras_softmax"); + runTensorFlowNet("slim_softmax"); } TEST_P(Test_TensorFlow_layers, relu6)