Support for some reduce layers for onnx

This commit is contained in:
Zihao Mu 2022-03-18 10:19:13 +08:00
parent 48cd2d190f
commit b6b5c27cec
8 changed files with 796 additions and 197 deletions

View File

@ -325,6 +325,20 @@ CV__DNN_INLINE_NS_BEGIN
static Ptr<PoolingLayerInt8> create(const LayerParams& params);
};
class CV_EXPORTS ReduceLayer : public Layer
{
public:
int reduceType;
std::vector<size_t> reduceDims;
static Ptr<ReduceLayer> create(const LayerParams& params);
};
class CV_EXPORTS ReduceLayerInt8 : public ReduceLayer
{
public:
static Ptr<ReduceLayerInt8> create(const LayerParams& params);
};
class CV_EXPORTS SoftmaxLayer : public Layer
{
public:

View File

@ -92,6 +92,7 @@ void initializeLayerFactory()
CV_DNN_REGISTER_LAYER_CLASS(Pooling, PoolingLayer);
CV_DNN_REGISTER_LAYER_CLASS(ROIPooling, PoolingLayer);
CV_DNN_REGISTER_LAYER_CLASS(PSROIPooling, PoolingLayer);
CV_DNN_REGISTER_LAYER_CLASS(Reduce, ReduceLayer);
CV_DNN_REGISTER_LAYER_CLASS(LRN, LRNLayer);
CV_DNN_REGISTER_LAYER_CLASS(InnerProduct, InnerProductLayer);
CV_DNN_REGISTER_LAYER_CLASS(Softmax, SoftmaxLayer);
@ -175,6 +176,7 @@ void initializeLayerFactory()
CV_DNN_REGISTER_LAYER_CLASS(ConvolutionInt8, ConvolutionLayerInt8);
CV_DNN_REGISTER_LAYER_CLASS(InnerProductInt8, InnerProductLayerInt8);
CV_DNN_REGISTER_LAYER_CLASS(PoolingInt8, PoolingLayerInt8);
CV_DNN_REGISTER_LAYER_CLASS(ReduceInt8, ReduceLayerInt8);
CV_DNN_REGISTER_LAYER_CLASS(EltwiseInt8, EltwiseLayerInt8);
CV_DNN_REGISTER_LAYER_CLASS(BatchNormInt8, BatchNormLayerInt8);
CV_DNN_REGISTER_LAYER_CLASS(ScaleInt8, ScaleLayerInt8);

View File

@ -0,0 +1,213 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include "../precomp.hpp"
#include "layers_common.hpp"
#include <algorithm>
#include <stdlib.h>
#include <numeric>
namespace cv
{
namespace dnn
{
class ReduceLayerInt8Impl CV_FINAL : public ReduceLayerInt8
{
public:
ReduceLayerInt8Impl(const LayerParams& params)
{
// Set reduce type
CV_Assert(params.has("reduce"));
String typeString = toLowerCase(params.get<String>("reduce"));
if (typeString == "max")
reduceType = MAX;
else if (typeString == "min")
reduceType = MIN;
else
CV_Error(Error::StsBadArg, "Unknown reduce type \"" + typeString + "\"");
// Set deleted dims
CV_Assert(params.has("deleted_dims"));
DictValue tempDims = params.get("deleted_dims");
int i, n = tempDims.size();
reduceDims.resize(n);
for (i = 0; i < n; i++)
{
reduceDims[i] = tempDims.get<int>(i);
}
}
virtual bool supportBackend(int backendId) CV_OVERRIDE
{
if (backendId == DNN_BACKEND_OPENCV)
{
return true;
}
return false;
}
// reduceType == MIN
struct ReduceOpMIN
{
int8_t apply(const int8_t* first, const int8_t* last)
{
return std::accumulate(first, last, *first,
[](int8_t a, int8_t b)
{
return std::min(a, b);
});
}
};
// reduceType == MAX
struct ReduceOpMAX
{
int8_t apply(const int8_t* first, const int8_t* last)
{
return std::accumulate(first, last, *first,
[](int8_t a, int8_t b)
{
return std::max(a, b);
});
}
};
template<typename Func>
class ReduceInvoker : public ParallelLoopBody
{
public:
const Mat* src;
Mat *dst;
std::vector<size_t> reduceDims;
int nstripes;
int reduceType;
Ptr<Func> func;
ReduceInvoker() : src(0), dst(0), nstripes(0), reduceType(MAX), func(makePtr<Func>()) {}
static void run(const Mat& src, Mat& dst, std::vector<size_t> reduceDims, int reduceType, int nstripes)
{
CV_Assert_N(src.isContinuous(), dst.isContinuous(), src.type() == CV_8S, src.type() == dst.type());
ReduceInvoker<Func> p;
p.src = &src;
p.dst = &dst;
p.reduceDims = reduceDims;
p.nstripes = nstripes;
p.reduceType = reduceType;
parallel_for_(Range(0, nstripes), p, nstripes);
}
void operator()(const Range& r) const CV_OVERRIDE
{
size_t total = dst->total();
size_t stripeSize = (total + nstripes - 1)/nstripes;
size_t stripeStart = r.start*stripeSize;
size_t stripeEnd = std::min(r.end*stripeSize, total);
size_t totalDeleted = std::accumulate(reduceDims.begin(), reduceDims.end(), 1, std::multiplies<size_t>());
int8_t *dstData = (int8_t *)dst->data;
int8_t *srcData = (int8_t *)src->data;
for (size_t ofs = stripeStart; ofs < stripeEnd;)
{
const int8_t* first = srcData + ofs * totalDeleted;
const int8_t* last = srcData + (ofs + 1) * totalDeleted;
dstData[ofs] = func->apply(first, last);
ofs += 1;
}
}
};
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
{
CV_TRACE_FUNCTION();
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
std::vector<Mat> inputs, outputs;
inputs_arr.getMatVector(inputs);
outputs_arr.getMatVector(outputs);
CV_Assert(inputs.size() == 1);
const int nstripes = getNumThreads();
switch (reduceType)
{
case MIN:
{
ReduceInvoker<ReduceOpMIN>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
break;
}
case MAX:
{
ReduceInvoker<ReduceOpMAX>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
break;
}
default:
CV_Error(Error::StsNotImplemented, "Not implemented");
break;
}
}
bool getMemoryShapes(const std::vector<MatShape> &inputs,
const int requiredOutputs,
std::vector<MatShape> &outputs,
std::vector<MatShape> &internals) const CV_OVERRIDE
{
CV_Assert(inputs.size() > 0);
CV_Assert(reduceDims.size() != 0 && inputs[0].size() >= reduceDims.size());
std::vector<int> outShape;
if (inputs[0].size() == reduceDims.size())
outShape.push_back(1);
else
{
for (int i = 0; i < inputs[0].size() - reduceDims.size(); i++)
{
outShape.push_back(inputs[0][i]);
}
}
outputs.assign(1, outShape);
return false;
}
virtual bool tryQuantize(const std::vector<std::vector<float> > &scales,
const std::vector<std::vector<int> > &zeropoints, LayerParams& params) CV_OVERRIDE
{
return false;
}
virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
const std::vector<MatShape> &outputs) const CV_OVERRIDE
{
CV_UNUSED(inputs); // suppress unused variable warning
long flops = 0;
size_t totalDeleted = std::accumulate(reduceDims.begin(), reduceDims.end(), 1, std::multiplies<size_t>());
for (int i = 0; i < outputs.size(); i++)
{
flops += total(outputs[i])*(totalDeleted);
}
return flops;
}
private:
enum Type
{
MAX,
MIN
};
};
Ptr<ReduceLayerInt8> ReduceLayerInt8::create(const LayerParams& params)
{
return Ptr<ReduceLayerInt8>(new ReduceLayerInt8Impl(params));
}
}
}

