`ultralytics 8.2.95` faster checkpoint saving (#16311)

Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
pull/1661/head^2 v8.2.95
Glenn Jocher 2 months ago committed by GitHub
parent 7b19e0daa0
commit ba438aea5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 79
      docs/en/modes/track.md
  2. 2
      ultralytics/__init__.py
  3. 11
      ultralytics/engine/trainer.py
  4. 4
      ultralytics/nn/tasks.py
  5. 13
      ultralytics/utils/torch_utils.py

@ -292,42 +292,27 @@ Finally, after all threads have completed their task, the windows displaying the
# Define model names and video sources # Define model names and video sources
MODEL_NAMES = ["yolov8n.pt", "yolov8n-seg.pt"] MODEL_NAMES = ["yolov8n.pt", "yolov8n-seg.pt"]
SOURCES = ["path/to/video1.mp4", 0] # local video, 0 for webcam SOURCES = ["path/to/video.mp4", "0"] # local video, 0 for webcam
def run_tracker_in_thread(model_name, filename, index): def run_tracker_in_thread(model_name, filename):
""" """
Runs a video file or webcam stream concurrently with the YOLOv8 model using threading. This function captures video Run YOLO tracker in its own thread for concurrent processing.
frames from a given file or camera source and utilizes the YOLOv8 model for object tracking. The function runs in
its own thread for concurrent processing.
Args: Args:
model_name (str): The YOLOv8 model object.
filename (str): The path to the video file or the identifier for the webcam/external camera source. filename (str): The path to the video file or the identifier for the webcam/external camera source.
model (obj): The YOLOv8 model object.
index (int): An index to uniquely identify the file being processed, used for display purposes.
""" """
model = YOLO(model_name) model = YOLO(model_name)
video = cv2.VideoCapture(filename) results = model.track(filename, save=True, stream=True)
for r in results:
while True: pass
ret, frame = video.read()
if not ret:
break
results = model.track(frame, persist=True)
res_plotted = results[0].plot()
cv2.imshow(f"Tracking_Stream_{index}", res_plotted)
if cv2.waitKey(1) == ord("q"):
break
video.release()
# Create and start tracker threads using a for loop # Create and start tracker threads using a for loop
tracker_threads = [] tracker_threads = []
for i, (video_file, model_name) in enumerate(zip(SOURCES, MODEL_NAMES), start=1): for video_file, model_name in zip(SOURCES, MODEL_NAMES):
thread = threading.Thread(target=run_tracker_in_thread, args=(model_name, video_file, i), daemon=True) thread = threading.Thread(target=run_tracker_in_thread, args=(model_name, video_file), daemon=True)
tracker_threads.append(thread) tracker_threads.append(thread)
thread.start() thread.start()
@ -395,35 +380,37 @@ To run object tracking on multiple video streams simultaneously, you can use Pyt
from ultralytics import YOLO from ultralytics import YOLO
# Define model names and video sources
MODEL_NAMES = ["yolov8n.pt", "yolov8n-seg.pt"]
SOURCES = ["path/to/video.mp4", "0"] # local video, 0 for webcam
def run_tracker_in_thread(filename, model, file_index):
video = cv2.VideoCapture(filename)
while True:
ret, frame = video.read()
if not ret:
break
results = model.track(frame, persist=True)
res_plotted = results[0].plot()
cv2.imshow(f"Tracking_Stream_{file_index}", res_plotted)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
video.release()
def run_tracker_in_thread(model_name, filename):
"""
Run YOLO tracker in its own thread for concurrent processing.
model1 = YOLO("yolov8n.pt") Args:
model2 = YOLO("yolov8n-seg.pt") model_name (str): The YOLOv8 model object.
video_file1 = "path/to/video1.mp4" filename (str): The path to the video file or the identifier for the webcam/external camera source.
video_file2 = 0 # Path to a second video file, or 0 for a webcam """
model = YOLO(model_name)
results = model.track(filename, save=True, stream=True)
for r in results:
pass
tracker_thread1 = threading.Thread(target=run_tracker_in_thread, args=(video_file1, model1, 1), daemon=True)
tracker_thread2 = threading.Thread(target=run_tracker_in_thread, args=(video_file2, model2, 2), daemon=True)
tracker_thread1.start() # Create and start tracker threads using a for loop
tracker_thread2.start() tracker_threads = []
for video_file, model_name in zip(SOURCES, MODEL_NAMES):
thread = threading.Thread(target=run_tracker_in_thread, args=(model_name, video_file), daemon=True)
tracker_threads.append(thread)
thread.start()
tracker_thread1.join() # Wait for all tracker threads to finish
tracker_thread2.join() for thread in tracker_threads:
thread.join()
# Clean up and close windows
cv2.destroyAllWindows() cv2.destroyAllWindows()
``` ```

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.94" __version__ = "8.2.95"
import os import os

@ -668,13 +668,14 @@ class BaseTrainer:
def final_eval(self): def final_eval(self):
"""Performs final evaluation and validation for object detection YOLO model.""" """Performs final evaluation and validation for object detection YOLO model."""
ckpt = {}
for f in self.last, self.best: for f in self.last, self.best:
if f.exists(): if f.exists():
strip_optimizer(f) # strip optimizers if f is self.last:
if f is self.best: ckpt = strip_optimizer(f)
if self.last.is_file(): # update best.pt train_metrics from last.pt elif f is self.best:
k = "train_results" k = "train_results" # update best.pt train_metrics from last.pt
torch.save({**torch.load(self.best), **{k: torch.load(self.last)[k]}}, self.best) strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None)
LOGGER.info(f"\nValidating {f}...") LOGGER.info(f"\nValidating {f}...")
self.validator.args.plots = self.args.plots self.validator.args.plots = self.args.plots
self.metrics = self.validator(model=f) self.metrics = self.validator(model=f)

