mirror of
https://github.com/opencv/opencv.git
synced 2024-12-04 00:39:11 +08:00
skipping missing layers and layer failures
This commit is contained in:
parent
25f908b320
commit
dc5199feea
@ -1,6 +1,6 @@
|
||||
/*************************************************
|
||||
USAGE:
|
||||
./model_diagnostics -m <onnx file location>
|
||||
./model_diagnostics -m <model file location>
|
||||
**************************************************/
|
||||
#include <opencv2/dnn.hpp>
|
||||
#include <opencv2/core/utils/filesystem.hpp>
|
||||
@ -32,7 +32,7 @@ static std::string checkFileExists(const std::string& fileName)
|
||||
}
|
||||
|
||||
std::string diagnosticKeys =
|
||||
"{ model m | | Path to the model .onnx file. }"
|
||||
"{ model m | | Path to the model file. }"
|
||||
"{ config c | | Path to the model configuration file. }"
|
||||
"{ framework f | | [Optional] Name of the model framework. }";
|
||||
|
||||
@ -41,7 +41,7 @@ std::string diagnosticKeys =
|
||||
int main( int argc, const char** argv )
|
||||
{
|
||||
CommandLineParser argParser(argc, argv, diagnosticKeys);
|
||||
argParser.about("Use this tool to run the diagnostics of provided ONNX model"
|
||||
argParser.about("Use this tool to run the diagnostics of provided ONNX/TF model"
|
||||
"to obtain the information about its support (supported layers).");
|
||||
|
||||
if (argc == 1)
|
||||
|
@ -32,6 +32,8 @@ namespace cv {
|
||||
namespace dnn {
|
||||
CV__DNN_INLINE_NS_BEGIN
|
||||
|
||||
extern bool DNN_DIAGNOSTICS_RUN;
|
||||
|
||||
#if HAVE_PROTOBUF
|
||||
|
||||
using ::google::protobuf::RepeatedField;
|
||||
@ -471,6 +473,7 @@ public:
|
||||
TFImporter(Net& net, const char *dataModel, size_t lenModel,
|
||||
const char *dataConfig = NULL, size_t lenConfig = 0);
|
||||
protected:
|
||||
std::unique_ptr<Net> utilNet;
|
||||
Net& dstNet;
|
||||
void populateNet();
|
||||
|
||||
@ -2337,7 +2340,8 @@ void TFImporter::parseCustomLayer(tensorflow::GraphDef& net, const tensorflow::N
|
||||
}
|
||||
|
||||
TFImporter::TFImporter(Net& net, const char *model, const char *config)
|
||||
: dstNet(net), dispatch(buildDispatchMap())
|
||||
: utilNet(DNN_DIAGNOSTICS_RUN ? new Net : nullptr),
|
||||
dstNet(DNN_DIAGNOSTICS_RUN ? *utilNet : net), dispatch(buildDispatchMap())
|
||||
{
|
||||
if (model && model[0])
|
||||
{
|
||||
@ -2358,7 +2362,8 @@ TFImporter::TFImporter(
|
||||
const char *dataModel, size_t lenModel,
|
||||
const char *dataConfig, size_t lenConfig
|
||||
)
|
||||
: dstNet(net), dispatch(buildDispatchMap())
|
||||
: utilNet(DNN_DIAGNOSTICS_RUN ? new Net : nullptr),
|
||||
dstNet(DNN_DIAGNOSTICS_RUN ? *utilNet : net), dispatch(buildDispatchMap())
|
||||
{
|
||||
if (dataModel != NULL && lenModel > 0)
|
||||
{
|
||||
@ -2615,6 +2620,11 @@ DataLayout TFImporter::predictOutputDataLayout(const tensorflow::NodeDef& layer)
|
||||
return it->second;
|
||||
}
|
||||
|
||||
Ptr<Layer> dummy_constructor(LayerParams & params)
|
||||
{
|
||||
return new Layer(params);
|
||||
}
|
||||
|
||||
void TFImporter::populateNet()
|
||||
{
|
||||
CV_Assert(netBin.ByteSize() || netTxt.ByteSize());
|
||||
@ -2757,9 +2767,9 @@ void TFImporter::parseNode(const tensorflow::NodeDef& layer)
|
||||
const std::string& name = layer.name();
|
||||
const std::string& type = layer.op();
|
||||
|
||||
LayerParams layerParams;
|
||||
try
|
||||
{
|
||||
LayerParams layerParams;
|
||||
|
||||
if (layers_to_ignore.find(name) != layers_to_ignore.end())
|
||||
{
|
||||
@ -2777,14 +2787,37 @@ void TFImporter::parseNode(const tensorflow::NodeDef& layer)
|
||||
}
|
||||
else
|
||||
{
|
||||
if (DNN_DIAGNOSTICS_RUN && !LayerFactory::createLayerInstance(type, layerParams))
|
||||
{
|
||||
CV_LOG_ERROR(NULL, "DNN/TF: Node='" << name << "' of type='"<< type
|
||||
<< "' is not supported. This error won't be displayed again.");
|
||||
LayerFactory::registerLayer(type, dummy_constructor);
|
||||
}
|
||||
|
||||
parseCustomLayer(net, layer, layerParams);
|
||||
}
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
CV_LOG_ERROR(NULL, "DNN/TF: Can't parse layer for node='" << name << "'. Exception: " << e.what());
|
||||
if (!DNN_DIAGNOSTICS_RUN)
|
||||
{
|
||||
CV_LOG_ERROR(NULL, "DNN/TF: Can't parse layer for node='" << name << "' of type='" << type
|
||||
<< "'. Exception: " << e.what());
|
||||
throw;
|
||||
}
|
||||
else
|
||||
{
|
||||
CV_LOG_ERROR(NULL, "DNN/TF: Can't parse layer for node='" << name << "' of type='" << type
|
||||
<< "'. Exception: " << e.what());
|
||||
|
||||
// internal layer failure (didnt call addLayer)
|
||||
if (dstNet.getLayerId(name) == -1)
|
||||
{
|
||||
int id = dstNet.addLayer(name, type, layerParams);
|
||||
layer_id[name] = id;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user