diff --git a/samples/dnn/tf_text_graph_mask_rcnn.py b/samples/dnn/tf_text_graph_mask_rcnn.py index c8803088f9..24d8790d32 100644 --- a/samples/dnn/tf_text_graph_mask_rcnn.py +++ b/samples/dnn/tf_text_graph_mask_rcnn.py @@ -25,7 +25,8 @@ scopesToIgnore = ('FirstStageFeatureExtractor/Assert', 'FirstStageFeatureExtractor/Shape', 'FirstStageFeatureExtractor/strided_slice', 'FirstStageFeatureExtractor/GreaterEqual', - 'FirstStageFeatureExtractor/LogicalAnd') + 'FirstStageFeatureExtractor/LogicalAnd', + 'Conv/required_space_to_batch_paddings') # Load a config file. config = readTextMessage(args.config) @@ -54,10 +55,30 @@ graph_def = parseTextGraph(args.output) removeIdentity(graph_def) +nodesToKeep = [] def to_remove(name, op): + if name in nodesToKeep: + return False return op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \ (name.startswith('CropAndResize') and op != 'CropAndResize') +# Fuse atrous convolutions (with dilations). +nodesMap = {node.name: node for node in graph_def.node} +for node in reversed(graph_def.node): + if node.op == 'BatchToSpaceND': + del node.input[2] + conv = nodesMap[node.input[0]] + spaceToBatchND = nodesMap[conv.input[0]] + + paddingsNode = NodeDef() + paddingsNode.name = conv.name + '/paddings' + paddingsNode.op = 'Const' + paddingsNode.addAttr('value', [2, 2, 2, 2]) + graph_def.node.insert(graph_def.node.index(spaceToBatchND), paddingsNode) + nodesToKeep.append(paddingsNode.name) + + spaceToBatchND.input[2] = paddingsNode.name + removeUnusedNodesAndAttrs(to_remove, graph_def) @@ -106,8 +127,8 @@ heights = [] for a in aspect_ratios: for s in scales: ar = np.sqrt(a) - heights.append((features_stride**2) * s / ar) - widths.append((features_stride**2) * s * ar) + heights.append((height_stride**2) * s / ar) + widths.append((width_stride**2) * s * ar) proposals.addAttr('width', widths) proposals.addAttr('height', heights) @@ -252,5 +273,25 @@ graph_def.node[-1].name = 'detection_masks' graph_def.node[-1].op = 'Sigmoid' graph_def.node[-1].input.pop() +def getUnconnectedNodes(): + unconnected = [node.name for node in graph_def.node] + for node in graph_def.node: + for inp in node.input: + if inp in unconnected: + unconnected.remove(inp) + return unconnected + +while True: + unconnectedNodes = getUnconnectedNodes() + unconnectedNodes.remove(graph_def.node[-1].name) + if not unconnectedNodes: + break + + for name in unconnectedNodes: + for i in range(len(graph_def.node)): + if graph_def.node[i].name == name: + del graph_def.node[i] + break + # Save as text. graph_def.save(args.output)