mirror of
https://github.com/opencv/opencv.git
synced 2024-11-24 19:20:28 +08:00
Merge pull request #23475 from Abdurrahheem:lstm_fix_initialization
Fix ONNX parser for single-layer LSTM hidden and cell states #23475 ### Fix ONNX parser for single-layer LSTM hidden and cell states ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake This PR addresses #21118 [issue](https://github.com/opencv/opencv/issues/21118). The problem is that the ONNX parser is unable to read the hidden state and cell state for single-layer LSTMs. This PR fixes the issue by updating the parser to correctly read hidden and cell states.
This commit is contained in:
parent
89c5a7584a
commit
e4e774d42b
@ -173,9 +173,17 @@ public:
|
||||
CV_CheckEQ(Wh.rows, Wx.rows, "");
|
||||
CV_CheckEQ(Wh.rows, (1 + static_cast<int>(bidirectional))*4*Wh.cols, "");
|
||||
CV_CheckEQ(Wh.rows, (int)bias.total(), "");
|
||||
CV_CheckEQ(hInternal.cols, Wh.cols, "");
|
||||
CV_CheckEQ(hInternal.cols, cInternal.cols, "");
|
||||
CV_CheckEQ(hInternal.rows, cInternal.rows, "");
|
||||
// Only perform these checks if hInternal and cInternal are not empty matrices
|
||||
// e.g. inputs are not given by a user
|
||||
if(!hInternal.empty()){
|
||||
CV_CheckEQ(hInternal.cols, Wh.cols, "");
|
||||
}
|
||||
if(!cInternal.empty()){
|
||||
CV_CheckEQ(cInternal.cols, Wh.cols, "");
|
||||
}
|
||||
if (!hInternal.empty() && !cInternal.empty()){ //otherwise check in forward
|
||||
CV_CheckEQ(hInternal.rows, cInternal.rows, "");
|
||||
}
|
||||
CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());
|
||||
|
||||
// Peephole weights.
|
||||
@ -266,7 +274,7 @@ public:
|
||||
std::vector<MatShape> &internals) const CV_OVERRIDE
|
||||
{
|
||||
CV_Assert((!usePeephole && blobs.size() == 5) || (usePeephole && blobs.size() == 8));
|
||||
CV_Assert(inputs.size() == 1);
|
||||
CV_Assert((inputs.size() == 1 || inputs.size() == 3));
|
||||
const MatShape& inp0 = inputs[0];
|
||||
|
||||
const Mat &Wh = blobs[0], &Wx = blobs[1];
|
||||
@ -326,7 +334,7 @@ public:
|
||||
inputs_arr.getMatVector(input);
|
||||
|
||||
CV_Assert((!usePeephole && blobs.size() == 5) || (usePeephole && blobs.size() == 8));
|
||||
CV_Assert(input.size() == 1);
|
||||
CV_Assert((input.size() == 1 || input.size() == 3));
|
||||
const Mat& inp0 = input[0];
|
||||
|
||||
Mat &Wh = blobs[0], &Wx = blobs[1];
|
||||
@ -383,8 +391,20 @@ public:
|
||||
Mat Wh = blobs[0];
|
||||
Mat Wx = blobs[1];
|
||||
Mat bias = blobs[2];
|
||||
Mat h_0 = blobs[3];
|
||||
Mat c_0 = blobs[4];
|
||||
|
||||
Mat h_0, c_0;
|
||||
// Handle h_0 and c_0 based on input size
|
||||
h_0 = (input.size() >= 2) ? input[1].reshape(1, input[1].size[0] * input[1].size[1]) : blobs[3];
|
||||
c_0 = (input.size() == 3) ? input[2].reshape(1, input[2].size[0] * input[2].size[1]) : blobs[4];
|
||||
|
||||
// Perform checks if input size is 2 or 3
|
||||
if (input.size() >= 2) {
|
||||
CV_CheckEQ(h_0.cols, Wh.cols, "");
|
||||
CV_CheckEQ(h_0.cols, c_0.cols, "");
|
||||
CV_CheckEQ(h_0.rows, c_0.rows, "");
|
||||
}
|
||||
|
||||
|
||||
Mat pI, pF, pO;
|
||||
|
||||
Wh = Wh.rowRange(i * Wh.rows / numDirs, (i + 1) * Wh.rows / numDirs);
|
||||
|
@ -1539,10 +1539,17 @@ void transformBlobs(std::vector<Mat>& blobs)
|
||||
|
||||
const int numHidden = Wh.size[2];
|
||||
|
||||
Mat h0 = blobs[3];
|
||||
h0 = h0.reshape(1, h0.size[0] * h0.size[1]);
|
||||
Mat c0 = blobs[4];
|
||||
c0 = c0.reshape(1, c0.size[0] * c0.size[1]);
|
||||
Mat h0, c0;
|
||||
// check weather input is dynamic or not: hx, cx are given by user.
|
||||
// Resahpe if only they are given
|
||||
if (!blobs[3].empty()){
|
||||
h0 = blobs[3];
|
||||
h0 = h0.reshape(1, h0.size[0] * h0.size[1]);
|
||||
}
|
||||
if (!blobs[4].empty()){
|
||||
c0 = blobs[4];
|
||||
c0 = c0.reshape(1, c0.size[0] * c0.size[1]);
|
||||
}
|
||||
|
||||
b = b.reshape(1, b.size[0]);
|
||||
Mat bx = b.colRange(0, b.cols / 2);
|
||||
@ -1569,8 +1576,13 @@ void transformBlobs(std::vector<Mat>& blobs)
|
||||
blobs[0] = Wh;
|
||||
blobs[1] = Wx;
|
||||
blobs[2] = b.reshape(1, 1);
|
||||
blobs[3] = h0;
|
||||
blobs[4] = c0;
|
||||
|
||||
if (!blobs[3].empty()){
|
||||
blobs[3] = h0;
|
||||
}
|
||||
if (!blobs[4].empty()){
|
||||
blobs[4] = c0;
|
||||
}
|
||||
|
||||
if (blobs.size() == 5) {
|
||||
// so that future patch removing copies can leave all indexing as is
|
||||
@ -1601,8 +1613,15 @@ void ONNXImporter::lstm_extractConsts(LayerParams& layerParams, const opencv_onn
|
||||
Mat blob;
|
||||
if (idx < lstm_proto.input_size() && !lstm_proto.input(idx).empty())
|
||||
{
|
||||
blob = getBlob(lstm_proto, idx);
|
||||
CV_Assert(shape(blob) == blobShape);
|
||||
if ((idx == 5 || idx == 6) && (constBlobs.find(lstm_proto.input(idx)) == constBlobs.end()))
|
||||
{
|
||||
blob = Mat();
|
||||
}
|
||||
else
|
||||
{
|
||||
blob = getBlob(lstm_proto, idx);
|
||||
CV_Assert(shape(blob) == blobShape);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user