View File

@ -0,0 +1,388 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include "../precomp.hpp"
#include "opencv2/core/hal/intrin.hpp"
#include "../op_cuda.hpp"
#include "../op_webnn.hpp"
#include <float.h>
#include <algorithm>
#include <numeric>
using std::max;
using std::min;
#include <opencv2/core/utils/logger.hpp>
namespace cv
{
namespace dnn
{
class ReduceLayerImpl CV_FINAL : public ReduceLayer
{
public:
ReduceLayerImpl(const LayerParams& params)
{
// set reduce type
CV_Assert(params.has("reduce"));
String typeString = toLowerCase(params.get<String>("reduce"));
if (typeString == "max")
reduceType= MAX;
else if (typeString == "min")
reduceType= MIN;
else if (typeString == "ave")
reduceType= AVE;
else if (typeString == "sum")
reduceType= SUM;
else if (typeString == "sum_square")
reduceType= SUM_SQUARE;
else if (typeString == "l1")
reduceType= L1;
else if (typeString == "l2")
reduceType= L2;
else if (typeString == "log_sum")
reduceType= LOG_SUM;
else if (typeString == "log_sum_exp")
reduceType= LOG_SUM_EXP;
else if (typeString == "prod")
reduceType= PROD;
else
CV_Error(Error::StsBadArg, "Unknown reduce type\"" + typeString + "\"");
// set deleted dims
CV_Assert(params.has("deleted_dims"));
DictValue tempDims = params.get("deleted_dims");
int i, n = tempDims.size();
reduceDims.resize(n);
for (i = 0; i < n; i++)
{
reduceDims[i] = tempDims.get<int>(i);
}
}
virtual bool supportBackend(int backendId) CV_OVERRIDE
{
if (backendId == DNN_BACKEND_OPENCV)
{
return true;
}
return false;
}
// reduceType == MIN
struct ReduceOpMIN
{
float apply(const float* first, const float* last, const float ikarea = 1.0f)
{
return std::accumulate(first, last, FLT_MAX,
[](float a, float b)
{
return std::min(a, b);
});
}
};
// reduceType == MAX
struct ReduceOpMAX
{
float apply(const float* first, const float* last, const float ikarea = 1.0f)
{
return std::accumulate(first, last, -FLT_MAX,
[](float a, float b)
{
return std::max(a, b);
});
}
};
// reduceType == SUM
struct ReduceOpSUM
{
float apply(const float* first, const float* last, const float ikarea = 1.0f)
{
return std::accumulate(first, last, 0.f);
}
};
// reduceType == AVE
struct ReduceOpAVE
{
float apply(const float* first, const float* last, const float ikarea = 1.0f)
{
float output = std::accumulate(first, last, 0.f);
return output * ikarea;
}
};
// reduceType == SUM_SQUARE
struct ReduceOpSUM_SQUARE
{
float apply(const float* first, const float* last, const float ikarea = 1.0f)
{
return std::accumulate(first, last, 0.f,
[](float a, float b)
{
return a + b * b;
});
}
};
// reduceType == L1
struct ReduceOpL1
{
float apply(const float* first, const float* last, const float ikarea = 1.0f)
{
return std::accumulate(first, last, 0.f,
[](float a, float b)
{
return a + std::abs(b);
});
}
};
// reduceType == L2
struct ReduceOpL2
{
float apply(const float* first, const float* last, const float ikarea = 1.0f)
{
float output = std::accumulate(first, last, 0.f,
[](float a, float b)
{
return a + b * b;
});
return std::sqrt(output);
}
};
// reduceType == PROD
struct ReduceOpPROD
{
float apply(const float* first, const float* last, const float ikarea = 1.0f)
{
return std::accumulate(first, last, 1.0f, std::multiplies<float>());
}
};
// reduceType == LOG_SUM
struct ReduceOpLOG_SUM
{
float apply(const float* first, const float* last, const float ikarea = 1.0f)
{
float output = std::accumulate(first, last, 0.0f);
return std::log(output);
}
};
// reduceType == LOG_SUM_EXP
struct ReduceOpLOG_SUM_EXP
{
float apply(const float* first, const float* last, const float ikarea = 1.0f)
{
float output = std::accumulate(first, last, 0.0f,
[](float a, float b)
{
return a + std::exp(b);
});
return std::log(output);
}
};
template<typename Func>
class ReduceInvoker : public ParallelLoopBody
{
public:
const Mat* src;
Mat *dst;
std::vector<size_t> reduceDims;
int nstripes;
int reduceType;
Ptr<Func> func;
ReduceInvoker() : src(0), dst(0), nstripes(0), reduceType(MAX), func(makePtr<Func>()) {}
static void run(const Mat& src, Mat& dst, std::vector<size_t> reduceDims, int reduceType, int nstripes)
{
CV_Assert_N( src.isContinuous(), dst.isContinuous(), src.type() == CV_32F, src.type() == dst.type());
ReduceInvoker<Func> p;
p.src = &src;
p.dst = &dst;
p.reduceDims = reduceDims;
p.nstripes = nstripes;
p.reduceType = reduceType;
parallel_for_(Range(0, nstripes), p, nstripes);
}
void operator()(const Range& r) const CV_OVERRIDE
{
size_t total = dst->total();
size_t stripeSize = (total + nstripes - 1)/nstripes;
size_t stripeStart = r.start*stripeSize;
size_t stripeEnd = std::min(r.end*stripeSize, total);
size_t stride_w = std::accumulate(reduceDims.begin(), reduceDims.end(), 1, std::multiplies<size_t>());
float *dstData = (float *)dst->data;
float *srcData = (float *)src->data;
for (size_t ofs = stripeStart; ofs < stripeEnd;)
{
const float* first = srcData + ofs * stride_w;
const float* last = srcData + (ofs + 1) * stride_w;
if (ofs < stripeEnd)
{
dstData[ofs] = func->apply(first, last, 1.0 / stride_w);
ofs += 1;
}
}
}
};
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
{
CV_TRACE_FUNCTION();
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
if (inputs_arr.depth() == CV_16S)
{
forward_fallback(inputs_arr, outputs_arr, internals_arr);
return;
}
std::vector<Mat> inputs, outputs;
inputs_arr.getMatVector(inputs);
outputs_arr.getMatVector(outputs);
CV_Assert(inputs.size() == 1 || (inputs.size() == 2 && reduceType== SUM));
const int nstripes = getNumThreads();
switch (reduceType)
{
case MIN:
{
ReduceInvoker<ReduceOpMIN>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
break;
}
case MAX:
{
ReduceInvoker<ReduceOpMAX>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
break;
}
case AVE:
{
ReduceInvoker<ReduceOpAVE>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
break;
}
case SUM:
{
ReduceInvoker<ReduceOpSUM>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
break;
}
case L1:
{
ReduceInvoker<ReduceOpL1>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
break;
}
case L2:
{
ReduceInvoker<ReduceOpL2>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
break;
}
case SUM_SQUARE:
{
ReduceInvoker<ReduceOpSUM_SQUARE>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
break;
}
case PROD:
{
ReduceInvoker<ReduceOpPROD>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
break;
}
case LOG_SUM:
{
ReduceInvoker<ReduceOpLOG_SUM>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
break;
}
case LOG_SUM_EXP:
{
ReduceInvoker<ReduceOpLOG_SUM_EXP>::run(inputs[0], outputs[0], reduceDims, reduceType, nstripes);
break;
}
default:
CV_Error(Error::StsNotImplemented, "Not implemented");
break;
}
}
bool getMemoryShapes(const std::vector<MatShape> &inputs,
const int requiredOutputs,
std::vector<MatShape> &outputs,
std::vector<MatShape> &internals) const CV_OVERRIDE
{
CV_Assert(inputs.size() > 0);
CV_Assert(reduceDims.size() != 0 && inputs[0].size() >= reduceDims.size());
std::vector<int> outShape;
if (inputs[0].size() == reduceDims.size())
outShape.push_back(1);
else
{
for (int i = 0; i < inputs[0].size() - reduceDims.size(); i++)
{
outShape.push_back(inputs[0][i]);
}
}
outputs.assign(1, outShape);
return false;
}
virtual bool tryQuantize(const std::vector<std::vector<float> > &scales,
const std::vector<std::vector<int> > &zeropoints, LayerParams& params) CV_OVERRIDE
{
if (reduceType== MAX || reduceType== MIN)
{
return true;
}
return false;
}
virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
const std::vector<MatShape> &outputs) const CV_OVERRIDE
{
CV_UNUSED(inputs); // suppress unused variable warning
long flops = 0;
size_t stride_w = std::accumulate(reduceDims.begin(), reduceDims.end(), 1, std::multiplies<size_t>());
for (int i = 0; i < outputs.size(); i++)
{
flops += total(outputs[i])*(stride_w);
}
return flops;
}
private:
enum ReduceType
{
MAX,
MIN,
AVE,
SUM,
L1,
L2,
PROD,
SUM_SQUARE,
LOG_SUM,
LOG_SUM_EXP
};
};
Ptr<ReduceLayer> ReduceLayer::create(const LayerParams& params)
{
return Ptr<ReduceLayer>(new ReduceLayerImpl(params));
}
}
}

