mirror of
https://github.com/opencv/opencv.git
synced 2024-11-28 13:10:12 +08:00
dnn: refactor reduce (#23613)
* initial impl * remove reduce in8; fix reduce importer * fix bugs and add log sum exp * remove unnecessary header and fix indentation
This commit is contained in:
parent
5229312ad2
commit
eefee8574a
@ -346,18 +346,9 @@ CV__DNN_INLINE_NS_BEGIN
|
||||
class CV_EXPORTS ReduceLayer : public Layer
|
||||
{
|
||||
public:
|
||||
int reduceType;
|
||||
// reduceDims contains the dimensions that need to be reduced, targetDims is the target output dimension.
|
||||
std::vector<size_t> reduceDims, targetDims;
|
||||
static Ptr<ReduceLayer> create(const LayerParams& params);
|
||||
};
|
||||
|
||||
class CV_EXPORTS ReduceLayerInt8 : public ReduceLayer
|
||||
{
|
||||
public:
|
||||
static Ptr<ReduceLayerInt8> create(const LayerParams& params);
|
||||
};
|
||||
|
||||
class CV_EXPORTS SoftmaxLayer : public Layer
|
||||
{
|
||||
public:
|
||||
|
@ -194,7 +194,6 @@ void initializeLayerFactory()
|
||||
CV_DNN_REGISTER_LAYER_CLASS(ConvolutionInt8, ConvolutionLayerInt8);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(InnerProductInt8, InnerProductLayerInt8);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(PoolingInt8, PoolingLayerInt8);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(ReduceInt8, ReduceLayerInt8);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(EltwiseInt8, EltwiseLayerInt8);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(BatchNormInt8, BatchNormLayerInt8);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(ScaleInt8, ScaleLayerInt8);
|
||||
|
@ -1,234 +0,0 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp"
|
||||
#include "layers_common.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdlib.h>
|
||||
#include <numeric>
|
||||
|
||||
namespace cv
|
||||
{
|
||||
namespace dnn
|
||||
{
|
||||
|
||||
class ReduceLayerInt8Impl CV_FINAL : public ReduceLayerInt8
|
||||
{
|
||||
public:
|
||||
ReduceLayerInt8Impl(const LayerParams& params)
|
||||
{
|
||||
// Set reduce type
|
||||
CV_Assert(params.has("reduce"));
|
||||
String typeString = toLowerCase(params.get<String>("reduce"));
|
||||
if (typeString == "max")
|
||||
reduceType = MAX;
|
||||
else if (typeString == "min")
|
||||
reduceType = MIN;
|
||||
else
|
||||
CV_Error(Error::StsBadArg, "Unknown reduce type \"" + typeString + "\"");
|
||||
|
||||
// Set deleted dims
|
||||
CV_Assert(params.has("deleted_dims"));
|
||||
DictValue tempDims = params.get("deleted_dims");
|
||||
int i, n = tempDims.size();
|
||||
reduceDims.resize(n);
|
||||
for (i = 0; i < n; i++)
|
||||
{
|
||||
reduceDims[i] = tempDims.get<int>(i);
|
||||
}
|
||||
|
||||
CV_Assert(params.has("target_dims"));
|
||||
tempDims = params.get("target_dims");
|
||||
n = tempDims.size();
|
||||
targetDims.resize(n);
|
||||
for (i = 0; i < n; i++)
|
||||
{
|
||||
targetDims[i] = tempDims.get<int>(i);
|
||||
}
|
||||
}
|
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE
|
||||
{
|
||||
if (backendId == DNN_BACKEND_OPENCV)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// reduceType == MIN
|
||||
struct ReduceOpMIN
|
||||
{
|
||||
int8_t apply(const int8_t* first, const int8_t* last)
|
||||
{
|
||||
return std::accumulate(first, last, *first,
|
||||
[](int8_t a, int8_t b)
|
||||
{
|
||||
return std::min(a, b);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// reduceType == MAX
|
||||
struct ReduceOpMAX
|
||||
{
|
||||
int8_t apply(const int8_t* first, const int8_t* last)
|
||||
{
|
||||
return std::accumulate(first, last, *first,
|
||||
[](int8_t a, int8_t b)
|
||||
{
|
||||
return std::max(a, b);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Func>
|
||||
class ReduceInvoker : public ParallelLoopBody
|
||||
{
|
||||
public:
|
||||
const Mat* src;
|
||||
Mat *dst;
|
||||
std::vector<size_t> reduceDims;
|
||||
int nstripes;
|
||||
int reduceType;
|
||||
Ptr<Func> func;
|
||||
|
||||
ReduceInvoker() : src(0), dst(0), nstripes(0), reduceType(MAX), func(makePtr<Func>()) {}
|
||||
|
||||
static void run(const Mat& src, Mat& dst, std::vector<size_t> reduceDims, int reduceType, int nstripes)
|
||||
{
|
||||
CV_Assert_N(src.isContinuous(), dst.isContinuous(), src.type() == CV_8S, src.type() == dst.type());
|
||||
|
||||
ReduceInvoker<Func> p;
|
||||
|
||||
p.src = &src;
|
||||
p.dst = &dst;
|
||||
|
||||
p.reduceDims = reduceDims;
|
||||
p.nstripes = nstripes;
|
||||
p.reduceType = reduceType;
|
||||
|
||||
parallel_for_(Range(0, nstripes), p, nstripes);
|
||||
}
|
||||
|
||||
void operator()(const Range& r) const CV_OVERRIDE
|
||||
{
|
||||
size_t total = dst->total();
|
||||
size_t stripeSize = (total + nstripes - 1)/nstripes;
|
||||
size_t stripeStart = r.start*stripeSize;
|
||||
size_t stripeEnd = std::min(r.end*stripeSize, total);
|
||||
size_t totalDeleted = std::accumulate(reduceDims.begin(), reduceDims.end(), 1, std::multiplies<size_t>());
|
||||
|
||||
int8_t *dstData = (int8_t *)dst->data;
|
||||
int8_t *srcData = (int8_t *)src->data;
|
||||
|
||||
for (size_t ofs = stripeStart; ofs < stripeEnd;)
|
||||
{
|
||||
const int8_t* first = srcData + ofs * totalDeleted;
|
||||
const int8_t* last = srcData + (ofs + 1) * totalDeleted;
|
||||
|
||||
dstData[ofs] = func->apply(first, last);
|
||||
ofs += 1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
|
||||
|
||||
std::vector<Mat> inputs, outputs;
|
||||
inputs_arr.getMatVector(inputs);
|
||||
outputs_arr.getMatVector(outputs);
|
||||
CV_Assert(inputs.size() == 1);
|
||||
const int nstripes = getNumThreads();
|
||||
|
||||
switch (reduceType)
|
||||
{
|
||||
case MIN:
|
||||
{
|
||||
ReduceInvoker<ReduceOpMIN>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
|
||||
break;
|
||||
}
|
||||
case MAX:
|
||||
{
|
||||
ReduceInvoker<ReduceOpMAX>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
CV_Error(Error::StsNotImplemented, "Not implemented");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||
const int requiredOutputs,
|
||||
std::vector<MatShape> &outputs,
|
||||
std::vector<MatShape> &internals) const CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(inputs.size() > 0);
|
||||
CV_Assert( reduceDims.size() !=0 && targetDims.size() != 0 && inputs[0].size() >= reduceDims.size());
|
||||
|
||||
// outShapeTmp can save the right number of `total(outShapeTmp)`. And the outShape is used as the final output shape.
|
||||
std::vector<int> outShapeTmp, outShape;
|
||||
outShape.assign(targetDims.begin(), targetDims.end());
|
||||
if (inputs[0].size() == reduceDims.size())
|
||||
outShapeTmp.push_back(1);
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < inputs[0].size() - reduceDims.size(); i++)
|
||||
{
|
||||
outShapeTmp.push_back(inputs[0][i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Support dynamic shape of Batch size.
|
||||
// Note that: when there are multiple dynamic inputs, we will give an error.
|
||||
if (total(outShape) != total(outShapeTmp))
|
||||
{
|
||||
if (outShape[0] != outShapeTmp[0])
|
||||
outShape[0] = outShapeTmp[0];
|
||||
}
|
||||
|
||||
CV_Assert(total(outShape) == total(outShapeTmp));
|
||||
outputs.assign(1, outShape);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual bool tryQuantize(const std::vector<std::vector<float> > &scales,
|
||||
const std::vector<std::vector<int> > &zeropoints, LayerParams& params) CV_OVERRIDE
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
|
||||
const std::vector<MatShape> &outputs) const CV_OVERRIDE
|
||||
{
|
||||
CV_UNUSED(inputs); // suppress unused variable warning
|
||||
long flops = 0;
|
||||
size_t totalDeleted = std::accumulate(reduceDims.begin(), reduceDims.end(), 1, std::multiplies<size_t>());
|
||||
for (int i = 0; i < outputs.size(); i++)
|
||||
{
|
||||
flops += total(outputs[i])*(totalDeleted);
|
||||
}
|
||||
return flops;
|
||||
}
|
||||
private:
|
||||
enum Type
|
||||
{
|
||||
MAX,
|
||||
MIN
|
||||
};
|
||||
};
|
||||
|
||||
Ptr<ReduceLayerInt8> ReduceLayerInt8::create(const LayerParams& params)
|
||||
{
|
||||
return Ptr<ReduceLayerInt8>(new ReduceLayerInt8Impl(params));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@ -3,251 +3,449 @@
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp"
|
||||
#include "opencv2/core/hal/intrin.hpp"
|
||||
#include "../op_cuda.hpp"
|
||||
#include "../op_webnn.hpp"
|
||||
#include <opencv2/dnn/shape_utils.hpp>
|
||||
|
||||
#include <float.h>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
using std::max;
|
||||
using std::min;
|
||||
|
||||
#include <opencv2/core/utils/logger.hpp>
|
||||
|
||||
namespace cv
|
||||
{
|
||||
namespace dnn
|
||||
{
|
||||
namespace cv { namespace dnn {
|
||||
|
||||
class ReduceLayerImpl CV_FINAL : public ReduceLayer
|
||||
{
|
||||
public:
|
||||
ReduceLayerImpl(const LayerParams& params)
|
||||
{
|
||||
ReduceLayerImpl(const LayerParams& params) {
|
||||
setParamsFrom(params);
|
||||
|
||||
// set reduce type
|
||||
CV_Assert(params.has("reduce"));
|
||||
String typeString = toLowerCase(params.get<String>("reduce"));
|
||||
if (typeString == "max")
|
||||
reduceType= MAX;
|
||||
else if (typeString == "min")
|
||||
reduceType= MIN;
|
||||
else if (typeString == "ave")
|
||||
reduceType= AVE;
|
||||
else if (typeString == "sum")
|
||||
reduceType= SUM;
|
||||
else if (typeString == "sum_square")
|
||||
reduceType= SUM_SQUARE;
|
||||
else if (typeString == "l1")
|
||||
reduceType= L1;
|
||||
else if (typeString == "l2")
|
||||
reduceType= L2;
|
||||
else if (typeString == "log_sum")
|
||||
reduceType= LOG_SUM;
|
||||
else if (typeString == "log_sum_exp")
|
||||
reduceType= LOG_SUM_EXP;
|
||||
else if (typeString == "prod")
|
||||
reduceType= PROD;
|
||||
String op_type = toLowerCase(params.get<String>("reduce"));
|
||||
if (op_type == "max")
|
||||
reduce_type = ReduceType::MAX;
|
||||
else if (op_type == "min")
|
||||
reduce_type = ReduceType::MIN;
|
||||
else if (op_type == "mean")
|
||||
reduce_type = ReduceType::MEAN;
|
||||
else if (op_type == "sum")
|
||||
reduce_type = ReduceType::SUM;
|
||||
else if (op_type == "sum_square")
|
||||
reduce_type = ReduceType::SUM_SQUARE;
|
||||
else if (op_type == "l1")
|
||||
reduce_type = ReduceType::L1;
|
||||
else if (op_type == "l2")
|
||||
reduce_type = ReduceType::L2;
|
||||
else if (op_type == "log_sum")
|
||||
reduce_type = ReduceType::LOG_SUM;
|
||||
else if (op_type == "log_sum_exp")
|
||||
reduce_type = ReduceType::LOG_SUM_EXP;
|
||||
else if (op_type == "prod")
|
||||
reduce_type = ReduceType::PROD;
|
||||
else
|
||||
CV_Error(Error::StsBadArg, "Unknown reduce type\"" + typeString + "\"");
|
||||
CV_Error(Error::StsBadArg, "Unknown reduce type\"" + op_type + "\"");
|
||||
|
||||
// set deleted dims
|
||||
CV_Assert(params.has("deleted_dims"));
|
||||
DictValue tempDims = params.get("deleted_dims");
|
||||
int i, n = tempDims.size();
|
||||
reduceDims.resize(n);
|
||||
for (i = 0; i < n; i++)
|
||||
{
|
||||
reduceDims[i] = tempDims.get<int>(i);
|
||||
}
|
||||
keepdims = params.get<bool>("keepdims", true);
|
||||
noop_with_empty_axes = params.get<bool>("noop_with_empty_axes", false);
|
||||
|
||||
CV_Assert(params.has("target_dims"));
|
||||
tempDims = params.get("target_dims");
|
||||
n = tempDims.size();
|
||||
targetDims.resize(n);
|
||||
for (i = 0; i < n; i++)
|
||||
{
|
||||
targetDims[i] = tempDims.get<int>(i);
|
||||
// get axes if it is existed, otherwise reduce all
|
||||
if (params.has("axes")) {
|
||||
auto param_axes = params.get("axes");
|
||||
int num_axes = param_axes.size();
|
||||
axes.resize(num_axes);
|
||||
for (int i = 0; i < num_axes; ++i)
|
||||
axes[i] = param_axes.get<int>(i);
|
||||
}
|
||||
}
|
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE
|
||||
{
|
||||
if (backendId == DNN_BACKEND_OPENCV)
|
||||
{
|
||||
return true;
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE {
|
||||
return backendId == DNN_BACKEND_OPENCV;
|
||||
}
|
||||
|
||||
virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE {
|
||||
if (axes.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<Mat> inputs, outputs;
|
||||
inputs_arr.getMatVector(inputs);
|
||||
outputs_arr.getMatVector(outputs);
|
||||
|
||||
auto shape_input = shape(inputs[0]);
|
||||
for (auto i = 0; i < axes.size(); ++i) {
|
||||
auto norm_axis = normalize_axis(axes[i], shape_input);
|
||||
axes[i] = norm_axis;
|
||||
}
|
||||
|
||||
bool do_nothing = true;
|
||||
for (auto axis : axes) {
|
||||
if (shape_input[axis] != 1) {
|
||||
do_nothing = false;
|
||||
}
|
||||
}
|
||||
if (do_nothing) {
|
||||
axes.clear();
|
||||
noop_with_empty_axes = true;
|
||||
}
|
||||
}
|
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||
const int requiredOutputs,
|
||||
std::vector<MatShape> &outputs,
|
||||
std::vector<MatShape> &internals) const CV_OVERRIDE
|
||||
{
|
||||
// empty axes
|
||||
if (axes.empty()) {
|
||||
if (noop_with_empty_axes) {
|
||||
// do nothing
|
||||
outputs.assign(1, inputs[0]);
|
||||
} else {
|
||||
// reduce all axes
|
||||
MatShape shape_output;
|
||||
if (keepdims) {
|
||||
shape_output = inputs[0];
|
||||
for (auto i = 0; i < shape_output.size(); ++i)
|
||||
shape_output[i] = 1;
|
||||
} else {
|
||||
shape_output.push_back(1);
|
||||
}
|
||||
outputs.assign(1, shape_output);
|
||||
}
|
||||
} else {
|
||||
auto shape_output_ = inputs[0];
|
||||
for (size_t i = 0; i < axes.size(); ++i) {
|
||||
auto norm_axis = normalize_axis(axes[i], inputs[0]);
|
||||
shape_output_[norm_axis] = -1;
|
||||
}
|
||||
MatShape shape_output;
|
||||
for (size_t i = 0; i < shape_output_.size(); ++i) {
|
||||
if (shape_output_[i] == -1) {
|
||||
if (keepdims)
|
||||
shape_output.push_back(1);
|
||||
else
|
||||
continue;
|
||||
} else
|
||||
shape_output.push_back(shape_output_[i]);
|
||||
}
|
||||
if (shape_output.empty())
|
||||
shape_output.push_back(1);
|
||||
|
||||
outputs.assign(1, shape_output);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// reduceType == MIN
|
||||
struct ReduceOpMIN
|
||||
{
|
||||
float apply(const float* first, const float* last, const float ikarea = 1.0f)
|
||||
{
|
||||
return std::accumulate(first, last, FLT_MAX,
|
||||
[](float a, float b)
|
||||
{
|
||||
return std::min(a, b);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// reduceType == MAX
|
||||
struct ReduceOpMAX
|
||||
{
|
||||
float apply(const float* first, const float* last, const float ikarea = 1.0f)
|
||||
{
|
||||
return std::accumulate(first, last, -FLT_MAX,
|
||||
[](float a, float b)
|
||||
{
|
||||
return std::max(a, b);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// reduceType == SUM
|
||||
struct ReduceOpSUM
|
||||
{
|
||||
float apply(const float* first, const float* last, const float ikarea = 1.0f)
|
||||
{
|
||||
return std::accumulate(first, last, 0.f);
|
||||
}
|
||||
};
|
||||
|
||||
// reduceType == AVE
|
||||
struct ReduceOpAVE
|
||||
{
|
||||
float apply(const float* first, const float* last, const float ikarea = 1.0f)
|
||||
{
|
||||
float output = std::accumulate(first, last, 0.f);
|
||||
return output * ikarea;
|
||||
}
|
||||
};
|
||||
|
||||
// reduceType == SUM_SQUARE
|
||||
struct ReduceOpSUM_SQUARE
|
||||
{
|
||||
float apply(const float* first, const float* last, const float ikarea = 1.0f)
|
||||
{
|
||||
return std::accumulate(first, last, 0.f,
|
||||
[](float a, float b)
|
||||
{
|
||||
return a + b * b;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// reduceType == L1
|
||||
struct ReduceOpL1
|
||||
{
|
||||
float apply(const float* first, const float* last, const float ikarea = 1.0f)
|
||||
{
|
||||
return std::accumulate(first, last, 0.f,
|
||||
[](float a, float b)
|
||||
{
|
||||
return a + std::abs(b);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// reduceType == L2
|
||||
struct ReduceOpL2
|
||||
{
|
||||
float apply(const float* first, const float* last, const float ikarea = 1.0f)
|
||||
{
|
||||
float output = std::accumulate(first, last, 0.f,
|
||||
[](float a, float b)
|
||||
{
|
||||
return a + b * b;
|
||||
});
|
||||
return std::sqrt(output);
|
||||
}
|
||||
};
|
||||
|
||||
// reduceType == PROD
|
||||
struct ReduceOpPROD
|
||||
{
|
||||
float apply(const float* first, const float* last, const float ikarea = 1.0f)
|
||||
{
|
||||
return std::accumulate(first, last, 1.0f, std::multiplies<float>());
|
||||
}
|
||||
};
|
||||
|
||||
// reduceType == LOG_SUM
|
||||
struct ReduceOpLOG_SUM
|
||||
{
|
||||
float apply(const float* first, const float* last, const float ikarea = 1.0f)
|
||||
{
|
||||
float output = std::accumulate(first, last, 0.0f);
|
||||
return std::log(output);
|
||||
}
|
||||
};
|
||||
|
||||
// reduceType == LOG_SUM_EXP
|
||||
struct ReduceOpLOG_SUM_EXP
|
||||
{
|
||||
float apply(const float* first, const float* last, const float ikarea = 1.0f)
|
||||
{
|
||||
float output = std::accumulate(first, last, 0.0f,
|
||||
[](float a, float b)
|
||||
{
|
||||
return a + std::exp(b);
|
||||
});
|
||||
return std::log(output);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Func>
|
||||
class ReduceInvoker : public ParallelLoopBody
|
||||
{
|
||||
template <typename T>
|
||||
class ReduceBase {
|
||||
public:
|
||||
const Mat* src;
|
||||
Mat *dst;
|
||||
std::vector<size_t> reduceDims;
|
||||
int nstripes;
|
||||
int reduceType;
|
||||
Ptr<Func> func;
|
||||
using dtype_input = T;
|
||||
|
||||
ReduceInvoker() : src(0), dst(0), nstripes(0), reduceType(MAX), func(makePtr<Func>()) {}
|
||||
ReduceBase(size_t n, const T& init) : n_(n), accumulator_(init) {}
|
||||
virtual void update(const T& a) = 0;
|
||||
virtual T get_value() { return accumulator_; }
|
||||
virtual ~ReduceBase() = default;
|
||||
protected:
|
||||
size_t n_;
|
||||
T accumulator_;
|
||||
};
|
||||
|
||||
static void run(const Mat& src, Mat& dst, std::vector<size_t> reduceDims, int reduceType, int nstripes)
|
||||
{
|
||||
CV_Assert_N( src.isContinuous(), dst.isContinuous(), src.type() == CV_32F, src.type() == dst.type());
|
||||
template <typename T>
|
||||
class ReduceMin : public ReduceBase<T> {
|
||||
public:
|
||||
ReduceMin(size_t n, const T& init) : ReduceBase<T>(n, init) {}
|
||||
void update(const T& a) override {
|
||||
this->accumulator_ = a > this->accumulator_ ? this->accumulator_ : a;
|
||||
}
|
||||
};
|
||||
|
||||
ReduceInvoker<Func> p;
|
||||
template <typename T>
|
||||
class ReduceMax : public ReduceBase<T> {
|
||||
public:
|
||||
ReduceMax(size_t n, const T& init) : ReduceBase<T>(n, init) {}
|
||||
void update(const T& a) override {
|
||||
this->accumulator_ = a > this->accumulator_ ? a : this->accumulator_;
|
||||
}
|
||||
};
|
||||
|
||||
p.src = &src;
|
||||
p.dst = &dst;
|
||||
template <typename T>
|
||||
class ReduceSum : public ReduceBase<T> {
|
||||
public:
|
||||
ReduceSum(size_t n, const T& init) : ReduceBase<T>(n, 0) {}
|
||||
void update(const T& a) override {
|
||||
this->accumulator_ += a;
|
||||
}
|
||||
};
|
||||
|
||||
p.reduceDims = reduceDims;
|
||||
p.nstripes = nstripes;
|
||||
p.reduceType = reduceType;
|
||||
template <typename T>
|
||||
class ReduceMean : public ReduceSum<T> {
|
||||
public:
|
||||
ReduceMean(size_t n, const T& init) : ReduceSum<T>(n, init) {}
|
||||
T get_value() override {
|
||||
return this->accumulator_ / static_cast<T>(this->n_);
|
||||
}
|
||||
};
|
||||
|
||||
parallel_for_(Range(0, nstripes), p, nstripes);
|
||||
template <typename T>
|
||||
class ReduceSumSquare : public ReduceBase<T> {
|
||||
public:
|
||||
ReduceSumSquare(size_t n, const T& init) : ReduceBase<T>(n, 0) {}
|
||||
void update(const T& a) override {
|
||||
this->accumulator_ += a * a;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class ReduceL1 : public ReduceBase<T> {
|
||||
public:
|
||||
ReduceL1(size_t n, const T& init) : ReduceBase<T>(n, 0) {}
|
||||
void update(const T& a) override {
|
||||
this->accumulator_ += a > 0 ? a : -a;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class ReduceL2 : public ReduceBase<T> {
|
||||
public:
|
||||
ReduceL2(size_t n, const T& init) : ReduceBase<T>(n, 0) {}
|
||||
void update(const T& a) override {
|
||||
this->accumulator_ += a * a;
|
||||
}
|
||||
T get_value() override {
|
||||
return std::sqrt(this->accumulator_);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class ReduceProd : public ReduceBase<T> {
|
||||
public:
|
||||
ReduceProd(size_t n, const T& init) : ReduceBase<T>(n, 1) {}
|
||||
void update(const T& a) override {
|
||||
this->accumulator_ *= a;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class ReduceLogSum : public ReduceBase<T> {
|
||||
public:
|
||||
ReduceLogSum(size_t n, const T& init) : ReduceBase<T>(n, 0) {}
|
||||
void update(const T& a) override {
|
||||
this->accumulator_ += a;
|
||||
}
|
||||
T get_value() override {
|
||||
return static_cast<T>(std::log(this->accumulator_));
|
||||
}
|
||||
};
|
||||
|
||||
// FIXME: overflow caution
|
||||
template <typename T>
|
||||
class ReduceLogSumExp : public ReduceBase<T> {
|
||||
public:
|
||||
ReduceLogSumExp(size_t n, const T& init) : ReduceBase<T>(n, 0) {}
|
||||
void update(const T& a) override {
|
||||
this->accumulator_ += static_cast<T>(std::exp(a));
|
||||
}
|
||||
T get_value() override {
|
||||
return static_cast<T>(std::log(this->accumulator_));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename Op>
|
||||
class ReduceAllInvoker : public ParallelLoopBody {
|
||||
public:
|
||||
using dtype = typename Op::dtype_input;
|
||||
|
||||
const Mat& src;
|
||||
Mat& dst;
|
||||
|
||||
int n_reduce;
|
||||
int loop_size;
|
||||
|
||||
int total;
|
||||
int cost_per_thread;
|
||||
|
||||
ReduceAllInvoker(const Mat& src_, Mat& dst_) : src(src_), dst(dst_) {
|
||||
auto shape_src = shape(src);
|
||||
|
||||
n_reduce = std::accumulate(shape_src.begin(), shape_src.end(), 1, std::multiplies<int>());
|
||||
loop_size = n_reduce;
|
||||
|
||||
total = 1;
|
||||
cost_per_thread = 1;
|
||||
}
|
||||
|
||||
void operator()(const Range& r) const CV_OVERRIDE
|
||||
{
|
||||
size_t total = dst->total();
|
||||
size_t stripeSize = (total + nstripes - 1)/nstripes;
|
||||
size_t stripeStart = r.start*stripeSize;
|
||||
size_t stripeEnd = std::min(r.end*stripeSize, total);
|
||||
size_t stride_w = std::accumulate(reduceDims.begin(), reduceDims.end(), 1, std::multiplies<size_t>());
|
||||
void operator()(const Range& r) const CV_OVERRIDE {
|
||||
int start = r.start;
|
||||
int end = r.end;
|
||||
|
||||
float *dstData = (float *)dst->data;
|
||||
float *srcData = (float *)src->data;
|
||||
const dtype* p_src = src.ptr<const dtype>();
|
||||
dtype* p_dst = dst.ptr<dtype>();
|
||||
|
||||
for (size_t ofs = stripeStart; ofs < stripeEnd;)
|
||||
{
|
||||
const float* first = srcData + ofs * stride_w;
|
||||
const float* last = srcData + (ofs + 1) * stride_w;
|
||||
for (int i = start; i < end; ++i) {
|
||||
Op accumulator(n_reduce, *p_src);
|
||||
for (int l = 0; l < loop_size; ++l) {
|
||||
accumulator.update(p_src[l]);
|
||||
}
|
||||
p_dst[i] = accumulator.get_value();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (ofs < stripeEnd)
|
||||
{
|
||||
dstData[ofs] = func->apply(first, last, 1.0 / stride_w);
|
||||
ofs += 1;
|
||||
template <typename Op>
|
||||
class ReduceInvoker : public ParallelLoopBody {
|
||||
public:
|
||||
using dtype = typename Op::dtype_input;
|
||||
|
||||
const Mat& src;
|
||||
Mat& dst;
|
||||
|
||||
std::vector<int> reduced_axes; // assume in ascending order
|
||||
|
||||
int n_reduce;
|
||||
int loop_size;
|
||||
|
||||
int last_reduced_dim;
|
||||
int last_reduced_step;
|
||||
std::vector<int> projected_steps;
|
||||
|
||||
int last_unreduced_dim;
|
||||
int last_unreduced_step;
|
||||
std::vector<int> unprojected_steps;
|
||||
|
||||
int total;
|
||||
int cost_per_thread;
|
||||
|
||||
ReduceInvoker(const Mat& src_, Mat& dst_, std::vector<int> axes_) : src(src_), dst(dst_), reduced_axes(axes_) {
|
||||
auto shape_src = shape(src);
|
||||
|
||||
auto steps_src = shape_src;
|
||||
steps_src[steps_src.size() - 1] = 1;
|
||||
for (int i = static_cast<int>(steps_src.size()) - 2; i >= 0; --i)
|
||||
steps_src[i] = steps_src[i + 1] * shape_src[i + 1];
|
||||
|
||||
size_t projection_size = 1;
|
||||
for (auto axis : reduced_axes) {
|
||||
projection_size *= shape_src[axis];
|
||||
}
|
||||
n_reduce = projection_size;
|
||||
|
||||
last_reduced_dim = shape_src[reduced_axes.back()];
|
||||
last_reduced_step = steps_src[reduced_axes.back()];
|
||||
loop_size = last_reduced_dim * last_reduced_step;
|
||||
projection_size /= last_reduced_dim;
|
||||
|
||||
// calculate projected_steps
|
||||
int last_reduced_axis = static_cast<int>(reduced_axes.size()) - 1;
|
||||
if (last_reduced_axis == 0) {
|
||||
projected_steps.resize(1, 0);
|
||||
} else {
|
||||
projected_steps.resize(projection_size);
|
||||
std::vector<int> projected_indices(last_reduced_axis, 0);
|
||||
for (size_t i = 0, current_step = 0; i < projection_size; ++i) {
|
||||
projected_steps[i] = current_step;
|
||||
++projected_indices[last_reduced_axis - 1];
|
||||
current_step += steps_src[reduced_axes[last_reduced_axis - 1]];
|
||||
for (int j = last_reduced_axis - 1; j > 0; --j) {
|
||||
if (projected_indices[j] < shape_src[reduced_axes[j]]) {
|
||||
break;
|
||||
}
|
||||
projected_indices[j] = 0;
|
||||
++projected_indices[j - 1];
|
||||
current_step = steps_src[reduced_axes[j - 1]];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// calculate unprojected_steps
|
||||
std::vector<int> unreduced_axes;
|
||||
for (int i = 0; i < static_cast<int>(shape_src.size()); ++i) {
|
||||
if (std::find(reduced_axes.begin(), reduced_axes.end(), i) == reduced_axes.end()) {
|
||||
unreduced_axes.push_back(i);
|
||||
}
|
||||
}
|
||||
size_t unprojection_size = 1;
|
||||
for (auto axis : unreduced_axes) {
|
||||
unprojection_size *= shape_src[axis];
|
||||
}
|
||||
last_unreduced_dim = shape_src[unreduced_axes.back()];
|
||||
last_unreduced_step = steps_src[unreduced_axes.back()];
|
||||
unprojection_size /= last_unreduced_dim;
|
||||
|
||||
std::vector<int> unprojected_indices(unreduced_axes.size(), 0);
|
||||
unprojected_steps.reserve(unprojection_size);
|
||||
if (unprojected_indices.size() <= 1) {
|
||||
unprojected_steps.push_back(0);
|
||||
} else {
|
||||
for (size_t i = 0, current_step = 0; i < unprojection_size; ++i) {
|
||||
unprojected_steps.push_back(current_step);
|
||||
++unprojected_indices[unprojected_indices.size() - 2];
|
||||
current_step += steps_src[unreduced_axes[unreduced_axes.size() - 2]];
|
||||
for (int j = static_cast<int>(unreduced_axes.size()) - 2; j > 0; --j) {
|
||||
if (unprojected_indices[j] < shape_src[unreduced_axes[j]]) {
|
||||
break;
|
||||
}
|
||||
unprojected_indices[j] = 0;
|
||||
++unprojected_indices[j - 1];
|
||||
current_step = steps_src[unreduced_axes[j - 1]];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto shape_dst = shape(dst);
|
||||
total = std::accumulate(shape_dst.begin(), shape_dst.end(), 1, std::multiplies<int>());
|
||||
cost_per_thread = static_cast<int>(projected_steps.size() * last_reduced_step);
|
||||
}
|
||||
|
||||
static void run(const Mat& src, Mat& dst, std::vector<int> axes, bool noop_with_empty_axes) {
|
||||
CV_Assert(src.isContinuous());
|
||||
CV_Assert(dst.isContinuous());
|
||||
|
||||
if (axes.empty()) {
|
||||
if (noop_with_empty_axes) {
|
||||
// copyTo is not used here for the reason that we want a
|
||||
// copy for the case when dims at all axes are 1
|
||||
const auto p_src = src.ptr<const dtype>();
|
||||
auto p_dst = dst.ptr<dtype>();
|
||||
std::memcpy(p_dst, p_src, sizeof(dtype) * dst.total());
|
||||
return;
|
||||
}
|
||||
|
||||
ReduceAllInvoker<Op> p(src, dst);
|
||||
double nstripes = (size_t)p.total * (size_t)p.cost_per_thread * (1 / 1024.0);
|
||||
parallel_for_(Range(0, p.total), p, nstripes);
|
||||
return;
|
||||
}
|
||||
|
||||
ReduceInvoker<Op> p(src, dst, axes);
|
||||
double nstripes = (size_t)p.total * (size_t)p.cost_per_thread * (1 / 1024.0);
|
||||
parallel_for_(Range(0, p.total), p, nstripes);
|
||||
}
|
||||
|
||||
void operator()(const Range& r) const CV_OVERRIDE {
|
||||
int start = r.start;
|
||||
int end = r.end;
|
||||
|
||||
const dtype* p_src = src.ptr<const dtype>();
|
||||
dtype* p_dst = dst.ptr<dtype>();
|
||||
|
||||
size_t main_index = start / last_unreduced_dim;
|
||||
size_t loop = start / last_unreduced_dim;
|
||||
size_t origin = unprojected_steps[main_index] + loop * last_unreduced_step;
|
||||
for (int i = start; i < end; ++i) {
|
||||
Op accumulator(n_reduce, p_src[origin + projected_steps[0]]);
|
||||
for (auto projected_step : projected_steps) {
|
||||
const dtype* loop_p_src = p_src + origin + projected_step;
|
||||
for (auto l = 0; l < loop_size; l += last_reduced_step) {
|
||||
accumulator.update(loop_p_src[l]);
|
||||
}
|
||||
}
|
||||
p_dst[i] = accumulator.get_value();
|
||||
|
||||
++loop;
|
||||
if (loop >= last_unreduced_dim) {
|
||||
loop = 0;
|
||||
++main_index;
|
||||
if (main_index < unprojected_steps.size()) {
|
||||
origin = unprojected_steps[main_index];
|
||||
}
|
||||
} else {
|
||||
origin += last_unreduced_step;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -267,129 +465,43 @@ public:
|
||||
std::vector<Mat> inputs, outputs;
|
||||
inputs_arr.getMatVector(inputs);
|
||||
outputs_arr.getMatVector(outputs);
|
||||
CV_Assert(inputs.size() == 1 || (inputs.size() == 2 && reduceType== SUM));
|
||||
const int nstripes = getNumThreads();
|
||||
|
||||
switch (reduceType)
|
||||
{
|
||||
case MIN:
|
||||
{
|
||||
ReduceInvoker<ReduceOpMIN>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
|
||||
break;
|
||||
}
|
||||
case MAX:
|
||||
{
|
||||
ReduceInvoker<ReduceOpMAX>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
|
||||
break;
|
||||
}
|
||||
case AVE:
|
||||
{
|
||||
ReduceInvoker<ReduceOpAVE>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
|
||||
break;
|
||||
}
|
||||
case SUM:
|
||||
{
|
||||
ReduceInvoker<ReduceOpSUM>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
|
||||
break;
|
||||
}
|
||||
case L1:
|
||||
{
|
||||
ReduceInvoker<ReduceOpL1>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
|
||||
break;
|
||||
}
|
||||
case L2:
|
||||
{
|
||||
ReduceInvoker<ReduceOpL2>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
|
||||
break;
|
||||
}
|
||||
case SUM_SQUARE:
|
||||
{
|
||||
ReduceInvoker<ReduceOpSUM_SQUARE>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
|
||||
break;
|
||||
}
|
||||
case PROD:
|
||||
{
|
||||
ReduceInvoker<ReduceOpPROD>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
|
||||
break;
|
||||
}
|
||||
case LOG_SUM:
|
||||
{
|
||||
ReduceInvoker<ReduceOpLOG_SUM>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
|
||||
break;
|
||||
}
|
||||
case LOG_SUM_EXP:
|
||||
{
|
||||
ReduceInvoker<ReduceOpLOG_SUM_EXP>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
CV_Error(Error::StsNotImplemented, "Not implemented");
|
||||
break;
|
||||
typeDispatch(outputs[0].type(), inputs[0], outputs[0], axes, noop_with_empty_axes);
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
inline void opDispatch(Args&&... args) {
|
||||
switch (reduce_type) {
|
||||
case ReduceType::MAX: ReduceInvoker<ReduceMax<T>>::run(std::forward<Args>(args)...); break;
|
||||
case ReduceType::MIN: ReduceInvoker<ReduceMin<T>>::run(std::forward<Args>(args)...); break;
|
||||
case ReduceType::MEAN: ReduceInvoker<ReduceMean<T>>::run(std::forward<Args>(args)...); break;
|
||||
case ReduceType::SUM: ReduceInvoker<ReduceSum<T>>::run(std::forward<Args>(args)...); break;
|
||||
case ReduceType::L1: ReduceInvoker<ReduceL1<T>>::run(std::forward<Args>(args)...); break;
|
||||
case ReduceType::L2: ReduceInvoker<ReduceL2<T>>::run(std::forward<Args>(args)...); break;
|
||||
case ReduceType::PROD: ReduceInvoker<ReduceProd<T>>::run(std::forward<Args>(args)...); break;
|
||||
case ReduceType::SUM_SQUARE: ReduceInvoker<ReduceSumSquare<T>>::run(std::forward<Args>(args)...); break;
|
||||
case ReduceType::LOG_SUM: ReduceInvoker<ReduceLogSum<T>>::run(std::forward<Args>(args)...); break;
|
||||
case ReduceType::LOG_SUM_EXP: ReduceInvoker<ReduceLogSumExp<T>>::run(std::forward<Args>(args)...); break;
|
||||
default: CV_Error(Error::StsBadArg, "DNN/Reduce: Unsupported operation.");
|
||||
}
|
||||
}
|
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||
const int requiredOutputs,
|
||||
std::vector<MatShape> &outputs,
|
||||
std::vector<MatShape> &internals) const CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(inputs.size() > 0);
|
||||
CV_Assert( reduceDims.size() !=0 && targetDims.size() != 0 && inputs[0].size() >= reduceDims.size());
|
||||
|
||||
// outShapeTmp can save the right number of `total(outShapeTmp)`. And the outShape is used as the final output shape.
|
||||
std::vector<int> outShapeTmp, outShape;
|
||||
outShape.assign(targetDims.begin(), targetDims.end());
|
||||
if (inputs[0].size() == reduceDims.size())
|
||||
outShapeTmp.push_back(1);
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < inputs[0].size() - reduceDims.size(); i++)
|
||||
{
|
||||
outShapeTmp.push_back(inputs[0][i]);
|
||||
}
|
||||
template <typename... Args>
|
||||
inline void typeDispatch(const int type, Args&&... args) {
|
||||
switch (type) {
|
||||
case CV_8U: opDispatch<uint8_t>(std::forward<Args>(args)...); break;
|
||||
case CV_32S: opDispatch<int32_t>(std::forward<Args>(args)...); break;
|
||||
case CV_32F: opDispatch<float>(std::forward<Args>(args)...); break;
|
||||
default: CV_Error(cv::Error::BadDepth, "DNN/Reduce: Unsupported type.");
|
||||
}
|
||||
|
||||
// Support dynamic shape of Batch size.
|
||||
// Note that: when there are multiple dynamic inputs, we will give an error.
|
||||
if (total(outShape) != total(outShapeTmp) && outShape[0] != outShapeTmp[0])
|
||||
{
|
||||
outShape[0] = outShapeTmp[0];
|
||||
}
|
||||
|
||||
CV_Assert(total(outShape) == total(outShapeTmp));
|
||||
outputs.assign(1, outShape);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual bool tryQuantize(const std::vector<std::vector<float> > &scales,
|
||||
const std::vector<std::vector<int> > &zeropoints, LayerParams& params) CV_OVERRIDE
|
||||
{
|
||||
if (reduceType== MAX || reduceType== MIN)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
|
||||
const std::vector<MatShape> &outputs) const CV_OVERRIDE
|
||||
{
|
||||
CV_UNUSED(inputs); // suppress unused variable warning
|
||||
long flops = 0;
|
||||
size_t stride_w = std::accumulate(reduceDims.begin(), reduceDims.end(), 1, std::multiplies<size_t>());
|
||||
for (int i = 0; i < outputs.size(); i++)
|
||||
{
|
||||
flops += total(outputs[i])*(stride_w);
|
||||
}
|
||||
return flops;
|
||||
}
|
||||
private:
|
||||
enum ReduceType
|
||||
{
|
||||
MAX,
|
||||
MIN,
|
||||
AVE,
|
||||
MEAN,
|
||||
SUM,
|
||||
L1,
|
||||
L2,
|
||||
@ -397,7 +509,11 @@ private:
|
||||
SUM_SQUARE,
|
||||
LOG_SUM,
|
||||
LOG_SUM_EXP
|
||||
};
|
||||
} reduce_type;
|
||||
|
||||
bool keepdims;
|
||||
bool noop_with_empty_axes;
|
||||
std::vector<int> axes;
|
||||
};
|
||||
|
||||
Ptr<ReduceLayer> ReduceLayer::create(const LayerParams& params)
|
||||
@ -405,5 +521,4 @@ Ptr<ReduceLayer> ReduceLayer::create(const LayerParams& params)
|
||||
return Ptr<ReduceLayer>(new ReduceLayerImpl(params));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}} // cv::dnn
|
||||
|
@ -1178,165 +1178,49 @@ void ONNXImporter::parseGlobalPool(LayerParams &layerParams, const opencv_onnx::
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
|
||||
void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
opencv_onnx::NodeProto node_proto = node_proto_;
|
||||
const std::string& layer_type = node_proto.op_type();
|
||||
const std::string output_name = node_proto.output(0);
|
||||
int depth = layerParams.get<int>("depth", CV_32F);
|
||||
|
||||
CV_Assert(node_proto.input_size() <= 2);
|
||||
String reduceType;
|
||||
|
||||
if (layer_type == "ReduceMax")
|
||||
reduceType = "MAX";
|
||||
else if (layer_type == "ReduceMin")
|
||||
reduceType = "MIN";
|
||||
else if (layer_type == "ReduceSum")
|
||||
reduceType = "SUM";
|
||||
else if (layer_type == "ReduceSumSquare")
|
||||
reduceType = "SUM_SQUARE";
|
||||
else if (layer_type == "ReduceProd")
|
||||
reduceType = "PROD";
|
||||
else if (layer_type == "ReduceL1")
|
||||
reduceType = "L1";
|
||||
else if (layer_type == "ReduceL2")
|
||||
reduceType = "L2";
|
||||
else if (layer_type == "ReduceLogSum")
|
||||
reduceType = "LOG_SUM";
|
||||
else if (layer_type == "ReduceLogSumExp")
|
||||
reduceType = "LOG_SUM_EXP";
|
||||
else if (layer_type == "ReduceMean")
|
||||
reduceType = "AVE";
|
||||
const auto& op_type = node_proto.op_type();
|
||||
String reduce_type;
|
||||
if (op_type == "ReduceMax")
|
||||
reduce_type = "MAX";
|
||||
else if (op_type == "ReduceMean")
|
||||
reduce_type = "MEAN";
|
||||
else if (op_type == "ReduceMin")
|
||||
reduce_type = "MIN";
|
||||
else if (op_type == "ReduceProd")
|
||||
reduce_type = "PROD";
|
||||
else if (op_type == "ReduceSum")
|
||||
reduce_type = "SUM";
|
||||
else if (op_type == "ReduceL1")
|
||||
reduce_type = "L1";
|
||||
else if (op_type == "ReduceL2")
|
||||
reduce_type = "L2";
|
||||
else if (op_type == "ReduceLogSum")
|
||||
reduce_type = "LOG_SUM";
|
||||
else if (op_type == "ReduceLogSumExp")
|
||||
reduce_type = "LOG_SUM_EXP";
|
||||
else if (op_type == "ReduceSumSquare")
|
||||
reduce_type = "SUM_SQUARE";
|
||||
else
|
||||
CV_Error(Error::StsNotImplemented, "Unsupported Pooling type of " + layer_type + " operation.");
|
||||
CV_Error(Error::StsNotImplemented, "DNN/ONNX: " + op_type + " is not supported.");
|
||||
layerParams.set("reduce", reduce_type);
|
||||
|
||||
// The ReduceInt8 can only support "MAX" and "MIN".
|
||||
if (depth == CV_8S)
|
||||
{
|
||||
CV_Assert(reduceType == "MAX" || reduceType == "MIN");
|
||||
int num_inputs = node_proto.input_size();
|
||||
CV_Check(num_inputs, num_inputs >= 1 && num_inputs <= 2, "DNN/ONNX: Reduce layers should have at least one input and at most two inputs");
|
||||
|
||||
// "axes" is turned to one of the inputs since opset 18,
|
||||
// except for ReduceSum, which has "axes" input since opset 13.
|
||||
if (!layerParams.has("axes") && num_inputs == 2 && constBlobs.find(node_proto.input(1)) != constBlobs.end()) {
|
||||
Mat mat_axes = getBlob(node_proto, 1);
|
||||
int num_axes = mat_axes.total();
|
||||
std::vector<int> axes(num_axes);
|
||||
for (int i = 0; i < num_axes; ++i)
|
||||
axes[i] = mat_axes.at<int>(i);
|
||||
layerParams.set("axes", DictValue::arrayInt(&axes[0], num_axes));
|
||||
}
|
||||
|
||||
layerParams.type = (depth == CV_8S) ? "ReduceInt8" : "Reduce";
|
||||
layerParams.set("reduce", reduceType);
|
||||
bool keepdims = layerParams.get<int>("keepdims", 1) == 1;
|
||||
|
||||
MatShape inpShape = outShapes[node_proto.input(0)];
|
||||
std::vector<bool> shouldDelete(inpShape.size(), false);
|
||||
|
||||
if (layer_type == "ReduceSum" && node_proto.input_size() == 2)
|
||||
{
|
||||
if (constBlobs.find(node_proto.input(1)) != constBlobs.end())
|
||||
{
|
||||
Mat axesMat = getBlob(node_proto, 1);
|
||||
int axesNum = axesMat.total();
|
||||
for (int i = 0; i < axesNum; i++)
|
||||
{
|
||||
int axis = normalize_axis(axesMat.at<int>(i), inpShape.size());
|
||||
shouldDelete[axis] = true;
|
||||
}
|
||||
}
|
||||
else
|
||||
// in opset 13, the ReduceSum has two input, it takes axes as input instead of attribute
|
||||
// details:https://github.com/onnx/onnx/issues/3420#issuecomment-844295687
|
||||
CV_Error(Error::StsNotImplemented, "Non-constant axis values in ReduceSum are not supported.");
|
||||
}
|
||||
else
|
||||
{
|
||||
if (layerParams.has("axes"))
|
||||
{
|
||||
DictValue axes = layerParams.get("axes");
|
||||
for (int i = 0; i < axes.size(); i++)
|
||||
{
|
||||
int axis = normalize_axis(axes.get<int>(i), inpShape.size());
|
||||
shouldDelete[axis] = true;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < inpShape.size(); i++)
|
||||
{
|
||||
shouldDelete[i] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> targetShape;
|
||||
for (int i = 0; i < inpShape.size(); ++i)
|
||||
{
|
||||
if (!shouldDelete[i])
|
||||
{
|
||||
targetShape.push_back(inpShape[i]);
|
||||
}
|
||||
else if (keepdims)
|
||||
{
|
||||
targetShape.push_back(1);
|
||||
}
|
||||
}
|
||||
|
||||
if (targetShape.empty())
|
||||
targetShape.push_back(1);
|
||||
|
||||
// Using PermuteLayer to move the deleted axis to the last.
|
||||
std::vector<int> perm(inpShape.size(), 0);
|
||||
for (int i = 0; i < inpShape.size(); i++)
|
||||
perm[i] = i;
|
||||
|
||||
bool needPermuet = false;
|
||||
for (int i = 0; i < inpShape.size(); i++)
|
||||
{
|
||||
if (shouldDelete[i])
|
||||
{
|
||||
// find the first not deleted element.
|
||||
std::vector<bool>::iterator iter = std::find(shouldDelete.begin() + i, shouldDelete.end(), false);
|
||||
|
||||
if (iter != shouldDelete.end())
|
||||
{
|
||||
int index = iter - shouldDelete.begin();
|
||||
|
||||
bool temp = shouldDelete[index];
|
||||
shouldDelete[index] = shouldDelete[i];
|
||||
shouldDelete[i] = temp;
|
||||
|
||||
std::swap(perm[index], perm[i]);
|
||||
std::swap(inpShape[index], inpShape[i]);
|
||||
needPermuet = true;
|
||||
}
|
||||
else
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto inputString= node_proto.input(0);
|
||||
if (needPermuet)
|
||||
{
|
||||
LayerParams permuteLp;
|
||||
permuteLp.name = layerParams.name + "/permute";
|
||||
permuteLp.type = (depth == CV_8S) ? "PermuteInt8" : "Permute";
|
||||
permuteLp.set("order", DictValue::arrayInt(perm.data(), perm.size()));
|
||||
|
||||
opencv_onnx::NodeProto protoPermute;
|
||||
protoPermute.add_input(inputString);
|
||||
protoPermute.add_output(permuteLp.name);
|
||||
addLayer(permuteLp, protoPermute);
|
||||
inputString = permuteLp.name;
|
||||
}
|
||||
|
||||
std::vector<int> deletedDims;
|
||||
for (int axis_i = 0; axis_i < inpShape.size(); ++axis_i)
|
||||
{
|
||||
if (shouldDelete[axis_i])
|
||||
{
|
||||
deletedDims.push_back(inpShape[axis_i]);
|
||||
}
|
||||
}
|
||||
|
||||
layerParams.set("deleted_dims", DictValue::arrayInt(&deletedDims[0], deletedDims.size()));
|
||||
layerParams.set("target_dims", DictValue::arrayInt(&targetShape[0], targetShape.size()));
|
||||
|
||||
node_proto.set_input(0, inputString);
|
||||
node_proto.set_output(0, output_name);
|
||||
|
||||
layerParams.type = "Reduce";
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user