mirror of
https://github.com/opencv/opencv.git
synced 2025-06-13 04:52:53 +08:00
Merge pull request #16925 from dkurt:dnn_ssd.pytorch
This commit is contained in:
commit
245b2fec34
@ -154,16 +154,10 @@ private:
|
||||
int axis;
|
||||
};
|
||||
|
||||
class NormalizeSubgraph1 : public Subgraph
|
||||
class NormalizeSubgraphBase : public Subgraph
|
||||
{
|
||||
public:
|
||||
NormalizeSubgraph1() : axis(1)
|
||||
{
|
||||
input = addNodeToMatch("");
|
||||
norm = addNodeToMatch("ReduceL2", input);
|
||||
addNodeToMatch("Div", input, norm);
|
||||
setFusedNode("Normalize", input);
|
||||
}
|
||||
NormalizeSubgraphBase(int _normNodeOrder = 0) : axis(1), normNodeOrder(_normNodeOrder) {}
|
||||
|
||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
std::vector<int>& matchedNodesIds,
|
||||
@ -171,7 +165,7 @@ public:
|
||||
{
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
||||
{
|
||||
Ptr<ImportNodeWrapper> norm = net->getNode(matchedNodesIds[0]);
|
||||
Ptr<ImportNodeWrapper> norm = net->getNode(matchedNodesIds[normNodeOrder]);
|
||||
opencv_onnx::NodeProto* node = norm.dynamicCast<ONNXNodeWrapper>()->node;
|
||||
|
||||
for (int i = 0; i < node->attribute_size(); i++)
|
||||
@ -204,20 +198,51 @@ public:
|
||||
}
|
||||
|
||||
protected:
|
||||
int input, norm;
|
||||
int axis;
|
||||
int axis, normNodeOrder;
|
||||
};
|
||||
|
||||
|
||||
class NormalizeSubgraph2 : public NormalizeSubgraph1
|
||||
class NormalizeSubgraph1 : public NormalizeSubgraphBase
|
||||
{
|
||||
public:
|
||||
NormalizeSubgraph2() : NormalizeSubgraph1()
|
||||
NormalizeSubgraph1()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int norm = addNodeToMatch("ReduceL2", input);
|
||||
addNodeToMatch("Div", input, norm);
|
||||
setFusedNode("Normalize", input);
|
||||
}
|
||||
};
|
||||
|
||||
class NormalizeSubgraph2 : public NormalizeSubgraphBase
|
||||
{
|
||||
public:
|
||||
NormalizeSubgraph2()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int norm = addNodeToMatch("ReduceL2", input);
|
||||
int clip = addNodeToMatch("Clip", norm);
|
||||
int shape = addNodeToMatch("Shape", input);
|
||||
int expand = addNodeToMatch("Expand", clip, shape);
|
||||
addNodeToMatch("Div", input, expand);
|
||||
setFusedNode("Normalize", input);
|
||||
}
|
||||
};
|
||||
|
||||
class NormalizeSubgraph3 : public NormalizeSubgraphBase
|
||||
{
|
||||
public:
|
||||
NormalizeSubgraph3() : NormalizeSubgraphBase(1)
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int power = addNodeToMatch("Constant");
|
||||
int squared = addNodeToMatch("Pow", input, power);
|
||||
int sum = addNodeToMatch("ReduceSum", squared);
|
||||
int sqrtNode = addNodeToMatch("Sqrt", sum);
|
||||
int eps = addNodeToMatch("Constant");
|
||||
int add = addNodeToMatch("Add", sqrtNode, eps);
|
||||
|
||||
addNodeToMatch("Div", input, add);
|
||||
setFusedNode("Normalize", input);
|
||||
}
|
||||
};
|
||||
|
||||
@ -368,6 +393,7 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
|
||||
subgraphs.push_back(makePtr<SoftMaxSubgraph>());
|
||||
subgraphs.push_back(makePtr<NormalizeSubgraph1>());
|
||||
subgraphs.push_back(makePtr<NormalizeSubgraph2>());
|
||||
subgraphs.push_back(makePtr<NormalizeSubgraph3>());
|
||||
|
||||
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
|
||||
}
|
||||
|
@ -1457,6 +1457,25 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
layerParams.type = "Softmax";
|
||||
layerParams.set("log_softmax", layer_type == "LogSoftmax");
|
||||
}
|
||||
else if (layer_type == "DetectionOutput")
|
||||
{
|
||||
CV_CheckEQ(node_proto.input_size(), 3, "");
|
||||
if (constBlobs.find(node_proto.input(2)) != constBlobs.end())
|
||||
{
|
||||
Mat priors = getBlob(node_proto, constBlobs, 2);
|
||||
|
||||
LayerParams constParams;
|
||||
constParams.name = layerParams.name + "/priors";
|
||||
constParams.type = "Const";
|
||||
constParams.blobs.push_back(priors);
|
||||
|
||||
opencv_onnx::NodeProto priorsProto;
|
||||
priorsProto.add_output(constParams.name);
|
||||
addLayer(dstNet, constParams, priorsProto, layer_id, outShapes);
|
||||
|
||||
node_proto.set_input(2, constParams.name);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int j = 0; j < node_proto.input_size(); j++) {
|
||||
|
@ -440,6 +440,7 @@ TEST_P(Test_ONNX_layers, ReduceL2)
|
||||
{
|
||||
testONNXModels("reduceL2");
|
||||
testONNXModels("reduceL2_subgraph");
|
||||
testONNXModels("reduceL2_subgraph_2");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Split)
|
||||
|
Loading…
Reference in New Issue
Block a user