mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 09:25:45 +08:00
Enable ResNet-based Mask-RCNN models from TensorFlow Object Detection API
This commit is contained in:
parent
a63f66c90e
commit
6ad3bf3130
@ -25,7 +25,8 @@ scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
|
|||||||
'FirstStageFeatureExtractor/Shape',
|
'FirstStageFeatureExtractor/Shape',
|
||||||
'FirstStageFeatureExtractor/strided_slice',
|
'FirstStageFeatureExtractor/strided_slice',
|
||||||
'FirstStageFeatureExtractor/GreaterEqual',
|
'FirstStageFeatureExtractor/GreaterEqual',
|
||||||
'FirstStageFeatureExtractor/LogicalAnd')
|
'FirstStageFeatureExtractor/LogicalAnd',
|
||||||
|
'Conv/required_space_to_batch_paddings')
|
||||||
|
|
||||||
# Load a config file.
|
# Load a config file.
|
||||||
config = readTextMessage(args.config)
|
config = readTextMessage(args.config)
|
||||||
@ -54,10 +55,30 @@ graph_def = parseTextGraph(args.output)
|
|||||||
|
|
||||||
removeIdentity(graph_def)
|
removeIdentity(graph_def)
|
||||||
|
|
||||||
|
nodesToKeep = []
|
||||||
def to_remove(name, op):
|
def to_remove(name, op):
|
||||||
|
if name in nodesToKeep:
|
||||||
|
return False
|
||||||
return op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \
|
return op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \
|
||||||
(name.startswith('CropAndResize') and op != 'CropAndResize')
|
(name.startswith('CropAndResize') and op != 'CropAndResize')
|
||||||
|
|
||||||
|
# Fuse atrous convolutions (with dilations).
|
||||||
|
nodesMap = {node.name: node for node in graph_def.node}
|
||||||
|
for node in reversed(graph_def.node):
|
||||||
|
if node.op == 'BatchToSpaceND':
|
||||||
|
del node.input[2]
|
||||||
|
conv = nodesMap[node.input[0]]
|
||||||
|
spaceToBatchND = nodesMap[conv.input[0]]
|
||||||
|
|
||||||
|
paddingsNode = NodeDef()
|
||||||
|
paddingsNode.name = conv.name + '/paddings'
|
||||||
|
paddingsNode.op = 'Const'
|
||||||
|
paddingsNode.addAttr('value', [2, 2, 2, 2])
|
||||||
|
graph_def.node.insert(graph_def.node.index(spaceToBatchND), paddingsNode)
|
||||||
|
nodesToKeep.append(paddingsNode.name)
|
||||||
|
|
||||||
|
spaceToBatchND.input[2] = paddingsNode.name
|
||||||
|
|
||||||
removeUnusedNodesAndAttrs(to_remove, graph_def)
|
removeUnusedNodesAndAttrs(to_remove, graph_def)
|
||||||
|
|
||||||
|
|
||||||
@ -106,8 +127,8 @@ heights = []
|
|||||||
for a in aspect_ratios:
|
for a in aspect_ratios:
|
||||||
for s in scales:
|
for s in scales:
|
||||||
ar = np.sqrt(a)
|
ar = np.sqrt(a)
|
||||||
heights.append((features_stride**2) * s / ar)
|
heights.append((height_stride**2) * s / ar)
|
||||||
widths.append((features_stride**2) * s * ar)
|
widths.append((width_stride**2) * s * ar)
|
||||||
|
|
||||||
proposals.addAttr('width', widths)
|
proposals.addAttr('width', widths)
|
||||||
proposals.addAttr('height', heights)
|
proposals.addAttr('height', heights)
|
||||||
@ -252,5 +273,25 @@ graph_def.node[-1].name = 'detection_masks'
|
|||||||
graph_def.node[-1].op = 'Sigmoid'
|
graph_def.node[-1].op = 'Sigmoid'
|
||||||
graph_def.node[-1].input.pop()
|
graph_def.node[-1].input.pop()
|
||||||
|
|
||||||
|
def getUnconnectedNodes():
|
||||||
|
unconnected = [node.name for node in graph_def.node]
|
||||||
|
for node in graph_def.node:
|
||||||
|
for inp in node.input:
|
||||||
|
if inp in unconnected:
|
||||||
|
unconnected.remove(inp)
|
||||||
|
return unconnected
|
||||||
|
|
||||||
|
while True:
|
||||||
|
unconnectedNodes = getUnconnectedNodes()
|
||||||
|
unconnectedNodes.remove(graph_def.node[-1].name)
|
||||||
|
if not unconnectedNodes:
|
||||||
|
break
|
||||||
|
|
||||||
|
for name in unconnectedNodes:
|
||||||
|
for i in range(len(graph_def.node)):
|
||||||
|
if graph_def.node[i].name == name:
|
||||||
|
del graph_def.node[i]
|
||||||
|
break
|
||||||
|
|
||||||
# Save as text.
|
# Save as text.
|
||||||
graph_def.save(args.output)
|
graph_def.save(args.output)
|
||||||
|
Loading…
Reference in New Issue
Block a user