mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
Merge pull request #11580 from dkurt:dnn_fix_tf_ssd
This commit is contained in:
commit
8a17ae29b9
@ -64,36 +64,51 @@ removedNodes = []
|
||||
|
||||
# Detect unfused batch normalization nodes and fuse them.
|
||||
def fuse_batch_normalization():
|
||||
pattern = ['Add', 'Rsqrt', 'Mul', 'Mul', 'Mul', 'Sub', 'Add']
|
||||
candidates = []
|
||||
|
||||
for node in graph_def.node:
|
||||
if node.op == pattern[len(candidates)]:
|
||||
candidates.append(node)
|
||||
# Add_0 <-- moving_variance, add_y
|
||||
# Rsqrt <-- Add_0
|
||||
# Mul_0 <-- Rsqrt, gamma
|
||||
# Mul_1 <-- input, Mul_0
|
||||
# Mul_2 <-- moving_mean, Mul_0
|
||||
# Sub_0 <-- beta, Mul_2
|
||||
# Add_1 <-- Mul_1, Sub_0
|
||||
nodesMap = {node.name: node for node in graph_def.node}
|
||||
subgraph = ['Add',
|
||||
['Mul', 'input', ['Mul', ['Rsqrt', ['Add', 'moving_variance', 'add_y']], 'gamma']],
|
||||
['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]]
|
||||
def checkSubgraph(node, targetNode, inputs, fusedNodes):
|
||||
op = targetNode[0]
|
||||
if node.op == op and (len(node.input) >= len(targetNode) - 1):
|
||||
fusedNodes.append(node)
|
||||
for i, inpOp in enumerate(targetNode[1:]):
|
||||
if isinstance(inpOp, list):
|
||||
if not node.input[i] in nodesMap or \
|
||||
not checkSubgraph(nodesMap[node.input[i]], inpOp, inputs, fusedNodes):
|
||||
return False
|
||||
else:
|
||||
candidates = []
|
||||
inputs[inpOp] = node.input[i]
|
||||
|
||||
if len(candidates) == len(pattern):
|
||||
inp = candidates[3].input[0]
|
||||
gamma = candidates[2].input[1]
|
||||
beta = candidates[5].input[0]
|
||||
moving_mean = candidates[4].input[0]
|
||||
moving_variance = candidates[0].input[0]
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
nodesToRemove = []
|
||||
for node in graph_def.node:
|
||||
inputs = {}
|
||||
fusedNodes = []
|
||||
if checkSubgraph(node, subgraph, inputs, fusedNodes):
|
||||
name = node.name
|
||||
node.Clear()
|
||||
node.name = name
|
||||
node.op = 'FusedBatchNorm'
|
||||
node.input.append(inp)
|
||||
node.input.append(gamma)
|
||||
node.input.append(beta)
|
||||
node.input.append(moving_mean)
|
||||
node.input.append(moving_variance)
|
||||
node.input.append(inputs['input'])
|
||||
node.input.append(inputs['gamma'])
|
||||
node.input.append(inputs['beta'])
|
||||
node.input.append(inputs['moving_mean'])
|
||||
node.input.append(inputs['moving_variance'])
|
||||
text_format.Merge('f: 0.001', node.attr["epsilon"])
|
||||
|
||||
for candidate in candidates[:-1]:
|
||||
graph_def.node.remove(candidate)
|
||||
candidates = []
|
||||
nodesToRemove += fusedNodes[1:]
|
||||
for node in nodesToRemove:
|
||||
graph_def.node.remove(node)
|
||||
|
||||
fuse_batch_normalization()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user