Merge pull request #23491 from fengyuentau:patch_for_segment_anything

Fixes for Segment Anything
This commit is contained in:
Alexander Smorkalov 2023-05-04 21:07:58 +03:00 committed by GitHub
commit 351589e5fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2565,8 +2565,11 @@ void ONNXImporter::parseExpand(LayerParams& layerParams, const opencv_onnx::Node
node_proto.set_input(0, constParams.name);
node_proto.set_input(1, srcName);
}
else if (broadcast_axes.size() == 1 && broadcast_axes[0] <= 1)
else if (broadcast_axes.size() == 1)
{
// FIXME: this will end up creating massive amount of Identity nodes for broadcasting,
// for example, broadcast 1 to 256 needs 256 Identity nodes and 1 Concat node.
// Possible improvement is to use "Scale".
expandMid(layerParams.name, node_proto, srcName, targetShape[broadcast_axes[0]]);
layerParams.set("axis", broadcast_axes[0]);
@ -2638,7 +2641,8 @@ void ONNXImporter::parsePad(LayerParams& layerParams, const opencv_onnx::NodePro
paddings = paddings.t();
layerParams.set("paddings", DictValue::arrayInt(paddings.ptr<int>(), paddings.total()));
if (node_proto.input_size() == 3)
// check for non-null constant_value
if (node_proto.input_size() == 3 && !node_proto.input(2).empty())
{
Mat value = getBlob(node_proto, 2);
float padValue = (depth == CV_8S) ? (float)value.ptr<int8_t>()[0] : value.ptr<float>()[0];