diff --git a/README.md b/README.md index e13ef56450..704c04b794 100644 --- a/README.md +++ b/README.md @@ -87,14 +87,25 @@ YOLOv8 may also be used directly in a Python environment, and accepts the same [ from ultralytics import YOLO # Load a model -model = YOLO("yolov8n.yaml") # build a new model from scratch -model = YOLO("yolov8n.pt") # load a pretrained model (recommended for training) - -# Use the model -model.train(data="coco8.yaml", epochs=3) # train the model -metrics = model.val() # evaluate model performance on the validation set -results = model("https://ultralytics.com/images/bus.jpg") # predict on an image -path = model.export(format="onnx") # export the model to ONNX format +model = YOLO("yolov8n.pt") + +# Train the model +train_results = model.train( + data="coco8.yaml", # path to dataset YAML + epochs=100, # number of training epochs + imgsz=640, # training image size + device="cpu", # device to run on, i.e. device=0 or device=0,1,2,3 or device=cpu +) + +# Evaluate model performance on the validation set +metrics = model.val() + +# Perform object detection on an image +results = model("path/to/image.jpg") +results[0].show() + +# Export the model to ONNX format +path = model.export(format="onnx") # return path to exported model ``` See YOLOv8 [Python Docs](https://docs.ultralytics.com/usage/python/) for more examples. @@ -139,23 +150,6 @@ See [Detection Docs](https://docs.ultralytics.com/tasks/detect/) for usage examp -
Detection (Open Image V7) - -See [Detection Docs](https://docs.ultralytics.com/tasks/detect/) for usage examples with these models trained on [Open Image V7](https://docs.ultralytics.com/datasets/detect/open-images-v7/), which include 600 pre-trained classes. - -| Model | size
(pixels) | mAPval
50-95 | Speed
CPU ONNX
(ms) | Speed
A100 TensorRT
(ms) | params
(M) | FLOPs
(B) | -| ----------------------------------------------------------------------------------------- | --------------------- | -------------------- | ------------------------------ | ----------------------------------- | ------------------ | ----------------- | -| [YOLOv8n](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n-oiv7.pt) | 640 | 18.4 | 142.4 | 1.21 | 3.5 | 10.5 | -| [YOLOv8s](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8s-oiv7.pt) | 640 | 27.7 | 183.1 | 1.40 | 11.4 | 29.7 | -| [YOLOv8m](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8m-oiv7.pt) | 640 | 33.6 | 408.5 | 2.26 | 26.2 | 80.6 | -| [YOLOv8l](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8l-oiv7.pt) | 640 | 34.9 | 596.9 | 2.43 | 44.1 | 167.4 | -| [YOLOv8x](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8x-oiv7.pt) | 640 | 36.3 | 860.6 | 3.56 | 68.7 | 260.6 | - -- **mAPval** values are for single-model single-scale on [Open Image V7](https://docs.ultralytics.com/datasets/detect/open-images-v7/) dataset.
Reproduce by `yolo val detect data=open-images-v7.yaml device=0` -- **Speed** averaged over Open Image V7 val images using an [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/) instance.
Reproduce by `yolo val detect data=open-images-v7.yaml batch=1 device=0|cpu` - -
-
Segmentation (COCO) See [Segmentation Docs](https://docs.ultralytics.com/tasks/segment/) for usage examples with these models trained on [COCO-Seg](https://docs.ultralytics.com/datasets/segment/coco/), which include 80 pre-trained classes. diff --git a/README.zh-CN.md b/README.zh-CN.md index 4319c0dff8..1e7b972762 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -89,14 +89,25 @@ YOLOv8 也可以在 Python 环境中直接使用,并接受与上述 CLI 示例 from ultralytics import YOLO # 加载模型 -model = YOLO("yolov8n.yaml") # 从头开始构建新模型 -model = YOLO("yolov8n.pt") # 加载预训练模型(建议用于训练) - -# 使用模型 -model.train(data="coco8.yaml", epochs=3) # 训练模型 -metrics = model.val() # 在验证集上评估模型性能 -results = model("https://ultralytics.com/images/bus.jpg") # 对图像进行预测 -success = model.export(format="onnx") # 将模型导出为 ONNX 格式 +model = YOLO("yolov8n.pt") + +# 训练模型 +train_results = model.train( + data="coco8.yaml", # 数据配置文件的路径 + epochs=100, # 训练的轮数 + imgsz=640, # 训练图像大小 + device="cpu", # 运行的设备,例如 device=0 或 device=0,1,2,3 或 device=cpu +) + +# 在验证集上评估模型性能 +metrics = model.val() + +# 对图像进行目标检测 +results = model("path/to/image.jpg") +results[0].show() + +# 将模型导出为 ONNX 格式 +path = model.export(format="onnx") # 返回导出的模型路径 ``` 查看 YOLOv8 [Python 文档](https://docs.ultralytics.com/usage/python/)以获取更多示例。 @@ -141,23 +152,6 @@ Ultralytics 提供了 YOLOv8 的交互式笔记本,涵盖训练、验证、跟
-
检测(Open Image V7) - -查看[检测文档](https://docs.ultralytics.com/tasks/detect/)以获取这些在[Open Image V7](https://docs.ultralytics.com/datasets/detect/open-images-v7/)上训练的模型的使用示例,其中包括600个预训练类别。 - -| 模型 | 尺寸
(像素) | mAP验证
50-95 | 速度
CPU ONNX
(毫秒) | 速度
A100 TensorRT
(毫秒) | 参数
(M) | 浮点运算
(B) | -| ----------------------------------------------------------------------------------------- | ------------------- | --------------------- | ------------------------------- | ------------------------------------ | ---------------- | -------------------- | -| [YOLOv8n](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n-oiv7.pt) | 640 | 18.4 | 142.4 | 1.21 | 3.5 | 10.5 | -| [YOLOv8s](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8s-oiv7.pt) | 640 | 27.7 | 183.1 | 1.40 | 11.4 | 29.7 | -| [YOLOv8m](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8m-oiv7.pt) | 640 | 33.6 | 408.5 | 2.26 | 26.2 | 80.6 | -| [YOLOv8l](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8l-oiv7.pt) | 640 | 34.9 | 596.9 | 2.43 | 44.1 | 167.4 | -| [YOLOv8x](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8x-oiv7.pt) | 640 | 36.3 | 860.6 | 3.56 | 68.7 | 260.6 | - -- **mAP验证** 值适用于在[Open Image V7](https://docs.ultralytics.com/datasets/detect/open-images-v7/)数据集上的单模型单尺度。
通过 `yolo val detect data=open-images-v7.yaml device=0` 以复现。 -- **速度** 在使用[Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/)实例对Open Image V7验证图像进行平均测算。
通过 `yolo val detect data=open-images-v7.yaml batch=1 device=0|cpu` 以复现。 - -
-
分割 (COCO) 查看[分割文档](https://docs.ultralytics.com/tasks/segment/)以获取这些在[COCO-Seg](https://docs.ultralytics.com/datasets/segment/coco/)上训练的模型的使用示例,其中包括80个预训练类别。 diff --git a/docs/en/modes/predict.md b/docs/en/modes/predict.md index 3bda0c079b..5ca5dab9d8 100644 --- a/docs/en/modes/predict.md +++ b/docs/en/modes/predict.md @@ -61,7 +61,7 @@ Ultralytics YOLO models return either a Python list of `Results` objects, or a m model = YOLO("yolov8n.pt") # pretrained YOLOv8n model # Run batched inference on a list of images - results = model(["im1.jpg", "im2.jpg"]) # return a list of Results objects + results = model(["image1.jpg", "image2.jpg"]) # return a list of Results objects # Process results list for result in results: @@ -83,7 +83,7 @@ Ultralytics YOLO models return either a Python list of `Results` objects, or a m model = YOLO("yolov8n.pt") # pretrained YOLOv8n model # Run batched inference on a list of images - results = model(["im1.jpg", "im2.jpg"], stream=True) # return a generator of Results objects + results = model(["image1.jpg", "image2.jpg"], stream=True) # return a generator of Results objects # Process results generator for result in results: @@ -109,8 +109,8 @@ YOLOv8 can process different types of input sources for inference, as shown in t | image | `'image.jpg'` | `str` or `Path` | Single image file. | | URL | `'https://ultralytics.com/images/bus.jpg'` | `str` | URL to an image. | | screenshot | `'screen'` | `str` | Capture a screenshot. | -| PIL | `Image.open('im.jpg')` | `PIL.Image` | HWC format with RGB channels. | -| OpenCV | `cv2.imread('im.jpg')` | `np.ndarray` | HWC format with BGR channels `uint8 (0-255)`. | +| PIL | `Image.open('image.jpg')` | `PIL.Image` | HWC format with RGB channels. | +| OpenCV | `cv2.imread('image.jpg')` | `np.ndarray` | HWC format with BGR channels `uint8 (0-255)`. | | numpy | `np.zeros((640,1280,3))` | `np.ndarray` | HWC format with BGR channels `uint8 (0-255)`. | | torch | `torch.zeros(16,3,320,640)` | `torch.Tensor` | BCHW format with RGB channels `float32 (0.0-1.0)`. | | CSV | `'sources.csv'` | `str` or `Path` | CSV file containing paths to images, videos, or directories. | @@ -710,16 +710,16 @@ When using YOLO models in a multi-threaded application, it's important to instan from ultralytics import YOLO - def thread_safe_predict(image_path): + def thread_safe_predict(model, image_path): """Performs thread-safe prediction on an image using a locally instantiated YOLO model.""" - local_model = YOLO("yolov8n.pt") - results = local_model.predict(image_path) + model = YOLO(model) + results = model.predict(image_path) # Process results # Starting threads that each have their own model instance - Thread(target=thread_safe_predict, args=("image1.jpg",)).start() - Thread(target=thread_safe_predict, args=("image2.jpg",)).start() + Thread(target=thread_safe_predict, args=("yolov8n.pt", "image1.jpg")).start() + Thread(target=thread_safe_predict, args=("yolov8n.pt", "image2.jpg")).start() ``` For an in-depth look at thread-safe inference with YOLO models and step-by-step instructions, please refer to our [YOLO Thread-Safe Inference Guide](../guides/yolo-thread-safe-inference.md). This guide will provide you with all the necessary information to avoid common pitfalls and ensure that your multi-threaded inference runs smoothly. diff --git a/docs/en/modes/track.md b/docs/en/modes/track.md index cfeb8c9084..7ed84189cc 100644 --- a/docs/en/modes/track.md +++ b/docs/en/modes/track.md @@ -290,63 +290,35 @@ Finally, after all threads have completed their task, the windows displaying the from ultralytics import YOLO + # Define model names and video sources + MODEL_NAMES = ["yolov8n.pt", "yolov8n-seg.pt"] + SOURCES = ["path/to/video.mp4", "0"] # local video, 0 for webcam - def run_tracker_in_thread(filename, model, file_index): - """ - Runs a video file or webcam stream concurrently with the YOLOv8 model using threading. - This function captures video frames from a given file or camera source and utilizes the YOLOv8 model for object - tracking. The function runs in its own thread for concurrent processing. + def run_tracker_in_thread(model_name, filename): + """ + Run YOLO tracker in its own thread for concurrent processing. Args: + model_name (str): The YOLOv8 model object. filename (str): The path to the video file or the identifier for the webcam/external camera source. - model (obj): The YOLOv8 model object. - file_index (int): An index to uniquely identify the file being processed, used for display purposes. - - Note: - Press 'q' to quit the video display window. """ - video = cv2.VideoCapture(filename) # Read the video file - - while True: - ret, frame = video.read() # Read the video frames - - # Exit the loop if no more frames in either video - if not ret: - break - - # Track objects in frames if available - results = model.track(frame, persist=True) - res_plotted = results[0].plot() - cv2.imshow(f"Tracking_Stream_{file_index}", res_plotted) - - key = cv2.waitKey(1) - if key == ord("q"): - break - - # Release video sources - video.release() - + model = YOLO(model_name) + results = model.track(filename, save=True, stream=True) + for r in results: + pass - # Load the models - model1 = YOLO("yolov8n.pt") - model2 = YOLO("yolov8n-seg.pt") - # Define the video files for the trackers - video_file1 = "path/to/video1.mp4" # Path to video file, 0 for webcam - video_file2 = 0 # Path to video file, 0 for webcam, 1 for external camera + # Create and start tracker threads using a for loop + tracker_threads = [] + for video_file, model_name in zip(SOURCES, MODEL_NAMES): + thread = threading.Thread(target=run_tracker_in_thread, args=(model_name, video_file), daemon=True) + tracker_threads.append(thread) + thread.start() - # Create the tracker threads - tracker_thread1 = threading.Thread(target=run_tracker_in_thread, args=(video_file1, model1, 1), daemon=True) - tracker_thread2 = threading.Thread(target=run_tracker_in_thread, args=(video_file2, model2, 2), daemon=True) - - # Start the tracker threads - tracker_thread1.start() - tracker_thread2.start() - - # Wait for the tracker threads to finish - tracker_thread1.join() - tracker_thread2.join() + # Wait for all tracker threads to finish + for thread in tracker_threads: + thread.join() # Clean up and close windows cv2.destroyAllWindows() @@ -408,35 +380,37 @@ To run object tracking on multiple video streams simultaneously, you can use Pyt from ultralytics import YOLO + # Define model names and video sources + MODEL_NAMES = ["yolov8n.pt", "yolov8n-seg.pt"] + SOURCES = ["path/to/video.mp4", "0"] # local video, 0 for webcam - def run_tracker_in_thread(filename, model, file_index): - video = cv2.VideoCapture(filename) - while True: - ret, frame = video.read() - if not ret: - break - results = model.track(frame, persist=True) - res_plotted = results[0].plot() - cv2.imshow(f"Tracking_Stream_{file_index}", res_plotted) - if cv2.waitKey(1) & 0xFF == ord("q"): - break - video.release() + def run_tracker_in_thread(model_name, filename): + """ + Run YOLO tracker in its own thread for concurrent processing. - model1 = YOLO("yolov8n.pt") - model2 = YOLO("yolov8n-seg.pt") - video_file1 = "path/to/video1.mp4" - video_file2 = 0 # Path to a second video file, or 0 for a webcam + Args: + model_name (str): The YOLOv8 model object. + filename (str): The path to the video file or the identifier for the webcam/external camera source. + """ + model = YOLO(model_name) + results = model.track(filename, save=True, stream=True) + for r in results: + pass - tracker_thread1 = threading.Thread(target=run_tracker_in_thread, args=(video_file1, model1, 1), daemon=True) - tracker_thread2 = threading.Thread(target=run_tracker_in_thread, args=(video_file2, model2, 2), daemon=True) - tracker_thread1.start() - tracker_thread2.start() + # Create and start tracker threads using a for loop + tracker_threads = [] + for video_file, model_name in zip(SOURCES, MODEL_NAMES): + thread = threading.Thread(target=run_tracker_in_thread, args=(model_name, video_file), daemon=True) + tracker_threads.append(thread) + thread.start() - tracker_thread1.join() - tracker_thread2.join() + # Wait for all tracker threads to finish + for thread in tracker_threads: + thread.join() + # Clean up and close windows cv2.destroyAllWindows() ``` diff --git a/docs/en/yolov5/quickstart_tutorial.md b/docs/en/yolov5/quickstart_tutorial.md index f8cabb9f23..582dfcbda8 100644 --- a/docs/en/yolov5/quickstart_tutorial.md +++ b/docs/en/yolov5/quickstart_tutorial.md @@ -44,8 +44,8 @@ Harness `detect.py` for versatile inference on various sources. It automatically ```bash python detect.py --weights yolov5s.pt --source 0 # webcam - img.jpg # image - vid.mp4 # video + image.jpg # image + video.mp4 # video screen # screenshot path/ # directory list.txt # list of images diff --git a/docs/mkdocs_github_authors.yaml b/docs/mkdocs_github_authors.yaml index 5154762a7d..839511e357 100644 --- a/docs/mkdocs_github_authors.yaml +++ b/docs/mkdocs_github_authors.yaml @@ -45,7 +45,7 @@ username: zhixuwei 49699333+dependabot[bot]@users.noreply.github.com: avatar: https://avatars.githubusercontent.com/u/27347476?v=4 - username: dependabot[bot] + username: dependabot 53246858+hasanghaffari93@users.noreply.github.com: avatar: https://avatars.githubusercontent.com/u/53246858?v=4 username: hasanghaffari93 diff --git a/tests/__init__.py b/tests/__init__.py index 3356f1cadb..ea6b398292 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,13 +1,13 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -from ultralytics.utils import ASSETS, ROOT, WEIGHTS_DIR, checks, is_dir_writeable +from ultralytics.utils import ASSETS, ROOT, WEIGHTS_DIR, checks # Constants used in tests MODEL = WEIGHTS_DIR / "path with spaces" / "yolov8n.pt" # test spaces in path CFG = "yolov8n.yaml" SOURCE = ASSETS / "bus.jpg" +SOURCES_LIST = [ASSETS / "bus.jpg", ASSETS, ASSETS / "*", ASSETS / "**/*.jpg"] TMP = (ROOT / "../tests/tmp").resolve() # temp directory for test files -IS_TMP_WRITEABLE = is_dir_writeable(TMP) CUDA_IS_AVAILABLE = checks.cuda_is_available() CUDA_DEVICE_COUNT = checks.cuda_device_count() @@ -15,6 +15,7 @@ __all__ = ( "MODEL", "CFG", "SOURCE", + "SOURCES_LIST", "TMP", "IS_TMP_WRITEABLE", "CUDA_IS_AVAILABLE", diff --git a/tests/test_python.py b/tests/test_python.py index aa18029d75..b5dd0c883b 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -1,6 +1,7 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license import contextlib +import csv import urllib from copy import copy from pathlib import Path @@ -12,7 +13,7 @@ import torch import yaml from PIL import Image -from tests import CFG, IS_TMP_WRITEABLE, MODEL, SOURCE, TMP +from tests import CFG, MODEL, SOURCE, SOURCES_LIST, TMP from ultralytics import RTDETR, YOLO from ultralytics.cfg import MODELS, TASK2DATA, TASKS from ultralytics.data.build import load_inference_source @@ -26,11 +27,14 @@ from ultralytics.utils import ( WEIGHTS_DIR, WINDOWS, checks, + is_dir_writeable, is_github_action_running, ) from ultralytics.utils.downloads import download from ultralytics.utils.torch_utils import TORCH_1_9 +IS_TMP_WRITEABLE = is_dir_writeable(TMP) # WARNING: must be run once tests start as TMP does not exist on tests/init + def test_model_forward(): """Test the forward pass of the YOLO model.""" @@ -70,11 +74,37 @@ def test_model_profile(): @pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable") def test_predict_txt(): """Tests YOLO predictions with file, directory, and pattern sources listed in a text file.""" - txt_file = TMP / "sources.txt" - with open(txt_file, "w") as f: - for x in [ASSETS / "bus.jpg", ASSETS, ASSETS / "*", ASSETS / "**/*.jpg"]: - f.write(f"{x}\n") - _ = YOLO(MODEL)(source=txt_file, imgsz=32) + file = TMP / "sources_multi_row.txt" + with open(file, "w") as f: + for src in SOURCES_LIST: + f.write(f"{src}\n") + results = YOLO(MODEL)(source=file, imgsz=32) + assert len(results) == 7 # 1 + 2 + 2 + 2 = 7 images + + +@pytest.mark.skipif(True, reason="disabled for testing") +@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable") +def test_predict_csv_multi_row(): + """Tests YOLO predictions with sources listed in multiple rows of a CSV file.""" + file = TMP / "sources_multi_row.csv" + with open(file, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["source"]) + writer.writerows([[src] for src in SOURCES_LIST]) + results = YOLO(MODEL)(source=file, imgsz=32) + assert len(results) == 7 # 1 + 2 + 2 + 2 = 7 images + + +@pytest.mark.skipif(True, reason="disabled for testing") +@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable") +def test_predict_csv_single_row(): + """Tests YOLO predictions with sources listed in a single row of a CSV file.""" + file = TMP / "sources_single_row.csv" + with open(file, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(SOURCES_LIST) + results = YOLO(MODEL)(source=file, imgsz=32) + assert len(results) == 7 # 1 + 2 + 2 + 2 = 7 images @pytest.mark.parametrize("model_name", MODELS) diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 3ff10e1cce..9f6607772e 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.2.94" +__version__ = "8.2.95" import os diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 2c26ea2189..ae98540b03 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -668,13 +668,14 @@ class BaseTrainer: def final_eval(self): """Performs final evaluation and validation for object detection YOLO model.""" + ckpt = {} for f in self.last, self.best: if f.exists(): - strip_optimizer(f) # strip optimizers - if f is self.best: - if self.last.is_file(): # update best.pt train_metrics from last.pt - k = "train_results" - torch.save({**torch.load(self.best), **{k: torch.load(self.last)[k]}}, self.best) + if f is self.last: + ckpt = strip_optimizer(f) + elif f is self.best: + k = "train_results" # update best.pt train_metrics from last.pt + strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None) LOGGER.info(f"\nValidating {f}...") self.validator.args.plots = self.args.plots self.metrics = self.validator(model=f) diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 274f56d54b..ad860e93e4 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -759,6 +759,10 @@ class SafeClass: """Initialize SafeClass instance, ignoring all arguments.""" pass + def __call__(self, *args, **kwargs): + """Run SafeClass instance, ignoring all arguments.""" + pass + class SafeUnpickler(pickle.Unpickler): """Custom Unpickler that replaces unknown classes with SafeClass.""" diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index 7cde9dc7a8..758a4e11fe 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -533,16 +533,17 @@ class ModelEMA: copy_attr(self.ema, model, include, exclude) -def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None: +def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict: """ Strip optimizer from 'f' to finalize training, optionally save as 's'. Args: f (str): file path to model to strip the optimizer from. Default is 'best.pt'. s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten. + updates (dict): a dictionary of updates to overlay onto the checkpoint before saving. Returns: - None + (dict): The combined checkpoint dictionary. Example: ```python @@ -562,9 +563,9 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None: assert "model" in x, "'model' missing from checkpoint" except Exception as e: LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}") - return + return {} - updates = { + metadata = { "date": datetime.now().isoformat(), "version": __version__, "license": "AGPL-3.0 License (https://ultralytics.com/license)", @@ -591,9 +592,11 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None: # x['model'].args = x['train_args'] # Save - torch.save({**updates, **x}, s or f, use_dill=False) # combine dicts (prefer to the right) + combined = {**metadata, **x, **(updates or {})} + torch.save(combined, s or f, use_dill=False) # combine dicts (prefer to the right) mb = os.path.getsize(s or f) / 1e6 # file size LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB") + return combined def convert_optimizer_state_dict_to_fp16(state_dict):