mirror of
https://github.com/opencv/opencv.git
synced 2025-06-12 20:42:53 +08:00
Merge pull request #23475 from Abdurrahheem:lstm_fix_initialization
Fix ONNX parser for single-layer LSTM hidden and cell states #23475 ### Fix ONNX parser for single-layer LSTM hidden and cell states ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake This PR addresses #21118 [issue](https://github.com/opencv/opencv/issues/21118). The problem is that the ONNX parser is unable to read the hidden state and cell state for single-layer LSTMs. This PR fixes the issue by updating the parser to correctly read hidden and cell states.
This commit is contained in:
parent
89c5a7584a
commit
e4e774d42b
@ -173,9 +173,17 @@ public:
|
|||||||
CV_CheckEQ(Wh.rows, Wx.rows, "");
|
CV_CheckEQ(Wh.rows, Wx.rows, "");
|
||||||
CV_CheckEQ(Wh.rows, (1 + static_cast<int>(bidirectional))*4*Wh.cols, "");
|
CV_CheckEQ(Wh.rows, (1 + static_cast<int>(bidirectional))*4*Wh.cols, "");
|
||||||
CV_CheckEQ(Wh.rows, (int)bias.total(), "");
|
CV_CheckEQ(Wh.rows, (int)bias.total(), "");
|
||||||
CV_CheckEQ(hInternal.cols, Wh.cols, "");
|
// Only perform these checks if hInternal and cInternal are not empty matrices
|
||||||
CV_CheckEQ(hInternal.cols, cInternal.cols, "");
|
// e.g. inputs are not given by a user
|
||||||
CV_CheckEQ(hInternal.rows, cInternal.rows, "");
|
if(!hInternal.empty()){
|
||||||
|
CV_CheckEQ(hInternal.cols, Wh.cols, "");
|
||||||
|
}
|
||||||
|
if(!cInternal.empty()){
|
||||||
|
CV_CheckEQ(cInternal.cols, Wh.cols, "");
|
||||||
|
}
|
||||||
|
if (!hInternal.empty() && !cInternal.empty()){ //otherwise check in forward
|
||||||
|
CV_CheckEQ(hInternal.rows, cInternal.rows, "");
|
||||||
|
}
|
||||||
CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());
|
CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());
|
||||||
|
|
||||||
// Peephole weights.
|
// Peephole weights.
|
||||||
@ -266,7 +274,7 @@ public:
|
|||||||
std::vector<MatShape> &internals) const CV_OVERRIDE
|
std::vector<MatShape> &internals) const CV_OVERRIDE
|
||||||
{
|
{
|
||||||
CV_Assert((!usePeephole && blobs.size() == 5) || (usePeephole && blobs.size() == 8));
|
CV_Assert((!usePeephole && blobs.size() == 5) || (usePeephole && blobs.size() == 8));
|
||||||
CV_Assert(inputs.size() == 1);
|
CV_Assert((inputs.size() == 1 || inputs.size() == 3));
|
||||||
const MatShape& inp0 = inputs[0];
|
const MatShape& inp0 = inputs[0];
|
||||||
|
|
||||||
const Mat &Wh = blobs[0], &Wx = blobs[1];
|
const Mat &Wh = blobs[0], &Wx = blobs[1];
|
||||||
@ -326,7 +334,7 @@ public:
|
|||||||
inputs_arr.getMatVector(input);
|
inputs_arr.getMatVector(input);
|
||||||
|
|
||||||
CV_Assert((!usePeephole && blobs.size() == 5) || (usePeephole && blobs.size() == 8));
|
CV_Assert((!usePeephole && blobs.size() == 5) || (usePeephole && blobs.size() == 8));
|
||||||
CV_Assert(input.size() == 1);
|
CV_Assert((input.size() == 1 || input.size() == 3));
|
||||||
const Mat& inp0 = input[0];
|
const Mat& inp0 = input[0];
|
||||||
|
|
||||||
Mat &Wh = blobs[0], &Wx = blobs[1];
|
Mat &Wh = blobs[0], &Wx = blobs[1];
|
||||||
@ -383,8 +391,20 @@ public:
|
|||||||
Mat Wh = blobs[0];
|
Mat Wh = blobs[0];
|
||||||
Mat Wx = blobs[1];
|
Mat Wx = blobs[1];
|
||||||
Mat bias = blobs[2];
|
Mat bias = blobs[2];
|
||||||
Mat h_0 = blobs[3];
|
|
||||||
Mat c_0 = blobs[4];
|
Mat h_0, c_0;
|
||||||
|
// Handle h_0 and c_0 based on input size
|
||||||
|
h_0 = (input.size() >= 2) ? input[1].reshape(1, input[1].size[0] * input[1].size[1]) : blobs[3];
|
||||||
|
c_0 = (input.size() == 3) ? input[2].reshape(1, input[2].size[0] * input[2].size[1]) : blobs[4];
|
||||||
|
|
||||||
|
// Perform checks if input size is 2 or 3
|
||||||
|
if (input.size() >= 2) {
|
||||||
|
CV_CheckEQ(h_0.cols, Wh.cols, "");
|
||||||
|
CV_CheckEQ(h_0.cols, c_0.cols, "");
|
||||||
|
CV_CheckEQ(h_0.rows, c_0.rows, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
Mat pI, pF, pO;
|
Mat pI, pF, pO;
|
||||||
|
|
||||||
Wh = Wh.rowRange(i * Wh.rows / numDirs, (i + 1) * Wh.rows / numDirs);
|
Wh = Wh.rowRange(i * Wh.rows / numDirs, (i + 1) * Wh.rows / numDirs);
|
||||||
|
@ -1539,10 +1539,17 @@ void transformBlobs(std::vector<Mat>& blobs)
|
|||||||
|
|
||||||
const int numHidden = Wh.size[2];
|
const int numHidden = Wh.size[2];
|
||||||
|
|
||||||
Mat h0 = blobs[3];
|
Mat h0, c0;
|
||||||
h0 = h0.reshape(1, h0.size[0] * h0.size[1]);
|
// check weather input is dynamic or not: hx, cx are given by user.
|
||||||
Mat c0 = blobs[4];
|
// Resahpe if only they are given
|
||||||
c0 = c0.reshape(1, c0.size[0] * c0.size[1]);
|
if (!blobs[3].empty()){
|
||||||
|
h0 = blobs[3];
|
||||||
|
h0 = h0.reshape(1, h0.size[0] * h0.size[1]);
|
||||||
|
}
|
||||||
|
if (!blobs[4].empty()){
|
||||||
|
c0 = blobs[4];
|
||||||
|
c0 = c0.reshape(1, c0.size[0] * c0.size[1]);
|
||||||
|
}
|
||||||
|
|
||||||
b = b.reshape(1, b.size[0]);
|
b = b.reshape(1, b.size[0]);
|
||||||
Mat bx = b.colRange(0, b.cols / 2);
|
Mat bx = b.colRange(0, b.cols / 2);
|
||||||
@ -1569,8 +1576,13 @@ void transformBlobs(std::vector<Mat>& blobs)
|
|||||||
blobs[0] = Wh;
|
blobs[0] = Wh;
|
||||||
blobs[1] = Wx;
|
blobs[1] = Wx;
|
||||||
blobs[2] = b.reshape(1, 1);
|
blobs[2] = b.reshape(1, 1);
|
||||||
blobs[3] = h0;
|
|
||||||
blobs[4] = c0;
|
if (!blobs[3].empty()){
|
||||||
|
blobs[3] = h0;
|
||||||
|
}
|
||||||
|
if (!blobs[4].empty()){
|
||||||
|
blobs[4] = c0;
|
||||||
|
}
|
||||||
|
|
||||||
if (blobs.size() == 5) {
|
if (blobs.size() == 5) {
|
||||||
// so that future patch removing copies can leave all indexing as is
|
// so that future patch removing copies can leave all indexing as is
|
||||||
@ -1601,8 +1613,15 @@ void ONNXImporter::lstm_extractConsts(LayerParams& layerParams, const opencv_onn
|
|||||||
Mat blob;
|
Mat blob;
|
||||||
if (idx < lstm_proto.input_size() && !lstm_proto.input(idx).empty())
|
if (idx < lstm_proto.input_size() && !lstm_proto.input(idx).empty())
|
||||||
{
|
{
|
||||||
blob = getBlob(lstm_proto, idx);
|
if ((idx == 5 || idx == 6) && (constBlobs.find(lstm_proto.input(idx)) == constBlobs.end()))
|
||||||
CV_Assert(shape(blob) == blobShape);
|
{
|
||||||
|
blob = Mat();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
blob = getBlob(lstm_proto, idx);
|
||||||
|
CV_Assert(shape(blob) == blobShape);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user