mirror of
https://github.com/opencv/opencv.git
synced 2025-08-06 06:26:29 +08:00
AddV2 from TensorFlow
This commit is contained in:
parent
b4759d7272
commit
76cfa65d55
@ -996,7 +996,7 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
if (getDataLayout(name, data_layouts) == DATA_LAYOUT_UNKNOWN)
|
if (getDataLayout(name, data_layouts) == DATA_LAYOUT_UNKNOWN)
|
||||||
data_layouts[name] = DATA_LAYOUT_NHWC;
|
data_layouts[name] = DATA_LAYOUT_NHWC;
|
||||||
}
|
}
|
||||||
else if (type == "BiasAdd" || type == "Add" || type == "Sub" || type=="AddN")
|
else if (type == "BiasAdd" || type == "Add" || type == "AddV2" || type == "Sub" || type=="AddN")
|
||||||
{
|
{
|
||||||
bool haveConst = false;
|
bool haveConst = false;
|
||||||
for(int ii = 0; !haveConst && ii < layer.input_size(); ++ii)
|
for(int ii = 0; !haveConst && ii < layer.input_size(); ++ii)
|
||||||
|
@ -62,7 +62,7 @@ class MultiscaleAnchorGenerator:
|
|||||||
|
|
||||||
def createSSDGraph(modelPath, configPath, outputPath):
|
def createSSDGraph(modelPath, configPath, outputPath):
|
||||||
# Nodes that should be kept.
|
# Nodes that should be kept.
|
||||||
keepOps = ['Conv2D', 'BiasAdd', 'Add', 'Relu', 'Relu6', 'Placeholder', 'FusedBatchNorm',
|
keepOps = ['Conv2D', 'BiasAdd', 'Add', 'AddV2', 'Relu', 'Relu6', 'Placeholder', 'FusedBatchNorm',
|
||||||
'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool', 'Identity',
|
'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool', 'Identity',
|
||||||
'Sub', 'ResizeNearestNeighbor', 'Pad', 'FusedBatchNormV3']
|
'Sub', 'ResizeNearestNeighbor', 'Pad', 'FusedBatchNormV3']
|
||||||
|
|
||||||
@ -151,6 +151,9 @@ def createSSDGraph(modelPath, configPath, outputPath):
|
|||||||
subgraphBatchNorm = ['Add',
|
subgraphBatchNorm = ['Add',
|
||||||
['Mul', 'input', ['Mul', ['Rsqrt', ['Add', 'moving_variance', 'add_y']], 'gamma']],
|
['Mul', 'input', ['Mul', ['Rsqrt', ['Add', 'moving_variance', 'add_y']], 'gamma']],
|
||||||
['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]]
|
['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]]
|
||||||
|
subgraphBatchNormV2 = ['AddV2',
|
||||||
|
['Mul', 'input', ['Mul', ['Rsqrt', ['AddV2', 'moving_variance', 'add_y']], 'gamma']],
|
||||||
|
['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]]
|
||||||
# Detect unfused nearest neighbor resize.
|
# Detect unfused nearest neighbor resize.
|
||||||
subgraphResizeNN = ['Reshape',
|
subgraphResizeNN = ['Reshape',
|
||||||
['Mul', ['Reshape', 'input', ['Pack', 'shape_1', 'shape_2', 'shape_3', 'shape_4', 'shape_5']],
|
['Mul', ['Reshape', 'input', ['Pack', 'shape_1', 'shape_2', 'shape_3', 'shape_4', 'shape_5']],
|
||||||
@ -177,7 +180,8 @@ def createSSDGraph(modelPath, configPath, outputPath):
|
|||||||
for node in graph_def.node:
|
for node in graph_def.node:
|
||||||
inputs = {}
|
inputs = {}
|
||||||
fusedNodes = []
|
fusedNodes = []
|
||||||
if checkSubgraph(node, subgraphBatchNorm, inputs, fusedNodes):
|
if checkSubgraph(node, subgraphBatchNorm, inputs, fusedNodes) or \
|
||||||
|
checkSubgraph(node, subgraphBatchNormV2, inputs, fusedNodes):
|
||||||
name = node.name
|
name = node.name
|
||||||
node.Clear()
|
node.Clear()
|
||||||
node.name = name
|
node.name = name
|
||||||
|
Loading…
Reference in New Issue
Block a user