`ultralytics 8.2.93` new SafeClass and SafeUnpickler classes (#16269)

Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/8642/head^2 v8.2.93
Ultralytics Assistant 2 months ago committed by GitHub
parent e309b6efab
commit c2068df9d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      ultralytics/__init__.py
  2. 53
      ultralytics/nn/tasks.py

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.92" __version__ = "8.2.93"
import os import os

@ -1,6 +1,8 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib import contextlib
import pickle
import types
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
@ -750,7 +752,35 @@ def temporary_modules(modules=None, attributes=None):
del sys.modules[old] del sys.modules[old]
def torch_safe_load(weight): class SafeClass:
"""A placeholder class to replace unknown classes during unpickling."""
def __init__(self, *args, **kwargs):
"""Initialize SafeClass instance, ignoring all arguments."""
pass
class SafeUnpickler(pickle.Unpickler):
"""Custom Unpickler that replaces unknown classes with SafeClass."""
def find_class(self, module, name):
"""Attempt to find a class, returning SafeClass if not among safe modules."""
safe_modules = (
"torch",
"collections",
"collections.abc",
"builtins",
"math",
"numpy",
# Add other modules considered safe
)
if module in safe_modules:
return super().find_class(module, name)
else:
return SafeClass
def torch_safe_load(weight, safe_only=False):
""" """
Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
error, logs a warning message, and attempts to install the missing module via the check_requirements() function. error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
@ -758,9 +788,18 @@ def torch_safe_load(weight):
Args: Args:
weight (str): The file path of the PyTorch model. weight (str): The file path of the PyTorch model.
safe_only (bool): If True, replace unknown classes with SafeClass during loading.
Example:
```python
from ultralytics.nn.tasks import torch_safe_load
ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True)
```
Returns: Returns:
(dict): The loaded PyTorch model. ckpt (dict): The loaded model checkpoint.
file (str): The loaded filename
""" """
from ultralytics.utils.downloads import attempt_download_asset from ultralytics.utils.downloads import attempt_download_asset
@ -779,6 +818,14 @@ def torch_safe_load(weight):
"ultralytics.utils.loss.v10DetectLoss": "ultralytics.utils.loss.E2EDetectLoss", # YOLOv10 "ultralytics.utils.loss.v10DetectLoss": "ultralytics.utils.loss.E2EDetectLoss", # YOLOv10
}, },
): ):
if safe_only:
# Load via custom pickle module
safe_pickle = types.ModuleType("safe_pickle")
safe_pickle.Unpickler = SafeUnpickler
safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load()
with open(file, "rb") as f:
ckpt = torch.load(f, pickle_module=safe_pickle)
else:
ckpt = torch.load(file, map_location="cpu") ckpt = torch.load(file, map_location="cpu")
except ModuleNotFoundError as e: # e.name is missing module name except ModuleNotFoundError as e: # e.name is missing module name
@ -809,7 +856,7 @@ def torch_safe_load(weight):
) )
ckpt = {"model": ckpt.model} ckpt = {"model": ckpt.model}
return ckpt, file # load return ckpt, file
def attempt_load_weights(weights, device=None, inplace=True, fuse=False): def attempt_load_weights(weights, device=None, inplace=True, fuse=False):

Loading…
Cancel
Save