Open Source Computer Vision Library https://opencv.org/
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

292 lines
11 KiB

import argparse
import numpy as np
import tensorflow as tf
from tensorflow.core.framework.node_def_pb2 import NodeDef
from tensorflow.tools.graph_transforms import TransformGraph
from google.protobuf import text_format
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
'SSD model from TensorFlow Object Detection API. '
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
parser.add_argument('--output', required=True, help='Path to output text graph.')
parser.add_argument('--num_classes', default=90, type=int, help='Number of trained classes.')
parser.add_argument('--scales', default=[0.25, 0.5, 1.0, 2.0], type=float, nargs='+',
help='Hyper-parameter of grid_anchor_generator from a config file.')
parser.add_argument('--aspect_ratios', default=[0.5, 1.0, 2.0], type=float, nargs='+',
help='Hyper-parameter of grid_anchor_generator from a config file.')
parser.add_argument('--features_stride', default=16, type=float, nargs='+',
help='Hyper-parameter from a config file.')
args = parser.parse_args()
scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',
'FirstStageBoxPredictor/BoxEncodingPredictor',
'FirstStageBoxPredictor/ClassPredictor',
'CropAndResize',
'MaxPool2D',
'SecondStageFeatureExtractor',
'SecondStageBoxPredictor',
'image_tensor')
scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
'FirstStageFeatureExtractor/Shape',
'FirstStageFeatureExtractor/strided_slice',
'FirstStageFeatureExtractor/GreaterEqual',
'FirstStageFeatureExtractor/LogicalAnd')
unusedAttrs = ['T', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu',
'Index', 'Tperm', 'is_training', 'Tpaddings']
# Read the graph.
with tf.gfile.FastGFile(args.input, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# Removes Identity nodes
def removeIdentity():
identities = {}
for node in graph_def.node:
if node.op == 'Identity':
identities[node.name] = node.input[0]
graph_def.node.remove(node)
for node in graph_def.node:
for i in range(len(node.input)):
if node.input[i] in identities:
node.input[i] = identities[node.input[i]]
removeIdentity()
removedNodes = []
for i in reversed(range(len(graph_def.node))):
op = graph_def.node[i].op
name = graph_def.node[i].name
if op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep):
if op != 'Const':
removedNodes.append(name)
del graph_def.node[i]
else:
for attr in unusedAttrs:
if attr in graph_def.node[i].attr:
del graph_def.node[i].attr[attr]
# Remove references to removed nodes except Const nodes.
for node in graph_def.node:
for i in reversed(range(len(node.input))):
if node.input[i] in removedNodes:
del node.input[i]
# Connect input node to the first layer
assert(graph_def.node[0].op == 'Placeholder')
graph_def.node[1].input.insert(0, graph_def.node[0].name)
# Temporarily remove top nodes.
topNodes = []
while True:
node = graph_def.node.pop()
topNodes.append(node)
if node.op == 'CropAndResize':
break
def tensorMsg(values):
if all([isinstance(v, float) for v in values]):
dtype = 'DT_FLOAT'
field = 'float_val'
elif all([isinstance(v, int) for v in values]):
dtype = 'DT_INT32'
field = 'int_val'
else:
raise Exception('Wrong values types')
msg = 'tensor { dtype: ' + dtype + ' tensor_shape { dim { size: %d } }' % len(values)
for value in values:
msg += '%s: %s ' % (field, str(value))
return msg + '}'
def addSlice(inp, out, begins, sizes):
beginsNode = NodeDef()
beginsNode.name = out + '/begins'
beginsNode.op = 'Const'
text_format.Merge(tensorMsg(begins), beginsNode.attr["value"])
graph_def.node.extend([beginsNode])
sizesNode = NodeDef()
sizesNode.name = out + '/sizes'
sizesNode.op = 'Const'
text_format.Merge(tensorMsg(sizes), sizesNode.attr["value"])
graph_def.node.extend([sizesNode])
sliced = NodeDef()
sliced.name = out
sliced.op = 'Slice'
sliced.input.append(inp)
sliced.input.append(beginsNode.name)
sliced.input.append(sizesNode.name)
graph_def.node.extend([sliced])
def addReshape(inp, out, shape):
shapeNode = NodeDef()
shapeNode.name = out + '/shape'
shapeNode.op = 'Const'
text_format.Merge(tensorMsg(shape), shapeNode.attr["value"])
graph_def.node.extend([shapeNode])
reshape = NodeDef()
reshape.name = out
reshape.op = 'Reshape'
reshape.input.append(inp)
reshape.input.append(shapeNode.name)
graph_def.node.extend([reshape])
def addSoftMax(inp, out):
softmax = NodeDef()
softmax.name = out
softmax.op = 'Softmax'
text_format.Merge('i: -1', softmax.attr['axis'])
softmax.input.append(inp)
graph_def.node.extend([softmax])
addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd',
'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2])
addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1',
'FirstStageBoxPredictor/ClassPredictor/softmax') # Compare with Reshape_4
flatten = NodeDef()
flatten.name = 'FirstStageBoxPredictor/BoxEncodingPredictor/flatten' # Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd
flatten.op = 'Flatten'
flatten.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
graph_def.node.extend([flatten])
proposals = NodeDef()
proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized)
proposals.op = 'PriorBox'
proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
proposals.input.append(graph_def.node[0].name) # image_tensor
text_format.Merge('b: false', proposals.attr["flip"])
text_format.Merge('b: true', proposals.attr["clip"])
text_format.Merge('f: %f' % args.features_stride, proposals.attr["step"])
text_format.Merge('f: 0.0', proposals.attr["offset"])
text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), proposals.attr["variance"])
widths = []
heights = []
for a in args.aspect_ratios:
for s in args.scales:
ar = np.sqrt(a)
heights.append((args.features_stride**2) * s / ar)
widths.append((args.features_stride**2) * s * ar)
text_format.Merge(tensorMsg(widths), proposals.attr["width"])
text_format.Merge(tensorMsg(heights), proposals.attr["height"])
graph_def.node.extend([proposals])
# Compare with Reshape_5
detectionOut = NodeDef()
detectionOut.name = 'detection_out'
detectionOut.op = 'DetectionOutput'
detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten')
detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax')
detectionOut.input.append('proposals')
text_format.Merge('i: 2', detectionOut.attr['num_classes'])
text_format.Merge('b: true', detectionOut.attr['share_location'])
text_format.Merge('i: 0', detectionOut.attr['background_label_id'])
text_format.Merge('f: 0.7', detectionOut.attr['nms_threshold'])
text_format.Merge('i: 6000', detectionOut.attr['top_k'])
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
text_format.Merge('b: true', detectionOut.attr['clip'])
text_format.Merge('b: true', detectionOut.attr['loc_pred_transposed'])
graph_def.node.extend([detectionOut])
# Save as text.
for node in reversed(topNodes):
graph_def.node.extend([node])
addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax')
addSlice('SecondStageBoxPredictor/Reshape_1/softmax',
'SecondStageBoxPredictor/Reshape_1/slice',
[0, 0, 1], [-1, -1, -1])
addReshape('SecondStageBoxPredictor/Reshape_1/slice',
'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1])
# Replace Flatten subgraph onto a single node.
for i in reversed(range(len(graph_def.node))):
if graph_def.node[i].op == 'CropAndResize':
graph_def.node[i].input.insert(1, 'detection_out')
if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape':
shapeNode = NodeDef()
shapeNode.name = 'SecondStageBoxPredictor/Reshape/shape2'
shapeNode.op = 'Const'
text_format.Merge(tensorMsg([1, -1, 4]), shapeNode.attr["value"])
graph_def.node.extend([shapeNode])
graph_def.node[i].input.pop()
graph_def.node[i].input.append(shapeNode.name)
if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape',
'SecondStageBoxPredictor/Flatten/flatten/strided_slice',
'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape']:
del graph_def.node[i]
for node in graph_def.node:
if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape':
node.op = 'Flatten'
node.input.pop()
break
################################################################################
### Postprocessing
################################################################################
addSlice('detection_out', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4])
variance = NodeDef()
variance.name = 'proposals/variance'
variance.op = 'Const'
text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), variance.attr["value"])
graph_def.node.extend([variance])
varianceEncoder = NodeDef()
varianceEncoder.name = 'variance_encoded'
varianceEncoder.op = 'Mul'
varianceEncoder.input.append('SecondStageBoxPredictor/Reshape')
varianceEncoder.input.append(variance.name)
text_format.Merge('i: 2', varianceEncoder.attr["axis"])
graph_def.node.extend([varianceEncoder])
addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1])
detectionOut = NodeDef()
detectionOut.name = 'detection_out_final'
detectionOut.op = 'DetectionOutput'
detectionOut.input.append('variance_encoded')
detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape')
detectionOut.input.append('detection_out/slice/reshape')
text_format.Merge('i: %d' % args.num_classes, detectionOut.attr['num_classes'])
text_format.Merge('b: false', detectionOut.attr['share_location'])
text_format.Merge('i: %d' % (args.num_classes + 1), detectionOut.attr['background_label_id'])
text_format.Merge('f: 0.6', detectionOut.attr['nms_threshold'])
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
text_format.Merge('b: true', detectionOut.attr['loc_pred_transposed'])
text_format.Merge('b: true', detectionOut.attr['clip'])
text_format.Merge('b: true', detectionOut.attr['variance_encoded_in_target'])
graph_def.node.extend([detectionOut])
tf.train.write_graph(graph_def, "", args.output, as_text=True)