mirror of
https://github.com/opencv/opencv.git
synced 2024-11-25 03:30:34 +08:00
Gather-Cast, Mul-Cast fusion
This commit is contained in:
parent
4d0f13544d
commit
e18d5e94c7
@ -194,15 +194,14 @@ void simplifySubgraphs(const Ptr<ImportGraphWrapper>& net,
|
||||
{
|
||||
int numNodes = net->getNumNodes();
|
||||
std::vector<int> matchedNodesIds, targetNodesIds;
|
||||
for (int i = 0; i < numNodes; ++i)
|
||||
for (int j = 0; j < patterns.size(); ++j)
|
||||
{
|
||||
for (int j = 0; j < patterns.size(); ++j)
|
||||
for (int i = 0; i < numNodes; ++i)
|
||||
{
|
||||
if (patterns[j]->match(net, i, matchedNodesIds, targetNodesIds))
|
||||
{
|
||||
patterns[j]->replace(net, matchedNodesIds, targetNodesIds);
|
||||
numNodes -= matchedNodesIds.size() - 1; // #matchedNodes removed and one added.
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -154,6 +154,32 @@ private:
|
||||
int axis;
|
||||
};
|
||||
|
||||
class GatherCastSubgraph : public Subgraph
|
||||
{
|
||||
public:
|
||||
GatherCastSubgraph()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int index = addNodeToMatch("Constant");
|
||||
int gather = addNodeToMatch("Gather", input, index);
|
||||
addNodeToMatch("Cast", gather);
|
||||
setFusedNode("Gather", input, index);
|
||||
}
|
||||
};
|
||||
|
||||
class MulCastSubgraph : public Subgraph
|
||||
{
|
||||
public:
|
||||
MulCastSubgraph()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int scaleNode = addNodeToMatch("Constant");
|
||||
int mul = addNodeToMatch("Mul", input, scaleNode);
|
||||
addNodeToMatch("Cast", mul);
|
||||
setFusedNode("Mul", input, scaleNode);
|
||||
}
|
||||
};
|
||||
|
||||
class ExtractScalesSubgraph : public Subgraph
|
||||
{
|
||||
public:
|
||||
@ -164,20 +190,16 @@ public:
|
||||
int indexH = addNodeToMatch("Constant");
|
||||
int shape1 = addNodeToMatch("Shape", input);
|
||||
int gather1 = addNodeToMatch("Gather", shape1, indexH);
|
||||
int castG1 = addNodeToMatch("Cast", gather1);
|
||||
scaleHNode = addNodeToMatch("Constant");
|
||||
int mul1 = addNodeToMatch("Mul", castG1, scaleHNode);
|
||||
int castM1 = addNodeToMatch("Cast", mul1);
|
||||
int floor1 = addNodeToMatch("Floor", castM1);
|
||||
int mul1 = addNodeToMatch("Mul", gather1, scaleHNode);
|
||||
int floor1 = addNodeToMatch("Floor", mul1);
|
||||
|
||||
int indexW = addNodeToMatch("Constant");
|
||||
int shape2 = addNodeToMatch("Shape", input);
|
||||
int gather2 = addNodeToMatch("Gather", shape2, indexW);
|
||||
int castG2 = addNodeToMatch("Cast", gather2);
|
||||
scaleWNode = addNodeToMatch("Constant");
|
||||
int mul2 = addNodeToMatch("Mul", castG2, scaleWNode);
|
||||
int castM2 = addNodeToMatch("Cast", mul2);
|
||||
int floor2 = addNodeToMatch("Floor", castM2);
|
||||
int mul2 = addNodeToMatch("Mul", gather2, scaleWNode);
|
||||
int floor2 = addNodeToMatch("Floor", mul2);
|
||||
|
||||
int unsqueeze1 = addNodeToMatch("Unsqueeze", floor1);
|
||||
int unsqueeze2 = addNodeToMatch("Unsqueeze", floor2);
|
||||
@ -190,19 +212,23 @@ public:
|
||||
{
|
||||
opencv_onnx::NodeProto* constant_node = inputs[1].dynamicCast<ONNXNodeWrapper>()->node;
|
||||
opencv_onnx::TensorProto tensor_proto = constant_node->attribute(0).t();
|
||||
float scaleW = getMatFromTensor(tensor_proto).at<float>(0);
|
||||
Mat scaleW = getMatFromTensor(tensor_proto);
|
||||
CV_Assert(scaleW.total() == 1);
|
||||
scaleW.convertTo(scaleW, CV_32F);
|
||||
|
||||
constant_node = inputs[2].dynamicCast<ONNXNodeWrapper>()->node;
|
||||
tensor_proto = constant_node->attribute(0).t();
|
||||
float scaleH = getMatFromTensor(tensor_proto).at<float>(0);
|
||||
Mat scaleH = getMatFromTensor(tensor_proto);
|
||||
CV_Assert(scaleH.total() == 1);
|
||||
scaleH.convertTo(scaleH, CV_32F);
|
||||
|
||||
opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node;
|
||||
opencv_onnx::AttributeProto* attrH = node->add_attribute();
|
||||
attrH->set_name("height_scale");
|
||||
attrH->set_i(scaleH);
|
||||
attrH->set_i(scaleH.at<float>(0));
|
||||
opencv_onnx::AttributeProto* attrW = node->add_attribute();
|
||||
attrW->set_name("width_scale");
|
||||
attrW->set_i(scaleW);
|
||||
attrW->set_i(scaleW.at<float>(0));
|
||||
|
||||
node->mutable_input()->DeleteSubrange(1, 2); // Remove two last inputs
|
||||
}
|
||||
@ -267,6 +293,8 @@ public:
|
||||
void simplifySubgraphs(opencv_onnx::GraphProto& net)
|
||||
{
|
||||
std::vector<Ptr<Subgraph> > subgraphs;
|
||||
subgraphs.push_back(makePtr<GatherCastSubgraph>());
|
||||
subgraphs.push_back(makePtr<MulCastSubgraph>());
|
||||
subgraphs.push_back(makePtr<UpsampleSubgraph>());
|
||||
subgraphs.push_back(makePtr<ResizeSubgraph1>());
|
||||
subgraphs.push_back(makePtr<ResizeSubgraph2>());
|
||||
|
@ -320,6 +320,7 @@ TEST_P(Test_ONNX_layers, ResizeUnfused)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
|
||||
testONNXModels("upsample_unfused_torch1.2");
|
||||
testONNXModels("upsample_unfused_opset9_torch1.4");
|
||||
testONNXModels("resize_nearest_unfused_opset11_torch1.4");
|
||||
testONNXModels("resize_nearest_unfused_opset11_torch1.3");
|
||||
|
Loading…
Reference in New Issue
Block a user