|
|
|
@ -1,6 +1,8 @@ |
|
|
|
|
# Ultralytics YOLO 🚀, AGPL-3.0 license |
|
|
|
|
|
|
|
|
|
import contextlib |
|
|
|
|
import pickle |
|
|
|
|
import types |
|
|
|
|
from copy import deepcopy |
|
|
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
@ -750,7 +752,35 @@ def temporary_modules(modules=None, attributes=None): |
|
|
|
|
del sys.modules[old] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def torch_safe_load(weight): |
|
|
|
|
class SafeClass: |
|
|
|
|
"""A placeholder class to replace unknown classes during unpickling.""" |
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
|
"""Initialize SafeClass instance, ignoring all arguments.""" |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SafeUnpickler(pickle.Unpickler): |
|
|
|
|
"""Custom Unpickler that replaces unknown classes with SafeClass.""" |
|
|
|
|
|
|
|
|
|
def find_class(self, module, name): |
|
|
|
|
"""Attempt to find a class, returning SafeClass if not among safe modules.""" |
|
|
|
|
safe_modules = ( |
|
|
|
|
"torch", |
|
|
|
|
"collections", |
|
|
|
|
"collections.abc", |
|
|
|
|
"builtins", |
|
|
|
|
"math", |
|
|
|
|
"numpy", |
|
|
|
|
# Add other modules considered safe |
|
|
|
|
) |
|
|
|
|
if module in safe_modules: |
|
|
|
|
return super().find_class(module, name) |
|
|
|
|
else: |
|
|
|
|
return SafeClass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def torch_safe_load(weight, safe_only=False): |
|
|
|
|
""" |
|
|
|
|
Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the |
|
|
|
|
error, logs a warning message, and attempts to install the missing module via the check_requirements() function. |
|
|
|
@ -758,9 +788,18 @@ def torch_safe_load(weight): |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
weight (str): The file path of the PyTorch model. |
|
|
|
|
safe_only (bool): If True, replace unknown classes with SafeClass during loading. |
|
|
|
|
|
|
|
|
|
Example: |
|
|
|
|
```python |
|
|
|
|
from ultralytics.nn.tasks import torch_safe_load |
|
|
|
|
|
|
|
|
|
ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True) |
|
|
|
|
``` |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
(dict): The loaded PyTorch model. |
|
|
|
|
ckpt (dict): The loaded model checkpoint. |
|
|
|
|
file (str): The loaded filename |
|
|
|
|
""" |
|
|
|
|
from ultralytics.utils.downloads import attempt_download_asset |
|
|
|
|
|
|
|
|
@ -779,6 +818,14 @@ def torch_safe_load(weight): |
|
|
|
|
"ultralytics.utils.loss.v10DetectLoss": "ultralytics.utils.loss.E2EDetectLoss", # YOLOv10 |
|
|
|
|
}, |
|
|
|
|
): |
|
|
|
|
if safe_only: |
|
|
|
|
# Load via custom pickle module |
|
|
|
|
safe_pickle = types.ModuleType("safe_pickle") |
|
|
|
|
safe_pickle.Unpickler = SafeUnpickler |
|
|
|
|
safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load() |
|
|
|
|
with open(file, "rb") as f: |
|
|
|
|
ckpt = torch.load(f, pickle_module=safe_pickle) |
|
|
|
|
else: |
|
|
|
|
ckpt = torch.load(file, map_location="cpu") |
|
|
|
|
|
|
|
|
|
except ModuleNotFoundError as e: # e.name is missing module name |
|
|
|
@ -809,7 +856,7 @@ def torch_safe_load(weight): |
|
|
|
|
) |
|
|
|
|
ckpt = {"model": ckpt.model} |
|
|
|
|
|
|
|
|
|
return ckpt, file # load |
|
|
|
|
return ckpt, file |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def attempt_load_weights(weights, device=None, inplace=True, fuse=False): |
|
|
|
|