Merge pull request #14459 from dkurt:tf_strided_slice

This commit is contained in:
Alexander Alekhin 2019-05-22 11:30:34 +00:00
commit f629cdfe2c
5 changed files with 47 additions and 2 deletions

View File

@ -820,6 +820,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
* * `*.t7` | `*.net` (Torch, http://torch.ch/)
* * `*.weights` (Darknet, https://pjreddie.com/darknet/)
* * `*.bin` (DLDT, https://software.intel.com/openvino-toolkit)
* * `*.onnx` (ONNX, https://onnx.ai/)
* @param[in] config Text file contains network configuration. It could be a
* file with the following extensions:
* * `*.prototxt` (Caffe, http://caffe.berkeleyvision.org/)

View File

@ -1423,6 +1423,43 @@ void TFImporter::populateNet(Net dstNet)
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
}
else if (type == "StridedSlice")
{
CV_Assert(layer.input_size() == 4);
Mat begins = getTensorContent(getConstBlob(layer, value_id, 1));
Mat ends = getTensorContent(getConstBlob(layer, value_id, 2));
Mat strides = getTensorContent(getConstBlob(layer, value_id, 3));
CV_CheckTypeEQ(begins.type(), CV_32SC1, "");
CV_CheckTypeEQ(ends.type(), CV_32SC1, "");
CV_CheckTypeEQ(strides.type(), CV_32SC1, "");
const int num = begins.total();
CV_Assert_N(num == ends.total(), num == strides.total());
int end_mask = getLayerAttr(layer, "end_mask").i();
for (int i = 0; i < num; ++i)
{
if (end_mask & (1 << i))
ends.at<int>(i) = -1;
if (strides.at<int>(i) != 1)
CV_Error(Error::StsNotImplemented,
format("StridedSlice with stride %d", strides.at<int>(i)));
}
if (begins.total() == 4 && getDataLayout(name, data_layouts) == DATA_LAYOUT_NHWC)
{
// Swap NHWC parameters' order to NCHW.
std::swap(begins.at<int>(2), begins.at<int>(3));
std::swap(begins.at<int>(1), begins.at<int>(2));
std::swap(ends.at<int>(2), ends.at<int>(3));
std::swap(ends.at<int>(1), ends.at<int>(2));
}
layerParams.set("begin", DictValue::arrayInt((int*)begins.data, begins.total()));
layerParams.set("end", DictValue::arrayInt((int*)ends.data, ends.total()));
int id = dstNet.addLayer(name, "Slice", layerParams);
layer_id[name] = id;
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
}
else if (type == "Mul")
{
bool haveConst = false;

View File

@ -248,7 +248,7 @@ TEST_P(Test_ONNX_layers, Reshape)
TEST_P(Test_ONNX_layers, Softmax)
{
testONNXModels("softmax");
testONNXModels("log_softmax");
testONNXModels("log_softmax", npy, 0, 0, false, false);
}
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());

View File

@ -663,6 +663,7 @@ TEST_P(Test_TensorFlow_layers, slice)
(target == DNN_TARGET_OPENCL || target == DNN_TARGET_OPENCL_FP16))
throw SkipTestException("");
runTensorFlowNet("slice_4d");
runTensorFlowNet("strided_slice");
}
TEST_P(Test_TensorFlow_layers, softmax)

View File

@ -31,7 +31,13 @@ def createFasterRCNNGraph(modelPath, configPath, outputPath):
aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']]
width_stride = float(grid_anchor_generator['width_stride'][0])
height_stride = float(grid_anchor_generator['height_stride'][0])
features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0])
feature_extractor = config['feature_extractor'][0]
if 'type' in feature_extractor and feature_extractor['type'][0] == 'faster_rcnn_nas':
features_stride = 16.0
else:
features_stride = float(feature_extractor['first_stage_features_stride'][0])
first_stage_nms_iou_threshold = float(config['first_stage_nms_iou_threshold'][0])
first_stage_max_proposals = int(config['first_stage_max_proposals'][0])