mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
add ArgMax and ArgMin layers
This commit is contained in:
parent
973e1acb67
commit
e608adea60
@ -284,6 +284,16 @@ CV__DNN_INLINE_NS_BEGIN
|
|||||||
static Ptr<LRNLayer> create(const LayerParams& params);
|
static Ptr<LRNLayer> create(const LayerParams& params);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/** @brief ArgMax/ArgMin layer
|
||||||
|
* @note returns indices as floats, which means the supported range is [-2^24; 2^24]
|
||||||
|
*/
|
||||||
|
class CV_EXPORTS ArgLayer : public Layer
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
static Ptr<ArgLayer> create(const LayerParams& params);
|
||||||
|
};
|
||||||
|
|
||||||
class CV_EXPORTS PoolingLayer : public Layer
|
class CV_EXPORTS PoolingLayer : public Layer
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
@ -123,6 +123,7 @@ void initializeLayerFactory()
|
|||||||
CV_DNN_REGISTER_LAYER_CLASS(Identity, BlankLayer);
|
CV_DNN_REGISTER_LAYER_CLASS(Identity, BlankLayer);
|
||||||
CV_DNN_REGISTER_LAYER_CLASS(Silence, BlankLayer);
|
CV_DNN_REGISTER_LAYER_CLASS(Silence, BlankLayer);
|
||||||
CV_DNN_REGISTER_LAYER_CLASS(Const, ConstLayer);
|
CV_DNN_REGISTER_LAYER_CLASS(Const, ConstLayer);
|
||||||
|
CV_DNN_REGISTER_LAYER_CLASS(Arg, ArgLayer);
|
||||||
|
|
||||||
CV_DNN_REGISTER_LAYER_CLASS(Crop, CropLayer);
|
CV_DNN_REGISTER_LAYER_CLASS(Crop, CropLayer);
|
||||||
CV_DNN_REGISTER_LAYER_CLASS(Eltwise, EltwiseLayer);
|
CV_DNN_REGISTER_LAYER_CLASS(Eltwise, EltwiseLayer);
|
||||||
|
120
modules/dnn/src/layers/arg_layer.cpp
Normal file
120
modules/dnn/src/layers/arg_layer.cpp
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
// This file is part of OpenCV project.
|
||||||
|
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||||
|
// of this distribution and at http://opencv.org/license.html.
|
||||||
|
|
||||||
|
#include "../precomp.hpp"
|
||||||
|
#include "layers_common.hpp"
|
||||||
|
|
||||||
|
|
||||||
|
namespace cv { namespace dnn {
|
||||||
|
|
||||||
|
class ArgLayerImpl CV_FINAL : public ArgLayer
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
enum class ArgOp
|
||||||
|
{
|
||||||
|
MIN = 0,
|
||||||
|
MAX = 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
ArgLayerImpl(const LayerParams& params)
|
||||||
|
{
|
||||||
|
setParamsFrom(params);
|
||||||
|
|
||||||
|
axis = params.get<int>("axis", 0);
|
||||||
|
keepdims = (params.get<int>("keepdims", 1) == 1);
|
||||||
|
select_last_index = (params.get<int>("select_last_index", 0) == 1);
|
||||||
|
|
||||||
|
const std::string& argOp = params.get<std::string>("op");
|
||||||
|
|
||||||
|
if (argOp == "max")
|
||||||
|
{
|
||||||
|
op = ArgOp::MAX;
|
||||||
|
}
|
||||||
|
else if (argOp == "min")
|
||||||
|
{
|
||||||
|
op = ArgOp::MIN;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
CV_Error(Error::StsBadArg, "Unsupported operation");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual bool supportBackend(int backendId) CV_OVERRIDE
|
||||||
|
{
|
||||||
|
return backendId == DNN_BACKEND_OPENCV && preferableTarget == DNN_TARGET_CPU;
|
||||||
|
}
|
||||||
|
|
||||||
|
void handleKeepDims(MatShape& shape, const int axis_) const
|
||||||
|
{
|
||||||
|
if (keepdims)
|
||||||
|
{
|
||||||
|
shape[axis_] = 1;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
shape.erase(shape.begin() + axis_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||||
|
const int requiredOutputs,
|
||||||
|
std::vector<MatShape> &outputs,
|
||||||
|
std::vector<MatShape> &internals) const CV_OVERRIDE
|
||||||
|
{
|
||||||
|
MatShape inpShape = inputs[0];
|
||||||
|
|
||||||
|
const int axis_ = normalize_axis(axis, inpShape);
|
||||||
|
handleKeepDims(inpShape, axis_);
|
||||||
|
outputs.assign(1, inpShape);
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
|
||||||
|
{
|
||||||
|
CV_TRACE_FUNCTION();
|
||||||
|
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
|
||||||
|
|
||||||
|
std::vector<Mat> inputs, outputs;
|
||||||
|
inputs_arr.getMatVector(inputs);
|
||||||
|
outputs_arr.getMatVector(outputs);
|
||||||
|
|
||||||
|
CV_Assert_N(inputs.size() == 1, outputs.size() == 1);
|
||||||
|
std::vector<int> outShape = shape(outputs[0]);
|
||||||
|
Mat output(outShape, CV_32SC1);
|
||||||
|
|
||||||
|
switch (op)
|
||||||
|
{
|
||||||
|
case ArgOp::MIN:
|
||||||
|
cv::reduceArgMin(inputs[0], output, axis, select_last_index);
|
||||||
|
break;
|
||||||
|
case ArgOp::MAX:
|
||||||
|
cv::reduceArgMax(inputs[0], output, axis, select_last_index);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
CV_Error(Error::StsBadArg, "Unsupported operation.");
|
||||||
|
}
|
||||||
|
|
||||||
|
output = output.reshape(1, outShape);
|
||||||
|
output.convertTo(outputs[0], CV_32FC1);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// The axis in which to compute the arg indices. Accepted range is [-r, r-1] where r = rank(data).
|
||||||
|
int axis;
|
||||||
|
// Keep the reduced dimension or not
|
||||||
|
bool keepdims;
|
||||||
|
// Whether to select the first or the last index or Max/Min.
|
||||||
|
bool select_last_index;
|
||||||
|
// Operation to be performed
|
||||||
|
ArgOp op;
|
||||||
|
};
|
||||||
|
|
||||||
|
Ptr<ArgLayer> ArgLayer::create(const LayerParams& params)
|
||||||
|
{
|
||||||
|
return Ptr<ArgLayer>(new ArgLayerImpl(params));
|
||||||
|
}
|
||||||
|
|
||||||
|
}} // namespace cv::dnn
|
@ -100,6 +100,7 @@ private:
|
|||||||
const DispatchMap dispatch;
|
const DispatchMap dispatch;
|
||||||
static const DispatchMap buildDispatchMap();
|
static const DispatchMap buildDispatchMap();
|
||||||
|
|
||||||
|
void parseArg (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||||
void parseMaxPool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
void parseMaxPool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||||
void parseAveragePool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
void parseAveragePool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||||
void parseReduce (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
void parseReduce (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||||
@ -768,6 +769,14 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ONNXImporter::parseArg(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||||
|
{
|
||||||
|
const std::string& layer_type = node_proto.op_type();
|
||||||
|
layerParams.type = "Arg";
|
||||||
|
layerParams.set("op", layer_type == "ArgMax" ? "max" : "min");
|
||||||
|
addLayer(layerParams, node_proto);
|
||||||
|
}
|
||||||
|
|
||||||
void setCeilMode(LayerParams& layerParams)
|
void setCeilMode(LayerParams& layerParams)
|
||||||
{
|
{
|
||||||
// auto_pad attribute is deprecated and uses ceil
|
// auto_pad attribute is deprecated and uses ceil
|
||||||
@ -2986,6 +2995,7 @@ const ONNXImporter::DispatchMap ONNXImporter::buildDispatchMap()
|
|||||||
{
|
{
|
||||||
DispatchMap dispatch;
|
DispatchMap dispatch;
|
||||||
|
|
||||||
|
dispatch["ArgMax"] = dispatch["ArgMin"] = &ONNXImporter::parseArg;
|
||||||
dispatch["MaxPool"] = &ONNXImporter::parseMaxPool;
|
dispatch["MaxPool"] = &ONNXImporter::parseMaxPool;
|
||||||
dispatch["AveragePool"] = &ONNXImporter::parseAveragePool;
|
dispatch["AveragePool"] = &ONNXImporter::parseAveragePool;
|
||||||
dispatch["GlobalAveragePool"] = dispatch["GlobalMaxPool"] = dispatch["ReduceMean"] = dispatch["ReduceSum"] =
|
dispatch["GlobalAveragePool"] = dispatch["GlobalMaxPool"] = dispatch["ReduceMean"] = dispatch["ReduceSum"] =
|
||||||
|
@ -355,6 +355,15 @@ TEST_P(Test_ONNX_layers, Min)
|
|||||||
testONNXModels("min", npy, 0, 0, false, true, 2);
|
testONNXModels("min", npy, 0, 0, false, true, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(Test_ONNX_layers, ArgLayer)
|
||||||
|
{
|
||||||
|
if (backend != DNN_BACKEND_OPENCV || target != DNN_TARGET_CPU)
|
||||||
|
throw SkipTestException("Only CPU is supported"); // FIXIT use tags
|
||||||
|
|
||||||
|
testONNXModels("argmax");
|
||||||
|
testONNXModels("argmin");
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(Test_ONNX_layers, Scale)
|
TEST_P(Test_ONNX_layers, Scale)
|
||||||
{
|
{
|
||||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
|
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
|
||||||
|
Loading…
Reference in New Issue
Block a user