|
|
@ -64,6 +64,8 @@ def select_device(device='', batch=0, newline=False, verbose=True): |
|
|
|
if cpu or mps: |
|
|
|
if cpu or mps: |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False |
|
|
|
elif device: # non-cpu device requested |
|
|
|
elif device: # non-cpu device requested |
|
|
|
|
|
|
|
if device == 'cuda': |
|
|
|
|
|
|
|
device = '0' |
|
|
|
visible = os.environ.get('CUDA_VISIBLE_DEVICES', None) |
|
|
|
visible = os.environ.get('CUDA_VISIBLE_DEVICES', None) |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available() |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available() |
|
|
|
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))): |
|
|
|
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))): |
|
|
|