mirror of
https://github.com/opencv/opencv.git
synced 2024-12-01 14:59:54 +08:00
Update script to generate MobileNet-SSD V2 text graph
This commit is contained in:
parent
684cf43360
commit
d381948cee
@ -13,6 +13,7 @@ import tensorflow as tf
|
||||
import argparse
|
||||
from math import sqrt
|
||||
from tensorflow.core.framework.node_def_pb2 import NodeDef
|
||||
from tensorflow.tools.graph_transforms import TransformGraph
|
||||
from google.protobuf import text_format
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
||||
@ -32,7 +33,7 @@ args = parser.parse_args()
|
||||
|
||||
# Nodes that should be kept.
|
||||
keepOps = ['Conv2D', 'BiasAdd', 'Add', 'Relu6', 'Placeholder', 'FusedBatchNorm',
|
||||
'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool']
|
||||
'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool', 'Identity']
|
||||
|
||||
# Nodes attributes that could be removed because they are not used during import.
|
||||
unusedAttrs = ['T', 'data_format', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu',
|
||||
@ -46,6 +47,10 @@ with tf.gfile.FastGFile(args.input, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
|
||||
inpNames = ['image_tensor']
|
||||
outNames = ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes']
|
||||
graph_def = TransformGraph(graph_def, inpNames, outNames, ['sort_by_execution_order'])
|
||||
|
||||
def getUnconnectedNodes():
|
||||
unconnected = []
|
||||
for node in graph_def.node:
|
||||
@ -98,6 +103,7 @@ def removeIdentity():
|
||||
for node in graph_def.node:
|
||||
if node.op == 'Identity':
|
||||
identities[node.name] = node.input[0]
|
||||
graph_def.node.remove(node)
|
||||
|
||||
for node in graph_def.node:
|
||||
for i in range(len(node.input)):
|
||||
|
Loading…
Reference in New Issue
Block a user