mirror of
https://github.com/opencv/opencv.git
synced 2025-08-06 14:36:36 +08:00
Merge pull request #24672 from dkurt:adjust_slice_optional_inputs
Replace Slice optional inputs removal to adjustment
This commit is contained in:
commit
098efb6d3d
@ -182,12 +182,12 @@ static Mat extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int
|
||||
}
|
||||
|
||||
/* Slice operator has two optional inputs "axes" and "steps". Some models may be set to have
|
||||
Slice with optional inputs of default values, some of them don't. This Subgraph removes
|
||||
all optional inputs of Slice if values are default.
|
||||
Slice with optional inputs of default values, some of them don't. This Subgraph adjusts
|
||||
all optional inputs of Slice up to 5.
|
||||
*/
|
||||
class RemoveSliceAllOptionalInputsSubgraph : public Subgraph {
|
||||
class AdjustSliceAllOptionalInputsSubgraph : public Subgraph {
|
||||
public:
|
||||
RemoveSliceAllOptionalInputsSubgraph(size_t num_inputs = 4) {
|
||||
AdjustSliceAllOptionalInputsSubgraph(size_t num_inputs = 4) {
|
||||
num_inputs_ = num_inputs;
|
||||
|
||||
int input = addNodeToMatch("");
|
||||
@ -200,35 +200,17 @@ class RemoveSliceAllOptionalInputsSubgraph : public Subgraph {
|
||||
|
||||
slice_id = addNodeToMatch("Slice", inputs);
|
||||
|
||||
setFusedNode("Slice", std::vector<int>{input, starts, ends});
|
||||
setFusedNode("Slice", inputs);
|
||||
}
|
||||
|
||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
std::vector<int>& matchedNodesIds) CV_OVERRIDE {
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds)) {
|
||||
if (num_inputs_ >= 4) { // with axes
|
||||
// Check if axes are -1 or last axis
|
||||
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
||||
int shape_size = onnx_net->getTensorShapeSize(matchedNodesIds[slice_id], 0);
|
||||
|
||||
auto axes = extractConstant(net, matchedNodesIds[slice_id], 3);
|
||||
for (size_t i = 0; i < axes.total(); i++) {
|
||||
const int axis = *(axes.ptr<const int>() + i);
|
||||
if (axis != -1 && axis != shape_size - 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (num_inputs_ == 5) {
|
||||
// Check if steps are 1
|
||||
auto steps = extractConstant(net, matchedNodesIds[slice_id], 4);
|
||||
if (countNonZero(steps != 1)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
virtual void finalize(const Ptr<ImportGraphWrapper>&,
|
||||
const Ptr<ImportNodeWrapper>& fusedNode,
|
||||
std::vector<Ptr<ImportNodeWrapper> >&) CV_OVERRIDE
|
||||
{
|
||||
opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node;
|
||||
for (int i = num_inputs_; i < 5; ++i) {
|
||||
node->add_input("");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
@ -1119,7 +1101,11 @@ public:
|
||||
ResizeSubgraph1() : ExtractScalesSubgraph()
|
||||
{
|
||||
int shape = addNodeToMatch("Shape", input);
|
||||
int slice = addNodeToMatch("Slice", shape, addNodeToMatch("Constant"), addNodeToMatch("Constant"), addNodeToMatch("Constant"));
|
||||
int slice = addNodeToMatch("Slice", {shape,
|
||||
addNodeToMatch(""),
|
||||
addNodeToMatch(""),
|
||||
addNodeToMatch(""),
|
||||
addNodeToMatch("")});
|
||||
|
||||
int castConcat = addNodeToMatch("Cast", concatId);
|
||||
int concat = addNodeToMatch("Concat", slice, castConcat);
|
||||
@ -1163,7 +1149,11 @@ public:
|
||||
int cast = addNodeToMatch("Cast", concat1);
|
||||
|
||||
int shape2 = addNodeToMatch("Shape", input);
|
||||
int slice = addNodeToMatch("Slice", shape2, addNodeToMatch("Constant"), addNodeToMatch("Constant"));
|
||||
int slice = addNodeToMatch("Slice", {shape2,
|
||||
addNodeToMatch(""),
|
||||
addNodeToMatch(""),
|
||||
addNodeToMatch(""),
|
||||
addNodeToMatch("")});
|
||||
int concat2 = addNodeToMatch("Concat", slice, cast);
|
||||
addNodeToMatch("Resize", input, addNodeToMatch("Constant"), addNodeToMatch("Constant"), concat2);
|
||||
|
||||
@ -1235,8 +1225,8 @@ public:
|
||||
void simplifySubgraphs(opencv_onnx::GraphProto& net)
|
||||
{
|
||||
std::vector<Ptr<Subgraph> > subgraphs;
|
||||
subgraphs.push_back(makePtr<RemoveSliceAllOptionalInputsSubgraph>(4));
|
||||
subgraphs.push_back(makePtr<RemoveSliceAllOptionalInputsSubgraph>(5));
|
||||
subgraphs.push_back(makePtr<AdjustSliceAllOptionalInputsSubgraph>(3));
|
||||
subgraphs.push_back(makePtr<AdjustSliceAllOptionalInputsSubgraph>(4));
|
||||
subgraphs.push_back(makePtr<GeluSubGraph>());
|
||||
subgraphs.push_back(makePtr<GeluApproximationSubGraph>());
|
||||
subgraphs.push_back(makePtr<LayerNormSubGraph>());
|
||||
|
@ -1235,7 +1235,7 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP
|
||||
starts_ = DictValue::arrayInt(start_blob.begin<int>(), start_blob.total());
|
||||
ends_ = DictValue::arrayInt(end_blob.begin<int>(), end_blob.total());
|
||||
|
||||
if (inp_size > 3)
|
||||
if (inp_size > 3 && !getBlob(node_proto, 3).empty())
|
||||
{
|
||||
Mat axes_blob = getBlob(node_proto, 3);
|
||||
CV_Assert(axes_blob.total() == start_blob.total());
|
||||
@ -1244,7 +1244,7 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP
|
||||
has_axes = true;
|
||||
}
|
||||
|
||||
if (inp_size == 5)
|
||||
if (inp_size == 5 && !getBlob(node_proto, 4).empty())
|
||||
{
|
||||
Mat step_blob = getBlob(node_proto, 4);
|
||||
CV_Assert(step_blob.total() == start_blob.total());
|
||||
|
Loading…
Reference in New Issue
Block a user