Merge pull request #25212 from alexlyulkov:al/dnn-int64-scatter

Added int64 values support to scatter, scatterND and maxunpool layers
This commit is contained in:
Alexander Smorkalov 2024-03-26 13:52:28 +03:00 committed by GitHub
commit a33de44b0b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 34 additions and 8 deletions

View File

@ -94,14 +94,34 @@ public:
Mat& input = inputs[0];
Mat& indices = inputs[1];
if (input.type() == CV_32F && indices.type() == CV_32S)
run<float, int32_t>(input, indices, outputs);
else if (input.type() == CV_32F && indices.type() == CV_64S)
run<float, int64_t>(input, indices, outputs);
else if (input.type() == CV_16F && indices.type() == CV_32S)
run<int16_t, int32_t>(input, indices, outputs);
else if (input.type() == CV_16F && indices.type() == CV_64S)
run<int16_t, int64_t>(input, indices, outputs);
if (indices.depth() == CV_32S)
typeDispatch<int32_t>(input.type(), input, indices, outputs);
else if (indices.depth() == CV_64S)
typeDispatch<int64_t>(input.type(), input, indices, outputs);
else
CV_Error(cv::Error::BadDepth, "Unsupported type.");
}
template<typename T_INDEX, typename... Args>
inline void typeDispatch(const int type, Args&&... args)
{
switch (type)
{
case CV_32S:
run<int32_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_64S:
run<int64_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_32F:
run<float, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_16F:
run<int16_t, T_INDEX>(std::forward<Args>(args)...);
break;
default:
CV_Error(cv::Error::BadDepth, "Unsupported type.");
};
}
template<typename T, typename INDEX_TYPE>

View File

@ -190,6 +190,9 @@ public:
case CV_32S:
reductionDispatch<int32_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_64S:
reductionDispatch<int64_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_32F:
reductionDispatch<float, T_INDEX>(std::forward<Args>(args)...);
break;

View File

@ -185,6 +185,9 @@ public:
case CV_32S:
reductionDispatch<int32_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_64S:
reductionDispatch<int64_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_32F:
reductionDispatch<float, T_INDEX>(std::forward<Args>(args)...);
break;