Import tensorflow to create text graphs if import cv is failed

This commit is contained in:
Dmitry Kurtaev 2018-09-18 09:04:28 +03:00
parent 70f38b4dfa
commit b0ad7f759a
4 changed files with 28 additions and 8 deletions

View File

@ -302,3 +302,26 @@ def removeUnusedNodesAndAttrs(to_remove, graph_def):
for i in reversed(range(len(node.input))): for i in reversed(range(len(node.input))):
if node.input[i] in removedNodes: if node.input[i] in removedNodes:
del node.input[i] del node.input[i]
def writeTextGraph(modelPath, outputPath, outNodes):
try:
import cv2 as cv
cv.dnn.writeTextGraph(modelPath, outputPath)
except:
import tensorflow as tf
from tensorflow.tools.graph_transforms import TransformGraph
with tf.gfile.FastGFile(modelPath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph_def = TransformGraph(graph_def, ['image_tensor'], outNodes, ['sort_by_execution_order'])
for node in graph_def.node:
if node.op == 'Const':
if 'value' in node.attr:
del node.attr['value']
tf.train.write_graph(graph_def, "", outputPath, as_text=True)

View File

@ -1,6 +1,5 @@
import argparse import argparse
import numpy as np import numpy as np
import cv2 as cv
from tf_text_graph_common import * from tf_text_graph_common import *
@ -42,7 +41,7 @@ def createFasterRCNNGraph(modelPath, configPath, outputPath):
print('Features stride: %f' % features_stride) print('Features stride: %f' % features_stride)
# Read the graph. # Read the graph.
cv.dnn.writeTextGraph(modelPath, outputPath) writeTextGraph(modelPath, outputPath, ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes'])
graph_def = parseTextGraph(outputPath) graph_def = parseTextGraph(outputPath)
removeIdentity(graph_def) removeIdentity(graph_def)

View File

@ -1,6 +1,5 @@
import argparse import argparse
import numpy as np import numpy as np
import cv2 as cv
from tf_text_graph_common import * from tf_text_graph_common import *
parser = argparse.ArgumentParser(description='Run this script to get a text graph of ' parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
@ -48,7 +47,7 @@ print('Height stride: %f' % height_stride)
print('Features stride: %f' % features_stride) print('Features stride: %f' % features_stride)
# Read the graph. # Read the graph.
cv.dnn.writeTextGraph(args.input, args.output) writeTextGraph(args.input, args.output, ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes', 'detection_masks'])
graph_def = parseTextGraph(args.output) graph_def = parseTextGraph(args.output)
removeIdentity(graph_def) removeIdentity(graph_def)

View File

@ -11,7 +11,6 @@
# See details and examples on the following wiki page: https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API # See details and examples on the following wiki page: https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API
import argparse import argparse
from math import sqrt from math import sqrt
import cv2 as cv
from tf_text_graph_common import * from tf_text_graph_common import *
def createSSDGraph(modelPath, configPath, outputPath): def createSSDGraph(modelPath, configPath, outputPath):
@ -52,12 +51,12 @@ def createSSDGraph(modelPath, configPath, outputPath):
print('Input image size: %dx%d' % (image_width, image_height)) print('Input image size: %dx%d' % (image_width, image_height))
# Read the graph. # Read the graph.
cv.dnn.writeTextGraph(modelPath, outputPath)
graph_def = parseTextGraph(outputPath)
inpNames = ['image_tensor'] inpNames = ['image_tensor']
outNames = ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes'] outNames = ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes']
writeTextGraph(modelPath, outputPath, outNames)
graph_def = parseTextGraph(outputPath)
def getUnconnectedNodes(): def getUnconnectedNodes():
unconnected = [] unconnected = []
for node in graph_def.node: for node in graph_def.node: