Added int64 values support to scatter, scatterND and maxunpool layers

This commit is contained in:
Alexander Lyulkov 2024-03-13 15:40:07 +03:00
parent 85cc02f4de
commit d2d6869a26
3 changed files with 34 additions and 8 deletions

View File

@ -94,14 +94,34 @@ public:
Mat& input = inputs[0]; Mat& input = inputs[0];
Mat& indices = inputs[1]; Mat& indices = inputs[1];
if (input.type() == CV_32F && indices.type() == CV_32S) if (indices.depth() == CV_32S)
run<float, int32_t>(input, indices, outputs); typeDispatch<int32_t>(input.type(), input, indices, outputs);
else if (input.type() == CV_32F && indices.type() == CV_64S) else if (indices.depth() == CV_64S)
run<float, int64_t>(input, indices, outputs); typeDispatch<int64_t>(input.type(), input, indices, outputs);
else if (input.type() == CV_16F && indices.type() == CV_32S) else
run<int16_t, int32_t>(input, indices, outputs); CV_Error(cv::Error::BadDepth, "Unsupported type.");
else if (input.type() == CV_16F && indices.type() == CV_64S) }
run<int16_t, int64_t>(input, indices, outputs);
template<typename T_INDEX, typename... Args>
inline void typeDispatch(const int type, Args&&... args)
{
switch (type)
{
case CV_32S:
run<int32_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_64S:
run<int64_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_32F:
run<float, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_16F:
run<int16_t, T_INDEX>(std::forward<Args>(args)...);
break;
default:
CV_Error(cv::Error::BadDepth, "Unsupported type.");
};
} }
template<typename T, typename INDEX_TYPE> template<typename T, typename INDEX_TYPE>

View File

@ -190,6 +190,9 @@ public:
case CV_32S: case CV_32S:
reductionDispatch<int32_t, T_INDEX>(std::forward<Args>(args)...); reductionDispatch<int32_t, T_INDEX>(std::forward<Args>(args)...);
break; break;
case CV_64S:
reductionDispatch<int64_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_32F: case CV_32F:
reductionDispatch<float, T_INDEX>(std::forward<Args>(args)...); reductionDispatch<float, T_INDEX>(std::forward<Args>(args)...);
break; break;

View File

@ -185,6 +185,9 @@ public:
case CV_32S: case CV_32S:
reductionDispatch<int32_t, T_INDEX>(std::forward<Args>(args)...); reductionDispatch<int32_t, T_INDEX>(std::forward<Args>(args)...);
break; break;
case CV_64S:
reductionDispatch<int64_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_32F: case CV_32F:
reductionDispatch<float, T_INDEX>(std::forward<Args>(args)...); reductionDispatch<float, T_INDEX>(std::forward<Args>(args)...);
break; break;