From 20400aa9f7a6f8f2cc75e3f18bd944a66108f717 Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Thu, 21 Feb 2019 19:48:46 +0300 Subject: [PATCH] Import Upsample and Unsqueeze from ONNX --- modules/dnn/src/onnx/onnx_importer.cpp | 66 +++++++++++++++++++------ modules/dnn/test/test_onnx_importer.cpp | 10 ++++ 2 files changed, 62 insertions(+), 14 deletions(-) diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 218775b39f..98c3563573 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -392,10 +392,10 @@ void ONNXImporter::populateNet(Net dstNet) layerParams.set("ceil_mode", isCeilMode(layerParams)); layerParams.set("ave_pool_padded_area", framework_name == "pytorch"); } - else if (layer_type == "GlobalAveragePool") + else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool") { layerParams.type = "Pooling"; - layerParams.set("pool", "AVE"); + layerParams.set("pool", layer_type == "GlobalAveragePool" ? "AVE" : "MAX"); layerParams.set("global_pooling", true); } else if (layer_type == "Add" || layer_type == "Sum") @@ -448,6 +448,11 @@ void ONNXImporter::populateNet(Net dstNet) layerParams.set("bias_term", false); } } + else if (layer_type == "Neg") + { + layerParams.type = "Power"; + layerParams.set("scale", -1); + } else if (layer_type == "Constant") { CV_Assert(node_proto.input_size() == 0); @@ -595,21 +600,35 @@ void ONNXImporter::populateNet(Net dstNet) else if (layer_type == "Unsqueeze") { CV_Assert(node_proto.input_size() == 1); - Mat input = getBlob(node_proto, constBlobs, 0); - DictValue axes = layerParams.get("axes"); - std::vector dims; - for (int j = 0; j < input.dims; j++) { - dims.push_back(input.size[j]); - } - CV_Assert(axes.getIntValue(axes.size()-1) <= dims.size()); - for (int j = 0; j < axes.size(); j++) { - dims.insert(dims.begin() + axes.getIntValue(j), 1); + if (constBlobs.find(node_proto.input(0)) != constBlobs.end()) + { + // Constant input. + Mat input = getBlob(node_proto, constBlobs, 0); + + std::vector dims; + for (int j = 0; j < input.dims; j++) { + dims.push_back(input.size[j]); + } + CV_Assert(axes.getIntValue(axes.size()-1) <= dims.size()); + for (int j = 0; j < axes.size(); j++) { + dims.insert(dims.begin() + axes.getIntValue(j), 1); + } + + Mat out = input.reshape(0, dims); + constBlobs.insert(std::make_pair(layerParams.name, out)); + continue; } - Mat out = input.reshape(0, dims); - constBlobs.insert(std::make_pair(layerParams.name, out)); - continue; + // Variable input. + if (axes.size() != 1) + CV_Error(Error::StsNotImplemented, "Multidimensional unsqueeze"); + + int dims[] = {1, -1}; + layerParams.type = "Reshape"; + layerParams.set("axis", axes.getIntValue(0)); + layerParams.set("num_axes", 1); + layerParams.set("dim", DictValue::arrayInt(&dims[0], 2)); } else if (layer_type == "Reshape") { @@ -707,6 +726,25 @@ void ONNXImporter::populateNet(Net dstNet) continue; } } + else if (layer_type == "Upsample") + { + layerParams.type = "Resize"; + if (layerParams.has("scales")) + { + // Pytorch layer + DictValue scales = layerParams.get("scales"); + CV_Assert(scales.size() == 4); + layerParams.set("zoom_factor_y", scales.getIntValue(2)); + layerParams.set("zoom_factor_x", scales.getIntValue(3)); + } + else + { + // Caffe2 layer + replaceLayerParam(layerParams, "height_scale", "zoom_factor_y"); + replaceLayerParam(layerParams, "width_scale", "zoom_factor_x"); + } + replaceLayerParam(layerParams, "mode", "interpolation"); + } else { for (int j = 0; j < node_proto.input_size(); j++) { diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 217ef34421..72112d2396 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -140,6 +140,11 @@ TEST_P(Test_ONNX_layers, Padding) testONNXModels("padding"); } +TEST_P(Test_ONNX_layers, Resize) +{ + testONNXModels("resize_nearest"); +} + TEST_P(Test_ONNX_layers, MultyInputs) { const String model = _tf("models/multy_inputs.onnx"); @@ -169,6 +174,11 @@ TEST_P(Test_ONNX_layers, DynamicReshape) testONNXModels("dynamic_reshape"); } +TEST_P(Test_ONNX_layers, Reshape) +{ + testONNXModels("unsqueeze"); +} + INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets()); class Test_ONNX_nets : public Test_ONNX_layers {};