Enable ResNet-based Mask-RCNN models from TensorFlow Object Detection API

pull/13766/head
Dmitry Kurtaev 6 years ago
parent a63f66c90e
commit 6ad3bf3130
  1. 47
      samples/dnn/tf_text_graph_mask_rcnn.py

@ -25,7 +25,8 @@ scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
'FirstStageFeatureExtractor/Shape',
'FirstStageFeatureExtractor/strided_slice',
'FirstStageFeatureExtractor/GreaterEqual',
'FirstStageFeatureExtractor/LogicalAnd')
'FirstStageFeatureExtractor/LogicalAnd',
'Conv/required_space_to_batch_paddings')
# Load a config file.
config = readTextMessage(args.config)
@ -54,10 +55,30 @@ graph_def = parseTextGraph(args.output)
removeIdentity(graph_def)
nodesToKeep = []
def to_remove(name, op):
if name in nodesToKeep:
return False
return op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \
(name.startswith('CropAndResize') and op != 'CropAndResize')
# Fuse atrous convolutions (with dilations).
nodesMap = {node.name: node for node in graph_def.node}
for node in reversed(graph_def.node):
if node.op == 'BatchToSpaceND':
del node.input[2]
conv = nodesMap[node.input[0]]
spaceToBatchND = nodesMap[conv.input[0]]
paddingsNode = NodeDef()
paddingsNode.name = conv.name + '/paddings'
paddingsNode.op = 'Const'
paddingsNode.addAttr('value', [2, 2, 2, 2])
graph_def.node.insert(graph_def.node.index(spaceToBatchND), paddingsNode)
nodesToKeep.append(paddingsNode.name)
spaceToBatchND.input[2] = paddingsNode.name
removeUnusedNodesAndAttrs(to_remove, graph_def)
@ -106,8 +127,8 @@ heights = []
for a in aspect_ratios:
for s in scales:
ar = np.sqrt(a)
heights.append((features_stride**2) * s / ar)
widths.append((features_stride**2) * s * ar)
heights.append((height_stride**2) * s / ar)
widths.append((width_stride**2) * s * ar)
proposals.addAttr('width', widths)
proposals.addAttr('height', heights)
@ -252,5 +273,25 @@ graph_def.node[-1].name = 'detection_masks'
graph_def.node[-1].op = 'Sigmoid'
graph_def.node[-1].input.pop()
def getUnconnectedNodes():
unconnected = [node.name for node in graph_def.node]
for node in graph_def.node:
for inp in node.input:
if inp in unconnected:
unconnected.remove(inp)
return unconnected
while True:
unconnectedNodes = getUnconnectedNodes()
unconnectedNodes.remove(graph_def.node[-1].name)
if not unconnectedNodes:
break
for name in unconnectedNodes:
for i in range(len(graph_def.node)):
if graph_def.node[i].name == name:
del graph_def.node[i]
break
# Save as text.
graph_def.save(args.output)

Loading…
Cancel
Save