mirror of
https://github.com/opencv/opencv.git
synced 2025-07-26 07:07:37 +08:00
Merge pull request #22229 from zihaomu:bug_fix_22195_3_4
This commit is contained in:
commit
a9354fc743
@ -91,6 +91,16 @@ public:
|
|||||||
if (hasWeights && hasBias)
|
if (hasWeights && hasBias)
|
||||||
CV_CheckEQ(weights.total(), bias.total(), "Incompatible weights/bias blobs");
|
CV_CheckEQ(weights.total(), bias.total(), "Incompatible weights/bias blobs");
|
||||||
|
|
||||||
|
if (weights.total() == 1)
|
||||||
|
{
|
||||||
|
// The total() of bias should be same as weights.
|
||||||
|
if (hasBias)
|
||||||
|
inpBlob.convertTo(outBlob, CV_32F, weights.at<float>(0), bias.at<float>(0));
|
||||||
|
else
|
||||||
|
inpBlob.convertTo(outBlob, CV_32F, weights.at<float>(0));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
int endAxis;
|
int endAxis;
|
||||||
for (endAxis = axis + 1; endAxis <= inpBlob.dims; ++endAxis)
|
for (endAxis = axis + 1; endAxis <= inpBlob.dims; ++endAxis)
|
||||||
{
|
{
|
||||||
|
@ -1818,6 +1818,8 @@ void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::Node
|
|||||||
|
|
||||||
void findBroadAxis(const MatShape& broadShape, const MatShape& outShape, size_t& axis, int& broadAxis)
|
void findBroadAxis(const MatShape& broadShape, const MatShape& outShape, size_t& axis, int& broadAxis)
|
||||||
{
|
{
|
||||||
|
// Currently, this function can only complete 1-dimensional expansion of broadShape.
|
||||||
|
// If there are two dimensions in broadShape that need to be expended, it will fail.
|
||||||
const size_t diff = outShape.size() - broadShape.size();
|
const size_t diff = outShape.size() - broadShape.size();
|
||||||
|
|
||||||
// find the first non-one element of the broadcasting shape
|
// find the first non-one element of the broadcasting shape
|
||||||
@ -1982,25 +1984,30 @@ void ONNXImporter::parseMul(LayerParams& layerParams, const opencv_onnx::NodePro
|
|||||||
const MatShape& outShape = outShapes[node_proto.input(0)];
|
const MatShape& outShape = outShapes[node_proto.input(0)];
|
||||||
|
|
||||||
size_t axis = 0;
|
size_t axis = 0;
|
||||||
int broadAxis = -1;
|
if (total(broadShape) != 1)
|
||||||
findBroadAxis(broadShape, outShape, axis, broadAxis);
|
|
||||||
|
|
||||||
// if there is a one dimension in the middle that should be broadcasted, broadcast it
|
|
||||||
if (broadAxis != -1)
|
|
||||||
{
|
{
|
||||||
opencv_onnx::NodeProto concat_node_proto = node_proto;
|
// If broadShape is a scalar, we set axis as 0.
|
||||||
const std::string& input1 = concat_node_proto.input(1);
|
// Other-wise, we check broadcast is available.
|
||||||
|
int broadAxis = -1;
|
||||||
|
findBroadAxis(broadShape, outShape, axis, broadAxis);
|
||||||
|
|
||||||
expandMid(layerParams.name, concat_node_proto, input1, outShape[broadAxis]);
|
// if there is a one dimension in the middle that should be broadcasted, broadcast it
|
||||||
|
if (broadAxis != -1)
|
||||||
|
{
|
||||||
|
opencv_onnx::NodeProto concat_node_proto = node_proto;
|
||||||
|
const std::string& input1 = concat_node_proto.input(1);
|
||||||
|
|
||||||
LayerParams concatLP;
|
expandMid(layerParams.name, concat_node_proto, input1, outShape[broadAxis]);
|
||||||
concatLP.name = layerParams.name + "/concat";
|
|
||||||
concatLP.set("axis", broadAxis);
|
|
||||||
concatLP.type = "Concat";
|
|
||||||
concat_node_proto.set_output(0, concatLP.name);
|
|
||||||
|
|
||||||
addLayer(concatLP, concat_node_proto);
|
LayerParams concatLP;
|
||||||
node_proto.set_input(1, concatLP.name);
|
concatLP.name = layerParams.name + "/concat";
|
||||||
|
concatLP.set("axis", broadAxis);
|
||||||
|
concatLP.type = "Concat";
|
||||||
|
concat_node_proto.set_output(0, concatLP.name);
|
||||||
|
|
||||||
|
addLayer(concatLP, concat_node_proto);
|
||||||
|
node_proto.set_input(1, concatLP.name);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CV_Assert(axis != outShape.size());
|
CV_Assert(axis != outShape.size());
|
||||||
|
@ -725,6 +725,8 @@ TEST_P(Test_ONNX_layers, Div)
|
|||||||
|
|
||||||
normAssert(ref, out, "", default_l1, default_lInf);
|
normAssert(ref, out, "", default_l1, default_lInf);
|
||||||
expectNoFallbacksFromIE(net);
|
expectNoFallbacksFromIE(net);
|
||||||
|
|
||||||
|
testONNXModels("div_test_1x1",npy, 0, 0, false, true, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(Test_ONNX_layers, DynamicReshape)
|
TEST_P(Test_ONNX_layers, DynamicReshape)
|
||||||
|
Loading…
Reference in New Issue
Block a user