From 169602442c8213e22b31688240d063f116ffe90f Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Mon, 24 Jun 2024 02:00:34 +0800 Subject: [PATCH] Fix HUB session with DDP training (#13103) Signed-off-by: Glenn Jocher Co-authored-by: Glenn Jocher Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> Co-authored-by: UltralyticsAssistant --- ultralytics/engine/trainer.py | 7 ++++++- ultralytics/hub/session.py | 2 +- ultralytics/utils/callbacks/hub.py | 2 +- ultralytics/utils/dist.py | 1 + ultralytics/utils/torch_utils.py | 4 ++-- 5 files changed, 11 insertions(+), 5 deletions(-) diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 69df57044..c833e7616 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -48,6 +48,7 @@ from ultralytics.utils.torch_utils import ( one_cycle, select_device, strip_optimizer, + torch_distributed_zero_first, ) @@ -127,7 +128,8 @@ class BaseTrainer: # Model and Dataset self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt - self.trainset, self.testset = self.get_dataset() + with torch_distributed_zero_first(RANK): # avoid auto-downloading dataset multiple times + self.trainset, self.testset = self.get_dataset() self.ema = None # Optimization utils init @@ -143,6 +145,9 @@ class BaseTrainer: self.csv = self.save_dir / "results.csv" self.plot_idx = [0, 1, 2] + # HUB + self.hub_session = None + # Callbacks self.callbacks = _callbacks or callbacks.get_default_callbacks() if RANK in {-1, 0}: diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py index 4de4fec3b..369df744a 100644 --- a/ultralytics/hub/session.py +++ b/ultralytics/hub/session.py @@ -72,7 +72,7 @@ class HUBTrainingSession: try: session = cls(identifier) assert session.client.authenticated, "HUB not authenticated" - if args: + if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # not a HUB model URL session.create_model(args) assert session.model.id, "HUB model not loaded correctly" return session diff --git a/ultralytics/utils/callbacks/hub.py b/ultralytics/utils/callbacks/hub.py index f312a6103..fbcd1667e 100644 --- a/ultralytics/utils/callbacks/hub.py +++ b/ultralytics/utils/callbacks/hub.py @@ -9,7 +9,7 @@ from ultralytics.utils import LOGGER, RANK, SETTINGS def on_pretrain_routine_start(trainer): """Create a remote Ultralytics HUB session to log local model training.""" - if RANK in {-1, 0} and SETTINGS["hub"] is True and not getattr(trainer, "hub_session", None): + if RANK in {-1, 0} and SETTINGS["hub"] is True and SETTINGS["api_key"] and trainer.hub_session is None: trainer.hub_session = HUBTrainingSession.create_session(trainer.args.model, trainer.args) diff --git a/ultralytics/utils/dist.py b/ultralytics/utils/dist.py index b669e52f5..ff980967f 100644 --- a/ultralytics/utils/dist.py +++ b/ultralytics/utils/dist.py @@ -37,6 +37,7 @@ if __name__ == "__main__": cfg = DEFAULT_CFG_DICT.copy() cfg.update(save_dir='') # handle the extra key 'save_dir' trainer = {name}(cfg=cfg, overrides=overrides) + trainer.args.model = "{getattr(trainer.hub_session, 'model_url', trainer.args.model)}" results = trainer.train() """ (USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True) diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index 41416a2d8..830ef915a 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -43,8 +43,8 @@ TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0") @contextmanager def torch_distributed_zero_first(local_rank: int): - """Decorator to make all processes in distributed training wait for each local_master to do something.""" - initialized = torch.distributed.is_available() and torch.distributed.is_initialized() + """Ensures all processes in distributed training wait for the local master (rank 0) to complete a task first.""" + initialized = dist.is_available() and dist.is_initialized() if initialized and local_rank not in {-1, 0}: dist.barrier(device_ids=[local_rank]) yield