|
|
|
@ -48,10 +48,42 @@ def createFasterRCNNGraph(modelPath, configPath, outputPath): |
|
|
|
|
|
|
|
|
|
removeIdentity(graph_def) |
|
|
|
|
|
|
|
|
|
nodesToKeep = [] |
|
|
|
|
def to_remove(name, op): |
|
|
|
|
if name in nodesToKeep: |
|
|
|
|
return False |
|
|
|
|
return op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \ |
|
|
|
|
(name.startswith('CropAndResize') and op != 'CropAndResize') |
|
|
|
|
|
|
|
|
|
# Fuse atrous convolutions (with dilations). |
|
|
|
|
nodesMap = {node.name: node for node in graph_def.node} |
|
|
|
|
for node in reversed(graph_def.node): |
|
|
|
|
if node.op == 'BatchToSpaceND': |
|
|
|
|
del node.input[2] |
|
|
|
|
conv = nodesMap[node.input[0]] |
|
|
|
|
spaceToBatchND = nodesMap[conv.input[0]] |
|
|
|
|
|
|
|
|
|
# Extract paddings |
|
|
|
|
stridedSlice = nodesMap[spaceToBatchND.input[2]] |
|
|
|
|
assert(stridedSlice.op == 'StridedSlice') |
|
|
|
|
pack = nodesMap[stridedSlice.input[0]] |
|
|
|
|
assert(pack.op == 'Pack') |
|
|
|
|
|
|
|
|
|
padNodeH = nodesMap[nodesMap[pack.input[0]].input[0]] |
|
|
|
|
padNodeW = nodesMap[nodesMap[pack.input[1]].input[0]] |
|
|
|
|
padH = int(padNodeH.attr['value']['tensor'][0]['int_val'][0]) |
|
|
|
|
padW = int(padNodeW.attr['value']['tensor'][0]['int_val'][0]) |
|
|
|
|
|
|
|
|
|
paddingsNode = NodeDef() |
|
|
|
|
paddingsNode.name = conv.name + '/paddings' |
|
|
|
|
paddingsNode.op = 'Const' |
|
|
|
|
paddingsNode.addAttr('value', [padH, padH, padW, padW]) |
|
|
|
|
graph_def.node.insert(graph_def.node.index(spaceToBatchND), paddingsNode) |
|
|
|
|
nodesToKeep.append(paddingsNode.name) |
|
|
|
|
|
|
|
|
|
spaceToBatchND.input[2] = paddingsNode.name |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
removeUnusedNodesAndAttrs(to_remove, graph_def) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -225,6 +257,26 @@ def createFasterRCNNGraph(modelPath, configPath, outputPath): |
|
|
|
|
detectionOut.addAttr('variance_encoded_in_target', True) |
|
|
|
|
graph_def.node.extend([detectionOut]) |
|
|
|
|
|
|
|
|
|
def getUnconnectedNodes(): |
|
|
|
|
unconnected = [node.name for node in graph_def.node] |
|
|
|
|
for node in graph_def.node: |
|
|
|
|
for inp in node.input: |
|
|
|
|
if inp in unconnected: |
|
|
|
|
unconnected.remove(inp) |
|
|
|
|
return unconnected |
|
|
|
|
|
|
|
|
|
while True: |
|
|
|
|
unconnectedNodes = getUnconnectedNodes() |
|
|
|
|
unconnectedNodes.remove(detectionOut.name) |
|
|
|
|
if not unconnectedNodes: |
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
for name in unconnectedNodes: |
|
|
|
|
for i in range(len(graph_def.node)): |
|
|
|
|
if graph_def.node[i].name == name: |
|
|
|
|
del graph_def.node[i] |
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
# Save as text. |
|
|
|
|
graph_def.save(outputPath) |
|
|
|
|
|
|
|
|
|