|
|
|
@ -62,10 +62,10 @@ class ClassificationTrainer(BaseTrainer): |
|
|
|
|
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
model = str(self.model) |
|
|
|
|
model, ckpt = str(self.model), None |
|
|
|
|
# Load a YOLO model locally, from torchvision, or from Ultralytics assets |
|
|
|
|
if model.endswith('.pt'): |
|
|
|
|
self.model, _ = attempt_load_one_weight(model, device='cpu') |
|
|
|
|
self.model, ckpt = attempt_load_one_weight(model, device='cpu') |
|
|
|
|
for p in self.model.parameters(): |
|
|
|
|
p.requires_grad = True # for training |
|
|
|
|
elif model.split('.')[-1] in ('yaml', 'yml'): |
|
|
|
@ -76,7 +76,7 @@ class ClassificationTrainer(BaseTrainer): |
|
|
|
|
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.') |
|
|
|
|
ClassificationModel.reshape_outputs(self.model, self.data['nc']) |
|
|
|
|
|
|
|
|
|
return # do not return ckpt. Classification doesn't support resume |
|
|
|
|
return ckpt |
|
|
|
|
|
|
|
|
|
def build_dataset(self, img_path, mode='train', batch=None): |
|
|
|
|
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train') |
|
|
|
@ -122,10 +122,6 @@ class ClassificationTrainer(BaseTrainer): |
|
|
|
|
loss_items = [round(float(loss_items), 5)] |
|
|
|
|
return dict(zip(keys, loss_items)) |
|
|
|
|
|
|
|
|
|
def resume_training(self, ckpt): |
|
|
|
|
"""Resumes training from a given checkpoint.""" |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
def plot_metrics(self): |
|
|
|
|
"""Plots metrics from a CSV file.""" |
|
|
|
|
plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png |
|
|
|
|