|
|
|
@ -25,7 +25,8 @@ scopesToIgnore = ('FirstStageFeatureExtractor/Assert', |
|
|
|
|
'FirstStageFeatureExtractor/Shape', |
|
|
|
|
'FirstStageFeatureExtractor/strided_slice', |
|
|
|
|
'FirstStageFeatureExtractor/GreaterEqual', |
|
|
|
|
'FirstStageFeatureExtractor/LogicalAnd') |
|
|
|
|
'FirstStageFeatureExtractor/LogicalAnd', |
|
|
|
|
'Conv/required_space_to_batch_paddings') |
|
|
|
|
|
|
|
|
|
# Load a config file. |
|
|
|
|
config = readTextMessage(args.config) |
|
|
|
@ -54,10 +55,30 @@ graph_def = parseTextGraph(args.output) |
|
|
|
|
|
|
|
|
|
removeIdentity(graph_def) |
|
|
|
|
|
|
|
|
|
nodesToKeep = [] |
|
|
|
|
def to_remove(name, op): |
|
|
|
|
if name in nodesToKeep: |
|
|
|
|
return False |
|
|
|
|
return op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \ |
|
|
|
|
(name.startswith('CropAndResize') and op != 'CropAndResize') |
|
|
|
|
|
|
|
|
|
# Fuse atrous convolutions (with dilations). |
|
|
|
|
nodesMap = {node.name: node for node in graph_def.node} |
|
|
|
|
for node in reversed(graph_def.node): |
|
|
|
|
if node.op == 'BatchToSpaceND': |
|
|
|
|
del node.input[2] |
|
|
|
|
conv = nodesMap[node.input[0]] |
|
|
|
|
spaceToBatchND = nodesMap[conv.input[0]] |
|
|
|
|
|
|
|
|
|
paddingsNode = NodeDef() |
|
|
|
|
paddingsNode.name = conv.name + '/paddings' |
|
|
|
|
paddingsNode.op = 'Const' |
|
|
|
|
paddingsNode.addAttr('value', [2, 2, 2, 2]) |
|
|
|
|
graph_def.node.insert(graph_def.node.index(spaceToBatchND), paddingsNode) |
|
|
|
|
nodesToKeep.append(paddingsNode.name) |
|
|
|
|
|
|
|
|
|
spaceToBatchND.input[2] = paddingsNode.name |
|
|
|
|
|
|
|
|
|
removeUnusedNodesAndAttrs(to_remove, graph_def) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -106,8 +127,8 @@ heights = [] |
|
|
|
|
for a in aspect_ratios: |
|
|
|
|
for s in scales: |
|
|
|
|
ar = np.sqrt(a) |
|
|
|
|
heights.append((features_stride**2) * s / ar) |
|
|
|
|
widths.append((features_stride**2) * s * ar) |
|
|
|
|
heights.append((height_stride**2) * s / ar) |
|
|
|
|
widths.append((width_stride**2) * s * ar) |
|
|
|
|
|
|
|
|
|
proposals.addAttr('width', widths) |
|
|
|
|
proposals.addAttr('height', heights) |
|
|
|
@ -252,5 +273,25 @@ graph_def.node[-1].name = 'detection_masks' |
|
|
|
|
graph_def.node[-1].op = 'Sigmoid' |
|
|
|
|
graph_def.node[-1].input.pop() |
|
|
|
|
|
|
|
|
|
def getUnconnectedNodes(): |
|
|
|
|
unconnected = [node.name for node in graph_def.node] |
|
|
|
|
for node in graph_def.node: |
|
|
|
|
for inp in node.input: |
|
|
|
|
if inp in unconnected: |
|
|
|
|
unconnected.remove(inp) |
|
|
|
|
return unconnected |
|
|
|
|
|
|
|
|
|
while True: |
|
|
|
|
unconnectedNodes = getUnconnectedNodes() |
|
|
|
|
unconnectedNodes.remove(graph_def.node[-1].name) |
|
|
|
|
if not unconnectedNodes: |
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
for name in unconnectedNodes: |
|
|
|
|
for i in range(len(graph_def.node)): |
|
|
|
|
if graph_def.node[i].name == name: |
|
|
|
|
del graph_def.node[i] |
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
# Save as text. |
|
|
|
|
graph_def.save(args.output) |
|
|
|
|