Added int support to CumSum layer (#25214)

* Added int support to CumSum layer

* Allowed int types in CumSum layer

---------

Co-authored-by: Alexander Lyulkov <alexander.lyulkov@opencv.ai>
This commit is contained in:
alexlyulkov 2024-03-22 04:35:43 +03:00 committed by GitHub
parent d188319b82
commit f8319de976
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -32,6 +32,16 @@ public:
return exclusive_raw == 0;
}
virtual void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
CV_CheckType(inputs[0], inputs[0] == CV_32F || inputs[0] == CV_32S || inputs[0] == CV_64S || inputs[0] == CV_16F, "");
outputs.assign(1, inputs[0]);
}
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
{
CV_TRACE_FUNCTION();
@ -47,9 +57,30 @@ public:
inputs_arr.getMatVector(inputs);
outputs_arr.getMatVector(outputs);
CV_CheckTypeEQ(inputs[0].depth(), outputs[0].depth(), "");
switch(inputs[0].depth())
{
case CV_32F:
forwardImpl<float>(inputs, outputs);
break;
case CV_32S:
forwardImpl<int32_t>(inputs, outputs);
break;
case CV_64S:
forwardImpl<int64_t>(inputs, outputs);
break;
default:
CV_Error(Error::BadDepth, "");
}
}
template <typename T>
void forwardImpl(const std::vector<Mat>& inputs, std::vector<Mat>& outputs)
{
// Get input tensor.
const auto& src_mat = inputs[0];
const auto* src_ptr = src_mat.ptr<float>();
const T* src_ptr = src_mat.ptr<T>();
// Get target axis.
int axis = inputs.size() > 1 ? parseAxis(inputs[1]) : axis_raw;
@ -58,7 +89,7 @@ public:
// Get output tensor.
auto& dst_mat = outputs[0];
auto* dst_ptr = dst_mat.ptr<float>();
T* dst_ptr = dst_mat.ptr<T>();
// Get flags.
const auto exclusive = exclusive_raw == 1;
@ -89,7 +120,7 @@ public:
size_t first_inner_offset = target_offset + target_start * inner_size;
if (exclusive)
for (size_t inner_idx = 0; inner_idx < inner_size; inner_idx++)
dst_ptr[first_inner_offset + inner_idx] = 0.0f;
dst_ptr[first_inner_offset + inner_idx] = 0;
else
for (size_t inner_idx = 0; inner_idx < inner_size; inner_idx++)
dst_ptr[first_inner_offset + inner_idx] = src_ptr[first_inner_offset + inner_idx];