opencv/apps/model-diagnostics/model_diagnostics.cpp
Yuantao Feng bc0618b688
Merge pull request #25582 from fengyuentau:dnn/dump_pbtxt
Current net exporter `dump` and `dumpToFile` exports the network structure (and its params) to a .dot file which works with `graphviz`. This is hard to use and not friendly to new user. What's worse, the produced picture is not looking pretty.
dnn: better net exporter that works with netron #25582

This PR introduces new exporter `dumpToPbtxt` and uses this new exporter by default with environment variable `OPENCV_DNN_NETWORK_DUMP`. It mimics the string output of a onnx model but modified with dnn-specific changes, see below for an example.

![image](https://github.com/opencv/opencv/assets/17219438/0644bed1-da71-4019-8466-88390698e4df)

## Usage

Call `cv::dnn::Net::dumpToPbtxt`:

```cpp
TEST(DumpNet, dumpToPbtxt) {
    std::string path = "/path/to/model.onnx";
    auto net = readNet(path);

    Mat input(std::vector<int>{1, 3, 640, 480}, CV_32F);
    net.setInput(input);

    net.dumpToPbtxt("yunet.pbtxt");
}
```

Set `export OPENCV_DNN_NETWORK_DUMP=1`

```cpp
TEST(DumpNet, env) {
    std::string path = "/path/to/model.onnx";
    auto net = readNet(path);

    Mat input(std::vector<int>{1, 3, 640, 480}, CV_32F);
    net.setInput(input);

    net.forward();
}
```

---

Note:
- `pbtxt` is registered as one of the ONNX model suffix in netron. So you can see `module: ai.onnx` and such in the model.
- We can get the string output of an ONNX model with the following script

```python
import onnx
net = onnx.load("/path/to/model.onnx")
net_str = str(net)
file = open("/path/to/model.pbtxt", "w")
file.write(net_str)
file.close()
```

### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [ ] There is a reference to the original bug report and related work
- [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
2024-05-17 11:07:05 +03:00

148 lines
5.9 KiB
C++

/*************************************************
USAGE:
./model_diagnostics -m <model file location>
**************************************************/
#include <opencv2/dnn.hpp>
#include <opencv2/core/utils/filesystem.hpp>
#include <opencv2/dnn/utils/debug_utils.hpp>
#include <iostream>
using namespace cv;
using namespace dnn;
static
int diagnosticsErrorCallback(int /*status*/, const char* /*func_name*/,
const char* /*err_msg*/, const char* /*file_name*/,
int /*line*/, void* /*userdata*/)
{
fflush(stdout);
fflush(stderr);
return 0;
}
static std::string checkFileExists(const std::string& fileName)
{
if (fileName.empty() || utils::fs::exists(fileName))
return fileName;
CV_Error(Error::StsObjectNotFound, "File " + fileName + " was not found! "
"Please, specify a full path to the file.");
}
static std::vector<int> parseShape(const std::string &shape_str) {
std::stringstream ss(shape_str);
std::string item;
std::vector<std::string> items;
while (std::getline(ss, item, ',')) {
items.push_back(item);
}
std::vector<int> shape;
for (size_t i = 0; i < items.size(); i++) {
shape.push_back(std::stoi(items[i]));
}
return shape;
}
std::string diagnosticKeys =
"{ model m | | Path to the model file. }"
"{ config c | | Path to the model configuration file. }"
"{ framework f | | [Optional] Name of the model framework. }"
"{ input0_name | | [Optional] Name of input0. Use with input0_shape}"
"{ input0_shape | | [Optional] Shape of input0. Use with input0_name}"
"{ input1_name | | [Optional] Name of input1. Use with input1_shape}"
"{ input1_shape | | [Optional] Shape of input1. Use with input1_name}"
"{ input2_name | | [Optional] Name of input2. Use with input2_shape}"
"{ input2_shape | | [Optional] Shape of input2. Use with input2_name}"
"{ input3_name | | [Optional] Name of input3. Use with input3_shape}"
"{ input3_shape | | [Optional] Shape of input3. Use with input3_name}"
"{ input4_name | | [Optional] Name of input4. Use with input4_shape}"
"{ input4_shape | | [Optional] Shape of input4. Use with input4_name}";
int main( int argc, const char** argv )
{
CommandLineParser argParser(argc, argv, diagnosticKeys);
argParser.about("Use this tool to run the diagnostics of provided ONNX/TF model"
"to obtain the information about its support (supported layers).");
if (argc == 1)
{
argParser.printMessage();
return 0;
}
std::string model = checkFileExists(argParser.get<std::string>("model"));
std::string config = checkFileExists(argParser.get<std::string>("config"));
std::string frameworkId = argParser.get<std::string>("framework");
std::string input0_name = argParser.get<std::string>("input0_name");
std::string input0_shape = argParser.get<std::string>("input0_shape");
std::string input1_name = argParser.get<std::string>("input1_name");
std::string input1_shape = argParser.get<std::string>("input1_shape");
std::string input2_name = argParser.get<std::string>("input2_name");
std::string input2_shape = argParser.get<std::string>("input2_shape");
std::string input3_name = argParser.get<std::string>("input3_name");
std::string input3_shape = argParser.get<std::string>("input3_shape");
std::string input4_name = argParser.get<std::string>("input4_name");
std::string input4_shape = argParser.get<std::string>("input4_shape");
CV_Assert(!model.empty());
enableModelDiagnostics(true);
skipModelImport(true);
redirectError(diagnosticsErrorCallback, NULL);
Net ocvNet = readNet(model, config, frameworkId);
std::vector<std::string> input_names;
std::vector<std::vector<int>> input_shapes;
if (!input0_name.empty() || !input0_shape.empty()) {
CV_CheckFalse(input0_name.empty(), "input0_name cannot be empty");
CV_CheckFalse(input0_shape.empty(), "input0_shape cannot be empty");
input_names.push_back(input0_name);
input_shapes.push_back(parseShape(input0_shape));
}
if (!input1_name.empty() || !input1_shape.empty()) {
CV_CheckFalse(input1_name.empty(), "input1_name cannot be empty");
CV_CheckFalse(input1_shape.empty(), "input1_shape cannot be empty");
input_names.push_back(input1_name);
input_shapes.push_back(parseShape(input1_shape));
}
if (!input2_name.empty() || !input2_shape.empty()) {
CV_CheckFalse(input2_name.empty(), "input2_name cannot be empty");
CV_CheckFalse(input2_shape.empty(), "input2_shape cannot be empty");
input_names.push_back(input2_name);
input_shapes.push_back(parseShape(input2_shape));
}
if (!input3_name.empty() || !input3_shape.empty()) {
CV_CheckFalse(input3_name.empty(), "input3_name cannot be empty");
CV_CheckFalse(input3_shape.empty(), "input3_shape cannot be empty");
input_names.push_back(input3_name);
input_shapes.push_back(parseShape(input3_shape));
}
if (!input4_name.empty() || !input4_shape.empty()) {
CV_CheckFalse(input4_name.empty(), "input4_name cannot be empty");
CV_CheckFalse(input4_shape.empty(), "input4_shape cannot be empty");
input_names.push_back(input4_name);
input_shapes.push_back(parseShape(input4_shape));
}
if (!input_names.empty() && !input_shapes.empty() && input_names.size() == input_shapes.size()) {
ocvNet.setInputsNames(input_names);
for (size_t i = 0; i < input_names.size(); i++) {
Mat input(input_shapes[i], CV_32F);
ocvNet.setInput(input, input_names[i]);
}
size_t dot_index = model.rfind('.');
std::string graph_filename = model.substr(0, dot_index) + ".pbtxt";
ocvNet.dumpToPbtxt(graph_filename);
}
return 0;
}