Gather-Cast, Mul-Cast fusion

This commit is contained in:
ashishiva3@gmail.com 2020-03-01 15:09:15 +05:30
parent 4d0f13544d
commit e18d5e94c7
3 changed files with 43 additions and 15 deletions

View File

@ -194,15 +194,14 @@ void simplifySubgraphs(const Ptr<ImportGraphWrapper>& net,
{
int numNodes = net->getNumNodes();
std::vector<int> matchedNodesIds, targetNodesIds;
for (int i = 0; i < numNodes; ++i)
for (int j = 0; j < patterns.size(); ++j)
{
for (int j = 0; j < patterns.size(); ++j)
for (int i = 0; i < numNodes; ++i)
{
if (patterns[j]->match(net, i, matchedNodesIds, targetNodesIds))
{
patterns[j]->replace(net, matchedNodesIds, targetNodesIds);
numNodes -= matchedNodesIds.size() - 1; // #matchedNodes removed and one added.
break;
}
}
}

View File

@ -154,6 +154,32 @@ private:
int axis;
};
class GatherCastSubgraph : public Subgraph
{
public:
GatherCastSubgraph()
{
int input = addNodeToMatch("");
int index = addNodeToMatch("Constant");
int gather = addNodeToMatch("Gather", input, index);
addNodeToMatch("Cast", gather);
setFusedNode("Gather", input, index);
}
};
class MulCastSubgraph : public Subgraph
{
public:
MulCastSubgraph()
{
int input = addNodeToMatch("");
int scaleNode = addNodeToMatch("Constant");
int mul = addNodeToMatch("Mul", input, scaleNode);
addNodeToMatch("Cast", mul);
setFusedNode("Mul", input, scaleNode);
}
};
class ExtractScalesSubgraph : public Subgraph
{
public:
@ -164,20 +190,16 @@ public:
int indexH = addNodeToMatch("Constant");
int shape1 = addNodeToMatch("Shape", input);
int gather1 = addNodeToMatch("Gather", shape1, indexH);
int castG1 = addNodeToMatch("Cast", gather1);
scaleHNode = addNodeToMatch("Constant");
int mul1 = addNodeToMatch("Mul", castG1, scaleHNode);
int castM1 = addNodeToMatch("Cast", mul1);
int floor1 = addNodeToMatch("Floor", castM1);
int mul1 = addNodeToMatch("Mul", gather1, scaleHNode);
int floor1 = addNodeToMatch("Floor", mul1);
int indexW = addNodeToMatch("Constant");
int shape2 = addNodeToMatch("Shape", input);
int gather2 = addNodeToMatch("Gather", shape2, indexW);
int castG2 = addNodeToMatch("Cast", gather2);
scaleWNode = addNodeToMatch("Constant");
int mul2 = addNodeToMatch("Mul", castG2, scaleWNode);
int castM2 = addNodeToMatch("Cast", mul2);
int floor2 = addNodeToMatch("Floor", castM2);
int mul2 = addNodeToMatch("Mul", gather2, scaleWNode);
int floor2 = addNodeToMatch("Floor", mul2);
int unsqueeze1 = addNodeToMatch("Unsqueeze", floor1);
int unsqueeze2 = addNodeToMatch("Unsqueeze", floor2);
@ -190,19 +212,23 @@ public:
{
opencv_onnx::NodeProto* constant_node = inputs[1].dynamicCast<ONNXNodeWrapper>()->node;
opencv_onnx::TensorProto tensor_proto = constant_node->attribute(0).t();
float scaleW = getMatFromTensor(tensor_proto).at<float>(0);
Mat scaleW = getMatFromTensor(tensor_proto);
CV_Assert(scaleW.total() == 1);
scaleW.convertTo(scaleW, CV_32F);
constant_node = inputs[2].dynamicCast<ONNXNodeWrapper>()->node;
tensor_proto = constant_node->attribute(0).t();
float scaleH = getMatFromTensor(tensor_proto).at<float>(0);
Mat scaleH = getMatFromTensor(tensor_proto);
CV_Assert(scaleH.total() == 1);
scaleH.convertTo(scaleH, CV_32F);
opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node;
opencv_onnx::AttributeProto* attrH = node->add_attribute();
attrH->set_name("height_scale");
attrH->set_i(scaleH);
attrH->set_i(scaleH.at<float>(0));
opencv_onnx::AttributeProto* attrW = node->add_attribute();
attrW->set_name("width_scale");
attrW->set_i(scaleW);
attrW->set_i(scaleW.at<float>(0));
node->mutable_input()->DeleteSubrange(1, 2); // Remove two last inputs
}
@ -267,6 +293,8 @@ public:
void simplifySubgraphs(opencv_onnx::GraphProto& net)
{
std::vector<Ptr<Subgraph> > subgraphs;
subgraphs.push_back(makePtr<GatherCastSubgraph>());
subgraphs.push_back(makePtr<MulCastSubgraph>());
subgraphs.push_back(makePtr<UpsampleSubgraph>());
subgraphs.push_back(makePtr<ResizeSubgraph1>());
subgraphs.push_back(makePtr<ResizeSubgraph2>());

View File

@ -320,6 +320,7 @@ TEST_P(Test_ONNX_layers, ResizeUnfused)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
testONNXModels("upsample_unfused_torch1.2");
testONNXModels("upsample_unfused_opset9_torch1.4");
testONNXModels("resize_nearest_unfused_opset11_torch1.4");
testONNXModels("resize_nearest_unfused_opset11_torch1.3");