mirror of
https://github.com/opencv/opencv.git
synced 2025-01-18 14:13:15 +08:00
Import tensorflow to create text graphs if import cv is failed
This commit is contained in:
parent
70f38b4dfa
commit
b0ad7f759a
@ -302,3 +302,26 @@ def removeUnusedNodesAndAttrs(to_remove, graph_def):
|
|||||||
for i in reversed(range(len(node.input))):
|
for i in reversed(range(len(node.input))):
|
||||||
if node.input[i] in removedNodes:
|
if node.input[i] in removedNodes:
|
||||||
del node.input[i]
|
del node.input[i]
|
||||||
|
|
||||||
|
|
||||||
|
def writeTextGraph(modelPath, outputPath, outNodes):
|
||||||
|
try:
|
||||||
|
import cv2 as cv
|
||||||
|
|
||||||
|
cv.dnn.writeTextGraph(modelPath, outputPath)
|
||||||
|
except:
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.tools.graph_transforms import TransformGraph
|
||||||
|
|
||||||
|
with tf.gfile.FastGFile(modelPath, 'rb') as f:
|
||||||
|
graph_def = tf.GraphDef()
|
||||||
|
graph_def.ParseFromString(f.read())
|
||||||
|
|
||||||
|
graph_def = TransformGraph(graph_def, ['image_tensor'], outNodes, ['sort_by_execution_order'])
|
||||||
|
|
||||||
|
for node in graph_def.node:
|
||||||
|
if node.op == 'Const':
|
||||||
|
if 'value' in node.attr:
|
||||||
|
del node.attr['value']
|
||||||
|
|
||||||
|
tf.train.write_graph(graph_def, "", outputPath, as_text=True)
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2 as cv
|
|
||||||
from tf_text_graph_common import *
|
from tf_text_graph_common import *
|
||||||
|
|
||||||
|
|
||||||
@ -42,7 +41,7 @@ def createFasterRCNNGraph(modelPath, configPath, outputPath):
|
|||||||
print('Features stride: %f' % features_stride)
|
print('Features stride: %f' % features_stride)
|
||||||
|
|
||||||
# Read the graph.
|
# Read the graph.
|
||||||
cv.dnn.writeTextGraph(modelPath, outputPath)
|
writeTextGraph(modelPath, outputPath, ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes'])
|
||||||
graph_def = parseTextGraph(outputPath)
|
graph_def = parseTextGraph(outputPath)
|
||||||
|
|
||||||
removeIdentity(graph_def)
|
removeIdentity(graph_def)
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2 as cv
|
|
||||||
from tf_text_graph_common import *
|
from tf_text_graph_common import *
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
||||||
@ -48,7 +47,7 @@ print('Height stride: %f' % height_stride)
|
|||||||
print('Features stride: %f' % features_stride)
|
print('Features stride: %f' % features_stride)
|
||||||
|
|
||||||
# Read the graph.
|
# Read the graph.
|
||||||
cv.dnn.writeTextGraph(args.input, args.output)
|
writeTextGraph(args.input, args.output, ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes', 'detection_masks'])
|
||||||
graph_def = parseTextGraph(args.output)
|
graph_def = parseTextGraph(args.output)
|
||||||
|
|
||||||
removeIdentity(graph_def)
|
removeIdentity(graph_def)
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
# See details and examples on the following wiki page: https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API
|
# See details and examples on the following wiki page: https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API
|
||||||
import argparse
|
import argparse
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
import cv2 as cv
|
|
||||||
from tf_text_graph_common import *
|
from tf_text_graph_common import *
|
||||||
|
|
||||||
def createSSDGraph(modelPath, configPath, outputPath):
|
def createSSDGraph(modelPath, configPath, outputPath):
|
||||||
@ -52,12 +51,12 @@ def createSSDGraph(modelPath, configPath, outputPath):
|
|||||||
print('Input image size: %dx%d' % (image_width, image_height))
|
print('Input image size: %dx%d' % (image_width, image_height))
|
||||||
|
|
||||||
# Read the graph.
|
# Read the graph.
|
||||||
cv.dnn.writeTextGraph(modelPath, outputPath)
|
|
||||||
graph_def = parseTextGraph(outputPath)
|
|
||||||
|
|
||||||
inpNames = ['image_tensor']
|
inpNames = ['image_tensor']
|
||||||
outNames = ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes']
|
outNames = ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes']
|
||||||
|
|
||||||
|
writeTextGraph(modelPath, outputPath, outNames)
|
||||||
|
graph_def = parseTextGraph(outputPath)
|
||||||
|
|
||||||
def getUnconnectedNodes():
|
def getUnconnectedNodes():
|
||||||
unconnected = []
|
unconnected = []
|
||||||
for node in graph_def.node:
|
for node in graph_def.node:
|
||||||
|
Loading…
Reference in New Issue
Block a user