Avoid copy of ONNX graph during import

pull/25163/head
Dmitry Kurtaev 1 year ago
parent b34ec57682
commit 98aed21dd4
  1. 44
      modules/dnn/src/onnx/onnx_importer.cpp

@ -112,7 +112,7 @@ protected:
std::unique_ptr<ONNXLayerHandler> layerHandler; std::unique_ptr<ONNXLayerHandler> layerHandler;
Net& dstNet; Net& dstNet;
opencv_onnx::GraphProto graph_proto; opencv_onnx::GraphProto* graph_proto;
std::string framework_name; std::string framework_name;
std::map<std::string, Mat> constBlobs; std::map<std::string, Mat> constBlobs;
@ -787,7 +787,7 @@ void ONNXImporter::setParamsDtype(LayerParams& layerParams, const opencv_onnx::N
void ONNXImporter::populateNet() void ONNXImporter::populateNet()
{ {
CV_Assert(model_proto.has_graph()); CV_Assert(model_proto.has_graph());
graph_proto = model_proto.graph(); graph_proto = model_proto.mutable_graph();
std::string framework_version; std::string framework_version;
if (model_proto.has_producer_name()) if (model_proto.has_producer_name())
@ -799,25 +799,25 @@ void ONNXImporter::populateNet()
<< (model_proto.has_ir_version() ? cv::format(" v%d", (int)model_proto.ir_version()) : cv::String()) << (model_proto.has_ir_version() ? cv::format(" v%d", (int)model_proto.ir_version()) : cv::String())
<< " model produced by '" << framework_name << "'" << " model produced by '" << framework_name << "'"
<< (framework_version.empty() ? cv::String() : cv::format(":%s", framework_version.c_str())) << (framework_version.empty() ? cv::String() : cv::format(":%s", framework_version.c_str()))
<< ". Number of nodes = " << graph_proto.node_size() << ". Number of nodes = " << graph_proto->node_size()
<< ", initializers = " << graph_proto.initializer_size() << ", initializers = " << graph_proto->initializer_size()
<< ", inputs = " << graph_proto.input_size() << ", inputs = " << graph_proto->input_size()
<< ", outputs = " << graph_proto.output_size() << ", outputs = " << graph_proto->output_size()
); );
parseOperatorSet(); parseOperatorSet();
simplifySubgraphs(graph_proto); simplifySubgraphs(*graph_proto);
const int layersSize = graph_proto.node_size(); const int layersSize = graph_proto->node_size();
CV_LOG_DEBUG(NULL, "DNN/ONNX: graph simplified to " << layersSize << " nodes"); CV_LOG_DEBUG(NULL, "DNN/ONNX: graph simplified to " << layersSize << " nodes");
constBlobs = getGraphTensors(graph_proto); // scan GraphProto.initializer constBlobs = getGraphTensors(*graph_proto); // scan GraphProto.initializer
std::vector<String> netInputs; // map with network inputs (without const blobs) std::vector<String> netInputs; // map with network inputs (without const blobs)
// Add all the inputs shapes. It includes as constant blobs as network's inputs shapes. // Add all the inputs shapes. It includes as constant blobs as network's inputs shapes.
for (int i = 0; i < graph_proto.input_size(); ++i) for (int i = 0; i < graph_proto->input_size(); ++i)
{ {
const opencv_onnx::ValueInfoProto& valueInfoProto = graph_proto.input(i); const opencv_onnx::ValueInfoProto& valueInfoProto = graph_proto->input(i);
CV_Assert(valueInfoProto.has_name()); CV_Assert(valueInfoProto.has_name());
const std::string& name = valueInfoProto.name(); const std::string& name = valueInfoProto.name();
CV_Assert(valueInfoProto.has_type()); CV_Assert(valueInfoProto.has_type());
@ -873,26 +873,26 @@ void ONNXImporter::populateNet()
} }
// dump outputs // dump outputs
for (int i = 0; i < graph_proto.output_size(); ++i) for (int i = 0; i < graph_proto->output_size(); ++i)
{ {
dumpValueInfoProto(i, graph_proto.output(i), "output"); dumpValueInfoProto(i, graph_proto->output(i), "output");
} }
if (DNN_DIAGNOSTICS_RUN) { if (DNN_DIAGNOSTICS_RUN) {
CV_LOG_INFO(NULL, "DNN/ONNX: start diagnostic run!"); CV_LOG_INFO(NULL, "DNN/ONNX: start diagnostic run!");
layerHandler->fillRegistry(graph_proto); layerHandler->fillRegistry(*graph_proto);
} }
for(int li = 0; li < layersSize; li++) for(int li = 0; li < layersSize; li++)
{ {
const opencv_onnx::NodeProto& node_proto = graph_proto.node(li); const opencv_onnx::NodeProto& node_proto = graph_proto->node(li);
handleNode(node_proto); handleNode(node_proto);
} }
// register outputs // register outputs
for (int i = 0; i < graph_proto.output_size(); ++i) for (int i = 0; i < graph_proto->output_size(); ++i)
{ {
const std::string& output_name = graph_proto.output(i).name(); const std::string& output_name = graph_proto->output(i).name();
if (output_name.empty()) if (output_name.empty())
{ {
CV_LOG_ERROR(NULL, "DNN/ONNX: can't register output without name: " << i); CV_LOG_ERROR(NULL, "DNN/ONNX: can't register output without name: " << i);
@ -3180,9 +3180,9 @@ void ONNXImporter::parseLayerNorm(LayerParams& layerParams, const opencv_onnx::N
{ {
// remove from graph proto // remove from graph proto
for (size_t i = 1; i < node_proto.output_size(); i++) { for (size_t i = 1; i < node_proto.output_size(); i++) {
for (int j = graph_proto.output_size() - 1; j >= 0; j--) { for (int j = graph_proto->output_size() - 1; j >= 0; j--) {
if (graph_proto.output(j).name() == node_proto.output(i)) { if (graph_proto->output(j).name() == node_proto.output(i)) {
graph_proto.mutable_output()->DeleteSubrange(j, 1); graph_proto->mutable_output()->DeleteSubrange(j, 1);
break; break;
} }
} }
@ -3683,9 +3683,9 @@ void ONNXImporter::parseQEltwise(LayerParams& layerParams, const opencv_onnx::No
layerParams.type = "ScaleInt8"; layerParams.type = "ScaleInt8";
layerParams.set("bias_term", op == "sum"); layerParams.set("bias_term", op == "sum");
int axis = 1; int axis = 1;
for (int i = 0; i < graph_proto.initializer_size(); i++) for (int i = 0; i < graph_proto->initializer_size(); i++)
{ {
opencv_onnx::TensorProto tensor_proto = graph_proto.initializer(i); opencv_onnx::TensorProto tensor_proto = graph_proto->initializer(i);
if (tensor_proto.name() == node_proto.input(constId)) if (tensor_proto.name() == node_proto.input(constId))
{ {
axis = inpShape.size() - tensor_proto.dims_size(); axis = inpShape.size() - tensor_proto.dims_size();

Loading…
Cancel
Save