@ -759,6 +759,10 @@ class SafeClass:
"""Initialize SafeClass instance, ignoring all arguments.""" """Initialize SafeClass instance, ignoring all arguments."""
pass pass
def __call__(self, *args, **kwargs):
"""Run SafeClass instance, ignoring all arguments."""
pass
class SafeUnpickler(pickle.Unpickler): class SafeUnpickler(pickle.Unpickler):
"""Custom Unpickler that replaces unknown classes with SafeClass.""" """Custom Unpickler that replaces unknown classes with SafeClass."""

@ -533,16 +533,17 @@ class ModelEMA:
copy_attr(self.ema, model, include, exclude) copy_attr(self.ema, model, include, exclude)
def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None: def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict:
""" """
Strip optimizer from 'f' to finalize training, optionally save as 's'. Strip optimizer from 'f' to finalize training, optionally save as 's'.
Args: Args:
f (str): file path to model to strip the optimizer from. Default is 'best.pt'. f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten. s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
updates (dict): a dictionary of updates to overlay onto the checkpoint before saving.
Returns: Returns:
None (dict): The combined checkpoint dictionary.
Example: Example:
```python ```python
@ -562,9 +563,9 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
assert "model" in x, "'model' missing from checkpoint" assert "model" in x, "'model' missing from checkpoint"
except Exception as e: except Exception as e:
LOGGER.warning(f"WARNING ⚠ Skipping {f}, not a valid Ultralytics model: {e}") LOGGER.warning(f"WARNING ⚠ Skipping {f}, not a valid Ultralytics model: {e}")
return return {}
updates = { metadata = {
"date": datetime.now().isoformat(), "date": datetime.now().isoformat(),
"version": __version__, "version": __version__,
"license": "AGPL-3.0 License (https://ultralytics.com/license)", "license": "AGPL-3.0 License (https://ultralytics.com/license)",
@ -591,9 +592,11 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
# x['model'].args = x['train_args'] # x['model'].args = x['train_args']
# Save # Save
torch.save({**updates, **x}, s or f, use_dill=False) # combine dicts (prefer to the right) combined = {**metadata, **x, **(updates or {})}
torch.save(combined, s or f, use_dill=False) # combine dicts (prefer to the right)
mb = os.path.getsize(s or f) / 1e6 # file size mb = os.path.getsize(s or f) / 1e6 # file size
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB") LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
return combined
def convert_optimizer_state_dict_to_fp16(state_dict): def convert_optimizer_state_dict_to_fp16(state_dict):

Loading…
Cancel
Save