mirror of
https://github.com/opencv/opencv.git
synced 2024-11-25 11:40:44 +08:00
Fix LSTM from ONNX with batch==1
This commit is contained in:
parent
8d69dbdf49
commit
11d565ca62
@ -110,10 +110,11 @@ public:
|
||||
const Mat& Wh = blobs[0];
|
||||
const Mat& Wx = blobs[1];
|
||||
const Mat& bias = blobs[2];
|
||||
CV_Assert(Wh.dims == 2 && Wx.dims == 2);
|
||||
CV_Assert(Wh.rows == Wx.rows);
|
||||
CV_Assert(Wh.rows == 4*Wh.cols);
|
||||
CV_Assert(Wh.rows == (int)bias.total());
|
||||
CV_CheckEQ(Wh.dims, 2, "");
|
||||
CV_CheckEQ(Wx.dims, 2, "");
|
||||
CV_CheckEQ(Wh.rows, Wx.rows, "");
|
||||
CV_CheckEQ(Wh.rows, 4*Wh.cols, "");
|
||||
CV_CheckEQ(Wh.rows, (int)bias.total(), "");
|
||||
CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());
|
||||
|
||||
// Peephole weights.
|
||||
|
@ -49,6 +49,11 @@ class ONNXImporter
|
||||
LayerParams getLayerParams(const opencv_onnx::NodeProto& node_proto);
|
||||
bool isCeilMode(const LayerParams& layerParams);
|
||||
|
||||
void addLayer(Net& dstNet, LayerParams& layerParams,
|
||||
const opencv_onnx::NodeProto& node_proto,
|
||||
std::map<std::string, LayerInfo>& layer_id,
|
||||
std::map<std::string, MatShape>& outShapes);
|
||||
|
||||
public:
|
||||
|
||||
ONNXImporter(const char *onnxFile)
|
||||
@ -259,6 +264,42 @@ Mat ONNXImporter::getBlob(const opencv_onnx::NodeProto& node_proto,
|
||||
return constBlob->second;
|
||||
}
|
||||
|
||||
void ONNXImporter::addLayer(Net& dstNet, LayerParams& layerParams,
|
||||
const opencv_onnx::NodeProto& node_proto,
|
||||
std::map<std::string, LayerInfo>& layer_id,
|
||||
std::map<std::string, MatShape>& outShapes)
|
||||
{
|
||||
std::map<std::string, LayerInfo>::iterator layerId;
|
||||
std::map<std::string, MatShape>::iterator shapeIt;
|
||||
|
||||
int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
|
||||
for (int i = 0; i < node_proto.output_size(); ++i)
|
||||
{
|
||||
layer_id.insert(std::make_pair(node_proto.output(i), LayerInfo(id, i)));
|
||||
}
|
||||
|
||||
std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
|
||||
int inpNum = 0;
|
||||
for (int j = 0; j < node_proto.input_size(); j++) {
|
||||
layerId = layer_id.find(node_proto.input(j));
|
||||
if (layerId != layer_id.end()) {
|
||||
dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, inpNum);
|
||||
++inpNum;
|
||||
// Collect input shapes.
|
||||
shapeIt = outShapes.find(node_proto.input(j));
|
||||
CV_Assert(shapeIt != outShapes.end());
|
||||
layerInpShapes.push_back(shapeIt->second);
|
||||
}
|
||||
}
|
||||
// Compute shape of output blob for this layer.
|
||||
Ptr<Layer> layer = dstNet.getLayer(id);
|
||||
layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
|
||||
for (int i = 0; i < node_proto.output_size() && i < (int)layerOutShapes.size(); ++i)
|
||||
{
|
||||
outShapes[node_proto.output(i)] = layerOutShapes[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXImporter::populateNet(Net dstNet)
|
||||
{
|
||||
CV_Assert(model_proto.has_graph());
|
||||
@ -581,13 +622,16 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
}
|
||||
else if (layer_type == "LSTM")
|
||||
{
|
||||
LayerParams lstmParams = layerParams;
|
||||
lstmParams.name += "/lstm";
|
||||
|
||||
// https://pytorch.org/docs/stable/nn.html#lstm
|
||||
CV_Assert(node_proto.input_size() == 7);
|
||||
Mat Wx = getBlob(node_proto, constBlobs, 1);
|
||||
Mat Wh = getBlob(node_proto, constBlobs, 2);
|
||||
Mat b = getBlob(node_proto, constBlobs, 3);
|
||||
|
||||
const int numHidden = Wh.size[2];
|
||||
const int numHidden = lstmParams.get<int>("hidden_size");
|
||||
|
||||
Wx = Wx.reshape(1, Wx.size[1]);
|
||||
Wh = Wh.reshape(1, Wh.size[1]);
|
||||
@ -612,10 +656,24 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
}
|
||||
std::swap(biasData[numHidden + j], biasData[numHidden * 2 + j]);
|
||||
}
|
||||
layerParams.blobs.resize(3);
|
||||
layerParams.blobs[0] = Wh;
|
||||
layerParams.blobs[1] = Wx;
|
||||
layerParams.blobs[2] = b;
|
||||
|
||||
lstmParams.blobs.resize(3);
|
||||
lstmParams.blobs[0] = Wh;
|
||||
lstmParams.blobs[1] = Wx;
|
||||
lstmParams.blobs[2] = b;
|
||||
|
||||
node_proto.set_output(0, lstmParams.name); // set different name so output shapes will be registered on that name
|
||||
addLayer(dstNet, lstmParams, node_proto, layer_id, outShapes);
|
||||
|
||||
MatShape lstmShape = outShapes[node_proto.output(0)];
|
||||
|
||||
// Add fake 1 as it is done in ONNX
|
||||
lstmShape.insert(lstmShape.begin() + 1, 1);
|
||||
|
||||
layerParams.type = "Reshape";
|
||||
layerParams.set("dim", DictValue::arrayInt(&lstmShape[0], lstmShape.size()));
|
||||
node_proto.set_input(0, lstmParams.name); // redirect input to LSTM
|
||||
node_proto.set_output(0, layerParams.name); // keep origin LSTM's name
|
||||
}
|
||||
else if (layer_type == "ImageScaler")
|
||||
{
|
||||
@ -1228,34 +1286,7 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j));
|
||||
}
|
||||
}
|
||||
|
||||
int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
|
||||
for (int i = 0; i < node_proto.output_size(); ++i)
|
||||
{
|
||||
layer_id.insert(std::make_pair(node_proto.output(i), LayerInfo(id, i)));
|
||||
}
|
||||
|
||||
std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
|
||||
int inpNum = 0;
|
||||
for (int j = 0; j < node_proto.input_size(); j++) {
|
||||
layerId = layer_id.find(node_proto.input(j));
|
||||
if (layerId != layer_id.end()) {
|
||||
dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, inpNum);
|
||||
++inpNum;
|
||||
// Collect input shapes.
|
||||
shapeIt = outShapes.find(node_proto.input(j));
|
||||
CV_Assert(shapeIt != outShapes.end());
|
||||
layerInpShapes.push_back(shapeIt->second);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute shape of output blob for this layer.
|
||||
Ptr<Layer> layer = dstNet.getLayer(id);
|
||||
layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
|
||||
for (int i = 0; i < node_proto.output_size() && i < (int)layerOutShapes.size(); ++i)
|
||||
{
|
||||
outShapes[node_proto.output(i)] = layerOutShapes[i];
|
||||
}
|
||||
addLayer(dstNet, layerParams, node_proto, layer_id, outShapes);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user