mirror of
https://github.com/opencv/opencv.git
synced 2025-01-18 22:44:02 +08:00
Merge pull request #16722 from l-bat:reshape_opset_11
* Supported Div op for constants * Added Mul test
This commit is contained in:
parent
57cf120118
commit
9ed1332355
@ -465,31 +465,6 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
layerParams.blobs.push_back(-1.0f * blob.reshape(1, 1));
|
||||
}
|
||||
}
|
||||
else if (layer_type == "Div")
|
||||
{
|
||||
if (constBlobs.find(node_proto.input(1)) == constBlobs.end())
|
||||
{
|
||||
layerParams.type = "Eltwise";
|
||||
layerParams.set("operation", "div");
|
||||
}
|
||||
else
|
||||
{
|
||||
Mat blob = getBlob(node_proto, constBlobs, 1);
|
||||
CV_Assert_N(blob.type() == CV_32F, blob.total());
|
||||
if (blob.total() == 1)
|
||||
{
|
||||
layerParams.set("scale", 1.0f / blob.at<float>(0));
|
||||
layerParams.type = "Power";
|
||||
}
|
||||
else
|
||||
{
|
||||
layerParams.type = "Scale";
|
||||
divide(1.0, blob, blob);
|
||||
layerParams.blobs.push_back(blob);
|
||||
layerParams.set("bias_term", false);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (layer_type == "Neg")
|
||||
{
|
||||
layerParams.type = "Power";
|
||||
@ -638,24 +613,58 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
layerParams.set("bias_term", false);
|
||||
layerParams.set("num_output", layerParams.blobs[0].size[0]);
|
||||
}
|
||||
else if (layer_type == "Mul")
|
||||
else if (layer_type == "Mul" || layer_type == "Div")
|
||||
{
|
||||
CV_Assert(node_proto.input_size() == 2);
|
||||
if (layer_id.find(node_proto.input(1)) == layer_id.end()) {
|
||||
Mat blob = getBlob(node_proto, constBlobs, 1);
|
||||
|
||||
bool isDiv = layer_type == "Div";
|
||||
int constId = -1;
|
||||
bool haveVariables = false;
|
||||
for (int i = 0; i < 2; ++i)
|
||||
{
|
||||
if (constBlobs.find(node_proto.input(i)) != constBlobs.end())
|
||||
constId = i;
|
||||
else
|
||||
haveVariables = true;
|
||||
}
|
||||
if (constId != -1 && haveVariables)
|
||||
{
|
||||
Mat blob = getBlob(node_proto, constBlobs, constId);
|
||||
blob = blob.reshape(1, 1);
|
||||
if (blob.total() == 1) {
|
||||
layerParams.set("scale", blob.at<float>(0));
|
||||
float coeff = isDiv ? 1.0 / blob.at<float>(0) : blob.at<float>(0);
|
||||
layerParams.set("scale", coeff);
|
||||
layerParams.type = "Power";
|
||||
}
|
||||
else {
|
||||
if (isDiv)
|
||||
divide(1.0, blob, blob);
|
||||
layerParams.blobs.push_back(blob);
|
||||
layerParams.type = "Scale";
|
||||
}
|
||||
}
|
||||
else {
|
||||
layerParams.type = "Eltwise";
|
||||
layerParams.set("operation", "prod");
|
||||
layerParams.set("operation", isDiv ? "div" : "prod");
|
||||
}
|
||||
|
||||
if (!haveVariables)
|
||||
{
|
||||
Mat inp0 = getBlob(node_proto, constBlobs, 0);
|
||||
Mat inp1 = getBlob(node_proto, constBlobs, 1);
|
||||
if (inp0.size != inp1.size)
|
||||
CV_Error(Error::StsNotImplemented, "Constant multiply with different shapes");
|
||||
|
||||
Mat out;
|
||||
if (isDiv)
|
||||
divide(inp0, inp1, out);
|
||||
else
|
||||
multiply(inp0, inp1, out);
|
||||
|
||||
out = out.reshape(1, inp0.dims, inp0.size);
|
||||
out.dims = inp0.dims; // to workaround dims == 1
|
||||
constBlobs.insert(std::make_pair(layerParams.name, out));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
else if (layer_type == "Conv")
|
||||
|
@ -382,6 +382,8 @@ TEST_P(Test_ONNX_layers, DynamicReshape)
|
||||
if (target == DNN_TARGET_OPENCL) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_OPENCL, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
|
||||
}
|
||||
testONNXModels("dynamic_reshape");
|
||||
testONNXModels("dynamic_reshape_opset_11");
|
||||
testONNXModels("flatten_by_prod");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Reshape)
|
||||
|
Loading…
Reference in New Issue
Block a user