You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
68 lines
2.1 KiB
68 lines
2.1 KiB
2 years ago
|
import torch
|
||
|
|
||
|
from ultralytics import YOLO
|
||
|
from ultralytics.nn.modules import Detect, Segment
|
||
|
|
||
|
|
||
|
def export_onnx(model, file):
|
||
|
# YOLOv5 ONNX export
|
||
|
import onnx
|
||
|
im = torch.zeros(1, 3, 640, 640)
|
||
|
model.eval()
|
||
|
model(im, profile=True)
|
||
|
for k, m in model.named_modules():
|
||
|
if isinstance(m, (Detect, Segment)):
|
||
|
m.export = True
|
||
|
|
||
|
torch.onnx.export(
|
||
|
model,
|
||
|
im,
|
||
|
file,
|
||
|
verbose=False,
|
||
|
opset_version=12,
|
||
|
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
|
||
|
input_names=['images'])
|
||
|
|
||
|
# Checks
|
||
|
model_onnx = onnx.load(file) # load onnx model
|
||
|
onnx.checker.check_model(model_onnx) # check onnx model
|
||
|
|
||
|
# Metadata
|
||
|
d = {'stride': int(max(model.stride)), 'names': model.names}
|
||
|
for k, v in d.items():
|
||
|
meta = model_onnx.metadata_props.add()
|
||
|
meta.key, meta.value = k, str(v)
|
||
|
onnx.save(model_onnx, file)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
model = YOLO()
|
||
|
print("yolov8n")
|
||
|
model.new("yolov8n.yaml")
|
||
|
print("yolov8n-seg")
|
||
|
model.new("yolov8n-seg.yaml")
|
||
|
print("yolov8s")
|
||
|
model.new("yolov8s.yaml")
|
||
|
# export_onnx(model.model, "yolov8s.onnx")
|
||
|
print("yolov8s-seg")
|
||
|
model.new("yolov8s-seg.yaml")
|
||
|
# export_onnx(model.model, "yolov8s-seg.onnx")
|
||
|
print("yolov8m")
|
||
|
model.new("yolov8m.yaml")
|
||
|
print("yolov8m-seg")
|
||
|
model.new("yolov8m-seg.yaml")
|
||
|
print("yolov8l")
|
||
|
model.new("yolov8l.yaml")
|
||
|
print("yolov8l-seg")
|
||
|
model.new("yolov8l-seg.yaml")
|
||
|
print("yolov8x")
|
||
|
model.new("yolov8x.yaml")
|
||
|
print("yolov8x-seg")
|
||
|
model.new("yolov8x-seg.yaml")
|
||
|
|
||
|
# n vs n-seg: 8.9GFLOPs vs 12.8GFLOPs, 3.16M vs 3.6M. ch[0] // 4 (11.9GFLOPs, 3.39M)
|
||
|
# s vs s-seg: 28.8GFLOPs vs 44.4GFLOPs, 11.1M vs 12.9M. ch[0] // 4 (39.5GFLOPs, 11.7M)
|
||
|
# m vs m-seg: 79.3GFLOPs vs 113.8GFLOPs, 25.9M vs 29.5M. ch[0] // 4 (103.GFLOPs, 27.1M)
|
||
|
# l vs l-seg: 165.7GFLOPs vs 226.3GFLOPs, 43.7M vs 49.6M. ch[0] // 4 (207GFLOPs, 45.7M)
|
||
|
# x vs x-seg: 258.5GFLOPs vs 353.0GFLOPs, 68.3M vs 77.5M. ch[0] // 4 (324GFLOPs, 71.4M)
|