mirror of
https://github.com/opencv/opencv.git
synced 2024-11-28 21:20:18 +08:00
ONNX: upsample subgraph fusion added
This commit is contained in:
parent
1602a38fa9
commit
8559237d4e
@ -69,8 +69,12 @@ int Subgraph::getInputNodeId(const Ptr<ImportGraphWrapper>& net,
|
||||
const int numNodes = net->getNumNodes();
|
||||
for (int i = 0; i < numNodes; ++i)
|
||||
{
|
||||
if (net->getNodeName(i) == name)
|
||||
return i;
|
||||
const int numOutputs = net->getNumOutputs(i);
|
||||
for (int j = 0; j < numOutputs; j++)
|
||||
{
|
||||
if (net->getOutputName(i, j) == name)
|
||||
return i;
|
||||
}
|
||||
}
|
||||
CV_Error(Error::StsParseError, "Input node with name " + name + " not found");
|
||||
}
|
||||
@ -111,12 +115,12 @@ bool Subgraph::match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
continue;
|
||||
nodeId = getInputNodeId(net, node, j);
|
||||
const Ptr<ImportNodeWrapper> inpNode = net->getNode(nodeId);
|
||||
if (inpNode->getType() != "Const")
|
||||
if (inpNode->getType() != "Const" && inpNode->getType() != "Constant")
|
||||
{
|
||||
nodesToMatch.push(nodeId);
|
||||
targetNodes.push(inputNodes[j]);
|
||||
}
|
||||
else if (nodes[inputNodes[j]] != "Const")
|
||||
else if (nodes[inputNodes[j]] != "Const" && nodes[inputNodes[j]] != "Constant")
|
||||
return false;
|
||||
}
|
||||
matchedNodesIds.push_back(nodeToMatch);
|
||||
|
@ -39,7 +39,9 @@ public:
|
||||
|
||||
virtual int getNumNodes() const = 0;
|
||||
|
||||
virtual std::string getNodeName(int idx) const = 0;
|
||||
virtual int getNumOutputs(int nodeId) const = 0;
|
||||
|
||||
virtual std::string getOutputName(int nodeId, int outId) const = 0;
|
||||
|
||||
virtual void removeNode(int idx) = 0;
|
||||
};
|
||||
|
@ -76,12 +76,21 @@ public:
|
||||
return numInputs + net.node_size();
|
||||
}
|
||||
|
||||
virtual std::string getNodeName(int idx) const CV_OVERRIDE
|
||||
virtual int getNumOutputs(int nodeId) const CV_OVERRIDE
|
||||
{
|
||||
if (idx < numInputs)
|
||||
return net.input(idx).name();
|
||||
if (nodeId < numInputs)
|
||||
return 1;
|
||||
else
|
||||
return net.node(idx - numInputs).output(0);
|
||||
return net.node(nodeId - numInputs).output_size();
|
||||
}
|
||||
|
||||
virtual std::string getOutputName(int nodeId, int outId) const CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(outId < getNumOutputs(nodeId));
|
||||
if (nodeId < numInputs)
|
||||
return net.input(nodeId).name();
|
||||
else
|
||||
return net.node(nodeId - numInputs).output(outId);
|
||||
}
|
||||
|
||||
virtual void removeNode(int idx) CV_OVERRIDE
|
||||
@ -145,13 +154,193 @@ private:
|
||||
int axis;
|
||||
};
|
||||
|
||||
class ExtractScalesSubgraph : public Subgraph
|
||||
{
|
||||
public:
|
||||
ExtractScalesSubgraph()
|
||||
{
|
||||
input = addNodeToMatch("");
|
||||
|
||||
int indexH = addNodeToMatch("Constant");
|
||||
int shape1 = addNodeToMatch("Shape", input);
|
||||
int gather1 = addNodeToMatch("Gather", shape1, indexH);
|
||||
int castG1 = addNodeToMatch("Cast", gather1);
|
||||
scaleHNode = addNodeToMatch("Constant");
|
||||
int mul1 = addNodeToMatch("Mul", castG1, scaleHNode);
|
||||
int castM1 = addNodeToMatch("Cast", mul1);
|
||||
int floor1 = addNodeToMatch("Floor", castM1);
|
||||
|
||||
int indexW = addNodeToMatch("Constant");
|
||||
int shape2 = addNodeToMatch("Shape", input);
|
||||
int gather2 = addNodeToMatch("Gather", shape2, indexW);
|
||||
int castG2 = addNodeToMatch("Cast", gather2);
|
||||
scaleWNode = addNodeToMatch("Constant");
|
||||
int mul2 = addNodeToMatch("Mul", castG2, scaleWNode);
|
||||
int castM2 = addNodeToMatch("Cast", mul2);
|
||||
int floor2 = addNodeToMatch("Floor", castM2);
|
||||
|
||||
int unsqueeze1 = addNodeToMatch("Unsqueeze", floor1);
|
||||
int unsqueeze2 = addNodeToMatch("Unsqueeze", floor2);
|
||||
concatId = addNodeToMatch("Concat", unsqueeze1, unsqueeze2);
|
||||
}
|
||||
|
||||
void finalize(const Ptr<ImportGraphWrapper>& net,
|
||||
const Ptr<ImportNodeWrapper>& fusedNode,
|
||||
std::vector<Ptr<ImportNodeWrapper> >& inputs) CV_OVERRIDE
|
||||
{
|
||||
opencv_onnx::NodeProto* constant_node = inputs[1].dynamicCast<ONNXNodeWrapper>()->node;
|
||||
opencv_onnx::TensorProto tensor_proto = constant_node->attribute(0).t();
|
||||
float scaleW = getMatFromTensor(tensor_proto).at<float>(0);
|
||||
|
||||
constant_node = inputs[2].dynamicCast<ONNXNodeWrapper>()->node;
|
||||
tensor_proto = constant_node->attribute(0).t();
|
||||
float scaleH = getMatFromTensor(tensor_proto).at<float>(0);
|
||||
|
||||
opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node;
|
||||
opencv_onnx::AttributeProto* attrH = node->add_attribute();
|
||||
attrH->set_name("height_scale");
|
||||
attrH->set_i(scaleH);
|
||||
opencv_onnx::AttributeProto* attrW = node->add_attribute();
|
||||
attrW->set_name("width_scale");
|
||||
attrW->set_i(scaleW);
|
||||
|
||||
node->mutable_input()->DeleteSubrange(1, 2); // Remove two last inputs
|
||||
}
|
||||
|
||||
protected:
|
||||
int input, concatId;
|
||||
int scaleHNode, scaleWNode;
|
||||
};
|
||||
|
||||
class UpsampleSubgraph : public ExtractScalesSubgraph
|
||||
{
|
||||
public:
|
||||
UpsampleSubgraph() : ExtractScalesSubgraph()
|
||||
{
|
||||
int shape = addNodeToMatch("Shape", input);
|
||||
int slice = addNodeToMatch("Slice", shape);
|
||||
|
||||
int castConcat = addNodeToMatch("Cast", concatId);
|
||||
int castSlice = addNodeToMatch("Cast", slice);
|
||||
int divide = addNodeToMatch("Div", castConcat, castSlice);
|
||||
|
||||
int constant = addNodeToMatch("Constant");
|
||||
int concat = addNodeToMatch("Concat", constant, divide);
|
||||
|
||||
addNodeToMatch("Upsample", input, concat);
|
||||
setFusedNode("Upsample", input, scaleWNode, scaleHNode);
|
||||
}
|
||||
};
|
||||
|
||||
class ResizeSubgraph1 : public ExtractScalesSubgraph
|
||||
{
|
||||
public:
|
||||
ResizeSubgraph1() : ExtractScalesSubgraph()
|
||||
{
|
||||
int shape = addNodeToMatch("Shape", input);
|
||||
int slice = addNodeToMatch("Slice", shape, addNodeToMatch("Constant"), addNodeToMatch("Constant"), addNodeToMatch("Constant"));
|
||||
|
||||
int castConcat = addNodeToMatch("Cast", concatId);
|
||||
int concat = addNodeToMatch("Concat", slice, castConcat);
|
||||
int constant = addNodeToMatch("Constant");
|
||||
|
||||
addNodeToMatch("Resize", input, constant, constant, concat);
|
||||
setFusedNode("Upsample", input, scaleWNode, scaleHNode);
|
||||
}
|
||||
};
|
||||
|
||||
class ResizeSubgraph2 : public ExtractScalesSubgraph
|
||||
{
|
||||
public:
|
||||
ResizeSubgraph2() : ExtractScalesSubgraph()
|
||||
{
|
||||
int constantConcat = addNodeToMatch("Constant");
|
||||
int castConcat = addNodeToMatch("Cast", concatId);
|
||||
int concat = addNodeToMatch("Concat", constantConcat, castConcat);
|
||||
int constant = addNodeToMatch("Constant");
|
||||
|
||||
addNodeToMatch("Resize", input, constant, constant, concat);
|
||||
setFusedNode("Upsample", input, scaleWNode, scaleHNode);
|
||||
}
|
||||
};
|
||||
|
||||
void simplifySubgraphs(opencv_onnx::GraphProto& net)
|
||||
{
|
||||
std::vector<Ptr<Subgraph> > subgraphs;
|
||||
subgraphs.push_back(makePtr<UpsampleSubgraph>());
|
||||
subgraphs.push_back(makePtr<ResizeSubgraph1>());
|
||||
subgraphs.push_back(makePtr<ResizeSubgraph2>());
|
||||
subgraphs.push_back(makePtr<SoftMaxSubgraph>());
|
||||
|
||||
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
|
||||
}
|
||||
|
||||
Mat getMatFromTensor(opencv_onnx::TensorProto& tensor_proto)
|
||||
{
|
||||
if (tensor_proto.raw_data().empty() && tensor_proto.float_data().empty() &&
|
||||
tensor_proto.double_data().empty() && tensor_proto.int64_data().empty())
|
||||
return Mat();
|
||||
|
||||
opencv_onnx::TensorProto_DataType datatype = tensor_proto.data_type();
|
||||
Mat blob;
|
||||
std::vector<int> sizes;
|
||||
for (int i = 0; i < tensor_proto.dims_size(); i++) {
|
||||
sizes.push_back(tensor_proto.dims(i));
|
||||
}
|
||||
if (sizes.empty())
|
||||
sizes.assign(1, 1);
|
||||
if (datatype == opencv_onnx::TensorProto_DataType_FLOAT) {
|
||||
|
||||
if (!tensor_proto.float_data().empty()) {
|
||||
const ::google::protobuf::RepeatedField<float> field = tensor_proto.float_data();
|
||||
Mat(sizes, CV_32FC1, (void*)field.data()).copyTo(blob);
|
||||
}
|
||||
else {
|
||||
char* val = const_cast<char*>(tensor_proto.raw_data().c_str());
|
||||
Mat(sizes, CV_32FC1, val).copyTo(blob);
|
||||
}
|
||||
}
|
||||
else if (datatype == opencv_onnx::TensorProto_DataType_DOUBLE)
|
||||
{
|
||||
const ::google::protobuf::RepeatedField<double> field = tensor_proto.double_data();
|
||||
CV_Assert(!field.empty());
|
||||
Mat(sizes, CV_64FC1, (void*)field.data()).convertTo(blob, CV_32FC1);
|
||||
}
|
||||
else if (datatype == opencv_onnx::TensorProto_DataType_INT64)
|
||||
{
|
||||
blob.create(sizes, CV_32SC1);
|
||||
int32_t* dst = reinterpret_cast<int32_t*>(blob.data);
|
||||
|
||||
if (!tensor_proto.int64_data().empty()) {
|
||||
::google::protobuf::RepeatedField< ::google::protobuf::int64> src = tensor_proto.int64_data();
|
||||
convertInt64ToInt32(src, dst, blob.total());
|
||||
}
|
||||
else
|
||||
{
|
||||
const char* val = tensor_proto.raw_data().c_str();
|
||||
#if CV_STRONG_ALIGNMENT
|
||||
// Aligned pointer is required: https://github.com/opencv/opencv/issues/16373
|
||||
// this doesn't work: typedef int64_t CV_DECL_ALIGNED(1) unaligned_int64_t;
|
||||
AutoBuffer<int64_t, 16> aligned_val;
|
||||
if (!isAligned<sizeof(int64_t)>(val))
|
||||
{
|
||||
size_t sz = tensor_proto.raw_data().size();
|
||||
aligned_val.allocate(divUp(sz, sizeof(int64_t)));
|
||||
memcpy(aligned_val.data(), val, sz);
|
||||
val = (const char*)aligned_val.data();
|
||||
}
|
||||
#endif
|
||||
const int64_t* src = reinterpret_cast<const int64_t*>(val);
|
||||
convertInt64ToInt32(src, dst, blob.total());
|
||||
}
|
||||
}
|
||||
else
|
||||
CV_Error(Error::StsUnsupportedFormat, "Unsupported data type: " +
|
||||
opencv_onnx::TensorProto_DataType_Name(datatype));
|
||||
if (tensor_proto.dims_size() == 0)
|
||||
blob.dims = 1; // To force 1-dimensional cv::Mat for scalars.
|
||||
return blob;
|
||||
}
|
||||
|
||||
CV__DNN_EXPERIMENTAL_NS_END
|
||||
}} // namespace cv::dnn
|
||||
|
@ -24,6 +24,19 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
|
||||
|
||||
void simplifySubgraphs(opencv_onnx::GraphProto& net);
|
||||
|
||||
template<typename T1, typename T2>
|
||||
void convertInt64ToInt32(const T1& src, T2& dst, int size)
|
||||
{
|
||||
for (int i = 0; i < size; i++) {
|
||||
if (src[i] < std::numeric_limits<int32_t>::min() || src[i] > std::numeric_limits<int32_t>::max()) {
|
||||
CV_Error(Error::StsOutOfRange, "Input is out of OpenCV 32S range");
|
||||
}
|
||||
dst[i] = saturate_cast<int32_t>(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
Mat getMatFromTensor(opencv_onnx::TensorProto& tensor_proto);
|
||||
|
||||
CV__DNN_EXPERIMENTAL_NS_END
|
||||
}} // namespace dnn, namespace cv
|
||||
|
||||
|
@ -95,83 +95,6 @@ void releaseONNXTensor(opencv_onnx::TensorProto& tensor_proto)
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T1, typename T2>
|
||||
void convertInt64ToInt32(const T1& src, T2& dst, int size)
|
||||
{
|
||||
for (int i = 0; i < size; i++) {
|
||||
if (src[i] < std::numeric_limits<int32_t>::min() || src[i] > std::numeric_limits<int32_t>::max()) {
|
||||
CV_Error(Error::StsOutOfRange, "Input is out of OpenCV 32S range");
|
||||
}
|
||||
dst[i] = saturate_cast<int32_t>(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
Mat getMatFromTensor(opencv_onnx::TensorProto& tensor_proto)
|
||||
{
|
||||
CV_Assert(!tensor_proto.raw_data().empty() || !tensor_proto.float_data().empty()
|
||||
|| !tensor_proto.double_data().empty() || !tensor_proto.int64_data().empty());
|
||||
|
||||
opencv_onnx::TensorProto_DataType datatype = tensor_proto.data_type();
|
||||
Mat blob;
|
||||
std::vector<int> sizes;
|
||||
for (int i = 0; i < tensor_proto.dims_size(); i++) {
|
||||
sizes.push_back(tensor_proto.dims(i));
|
||||
}
|
||||
if (sizes.empty())
|
||||
sizes.assign(1, 1);
|
||||
if (datatype == opencv_onnx::TensorProto_DataType_FLOAT) {
|
||||
|
||||
if (!tensor_proto.float_data().empty()) {
|
||||
const ::google::protobuf::RepeatedField<float> field = tensor_proto.float_data();
|
||||
Mat(sizes, CV_32FC1, (void*)field.data()).copyTo(blob);
|
||||
}
|
||||
else {
|
||||
char* val = const_cast<char*>(tensor_proto.raw_data().c_str());
|
||||
Mat(sizes, CV_32FC1, val).copyTo(blob);
|
||||
}
|
||||
}
|
||||
else if (datatype == opencv_onnx::TensorProto_DataType_DOUBLE)
|
||||
{
|
||||
const ::google::protobuf::RepeatedField<double> field = tensor_proto.double_data();
|
||||
CV_Assert(!field.empty());
|
||||
Mat(sizes, CV_64FC1, (void*)field.data()).convertTo(blob, CV_32FC1);
|
||||
}
|
||||
else if (datatype == opencv_onnx::TensorProto_DataType_INT64)
|
||||
{
|
||||
blob.create(sizes, CV_32SC1);
|
||||
int32_t* dst = reinterpret_cast<int32_t*>(blob.data);
|
||||
|
||||
if (!tensor_proto.int64_data().empty()) {
|
||||
::google::protobuf::RepeatedField< ::google::protobuf::int64> src = tensor_proto.int64_data();
|
||||
convertInt64ToInt32(src, dst, blob.total());
|
||||
}
|
||||
else
|
||||
{
|
||||
const char* val = tensor_proto.raw_data().c_str();
|
||||
#if CV_STRONG_ALIGNMENT
|
||||
// Aligned pointer is required: https://github.com/opencv/opencv/issues/16373
|
||||
// this doesn't work: typedef int64_t CV_DECL_ALIGNED(1) unaligned_int64_t;
|
||||
AutoBuffer<int64_t, 16> aligned_val;
|
||||
if (!isAligned<sizeof(int64_t)>(val))
|
||||
{
|
||||
size_t sz = tensor_proto.raw_data().size();
|
||||
aligned_val.allocate(divUp(sz, sizeof(int64_t)));
|
||||
memcpy(aligned_val.data(), val, sz);
|
||||
val = (const char*)aligned_val.data();
|
||||
}
|
||||
#endif
|
||||
const int64_t* src = reinterpret_cast<const int64_t*>(val);
|
||||
convertInt64ToInt32(src, dst, blob.total());
|
||||
}
|
||||
}
|
||||
else
|
||||
CV_Error(Error::StsUnsupportedFormat, "Unsupported data type: " +
|
||||
opencv_onnx::TensorProto_DataType_Name(datatype));
|
||||
if (tensor_proto.dims_size() == 0)
|
||||
blob.dims = 1; // To force 1-dimensional cv::Mat for scalars.
|
||||
return blob;
|
||||
}
|
||||
|
||||
void runLayer(LayerParams& params, const std::vector<Mat>& inputs,
|
||||
std::vector<Mat>& outputs)
|
||||
{
|
||||
|
@ -69,9 +69,15 @@ public:
|
||||
return net.node_size();
|
||||
}
|
||||
|
||||
virtual std::string getNodeName(int idx) const CV_OVERRIDE
|
||||
virtual int getNumOutputs(int nodeId) const CV_OVERRIDE
|
||||
{
|
||||
return net.node(idx).name();
|
||||
return 1;
|
||||
}
|
||||
|
||||
virtual std::string getOutputName(int nodeId, int outId) const CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(outId == 0);
|
||||
return net.node(nodeId).name();
|
||||
}
|
||||
|
||||
virtual void removeNode(int idx) CV_OVERRIDE
|
||||
|
@ -316,6 +316,13 @@ TEST_P(Test_ONNX_layers, Resize)
|
||||
testONNXModels("resize_bilinear");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, ResizeUnfused)
|
||||
{
|
||||
testONNXModels("upsample_unfused_opset9_torch1.4");
|
||||
testONNXModels("resize_nearest_unfused_opset11_torch1.4");
|
||||
testONNXModels("resize_nearest_unfused_opset11_torch1.3");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, MultyInputs)
|
||||
{
|
||||
const String model = _tf("models/multy_inputs.onnx");
|
||||
|
Loading…
Reference in New Issue
Block a user