Samples DNN: tf_text_graph_sd.py loads box coder variance and box NMS params from config file

pull/15956/head
Lorenzo Lucignano 5 years ago
parent 1f57eb93fd
commit c40fbad12e
  1. 29
      samples/dnn/tf_text_graph_ssd.py

@ -283,6 +283,9 @@ def createSSDGraph(modelPath, configPath, outputPath):
# Add layers that generate anchors (bounding boxes proposals).
priorBoxes = []
boxCoder = config['box_coder'][0]
fasterRcnnBoxCoder = boxCoder['faster_rcnn_box_coder'][0]
boxCoderVariance = [1.0/float(fasterRcnnBoxCoder['x_scale'][0]), 1.0/float(fasterRcnnBoxCoder['y_scale'][0]), 1.0/float(fasterRcnnBoxCoder['width_scale'][0]), 1.0/float(fasterRcnnBoxCoder['height_scale'][0])]
for i in range(num_layers):
priorBox = NodeDef()
priorBox.name = 'PriorBox_%d' % i
@ -303,7 +306,7 @@ def createSSDGraph(modelPath, configPath, outputPath):
priorBox.addAttr('width', widths)
priorBox.addAttr('height', heights)
priorBox.addAttr('variance', [0.1, 0.1, 0.2, 0.2])
priorBox.addAttr('variance', boxCoderVariance)
graph_def.node.extend([priorBox])
priorBoxes.append(priorBox.name)
@ -336,11 +339,31 @@ def createSSDGraph(modelPath, configPath, outputPath):
detectionOut.addAttr('num_classes', num_classes + 1)
detectionOut.addAttr('share_location', True)
detectionOut.addAttr('background_label_id', 0)
postProcessing = config['post_processing'][0]
batchNMS = postProcessing['batch_non_max_suppression'][0]
if 'iou_threshold' in batchNMS:
detectionOut.addAttr('nms_threshold', float(batchNMS['iou_threshold'][0]))
else:
detectionOut.addAttr('nms_threshold', 0.6)
if 'score_threshold' in batchNMS:
detectionOut.addAttr('confidence_threshold', float(batchNMS['score_threshold'][0]))
else:
detectionOut.addAttr('confidence_threshold', 0.01)
if 'max_detections_per_class' in batchNMS:
detectionOut.addAttr('top_k', int(batchNMS['max_detections_per_class'][0]))
else:
detectionOut.addAttr('top_k', 100)
detectionOut.addAttr('code_type', "CENTER_SIZE")
if 'max_total_detections' in batchNMS:
detectionOut.addAttr('keep_top_k', int(batchNMS['max_total_detections'][0]))
else:
detectionOut.addAttr('keep_top_k', 100)
detectionOut.addAttr('confidence_threshold', 0.01)
detectionOut.addAttr('code_type', "CENTER_SIZE")
graph_def.node.extend([detectionOut])

Loading…
Cancel
Save