diff --git a/modules/dnn/src/layers/max_unpooling_layer.cpp b/modules/dnn/src/layers/max_unpooling_layer.cpp index fe359ef293..89f0324f81 100644 --- a/modules/dnn/src/layers/max_unpooling_layer.cpp +++ b/modules/dnn/src/layers/max_unpooling_layer.cpp @@ -94,14 +94,34 @@ public: Mat& input = inputs[0]; Mat& indices = inputs[1]; - if (input.type() == CV_32F && indices.type() == CV_32S) - run(input, indices, outputs); - else if (input.type() == CV_32F && indices.type() == CV_64S) - run(input, indices, outputs); - else if (input.type() == CV_16F && indices.type() == CV_32S) - run(input, indices, outputs); - else if (input.type() == CV_16F && indices.type() == CV_64S) - run(input, indices, outputs); + if (indices.depth() == CV_32S) + typeDispatch(input.type(), input, indices, outputs); + else if (indices.depth() == CV_64S) + typeDispatch(input.type(), input, indices, outputs); + else + CV_Error(cv::Error::BadDepth, "Unsupported type."); + } + + template + inline void typeDispatch(const int type, Args&&... args) + { + switch (type) + { + case CV_32S: + run(std::forward(args)...); + break; + case CV_64S: + run(std::forward(args)...); + break; + case CV_32F: + run(std::forward(args)...); + break; + case CV_16F: + run(std::forward(args)...); + break; + default: + CV_Error(cv::Error::BadDepth, "Unsupported type."); + }; } template diff --git a/modules/dnn/src/layers/scatterND_layer.cpp b/modules/dnn/src/layers/scatterND_layer.cpp index f9dcb41647..b0d26938b4 100644 --- a/modules/dnn/src/layers/scatterND_layer.cpp +++ b/modules/dnn/src/layers/scatterND_layer.cpp @@ -190,6 +190,9 @@ public: case CV_32S: reductionDispatch(std::forward(args)...); break; + case CV_64S: + reductionDispatch(std::forward(args)...); + break; case CV_32F: reductionDispatch(std::forward(args)...); break; diff --git a/modules/dnn/src/layers/scatter_layer.cpp b/modules/dnn/src/layers/scatter_layer.cpp index e0e6f630b6..48757c6332 100644 --- a/modules/dnn/src/layers/scatter_layer.cpp +++ b/modules/dnn/src/layers/scatter_layer.cpp @@ -185,6 +185,9 @@ public: case CV_32S: reductionDispatch(std::forward(args)...); break; + case CV_64S: + reductionDispatch(std::forward(args)...); + break; case CV_32F: reductionDispatch(std::forward(args)...); break;