Merge pull request #13892 from dkurt:onnx_upsample_unsqueeze

This commit is contained in:
Alexander Alekhin 2019-02-26 10:17:06 +00:00
commit 865c29a754
2 changed files with 62 additions and 14 deletions

View File

@ -392,10 +392,10 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams.set("ceil_mode", isCeilMode(layerParams)); layerParams.set("ceil_mode", isCeilMode(layerParams));
layerParams.set("ave_pool_padded_area", framework_name == "pytorch"); layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
} }
else if (layer_type == "GlobalAveragePool") else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool")
{ {
layerParams.type = "Pooling"; layerParams.type = "Pooling";
layerParams.set("pool", "AVE"); layerParams.set("pool", layer_type == "GlobalAveragePool" ? "AVE" : "MAX");
layerParams.set("global_pooling", true); layerParams.set("global_pooling", true);
} }
else if (layer_type == "Add" || layer_type == "Sum") else if (layer_type == "Add" || layer_type == "Sum")
@ -448,6 +448,11 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams.set("bias_term", false); layerParams.set("bias_term", false);
} }
} }
else if (layer_type == "Neg")
{
layerParams.type = "Power";
layerParams.set("scale", -1);
}
else if (layer_type == "Constant") else if (layer_type == "Constant")
{ {
CV_Assert(node_proto.input_size() == 0); CV_Assert(node_proto.input_size() == 0);
@ -595,21 +600,35 @@ void ONNXImporter::populateNet(Net dstNet)
else if (layer_type == "Unsqueeze") else if (layer_type == "Unsqueeze")
{ {
CV_Assert(node_proto.input_size() == 1); CV_Assert(node_proto.input_size() == 1);
Mat input = getBlob(node_proto, constBlobs, 0);
DictValue axes = layerParams.get("axes"); DictValue axes = layerParams.get("axes");
std::vector<int> dims; if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
for (int j = 0; j < input.dims; j++) { {
dims.push_back(input.size[j]); // Constant input.
} Mat input = getBlob(node_proto, constBlobs, 0);
CV_Assert(axes.getIntValue(axes.size()-1) <= dims.size());
for (int j = 0; j < axes.size(); j++) { std::vector<int> dims;
dims.insert(dims.begin() + axes.getIntValue(j), 1); for (int j = 0; j < input.dims; j++) {
dims.push_back(input.size[j]);
}
CV_Assert(axes.getIntValue(axes.size()-1) <= dims.size());
for (int j = 0; j < axes.size(); j++) {
dims.insert(dims.begin() + axes.getIntValue(j), 1);
}
Mat out = input.reshape(0, dims);
constBlobs.insert(std::make_pair(layerParams.name, out));
continue;
} }
Mat out = input.reshape(0, dims); // Variable input.
constBlobs.insert(std::make_pair(layerParams.name, out)); if (axes.size() != 1)
continue; CV_Error(Error::StsNotImplemented, "Multidimensional unsqueeze");
int dims[] = {1, -1};
layerParams.type = "Reshape";
layerParams.set("axis", axes.getIntValue(0));
layerParams.set("num_axes", 1);
layerParams.set("dim", DictValue::arrayInt(&dims[0], 2));
} }
else if (layer_type == "Reshape") else if (layer_type == "Reshape")
{ {
@ -707,6 +726,25 @@ void ONNXImporter::populateNet(Net dstNet)
continue; continue;
} }
} }
else if (layer_type == "Upsample")
{
layerParams.type = "Resize";
if (layerParams.has("scales"))
{
// Pytorch layer
DictValue scales = layerParams.get("scales");
CV_Assert(scales.size() == 4);
layerParams.set("zoom_factor_y", scales.getIntValue(2));
layerParams.set("zoom_factor_x", scales.getIntValue(3));
}
else
{
// Caffe2 layer
replaceLayerParam(layerParams, "height_scale", "zoom_factor_y");
replaceLayerParam(layerParams, "width_scale", "zoom_factor_x");
}
replaceLayerParam(layerParams, "mode", "interpolation");
}
else else
{ {
for (int j = 0; j < node_proto.input_size(); j++) { for (int j = 0; j < node_proto.input_size(); j++) {

View File

@ -141,6 +141,11 @@ TEST_P(Test_ONNX_layers, Padding)
testONNXModels("padding"); testONNXModels("padding");
} }
TEST_P(Test_ONNX_layers, Resize)
{
testONNXModels("resize_nearest");
}
TEST_P(Test_ONNX_layers, MultyInputs) TEST_P(Test_ONNX_layers, MultyInputs)
{ {
const String model = _tf("models/multy_inputs.onnx"); const String model = _tf("models/multy_inputs.onnx");
@ -170,6 +175,11 @@ TEST_P(Test_ONNX_layers, DynamicReshape)
testONNXModels("dynamic_reshape"); testONNXModels("dynamic_reshape");
} }
TEST_P(Test_ONNX_layers, Reshape)
{
testONNXModels("unsqueeze");
}
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets()); INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
class Test_ONNX_nets : public Test_ONNX_layers {}; class Test_ONNX_nets : public Test_ONNX_layers {};