import argparse import numpy as np from tf_text_graph_common import * parser = argparse.ArgumentParser(description='Run this script to get a text graph of ' 'Mask-RCNN model from TensorFlow Object Detection API. ' 'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.') parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.') parser.add_argument('--output', required=True, help='Path to output text graph.') parser.add_argument('--config', required=True, help='Path to a *.config file is used for training.') args = parser.parse_args() scopesToKeep = ('FirstStageFeatureExtractor', 'Conv', 'FirstStageBoxPredictor/BoxEncodingPredictor', 'FirstStageBoxPredictor/ClassPredictor', 'CropAndResize', 'MaxPool2D', 'SecondStageFeatureExtractor', 'SecondStageBoxPredictor', 'Preprocessor/sub', 'Preprocessor/mul', 'image_tensor') scopesToIgnore = ('FirstStageFeatureExtractor/Assert', 'FirstStageFeatureExtractor/Shape', 'FirstStageFeatureExtractor/strided_slice', 'FirstStageFeatureExtractor/GreaterEqual', 'FirstStageFeatureExtractor/LogicalAnd', 'Conv/required_space_to_batch_paddings') # Load a config file. config = readTextMessage(args.config) config = config['model'][0]['faster_rcnn'][0] num_classes = int(config['num_classes'][0]) grid_anchor_generator = config['first_stage_anchor_generator'][0]['grid_anchor_generator'][0] scales = [float(s) for s in grid_anchor_generator['scales']] aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']] width_stride = float(grid_anchor_generator['width_stride'][0]) height_stride = float(grid_anchor_generator['height_stride'][0]) features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0]) first_stage_nms_iou_threshold = float(config['first_stage_nms_iou_threshold'][0]) first_stage_max_proposals = int(config['first_stage_max_proposals'][0]) print('Number of classes: %d' % num_classes) print('Scales: %s' % str(scales)) print('Aspect ratios: %s' % str(aspect_ratios)) print('Width stride: %f' % width_stride) print('Height stride: %f' % height_stride) print('Features stride: %f' % features_stride) # Read the graph. writeTextGraph(args.input, args.output, ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes', 'detection_masks']) graph_def = parseTextGraph(args.output) removeIdentity(graph_def) nodesToKeep = [] def to_remove(name, op): if name in nodesToKeep: return False return op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \ (name.startswith('CropAndResize') and op != 'CropAndResize') # Fuse atrous convolutions (with dilations). nodesMap = {node.name: node for node in graph_def.node} for node in reversed(graph_def.node): if node.op == 'BatchToSpaceND': del node.input[2] conv = nodesMap[node.input[0]] spaceToBatchND = nodesMap[conv.input[0]] paddingsNode = NodeDef() paddingsNode.name = conv.name + '/paddings' paddingsNode.op = 'Const' paddingsNode.addAttr('value', [2, 2, 2, 2]) graph_def.node.insert(graph_def.node.index(spaceToBatchND), paddingsNode) nodesToKeep.append(paddingsNode.name) spaceToBatchND.input[2] = paddingsNode.name removeUnusedNodesAndAttrs(to_remove, graph_def) # Connect input node to the first layer assert(graph_def.node[0].op == 'Placeholder') graph_def.node[1].input.insert(0, graph_def.node[0].name) # Temporarily remove top nodes. topNodes = [] numCropAndResize = 0 while True: node = graph_def.node.pop() topNodes.append(node) if node.op == 'CropAndResize': numCropAndResize += 1 if numCropAndResize == 2: break addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd', 'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2], graph_def) addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1', 'FirstStageBoxPredictor/ClassPredictor/softmax', graph_def) # Compare with Reshape_4 addFlatten('FirstStageBoxPredictor/ClassPredictor/softmax', 'FirstStageBoxPredictor/ClassPredictor/softmax/flatten', graph_def) # Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd addFlatten('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd', 'FirstStageBoxPredictor/BoxEncodingPredictor/flatten', graph_def) proposals = NodeDef() proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized) proposals.op = 'PriorBox' proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd') proposals.input.append(graph_def.node[0].name) # image_tensor proposals.addAttr('flip', False) proposals.addAttr('clip', True) proposals.addAttr('step', features_stride) proposals.addAttr('offset', 0.0) proposals.addAttr('variance', [0.1, 0.1, 0.2, 0.2]) widths = [] heights = [] for a in aspect_ratios: for s in scales: ar = np.sqrt(a) heights.append((height_stride**2) * s / ar) widths.append((width_stride**2) * s * ar) proposals.addAttr('width', widths) proposals.addAttr('height', heights) graph_def.node.extend([proposals]) # Compare with Reshape_5 detectionOut = NodeDef() detectionOut.name = 'detection_out' detectionOut.op = 'DetectionOutput' detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten') detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten') detectionOut.input.append('proposals') detectionOut.addAttr('num_classes', 2) detectionOut.addAttr('share_location', True) detectionOut.addAttr('background_label_id', 0) detectionOut.addAttr('nms_threshold', first_stage_nms_iou_threshold) detectionOut.addAttr('top_k', 6000) detectionOut.addAttr('code_type', "CENTER_SIZE") detectionOut.addAttr('keep_top_k', first_stage_max_proposals) detectionOut.addAttr('clip', True) graph_def.node.extend([detectionOut]) # Save as text. cropAndResizeNodesNames = [] for node in reversed(topNodes): if node.op != 'CropAndResize': graph_def.node.extend([node]) topNodes.pop() else: cropAndResizeNodesNames.append(node.name) if numCropAndResize == 1: break else: graph_def.node.extend([node]) topNodes.pop() numCropAndResize -= 1 addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax', graph_def) addSlice('SecondStageBoxPredictor/Reshape_1/softmax', 'SecondStageBoxPredictor/Reshape_1/slice', [0, 0, 1], [-1, -1, -1], graph_def) addReshape('SecondStageBoxPredictor/Reshape_1/slice', 'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1], graph_def) # Replace Flatten subgraph onto a single node. for i in reversed(range(len(graph_def.node))): if graph_def.node[i].op == 'CropAndResize': graph_def.node[i].input.insert(1, 'detection_out') if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape': addConstNode('SecondStageBoxPredictor/Reshape/shape2', [1, -1, 4], graph_def) graph_def.node[i].input.pop() graph_def.node[i].input.append('SecondStageBoxPredictor/Reshape/shape2') if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape', 'SecondStageBoxPredictor/Flatten/flatten/strided_slice', 'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape', 'SecondStageBoxPredictor/Flatten_1/flatten/Shape', 'SecondStageBoxPredictor/Flatten_1/flatten/strided_slice', 'SecondStageBoxPredictor/Flatten_1/flatten/Reshape/shape']: del graph_def.node[i] for node in graph_def.node: if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape' or \ node.name == 'SecondStageBoxPredictor/Flatten_1/flatten/Reshape': node.op = 'Flatten' node.input.pop() if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D', 'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']: node.addAttr('loc_pred_transposed', True) if node.name.startswith('MaxPool2D'): assert(node.op == 'MaxPool') assert(len(cropAndResizeNodesNames) == 2) node.input = [cropAndResizeNodesNames[0]] del cropAndResizeNodesNames[0] ################################################################################ ### Postprocessing ################################################################################ addSlice('detection_out', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4], graph_def) variance = NodeDef() variance.name = 'proposals/variance' variance.op = 'Const' variance.addAttr('value', [0.1, 0.1, 0.2, 0.2]) graph_def.node.extend([variance]) varianceEncoder = NodeDef() varianceEncoder.name = 'variance_encoded' varianceEncoder.op = 'Mul' varianceEncoder.input.append('SecondStageBoxPredictor/Reshape') varianceEncoder.input.append(variance.name) varianceEncoder.addAttr('axis', 2) graph_def.node.extend([varianceEncoder]) addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1], graph_def) addFlatten('variance_encoded', 'variance_encoded/flatten', graph_def) detectionOut = NodeDef() detectionOut.name = 'detection_out_final' detectionOut.op = 'DetectionOutput' detectionOut.input.append('variance_encoded/flatten') detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape') detectionOut.input.append('detection_out/slice/reshape') detectionOut.addAttr('num_classes', num_classes) detectionOut.addAttr('share_location', False) detectionOut.addAttr('background_label_id', num_classes + 1) detectionOut.addAttr('nms_threshold', 0.6) detectionOut.addAttr('code_type', "CENTER_SIZE") detectionOut.addAttr('keep_top_k',100) detectionOut.addAttr('clip', True) detectionOut.addAttr('variance_encoded_in_target', True) detectionOut.addAttr('confidence_threshold', 0.3) detectionOut.addAttr('group_by_classes', False) graph_def.node.extend([detectionOut]) for node in reversed(topNodes): graph_def.node.extend([node]) if node.name.startswith('MaxPool2D'): assert(node.op == 'MaxPool') assert(len(cropAndResizeNodesNames) == 1) node.input = [cropAndResizeNodesNames[0]] for i in reversed(range(len(graph_def.node))): if graph_def.node[i].op == 'CropAndResize': graph_def.node[i].input.insert(1, 'detection_out_final') break graph_def.node[-1].name = 'detection_masks' graph_def.node[-1].op = 'Sigmoid' graph_def.node[-1].input.pop() def getUnconnectedNodes(): unconnected = [node.name for node in graph_def.node] for node in graph_def.node: for inp in node.input: if inp in unconnected: unconnected.remove(inp) return unconnected while True: unconnectedNodes = getUnconnectedNodes() unconnectedNodes.remove(graph_def.node[-1].name) if not unconnectedNodes: break for name in unconnectedNodes: for i in range(len(graph_def.node)): if graph_def.node[i].name == name: del graph_def.node[i] break # Save as text. graph_def.save(args.output)