Fix pytest `world_size` environment bug (#4590)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/4591/head
Laughing 2 years ago committed by GitHub
parent deac7575b1
commit 47ab96dab6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 8
      ultralytics/engine/trainer.py

@ -168,9 +168,11 @@ class BaseTrainer:
def train(self):
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
if isinstance(self.args.device, int) or self.args.device: # i.e. device=0 or device=[0,1,2,3]
world_size = torch.cuda.device_count()
elif torch.cuda.is_available(): # i.e. device=None or device=''
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
world_size = len(self.args.device.split(','))
elif isinstance(self.args.device, tuple): # multi devices from cli is tuple type
world_size = len(self.args.device)
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
world_size = 1 # default to device 0
else: # i.e. device='cpu' or 'mps'
world_size = 0

Loading…
Cancel
Save