mirror of
https://github.com/opencv/opencv.git
synced 2024-11-29 05:29:54 +08:00
Merge pull request #13892 from dkurt:onnx_upsample_unsqueeze
This commit is contained in:
commit
865c29a754
@ -392,10 +392,10 @@ void ONNXImporter::populateNet(Net dstNet)
|
|||||||
layerParams.set("ceil_mode", isCeilMode(layerParams));
|
layerParams.set("ceil_mode", isCeilMode(layerParams));
|
||||||
layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
|
layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
|
||||||
}
|
}
|
||||||
else if (layer_type == "GlobalAveragePool")
|
else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool")
|
||||||
{
|
{
|
||||||
layerParams.type = "Pooling";
|
layerParams.type = "Pooling";
|
||||||
layerParams.set("pool", "AVE");
|
layerParams.set("pool", layer_type == "GlobalAveragePool" ? "AVE" : "MAX");
|
||||||
layerParams.set("global_pooling", true);
|
layerParams.set("global_pooling", true);
|
||||||
}
|
}
|
||||||
else if (layer_type == "Add" || layer_type == "Sum")
|
else if (layer_type == "Add" || layer_type == "Sum")
|
||||||
@ -448,6 +448,11 @@ void ONNXImporter::populateNet(Net dstNet)
|
|||||||
layerParams.set("bias_term", false);
|
layerParams.set("bias_term", false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
else if (layer_type == "Neg")
|
||||||
|
{
|
||||||
|
layerParams.type = "Power";
|
||||||
|
layerParams.set("scale", -1);
|
||||||
|
}
|
||||||
else if (layer_type == "Constant")
|
else if (layer_type == "Constant")
|
||||||
{
|
{
|
||||||
CV_Assert(node_proto.input_size() == 0);
|
CV_Assert(node_proto.input_size() == 0);
|
||||||
@ -595,21 +600,35 @@ void ONNXImporter::populateNet(Net dstNet)
|
|||||||
else if (layer_type == "Unsqueeze")
|
else if (layer_type == "Unsqueeze")
|
||||||
{
|
{
|
||||||
CV_Assert(node_proto.input_size() == 1);
|
CV_Assert(node_proto.input_size() == 1);
|
||||||
Mat input = getBlob(node_proto, constBlobs, 0);
|
|
||||||
|
|
||||||
DictValue axes = layerParams.get("axes");
|
DictValue axes = layerParams.get("axes");
|
||||||
std::vector<int> dims;
|
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
|
||||||
for (int j = 0; j < input.dims; j++) {
|
{
|
||||||
dims.push_back(input.size[j]);
|
// Constant input.
|
||||||
}
|
Mat input = getBlob(node_proto, constBlobs, 0);
|
||||||
CV_Assert(axes.getIntValue(axes.size()-1) <= dims.size());
|
|
||||||
for (int j = 0; j < axes.size(); j++) {
|
std::vector<int> dims;
|
||||||
dims.insert(dims.begin() + axes.getIntValue(j), 1);
|
for (int j = 0; j < input.dims; j++) {
|
||||||
|
dims.push_back(input.size[j]);
|
||||||
|
}
|
||||||
|
CV_Assert(axes.getIntValue(axes.size()-1) <= dims.size());
|
||||||
|
for (int j = 0; j < axes.size(); j++) {
|
||||||
|
dims.insert(dims.begin() + axes.getIntValue(j), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat out = input.reshape(0, dims);
|
||||||
|
constBlobs.insert(std::make_pair(layerParams.name, out));
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
Mat out = input.reshape(0, dims);
|
// Variable input.
|
||||||
constBlobs.insert(std::make_pair(layerParams.name, out));
|
if (axes.size() != 1)
|
||||||
continue;
|
CV_Error(Error::StsNotImplemented, "Multidimensional unsqueeze");
|
||||||
|
|
||||||
|
int dims[] = {1, -1};
|
||||||
|
layerParams.type = "Reshape";
|
||||||
|
layerParams.set("axis", axes.getIntValue(0));
|
||||||
|
layerParams.set("num_axes", 1);
|
||||||
|
layerParams.set("dim", DictValue::arrayInt(&dims[0], 2));
|
||||||
}
|
}
|
||||||
else if (layer_type == "Reshape")
|
else if (layer_type == "Reshape")
|
||||||
{
|
{
|
||||||
@ -707,6 +726,25 @@ void ONNXImporter::populateNet(Net dstNet)
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
else if (layer_type == "Upsample")
|
||||||
|
{
|
||||||
|
layerParams.type = "Resize";
|
||||||
|
if (layerParams.has("scales"))
|
||||||
|
{
|
||||||
|
// Pytorch layer
|
||||||
|
DictValue scales = layerParams.get("scales");
|
||||||
|
CV_Assert(scales.size() == 4);
|
||||||
|
layerParams.set("zoom_factor_y", scales.getIntValue(2));
|
||||||
|
layerParams.set("zoom_factor_x", scales.getIntValue(3));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Caffe2 layer
|
||||||
|
replaceLayerParam(layerParams, "height_scale", "zoom_factor_y");
|
||||||
|
replaceLayerParam(layerParams, "width_scale", "zoom_factor_x");
|
||||||
|
}
|
||||||
|
replaceLayerParam(layerParams, "mode", "interpolation");
|
||||||
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
for (int j = 0; j < node_proto.input_size(); j++) {
|
for (int j = 0; j < node_proto.input_size(); j++) {
|
||||||
|
@ -141,6 +141,11 @@ TEST_P(Test_ONNX_layers, Padding)
|
|||||||
testONNXModels("padding");
|
testONNXModels("padding");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(Test_ONNX_layers, Resize)
|
||||||
|
{
|
||||||
|
testONNXModels("resize_nearest");
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(Test_ONNX_layers, MultyInputs)
|
TEST_P(Test_ONNX_layers, MultyInputs)
|
||||||
{
|
{
|
||||||
const String model = _tf("models/multy_inputs.onnx");
|
const String model = _tf("models/multy_inputs.onnx");
|
||||||
@ -170,6 +175,11 @@ TEST_P(Test_ONNX_layers, DynamicReshape)
|
|||||||
testONNXModels("dynamic_reshape");
|
testONNXModels("dynamic_reshape");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(Test_ONNX_layers, Reshape)
|
||||||
|
{
|
||||||
|
testONNXModels("unsqueeze");
|
||||||
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
|
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
|
||||||
|
|
||||||
class Test_ONNX_nets : public Test_ONNX_layers {};
|
class Test_ONNX_nets : public Test_ONNX_layers {};
|
||||||
|
Loading…
Reference in New Issue
Block a user