Merge pull request #26324 from asmorkalov:as/model_diagnostics_engine

Added DNN engine selector to model diagnostics tool.
This commit is contained in:
Alexander Smorkalov 2024-10-21 15:51:53 +03:00 committed by GitHub
commit 6648482b69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -49,9 +49,11 @@ static std::vector<int> parseShape(const std::string &shape_str) {
}
std::string diagnosticKeys =
"{ help h | | Print help. }"
"{ model m | | Path to the model file. }"
"{ config c | | Path to the model configuration file. }"
"{ framework f | | [Optional] Name of the model framework. }"
"{ engine e | auto | [Optional] Graph negine selector: auto or classic or new}"
"{ input0_name | | [Optional] Name of input0. Use with input0_shape}"
"{ input0_shape | | [Optional] Shape of input0. Use with input0_name}"
"{ input1_name | | [Optional] Name of input1. Use with input1_shape}"
@ -75,6 +77,12 @@ int main( int argc, const char** argv )
return 0;
}
if(argParser.has("help"))
{
argParser.printMessage();
return 0;
}
std::string model = checkFileExists(argParser.get<std::string>("model"));
std::string config = checkFileExists(argParser.get<std::string>("config"));
std::string frameworkId = argParser.get<std::string>("framework");
@ -90,13 +98,30 @@ int main( int argc, const char** argv )
std::string input4_name = argParser.get<std::string>("input4_name");
std::string input4_shape = argParser.get<std::string>("input4_shape");
dnn::EngineType engine = dnn::ENGINE_AUTO;
if (argParser.has("engine"))
{
std::string eng_name = argParser.get<std::string>("engine");
if(eng_name == "auto")
engine = dnn::ENGINE_AUTO;
else if(eng_name == "classic")
engine = dnn::ENGINE_CLASSIC;
else if(eng_name == "new")
engine = dnn::ENGINE_NEW;
else
{
std::cerr << "Unknown DNN graph engine \"" << eng_name << "\"\n";
return -1;
}
}
CV_Assert(!model.empty());
enableModelDiagnostics(true);
skipModelImport(true);
redirectError(diagnosticsErrorCallback, NULL);
Net ocvNet = readNet(model, config, frameworkId);
Net ocvNet = readNet(model, config, frameworkId, engine);
std::vector<std::string> input_names;
std::vector<std::vector<int>> input_shapes;