Merge pull request #23475 from Abdurrahheem:lstm_fix_initialization

Fix ONNX parser for single-layer LSTM hidden and cell states #23475

### Fix ONNX parser for single-layer LSTM hidden and cell states

### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake


This PR addresses #21118 [issue](https://github.com/opencv/opencv/issues/21118). The problem is that the ONNX parser is unable to read the hidden state and cell state for single-layer LSTMs. This PR fixes the issue by updating the parser to correctly read hidden and cell states.
This commit is contained in:
Abduragim Shtanchaev 2023-04-24 13:39:41 +03:00 committed by GitHub
parent 89c5a7584a
commit e4e774d42b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 15 deletions

View File

@ -173,9 +173,17 @@ public:
CV_CheckEQ(Wh.rows, Wx.rows, "");
CV_CheckEQ(Wh.rows, (1 + static_cast<int>(bidirectional))*4*Wh.cols, "");
CV_CheckEQ(Wh.rows, (int)bias.total(), "");
CV_CheckEQ(hInternal.cols, Wh.cols, "");
CV_CheckEQ(hInternal.cols, cInternal.cols, "");
CV_CheckEQ(hInternal.rows, cInternal.rows, "");
// Only perform these checks if hInternal and cInternal are not empty matrices
// e.g. inputs are not given by a user
if(!hInternal.empty()){
CV_CheckEQ(hInternal.cols, Wh.cols, "");
}
if(!cInternal.empty()){
CV_CheckEQ(cInternal.cols, Wh.cols, "");
}
if (!hInternal.empty() && !cInternal.empty()){ //otherwise check in forward
CV_CheckEQ(hInternal.rows, cInternal.rows, "");
}
CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());
// Peephole weights.
@ -266,7 +274,7 @@ public:
std::vector<MatShape> &internals) const CV_OVERRIDE
{
CV_Assert((!usePeephole && blobs.size() == 5) || (usePeephole && blobs.size() == 8));
CV_Assert(inputs.size() == 1);
CV_Assert((inputs.size() == 1 || inputs.size() == 3));
const MatShape& inp0 = inputs[0];
const Mat &Wh = blobs[0], &Wx = blobs[1];
@ -326,7 +334,7 @@ public:
inputs_arr.getMatVector(input);
CV_Assert((!usePeephole && blobs.size() == 5) || (usePeephole && blobs.size() == 8));
CV_Assert(input.size() == 1);
CV_Assert((input.size() == 1 || input.size() == 3));
const Mat& inp0 = input[0];
Mat &Wh = blobs[0], &Wx = blobs[1];
@ -383,8 +391,20 @@ public:
Mat Wh = blobs[0];
Mat Wx = blobs[1];
Mat bias = blobs[2];
Mat h_0 = blobs[3];
Mat c_0 = blobs[4];
Mat h_0, c_0;
// Handle h_0 and c_0 based on input size
h_0 = (input.size() >= 2) ? input[1].reshape(1, input[1].size[0] * input[1].size[1]) : blobs[3];
c_0 = (input.size() == 3) ? input[2].reshape(1, input[2].size[0] * input[2].size[1]) : blobs[4];
// Perform checks if input size is 2 or 3
if (input.size() >= 2) {
CV_CheckEQ(h_0.cols, Wh.cols, "");
CV_CheckEQ(h_0.cols, c_0.cols, "");
CV_CheckEQ(h_0.rows, c_0.rows, "");
}
Mat pI, pF, pO;
Wh = Wh.rowRange(i * Wh.rows / numDirs, (i + 1) * Wh.rows / numDirs);

View File

@ -1539,10 +1539,17 @@ void transformBlobs(std::vector<Mat>& blobs)
const int numHidden = Wh.size[2];
Mat h0 = blobs[3];
h0 = h0.reshape(1, h0.size[0] * h0.size[1]);
Mat c0 = blobs[4];
c0 = c0.reshape(1, c0.size[0] * c0.size[1]);
Mat h0, c0;
// check weather input is dynamic or not: hx, cx are given by user.
// Resahpe if only they are given
if (!blobs[3].empty()){
h0 = blobs[3];
h0 = h0.reshape(1, h0.size[0] * h0.size[1]);
}
if (!blobs[4].empty()){
c0 = blobs[4];
c0 = c0.reshape(1, c0.size[0] * c0.size[1]);
}
b = b.reshape(1, b.size[0]);
Mat bx = b.colRange(0, b.cols / 2);
@ -1569,8 +1576,13 @@ void transformBlobs(std::vector<Mat>& blobs)
blobs[0] = Wh;
blobs[1] = Wx;
blobs[2] = b.reshape(1, 1);
blobs[3] = h0;
blobs[4] = c0;
if (!blobs[3].empty()){
blobs[3] = h0;
}
if (!blobs[4].empty()){
blobs[4] = c0;
}
if (blobs.size() == 5) {
// so that future patch removing copies can leave all indexing as is
@ -1601,8 +1613,15 @@ void ONNXImporter::lstm_extractConsts(LayerParams& layerParams, const opencv_onn
Mat blob;
if (idx < lstm_proto.input_size() && !lstm_proto.input(idx).empty())
{
blob = getBlob(lstm_proto, idx);
CV_Assert(shape(blob) == blobShape);
if ((idx == 5 || idx == 6) && (constBlobs.find(lstm_proto.input(idx)) == constBlobs.end()))
{
blob = Mat();
}
else
{
blob = getBlob(lstm_proto, idx);
CV_Assert(shape(blob) == blobShape);
}
}
else
{