|
|
|
@ -57,7 +57,33 @@ def imshow(winname: str, mat: np.ndarray): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# PyTorch functions ---------------------------------------------------------------------------------------------------- |
|
|
|
|
_torch_save = torch.save # copy to avoid recursion errors |
|
|
|
|
_torch_load = torch.load # copy to avoid recursion errors |
|
|
|
|
_torch_save = torch.save |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def torch_load(*args, **kwargs): |
|
|
|
|
""" |
|
|
|
|
Load a PyTorch model with updated arguments to avoid warnings. |
|
|
|
|
|
|
|
|
|
This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
*args (Any): Variable length argument list to pass to torch.load. |
|
|
|
|
**kwargs (Any): Arbitrary keyword arguments to pass to torch.load. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
(Any): The loaded PyTorch object. |
|
|
|
|
|
|
|
|
|
Note: |
|
|
|
|
For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False' |
|
|
|
|
if the argument is not provided, to avoid deprecation warnings. |
|
|
|
|
""" |
|
|
|
|
from ultralytics.utils.torch_utils import TORCH_1_13 |
|
|
|
|
|
|
|
|
|
if TORCH_1_13 and "weights_only" not in kwargs: |
|
|
|
|
kwargs["weights_only"] = False |
|
|
|
|
|
|
|
|
|
return _torch_load(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def torch_save(*args, use_dill=True, **kwargs): |
|
|
|
@ -68,7 +94,7 @@ def torch_save(*args, use_dill=True, **kwargs): |
|
|
|
|
Args: |
|
|
|
|
*args (tuple): Positional arguments to pass to torch.save. |
|
|
|
|
use_dill (bool): Whether to try using dill for serialization if available. Defaults to True. |
|
|
|
|
**kwargs (any): Keyword arguments to pass to torch.save. |
|
|
|
|
**kwargs (Any): Keyword arguments to pass to torch.save. |
|
|
|
|
""" |
|
|
|
|
try: |
|
|
|
|
assert use_dill |
|
|
|
|