mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 09:25:45 +08:00
Import SSDs from TensorFlow by training config (#12188)
* Remove TensorFlow and protobuf dependencies from object detection scripts * Create text graphs for TensorFlow object detection networks from sample
This commit is contained in:
parent
e3af72bb68
commit
c7cf8fb35c
@ -885,6 +885,14 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
|
||||
CV_EXPORTS_W void shrinkCaffeModel(const String& src, const String& dst,
|
||||
const std::vector<String>& layersTypes = std::vector<String>());
|
||||
|
||||
/** @brief Create a text representation for a binary network stored in protocol buffer format.
|
||||
* @param[in] model A path to binary network.
|
||||
* @param[in] output A path to output text file to be created.
|
||||
*
|
||||
* @note To reduce output file size, trained weights are not included.
|
||||
*/
|
||||
CV_EXPORTS_W void writeTextGraph(const String& model, const String& output);
|
||||
|
||||
/** @brief Performs non maximum suppression given boxes and corresponding scores.
|
||||
|
||||
* @param bboxes a set of bounding boxes to apply NMS.
|
||||
|
@ -782,6 +782,108 @@ void releaseTensor(tensorflow::TensorProto* tensor)
|
||||
}
|
||||
}
|
||||
|
||||
static void permute(google::protobuf::RepeatedPtrField<tensorflow::NodeDef>* data,
|
||||
const std::vector<int>& indices)
|
||||
{
|
||||
const int num = data->size();
|
||||
CV_Assert(num == indices.size());
|
||||
|
||||
std::vector<int> elemIdToPos(num);
|
||||
std::vector<int> posToElemId(num);
|
||||
for (int i = 0; i < num; ++i)
|
||||
{
|
||||
elemIdToPos[i] = i;
|
||||
posToElemId[i] = i;
|
||||
}
|
||||
for (int i = 0; i < num; ++i)
|
||||
{
|
||||
int elemId = indices[i];
|
||||
int pos = elemIdToPos[elemId];
|
||||
if (pos != i)
|
||||
{
|
||||
data->SwapElements(i, pos);
|
||||
const int swappedElemId = posToElemId[i];
|
||||
elemIdToPos[elemId] = i;
|
||||
elemIdToPos[swappedElemId] = pos;
|
||||
|
||||
posToElemId[i] = elemId;
|
||||
posToElemId[pos] = swappedElemId;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Is based on tensorflow::graph_transforms::SortByExecutionOrder
|
||||
void sortByExecutionOrder(tensorflow::GraphDef& net)
|
||||
{
|
||||
// Maps node's name to index at net.node() list.
|
||||
std::map<std::string, int> nodesMap;
|
||||
std::map<std::string, int>::iterator nodesMapIt;
|
||||
for (int i = 0; i < net.node_size(); ++i)
|
||||
{
|
||||
const tensorflow::NodeDef& node = net.node(i);
|
||||
nodesMap.insert(std::make_pair(node.name(), i));
|
||||
}
|
||||
|
||||
// Indices of nodes which use specific node as input.
|
||||
std::vector<std::vector<int> > edges(nodesMap.size());
|
||||
std::vector<int> numRefsToAdd(nodesMap.size(), 0);
|
||||
std::vector<int> nodesToAdd;
|
||||
for (int i = 0; i < net.node_size(); ++i)
|
||||
{
|
||||
const tensorflow::NodeDef& node = net.node(i);
|
||||
for (int j = 0; j < node.input_size(); ++j)
|
||||
{
|
||||
std::string inpName = node.input(j);
|
||||
inpName = inpName.substr(0, inpName.rfind(':'));
|
||||
inpName = inpName.substr(inpName.find('^') + 1);
|
||||
|
||||
nodesMapIt = nodesMap.find(inpName);
|
||||
CV_Assert(nodesMapIt != nodesMap.end());
|
||||
edges[nodesMapIt->second].push_back(i);
|
||||
}
|
||||
if (node.input_size() == 0)
|
||||
nodesToAdd.push_back(i);
|
||||
else
|
||||
{
|
||||
if (node.op() == "Merge" || node.op() == "RefMerge")
|
||||
{
|
||||
int numControlEdges = 0;
|
||||
for (int j = 0; j < node.input_size(); ++j)
|
||||
numControlEdges += node.input(j)[0] == '^';
|
||||
numRefsToAdd[i] = numControlEdges + 1;
|
||||
}
|
||||
else
|
||||
numRefsToAdd[i] = node.input_size();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> permIds;
|
||||
permIds.reserve(net.node_size());
|
||||
while (!nodesToAdd.empty())
|
||||
{
|
||||
int nodeToAdd = nodesToAdd.back();
|
||||
nodesToAdd.pop_back();
|
||||
|
||||
permIds.push_back(nodeToAdd);
|
||||
// std::cout << net.node(nodeToAdd).name() << '\n';
|
||||
|
||||
for (int i = 0; i < edges[nodeToAdd].size(); ++i)
|
||||
{
|
||||
int consumerId = edges[nodeToAdd][i];
|
||||
if (numRefsToAdd[consumerId] > 0)
|
||||
{
|
||||
if (numRefsToAdd[consumerId] == 1)
|
||||
nodesToAdd.push_back(consumerId);
|
||||
else
|
||||
CV_Assert(numRefsToAdd[consumerId] >= 0);
|
||||
numRefsToAdd[consumerId] -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
CV_Assert(permIds.size() == net.node_size());
|
||||
permute(net.mutable_node(), permIds);
|
||||
}
|
||||
|
||||
CV__DNN_EXPERIMENTAL_NS_END
|
||||
}} // namespace dnn, namespace cv
|
||||
|
||||
|
@ -25,6 +25,8 @@ Mat getTensorContent(const tensorflow::TensorProto &tensor);
|
||||
|
||||
void releaseTensor(tensorflow::TensorProto* tensor);
|
||||
|
||||
void sortByExecutionOrder(tensorflow::GraphDef& net);
|
||||
|
||||
CV__DNN_EXPERIMENTAL_NS_END
|
||||
}} // namespace dnn, namespace cv
|
||||
|
||||
|
@ -1950,5 +1950,34 @@ Net readNetFromTensorflow(const std::vector<uchar>& bufferModel, const std::vect
|
||||
bufferConfigPtr, bufferConfig.size());
|
||||
}
|
||||
|
||||
void writeTextGraph(const String& _model, const String& output)
|
||||
{
|
||||
String model = _model;
|
||||
const std::string modelExt = model.substr(model.rfind('.') + 1);
|
||||
if (modelExt != "pb")
|
||||
CV_Error(Error::StsNotImplemented, "Only TensorFlow models support export to text file");
|
||||
|
||||
tensorflow::GraphDef net;
|
||||
ReadTFNetParamsFromBinaryFileOrDie(model.c_str(), &net);
|
||||
|
||||
sortByExecutionOrder(net);
|
||||
|
||||
RepeatedPtrField<tensorflow::NodeDef>::iterator it;
|
||||
for (it = net.mutable_node()->begin(); it != net.mutable_node()->end(); ++it)
|
||||
{
|
||||
if (it->op() == "Const")
|
||||
{
|
||||
it->mutable_attr()->at("value").mutable_tensor()->clear_tensor_content();
|
||||
}
|
||||
}
|
||||
|
||||
std::string content;
|
||||
google::protobuf::TextFormat::PrintToString(net, &content);
|
||||
|
||||
std::ofstream ofs(output.c_str());
|
||||
ofs << content;
|
||||
ofs.close();
|
||||
}
|
||||
|
||||
CV__DNN_EXPERIMENTAL_NS_END
|
||||
}} // namespace
|
||||
|
@ -315,6 +315,29 @@ TEST_P(Test_TensorFlow_nets, Inception_v2_SSD)
|
||||
normAssertDetections(ref, out, "", 0.5, scoreDiff, iouDiff);
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_nets, MobileNet_v1_SSD)
|
||||
{
|
||||
checkBackend();
|
||||
|
||||
std::string model = findDataFile("dnn/ssd_mobilenet_v1_coco_2017_11_17.pb", false);
|
||||
std::string proto = findDataFile("dnn/ssd_mobilenet_v1_coco_2017_11_17.pbtxt", false);
|
||||
|
||||
Net net = readNetFromTensorflow(model, proto);
|
||||
Mat img = imread(findDataFile("dnn/dog416.png", false));
|
||||
Mat blob = blobFromImage(img, 1.0f, Size(300, 300), Scalar(), true, false);
|
||||
|
||||
net.setPreferableBackend(backend);
|
||||
net.setPreferableTarget(target);
|
||||
|
||||
net.setInput(blob);
|
||||
Mat out = net.forward();
|
||||
|
||||
Mat ref = blobFromNPY(findDataFile("dnn/tensorflow/ssd_mobilenet_v1_coco_2017_11_17.detection_out.npy"));
|
||||
float scoreDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 7e-3 : 1e-5;
|
||||
float iouDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.0098 : 1e-3;
|
||||
normAssertDetections(ref, out, "", 0.3, scoreDiff, iouDiff);
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_nets, Faster_RCNN)
|
||||
{
|
||||
static std::string names[] = {"faster_rcnn_inception_v2_coco_2018_01_28",
|
||||
@ -360,7 +383,8 @@ TEST_P(Test_TensorFlow_nets, MobileNet_v1_SSD_PPN)
|
||||
|
||||
net.setInput(blob);
|
||||
Mat out = net.forward();
|
||||
double scoreDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.008 : default_l1;
|
||||
|
||||
double scoreDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.011 : default_l1;
|
||||
double iouDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.021 : default_lInf;
|
||||
normAssertDetections(ref, out, "", 0.4, scoreDiff, iouDiff);
|
||||
}
|
||||
|
@ -3,6 +3,10 @@ import argparse
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
from tf_text_graph_common import readTextMessage
|
||||
from tf_text_graph_ssd import createSSDGraph
|
||||
from tf_text_graph_faster_rcnn import createFasterRCNNGraph
|
||||
|
||||
backends = (cv.dnn.DNN_BACKEND_DEFAULT, cv.dnn.DNN_BACKEND_HALIDE, cv.dnn.DNN_BACKEND_INFERENCE_ENGINE, cv.dnn.DNN_BACKEND_OPENCV)
|
||||
targets = (cv.dnn.DNN_TARGET_CPU, cv.dnn.DNN_TARGET_OPENCL, cv.dnn.DNN_TARGET_OPENCL_FP16, cv.dnn.DNN_TARGET_MYRIAD)
|
||||
|
||||
@ -11,11 +15,15 @@ parser.add_argument('--input', help='Path to input image or video file. Skip thi
|
||||
parser.add_argument('--model', required=True,
|
||||
help='Path to a binary file of model contains trained weights. '
|
||||
'It could be a file with extensions .caffemodel (Caffe), '
|
||||
'.pb (TensorFlow), .t7 or .net (Torch), .weights (Darknet)')
|
||||
'.pb (TensorFlow), .t7 or .net (Torch), .weights (Darknet), .bin (OpenVINO)')
|
||||
parser.add_argument('--config',
|
||||
help='Path to a text file of model contains network configuration. '
|
||||
'It could be a file with extensions .prototxt (Caffe), .pbtxt (TensorFlow), .cfg (Darknet)')
|
||||
parser.add_argument('--framework', choices=['caffe', 'tensorflow', 'torch', 'darknet'],
|
||||
'It could be a file with extensions .prototxt (Caffe), .pbtxt or .config (TensorFlow), .cfg (Darknet), .xml (OpenVINO)')
|
||||
parser.add_argument('--out_tf_graph', default='graph.pbtxt',
|
||||
help='For models from TensorFlow Object Detection API, you may '
|
||||
'pass a .config file which was used for training through --config '
|
||||
'argument. This way an additional .pbtxt file with TensorFlow graph will be created.')
|
||||
parser.add_argument('--framework', choices=['caffe', 'tensorflow', 'torch', 'darknet', 'dldt'],
|
||||
help='Optional name of an origin framework of the model. '
|
||||
'Detect it automatically if it does not set.')
|
||||
parser.add_argument('--classes', help='Optional path to a text file with names of classes to label detected objects.')
|
||||
@ -46,6 +54,20 @@ parser.add_argument('--target', choices=targets, default=cv.dnn.DNN_TARGET_CPU,
|
||||
'%d: VPU' % targets)
|
||||
args = parser.parse_args()
|
||||
|
||||
# If config specified, try to load it as TensorFlow Object Detection API's pipeline.
|
||||
config = readTextMessage(args.config)
|
||||
if 'model' in config:
|
||||
print('TensorFlow Object Detection API config detected')
|
||||
if 'ssd' in config['model'][0]:
|
||||
print('Preparing text graph representation for SSD model: ' + args.out_tf_graph)
|
||||
createSSDGraph(args.model, args.config, args.out_tf_graph)
|
||||
args.config = args.out_tf_graph
|
||||
elif 'faster_rcnn' in config['model'][0]:
|
||||
print('Preparing text graph representation for Faster-RCNN model: ' + args.out_tf_graph)
|
||||
createFasterRCNNGraph(args.model, args.config, args.out_tf_graph)
|
||||
args.config = args.out_tf_graph
|
||||
|
||||
|
||||
# Load names of classes
|
||||
classes = None
|
||||
if args.classes:
|
||||
|
@ -1,8 +1,86 @@
|
||||
import tensorflow as tf
|
||||
from tensorflow.core.framework.node_def_pb2 import NodeDef
|
||||
from google.protobuf import text_format
|
||||
def tokenize(s):
|
||||
tokens = []
|
||||
token = ""
|
||||
isString = False
|
||||
isComment = False
|
||||
for symbol in s:
|
||||
isComment = (isComment and symbol != '\n') or (not isString and symbol == '#')
|
||||
if isComment:
|
||||
continue
|
||||
|
||||
def tensorMsg(values):
|
||||
if symbol == ' ' or symbol == '\t' or symbol == '\r' or symbol == '\'' or \
|
||||
symbol == '\n' or symbol == ':' or symbol == '\"' or symbol == ';' or \
|
||||
symbol == ',':
|
||||
|
||||
if (symbol == '\"' or symbol == '\'') and isString:
|
||||
tokens.append(token)
|
||||
token = ""
|
||||
else:
|
||||
if isString:
|
||||
token += symbol
|
||||
elif token:
|
||||
tokens.append(token)
|
||||
token = ""
|
||||
isString = (symbol == '\"' or symbol == '\'') ^ isString;
|
||||
|
||||
elif symbol == '{' or symbol == '}' or symbol == '[' or symbol == ']':
|
||||
if token:
|
||||
tokens.append(token)
|
||||
token = ""
|
||||
tokens.append(symbol)
|
||||
else:
|
||||
token += symbol
|
||||
if token:
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
|
||||
def parseMessage(tokens, idx):
|
||||
msg = {}
|
||||
assert(tokens[idx] == '{')
|
||||
|
||||
isArray = False
|
||||
while True:
|
||||
if not isArray:
|
||||
idx += 1
|
||||
if idx < len(tokens):
|
||||
fieldName = tokens[idx]
|
||||
else:
|
||||
return None
|
||||
if fieldName == '}':
|
||||
break
|
||||
|
||||
idx += 1
|
||||
fieldValue = tokens[idx]
|
||||
|
||||
if fieldValue == '{':
|
||||
embeddedMsg, idx = parseMessage(tokens, idx)
|
||||
if fieldName in msg:
|
||||
msg[fieldName].append(embeddedMsg)
|
||||
else:
|
||||
msg[fieldName] = [embeddedMsg]
|
||||
elif fieldValue == '[':
|
||||
isArray = True
|
||||
elif fieldValue == ']':
|
||||
isArray = False
|
||||
else:
|
||||
if fieldName in msg:
|
||||
msg[fieldName].append(fieldValue)
|
||||
else:
|
||||
msg[fieldName] = [fieldValue]
|
||||
return msg, idx
|
||||
|
||||
|
||||
def readTextMessage(filePath):
|
||||
with open(filePath, 'rt') as f:
|
||||
content = f.read()
|
||||
|
||||
tokens = tokenize('{' + content + '}')
|
||||
msg = parseMessage(tokens, 0)
|
||||
return msg[0] if msg else {}
|
||||
|
||||
|
||||
def listToTensor(values):
|
||||
if all([isinstance(v, float) for v in values]):
|
||||
dtype = 'DT_FLOAT'
|
||||
field = 'float_val'
|
||||
@ -12,16 +90,25 @@ def tensorMsg(values):
|
||||
else:
|
||||
raise Exception('Wrong values types')
|
||||
|
||||
msg = 'tensor { dtype: ' + dtype + ' tensor_shape { dim { size: %d } }' % len(values)
|
||||
for value in values:
|
||||
msg += '%s: %s ' % (field, str(value))
|
||||
return msg + '}'
|
||||
msg = {
|
||||
'tensor': {
|
||||
'dtype': dtype,
|
||||
'tensor_shape': {
|
||||
'dim': {
|
||||
'size': len(values)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
msg['tensor'][field] = values
|
||||
return msg
|
||||
|
||||
|
||||
def addConstNode(name, values, graph_def):
|
||||
node = NodeDef()
|
||||
node.name = name
|
||||
node.op = 'Const'
|
||||
text_format.Merge(tensorMsg(values), node.attr["value"])
|
||||
node.addAttr('value', values)
|
||||
graph_def.node.extend([node])
|
||||
|
||||
|
||||
@ -29,13 +116,13 @@ def addSlice(inp, out, begins, sizes, graph_def):
|
||||
beginsNode = NodeDef()
|
||||
beginsNode.name = out + '/begins'
|
||||
beginsNode.op = 'Const'
|
||||
text_format.Merge(tensorMsg(begins), beginsNode.attr["value"])
|
||||
beginsNode.addAttr('value', begins)
|
||||
graph_def.node.extend([beginsNode])
|
||||
|
||||
sizesNode = NodeDef()
|
||||
sizesNode.name = out + '/sizes'
|
||||
sizesNode.op = 'Const'
|
||||
text_format.Merge(tensorMsg(sizes), sizesNode.attr["value"])
|
||||
sizesNode.addAttr('value', sizes)
|
||||
graph_def.node.extend([sizesNode])
|
||||
|
||||
sliced = NodeDef()
|
||||
@ -51,7 +138,7 @@ def addReshape(inp, out, shape, graph_def):
|
||||
shapeNode = NodeDef()
|
||||
shapeNode.name = out + '/shape'
|
||||
shapeNode.op = 'Const'
|
||||
text_format.Merge(tensorMsg(shape), shapeNode.attr["value"])
|
||||
shapeNode.addAttr('value', shape)
|
||||
graph_def.node.extend([shapeNode])
|
||||
|
||||
reshape = NodeDef()
|
||||
@ -66,7 +153,7 @@ def addSoftMax(inp, out, graph_def):
|
||||
softmax = NodeDef()
|
||||
softmax.name = out
|
||||
softmax.op = 'Softmax'
|
||||
text_format.Merge('i: -1', softmax.attr['axis'])
|
||||
softmax.addAttr('axis', -1)
|
||||
softmax.input.append(inp)
|
||||
graph_def.node.extend([softmax])
|
||||
|
||||
@ -79,6 +166,103 @@ def addFlatten(inp, out, graph_def):
|
||||
graph_def.node.extend([flatten])
|
||||
|
||||
|
||||
class NodeDef:
|
||||
def __init__(self):
|
||||
self.input = []
|
||||
self.name = ""
|
||||
self.op = ""
|
||||
self.attr = {}
|
||||
|
||||
def addAttr(self, key, value):
|
||||
assert(not key in self.attr)
|
||||
if isinstance(value, bool):
|
||||
self.attr[key] = {'b': value}
|
||||
elif isinstance(value, int):
|
||||
self.attr[key] = {'i': value}
|
||||
elif isinstance(value, float):
|
||||
self.attr[key] = {'f': value}
|
||||
elif isinstance(value, str):
|
||||
self.attr[key] = {'s': value}
|
||||
elif isinstance(value, list):
|
||||
self.attr[key] = listToTensor(value)
|
||||
else:
|
||||
raise Exception('Unknown type of attribute ' + key)
|
||||
|
||||
def Clear(self):
|
||||
self.input = []
|
||||
self.name = ""
|
||||
self.op = ""
|
||||
self.attr = {}
|
||||
|
||||
|
||||
class GraphDef:
|
||||
def __init__(self):
|
||||
self.node = []
|
||||
|
||||
def save(self, filePath):
|
||||
with open(filePath, 'wt') as f:
|
||||
|
||||
def printAttr(d, indent):
|
||||
indent = ' ' * indent
|
||||
for key, value in sorted(d.items(), key=lambda x:x[0].lower()):
|
||||
value = value if isinstance(value, list) else [value]
|
||||
for v in value:
|
||||
if isinstance(v, dict):
|
||||
f.write(indent + key + ' {\n')
|
||||
printAttr(v, len(indent) + 2)
|
||||
f.write(indent + '}\n')
|
||||
else:
|
||||
isString = False
|
||||
if isinstance(v, str) and not v.startswith('DT_'):
|
||||
try:
|
||||
float(v)
|
||||
except:
|
||||
isString = True
|
||||
|
||||
if isinstance(v, bool):
|
||||
printed = 'true' if v else 'false'
|
||||
elif v == 'true' or v == 'false':
|
||||
printed = 'true' if v == 'true' else 'false'
|
||||
elif isString:
|
||||
printed = '\"%s\"' % v
|
||||
else:
|
||||
printed = str(v)
|
||||
f.write(indent + key + ': ' + printed + '\n')
|
||||
|
||||
for node in self.node:
|
||||
f.write('node {\n')
|
||||
f.write(' name: \"%s\"\n' % node.name)
|
||||
f.write(' op: \"%s\"\n' % node.op)
|
||||
for inp in node.input:
|
||||
f.write(' input: \"%s\"\n' % inp)
|
||||
for key, value in sorted(node.attr.items(), key=lambda x:x[0].lower()):
|
||||
f.write(' attr {\n')
|
||||
f.write(' key: \"%s\"\n' % key)
|
||||
f.write(' value {\n')
|
||||
printAttr(value, 6)
|
||||
f.write(' }\n')
|
||||
f.write(' }\n')
|
||||
f.write('}\n')
|
||||
|
||||
|
||||
def parseTextGraph(filePath):
|
||||
msg = readTextMessage(filePath)
|
||||
|
||||
graph = GraphDef()
|
||||
for node in msg['node']:
|
||||
graphNode = NodeDef()
|
||||
graphNode.name = node['name'][0]
|
||||
graphNode.op = node['op'][0]
|
||||
graphNode.input = node['input'] if 'input' in node else []
|
||||
|
||||
if 'attr' in node:
|
||||
for attr in node['attr']:
|
||||
graphNode.attr[attr['key'][0]] = attr['value'][0]
|
||||
|
||||
graph.node.append(graphNode)
|
||||
return graph
|
||||
|
||||
|
||||
# Removes Identity nodes
|
||||
def removeIdentity(graph_def):
|
||||
identities = {}
|
||||
|
@ -1,214 +1,228 @@
|
||||
import argparse
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.core.framework.node_def_pb2 import NodeDef
|
||||
from tensorflow.tools.graph_transforms import TransformGraph
|
||||
from google.protobuf import text_format
|
||||
|
||||
import cv2 as cv
|
||||
from tf_text_graph_common import *
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
||||
'SSD model from TensorFlow Object Detection API. '
|
||||
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
|
||||
parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
|
||||
parser.add_argument('--output', required=True, help='Path to output text graph.')
|
||||
parser.add_argument('--num_classes', default=90, type=int, help='Number of trained classes.')
|
||||
parser.add_argument('--scales', default=[0.25, 0.5, 1.0, 2.0], type=float, nargs='+',
|
||||
help='Hyper-parameter of grid_anchor_generator from a config file.')
|
||||
parser.add_argument('--aspect_ratios', default=[0.5, 1.0, 2.0], type=float, nargs='+',
|
||||
help='Hyper-parameter of grid_anchor_generator from a config file.')
|
||||
parser.add_argument('--features_stride', default=16, type=float, nargs='+',
|
||||
help='Hyper-parameter from a config file.')
|
||||
args = parser.parse_args()
|
||||
|
||||
scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',
|
||||
'FirstStageBoxPredictor/BoxEncodingPredictor',
|
||||
'FirstStageBoxPredictor/ClassPredictor',
|
||||
'CropAndResize',
|
||||
'MaxPool2D',
|
||||
'SecondStageFeatureExtractor',
|
||||
'SecondStageBoxPredictor',
|
||||
'Preprocessor/sub',
|
||||
'Preprocessor/mul',
|
||||
'image_tensor')
|
||||
def createFasterRCNNGraph(modelPath, configPath, outputPath):
|
||||
scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',
|
||||
'FirstStageBoxPredictor/BoxEncodingPredictor',
|
||||
'FirstStageBoxPredictor/ClassPredictor',
|
||||
'CropAndResize',
|
||||
'MaxPool2D',
|
||||
'SecondStageFeatureExtractor',
|
||||
'SecondStageBoxPredictor',
|
||||
'Preprocessor/sub',
|
||||
'Preprocessor/mul',
|
||||
'image_tensor')
|
||||
|
||||
scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
|
||||
'FirstStageFeatureExtractor/Shape',
|
||||
'FirstStageFeatureExtractor/strided_slice',
|
||||
'FirstStageFeatureExtractor/GreaterEqual',
|
||||
'FirstStageFeatureExtractor/LogicalAnd')
|
||||
scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
|
||||
'FirstStageFeatureExtractor/Shape',
|
||||
'FirstStageFeatureExtractor/strided_slice',
|
||||
'FirstStageFeatureExtractor/GreaterEqual',
|
||||
'FirstStageFeatureExtractor/LogicalAnd')
|
||||
|
||||
# Read the graph.
|
||||
with tf.gfile.FastGFile(args.input, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
# Load a config file.
|
||||
config = readTextMessage(configPath)
|
||||
config = config['model'][0]['faster_rcnn'][0]
|
||||
num_classes = int(config['num_classes'][0])
|
||||
|
||||
removeIdentity(graph_def)
|
||||
grid_anchor_generator = config['first_stage_anchor_generator'][0]['grid_anchor_generator'][0]
|
||||
scales = [float(s) for s in grid_anchor_generator['scales']]
|
||||
aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']]
|
||||
width_stride = float(grid_anchor_generator['width_stride'][0])
|
||||
height_stride = float(grid_anchor_generator['height_stride'][0])
|
||||
features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0])
|
||||
|
||||
def to_remove(name, op):
|
||||
return name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep)
|
||||
print('Number of classes: %d' % num_classes)
|
||||
print('Scales: %s' % str(scales))
|
||||
print('Aspect ratios: %s' % str(aspect_ratios))
|
||||
print('Width stride: %f' % width_stride)
|
||||
print('Height stride: %f' % height_stride)
|
||||
print('Features stride: %f' % features_stride)
|
||||
|
||||
removeUnusedNodesAndAttrs(to_remove, graph_def)
|
||||
# Read the graph.
|
||||
cv.dnn.writeTextGraph(modelPath, outputPath)
|
||||
graph_def = parseTextGraph(outputPath)
|
||||
|
||||
removeIdentity(graph_def)
|
||||
|
||||
def to_remove(name, op):
|
||||
return name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep)
|
||||
|
||||
removeUnusedNodesAndAttrs(to_remove, graph_def)
|
||||
|
||||
|
||||
# Connect input node to the first layer
|
||||
assert(graph_def.node[0].op == 'Placeholder')
|
||||
graph_def.node[1].input.insert(0, graph_def.node[0].name)
|
||||
# Connect input node to the first layer
|
||||
assert(graph_def.node[0].op == 'Placeholder')
|
||||
graph_def.node[1].input.insert(0, graph_def.node[0].name)
|
||||
|
||||
# Temporarily remove top nodes.
|
||||
topNodes = []
|
||||
while True:
|
||||
node = graph_def.node.pop()
|
||||
topNodes.append(node)
|
||||
if node.op == 'CropAndResize':
|
||||
break
|
||||
# Temporarily remove top nodes.
|
||||
topNodes = []
|
||||
while True:
|
||||
node = graph_def.node.pop()
|
||||
topNodes.append(node)
|
||||
if node.op == 'CropAndResize':
|
||||
break
|
||||
|
||||
addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd',
|
||||
'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2], graph_def)
|
||||
addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd',
|
||||
'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2], graph_def)
|
||||
|
||||
addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1',
|
||||
'FirstStageBoxPredictor/ClassPredictor/softmax', graph_def) # Compare with Reshape_4
|
||||
addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1',
|
||||
'FirstStageBoxPredictor/ClassPredictor/softmax', graph_def) # Compare with Reshape_4
|
||||
|
||||
addFlatten('FirstStageBoxPredictor/ClassPredictor/softmax',
|
||||
'FirstStageBoxPredictor/ClassPredictor/softmax/flatten', graph_def)
|
||||
addFlatten('FirstStageBoxPredictor/ClassPredictor/softmax',
|
||||
'FirstStageBoxPredictor/ClassPredictor/softmax/flatten', graph_def)
|
||||
|
||||
# Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd
|
||||
addFlatten('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd',
|
||||
'FirstStageBoxPredictor/BoxEncodingPredictor/flatten', graph_def)
|
||||
# Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd
|
||||
addFlatten('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd',
|
||||
'FirstStageBoxPredictor/BoxEncodingPredictor/flatten', graph_def)
|
||||
|
||||
proposals = NodeDef()
|
||||
proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized)
|
||||
proposals.op = 'PriorBox'
|
||||
proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
|
||||
proposals.input.append(graph_def.node[0].name) # image_tensor
|
||||
proposals = NodeDef()
|
||||
proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized)
|
||||
proposals.op = 'PriorBox'
|
||||
proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
|
||||
proposals.input.append(graph_def.node[0].name) # image_tensor
|
||||
|
||||
text_format.Merge('b: false', proposals.attr["flip"])
|
||||
text_format.Merge('b: true', proposals.attr["clip"])
|
||||
text_format.Merge('f: %f' % args.features_stride, proposals.attr["step"])
|
||||
text_format.Merge('f: 0.0', proposals.attr["offset"])
|
||||
text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), proposals.attr["variance"])
|
||||
proposals.addAttr('flip', False)
|
||||
proposals.addAttr('clip', True)
|
||||
proposals.addAttr('step', features_stride)
|
||||
proposals.addAttr('offset', 0.0)
|
||||
proposals.addAttr('variance', [0.1, 0.1, 0.2, 0.2])
|
||||
|
||||
widths = []
|
||||
heights = []
|
||||
for a in args.aspect_ratios:
|
||||
for s in args.scales:
|
||||
ar = np.sqrt(a)
|
||||
heights.append((args.features_stride**2) * s / ar)
|
||||
widths.append((args.features_stride**2) * s * ar)
|
||||
widths = []
|
||||
heights = []
|
||||
for a in aspect_ratios:
|
||||
for s in scales:
|
||||
ar = np.sqrt(a)
|
||||
heights.append((height_stride**2) * s / ar)
|
||||
widths.append((width_stride**2) * s * ar)
|
||||
|
||||
text_format.Merge(tensorMsg(widths), proposals.attr["width"])
|
||||
text_format.Merge(tensorMsg(heights), proposals.attr["height"])
|
||||
proposals.addAttr('width', widths)
|
||||
proposals.addAttr('height', heights)
|
||||
|
||||
graph_def.node.extend([proposals])
|
||||
graph_def.node.extend([proposals])
|
||||
|
||||
# Compare with Reshape_5
|
||||
detectionOut = NodeDef()
|
||||
detectionOut.name = 'detection_out'
|
||||
detectionOut.op = 'DetectionOutput'
|
||||
# Compare with Reshape_5
|
||||
detectionOut = NodeDef()
|
||||
detectionOut.name = 'detection_out'
|
||||
detectionOut.op = 'DetectionOutput'
|
||||
|
||||
detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten')
|
||||
detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten')
|
||||
detectionOut.input.append('proposals')
|
||||
detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten')
|
||||
detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten')
|
||||
detectionOut.input.append('proposals')
|
||||
|
||||
text_format.Merge('i: 2', detectionOut.attr['num_classes'])
|
||||
text_format.Merge('b: true', detectionOut.attr['share_location'])
|
||||
text_format.Merge('i: 0', detectionOut.attr['background_label_id'])
|
||||
text_format.Merge('f: 0.7', detectionOut.attr['nms_threshold'])
|
||||
text_format.Merge('i: 6000', detectionOut.attr['top_k'])
|
||||
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
|
||||
text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
|
||||
text_format.Merge('b: false', detectionOut.attr['clip'])
|
||||
detectionOut.addAttr('num_classes', 2)
|
||||
detectionOut.addAttr('share_location', True)
|
||||
detectionOut.addAttr('background_label_id', 0)
|
||||
detectionOut.addAttr('nms_threshold', 0.7)
|
||||
detectionOut.addAttr('top_k', 6000)
|
||||
detectionOut.addAttr('code_type', "CENTER_SIZE")
|
||||
detectionOut.addAttr('keep_top_k', 100)
|
||||
detectionOut.addAttr('clip', False)
|
||||
|
||||
graph_def.node.extend([detectionOut])
|
||||
graph_def.node.extend([detectionOut])
|
||||
|
||||
addConstNode('clip_by_value/lower', [0.0], graph_def)
|
||||
addConstNode('clip_by_value/upper', [1.0], graph_def)
|
||||
addConstNode('clip_by_value/lower', [0.0], graph_def)
|
||||
addConstNode('clip_by_value/upper', [1.0], graph_def)
|
||||
|
||||
clipByValueNode = NodeDef()
|
||||
clipByValueNode.name = 'detection_out/clip_by_value'
|
||||
clipByValueNode.op = 'ClipByValue'
|
||||
clipByValueNode.input.append('detection_out')
|
||||
clipByValueNode.input.append('clip_by_value/lower')
|
||||
clipByValueNode.input.append('clip_by_value/upper')
|
||||
graph_def.node.extend([clipByValueNode])
|
||||
clipByValueNode = NodeDef()
|
||||
clipByValueNode.name = 'detection_out/clip_by_value'
|
||||
clipByValueNode.op = 'ClipByValue'
|
||||
clipByValueNode.input.append('detection_out')
|
||||
clipByValueNode.input.append('clip_by_value/lower')
|
||||
clipByValueNode.input.append('clip_by_value/upper')
|
||||
graph_def.node.extend([clipByValueNode])
|
||||
|
||||
# Save as text.
|
||||
for node in reversed(topNodes):
|
||||
graph_def.node.extend([node])
|
||||
# Save as text.
|
||||
for node in reversed(topNodes):
|
||||
graph_def.node.extend([node])
|
||||
|
||||
addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax', graph_def)
|
||||
addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax', graph_def)
|
||||
|
||||
addSlice('SecondStageBoxPredictor/Reshape_1/softmax',
|
||||
'SecondStageBoxPredictor/Reshape_1/slice',
|
||||
[0, 0, 1], [-1, -1, -1], graph_def)
|
||||
addSlice('SecondStageBoxPredictor/Reshape_1/softmax',
|
||||
'SecondStageBoxPredictor/Reshape_1/slice',
|
||||
[0, 0, 1], [-1, -1, -1], graph_def)
|
||||
|
||||
addReshape('SecondStageBoxPredictor/Reshape_1/slice',
|
||||
'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1], graph_def)
|
||||
addReshape('SecondStageBoxPredictor/Reshape_1/slice',
|
||||
'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1], graph_def)
|
||||
|
||||
# Replace Flatten subgraph onto a single node.
|
||||
for i in reversed(range(len(graph_def.node))):
|
||||
if graph_def.node[i].op == 'CropAndResize':
|
||||
graph_def.node[i].input.insert(1, 'detection_out/clip_by_value')
|
||||
# Replace Flatten subgraph onto a single node.
|
||||
for i in reversed(range(len(graph_def.node))):
|
||||
if graph_def.node[i].op == 'CropAndResize':
|
||||
graph_def.node[i].input.insert(1, 'detection_out/clip_by_value')
|
||||
|
||||
if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape':
|
||||
addConstNode('SecondStageBoxPredictor/Reshape/shape2', [1, -1, 4], graph_def)
|
||||
if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape':
|
||||
addConstNode('SecondStageBoxPredictor/Reshape/shape2', [1, -1, 4], graph_def)
|
||||
|
||||
graph_def.node[i].input.pop()
|
||||
graph_def.node[i].input.append('SecondStageBoxPredictor/Reshape/shape2')
|
||||
graph_def.node[i].input.pop()
|
||||
graph_def.node[i].input.append('SecondStageBoxPredictor/Reshape/shape2')
|
||||
|
||||
if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape',
|
||||
'SecondStageBoxPredictor/Flatten/flatten/strided_slice',
|
||||
'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape']:
|
||||
del graph_def.node[i]
|
||||
if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape',
|
||||
'SecondStageBoxPredictor/Flatten/flatten/strided_slice',
|
||||
'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape']:
|
||||
del graph_def.node[i]
|
||||
|
||||
for node in graph_def.node:
|
||||
if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape':
|
||||
node.op = 'Flatten'
|
||||
node.input.pop()
|
||||
for node in graph_def.node:
|
||||
if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape':
|
||||
node.op = 'Flatten'
|
||||
node.input.pop()
|
||||
|
||||
if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D',
|
||||
'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']:
|
||||
text_format.Merge('b: true', node.attr["loc_pred_transposed"])
|
||||
if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D',
|
||||
'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']:
|
||||
node.addAttr('loc_pred_transposed', True)
|
||||
|
||||
################################################################################
|
||||
### Postprocessing
|
||||
################################################################################
|
||||
addSlice('detection_out/clip_by_value', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4], graph_def)
|
||||
################################################################################
|
||||
### Postprocessing
|
||||
################################################################################
|
||||
addSlice('detection_out/clip_by_value', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4], graph_def)
|
||||
|
||||
variance = NodeDef()
|
||||
variance.name = 'proposals/variance'
|
||||
variance.op = 'Const'
|
||||
text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), variance.attr["value"])
|
||||
graph_def.node.extend([variance])
|
||||
variance = NodeDef()
|
||||
variance.name = 'proposals/variance'
|
||||
variance.op = 'Const'
|
||||
variance.addAttr('value', [0.1, 0.1, 0.2, 0.2])
|
||||
graph_def.node.extend([variance])
|
||||
|
||||
varianceEncoder = NodeDef()
|
||||
varianceEncoder.name = 'variance_encoded'
|
||||
varianceEncoder.op = 'Mul'
|
||||
varianceEncoder.input.append('SecondStageBoxPredictor/Reshape')
|
||||
varianceEncoder.input.append(variance.name)
|
||||
text_format.Merge('i: 2', varianceEncoder.attr["axis"])
|
||||
graph_def.node.extend([varianceEncoder])
|
||||
varianceEncoder = NodeDef()
|
||||
varianceEncoder.name = 'variance_encoded'
|
||||
varianceEncoder.op = 'Mul'
|
||||
varianceEncoder.input.append('SecondStageBoxPredictor/Reshape')
|
||||
varianceEncoder.input.append(variance.name)
|
||||
varianceEncoder.addAttr('axis', 2)
|
||||
graph_def.node.extend([varianceEncoder])
|
||||
|
||||
addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1], graph_def)
|
||||
addFlatten('variance_encoded', 'variance_encoded/flatten', graph_def)
|
||||
addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1], graph_def)
|
||||
addFlatten('variance_encoded', 'variance_encoded/flatten', graph_def)
|
||||
|
||||
detectionOut = NodeDef()
|
||||
detectionOut.name = 'detection_out_final'
|
||||
detectionOut.op = 'DetectionOutput'
|
||||
detectionOut = NodeDef()
|
||||
detectionOut.name = 'detection_out_final'
|
||||
detectionOut.op = 'DetectionOutput'
|
||||
|
||||
detectionOut.input.append('variance_encoded/flatten')
|
||||
detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape')
|
||||
detectionOut.input.append('detection_out/slice/reshape')
|
||||
detectionOut.input.append('variance_encoded/flatten')
|
||||
detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape')
|
||||
detectionOut.input.append('detection_out/slice/reshape')
|
||||
|
||||
text_format.Merge('i: %d' % args.num_classes, detectionOut.attr['num_classes'])
|
||||
text_format.Merge('b: false', detectionOut.attr['share_location'])
|
||||
text_format.Merge('i: %d' % (args.num_classes + 1), detectionOut.attr['background_label_id'])
|
||||
text_format.Merge('f: 0.6', detectionOut.attr['nms_threshold'])
|
||||
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
|
||||
text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
|
||||
text_format.Merge('b: true', detectionOut.attr['clip'])
|
||||
text_format.Merge('b: true', detectionOut.attr['variance_encoded_in_target'])
|
||||
graph_def.node.extend([detectionOut])
|
||||
detectionOut.addAttr('num_classes', num_classes)
|
||||
detectionOut.addAttr('share_location', False)
|
||||
detectionOut.addAttr('background_label_id', num_classes + 1)
|
||||
detectionOut.addAttr('nms_threshold', 0.6)
|
||||
detectionOut.addAttr('code_type', "CENTER_SIZE")
|
||||
detectionOut.addAttr('keep_top_k', 100)
|
||||
detectionOut.addAttr('clip', True)
|
||||
detectionOut.addAttr('variance_encoded_in_target', True)
|
||||
graph_def.node.extend([detectionOut])
|
||||
|
||||
tf.train.write_graph(graph_def, "", args.output, as_text=True)
|
||||
# Save as text.
|
||||
graph_def.save(outputPath)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
||||
'Faster-RCNN model from TensorFlow Object Detection API. '
|
||||
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
|
||||
parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
|
||||
parser.add_argument('--output', required=True, help='Path to output text graph.')
|
||||
parser.add_argument('--config', required=True, help='Path to a *.config file is used for training.')
|
||||
args = parser.parse_args()
|
||||
|
||||
createFasterRCNNGraph(args.input, args.config, args.output)
|
||||
|
@ -1,11 +1,6 @@
|
||||
import argparse
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.core.framework.node_def_pb2 import NodeDef
|
||||
from tensorflow.tools.graph_transforms import TransformGraph
|
||||
from google.protobuf import text_format
|
||||
|
||||
import cv2 as cv
|
||||
from tf_text_graph_common import *
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
||||
@ -13,13 +8,7 @@ parser = argparse.ArgumentParser(description='Run this script to get a text grap
|
||||
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
|
||||
parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
|
||||
parser.add_argument('--output', required=True, help='Path to output text graph.')
|
||||
parser.add_argument('--num_classes', default=90, type=int, help='Number of trained classes.')
|
||||
parser.add_argument('--scales', default=[0.25, 0.5, 1.0, 2.0], type=float, nargs='+',
|
||||
help='Hyper-parameter of grid_anchor_generator from a config file.')
|
||||
parser.add_argument('--aspect_ratios', default=[0.5, 1.0, 2.0], type=float, nargs='+',
|
||||
help='Hyper-parameter of grid_anchor_generator from a config file.')
|
||||
parser.add_argument('--features_stride', default=16, type=float, nargs='+',
|
||||
help='Hyper-parameter from a config file.')
|
||||
parser.add_argument('--config', required=True, help='Path to a *.config file is used for training.')
|
||||
args = parser.parse_args()
|
||||
|
||||
scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',
|
||||
@ -39,11 +28,28 @@ scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
|
||||
'FirstStageFeatureExtractor/GreaterEqual',
|
||||
'FirstStageFeatureExtractor/LogicalAnd')
|
||||
|
||||
# Load a config file.
|
||||
config = readTextMessage(args.config)
|
||||
config = config['model'][0]['faster_rcnn'][0]
|
||||
num_classes = int(config['num_classes'][0])
|
||||
|
||||
grid_anchor_generator = config['first_stage_anchor_generator'][0]['grid_anchor_generator'][0]
|
||||
scales = [float(s) for s in grid_anchor_generator['scales']]
|
||||
aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']]
|
||||
width_stride = float(grid_anchor_generator['width_stride'][0])
|
||||
height_stride = float(grid_anchor_generator['height_stride'][0])
|
||||
features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0])
|
||||
|
||||
print('Number of classes: %d' % num_classes)
|
||||
print('Scales: %s' % str(scales))
|
||||
print('Aspect ratios: %s' % str(aspect_ratios))
|
||||
print('Width stride: %f' % width_stride)
|
||||
print('Height stride: %f' % height_stride)
|
||||
print('Features stride: %f' % features_stride)
|
||||
|
||||
# Read the graph.
|
||||
with tf.gfile.FastGFile(args.input, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
cv.dnn.writeTextGraph(args.input, args.output)
|
||||
graph_def = parseTextGraph(args.output)
|
||||
|
||||
removeIdentity(graph_def)
|
||||
|
||||
@ -87,22 +93,22 @@ proposals.op = 'PriorBox'
|
||||
proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
|
||||
proposals.input.append(graph_def.node[0].name) # image_tensor
|
||||
|
||||
text_format.Merge('b: false', proposals.attr["flip"])
|
||||
text_format.Merge('b: true', proposals.attr["clip"])
|
||||
text_format.Merge('f: %f' % args.features_stride, proposals.attr["step"])
|
||||
text_format.Merge('f: 0.0', proposals.attr["offset"])
|
||||
text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), proposals.attr["variance"])
|
||||
proposals.addAttr('flip', False)
|
||||
proposals.addAttr('clip', True)
|
||||
proposals.addAttr('step', features_stride)
|
||||
proposals.addAttr('offset', 0.0)
|
||||
proposals.addAttr('variance', [0.1, 0.1, 0.2, 0.2])
|
||||
|
||||
widths = []
|
||||
heights = []
|
||||
for a in args.aspect_ratios:
|
||||
for s in args.scales:
|
||||
for a in aspect_ratios:
|
||||
for s in scales:
|
||||
ar = np.sqrt(a)
|
||||
heights.append((args.features_stride**2) * s / ar)
|
||||
widths.append((args.features_stride**2) * s * ar)
|
||||
heights.append((features_stride**2) * s / ar)
|
||||
widths.append((features_stride**2) * s * ar)
|
||||
|
||||
text_format.Merge(tensorMsg(widths), proposals.attr["width"])
|
||||
text_format.Merge(tensorMsg(heights), proposals.attr["height"])
|
||||
proposals.addAttr('width', widths)
|
||||
proposals.addAttr('height', heights)
|
||||
|
||||
graph_def.node.extend([proposals])
|
||||
|
||||
@ -115,14 +121,14 @@ detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten')
|
||||
detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten')
|
||||
detectionOut.input.append('proposals')
|
||||
|
||||
text_format.Merge('i: 2', detectionOut.attr['num_classes'])
|
||||
text_format.Merge('b: true', detectionOut.attr['share_location'])
|
||||
text_format.Merge('i: 0', detectionOut.attr['background_label_id'])
|
||||
text_format.Merge('f: 0.7', detectionOut.attr['nms_threshold'])
|
||||
text_format.Merge('i: 6000', detectionOut.attr['top_k'])
|
||||
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
|
||||
text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
|
||||
text_format.Merge('b: true', detectionOut.attr['clip'])
|
||||
detectionOut.addAttr('num_classes', 2)
|
||||
detectionOut.addAttr('share_location', True)
|
||||
detectionOut.addAttr('background_label_id', 0)
|
||||
detectionOut.addAttr('nms_threshold', 0.7)
|
||||
detectionOut.addAttr('top_k', 6000)
|
||||
detectionOut.addAttr('code_type', "CENTER_SIZE")
|
||||
detectionOut.addAttr('keep_top_k', 100)
|
||||
detectionOut.addAttr('clip', True)
|
||||
|
||||
graph_def.node.extend([detectionOut])
|
||||
|
||||
@ -171,7 +177,7 @@ for node in graph_def.node:
|
||||
|
||||
if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D',
|
||||
'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']:
|
||||
text_format.Merge('b: true', node.attr["loc_pred_transposed"])
|
||||
node.addAttr('loc_pred_transposed', True)
|
||||
|
||||
################################################################################
|
||||
### Postprocessing
|
||||
@ -181,7 +187,7 @@ addSlice('detection_out', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4],
|
||||
variance = NodeDef()
|
||||
variance.name = 'proposals/variance'
|
||||
variance.op = 'Const'
|
||||
text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), variance.attr["value"])
|
||||
variance.addAttr('value', [0.1, 0.1, 0.2, 0.2])
|
||||
graph_def.node.extend([variance])
|
||||
|
||||
varianceEncoder = NodeDef()
|
||||
@ -189,7 +195,7 @@ varianceEncoder.name = 'variance_encoded'
|
||||
varianceEncoder.op = 'Mul'
|
||||
varianceEncoder.input.append('SecondStageBoxPredictor/Reshape')
|
||||
varianceEncoder.input.append(variance.name)
|
||||
text_format.Merge('i: 2', varianceEncoder.attr["axis"])
|
||||
varianceEncoder.addAttr('axis', 2)
|
||||
graph_def.node.extend([varianceEncoder])
|
||||
|
||||
addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1], graph_def)
|
||||
@ -203,16 +209,16 @@ detectionOut.input.append('variance_encoded/flatten')
|
||||
detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape')
|
||||
detectionOut.input.append('detection_out/slice/reshape')
|
||||
|
||||
text_format.Merge('i: %d' % args.num_classes, detectionOut.attr['num_classes'])
|
||||
text_format.Merge('b: false', detectionOut.attr['share_location'])
|
||||
text_format.Merge('i: %d' % (args.num_classes + 1), detectionOut.attr['background_label_id'])
|
||||
text_format.Merge('f: 0.6', detectionOut.attr['nms_threshold'])
|
||||
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
|
||||
text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
|
||||
text_format.Merge('b: true', detectionOut.attr['clip'])
|
||||
text_format.Merge('b: true', detectionOut.attr['variance_encoded_in_target'])
|
||||
text_format.Merge('f: 0.3', detectionOut.attr['confidence_threshold'])
|
||||
text_format.Merge('b: false', detectionOut.attr['group_by_classes'])
|
||||
detectionOut.addAttr('num_classes', num_classes)
|
||||
detectionOut.addAttr('share_location', False)
|
||||
detectionOut.addAttr('background_label_id', num_classes + 1)
|
||||
detectionOut.addAttr('nms_threshold', 0.6)
|
||||
detectionOut.addAttr('code_type', "CENTER_SIZE")
|
||||
detectionOut.addAttr('keep_top_k',100)
|
||||
detectionOut.addAttr('clip', True)
|
||||
detectionOut.addAttr('variance_encoded_in_target', True)
|
||||
detectionOut.addAttr('confidence_threshold', 0.3)
|
||||
detectionOut.addAttr('group_by_classes', False)
|
||||
graph_def.node.extend([detectionOut])
|
||||
|
||||
for node in reversed(topNodes):
|
||||
@ -227,4 +233,5 @@ graph_def.node[-1].name = 'detection_masks'
|
||||
graph_def.node[-1].op = 'Sigmoid'
|
||||
graph_def.node[-1].input.pop()
|
||||
|
||||
tf.train.write_graph(graph_def, "", args.output, as_text=True)
|
||||
# Save as text.
|
||||
graph_def.save(args.output)
|
||||
|
@ -9,253 +9,269 @@
|
||||
# deep learning network trained in TensorFlow Object Detection API.
|
||||
# Then you can import it with a binary frozen graph (.pb) using readNetFromTensorflow() function.
|
||||
# See details and examples on the following wiki page: https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API
|
||||
import tensorflow as tf
|
||||
import argparse
|
||||
from math import sqrt
|
||||
from tensorflow.core.framework.node_def_pb2 import NodeDef
|
||||
from tensorflow.tools.graph_transforms import TransformGraph
|
||||
from google.protobuf import text_format
|
||||
import cv2 as cv
|
||||
from tf_text_graph_common import *
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
||||
'SSD model from TensorFlow Object Detection API. '
|
||||
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
|
||||
parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
|
||||
parser.add_argument('--output', required=True, help='Path to output text graph.')
|
||||
parser.add_argument('--num_classes', default=90, type=int, help='Number of trained classes.')
|
||||
parser.add_argument('--min_scale', default=0.2, type=float, help='Hyper-parameter of ssd_anchor_generator from config file.')
|
||||
parser.add_argument('--max_scale', default=0.95, type=float, help='Hyper-parameter of ssd_anchor_generator from config file.')
|
||||
parser.add_argument('--num_layers', default=6, type=int, help='Hyper-parameter of ssd_anchor_generator from config file.')
|
||||
parser.add_argument('--aspect_ratios', default=[1.0, 2.0, 0.5, 3.0, 0.333], type=float, nargs='+',
|
||||
help='Hyper-parameter of ssd_anchor_generator from config file.')
|
||||
parser.add_argument('--image_width', default=300, type=int, help='Training images width.')
|
||||
parser.add_argument('--image_height', default=300, type=int, help='Training images height.')
|
||||
parser.add_argument('--not_reduce_boxes_in_lowest_layer', default=False, action='store_true',
|
||||
help='A boolean to indicate whether the fixed 3 boxes per '
|
||||
'location is used in the lowest achors generation layer.')
|
||||
parser.add_argument('--box_predictor', default='convolutional', type=str,
|
||||
choices=['convolutional', 'weight_shared_convolutional'])
|
||||
args = parser.parse_args()
|
||||
def createSSDGraph(modelPath, configPath, outputPath):
|
||||
# Nodes that should be kept.
|
||||
keepOps = ['Conv2D', 'BiasAdd', 'Add', 'Relu6', 'Placeholder', 'FusedBatchNorm',
|
||||
'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool', 'Identity',
|
||||
'Sub']
|
||||
|
||||
# Nodes that should be kept.
|
||||
keepOps = ['Conv2D', 'BiasAdd', 'Add', 'Relu6', 'Placeholder', 'FusedBatchNorm',
|
||||
'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool', 'Identity',
|
||||
'Sub']
|
||||
# Node with which prefixes should be removed
|
||||
prefixesToRemove = ('MultipleGridAnchorGenerator/', 'Postprocessor/', 'Preprocessor/map')
|
||||
|
||||
# Node with which prefixes should be removed
|
||||
prefixesToRemove = ('MultipleGridAnchorGenerator/', 'Postprocessor/', 'Preprocessor/map')
|
||||
# Load a config file.
|
||||
config = readTextMessage(configPath)
|
||||
config = config['model'][0]['ssd'][0]
|
||||
num_classes = int(config['num_classes'][0])
|
||||
|
||||
# Read the graph.
|
||||
with tf.gfile.FastGFile(args.input, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
ssd_anchor_generator = config['anchor_generator'][0]['ssd_anchor_generator'][0]
|
||||
min_scale = float(ssd_anchor_generator['min_scale'][0])
|
||||
max_scale = float(ssd_anchor_generator['max_scale'][0])
|
||||
num_layers = int(ssd_anchor_generator['num_layers'][0])
|
||||
aspect_ratios = [float(ar) for ar in ssd_anchor_generator['aspect_ratios']]
|
||||
reduce_boxes_in_lowest_layer = True
|
||||
if 'reduce_boxes_in_lowest_layer' in ssd_anchor_generator:
|
||||
reduce_boxes_in_lowest_layer = ssd_anchor_generator['reduce_boxes_in_lowest_layer'][0] == 'true'
|
||||
|
||||
inpNames = ['image_tensor']
|
||||
outNames = ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes']
|
||||
graph_def = TransformGraph(graph_def, inpNames, outNames, ['sort_by_execution_order'])
|
||||
fixed_shape_resizer = config['image_resizer'][0]['fixed_shape_resizer'][0]
|
||||
image_width = int(fixed_shape_resizer['width'][0])
|
||||
image_height = int(fixed_shape_resizer['height'][0])
|
||||
|
||||
def getUnconnectedNodes():
|
||||
unconnected = []
|
||||
for node in graph_def.node:
|
||||
unconnected.append(node.name)
|
||||
for inp in node.input:
|
||||
if inp in unconnected:
|
||||
unconnected.remove(inp)
|
||||
return unconnected
|
||||
box_predictor = 'convolutional' if 'convolutional_box_predictor' in config['box_predictor'][0] else 'weight_shared_convolutional'
|
||||
|
||||
print('Number of classes: %d' % num_classes)
|
||||
print('Number of layers: %d' % num_layers)
|
||||
print('Scale: [%f-%f]' % (min_scale, max_scale))
|
||||
print('Aspect ratios: %s' % str(aspect_ratios))
|
||||
print('Reduce boxes in the lowest layer: %s' % str(reduce_boxes_in_lowest_layer))
|
||||
print('box predictor: %s' % box_predictor)
|
||||
print('Input image size: %dx%d' % (image_width, image_height))
|
||||
|
||||
# Read the graph.
|
||||
cv.dnn.writeTextGraph(modelPath, outputPath)
|
||||
graph_def = parseTextGraph(outputPath)
|
||||
|
||||
inpNames = ['image_tensor']
|
||||
outNames = ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes']
|
||||
|
||||
def getUnconnectedNodes():
|
||||
unconnected = []
|
||||
for node in graph_def.node:
|
||||
unconnected.append(node.name)
|
||||
for inp in node.input:
|
||||
if inp in unconnected:
|
||||
unconnected.remove(inp)
|
||||
return unconnected
|
||||
|
||||
|
||||
# Detect unfused batch normalization nodes and fuse them.
|
||||
def fuse_batch_normalization():
|
||||
# Add_0 <-- moving_variance, add_y
|
||||
# Rsqrt <-- Add_0
|
||||
# Mul_0 <-- Rsqrt, gamma
|
||||
# Mul_1 <-- input, Mul_0
|
||||
# Mul_2 <-- moving_mean, Mul_0
|
||||
# Sub_0 <-- beta, Mul_2
|
||||
# Add_1 <-- Mul_1, Sub_0
|
||||
nodesMap = {node.name: node for node in graph_def.node}
|
||||
subgraph = ['Add',
|
||||
['Mul', 'input', ['Mul', ['Rsqrt', ['Add', 'moving_variance', 'add_y']], 'gamma']],
|
||||
['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]]
|
||||
def checkSubgraph(node, targetNode, inputs, fusedNodes):
|
||||
op = targetNode[0]
|
||||
if node.op == op and (len(node.input) >= len(targetNode) - 1):
|
||||
fusedNodes.append(node)
|
||||
for i, inpOp in enumerate(targetNode[1:]):
|
||||
if isinstance(inpOp, list):
|
||||
if not node.input[i] in nodesMap or \
|
||||
not checkSubgraph(nodesMap[node.input[i]], inpOp, inputs, fusedNodes):
|
||||
return False
|
||||
# Detect unfused batch normalization nodes and fuse them.
|
||||
def fuse_batch_normalization():
|
||||
# Add_0 <-- moving_variance, add_y
|
||||
# Rsqrt <-- Add_0
|
||||
# Mul_0 <-- Rsqrt, gamma
|
||||
# Mul_1 <-- input, Mul_0
|
||||
# Mul_2 <-- moving_mean, Mul_0
|
||||
# Sub_0 <-- beta, Mul_2
|
||||
# Add_1 <-- Mul_1, Sub_0
|
||||
nodesMap = {node.name: node for node in graph_def.node}
|
||||
subgraph = ['Add',
|
||||
['Mul', 'input', ['Mul', ['Rsqrt', ['Add', 'moving_variance', 'add_y']], 'gamma']],
|
||||
['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]]
|
||||
def checkSubgraph(node, targetNode, inputs, fusedNodes):
|
||||
op = targetNode[0]
|
||||
if node.op == op and (len(node.input) >= len(targetNode) - 1):
|
||||
fusedNodes.append(node)
|
||||
for i, inpOp in enumerate(targetNode[1:]):
|
||||
if isinstance(inpOp, list):
|
||||
if not node.input[i] in nodesMap or \
|
||||
not checkSubgraph(nodesMap[node.input[i]], inpOp, inputs, fusedNodes):
|
||||
return False
|
||||
else:
|
||||
inputs[inpOp] = node.input[i]
|
||||
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
nodesToRemove = []
|
||||
for node in graph_def.node:
|
||||
inputs = {}
|
||||
fusedNodes = []
|
||||
if checkSubgraph(node, subgraph, inputs, fusedNodes):
|
||||
name = node.name
|
||||
node.Clear()
|
||||
node.name = name
|
||||
node.op = 'FusedBatchNorm'
|
||||
node.input.append(inputs['input'])
|
||||
node.input.append(inputs['gamma'])
|
||||
node.input.append(inputs['beta'])
|
||||
node.input.append(inputs['moving_mean'])
|
||||
node.input.append(inputs['moving_variance'])
|
||||
node.addAttr('epsilon', 0.001)
|
||||
nodesToRemove += fusedNodes[1:]
|
||||
for node in nodesToRemove:
|
||||
graph_def.node.remove(node)
|
||||
|
||||
fuse_batch_normalization()
|
||||
|
||||
removeIdentity(graph_def)
|
||||
|
||||
def to_remove(name, op):
|
||||
return (not op in keepOps) or name.startswith(prefixesToRemove)
|
||||
|
||||
removeUnusedNodesAndAttrs(to_remove, graph_def)
|
||||
|
||||
|
||||
# Connect input node to the first layer
|
||||
assert(graph_def.node[0].op == 'Placeholder')
|
||||
# assert(graph_def.node[1].op == 'Conv2D')
|
||||
weights = graph_def.node[1].input[0]
|
||||
for i in range(len(graph_def.node[1].input)):
|
||||
graph_def.node[1].input.pop()
|
||||
graph_def.node[1].input.append(graph_def.node[0].name)
|
||||
graph_def.node[1].input.append(weights)
|
||||
|
||||
# Create SSD postprocessing head ###############################################
|
||||
|
||||
# Concatenate predictions of classes, predictions of bounding boxes and proposals.
|
||||
def addConcatNode(name, inputs, axisNodeName):
|
||||
concat = NodeDef()
|
||||
concat.name = name
|
||||
concat.op = 'ConcatV2'
|
||||
for inp in inputs:
|
||||
concat.input.append(inp)
|
||||
concat.input.append(axisNodeName)
|
||||
graph_def.node.extend([concat])
|
||||
|
||||
addConstNode('concat/axis_flatten', [-1], graph_def)
|
||||
addConstNode('PriorBox/concat/axis', [-2], graph_def)
|
||||
|
||||
for label in ['ClassPredictor', 'BoxEncodingPredictor' if box_predictor is 'convolutional' else 'BoxPredictor']:
|
||||
concatInputs = []
|
||||
for i in range(num_layers):
|
||||
# Flatten predictions
|
||||
flatten = NodeDef()
|
||||
if box_predictor is 'convolutional':
|
||||
inpName = 'BoxPredictor_%d/%s/BiasAdd' % (i, label)
|
||||
else:
|
||||
if i == 0:
|
||||
inpName = 'WeightSharedConvolutionalBoxPredictor/%s/BiasAdd' % label
|
||||
else:
|
||||
inputs[inpOp] = node.input[i]
|
||||
inpName = 'WeightSharedConvolutionalBoxPredictor_%d/%s/BiasAdd' % (i, label)
|
||||
flatten.input.append(inpName)
|
||||
flatten.name = inpName + '/Flatten'
|
||||
flatten.op = 'Flatten'
|
||||
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
concatInputs.append(flatten.name)
|
||||
graph_def.node.extend([flatten])
|
||||
addConcatNode('%s/concat' % label, concatInputs, 'concat/axis_flatten')
|
||||
|
||||
nodesToRemove = []
|
||||
idx = 0
|
||||
for node in graph_def.node:
|
||||
inputs = {}
|
||||
fusedNodes = []
|
||||
if checkSubgraph(node, subgraph, inputs, fusedNodes):
|
||||
name = node.name
|
||||
node.Clear()
|
||||
node.name = name
|
||||
node.op = 'FusedBatchNorm'
|
||||
node.input.append(inputs['input'])
|
||||
node.input.append(inputs['gamma'])
|
||||
node.input.append(inputs['beta'])
|
||||
node.input.append(inputs['moving_mean'])
|
||||
node.input.append(inputs['moving_variance'])
|
||||
text_format.Merge('f: 0.001', node.attr["epsilon"])
|
||||
nodesToRemove += fusedNodes[1:]
|
||||
for node in nodesToRemove:
|
||||
graph_def.node.remove(node)
|
||||
if node.name == ('BoxPredictor_%d/BoxEncodingPredictor/Conv2D' % idx) or \
|
||||
node.name == ('WeightSharedConvolutionalBoxPredictor_%d/BoxPredictor/Conv2D' % idx) or \
|
||||
node.name == 'WeightSharedConvolutionalBoxPredictor/BoxPredictor/Conv2D':
|
||||
node.addAttr('loc_pred_transposed', True)
|
||||
idx += 1
|
||||
assert(idx == num_layers)
|
||||
|
||||
fuse_batch_normalization()
|
||||
# Add layers that generate anchors (bounding boxes proposals).
|
||||
scales = [min_scale + (max_scale - min_scale) * i / (num_layers - 1)
|
||||
for i in range(num_layers)] + [1.0]
|
||||
|
||||
removeIdentity(graph_def)
|
||||
|
||||
def to_remove(name, op):
|
||||
return (not op in keepOps) or name.startswith(prefixesToRemove)
|
||||
|
||||
removeUnusedNodesAndAttrs(to_remove, graph_def)
|
||||
|
||||
|
||||
# Connect input node to the first layer
|
||||
assert(graph_def.node[0].op == 'Placeholder')
|
||||
# assert(graph_def.node[1].op == 'Conv2D')
|
||||
weights = graph_def.node[1].input[0]
|
||||
for i in range(len(graph_def.node[1].input)):
|
||||
graph_def.node[1].input.pop()
|
||||
graph_def.node[1].input.append(graph_def.node[0].name)
|
||||
graph_def.node[1].input.append(weights)
|
||||
|
||||
# Create SSD postprocessing head ###############################################
|
||||
|
||||
# Concatenate predictions of classes, predictions of bounding boxes and proposals.
|
||||
def addConcatNode(name, inputs, axisNodeName):
|
||||
concat = NodeDef()
|
||||
concat.name = name
|
||||
concat.op = 'ConcatV2'
|
||||
for inp in inputs:
|
||||
concat.input.append(inp)
|
||||
concat.input.append(axisNodeName)
|
||||
graph_def.node.extend([concat])
|
||||
|
||||
addConstNode('concat/axis_flatten', [-1], graph_def)
|
||||
addConstNode('PriorBox/concat/axis', [-2], graph_def)
|
||||
|
||||
for label in ['ClassPredictor', 'BoxEncodingPredictor' if args.box_predictor is 'convolutional' else 'BoxPredictor']:
|
||||
concatInputs = []
|
||||
for i in range(args.num_layers):
|
||||
# Flatten predictions
|
||||
flatten = NodeDef()
|
||||
if args.box_predictor is 'convolutional':
|
||||
inpName = 'BoxPredictor_%d/%s/BiasAdd' % (i, label)
|
||||
priorBoxes = []
|
||||
for i in range(num_layers):
|
||||
priorBox = NodeDef()
|
||||
priorBox.name = 'PriorBox_%d' % i
|
||||
priorBox.op = 'PriorBox'
|
||||
if box_predictor is 'convolutional':
|
||||
priorBox.input.append('BoxPredictor_%d/BoxEncodingPredictor/BiasAdd' % i)
|
||||
else:
|
||||
if i == 0:
|
||||
inpName = 'WeightSharedConvolutionalBoxPredictor/%s/BiasAdd' % label
|
||||
priorBox.input.append('WeightSharedConvolutionalBoxPredictor/BoxPredictor/Conv2D')
|
||||
else:
|
||||
inpName = 'WeightSharedConvolutionalBoxPredictor_%d/%s/BiasAdd' % (i, label)
|
||||
flatten.input.append(inpName)
|
||||
flatten.name = inpName + '/Flatten'
|
||||
flatten.op = 'Flatten'
|
||||
priorBox.input.append('WeightSharedConvolutionalBoxPredictor_%d/BoxPredictor/BiasAdd' % i)
|
||||
priorBox.input.append(graph_def.node[0].name) # image_tensor
|
||||
|
||||
concatInputs.append(flatten.name)
|
||||
graph_def.node.extend([flatten])
|
||||
addConcatNode('%s/concat' % label, concatInputs, 'concat/axis_flatten')
|
||||
priorBox.addAttr('flip', False)
|
||||
priorBox.addAttr('clip', False)
|
||||
|
||||
idx = 0
|
||||
for node in graph_def.node:
|
||||
if node.name == ('BoxPredictor_%d/BoxEncodingPredictor/Conv2D' % idx) or \
|
||||
node.name == ('WeightSharedConvolutionalBoxPredictor_%d/BoxPredictor/Conv2D' % idx) or \
|
||||
node.name == 'WeightSharedConvolutionalBoxPredictor/BoxPredictor/Conv2D':
|
||||
text_format.Merge('b: true', node.attr["loc_pred_transposed"])
|
||||
idx += 1
|
||||
assert(idx == args.num_layers)
|
||||
|
||||
# Add layers that generate anchors (bounding boxes proposals).
|
||||
scales = [args.min_scale + (args.max_scale - args.min_scale) * i / (args.num_layers - 1)
|
||||
for i in range(args.num_layers)] + [1.0]
|
||||
|
||||
priorBoxes = []
|
||||
for i in range(args.num_layers):
|
||||
priorBox = NodeDef()
|
||||
priorBox.name = 'PriorBox_%d' % i
|
||||
priorBox.op = 'PriorBox'
|
||||
if args.box_predictor is 'convolutional':
|
||||
priorBox.input.append('BoxPredictor_%d/BoxEncodingPredictor/BiasAdd' % i)
|
||||
else:
|
||||
if i == 0:
|
||||
priorBox.input.append('WeightSharedConvolutionalBoxPredictor/BoxPredictor/Conv2D')
|
||||
if i == 0 and reduce_boxes_in_lowest_layer:
|
||||
widths = [0.1, min_scale * sqrt(2.0), min_scale * sqrt(0.5)]
|
||||
heights = [0.1, min_scale / sqrt(2.0), min_scale / sqrt(0.5)]
|
||||
else:
|
||||
priorBox.input.append('WeightSharedConvolutionalBoxPredictor_%d/BoxPredictor/BiasAdd' % i)
|
||||
priorBox.input.append(graph_def.node[0].name) # image_tensor
|
||||
widths = [scales[i] * sqrt(ar) for ar in aspect_ratios]
|
||||
heights = [scales[i] / sqrt(ar) for ar in aspect_ratios]
|
||||
|
||||
text_format.Merge('b: false', priorBox.attr["flip"])
|
||||
text_format.Merge('b: false', priorBox.attr["clip"])
|
||||
widths += [sqrt(scales[i] * scales[i + 1])]
|
||||
heights += [sqrt(scales[i] * scales[i + 1])]
|
||||
widths = [w * image_width for w in widths]
|
||||
heights = [h * image_height for h in heights]
|
||||
priorBox.addAttr('width', widths)
|
||||
priorBox.addAttr('height', heights)
|
||||
priorBox.addAttr('variance', [0.1, 0.1, 0.2, 0.2])
|
||||
|
||||
if i == 0 and not args.not_reduce_boxes_in_lowest_layer:
|
||||
widths = [0.1, args.min_scale * sqrt(2.0), args.min_scale * sqrt(0.5)]
|
||||
heights = [0.1, args.min_scale / sqrt(2.0), args.min_scale / sqrt(0.5)]
|
||||
graph_def.node.extend([priorBox])
|
||||
priorBoxes.append(priorBox.name)
|
||||
|
||||
addConcatNode('PriorBox/concat', priorBoxes, 'concat/axis_flatten')
|
||||
|
||||
# Sigmoid for classes predictions and DetectionOutput layer
|
||||
sigmoid = NodeDef()
|
||||
sigmoid.name = 'ClassPredictor/concat/sigmoid'
|
||||
sigmoid.op = 'Sigmoid'
|
||||
sigmoid.input.append('ClassPredictor/concat')
|
||||
graph_def.node.extend([sigmoid])
|
||||
|
||||
detectionOut = NodeDef()
|
||||
detectionOut.name = 'detection_out'
|
||||
detectionOut.op = 'DetectionOutput'
|
||||
|
||||
if box_predictor == 'convolutional':
|
||||
detectionOut.input.append('BoxEncodingPredictor/concat')
|
||||
else:
|
||||
widths = [scales[i] * sqrt(ar) for ar in args.aspect_ratios]
|
||||
heights = [scales[i] / sqrt(ar) for ar in args.aspect_ratios]
|
||||
detectionOut.input.append('BoxPredictor/concat')
|
||||
detectionOut.input.append(sigmoid.name)
|
||||
detectionOut.input.append('PriorBox/concat')
|
||||
|
||||
widths += [sqrt(scales[i] * scales[i + 1])]
|
||||
heights += [sqrt(scales[i] * scales[i + 1])]
|
||||
widths = [w * args.image_width for w in widths]
|
||||
heights = [h * args.image_height for h in heights]
|
||||
text_format.Merge(tensorMsg(widths), priorBox.attr["width"])
|
||||
text_format.Merge(tensorMsg(heights), priorBox.attr["height"])
|
||||
text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), priorBox.attr["variance"])
|
||||
detectionOut.addAttr('num_classes', num_classes + 1)
|
||||
detectionOut.addAttr('share_location', True)
|
||||
detectionOut.addAttr('background_label_id', 0)
|
||||
detectionOut.addAttr('nms_threshold', 0.6)
|
||||
detectionOut.addAttr('top_k', 100)
|
||||
detectionOut.addAttr('code_type', "CENTER_SIZE")
|
||||
detectionOut.addAttr('keep_top_k', 100)
|
||||
detectionOut.addAttr('confidence_threshold', 0.01)
|
||||
|
||||
graph_def.node.extend([priorBox])
|
||||
priorBoxes.append(priorBox.name)
|
||||
graph_def.node.extend([detectionOut])
|
||||
|
||||
addConcatNode('PriorBox/concat', priorBoxes, 'concat/axis_flatten')
|
||||
while True:
|
||||
unconnectedNodes = getUnconnectedNodes()
|
||||
unconnectedNodes.remove(detectionOut.name)
|
||||
if not unconnectedNodes:
|
||||
break
|
||||
|
||||
# Sigmoid for classes predictions and DetectionOutput layer
|
||||
sigmoid = NodeDef()
|
||||
sigmoid.name = 'ClassPredictor/concat/sigmoid'
|
||||
sigmoid.op = 'Sigmoid'
|
||||
sigmoid.input.append('ClassPredictor/concat')
|
||||
graph_def.node.extend([sigmoid])
|
||||
for name in unconnectedNodes:
|
||||
for i in range(len(graph_def.node)):
|
||||
if graph_def.node[i].name == name:
|
||||
del graph_def.node[i]
|
||||
break
|
||||
|
||||
detectionOut = NodeDef()
|
||||
detectionOut.name = 'detection_out'
|
||||
detectionOut.op = 'DetectionOutput'
|
||||
# Save as text.
|
||||
graph_def.save(outputPath)
|
||||
|
||||
if args.box_predictor == 'convolutional':
|
||||
detectionOut.input.append('BoxEncodingPredictor/concat')
|
||||
else:
|
||||
detectionOut.input.append('BoxPredictor/concat')
|
||||
detectionOut.input.append(sigmoid.name)
|
||||
detectionOut.input.append('PriorBox/concat')
|
||||
|
||||
text_format.Merge('i: %d' % (args.num_classes + 1), detectionOut.attr['num_classes'])
|
||||
text_format.Merge('b: true', detectionOut.attr['share_location'])
|
||||
text_format.Merge('i: 0', detectionOut.attr['background_label_id'])
|
||||
text_format.Merge('f: 0.6', detectionOut.attr['nms_threshold'])
|
||||
text_format.Merge('i: 100', detectionOut.attr['top_k'])
|
||||
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
|
||||
text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
|
||||
text_format.Merge('f: 0.01', detectionOut.attr['confidence_threshold'])
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
||||
'SSD model from TensorFlow Object Detection API. '
|
||||
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
|
||||
parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
|
||||
parser.add_argument('--output', required=True, help='Path to output text graph.')
|
||||
parser.add_argument('--config', required=True, help='Path to a *.config file is used for training.')
|
||||
args = parser.parse_args()
|
||||
|
||||
graph_def.node.extend([detectionOut])
|
||||
|
||||
while True:
|
||||
unconnectedNodes = getUnconnectedNodes()
|
||||
unconnectedNodes.remove(detectionOut.name)
|
||||
if not unconnectedNodes:
|
||||
break
|
||||
|
||||
for name in unconnectedNodes:
|
||||
for i in range(len(graph_def.node)):
|
||||
if graph_def.node[i].name == name:
|
||||
del graph_def.node[i]
|
||||
break
|
||||
|
||||
# Save as text.
|
||||
tf.train.write_graph(graph_def, "", args.output, as_text=True)
|
||||
createSSDGraph(args.input, args.config, args.output)
|
||||
|
Loading…
Reference in New Issue
Block a user