mirror of
https://github.com/opencv/opencv.git
synced 2025-01-19 06:53:50 +08:00
Added int support to CumSum layer (#25214)
* Added int support to CumSum layer * Allowed int types in CumSum layer --------- Co-authored-by: Alexander Lyulkov <alexander.lyulkov@opencv.ai>
This commit is contained in:
parent
d188319b82
commit
f8319de976
@ -32,6 +32,16 @@ public:
|
||||
return exclusive_raw == 0;
|
||||
}
|
||||
|
||||
virtual void getTypes(const std::vector<MatType>& inputs,
|
||||
const int requiredOutputs,
|
||||
const int requiredInternals,
|
||||
std::vector<MatType>& outputs,
|
||||
std::vector<MatType>& internals) const CV_OVERRIDE
|
||||
{
|
||||
CV_CheckType(inputs[0], inputs[0] == CV_32F || inputs[0] == CV_32S || inputs[0] == CV_64S || inputs[0] == CV_16F, "");
|
||||
outputs.assign(1, inputs[0]);
|
||||
}
|
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
@ -47,9 +57,30 @@ public:
|
||||
inputs_arr.getMatVector(inputs);
|
||||
outputs_arr.getMatVector(outputs);
|
||||
|
||||
CV_CheckTypeEQ(inputs[0].depth(), outputs[0].depth(), "");
|
||||
|
||||
switch(inputs[0].depth())
|
||||
{
|
||||
case CV_32F:
|
||||
forwardImpl<float>(inputs, outputs);
|
||||
break;
|
||||
case CV_32S:
|
||||
forwardImpl<int32_t>(inputs, outputs);
|
||||
break;
|
||||
case CV_64S:
|
||||
forwardImpl<int64_t>(inputs, outputs);
|
||||
break;
|
||||
default:
|
||||
CV_Error(Error::BadDepth, "");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void forwardImpl(const std::vector<Mat>& inputs, std::vector<Mat>& outputs)
|
||||
{
|
||||
// Get input tensor.
|
||||
const auto& src_mat = inputs[0];
|
||||
const auto* src_ptr = src_mat.ptr<float>();
|
||||
const T* src_ptr = src_mat.ptr<T>();
|
||||
|
||||
// Get target axis.
|
||||
int axis = inputs.size() > 1 ? parseAxis(inputs[1]) : axis_raw;
|
||||
@ -58,7 +89,7 @@ public:
|
||||
|
||||
// Get output tensor.
|
||||
auto& dst_mat = outputs[0];
|
||||
auto* dst_ptr = dst_mat.ptr<float>();
|
||||
T* dst_ptr = dst_mat.ptr<T>();
|
||||
|
||||
// Get flags.
|
||||
const auto exclusive = exclusive_raw == 1;
|
||||
@ -89,7 +120,7 @@ public:
|
||||
size_t first_inner_offset = target_offset + target_start * inner_size;
|
||||
if (exclusive)
|
||||
for (size_t inner_idx = 0; inner_idx < inner_size; inner_idx++)
|
||||
dst_ptr[first_inner_offset + inner_idx] = 0.0f;
|
||||
dst_ptr[first_inner_offset + inner_idx] = 0;
|
||||
else
|
||||
for (size_t inner_idx = 0; inner_idx < inner_size; inner_idx++)
|
||||
dst_ptr[first_inner_offset + inner_idx] = src_ptr[first_inner_offset + inner_idx];
|
||||
|
Loading…
Reference in New Issue
Block a user