Fix Identity Switch from Keras

pull/14662/head
Dmitry Kurtaev 6 years ago
parent e07ffe902e
commit 081d9bc73f
  1. 11
      modules/dnn/src/tensorflow/tf_graph_simplifier.cpp
  2. 1
      modules/dnn/test/test_tf_importer.cpp

@ -601,7 +601,7 @@ public:
class UpsamplingKerasSubgraph : public Subgraph
{
public:
UpsamplingKerasSubgraph()
UpsamplingKerasSubgraph(const std::string& type)
{
int input = addNodeToMatch("");
int shape = addNodeToMatch("Shape", input);
@ -611,8 +611,8 @@ public:
int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
int factors = addNodeToMatch("Const");
int mul = addNodeToMatch("Mul", strided_slice, factors);
addNodeToMatch("ResizeNearestNeighbor", input, mul);
setFusedNode("ResizeNearestNeighbor", input, factors);
addNodeToMatch(type, input, mul);
setFusedNode(type, input, factors);
}
virtual void finalize(tensorflow::GraphDef& net, tensorflow::NodeDef* fusedNode,
@ -707,7 +707,8 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionValidKerasSubgraph()));
subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionSameKerasSubgraph()));
subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraph()));
subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph()));
subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph("ResizeNearestNeighbor")));
subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph("ResizeBilinear")));
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimSubgraph()));
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimV2Subgraph()));
subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph()));
@ -752,6 +753,8 @@ void RemoveIdentityOps(tensorflow::GraphDef& net)
tensorflow::NodeDef* layer = net.mutable_node(li);
for (int input_id = 0; input_id < layer->input_size(); input_id++) {
String input_op_name = layer->input(input_id);
input_op_name = input_op_name.substr(input_op_name.find('^') + 1,
input_op_name.rfind(':'));
IdentityOpsMap::iterator it = identity_ops.find(input_op_name);
if (it != identity_ops.end()) {

@ -186,6 +186,7 @@ TEST_P(Test_TensorFlow_layers, batch_norm)
runTensorFlowNet("unfused_batch_norm_no_gamma");
runTensorFlowNet("mvn_batch_norm");
runTensorFlowNet("mvn_batch_norm_1x1");
runTensorFlowNet("switch_identity");
}
TEST_P(Test_TensorFlow_layers, batch_norm3D)

Loading…
Cancel
Save