diff --git a/modules/dnn/src/layers/nary_eltwise_layers.cpp b/modules/dnn/src/layers/nary_eltwise_layers.cpp index 5750766e51..c369a710a6 100644 --- a/modules/dnn/src/layers/nary_eltwise_layers.cpp +++ b/modules/dnn/src/layers/nary_eltwise_layers.cpp @@ -47,11 +47,12 @@ public: std::vector ptrs; std::vector> shapes; std::vector> steps; + std::vector elemsize; NaryEltwiseHelper() { } - void helperInit(const std::vector& inputs, const std::vector& outputs) + void init(const std::vector& inputs, const std::vector& outputs) { narrays = 0; max_ndims = 0; @@ -61,6 +62,7 @@ public: ptrs.clear(); shapes.clear(); steps.clear(); + elemsize.clear(); ninputs = inputs.size(); narrays = ninputs + 1; @@ -95,15 +97,33 @@ public: } orig_shapes.push_back(_size); orig_steps.push_back(_step); + + int esz = i == 0 ? outputs[0].elemSize() : inputs[i - 1].elemSize(); + elemsize.push_back(esz); } } - // use FP32 as default type in finalized() function - template + void reInit(size_t newElemSize) { + std::vector newElemSizes(elemsize.size(), newElemSize); + reInit(newElemSizes); + } + + void reInit(std::vector newElemSizes) { + for (size_t array_index = 0; array_index < orig_steps.size(); array_index++) { + auto &step = orig_steps[array_index]; + int esz = elemsize[array_index]; + int new_esz = newElemSizes[array_index]; + for (size_t step_index = 0; step_index < step.size(); step_index++) { + step[step_index] = static_cast(step[step_index] / esz * new_esz); + } + elemsize[array_index] = newElemSizes[array_index]; + } + prepare_for_broadcast_op(); + } + bool prepare_for_broadcast_op() { int i, j, k; - std::vector elemsize(this->all_ndims.size(), sizeof(T)); // step 1. // * make all inputs and the output max_ndims-dimensional. @@ -313,8 +333,8 @@ public: inputs_arr.getMatVector(inputs); outputs_arr.getMatVector(outputs); - helper.helperInit(inputs, outputs); - CV_Assert(helper.prepare_for_broadcast_op()); + helper.init(inputs, outputs); + CV_Assert(helper.prepare_for_broadcast_op()); } bool getMemoryShapes(const std::vector &inputs, @@ -579,6 +599,7 @@ public: if (inputs_arr.depth() == CV_16F) { + helper.reInit(sizeof(float)); forward_fallback(inputs_arr, outputs_arr, internals_arr); return; } @@ -733,14 +754,13 @@ public: switch (type) { case CV_8U: + // TODO: integrate with type inference + helper.reInit(sizeof(uint8_t)); opDispatch(std::forward(args)...); - helper.prepare_for_broadcast_op(); - /* - recompute broadcasted shapes - because default type is FP32 which is calculated in finalize() function - */ break; case CV_32S: + // TODO: integrate with type inference + helper.reInit(sizeof(int32_t)); opDispatch(std::forward(args)...); break; case CV_32F: