diff --git a/samples/dnn/tf_text_graph_mask_rcnn.py b/samples/dnn/tf_text_graph_mask_rcnn.py index b92d4623b8..aaefe456ad 100644 --- a/samples/dnn/tf_text_graph_mask_rcnn.py +++ b/samples/dnn/tf_text_graph_mask_rcnn.py @@ -38,6 +38,8 @@ aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']] width_stride = float(grid_anchor_generator['width_stride'][0]) height_stride = float(grid_anchor_generator['height_stride'][0]) features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0]) +first_stage_nms_iou_threshold = float(config['first_stage_nms_iou_threshold'][0]) +first_stage_max_proposals = int(config['first_stage_max_proposals'][0]) print('Number of classes: %d' % num_classes) print('Scales: %s' % str(scales)) @@ -53,7 +55,8 @@ graph_def = parseTextGraph(args.output) removeIdentity(graph_def) def to_remove(name, op): - return name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) + return name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \ + (name.startswith('CropAndResize') and op != 'CropAndResize') removeUnusedNodesAndAttrs(to_remove, graph_def) @@ -123,20 +126,22 @@ detectionOut.input.append('proposals') detectionOut.addAttr('num_classes', 2) detectionOut.addAttr('share_location', True) detectionOut.addAttr('background_label_id', 0) -detectionOut.addAttr('nms_threshold', 0.7) +detectionOut.addAttr('nms_threshold', first_stage_nms_iou_threshold) detectionOut.addAttr('top_k', 6000) detectionOut.addAttr('code_type', "CENTER_SIZE") -detectionOut.addAttr('keep_top_k', 100) +detectionOut.addAttr('keep_top_k', first_stage_max_proposals) detectionOut.addAttr('clip', True) graph_def.node.extend([detectionOut]) # Save as text. +cropAndResizeNodesNames = [] for node in reversed(topNodes): if node.op != 'CropAndResize': graph_def.node.extend([node]) topNodes.pop() else: + cropAndResizeNodesNames.append(node.name) if numCropAndResize == 1: break else: @@ -166,11 +171,15 @@ for i in reversed(range(len(graph_def.node))): if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape', 'SecondStageBoxPredictor/Flatten/flatten/strided_slice', - 'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape']: + 'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape', + 'SecondStageBoxPredictor/Flatten_1/flatten/Shape', + 'SecondStageBoxPredictor/Flatten_1/flatten/strided_slice', + 'SecondStageBoxPredictor/Flatten_1/flatten/Reshape/shape']: del graph_def.node[i] for node in graph_def.node: - if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape': + if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape' or \ + node.name == 'SecondStageBoxPredictor/Flatten_1/flatten/Reshape': node.op = 'Flatten' node.input.pop() @@ -178,6 +187,12 @@ for node in graph_def.node: 'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']: node.addAttr('loc_pred_transposed', True) + if node.name.startswith('MaxPool2D'): + assert(node.op == 'MaxPool') + assert(len(cropAndResizeNodesNames) == 2) + node.input = [cropAndResizeNodesNames[0]] + del cropAndResizeNodesNames[0] + ################################################################################ ### Postprocessing ################################################################################ @@ -223,6 +238,11 @@ graph_def.node.extend([detectionOut]) for node in reversed(topNodes): graph_def.node.extend([node]) + if node.name.startswith('MaxPool2D'): + assert(node.op == 'MaxPool') + assert(len(cropAndResizeNodesNames) == 1) + node.input = [cropAndResizeNodesNames[0]] + for i in reversed(range(len(graph_def.node))): if graph_def.node[i].op == 'CropAndResize': graph_def.node[i].input.insert(1, 'detection_out_final')