|
|
|
@ -8,7 +8,7 @@ from ultralytics.data import ClassificationDataset, build_dataloader |
|
|
|
|
from ultralytics.engine.trainer import BaseTrainer |
|
|
|
|
from ultralytics.models import yolo |
|
|
|
|
from ultralytics.nn.tasks import ClassificationModel |
|
|
|
|
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, colorstr |
|
|
|
|
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK |
|
|
|
|
from ultralytics.utils.plotting import plot_images, plot_results |
|
|
|
|
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first |
|
|
|
|
|
|
|
|
@ -141,7 +141,6 @@ class ClassificationTrainer(BaseTrainer): |
|
|
|
|
self.metrics = self.validator(model=f) |
|
|
|
|
self.metrics.pop("fitness", None) |
|
|
|
|
self.run_callbacks("on_fit_epoch_end") |
|
|
|
|
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}") |
|
|
|
|
|
|
|
|
|
def plot_training_samples(self, batch, ni): |
|
|
|
|
"""Plots training samples with their annotations.""" |
|
|
|
|