From 76cfa65d55726152c3ced49e6cc8bde2f22959db Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Mon, 30 Dec 2019 20:06:58 +0300 Subject: [PATCH] AddV2 from TensorFlow --- modules/dnn/src/tensorflow/tf_importer.cpp | 2 +- samples/dnn/tf_text_graph_ssd.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index b1d7178798..192b94e3bb 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -996,7 +996,7 @@ void TFImporter::populateNet(Net dstNet) if (getDataLayout(name, data_layouts) == DATA_LAYOUT_UNKNOWN) data_layouts[name] = DATA_LAYOUT_NHWC; } - else if (type == "BiasAdd" || type == "Add" || type == "Sub" || type=="AddN") + else if (type == "BiasAdd" || type == "Add" || type == "AddV2" || type == "Sub" || type=="AddN") { bool haveConst = false; for(int ii = 0; !haveConst && ii < layer.input_size(); ++ii) diff --git a/samples/dnn/tf_text_graph_ssd.py b/samples/dnn/tf_text_graph_ssd.py index 1060047260..b466548502 100644 --- a/samples/dnn/tf_text_graph_ssd.py +++ b/samples/dnn/tf_text_graph_ssd.py @@ -62,7 +62,7 @@ class MultiscaleAnchorGenerator: def createSSDGraph(modelPath, configPath, outputPath): # Nodes that should be kept. - keepOps = ['Conv2D', 'BiasAdd', 'Add', 'Relu', 'Relu6', 'Placeholder', 'FusedBatchNorm', + keepOps = ['Conv2D', 'BiasAdd', 'Add', 'AddV2', 'Relu', 'Relu6', 'Placeholder', 'FusedBatchNorm', 'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool', 'Identity', 'Sub', 'ResizeNearestNeighbor', 'Pad', 'FusedBatchNormV3'] @@ -151,6 +151,9 @@ def createSSDGraph(modelPath, configPath, outputPath): subgraphBatchNorm = ['Add', ['Mul', 'input', ['Mul', ['Rsqrt', ['Add', 'moving_variance', 'add_y']], 'gamma']], ['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]] + subgraphBatchNormV2 = ['AddV2', + ['Mul', 'input', ['Mul', ['Rsqrt', ['AddV2', 'moving_variance', 'add_y']], 'gamma']], + ['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]] # Detect unfused nearest neighbor resize. subgraphResizeNN = ['Reshape', ['Mul', ['Reshape', 'input', ['Pack', 'shape_1', 'shape_2', 'shape_3', 'shape_4', 'shape_5']], @@ -177,7 +180,8 @@ def createSSDGraph(modelPath, configPath, outputPath): for node in graph_def.node: inputs = {} fusedNodes = [] - if checkSubgraph(node, subgraphBatchNorm, inputs, fusedNodes): + if checkSubgraph(node, subgraphBatchNorm, inputs, fusedNodes) or \ + checkSubgraph(node, subgraphBatchNormV2, inputs, fusedNodes): name = node.name node.Clear() node.name = name