Merge pull request #20453 from rogday:onnx_importer_fix

Split layer dispatch into functions in ONNXImporter

* split layer dispatch into functions

* fixes

* identation and comment fixes

* fix constness
This commit is contained in:
rogday 2021-07-28 18:06:24 +03:00 committed by GitHub
parent d83901e665
commit cff0168f3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 1724 additions and 1517 deletions

View File

@ -14,6 +14,7 @@ Mutex& getInitializationMutex();
void initializeLayerFactory(); void initializeLayerFactory();
namespace detail { namespace detail {
#define CALL_MEMBER_FN(object, ptrToMemFn) ((object).*(ptrToMemFn))
struct NetImplBase struct NetImplBase
{ {

View File

@ -62,7 +62,7 @@ class ONNXImporter
public: public:
ONNXImporter(Net& net, const char *onnxFile) ONNXImporter(Net& net, const char *onnxFile)
: dstNet(net) : dstNet(net), dispatch(buildDispatchMap())
{ {
hasDynamicShapes = false; hasDynamicShapes = false;
CV_Assert(onnxFile); CV_Assert(onnxFile);
@ -83,7 +83,7 @@ public:
} }
ONNXImporter(Net& net, const char* buffer, size_t sizeBuffer) ONNXImporter(Net& net, const char* buffer, size_t sizeBuffer)
: dstNet(net) : dstNet(net), dispatch(buildDispatchMap())
{ {
hasDynamicShapes = false; hasDynamicShapes = false;
CV_LOG_DEBUG(NULL, "DNN/ONNX: processing in-memory ONNX model (" << sizeBuffer << " bytes)"); CV_LOG_DEBUG(NULL, "DNN/ONNX: processing in-memory ONNX model (" << sizeBuffer << " bytes)");
@ -124,6 +124,57 @@ protected:
typedef std::map<std::string, LayerInfo>::iterator IterLayerId_t; typedef std::map<std::string, LayerInfo>::iterator IterLayerId_t;
void handleNode(const opencv_onnx::NodeProto& node_proto); void handleNode(const opencv_onnx::NodeProto& node_proto);
private:
typedef void (ONNXImporter::*ONNXImporterNodeParser)(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
typedef std::map<std::string, ONNXImporterNodeParser> DispatchMap;
void parseMaxPool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseAveragePool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseReduce (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseSlice (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseSplit (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseBias (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parsePow (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseMax (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseNeg (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseConstant (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseLSTM (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseImageScaler (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseClip (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseLeakyRelu (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseRelu (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseElu (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseTanh (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parsePRelu (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseLRN (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseInstanceNormalization(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseBatchNormalization (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseGemm (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseMatMul (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseMul (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseConv (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseConvTranspose (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseTranspose (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseSqueeze (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseFlatten (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseUnsqueeze (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseExpand (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseReshape (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parsePad (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseShape (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseCast (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseConstantFill (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseGather (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseConcat (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseResize (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseUpsample (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseSoftMax (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseDetectionOutput (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseCustom (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
const DispatchMap dispatch;
static const DispatchMap buildDispatchMap();
}; };
inline void replaceLayerParam(LayerParams& layerParams, const String& oldKey, const String& newKey) inline void replaceLayerParam(LayerParams& layerParams, const String& oldKey, const String& newKey)
@ -448,13 +499,11 @@ void ONNXImporter::populateNet()
CV_LOG_DEBUG(NULL, "DNN/ONNX: import completed!"); CV_LOG_DEBUG(NULL, "DNN/ONNX: import completed!");
} }
void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_) void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto)
{ {
opencv_onnx::NodeProto node_proto = node_proto_; // TODO FIXIT
CV_Assert(node_proto.output_size() >= 1); CV_Assert(node_proto.output_size() >= 1);
std::string name = node_proto.output(0); std::string name = node_proto.output(0);
std::string layer_type = node_proto.op_type(); const std::string& layer_type = node_proto.op_type();
CV_LOG_DEBUG(NULL, "DNN/ONNX: processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: " CV_LOG_DEBUG(NULL, "DNN/ONNX: processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str()) << cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
); );
@ -468,22 +517,56 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
layerParams.type = layer_type; layerParams.type = layer_type;
layerParams.set("has_dynamic_shapes", hasDynamicShapes); layerParams.set("has_dynamic_shapes", hasDynamicShapes);
if (layer_type == "MaxPool") DispatchMap::const_iterator iter = dispatch.find(layer_type);
if (iter != dispatch.end())
{ {
CALL_MEMBER_FN(*this, iter->second)(layerParams, node_proto);
}
else
{
parseCustom(layerParams, node_proto);
}
}
catch (const cv::Exception& e)
{
CV_LOG_ERROR(NULL, "DNN/ONNX: ERROR during processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
);
for (int i = 0; i < node_proto.input_size(); i++)
{
CV_LOG_INFO(NULL, " Input[" << i << "] = '" << node_proto.input(i) << "'");
}
for (int i = 0; i < node_proto.output_size(); i++)
{
CV_LOG_INFO(NULL, " Output[" << i << "] = '" << node_proto.output(i) << "'");
}
CV_Error(Error::StsError, cv::format("Node [%s]:(%s) parse error: %s", layer_type.c_str(), name.c_str(), e.what()));
}
}
void ONNXImporter::parseMaxPool(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "Pooling"; layerParams.type = "Pooling";
layerParams.set("pool", "MAX"); layerParams.set("pool", "MAX");
layerParams.set("ceil_mode", layerParams.has("pad_mode")); layerParams.set("ceil_mode", layerParams.has("pad_mode"));
} addLayer(layerParams, node_proto);
else if (layer_type == "AveragePool") }
{
void ONNXImporter::parseAveragePool(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "Pooling"; layerParams.type = "Pooling";
layerParams.set("pool", "AVE"); layerParams.set("pool", "AVE");
layerParams.set("ceil_mode", layerParams.has("pad_mode")); layerParams.set("ceil_mode", layerParams.has("pad_mode"));
layerParams.set("ave_pool_padded_area", framework_name == "pytorch"); layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
} addLayer(layerParams, node_proto);
else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" || }
layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax")
{ // "GlobalAveragePool" "GlobalMaxPool" "ReduceMean" "ReduceSum" "ReduceMax"
void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
const std::string& layer_type = node_proto.op_type();
CV_Assert(node_proto.input_size() == 1); CV_Assert(node_proto.input_size() == 1);
layerParams.type = "Pooling"; layerParams.type = "Pooling";
String pool; String pool;
@ -635,9 +718,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
node_proto.set_input(0, node_proto.output(0)); node_proto.set_input(0, node_proto.output(0));
node_proto.set_output(0, layerParams.name); node_proto.set_output(0, layerParams.name);
} }
} addLayer(layerParams, node_proto);
else if (layer_type == "Slice") }
{
void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
int axis = 0; int axis = 0;
std::vector<int> begin; std::vector<int> begin;
std::vector<int> end; std::vector<int> end;
@ -744,9 +829,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
addConstant(layerParams.name, sliced[0]); addConstant(layerParams.name, sliced[0]);
return; return;
} }
} addLayer(layerParams, node_proto);
else if (layer_type == "Split") }
{
void ONNXImporter::parseSplit(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
if (layerParams.has("split")) if (layerParams.has("split"))
{ {
DictValue splits = layerParams.get("split"); DictValue splits = layerParams.get("split");
@ -765,9 +852,14 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
layerParams.set("num_split", node_proto.output_size()); layerParams.set("num_split", node_proto.output_size());
} }
layerParams.type = "Slice"; layerParams.type = "Slice";
} addLayer(layerParams, node_proto);
else if (layer_type == "Add" || layer_type == "Sum" || layer_type == "Sub") }
{
// "Add" "Sum" "Sub"
void ONNXImporter::parseBias(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
const std::string& layer_type = node_proto.op_type();
bool isSub = layer_type == "Sub"; bool isSub = layer_type == "Sub";
CV_CheckEQ(node_proto.input_size(), 2, ""); CV_CheckEQ(node_proto.input_size(), 2, "");
bool is_const_0 = layer_id.find(node_proto.input(0)) == layer_id.end(); bool is_const_0 = layer_id.find(node_proto.input(0)) == layer_id.end();
@ -859,9 +951,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
layerParams.type = "Scale"; layerParams.type = "Scale";
layerParams.set("bias_term", true); layerParams.set("bias_term", true);
} }
} addLayer(layerParams, node_proto);
else if (layer_type == "Pow") }
{
void ONNXImporter::parsePow(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
if (layer_id.find(node_proto.input(1)) != layer_id.end()) if (layer_id.find(node_proto.input(1)) != layer_id.end())
CV_Error(Error::StsNotImplemented, "Unsupported Pow op with variable power"); CV_Error(Error::StsNotImplemented, "Unsupported Pow op with variable power");
@ -872,26 +966,33 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
blob.convertTo(blob, CV_32F); blob.convertTo(blob, CV_32F);
layerParams.type = "Power"; layerParams.type = "Power";
layerParams.set("power", blob.ptr<float>()[0]); layerParams.set("power", blob.ptr<float>()[0]);
} addLayer(layerParams, node_proto);
else if (layer_type == "Max") }
{
void ONNXImporter::parseMax(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "Eltwise"; layerParams.type = "Eltwise";
layerParams.set("operation", "max"); layerParams.set("operation", "max");
} addLayer(layerParams, node_proto);
else if (layer_type == "Neg") }
{
void ONNXImporter::parseNeg(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "Power"; layerParams.type = "Power";
layerParams.set("scale", -1); layerParams.set("scale", -1);
} addLayer(layerParams, node_proto);
else if (layer_type == "Constant") }
{
void ONNXImporter::parseConstant(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
CV_Assert(node_proto.input_size() == 0); CV_Assert(node_proto.input_size() == 0);
CV_Assert(layerParams.blobs.size() == 1); CV_Assert(layerParams.blobs.size() == 1);
addConstant(layerParams.name, layerParams.blobs[0]); addConstant(layerParams.name, layerParams.blobs[0]);
return; }
}
else if (layer_type == "LSTM") void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{ {
opencv_onnx::NodeProto node_proto = node_proto_;
LayerParams lstmParams = layerParams; LayerParams lstmParams = layerParams;
lstmParams.name += "/lstm"; lstmParams.name += "/lstm";
@ -958,9 +1059,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
layerParams.set("dim", DictValue::arrayInt(&lstmShape[0], lstmShape.size())); layerParams.set("dim", DictValue::arrayInt(&lstmShape[0], lstmShape.size()));
node_proto.set_input(0, lstmParams.name); // redirect input to LSTM node_proto.set_input(0, lstmParams.name); // redirect input to LSTM
node_proto.set_output(0, layerParams.name); // keep origin LSTM's name node_proto.set_output(0, layerParams.name); // keep origin LSTM's name
} addLayer(layerParams, node_proto);
else if (layer_type == "ImageScaler") }
{
void ONNXImporter::parseImageScaler(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
const float scale = layerParams.has("scale") ? layerParams.get<float>("scale") : 1.0f; const float scale = layerParams.has("scale") ? layerParams.get<float>("scale") : 1.0f;
layerParams.erase("scale"); layerParams.erase("scale");
@ -982,42 +1085,58 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
layerParams.set("scale", scale); layerParams.set("scale", scale);
layerParams.type = "Power"; layerParams.type = "Power";
} }
} addLayer(layerParams, node_proto);
else if (layer_type == "Clip") }
{
void ONNXImporter::parseClip(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "ReLU6"; layerParams.type = "ReLU6";
replaceLayerParam(layerParams, "min", "min_value"); replaceLayerParam(layerParams, "min", "min_value");
replaceLayerParam(layerParams, "max", "max_value"); replaceLayerParam(layerParams, "max", "max_value");
addLayer(layerParams, node_proto);
}
} void ONNXImporter::parseLeakyRelu(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
else if (layer_type == "LeakyRelu") {
{
layerParams.type = "ReLU"; layerParams.type = "ReLU";
replaceLayerParam(layerParams, "alpha", "negative_slope"); replaceLayerParam(layerParams, "alpha", "negative_slope");
} addLayer(layerParams, node_proto);
else if (layer_type == "Relu") }
{
void ONNXImporter::parseRelu(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "ReLU"; layerParams.type = "ReLU";
} addLayer(layerParams, node_proto);
else if (layer_type == "Elu") }
{
void ONNXImporter::parseElu(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "ELU"; layerParams.type = "ELU";
} addLayer(layerParams, node_proto);
else if (layer_type == "Tanh") }
{
void ONNXImporter::parseTanh(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "TanH"; layerParams.type = "TanH";
} addLayer(layerParams, node_proto);
else if (layer_type == "PRelu") }
{
void ONNXImporter::parsePRelu(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "PReLU"; layerParams.type = "PReLU";
layerParams.blobs.push_back(getBlob(node_proto, 1)); layerParams.blobs.push_back(getBlob(node_proto, 1));
} addLayer(layerParams, node_proto);
else if (layer_type == "LRN") }
{
void ONNXImporter::parseLRN(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
replaceLayerParam(layerParams, "size", "local_size"); replaceLayerParam(layerParams, "size", "local_size");
} addLayer(layerParams, node_proto);
else if (layer_type == "InstanceNormalization") }
{
void ONNXImporter::parseInstanceNormalization(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
if (node_proto.input_size() != 3) if (node_proto.input_size() != 3)
CV_Error(Error::StsNotImplemented, CV_Error(Error::StsNotImplemented,
"Expected input, scale, bias"); "Expected input, scale, bias");
@ -1052,9 +1171,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
//Replace Batch Norm's input to MVN //Replace Batch Norm's input to MVN
node_proto.set_input(0, mvnParams.name); node_proto.set_input(0, mvnParams.name);
layerParams.type = "BatchNorm"; layerParams.type = "BatchNorm";
} addLayer(layerParams, node_proto);
else if (layer_type == "BatchNormalization") }
{
void ONNXImporter::parseBatchNormalization(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
if (node_proto.input_size() != 5) if (node_proto.input_size() != 5)
CV_Error(Error::StsNotImplemented, CV_Error(Error::StsNotImplemented,
"Expected input, scale, bias, mean and var"); "Expected input, scale, bias, mean and var");
@ -1082,9 +1203,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
} else { } else {
layerParams.set("has_bias", false); layerParams.set("has_bias", false);
} }
} addLayer(layerParams, node_proto);
else if (layer_type == "Gemm") }
{
void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
CV_Assert(node_proto.input_size() >= 2); CV_Assert(node_proto.input_size() >= 2);
layerParams.type = "InnerProduct"; layerParams.type = "InnerProduct";
Mat weights = getBlob(node_proto, 1); Mat weights = getBlob(node_proto, 1);
@ -1115,9 +1238,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
layerParams.set("num_output", layerParams.blobs[0].size[ind_num_out]); layerParams.set("num_output", layerParams.blobs[0].size[ind_num_out]);
layerParams.set("bias_term", node_proto.input_size() == 3); layerParams.set("bias_term", node_proto.input_size() == 3);
} addLayer(layerParams, node_proto);
else if (layer_type == "MatMul") }
{
void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
CV_Assert(node_proto.input_size() == 2); CV_Assert(node_proto.input_size() == 2);
layerParams.type = "InnerProduct"; layerParams.type = "InnerProduct";
layerParams.set("bias_term", false); layerParams.set("bias_term", false);
@ -1135,9 +1260,14 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
secondInpDims = outShapes[node_proto.input(1)].size(); secondInpDims = outShapes[node_proto.input(1)].size();
} }
layerParams.set("axis", firstInpDims - secondInpDims + 1); layerParams.set("axis", firstInpDims - secondInpDims + 1);
} addLayer(layerParams, node_proto);
else if (layer_type == "Mul" || layer_type == "Div") }
{
// "Mul" "Div"
void ONNXImporter::parseMul(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
const std::string& layer_type = node_proto.op_type();
CV_Assert(node_proto.input_size() == 2); CV_Assert(node_proto.input_size() == 2);
bool isDiv = layer_type == "Div"; bool isDiv = layer_type == "Div";
@ -1255,9 +1385,12 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
} }
layerParams.type = "Scale"; layerParams.type = "Scale";
} }
} addLayer(layerParams, node_proto);
else if (layer_type == "Conv") }
{
void ONNXImporter::parseConv(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
CV_Assert(node_proto.input_size() >= 2); CV_Assert(node_proto.input_size() >= 2);
layerParams.type = "Convolution"; layerParams.type = "Convolution";
for (int j = 1; j < node_proto.input_size(); j++) { for (int j = 1; j < node_proto.input_size(); j++) {
@ -1307,9 +1440,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
node_proto.set_input(0, padLp.name); node_proto.set_input(0, padLp.name);
} }
} }
} addLayer(layerParams, node_proto);
else if (layer_type == "ConvTranspose") }
{
void ONNXImporter::parseConvTranspose(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
CV_Assert(node_proto.input_size() >= 2); CV_Assert(node_proto.input_size() >= 2);
layerParams.type = "Deconvolution"; layerParams.type = "Deconvolution";
for (int j = 1; j < node_proto.input_size(); j++) { for (int j = 1; j < node_proto.input_size(); j++) {
@ -1350,9 +1485,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
{ {
replaceLayerParam(layerParams, "output_padding", "adj"); replaceLayerParam(layerParams, "output_padding", "adj");
} }
} addLayer(layerParams, node_proto);
else if (layer_type == "Transpose") }
{
void ONNXImporter::parseTranspose(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "Permute"; layerParams.type = "Permute";
replaceLayerParam(layerParams, "perm", "order"); replaceLayerParam(layerParams, "perm", "order");
@ -1365,9 +1502,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
addConstant(layerParams.name, transposed[0]); addConstant(layerParams.name, transposed[0]);
return; return;
} }
} addLayer(layerParams, node_proto);
else if (layer_type == "Squeeze") }
{
void ONNXImporter::parseSqueeze(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
CV_Assert_N(node_proto.input_size() == 1, layerParams.has("axes")); CV_Assert_N(node_proto.input_size() == 1, layerParams.has("axes"));
DictValue axes_dict = layerParams.get("axes"); DictValue axes_dict = layerParams.get("axes");
MatShape inpShape = outShapes[node_proto.input(0)]; MatShape inpShape = outShapes[node_proto.input(0)];
@ -1415,9 +1554,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
addConstant(layerParams.name, out); addConstant(layerParams.name, out);
return; return;
} }
} addLayer(layerParams, node_proto);
else if (layer_type == "Flatten") }
{
void ONNXImporter::parseFlatten(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
CV_CheckEQ(node_proto.input_size(), 1, ""); CV_CheckEQ(node_proto.input_size(), 1, "");
if (constBlobs.find(node_proto.input(0)) != constBlobs.end()) if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
{ {
@ -1430,9 +1571,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
addConstant(layerParams.name, output); addConstant(layerParams.name, output);
return; return;
} }
} addLayer(layerParams, node_proto);
else if (layer_type == "Unsqueeze") }
{
void ONNXImporter::parseUnsqueeze(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
CV_Assert(node_proto.input_size() == 1); CV_Assert(node_proto.input_size() == 1);
DictValue axes = layerParams.get("axes"); DictValue axes = layerParams.get("axes");
if (constBlobs.find(node_proto.input(0)) != constBlobs.end()) if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
@ -1478,9 +1621,12 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
layerParams.set("dynamic_axes", DictValue::arrayInt(dynamicAxes.data(), dynamicAxes.size())); layerParams.set("dynamic_axes", DictValue::arrayInt(dynamicAxes.data(), dynamicAxes.size()));
layerParams.set("input_indices", DictValue::arrayInt(inputIndices.data(), inputIndices.size())); layerParams.set("input_indices", DictValue::arrayInt(inputIndices.data(), inputIndices.size()));
} }
} addLayer(layerParams, node_proto);
else if (layer_type == "Expand") }
{
void ONNXImporter::parseExpand(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
CV_CheckEQ(node_proto.input_size(), 2, ""); CV_CheckEQ(node_proto.input_size(), 2, "");
const std::string& input0 = node_proto.input(0); const std::string& input0 = node_proto.input(0);
const std::string& input1 = node_proto.input(1); const std::string& input1 = node_proto.input(1);
@ -1602,9 +1748,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
} }
else else
CV_Error(Error::StsNotImplemented, "Unsupported Expand op"); CV_Error(Error::StsNotImplemented, "Unsupported Expand op");
} addLayer(layerParams, node_proto);
else if (layer_type == "Reshape") }
{
void ONNXImporter::parseReshape(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
CV_Assert(node_proto.input_size() == 2 || layerParams.has("shape")); CV_Assert(node_proto.input_size() == 2 || layerParams.has("shape"));
if (node_proto.input_size() == 2) { if (node_proto.input_size() == 2) {
@ -1636,9 +1784,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
} }
replaceLayerParam(layerParams, "shape", "dim"); replaceLayerParam(layerParams, "shape", "dim");
} }
} addLayer(layerParams, node_proto);
else if (layer_type == "Pad") }
{
void ONNXImporter::parsePad(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
layerParams.type = "Padding"; layerParams.type = "Padding";
replaceLayerParam(layerParams, "mode", "type"); replaceLayerParam(layerParams, "mode", "type");
if (node_proto.input_size() == 3 || node_proto.input_size() == 2) if (node_proto.input_size() == 3 || node_proto.input_size() == 2)
@ -1655,9 +1805,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
layerParams.set("value", value.ptr<float>()[0]); layerParams.set("value", value.ptr<float>()[0]);
} }
} }
} addLayer(layerParams, node_proto);
else if (layer_type == "Shape") }
{
void ONNXImporter::parseShape(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
CV_Assert(node_proto.input_size() == 1); CV_Assert(node_proto.input_size() == 1);
IterShape_t shapeIt = outShapes.find(node_proto.input(0)); IterShape_t shapeIt = outShapes.find(node_proto.input(0));
CV_Assert(shapeIt != outShapes.end()); CV_Assert(shapeIt != outShapes.end());
@ -1669,10 +1821,10 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
shapeMat.dims = 1; shapeMat.dims = 1;
addConstant(layerParams.name, shapeMat); addConstant(layerParams.name, shapeMat);
return; }
}
else if (layer_type == "Cast") void ONNXImporter::parseCast(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{ {
if (constBlobs.find(node_proto.input(0)) != constBlobs.end()) if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
{ {
Mat blob = getBlob(node_proto, 0); Mat blob = getBlob(node_proto, 0);
@ -1697,9 +1849,12 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
} }
else else
layerParams.type = "Identity"; layerParams.type = "Identity";
} addLayer(layerParams, node_proto);
else if (layer_type == "ConstantOfShape" || layer_type == "ConstantFill") }
{
// "ConstantOfShape" "ConstantFill"
void ONNXImporter::parseConstantFill(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
int depth = CV_32F; int depth = CV_32F;
float fill_value; float fill_value;
if (!layerParams.blobs.empty()) if (!layerParams.blobs.empty())
@ -1718,10 +1873,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
CV_CheckGT(inpShape[i], 0, ""); CV_CheckGT(inpShape[i], 0, "");
Mat tensor(inpShape.size(), &inpShape[0], depth, Scalar(fill_value)); Mat tensor(inpShape.size(), &inpShape[0], depth, Scalar(fill_value));
addConstant(layerParams.name, tensor); addConstant(layerParams.name, tensor);
return; }
}
else if (layer_type == "Gather") void ONNXImporter::parseGather(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{ {
opencv_onnx::NodeProto node_proto = node_proto_;
CV_Assert(node_proto.input_size() == 2); CV_Assert(node_proto.input_size() == 2);
Mat indexMat = getBlob(node_proto, 1); Mat indexMat = getBlob(node_proto, 1);
CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1); CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1);
@ -1796,9 +1952,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
layerParams = sliceLp; layerParams = sliceLp;
} }
} }
} addLayer(layerParams, node_proto);
else if (layer_type == "Concat") }
{
void ONNXImporter::parseConcat(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
bool hasVariableInps = false; bool hasVariableInps = false;
for (int i = 0; i < node_proto.input_size(); ++i) for (int i = 0; i < node_proto.input_size(); ++i)
{ {
@ -1856,9 +2014,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
} }
} }
} }
} addLayer(layerParams, node_proto);
else if (layer_type == "Resize") }
{
void ONNXImporter::parseResize(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
for (int i = 1; i < node_proto.input_size(); i++) for (int i = 1; i < node_proto.input_size(); i++)
CV_Assert(layer_id.find(node_proto.input(i)) == layer_id.end()); CV_Assert(layer_id.find(node_proto.input(i)) == layer_id.end());
@ -1903,9 +2063,11 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
} }
} }
replaceLayerParam(layerParams, "mode", "interpolation"); replaceLayerParam(layerParams, "mode", "interpolation");
} addLayer(layerParams, node_proto);
else if (layer_type == "Upsample") }
{
void ONNXImporter::parseUpsample(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
//fused from Resize Subgraph //fused from Resize Subgraph
if (layerParams.has("coordinate_transformation_mode")) if (layerParams.has("coordinate_transformation_mode"))
{ {
@ -1950,14 +2112,21 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
} }
} }
replaceLayerParam(layerParams, "mode", "interpolation"); replaceLayerParam(layerParams, "mode", "interpolation");
} addLayer(layerParams, node_proto);
else if (layer_type == "SoftMax" || layer_type == "LogSoftmax") }
{
// "SoftMax" "LogSoftmax"
void ONNXImporter::parseSoftMax(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
const std::string& layer_type = node_proto.op_type();
layerParams.type = "Softmax"; layerParams.type = "Softmax";
layerParams.set("log_softmax", layer_type == "LogSoftmax"); layerParams.set("log_softmax", layer_type == "LogSoftmax");
} addLayer(layerParams, node_proto);
else if (layer_type == "DetectionOutput") }
{
void ONNXImporter::parseDetectionOutput(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
CV_CheckEQ(node_proto.input_size(), 3, ""); CV_CheckEQ(node_proto.input_size(), 3, "");
if (constBlobs.find(node_proto.input(2)) != constBlobs.end()) if (constBlobs.find(node_proto.input(2)) != constBlobs.end())
{ {
@ -1974,31 +2143,68 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
node_proto.set_input(2, constParams.name); node_proto.set_input(2, constParams.name);
} }
} addLayer(layerParams, node_proto);
else }
{
void ONNXImporter::parseCustom(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
for (int j = 0; j < node_proto.input_size(); j++) { for (int j = 0; j < node_proto.input_size(); j++) {
if (layer_id.find(node_proto.input(j)) == layer_id.end()) if (layer_id.find(node_proto.input(j)) == layer_id.end())
layerParams.blobs.push_back(getBlob(node_proto, j)); layerParams.blobs.push_back(getBlob(node_proto, j));
} }
}
addLayer(layerParams, node_proto); addLayer(layerParams, node_proto);
} }
catch (const cv::Exception& e)
{ const ONNXImporter::DispatchMap ONNXImporter::buildDispatchMap()
CV_LOG_ERROR(NULL, "DNN/ONNX: ERROR during processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: " {
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str()) DispatchMap dispatch;
);
for (int i = 0; i < node_proto.input_size(); i++) dispatch["MaxPool"] = &ONNXImporter::parseMaxPool;
{ dispatch["AveragePool"] = &ONNXImporter::parseAveragePool;
CV_LOG_INFO(NULL, " Input[" << i << "] = '" << node_proto.input(i) << "'"); dispatch["GlobalAveragePool"] = dispatch["GlobalMaxPool"] = dispatch["ReduceMean"] = dispatch["ReduceSum"] =
} dispatch["ReduceMax"] = &ONNXImporter::parseReduce;
for (int i = 0; i < node_proto.output_size(); i++) dispatch["Slice"] = &ONNXImporter::parseSlice;
{ dispatch["Split"] = &ONNXImporter::parseSplit;
CV_LOG_INFO(NULL, " Output[" << i << "] = '" << node_proto.output(i) << "'"); dispatch["Add"] = dispatch["Sum"] = dispatch["Sub"] = &ONNXImporter::parseBias;
} dispatch["Pow"] = &ONNXImporter::parsePow;
CV_Error(Error::StsError, cv::format("Node [%s]:(%s) parse error: %s", layer_type.c_str(), name.c_str(), e.what())); dispatch["Max"] = &ONNXImporter::parseMax;
} dispatch["Neg"] = &ONNXImporter::parseNeg;
dispatch["Constant"] = &ONNXImporter::parseConstant;
dispatch["LSTM"] = &ONNXImporter::parseLSTM;
dispatch["ImageScaler"] = &ONNXImporter::parseImageScaler;
dispatch["Clip"] = &ONNXImporter::parseClip;
dispatch["LeakyRelu"] = &ONNXImporter::parseLeakyRelu;
dispatch["Relu"] = &ONNXImporter::parseRelu;
dispatch["Elu"] = &ONNXImporter::parseElu;
dispatch["Tanh"] = &ONNXImporter::parseTanh;
dispatch["PRelu"] = &ONNXImporter::parsePRelu;
dispatch["LRN"] = &ONNXImporter::parseLRN;
dispatch["InstanceNormalization"] = &ONNXImporter::parseInstanceNormalization;
dispatch["BatchNormalization"] = &ONNXImporter::parseBatchNormalization;
dispatch["Gemm"] = &ONNXImporter::parseGemm;
dispatch["MatMul"] = &ONNXImporter::parseMatMul;
dispatch["Mul"] = dispatch["Div"] = &ONNXImporter::parseMul;
dispatch["Conv"] = &ONNXImporter::parseConv;
dispatch["ConvTranspose"] = &ONNXImporter::parseConvTranspose;
dispatch["Transpose"] = &ONNXImporter::parseTranspose;
dispatch["Squeeze"] = &ONNXImporter::parseSqueeze;
dispatch["Flatten"] = &ONNXImporter::parseFlatten;
dispatch["Unsqueeze"] = &ONNXImporter::parseUnsqueeze;
dispatch["Expand"] = &ONNXImporter::parseExpand;
dispatch["Reshape"] = &ONNXImporter::parseReshape;
dispatch["Pad"] = &ONNXImporter::parsePad;
dispatch["Shape"] = &ONNXImporter::parseShape;
dispatch["Cast"] = &ONNXImporter::parseCast;
dispatch["ConstantFill"] = dispatch["ConstantOfShape"] = &ONNXImporter::parseConstantFill;
dispatch["Gather"] = &ONNXImporter::parseGather;
dispatch["Concat"] = &ONNXImporter::parseConcat;
dispatch["Resize"] = &ONNXImporter::parseResize;
dispatch["Upsample"] = &ONNXImporter::parseUpsample;
dispatch["SoftMax"] = dispatch["LogSoftmax"] = &ONNXImporter::parseSoftMax;
dispatch["DetectionOutput"] = &ONNXImporter::parseDetectionOutput;
dispatch["Custom"] = &ONNXImporter::parseCustom;
return dispatch;
} }
Net readNetFromONNX(const String& onnxFile) Net readNetFromONNX(const String& onnxFile)

View File

@ -2869,7 +2869,7 @@ void TFImporter::parseNode(const tensorflow::NodeDef& layer)
DispatchMap::const_iterator iter = dispatch.find(type); DispatchMap::const_iterator iter = dispatch.find(type);
if (iter != dispatch.end()) if (iter != dispatch.end())
{ {
((*this).*(iter->second))(net, layer, layerParams); CALL_MEMBER_FN(*this, iter->second)(net, layer, layerParams);
} }
else else
{ {