mirror of
https://github.com/opencv/opencv.git
synced 2024-12-05 01:39:13 +08:00
257 lines
10 KiB
Python
257 lines
10 KiB
Python
import argparse
|
|
import numpy as np
|
|
from tf_text_graph_common import *
|
|
|
|
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
|
'Mask-RCNN model from TensorFlow Object Detection API. '
|
|
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
|
|
parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
|
|
parser.add_argument('--output', required=True, help='Path to output text graph.')
|
|
parser.add_argument('--config', required=True, help='Path to a *.config file is used for training.')
|
|
args = parser.parse_args()
|
|
|
|
scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',
|
|
'FirstStageBoxPredictor/BoxEncodingPredictor',
|
|
'FirstStageBoxPredictor/ClassPredictor',
|
|
'CropAndResize',
|
|
'MaxPool2D',
|
|
'SecondStageFeatureExtractor',
|
|
'SecondStageBoxPredictor',
|
|
'Preprocessor/sub',
|
|
'Preprocessor/mul',
|
|
'image_tensor')
|
|
|
|
scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
|
|
'FirstStageFeatureExtractor/Shape',
|
|
'FirstStageFeatureExtractor/strided_slice',
|
|
'FirstStageFeatureExtractor/GreaterEqual',
|
|
'FirstStageFeatureExtractor/LogicalAnd')
|
|
|
|
# Load a config file.
|
|
config = readTextMessage(args.config)
|
|
config = config['model'][0]['faster_rcnn'][0]
|
|
num_classes = int(config['num_classes'][0])
|
|
|
|
grid_anchor_generator = config['first_stage_anchor_generator'][0]['grid_anchor_generator'][0]
|
|
scales = [float(s) for s in grid_anchor_generator['scales']]
|
|
aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']]
|
|
width_stride = float(grid_anchor_generator['width_stride'][0])
|
|
height_stride = float(grid_anchor_generator['height_stride'][0])
|
|
features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0])
|
|
first_stage_nms_iou_threshold = float(config['first_stage_nms_iou_threshold'][0])
|
|
first_stage_max_proposals = int(config['first_stage_max_proposals'][0])
|
|
|
|
print('Number of classes: %d' % num_classes)
|
|
print('Scales: %s' % str(scales))
|
|
print('Aspect ratios: %s' % str(aspect_ratios))
|
|
print('Width stride: %f' % width_stride)
|
|
print('Height stride: %f' % height_stride)
|
|
print('Features stride: %f' % features_stride)
|
|
|
|
# Read the graph.
|
|
writeTextGraph(args.input, args.output, ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes', 'detection_masks'])
|
|
graph_def = parseTextGraph(args.output)
|
|
|
|
removeIdentity(graph_def)
|
|
|
|
def to_remove(name, op):
|
|
return name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \
|
|
(name.startswith('CropAndResize') and op != 'CropAndResize')
|
|
|
|
removeUnusedNodesAndAttrs(to_remove, graph_def)
|
|
|
|
|
|
# Connect input node to the first layer
|
|
assert(graph_def.node[0].op == 'Placeholder')
|
|
graph_def.node[1].input.insert(0, graph_def.node[0].name)
|
|
|
|
# Temporarily remove top nodes.
|
|
topNodes = []
|
|
numCropAndResize = 0
|
|
while True:
|
|
node = graph_def.node.pop()
|
|
topNodes.append(node)
|
|
if node.op == 'CropAndResize':
|
|
numCropAndResize += 1
|
|
if numCropAndResize == 2:
|
|
break
|
|
|
|
addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd',
|
|
'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2], graph_def)
|
|
|
|
addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1',
|
|
'FirstStageBoxPredictor/ClassPredictor/softmax', graph_def) # Compare with Reshape_4
|
|
|
|
addFlatten('FirstStageBoxPredictor/ClassPredictor/softmax',
|
|
'FirstStageBoxPredictor/ClassPredictor/softmax/flatten', graph_def)
|
|
|
|
# Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd
|
|
addFlatten('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd',
|
|
'FirstStageBoxPredictor/BoxEncodingPredictor/flatten', graph_def)
|
|
|
|
proposals = NodeDef()
|
|
proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized)
|
|
proposals.op = 'PriorBox'
|
|
proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
|
|
proposals.input.append(graph_def.node[0].name) # image_tensor
|
|
|
|
proposals.addAttr('flip', False)
|
|
proposals.addAttr('clip', True)
|
|
proposals.addAttr('step', features_stride)
|
|
proposals.addAttr('offset', 0.0)
|
|
proposals.addAttr('variance', [0.1, 0.1, 0.2, 0.2])
|
|
|
|
widths = []
|
|
heights = []
|
|
for a in aspect_ratios:
|
|
for s in scales:
|
|
ar = np.sqrt(a)
|
|
heights.append((features_stride**2) * s / ar)
|
|
widths.append((features_stride**2) * s * ar)
|
|
|
|
proposals.addAttr('width', widths)
|
|
proposals.addAttr('height', heights)
|
|
|
|
graph_def.node.extend([proposals])
|
|
|
|
# Compare with Reshape_5
|
|
detectionOut = NodeDef()
|
|
detectionOut.name = 'detection_out'
|
|
detectionOut.op = 'DetectionOutput'
|
|
|
|
detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten')
|
|
detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten')
|
|
detectionOut.input.append('proposals')
|
|
|
|
detectionOut.addAttr('num_classes', 2)
|
|
detectionOut.addAttr('share_location', True)
|
|
detectionOut.addAttr('background_label_id', 0)
|
|
detectionOut.addAttr('nms_threshold', first_stage_nms_iou_threshold)
|
|
detectionOut.addAttr('top_k', 6000)
|
|
detectionOut.addAttr('code_type', "CENTER_SIZE")
|
|
detectionOut.addAttr('keep_top_k', first_stage_max_proposals)
|
|
detectionOut.addAttr('clip', True)
|
|
|
|
graph_def.node.extend([detectionOut])
|
|
|
|
# Save as text.
|
|
cropAndResizeNodesNames = []
|
|
for node in reversed(topNodes):
|
|
if node.op != 'CropAndResize':
|
|
graph_def.node.extend([node])
|
|
topNodes.pop()
|
|
else:
|
|
cropAndResizeNodesNames.append(node.name)
|
|
if numCropAndResize == 1:
|
|
break
|
|
else:
|
|
graph_def.node.extend([node])
|
|
topNodes.pop()
|
|
numCropAndResize -= 1
|
|
|
|
addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax', graph_def)
|
|
|
|
addSlice('SecondStageBoxPredictor/Reshape_1/softmax',
|
|
'SecondStageBoxPredictor/Reshape_1/slice',
|
|
[0, 0, 1], [-1, -1, -1], graph_def)
|
|
|
|
addReshape('SecondStageBoxPredictor/Reshape_1/slice',
|
|
'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1], graph_def)
|
|
|
|
# Replace Flatten subgraph onto a single node.
|
|
for i in reversed(range(len(graph_def.node))):
|
|
if graph_def.node[i].op == 'CropAndResize':
|
|
graph_def.node[i].input.insert(1, 'detection_out')
|
|
|
|
if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape':
|
|
addConstNode('SecondStageBoxPredictor/Reshape/shape2', [1, -1, 4], graph_def)
|
|
|
|
graph_def.node[i].input.pop()
|
|
graph_def.node[i].input.append('SecondStageBoxPredictor/Reshape/shape2')
|
|
|
|
if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape',
|
|
'SecondStageBoxPredictor/Flatten/flatten/strided_slice',
|
|
'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape',
|
|
'SecondStageBoxPredictor/Flatten_1/flatten/Shape',
|
|
'SecondStageBoxPredictor/Flatten_1/flatten/strided_slice',
|
|
'SecondStageBoxPredictor/Flatten_1/flatten/Reshape/shape']:
|
|
del graph_def.node[i]
|
|
|
|
for node in graph_def.node:
|
|
if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape' or \
|
|
node.name == 'SecondStageBoxPredictor/Flatten_1/flatten/Reshape':
|
|
node.op = 'Flatten'
|
|
node.input.pop()
|
|
|
|
if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D',
|
|
'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']:
|
|
node.addAttr('loc_pred_transposed', True)
|
|
|
|
if node.name.startswith('MaxPool2D'):
|
|
assert(node.op == 'MaxPool')
|
|
assert(len(cropAndResizeNodesNames) == 2)
|
|
node.input = [cropAndResizeNodesNames[0]]
|
|
del cropAndResizeNodesNames[0]
|
|
|
|
################################################################################
|
|
### Postprocessing
|
|
################################################################################
|
|
addSlice('detection_out', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4], graph_def)
|
|
|
|
variance = NodeDef()
|
|
variance.name = 'proposals/variance'
|
|
variance.op = 'Const'
|
|
variance.addAttr('value', [0.1, 0.1, 0.2, 0.2])
|
|
graph_def.node.extend([variance])
|
|
|
|
varianceEncoder = NodeDef()
|
|
varianceEncoder.name = 'variance_encoded'
|
|
varianceEncoder.op = 'Mul'
|
|
varianceEncoder.input.append('SecondStageBoxPredictor/Reshape')
|
|
varianceEncoder.input.append(variance.name)
|
|
varianceEncoder.addAttr('axis', 2)
|
|
graph_def.node.extend([varianceEncoder])
|
|
|
|
addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1], graph_def)
|
|
addFlatten('variance_encoded', 'variance_encoded/flatten', graph_def)
|
|
|
|
detectionOut = NodeDef()
|
|
detectionOut.name = 'detection_out_final'
|
|
detectionOut.op = 'DetectionOutput'
|
|
|
|
detectionOut.input.append('variance_encoded/flatten')
|
|
detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape')
|
|
detectionOut.input.append('detection_out/slice/reshape')
|
|
|
|
detectionOut.addAttr('num_classes', num_classes)
|
|
detectionOut.addAttr('share_location', False)
|
|
detectionOut.addAttr('background_label_id', num_classes + 1)
|
|
detectionOut.addAttr('nms_threshold', 0.6)
|
|
detectionOut.addAttr('code_type', "CENTER_SIZE")
|
|
detectionOut.addAttr('keep_top_k',100)
|
|
detectionOut.addAttr('clip', True)
|
|
detectionOut.addAttr('variance_encoded_in_target', True)
|
|
detectionOut.addAttr('confidence_threshold', 0.3)
|
|
detectionOut.addAttr('group_by_classes', False)
|
|
graph_def.node.extend([detectionOut])
|
|
|
|
for node in reversed(topNodes):
|
|
graph_def.node.extend([node])
|
|
|
|
if node.name.startswith('MaxPool2D'):
|
|
assert(node.op == 'MaxPool')
|
|
assert(len(cropAndResizeNodesNames) == 1)
|
|
node.input = [cropAndResizeNodesNames[0]]
|
|
|
|
for i in reversed(range(len(graph_def.node))):
|
|
if graph_def.node[i].op == 'CropAndResize':
|
|
graph_def.node[i].input.insert(1, 'detection_out_final')
|
|
break
|
|
|
|
graph_def.node[-1].name = 'detection_masks'
|
|
graph_def.node[-1].op = 'Sigmoid'
|
|
graph_def.node[-1].input.pop()
|
|
|
|
# Save as text.
|
|
graph_def.save(args.output)
|