`ultralytics 8.2.71` Multinode DDP training (#14879)

Co-authored-by: Haris Rehman <haris.rehman.cowlar@gmail.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
pull/10046/head v8.2.71
Haris Rehman 4 months ago committed by GitHub
parent 16fc325308
commit 9c5d1a2451
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 3
      .gitignore
  2. 2
      ultralytics/__init__.py
  3. 5
      ultralytics/engine/trainer.py
  4. 3
      ultralytics/utils/torch_utils.py

3
.gitignore vendored

@ -26,6 +26,9 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
requirements.txt
setup.py
ultralytics.egg-info
# PyInstaller
# Usually these files are written by a python script from a template

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.70"
__version__ = "8.2.71"
import os

@ -26,6 +26,7 @@ from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
from ultralytics.utils import (
DEFAULT_CFG,
LOCAL_RANK,
LOGGER,
RANK,
TQDM,
@ -129,7 +130,7 @@ class BaseTrainer:
# Model and Dataset
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
with torch_distributed_zero_first(RANK): # avoid auto-downloading dataset multiple times
with torch_distributed_zero_first(LOCAL_RANK): # avoid auto-downloading dataset multiple times
self.trainset, self.testset = self.get_dataset()
self.ema = None
@ -285,7 +286,7 @@ class BaseTrainer:
# Dataloaders
batch_size = self.batch_size // max(world_size, 1)
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=LOCAL_RANK, mode="train")
if RANK in {-1, 0}:
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
self.test_loader = self.get_dataloader(

@ -48,11 +48,12 @@ TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, "0.18.0")
def torch_distributed_zero_first(local_rank: int):
"""Ensures all processes in distributed training wait for the local master (rank 0) to complete a task first."""
initialized = dist.is_available() and dist.is_initialized()
if initialized and local_rank not in {-1, 0}:
dist.barrier(device_ids=[local_rank])
yield
if initialized and local_rank == 0:
dist.barrier(device_ids=[0])
dist.barrier(device_ids=[local_rank])
def smart_inference_mode():

Loading…
Cancel
Save