`ultralytics 8.2.93` new SafeClass and SafeUnpickler classes (#16269)

Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/8642/head^2 v8.2.93
Ultralytics Assistant 2 months ago committed by GitHub
parent e309b6efab
commit c2068df9d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      ultralytics/__init__.py
  2. 55
      ultralytics/nn/tasks.py

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.92"
__version__ = "8.2.93"
import os

@ -1,6 +1,8 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import pickle
import types
from copy import deepcopy
from pathlib import Path
@ -750,7 +752,35 @@ def temporary_modules(modules=None, attributes=None):
del sys.modules[old]
def torch_safe_load(weight):
class SafeClass:
"""A placeholder class to replace unknown classes during unpickling."""
def __init__(self, *args, **kwargs):
"""Initialize SafeClass instance, ignoring all arguments."""
pass
class SafeUnpickler(pickle.Unpickler):
"""Custom Unpickler that replaces unknown classes with SafeClass."""
def find_class(self, module, name):
"""Attempt to find a class, returning SafeClass if not among safe modules."""
safe_modules = (
"torch",
"collections",
"collections.abc",
"builtins",
"math",
"numpy",
# Add other modules considered safe
)
if module in safe_modules:
return super().find_class(module, name)
else:
return SafeClass
def torch_safe_load(weight, safe_only=False):
"""
Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
@ -758,9 +788,18 @@ def torch_safe_load(weight):
Args:
weight (str): The file path of the PyTorch model.
safe_only (bool): If True, replace unknown classes with SafeClass during loading.
Example:
```python
from ultralytics.nn.tasks import torch_safe_load
ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True)
```
Returns:
(dict): The loaded PyTorch model.
ckpt (dict): The loaded model checkpoint.
file (str): The loaded filename
"""
from ultralytics.utils.downloads import attempt_download_asset
@ -779,7 +818,15 @@ def torch_safe_load(weight):
"ultralytics.utils.loss.v10DetectLoss": "ultralytics.utils.loss.E2EDetectLoss", # YOLOv10
},
):
ckpt = torch.load(file, map_location="cpu")
if safe_only:
# Load via custom pickle module
safe_pickle = types.ModuleType("safe_pickle")
safe_pickle.Unpickler = SafeUnpickler
safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load()
with open(file, "rb") as f:
ckpt = torch.load(f, pickle_module=safe_pickle)
else:
ckpt = torch.load(file, map_location="cpu")
except ModuleNotFoundError as e: # e.name is missing module name
if e.name == "models":
@ -809,7 +856,7 @@ def torch_safe_load(weight):
)
ckpt = {"model": ckpt.model}
return ckpt, file # load
return ckpt, file
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):

Loading…
Cancel
Save