Add checks for LSTM initial h and c

This commit is contained in:
Dmitry Kurtaev 2020-03-22 16:04:30 +03:00
parent 8433620295
commit 467c3ef0ac
2 changed files with 17 additions and 11 deletions

View File

@ -496,6 +496,7 @@ void ONNXImporter::populateNet(Net dstNet)
runLayer(layerParams, inputs, sliced);
CV_Assert(sliced.size() == 1);
constBlobs.insert(std::make_pair(layerParams.name, sliced[0]));
outShapes[layerParams.name] = shape(sliced[0]);
continue;
}
}
@ -630,6 +631,8 @@ void ONNXImporter::populateNet(Net dstNet)
Mat Wx = getBlob(node_proto, constBlobs, 1);
Mat Wh = getBlob(node_proto, constBlobs, 2);
Mat b = getBlob(node_proto, constBlobs, 3);
CV_CheckEQ(countNonZero(getBlob(node_proto, constBlobs, 5)), 0, "Unsupported non zero initial_h");
CV_CheckEQ(countNonZero(getBlob(node_proto, constBlobs, 6)), 0, "Unsupported non zero initial_c");
b = b.reshape(1, b.size[0]);
const int numHidden = lstmParams.get<int>("hidden_size");
@ -1007,6 +1010,16 @@ void ONNXImporter::populateNet(Net dstNet)
}
else
layerParams.type = "Identity";
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
{
Mat inp = getBlob(node_proto, constBlobs, 0);
Mat out = inp.reshape(1, outShape);
out.dims = outShape.size(); // to workaround dims == 1
constBlobs.insert(std::make_pair(layerParams.name, out));
outShapes[layerParams.name] = shape(out);
continue;
}
}
else if (layer_type == "Flatten")
{
@ -1136,15 +1149,6 @@ void ONNXImporter::populateNet(Net dstNet)
else
layerParams.type = "Identity";
}
else if (layer_type == "ConstantFill" || layer_type == "ConstantOfShape")
{
CV_Assert_N(node_proto.input_size());
MatShape inpShape = getBlob(node_proto, constBlobs, 0);
float value = layerParams.get("value", 0);
Mat fill(inpShape.size(), &inpShape[0], CV_32F, Scalar(value));
constBlobs.insert(std::make_pair(layerParams.name, fill));
continue;
}
else if (layer_type == "ConstantOfShape" || layer_type == "ConstantFill")
{
float fill_value;

View File

@ -405,6 +405,8 @@ TEST_P(Test_ONNX_layers, Reshape)
TEST_P(Test_ONNX_layers, Squeeze)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target == DNN_TARGET_MYRIAD)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
testONNXModels("squeeze");
}
@ -453,12 +455,12 @@ TEST_P(Test_ONNX_layers, Split_EltwiseMax)
TEST_P(Test_ONNX_layers, LSTM)
{
testONNXModels("lstm");
testONNXModels("lstm", npy, 0, 0, false, false);
}
TEST_P(Test_ONNX_layers, LSTM_bidirectional)
{
testONNXModels("lstm_bidirectional");
testONNXModels("lstm_bidirectional", npy, 0, 0, false, false);
}
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());