mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
Merge pull request #25163 from dkurt:onnx_graph_ref
Avoid copy of ONNX graph during import
This commit is contained in:
commit
66ff71085e
@ -112,7 +112,7 @@ protected:
|
|||||||
std::unique_ptr<ONNXLayerHandler> layerHandler;
|
std::unique_ptr<ONNXLayerHandler> layerHandler;
|
||||||
Net& dstNet;
|
Net& dstNet;
|
||||||
|
|
||||||
opencv_onnx::GraphProto graph_proto;
|
opencv_onnx::GraphProto* graph_proto;
|
||||||
std::string framework_name;
|
std::string framework_name;
|
||||||
|
|
||||||
std::map<std::string, Mat> constBlobs;
|
std::map<std::string, Mat> constBlobs;
|
||||||
@ -787,7 +787,7 @@ void ONNXImporter::setParamsDtype(LayerParams& layerParams, const opencv_onnx::N
|
|||||||
void ONNXImporter::populateNet()
|
void ONNXImporter::populateNet()
|
||||||
{
|
{
|
||||||
CV_Assert(model_proto.has_graph());
|
CV_Assert(model_proto.has_graph());
|
||||||
graph_proto = model_proto.graph();
|
graph_proto = model_proto.mutable_graph();
|
||||||
|
|
||||||
std::string framework_version;
|
std::string framework_version;
|
||||||
if (model_proto.has_producer_name())
|
if (model_proto.has_producer_name())
|
||||||
@ -799,25 +799,25 @@ void ONNXImporter::populateNet()
|
|||||||
<< (model_proto.has_ir_version() ? cv::format(" v%d", (int)model_proto.ir_version()) : cv::String())
|
<< (model_proto.has_ir_version() ? cv::format(" v%d", (int)model_proto.ir_version()) : cv::String())
|
||||||
<< " model produced by '" << framework_name << "'"
|
<< " model produced by '" << framework_name << "'"
|
||||||
<< (framework_version.empty() ? cv::String() : cv::format(":%s", framework_version.c_str()))
|
<< (framework_version.empty() ? cv::String() : cv::format(":%s", framework_version.c_str()))
|
||||||
<< ". Number of nodes = " << graph_proto.node_size()
|
<< ". Number of nodes = " << graph_proto->node_size()
|
||||||
<< ", initializers = " << graph_proto.initializer_size()
|
<< ", initializers = " << graph_proto->initializer_size()
|
||||||
<< ", inputs = " << graph_proto.input_size()
|
<< ", inputs = " << graph_proto->input_size()
|
||||||
<< ", outputs = " << graph_proto.output_size()
|
<< ", outputs = " << graph_proto->output_size()
|
||||||
);
|
);
|
||||||
|
|
||||||
parseOperatorSet();
|
parseOperatorSet();
|
||||||
|
|
||||||
simplifySubgraphs(graph_proto);
|
simplifySubgraphs(*graph_proto);
|
||||||
|
|
||||||
const int layersSize = graph_proto.node_size();
|
const int layersSize = graph_proto->node_size();
|
||||||
CV_LOG_DEBUG(NULL, "DNN/ONNX: graph simplified to " << layersSize << " nodes");
|
CV_LOG_DEBUG(NULL, "DNN/ONNX: graph simplified to " << layersSize << " nodes");
|
||||||
|
|
||||||
constBlobs = getGraphTensors(graph_proto); // scan GraphProto.initializer
|
constBlobs = getGraphTensors(*graph_proto); // scan GraphProto.initializer
|
||||||
std::vector<String> netInputs; // map with network inputs (without const blobs)
|
std::vector<String> netInputs; // map with network inputs (without const blobs)
|
||||||
// Add all the inputs shapes. It includes as constant blobs as network's inputs shapes.
|
// Add all the inputs shapes. It includes as constant blobs as network's inputs shapes.
|
||||||
for (int i = 0; i < graph_proto.input_size(); ++i)
|
for (int i = 0; i < graph_proto->input_size(); ++i)
|
||||||
{
|
{
|
||||||
const opencv_onnx::ValueInfoProto& valueInfoProto = graph_proto.input(i);
|
const opencv_onnx::ValueInfoProto& valueInfoProto = graph_proto->input(i);
|
||||||
CV_Assert(valueInfoProto.has_name());
|
CV_Assert(valueInfoProto.has_name());
|
||||||
const std::string& name = valueInfoProto.name();
|
const std::string& name = valueInfoProto.name();
|
||||||
CV_Assert(valueInfoProto.has_type());
|
CV_Assert(valueInfoProto.has_type());
|
||||||
@ -873,26 +873,26 @@ void ONNXImporter::populateNet()
|
|||||||
}
|
}
|
||||||
|
|
||||||
// dump outputs
|
// dump outputs
|
||||||
for (int i = 0; i < graph_proto.output_size(); ++i)
|
for (int i = 0; i < graph_proto->output_size(); ++i)
|
||||||
{
|
{
|
||||||
dumpValueInfoProto(i, graph_proto.output(i), "output");
|
dumpValueInfoProto(i, graph_proto->output(i), "output");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (DNN_DIAGNOSTICS_RUN) {
|
if (DNN_DIAGNOSTICS_RUN) {
|
||||||
CV_LOG_INFO(NULL, "DNN/ONNX: start diagnostic run!");
|
CV_LOG_INFO(NULL, "DNN/ONNX: start diagnostic run!");
|
||||||
layerHandler->fillRegistry(graph_proto);
|
layerHandler->fillRegistry(*graph_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
for(int li = 0; li < layersSize; li++)
|
for(int li = 0; li < layersSize; li++)
|
||||||
{
|
{
|
||||||
const opencv_onnx::NodeProto& node_proto = graph_proto.node(li);
|
const opencv_onnx::NodeProto& node_proto = graph_proto->node(li);
|
||||||
handleNode(node_proto);
|
handleNode(node_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
// register outputs
|
// register outputs
|
||||||
for (int i = 0; i < graph_proto.output_size(); ++i)
|
for (int i = 0; i < graph_proto->output_size(); ++i)
|
||||||
{
|
{
|
||||||
const std::string& output_name = graph_proto.output(i).name();
|
const std::string& output_name = graph_proto->output(i).name();
|
||||||
if (output_name.empty())
|
if (output_name.empty())
|
||||||
{
|
{
|
||||||
CV_LOG_ERROR(NULL, "DNN/ONNX: can't register output without name: " << i);
|
CV_LOG_ERROR(NULL, "DNN/ONNX: can't register output without name: " << i);
|
||||||
@ -3180,9 +3180,9 @@ void ONNXImporter::parseLayerNorm(LayerParams& layerParams, const opencv_onnx::N
|
|||||||
{
|
{
|
||||||
// remove from graph proto
|
// remove from graph proto
|
||||||
for (size_t i = 1; i < node_proto.output_size(); i++) {
|
for (size_t i = 1; i < node_proto.output_size(); i++) {
|
||||||
for (int j = graph_proto.output_size() - 1; j >= 0; j--) {
|
for (int j = graph_proto->output_size() - 1; j >= 0; j--) {
|
||||||
if (graph_proto.output(j).name() == node_proto.output(i)) {
|
if (graph_proto->output(j).name() == node_proto.output(i)) {
|
||||||
graph_proto.mutable_output()->DeleteSubrange(j, 1);
|
graph_proto->mutable_output()->DeleteSubrange(j, 1);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3683,9 +3683,9 @@ void ONNXImporter::parseQEltwise(LayerParams& layerParams, const opencv_onnx::No
|
|||||||
layerParams.type = "ScaleInt8";
|
layerParams.type = "ScaleInt8";
|
||||||
layerParams.set("bias_term", op == "sum");
|
layerParams.set("bias_term", op == "sum");
|
||||||
int axis = 1;
|
int axis = 1;
|
||||||
for (int i = 0; i < graph_proto.initializer_size(); i++)
|
for (int i = 0; i < graph_proto->initializer_size(); i++)
|
||||||
{
|
{
|
||||||
opencv_onnx::TensorProto tensor_proto = graph_proto.initializer(i);
|
opencv_onnx::TensorProto tensor_proto = graph_proto->initializer(i);
|
||||||
if (tensor_proto.name() == node_proto.input(constId))
|
if (tensor_proto.name() == node_proto.input(constId))
|
||||||
{
|
{
|
||||||
axis = inpShape.size() - tensor_proto.dims_size();
|
axis = inpShape.size() - tensor_proto.dims_size();
|
||||||
|
Loading…
Reference in New Issue
Block a user