From d2d6869a26dbeca08a38113d9367de8be93e8d74 Mon Sep 17 00:00:00 2001 From: Alexander Lyulkov Date: Wed, 13 Mar 2024 15:40:07 +0300 Subject: [PATCH] Added int64 values support to scatter, scatterND and maxunpool layers --- .../dnn/src/layers/max_unpooling_layer.cpp | 36 ++++++++++++++----- modules/dnn/src/layers/scatterND_layer.cpp | 3 ++ modules/dnn/src/layers/scatter_layer.cpp | 3 ++ 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/modules/dnn/src/layers/max_unpooling_layer.cpp b/modules/dnn/src/layers/max_unpooling_layer.cpp index fe359ef293..89f0324f81 100644 --- a/modules/dnn/src/layers/max_unpooling_layer.cpp +++ b/modules/dnn/src/layers/max_unpooling_layer.cpp @@ -94,14 +94,34 @@ public: Mat& input = inputs[0]; Mat& indices = inputs[1]; - if (input.type() == CV_32F && indices.type() == CV_32S) - run(input, indices, outputs); - else if (input.type() == CV_32F && indices.type() == CV_64S) - run(input, indices, outputs); - else if (input.type() == CV_16F && indices.type() == CV_32S) - run(input, indices, outputs); - else if (input.type() == CV_16F && indices.type() == CV_64S) - run(input, indices, outputs); + if (indices.depth() == CV_32S) + typeDispatch(input.type(), input, indices, outputs); + else if (indices.depth() == CV_64S) + typeDispatch(input.type(), input, indices, outputs); + else + CV_Error(cv::Error::BadDepth, "Unsupported type."); + } + + template + inline void typeDispatch(const int type, Args&&... args) + { + switch (type) + { + case CV_32S: + run(std::forward(args)...); + break; + case CV_64S: + run(std::forward(args)...); + break; + case CV_32F: + run(std::forward(args)...); + break; + case CV_16F: + run(std::forward(args)...); + break; + default: + CV_Error(cv::Error::BadDepth, "Unsupported type."); + }; } template diff --git a/modules/dnn/src/layers/scatterND_layer.cpp b/modules/dnn/src/layers/scatterND_layer.cpp index f9dcb41647..b0d26938b4 100644 --- a/modules/dnn/src/layers/scatterND_layer.cpp +++ b/modules/dnn/src/layers/scatterND_layer.cpp @@ -190,6 +190,9 @@ public: case CV_32S: reductionDispatch(std::forward(args)...); break; + case CV_64S: + reductionDispatch(std::forward(args)...); + break; case CV_32F: reductionDispatch(std::forward(args)...); break; diff --git a/modules/dnn/src/layers/scatter_layer.cpp b/modules/dnn/src/layers/scatter_layer.cpp index e0e6f630b6..48757c6332 100644 --- a/modules/dnn/src/layers/scatter_layer.cpp +++ b/modules/dnn/src/layers/scatter_layer.cpp @@ -185,6 +185,9 @@ public: case CV_32S: reductionDispatch(std::forward(args)...); break; + case CV_64S: + reductionDispatch(std::forward(args)...); + break; case CV_32F: reductionDispatch(std::forward(args)...); break;