Merge pull request #23482 from zihaomu:onnx_opset13_split

DNN: support the split node of onnx opset >= 13
pull/23495/head
Alexander Smorkalov 2 years ago committed by GitHub
commit aa17f881b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 19
      modules/dnn/src/onnx/onnx_importer.cpp
  2. 2
      modules/dnn/test/test_onnx_importer.cpp

@ -1467,6 +1467,10 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP
void ONNXImporter::parseSplit(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) void ONNXImporter::parseSplit(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{ {
int axis = layerParams.get<int>("axis", 0);
MatShape inpShape = outShapes[node_proto.input(0)];
axis = normalize_axis(axis, inpShape.size());
if (layerParams.has("split")) if (layerParams.has("split"))
{ {
DictValue splits = layerParams.get("split"); DictValue splits = layerParams.get("split");
@ -1480,13 +1484,26 @@ void ONNXImporter::parseSplit(LayerParams& layerParams, const opencv_onnx::NodeP
} }
layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size())); layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
} }
else if (node_proto.input_size() == 2) // opset >= 13, the split will be stored at the second input, instead of the attribute.
{
CV_Assert(constBlobs.find(node_proto.input(1)) != constBlobs.end());
Mat splitsBlob = getBlob(node_proto, 1);
int splitSize = splitsBlob.total();
std::vector<int> slicePoints(splitSize - 1, splitsBlob.at<int>(0));
for (int i = 1; i < splitSize - 1; ++i)
{
slicePoints[i] = slicePoints[i - 1] + splitsBlob.at<int>(i);
}
layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
}
else else
{ {
layerParams.set("num_split", node_proto.output_size()); layerParams.set("num_split", node_proto.output_size());
} }
int depth = layerParams.get<int>("depth", CV_32F); int depth = layerParams.get<int>("depth", CV_32F);
layerParams.type = (depth == CV_8S) ? "SliceInt8" : "Slice"; layerParams.type = (depth == CV_8S) ? "SliceInt8" : "Slice";
layerParams.set("axis", layerParams.get<float>("axis", 0)); layerParams.set("axis", axis);
addLayer(layerParams, node_proto); addLayer(layerParams, node_proto);
} }

@ -1149,6 +1149,8 @@ TEST_P(Test_ONNX_layers, Split)
testONNXModels("split_2"); testONNXModels("split_2");
testONNXModels("split_3"); testONNXModels("split_3");
testONNXModels("split_4"); testONNXModels("split_4");
testONNXModels("split_5");
testONNXModels("split_6");
testONNXModels("split_neg_axis"); testONNXModels("split_neg_axis");
} }

Loading…
Cancel
Save