mirror of
https://github.com/opencv/opencv.git
synced 2025-06-12 04:12:52 +08:00
Merge pull request #21343 from alalek:dnn_onnx_handle_domains
This commit is contained in:
commit
3b0ed61826
@ -15,6 +15,9 @@
|
||||
#define CV_LOG_STRIP_LEVEL CV_LOG_LEVEL_VERBOSE + 1
|
||||
#include <opencv2/core/utils/logger.hpp>
|
||||
|
||||
#include <opencv2/core/utils/configuration.private.hpp>
|
||||
|
||||
|
||||
#ifdef HAVE_PROTOBUF
|
||||
|
||||
#include <iostream>
|
||||
@ -23,6 +26,10 @@
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
|
||||
#if defined _MSC_VER && _MSC_VER < 1910/*MSVS 2017*/
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable: 4503) // decorated name length exceeded, name was truncated
|
||||
#endif
|
||||
|
||||
#if defined(__GNUC__) && __GNUC__ >= 5
|
||||
#pragma GCC diagnostic push
|
||||
@ -41,8 +48,6 @@ CV__DNN_INLINE_NS_BEGIN
|
||||
|
||||
extern bool DNN_DIAGNOSTICS_RUN;
|
||||
|
||||
class ONNXLayerHandler;
|
||||
|
||||
class ONNXImporter
|
||||
{
|
||||
opencv_onnx::ModelProto model_proto;
|
||||
@ -75,7 +80,7 @@ public:
|
||||
void populateNet();
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ONNXLayerHandler> layerHandler;
|
||||
std::unique_ptr<detail::LayerHandler> missingLayerHandler;
|
||||
Net& dstNet;
|
||||
|
||||
opencv_onnx::GraphProto graph_proto;
|
||||
@ -93,13 +98,16 @@ protected:
|
||||
void handleNode(const opencv_onnx::NodeProto& node_proto);
|
||||
|
||||
private:
|
||||
friend class ONNXLayerHandler;
|
||||
typedef void (ONNXImporter::*ONNXImporterNodeParser)(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
typedef std::map<std::string, ONNXImporterNodeParser> DispatchMap;
|
||||
typedef std::map<std::string, DispatchMap> DomainDispatchMap;
|
||||
|
||||
const DispatchMap dispatch;
|
||||
static const DispatchMap buildDispatchMap();
|
||||
DomainDispatchMap domain_dispatch_map;
|
||||
void buildDispatchMap_ONNX_AI(int opset_version);
|
||||
void buildDispatchMap_COM_MICROSOFT(int opset_version);
|
||||
|
||||
// Domain: 'ai.onnx' (default)
|
||||
// URL: https://github.com/onnx/onnx/blob/master/docs/Operators.md
|
||||
void parseArg (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseMaxUnpool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseMaxPool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
@ -148,6 +156,9 @@ private:
|
||||
void parseSoftMax (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseDetectionOutput (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseCumSum (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
|
||||
// Domain: com.microsoft
|
||||
// URL: https://github.com/microsoft/onnxruntime/blob/master/docs/ContribOperators.md
|
||||
void parseQuantDequant (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseQConv (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseQMatMul (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
@ -157,43 +168,20 @@ private:
|
||||
void parseQAvgPool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseQConcat (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
|
||||
// '???' domain or '???' layer type
|
||||
void parseCustomLayer (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
|
||||
int onnx_opset; // OperatorSetIdProto for 'onnx' domain
|
||||
std::map<std::string, int> onnx_opset_map; // map from OperatorSetIdProto
|
||||
void parseOperatorSet();
|
||||
|
||||
const std::string str_domain_ai_onnx = "ai.onnx";
|
||||
};
|
||||
|
||||
class ONNXLayerHandler : public detail::LayerHandler
|
||||
{
|
||||
public:
|
||||
explicit ONNXLayerHandler(ONNXImporter* importer_);
|
||||
|
||||
void fillRegistry(const opencv_onnx::GraphProto& net);
|
||||
|
||||
protected:
|
||||
ONNXImporter* importer;
|
||||
};
|
||||
|
||||
ONNXLayerHandler::ONNXLayerHandler(ONNXImporter* importer_) : importer(importer_){}
|
||||
|
||||
void ONNXLayerHandler::fillRegistry(const opencv_onnx::GraphProto &net)
|
||||
{
|
||||
int layersSize = net.node_size();
|
||||
for (int li = 0; li < layersSize; li++) {
|
||||
const opencv_onnx::NodeProto &node_proto = net.node(li);
|
||||
const std::string& name = node_proto.output(0);
|
||||
const std::string& type = node_proto.op_type();
|
||||
if (importer->dispatch.find(type) == importer->dispatch.end())
|
||||
{
|
||||
addMissing(name, type);
|
||||
}
|
||||
}
|
||||
printMissing();
|
||||
}
|
||||
|
||||
ONNXImporter::ONNXImporter(Net& net, const char *onnxFile)
|
||||
: layerHandler(DNN_DIAGNOSTICS_RUN ? new ONNXLayerHandler(this) : nullptr)
|
||||
, dstNet(net), dispatch(buildDispatchMap())
|
||||
: missingLayerHandler(DNN_DIAGNOSTICS_RUN ? new detail::LayerHandler() : nullptr)
|
||||
, dstNet(net)
|
||||
, onnx_opset(0)
|
||||
{
|
||||
hasDynamicShapes = false;
|
||||
@ -215,8 +203,8 @@ ONNXImporter::ONNXImporter(Net& net, const char *onnxFile)
|
||||
}
|
||||
|
||||
ONNXImporter::ONNXImporter(Net& net, const char* buffer, size_t sizeBuffer)
|
||||
: layerHandler(DNN_DIAGNOSTICS_RUN ? new ONNXLayerHandler(this) : nullptr)
|
||||
, dstNet(net), dispatch(buildDispatchMap())
|
||||
: missingLayerHandler(DNN_DIAGNOSTICS_RUN ? new detail::LayerHandler() : nullptr)
|
||||
, dstNet(net)
|
||||
, onnx_opset(0)
|
||||
{
|
||||
hasDynamicShapes = false;
|
||||
@ -638,20 +626,37 @@ void ONNXImporter::parseOperatorSet()
|
||||
const ::opencv_onnx::OperatorSetIdProto& opset_entry = model_proto.opset_import(i);
|
||||
const std::string& domain = opset_entry.has_domain() ? opset_entry.domain() : std::string();
|
||||
int version = opset_entry.has_version() ? opset_entry.version() : -1;
|
||||
if (domain.empty() || domain == "ai.onnx")
|
||||
if (domain.empty() || domain == str_domain_ai_onnx)
|
||||
{
|
||||
// ONNX opset covered by specification: https://github.com/onnx/onnx/blob/master/docs/Operators.md
|
||||
onnx_opset = std::max(onnx_opset, version);
|
||||
onnx_opset_map[str_domain_ai_onnx] = onnx_opset;
|
||||
}
|
||||
else
|
||||
{
|
||||
// OpenCV don't know other opsets
|
||||
// will fail later on unsupported node processing
|
||||
CV_LOG_WARNING(NULL, "DNN/ONNX: unsupported opset[" << i << "]: domain='" << domain << "' version=" << version);
|
||||
CV_LOG_DEBUG(NULL, "DNN/ONNX: using non-standard ONNX opset[" << i << "]: domain='" << domain << "' version=" << version);
|
||||
onnx_opset_map[domain] = onnx_opset;
|
||||
}
|
||||
}
|
||||
|
||||
CV_LOG_INFO(NULL, "DNN/ONNX: ONNX opset version = " << onnx_opset);
|
||||
|
||||
buildDispatchMap_ONNX_AI(onnx_opset);
|
||||
for (const auto& pair : onnx_opset_map)
|
||||
{
|
||||
if (pair.first == str_domain_ai_onnx)
|
||||
{
|
||||
continue; // done above
|
||||
}
|
||||
else if (pair.first == "com.microsoft")
|
||||
{
|
||||
buildDispatchMap_COM_MICROSOFT(pair.second);
|
||||
}
|
||||
else
|
||||
{
|
||||
CV_LOG_INFO(NULL, "DNN/ONNX: unknown domain='" << pair.first << "' version=" << pair.second << ". No dispatch map, you may need to register 'custom' layers.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXImporter::handleQuantizedNode(LayerParams& layerParams,
|
||||
@ -790,7 +795,6 @@ void ONNXImporter::populateNet()
|
||||
|
||||
if (DNN_DIAGNOSTICS_RUN) {
|
||||
CV_LOG_INFO(NULL, "DNN/ONNX: start diagnostic run!");
|
||||
layerHandler->fillRegistry(graph_proto);
|
||||
}
|
||||
|
||||
for(int li = 0; li < layersSize; li++)
|
||||
@ -805,22 +809,52 @@ void ONNXImporter::populateNet()
|
||||
void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
CV_Assert(node_proto.output_size() >= 1);
|
||||
std::string name = node_proto.output(0);
|
||||
const std::string& name = node_proto.output(0);
|
||||
const std::string& layer_type = node_proto.op_type();
|
||||
const std::string& layer_type_domain = node_proto.has_domain() ? node_proto.domain() : std::string();
|
||||
if (!layer_type_domain.empty() && layer_type_domain != "ai.onnx")
|
||||
const std::string& layer_type_domain = [&]()
|
||||
{
|
||||
CV_LOG_WARNING(NULL, "DNN/ONNX: can't handle node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
|
||||
<< cv::format("[%s@%s]:(%s)", layer_type.c_str(), layer_type_domain.c_str(), name.c_str())
|
||||
);
|
||||
if (DNN_DIAGNOSTICS_RUN)
|
||||
return; // ignore error
|
||||
CV_Error(Error::StsNotImplemented, cv::format("ONNX: unsupported domain: %s", layer_type_domain.c_str()));
|
||||
}
|
||||
if (!node_proto.has_domain())
|
||||
return str_domain_ai_onnx;
|
||||
const std::string& domain = node_proto.domain();
|
||||
if (domain.empty())
|
||||
return str_domain_ai_onnx;
|
||||
return domain;
|
||||
}();
|
||||
const auto& dispatch = [&]()
|
||||
{
|
||||
if (layer_type_domain != str_domain_ai_onnx)
|
||||
{
|
||||
if (onnx_opset_map.find(layer_type_domain) == onnx_opset_map.end())
|
||||
{
|
||||
CV_LOG_WARNING(NULL, "DNN/ONNX: processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
|
||||
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
|
||||
<< " from undeclared domain='" << layer_type_domain << "'"
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
CV_LOG_DEBUG(NULL, "DNN/ONNX: processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
|
||||
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
|
||||
<< " from domain='" << layer_type_domain << "'"
|
||||
);
|
||||
}
|
||||
auto it = domain_dispatch_map.find(layer_type_domain);
|
||||
if (it == domain_dispatch_map.end())
|
||||
{
|
||||
CV_LOG_WARNING(NULL, "DNN/ONNX: missing dispatch map for domain='" << layer_type_domain << "'");
|
||||
return DispatchMap();
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
else
|
||||
{
|
||||
CV_LOG_DEBUG(NULL, "DNN/ONNX: processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
|
||||
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
|
||||
);
|
||||
return domain_dispatch_map[str_domain_ai_onnx];
|
||||
}
|
||||
}();
|
||||
|
||||
CV_LOG_DEBUG(NULL, "DNN/ONNX: processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
|
||||
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
|
||||
);
|
||||
LayerParams layerParams;
|
||||
try
|
||||
{
|
||||
@ -848,7 +882,9 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto)
|
||||
if (DNN_DIAGNOSTICS_RUN)
|
||||
{
|
||||
CV_LOG_ERROR(NULL, "DNN/ONNX: Potential problem during processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
|
||||
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str()) << "\n" << e.msg
|
||||
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
|
||||
<< " from domain='" << layer_type_domain << "'"
|
||||
<< "\n" << e.msg
|
||||
);
|
||||
cv::AutoLock lock(getLayerFactoryMutex());
|
||||
auto registeredLayers = getLayerFactoryImpl();
|
||||
@ -869,6 +905,7 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
CV_LOG_ERROR(NULL, "DNN/ONNX: ERROR during processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
|
||||
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
|
||||
<< " from domain='" << layer_type_domain << "'"
|
||||
);
|
||||
}
|
||||
for (int i = 0; i < node_proto.input_size(); i++)
|
||||
@ -888,7 +925,7 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto)
|
||||
}
|
||||
}
|
||||
else
|
||||
CV_Error(Error::StsError, cv::format("Node [%s]:(%s) parse error: %s", layer_type.c_str(), name.c_str(), e.what()));
|
||||
CV_Error(Error::StsError, cv::format("Node [%s@%s]:(%s) parse error: %s", layer_type.c_str(), layer_type_domain.c_str(), name.c_str(), e.what()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -2836,6 +2873,28 @@ void ONNXImporter::parseCumSum(LayerParams& layerParams, const opencv_onnx::Node
|
||||
|
||||
void ONNXImporter::parseCustomLayer(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
const std::string& name = layerParams.name;
|
||||
std::string& layer_type = layerParams.type;
|
||||
const std::string& layer_type_domain = node_proto.has_domain() ? node_proto.domain() : std::string();
|
||||
if (!layer_type_domain.empty() && layer_type_domain != str_domain_ai_onnx)
|
||||
{
|
||||
// append ONNX domain name
|
||||
static bool DNN_CUSTOM_ONNX_TYPE_INCLUDE_DOMAIN_NAME = utils::getConfigurationParameterBool("OPENCV_DNN_CUSTOM_ONNX_TYPE_INCLUDE_DOMAIN_NAME", true);
|
||||
if (DNN_CUSTOM_ONNX_TYPE_INCLUDE_DOMAIN_NAME)
|
||||
{
|
||||
layer_type = layer_type_domain + "." + layer_type;
|
||||
}
|
||||
}
|
||||
|
||||
CV_LOG_INFO(NULL, "DNN/ONNX: unknown node type, try using custom handler for node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
|
||||
<< cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
|
||||
);
|
||||
|
||||
if (missingLayerHandler)
|
||||
{
|
||||
missingLayerHandler->addMissing(layerParams.name, layerParams.type);
|
||||
}
|
||||
|
||||
for (int j = 0; j < node_proto.input_size(); j++) {
|
||||
if (layer_id.find(node_proto.input(j)) == layer_id.end())
|
||||
layerParams.blobs.push_back(getBlob(node_proto, j));
|
||||
@ -3233,8 +3292,11 @@ void ONNXImporter::parseQConcat(LayerParams& layerParams, const opencv_onnx::Nod
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
const ONNXImporter::DispatchMap ONNXImporter::buildDispatchMap()
|
||||
// Domain: ai.onnx (default)
|
||||
// URL: https://github.com/onnx/onnx/blob/master/docs/Operators.md
|
||||
void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version)
|
||||
{
|
||||
CV_UNUSED(opset_version);
|
||||
DispatchMap dispatch;
|
||||
|
||||
dispatch["ArgMax"] = dispatch["ArgMin"] = &ONNXImporter::parseArg;
|
||||
@ -3286,18 +3348,32 @@ const ONNXImporter::DispatchMap ONNXImporter::buildDispatchMap()
|
||||
dispatch["SoftMax"] = dispatch["LogSoftmax"] = &ONNXImporter::parseSoftMax;
|
||||
dispatch["DetectionOutput"] = &ONNXImporter::parseDetectionOutput;
|
||||
dispatch["CumSum"] = &ONNXImporter::parseCumSum;
|
||||
|
||||
// ai.onnx: opset 10+
|
||||
dispatch["QuantizeLinear"] = dispatch["DequantizeLinear"] = &ONNXImporter::parseQuantDequant;
|
||||
dispatch["QLinearConv"] = &ONNXImporter::parseQConv;
|
||||
dispatch["QLinearMatMul"] = &ONNXImporter::parseQMatMul;
|
||||
|
||||
domain_dispatch_map[str_domain_ai_onnx] = dispatch;
|
||||
}
|
||||
|
||||
// Domain: com.microsoft
|
||||
// URL: https://github.com/microsoft/onnxruntime/blob/master/docs/ContribOperators.md
|
||||
void ONNXImporter::buildDispatchMap_COM_MICROSOFT(int opset_version)
|
||||
{
|
||||
CV_UNUSED(opset_version);
|
||||
DispatchMap dispatch;
|
||||
|
||||
dispatch["QLinearAdd"] = dispatch["QLinearMul"] = &ONNXImporter::parseQEltwise;
|
||||
dispatch["QLinearAveragePool"] = dispatch["QLinearGlobalAveragePool"] = &ONNXImporter::parseQAvgPool;
|
||||
dispatch["QLinearLeakyRelu"] = &ONNXImporter::parseQLeakyRelu;
|
||||
dispatch["QLinearSigmoid"] = &ONNXImporter::parseQSigmoid;
|
||||
dispatch["QLinearAveragePool"] = dispatch["QLinearGlobalAveragePool"] = &ONNXImporter::parseQAvgPool;
|
||||
dispatch["QLinearConcat"] = &ONNXImporter::parseQConcat;
|
||||
|
||||
return dispatch;
|
||||
domain_dispatch_map["com.microsoft"] = dispatch;
|
||||
}
|
||||
|
||||
|
||||
Net readNetFromONNX(const String& onnxFile)
|
||||
{
|
||||
return detail::readNetDiagnostic<ONNXImporter>(onnxFile.c_str());
|
||||
|
@ -1489,16 +1489,6 @@ TEST_P(Test_ONNX_layers, DivConst)
|
||||
}
|
||||
|
||||
|
||||
// FIXIT disabled due to non-standard ONNX model domains, need to add ONNX domains support
|
||||
// Example:
|
||||
// DNN/ONNX: unsupported opset[1]: domain='com.microsoft.experimental' version=1
|
||||
// DNN/ONNX: unsupported opset[2]: domain='ai.onnx.preview.training' version=1
|
||||
// DNN/ONNX: unsupported opset[3]: domain='com.microsoft.nchwc' version=1
|
||||
// DNN/ONNX: unsupported opset[4]: domain='com.microsoft.mlfeaturizers' version=1
|
||||
// DNN/ONNX: unsupported opset[5]: domain='ai.onnx.ml' version=2
|
||||
// DNN/ONNX: unsupported opset[6]: domain='com.microsoft' version=1
|
||||
// DNN/ONNX: unsupported opset[7]: domain='ai.onnx.training' version=1
|
||||
#if 0
|
||||
TEST_P(Test_ONNX_layers, Quantized_Convolution)
|
||||
{
|
||||
testONNXModels("quantized_conv_uint8_weights", npy, 0.004, 0.02);
|
||||
@ -1604,7 +1594,6 @@ TEST_P(Test_ONNX_layers, Quantized_Constant)
|
||||
{
|
||||
testONNXModels("quantized_constant", npy, 0.002, 0.008);
|
||||
}
|
||||
#endif
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
|
||||
|
||||
@ -1749,8 +1738,7 @@ TEST_P(Test_ONNX_nets, ResNet50v1)
|
||||
testONNXModels("resnet50v1", pb, default_l1, default_lInf, true, target != DNN_TARGET_MYRIAD);
|
||||
}
|
||||
|
||||
// FIXIT missing ONNX domains support
|
||||
TEST_P(Test_ONNX_nets, DISABLED_ResNet50_Int8)
|
||||
TEST_P(Test_ONNX_nets, ResNet50_Int8)
|
||||
{
|
||||
testONNXModels("resnet50_int8", pb, default_l1, default_lInf, true);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user