mirror of
https://github.com/opencv/opencv.git
synced 2025-06-13 04:52:53 +08:00
Create text graphs for Faster-RCNN from TensorFlow with dilated convolutions
This commit is contained in:
parent
3721c8bb06
commit
67e6a6077d
@ -48,10 +48,42 @@ def createFasterRCNNGraph(modelPath, configPath, outputPath):
|
|||||||
|
|
||||||
removeIdentity(graph_def)
|
removeIdentity(graph_def)
|
||||||
|
|
||||||
|
nodesToKeep = []
|
||||||
def to_remove(name, op):
|
def to_remove(name, op):
|
||||||
|
if name in nodesToKeep:
|
||||||
|
return False
|
||||||
return op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \
|
return op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \
|
||||||
(name.startswith('CropAndResize') and op != 'CropAndResize')
|
(name.startswith('CropAndResize') and op != 'CropAndResize')
|
||||||
|
|
||||||
|
# Fuse atrous convolutions (with dilations).
|
||||||
|
nodesMap = {node.name: node for node in graph_def.node}
|
||||||
|
for node in reversed(graph_def.node):
|
||||||
|
if node.op == 'BatchToSpaceND':
|
||||||
|
del node.input[2]
|
||||||
|
conv = nodesMap[node.input[0]]
|
||||||
|
spaceToBatchND = nodesMap[conv.input[0]]
|
||||||
|
|
||||||
|
# Extract paddings
|
||||||
|
stridedSlice = nodesMap[spaceToBatchND.input[2]]
|
||||||
|
assert(stridedSlice.op == 'StridedSlice')
|
||||||
|
pack = nodesMap[stridedSlice.input[0]]
|
||||||
|
assert(pack.op == 'Pack')
|
||||||
|
|
||||||
|
padNodeH = nodesMap[nodesMap[pack.input[0]].input[0]]
|
||||||
|
padNodeW = nodesMap[nodesMap[pack.input[1]].input[0]]
|
||||||
|
padH = int(padNodeH.attr['value']['tensor'][0]['int_val'][0])
|
||||||
|
padW = int(padNodeW.attr['value']['tensor'][0]['int_val'][0])
|
||||||
|
|
||||||
|
paddingsNode = NodeDef()
|
||||||
|
paddingsNode.name = conv.name + '/paddings'
|
||||||
|
paddingsNode.op = 'Const'
|
||||||
|
paddingsNode.addAttr('value', [padH, padH, padW, padW])
|
||||||
|
graph_def.node.insert(graph_def.node.index(spaceToBatchND), paddingsNode)
|
||||||
|
nodesToKeep.append(paddingsNode.name)
|
||||||
|
|
||||||
|
spaceToBatchND.input[2] = paddingsNode.name
|
||||||
|
|
||||||
|
|
||||||
removeUnusedNodesAndAttrs(to_remove, graph_def)
|
removeUnusedNodesAndAttrs(to_remove, graph_def)
|
||||||
|
|
||||||
|
|
||||||
@ -225,6 +257,26 @@ def createFasterRCNNGraph(modelPath, configPath, outputPath):
|
|||||||
detectionOut.addAttr('variance_encoded_in_target', True)
|
detectionOut.addAttr('variance_encoded_in_target', True)
|
||||||
graph_def.node.extend([detectionOut])
|
graph_def.node.extend([detectionOut])
|
||||||
|
|
||||||
|
def getUnconnectedNodes():
|
||||||
|
unconnected = [node.name for node in graph_def.node]
|
||||||
|
for node in graph_def.node:
|
||||||
|
for inp in node.input:
|
||||||
|
if inp in unconnected:
|
||||||
|
unconnected.remove(inp)
|
||||||
|
return unconnected
|
||||||
|
|
||||||
|
while True:
|
||||||
|
unconnectedNodes = getUnconnectedNodes()
|
||||||
|
unconnectedNodes.remove(detectionOut.name)
|
||||||
|
if not unconnectedNodes:
|
||||||
|
break
|
||||||
|
|
||||||
|
for name in unconnectedNodes:
|
||||||
|
for i in range(len(graph_def.node)):
|
||||||
|
if graph_def.node[i].name == name:
|
||||||
|
del graph_def.node[i]
|
||||||
|
break
|
||||||
|
|
||||||
# Save as text.
|
# Save as text.
|
||||||
graph_def.save(outputPath)
|
graph_def.save(outputPath)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user