mirror of
https://github.com/opencv/opencv.git
synced 2024-11-24 03:00:14 +08:00
dnn: expand refactor with cv::broadcast for onnx models (#24295)
* add expand impl with cv::broadcast * remove expandMid * deduce shape from -1 * add constant folding * handle input constant; handle input constant 1d * add expand conformance tests; add checks to disallow shape of neg values; add early copy for unchanged total elements * fix ExpandSubgraph * dummy commit to trigger build * dummy commit to trigger build 1 * remove conformance from test names
This commit is contained in:
parent
9942757bab
commit
bb171a0c05
@ -1144,6 +1144,12 @@ CV__DNN_INLINE_NS_BEGIN
|
||||
static Ptr<GemmLayer> create(const LayerParams& params);
|
||||
};
|
||||
|
||||
class CV_EXPORTS ExpandLayer : public Layer
|
||||
{
|
||||
public:
|
||||
static Ptr<ExpandLayer> create(const LayerParams ¶ms);
|
||||
};
|
||||
|
||||
//! @}
|
||||
//! @}
|
||||
CV__DNN_INLINE_NS_END
|
||||
|
@ -158,6 +158,7 @@ void initializeLayerFactory()
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Reciprocal, ReciprocalLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Gather, GatherLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(LayerNormalization, LayerNormLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Expand, ExpandLayer);
|
||||
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Crop, CropLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Eltwise, EltwiseLayer);
|
||||
|
149
modules/dnn/src/layers/expand_layer.cpp
Normal file
149
modules/dnn/src/layers/expand_layer.cpp
Normal file
@ -0,0 +1,149 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp"
|
||||
#include <opencv2/dnn/shape_utils.hpp>
|
||||
|
||||
namespace cv { namespace dnn {
|
||||
|
||||
class ExpandLayerImpl CV_FINAL : public ExpandLayer
|
||||
{
|
||||
public:
|
||||
ExpandLayerImpl(const LayerParams ¶ms) {
|
||||
setParamsFrom(params);
|
||||
|
||||
// shape as param
|
||||
CV_CheckTrue(params.has("shape"), "DNN/Expand: shape is required in Expand layer initialization");
|
||||
DictValue param_shape = params.get("shape");
|
||||
int ndims_shape = param_shape.size();
|
||||
CV_CheckGT(ndims_shape, 0, "DNN/Expand: ndims of shape must be > 0");
|
||||
target_shape.resize(ndims_shape);
|
||||
for (int i = 0; i < ndims_shape; i++) {
|
||||
target_shape[i] = param_shape.get<int>(i);
|
||||
}
|
||||
|
||||
// FIXME: remove when 0d/1d mat is available
|
||||
const_input_1d = params.get("const_input_1d", false);
|
||||
}
|
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE {
|
||||
return backendId == DNN_BACKEND_OPENCV;
|
||||
}
|
||||
|
||||
virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||
const int requiredOutputs,
|
||||
std::vector<MatShape> &outputs,
|
||||
std::vector<MatShape> &internals) const CV_OVERRIDE {
|
||||
CV_CheckGE(inputs.size(), static_cast<size_t>(1), "DNN/Expand: one input at least");
|
||||
CV_CheckLE(inputs.size(), static_cast<size_t>(2), "DNN/Expand: two input at most");
|
||||
CV_CheckFalse(target_shape.empty(), "DNN/Expand: shape must known before memory is set");
|
||||
|
||||
MatShape input_shape = inputs[0]; // 1d tensor is represented as 2d mat, e.g. [3] -> [3, 1]
|
||||
if (const_input_1d) {
|
||||
input_shape = {inputs[0][0]};
|
||||
}
|
||||
|
||||
auto& moreDimension = input_shape.size() > target_shape.size() ? input_shape : target_shape;
|
||||
auto& lessDimension = input_shape.size() <= target_shape.size() ? input_shape : target_shape;
|
||||
|
||||
/* Example:
|
||||
i = 3
|
||||
|
|
||||
moreDimension: 1 2 3 4 5, assign non-aligned dimensions to output shape
|
||||
lessDimension: 1 1 5, when dimension is aligned, check valid dimension (either equal or one of them is 1) and assign bigger one
|
||||
|
|
||||
j = 0 = i - (moreDimension.size() - lessDimension.size());
|
||||
*/
|
||||
MatShape outputShape(moreDimension.size(), 1);
|
||||
for (int i = 0; i < moreDimension.size(); i++) {
|
||||
int d = moreDimension[i];
|
||||
int j = i - (moreDimension.size() - lessDimension.size());
|
||||
if (j >= 0) {
|
||||
if (d == 1 || lessDimension[j] == 1 || // broadcast
|
||||
d == lessDimension[j]) { // plain copy
|
||||
outputShape[i] = std::max(d, lessDimension[j]);
|
||||
} else {
|
||||
CV_Error(Error::StsBadSize, cv::format("DNN/Expand: invalid dimension, d (%d) != d (%d)", moreDimension[i], lessDimension[j]));
|
||||
}
|
||||
} else {
|
||||
outputShape[i] = d;
|
||||
}
|
||||
}
|
||||
outputs.assign(1, outputShape);
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE {
|
||||
std::vector<Mat> inputs;
|
||||
inputs_arr.getMatVector(inputs);
|
||||
|
||||
const auto &input = inputs[0];
|
||||
auto input_shape = shape(input);
|
||||
if (const_input_1d) {
|
||||
input_shape = {input_shape[0]};
|
||||
}
|
||||
|
||||
auto& moreDimension = input_shape.size() > target_shape.size() ? input_shape : target_shape;
|
||||
auto& lessDimension = input_shape.size() <= target_shape.size() ? input_shape : target_shape;
|
||||
|
||||
MatShape final_target_shape(moreDimension.size(), 1);
|
||||
for (int i = 0; i < moreDimension.size(); i++) {
|
||||
int d = moreDimension[i];
|
||||
int j = i - (moreDimension.size() - lessDimension.size());
|
||||
if (j >= 0) {
|
||||
final_target_shape[i] = std::max(lessDimension[j], d);
|
||||
} else {
|
||||
final_target_shape[i] = d;
|
||||
}
|
||||
}
|
||||
target_shape.clear();
|
||||
target_shape = std::move(final_target_shape);
|
||||
}
|
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE {
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
|
||||
|
||||
if (inputs_arr.depth() == CV_16S)
|
||||
{
|
||||
forward_fallback(inputs_arr, outputs_arr, internals_arr);
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<Mat> inputs, outputs;
|
||||
inputs_arr.getMatVector(inputs);
|
||||
outputs_arr.getMatVector(outputs);
|
||||
|
||||
int target_shape_total = std::accumulate(target_shape.begin(), target_shape.end(), 1, std::multiplies<int>());
|
||||
if (target_shape_total == inputs[0].total()) {
|
||||
const char *data = inputs[0].ptr<const char>();
|
||||
char *output = outputs[0].ptr<char>();
|
||||
int step = target_shape_total * outputs[0].elemSize();
|
||||
std::memcpy(output, data, step);
|
||||
return;
|
||||
}
|
||||
|
||||
if (const_input_1d) {
|
||||
const char *data = inputs[0].ptr<const char>();
|
||||
char *output = outputs[0].ptr<char>();
|
||||
int step = target_shape.back() * outputs[0].elemSize();
|
||||
int total = std::accumulate(target_shape.begin(), target_shape.end() - 1, 1, std::multiplies<int>());
|
||||
for (int i = 0; i < total; i++) {
|
||||
std::memcpy(output + i * step, data, step);
|
||||
}
|
||||
} else {
|
||||
cv::broadcast(inputs[0], target_shape, outputs[0]);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
MatShape target_shape;
|
||||
bool const_input_1d;
|
||||
};
|
||||
|
||||
Ptr<ExpandLayer> ExpandLayer::create(const LayerParams ¶ms) {
|
||||
return makePtr<ExpandLayerImpl>(params);
|
||||
}
|
||||
|
||||
}} // cv::dnn
|
@ -821,6 +821,16 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/* Constant folding shape for Expand.
|
||||
|
||||
Before fusion:
|
||||
+--------------------------------------------------------------+ (X)
|
||||
| |
|
||||
ConstantOfShape[input=[4]] -> Mul[B=-1] -> Equal[A=[2, -1, -1, -1]] -> Where[Y=[2, -1, -1, -1]] -> Expand
|
||||
\ \
|
||||
value=[1] (condition)
|
||||
|
||||
*/
|
||||
class ExpandSubgraph : public Subgraph
|
||||
{
|
||||
public:
|
||||
@ -837,6 +847,128 @@ public:
|
||||
addNodeToMatch("Expand", input, where);
|
||||
setFusedNode("Expand", input, shape);
|
||||
}
|
||||
|
||||
static int extractValue(const Ptr<ImportGraphWrapper>& net, int node_id, int64_t &val) {
|
||||
Ptr<ImportNodeWrapper> node_wrapper = net->getNode(node_id);
|
||||
opencv_onnx::NodeProto* node = node_wrapper.dynamicCast<ONNXNodeWrapper>()->node;
|
||||
|
||||
if (node->attribute_size() == 0) {
|
||||
val = 0;
|
||||
return 1;
|
||||
} else if (node->attribute_size() == 1) {
|
||||
opencv_onnx::AttributeProto attr = node->attribute(0);
|
||||
if (attr.name() != "value") {
|
||||
return 0;
|
||||
}
|
||||
Mat mat_value = getMatFromTensor(attr.t());
|
||||
switch (mat_value.type()) {
|
||||
case CV_32S: {
|
||||
val = static_cast<int64_t>(mat_value.at<int>());
|
||||
} break;
|
||||
default: return 0;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static std::vector<int64_t> extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
|
||||
{
|
||||
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
||||
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
|
||||
Mat mat_constant;
|
||||
if (initializer_id != -1) // initializer
|
||||
{
|
||||
mat_constant = onnx_net->getMatFromInitializer(initializer_id);
|
||||
}
|
||||
else
|
||||
{
|
||||
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
|
||||
int constant_id = getInputNodeId(net, node, input_id);
|
||||
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
|
||||
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
|
||||
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
|
||||
mat_constant = getMatFromTensor(constant_proto);
|
||||
}
|
||||
|
||||
std::vector<int64_t> retvals{mat_constant.begin<int>(), mat_constant.end<int>()};
|
||||
return retvals;
|
||||
}
|
||||
|
||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
std::vector<int>& matchedNodesIds,
|
||||
std::vector<int>& targetNodesIds) CV_OVERRIDE {
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) {
|
||||
int64_t value_ConstantOfShape;
|
||||
if (!extractValue(net, matchedNodesIds[0], value_ConstantOfShape)) {
|
||||
return false;
|
||||
}
|
||||
std::vector<int64_t> input_ConstantOfShape = extractConstant(net, matchedNodesIds[0], 0);
|
||||
if (input_ConstantOfShape.size() != static_cast<size_t>(1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto B_Mul = extractConstant(net, matchedNodesIds[1], 1);
|
||||
if (B_Mul.size() != static_cast<size_t>(1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto A_Equal = extractConstant(net, matchedNodesIds[2], 0);
|
||||
if (A_Equal.size() != static_cast<size_t>(input_ConstantOfShape[0])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto Y_Where = extractConstant(net, matchedNodesIds[3], 2);
|
||||
if (Y_Where.size() != A_Equal.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// run ConstantOfShape
|
||||
std::vector<int64_t> output_ConstantOfShape(std::accumulate(input_ConstantOfShape.begin(), input_ConstantOfShape.end(), static_cast<int64_t>(1), std::multiplies<int64_t>()), value_ConstantOfShape);
|
||||
// run Mul
|
||||
std::vector<int64_t> output_Mul = output_ConstantOfShape;
|
||||
for (size_t i = 0; i < output_Mul.size(); i++) {
|
||||
int64_t b = B_Mul[0];
|
||||
output_Mul[i] *= b;
|
||||
}
|
||||
// run Equal
|
||||
std::vector<bool> output_Equal(output_Mul.size());
|
||||
for (int i = 0; i < output_Equal.size(); i++) {
|
||||
if (A_Equal[i] == output_Mul[i]) {
|
||||
output_Equal[i] = true;
|
||||
} else {
|
||||
output_Equal[i] = false;
|
||||
}
|
||||
}
|
||||
// run Where
|
||||
std::vector<int64_t> output_Where(output_Equal.size());
|
||||
for (int i = 0; i < output_Where.size(); i++) {
|
||||
if (output_Equal[i]) {
|
||||
output_Where[i] = output_ConstantOfShape[i];
|
||||
} else {
|
||||
output_Where[i] = Y_Where[i];
|
||||
}
|
||||
}
|
||||
shape = output_Where;
|
||||
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual void finalize(const Ptr<ImportGraphWrapper>& graph,
|
||||
const Ptr<ImportNodeWrapper>& fusedNode,
|
||||
std::vector<Ptr<ImportNodeWrapper> >& inputs) CV_OVERRIDE {
|
||||
// replace values
|
||||
opencv_onnx::NodeProto* node_shape = inputs[1].dynamicCast<ONNXNodeWrapper>()->node;
|
||||
auto attr = node_shape->mutable_attribute()->Mutable(0);
|
||||
auto tensor = attr->mutable_t();
|
||||
tensor->clear_raw_data();
|
||||
tensor->set_raw_data(std::string((const char*)(shape.data()), shape.size() * sizeof(int64_t)));
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<int64_t> shape;
|
||||
};
|
||||
|
||||
class MishSubgraph : public Subgraph
|
||||
|
@ -93,8 +93,6 @@ class ONNXImporter
|
||||
const opencv_onnx::NodeProto& node_proto);
|
||||
void setParamsDtype(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
|
||||
void expandMid(const std::string& prefix, opencv_onnx::NodeProto& node_proto,
|
||||
const std::string& input, size_t n);
|
||||
void lstm_extractConsts(LayerParams& layerParams, const opencv_onnx::NodeProto& lstm_proto, size_t idx, int* blobShape_, int size);
|
||||
void lstm_add_reshape(const std::string& input_name, const std::string& output_name, int* layerShape, size_t n);
|
||||
std::string lstm_add_slice(int index, const std::string& input_name, int* begin, int* end, size_t n);
|
||||
@ -657,37 +655,6 @@ void ONNXImporter::addLayer(LayerParams& layerParams,
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief Make N copies of input layer and set them as input to node_proto.
|
||||
* @param prefix prefix of new layers' names
|
||||
* @param node_proto node which will contain all copies as inputs
|
||||
* @param input name of the node to copy
|
||||
* @param n number of copies
|
||||
*/
|
||||
void ONNXImporter::expandMid(const std::string& prefix, opencv_onnx::NodeProto& node_proto,
|
||||
const std::string& input, size_t n)
|
||||
{
|
||||
std::vector<std::string> input_names;
|
||||
input_names.reserve(n);
|
||||
for (size_t j = 0; j < n; j++)
|
||||
{
|
||||
LayerParams copyLP;
|
||||
copyLP.name = format("%s/copy_%zu", prefix.c_str(), j);
|
||||
copyLP.type = "Identity";
|
||||
CV_Assert((layer_id.find(copyLP.name) == layer_id.end()) &&
|
||||
"Couldn't copy the node: generated name already exists in the graph.");
|
||||
input_names.push_back(copyLP.name);
|
||||
|
||||
node_proto.set_input(0, input);
|
||||
node_proto.set_output(0, copyLP.name);
|
||||
addLayer(copyLP, node_proto);
|
||||
}
|
||||
node_proto.clear_input();
|
||||
for (size_t i = 0; i < input_names.size(); i++)
|
||||
{
|
||||
node_proto.add_input(input_names[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXImporter::addConstant(const std::string& name, const Mat& blob)
|
||||
{
|
||||
CV_LOG_DEBUG(NULL, "DNN/ONNX: add constant '" << name << "' shape=" << toString(shape(blob)) << ": " << toString(blob));
|
||||
@ -2341,137 +2308,38 @@ void ONNXImporter::parseUnsqueeze(LayerParams& layerParams, const opencv_onnx::N
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
void ONNXImporter::parseExpand(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
|
||||
void ONNXImporter::parseExpand(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
opencv_onnx::NodeProto node_proto = node_proto_;
|
||||
CV_CheckEQ(node_proto.input_size(), 2, "");
|
||||
const std::string& input0 = node_proto.input(0);
|
||||
const std::string& input1 = node_proto.input(1);
|
||||
const std::string output_name = node_proto.output(0);
|
||||
Mat newShapeMat = getBlob(input1);
|
||||
MatShape targetShape(newShapeMat.ptr<int>(), newShapeMat.ptr<int>() + newShapeMat.total());
|
||||
CV_CheckEQ(node_proto.input_size(), 2, "DNN/ONNXImporter-Expand: two inputs are required");
|
||||
// input shape must be constant and it is passed as param to the layer
|
||||
CV_CheckTrue(constBlobs.find(node_proto.input(1)) != constBlobs.end(),
|
||||
"DNN/ONNXImporter-Expand: input shape must be constant");
|
||||
|
||||
MatShape inpShape;
|
||||
bool haveVariables = constBlobs.find(input0) == constBlobs.end();
|
||||
if (haveVariables)
|
||||
{
|
||||
IterShape_t shapeIt = outShapes.find(input0);
|
||||
CV_Assert(shapeIt != outShapes.end());
|
||||
inpShape = shapeIt->second;
|
||||
}
|
||||
else
|
||||
{
|
||||
Mat blob = getBlob(input0);
|
||||
if (constBlobsExtraInfo.find(node_proto.input(0)) != constBlobsExtraInfo.end() &&
|
||||
getBlobExtraInfo(node_proto, 0).real_ndims == 1) {
|
||||
inpShape = {(int)blob.total()};
|
||||
} else {
|
||||
inpShape = shape(blob);
|
||||
}
|
||||
Mat mat_input_shape = getBlob(node_proto, 1);
|
||||
CV_CheckTypeEQ(mat_input_shape.depth(), CV_32S, "DNN/ONNXImporter-Expand: data type of input shape must be CV_32S");
|
||||
for (int i = 0; i < mat_input_shape.total(); ++i) {
|
||||
CV_Check(i, *(mat_input_shape.ptr<int>() + i) >= 0, "DNN/ONNXImporter-Expand: invalid shape dimension");
|
||||
}
|
||||
layerParams.set("shape", DictValue::arrayInt(mat_input_shape.ptr<int>(), mat_input_shape.total()));
|
||||
|
||||
String srcName = input0;
|
||||
// Unsqueeze and repeat along new axis
|
||||
if (targetShape.size() > inpShape.size())
|
||||
{
|
||||
inpShape.insert(inpShape.begin(), targetShape.size() - inpShape.size(), 1);
|
||||
for (int i = 0; i < targetShape.size(); i++)
|
||||
{
|
||||
if (abs(targetShape[i]) == 1)
|
||||
targetShape[i] = inpShape[i];
|
||||
}
|
||||
if (haveVariables)
|
||||
{
|
||||
LayerParams reshapeLp;
|
||||
reshapeLp.name = layerParams.name + "/reshape";
|
||||
reshapeLp.type = "Reshape";
|
||||
CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
|
||||
reshapeLp.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size()));
|
||||
|
||||
opencv_onnx::NodeProto proto;
|
||||
proto.add_input(node_proto.input(0));
|
||||
proto.add_output(reshapeLp.name);
|
||||
addLayer(reshapeLp, proto);
|
||||
srcName = reshapeLp.name;
|
||||
}
|
||||
}
|
||||
CV_CheckEQ(inpShape.size(), targetShape.size(), "Unsupported Expand op with different dims");
|
||||
|
||||
std::vector<int> broadcast_axes;
|
||||
// shapes aren't right-aligned here because targetShape.size() == inpShape.size()
|
||||
for (int i = 0; i < targetShape.size(); i++)
|
||||
{
|
||||
if (targetShape[i] != inpShape[i])
|
||||
{
|
||||
if (inpShape[i] == 1)
|
||||
{
|
||||
broadcast_axes.push_back(i);
|
||||
}
|
||||
else if (targetShape[i] != 1)
|
||||
{
|
||||
CV_Error(Error::StsError, format("Could not be broadcast by axis: %d", i));
|
||||
if (constBlobs.find(node_proto.input(0)) != constBlobs.end()) {
|
||||
bool const_input_1d = false;
|
||||
if (constBlobsExtraInfo.find(node_proto.input(0)) != constBlobsExtraInfo.end()) {
|
||||
if (getBlobExtraInfo(node_proto, 0).real_ndims == 1) {
|
||||
const_input_1d = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!haveVariables)
|
||||
{
|
||||
if (broadcast_axes.empty())
|
||||
{
|
||||
addConstant(output_name, getBlob(node_proto, 0).reshape(1, targetShape));
|
||||
return;
|
||||
}
|
||||
layerParams.set("const_input_1d", const_input_1d);
|
||||
|
||||
Mat input = getBlob(node_proto, 0);
|
||||
MatShape subTargetShape = inpShape;
|
||||
for (auto broadcast_axis : broadcast_axes)
|
||||
{
|
||||
subTargetShape[broadcast_axis] = targetShape[broadcast_axis];
|
||||
input = input.reshape(0, total(inpShape, 0, broadcast_axis));
|
||||
Mat output = cv::repeat(input, 1, subTargetShape[broadcast_axis]);
|
||||
input = output.reshape(0, subTargetShape);
|
||||
}
|
||||
addConstant(output_name, input);
|
||||
std::vector<Mat> inputs, expanded;
|
||||
inputs.push_back(input);
|
||||
runLayer(layerParams, inputs, expanded);
|
||||
CV_CheckEQ(expanded.size(), static_cast<size_t>(1), "DNN/Expand: only one output is expected when folding constant");
|
||||
addConstant(node_proto.output(0), expanded[0]);
|
||||
return;
|
||||
}
|
||||
|
||||
if (broadcast_axes.size() == 2 &&
|
||||
broadcast_axes[0] == broadcast_axes[1] - 1 && broadcast_axes[1] == inpShape.size() - 1)
|
||||
{
|
||||
LayerParams constParams;
|
||||
constParams.name = layerParams.name + "/const";
|
||||
CV_Assert(layer_id.find(constParams.name) == layer_id.end());
|
||||
constParams.type = "Const";
|
||||
|
||||
Mat inp = Mat::ones(newShapeMat.total(), newShapeMat.ptr<int>(), CV_32F);
|
||||
constParams.blobs.push_back(inp);
|
||||
|
||||
opencv_onnx::NodeProto proto;
|
||||
proto.add_output(constParams.name);
|
||||
addLayer(constParams, proto);
|
||||
|
||||
layerParams.type = "Scale";
|
||||
layerParams.set("bias_term", false);
|
||||
node_proto.set_input(0, constParams.name);
|
||||
node_proto.set_input(1, srcName);
|
||||
}
|
||||
else if (broadcast_axes.size() == 1)
|
||||
{
|
||||
// FIXME: this will end up creating massive amount of Identity nodes for broadcasting,
|
||||
// for example, broadcast 1 to 256 needs 256 Identity nodes and 1 Concat node.
|
||||
// Possible improvement is to use "Scale".
|
||||
expandMid(layerParams.name, node_proto, srcName, targetShape[broadcast_axes[0]]);
|
||||
|
||||
layerParams.set("axis", broadcast_axes[0]);
|
||||
layerParams.type = "Concat";
|
||||
node_proto.set_output(0, output_name);
|
||||
}
|
||||
else if (broadcast_axes.empty())
|
||||
{
|
||||
layerParams.type = "Identity";
|
||||
}
|
||||
else
|
||||
CV_Error(Error::StsNotImplemented, "Unsupported Expand op");
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
|
@ -987,9 +987,21 @@ TEST_P(Test_ONNX_layers, MatMulAdd)
|
||||
TEST_P(Test_ONNX_layers, Expand)
|
||||
{
|
||||
testONNXModels("expand");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, ExpandIdentity) {
|
||||
testONNXModels("expand_identity");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, ExpandBatch) {
|
||||
testONNXModels("expand_batch");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, ExpandChannels) {
|
||||
testONNXModels("expand_channels");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, ExpandNegBatch) {
|
||||
testONNXModels("expand_neg_batch");
|
||||
}
|
||||
|
||||
@ -2681,6 +2693,27 @@ TEST_P(Test_ONNX_layers, Conformance_Gemm_transposeB) {
|
||||
testONNXModels("test_gemm_transposeB", pb, 0, 0, false, true, 2);
|
||||
}
|
||||
|
||||
// Note: These tests are converted from onnx/onnx so that they have constant shape as input.
|
||||
// TODO: They can be moved into conformance tests once dynamic input is properly supported.
|
||||
TEST_P(Test_ONNX_layers, Expand_dim_changed) {
|
||||
testONNXModels("test_expand_dim_changed", pb, 0, 0, false, true, 1);
|
||||
}
|
||||
TEST_P(Test_ONNX_layers, Expand_dim_unchanged) {
|
||||
testONNXModels("test_expand_dim_unchanged", pb, 0, 0, false, true, 1);
|
||||
}
|
||||
TEST_P(Test_ONNX_layers, Expand_shape_model1) {
|
||||
testONNXModels("test_expand_shape_model1", pb, 0, 0, false, true, 1);
|
||||
}
|
||||
TEST_P(Test_ONNX_layers, Expand_shape_model2) {
|
||||
testONNXModels("test_expand_shape_model2", pb, 0, 0, false, true, 1);
|
||||
}
|
||||
TEST_P(Test_ONNX_layers, Expand_shape_model3) {
|
||||
testONNXModels("test_expand_shape_model3", pb, 0, 0, false, true, 1);
|
||||
}
|
||||
TEST_P(Test_ONNX_layers, Expand_shape_model4) {
|
||||
testONNXModels("test_expand_shape_model4", pb, 0, 0, false, true, 1);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets());
|
||||
|
||||
}} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user