|
|
@ -6,6 +6,8 @@ from tensorflow.core.framework.node_def_pb2 import NodeDef |
|
|
|
from tensorflow.tools.graph_transforms import TransformGraph |
|
|
|
from tensorflow.tools.graph_transforms import TransformGraph |
|
|
|
from google.protobuf import text_format |
|
|
|
from google.protobuf import text_format |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from tf_text_graph_common import tensorMsg, addConstNode |
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Run this script to get a text graph of ' |
|
|
|
parser = argparse.ArgumentParser(description='Run this script to get a text graph of ' |
|
|
|
'SSD model from TensorFlow Object Detection API. ' |
|
|
|
'SSD model from TensorFlow Object Detection API. ' |
|
|
|
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.') |
|
|
|
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.') |
|
|
@ -93,21 +95,6 @@ while True: |
|
|
|
if node.op == 'CropAndResize': |
|
|
|
if node.op == 'CropAndResize': |
|
|
|
break |
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
def tensorMsg(values): |
|
|
|
|
|
|
|
if all([isinstance(v, float) for v in values]): |
|
|
|
|
|
|
|
dtype = 'DT_FLOAT' |
|
|
|
|
|
|
|
field = 'float_val' |
|
|
|
|
|
|
|
elif all([isinstance(v, int) for v in values]): |
|
|
|
|
|
|
|
dtype = 'DT_INT32' |
|
|
|
|
|
|
|
field = 'int_val' |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
raise Exception('Wrong values types') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
msg = 'tensor { dtype: ' + dtype + ' tensor_shape { dim { size: %d } }' % len(values) |
|
|
|
|
|
|
|
for value in values: |
|
|
|
|
|
|
|
msg += '%s: %s ' % (field, str(value)) |
|
|
|
|
|
|
|
return msg + '}' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def addSlice(inp, out, begins, sizes): |
|
|
|
def addSlice(inp, out, begins, sizes): |
|
|
|
beginsNode = NodeDef() |
|
|
|
beginsNode = NodeDef() |
|
|
|
beginsNode.name = out + '/begins' |
|
|
|
beginsNode.name = out + '/begins' |
|
|
@ -151,17 +138,25 @@ def addSoftMax(inp, out): |
|
|
|
softmax.input.append(inp) |
|
|
|
softmax.input.append(inp) |
|
|
|
graph_def.node.extend([softmax]) |
|
|
|
graph_def.node.extend([softmax]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def addFlatten(inp, out): |
|
|
|
|
|
|
|
flatten = NodeDef() |
|
|
|
|
|
|
|
flatten.name = out |
|
|
|
|
|
|
|
flatten.op = 'Flatten' |
|
|
|
|
|
|
|
flatten.input.append(inp) |
|
|
|
|
|
|
|
graph_def.node.extend([flatten]) |
|
|
|
|
|
|
|
|
|
|
|
addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd', |
|
|
|
addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd', |
|
|
|
'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2]) |
|
|
|
'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2]) |
|
|
|
|
|
|
|
|
|
|
|
addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1', |
|
|
|
addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1', |
|
|
|
'FirstStageBoxPredictor/ClassPredictor/softmax') # Compare with Reshape_4 |
|
|
|
'FirstStageBoxPredictor/ClassPredictor/softmax') # Compare with Reshape_4 |
|
|
|
|
|
|
|
|
|
|
|
flatten = NodeDef() |
|
|
|
addFlatten('FirstStageBoxPredictor/ClassPredictor/softmax', |
|
|
|
flatten.name = 'FirstStageBoxPredictor/BoxEncodingPredictor/flatten' # Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd |
|
|
|
'FirstStageBoxPredictor/ClassPredictor/softmax/flatten') |
|
|
|
flatten.op = 'Flatten' |
|
|
|
|
|
|
|
flatten.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd') |
|
|
|
# Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd |
|
|
|
graph_def.node.extend([flatten]) |
|
|
|
addFlatten('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd', |
|
|
|
|
|
|
|
'FirstStageBoxPredictor/BoxEncodingPredictor/flatten') |
|
|
|
|
|
|
|
|
|
|
|
proposals = NodeDef() |
|
|
|
proposals = NodeDef() |
|
|
|
proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized) |
|
|
|
proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized) |
|
|
@ -194,7 +189,7 @@ detectionOut.name = 'detection_out' |
|
|
|
detectionOut.op = 'DetectionOutput' |
|
|
|
detectionOut.op = 'DetectionOutput' |
|
|
|
|
|
|
|
|
|
|
|
detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten') |
|
|
|
detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten') |
|
|
|
detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax') |
|
|
|
detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten') |
|
|
|
detectionOut.input.append('proposals') |
|
|
|
detectionOut.input.append('proposals') |
|
|
|
|
|
|
|
|
|
|
|
text_format.Merge('i: 2', detectionOut.attr['num_classes']) |
|
|
|
text_format.Merge('i: 2', detectionOut.attr['num_classes']) |
|
|
@ -204,11 +199,21 @@ text_format.Merge('f: 0.7', detectionOut.attr['nms_threshold']) |
|
|
|
text_format.Merge('i: 6000', detectionOut.attr['top_k']) |
|
|
|
text_format.Merge('i: 6000', detectionOut.attr['top_k']) |
|
|
|
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type']) |
|
|
|
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type']) |
|
|
|
text_format.Merge('i: 100', detectionOut.attr['keep_top_k']) |
|
|
|
text_format.Merge('i: 100', detectionOut.attr['keep_top_k']) |
|
|
|
text_format.Merge('b: true', detectionOut.attr['clip']) |
|
|
|
text_format.Merge('b: false', detectionOut.attr['clip']) |
|
|
|
text_format.Merge('b: true', detectionOut.attr['loc_pred_transposed']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_def.node.extend([detectionOut]) |
|
|
|
graph_def.node.extend([detectionOut]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
addConstNode('clip_by_value/lower', [0.0], graph_def) |
|
|
|
|
|
|
|
addConstNode('clip_by_value/upper', [1.0], graph_def) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipByValueNode = NodeDef() |
|
|
|
|
|
|
|
clipByValueNode.name = 'detection_out/clip_by_value' |
|
|
|
|
|
|
|
clipByValueNode.op = 'ClipByValue' |
|
|
|
|
|
|
|
clipByValueNode.input.append('detection_out') |
|
|
|
|
|
|
|
clipByValueNode.input.append('clip_by_value/lower') |
|
|
|
|
|
|
|
clipByValueNode.input.append('clip_by_value/upper') |
|
|
|
|
|
|
|
graph_def.node.extend([clipByValueNode]) |
|
|
|
|
|
|
|
|
|
|
|
# Save as text. |
|
|
|
# Save as text. |
|
|
|
for node in reversed(topNodes): |
|
|
|
for node in reversed(topNodes): |
|
|
|
graph_def.node.extend([node]) |
|
|
|
graph_def.node.extend([node]) |
|
|
@ -225,17 +230,13 @@ addReshape('SecondStageBoxPredictor/Reshape_1/slice', |
|
|
|
# Replace Flatten subgraph onto a single node. |
|
|
|
# Replace Flatten subgraph onto a single node. |
|
|
|
for i in reversed(range(len(graph_def.node))): |
|
|
|
for i in reversed(range(len(graph_def.node))): |
|
|
|
if graph_def.node[i].op == 'CropAndResize': |
|
|
|
if graph_def.node[i].op == 'CropAndResize': |
|
|
|
graph_def.node[i].input.insert(1, 'detection_out') |
|
|
|
graph_def.node[i].input.insert(1, 'detection_out/clip_by_value') |
|
|
|
|
|
|
|
|
|
|
|
if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape': |
|
|
|
if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape': |
|
|
|
shapeNode = NodeDef() |
|
|
|
addConstNode('SecondStageBoxPredictor/Reshape/shape2', [1, -1, 4], graph_def) |
|
|
|
shapeNode.name = 'SecondStageBoxPredictor/Reshape/shape2' |
|
|
|
|
|
|
|
shapeNode.op = 'Const' |
|
|
|
|
|
|
|
text_format.Merge(tensorMsg([1, -1, 4]), shapeNode.attr["value"]) |
|
|
|
|
|
|
|
graph_def.node.extend([shapeNode]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_def.node[i].input.pop() |
|
|
|
graph_def.node[i].input.pop() |
|
|
|
graph_def.node[i].input.append(shapeNode.name) |
|
|
|
graph_def.node[i].input.append('SecondStageBoxPredictor/Reshape/shape2') |
|
|
|
|
|
|
|
|
|
|
|
if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape', |
|
|
|
if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape', |
|
|
|
'SecondStageBoxPredictor/Flatten/flatten/strided_slice', |
|
|
|
'SecondStageBoxPredictor/Flatten/flatten/strided_slice', |
|
|
@ -246,12 +247,15 @@ for node in graph_def.node: |
|
|
|
if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape': |
|
|
|
if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape': |
|
|
|
node.op = 'Flatten' |
|
|
|
node.op = 'Flatten' |
|
|
|
node.input.pop() |
|
|
|
node.input.pop() |
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D', |
|
|
|
|
|
|
|
'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']: |
|
|
|
|
|
|
|
text_format.Merge('b: true', node.attr["loc_pred_transposed"]) |
|
|
|
|
|
|
|
|
|
|
|
################################################################################ |
|
|
|
################################################################################ |
|
|
|
### Postprocessing |
|
|
|
### Postprocessing |
|
|
|
################################################################################ |
|
|
|
################################################################################ |
|
|
|
addSlice('detection_out', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4]) |
|
|
|
addSlice('detection_out/clip_by_value', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4]) |
|
|
|
|
|
|
|
|
|
|
|
variance = NodeDef() |
|
|
|
variance = NodeDef() |
|
|
|
variance.name = 'proposals/variance' |
|
|
|
variance.name = 'proposals/variance' |
|
|
@ -268,12 +272,13 @@ text_format.Merge('i: 2', varianceEncoder.attr["axis"]) |
|
|
|
graph_def.node.extend([varianceEncoder]) |
|
|
|
graph_def.node.extend([varianceEncoder]) |
|
|
|
|
|
|
|
|
|
|
|
addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1]) |
|
|
|
addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1]) |
|
|
|
|
|
|
|
addFlatten('variance_encoded', 'variance_encoded/flatten') |
|
|
|
|
|
|
|
|
|
|
|
detectionOut = NodeDef() |
|
|
|
detectionOut = NodeDef() |
|
|
|
detectionOut.name = 'detection_out_final' |
|
|
|
detectionOut.name = 'detection_out_final' |
|
|
|
detectionOut.op = 'DetectionOutput' |
|
|
|
detectionOut.op = 'DetectionOutput' |
|
|
|
|
|
|
|
|
|
|
|
detectionOut.input.append('variance_encoded') |
|
|
|
detectionOut.input.append('variance_encoded/flatten') |
|
|
|
detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape') |
|
|
|
detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape') |
|
|
|
detectionOut.input.append('detection_out/slice/reshape') |
|
|
|
detectionOut.input.append('detection_out/slice/reshape') |
|
|
|
|
|
|
|
|
|
|
@ -283,7 +288,6 @@ text_format.Merge('i: %d' % (args.num_classes + 1), detectionOut.attr['backgroun |
|
|
|
text_format.Merge('f: 0.6', detectionOut.attr['nms_threshold']) |
|
|
|
text_format.Merge('f: 0.6', detectionOut.attr['nms_threshold']) |
|
|
|
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type']) |
|
|
|
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type']) |
|
|
|
text_format.Merge('i: 100', detectionOut.attr['keep_top_k']) |
|
|
|
text_format.Merge('i: 100', detectionOut.attr['keep_top_k']) |
|
|
|
text_format.Merge('b: true', detectionOut.attr['loc_pred_transposed']) |
|
|
|
|
|
|
|
text_format.Merge('b: true', detectionOut.attr['clip']) |
|
|
|
text_format.Merge('b: true', detectionOut.attr['clip']) |
|
|
|
text_format.Merge('b: true', detectionOut.attr['variance_encoded_in_target']) |
|
|
|
text_format.Merge('b: true', detectionOut.attr['variance_encoded_in_target']) |
|
|
|
graph_def.node.extend([detectionOut]) |
|
|
|
graph_def.node.extend([detectionOut]) |
|
|
|