mirror of
https://github.com/opencv/opencv.git
synced 2025-01-19 06:53:50 +08:00
Fix LogSoftmax for ONNX
Fix wrong indentation as well while at it
This commit is contained in:
parent
447116a93c
commit
f0f50b757d
@ -786,37 +786,42 @@ void ONNXImporter::populateNet(Net dstNet)
|
|||||||
}
|
}
|
||||||
replaceLayerParam(layerParams, "mode", "interpolation");
|
replaceLayerParam(layerParams, "mode", "interpolation");
|
||||||
}
|
}
|
||||||
|
else if (layer_type == "LogSoftmax")
|
||||||
|
{
|
||||||
|
layerParams.type = "Softmax";
|
||||||
|
layerParams.set("log_softmax", true);
|
||||||
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
for (int j = 0; j < node_proto.input_size(); j++) {
|
for (int j = 0; j < node_proto.input_size(); j++) {
|
||||||
if (layer_id.find(node_proto.input(j)) == layer_id.end())
|
if (layer_id.find(node_proto.input(j)) == layer_id.end())
|
||||||
layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j));
|
layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
|
int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
|
||||||
layer_id.insert(std::make_pair(layerParams.name, LayerInfo(id, 0)));
|
layer_id.insert(std::make_pair(layerParams.name, LayerInfo(id, 0)));
|
||||||
|
|
||||||
|
|
||||||
std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
|
std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
|
||||||
for (int j = 0; j < node_proto.input_size(); j++) {
|
for (int j = 0; j < node_proto.input_size(); j++) {
|
||||||
layerId = layer_id.find(node_proto.input(j));
|
layerId = layer_id.find(node_proto.input(j));
|
||||||
if (layerId != layer_id.end()) {
|
if (layerId != layer_id.end()) {
|
||||||
dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, j);
|
dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, j);
|
||||||
// Collect input shapes.
|
// Collect input shapes.
|
||||||
shapeIt = outShapes.find(node_proto.input(j));
|
shapeIt = outShapes.find(node_proto.input(j));
|
||||||
CV_Assert(shapeIt != outShapes.end());
|
CV_Assert(shapeIt != outShapes.end());
|
||||||
layerInpShapes.push_back(shapeIt->second);
|
layerInpShapes.push_back(shapeIt->second);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute shape of output blob for this layer.
|
// Compute shape of output blob for this layer.
|
||||||
Ptr<Layer> layer = dstNet.getLayer(id);
|
Ptr<Layer> layer = dstNet.getLayer(id);
|
||||||
layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
|
layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
|
||||||
CV_Assert(!layerOutShapes.empty());
|
CV_Assert(!layerOutShapes.empty());
|
||||||
outShapes[layerParams.name] = layerOutShapes[0];
|
outShapes[layerParams.name] = layerOutShapes[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Net readNetFromONNX(const String& onnxFile)
|
Net readNetFromONNX(const String& onnxFile)
|
||||||
{
|
{
|
||||||
|
@ -245,6 +245,12 @@ TEST_P(Test_ONNX_layers, Reshape)
|
|||||||
testONNXModels("unsqueeze");
|
testONNXModels("unsqueeze");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(Test_ONNX_layers, Softmax)
|
||||||
|
{
|
||||||
|
testONNXModels("softmax");
|
||||||
|
testONNXModels("log_softmax");
|
||||||
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
|
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
|
||||||
|
|
||||||
class Test_ONNX_nets : public Test_ONNX_layers {};
|
class Test_ONNX_nets : public Test_ONNX_layers {};
|
||||||
|
Loading…
Reference in New Issue
Block a user