mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 01:13:28 +08:00
tile impl
This commit is contained in:
parent
2aad039b4f
commit
441624a5fb
@ -1079,6 +1079,12 @@ CV__DNN_INLINE_NS_BEGIN
|
||||
static Ptr<ScatterNDLayer> create(const LayerParams& params);
|
||||
};
|
||||
|
||||
class CV_EXPORTS TileLayer : public Layer
|
||||
{
|
||||
public:
|
||||
static Ptr<TileLayer> create(const LayerParams& params);
|
||||
};
|
||||
|
||||
//! @}
|
||||
//! @}
|
||||
CV__DNN_INLINE_NS_END
|
||||
|
@ -183,6 +183,7 @@ void initializeLayerFactory()
|
||||
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Scatter, ScatterLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(ScatterND, ScatterNDLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Tile, TileLayer);
|
||||
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Quantize, QuantizeLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Dequantize, DequantizeLayer);
|
||||
|
97
modules/dnn/src/layers/tile_layer.cpp
Normal file
97
modules/dnn/src/layers/tile_layer.cpp
Normal file
@ -0,0 +1,97 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp"
|
||||
#include "layers_common.hpp"
|
||||
|
||||
#include <opencv2/dnn/shape_utils.hpp>
|
||||
|
||||
namespace cv { namespace dnn {
|
||||
|
||||
class TileLayerImpl CV_FINAL : public TileLayer
|
||||
{
|
||||
public:
|
||||
TileLayerImpl(const LayerParams& params)
|
||||
{
|
||||
setParamsFrom(params);
|
||||
if (params.has("repeats"))
|
||||
{
|
||||
DictValue param_repeats = params.get("repeats");
|
||||
int n_repeats = param_repeats.size();
|
||||
|
||||
CV_Assert(n_repeats > 0);
|
||||
repeats.resize(n_repeats);
|
||||
for (int i = 0; i < n_repeats; i++)
|
||||
repeats[i] = param_repeats.get<int>(i);
|
||||
}
|
||||
else
|
||||
CV_Error(Error::StsNotImplemented, "Tile: repeats needs to be treated as parameter but it is missing.");
|
||||
}
|
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE
|
||||
{
|
||||
return backendId == DNN_BACKEND_OPENCV;
|
||||
}
|
||||
|
||||
virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||
const int requiredOutputs,
|
||||
std::vector<MatShape> &outputs,
|
||||
std::vector<MatShape> &internals) const CV_OVERRIDE
|
||||
{
|
||||
CV_CheckEQ(inputs.size(), 1ull, "Tile: one input is expected");
|
||||
|
||||
// repeats must have the same length as input's dimension number
|
||||
// FIXIT: it breaks when the input is 1d tensor (represented as 2d mat with size=2 in opencv dnn)
|
||||
CV_CheckEQ(inputs[0].size(), repeats.size(), "Tile: repeats must be a 1D tensor of the same length as input's dimension number");
|
||||
|
||||
outputs.assign(1, inputs[0]);
|
||||
for (int i = 0; i < repeats.size(); i++)
|
||||
{
|
||||
outputs[0][i] *= repeats[i];
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
|
||||
|
||||
std::vector<Mat> inputs, outputs;
|
||||
inputs_arr.getMatVector(inputs);
|
||||
outputs_arr.getMatVector(outputs);
|
||||
|
||||
const Mat& data = inputs[0];
|
||||
Mat& out = outputs[0];
|
||||
|
||||
Mat tmp = data.clone();
|
||||
MatShape tmp_shape = shape(tmp);
|
||||
MatShape out_shape = shape(out);
|
||||
int rep_i, ndims = data.dims;
|
||||
int dims = 1;
|
||||
for (int i = 0; i < ndims; i++)
|
||||
{
|
||||
rep_i = repeats[i];
|
||||
if (rep_i != 1)
|
||||
{
|
||||
tmp = tmp.reshape(0, dims);
|
||||
tmp = cv::repeat(tmp, 1, rep_i);
|
||||
dims *= out_shape[i];
|
||||
}
|
||||
}
|
||||
tmp = tmp.reshape(0, out_shape);
|
||||
|
||||
tmp.copyTo(out);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int> repeats;
|
||||
};
|
||||
|
||||
Ptr<TileLayer> TileLayer::create(const LayerParams& params)
|
||||
{
|
||||
return makePtr<TileLayerImpl>(params);
|
||||
}
|
||||
|
||||
}} // namespace cv::dnn
|
@ -189,6 +189,7 @@ private:
|
||||
void parseDepthToSpace (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseRange (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseScatter (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseTile (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseSimpleLayers (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
|
||||
// Domain: com.microsoft
|
||||
@ -3156,6 +3157,82 @@ void ONNXImporter::parseScatter(LayerParams& layerParams, const opencv_onnx::Nod
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
void ONNXImporter::parseTile(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
// for Tile>1, only the case of 'repeats' being constant is supported.
|
||||
// 'repeats' is treated as a parameter instead of an input to determine shape in pre-run.
|
||||
|
||||
CV_Assert(node_proto.input_size() == 2 || node_proto.input_size() == 3); // tile-1: 3 inputs, tile>1: 2 inputs
|
||||
bool is_opset_1 = node_proto.input_size() == 3;
|
||||
|
||||
std::vector<size_t> const_input_idx;
|
||||
for (size_t i = 0; i < node_proto.input_size(); ++i)
|
||||
if (layer_id.find(node_proto.input(i)) == layer_id.end())
|
||||
const_input_idx.push_back(i);
|
||||
|
||||
bool all_const = false;
|
||||
if (const_input_idx.size() == node_proto.input_size()) // all inputs are constant
|
||||
{
|
||||
all_const = true;
|
||||
}
|
||||
else if ((const_input_idx.size() == 1 && const_input_idx[0] == 1) || // tile>1
|
||||
(const_input_idx.size() == 2 && const_input_idx[0] == 1 && const_input_idx[1] == 2)) // tile-1
|
||||
{
|
||||
all_const = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!is_opset_1)
|
||||
CV_Error(Error::StsNotImplemented, "ONNX/Tile: repeats being non-constant is not supported.");
|
||||
else
|
||||
CV_Error(Error::StsNotImplemented, "ONNX/Tile: tiles or axis being non-constant are not supported.");
|
||||
}
|
||||
|
||||
int input0_dims = 1;
|
||||
if (all_const)
|
||||
input0_dims = getBlob(node_proto, 0).dims;
|
||||
else
|
||||
input0_dims = outShapes[node_proto.input(0)].size();
|
||||
|
||||
// repeats, treated as paramenter
|
||||
std::vector<int> repeats_vec(input0_dims, 1);
|
||||
Mat input1_blob = getBlob(node_proto, 1);
|
||||
if (is_opset_1)
|
||||
{
|
||||
// input1 in tile-1: tiles, 1d tensor of shape [1]
|
||||
CV_CheckEQ(input1_blob.total(), 1ull, "ONNX/Tile: tiles must be a 0D tensor or 1D tensor of shape [1].");
|
||||
int tiles = input1_blob.at<int>(0);
|
||||
// input2 in tile-1: axis, 1d tensor of shape [1]
|
||||
Mat input2_blob = getBlob(node_proto, 2);
|
||||
CV_CheckEQ(input2_blob.total(), 1ull, "ONNX/Tile: axis must be a 0D tensor or 1D tensor of shape [1].");
|
||||
int axis = input2_blob.at<int>(0);
|
||||
repeats_vec[axis] = tiles;
|
||||
}
|
||||
else
|
||||
{
|
||||
// input1 in tile>1: repeats
|
||||
CV_CheckEQ(input1_blob.dims, 2, "ONNX/Tile: repeats must be a 1D tensor."); // 1D tensor is represented as a 2D Mat
|
||||
for (int i = 0; i < input1_blob.total(); i++)
|
||||
repeats_vec[i] = input1_blob.at<int>(i);
|
||||
}
|
||||
layerParams.set("repeats", DictValue::arrayInt(repeats_vec.data(), repeats_vec.size()));
|
||||
|
||||
if (all_const)
|
||||
{
|
||||
std::vector<Mat> inputs, output;
|
||||
Mat input0_blob = getBlob(node_proto, 0);
|
||||
inputs.push_back(input0_blob);
|
||||
runLayer(layerParams, inputs, output);
|
||||
CV_Assert(output.size() == 1);
|
||||
addConstant(node_proto.output(0), output[0]);
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXImporter::parseSimpleLayers(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
bool is_all_input_const = true;
|
||||
@ -3857,6 +3934,7 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version)
|
||||
dispatch["CumSum"] = &ONNXImporter::parseCumSum;
|
||||
dispatch["SpaceToDepth"] = dispatch["DepthToSpace"] = &ONNXImporter::parseDepthToSpace;
|
||||
dispatch["ScatterElements"] = dispatch["Scatter"] = dispatch["ScatterND"] = &ONNXImporter::parseScatter;
|
||||
dispatch["Tile"] = &ONNXImporter::parseTile;
|
||||
|
||||
dispatch["Equal"] = dispatch["Greater"] = dispatch["Less"] = dispatch["Pow"] = dispatch["Add"] =
|
||||
dispatch["Sub"] = dispatch["Mul"] = dispatch["Div"] = dispatch["GreaterOrEqual"] =
|
||||
|
@ -2483,6 +2483,11 @@ TEST_P(Test_ONNX_layers, YOLOv7)
|
||||
normAssertDetections(refClassIds, refScores, refBoxes, keep_classIds, keep_confidences, keep_boxes);
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Tile)
|
||||
{
|
||||
testONNXModels("tile", pb);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets());
|
||||
|
||||
}} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user