|
|
|
@ -13,6 +13,7 @@ import tensorflow as tf |
|
|
|
|
import argparse |
|
|
|
|
from math import sqrt |
|
|
|
|
from tensorflow.core.framework.node_def_pb2 import NodeDef |
|
|
|
|
from tensorflow.tools.graph_transforms import TransformGraph |
|
|
|
|
from google.protobuf import text_format |
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Run this script to get a text graph of ' |
|
|
|
@ -32,7 +33,7 @@ args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
# Nodes that should be kept. |
|
|
|
|
keepOps = ['Conv2D', 'BiasAdd', 'Add', 'Relu6', 'Placeholder', 'FusedBatchNorm', |
|
|
|
|
'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool'] |
|
|
|
|
'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool', 'Identity'] |
|
|
|
|
|
|
|
|
|
# Nodes attributes that could be removed because they are not used during import. |
|
|
|
|
unusedAttrs = ['T', 'data_format', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu', |
|
|
|
@ -46,6 +47,10 @@ with tf.gfile.FastGFile(args.input, 'rb') as f: |
|
|
|
|
graph_def = tf.GraphDef() |
|
|
|
|
graph_def.ParseFromString(f.read()) |
|
|
|
|
|
|
|
|
|
inpNames = ['image_tensor'] |
|
|
|
|
outNames = ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes'] |
|
|
|
|
graph_def = TransformGraph(graph_def, inpNames, outNames, ['sort_by_execution_order']) |
|
|
|
|
|
|
|
|
|
def getUnconnectedNodes(): |
|
|
|
|
unconnected = [] |
|
|
|
|
for node in graph_def.node: |
|
|
|
@ -98,6 +103,7 @@ def removeIdentity(): |
|
|
|
|
for node in graph_def.node: |
|
|
|
|
if node.op == 'Identity': |
|
|
|
|
identities[node.name] = node.input[0] |
|
|
|
|
graph_def.node.remove(node) |
|
|
|
|
|
|
|
|
|
for node in graph_def.node: |
|
|
|
|
for i in range(len(node.input)): |
|
|
|
|