mirror of
https://github.com/opencv/opencv.git
synced 2025-07-25 06:38:42 +08:00
modification for upsample node fused from unfused Resize subgraph
This commit is contained in:
parent
245b2fec34
commit
d37180a2c4
@ -1397,8 +1397,7 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
CV_Assert(layer_id.find(node_proto.input(i)) == layer_id.end());
|
||||
|
||||
String interp_mode = layerParams.get<String>("coordinate_transformation_mode");
|
||||
CV_Assert_N(interp_mode != "tf_crop_and_resize", interp_mode != "asymmetric",
|
||||
interp_mode != "tf_half_pixel_for_nn");
|
||||
CV_Assert_N(interp_mode != "tf_crop_and_resize", interp_mode != "tf_half_pixel_for_nn");
|
||||
|
||||
layerParams.set("align_corners", interp_mode == "align_corners");
|
||||
Mat shapes = getBlob(node_proto, constBlobs, node_proto.input_size() - 1);
|
||||
@ -1426,6 +1425,22 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
}
|
||||
else if (layer_type == "Upsample")
|
||||
{
|
||||
//fused from Resize Subgraph
|
||||
if (layerParams.has("coordinate_transformation_mode"))
|
||||
{
|
||||
String interp_mode = layerParams.get<String>("coordinate_transformation_mode");
|
||||
CV_Assert_N(interp_mode != "tf_crop_and_resize", interp_mode != "tf_half_pixel_for_nn");
|
||||
|
||||
layerParams.set("align_corners", interp_mode == "align_corners");
|
||||
if (layerParams.get<String>("mode") == "linear")
|
||||
{
|
||||
layerParams.set("mode", interp_mode == "pytorch_half_pixel" ?
|
||||
"opencv_linear" : "bilinear");
|
||||
}
|
||||
}
|
||||
if (layerParams.get<String>("mode") == "linear" && framework_name == "pytorch")
|
||||
layerParams.set("mode", "opencv_linear");
|
||||
|
||||
layerParams.type = "Resize";
|
||||
if (layerParams.has("scales"))
|
||||
{
|
||||
@ -1435,22 +1450,21 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
layerParams.set("zoom_factor_y", scales.getIntValue(2));
|
||||
layerParams.set("zoom_factor_x", scales.getIntValue(3));
|
||||
}
|
||||
else
|
||||
else if (layerParams.has("height_scale") && layerParams.has("width_scale"))
|
||||
{
|
||||
// Caffe2 layer
|
||||
replaceLayerParam(layerParams, "height_scale", "zoom_factor_y");
|
||||
replaceLayerParam(layerParams, "width_scale", "zoom_factor_x");
|
||||
}
|
||||
replaceLayerParam(layerParams, "mode", "interpolation");
|
||||
|
||||
if (layerParams.get<String>("interpolation") == "linear" && framework_name == "pytorch") {
|
||||
layerParams.type = "Resize";
|
||||
else
|
||||
{
|
||||
// scales as input
|
||||
Mat scales = getBlob(node_proto, constBlobs, 1);
|
||||
CV_Assert(scales.total() == 4);
|
||||
layerParams.set("interpolation", "opencv_linear");
|
||||
layerParams.set("zoom_factor_y", scales.at<float>(2));
|
||||
layerParams.set("zoom_factor_x", scales.at<float>(3));
|
||||
}
|
||||
replaceLayerParam(layerParams, "mode", "interpolation");
|
||||
}
|
||||
else if (layer_type == "SoftMax" || layer_type == "LogSoftmax")
|
||||
{
|
||||
|
@ -369,6 +369,7 @@ TEST_P(Test_ONNX_layers, ResizeUnfused)
|
||||
testONNXModels("upsample_unfused_opset9_torch1.4");
|
||||
testONNXModels("resize_nearest_unfused_opset11_torch1.4");
|
||||
testONNXModels("resize_nearest_unfused_opset11_torch1.3");
|
||||
testONNXModels("resize_bilinear_unfused_opset11_torch1.4");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, MultyInputs)
|
||||
|
Loading…
Reference in New Issue
Block a user