|
|
@ -1467,6 +1467,10 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP |
|
|
|
|
|
|
|
|
|
|
|
void ONNXImporter::parseSplit(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) |
|
|
|
void ONNXImporter::parseSplit(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) |
|
|
|
{ |
|
|
|
{ |
|
|
|
|
|
|
|
int axis = layerParams.get<int>("axis", 0); |
|
|
|
|
|
|
|
MatShape inpShape = outShapes[node_proto.input(0)]; |
|
|
|
|
|
|
|
axis = normalize_axis(axis, inpShape.size()); |
|
|
|
|
|
|
|
|
|
|
|
if (layerParams.has("split")) |
|
|
|
if (layerParams.has("split")) |
|
|
|
{ |
|
|
|
{ |
|
|
|
DictValue splits = layerParams.get("split"); |
|
|
|
DictValue splits = layerParams.get("split"); |
|
|
@ -1480,13 +1484,26 @@ void ONNXImporter::parseSplit(LayerParams& layerParams, const opencv_onnx::NodeP |
|
|
|
} |
|
|
|
} |
|
|
|
layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size())); |
|
|
|
layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size())); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
else if (node_proto.input_size() == 2) // opset >= 13, the split will be stored at the second input, instead of the attribute.
|
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
CV_Assert(constBlobs.find(node_proto.input(1)) != constBlobs.end()); |
|
|
|
|
|
|
|
Mat splitsBlob = getBlob(node_proto, 1); |
|
|
|
|
|
|
|
int splitSize = splitsBlob.total(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<int> slicePoints(splitSize - 1, splitsBlob.at<int>(0)); |
|
|
|
|
|
|
|
for (int i = 1; i < splitSize - 1; ++i) |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
slicePoints[i] = slicePoints[i - 1] + splitsBlob.at<int>(i); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size())); |
|
|
|
|
|
|
|
} |
|
|
|
else |
|
|
|
else |
|
|
|
{ |
|
|
|
{ |
|
|
|
layerParams.set("num_split", node_proto.output_size()); |
|
|
|
layerParams.set("num_split", node_proto.output_size()); |
|
|
|
} |
|
|
|
} |
|
|
|
int depth = layerParams.get<int>("depth", CV_32F); |
|
|
|
int depth = layerParams.get<int>("depth", CV_32F); |
|
|
|
layerParams.type = (depth == CV_8S) ? "SliceInt8" : "Slice"; |
|
|
|
layerParams.type = (depth == CV_8S) ? "SliceInt8" : "Slice"; |
|
|
|
layerParams.set("axis", layerParams.get<float>("axis", 0)); |
|
|
|
layerParams.set("axis", axis); |
|
|
|
addLayer(layerParams, node_proto); |
|
|
|
addLayer(layerParams, node_proto); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|