mirror of
https://github.com/opencv/opencv.git
synced 2025-06-08 01:53:19 +08:00
Added int64 values support to scatter, scatterND and maxunpool layers
This commit is contained in:
parent
85cc02f4de
commit
d2d6869a26
@ -94,14 +94,34 @@ public:
|
|||||||
Mat& input = inputs[0];
|
Mat& input = inputs[0];
|
||||||
Mat& indices = inputs[1];
|
Mat& indices = inputs[1];
|
||||||
|
|
||||||
if (input.type() == CV_32F && indices.type() == CV_32S)
|
if (indices.depth() == CV_32S)
|
||||||
run<float, int32_t>(input, indices, outputs);
|
typeDispatch<int32_t>(input.type(), input, indices, outputs);
|
||||||
else if (input.type() == CV_32F && indices.type() == CV_64S)
|
else if (indices.depth() == CV_64S)
|
||||||
run<float, int64_t>(input, indices, outputs);
|
typeDispatch<int64_t>(input.type(), input, indices, outputs);
|
||||||
else if (input.type() == CV_16F && indices.type() == CV_32S)
|
else
|
||||||
run<int16_t, int32_t>(input, indices, outputs);
|
CV_Error(cv::Error::BadDepth, "Unsupported type.");
|
||||||
else if (input.type() == CV_16F && indices.type() == CV_64S)
|
}
|
||||||
run<int16_t, int64_t>(input, indices, outputs);
|
|
||||||
|
template<typename T_INDEX, typename... Args>
|
||||||
|
inline void typeDispatch(const int type, Args&&... args)
|
||||||
|
{
|
||||||
|
switch (type)
|
||||||
|
{
|
||||||
|
case CV_32S:
|
||||||
|
run<int32_t, T_INDEX>(std::forward<Args>(args)...);
|
||||||
|
break;
|
||||||
|
case CV_64S:
|
||||||
|
run<int64_t, T_INDEX>(std::forward<Args>(args)...);
|
||||||
|
break;
|
||||||
|
case CV_32F:
|
||||||
|
run<float, T_INDEX>(std::forward<Args>(args)...);
|
||||||
|
break;
|
||||||
|
case CV_16F:
|
||||||
|
run<int16_t, T_INDEX>(std::forward<Args>(args)...);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
CV_Error(cv::Error::BadDepth, "Unsupported type.");
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T, typename INDEX_TYPE>
|
template<typename T, typename INDEX_TYPE>
|
||||||
|
@ -190,6 +190,9 @@ public:
|
|||||||
case CV_32S:
|
case CV_32S:
|
||||||
reductionDispatch<int32_t, T_INDEX>(std::forward<Args>(args)...);
|
reductionDispatch<int32_t, T_INDEX>(std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
|
case CV_64S:
|
||||||
|
reductionDispatch<int64_t, T_INDEX>(std::forward<Args>(args)...);
|
||||||
|
break;
|
||||||
case CV_32F:
|
case CV_32F:
|
||||||
reductionDispatch<float, T_INDEX>(std::forward<Args>(args)...);
|
reductionDispatch<float, T_INDEX>(std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
|
@ -185,6 +185,9 @@ public:
|
|||||||
case CV_32S:
|
case CV_32S:
|
||||||
reductionDispatch<int32_t, T_INDEX>(std::forward<Args>(args)...);
|
reductionDispatch<int32_t, T_INDEX>(std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
|
case CV_64S:
|
||||||
|
reductionDispatch<int64_t, T_INDEX>(std::forward<Args>(args)...);
|
||||||
|
break;
|
||||||
case CV_32F:
|
case CV_32F:
|
||||||
reductionDispatch<float, T_INDEX>(std::forward<Args>(args)...);
|
reductionDispatch<float, T_INDEX>(std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
|
Loading…
Reference in New Issue
Block a user