// This file is part of OpenCV project. // It is subject to the license terms in the LICENSE file found in the top-level directory // of this distribution and at http://opencv.org/license.html. // Copyright (C) 2020, Intel Corporation, all rights reserved. // Third party copyrights are property of their respective owners. #ifndef __OPENCV_DNN_GRAPH_SIMPLIFIER_HPP__ #define __OPENCV_DNN_GRAPH_SIMPLIFIER_HPP__ #include #include namespace cv { namespace dnn { class ImportNodeWrapper { public: virtual ~ImportNodeWrapper() {} virtual int getNumInputs() const = 0; virtual std::string getInputName(int idx) const = 0; virtual std::string getType() const = 0; virtual void setType(const std::string& type) = 0; virtual void setInputNames(const std::vector& inputs) = 0; }; class ImportGraphWrapper { public: virtual ~ImportGraphWrapper() {} virtual Ptr getNode(int idx) const = 0; virtual int getNumNodes() const = 0; virtual int getNumOutputs(int nodeId) const = 0; virtual std::string getOutputName(int nodeId, int outId) const = 0; virtual void removeNode(int idx) = 0; virtual bool isCommutativeOp(const std::string& type) const = 0; }; class Subgraph // Interface to match and replace subgraphs. { public: virtual ~Subgraph(); // Add a node to be matched in the origin graph. Specify ids of nodes that // are expected to be inputs. Returns id of a newly added node. // TODO: Replace inputs to std::vector in C++11 int addNodeToMatch(const std::string& op, int input_0 = -1, int input_1 = -1, int input_2 = -1, int input_3 = -1); int addNodeToMatch(const std::string& op, const std::vector& inputs_); // Specify resulting node. All the matched nodes in subgraph excluding // input nodes will be fused into this single node. // TODO: Replace inputs to std::vector in C++11 void setFusedNode(const std::string& op, int input_0 = -1, int input_1 = -1, int input_2 = -1, int input_3 = -1, int input_4 = -1, int input_5 = -1); void setFusedNode(const std::string& op, const std::vector& inputs_); static int getInputNodeId(const Ptr& net, const Ptr& node, int inpId); // Match TensorFlow subgraph starting from with a set of nodes to be fused. // Const nodes are skipped during matching. Returns true if nodes are matched and can be fused. virtual bool match(const Ptr& net, int nodeId, std::vector& matchedNodesIds); // Fuse matched subgraph. void replace(const Ptr& net, const std::vector& matchedNodesIds); virtual void finalize(const Ptr& net, const Ptr& fusedNode, std::vector >& inputs); private: std::vector nodes; // Nodes to be matched in the origin graph. std::vector > inputs; // Connections of an every node to it's inputs. std::string fusedNodeOp; // Operation name of resulting fused node. std::vector fusedNodeInputs; // Inputs of fused node. }; void simplifySubgraphs(const Ptr& net, const std::vector >& patterns); }} // namespace dnn, namespace cv #endif // __OPENCV_DNN_GRAPH_SIMPLIFIER_HPP__