Merge pull request #21159 from rogday:ceil_mode

fix ceil_mode for Average/MaxPooling

* fix ceil_mode

* add a comment
This commit is contained in:
rogday 2021-12-02 20:11:11 +03:00 committed by GitHub
parent b6df9debaf
commit 1613d30544
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -612,11 +612,24 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto)
} }
} }
void setCeilMode(LayerParams& layerParams)
{
// auto_pad attribute is deprecated and uses ceil
if (layerParams.has("pad_mode"))
{
layerParams.set("ceil_mode", true);
}
else if (!layerParams.has("ceil_mode"))
{
layerParams.set("ceil_mode", false);
}
}
void ONNXImporter::parseMaxPool(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) void ONNXImporter::parseMaxPool(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{ {
layerParams.type = "Pooling"; layerParams.type = "Pooling";
layerParams.set("pool", "MAX"); layerParams.set("pool", "MAX");
layerParams.set("ceil_mode", layerParams.has("pad_mode")); setCeilMode(layerParams);
addLayer(layerParams, node_proto); addLayer(layerParams, node_proto);
} }
@ -624,7 +637,7 @@ void ONNXImporter::parseAveragePool(LayerParams& layerParams, const opencv_onnx:
{ {
layerParams.type = "Pooling"; layerParams.type = "Pooling";
layerParams.set("pool", "AVE"); layerParams.set("pool", "AVE");
layerParams.set("ceil_mode", layerParams.has("pad_mode")); setCeilMode(layerParams);
layerParams.set("ave_pool_padded_area", framework_name == "pytorch"); layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
addLayer(layerParams, node_proto); addLayer(layerParams, node_proto);
} }