`ultralytics 8.2.48` strip model `criterion` on save (#14106)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/14107/head v8.2.48
Glenn Jocher 5 months ago committed by GitHub
parent 7c1999929a
commit e7ede6564d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      ultralytics/__init__.py
  2. 6
      ultralytics/nn/tasks.py
  3. 21
      ultralytics/utils/torch_utils.py

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.47"
__version__ = "8.2.48"
import os

@ -788,14 +788,14 @@ def torch_safe_load(weight):
f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "
f"YOLOv8 at https://github.com/ultralytics/ultralytics."
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolov8n.pt'"
)
) from e
LOGGER.warning(
f"WARNING ⚠ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
f"WARNING ⚠ {weight} appears to require '{e.name}', which is not in Ultralytics requirements."
f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolov8n.pt'"
)
check_requirements(e.name) # install missing module
ckpt = torch.load(file, map_location="cpu")

@ -511,23 +511,30 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
```
"""
x = torch.load(f, map_location=torch.device("cpu"))
if "model" not in x:
if not isinstance(x, dict) or "model" not in x:
LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.")
return
# Update model
if x.get("ema"):
x["model"] = x["ema"] # replace model with EMA
if hasattr(x["model"], "args"):
x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict
args = {**DEFAULT_CFG_DICT, **x["train_args"]} if "train_args" in x else None # combine args
if x.get("ema"):
x["model"] = x["ema"] # replace model with ema
for k in "optimizer", "best_fitness", "ema", "updates": # keys
x[k] = None
x["epoch"] = -1
if hasattr(x["model"], "criterion"):
x["model"].criterion = None # strip loss criterion
x["model"].half() # to FP16
for p in x["model"].parameters():
p.requires_grad = False
# Update other keys
args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})} # combine args
for k in "optimizer", "best_fitness", "ema", "updates": # keys
x[k] = None
x["epoch"] = -1
x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
# x['model'].args = x['train_args']
# Save
torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1e6 # file size
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")

Loading…
Cancel
Save