|
|
|
@ -511,23 +511,30 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None: |
|
|
|
|
``` |
|
|
|
|
""" |
|
|
|
|
x = torch.load(f, map_location=torch.device("cpu")) |
|
|
|
|
if "model" not in x: |
|
|
|
|
if not isinstance(x, dict) or "model" not in x: |
|
|
|
|
LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.") |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
# Update model |
|
|
|
|
if x.get("ema"): |
|
|
|
|
x["model"] = x["ema"] # replace model with EMA |
|
|
|
|
if hasattr(x["model"], "args"): |
|
|
|
|
x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict |
|
|
|
|
args = {**DEFAULT_CFG_DICT, **x["train_args"]} if "train_args" in x else None # combine args |
|
|
|
|
if x.get("ema"): |
|
|
|
|
x["model"] = x["ema"] # replace model with ema |
|
|
|
|
for k in "optimizer", "best_fitness", "ema", "updates": # keys |
|
|
|
|
x[k] = None |
|
|
|
|
x["epoch"] = -1 |
|
|
|
|
if hasattr(x["model"], "criterion"): |
|
|
|
|
x["model"].criterion = None # strip loss criterion |
|
|
|
|
x["model"].half() # to FP16 |
|
|
|
|
for p in x["model"].parameters(): |
|
|
|
|
p.requires_grad = False |
|
|
|
|
|
|
|
|
|
# Update other keys |
|
|
|
|
args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})} # combine args |
|
|
|
|
for k in "optimizer", "best_fitness", "ema", "updates": # keys |
|
|
|
|
x[k] = None |
|
|
|
|
x["epoch"] = -1 |
|
|
|
|
x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys |
|
|
|
|
# x['model'].args = x['train_args'] |
|
|
|
|
|
|
|
|
|
# Save |
|
|
|
|
torch.save(x, s or f) |
|
|
|
|
mb = os.path.getsize(s or f) / 1e6 # file size |
|
|
|
|
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB") |
|
|
|
|