mirror of
https://github.com/opencv/opencv.git
synced 2025-01-18 14:13:15 +08:00
Merge pull request #12082 from dkurt:dnn_ie_faster_rcnn
This commit is contained in:
commit
e0c93bcf6c
@ -258,6 +258,17 @@ PERF_TEST_P_(DNNTestNetwork, FastNeuralStyle_eccv16)
|
||||
processNet("dnn/fast_neural_style_eccv16_starry_night.t7", "", "", Mat(cv::Size(320, 240), CV_32FC3));
|
||||
}
|
||||
|
||||
PERF_TEST_P_(DNNTestNetwork, Inception_v2_Faster_RCNN)
|
||||
{
|
||||
if (backend == DNN_BACKEND_HALIDE ||
|
||||
(backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU) ||
|
||||
(backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16))
|
||||
throw SkipTestException("");
|
||||
processNet("dnn/faster_rcnn_inception_v2_coco_2018_01_28.pb",
|
||||
"dnn/faster_rcnn_inception_v2_coco_2018_01_28.pbtxt", "",
|
||||
Mat(cv::Size(800, 600), CV_32FC3));
|
||||
}
|
||||
|
||||
const tuple<DNNBackend, DNNTarget> testCases[] = {
|
||||
#ifdef HAVE_HALIDE
|
||||
tuple<DNNBackend, DNNTarget>(DNN_BACKEND_HALIDE, DNN_TARGET_CPU),
|
||||
|
@ -1408,7 +1408,7 @@ struct Net::Impl
|
||||
bool fused = ld.skip;
|
||||
|
||||
Ptr<Layer> layer = ld.layerInstance;
|
||||
if (!layer->supportBackend(preferableBackend))
|
||||
if (!fused && !layer->supportBackend(preferableBackend))
|
||||
{
|
||||
addInfEngineNetOutputs(ld);
|
||||
net = Ptr<InfEngineBackendNet>();
|
||||
@ -2050,10 +2050,10 @@ struct Net::Impl
|
||||
TickMeter tm;
|
||||
tm.start();
|
||||
|
||||
if (preferableBackend == DNN_BACKEND_OPENCV ||
|
||||
!layer->supportBackend(preferableBackend))
|
||||
if( !ld.skip )
|
||||
{
|
||||
if( !ld.skip )
|
||||
std::map<int, Ptr<BackendNode> >::iterator it = ld.backendNodes.find(preferableBackend);
|
||||
if (preferableBackend == DNN_BACKEND_OPENCV || it == ld.backendNodes.end() || it->second.empty())
|
||||
{
|
||||
if (preferableBackend == DNN_BACKEND_OPENCV && IS_DNN_OPENCL_TARGET(preferableTarget))
|
||||
{
|
||||
@ -2196,24 +2196,25 @@ struct Net::Impl
|
||||
}
|
||||
}
|
||||
else
|
||||
tm.reset();
|
||||
}
|
||||
else if (!ld.skip)
|
||||
{
|
||||
Ptr<BackendNode> node = ld.backendNodes[preferableBackend];
|
||||
if (preferableBackend == DNN_BACKEND_HALIDE)
|
||||
{
|
||||
forwardHalide(ld.outputBlobsWrappers, node);
|
||||
}
|
||||
else if (preferableBackend == DNN_BACKEND_INFERENCE_ENGINE)
|
||||
{
|
||||
forwardInfEngine(node);
|
||||
}
|
||||
else
|
||||
{
|
||||
CV_Error(Error::StsNotImplemented, "Unknown backend identifier");
|
||||
Ptr<BackendNode> node = it->second;
|
||||
CV_Assert(!node.empty());
|
||||
if (preferableBackend == DNN_BACKEND_HALIDE)
|
||||
{
|
||||
forwardHalide(ld.outputBlobsWrappers, node);
|
||||
}
|
||||
else if (preferableBackend == DNN_BACKEND_INFERENCE_ENGINE)
|
||||
{
|
||||
forwardInfEngine(node);
|
||||
}
|
||||
else
|
||||
{
|
||||
CV_Error(Error::StsNotImplemented, "Unknown backend identifier");
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
tm.reset();
|
||||
|
||||
tm.stop();
|
||||
layersTimings[ld.id] = tm.getTimeTicks();
|
||||
|
@ -196,7 +196,7 @@ public:
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE
|
||||
{
|
||||
return backendId == DNN_BACKEND_OPENCV ||
|
||||
backendId == DNN_BACKEND_INFERENCE_ENGINE && !_locPredTransposed && _bboxesNormalized;
|
||||
backendId == DNN_BACKEND_INFERENCE_ENGINE && !_locPredTransposed && _bboxesNormalized && !_clip;
|
||||
}
|
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||
|
@ -48,9 +48,8 @@ public:
|
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE
|
||||
{
|
||||
return backendId == DNN_BACKEND_OPENCV ||
|
||||
backendId == DNN_BACKEND_HALIDE && haveHalide() ||
|
||||
backendId == DNN_BACKEND_INFERENCE_ENGINE && haveInfEngine();
|
||||
return backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_HALIDE ||
|
||||
backendId == DNN_BACKEND_INFERENCE_ENGINE && axis == 1;
|
||||
}
|
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
|
||||
|
@ -111,7 +111,7 @@ public:
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE
|
||||
{
|
||||
return backendId == DNN_BACKEND_OPENCV ||
|
||||
backendId == DNN_BACKEND_INFERENCE_ENGINE && sliceRanges.size() == 1;
|
||||
backendId == DNN_BACKEND_INFERENCE_ENGINE && sliceRanges.size() == 1 && sliceRanges[0].size() == 4;
|
||||
}
|
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||
|
@ -307,15 +307,17 @@ public:
|
||||
return Ptr<BackendNode>();
|
||||
}
|
||||
|
||||
virtual Ptr<BackendNode> initInfEngine(const std::vector<Ptr<BackendWrapper> >&) CV_OVERRIDE
|
||||
virtual Ptr<BackendNode> initInfEngine(const std::vector<Ptr<BackendWrapper> >& inputs) CV_OVERRIDE
|
||||
{
|
||||
#ifdef HAVE_INF_ENGINE
|
||||
InferenceEngine::DataPtr input = infEngineDataNode(inputs[0]);
|
||||
|
||||
InferenceEngine::LayerParams lp;
|
||||
lp.name = name;
|
||||
lp.type = "SoftMax";
|
||||
lp.precision = InferenceEngine::Precision::FP32;
|
||||
std::shared_ptr<InferenceEngine::SoftMaxLayer> ieLayer(new InferenceEngine::SoftMaxLayer(lp));
|
||||
ieLayer->axis = axisRaw;
|
||||
ieLayer->axis = clamp(axisRaw, input->dims.size());
|
||||
return Ptr<BackendNode>(new InfEngineBackendNode(ieLayer));
|
||||
#endif // HAVE_INF_ENGINE
|
||||
return Ptr<BackendNode>();
|
||||
|
@ -954,6 +954,13 @@ void TFImporter::populateNet(Net dstNet)
|
||||
{
|
||||
CV_Assert(layer.input_size() == 2);
|
||||
|
||||
// For the object detection networks, TensorFlow Object Detection API
|
||||
// predicts deltas for bounding boxes in yxYX (ymin, xmin, ymax, xmax)
|
||||
// order. We can manage it at DetectionOutput layer parsing predictions
|
||||
// or shuffle last Faster-RCNN's matmul weights.
|
||||
bool locPredTransposed = hasLayerAttr(layer, "loc_pred_transposed") &&
|
||||
getLayerAttr(layer, "loc_pred_transposed").b();
|
||||
|
||||
layerParams.set("bias_term", false);
|
||||
layerParams.blobs.resize(1);
|
||||
|
||||
@ -970,6 +977,17 @@ void TFImporter::populateNet(Net dstNet)
|
||||
blobFromTensor(getConstBlob(net.node(weights_layer_index), value_id), layerParams.blobs[1]);
|
||||
ExcludeLayer(net, weights_layer_index, 0, false);
|
||||
layers_to_ignore.insert(next_layers[0].first);
|
||||
|
||||
if (locPredTransposed)
|
||||
{
|
||||
const int numWeights = layerParams.blobs[1].total();
|
||||
float* biasData = reinterpret_cast<float*>(layerParams.blobs[1].data);
|
||||
CV_Assert(numWeights % 4 == 0);
|
||||
for (int i = 0; i < numWeights; i += 2)
|
||||
{
|
||||
std::swap(biasData[i], biasData[i + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int kernel_blob_index = -1;
|
||||
@ -983,6 +1001,16 @@ void TFImporter::populateNet(Net dstNet)
|
||||
}
|
||||
|
||||
layerParams.set("num_output", layerParams.blobs[0].size[0]);
|
||||
if (locPredTransposed)
|
||||
{
|
||||
CV_Assert(layerParams.blobs[0].dims == 2);
|
||||
for (int i = 0; i < layerParams.blobs[0].size[0]; i += 2)
|
||||
{
|
||||
cv::Mat src = layerParams.blobs[0].row(i);
|
||||
cv::Mat dst = layerParams.blobs[0].row(i + 1);
|
||||
std::swap_ranges(src.begin<float>(), src.end<float>(), dst.begin<float>());
|
||||
}
|
||||
}
|
||||
|
||||
int id = dstNet.addLayer(name, "InnerProduct", layerParams);
|
||||
layer_id[name] = id;
|
||||
@ -1010,6 +1038,7 @@ void TFImporter::populateNet(Net dstNet)
|
||||
layer_id[permName] = permId;
|
||||
connect(layer_id, dstNet, inpId, permId, 0);
|
||||
inpId = Pin(permName);
|
||||
inpLayout = DATA_LAYOUT_NCHW;
|
||||
}
|
||||
else if (newShape.total() == 4 && inpLayout == DATA_LAYOUT_NHWC)
|
||||
{
|
||||
@ -1024,7 +1053,7 @@ void TFImporter::populateNet(Net dstNet)
|
||||
|
||||
// one input only
|
||||
connect(layer_id, dstNet, inpId, id, 0);
|
||||
data_layouts[name] = newShape.total() == 2 ? DATA_LAYOUT_PLANAR : DATA_LAYOUT_UNKNOWN;
|
||||
data_layouts[name] = newShape.total() == 2 ? DATA_LAYOUT_PLANAR : inpLayout;
|
||||
}
|
||||
else if (type == "Flatten" || type == "Squeeze")
|
||||
{
|
||||
@ -1696,41 +1725,6 @@ void TFImporter::populateNet(Net dstNet)
|
||||
connect(layer_id, dstNet, parsePin(layer.input(1)), id, 1);
|
||||
data_layouts[name] = DATA_LAYOUT_UNKNOWN;
|
||||
}
|
||||
else if (type == "DetectionOutput")
|
||||
{
|
||||
// op: "DetectionOutput"
|
||||
// input_0: "locations"
|
||||
// input_1: "classifications"
|
||||
// input_2: "prior_boxes"
|
||||
if (hasLayerAttr(layer, "num_classes"))
|
||||
layerParams.set("num_classes", getLayerAttr(layer, "num_classes").i());
|
||||
if (hasLayerAttr(layer, "share_location"))
|
||||
layerParams.set("share_location", getLayerAttr(layer, "share_location").b());
|
||||
if (hasLayerAttr(layer, "background_label_id"))
|
||||
layerParams.set("background_label_id", getLayerAttr(layer, "background_label_id").i());
|
||||
if (hasLayerAttr(layer, "nms_threshold"))
|
||||
layerParams.set("nms_threshold", getLayerAttr(layer, "nms_threshold").f());
|
||||
if (hasLayerAttr(layer, "top_k"))
|
||||
layerParams.set("top_k", getLayerAttr(layer, "top_k").i());
|
||||
if (hasLayerAttr(layer, "code_type"))
|
||||
layerParams.set("code_type", getLayerAttr(layer, "code_type").s());
|
||||
if (hasLayerAttr(layer, "keep_top_k"))
|
||||
layerParams.set("keep_top_k", getLayerAttr(layer, "keep_top_k").i());
|
||||
if (hasLayerAttr(layer, "confidence_threshold"))
|
||||
layerParams.set("confidence_threshold", getLayerAttr(layer, "confidence_threshold").f());
|
||||
if (hasLayerAttr(layer, "loc_pred_transposed"))
|
||||
layerParams.set("loc_pred_transposed", getLayerAttr(layer, "loc_pred_transposed").b());
|
||||
if (hasLayerAttr(layer, "clip"))
|
||||
layerParams.set("clip", getLayerAttr(layer, "clip").b());
|
||||
if (hasLayerAttr(layer, "variance_encoded_in_target"))
|
||||
layerParams.set("variance_encoded_in_target", getLayerAttr(layer, "variance_encoded_in_target").b());
|
||||
|
||||
int id = dstNet.addLayer(name, "DetectionOutput", layerParams);
|
||||
layer_id[name] = id;
|
||||
for (int i = 0; i < 3; ++i)
|
||||
connect(layer_id, dstNet, parsePin(layer.input(i)), id, i);
|
||||
data_layouts[name] = DATA_LAYOUT_UNKNOWN;
|
||||
}
|
||||
else if (type == "Softmax")
|
||||
{
|
||||
if (hasLayerAttr(layer, "axis"))
|
||||
|
@ -323,7 +323,7 @@ TEST_P(Test_TensorFlow_nets, Inception_v2_SSD)
|
||||
TEST_P(Test_TensorFlow_nets, Inception_v2_Faster_RCNN)
|
||||
{
|
||||
checkBackend();
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE ||
|
||||
if ((backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU) ||
|
||||
(backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16))
|
||||
throw SkipTestException("");
|
||||
|
||||
|
25
samples/dnn/tf_text_graph_common.py
Normal file
25
samples/dnn/tf_text_graph_common.py
Normal file
@ -0,0 +1,25 @@
|
||||
import tensorflow as tf
|
||||
from tensorflow.core.framework.node_def_pb2 import NodeDef
|
||||
from google.protobuf import text_format
|
||||
|
||||
def tensorMsg(values):
|
||||
if all([isinstance(v, float) for v in values]):
|
||||
dtype = 'DT_FLOAT'
|
||||
field = 'float_val'
|
||||
elif all([isinstance(v, int) for v in values]):
|
||||
dtype = 'DT_INT32'
|
||||
field = 'int_val'
|
||||
else:
|
||||
raise Exception('Wrong values types')
|
||||
|
||||
msg = 'tensor { dtype: ' + dtype + ' tensor_shape { dim { size: %d } }' % len(values)
|
||||
for value in values:
|
||||
msg += '%s: %s ' % (field, str(value))
|
||||
return msg + '}'
|
||||
|
||||
def addConstNode(name, values, graph_def):
|
||||
node = NodeDef()
|
||||
node.name = name
|
||||
node.op = 'Const'
|
||||
text_format.Merge(tensorMsg(values), node.attr["value"])
|
||||
graph_def.node.extend([node])
|
@ -6,6 +6,8 @@ from tensorflow.core.framework.node_def_pb2 import NodeDef
|
||||
from tensorflow.tools.graph_transforms import TransformGraph
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tf_text_graph_common import tensorMsg, addConstNode
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
||||
'SSD model from TensorFlow Object Detection API. '
|
||||
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
|
||||
@ -93,21 +95,6 @@ while True:
|
||||
if node.op == 'CropAndResize':
|
||||
break
|
||||
|
||||
def tensorMsg(values):
|
||||
if all([isinstance(v, float) for v in values]):
|
||||
dtype = 'DT_FLOAT'
|
||||
field = 'float_val'
|
||||
elif all([isinstance(v, int) for v in values]):
|
||||
dtype = 'DT_INT32'
|
||||
field = 'int_val'
|
||||
else:
|
||||
raise Exception('Wrong values types')
|
||||
|
||||
msg = 'tensor { dtype: ' + dtype + ' tensor_shape { dim { size: %d } }' % len(values)
|
||||
for value in values:
|
||||
msg += '%s: %s ' % (field, str(value))
|
||||
return msg + '}'
|
||||
|
||||
def addSlice(inp, out, begins, sizes):
|
||||
beginsNode = NodeDef()
|
||||
beginsNode.name = out + '/begins'
|
||||
@ -151,17 +138,25 @@ def addSoftMax(inp, out):
|
||||
softmax.input.append(inp)
|
||||
graph_def.node.extend([softmax])
|
||||
|
||||
def addFlatten(inp, out):
|
||||
flatten = NodeDef()
|
||||
flatten.name = out
|
||||
flatten.op = 'Flatten'
|
||||
flatten.input.append(inp)
|
||||
graph_def.node.extend([flatten])
|
||||
|
||||
addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd',
|
||||
'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2])
|
||||
|
||||
addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1',
|
||||
'FirstStageBoxPredictor/ClassPredictor/softmax') # Compare with Reshape_4
|
||||
|
||||
flatten = NodeDef()
|
||||
flatten.name = 'FirstStageBoxPredictor/BoxEncodingPredictor/flatten' # Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd
|
||||
flatten.op = 'Flatten'
|
||||
flatten.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
|
||||
graph_def.node.extend([flatten])
|
||||
addFlatten('FirstStageBoxPredictor/ClassPredictor/softmax',
|
||||
'FirstStageBoxPredictor/ClassPredictor/softmax/flatten')
|
||||
|
||||
# Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd
|
||||
addFlatten('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd',
|
||||
'FirstStageBoxPredictor/BoxEncodingPredictor/flatten')
|
||||
|
||||
proposals = NodeDef()
|
||||
proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized)
|
||||
@ -194,7 +189,7 @@ detectionOut.name = 'detection_out'
|
||||
detectionOut.op = 'DetectionOutput'
|
||||
|
||||
detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten')
|
||||
detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax')
|
||||
detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten')
|
||||
detectionOut.input.append('proposals')
|
||||
|
||||
text_format.Merge('i: 2', detectionOut.attr['num_classes'])
|
||||
@ -204,11 +199,21 @@ text_format.Merge('f: 0.7', detectionOut.attr['nms_threshold'])
|
||||
text_format.Merge('i: 6000', detectionOut.attr['top_k'])
|
||||
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
|
||||
text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
|
||||
text_format.Merge('b: true', detectionOut.attr['clip'])
|
||||
text_format.Merge('b: true', detectionOut.attr['loc_pred_transposed'])
|
||||
text_format.Merge('b: false', detectionOut.attr['clip'])
|
||||
|
||||
graph_def.node.extend([detectionOut])
|
||||
|
||||
addConstNode('clip_by_value/lower', [0.0], graph_def)
|
||||
addConstNode('clip_by_value/upper', [1.0], graph_def)
|
||||
|
||||
clipByValueNode = NodeDef()
|
||||
clipByValueNode.name = 'detection_out/clip_by_value'
|
||||
clipByValueNode.op = 'ClipByValue'
|
||||
clipByValueNode.input.append('detection_out')
|
||||
clipByValueNode.input.append('clip_by_value/lower')
|
||||
clipByValueNode.input.append('clip_by_value/upper')
|
||||
graph_def.node.extend([clipByValueNode])
|
||||
|
||||
# Save as text.
|
||||
for node in reversed(topNodes):
|
||||
graph_def.node.extend([node])
|
||||
@ -225,17 +230,13 @@ addReshape('SecondStageBoxPredictor/Reshape_1/slice',
|
||||
# Replace Flatten subgraph onto a single node.
|
||||
for i in reversed(range(len(graph_def.node))):
|
||||
if graph_def.node[i].op == 'CropAndResize':
|
||||
graph_def.node[i].input.insert(1, 'detection_out')
|
||||
graph_def.node[i].input.insert(1, 'detection_out/clip_by_value')
|
||||
|
||||
if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape':
|
||||
shapeNode = NodeDef()
|
||||
shapeNode.name = 'SecondStageBoxPredictor/Reshape/shape2'
|
||||
shapeNode.op = 'Const'
|
||||
text_format.Merge(tensorMsg([1, -1, 4]), shapeNode.attr["value"])
|
||||
graph_def.node.extend([shapeNode])
|
||||
addConstNode('SecondStageBoxPredictor/Reshape/shape2', [1, -1, 4], graph_def)
|
||||
|
||||
graph_def.node[i].input.pop()
|
||||
graph_def.node[i].input.append(shapeNode.name)
|
||||
graph_def.node[i].input.append('SecondStageBoxPredictor/Reshape/shape2')
|
||||
|
||||
if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape',
|
||||
'SecondStageBoxPredictor/Flatten/flatten/strided_slice',
|
||||
@ -246,12 +247,15 @@ for node in graph_def.node:
|
||||
if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape':
|
||||
node.op = 'Flatten'
|
||||
node.input.pop()
|
||||
break
|
||||
|
||||
if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D',
|
||||
'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']:
|
||||
text_format.Merge('b: true', node.attr["loc_pred_transposed"])
|
||||
|
||||
################################################################################
|
||||
### Postprocessing
|
||||
################################################################################
|
||||
addSlice('detection_out', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4])
|
||||
addSlice('detection_out/clip_by_value', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4])
|
||||
|
||||
variance = NodeDef()
|
||||
variance.name = 'proposals/variance'
|
||||
@ -268,12 +272,13 @@ text_format.Merge('i: 2', varianceEncoder.attr["axis"])
|
||||
graph_def.node.extend([varianceEncoder])
|
||||
|
||||
addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1])
|
||||
addFlatten('variance_encoded', 'variance_encoded/flatten')
|
||||
|
||||
detectionOut = NodeDef()
|
||||
detectionOut.name = 'detection_out_final'
|
||||
detectionOut.op = 'DetectionOutput'
|
||||
|
||||
detectionOut.input.append('variance_encoded')
|
||||
detectionOut.input.append('variance_encoded/flatten')
|
||||
detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape')
|
||||
detectionOut.input.append('detection_out/slice/reshape')
|
||||
|
||||
@ -283,7 +288,6 @@ text_format.Merge('i: %d' % (args.num_classes + 1), detectionOut.attr['backgroun
|
||||
text_format.Merge('f: 0.6', detectionOut.attr['nms_threshold'])
|
||||
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
|
||||
text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
|
||||
text_format.Merge('b: true', detectionOut.attr['loc_pred_transposed'])
|
||||
text_format.Merge('b: true', detectionOut.attr['clip'])
|
||||
text_format.Merge('b: true', detectionOut.attr['variance_encoded_in_target'])
|
||||
graph_def.node.extend([detectionOut])
|
||||
|
@ -15,6 +15,7 @@ from math import sqrt
|
||||
from tensorflow.core.framework.node_def_pb2 import NodeDef
|
||||
from tensorflow.tools.graph_transforms import TransformGraph
|
||||
from google.protobuf import text_format
|
||||
from tf_text_graph_common import tensorMsg, addConstNode
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
|
||||
'SSD model from TensorFlow Object Detection API. '
|
||||
@ -160,28 +161,6 @@ graph_def.node[1].input.append(weights)
|
||||
# Create SSD postprocessing head ###############################################
|
||||
|
||||
# Concatenate predictions of classes, predictions of bounding boxes and proposals.
|
||||
def tensorMsg(values):
|
||||
if all([isinstance(v, float) for v in values]):
|
||||
dtype = 'DT_FLOAT'
|
||||
field = 'float_val'
|
||||
elif all([isinstance(v, int) for v in values]):
|
||||
dtype = 'DT_INT32'
|
||||
field = 'int_val'
|
||||
else:
|
||||
raise Exception('Wrong values types')
|
||||
|
||||
msg = 'tensor { dtype: ' + dtype + ' tensor_shape { dim { size: %d } }' % len(values)
|
||||
for value in values:
|
||||
msg += '%s: %s ' % (field, str(value))
|
||||
return msg + '}'
|
||||
|
||||
def addConstNode(name, values):
|
||||
node = NodeDef()
|
||||
node.name = name
|
||||
node.op = 'Const'
|
||||
text_format.Merge(tensorMsg(values), node.attr["value"])
|
||||
graph_def.node.extend([node])
|
||||
|
||||
def addConcatNode(name, inputs, axisNodeName):
|
||||
concat = NodeDef()
|
||||
concat.name = name
|
||||
|
Loading…
Reference in New Issue
Block a user