Patch `torch.load(..., weights_only=False)` to reduce warnings (#14638)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
pull/14572/head
Glenn Jocher 4 months ago committed by GitHub
parent 72466b9648
commit fb20867262
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 4
      docs/en/reference/utils/patches.md
  2. 3
      ultralytics/utils/__init__.py
  3. 30
      ultralytics/utils/patches.py

@ -23,6 +23,10 @@ keywords: Ultralytics, utils, patches, imread, imwrite, imshow, torch_save, Open
<br><br><hr><br>
## ::: ultralytics.utils.patches.torch_load
<br><br><hr><br>
## ::: ultralytics.utils.patches.torch_save
<br><br>

@ -1066,8 +1066,9 @@ TESTS_RUNNING = is_pytest_running() or is_github_action_running()
set_sentry()
# Apply monkey patches
from ultralytics.utils.patches import imread, imshow, imwrite, torch_save
from ultralytics.utils.patches import imread, imshow, imwrite, torch_load, torch_save
torch.load = torch_load
torch.save = torch_save
if WINDOWS:
# Apply cv2 patches for non-ASCII and non-UTF characters in image paths

@ -57,7 +57,33 @@ def imshow(winname: str, mat: np.ndarray):
# PyTorch functions ----------------------------------------------------------------------------------------------------
_torch_save = torch.save # copy to avoid recursion errors
_torch_load = torch.load # copy to avoid recursion errors
_torch_save = torch.save
def torch_load(*args, **kwargs):
"""
Load a PyTorch model with updated arguments to avoid warnings.
This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.
Args:
*args (Any): Variable length argument list to pass to torch.load.
**kwargs (Any): Arbitrary keyword arguments to pass to torch.load.
Returns:
(Any): The loaded PyTorch object.
Note:
For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False'
if the argument is not provided, to avoid deprecation warnings.
"""
from ultralytics.utils.torch_utils import TORCH_1_13
if TORCH_1_13 and "weights_only" not in kwargs:
kwargs["weights_only"] = False
return _torch_load(*args, **kwargs)
def torch_save(*args, use_dill=True, **kwargs):
@ -68,7 +94,7 @@ def torch_save(*args, use_dill=True, **kwargs):
Args:
*args (tuple): Positional arguments to pass to torch.save.
use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.
**kwargs (any): Keyword arguments to pass to torch.save.
**kwargs (Any): Keyword arguments to pass to torch.save.
"""
try:
assert use_dill

Loading…
Cancel
Save