mirror of
https://github.com/opencv/opencv.git
synced 2025-01-18 14:13:15 +08:00
Merge pull request #11255 from dkurt:dnn_tf_faster_rcnn
This commit is contained in:
commit
c58cc4c2ff
@ -581,6 +581,12 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
|
||||
static Ptr<ProposalLayer> create(const LayerParams& params);
|
||||
};
|
||||
|
||||
class CV_EXPORTS CropAndResizeLayer : public Layer
|
||||
{
|
||||
public:
|
||||
static Ptr<Layer> create(const LayerParams& params);
|
||||
};
|
||||
|
||||
//! @}
|
||||
//! @}
|
||||
CV__DNN_EXPERIMENTAL_NS_END
|
||||
|
@ -84,6 +84,7 @@ void initializeLayerFactory()
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Reshape, ReshapeLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Flatten, FlattenLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(ResizeNearestNeighbor, ResizeNearestNeighborLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(CropAndResize, CropAndResizeLayer);
|
||||
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Convolution, ConvolutionLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Deconvolution, DeconvolutionLayer);
|
||||
|
108
modules/dnn/src/layers/crop_and_resize_layer.cpp
Normal file
108
modules/dnn/src/layers/crop_and_resize_layer.cpp
Normal file
@ -0,0 +1,108 @@
|
||||
#include "../precomp.hpp"
|
||||
#include "layers_common.hpp"
|
||||
|
||||
namespace cv { namespace dnn {
|
||||
|
||||
class CropAndResizeLayerImpl CV_FINAL : public CropAndResizeLayer
|
||||
{
|
||||
public:
|
||||
CropAndResizeLayerImpl(const LayerParams& params)
|
||||
{
|
||||
CV_Assert(params.has("width"), params.has("height"));
|
||||
outWidth = params.get<float>("width");
|
||||
outHeight = params.get<float>("height");
|
||||
}
|
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||
const int requiredOutputs,
|
||||
std::vector<MatShape> &outputs,
|
||||
std::vector<MatShape> &internals) const CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(inputs.size() == 2, inputs[0].size() == 4);
|
||||
if (inputs[0][0] != 1)
|
||||
CV_Error(Error::StsNotImplemented, "");
|
||||
outputs.resize(1, MatShape(4));
|
||||
outputs[0][0] = inputs[1][2]; // Number of bounding boxes.
|
||||
outputs[0][1] = inputs[0][1]; // Number of channels.
|
||||
outputs[0][2] = outHeight;
|
||||
outputs[0][3] = outWidth;
|
||||
return false;
|
||||
}
|
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
|
||||
|
||||
Layer::forward_fallback(inputs_arr, outputs_arr, internals_arr);
|
||||
}
|
||||
|
||||
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals) CV_OVERRIDE
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
|
||||
|
||||
Mat& inp = *inputs[0];
|
||||
Mat& out = outputs[0];
|
||||
Mat boxes = inputs[1]->reshape(1, inputs[1]->total() / 7);
|
||||
const int numChannels = inp.size[1];
|
||||
const int inpHeight = inp.size[2];
|
||||
const int inpWidth = inp.size[3];
|
||||
const int inpSpatialSize = inpHeight * inpWidth;
|
||||
const int outSpatialSize = outHeight * outWidth;
|
||||
CV_Assert(inp.isContinuous(), out.isContinuous());
|
||||
|
||||
for (int b = 0; b < boxes.rows; ++b)
|
||||
{
|
||||
float* outDataBox = out.ptr<float>(b);
|
||||
float left = boxes.at<float>(b, 3);
|
||||
float top = boxes.at<float>(b, 4);
|
||||
float right = boxes.at<float>(b, 5);
|
||||
float bottom = boxes.at<float>(b, 6);
|
||||
float boxWidth = right - left;
|
||||
float boxHeight = bottom - top;
|
||||
|
||||
float heightScale = boxHeight * static_cast<float>(inpHeight - 1) / (outHeight - 1);
|
||||
float widthScale = boxWidth * static_cast<float>(inpWidth - 1) / (outWidth - 1);
|
||||
for (int y = 0; y < outHeight; ++y)
|
||||
{
|
||||
float input_y = top * (inpHeight - 1) + y * heightScale;
|
||||
int y0 = static_cast<int>(input_y);
|
||||
const float* inpData_row0 = (float*)inp.data + y0 * inpWidth;
|
||||
const float* inpData_row1 = (y0 + 1 < inpHeight) ? (inpData_row0 + inpWidth) : inpData_row0;
|
||||
for (int x = 0; x < outWidth; ++x)
|
||||
{
|
||||
float input_x = left * (inpWidth - 1) + x * widthScale;
|
||||
int x0 = static_cast<int>(input_x);
|
||||
int x1 = std::min(x0 + 1, inpWidth - 1);
|
||||
|
||||
float* outData = outDataBox + y * outWidth + x;
|
||||
const float* inpData_row0_c = inpData_row0;
|
||||
const float* inpData_row1_c = inpData_row1;
|
||||
for (int c = 0; c < numChannels; ++c)
|
||||
{
|
||||
*outData = inpData_row0_c[x0] +
|
||||
(input_y - y0) * (inpData_row1_c[x0] - inpData_row0_c[x0]) +
|
||||
(input_x - x0) * (inpData_row0_c[x1] - inpData_row0_c[x0] +
|
||||
(input_y - y0) * (inpData_row1_c[x1] - inpData_row0_c[x1] - inpData_row1_c[x0] + inpData_row0_c[x0]));
|
||||
|
||||
inpData_row0_c += inpSpatialSize;
|
||||
inpData_row1_c += inpSpatialSize;
|
||||
outData += outSpatialSize;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int outWidth, outHeight;
|
||||
};
|
||||
|
||||
Ptr<Layer> CropAndResizeLayer::create(const LayerParams& params)
|
||||
{
|
||||
return Ptr<CropAndResizeLayer>(new CropAndResizeLayerImpl(params));
|
||||
}
|
||||
|
||||
} // namespace dnn
|
||||
} // namespace cv
|
@ -208,8 +208,9 @@ public:
|
||||
CV_Assert(inputs[0][0] == inputs[1][0]);
|
||||
|
||||
int numPriors = inputs[2][2] / 4;
|
||||
CV_Assert((numPriors * _numLocClasses * 4) == inputs[0][1]);
|
||||
CV_Assert(int(numPriors * _numClasses) == inputs[1][1]);
|
||||
CV_Assert((numPriors * _numLocClasses * 4) == total(inputs[0], 1));
|
||||
CV_Assert(int(numPriors * _numClasses) == total(inputs[1], 1));
|
||||
CV_Assert(inputs[2][1] == 1 + (int)(!_varianceEncodedInTarget));
|
||||
|
||||
// num() and channels() are 1.
|
||||
// Since the number of bboxes to be kept is unknown before nms, we manually
|
||||
|
@ -1094,9 +1094,9 @@ void TFImporter::populateNet(Net dstNet)
|
||||
CV_Assert(!begins.empty(), !sizes.empty(), begins.type() == CV_32SC1,
|
||||
sizes.type() == CV_32SC1);
|
||||
|
||||
if (begins.total() == 4)
|
||||
if (begins.total() == 4 && data_layouts[name] == DATA_LAYOUT_NHWC)
|
||||
{
|
||||
// Perhabs, we have an NHWC order. Swap it to NCHW.
|
||||
// Swap NHWC parameters' order to NCHW.
|
||||
std::swap(*begins.ptr<int32_t>(0, 2), *begins.ptr<int32_t>(0, 3));
|
||||
std::swap(*begins.ptr<int32_t>(0, 1), *begins.ptr<int32_t>(0, 2));
|
||||
std::swap(*sizes.ptr<int32_t>(0, 2), *sizes.ptr<int32_t>(0, 3));
|
||||
@ -1176,6 +1176,9 @@ void TFImporter::populateNet(Net dstNet)
|
||||
layers_to_ignore.insert(next_layers[0].first);
|
||||
}
|
||||
|
||||
if (hasLayerAttr(layer, "axis"))
|
||||
layerParams.set("axis", getLayerAttr(layer, "axis").i());
|
||||
|
||||
id = dstNet.addLayer(name, "Scale", layerParams);
|
||||
}
|
||||
layer_id[name] = id;
|
||||
@ -1547,6 +1550,10 @@ void TFImporter::populateNet(Net dstNet)
|
||||
layerParams.set("confidence_threshold", getLayerAttr(layer, "confidence_threshold").f());
|
||||
if (hasLayerAttr(layer, "loc_pred_transposed"))
|
||||
layerParams.set("loc_pred_transposed", getLayerAttr(layer, "loc_pred_transposed").b());
|
||||
if (hasLayerAttr(layer, "clip"))
|
||||
layerParams.set("clip", getLayerAttr(layer, "clip").b());
|
||||
if (hasLayerAttr(layer, "variance_encoded_in_target"))
|
||||
layerParams.set("variance_encoded_in_target", getLayerAttr(layer, "variance_encoded_in_target").b());
|
||||
|
||||
int id = dstNet.addLayer(name, "DetectionOutput", layerParams);
|
||||
layer_id[name] = id;
|
||||
@ -1563,6 +1570,26 @@ void TFImporter::populateNet(Net dstNet)
|
||||
layer_id[name] = id;
|
||||
connectToAllBlobs(layer_id, dstNet, parsePin(layer.input(0)), id, layer.input_size());
|
||||
}
|
||||
else if (type == "CropAndResize")
|
||||
{
|
||||
// op: "CropAndResize"
|
||||
// input: "input"
|
||||
// input: "boxes"
|
||||
// input: "sizes"
|
||||
CV_Assert(layer.input_size() == 3);
|
||||
|
||||
Mat cropSize = getTensorContent(getConstBlob(layer, value_id, 2));
|
||||
CV_Assert(cropSize.type() == CV_32SC1, cropSize.total() == 2);
|
||||
|
||||
layerParams.set("height", cropSize.at<int>(0));
|
||||
layerParams.set("width", cropSize.at<int>(1));
|
||||
|
||||
int id = dstNet.addLayer(name, "CropAndResize", layerParams);
|
||||
layer_id[name] = id;
|
||||
|
||||
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
|
||||
connect(layer_id, dstNet, parsePin(layer.input(1)), id, 1);
|
||||
}
|
||||
else if (type == "Mean")
|
||||
{
|
||||
Mat indices = getTensorContent(getConstBlob(layer, value_id, 1));
|
||||
|
@ -270,6 +270,22 @@ TEST_P(Test_TensorFlow_nets, Inception_v2_SSD)
|
||||
normAssertDetections(ref, out, "", 0.5);
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_nets, Inception_v2_Faster_RCNN)
|
||||
{
|
||||
std::string proto = findDataFile("dnn/faster_rcnn_inception_v2_coco_2018_01_28.pbtxt", false);
|
||||
std::string model = findDataFile("dnn/faster_rcnn_inception_v2_coco_2018_01_28.pb", false);
|
||||
|
||||
Net net = readNetFromTensorflow(model, proto);
|
||||
Mat img = imread(findDataFile("dnn/dog416.png", false));
|
||||
Mat blob = blobFromImage(img, 1.0f / 127.5, Size(800, 600), Scalar(127.5, 127.5, 127.5), true, false);
|
||||
|
||||
net.setInput(blob);
|
||||
Mat out = net.forward();
|
||||
|
||||
Mat ref = blobFromNPY(findDataFile("dnn/tensorflow/faster_rcnn_inception_v2_coco_2018_01_28.detection_out.npy"));
|
||||
normAssertDetections(ref, out, "", 0.3);
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_nets, opencv_face_detector_uint8)
|
||||
{
|
||||
std::string proto = findDataFile("dnn/opencv_face_detector.pbtxt", false);
|
||||
|
@ -11,8 +11,10 @@
|
||||
| [SSDs from TensorFlow](https://github.com/tensorflow/models/tree/master/research/object_detection/) | `0.00784 (2/255)` | `300x300` | `127.5 127.5 127.5` | RGB |
|
||||
| [YOLO](https://pjreddie.com/darknet/yolo/) | `0.00392 (1/255)` | `416x416` | `0 0 0` | RGB |
|
||||
| [VGG16-SSD](https://github.com/weiliu89/caffe/tree/ssd) | `1.0` | `300x300` | `104 117 123` | BGR |
|
||||
| [Faster-RCNN](https://github.com/rbgirshick/py-faster-rcnn) | `1.0` | `800x600` | `102.9801, 115.9465, 122.7717` | BGR |
|
||||
| [Faster-RCNN](https://github.com/rbgirshick/py-faster-rcnn) | `1.0` | `800x600` | `102.9801 115.9465 122.7717` | BGR |
|
||||
| [R-FCN](https://github.com/YuwenXiong/py-R-FCN) | `1.0` | `800x600` | `102.9801 115.9465 122.7717` | BGR |
|
||||
| [Faster-RCNN, ResNet backbone](https://github.com/tensorflow/models/tree/master/research/object_detection/) | `1.0` | `300x300` | `103.939 116.779 123.68` | RGB |
|
||||
| [Faster-RCNN, InceptionV2 backbone](https://github.com/tensorflow/models/tree/master/research/object_detection/) | `0.00784 (2/255)` | `300x300` | `127.5 127.5 127.5` | RGB |
|
||||
|
||||
#### Face detection
|
||||
[An origin model](https://github.com/opencv/opencv/tree/master/samples/dnn/face_detector)
|
||||
|
291
samples/dnn/tf_text_graph_faster_rcnn.py
Normal file
291
samples/dnn/tf_text_graph_faster_rcnn.py
Normal file
@ -0,0 +1,291 @@
|
||||
import argparse
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.core.framework.node_def_pb2 import NodeDef
|
||||
from tensorflow.tools.graph_transforms import TransformGraph
|
||||
from google.protobuf import text_format
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
||||
'SSD model from TensorFlow Object Detection API. '
|
||||
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
|
||||
parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
|
||||
parser.add_argument('--output', required=True, help='Path to output text graph.')
|
||||
parser.add_argument('--num_classes', default=90, type=int, help='Number of trained classes.')
|
||||
parser.add_argument('--scales', default=[0.25, 0.5, 1.0, 2.0], type=float, nargs='+',
|
||||
help='Hyper-parameter of grid_anchor_generator from a config file.')
|
||||
parser.add_argument('--aspect_ratios', default=[0.5, 1.0, 2.0], type=float, nargs='+',
|
||||
help='Hyper-parameter of grid_anchor_generator from a config file.')
|
||||
parser.add_argument('--features_stride', default=16, type=float, nargs='+',
|
||||
help='Hyper-parameter from a config file.')
|
||||
args = parser.parse_args()
|
||||
|
||||
scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',
|
||||
'FirstStageBoxPredictor/BoxEncodingPredictor',
|
||||
'FirstStageBoxPredictor/ClassPredictor',
|
||||
'CropAndResize',
|
||||
'MaxPool2D',
|
||||
'SecondStageFeatureExtractor',
|
||||
'SecondStageBoxPredictor',
|
||||
'image_tensor')
|
||||
|
||||
scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
|
||||
'FirstStageFeatureExtractor/Shape',
|
||||
'FirstStageFeatureExtractor/strided_slice',
|
||||
'FirstStageFeatureExtractor/GreaterEqual',
|
||||
'FirstStageFeatureExtractor/LogicalAnd')
|
||||
|
||||
unusedAttrs = ['T', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu',
|
||||
'Index', 'Tperm', 'is_training', 'Tpaddings']
|
||||
|
||||
# Read the graph.
|
||||
with tf.gfile.FastGFile(args.input, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
|
||||
# Removes Identity nodes
|
||||
def removeIdentity():
|
||||
identities = {}
|
||||
for node in graph_def.node:
|
||||
if node.op == 'Identity':
|
||||
identities[node.name] = node.input[0]
|
||||
graph_def.node.remove(node)
|
||||
|
||||
for node in graph_def.node:
|
||||
for i in range(len(node.input)):
|
||||
if node.input[i] in identities:
|
||||
node.input[i] = identities[node.input[i]]
|
||||
|
||||
removeIdentity()
|
||||
|
||||
removedNodes = []
|
||||
|
||||
for i in reversed(range(len(graph_def.node))):
|
||||
op = graph_def.node[i].op
|
||||
name = graph_def.node[i].name
|
||||
|
||||
if op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep):
|
||||
if op != 'Const':
|
||||
removedNodes.append(name)
|
||||
|
||||
del graph_def.node[i]
|
||||
else:
|
||||
for attr in unusedAttrs:
|
||||
if attr in graph_def.node[i].attr:
|
||||
del graph_def.node[i].attr[attr]
|
||||
|
||||
# Remove references to removed nodes except Const nodes.
|
||||
for node in graph_def.node:
|
||||
for i in reversed(range(len(node.input))):
|
||||
if node.input[i] in removedNodes:
|
||||
del node.input[i]
|
||||
|
||||
|
||||
# Connect input node to the first layer
|
||||
assert(graph_def.node[0].op == 'Placeholder')
|
||||
graph_def.node[1].input.insert(0, graph_def.node[0].name)
|
||||
|
||||
# Temporarily remove top nodes.
|
||||
topNodes = []
|
||||
while True:
|
||||
node = graph_def.node.pop()
|
||||
topNodes.append(node)
|
||||
if node.op == 'CropAndResize':
|
||||
break
|
||||
|
||||
def tensorMsg(values):
|
||||
if all([isinstance(v, float) for v in values]):
|
||||
dtype = 'DT_FLOAT'
|
||||
field = 'float_val'
|
||||
elif all([isinstance(v, int) for v in values]):
|
||||
dtype = 'DT_INT32'
|
||||
field = 'int_val'
|
||||
else:
|
||||
raise Exception('Wrong values types')
|
||||
|
||||
msg = 'tensor { dtype: ' + dtype + ' tensor_shape { dim { size: %d } }' % len(values)
|
||||
for value in values:
|
||||
msg += '%s: %s ' % (field, str(value))
|
||||
return msg + '}'
|
||||
|
||||
def addSlice(inp, out, begins, sizes):
|
||||
beginsNode = NodeDef()
|
||||
beginsNode.name = out + '/begins'
|
||||
beginsNode.op = 'Const'
|
||||
text_format.Merge(tensorMsg(begins), beginsNode.attr["value"])
|
||||
graph_def.node.extend([beginsNode])
|
||||
|
||||
sizesNode = NodeDef()
|
||||
sizesNode.name = out + '/sizes'
|
||||
sizesNode.op = 'Const'
|
||||
text_format.Merge(tensorMsg(sizes), sizesNode.attr["value"])
|
||||
graph_def.node.extend([sizesNode])
|
||||
|
||||
sliced = NodeDef()
|
||||
sliced.name = out
|
||||
sliced.op = 'Slice'
|
||||
sliced.input.append(inp)
|
||||
sliced.input.append(beginsNode.name)
|
||||
sliced.input.append(sizesNode.name)
|
||||
graph_def.node.extend([sliced])
|
||||
|
||||
def addReshape(inp, out, shape):
|
||||
shapeNode = NodeDef()
|
||||
shapeNode.name = out + '/shape'
|
||||
shapeNode.op = 'Const'
|
||||
text_format.Merge(tensorMsg(shape), shapeNode.attr["value"])
|
||||
graph_def.node.extend([shapeNode])
|
||||
|
||||
reshape = NodeDef()
|
||||
reshape.name = out
|
||||
reshape.op = 'Reshape'
|
||||
reshape.input.append(inp)
|
||||
reshape.input.append(shapeNode.name)
|
||||
graph_def.node.extend([reshape])
|
||||
|
||||
def addSoftMax(inp, out):
|
||||
softmax = NodeDef()
|
||||
softmax.name = out
|
||||
softmax.op = 'Softmax'
|
||||
text_format.Merge('i: -1', softmax.attr['axis'])
|
||||
softmax.input.append(inp)
|
||||
graph_def.node.extend([softmax])
|
||||
|
||||
addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd',
|
||||
'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2])
|
||||
|
||||
addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1',
|
||||
'FirstStageBoxPredictor/ClassPredictor/softmax') # Compare with Reshape_4
|
||||
|
||||
flatten = NodeDef()
|
||||
flatten.name = 'FirstStageBoxPredictor/BoxEncodingPredictor/flatten' # Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd
|
||||
flatten.op = 'Flatten'
|
||||
flatten.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
|
||||
graph_def.node.extend([flatten])
|
||||
|
||||
proposals = NodeDef()
|
||||
proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized)
|
||||
proposals.op = 'PriorBox'
|
||||
proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
|
||||
proposals.input.append(graph_def.node[0].name) # image_tensor
|
||||
|
||||
text_format.Merge('b: false', proposals.attr["flip"])
|
||||
text_format.Merge('b: true', proposals.attr["clip"])
|
||||
text_format.Merge('f: %f' % args.features_stride, proposals.attr["step"])
|
||||
text_format.Merge('f: 0.0', proposals.attr["offset"])
|
||||
text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), proposals.attr["variance"])
|
||||
|
||||
widths = []
|
||||
heights = []
|
||||
for a in args.aspect_ratios:
|
||||
for s in args.scales:
|
||||
ar = np.sqrt(a)
|
||||
heights.append((args.features_stride**2) * s / ar)
|
||||
widths.append((args.features_stride**2) * s * ar)
|
||||
|
||||
text_format.Merge(tensorMsg(widths), proposals.attr["width"])
|
||||
text_format.Merge(tensorMsg(heights), proposals.attr["height"])
|
||||
|
||||
graph_def.node.extend([proposals])
|
||||
|
||||
# Compare with Reshape_5
|
||||
detectionOut = NodeDef()
|
||||
detectionOut.name = 'detection_out'
|
||||
detectionOut.op = 'DetectionOutput'
|
||||
|
||||
detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten')
|
||||
detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax')
|
||||
detectionOut.input.append('proposals')
|
||||
|
||||
text_format.Merge('i: 2', detectionOut.attr['num_classes'])
|
||||
text_format.Merge('b: true', detectionOut.attr['share_location'])
|
||||
text_format.Merge('i: 0', detectionOut.attr['background_label_id'])
|
||||
text_format.Merge('f: 0.7', detectionOut.attr['nms_threshold'])
|
||||
text_format.Merge('i: 6000', detectionOut.attr['top_k'])
|
||||
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
|
||||
text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
|
||||
text_format.Merge('b: true', detectionOut.attr['clip'])
|
||||
text_format.Merge('b: true', detectionOut.attr['loc_pred_transposed'])
|
||||
|
||||
graph_def.node.extend([detectionOut])
|
||||
|
||||
# Save as text.
|
||||
for node in reversed(topNodes):
|
||||
graph_def.node.extend([node])
|
||||
|
||||
addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax')
|
||||
|
||||
addSlice('SecondStageBoxPredictor/Reshape_1/softmax',
|
||||
'SecondStageBoxPredictor/Reshape_1/slice',
|
||||
[0, 0, 1], [-1, -1, -1])
|
||||
|
||||
addReshape('SecondStageBoxPredictor/Reshape_1/slice',
|
||||
'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1])
|
||||
|
||||
# Replace Flatten subgraph onto a single node.
|
||||
for i in reversed(range(len(graph_def.node))):
|
||||
if graph_def.node[i].op == 'CropAndResize':
|
||||
graph_def.node[i].input.insert(1, 'detection_out')
|
||||
|
||||
if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape':
|
||||
shapeNode = NodeDef()
|
||||
shapeNode.name = 'SecondStageBoxPredictor/Reshape/shape2'
|
||||
shapeNode.op = 'Const'
|
||||
text_format.Merge(tensorMsg([1, -1, 4]), shapeNode.attr["value"])
|
||||
graph_def.node.extend([shapeNode])
|
||||
|
||||
graph_def.node[i].input.pop()
|
||||
graph_def.node[i].input.append(shapeNode.name)
|
||||
|
||||
if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape',
|
||||
'SecondStageBoxPredictor/Flatten/flatten/strided_slice',
|
||||
'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape']:
|
||||
del graph_def.node[i]
|
||||
|
||||
for node in graph_def.node:
|
||||
if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape':
|
||||
node.op = 'Flatten'
|
||||
node.input.pop()
|
||||
break
|
||||
|
||||
################################################################################
|
||||
### Postprocessing
|
||||
################################################################################
|
||||
addSlice('detection_out', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4])
|
||||
|
||||
variance = NodeDef()
|
||||
variance.name = 'proposals/variance'
|
||||
variance.op = 'Const'
|
||||
text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), variance.attr["value"])
|
||||
graph_def.node.extend([variance])
|
||||
|
||||
varianceEncoder = NodeDef()
|
||||
varianceEncoder.name = 'variance_encoded'
|
||||
varianceEncoder.op = 'Mul'
|
||||
varianceEncoder.input.append('SecondStageBoxPredictor/Reshape')
|
||||
varianceEncoder.input.append(variance.name)
|
||||
text_format.Merge('i: 2', varianceEncoder.attr["axis"])
|
||||
graph_def.node.extend([varianceEncoder])
|
||||
|
||||
addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1])
|
||||
|
||||
detectionOut = NodeDef()
|
||||
detectionOut.name = 'detection_out_final'
|
||||
detectionOut.op = 'DetectionOutput'
|
||||
|
||||
detectionOut.input.append('variance_encoded')
|
||||
detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape')
|
||||
detectionOut.input.append('detection_out/slice/reshape')
|
||||
|
||||
text_format.Merge('i: %d' % args.num_classes, detectionOut.attr['num_classes'])
|
||||
text_format.Merge('b: false', detectionOut.attr['share_location'])
|
||||
text_format.Merge('i: %d' % (args.num_classes + 1), detectionOut.attr['background_label_id'])
|
||||
text_format.Merge('f: 0.6', detectionOut.attr['nms_threshold'])
|
||||
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
|
||||
text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
|
||||
text_format.Merge('b: true', detectionOut.attr['loc_pred_transposed'])
|
||||
text_format.Merge('b: true', detectionOut.attr['clip'])
|
||||
text_format.Merge('b: true', detectionOut.attr['variance_encoded_in_target'])
|
||||
graph_def.node.extend([detectionOut])
|
||||
|
||||
tf.train.write_graph(graph_def, "", args.output, as_text=True)
|
Loading…
Reference in New Issue
Block a user