diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 1968fbaa2f..979bc037da 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.2.92" +__version__ = "8.2.93" import os diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 46974620a6..274f56d54b 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -1,6 +1,8 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license import contextlib +import pickle +import types from copy import deepcopy from pathlib import Path @@ -750,7 +752,35 @@ def temporary_modules(modules=None, attributes=None): del sys.modules[old] -def torch_safe_load(weight): +class SafeClass: + """A placeholder class to replace unknown classes during unpickling.""" + + def __init__(self, *args, **kwargs): + """Initialize SafeClass instance, ignoring all arguments.""" + pass + + +class SafeUnpickler(pickle.Unpickler): + """Custom Unpickler that replaces unknown classes with SafeClass.""" + + def find_class(self, module, name): + """Attempt to find a class, returning SafeClass if not among safe modules.""" + safe_modules = ( + "torch", + "collections", + "collections.abc", + "builtins", + "math", + "numpy", + # Add other modules considered safe + ) + if module in safe_modules: + return super().find_class(module, name) + else: + return SafeClass + + +def torch_safe_load(weight, safe_only=False): """ Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the error, logs a warning message, and attempts to install the missing module via the check_requirements() function. @@ -758,9 +788,18 @@ def torch_safe_load(weight): Args: weight (str): The file path of the PyTorch model. + safe_only (bool): If True, replace unknown classes with SafeClass during loading. + + Example: + ```python + from ultralytics.nn.tasks import torch_safe_load + + ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True) + ``` Returns: - (dict): The loaded PyTorch model. + ckpt (dict): The loaded model checkpoint. + file (str): The loaded filename """ from ultralytics.utils.downloads import attempt_download_asset @@ -779,7 +818,15 @@ def torch_safe_load(weight): "ultralytics.utils.loss.v10DetectLoss": "ultralytics.utils.loss.E2EDetectLoss", # YOLOv10 }, ): - ckpt = torch.load(file, map_location="cpu") + if safe_only: + # Load via custom pickle module + safe_pickle = types.ModuleType("safe_pickle") + safe_pickle.Unpickler = SafeUnpickler + safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load() + with open(file, "rb") as f: + ckpt = torch.load(f, pickle_module=safe_pickle) + else: + ckpt = torch.load(file, map_location="cpu") except ModuleNotFoundError as e: # e.name is missing module name if e.name == "models": @@ -809,7 +856,7 @@ def torch_safe_load(weight): ) ckpt = {"model": ckpt.model} - return ckpt, file # load + return ckpt, file def attempt_load_weights(weights, device=None, inplace=True, fuse=False):