diff --git a/modules/dnn/src/layers/gather_layer.cpp b/modules/dnn/src/layers/gather_layer.cpp index ed31f85747..06ca8fcd66 100644 --- a/modules/dnn/src/layers/gather_layer.cpp +++ b/modules/dnn/src/layers/gather_layer.cpp @@ -52,7 +52,6 @@ public: outputs.assign(1, inputs[0]); } - void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE { CV_TRACE_FUNCTION(); @@ -66,43 +65,31 @@ public: CV_CheckEQ(inputs.size(), (size_t)2, ""); CV_CheckEQ(outputs.size(), (size_t)1, ""); + CV_CheckTypeEQ(inputs[0].type(), outputs[0].type(), ""); - const Mat& inp = inputs[0]; - - int indicesType = inputs[1].type(); - CV_CheckType(indicesType, indicesType == CV_32SC1 || indicesType == CV_64SC1, ""); - Mat indices32S; - if (indicesType == CV_64SC1) - { - inputs[1].convertTo(indices32S, CV_32S); - } + if (inputs[1].type() == CV_32SC1) + forward_impl(inputs[0], inputs[1], outputs[0]); + else if (inputs[1].type() == CV_64SC1) + forward_impl(inputs[0], inputs[1], outputs[0]); else - { - indices32S = inputs[1]; - } - const size_t indices_total = indices32S.total(); - indices32S = indices32S.reshape(1, indices_total); + CV_CheckType(inputs[1].type(), inputs[1].type() == CV_32SC1 || inputs[1].type() == CV_64SC1, ""); + } - Mat& out = outputs[0]; - - CV_CheckTypeEQ(inp.type(), out.type(), ""); - CV_CheckTypeEQ(indices32S.type(), CV_32SC1, ""); + template + void forward_impl(const Mat& inp, const Mat& indices, Mat& out) + { + const size_t indices_total = indices.total(); const int axis = normalize_axis(m_axis, shape(inp)); // FIXIT: why should we work with non-normalized input? it should be handled in importer or layers's output generator const int axis_size = (int)inp.size[axis]; - for (size_t j = 0 ; j < indices_total; ++j) - { - int& idx = indices32S.at(j); - idx = normalize_axis(idx, axis_size); // validate and normalize indices - } const size_t outer_size = axis == 0 ? inp.total() : inp.step1(axis - 1); const size_t outer_dims = inp.total() / outer_size; const size_t inner_size = inp.step1(axis); - const int* idx = indices32S.ptr(); + const T_INDEX* idx = indices.ptr(); const char* src = inp.ptr(); char* dst = out.ptr(); CV_CheckEQ(out.total(), outer_dims * indices_total * inner_size, ""); @@ -115,7 +102,7 @@ public: const size_t src_offset = i * outer_size; for (size_t j = 0 ; j < indices_total; ++j) { - const int index = idx[j]; + const int index = normalize_axis(idx[j], axis_size); CV_DbgCheck(index, index >= 0 && index < axis_size, ""); const size_t new_offset = src_offset + index * inner_size; std::memcpy(dst, src + new_offset * es, inner_bytes);