From fb2086726268a02ff2e191fd2a880cbb55d29c4c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 23 Jul 2024 20:35:58 +0200 Subject: [PATCH] Patch `torch.load(..., weights_only=False)` to reduce warnings (#14638) Signed-off-by: Glenn Jocher Co-authored-by: UltralyticsAssistant --- docs/en/reference/utils/patches.md | 4 ++++ ultralytics/utils/__init__.py | 3 ++- ultralytics/utils/patches.py | 30 ++++++++++++++++++++++++++++-- 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/docs/en/reference/utils/patches.md b/docs/en/reference/utils/patches.md index 444a27423..50422d8c8 100644 --- a/docs/en/reference/utils/patches.md +++ b/docs/en/reference/utils/patches.md @@ -23,6 +23,10 @@ keywords: Ultralytics, utils, patches, imread, imwrite, imshow, torch_save, Open



+## ::: ultralytics.utils.patches.torch_load + +



+ ## ::: ultralytics.utils.patches.torch_save

diff --git a/ultralytics/utils/__init__.py b/ultralytics/utils/__init__.py index b97a0fc42..39f6ad2b3 100644 --- a/ultralytics/utils/__init__.py +++ b/ultralytics/utils/__init__.py @@ -1066,8 +1066,9 @@ TESTS_RUNNING = is_pytest_running() or is_github_action_running() set_sentry() # Apply monkey patches -from ultralytics.utils.patches import imread, imshow, imwrite, torch_save +from ultralytics.utils.patches import imread, imshow, imwrite, torch_load, torch_save +torch.load = torch_load torch.save = torch_save if WINDOWS: # Apply cv2 patches for non-ASCII and non-UTF characters in image paths diff --git a/ultralytics/utils/patches.py b/ultralytics/utils/patches.py index d43840711..d918e0efe 100644 --- a/ultralytics/utils/patches.py +++ b/ultralytics/utils/patches.py @@ -57,7 +57,33 @@ def imshow(winname: str, mat: np.ndarray): # PyTorch functions ---------------------------------------------------------------------------------------------------- -_torch_save = torch.save # copy to avoid recursion errors +_torch_load = torch.load # copy to avoid recursion errors +_torch_save = torch.save + + +def torch_load(*args, **kwargs): + """ + Load a PyTorch model with updated arguments to avoid warnings. + + This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings. + + Args: + *args (Any): Variable length argument list to pass to torch.load. + **kwargs (Any): Arbitrary keyword arguments to pass to torch.load. + + Returns: + (Any): The loaded PyTorch object. + + Note: + For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False' + if the argument is not provided, to avoid deprecation warnings. + """ + from ultralytics.utils.torch_utils import TORCH_1_13 + + if TORCH_1_13 and "weights_only" not in kwargs: + kwargs["weights_only"] = False + + return _torch_load(*args, **kwargs) def torch_save(*args, use_dill=True, **kwargs): @@ -68,7 +94,7 @@ def torch_save(*args, use_dill=True, **kwargs): Args: *args (tuple): Positional arguments to pass to torch.save. use_dill (bool): Whether to try using dill for serialization if available. Defaults to True. - **kwargs (any): Keyword arguments to pass to torch.save. + **kwargs (Any): Keyword arguments to pass to torch.save. """ try: assert use_dill