mirror of
https://github.com/opencv/opencv.git
synced 2025-06-11 11:45:30 +08:00
LSTM from ONNX works
This commit is contained in:
parent
14da5ec311
commit
8d69dbdf49
@ -215,8 +215,6 @@ public:
|
||||
internals.push_back(shape(_numSamples, 1)); // dummyOnes
|
||||
internals.push_back(shape(_numSamples, 4*_numOut)); // gates
|
||||
|
||||
|
||||
std::cout << "LSTM out: " << outputs[0] << '\n';
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -303,8 +301,6 @@ public:
|
||||
tsEnd = numTimeStamps;
|
||||
tsInc = 1;
|
||||
}
|
||||
std::cout << "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" << '\n';
|
||||
std::cout << tsStart << " " << tsEnd << '\n';
|
||||
for (int ts = tsStart; ts != tsEnd; ts += tsInc)
|
||||
{
|
||||
Range curRowRange(ts*numSamples, (ts + 1)*numSamples);
|
||||
@ -318,7 +314,6 @@ public:
|
||||
Mat gateF = gates.colRange(1*numOut, 2*numOut);
|
||||
Mat gateO = gates.colRange(2*numOut, 3*numOut);
|
||||
Mat gateG = gates.colRange(3*numOut, 4*numOut);
|
||||
std::cout << "i " << gateI << '\n';
|
||||
|
||||
if (forgetBias)
|
||||
add(gateF, forgetBias, gateF);
|
||||
@ -334,7 +329,6 @@ public:
|
||||
{
|
||||
Mat gatesIFO = gates.colRange(0, 3*numOut);
|
||||
sigmoid(gatesIFO, gatesIFO);
|
||||
std::cout << "ifo " << gatesIFO << '\n';
|
||||
}
|
||||
|
||||
tanh(gateG, gateG);
|
||||
@ -351,15 +345,12 @@ public:
|
||||
}
|
||||
if (usePeephole)
|
||||
{
|
||||
std::cout << "if (usePeephole)" << '\n';
|
||||
gemm(cInternal, blobs[5], 1, gateO, 1, gateO);
|
||||
sigmoid(gateO, gateO);
|
||||
}
|
||||
|
||||
//compute h_t
|
||||
tanh(cInternal, hInternal);
|
||||
std::cout << "o " << gateO << '\n';
|
||||
std::cout << "tanh(o) " << hInternal << '\n';
|
||||
multiply(gateO, hInternal, hInternal);
|
||||
|
||||
//save results in output blobs
|
||||
@ -367,7 +358,6 @@ public:
|
||||
if (produceCellOutput)
|
||||
cInternal.copyTo(cOutTs.rowRange(curRowRange));
|
||||
}
|
||||
std::cout << "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" << '\n';
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -290,30 +290,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// // To remove Squeeze after LSTM for non-bidirectional LSTM
|
||||
// class LSTMSqueeze : public Subgraph
|
||||
// {
|
||||
// public:
|
||||
// LSTMSqueeze()
|
||||
// {
|
||||
// int input = addNodeToMatch("");
|
||||
//
|
||||
// std::vector<int> lstmInps(7);
|
||||
// lstmInps[0] = input;
|
||||
//
|
||||
// for (int i = 1; i < 4; ++i)
|
||||
// lstmInps[i] = addNodeToMatch("Unsqueeze");
|
||||
// lstmInps[4] = addNodeToMatch("");
|
||||
// for (int i = 5; i < 7; ++i)
|
||||
// lstmInps[i] = addNodeToMatch("ConstantOfShape");
|
||||
//
|
||||
// int lstm = addNodeToMatch("LSTM", lstmInps);
|
||||
// addNodeToMatch("Squeeze", lstm);
|
||||
//
|
||||
// setFusedNode("LSTM", lstmInps);
|
||||
// }
|
||||
// };
|
||||
|
||||
void simplifySubgraphs(opencv_onnx::GraphProto& net)
|
||||
{
|
||||
std::vector<Ptr<Subgraph> > subgraphs;
|
||||
@ -323,7 +299,6 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
|
||||
subgraphs.push_back(makePtr<ResizeSubgraph1>());
|
||||
subgraphs.push_back(makePtr<ResizeSubgraph2>());
|
||||
subgraphs.push_back(makePtr<SoftMaxSubgraph>());
|
||||
// subgraphs.push_back(makePtr<LSTMSqueeze>());
|
||||
|
||||
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
|
||||
}
|
||||
|
@ -322,7 +322,7 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
|
||||
std::string layer_type = node_proto.op_type();
|
||||
layerParams.type = layer_type;
|
||||
std::cout << layerParams.name << " " << layer_type << '\n';
|
||||
|
||||
|
||||
if (layer_type == "MaxPool")
|
||||
{
|
||||
@ -457,19 +457,6 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
constBlobs.insert(std::make_pair(layerParams.name, sliced[0]));
|
||||
continue;
|
||||
}
|
||||
|
||||
layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size()));
|
||||
layerParams.set("end", DictValue::arrayInt(&end[0], end.size()));
|
||||
|
||||
CV_Assert(node_proto.input_size() == 1);
|
||||
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
|
||||
{
|
||||
std::vector<Mat> inputs(1, getBlob(node_proto, constBlobs, 0)), sliced;
|
||||
runLayer(layerParams, inputs, sliced);
|
||||
CV_Assert(sliced.size() == 1);
|
||||
constBlobs.insert(std::make_pair(layerParams.name, sliced[0]));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
else if (layer_type == "Split")
|
||||
{
|
||||
@ -592,116 +579,43 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
constBlobs.insert(std::make_pair(layerParams.name, layerParams.blobs[0]));
|
||||
continue;
|
||||
}
|
||||
else if (layer_type == "ConstantFill" || layer_type == "ConstantOfShape")
|
||||
{
|
||||
CV_Assert_N(node_proto.input_size());
|
||||
MatShape inpShape = getBlob(node_proto, constBlobs, 0);
|
||||
float value = layerParams.get("value", 0);
|
||||
Mat fill(inpShape.size(), &inpShape[0], CV_32F, Scalar(value));
|
||||
constBlobs.insert(std::make_pair(layerParams.name, fill));
|
||||
continue;
|
||||
}
|
||||
else if (layer_type == "LSTM")
|
||||
{
|
||||
std::cout << "~~~~~~" << '\n';
|
||||
std::cout << layerParams << '\n';
|
||||
for (int i = 1; i < node_proto.input_size(); ++i) {
|
||||
std::cout << "i: " << node_proto.input(i) << " " << constBlobs[node_proto.input(i)].size << '\n';
|
||||
}
|
||||
|
||||
// https://pytorch.org/docs/stable/nn.html#lstm
|
||||
CV_Assert(node_proto.input_size() == 7);
|
||||
Mat Wx = getBlob(node_proto, constBlobs, 1);
|
||||
Mat Wh = getBlob(node_proto, constBlobs, 2);
|
||||
Mat b = getBlob(node_proto, constBlobs, 3);
|
||||
|
||||
const int numHidden = Wh.size[2];
|
||||
|
||||
std::cout << Wx.size << '\n';
|
||||
std::cout << Wh.size << '\n';
|
||||
|
||||
int Wx_shape[] = {Wx.size[1], Wx.size[2]};
|
||||
int Wh_shape[] = {Wh.size[1], Wh.size[2]};
|
||||
std::cout << "b.size " << b.size << '\n';
|
||||
int b_shape[] = {2, b.size[1] / 2};
|
||||
|
||||
Wx = Wx.reshape(1, 2, &Wx_shape[0]);
|
||||
b = b.reshape(1, 2, &b_shape[0]);
|
||||
|
||||
std::cout << "b ----------------" << '\n';
|
||||
|
||||
std::cout << b << '\n';
|
||||
Wx = Wx.reshape(1, Wx.size[1]);
|
||||
Wh = Wh.reshape(1, Wh.size[1]);
|
||||
b = b.reshape(1, 2);
|
||||
reduce(b, b, 0, REDUCE_SUM);
|
||||
std::cout << b << '\n';
|
||||
|
||||
// https://pytorch.org/docs/stable/nn.html#lstm
|
||||
// IFGO->IFOG
|
||||
// swap each 3rd and 4th rows
|
||||
// Wx = Wx.t();
|
||||
|
||||
float* weightData = (float*)Wx.data;
|
||||
std::swap(weightData[1], weightData[2]);
|
||||
|
||||
// IFGO->IGFO
|
||||
float* WxData = (float*)Wx.data;
|
||||
float* WhData = (float*)Wh.data;
|
||||
float* biasData = (float*)b.data;
|
||||
std::swap(biasData[1], biasData[2]);
|
||||
|
||||
// std::swap(weightData[2], weightData[3]);
|
||||
//
|
||||
// weightData = (float*)Wh.data;
|
||||
// std::swap(weightData[1], weightData[2]);
|
||||
// std::swap(weightData[2], weightData[3]);
|
||||
|
||||
|
||||
// const int outSize = Wx.cols / 4;
|
||||
// for (int i = 0; i < Wx.rows; ++i)
|
||||
// for (int j = 0; j < outSize; ++j)
|
||||
// {
|
||||
// // std::swap(weightData[i * W.cols + 1 * outSize + j],
|
||||
// // weightData[i * W.cols + 2 * outSize + j]);
|
||||
// std::swap(weightData[i * Wx.cols + 2 * outSize + j],
|
||||
// weightData[i * Wx.cols + 3 * outSize + j]);
|
||||
// }
|
||||
|
||||
// float* weightData = Wx.ptr<float>();
|
||||
// for (int j = 0; j < 5; ++j)
|
||||
// {
|
||||
// std::cout << "swap " << (10 + j) << " " << (15 + j) << '\n';
|
||||
// for (int i = 0; i < 12; ++i)
|
||||
// std::swap(weightData[(10 + j) * 12 + i],
|
||||
// weightData[(15 + j) * 12 + i]);
|
||||
// }
|
||||
|
||||
for (int j = 0; j < numHidden; ++j)
|
||||
{
|
||||
for (int i = 0; i < Wx.cols; ++i)
|
||||
{
|
||||
std::swap(WxData[(numHidden + j) * Wx.cols + i],
|
||||
WxData[(numHidden * 2 + j) * Wx.cols + i]);
|
||||
}
|
||||
for (int i = 0; i < Wh.cols; ++i)
|
||||
{
|
||||
std::swap(WhData[(numHidden + j) * Wh.cols + i],
|
||||
WhData[(numHidden * 2 + j) * Wh.cols + i]);
|
||||
}
|
||||
std::swap(biasData[numHidden + j], biasData[numHidden * 2 + j]);
|
||||
}
|
||||
layerParams.blobs.resize(3);
|
||||
layerParams.blobs[0] = Wh.reshape(1, 2, &Wh_shape[0]);
|
||||
layerParams.blobs[0] = Wh;
|
||||
layerParams.blobs[1] = Wx;
|
||||
layerParams.blobs[2] = b;
|
||||
|
||||
std::cout << "Wx" << '\n';
|
||||
std::cout << layerParams.blobs[1] << '\n';
|
||||
|
||||
std::cout << "Wh" << '\n';
|
||||
std::cout << layerParams.blobs[0] << '\n';
|
||||
|
||||
// layerParams.set("reverse", true);
|
||||
|
||||
|
||||
// layerParams.set("use_peephole", true);
|
||||
// layerParams.blobs.resize(6);
|
||||
// for (int i = 0; i < 3; ++i)
|
||||
// {
|
||||
// Mat w = Mat::eye(layerParams.blobs[0].cols, layerParams.blobs[0].cols, CV_32F);
|
||||
// layerParams.blobs[3 + i] = w;
|
||||
// }
|
||||
|
||||
// std::cout << layerParams.blobs[1] << '\n';
|
||||
|
||||
// int lstmId = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
|
||||
//
|
||||
// layerParams = LayerParams();
|
||||
//
|
||||
// // Add reshape
|
||||
// int shape[] = {1, 10, 11, 5};
|
||||
// layerParams.name = node_proto.output(0) + "/reshape";
|
||||
// layerParams.type = "Reshape";
|
||||
// layerParams.set("dim", DictValue::arrayInt(&shape[0], 4));
|
||||
}
|
||||
else if (layer_type == "ImageScaler")
|
||||
{
|
||||
@ -1005,14 +919,29 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
else if (layer_type == "Squeeze")
|
||||
{
|
||||
CV_Assert_N(node_proto.input_size() == 1, layerParams.has("axes"));
|
||||
// DictValue axes_dict = layerParams.get("axes");
|
||||
// if (axes_dict.size() != 1)
|
||||
// CV_Error(Error::StsNotImplemented, "Multidimensional squeeze");
|
||||
//
|
||||
// int axis = axes_dict.getIntValue(0);
|
||||
// layerParams.set("axis", axis - 1);
|
||||
// layerParams.set("end_axis", axis);
|
||||
layerParams.type = "Identity";
|
||||
DictValue axes_dict = layerParams.get("axes");
|
||||
MatShape inpShape = outShapes[node_proto.input(0)];
|
||||
|
||||
std::vector<bool> maskedAxes(inpShape.size(), false);
|
||||
for (int i = 0; i < axes_dict.size(); ++i)
|
||||
{
|
||||
int axis = axes_dict.getIntValue(i);
|
||||
CV_CheckLE(axis, static_cast<int>(inpShape.size()), "Squeeze axis");
|
||||
maskedAxes[axis] = inpShape[axis] == 1;
|
||||
}
|
||||
MatShape outShape;
|
||||
for (int i = 0; i < inpShape.size(); ++i)
|
||||
{
|
||||
if (!maskedAxes[i])
|
||||
outShape.push_back(inpShape[i]);
|
||||
}
|
||||
if (outShape.size() != inpShape.size())
|
||||
{
|
||||
layerParams.type = "Reshape";
|
||||
layerParams.set("dim", DictValue::arrayInt(&outShape[0], outShape.size()));
|
||||
}
|
||||
else
|
||||
layerParams.type = "Identity";
|
||||
}
|
||||
else if (layer_type == "Flatten")
|
||||
{
|
||||
@ -1142,9 +1071,26 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
else
|
||||
layerParams.type = "Identity";
|
||||
}
|
||||
else if (layer_type == "ConstantOfShape")
|
||||
else if (layer_type == "ConstantFill" || layer_type == "ConstantOfShape")
|
||||
{
|
||||
float fill_value = layerParams.blobs.empty() ? 0 : layerParams.blobs[0].at<float>(0, 0);
|
||||
CV_Assert_N(node_proto.input_size());
|
||||
MatShape inpShape = getBlob(node_proto, constBlobs, 0);
|
||||
float value = layerParams.get("value", 0);
|
||||
Mat fill(inpShape.size(), &inpShape[0], CV_32F, Scalar(value));
|
||||
constBlobs.insert(std::make_pair(layerParams.name, fill));
|
||||
continue;
|
||||
}
|
||||
else if (layer_type == "ConstantOfShape" || layer_type == "ConstantFill")
|
||||
{
|
||||
float fill_value;
|
||||
if (!layerParams.blobs.empty())
|
||||
{
|
||||
CV_Assert(!layerParams.has("value"));
|
||||
fill_value = layerParams.blobs[0].at<float>(0, 0);
|
||||
}
|
||||
else
|
||||
fill_value = layerParams.get("value", 0);
|
||||
|
||||
MatShape inpShape = getBlob(node_proto, constBlobs, 0);
|
||||
for (int i = 0; i < inpShape.size(); i++)
|
||||
CV_CheckGT(inpShape[i], 0, "");
|
||||
|
@ -1826,12 +1826,10 @@ void TFImporter::populateNet(Net dstNet)
|
||||
const int outSize = W.cols / 4;
|
||||
|
||||
// IGFO->IFOG
|
||||
std::cout << "(TF) W " << W.size << '\n';
|
||||
float* weightData = (float*)W.data;
|
||||
for (int i = 0; i < W.rows; ++i)
|
||||
for (int j = 0; j < outSize; ++j)
|
||||
{
|
||||
// std::cout << "swap " << i * W.cols + 1 * outSize << " " << i * W.cols + 2 * outSize << '\n';
|
||||
std::swap(weightData[i * W.cols + 1 * outSize + j],
|
||||
weightData[i * W.cols + 2 * outSize + j]);
|
||||
std::swap(weightData[i * W.cols + 2 * outSize + j],
|
||||
@ -1840,11 +1838,6 @@ void TFImporter::populateNet(Net dstNet)
|
||||
Wx = W.rowRange(0, W.rows - outSize).t();
|
||||
Wh = W.rowRange(W.rows - outSize, W.rows).t();
|
||||
|
||||
std::cout << "(TF) Wx " << Wx.size << '\n';
|
||||
std::cout << "(TF) Wh " << Wh.size << '\n';
|
||||
std::cout << "(TF) b " << b.size << '\n';
|
||||
|
||||
|
||||
layerParams.blobs.resize(3);
|
||||
layerParams.blobs[0] = Wh;
|
||||
layerParams.blobs[1] = Wx;
|
||||
|
@ -79,12 +79,6 @@ public:
|
||||
netSoftmax.setInput(ref);
|
||||
ref = netSoftmax.forward();
|
||||
}
|
||||
std::cout << "ref: " << ref.size << '\n';
|
||||
std::cout << "out: " << out.size << '\n';
|
||||
std::cout << ref.reshape(1, 1) << '\n';
|
||||
std::cout << '\n';
|
||||
std::cout << out.reshape(1, 1) << '\n';
|
||||
|
||||
normAssert(ref, out, "", l1 ? l1 : default_l1, lInf ? lInf : default_lInf);
|
||||
if (checkNoFallbacks)
|
||||
expectNoFallbacksFromIE(net);
|
||||
|
Loading…
Reference in New Issue
Block a user