mirror of
https://github.com/opencv/opencv.git
synced 2025-06-10 19:24:07 +08:00
add new (Log)SoftMax simplification passes
This commit is contained in:
parent
a6277370ca
commit
829410729c
@ -107,17 +107,10 @@ private:
|
|||||||
opencv_onnx::GraphProto& net;
|
opencv_onnx::GraphProto& net;
|
||||||
};
|
};
|
||||||
|
|
||||||
class SoftMaxSubgraph : public Subgraph
|
class SoftMaxSubgraphBase : public Subgraph
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
SoftMaxSubgraph() : axis(1)
|
SoftMaxSubgraphBase() : axis(1), id(-1) {}
|
||||||
{
|
|
||||||
int input = addNodeToMatch("");
|
|
||||||
int inpExp = addNodeToMatch("Exp", input);
|
|
||||||
int sum = addNodeToMatch("ReduceSum", inpExp);
|
|
||||||
addNodeToMatch("Div", inpExp, sum);
|
|
||||||
setFusedNode("Softmax", input);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||||
std::vector<int>& matchedNodesIds,
|
std::vector<int>& matchedNodesIds,
|
||||||
@ -125,7 +118,8 @@ public:
|
|||||||
{
|
{
|
||||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
||||||
{
|
{
|
||||||
Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[1]);
|
CV_Assert(id >= 0 && id < matchedNodesIds.size());
|
||||||
|
Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[id]);
|
||||||
opencv_onnx::NodeProto* node = sum.dynamicCast<ONNXNodeWrapper>()->node;
|
opencv_onnx::NodeProto* node = sum.dynamicCast<ONNXNodeWrapper>()->node;
|
||||||
|
|
||||||
for (int i = 0; i < node->attribute_size(); i++)
|
for (int i = 0; i < node->attribute_size(); i++)
|
||||||
@ -153,8 +147,60 @@ public:
|
|||||||
attr->set_i(axis);
|
attr->set_i(axis);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
int axis;
|
int axis;
|
||||||
|
int id;
|
||||||
|
};
|
||||||
|
|
||||||
|
class SoftMaxSubgraph : public SoftMaxSubgraphBase
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
SoftMaxSubgraph()
|
||||||
|
{
|
||||||
|
int input = addNodeToMatch("");
|
||||||
|
int inpExp = addNodeToMatch("Exp", input);
|
||||||
|
|
||||||
|
int sum = addNodeToMatch("ReduceSum", inpExp);
|
||||||
|
id = 1;
|
||||||
|
|
||||||
|
addNodeToMatch("Div", inpExp, sum);
|
||||||
|
setFusedNode("Softmax", input);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class SoftMaxSubgraph2 : public SoftMaxSubgraphBase {
|
||||||
|
public:
|
||||||
|
SoftMaxSubgraph2() {
|
||||||
|
int input = addNodeToMatch("");
|
||||||
|
|
||||||
|
int reducemax = addNodeToMatch("ReduceMax", input);
|
||||||
|
id = 0;
|
||||||
|
|
||||||
|
int sub = addNodeToMatch("Sub", input, reducemax);
|
||||||
|
int exp = addNodeToMatch("Exp", sub);
|
||||||
|
int reducesum = addNodeToMatch("ReduceSum", exp, addNodeToMatch(""));
|
||||||
|
addNodeToMatch("Div", exp, reducesum);
|
||||||
|
setFusedNode("Softmax", input);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class LogSoftMaxSubgraph : public SoftMaxSubgraphBase
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
LogSoftMaxSubgraph()
|
||||||
|
{
|
||||||
|
int input = addNodeToMatch("");
|
||||||
|
|
||||||
|
int reducemax = addNodeToMatch("ReduceMax", input);
|
||||||
|
id = 0;
|
||||||
|
|
||||||
|
int sub_1 = addNodeToMatch("Sub", input, reducemax);
|
||||||
|
int exp = addNodeToMatch("Exp", sub_1);
|
||||||
|
int reducesum = addNodeToMatch("ReduceSum", exp, addNodeToMatch(""));
|
||||||
|
int log = addNodeToMatch("Log", reducesum);
|
||||||
|
addNodeToMatch("Sub", sub_1, log);
|
||||||
|
setFusedNode("LogSoftmax", input);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class NormalizeSubgraphBase : public Subgraph
|
class NormalizeSubgraphBase : public Subgraph
|
||||||
@ -574,6 +620,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
|
|||||||
subgraphs.push_back(makePtr<ResizeSubgraph1>());
|
subgraphs.push_back(makePtr<ResizeSubgraph1>());
|
||||||
subgraphs.push_back(makePtr<ResizeSubgraph2>());
|
subgraphs.push_back(makePtr<ResizeSubgraph2>());
|
||||||
subgraphs.push_back(makePtr<SoftMaxSubgraph>());
|
subgraphs.push_back(makePtr<SoftMaxSubgraph>());
|
||||||
|
subgraphs.push_back(makePtr<SoftMaxSubgraph2>());
|
||||||
|
subgraphs.push_back(makePtr<LogSoftMaxSubgraph>());
|
||||||
subgraphs.push_back(makePtr<NormalizeSubgraph1>());
|
subgraphs.push_back(makePtr<NormalizeSubgraph1>());
|
||||||
subgraphs.push_back(makePtr<NormalizeSubgraph2>());
|
subgraphs.push_back(makePtr<NormalizeSubgraph2>());
|
||||||
subgraphs.push_back(makePtr<NormalizeSubgraph2_2>());
|
subgraphs.push_back(makePtr<NormalizeSubgraph2_2>());
|
||||||
|
Loading…
Reference in New Issue
Block a user