add optional outputs support and fix graph links

pull/21490/head
Smirnov Egor 3 years ago
parent 8f4473b3e3
commit 17b2d92a3d
  1. 82
      modules/dnn/src/onnx/onnx_importer.cpp

@ -455,7 +455,11 @@ void ONNXImporter::addLayer(LayerParams& layerParams,
int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
for (int i = 0; i < node_proto.output_size(); ++i)
{
layer_id.insert(std::make_pair(node_proto.output(i), LayerInfo(id, i)));
const std::string& output_name = node_proto.output(i);
if (!output_name.empty())
{
layer_id.insert(std::make_pair(output_name, LayerInfo(id, i)));
}
}
std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
@ -478,7 +482,11 @@ void ONNXImporter::addLayer(LayerParams& layerParams,
layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
for (int i = 0; i < node_proto.output_size() && i < (int)layerOutShapes.size(); ++i)
{
outShapes[node_proto.output(i)] = layerOutShapes[i];
const std::string& output_name = node_proto.output(i);
if (!output_name.empty())
{
outShapes[node_proto.output(i)] = layerOutShapes[i];
}
}
}
@ -678,10 +686,30 @@ void ONNXImporter::populateNet()
CV_LOG_DEBUG(NULL, "DNN/ONNX: import completed!");
}
const std::string& extractNodeName(const opencv_onnx::NodeProto& node_proto)
{
if (node_proto.has_name() && !node_proto.name().empty())
{
return node_proto.name();
}
for (int i = 0; i < node_proto.output_size(); ++i)
{
const std::string& name = node_proto.output(i);
// There are two ways to leave an optional input or output unspecified:
// the first, available only for trailing inputs and outputs, is to simply not provide that input;
// the second method is to use an empty string in place of an input or output name.
if (!name.empty())
{
return name;
}
}
CV_Error(Error::StsAssert, "Couldn't deduce Node name.");
}
void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto)
{
CV_Assert(node_proto.output_size() >= 1);
std::string name = node_proto.output(0);
const std::string& name = extractNodeName(node_proto);
const std::string& layer_type = node_proto.op_type();
const std::string& layer_type_domain = node_proto.has_domain() ? node_proto.domain() : std::string();
if (!layer_type_domain.empty() && layer_type_domain != "ai.onnx")
@ -802,6 +830,7 @@ void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::Node
{
opencv_onnx::NodeProto node_proto = node_proto_;
const std::string& layer_type = node_proto.op_type();
const std::string output_name = node_proto.output(0);
CV_Assert(node_proto.input_size() == 1);
layerParams.type = "Pooling";
@ -922,7 +951,7 @@ void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::Node
layerParams.set("dim", DictValue::arrayInt(&targetShape[0], targetShape.size()));
node_proto.set_input(0, node_proto.output(0));
node_proto.set_output(0, layerParams.name);
node_proto.set_output(0, output_name);
}
else if (!layerParams.has("axes") && (layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax"))
{
@ -955,7 +984,7 @@ void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::Node
layerParams.set("dim", DictValue::arrayInt(targetShape.data(), targetShape.size()));
node_proto.set_input(0, node_proto.output(0));
node_proto.set_output(0, layerParams.name);
node_proto.set_output(0, output_name);
}
addLayer(layerParams, node_proto);
}
@ -1045,7 +1074,7 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP
{
Mat flipped;
flip(inp, flipped, 0);
addConstant(layerParams.name, flipped);
addConstant(node_proto.output(0), flipped);
return;
}
}
@ -1065,7 +1094,7 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP
inputs.push_back(inp);
runLayer(layerParams, inputs, sliced);
CV_Assert(sliced.size() == 1);
addConstant(layerParams.name, sliced[0]);
addConstant(node_proto.output(0), sliced[0]);
return;
}
addLayer(layerParams, node_proto);
@ -1130,7 +1159,7 @@ void ONNXImporter::parseBias(LayerParams& layerParams, const opencv_onnx::NodePr
Mat blob_1 = getBlob(node_proto, 1);
CV_Assert(blob_0.size == blob_1.size);
Mat output = isSub ? (blob_0 - blob_1) : (blob_0 + blob_1);
addConstant(layerParams.name, output);
addConstant(node_proto.output(0), output);
return;
}
else if (is_const_0 || is_const_1)
@ -1244,12 +1273,13 @@ void ONNXImporter::parseConstant(LayerParams& layerParams, const opencv_onnx::No
{
CV_Assert(node_proto.input_size() == 0);
CV_Assert(layerParams.blobs.size() == 1);
addConstant(layerParams.name, layerParams.blobs[0]);
addConstant(node_proto.output(0), layerParams.blobs[0]);
}
void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
const std::string output_name = node_proto.output(0);
LayerParams lstmParams = layerParams;
lstmParams.name += "/lstm";
@ -1331,7 +1361,7 @@ void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodePr
layerParams.type = "Reshape";
layerParams.set("dim", DictValue::arrayInt(&lstmShape[0], lstmShape.size()));
node_proto.set_input(0, lstmParams.name); // redirect input to LSTM
node_proto.set_output(0, layerParams.name); // keep origin LSTM's name
node_proto.set_output(0, output_name); // keep origin LSTM's name
addLayer(layerParams, node_proto);
}
@ -1573,6 +1603,7 @@ void ONNXImporter::parseMul(LayerParams& layerParams, const opencv_onnx::NodePro
{
opencv_onnx::NodeProto node_proto = node_proto_;
const std::string& layer_type = node_proto.op_type();
const std::string output_name = node_proto.output(0);
CV_Assert(node_proto.input_size() == 2);
bool isDiv = layer_type == "Div";
@ -1657,7 +1688,7 @@ void ONNXImporter::parseMul(LayerParams& layerParams, const opencv_onnx::NodePro
if (inp0.dims == 1 && inp1.dims == 1)
out.dims = 1; // to workaround dims == 1
addConstant(layerParams.name, out);
addConstant(output_name, out);
return;
}
else if (outShapes[node_proto.input(0)] == outShapes[node_proto.input(1)])
@ -1673,7 +1704,7 @@ void ONNXImporter::parseMul(LayerParams& layerParams, const opencv_onnx::NodePro
opencv_onnx::NodeProto proto;
proto.add_input(node_proto.input(1));
proto.add_input(node_proto.input(0));
proto.add_output(layerParams.name);
proto.add_output(output_name);
node_proto = proto;
}
@ -1851,7 +1882,7 @@ void ONNXImporter::parseTranspose(LayerParams& layerParams, const opencv_onnx::N
std::vector<Mat> inputs(1, getBlob(node_proto, 0)), transposed;
runLayer(layerParams, inputs, transposed);
CV_Assert(transposed.size() == 1);
addConstant(layerParams.name, transposed[0]);
addConstant(node_proto.output(0), transposed[0]);
return;
}
addLayer(layerParams, node_proto);
@ -1903,7 +1934,7 @@ void ONNXImporter::parseSqueeze(LayerParams& layerParams, const opencv_onnx::Nod
Mat inp = getBlob(node_proto, 0);
Mat out = inp.reshape(1, outShape);
out.dims = outShape.size(); // to workaround dims == 1
addConstant(layerParams.name, out);
addConstant(node_proto.output(0), out);
return;
}
addLayer(layerParams, node_proto);
@ -1930,7 +1961,7 @@ void ONNXImporter::parseFlatten(LayerParams& layerParams, const opencv_onnx::Nod
}
Mat output = input.reshape(1, 2, out_size);
addConstant(layerParams.name, output);
addConstant(node_proto.output(0), output);
return;
}
IterShape_t shapeIt = outShapes.find(node_proto.input(0));
@ -2002,7 +2033,7 @@ void ONNXImporter::parseUnsqueeze(LayerParams& layerParams, const opencv_onnx::N
}
Mat out = input.reshape(0, dims);
addConstant(layerParams.name, out);
addConstant(node_proto.output(0), out);
return;
}
@ -2039,6 +2070,7 @@ void ONNXImporter::parseExpand(LayerParams& layerParams, const opencv_onnx::Node
CV_CheckEQ(node_proto.input_size(), 2, "");
const std::string& input0 = node_proto.input(0);
const std::string& input1 = node_proto.input(1);
const std::string output_name = node_proto.output(0);
Mat newShapeMat = getBlob(input1);
MatShape targetShape(newShapeMat.ptr<int>(), newShapeMat.ptr<int>() + newShapeMat.total());
@ -2108,7 +2140,7 @@ void ONNXImporter::parseExpand(LayerParams& layerParams, const opencv_onnx::Node
input = input.reshape(0, total(inpShape, 0, broadcast_axes[0]));
Mat output = cv::repeat(input, 1, targetShape[broadcast_axes[0]]);
output = output.reshape(0, targetShape);
addConstant(layerParams.name, output);
addConstant(output_name, output);
return;
}
@ -2138,7 +2170,7 @@ void ONNXImporter::parseExpand(LayerParams& layerParams, const opencv_onnx::Node
layerParams.set("axis", broadcast_axes[0]);
layerParams.type = "Concat";
node_proto.set_output(0, layerParams.name);
node_proto.set_output(0, output_name);
}
else if (broadcast_axes.empty())
{
@ -2163,7 +2195,7 @@ void ONNXImporter::parseReshape(LayerParams& layerParams, const opencv_onnx::Nod
if (layer_id.find(node_proto.input(0)) == layer_id.end()) {
std::vector<Mat> inputs(1, getBlob(node_proto, 0)), outputs;
runLayer(layerParams, inputs, outputs);
addConstant(layerParams.name, outputs[0]);
addConstant(node_proto.output(0), outputs[0]);
return;
}
}
@ -2177,7 +2209,7 @@ void ONNXImporter::parseReshape(LayerParams& layerParams, const opencv_onnx::Nod
if (layer_id.find(node_proto.input(0)) == layer_id.end()) {
Mat input = getBlob(node_proto, 0);
Mat out = input.reshape(0, dim);
addConstant(layerParams.name, out);
addConstant(node_proto.output(0), out);
return;
}
replaceLayerParam(layerParams, "shape", "dim");
@ -2229,7 +2261,7 @@ void ONNXImporter::parseShape(LayerParams& layerParams, const opencv_onnx::NodeP
CV_LOG_ERROR(NULL, "DNN/ONNX(Shape): dynamic 'zero' shapes are not supported, input " << toString(inpShape, node_proto.input(0)));
CV_Assert(!isDynamicShape); // not supported
}
addConstant(layerParams.name, shapeMat);
addConstant(node_proto.output(0), shapeMat);
}
void ONNXImporter::parseCast(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
@ -2253,7 +2285,7 @@ void ONNXImporter::parseCast(LayerParams& layerParams, const opencv_onnx::NodePr
Mat dst;
blob.convertTo(dst, type);
dst.dims = blob.dims;
addConstant(layerParams.name, dst);
addConstant(node_proto.output(0), dst);
return;
}
else
@ -2281,7 +2313,7 @@ void ONNXImporter::parseConstantFill(LayerParams& layerParams, const opencv_onnx
for (int i = 0; i < inpShape.size(); i++)
CV_CheckGT(inpShape[i], 0, "");
Mat tensor(inpShape.size(), &inpShape[0], depth, Scalar(fill_value));
addConstant(layerParams.name, tensor);
addConstant(node_proto.output(0), tensor);
}
void ONNXImporter::parseGather(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
@ -2309,7 +2341,7 @@ void ONNXImporter::parseGather(LayerParams& layerParams, const opencv_onnx::Node
} else {
out.dims = 1;
}
addConstant(layerParams.name, out);
addConstant(node_proto.output(0), out);
return;
}
else
@ -2403,7 +2435,7 @@ void ONNXImporter::parseConcat(LayerParams& layerParams, const opencv_onnx::Node
runLayer(layerParams, inputs, concatenated);
CV_Assert(concatenated.size() == 1);
addConstant(layerParams.name, concatenated[0]);
addConstant(node_proto.output(0), concatenated[0]);
return;
}
else

Loading…
Cancel
Save