Merge branch 'main' into augment-fix

augment-fix
Francesco Mattioli 6 months ago
commit 323d9338c5
  1. 2
      ultralytics/__init__.py
  2. 5
      ultralytics/engine/trainer.py
  3. 3
      ultralytics/solutions/queue_management.py
  4. 1
      ultralytics/utils/torch_utils.py
  5. 9
      ultralytics/utils/tuner.py

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.86"
__version__ = "8.2.87"
import os

@ -42,6 +42,7 @@ from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_m
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.utils.files import get_latest_run
from ultralytics.utils.torch_utils import (
TORCH_2_4,
EarlyStopping,
ModelEMA,
autocast,
@ -265,7 +266,9 @@ class BaseTrainer:
if RANK > -1 and world_size > 1: # DDP
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
self.amp = bool(self.amp) # as boolean
self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
self.scaler = (
torch.amp.GradScaler("cuda", enabled=self.amp) if TORCH_2_4 else torch.cuda.amp.GradScaler(enabled=self.amp)
)
if world_size > 1:
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)

@ -89,7 +89,7 @@ class QueueManager:
"""Extracts and processes tracks for queue management in a video stream."""
# Initialize annotator and draw the queue region
self.annotator = Annotator(self.im0, self.tf, self.names)
self.counts = 0 # Reset counts every frame
if tracks[0].boxes.id is not None:
boxes = tracks[0].boxes.xyxy.cpu()
clss = tracks[0].boxes.cls.cpu().tolist()
@ -132,7 +132,6 @@ class QueueManager:
txt_color=self.count_txt_color,
)
self.counts = 0 # Reset counts after displaying
self.display_frames()
def display_frames(self):

@ -40,6 +40,7 @@ except ImportError:
TORCH_1_9 = check_version(torch.__version__, "1.9.0")
TORCH_1_13 = check_version(torch.__version__, "1.13.0")
TORCH_2_0 = check_version(torch.__version__, "2.0.0")
TORCH_2_4 = check_version(torch.__version__, "2.4.0")
TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")
TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")
TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")

@ -143,5 +143,10 @@ def run_ray_tune(
# Run the hyperparameter search
tuner.fit()
# Return the results of the hyperparameter search
return tuner.get_results()
# Get the results of the hyperparameter search
results = tuner.get_results()
# Shut down Ray to clean up workers
ray.shutdown()
return results

Loading…
Cancel
Save