mirror of
https://github.com/opencv/opencv.git
synced 2025-06-08 01:53:19 +08:00
add new (Log)SoftMax simplification passes
This commit is contained in:
parent
a6277370ca
commit
829410729c
@ -107,17 +107,10 @@ private:
|
||||
opencv_onnx::GraphProto& net;
|
||||
};
|
||||
|
||||
class SoftMaxSubgraph : public Subgraph
|
||||
class SoftMaxSubgraphBase : public Subgraph
|
||||
{
|
||||
public:
|
||||
SoftMaxSubgraph() : axis(1)
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int inpExp = addNodeToMatch("Exp", input);
|
||||
int sum = addNodeToMatch("ReduceSum", inpExp);
|
||||
addNodeToMatch("Div", inpExp, sum);
|
||||
setFusedNode("Softmax", input);
|
||||
}
|
||||
SoftMaxSubgraphBase() : axis(1), id(-1) {}
|
||||
|
||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
std::vector<int>& matchedNodesIds,
|
||||
@ -125,7 +118,8 @@ public:
|
||||
{
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
||||
{
|
||||
Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[1]);
|
||||
CV_Assert(id >= 0 && id < matchedNodesIds.size());
|
||||
Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[id]);
|
||||
opencv_onnx::NodeProto* node = sum.dynamicCast<ONNXNodeWrapper>()->node;
|
||||
|
||||
for (int i = 0; i < node->attribute_size(); i++)
|
||||
@ -153,8 +147,60 @@ public:
|
||||
attr->set_i(axis);
|
||||
}
|
||||
|
||||
private:
|
||||
protected:
|
||||
int axis;
|
||||
int id;
|
||||
};
|
||||
|
||||
class SoftMaxSubgraph : public SoftMaxSubgraphBase
|
||||
{
|
||||
public:
|
||||
SoftMaxSubgraph()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int inpExp = addNodeToMatch("Exp", input);
|
||||
|
||||
int sum = addNodeToMatch("ReduceSum", inpExp);
|
||||
id = 1;
|
||||
|
||||
addNodeToMatch("Div", inpExp, sum);
|
||||
setFusedNode("Softmax", input);
|
||||
}
|
||||
};
|
||||
|
||||
class SoftMaxSubgraph2 : public SoftMaxSubgraphBase {
|
||||
public:
|
||||
SoftMaxSubgraph2() {
|
||||
int input = addNodeToMatch("");
|
||||
|
||||
int reducemax = addNodeToMatch("ReduceMax", input);
|
||||
id = 0;
|
||||
|
||||
int sub = addNodeToMatch("Sub", input, reducemax);
|
||||
int exp = addNodeToMatch("Exp", sub);
|
||||
int reducesum = addNodeToMatch("ReduceSum", exp, addNodeToMatch(""));
|
||||
addNodeToMatch("Div", exp, reducesum);
|
||||
setFusedNode("Softmax", input);
|
||||
}
|
||||
};
|
||||
|
||||
class LogSoftMaxSubgraph : public SoftMaxSubgraphBase
|
||||
{
|
||||
public:
|
||||
LogSoftMaxSubgraph()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
|
||||
int reducemax = addNodeToMatch("ReduceMax", input);
|
||||
id = 0;
|
||||
|
||||
int sub_1 = addNodeToMatch("Sub", input, reducemax);
|
||||
int exp = addNodeToMatch("Exp", sub_1);
|
||||
int reducesum = addNodeToMatch("ReduceSum", exp, addNodeToMatch(""));
|
||||
int log = addNodeToMatch("Log", reducesum);
|
||||
addNodeToMatch("Sub", sub_1, log);
|
||||
setFusedNode("LogSoftmax", input);
|
||||
}
|
||||
};
|
||||
|
||||
class NormalizeSubgraphBase : public Subgraph
|
||||
@ -574,6 +620,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
|
||||
subgraphs.push_back(makePtr<ResizeSubgraph1>());
|
||||
subgraphs.push_back(makePtr<ResizeSubgraph2>());
|
||||
subgraphs.push_back(makePtr<SoftMaxSubgraph>());
|
||||
subgraphs.push_back(makePtr<SoftMaxSubgraph2>());
|
||||
subgraphs.push_back(makePtr<LogSoftMaxSubgraph>());
|
||||
subgraphs.push_back(makePtr<NormalizeSubgraph1>());
|
||||
subgraphs.push_back(makePtr<NormalizeSubgraph2>());
|
||||
subgraphs.push_back(makePtr<NormalizeSubgraph2_2>());
|
||||
|
Loading…
Reference in New Issue
Block a user