From 51281f8d69065510e962b54c105295fd60e11444 Mon Sep 17 00:00:00 2001 From: zihaomu Date: Tue, 11 Apr 2023 16:18:50 +0800 Subject: [PATCH] support the split node of onnx opset >= 13 --- modules/dnn/src/onnx/onnx_importer.cpp | 19 ++++++++++++++++++- modules/dnn/test/test_onnx_importer.cpp | 2 ++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 651d1b1571..e074d54169 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -1467,6 +1467,10 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP void ONNXImporter::parseSplit(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) { + int axis = layerParams.get("axis", 0); + MatShape inpShape = outShapes[node_proto.input(0)]; + axis = normalize_axis(axis, inpShape.size()); + if (layerParams.has("split")) { DictValue splits = layerParams.get("split"); @@ -1480,13 +1484,26 @@ void ONNXImporter::parseSplit(LayerParams& layerParams, const opencv_onnx::NodeP } layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size())); } + else if (node_proto.input_size() == 2) // opset >= 13, the split will be stored at the second input, instead of the attribute. + { + CV_Assert(constBlobs.find(node_proto.input(1)) != constBlobs.end()); + Mat splitsBlob = getBlob(node_proto, 1); + int splitSize = splitsBlob.total(); + + std::vector slicePoints(splitSize - 1, splitsBlob.at(0)); + for (int i = 1; i < splitSize - 1; ++i) + { + slicePoints[i] = slicePoints[i - 1] + splitsBlob.at(i); + } + layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size())); + } else { layerParams.set("num_split", node_proto.output_size()); } int depth = layerParams.get("depth", CV_32F); layerParams.type = (depth == CV_8S) ? "SliceInt8" : "Slice"; - layerParams.set("axis", layerParams.get("axis", 0)); + layerParams.set("axis", axis); addLayer(layerParams, node_proto); } diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index e566acd827..b8615912c5 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -1149,6 +1149,8 @@ TEST_P(Test_ONNX_layers, Split) testONNXModels("split_2"); testONNXModels("split_3"); testONNXModels("split_4"); + testONNXModels("split_5"); + testONNXModels("split_6"); testONNXModels("split_neg_axis"); }