|
|
|
@ -170,6 +170,8 @@ def select_device(device="", batch=0, newline=False, verbose=True): |
|
|
|
|
elif device: # non-cpu device requested |
|
|
|
|
if device == "cuda": |
|
|
|
|
device = "0" |
|
|
|
|
if "," in device: |
|
|
|
|
device = ",".join([x for x in device.split(",") if x]) # remove sequential commas, i.e. "0,,1" -> "0,1" |
|
|
|
|
visible = os.environ.get("CUDA_VISIBLE_DEVICES", None) |
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available() |
|
|
|
|
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))): |
|
|
|
@ -191,7 +193,7 @@ def select_device(device="", batch=0, newline=False, verbose=True): |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available |
|
|
|
|
devices = device.split(",") if device else "0" # range(torch.cuda.device_count()) # i.e. 0,1,6,7 |
|
|
|
|
devices = device.split(",") if device else "0" # i.e. "0,1" -> ["0", "1"] |
|
|
|
|
n = len(devices) # device count |
|
|
|
|
if n > 1: # multi-GPU |
|
|
|
|
if batch < 1: |
|
|
|
|