mirror of
https://github.com/opencv/opencv.git
synced 2025-06-12 04:12:52 +08:00
Create text graphs for Faster-RCNN from TensorFlow with dilated convolutions
This commit is contained in:
parent
3721c8bb06
commit
67e6a6077d
@ -48,10 +48,42 @@ def createFasterRCNNGraph(modelPath, configPath, outputPath):
|
||||
|
||||
removeIdentity(graph_def)
|
||||
|
||||
nodesToKeep = []
|
||||
def to_remove(name, op):
|
||||
if name in nodesToKeep:
|
||||
return False
|
||||
return op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \
|
||||
(name.startswith('CropAndResize') and op != 'CropAndResize')
|
||||
|
||||
# Fuse atrous convolutions (with dilations).
|
||||
nodesMap = {node.name: node for node in graph_def.node}
|
||||
for node in reversed(graph_def.node):
|
||||
if node.op == 'BatchToSpaceND':
|
||||
del node.input[2]
|
||||
conv = nodesMap[node.input[0]]
|
||||
spaceToBatchND = nodesMap[conv.input[0]]
|
||||
|
||||
# Extract paddings
|
||||
stridedSlice = nodesMap[spaceToBatchND.input[2]]
|
||||
assert(stridedSlice.op == 'StridedSlice')
|
||||
pack = nodesMap[stridedSlice.input[0]]
|
||||
assert(pack.op == 'Pack')
|
||||
|
||||
padNodeH = nodesMap[nodesMap[pack.input[0]].input[0]]
|
||||
padNodeW = nodesMap[nodesMap[pack.input[1]].input[0]]
|
||||
padH = int(padNodeH.attr['value']['tensor'][0]['int_val'][0])
|
||||
padW = int(padNodeW.attr['value']['tensor'][0]['int_val'][0])
|
||||
|
||||
paddingsNode = NodeDef()
|
||||
paddingsNode.name = conv.name + '/paddings'
|
||||
paddingsNode.op = 'Const'
|
||||
paddingsNode.addAttr('value', [padH, padH, padW, padW])
|
||||
graph_def.node.insert(graph_def.node.index(spaceToBatchND), paddingsNode)
|
||||
nodesToKeep.append(paddingsNode.name)
|
||||
|
||||
spaceToBatchND.input[2] = paddingsNode.name
|
||||
|
||||
|
||||
removeUnusedNodesAndAttrs(to_remove, graph_def)
|
||||
|
||||
|
||||
@ -225,6 +257,26 @@ def createFasterRCNNGraph(modelPath, configPath, outputPath):
|
||||
detectionOut.addAttr('variance_encoded_in_target', True)
|
||||
graph_def.node.extend([detectionOut])
|
||||
|
||||
def getUnconnectedNodes():
|
||||
unconnected = [node.name for node in graph_def.node]
|
||||
for node in graph_def.node:
|
||||
for inp in node.input:
|
||||
if inp in unconnected:
|
||||
unconnected.remove(inp)
|
||||
return unconnected
|
||||
|
||||
while True:
|
||||
unconnectedNodes = getUnconnectedNodes()
|
||||
unconnectedNodes.remove(detectionOut.name)
|
||||
if not unconnectedNodes:
|
||||
break
|
||||
|
||||
for name in unconnectedNodes:
|
||||
for i in range(len(graph_def.node)):
|
||||
if graph_def.node[i].name == name:
|
||||
del graph_def.node[i]
|
||||
break
|
||||
|
||||
# Save as text.
|
||||
graph_def.save(outputPath)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user