Merge pull request #22229 from zihaomu:bug_fix_22195_3_4

This commit is contained in:
Alexander Alekhin 2022-07-14 20:27:25 +00:00
commit a9354fc743
3 changed files with 34 additions and 15 deletions

View File

@ -91,6 +91,16 @@ public:
if (hasWeights && hasBias)
CV_CheckEQ(weights.total(), bias.total(), "Incompatible weights/bias blobs");
if (weights.total() == 1)
{
// The total() of bias should be same as weights.
if (hasBias)
inpBlob.convertTo(outBlob, CV_32F, weights.at<float>(0), bias.at<float>(0));
else
inpBlob.convertTo(outBlob, CV_32F, weights.at<float>(0));
return;
}
int endAxis;
for (endAxis = axis + 1; endAxis <= inpBlob.dims; ++endAxis)
{

View File

@ -1818,6 +1818,8 @@ void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::Node
void findBroadAxis(const MatShape& broadShape, const MatShape& outShape, size_t& axis, int& broadAxis)
{
// Currently, this function can only complete 1-dimensional expansion of broadShape.
// If there are two dimensions in broadShape that need to be expended, it will fail.
const size_t diff = outShape.size() - broadShape.size();
// find the first non-one element of the broadcasting shape
@ -1982,6 +1984,10 @@ void ONNXImporter::parseMul(LayerParams& layerParams, const opencv_onnx::NodePro
const MatShape& outShape = outShapes[node_proto.input(0)];
size_t axis = 0;
if (total(broadShape) != 1)
{
// If broadShape is a scalar, we set axis as 0.
// Other-wise, we check broadcast is available.
int broadAxis = -1;
findBroadAxis(broadShape, outShape, axis, broadAxis);
@ -2002,6 +2008,7 @@ void ONNXImporter::parseMul(LayerParams& layerParams, const opencv_onnx::NodePro
addLayer(concatLP, concat_node_proto);
node_proto.set_input(1, concatLP.name);
}
}
CV_Assert(axis != outShape.size());
layerParams.set("axis", static_cast<int>(axis));

View File

@ -725,6 +725,8 @@ TEST_P(Test_ONNX_layers, Div)
normAssert(ref, out, "", default_l1, default_lInf);
expectNoFallbacksFromIE(net);
testONNXModels("div_test_1x1",npy, 0, 0, false, true, 2);
}
TEST_P(Test_ONNX_layers, DynamicReshape)