mirror of
https://github.com/opencv/opencv.git
synced 2025-06-11 03:33:28 +08:00
Merge pull request #11580 from dkurt:dnn_fix_tf_ssd
This commit is contained in:
commit
8a17ae29b9
@ -64,36 +64,51 @@ removedNodes = []
|
|||||||
|
|
||||||
# Detect unfused batch normalization nodes and fuse them.
|
# Detect unfused batch normalization nodes and fuse them.
|
||||||
def fuse_batch_normalization():
|
def fuse_batch_normalization():
|
||||||
pattern = ['Add', 'Rsqrt', 'Mul', 'Mul', 'Mul', 'Sub', 'Add']
|
# Add_0 <-- moving_variance, add_y
|
||||||
candidates = []
|
# Rsqrt <-- Add_0
|
||||||
|
# Mul_0 <-- Rsqrt, gamma
|
||||||
|
# Mul_1 <-- input, Mul_0
|
||||||
|
# Mul_2 <-- moving_mean, Mul_0
|
||||||
|
# Sub_0 <-- beta, Mul_2
|
||||||
|
# Add_1 <-- Mul_1, Sub_0
|
||||||
|
nodesMap = {node.name: node for node in graph_def.node}
|
||||||
|
subgraph = ['Add',
|
||||||
|
['Mul', 'input', ['Mul', ['Rsqrt', ['Add', 'moving_variance', 'add_y']], 'gamma']],
|
||||||
|
['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]]
|
||||||
|
def checkSubgraph(node, targetNode, inputs, fusedNodes):
|
||||||
|
op = targetNode[0]
|
||||||
|
if node.op == op and (len(node.input) >= len(targetNode) - 1):
|
||||||
|
fusedNodes.append(node)
|
||||||
|
for i, inpOp in enumerate(targetNode[1:]):
|
||||||
|
if isinstance(inpOp, list):
|
||||||
|
if not node.input[i] in nodesMap or \
|
||||||
|
not checkSubgraph(nodesMap[node.input[i]], inpOp, inputs, fusedNodes):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
inputs[inpOp] = node.input[i]
|
||||||
|
|
||||||
for node in graph_def.node:
|
return True
|
||||||
if node.op == pattern[len(candidates)]:
|
|
||||||
candidates.append(node)
|
|
||||||
else:
|
else:
|
||||||
candidates = []
|
return False
|
||||||
|
|
||||||
if len(candidates) == len(pattern):
|
|
||||||
inp = candidates[3].input[0]
|
|
||||||
gamma = candidates[2].input[1]
|
|
||||||
beta = candidates[5].input[0]
|
|
||||||
moving_mean = candidates[4].input[0]
|
|
||||||
moving_variance = candidates[0].input[0]
|
|
||||||
|
|
||||||
|
nodesToRemove = []
|
||||||
|
for node in graph_def.node:
|
||||||
|
inputs = {}
|
||||||
|
fusedNodes = []
|
||||||
|
if checkSubgraph(node, subgraph, inputs, fusedNodes):
|
||||||
name = node.name
|
name = node.name
|
||||||
node.Clear()
|
node.Clear()
|
||||||
node.name = name
|
node.name = name
|
||||||
node.op = 'FusedBatchNorm'
|
node.op = 'FusedBatchNorm'
|
||||||
node.input.append(inp)
|
node.input.append(inputs['input'])
|
||||||
node.input.append(gamma)
|
node.input.append(inputs['gamma'])
|
||||||
node.input.append(beta)
|
node.input.append(inputs['beta'])
|
||||||
node.input.append(moving_mean)
|
node.input.append(inputs['moving_mean'])
|
||||||
node.input.append(moving_variance)
|
node.input.append(inputs['moving_variance'])
|
||||||
text_format.Merge('f: 0.001', node.attr["epsilon"])
|
text_format.Merge('f: 0.001', node.attr["epsilon"])
|
||||||
|
nodesToRemove += fusedNodes[1:]
|
||||||
for candidate in candidates[:-1]:
|
for node in nodesToRemove:
|
||||||
graph_def.node.remove(candidate)
|
graph_def.node.remove(node)
|
||||||
candidates = []
|
|
||||||
|
|
||||||
fuse_batch_normalization()
|
fuse_batch_normalization()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user