mirror of
https://github.com/opencv/opencv.git
synced 2025-06-08 01:53:19 +08:00
Import SSDs from TensorFlow by training config (#12188)
* Remove TensorFlow and protobuf dependencies from object detection scripts * Create text graphs for TensorFlow object detection networks from sample
This commit is contained in:
parent
e3af72bb68
commit
c7cf8fb35c
@ -885,6 +885,14 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
|
|||||||
CV_EXPORTS_W void shrinkCaffeModel(const String& src, const String& dst,
|
CV_EXPORTS_W void shrinkCaffeModel(const String& src, const String& dst,
|
||||||
const std::vector<String>& layersTypes = std::vector<String>());
|
const std::vector<String>& layersTypes = std::vector<String>());
|
||||||
|
|
||||||
|
/** @brief Create a text representation for a binary network stored in protocol buffer format.
|
||||||
|
* @param[in] model A path to binary network.
|
||||||
|
* @param[in] output A path to output text file to be created.
|
||||||
|
*
|
||||||
|
* @note To reduce output file size, trained weights are not included.
|
||||||
|
*/
|
||||||
|
CV_EXPORTS_W void writeTextGraph(const String& model, const String& output);
|
||||||
|
|
||||||
/** @brief Performs non maximum suppression given boxes and corresponding scores.
|
/** @brief Performs non maximum suppression given boxes and corresponding scores.
|
||||||
|
|
||||||
* @param bboxes a set of bounding boxes to apply NMS.
|
* @param bboxes a set of bounding boxes to apply NMS.
|
||||||
|
@ -782,6 +782,108 @@ void releaseTensor(tensorflow::TensorProto* tensor)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void permute(google::protobuf::RepeatedPtrField<tensorflow::NodeDef>* data,
|
||||||
|
const std::vector<int>& indices)
|
||||||
|
{
|
||||||
|
const int num = data->size();
|
||||||
|
CV_Assert(num == indices.size());
|
||||||
|
|
||||||
|
std::vector<int> elemIdToPos(num);
|
||||||
|
std::vector<int> posToElemId(num);
|
||||||
|
for (int i = 0; i < num; ++i)
|
||||||
|
{
|
||||||
|
elemIdToPos[i] = i;
|
||||||
|
posToElemId[i] = i;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < num; ++i)
|
||||||
|
{
|
||||||
|
int elemId = indices[i];
|
||||||
|
int pos = elemIdToPos[elemId];
|
||||||
|
if (pos != i)
|
||||||
|
{
|
||||||
|
data->SwapElements(i, pos);
|
||||||
|
const int swappedElemId = posToElemId[i];
|
||||||
|
elemIdToPos[elemId] = i;
|
||||||
|
elemIdToPos[swappedElemId] = pos;
|
||||||
|
|
||||||
|
posToElemId[i] = elemId;
|
||||||
|
posToElemId[pos] = swappedElemId;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is based on tensorflow::graph_transforms::SortByExecutionOrder
|
||||||
|
void sortByExecutionOrder(tensorflow::GraphDef& net)
|
||||||
|
{
|
||||||
|
// Maps node's name to index at net.node() list.
|
||||||
|
std::map<std::string, int> nodesMap;
|
||||||
|
std::map<std::string, int>::iterator nodesMapIt;
|
||||||
|
for (int i = 0; i < net.node_size(); ++i)
|
||||||
|
{
|
||||||
|
const tensorflow::NodeDef& node = net.node(i);
|
||||||
|
nodesMap.insert(std::make_pair(node.name(), i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Indices of nodes which use specific node as input.
|
||||||
|
std::vector<std::vector<int> > edges(nodesMap.size());
|
||||||
|
std::vector<int> numRefsToAdd(nodesMap.size(), 0);
|
||||||
|
std::vector<int> nodesToAdd;
|
||||||
|
for (int i = 0; i < net.node_size(); ++i)
|
||||||
|
{
|
||||||
|
const tensorflow::NodeDef& node = net.node(i);
|
||||||
|
for (int j = 0; j < node.input_size(); ++j)
|
||||||
|
{
|
||||||
|
std::string inpName = node.input(j);
|
||||||
|
inpName = inpName.substr(0, inpName.rfind(':'));
|
||||||
|
inpName = inpName.substr(inpName.find('^') + 1);
|
||||||
|
|
||||||
|
nodesMapIt = nodesMap.find(inpName);
|
||||||
|
CV_Assert(nodesMapIt != nodesMap.end());
|
||||||
|
edges[nodesMapIt->second].push_back(i);
|
||||||
|
}
|
||||||
|
if (node.input_size() == 0)
|
||||||
|
nodesToAdd.push_back(i);
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if (node.op() == "Merge" || node.op() == "RefMerge")
|
||||||
|
{
|
||||||
|
int numControlEdges = 0;
|
||||||
|
for (int j = 0; j < node.input_size(); ++j)
|
||||||
|
numControlEdges += node.input(j)[0] == '^';
|
||||||
|
numRefsToAdd[i] = numControlEdges + 1;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
numRefsToAdd[i] = node.input_size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> permIds;
|
||||||
|
permIds.reserve(net.node_size());
|
||||||
|
while (!nodesToAdd.empty())
|
||||||
|
{
|
||||||
|
int nodeToAdd = nodesToAdd.back();
|
||||||
|
nodesToAdd.pop_back();
|
||||||
|
|
||||||
|
permIds.push_back(nodeToAdd);
|
||||||
|
// std::cout << net.node(nodeToAdd).name() << '\n';
|
||||||
|
|
||||||
|
for (int i = 0; i < edges[nodeToAdd].size(); ++i)
|
||||||
|
{
|
||||||
|
int consumerId = edges[nodeToAdd][i];
|
||||||
|
if (numRefsToAdd[consumerId] > 0)
|
||||||
|
{
|
||||||
|
if (numRefsToAdd[consumerId] == 1)
|
||||||
|
nodesToAdd.push_back(consumerId);
|
||||||
|
else
|
||||||
|
CV_Assert(numRefsToAdd[consumerId] >= 0);
|
||||||
|
numRefsToAdd[consumerId] -= 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CV_Assert(permIds.size() == net.node_size());
|
||||||
|
permute(net.mutable_node(), permIds);
|
||||||
|
}
|
||||||
|
|
||||||
CV__DNN_EXPERIMENTAL_NS_END
|
CV__DNN_EXPERIMENTAL_NS_END
|
||||||
}} // namespace dnn, namespace cv
|
}} // namespace dnn, namespace cv
|
||||||
|
|
||||||
|
@ -25,6 +25,8 @@ Mat getTensorContent(const tensorflow::TensorProto &tensor);
|
|||||||
|
|
||||||
void releaseTensor(tensorflow::TensorProto* tensor);
|
void releaseTensor(tensorflow::TensorProto* tensor);
|
||||||
|
|
||||||
|
void sortByExecutionOrder(tensorflow::GraphDef& net);
|
||||||
|
|
||||||
CV__DNN_EXPERIMENTAL_NS_END
|
CV__DNN_EXPERIMENTAL_NS_END
|
||||||
}} // namespace dnn, namespace cv
|
}} // namespace dnn, namespace cv
|
||||||
|
|
||||||
|
@ -1950,5 +1950,34 @@ Net readNetFromTensorflow(const std::vector<uchar>& bufferModel, const std::vect
|
|||||||
bufferConfigPtr, bufferConfig.size());
|
bufferConfigPtr, bufferConfig.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void writeTextGraph(const String& _model, const String& output)
|
||||||
|
{
|
||||||
|
String model = _model;
|
||||||
|
const std::string modelExt = model.substr(model.rfind('.') + 1);
|
||||||
|
if (modelExt != "pb")
|
||||||
|
CV_Error(Error::StsNotImplemented, "Only TensorFlow models support export to text file");
|
||||||
|
|
||||||
|
tensorflow::GraphDef net;
|
||||||
|
ReadTFNetParamsFromBinaryFileOrDie(model.c_str(), &net);
|
||||||
|
|
||||||
|
sortByExecutionOrder(net);
|
||||||
|
|
||||||
|
RepeatedPtrField<tensorflow::NodeDef>::iterator it;
|
||||||
|
for (it = net.mutable_node()->begin(); it != net.mutable_node()->end(); ++it)
|
||||||
|
{
|
||||||
|
if (it->op() == "Const")
|
||||||
|
{
|
||||||
|
it->mutable_attr()->at("value").mutable_tensor()->clear_tensor_content();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string content;
|
||||||
|
google::protobuf::TextFormat::PrintToString(net, &content);
|
||||||
|
|
||||||
|
std::ofstream ofs(output.c_str());
|
||||||
|
ofs << content;
|
||||||
|
ofs.close();
|
||||||
|
}
|
||||||
|
|
||||||
CV__DNN_EXPERIMENTAL_NS_END
|
CV__DNN_EXPERIMENTAL_NS_END
|
||||||
}} // namespace
|
}} // namespace
|
||||||
|
@ -315,6 +315,29 @@ TEST_P(Test_TensorFlow_nets, Inception_v2_SSD)
|
|||||||
normAssertDetections(ref, out, "", 0.5, scoreDiff, iouDiff);
|
normAssertDetections(ref, out, "", 0.5, scoreDiff, iouDiff);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(Test_TensorFlow_nets, MobileNet_v1_SSD)
|
||||||
|
{
|
||||||
|
checkBackend();
|
||||||
|
|
||||||
|
std::string model = findDataFile("dnn/ssd_mobilenet_v1_coco_2017_11_17.pb", false);
|
||||||
|
std::string proto = findDataFile("dnn/ssd_mobilenet_v1_coco_2017_11_17.pbtxt", false);
|
||||||
|
|
||||||
|
Net net = readNetFromTensorflow(model, proto);
|
||||||
|
Mat img = imread(findDataFile("dnn/dog416.png", false));
|
||||||
|
Mat blob = blobFromImage(img, 1.0f, Size(300, 300), Scalar(), true, false);
|
||||||
|
|
||||||
|
net.setPreferableBackend(backend);
|
||||||
|
net.setPreferableTarget(target);
|
||||||
|
|
||||||
|
net.setInput(blob);
|
||||||
|
Mat out = net.forward();
|
||||||
|
|
||||||
|
Mat ref = blobFromNPY(findDataFile("dnn/tensorflow/ssd_mobilenet_v1_coco_2017_11_17.detection_out.npy"));
|
||||||
|
float scoreDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 7e-3 : 1e-5;
|
||||||
|
float iouDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.0098 : 1e-3;
|
||||||
|
normAssertDetections(ref, out, "", 0.3, scoreDiff, iouDiff);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(Test_TensorFlow_nets, Faster_RCNN)
|
TEST_P(Test_TensorFlow_nets, Faster_RCNN)
|
||||||
{
|
{
|
||||||
static std::string names[] = {"faster_rcnn_inception_v2_coco_2018_01_28",
|
static std::string names[] = {"faster_rcnn_inception_v2_coco_2018_01_28",
|
||||||
@ -360,7 +383,8 @@ TEST_P(Test_TensorFlow_nets, MobileNet_v1_SSD_PPN)
|
|||||||
|
|
||||||
net.setInput(blob);
|
net.setInput(blob);
|
||||||
Mat out = net.forward();
|
Mat out = net.forward();
|
||||||
double scoreDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.008 : default_l1;
|
|
||||||
|
double scoreDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.011 : default_l1;
|
||||||
double iouDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.021 : default_lInf;
|
double iouDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.021 : default_lInf;
|
||||||
normAssertDetections(ref, out, "", 0.4, scoreDiff, iouDiff);
|
normAssertDetections(ref, out, "", 0.4, scoreDiff, iouDiff);
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,10 @@ import argparse
|
|||||||
import sys
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tf_text_graph_common import readTextMessage
|
||||||
|
from tf_text_graph_ssd import createSSDGraph
|
||||||
|
from tf_text_graph_faster_rcnn import createFasterRCNNGraph
|
||||||
|
|
||||||
backends = (cv.dnn.DNN_BACKEND_DEFAULT, cv.dnn.DNN_BACKEND_HALIDE, cv.dnn.DNN_BACKEND_INFERENCE_ENGINE, cv.dnn.DNN_BACKEND_OPENCV)
|
backends = (cv.dnn.DNN_BACKEND_DEFAULT, cv.dnn.DNN_BACKEND_HALIDE, cv.dnn.DNN_BACKEND_INFERENCE_ENGINE, cv.dnn.DNN_BACKEND_OPENCV)
|
||||||
targets = (cv.dnn.DNN_TARGET_CPU, cv.dnn.DNN_TARGET_OPENCL, cv.dnn.DNN_TARGET_OPENCL_FP16, cv.dnn.DNN_TARGET_MYRIAD)
|
targets = (cv.dnn.DNN_TARGET_CPU, cv.dnn.DNN_TARGET_OPENCL, cv.dnn.DNN_TARGET_OPENCL_FP16, cv.dnn.DNN_TARGET_MYRIAD)
|
||||||
|
|
||||||
@ -11,11 +15,15 @@ parser.add_argument('--input', help='Path to input image or video file. Skip thi
|
|||||||
parser.add_argument('--model', required=True,
|
parser.add_argument('--model', required=True,
|
||||||
help='Path to a binary file of model contains trained weights. '
|
help='Path to a binary file of model contains trained weights. '
|
||||||
'It could be a file with extensions .caffemodel (Caffe), '
|
'It could be a file with extensions .caffemodel (Caffe), '
|
||||||
'.pb (TensorFlow), .t7 or .net (Torch), .weights (Darknet)')
|
'.pb (TensorFlow), .t7 or .net (Torch), .weights (Darknet), .bin (OpenVINO)')
|
||||||
parser.add_argument('--config',
|
parser.add_argument('--config',
|
||||||
help='Path to a text file of model contains network configuration. '
|
help='Path to a text file of model contains network configuration. '
|
||||||
'It could be a file with extensions .prototxt (Caffe), .pbtxt (TensorFlow), .cfg (Darknet)')
|
'It could be a file with extensions .prototxt (Caffe), .pbtxt or .config (TensorFlow), .cfg (Darknet), .xml (OpenVINO)')
|
||||||
parser.add_argument('--framework', choices=['caffe', 'tensorflow', 'torch', 'darknet'],
|
parser.add_argument('--out_tf_graph', default='graph.pbtxt',
|
||||||
|
help='For models from TensorFlow Object Detection API, you may '
|
||||||
|
'pass a .config file which was used for training through --config '
|
||||||
|
'argument. This way an additional .pbtxt file with TensorFlow graph will be created.')
|
||||||
|
parser.add_argument('--framework', choices=['caffe', 'tensorflow', 'torch', 'darknet', 'dldt'],
|
||||||
help='Optional name of an origin framework of the model. '
|
help='Optional name of an origin framework of the model. '
|
||||||
'Detect it automatically if it does not set.')
|
'Detect it automatically if it does not set.')
|
||||||
parser.add_argument('--classes', help='Optional path to a text file with names of classes to label detected objects.')
|
parser.add_argument('--classes', help='Optional path to a text file with names of classes to label detected objects.')
|
||||||
@ -46,6 +54,20 @@ parser.add_argument('--target', choices=targets, default=cv.dnn.DNN_TARGET_CPU,
|
|||||||
'%d: VPU' % targets)
|
'%d: VPU' % targets)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# If config specified, try to load it as TensorFlow Object Detection API's pipeline.
|
||||||
|
config = readTextMessage(args.config)
|
||||||
|
if 'model' in config:
|
||||||
|
print('TensorFlow Object Detection API config detected')
|
||||||
|
if 'ssd' in config['model'][0]:
|
||||||
|
print('Preparing text graph representation for SSD model: ' + args.out_tf_graph)
|
||||||
|
createSSDGraph(args.model, args.config, args.out_tf_graph)
|
||||||
|
args.config = args.out_tf_graph
|
||||||
|
elif 'faster_rcnn' in config['model'][0]:
|
||||||
|
print('Preparing text graph representation for Faster-RCNN model: ' + args.out_tf_graph)
|
||||||
|
createFasterRCNNGraph(args.model, args.config, args.out_tf_graph)
|
||||||
|
args.config = args.out_tf_graph
|
||||||
|
|
||||||
|
|
||||||
# Load names of classes
|
# Load names of classes
|
||||||
classes = None
|
classes = None
|
||||||
if args.classes:
|
if args.classes:
|
||||||
|
@ -1,8 +1,86 @@
|
|||||||
import tensorflow as tf
|
def tokenize(s):
|
||||||
from tensorflow.core.framework.node_def_pb2 import NodeDef
|
tokens = []
|
||||||
from google.protobuf import text_format
|
token = ""
|
||||||
|
isString = False
|
||||||
|
isComment = False
|
||||||
|
for symbol in s:
|
||||||
|
isComment = (isComment and symbol != '\n') or (not isString and symbol == '#')
|
||||||
|
if isComment:
|
||||||
|
continue
|
||||||
|
|
||||||
def tensorMsg(values):
|
if symbol == ' ' or symbol == '\t' or symbol == '\r' or symbol == '\'' or \
|
||||||
|
symbol == '\n' or symbol == ':' or symbol == '\"' or symbol == ';' or \
|
||||||
|
symbol == ',':
|
||||||
|
|
||||||
|
if (symbol == '\"' or symbol == '\'') and isString:
|
||||||
|
tokens.append(token)
|
||||||
|
token = ""
|
||||||
|
else:
|
||||||
|
if isString:
|
||||||
|
token += symbol
|
||||||
|
elif token:
|
||||||
|
tokens.append(token)
|
||||||
|
token = ""
|
||||||
|
isString = (symbol == '\"' or symbol == '\'') ^ isString;
|
||||||
|
|
||||||
|
elif symbol == '{' or symbol == '}' or symbol == '[' or symbol == ']':
|
||||||
|
if token:
|
||||||
|
tokens.append(token)
|
||||||
|
token = ""
|
||||||
|
tokens.append(symbol)
|
||||||
|
else:
|
||||||
|
token += symbol
|
||||||
|
if token:
|
||||||
|
tokens.append(token)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
def parseMessage(tokens, idx):
|
||||||
|
msg = {}
|
||||||
|
assert(tokens[idx] == '{')
|
||||||
|
|
||||||
|
isArray = False
|
||||||
|
while True:
|
||||||
|
if not isArray:
|
||||||
|
idx += 1
|
||||||
|
if idx < len(tokens):
|
||||||
|
fieldName = tokens[idx]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
if fieldName == '}':
|
||||||
|
break
|
||||||
|
|
||||||
|
idx += 1
|
||||||
|
fieldValue = tokens[idx]
|
||||||
|
|
||||||
|
if fieldValue == '{':
|
||||||
|
embeddedMsg, idx = parseMessage(tokens, idx)
|
||||||
|
if fieldName in msg:
|
||||||
|
msg[fieldName].append(embeddedMsg)
|
||||||
|
else:
|
||||||
|
msg[fieldName] = [embeddedMsg]
|
||||||
|
elif fieldValue == '[':
|
||||||
|
isArray = True
|
||||||
|
elif fieldValue == ']':
|
||||||
|
isArray = False
|
||||||
|
else:
|
||||||
|
if fieldName in msg:
|
||||||
|
msg[fieldName].append(fieldValue)
|
||||||
|
else:
|
||||||
|
msg[fieldName] = [fieldValue]
|
||||||
|
return msg, idx
|
||||||
|
|
||||||
|
|
||||||
|
def readTextMessage(filePath):
|
||||||
|
with open(filePath, 'rt') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
tokens = tokenize('{' + content + '}')
|
||||||
|
msg = parseMessage(tokens, 0)
|
||||||
|
return msg[0] if msg else {}
|
||||||
|
|
||||||
|
|
||||||
|
def listToTensor(values):
|
||||||
if all([isinstance(v, float) for v in values]):
|
if all([isinstance(v, float) for v in values]):
|
||||||
dtype = 'DT_FLOAT'
|
dtype = 'DT_FLOAT'
|
||||||
field = 'float_val'
|
field = 'float_val'
|
||||||
@ -12,16 +90,25 @@ def tensorMsg(values):
|
|||||||
else:
|
else:
|
||||||
raise Exception('Wrong values types')
|
raise Exception('Wrong values types')
|
||||||
|
|
||||||
msg = 'tensor { dtype: ' + dtype + ' tensor_shape { dim { size: %d } }' % len(values)
|
msg = {
|
||||||
for value in values:
|
'tensor': {
|
||||||
msg += '%s: %s ' % (field, str(value))
|
'dtype': dtype,
|
||||||
return msg + '}'
|
'tensor_shape': {
|
||||||
|
'dim': {
|
||||||
|
'size': len(values)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
msg['tensor'][field] = values
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
def addConstNode(name, values, graph_def):
|
def addConstNode(name, values, graph_def):
|
||||||
node = NodeDef()
|
node = NodeDef()
|
||||||
node.name = name
|
node.name = name
|
||||||
node.op = 'Const'
|
node.op = 'Const'
|
||||||
text_format.Merge(tensorMsg(values), node.attr["value"])
|
node.addAttr('value', values)
|
||||||
graph_def.node.extend([node])
|
graph_def.node.extend([node])
|
||||||
|
|
||||||
|
|
||||||
@ -29,13 +116,13 @@ def addSlice(inp, out, begins, sizes, graph_def):
|
|||||||
beginsNode = NodeDef()
|
beginsNode = NodeDef()
|
||||||
beginsNode.name = out + '/begins'
|
beginsNode.name = out + '/begins'
|
||||||
beginsNode.op = 'Const'
|
beginsNode.op = 'Const'
|
||||||
text_format.Merge(tensorMsg(begins), beginsNode.attr["value"])
|
beginsNode.addAttr('value', begins)
|
||||||
graph_def.node.extend([beginsNode])
|
graph_def.node.extend([beginsNode])
|
||||||
|
|
||||||
sizesNode = NodeDef()
|
sizesNode = NodeDef()
|
||||||
sizesNode.name = out + '/sizes'
|
sizesNode.name = out + '/sizes'
|
||||||
sizesNode.op = 'Const'
|
sizesNode.op = 'Const'
|
||||||
text_format.Merge(tensorMsg(sizes), sizesNode.attr["value"])
|
sizesNode.addAttr('value', sizes)
|
||||||
graph_def.node.extend([sizesNode])
|
graph_def.node.extend([sizesNode])
|
||||||
|
|
||||||
sliced = NodeDef()
|
sliced = NodeDef()
|
||||||
@ -51,7 +138,7 @@ def addReshape(inp, out, shape, graph_def):
|
|||||||
shapeNode = NodeDef()
|
shapeNode = NodeDef()
|
||||||
shapeNode.name = out + '/shape'
|
shapeNode.name = out + '/shape'
|
||||||
shapeNode.op = 'Const'
|
shapeNode.op = 'Const'
|
||||||
text_format.Merge(tensorMsg(shape), shapeNode.attr["value"])
|
shapeNode.addAttr('value', shape)
|
||||||
graph_def.node.extend([shapeNode])
|
graph_def.node.extend([shapeNode])
|
||||||
|
|
||||||
reshape = NodeDef()
|
reshape = NodeDef()
|
||||||
@ -66,7 +153,7 @@ def addSoftMax(inp, out, graph_def):
|
|||||||
softmax = NodeDef()
|
softmax = NodeDef()
|
||||||
softmax.name = out
|
softmax.name = out
|
||||||
softmax.op = 'Softmax'
|
softmax.op = 'Softmax'
|
||||||
text_format.Merge('i: -1', softmax.attr['axis'])
|
softmax.addAttr('axis', -1)
|
||||||
softmax.input.append(inp)
|
softmax.input.append(inp)
|
||||||
graph_def.node.extend([softmax])
|
graph_def.node.extend([softmax])
|
||||||
|
|
||||||
@ -79,6 +166,103 @@ def addFlatten(inp, out, graph_def):
|
|||||||
graph_def.node.extend([flatten])
|
graph_def.node.extend([flatten])
|
||||||
|
|
||||||
|
|
||||||
|
class NodeDef:
|
||||||
|
def __init__(self):
|
||||||
|
self.input = []
|
||||||
|
self.name = ""
|
||||||
|
self.op = ""
|
||||||
|
self.attr = {}
|
||||||
|
|
||||||
|
def addAttr(self, key, value):
|
||||||
|
assert(not key in self.attr)
|
||||||
|
if isinstance(value, bool):
|
||||||
|
self.attr[key] = {'b': value}
|
||||||
|
elif isinstance(value, int):
|
||||||
|
self.attr[key] = {'i': value}
|
||||||
|
elif isinstance(value, float):
|
||||||
|
self.attr[key] = {'f': value}
|
||||||
|
elif isinstance(value, str):
|
||||||
|
self.attr[key] = {'s': value}
|
||||||
|
elif isinstance(value, list):
|
||||||
|
self.attr[key] = listToTensor(value)
|
||||||
|
else:
|
||||||
|
raise Exception('Unknown type of attribute ' + key)
|
||||||
|
|
||||||
|
def Clear(self):
|
||||||
|
self.input = []
|
||||||
|
self.name = ""
|
||||||
|
self.op = ""
|
||||||
|
self.attr = {}
|
||||||
|
|
||||||
|
|
||||||
|
class GraphDef:
|
||||||
|
def __init__(self):
|
||||||
|
self.node = []
|
||||||
|
|
||||||
|
def save(self, filePath):
|
||||||
|
with open(filePath, 'wt') as f:
|
||||||
|
|
||||||
|
def printAttr(d, indent):
|
||||||
|
indent = ' ' * indent
|
||||||
|
for key, value in sorted(d.items(), key=lambda x:x[0].lower()):
|
||||||
|
value = value if isinstance(value, list) else [value]
|
||||||
|
for v in value:
|
||||||
|
if isinstance(v, dict):
|
||||||
|
f.write(indent + key + ' {\n')
|
||||||
|
printAttr(v, len(indent) + 2)
|
||||||
|
f.write(indent + '}\n')
|
||||||
|
else:
|
||||||
|
isString = False
|
||||||
|
if isinstance(v, str) and not v.startswith('DT_'):
|
||||||
|
try:
|
||||||
|
float(v)
|
||||||
|
except:
|
||||||
|
isString = True
|
||||||
|
|
||||||
|
if isinstance(v, bool):
|
||||||
|
printed = 'true' if v else 'false'
|
||||||
|
elif v == 'true' or v == 'false':
|
||||||
|
printed = 'true' if v == 'true' else 'false'
|
||||||
|
elif isString:
|
||||||
|
printed = '\"%s\"' % v
|
||||||
|
else:
|
||||||
|
printed = str(v)
|
||||||
|
f.write(indent + key + ': ' + printed + '\n')
|
||||||
|
|
||||||
|
for node in self.node:
|
||||||
|
f.write('node {\n')
|
||||||
|
f.write(' name: \"%s\"\n' % node.name)
|
||||||
|
f.write(' op: \"%s\"\n' % node.op)
|
||||||
|
for inp in node.input:
|
||||||
|
f.write(' input: \"%s\"\n' % inp)
|
||||||
|
for key, value in sorted(node.attr.items(), key=lambda x:x[0].lower()):
|
||||||
|
f.write(' attr {\n')
|
||||||
|
f.write(' key: \"%s\"\n' % key)
|
||||||
|
f.write(' value {\n')
|
||||||
|
printAttr(value, 6)
|
||||||
|
f.write(' }\n')
|
||||||
|
f.write(' }\n')
|
||||||
|
f.write('}\n')
|
||||||
|
|
||||||
|
|
||||||
|
def parseTextGraph(filePath):
|
||||||
|
msg = readTextMessage(filePath)
|
||||||
|
|
||||||
|
graph = GraphDef()
|
||||||
|
for node in msg['node']:
|
||||||
|
graphNode = NodeDef()
|
||||||
|
graphNode.name = node['name'][0]
|
||||||
|
graphNode.op = node['op'][0]
|
||||||
|
graphNode.input = node['input'] if 'input' in node else []
|
||||||
|
|
||||||
|
if 'attr' in node:
|
||||||
|
for attr in node['attr']:
|
||||||
|
graphNode.attr[attr['key'][0]] = attr['value'][0]
|
||||||
|
|
||||||
|
graph.node.append(graphNode)
|
||||||
|
return graph
|
||||||
|
|
||||||
|
|
||||||
# Removes Identity nodes
|
# Removes Identity nodes
|
||||||
def removeIdentity(graph_def):
|
def removeIdentity(graph_def):
|
||||||
identities = {}
|
identities = {}
|
||||||
|
@ -1,28 +1,11 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import cv2 as cv
|
||||||
|
|
||||||
from tensorflow.core.framework.node_def_pb2 import NodeDef
|
|
||||||
from tensorflow.tools.graph_transforms import TransformGraph
|
|
||||||
from google.protobuf import text_format
|
|
||||||
|
|
||||||
from tf_text_graph_common import *
|
from tf_text_graph_common import *
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
|
||||||
'SSD model from TensorFlow Object Detection API. '
|
|
||||||
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
|
|
||||||
parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
|
|
||||||
parser.add_argument('--output', required=True, help='Path to output text graph.')
|
|
||||||
parser.add_argument('--num_classes', default=90, type=int, help='Number of trained classes.')
|
|
||||||
parser.add_argument('--scales', default=[0.25, 0.5, 1.0, 2.0], type=float, nargs='+',
|
|
||||||
help='Hyper-parameter of grid_anchor_generator from a config file.')
|
|
||||||
parser.add_argument('--aspect_ratios', default=[0.5, 1.0, 2.0], type=float, nargs='+',
|
|
||||||
help='Hyper-parameter of grid_anchor_generator from a config file.')
|
|
||||||
parser.add_argument('--features_stride', default=16, type=float, nargs='+',
|
|
||||||
help='Hyper-parameter from a config file.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',
|
def createFasterRCNNGraph(modelPath, configPath, outputPath):
|
||||||
|
scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',
|
||||||
'FirstStageBoxPredictor/BoxEncodingPredictor',
|
'FirstStageBoxPredictor/BoxEncodingPredictor',
|
||||||
'FirstStageBoxPredictor/ClassPredictor',
|
'FirstStageBoxPredictor/ClassPredictor',
|
||||||
'CropAndResize',
|
'CropAndResize',
|
||||||
@ -33,121 +16,139 @@ scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',
|
|||||||
'Preprocessor/mul',
|
'Preprocessor/mul',
|
||||||
'image_tensor')
|
'image_tensor')
|
||||||
|
|
||||||
scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
|
scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
|
||||||
'FirstStageFeatureExtractor/Shape',
|
'FirstStageFeatureExtractor/Shape',
|
||||||
'FirstStageFeatureExtractor/strided_slice',
|
'FirstStageFeatureExtractor/strided_slice',
|
||||||
'FirstStageFeatureExtractor/GreaterEqual',
|
'FirstStageFeatureExtractor/GreaterEqual',
|
||||||
'FirstStageFeatureExtractor/LogicalAnd')
|
'FirstStageFeatureExtractor/LogicalAnd')
|
||||||
|
|
||||||
# Read the graph.
|
# Load a config file.
|
||||||
with tf.gfile.FastGFile(args.input, 'rb') as f:
|
config = readTextMessage(configPath)
|
||||||
graph_def = tf.GraphDef()
|
config = config['model'][0]['faster_rcnn'][0]
|
||||||
graph_def.ParseFromString(f.read())
|
num_classes = int(config['num_classes'][0])
|
||||||
|
|
||||||
removeIdentity(graph_def)
|
grid_anchor_generator = config['first_stage_anchor_generator'][0]['grid_anchor_generator'][0]
|
||||||
|
scales = [float(s) for s in grid_anchor_generator['scales']]
|
||||||
|
aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']]
|
||||||
|
width_stride = float(grid_anchor_generator['width_stride'][0])
|
||||||
|
height_stride = float(grid_anchor_generator['height_stride'][0])
|
||||||
|
features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0])
|
||||||
|
|
||||||
def to_remove(name, op):
|
print('Number of classes: %d' % num_classes)
|
||||||
|
print('Scales: %s' % str(scales))
|
||||||
|
print('Aspect ratios: %s' % str(aspect_ratios))
|
||||||
|
print('Width stride: %f' % width_stride)
|
||||||
|
print('Height stride: %f' % height_stride)
|
||||||
|
print('Features stride: %f' % features_stride)
|
||||||
|
|
||||||
|
# Read the graph.
|
||||||
|
cv.dnn.writeTextGraph(modelPath, outputPath)
|
||||||
|
graph_def = parseTextGraph(outputPath)
|
||||||
|
|
||||||
|
removeIdentity(graph_def)
|
||||||
|
|
||||||
|
def to_remove(name, op):
|
||||||
return name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep)
|
return name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep)
|
||||||
|
|
||||||
removeUnusedNodesAndAttrs(to_remove, graph_def)
|
removeUnusedNodesAndAttrs(to_remove, graph_def)
|
||||||
|
|
||||||
|
|
||||||
# Connect input node to the first layer
|
# Connect input node to the first layer
|
||||||
assert(graph_def.node[0].op == 'Placeholder')
|
assert(graph_def.node[0].op == 'Placeholder')
|
||||||
graph_def.node[1].input.insert(0, graph_def.node[0].name)
|
graph_def.node[1].input.insert(0, graph_def.node[0].name)
|
||||||
|
|
||||||
# Temporarily remove top nodes.
|
# Temporarily remove top nodes.
|
||||||
topNodes = []
|
topNodes = []
|
||||||
while True:
|
while True:
|
||||||
node = graph_def.node.pop()
|
node = graph_def.node.pop()
|
||||||
topNodes.append(node)
|
topNodes.append(node)
|
||||||
if node.op == 'CropAndResize':
|
if node.op == 'CropAndResize':
|
||||||
break
|
break
|
||||||
|
|
||||||
addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd',
|
addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd',
|
||||||
'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2], graph_def)
|
'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2], graph_def)
|
||||||
|
|
||||||
addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1',
|
addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1',
|
||||||
'FirstStageBoxPredictor/ClassPredictor/softmax', graph_def) # Compare with Reshape_4
|
'FirstStageBoxPredictor/ClassPredictor/softmax', graph_def) # Compare with Reshape_4
|
||||||
|
|
||||||
addFlatten('FirstStageBoxPredictor/ClassPredictor/softmax',
|
addFlatten('FirstStageBoxPredictor/ClassPredictor/softmax',
|
||||||
'FirstStageBoxPredictor/ClassPredictor/softmax/flatten', graph_def)
|
'FirstStageBoxPredictor/ClassPredictor/softmax/flatten', graph_def)
|
||||||
|
|
||||||
# Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd
|
# Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd
|
||||||
addFlatten('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd',
|
addFlatten('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd',
|
||||||
'FirstStageBoxPredictor/BoxEncodingPredictor/flatten', graph_def)
|
'FirstStageBoxPredictor/BoxEncodingPredictor/flatten', graph_def)
|
||||||
|
|
||||||
proposals = NodeDef()
|
proposals = NodeDef()
|
||||||
proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized)
|
proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized)
|
||||||
proposals.op = 'PriorBox'
|
proposals.op = 'PriorBox'
|
||||||
proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
|
proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
|
||||||
proposals.input.append(graph_def.node[0].name) # image_tensor
|
proposals.input.append(graph_def.node[0].name) # image_tensor
|
||||||
|
|
||||||
text_format.Merge('b: false', proposals.attr["flip"])
|
proposals.addAttr('flip', False)
|
||||||
text_format.Merge('b: true', proposals.attr["clip"])
|
proposals.addAttr('clip', True)
|
||||||
text_format.Merge('f: %f' % args.features_stride, proposals.attr["step"])
|
proposals.addAttr('step', features_stride)
|
||||||
text_format.Merge('f: 0.0', proposals.attr["offset"])
|
proposals.addAttr('offset', 0.0)
|
||||||
text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), proposals.attr["variance"])
|
proposals.addAttr('variance', [0.1, 0.1, 0.2, 0.2])
|
||||||
|
|
||||||
widths = []
|
widths = []
|
||||||
heights = []
|
heights = []
|
||||||
for a in args.aspect_ratios:
|
for a in aspect_ratios:
|
||||||
for s in args.scales:
|
for s in scales:
|
||||||
ar = np.sqrt(a)
|
ar = np.sqrt(a)
|
||||||
heights.append((args.features_stride**2) * s / ar)
|
heights.append((height_stride**2) * s / ar)
|
||||||
widths.append((args.features_stride**2) * s * ar)
|
widths.append((width_stride**2) * s * ar)
|
||||||
|
|
||||||
text_format.Merge(tensorMsg(widths), proposals.attr["width"])
|
proposals.addAttr('width', widths)
|
||||||
text_format.Merge(tensorMsg(heights), proposals.attr["height"])
|
proposals.addAttr('height', heights)
|
||||||
|
|
||||||
graph_def.node.extend([proposals])
|
graph_def.node.extend([proposals])
|
||||||
|
|
||||||
# Compare with Reshape_5
|
# Compare with Reshape_5
|
||||||
detectionOut = NodeDef()
|
detectionOut = NodeDef()
|
||||||
detectionOut.name = 'detection_out'
|
detectionOut.name = 'detection_out'
|
||||||
detectionOut.op = 'DetectionOutput'
|
detectionOut.op = 'DetectionOutput'
|
||||||
|
|
||||||
detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten')
|
detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten')
|
||||||
detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten')
|
detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten')
|
||||||
detectionOut.input.append('proposals')
|
detectionOut.input.append('proposals')
|
||||||
|
|
||||||
text_format.Merge('i: 2', detectionOut.attr['num_classes'])
|
detectionOut.addAttr('num_classes', 2)
|
||||||
text_format.Merge('b: true', detectionOut.attr['share_location'])
|
detectionOut.addAttr('share_location', True)
|
||||||
text_format.Merge('i: 0', detectionOut.attr['background_label_id'])
|
detectionOut.addAttr('background_label_id', 0)
|
||||||
text_format.Merge('f: 0.7', detectionOut.attr['nms_threshold'])
|
detectionOut.addAttr('nms_threshold', 0.7)
|
||||||
text_format.Merge('i: 6000', detectionOut.attr['top_k'])
|
detectionOut.addAttr('top_k', 6000)
|
||||||
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
|
detectionOut.addAttr('code_type', "CENTER_SIZE")
|
||||||
text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
|
detectionOut.addAttr('keep_top_k', 100)
|
||||||
text_format.Merge('b: false', detectionOut.attr['clip'])
|
detectionOut.addAttr('clip', False)
|
||||||
|
|
||||||
graph_def.node.extend([detectionOut])
|
graph_def.node.extend([detectionOut])
|
||||||
|
|
||||||
addConstNode('clip_by_value/lower', [0.0], graph_def)
|
addConstNode('clip_by_value/lower', [0.0], graph_def)
|
||||||
addConstNode('clip_by_value/upper', [1.0], graph_def)
|
addConstNode('clip_by_value/upper', [1.0], graph_def)
|
||||||
|
|
||||||
clipByValueNode = NodeDef()
|
clipByValueNode = NodeDef()
|
||||||
clipByValueNode.name = 'detection_out/clip_by_value'
|
clipByValueNode.name = 'detection_out/clip_by_value'
|
||||||
clipByValueNode.op = 'ClipByValue'
|
clipByValueNode.op = 'ClipByValue'
|
||||||
clipByValueNode.input.append('detection_out')
|
clipByValueNode.input.append('detection_out')
|
||||||
clipByValueNode.input.append('clip_by_value/lower')
|
clipByValueNode.input.append('clip_by_value/lower')
|
||||||
clipByValueNode.input.append('clip_by_value/upper')
|
clipByValueNode.input.append('clip_by_value/upper')
|
||||||
graph_def.node.extend([clipByValueNode])
|
graph_def.node.extend([clipByValueNode])
|
||||||
|
|
||||||
# Save as text.
|
# Save as text.
|
||||||
for node in reversed(topNodes):
|
for node in reversed(topNodes):
|
||||||
graph_def.node.extend([node])
|
graph_def.node.extend([node])
|
||||||
|
|
||||||
addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax', graph_def)
|
addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax', graph_def)
|
||||||
|
|
||||||
addSlice('SecondStageBoxPredictor/Reshape_1/softmax',
|
addSlice('SecondStageBoxPredictor/Reshape_1/softmax',
|
||||||
'SecondStageBoxPredictor/Reshape_1/slice',
|
'SecondStageBoxPredictor/Reshape_1/slice',
|
||||||
[0, 0, 1], [-1, -1, -1], graph_def)
|
[0, 0, 1], [-1, -1, -1], graph_def)
|
||||||
|
|
||||||
addReshape('SecondStageBoxPredictor/Reshape_1/slice',
|
addReshape('SecondStageBoxPredictor/Reshape_1/slice',
|
||||||
'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1], graph_def)
|
'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1], graph_def)
|
||||||
|
|
||||||
# Replace Flatten subgraph onto a single node.
|
# Replace Flatten subgraph onto a single node.
|
||||||
for i in reversed(range(len(graph_def.node))):
|
for i in reversed(range(len(graph_def.node))):
|
||||||
if graph_def.node[i].op == 'CropAndResize':
|
if graph_def.node[i].op == 'CropAndResize':
|
||||||
graph_def.node[i].input.insert(1, 'detection_out/clip_by_value')
|
graph_def.node[i].input.insert(1, 'detection_out/clip_by_value')
|
||||||
|
|
||||||
@ -162,53 +163,66 @@ for i in reversed(range(len(graph_def.node))):
|
|||||||
'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape']:
|
'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape']:
|
||||||
del graph_def.node[i]
|
del graph_def.node[i]
|
||||||
|
|
||||||
for node in graph_def.node:
|
for node in graph_def.node:
|
||||||
if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape':
|
if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape':
|
||||||
node.op = 'Flatten'
|
node.op = 'Flatten'
|
||||||
node.input.pop()
|
node.input.pop()
|
||||||
|
|
||||||
if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D',
|
if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D',
|
||||||
'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']:
|
'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']:
|
||||||
text_format.Merge('b: true', node.attr["loc_pred_transposed"])
|
node.addAttr('loc_pred_transposed', True)
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
### Postprocessing
|
### Postprocessing
|
||||||
################################################################################
|
################################################################################
|
||||||
addSlice('detection_out/clip_by_value', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4], graph_def)
|
addSlice('detection_out/clip_by_value', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4], graph_def)
|
||||||
|
|
||||||
variance = NodeDef()
|
variance = NodeDef()
|
||||||
variance.name = 'proposals/variance'
|
variance.name = 'proposals/variance'
|
||||||
variance.op = 'Const'
|
variance.op = 'Const'
|
||||||
text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), variance.attr["value"])
|
variance.addAttr('value', [0.1, 0.1, 0.2, 0.2])
|
||||||
graph_def.node.extend([variance])
|
graph_def.node.extend([variance])
|
||||||
|
|
||||||
varianceEncoder = NodeDef()
|
varianceEncoder = NodeDef()
|
||||||
varianceEncoder.name = 'variance_encoded'
|
varianceEncoder.name = 'variance_encoded'
|
||||||
varianceEncoder.op = 'Mul'
|
varianceEncoder.op = 'Mul'
|
||||||
varianceEncoder.input.append('SecondStageBoxPredictor/Reshape')
|
varianceEncoder.input.append('SecondStageBoxPredictor/Reshape')
|
||||||
varianceEncoder.input.append(variance.name)
|
varianceEncoder.input.append(variance.name)
|
||||||
text_format.Merge('i: 2', varianceEncoder.attr["axis"])
|
varianceEncoder.addAttr('axis', 2)
|
||||||
graph_def.node.extend([varianceEncoder])
|
graph_def.node.extend([varianceEncoder])
|
||||||
|
|
||||||
addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1], graph_def)
|
addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1], graph_def)
|
||||||
addFlatten('variance_encoded', 'variance_encoded/flatten', graph_def)
|
addFlatten('variance_encoded', 'variance_encoded/flatten', graph_def)
|
||||||
|
|
||||||
detectionOut = NodeDef()
|
detectionOut = NodeDef()
|
||||||
detectionOut.name = 'detection_out_final'
|
detectionOut.name = 'detection_out_final'
|
||||||
detectionOut.op = 'DetectionOutput'
|
detectionOut.op = 'DetectionOutput'
|
||||||
|
|
||||||
detectionOut.input.append('variance_encoded/flatten')
|
detectionOut.input.append('variance_encoded/flatten')
|
||||||
detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape')
|
detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape')
|
||||||
detectionOut.input.append('detection_out/slice/reshape')
|
detectionOut.input.append('detection_out/slice/reshape')
|
||||||
|
|
||||||
text_format.Merge('i: %d' % args.num_classes, detectionOut.attr['num_classes'])
|
detectionOut.addAttr('num_classes', num_classes)
|
||||||
text_format.Merge('b: false', detectionOut.attr['share_location'])
|
detectionOut.addAttr('share_location', False)
|
||||||
text_format.Merge('i: %d' % (args.num_classes + 1), detectionOut.attr['background_label_id'])
|
detectionOut.addAttr('background_label_id', num_classes + 1)
|
||||||
text_format.Merge('f: 0.6', detectionOut.attr['nms_threshold'])
|
detectionOut.addAttr('nms_threshold', 0.6)
|
||||||
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
|
detectionOut.addAttr('code_type', "CENTER_SIZE")
|
||||||
text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
|
detectionOut.addAttr('keep_top_k', 100)
|
||||||
text_format.Merge('b: true', detectionOut.attr['clip'])
|
detectionOut.addAttr('clip', True)
|
||||||
text_format.Merge('b: true', detectionOut.attr['variance_encoded_in_target'])
|
detectionOut.addAttr('variance_encoded_in_target', True)
|
||||||
graph_def.node.extend([detectionOut])
|
graph_def.node.extend([detectionOut])
|
||||||
|
|
||||||
tf.train.write_graph(graph_def, "", args.output, as_text=True)
|
# Save as text.
|
||||||
|
graph_def.save(outputPath)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
||||||
|
'Faster-RCNN model from TensorFlow Object Detection API. '
|
||||||
|
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
|
||||||
|
parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
|
||||||
|
parser.add_argument('--output', required=True, help='Path to output text graph.')
|
||||||
|
parser.add_argument('--config', required=True, help='Path to a *.config file is used for training.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
createFasterRCNNGraph(args.input, args.config, args.output)
|
||||||
|
@ -1,11 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import cv2 as cv
|
||||||
|
|
||||||
from tensorflow.core.framework.node_def_pb2 import NodeDef
|
|
||||||
from tensorflow.tools.graph_transforms import TransformGraph
|
|
||||||
from google.protobuf import text_format
|
|
||||||
|
|
||||||
from tf_text_graph_common import *
|
from tf_text_graph_common import *
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
||||||
@ -13,13 +8,7 @@ parser = argparse.ArgumentParser(description='Run this script to get a text grap
|
|||||||
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
|
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
|
||||||
parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
|
parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
|
||||||
parser.add_argument('--output', required=True, help='Path to output text graph.')
|
parser.add_argument('--output', required=True, help='Path to output text graph.')
|
||||||
parser.add_argument('--num_classes', default=90, type=int, help='Number of trained classes.')
|
parser.add_argument('--config', required=True, help='Path to a *.config file is used for training.')
|
||||||
parser.add_argument('--scales', default=[0.25, 0.5, 1.0, 2.0], type=float, nargs='+',
|
|
||||||
help='Hyper-parameter of grid_anchor_generator from a config file.')
|
|
||||||
parser.add_argument('--aspect_ratios', default=[0.5, 1.0, 2.0], type=float, nargs='+',
|
|
||||||
help='Hyper-parameter of grid_anchor_generator from a config file.')
|
|
||||||
parser.add_argument('--features_stride', default=16, type=float, nargs='+',
|
|
||||||
help='Hyper-parameter from a config file.')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',
|
scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',
|
||||||
@ -39,11 +28,28 @@ scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
|
|||||||
'FirstStageFeatureExtractor/GreaterEqual',
|
'FirstStageFeatureExtractor/GreaterEqual',
|
||||||
'FirstStageFeatureExtractor/LogicalAnd')
|
'FirstStageFeatureExtractor/LogicalAnd')
|
||||||
|
|
||||||
|
# Load a config file.
|
||||||
|
config = readTextMessage(args.config)
|
||||||
|
config = config['model'][0]['faster_rcnn'][0]
|
||||||
|
num_classes = int(config['num_classes'][0])
|
||||||
|
|
||||||
|
grid_anchor_generator = config['first_stage_anchor_generator'][0]['grid_anchor_generator'][0]
|
||||||
|
scales = [float(s) for s in grid_anchor_generator['scales']]
|
||||||
|
aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']]
|
||||||
|
width_stride = float(grid_anchor_generator['width_stride'][0])
|
||||||
|
height_stride = float(grid_anchor_generator['height_stride'][0])
|
||||||
|
features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0])
|
||||||
|
|
||||||
|
print('Number of classes: %d' % num_classes)
|
||||||
|
print('Scales: %s' % str(scales))
|
||||||
|
print('Aspect ratios: %s' % str(aspect_ratios))
|
||||||
|
print('Width stride: %f' % width_stride)
|
||||||
|
print('Height stride: %f' % height_stride)
|
||||||
|
print('Features stride: %f' % features_stride)
|
||||||
|
|
||||||
# Read the graph.
|
# Read the graph.
|
||||||
with tf.gfile.FastGFile(args.input, 'rb') as f:
|
cv.dnn.writeTextGraph(args.input, args.output)
|
||||||
graph_def = tf.GraphDef()
|
graph_def = parseTextGraph(args.output)
|
||||||
graph_def.ParseFromString(f.read())
|
|
||||||
|
|
||||||
removeIdentity(graph_def)
|
removeIdentity(graph_def)
|
||||||
|
|
||||||
@ -87,22 +93,22 @@ proposals.op = 'PriorBox'
|
|||||||
proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
|
proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
|
||||||
proposals.input.append(graph_def.node[0].name) # image_tensor
|
proposals.input.append(graph_def.node[0].name) # image_tensor
|
||||||
|
|
||||||
text_format.Merge('b: false', proposals.attr["flip"])
|
proposals.addAttr('flip', False)
|
||||||
text_format.Merge('b: true', proposals.attr["clip"])
|
proposals.addAttr('clip', True)
|
||||||
text_format.Merge('f: %f' % args.features_stride, proposals.attr["step"])
|
proposals.addAttr('step', features_stride)
|
||||||
text_format.Merge('f: 0.0', proposals.attr["offset"])
|
proposals.addAttr('offset', 0.0)
|
||||||
text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), proposals.attr["variance"])
|
proposals.addAttr('variance', [0.1, 0.1, 0.2, 0.2])
|
||||||
|
|
||||||
widths = []
|
widths = []
|
||||||
heights = []
|
heights = []
|
||||||
for a in args.aspect_ratios:
|
for a in aspect_ratios:
|
||||||
for s in args.scales:
|
for s in scales:
|
||||||
ar = np.sqrt(a)
|
ar = np.sqrt(a)
|
||||||
heights.append((args.features_stride**2) * s / ar)
|
heights.append((features_stride**2) * s / ar)
|
||||||
widths.append((args.features_stride**2) * s * ar)
|
widths.append((features_stride**2) * s * ar)
|
||||||
|
|
||||||
text_format.Merge(tensorMsg(widths), proposals.attr["width"])
|
proposals.addAttr('width', widths)
|
||||||
text_format.Merge(tensorMsg(heights), proposals.attr["height"])
|
proposals.addAttr('height', heights)
|
||||||
|
|
||||||
graph_def.node.extend([proposals])
|
graph_def.node.extend([proposals])
|
||||||
|
|
||||||
@ -115,14 +121,14 @@ detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten')
|
|||||||
detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten')
|
detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten')
|
||||||
detectionOut.input.append('proposals')
|
detectionOut.input.append('proposals')
|
||||||
|
|
||||||
text_format.Merge('i: 2', detectionOut.attr['num_classes'])
|
detectionOut.addAttr('num_classes', 2)
|
||||||
text_format.Merge('b: true', detectionOut.attr['share_location'])
|
detectionOut.addAttr('share_location', True)
|
||||||
text_format.Merge('i: 0', detectionOut.attr['background_label_id'])
|
detectionOut.addAttr('background_label_id', 0)
|
||||||
text_format.Merge('f: 0.7', detectionOut.attr['nms_threshold'])
|
detectionOut.addAttr('nms_threshold', 0.7)
|
||||||
text_format.Merge('i: 6000', detectionOut.attr['top_k'])
|
detectionOut.addAttr('top_k', 6000)
|
||||||
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
|
detectionOut.addAttr('code_type', "CENTER_SIZE")
|
||||||
text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
|
detectionOut.addAttr('keep_top_k', 100)
|
||||||
text_format.Merge('b: true', detectionOut.attr['clip'])
|
detectionOut.addAttr('clip', True)
|
||||||
|
|
||||||
graph_def.node.extend([detectionOut])
|
graph_def.node.extend([detectionOut])
|
||||||
|
|
||||||
@ -171,7 +177,7 @@ for node in graph_def.node:
|
|||||||
|
|
||||||
if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D',
|
if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D',
|
||||||
'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']:
|
'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']:
|
||||||
text_format.Merge('b: true', node.attr["loc_pred_transposed"])
|
node.addAttr('loc_pred_transposed', True)
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
### Postprocessing
|
### Postprocessing
|
||||||
@ -181,7 +187,7 @@ addSlice('detection_out', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4],
|
|||||||
variance = NodeDef()
|
variance = NodeDef()
|
||||||
variance.name = 'proposals/variance'
|
variance.name = 'proposals/variance'
|
||||||
variance.op = 'Const'
|
variance.op = 'Const'
|
||||||
text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), variance.attr["value"])
|
variance.addAttr('value', [0.1, 0.1, 0.2, 0.2])
|
||||||
graph_def.node.extend([variance])
|
graph_def.node.extend([variance])
|
||||||
|
|
||||||
varianceEncoder = NodeDef()
|
varianceEncoder = NodeDef()
|
||||||
@ -189,7 +195,7 @@ varianceEncoder.name = 'variance_encoded'
|
|||||||
varianceEncoder.op = 'Mul'
|
varianceEncoder.op = 'Mul'
|
||||||
varianceEncoder.input.append('SecondStageBoxPredictor/Reshape')
|
varianceEncoder.input.append('SecondStageBoxPredictor/Reshape')
|
||||||
varianceEncoder.input.append(variance.name)
|
varianceEncoder.input.append(variance.name)
|
||||||
text_format.Merge('i: 2', varianceEncoder.attr["axis"])
|
varianceEncoder.addAttr('axis', 2)
|
||||||
graph_def.node.extend([varianceEncoder])
|
graph_def.node.extend([varianceEncoder])
|
||||||
|
|
||||||
addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1], graph_def)
|
addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1], graph_def)
|
||||||
@ -203,16 +209,16 @@ detectionOut.input.append('variance_encoded/flatten')
|
|||||||
detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape')
|
detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape')
|
||||||
detectionOut.input.append('detection_out/slice/reshape')
|
detectionOut.input.append('detection_out/slice/reshape')
|
||||||
|
|
||||||
text_format.Merge('i: %d' % args.num_classes, detectionOut.attr['num_classes'])
|
detectionOut.addAttr('num_classes', num_classes)
|
||||||
text_format.Merge('b: false', detectionOut.attr['share_location'])
|
detectionOut.addAttr('share_location', False)
|
||||||
text_format.Merge('i: %d' % (args.num_classes + 1), detectionOut.attr['background_label_id'])
|
detectionOut.addAttr('background_label_id', num_classes + 1)
|
||||||
text_format.Merge('f: 0.6', detectionOut.attr['nms_threshold'])
|
detectionOut.addAttr('nms_threshold', 0.6)
|
||||||
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
|
detectionOut.addAttr('code_type', "CENTER_SIZE")
|
||||||
text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
|
detectionOut.addAttr('keep_top_k',100)
|
||||||
text_format.Merge('b: true', detectionOut.attr['clip'])
|
detectionOut.addAttr('clip', True)
|
||||||
text_format.Merge('b: true', detectionOut.attr['variance_encoded_in_target'])
|
detectionOut.addAttr('variance_encoded_in_target', True)
|
||||||
text_format.Merge('f: 0.3', detectionOut.attr['confidence_threshold'])
|
detectionOut.addAttr('confidence_threshold', 0.3)
|
||||||
text_format.Merge('b: false', detectionOut.attr['group_by_classes'])
|
detectionOut.addAttr('group_by_classes', False)
|
||||||
graph_def.node.extend([detectionOut])
|
graph_def.node.extend([detectionOut])
|
||||||
|
|
||||||
for node in reversed(topNodes):
|
for node in reversed(topNodes):
|
||||||
@ -227,4 +233,5 @@ graph_def.node[-1].name = 'detection_masks'
|
|||||||
graph_def.node[-1].op = 'Sigmoid'
|
graph_def.node[-1].op = 'Sigmoid'
|
||||||
graph_def.node[-1].input.pop()
|
graph_def.node[-1].input.pop()
|
||||||
|
|
||||||
tf.train.write_graph(graph_def, "", args.output, as_text=True)
|
# Save as text.
|
||||||
|
graph_def.save(args.output)
|
||||||
|
@ -9,52 +9,56 @@
|
|||||||
# deep learning network trained in TensorFlow Object Detection API.
|
# deep learning network trained in TensorFlow Object Detection API.
|
||||||
# Then you can import it with a binary frozen graph (.pb) using readNetFromTensorflow() function.
|
# Then you can import it with a binary frozen graph (.pb) using readNetFromTensorflow() function.
|
||||||
# See details and examples on the following wiki page: https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API
|
# See details and examples on the following wiki page: https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API
|
||||||
import tensorflow as tf
|
|
||||||
import argparse
|
import argparse
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
from tensorflow.core.framework.node_def_pb2 import NodeDef
|
import cv2 as cv
|
||||||
from tensorflow.tools.graph_transforms import TransformGraph
|
|
||||||
from google.protobuf import text_format
|
|
||||||
from tf_text_graph_common import *
|
from tf_text_graph_common import *
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
def createSSDGraph(modelPath, configPath, outputPath):
|
||||||
'SSD model from TensorFlow Object Detection API. '
|
# Nodes that should be kept.
|
||||||
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
|
keepOps = ['Conv2D', 'BiasAdd', 'Add', 'Relu6', 'Placeholder', 'FusedBatchNorm',
|
||||||
parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
|
|
||||||
parser.add_argument('--output', required=True, help='Path to output text graph.')
|
|
||||||
parser.add_argument('--num_classes', default=90, type=int, help='Number of trained classes.')
|
|
||||||
parser.add_argument('--min_scale', default=0.2, type=float, help='Hyper-parameter of ssd_anchor_generator from config file.')
|
|
||||||
parser.add_argument('--max_scale', default=0.95, type=float, help='Hyper-parameter of ssd_anchor_generator from config file.')
|
|
||||||
parser.add_argument('--num_layers', default=6, type=int, help='Hyper-parameter of ssd_anchor_generator from config file.')
|
|
||||||
parser.add_argument('--aspect_ratios', default=[1.0, 2.0, 0.5, 3.0, 0.333], type=float, nargs='+',
|
|
||||||
help='Hyper-parameter of ssd_anchor_generator from config file.')
|
|
||||||
parser.add_argument('--image_width', default=300, type=int, help='Training images width.')
|
|
||||||
parser.add_argument('--image_height', default=300, type=int, help='Training images height.')
|
|
||||||
parser.add_argument('--not_reduce_boxes_in_lowest_layer', default=False, action='store_true',
|
|
||||||
help='A boolean to indicate whether the fixed 3 boxes per '
|
|
||||||
'location is used in the lowest achors generation layer.')
|
|
||||||
parser.add_argument('--box_predictor', default='convolutional', type=str,
|
|
||||||
choices=['convolutional', 'weight_shared_convolutional'])
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Nodes that should be kept.
|
|
||||||
keepOps = ['Conv2D', 'BiasAdd', 'Add', 'Relu6', 'Placeholder', 'FusedBatchNorm',
|
|
||||||
'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool', 'Identity',
|
'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool', 'Identity',
|
||||||
'Sub']
|
'Sub']
|
||||||
|
|
||||||
# Node with which prefixes should be removed
|
# Node with which prefixes should be removed
|
||||||
prefixesToRemove = ('MultipleGridAnchorGenerator/', 'Postprocessor/', 'Preprocessor/map')
|
prefixesToRemove = ('MultipleGridAnchorGenerator/', 'Postprocessor/', 'Preprocessor/map')
|
||||||
|
|
||||||
# Read the graph.
|
# Load a config file.
|
||||||
with tf.gfile.FastGFile(args.input, 'rb') as f:
|
config = readTextMessage(configPath)
|
||||||
graph_def = tf.GraphDef()
|
config = config['model'][0]['ssd'][0]
|
||||||
graph_def.ParseFromString(f.read())
|
num_classes = int(config['num_classes'][0])
|
||||||
|
|
||||||
inpNames = ['image_tensor']
|
ssd_anchor_generator = config['anchor_generator'][0]['ssd_anchor_generator'][0]
|
||||||
outNames = ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes']
|
min_scale = float(ssd_anchor_generator['min_scale'][0])
|
||||||
graph_def = TransformGraph(graph_def, inpNames, outNames, ['sort_by_execution_order'])
|
max_scale = float(ssd_anchor_generator['max_scale'][0])
|
||||||
|
num_layers = int(ssd_anchor_generator['num_layers'][0])
|
||||||
|
aspect_ratios = [float(ar) for ar in ssd_anchor_generator['aspect_ratios']]
|
||||||
|
reduce_boxes_in_lowest_layer = True
|
||||||
|
if 'reduce_boxes_in_lowest_layer' in ssd_anchor_generator:
|
||||||
|
reduce_boxes_in_lowest_layer = ssd_anchor_generator['reduce_boxes_in_lowest_layer'][0] == 'true'
|
||||||
|
|
||||||
def getUnconnectedNodes():
|
fixed_shape_resizer = config['image_resizer'][0]['fixed_shape_resizer'][0]
|
||||||
|
image_width = int(fixed_shape_resizer['width'][0])
|
||||||
|
image_height = int(fixed_shape_resizer['height'][0])
|
||||||
|
|
||||||
|
box_predictor = 'convolutional' if 'convolutional_box_predictor' in config['box_predictor'][0] else 'weight_shared_convolutional'
|
||||||
|
|
||||||
|
print('Number of classes: %d' % num_classes)
|
||||||
|
print('Number of layers: %d' % num_layers)
|
||||||
|
print('Scale: [%f-%f]' % (min_scale, max_scale))
|
||||||
|
print('Aspect ratios: %s' % str(aspect_ratios))
|
||||||
|
print('Reduce boxes in the lowest layer: %s' % str(reduce_boxes_in_lowest_layer))
|
||||||
|
print('box predictor: %s' % box_predictor)
|
||||||
|
print('Input image size: %dx%d' % (image_width, image_height))
|
||||||
|
|
||||||
|
# Read the graph.
|
||||||
|
cv.dnn.writeTextGraph(modelPath, outputPath)
|
||||||
|
graph_def = parseTextGraph(outputPath)
|
||||||
|
|
||||||
|
inpNames = ['image_tensor']
|
||||||
|
outNames = ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes']
|
||||||
|
|
||||||
|
def getUnconnectedNodes():
|
||||||
unconnected = []
|
unconnected = []
|
||||||
for node in graph_def.node:
|
for node in graph_def.node:
|
||||||
unconnected.append(node.name)
|
unconnected.append(node.name)
|
||||||
@ -64,8 +68,8 @@ def getUnconnectedNodes():
|
|||||||
return unconnected
|
return unconnected
|
||||||
|
|
||||||
|
|
||||||
# Detect unfused batch normalization nodes and fuse them.
|
# Detect unfused batch normalization nodes and fuse them.
|
||||||
def fuse_batch_normalization():
|
def fuse_batch_normalization():
|
||||||
# Add_0 <-- moving_variance, add_y
|
# Add_0 <-- moving_variance, add_y
|
||||||
# Rsqrt <-- Add_0
|
# Rsqrt <-- Add_0
|
||||||
# Mul_0 <-- Rsqrt, gamma
|
# Mul_0 <-- Rsqrt, gamma
|
||||||
@ -107,34 +111,34 @@ def fuse_batch_normalization():
|
|||||||
node.input.append(inputs['beta'])
|
node.input.append(inputs['beta'])
|
||||||
node.input.append(inputs['moving_mean'])
|
node.input.append(inputs['moving_mean'])
|
||||||
node.input.append(inputs['moving_variance'])
|
node.input.append(inputs['moving_variance'])
|
||||||
text_format.Merge('f: 0.001', node.attr["epsilon"])
|
node.addAttr('epsilon', 0.001)
|
||||||
nodesToRemove += fusedNodes[1:]
|
nodesToRemove += fusedNodes[1:]
|
||||||
for node in nodesToRemove:
|
for node in nodesToRemove:
|
||||||
graph_def.node.remove(node)
|
graph_def.node.remove(node)
|
||||||
|
|
||||||
fuse_batch_normalization()
|
fuse_batch_normalization()
|
||||||
|
|
||||||
removeIdentity(graph_def)
|
removeIdentity(graph_def)
|
||||||
|
|
||||||
def to_remove(name, op):
|
def to_remove(name, op):
|
||||||
return (not op in keepOps) or name.startswith(prefixesToRemove)
|
return (not op in keepOps) or name.startswith(prefixesToRemove)
|
||||||
|
|
||||||
removeUnusedNodesAndAttrs(to_remove, graph_def)
|
removeUnusedNodesAndAttrs(to_remove, graph_def)
|
||||||
|
|
||||||
|
|
||||||
# Connect input node to the first layer
|
# Connect input node to the first layer
|
||||||
assert(graph_def.node[0].op == 'Placeholder')
|
assert(graph_def.node[0].op == 'Placeholder')
|
||||||
# assert(graph_def.node[1].op == 'Conv2D')
|
# assert(graph_def.node[1].op == 'Conv2D')
|
||||||
weights = graph_def.node[1].input[0]
|
weights = graph_def.node[1].input[0]
|
||||||
for i in range(len(graph_def.node[1].input)):
|
for i in range(len(graph_def.node[1].input)):
|
||||||
graph_def.node[1].input.pop()
|
graph_def.node[1].input.pop()
|
||||||
graph_def.node[1].input.append(graph_def.node[0].name)
|
graph_def.node[1].input.append(graph_def.node[0].name)
|
||||||
graph_def.node[1].input.append(weights)
|
graph_def.node[1].input.append(weights)
|
||||||
|
|
||||||
# Create SSD postprocessing head ###############################################
|
# Create SSD postprocessing head ###############################################
|
||||||
|
|
||||||
# Concatenate predictions of classes, predictions of bounding boxes and proposals.
|
# Concatenate predictions of classes, predictions of bounding boxes and proposals.
|
||||||
def addConcatNode(name, inputs, axisNodeName):
|
def addConcatNode(name, inputs, axisNodeName):
|
||||||
concat = NodeDef()
|
concat = NodeDef()
|
||||||
concat.name = name
|
concat.name = name
|
||||||
concat.op = 'ConcatV2'
|
concat.op = 'ConcatV2'
|
||||||
@ -143,15 +147,15 @@ def addConcatNode(name, inputs, axisNodeName):
|
|||||||
concat.input.append(axisNodeName)
|
concat.input.append(axisNodeName)
|
||||||
graph_def.node.extend([concat])
|
graph_def.node.extend([concat])
|
||||||
|
|
||||||
addConstNode('concat/axis_flatten', [-1], graph_def)
|
addConstNode('concat/axis_flatten', [-1], graph_def)
|
||||||
addConstNode('PriorBox/concat/axis', [-2], graph_def)
|
addConstNode('PriorBox/concat/axis', [-2], graph_def)
|
||||||
|
|
||||||
for label in ['ClassPredictor', 'BoxEncodingPredictor' if args.box_predictor is 'convolutional' else 'BoxPredictor']:
|
for label in ['ClassPredictor', 'BoxEncodingPredictor' if box_predictor is 'convolutional' else 'BoxPredictor']:
|
||||||
concatInputs = []
|
concatInputs = []
|
||||||
for i in range(args.num_layers):
|
for i in range(num_layers):
|
||||||
# Flatten predictions
|
# Flatten predictions
|
||||||
flatten = NodeDef()
|
flatten = NodeDef()
|
||||||
if args.box_predictor is 'convolutional':
|
if box_predictor is 'convolutional':
|
||||||
inpName = 'BoxPredictor_%d/%s/BiasAdd' % (i, label)
|
inpName = 'BoxPredictor_%d/%s/BiasAdd' % (i, label)
|
||||||
else:
|
else:
|
||||||
if i == 0:
|
if i == 0:
|
||||||
@ -166,25 +170,25 @@ for label in ['ClassPredictor', 'BoxEncodingPredictor' if args.box_predictor is
|
|||||||
graph_def.node.extend([flatten])
|
graph_def.node.extend([flatten])
|
||||||
addConcatNode('%s/concat' % label, concatInputs, 'concat/axis_flatten')
|
addConcatNode('%s/concat' % label, concatInputs, 'concat/axis_flatten')
|
||||||
|
|
||||||
idx = 0
|
idx = 0
|
||||||
for node in graph_def.node:
|
for node in graph_def.node:
|
||||||
if node.name == ('BoxPredictor_%d/BoxEncodingPredictor/Conv2D' % idx) or \
|
if node.name == ('BoxPredictor_%d/BoxEncodingPredictor/Conv2D' % idx) or \
|
||||||
node.name == ('WeightSharedConvolutionalBoxPredictor_%d/BoxPredictor/Conv2D' % idx) or \
|
node.name == ('WeightSharedConvolutionalBoxPredictor_%d/BoxPredictor/Conv2D' % idx) or \
|
||||||
node.name == 'WeightSharedConvolutionalBoxPredictor/BoxPredictor/Conv2D':
|
node.name == 'WeightSharedConvolutionalBoxPredictor/BoxPredictor/Conv2D':
|
||||||
text_format.Merge('b: true', node.attr["loc_pred_transposed"])
|
node.addAttr('loc_pred_transposed', True)
|
||||||
idx += 1
|
idx += 1
|
||||||
assert(idx == args.num_layers)
|
assert(idx == num_layers)
|
||||||
|
|
||||||
# Add layers that generate anchors (bounding boxes proposals).
|
# Add layers that generate anchors (bounding boxes proposals).
|
||||||
scales = [args.min_scale + (args.max_scale - args.min_scale) * i / (args.num_layers - 1)
|
scales = [min_scale + (max_scale - min_scale) * i / (num_layers - 1)
|
||||||
for i in range(args.num_layers)] + [1.0]
|
for i in range(num_layers)] + [1.0]
|
||||||
|
|
||||||
priorBoxes = []
|
priorBoxes = []
|
||||||
for i in range(args.num_layers):
|
for i in range(num_layers):
|
||||||
priorBox = NodeDef()
|
priorBox = NodeDef()
|
||||||
priorBox.name = 'PriorBox_%d' % i
|
priorBox.name = 'PriorBox_%d' % i
|
||||||
priorBox.op = 'PriorBox'
|
priorBox.op = 'PriorBox'
|
||||||
if args.box_predictor is 'convolutional':
|
if box_predictor is 'convolutional':
|
||||||
priorBox.input.append('BoxPredictor_%d/BoxEncodingPredictor/BiasAdd' % i)
|
priorBox.input.append('BoxPredictor_%d/BoxEncodingPredictor/BiasAdd' % i)
|
||||||
else:
|
else:
|
||||||
if i == 0:
|
if i == 0:
|
||||||
@ -193,59 +197,59 @@ for i in range(args.num_layers):
|
|||||||
priorBox.input.append('WeightSharedConvolutionalBoxPredictor_%d/BoxPredictor/BiasAdd' % i)
|
priorBox.input.append('WeightSharedConvolutionalBoxPredictor_%d/BoxPredictor/BiasAdd' % i)
|
||||||
priorBox.input.append(graph_def.node[0].name) # image_tensor
|
priorBox.input.append(graph_def.node[0].name) # image_tensor
|
||||||
|
|
||||||
text_format.Merge('b: false', priorBox.attr["flip"])
|
priorBox.addAttr('flip', False)
|
||||||
text_format.Merge('b: false', priorBox.attr["clip"])
|
priorBox.addAttr('clip', False)
|
||||||
|
|
||||||
if i == 0 and not args.not_reduce_boxes_in_lowest_layer:
|
if i == 0 and reduce_boxes_in_lowest_layer:
|
||||||
widths = [0.1, args.min_scale * sqrt(2.0), args.min_scale * sqrt(0.5)]
|
widths = [0.1, min_scale * sqrt(2.0), min_scale * sqrt(0.5)]
|
||||||
heights = [0.1, args.min_scale / sqrt(2.0), args.min_scale / sqrt(0.5)]
|
heights = [0.1, min_scale / sqrt(2.0), min_scale / sqrt(0.5)]
|
||||||
else:
|
else:
|
||||||
widths = [scales[i] * sqrt(ar) for ar in args.aspect_ratios]
|
widths = [scales[i] * sqrt(ar) for ar in aspect_ratios]
|
||||||
heights = [scales[i] / sqrt(ar) for ar in args.aspect_ratios]
|
heights = [scales[i] / sqrt(ar) for ar in aspect_ratios]
|
||||||
|
|
||||||
widths += [sqrt(scales[i] * scales[i + 1])]
|
widths += [sqrt(scales[i] * scales[i + 1])]
|
||||||
heights += [sqrt(scales[i] * scales[i + 1])]
|
heights += [sqrt(scales[i] * scales[i + 1])]
|
||||||
widths = [w * args.image_width for w in widths]
|
widths = [w * image_width for w in widths]
|
||||||
heights = [h * args.image_height for h in heights]
|
heights = [h * image_height for h in heights]
|
||||||
text_format.Merge(tensorMsg(widths), priorBox.attr["width"])
|
priorBox.addAttr('width', widths)
|
||||||
text_format.Merge(tensorMsg(heights), priorBox.attr["height"])
|
priorBox.addAttr('height', heights)
|
||||||
text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), priorBox.attr["variance"])
|
priorBox.addAttr('variance', [0.1, 0.1, 0.2, 0.2])
|
||||||
|
|
||||||
graph_def.node.extend([priorBox])
|
graph_def.node.extend([priorBox])
|
||||||
priorBoxes.append(priorBox.name)
|
priorBoxes.append(priorBox.name)
|
||||||
|
|
||||||
addConcatNode('PriorBox/concat', priorBoxes, 'concat/axis_flatten')
|
addConcatNode('PriorBox/concat', priorBoxes, 'concat/axis_flatten')
|
||||||
|
|
||||||
# Sigmoid for classes predictions and DetectionOutput layer
|
# Sigmoid for classes predictions and DetectionOutput layer
|
||||||
sigmoid = NodeDef()
|
sigmoid = NodeDef()
|
||||||
sigmoid.name = 'ClassPredictor/concat/sigmoid'
|
sigmoid.name = 'ClassPredictor/concat/sigmoid'
|
||||||
sigmoid.op = 'Sigmoid'
|
sigmoid.op = 'Sigmoid'
|
||||||
sigmoid.input.append('ClassPredictor/concat')
|
sigmoid.input.append('ClassPredictor/concat')
|
||||||
graph_def.node.extend([sigmoid])
|
graph_def.node.extend([sigmoid])
|
||||||
|
|
||||||
detectionOut = NodeDef()
|
detectionOut = NodeDef()
|
||||||
detectionOut.name = 'detection_out'
|
detectionOut.name = 'detection_out'
|
||||||
detectionOut.op = 'DetectionOutput'
|
detectionOut.op = 'DetectionOutput'
|
||||||
|
|
||||||
if args.box_predictor == 'convolutional':
|
if box_predictor == 'convolutional':
|
||||||
detectionOut.input.append('BoxEncodingPredictor/concat')
|
detectionOut.input.append('BoxEncodingPredictor/concat')
|
||||||
else:
|
else:
|
||||||
detectionOut.input.append('BoxPredictor/concat')
|
detectionOut.input.append('BoxPredictor/concat')
|
||||||
detectionOut.input.append(sigmoid.name)
|
detectionOut.input.append(sigmoid.name)
|
||||||
detectionOut.input.append('PriorBox/concat')
|
detectionOut.input.append('PriorBox/concat')
|
||||||
|
|
||||||
text_format.Merge('i: %d' % (args.num_classes + 1), detectionOut.attr['num_classes'])
|
detectionOut.addAttr('num_classes', num_classes + 1)
|
||||||
text_format.Merge('b: true', detectionOut.attr['share_location'])
|
detectionOut.addAttr('share_location', True)
|
||||||
text_format.Merge('i: 0', detectionOut.attr['background_label_id'])
|
detectionOut.addAttr('background_label_id', 0)
|
||||||
text_format.Merge('f: 0.6', detectionOut.attr['nms_threshold'])
|
detectionOut.addAttr('nms_threshold', 0.6)
|
||||||
text_format.Merge('i: 100', detectionOut.attr['top_k'])
|
detectionOut.addAttr('top_k', 100)
|
||||||
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
|
detectionOut.addAttr('code_type', "CENTER_SIZE")
|
||||||
text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
|
detectionOut.addAttr('keep_top_k', 100)
|
||||||
text_format.Merge('f: 0.01', detectionOut.attr['confidence_threshold'])
|
detectionOut.addAttr('confidence_threshold', 0.01)
|
||||||
|
|
||||||
graph_def.node.extend([detectionOut])
|
graph_def.node.extend([detectionOut])
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
unconnectedNodes = getUnconnectedNodes()
|
unconnectedNodes = getUnconnectedNodes()
|
||||||
unconnectedNodes.remove(detectionOut.name)
|
unconnectedNodes.remove(detectionOut.name)
|
||||||
if not unconnectedNodes:
|
if not unconnectedNodes:
|
||||||
@ -257,5 +261,17 @@ while True:
|
|||||||
del graph_def.node[i]
|
del graph_def.node[i]
|
||||||
break
|
break
|
||||||
|
|
||||||
# Save as text.
|
# Save as text.
|
||||||
tf.train.write_graph(graph_def, "", args.output, as_text=True)
|
graph_def.save(outputPath)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
||||||
|
'SSD model from TensorFlow Object Detection API. '
|
||||||
|
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
|
||||||
|
parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
|
||||||
|
parser.add_argument('--output', required=True, help='Path to output text graph.')
|
||||||
|
parser.add_argument('--config', required=True, help='Path to a *.config file is used for training.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
createSSDGraph(args.input, args.config, args.output)
|
||||||
|
Loading…
Reference in New Issue
Block a user