mirror of
https://github.com/opencv/opencv.git
synced 2025-08-06 06:26:29 +08:00
Merge pull request #24655 from fengyuentau:graph_simplifier_optional_input
dnn onnx graph simplifier: handle optional inputs of Slice #24655 Resolves https://github.com/opencv/opencv/issues/24609 ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
22edfd2628
commit
f5ec92e4ca
@ -82,6 +82,23 @@ public:
|
||||
return makePtr<ONNXNodeWrapper>(node);
|
||||
}
|
||||
|
||||
int getTensorShapeSize(int node_id, int node_input_id) {
|
||||
const auto node = getNode(node_id);
|
||||
const auto &input_name = node->getInputName(node_input_id);
|
||||
for (int i = 0; i < net.value_info_size(); i++) {
|
||||
const auto value_info = net.value_info(i);
|
||||
if (value_info.name() == input_name) {
|
||||
if (value_info.has_type() && value_info.type().has_tensor_type() &&
|
||||
value_info.type().tensor_type().has_shape()) {
|
||||
return value_info.type().tensor_type().shape().dim_size();
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int getInputInitializerId(int node_id, int node_input_id)
|
||||
{
|
||||
auto node = getNode(node_id);
|
||||
@ -164,6 +181,61 @@ static Mat extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int
|
||||
}
|
||||
}
|
||||
|
||||
/* Slice operator has two optional inputs "axes" and "steps". Some models may be set to have
|
||||
Slice with optional inputs of default values, some of them don't. This Subgraph removes
|
||||
all optional inputs of Slice if values are default.
|
||||
*/
|
||||
class RemoveSliceAllOptionalInputsSubgraph : public Subgraph {
|
||||
public:
|
||||
RemoveSliceAllOptionalInputsSubgraph(size_t num_inputs = 4) {
|
||||
num_inputs_ = num_inputs;
|
||||
|
||||
int input = addNodeToMatch("");
|
||||
int starts = addNodeToMatch("");
|
||||
int ends = addNodeToMatch("");
|
||||
std::vector<int> inputs{input, starts, ends};
|
||||
for (size_t i = 3; i < num_inputs_; i++) { // axes and steps
|
||||
inputs.push_back(addNodeToMatch(""));
|
||||
}
|
||||
|
||||
slice_id = addNodeToMatch("Slice", inputs);
|
||||
|
||||
setFusedNode("Slice", std::vector<int>{input, starts, ends});
|
||||
}
|
||||
|
||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
std::vector<int>& matchedNodesIds) CV_OVERRIDE {
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds)) {
|
||||
if (num_inputs_ >= 4) { // with axes
|
||||
// Check if axes are -1 or last axis
|
||||
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
||||
int shape_size = onnx_net->getTensorShapeSize(matchedNodesIds[slice_id], 0);
|
||||
|
||||
auto axes = extractConstant(net, matchedNodesIds[slice_id], 3);
|
||||
for (size_t i = 0; i < axes.total(); i++) {
|
||||
const int axis = *(axes.ptr<const int>() + i);
|
||||
if (axis != -1 && axis != shape_size - 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (num_inputs_ == 5) {
|
||||
// Check if steps are 1
|
||||
auto steps = extractConstant(net, matchedNodesIds[slice_id], 4);
|
||||
if (countNonZero(steps != 1)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
int slice_id;
|
||||
size_t num_inputs_;
|
||||
};
|
||||
|
||||
/* Fusion for Gelu.
|
||||
|
||||
Graph before fusion:
|
||||
@ -1091,7 +1163,7 @@ public:
|
||||
int cast = addNodeToMatch("Cast", concat1);
|
||||
|
||||
int shape2 = addNodeToMatch("Shape", input);
|
||||
int slice = addNodeToMatch("Slice", shape2, addNodeToMatch("Constant"), addNodeToMatch("Constant"), addNodeToMatch("Constant"));
|
||||
int slice = addNodeToMatch("Slice", shape2, addNodeToMatch("Constant"), addNodeToMatch("Constant"));
|
||||
int concat2 = addNodeToMatch("Concat", slice, cast);
|
||||
addNodeToMatch("Resize", input, addNodeToMatch("Constant"), addNodeToMatch("Constant"), concat2);
|
||||
|
||||
@ -1163,6 +1235,8 @@ public:
|
||||
void simplifySubgraphs(opencv_onnx::GraphProto& net)
|
||||
{
|
||||
std::vector<Ptr<Subgraph> > subgraphs;
|
||||
subgraphs.push_back(makePtr<RemoveSliceAllOptionalInputsSubgraph>(4));
|
||||
subgraphs.push_back(makePtr<RemoveSliceAllOptionalInputsSubgraph>(5));
|
||||
subgraphs.push_back(makePtr<GeluSubGraph>());
|
||||
subgraphs.push_back(makePtr<GeluApproximationSubGraph>());
|
||||
subgraphs.push_back(makePtr<LayerNormSubGraph>());
|
||||
|
Loading…
Reference in New Issue
Block a user