`ultralytics 8.2.95` faster checkpoint saving (#16311)

Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
pull/1661/head^2 v8.2.95
Glenn Jocher 2 months ago committed by GitHub
parent 7b19e0daa0
commit ba438aea5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 79
      docs/en/modes/track.md
  2. 2
      ultralytics/__init__.py
  3. 11
      ultralytics/engine/trainer.py
  4. 4
      ultralytics/nn/tasks.py
  5. 13
      ultralytics/utils/torch_utils.py

@ -292,42 +292,27 @@ Finally, after all threads have completed their task, the windows displaying the
# Define model names and video sources
MODEL_NAMES = ["yolov8n.pt", "yolov8n-seg.pt"]
SOURCES = ["path/to/video1.mp4", 0] # local video, 0 for webcam
SOURCES = ["path/to/video.mp4", "0"] # local video, 0 for webcam
def run_tracker_in_thread(model_name, filename, index):
def run_tracker_in_thread(model_name, filename):
"""
Runs a video file or webcam stream concurrently with the YOLOv8 model using threading. This function captures video
frames from a given file or camera source and utilizes the YOLOv8 model for object tracking. The function runs in
its own thread for concurrent processing.
Run YOLO tracker in its own thread for concurrent processing.
Args:
model_name (str): The YOLOv8 model object.
filename (str): The path to the video file or the identifier for the webcam/external camera source.
model (obj): The YOLOv8 model object.
index (int): An index to uniquely identify the file being processed, used for display purposes.
"""
model = YOLO(model_name)
video = cv2.VideoCapture(filename)
while True:
ret, frame = video.read()
if not ret:
break
results = model.track(frame, persist=True)
res_plotted = results[0].plot()
cv2.imshow(f"Tracking_Stream_{index}", res_plotted)
if cv2.waitKey(1) == ord("q"):
break
video.release()
results = model.track(filename, save=True, stream=True)
for r in results:
pass
# Create and start tracker threads using a for loop
tracker_threads = []
for i, (video_file, model_name) in enumerate(zip(SOURCES, MODEL_NAMES), start=1):
thread = threading.Thread(target=run_tracker_in_thread, args=(model_name, video_file, i), daemon=True)
for video_file, model_name in zip(SOURCES, MODEL_NAMES):
thread = threading.Thread(target=run_tracker_in_thread, args=(model_name, video_file), daemon=True)
tracker_threads.append(thread)
thread.start()
@ -395,35 +380,37 @@ To run object tracking on multiple video streams simultaneously, you can use Pyt
from ultralytics import YOLO
# Define model names and video sources
MODEL_NAMES = ["yolov8n.pt", "yolov8n-seg.pt"]
SOURCES = ["path/to/video.mp4", "0"] # local video, 0 for webcam
def run_tracker_in_thread(filename, model, file_index):
video = cv2.VideoCapture(filename)
while True:
ret, frame = video.read()
if not ret:
break
results = model.track(frame, persist=True)
res_plotted = results[0].plot()
cv2.imshow(f"Tracking_Stream_{file_index}", res_plotted)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
video.release()
def run_tracker_in_thread(model_name, filename):
"""
Run YOLO tracker in its own thread for concurrent processing.
model1 = YOLO("yolov8n.pt")
model2 = YOLO("yolov8n-seg.pt")
video_file1 = "path/to/video1.mp4"
video_file2 = 0 # Path to a second video file, or 0 for a webcam
Args:
model_name (str): The YOLOv8 model object.
filename (str): The path to the video file or the identifier for the webcam/external camera source.
"""
model = YOLO(model_name)
results = model.track(filename, save=True, stream=True)
for r in results:
pass
tracker_thread1 = threading.Thread(target=run_tracker_in_thread, args=(video_file1, model1, 1), daemon=True)
tracker_thread2 = threading.Thread(target=run_tracker_in_thread, args=(video_file2, model2, 2), daemon=True)
tracker_thread1.start()
tracker_thread2.start()
# Create and start tracker threads using a for loop
tracker_threads = []
for video_file, model_name in zip(SOURCES, MODEL_NAMES):
thread = threading.Thread(target=run_tracker_in_thread, args=(model_name, video_file), daemon=True)
tracker_threads.append(thread)
thread.start()
tracker_thread1.join()
tracker_thread2.join()
# Wait for all tracker threads to finish
for thread in tracker_threads:
thread.join()
# Clean up and close windows
cv2.destroyAllWindows()
```

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.94"
__version__ = "8.2.95"
import os

@ -668,13 +668,14 @@ class BaseTrainer:
def final_eval(self):
"""Performs final evaluation and validation for object detection YOLO model."""
ckpt = {}
for f in self.last, self.best:
if f.exists():
strip_optimizer(f) # strip optimizers
if f is self.best:
if self.last.is_file(): # update best.pt train_metrics from last.pt
k = "train_results"
torch.save({**torch.load(self.best), **{k: torch.load(self.last)[k]}}, self.best)
if f is self.last:
ckpt = strip_optimizer(f)
elif f is self.best:
k = "train_results" # update best.pt train_metrics from last.pt
strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None)
LOGGER.info(f"\nValidating {f}...")
self.validator.args.plots = self.args.plots
self.metrics = self.validator(model=f)

@ -759,6 +759,10 @@ class SafeClass:
"""Initialize SafeClass instance, ignoring all arguments."""
pass
def __call__(self, *args, **kwargs):
"""Run SafeClass instance, ignoring all arguments."""
pass
class SafeUnpickler(pickle.Unpickler):
"""Custom Unpickler that replaces unknown classes with SafeClass."""

@ -533,16 +533,17 @@ class ModelEMA:
copy_attr(self.ema, model, include, exclude)
def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict:
"""
Strip optimizer from 'f' to finalize training, optionally save as 's'.
Args:
f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
updates (dict): a dictionary of updates to overlay onto the checkpoint before saving.
Returns:
None
(dict): The combined checkpoint dictionary.
Example:
```python
@ -562,9 +563,9 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
assert "model" in x, "'model' missing from checkpoint"
except Exception as e:
LOGGER.warning(f"WARNING ⚠ Skipping {f}, not a valid Ultralytics model: {e}")
return
return {}
updates = {
metadata = {
"date": datetime.now().isoformat(),
"version": __version__,
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
@ -591,9 +592,11 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
# x['model'].args = x['train_args']
# Save
torch.save({**updates, **x}, s or f, use_dill=False) # combine dicts (prefer to the right)
combined = {**metadata, **x, **(updates or {})}
torch.save(combined, s or f, use_dill=False) # combine dicts (prefer to the right)
mb = os.path.getsize(s or f) / 1e6 # file size
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
return combined
def convert_optimizer_state_dict_to_fp16(state_dict):

Loading…
Cancel
Save