View File

@ -133,7 +133,9 @@ Net Net::Impl::quantize(InputArrayOfArrays calibData, int inputsDtype, int outpu
if (ld.type == "Blank" || ld.type == "Dropout" || ld.type == "Identity" || ld.type == "Silence" ||
ld.type == "Flatten" || ld.type == "Padding" || ld.type == "Permute" || ld.type == "Reshape" ||
ld.type == "ReLU6" || ld.type == "Reorg" || ld.type == "ShuffleChannel" || ld.type == "Resize" ||
(ld.type == "ReLU" && !ld.params.get<float>("negative_slope", 0.f)) /* ReLU with negative slope 0 */)
(ld.type == "ReLU" && !ld.params.get<float>("negative_slope", 0.f)) || /* ReLU with negative slope 0 */
(ld.type == "Reduce" && (toLowerCase(ld.params.get<String>("reduce")) == "max" ||
toLowerCase(ld.params.get<String>("reduce")) == "min")))
{
for (int i = 0; i < ld.outputBlobs.size(); i++)
{

View File

@ -122,6 +122,7 @@ private:
void parseMaxUnpool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseMaxPool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseAveragePool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseGlobalPool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseReduce (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseSlice (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseSplit (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
@ -1087,7 +1088,7 @@ void ONNXImporter::parseAveragePool(LayerParams& layerParams, const opencv_onnx:
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
void ONNXImporter::parseGlobalPool(LayerParams &layerParams, const opencv_onnx::NodeProto &node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
const std::string& layer_type = node_proto.op_type();
@ -1096,157 +1097,176 @@ void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::Node
CV_Assert(node_proto.input_size() == 1);
layerParams.type = "Pooling";
String pool;
if (layer_type == "GlobalMaxPool" || layer_type == "ReduceMax")
if (layer_type == "GlobalMaxPool")
pool = "MAX";
else if (layer_type == "ReduceSum")
pool = "SUM";
else
else if (layer_type == "GlobalAveragePool")
pool = "AVE";
else
CV_Error(Error::StsNotImplemented, "Unsupported Pooling type of " + layer_type + " operation.");
CV_Assert(!layerParams.has("axes"));
layerParams.set("global_pooling", true);
layerParams.set("pool", pool);
layerParams.set("global_pooling", !layerParams.has("axes"));
bool keepdims = layerParams.get<int>("keepdims", 1) == 1;
if (layerParams.has("axes") && (layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax"))
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
const std::string& layer_type = node_proto.op_type();
const std::string output_name = node_proto.output(0);
int depth = layerParams.get<int>("depth", CV_32F);
CV_Assert(node_proto.input_size() <= 2);
String reduceType;
if (layer_type == "ReduceMax")
reduceType = "MAX";
else if (layer_type == "ReduceMin")
reduceType = "MIN";
else if (layer_type == "ReduceSum")
reduceType = "SUM";
else if (layer_type == "ReduceSumSquare")
reduceType = "SUM_SQUARE";
else if (layer_type == "ReduceProd")
reduceType = "PROD";
else if (layer_type == "ReduceL1")
reduceType = "L1";
else if (layer_type == "ReduceL2")
reduceType = "L2";
else if (layer_type == "ReduceLogSum")
reduceType = "LOG_SUM";
else if (layer_type == "ReduceLogSumExp")
reduceType = "LOG_SUM_EXP";
else if (layer_type == "ReduceMean")
reduceType = "AVE";
else
CV_Error(Error::StsNotImplemented, "Unsupported Pooling type of " + layer_type + " operation.");
// The ReduceInt8 can only support "MAX" and "MIN".
if (depth == CV_8S)
{
CV_Assert(reduceType == "MAX" || reduceType == "MIN");
}
layerParams.type = (depth == CV_8S) ? "ReduceInt8" : "Reduce";
layerParams.set("reduce", reduceType);
bool keepdims = layerParams.get<int>("keepdims", 1) == 1;
if (layer_type == "ReduceSum" && node_proto.input_size() == 2)
{
// TODO support the opset 13 of ReduceSum.
// in opset 13, the ReduceSum has two input, it takes axes as input instead of attribute
// details:https://github.com/onnx/onnx/issues/3420#issuecomment-844295687
CV_Error(Error::StsNotImplemented, "Unsupported " + layer_type + " operation of opset 13, please try to "
"re-export the onnx model with opset 11.");
}
MatShape inpShape = outShapes[node_proto.input(0)];
std::vector<bool> shouldDelete(inpShape.size(), false);
if (layerParams.has("axes"))
{
MatShape inpShape = outShapes[node_proto.input(0)];
DictValue axes = layerParams.get("axes");
MatShape targetShape;
std::vector<bool> shouldDelete(inpShape.size(), false);
for (int i = 0; i < axes.size(); i++) {
for (int i = 0; i < axes.size(); i++)
{
int axis = normalize_axis(axes.get<int>(i), inpShape.size());
shouldDelete[axis] = true;
}
for (int axis = 0; axis < inpShape.size(); ++axis){
if (!shouldDelete[axis])
targetShape.push_back(inpShape[axis]);
else if (keepdims)
targetShape.push_back(1);
}
if (inpShape.size() == 3 && axes.size() <= 2)
{
int axis = normalize_axis(axes.get<int>(0), inpShape.size());
CV_CheckNE(axis, 0, "");
LayerParams reshapeLp;
reshapeLp.name = layerParams.name + "/reshape";
reshapeLp.type = "Reshape";
CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
reshapeLp.set("axis", 0);
reshapeLp.set("num_axes", 1);
int newShape[] = {1, -1};
reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], 2));
opencv_onnx::NodeProto proto;
proto.add_input(node_proto.input(0));
proto.add_output(reshapeLp.name);
addLayer(reshapeLp, proto);
LayerParams avgLp;
avgLp.name = layerParams.name + "/avg";
avgLp.type = "Pooling";
CV_Assert(layer_id.find(avgLp.name) == layer_id.end());
avgLp.set("pool", pool);
if (axes.size() == 2)
{
CV_CheckEQ(normalize_axis(axes.get<int>(0), inpShape.size()), 1, "Unsupported mode");
CV_CheckEQ(normalize_axis(axes.get<int>(1), inpShape.size()), 2, "Unsupported mode");
avgLp.set("global_pooling", true);
}
else
{
avgLp.set(axis == 2 ? "global_pooling_w" : "global_pooling_h", true);
avgLp.set(axis == 2 ? "kernel_h" : "kernel_w", 1);
}
node_proto.set_input(0, reshapeLp.name);
node_proto.set_output(0, avgLp.name);
addLayer(avgLp, node_proto);
}
else
{
if (inpShape.size() != 4 && inpShape.size() != 5)
CV_Error(Error::StsNotImplemented, "Unsupported input shape of " + layer_type + " operation.");
CV_Assert(axes.size() <= inpShape.size() - 2);
std::vector<int> kernel_size(inpShape.size() - 2, 1);
if (axes.size() == 1 && (normalize_axis(axes.get<int>(0), inpShape.size()) <= 1))
{
int axis = normalize_axis(axes.get<int>(0), inpShape.size());
MatShape newShape = inpShape;
newShape[axis + 1] = total(newShape, axis + 1);
newShape.resize(axis + 2);
newShape.insert(newShape.begin(), 2 - axis, 1);
LayerParams reshapeLp;
reshapeLp.type = "Reshape";
reshapeLp.name = layerParams.name + "/reshape";
CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], newShape.size()));
node_proto.set_output(0, reshapeLp.name);
addLayer(reshapeLp, node_proto);
kernel_size.resize(2);
kernel_size[0] = inpShape[axis];
node_proto.set_input(0, node_proto.output(0));
}
else
{
for (int i = 0; i < axes.size(); i++) {
int axis = normalize_axis(axes.get<int>(i), inpShape.size());
CV_Assert_N(axis >= 2 + i, axis < inpShape.size());
kernel_size[axis - 2] = inpShape[axis];
}
}
LayerParams poolLp = layerParams;
poolLp.name = layerParams.name + "/avg";
CV_Assert(layer_id.find(poolLp.name) == layer_id.end());
poolLp.set("kernel_size", DictValue::arrayInt(&kernel_size[0], kernel_size.size()));
node_proto.set_output(0, poolLp.name);
addLayer(poolLp, node_proto);
}
layerParams.type = "Reshape";
layerParams.set("dim", DictValue::arrayInt(&targetShape[0], targetShape.size()));
node_proto.set_input(0, node_proto.output(0));
node_proto.set_output(0, output_name);
}
else if (!layerParams.has("axes") && (layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax"))
else
{
IterShape_t shapeIt = outShapes.find(node_proto.input(0));
CV_Assert(shapeIt != outShapes.end());
const size_t dims = keepdims ? shapeIt->second.size() : 1;
LayerParams reshapeLp;
reshapeLp.name = layerParams.name + "/reshape";
reshapeLp.type = "Reshape";
CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
int newShape[] = {1, 1, 1, -1};
reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], 4));
opencv_onnx::NodeProto proto;
proto.add_input(node_proto.input(0));
proto.add_output(reshapeLp.name);
addLayer(reshapeLp, proto);
LayerParams poolLp = layerParams;
poolLp.name = layerParams.name + "/pool";
CV_Assert(layer_id.find(poolLp.name) == layer_id.end());
node_proto.set_input(0, reshapeLp.name);
node_proto.set_output(0, poolLp.name);
addLayer(poolLp, node_proto);
layerParams.type = "Reshape";
std::vector<int> targetShape(dims, 1);
layerParams.set("dim", DictValue::arrayInt(targetShape.data(), targetShape.size()));
node_proto.set_input(0, node_proto.output(0));
node_proto.set_output(0, output_name);
for (int i = 0; i < inpShape.size(); i++)
{
shouldDelete[i] = true;
}
}
MatShape targetShape;
for (int i = 0; i < inpShape.size(); ++i)
{
if (!shouldDelete[i])
{
targetShape.push_back(inpShape[i]);
}
else if (keepdims)
{
targetShape.push_back(1);
}
}
if (targetShape.empty())
targetShape.push_back(1);
// Using PermuteLayer to move the deleted axis to the last.
std::vector<int> perm(inpShape.size(), 0);
for (int i = 0; i < inpShape.size(); i++)
perm[i] = i;
bool needPermuet = false;
for (int i = 0; i < inpShape.size(); i++)
{
if (shouldDelete[i])
{
// find the first not deleted element.
std::vector<bool>::iterator iter = std::find(shouldDelete.begin() + i, shouldDelete.end(), false);
if (iter != shouldDelete.end())
{
int index = iter - shouldDelete.begin();
bool temp = shouldDelete[index];
shouldDelete[index] = shouldDelete[i];
shouldDelete[i] = temp;
std::swap(perm[index], perm[i]);
std::swap(inpShape[index], inpShape[i]);
needPermuet = true;
}
else
break;
}
}
auto inputString= node_proto.input(0);
if (needPermuet)
{
LayerParams permuteLp;
permuteLp.name = layerParams.name + "/permute";
permuteLp.type = (depth == CV_8S) ? "PermuteInt8" : "Permute";
permuteLp.set("order", DictValue::arrayInt(perm.data(), perm.size()));
opencv_onnx::NodeProto protoPermute;
protoPermute.add_input(inputString);
protoPermute.add_output(permuteLp.name);
addLayer(permuteLp, protoPermute);
inputString = permuteLp.name;
}
std::vector<int> deletedDims;
for (int axis_i = 0; axis_i < inpShape.size(); ++axis_i)
{
if (shouldDelete[axis_i])
{
deletedDims.push_back(inpShape[axis_i]);
}
}
LayerParams reduceLp = layerParams;
reduceLp.name = layerParams.name + "/reduce";
CV_Assert(layer_id.find(reduceLp.name) == layer_id.end());
reduceLp.set("deleted_dims", DictValue::arrayInt(&deletedDims[0], deletedDims.size()));
node_proto.set_input(0, inputString);
node_proto.set_output(0, reduceLp.name);
addLayer(reduceLp, node_proto);
layerParams.type = (depth == CV_8S) ? "ReshapeInt8" : "Reshape";
layerParams.set("dim", DictValue::arrayInt(&targetShape[0], targetShape.size()));
node_proto.set_input(0, node_proto.output(0));
node_proto.set_output(0, output_name);
addLayer(layerParams, node_proto);
}
@ -3406,8 +3426,10 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version)
dispatch["MaxUnpool"] = &ONNXImporter::parseMaxUnpool;
dispatch["MaxPool"] = &ONNXImporter::parseMaxPool;
dispatch["AveragePool"] = &ONNXImporter::parseAveragePool;
dispatch["GlobalAveragePool"] = dispatch["GlobalMaxPool"] = dispatch["ReduceMean"] = dispatch["ReduceSum"] =
dispatch["ReduceMax"] = &ONNXImporter::parseReduce;
dispatch["GlobalAveragePool"] = dispatch["GlobalMaxPool"] = &ONNXImporter::parseGlobalPool;
dispatch["ReduceMax"] = dispatch["ReduceMin"] = dispatch["ReduceMean"] = dispatch["ReduceSum"] = dispatch["ReduceMax"] =
dispatch["ReduceMin"] = dispatch["ReduceSumSquare"] = dispatch["ReduceProd"] = dispatch["ReduceL1"] =
dispatch["ReduceL2"] = dispatch["ReduceLogSum"] = dispatch["ReduceLogSumExp"] = &ONNXImporter::parseReduce;
dispatch["Slice"] = &ONNXImporter::parseSlice;
dispatch["Split"] = &ONNXImporter::parseSplit;
dispatch["Add"] = dispatch["Sum"] = dispatch["Sub"] = &ONNXImporter::parseBias;

View File

@ -20,3 +20,14 @@
"test_split_equal_parts_2d",
"test_split_equal_parts_default_axis",
"test_tan",
"test_reduce_l2_default_axes_keepdims_example", // Expected: (normL1) <= (l1), actual: 0.00490189 vs 0.004
"test_reduce_log_sum_exp_default_axes_keepdims_example", // Expected: (normL1) <= (l1), actual: 0.00671387 vs 0.004
"test_reduce_prod_default_axes_keepdims_example", // Expected: (normL1) <= (l1), actual: inf vs 0.004
"test_reduce_prod_default_axes_keepdims_random", // Expected: (normL1) <= (l1), actual: 18.6621 vs 0.004, Expected: (normInf) <= (lInf), actual: 18.6621 vs 0.02
"test_reduce_prod_do_not_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.00436729 vs 0.004, Expected: (normInf) <= (lInf), actual: 0.0201836 vs 0.02
"test_reduce_prod_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.00436729 vs 0.004, Expected: (normInf) <= (lInf), actual: 0.0201836 vs 0.02
"test_reduce_prod_negative_axes_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.00436729 vs 0.004, Expected: (normInf) <= (lInf), actual: 0.0201836 vs 0.02
"test_reduce_sum_square_default_axes_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.0183411 vs 0.004
"test_reduce_sum_square_do_not_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.010789 vs 0.004, Expected: (normInf) <= (lInf), actual: 0.0290298 vs 0.02
"test_reduce_sum_square_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.010789 vs 0.004, Expected: (normInf) <= (lInf), actual: 0.0290298 vs 0.02
"test_reduce_sum_square_negative_axes_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.010789 vs 0.004, Expected: (normInf) <= (lInf), actual: 0.0290298 vs 0.02

View File

@ -339,51 +339,6 @@
"test_range_int32_type_negative_delta_expanded",
"test_reciprocal",
"test_reciprocal_example",
"test_reduce_l1_default_axes_keepdims_example",
"test_reduce_l1_default_axes_keepdims_random",
"test_reduce_l1_do_not_keepdims_example",
"test_reduce_l1_do_not_keepdims_random",
"test_reduce_l1_keep_dims_example",
"test_reduce_l1_keep_dims_random",
"test_reduce_l1_negative_axes_keep_dims_example",
"test_reduce_l1_negative_axes_keep_dims_random",
"test_reduce_l2_default_axes_keepdims_example",
"test_reduce_l2_default_axes_keepdims_random",
"test_reduce_l2_do_not_keepdims_example",
"test_reduce_l2_do_not_keepdims_random",
"test_reduce_l2_keep_dims_example",
"test_reduce_l2_keep_dims_random",
"test_reduce_l2_negative_axes_keep_dims_example",
"test_reduce_l2_negative_axes_keep_dims_random",
"test_reduce_log_sum",
"test_reduce_log_sum_asc_axes",
"test_reduce_log_sum_default",
"test_reduce_log_sum_desc_axes",
"test_reduce_log_sum_exp_default_axes_keepdims_example",
"test_reduce_log_sum_exp_default_axes_keepdims_random",
"test_reduce_log_sum_exp_do_not_keepdims_example",
"test_reduce_log_sum_exp_do_not_keepdims_random",
"test_reduce_log_sum_exp_keepdims_example",
"test_reduce_log_sum_exp_keepdims_random",
"test_reduce_log_sum_exp_negative_axes_keepdims_example",
"test_reduce_log_sum_exp_negative_axes_keepdims_random",
"test_reduce_log_sum_negative_axes",
"test_reduce_min_default_axes_keepdims_example",
"test_reduce_min_default_axes_keepdims_random",
"test_reduce_min_do_not_keepdims_example",
"test_reduce_min_do_not_keepdims_random",
"test_reduce_min_keepdims_example",
"test_reduce_min_keepdims_random",
"test_reduce_min_negative_axes_keepdims_example",
"test_reduce_min_negative_axes_keepdims_random",
"test_reduce_prod_default_axes_keepdims_example",
"test_reduce_prod_default_axes_keepdims_random",
"test_reduce_prod_do_not_keepdims_example",
"test_reduce_prod_do_not_keepdims_random",
"test_reduce_prod_keepdims_example",
"test_reduce_prod_keepdims_random",
"test_reduce_prod_negative_axes_keepdims_example",
"test_reduce_prod_negative_axes_keepdims_random",
"test_reduce_sum_default_axes_keepdims_example",
"test_reduce_sum_default_axes_keepdims_random",
"test_reduce_sum_do_not_keepdims_example",
@ -394,14 +349,6 @@
"test_reduce_sum_keepdims_random",
"test_reduce_sum_negative_axes_keepdims_example",
"test_reduce_sum_negative_axes_keepdims_random",
"test_reduce_sum_square_default_axes_keepdims_example",
"test_reduce_sum_square_default_axes_keepdims_random",
"test_reduce_sum_square_do_not_keepdims_example",
"test_reduce_sum_square_do_not_keepdims_random",
"test_reduce_sum_square_keepdims_example",
"test_reduce_sum_square_keepdims_random",
"test_reduce_sum_square_negative_axes_keepdims_example",
"test_reduce_sum_square_negative_axes_keepdims_random",
"test_reflect_pad",
"test_reshape_allowzero_reordered",
"test_reshape_extended_dims",