mirror of
https://github.com/opencv/opencv.git
synced 2025-08-06 14:36:36 +08:00
Merge pull request #24655 from fengyuentau:graph_simplifier_optional_input
dnn onnx graph simplifier: handle optional inputs of Slice #24655 Resolves https://github.com/opencv/opencv/issues/24609 ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
22edfd2628
commit
f5ec92e4ca
@ -82,6 +82,23 @@ public:
|
|||||||
return makePtr<ONNXNodeWrapper>(node);
|
return makePtr<ONNXNodeWrapper>(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int getTensorShapeSize(int node_id, int node_input_id) {
|
||||||
|
const auto node = getNode(node_id);
|
||||||
|
const auto &input_name = node->getInputName(node_input_id);
|
||||||
|
for (int i = 0; i < net.value_info_size(); i++) {
|
||||||
|
const auto value_info = net.value_info(i);
|
||||||
|
if (value_info.name() == input_name) {
|
||||||
|
if (value_info.has_type() && value_info.type().has_tensor_type() &&
|
||||||
|
value_info.type().tensor_type().has_shape()) {
|
||||||
|
return value_info.type().tensor_type().shape().dim_size();
|
||||||
|
} else {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
int getInputInitializerId(int node_id, int node_input_id)
|
int getInputInitializerId(int node_id, int node_input_id)
|
||||||
{
|
{
|
||||||
auto node = getNode(node_id);
|
auto node = getNode(node_id);
|
||||||
@ -164,6 +181,61 @@ static Mat extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Slice operator has two optional inputs "axes" and "steps". Some models may be set to have
|
||||||
|
Slice with optional inputs of default values, some of them don't. This Subgraph removes
|
||||||
|
all optional inputs of Slice if values are default.
|
||||||
|
*/
|
||||||
|
class RemoveSliceAllOptionalInputsSubgraph : public Subgraph {
|
||||||
|
public:
|
||||||
|
RemoveSliceAllOptionalInputsSubgraph(size_t num_inputs = 4) {
|
||||||
|
num_inputs_ = num_inputs;
|
||||||
|
|
||||||
|
int input = addNodeToMatch("");
|
||||||
|
int starts = addNodeToMatch("");
|
||||||
|
int ends = addNodeToMatch("");
|
||||||
|
std::vector<int> inputs{input, starts, ends};
|
||||||
|
for (size_t i = 3; i < num_inputs_; i++) { // axes and steps
|
||||||
|
inputs.push_back(addNodeToMatch(""));
|
||||||
|
}
|
||||||
|
|
||||||
|
slice_id = addNodeToMatch("Slice", inputs);
|
||||||
|
|
||||||
|
setFusedNode("Slice", std::vector<int>{input, starts, ends});
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||||
|
std::vector<int>& matchedNodesIds) CV_OVERRIDE {
|
||||||
|
if (Subgraph::match(net, nodeId, matchedNodesIds)) {
|
||||||
|
if (num_inputs_ >= 4) { // with axes
|
||||||
|
// Check if axes are -1 or last axis
|
||||||
|
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
||||||
|
int shape_size = onnx_net->getTensorShapeSize(matchedNodesIds[slice_id], 0);
|
||||||
|
|
||||||
|
auto axes = extractConstant(net, matchedNodesIds[slice_id], 3);
|
||||||
|
for (size_t i = 0; i < axes.total(); i++) {
|
||||||
|
const int axis = *(axes.ptr<const int>() + i);
|
||||||
|
if (axis != -1 && axis != shape_size - 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (num_inputs_ == 5) {
|
||||||
|
// Check if steps are 1
|
||||||
|
auto steps = extractConstant(net, matchedNodesIds[slice_id], 4);
|
||||||
|
if (countNonZero(steps != 1)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int slice_id;
|
||||||
|
size_t num_inputs_;
|
||||||
|
};
|
||||||
|
|
||||||
/* Fusion for Gelu.
|
/* Fusion for Gelu.
|
||||||
|
|
||||||
Graph before fusion:
|
Graph before fusion:
|
||||||
@ -1091,7 +1163,7 @@ public:
|
|||||||
int cast = addNodeToMatch("Cast", concat1);
|
int cast = addNodeToMatch("Cast", concat1);
|
||||||
|
|
||||||
int shape2 = addNodeToMatch("Shape", input);
|
int shape2 = addNodeToMatch("Shape", input);
|
||||||
int slice = addNodeToMatch("Slice", shape2, addNodeToMatch("Constant"), addNodeToMatch("Constant"), addNodeToMatch("Constant"));
|
int slice = addNodeToMatch("Slice", shape2, addNodeToMatch("Constant"), addNodeToMatch("Constant"));
|
||||||
int concat2 = addNodeToMatch("Concat", slice, cast);
|
int concat2 = addNodeToMatch("Concat", slice, cast);
|
||||||
addNodeToMatch("Resize", input, addNodeToMatch("Constant"), addNodeToMatch("Constant"), concat2);
|
addNodeToMatch("Resize", input, addNodeToMatch("Constant"), addNodeToMatch("Constant"), concat2);
|
||||||
|
|
||||||
@ -1163,6 +1235,8 @@ public:
|
|||||||
void simplifySubgraphs(opencv_onnx::GraphProto& net)
|
void simplifySubgraphs(opencv_onnx::GraphProto& net)
|
||||||
{
|
{
|
||||||
std::vector<Ptr<Subgraph> > subgraphs;
|
std::vector<Ptr<Subgraph> > subgraphs;
|
||||||
|
subgraphs.push_back(makePtr<RemoveSliceAllOptionalInputsSubgraph>(4));
|
||||||
|
subgraphs.push_back(makePtr<RemoveSliceAllOptionalInputsSubgraph>(5));
|
||||||
subgraphs.push_back(makePtr<GeluSubGraph>());
|
subgraphs.push_back(makePtr<GeluSubGraph>());
|
||||||
subgraphs.push_back(makePtr<GeluApproximationSubGraph>());
|
subgraphs.push_back(makePtr<GeluApproximationSubGraph>());
|
||||||
subgraphs.push_back(makePtr<LayerNormSubGraph>());
|
subgraphs.push_back(makePtr<LayerNormSubGraph>());
|
||||||
|
Loading…
Reference in New Issue
Block a user