mirror of
https://github.com/opencv/opencv.git
synced 2025-06-08 01:53:19 +08:00
Added native int64 indices support to gather layer (#25211)
Co-authored-by: Alexander Lyulkov <alexander.lyulkov@opencv.ai>
This commit is contained in:
parent
f2cf3c8890
commit
aa9e80b07b
@ -52,7 +52,6 @@ public:
|
|||||||
outputs.assign(1, inputs[0]);
|
outputs.assign(1, inputs[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
|
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
|
||||||
{
|
{
|
||||||
CV_TRACE_FUNCTION();
|
CV_TRACE_FUNCTION();
|
||||||
@ -66,43 +65,31 @@ public:
|
|||||||
|
|
||||||
CV_CheckEQ(inputs.size(), (size_t)2, "");
|
CV_CheckEQ(inputs.size(), (size_t)2, "");
|
||||||
CV_CheckEQ(outputs.size(), (size_t)1, "");
|
CV_CheckEQ(outputs.size(), (size_t)1, "");
|
||||||
|
CV_CheckTypeEQ(inputs[0].type(), outputs[0].type(), "");
|
||||||
|
|
||||||
const Mat& inp = inputs[0];
|
if (inputs[1].type() == CV_32SC1)
|
||||||
|
forward_impl<int32_t>(inputs[0], inputs[1], outputs[0]);
|
||||||
int indicesType = inputs[1].type();
|
else if (inputs[1].type() == CV_64SC1)
|
||||||
CV_CheckType(indicesType, indicesType == CV_32SC1 || indicesType == CV_64SC1, "");
|
forward_impl<int64_t>(inputs[0], inputs[1], outputs[0]);
|
||||||
Mat indices32S;
|
|
||||||
if (indicesType == CV_64SC1)
|
|
||||||
{
|
|
||||||
inputs[1].convertTo(indices32S, CV_32S);
|
|
||||||
}
|
|
||||||
else
|
else
|
||||||
{
|
CV_CheckType(inputs[1].type(), inputs[1].type() == CV_32SC1 || inputs[1].type() == CV_64SC1, "");
|
||||||
indices32S = inputs[1];
|
}
|
||||||
}
|
|
||||||
const size_t indices_total = indices32S.total();
|
|
||||||
indices32S = indices32S.reshape(1, indices_total);
|
|
||||||
|
|
||||||
Mat& out = outputs[0];
|
template<typename T_INDEX>
|
||||||
|
void forward_impl(const Mat& inp, const Mat& indices, Mat& out)
|
||||||
CV_CheckTypeEQ(inp.type(), out.type(), "");
|
{
|
||||||
CV_CheckTypeEQ(indices32S.type(), CV_32SC1, "");
|
|
||||||
|
|
||||||
|
const size_t indices_total = indices.total();
|
||||||
const int axis = normalize_axis(m_axis, shape(inp));
|
const int axis = normalize_axis(m_axis, shape(inp));
|
||||||
|
|
||||||
// FIXIT: why should we work with non-normalized input? it should be handled in importer or layers's output generator
|
// FIXIT: why should we work with non-normalized input? it should be handled in importer or layers's output generator
|
||||||
const int axis_size = (int)inp.size[axis];
|
const int axis_size = (int)inp.size[axis];
|
||||||
for (size_t j = 0 ; j < indices_total; ++j)
|
|
||||||
{
|
|
||||||
int& idx = indices32S.at<int>(j);
|
|
||||||
idx = normalize_axis(idx, axis_size); // validate and normalize indices
|
|
||||||
}
|
|
||||||
|
|
||||||
const size_t outer_size = axis == 0 ? inp.total() : inp.step1(axis - 1);
|
const size_t outer_size = axis == 0 ? inp.total() : inp.step1(axis - 1);
|
||||||
const size_t outer_dims = inp.total() / outer_size;
|
const size_t outer_dims = inp.total() / outer_size;
|
||||||
const size_t inner_size = inp.step1(axis);
|
const size_t inner_size = inp.step1(axis);
|
||||||
|
|
||||||
const int* idx = indices32S.ptr<int>();
|
const T_INDEX* idx = indices.ptr<T_INDEX>();
|
||||||
const char* src = inp.ptr<const char>();
|
const char* src = inp.ptr<const char>();
|
||||||
char* dst = out.ptr<char>();
|
char* dst = out.ptr<char>();
|
||||||
CV_CheckEQ(out.total(), outer_dims * indices_total * inner_size, "");
|
CV_CheckEQ(out.total(), outer_dims * indices_total * inner_size, "");
|
||||||
@ -115,7 +102,7 @@ public:
|
|||||||
const size_t src_offset = i * outer_size;
|
const size_t src_offset = i * outer_size;
|
||||||
for (size_t j = 0 ; j < indices_total; ++j)
|
for (size_t j = 0 ; j < indices_total; ++j)
|
||||||
{
|
{
|
||||||
const int index = idx[j];
|
const int index = normalize_axis(idx[j], axis_size);
|
||||||
CV_DbgCheck(index, index >= 0 && index < axis_size, "");
|
CV_DbgCheck(index, index >= 0 && index < axis_size, "");
|
||||||
const size_t new_offset = src_offset + index * inner_size;
|
const size_t new_offset = src_offset + index * inner_size;
|
||||||
std::memcpy(dst, src + new_offset * es, inner_bytes);
|
std::memcpy(dst, src + new_offset * es, inner_bytes);
|
||||||
|
Loading…
Reference in New Issue
Block a user