mirror of
https://github.com/opencv/opencv.git
synced 2024-12-05 09:49:12 +08:00
Merge pull request #25212 from alexlyulkov:al/dnn-int64-scatter
Added int64 values support to scatter, scatterND and maxunpool layers
This commit is contained in:
commit
a33de44b0b
@ -94,14 +94,34 @@ public:
|
||||
Mat& input = inputs[0];
|
||||
Mat& indices = inputs[1];
|
||||
|
||||
if (input.type() == CV_32F && indices.type() == CV_32S)
|
||||
run<float, int32_t>(input, indices, outputs);
|
||||
else if (input.type() == CV_32F && indices.type() == CV_64S)
|
||||
run<float, int64_t>(input, indices, outputs);
|
||||
else if (input.type() == CV_16F && indices.type() == CV_32S)
|
||||
run<int16_t, int32_t>(input, indices, outputs);
|
||||
else if (input.type() == CV_16F && indices.type() == CV_64S)
|
||||
run<int16_t, int64_t>(input, indices, outputs);
|
||||
if (indices.depth() == CV_32S)
|
||||
typeDispatch<int32_t>(input.type(), input, indices, outputs);
|
||||
else if (indices.depth() == CV_64S)
|
||||
typeDispatch<int64_t>(input.type(), input, indices, outputs);
|
||||
else
|
||||
CV_Error(cv::Error::BadDepth, "Unsupported type.");
|
||||
}
|
||||
|
||||
template<typename T_INDEX, typename... Args>
|
||||
inline void typeDispatch(const int type, Args&&... args)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case CV_32S:
|
||||
run<int32_t, T_INDEX>(std::forward<Args>(args)...);
|
||||
break;
|
||||
case CV_64S:
|
||||
run<int64_t, T_INDEX>(std::forward<Args>(args)...);
|
||||
break;
|
||||
case CV_32F:
|
||||
run<float, T_INDEX>(std::forward<Args>(args)...);
|
||||
break;
|
||||
case CV_16F:
|
||||
run<int16_t, T_INDEX>(std::forward<Args>(args)...);
|
||||
break;
|
||||
default:
|
||||
CV_Error(cv::Error::BadDepth, "Unsupported type.");
|
||||
};
|
||||
}
|
||||
|
||||
template<typename T, typename INDEX_TYPE>
|
||||
|
@ -190,6 +190,9 @@ public:
|
||||
case CV_32S:
|
||||
reductionDispatch<int32_t, T_INDEX>(std::forward<Args>(args)...);
|
||||
break;
|
||||
case CV_64S:
|
||||
reductionDispatch<int64_t, T_INDEX>(std::forward<Args>(args)...);
|
||||
break;
|
||||
case CV_32F:
|
||||
reductionDispatch<float, T_INDEX>(std::forward<Args>(args)...);
|
||||
break;
|
||||
|
@ -185,6 +185,9 @@ public:
|
||||
case CV_32S:
|
||||
reductionDispatch<int32_t, T_INDEX>(std::forward<Args>(args)...);
|
||||
break;
|
||||
case CV_64S:
|
||||
reductionDispatch<int64_t, T_INDEX>(std::forward<Args>(args)...);
|
||||
break;
|
||||
case CV_32F:
|
||||
reductionDispatch<float, T_INDEX>(std::forward<Args>(args)...);
|
||||
break;
|
||||
|
Loading…
Reference in New Issue
Block a user