mirror of
https://github.com/opencv/opencv.git
synced 2025-01-21 08:37:57 +08:00
Merge pull request #14459 from dkurt:tf_strided_slice
This commit is contained in:
commit
f629cdfe2c
@ -820,6 +820,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
|
||||
* * `*.t7` | `*.net` (Torch, http://torch.ch/)
|
||||
* * `*.weights` (Darknet, https://pjreddie.com/darknet/)
|
||||
* * `*.bin` (DLDT, https://software.intel.com/openvino-toolkit)
|
||||
* * `*.onnx` (ONNX, https://onnx.ai/)
|
||||
* @param[in] config Text file contains network configuration. It could be a
|
||||
* file with the following extensions:
|
||||
* * `*.prototxt` (Caffe, http://caffe.berkeleyvision.org/)
|
||||
|
@ -1423,6 +1423,43 @@ void TFImporter::populateNet(Net dstNet)
|
||||
|
||||
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
|
||||
}
|
||||
else if (type == "StridedSlice")
|
||||
{
|
||||
CV_Assert(layer.input_size() == 4);
|
||||
Mat begins = getTensorContent(getConstBlob(layer, value_id, 1));
|
||||
Mat ends = getTensorContent(getConstBlob(layer, value_id, 2));
|
||||
Mat strides = getTensorContent(getConstBlob(layer, value_id, 3));
|
||||
CV_CheckTypeEQ(begins.type(), CV_32SC1, "");
|
||||
CV_CheckTypeEQ(ends.type(), CV_32SC1, "");
|
||||
CV_CheckTypeEQ(strides.type(), CV_32SC1, "");
|
||||
const int num = begins.total();
|
||||
CV_Assert_N(num == ends.total(), num == strides.total());
|
||||
|
||||
int end_mask = getLayerAttr(layer, "end_mask").i();
|
||||
for (int i = 0; i < num; ++i)
|
||||
{
|
||||
if (end_mask & (1 << i))
|
||||
ends.at<int>(i) = -1;
|
||||
if (strides.at<int>(i) != 1)
|
||||
CV_Error(Error::StsNotImplemented,
|
||||
format("StridedSlice with stride %d", strides.at<int>(i)));
|
||||
}
|
||||
if (begins.total() == 4 && getDataLayout(name, data_layouts) == DATA_LAYOUT_NHWC)
|
||||
{
|
||||
// Swap NHWC parameters' order to NCHW.
|
||||
std::swap(begins.at<int>(2), begins.at<int>(3));
|
||||
std::swap(begins.at<int>(1), begins.at<int>(2));
|
||||
std::swap(ends.at<int>(2), ends.at<int>(3));
|
||||
std::swap(ends.at<int>(1), ends.at<int>(2));
|
||||
}
|
||||
layerParams.set("begin", DictValue::arrayInt((int*)begins.data, begins.total()));
|
||||
layerParams.set("end", DictValue::arrayInt((int*)ends.data, ends.total()));
|
||||
|
||||
int id = dstNet.addLayer(name, "Slice", layerParams);
|
||||
layer_id[name] = id;
|
||||
|
||||
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
|
||||
}
|
||||
else if (type == "Mul")
|
||||
{
|
||||
bool haveConst = false;
|
||||
|
@ -248,7 +248,7 @@ TEST_P(Test_ONNX_layers, Reshape)
|
||||
TEST_P(Test_ONNX_layers, Softmax)
|
||||
{
|
||||
testONNXModels("softmax");
|
||||
testONNXModels("log_softmax");
|
||||
testONNXModels("log_softmax", npy, 0, 0, false, false);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
|
||||
|
@ -663,6 +663,7 @@ TEST_P(Test_TensorFlow_layers, slice)
|
||||
(target == DNN_TARGET_OPENCL || target == DNN_TARGET_OPENCL_FP16))
|
||||
throw SkipTestException("");
|
||||
runTensorFlowNet("slice_4d");
|
||||
runTensorFlowNet("strided_slice");
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_layers, softmax)
|
||||
|
@ -31,7 +31,13 @@ def createFasterRCNNGraph(modelPath, configPath, outputPath):
|
||||
aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']]
|
||||
width_stride = float(grid_anchor_generator['width_stride'][0])
|
||||
height_stride = float(grid_anchor_generator['height_stride'][0])
|
||||
features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0])
|
||||
|
||||
feature_extractor = config['feature_extractor'][0]
|
||||
if 'type' in feature_extractor and feature_extractor['type'][0] == 'faster_rcnn_nas':
|
||||
features_stride = 16.0
|
||||
else:
|
||||
features_stride = float(feature_extractor['first_stage_features_stride'][0])
|
||||
|
||||
first_stage_nms_iou_threshold = float(config['first_stage_nms_iou_threshold'][0])
|
||||
first_stage_max_proposals = int(config['first_stage_max_proposals'][0])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user