fix incorrect steps and elemsize when dtype changes

This commit is contained in:
fengyuentau 2024-02-06 16:27:25 +08:00
parent 8850a8219e
commit fcaa8ce3c2

View File

@ -47,11 +47,12 @@ public:
std::vector<char*> ptrs;
std::vector<std::vector<int>> shapes;
std::vector<std::vector<size_t>> steps;
std::vector<size_t> elemsize;
NaryEltwiseHelper() {
}
void helperInit(const std::vector<Mat>& inputs, const std::vector<Mat>& outputs)
void init(const std::vector<Mat>& inputs, const std::vector<Mat>& outputs)
{
narrays = 0;
max_ndims = 0;
@ -61,6 +62,7 @@ public:
ptrs.clear();
shapes.clear();
steps.clear();
elemsize.clear();
ninputs = inputs.size();
narrays = ninputs + 1;
@ -95,15 +97,33 @@ public:
}
orig_shapes.push_back(_size);
orig_steps.push_back(_step);
int esz = i == 0 ? outputs[0].elemSize() : inputs[i - 1].elemSize();
elemsize.push_back(esz);
}
}
// use FP32 as default type in finalized() function
template <typename T>
void reInit(size_t newElemSize) {
std::vector<size_t> newElemSizes(elemsize.size(), newElemSize);
reInit(newElemSizes);
}
void reInit(std::vector<size_t> newElemSizes) {
for (size_t array_index = 0; array_index < orig_steps.size(); array_index++) {
auto &step = orig_steps[array_index];
int esz = elemsize[array_index];
int new_esz = newElemSizes[array_index];
for (size_t step_index = 0; step_index < step.size(); step_index++) {
step[step_index] = static_cast<size_t>(step[step_index] / esz * new_esz);
}
elemsize[array_index] = newElemSizes[array_index];
}
prepare_for_broadcast_op();
}
bool prepare_for_broadcast_op()
{
int i, j, k;
std::vector<size_t> elemsize(this->all_ndims.size(), sizeof(T));
// step 1.
// * make all inputs and the output max_ndims-dimensional.
@ -313,8 +333,8 @@ public:
inputs_arr.getMatVector(inputs);
outputs_arr.getMatVector(outputs);
helper.helperInit(inputs, outputs);
CV_Assert(helper.prepare_for_broadcast_op<float>());
helper.init(inputs, outputs);
CV_Assert(helper.prepare_for_broadcast_op());
}
bool getMemoryShapes(const std::vector<MatShape> &inputs,
@ -579,6 +599,7 @@ public:
if (inputs_arr.depth() == CV_16F)
{
helper.reInit(sizeof(float));
forward_fallback(inputs_arr, outputs_arr, internals_arr);
return;
}
@ -733,14 +754,13 @@ public:
switch (type)
{
case CV_8U:
// TODO: integrate with type inference
helper.reInit(sizeof(uint8_t));
opDispatch<uint8_t>(std::forward<Args>(args)...);
helper.prepare_for_broadcast_op<uint8_t>();
/*
recompute broadcasted shapes
because default type is FP32 which is calculated in finalize() function
*/
break;
case CV_32S:
// TODO: integrate with type inference
helper.reInit(sizeof(int32_t));
opDispatch<int32_t>(std::forward<Args>(args)...);
break;
case CV_32F: