From 41dfd65cc16fcfb5fb164179b6d367dcac787680 Mon Sep 17 00:00:00 2001 From: Burhan <62214284+Burhan-Q@users.noreply.github.com> Date: Thu, 5 Sep 2024 16:50:14 -0400 Subject: [PATCH] Allows any PyTorch install except `torch==2.4.0` on Windows (#16019) Co-authored-by: Glenn Jocher --- pyproject.toml | 2 +- ultralytics/utils/torch_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 00366df58..03e557e94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,8 +71,8 @@ dependencies = [ "pyyaml>=5.3.1", "requests>=2.23.0", "scipy>=1.4.1", - "torch>=1.8.0,<2.4.0; sys_platform == 'win32'", # Windows CPU errors https://github.com/ultralytics/ultralytics/issues/15049 "torch>=1.8.0", + "torch>=1.8.0,!=2.4.0; sys_platform == 'win32'", # Windows CPU errors w/ 2.4.0 https://github.com/ultralytics/ultralytics/issues/15049 "torchvision>=0.9.0", "tqdm>=4.64.0", # progress bars "psutil", # system utilization diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index 16bcddadd..7cde9dc7a 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -45,9 +45,9 @@ TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0") TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0") TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0") TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, "0.18.0") -if WINDOWS and torch.__version__[:3] == "2.4": # reject all versions of 2.4 on Windows +if WINDOWS and check_version(torch.__version__, "==2.4.0"): # reject version 2.4.0 on Windows LOGGER.warning( - "WARNING ⚠️ Known issue with torch>=2.4.0 on Windows with CPU, recommend downgrading to torch<=2.3.1 to resolve " + "WARNING ⚠️ Known issue with torch==2.4.0 on Windows with CPU, recommend upgrading to torch>=2.4.1 to resolve " "https://github.com/ultralytics/ultralytics/issues/15049" )