mirror of
https://github.com/opencv/opencv.git
synced 2024-11-29 05:29:54 +08:00
Add checks for LSTM initial h and c
This commit is contained in:
parent
8433620295
commit
467c3ef0ac
@ -496,6 +496,7 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
runLayer(layerParams, inputs, sliced);
|
||||
CV_Assert(sliced.size() == 1);
|
||||
constBlobs.insert(std::make_pair(layerParams.name, sliced[0]));
|
||||
outShapes[layerParams.name] = shape(sliced[0]);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@ -630,6 +631,8 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
Mat Wx = getBlob(node_proto, constBlobs, 1);
|
||||
Mat Wh = getBlob(node_proto, constBlobs, 2);
|
||||
Mat b = getBlob(node_proto, constBlobs, 3);
|
||||
CV_CheckEQ(countNonZero(getBlob(node_proto, constBlobs, 5)), 0, "Unsupported non zero initial_h");
|
||||
CV_CheckEQ(countNonZero(getBlob(node_proto, constBlobs, 6)), 0, "Unsupported non zero initial_c");
|
||||
b = b.reshape(1, b.size[0]);
|
||||
|
||||
const int numHidden = lstmParams.get<int>("hidden_size");
|
||||
@ -1007,6 +1010,16 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
}
|
||||
else
|
||||
layerParams.type = "Identity";
|
||||
|
||||
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
|
||||
{
|
||||
Mat inp = getBlob(node_proto, constBlobs, 0);
|
||||
Mat out = inp.reshape(1, outShape);
|
||||
out.dims = outShape.size(); // to workaround dims == 1
|
||||
constBlobs.insert(std::make_pair(layerParams.name, out));
|
||||
outShapes[layerParams.name] = shape(out);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
else if (layer_type == "Flatten")
|
||||
{
|
||||
@ -1136,15 +1149,6 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
else
|
||||
layerParams.type = "Identity";
|
||||
}
|
||||
else if (layer_type == "ConstantFill" || layer_type == "ConstantOfShape")
|
||||
{
|
||||
CV_Assert_N(node_proto.input_size());
|
||||
MatShape inpShape = getBlob(node_proto, constBlobs, 0);
|
||||
float value = layerParams.get("value", 0);
|
||||
Mat fill(inpShape.size(), &inpShape[0], CV_32F, Scalar(value));
|
||||
constBlobs.insert(std::make_pair(layerParams.name, fill));
|
||||
continue;
|
||||
}
|
||||
else if (layer_type == "ConstantOfShape" || layer_type == "ConstantFill")
|
||||
{
|
||||
float fill_value;
|
||||
|
@ -405,6 +405,8 @@ TEST_P(Test_ONNX_layers, Reshape)
|
||||
|
||||
TEST_P(Test_ONNX_layers, Squeeze)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target == DNN_TARGET_MYRIAD)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
|
||||
testONNXModels("squeeze");
|
||||
}
|
||||
|
||||
@ -453,12 +455,12 @@ TEST_P(Test_ONNX_layers, Split_EltwiseMax)
|
||||
|
||||
TEST_P(Test_ONNX_layers, LSTM)
|
||||
{
|
||||
testONNXModels("lstm");
|
||||
testONNXModels("lstm", npy, 0, 0, false, false);
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, LSTM_bidirectional)
|
||||
{
|
||||
testONNXModels("lstm_bidirectional");
|
||||
testONNXModels("lstm_bidirectional", npy, 0, 0, false, false);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
|
||||
|
Loading…
Reference in New Issue
Block a user