`ultralytics 8.0.159` add Classify training `resume` feature (#4482)

pull/4483/head v8.0.159
Glenn Jocher 1 year ago committed by GitHub
parent b2f279ffdd
commit c0a9660310
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      ultralytics/__init__.py
  2. 2
      ultralytics/cfg/__init__.py
  3. 10
      ultralytics/models/yolo/classify/train.py

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.158'
__version__ = '8.0.159'
from ultralytics.hub import start
from ultralytics.models import RTDETR, SAM, YOLO

@ -419,7 +419,7 @@ def entrypoint(debug=''):
overrides['source'] = DEFAULT_CFG.source or ASSETS
LOGGER.warning(f"WARNING ⚠ 'source' is missing. Using default 'source={overrides['source']}'.")
elif mode in ('train', 'val'):
if 'data' not in overrides:
if 'data' not in overrides and 'resume' not in overrides:
overrides['data'] = TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
LOGGER.warning(f"WARNING ⚠ 'data' is missing. Using default 'data={overrides['data']}'.")
elif mode == 'export':

@ -62,10 +62,10 @@ class ClassificationTrainer(BaseTrainer):
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
model = str(self.model)
model, ckpt = str(self.model), None
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
if model.endswith('.pt'):
self.model, _ = attempt_load_one_weight(model, device='cpu')
self.model, ckpt = attempt_load_one_weight(model, device='cpu')
for p in self.model.parameters():
p.requires_grad = True # for training
elif model.split('.')[-1] in ('yaml', 'yml'):
@ -76,7 +76,7 @@ class ClassificationTrainer(BaseTrainer):
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
return # do not return ckpt. Classification doesn't support resume
return ckpt
def build_dataset(self, img_path, mode='train', batch=None):
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train')
@ -122,10 +122,6 @@ class ClassificationTrainer(BaseTrainer):
loss_items = [round(float(loss_items), 5)]
return dict(zip(keys, loss_items))
def resume_training(self, ckpt):
"""Resumes training from a given checkpoint."""
pass
def plot_metrics(self):
"""Plots metrics from a CSV file."""
plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png

Loading…
Cancel
Save