From 47ab96dab6ccffde018c7af8059b0fc6b74a20c0 Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Sun, 27 Aug 2023 21:14:29 +0800 Subject: [PATCH] Fix pytest `world_size` environment bug (#4590) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/engine/trainer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 58e8071d..9b280e38 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -168,9 +168,11 @@ class BaseTrainer: def train(self): """Allow device='', device=None on Multi-GPU systems to default to device=0.""" - if isinstance(self.args.device, int) or self.args.device: # i.e. device=0 or device=[0,1,2,3] - world_size = torch.cuda.device_count() - elif torch.cuda.is_available(): # i.e. device=None or device='' + if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3' + world_size = len(self.args.device.split(',')) + elif isinstance(self.args.device, tuple): # multi devices from cli is tuple type + world_size = len(self.args.device) + elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number world_size = 1 # default to device 0 else: # i.e. device='cpu' or 'mps' world_size = 0