mirror of
https://github.com/opencv/opencv.git
synced 2024-12-18 19:38:02 +08:00
fix incorrect steps and elemsize when dtype changes
This commit is contained in:
parent
8850a8219e
commit
fcaa8ce3c2
@ -47,11 +47,12 @@ public:
|
|||||||
std::vector<char*> ptrs;
|
std::vector<char*> ptrs;
|
||||||
std::vector<std::vector<int>> shapes;
|
std::vector<std::vector<int>> shapes;
|
||||||
std::vector<std::vector<size_t>> steps;
|
std::vector<std::vector<size_t>> steps;
|
||||||
|
std::vector<size_t> elemsize;
|
||||||
|
|
||||||
NaryEltwiseHelper() {
|
NaryEltwiseHelper() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void helperInit(const std::vector<Mat>& inputs, const std::vector<Mat>& outputs)
|
void init(const std::vector<Mat>& inputs, const std::vector<Mat>& outputs)
|
||||||
{
|
{
|
||||||
narrays = 0;
|
narrays = 0;
|
||||||
max_ndims = 0;
|
max_ndims = 0;
|
||||||
@ -61,6 +62,7 @@ public:
|
|||||||
ptrs.clear();
|
ptrs.clear();
|
||||||
shapes.clear();
|
shapes.clear();
|
||||||
steps.clear();
|
steps.clear();
|
||||||
|
elemsize.clear();
|
||||||
|
|
||||||
ninputs = inputs.size();
|
ninputs = inputs.size();
|
||||||
narrays = ninputs + 1;
|
narrays = ninputs + 1;
|
||||||
@ -95,15 +97,33 @@ public:
|
|||||||
}
|
}
|
||||||
orig_shapes.push_back(_size);
|
orig_shapes.push_back(_size);
|
||||||
orig_steps.push_back(_step);
|
orig_steps.push_back(_step);
|
||||||
|
|
||||||
|
int esz = i == 0 ? outputs[0].elemSize() : inputs[i - 1].elemSize();
|
||||||
|
elemsize.push_back(esz);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// use FP32 as default type in finalized() function
|
void reInit(size_t newElemSize) {
|
||||||
template <typename T>
|
std::vector<size_t> newElemSizes(elemsize.size(), newElemSize);
|
||||||
|
reInit(newElemSizes);
|
||||||
|
}
|
||||||
|
|
||||||
|
void reInit(std::vector<size_t> newElemSizes) {
|
||||||
|
for (size_t array_index = 0; array_index < orig_steps.size(); array_index++) {
|
||||||
|
auto &step = orig_steps[array_index];
|
||||||
|
int esz = elemsize[array_index];
|
||||||
|
int new_esz = newElemSizes[array_index];
|
||||||
|
for (size_t step_index = 0; step_index < step.size(); step_index++) {
|
||||||
|
step[step_index] = static_cast<size_t>(step[step_index] / esz * new_esz);
|
||||||
|
}
|
||||||
|
elemsize[array_index] = newElemSizes[array_index];
|
||||||
|
}
|
||||||
|
prepare_for_broadcast_op();
|
||||||
|
}
|
||||||
|
|
||||||
bool prepare_for_broadcast_op()
|
bool prepare_for_broadcast_op()
|
||||||
{
|
{
|
||||||
int i, j, k;
|
int i, j, k;
|
||||||
std::vector<size_t> elemsize(this->all_ndims.size(), sizeof(T));
|
|
||||||
|
|
||||||
// step 1.
|
// step 1.
|
||||||
// * make all inputs and the output max_ndims-dimensional.
|
// * make all inputs and the output max_ndims-dimensional.
|
||||||
@ -313,8 +333,8 @@ public:
|
|||||||
inputs_arr.getMatVector(inputs);
|
inputs_arr.getMatVector(inputs);
|
||||||
outputs_arr.getMatVector(outputs);
|
outputs_arr.getMatVector(outputs);
|
||||||
|
|
||||||
helper.helperInit(inputs, outputs);
|
helper.init(inputs, outputs);
|
||||||
CV_Assert(helper.prepare_for_broadcast_op<float>());
|
CV_Assert(helper.prepare_for_broadcast_op());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||||
@ -579,6 +599,7 @@ public:
|
|||||||
|
|
||||||
if (inputs_arr.depth() == CV_16F)
|
if (inputs_arr.depth() == CV_16F)
|
||||||
{
|
{
|
||||||
|
helper.reInit(sizeof(float));
|
||||||
forward_fallback(inputs_arr, outputs_arr, internals_arr);
|
forward_fallback(inputs_arr, outputs_arr, internals_arr);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -733,14 +754,13 @@ public:
|
|||||||
switch (type)
|
switch (type)
|
||||||
{
|
{
|
||||||
case CV_8U:
|
case CV_8U:
|
||||||
|
// TODO: integrate with type inference
|
||||||
|
helper.reInit(sizeof(uint8_t));
|
||||||
opDispatch<uint8_t>(std::forward<Args>(args)...);
|
opDispatch<uint8_t>(std::forward<Args>(args)...);
|
||||||
helper.prepare_for_broadcast_op<uint8_t>();
|
|
||||||
/*
|
|
||||||
recompute broadcasted shapes
|
|
||||||
because default type is FP32 which is calculated in finalize() function
|
|
||||||
*/
|
|
||||||
break;
|
break;
|
||||||
case CV_32S:
|
case CV_32S:
|
||||||
|
// TODO: integrate with type inference
|
||||||
|
helper.reInit(sizeof(int32_t));
|
||||||
opDispatch<int32_t>(std::forward<Args>(args)...);
|
opDispatch<int32_t>(std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case CV_32F:
|
case CV_32F:
|
||||||
|
Loading…
Reference in New Issue
Block a user