update segment training (#57)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
pull/59/head
Laughing 2 years ago committed by GitHub
parent d0b0fe2592
commit 3a241e4cea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 6
      ultralytics/yolo/data/augment.py
  2. 2
      ultralytics/yolo/data/base.py
  3. 54
      ultralytics/yolo/data/build.py
  4. 115
      ultralytics/yolo/engine/trainer.py
  5. 42
      ultralytics/yolo/engine/validator.py
  6. 11
      ultralytics/yolo/utils/__init__.py
  7. 31
      ultralytics/yolo/utils/configs/default.yaml
  8. 47
      ultralytics/yolo/utils/metrics.py
  9. 150
      ultralytics/yolo/utils/plotting.py
  10. 16
      ultralytics/yolo/utils/torch_utils.py
  11. 5
      ultralytics/yolo/v8/classify/train.py
  12. 4
      ultralytics/yolo/v8/classify/val.py
  13. 49
      ultralytics/yolo/v8/segment/train.py
  14. 72
      ultralytics/yolo/v8/segment/val.py

@ -578,8 +578,8 @@ class Albumentations:
# TODO: add supports of segments and keypoints
if self.transform and random.random() < self.p:
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
labels["img"] = new["image"]
labels["cls"] = np.array(new["class_labels"])
labels["img"] = new["image"]
labels["cls"] = np.array(new["class_labels"])
labels["instances"].update(bboxes=bboxes)
return labels
@ -635,7 +635,7 @@ class Format:
def _format_img(self, img):
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1])
img = torch.from_numpy(img)
return img

@ -151,7 +151,7 @@ class BaseDataset(Dataset):
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
nb = bi[-1] + 1 # number of batches
s = np.array([x["shape"] for x in self.labels]) # hw
s = np.array([x.pop("shape") for x in self.labels]) # hw
ar = s[:, 0] / s[:, 1] # aspect ratio
irect = ar.argsort()
self.im_files = [self.im_files[i] for i in irect]

@ -5,7 +5,7 @@ import numpy as np
import torch
from torch.utils.data import DataLoader, dataloader, distributed
from ..utils import LOGGER
from ..utils import LOGGER, colorstr
from ..utils.torch_utils import torch_distributed_zero_first
from .dataset import ClassificationDataset, YOLODataset
from .utils import PIN_MEMORY, RANK
@ -52,53 +52,36 @@ def seed_worker(worker_id):
random.seed(worker_seed)
# TODO: we can inject most args from a config file
def build_dataloader(
img_path,
img_size, #
batch_size, #
single_cls=False, #
hyp=None, #
augment=False,
cache=False, #
image_weights=False, #
stride=32,
label_path=None,
pad=0.0,
rect=False,
rank=-1,
workers=8,
prefix="",
shuffle=False,
use_segments=False,
use_keypoints=False,
):
if rect and shuffle:
def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank=-1, mode="train"):
assert mode in ["train", "val"]
shuffle = mode == "train"
if cfg.rect and shuffle:
LOGGER.warning("WARNING ⚠ --rect is incompatible with DataLoader shuffle, setting shuffle=False")
shuffle = False
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = YOLODataset(
img_path=img_path,
img_size=img_size,
batch_size=batch_size,
label_path=label_path,
augment=augment, # augmentation
hyp=hyp,
rect=rect, # rectangular batches
cache=cache,
single_cls=single_cls,
img_size=cfg.img_size,
batch_size=batch_size,
augment=True if mode == "train" else False, # augmentation
hyp=cfg.get("augment_hyp", None),
rect=cfg.rect if mode == "train" else True, # rectangular batches
cache=None if cfg.noval else cfg.get("cache", None),
single_cls=cfg.get("single_cls", False),
stride=int(stride),
pad=pad,
prefix=prefix,
use_segments=use_segments,
use_keypoints=use_keypoints,
pad=0.0 if mode == "train" else 0.5,
prefix=colorstr(f"{mode}: "),
use_segments=cfg.task == "segment",
use_keypoints=cfg.task == "keypoint",
)
batch_size = min(batch_size, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices
workers = cfg.workers if mode == "train" else cfg.workers * 2
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
loader = DataLoader if cfg.image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK)
return (
@ -118,6 +101,7 @@ def build_dataloader(
# build classification
# TODO: using cfg like `build_dataloader`
def build_classification_dataloader(path,
imgsz=224,
batch_size=16,

@ -24,11 +24,11 @@ from tqdm import tqdm
import ultralytics.yolo.utils as utils
import ultralytics.yolo.utils.callbacks as callbacks
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
from ultralytics.yolo.utils.checks import print_args
from ultralytics.yolo.utils.files import increment_path, save_yaml
from ultralytics.yolo.utils.modeling import get_model
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
RANK = int(os.getenv('RANK', -1))
@ -48,13 +48,15 @@ class BaseTrainer:
self.wdir = self.save_dir / 'weights' # weights dir
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
self.batch_size = self.args.batch_size
self.epochs = self.args.epochs
print_args(dict(self.args))
# Save run settings
save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True))
# device
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch_size)
self.device = utils.torch_utils.select_device(self.args.device, self.batch_size)
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
# Model and Dataloaders.
@ -73,10 +75,11 @@ class BaseTrainer:
self.scheduler = None
# epoch level metrics
self.metrics = {} # handle metrics returned by validator
self.best_fitness = None
self.fitness = None
self.loss = None
self.tloss = None
self.csv = self.save_dir / 'results.csv'
for callback, func in callbacks.default_callbacks.items():
self.add_callback(callback, func)
@ -122,6 +125,7 @@ class BaseTrainer:
if world_size > 1:
mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True)
else:
# self._do_train(int(os.getenv("RANK", -1)), world_size)
self._do_train()
def _setup_ddp(self, rank, world_size):
@ -129,21 +133,20 @@ class BaseTrainer:
os.environ['MASTER_PORT'] = '9020'
torch.cuda.set_device(rank)
self.device = torch.device('cuda', rank)
print(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
self.model = self.model.to(self.device)
self.model = DDP(self.model, device_ids=[rank])
self.args.batch_size = self.args.batch_size // world_size
def _setup_train(self, rank):
def _setup_train(self, rank, world_size):
"""
Builds dataloaders and optimizer on correct rank process
"""
# Optimizer
self.set_model_attributes()
accumulate = max(round(self.args.nbs / self.args.batch_size), 1) # accumulate loss before optimizing
self.args.weight_decay *= self.args.batch_size * accumulate / self.args.nbs # scale weight_decay
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
self.optimizer = build_optimizer(model=self.model,
name=self.args.optimizer,
lr=self.args.lr0,
@ -151,18 +154,21 @@ class BaseTrainer:
decay=self.args.weight_decay)
# Scheduler
if self.args.cos_lr:
self.lf = one_cycle(1, self.args.lrf, self.args.epochs) # cosine 1->hyp['lrf']
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
else:
self.lf = lambda x: (1 - x / self.args.epochs) * (1.0 - self.args.lrf + self.args.lrf) # linear
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
# dataloaders
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank)
batch_size = self.batch_size // world_size
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode="train")
if rank in {0, -1}:
print(" Creating testloader rank :", rank)
self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=-1)
self.validator = self.get_validator()
print("created testloader :", rank)
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
validator = self.get_validator()
# init metric, for plot_results
metric_keys = validator.metric_keys + self.label_loss_items(prefix="val")
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
self.validator = validator
self.ema = ModelEMA(self.model)
def _do_train(self, rank=-1, world_size=1):
@ -172,7 +178,7 @@ class BaseTrainer:
self.model = self.model.to(self.device)
self.trigger_callbacks("before_train")
self._setup_train(rank)
self._setup_train(rank, world_size)
self.epoch = 0
self.epoch_time = None
@ -181,13 +187,17 @@ class BaseTrainer:
nb = len(self.train_loader) # number of batches
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
last_opt_step = -1
for epoch in range(self.args.epochs):
for epoch in range(self.epochs):
self.trigger_callbacks("on_epoch_start")
self.model.train()
if rank != -1:
self.train_loader.sampler.set_epoch(epoch)
pbar = enumerate(self.train_loader)
if rank in {-1, 0}:
self.console.info(self.progress_string())
pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), bar_format=TQDM_BAR_FORMAT)
self.tloss = None
self.optimizer.zero_grad()
for i, batch in pbar:
self.trigger_callbacks("on_batch_start")
# forward
@ -197,7 +207,7 @@ class BaseTrainer:
ni = i + nb * epoch
if ni <= nw:
xi = [0, nw] # x interp
accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.args.batch_size]).round())
self.accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())
for j, x in enumerate(self.optimizer.param_groups):
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
x['lr'] = np.interp(
@ -207,37 +217,47 @@ class BaseTrainer:
preds = self.model(batch["img"])
self.loss, self.loss_items = self.criterion(preds, batch)
if rank != -1:
self.loss *= world_size
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
else self.loss_items
# backward
self.model.zero_grad(set_to_none=True)
self.scaler.scale(self.loss).backward()
# optimize
if ni - last_opt_step >= accumulate:
if ni - last_opt_step >= self.accumulate:
self.optimizer_step()
last_opt_step = ni
# log
mem = (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
if rank in {-1, 0}:
pbar.set_description(
(" {} " + "{:.3f} " * (1 + loss_len) + ' {} ').format(f'{epoch + 1}/{self.args.epochs}', mem,
*losses, batch["img"].shape[-1]))
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]))
self.trigger_callbacks('on_batch_end')
if self.args.plots and ni < 3:
self.plot_training_samples(batch, ni)
lr = {f"lr{ir}": x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
self.scheduler.step()
if rank in [-1, 0]:
# validation
self.trigger_callbacks('on_val_start')
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
self.metrics, self.fitness = self.validate()
final_epoch = (epoch + 1 == self.epochs)
if not self.args.noval or final_epoch:
self.metrics, self.fitness = self.validate()
self.trigger_callbacks('on_val_end')
log_vals = self.label_loss_items(self.tloss) | self.metrics | lr
self.save_metrics(metrics=log_vals)
# save model
if (not self.args.nosave) or (self.epoch + 1 == self.args.epochs):
if (not self.args.nosave) or (self.epoch + 1 == self.epochs):
self.save_model()
self.trigger_callbacks('on_model_save')
@ -248,9 +268,15 @@ class BaseTrainer:
# TODO: termination condition
self.log(f"\nTraining complete ({(time.time() - self.train_time_start) / 3600:.3f} hours)")
self.trigger_callbacks('on_train_end')
if rank in [-1, 0]:
# do the last evaluation with best.pt
self.final_eval()
if self.args.plots:
self.plot_metrics()
self.log(f"\nTraining complete ({(time.time() - self.train_time_start) / 3600:.3f} hours)")
self.trigger_callbacks('on_train_end')
dist.destroy_process_group() if world_size != 1 else None
torch.cuda.empty_cache()
def save_model(self):
ckpt = {
@ -306,7 +332,7 @@ class BaseTrainer:
"fitness" metric.
"""
metrics = self.validator(self)
fitness = metrics.get("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
if not self.best_fitness or self.best_fitness < fitness:
self.best_fitness = self.fitness
return metrics, fitness
@ -339,12 +365,12 @@ class BaseTrainer:
"""
raise NotImplementedError("criterion function not implemented in trainer")
def label_loss_items(self, loss_items):
def label_loss_items(self, loss_items=None, prefix="train"):
"""
Returns a loss dict with labelled training loss items tensor
"""
# Not needed for classification but necessary for segmentation & detection
return {"loss": loss_items}
return {"loss": loss_items} if loss_items is not None else ["loss"]
def set_model_attributes(self):
"""
@ -355,6 +381,31 @@ class BaseTrainer:
def build_targets(self, preds, targets):
pass
def progress_string(self):
return ""
# TODO: may need to put these following functions into callback
def plot_training_samples(self, batch, ni):
pass
def save_metrics(self, metrics):
keys, vals = list(metrics.keys()), list(metrics.values())
n = len(metrics) + 1 # number of cols
s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
with open(self.csv, 'a') as f:
f.write(s + ('%23.5g,' * n % tuple([self.epoch] + vals)).rstrip(',') + '\n')
def plot_metrics(self):
pass
def final_eval(self):
# TODO: need standalone evaluator to do this
for f in self.last, self.best:
if f.exists():
strip_optimizer(f) # strip optimizers
if f is self.best:
self.console.info(f'\nValidating {f}...')
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
# TODO: 1. docstring with example? 2. Move this inside Trainer? or utils?
@ -382,7 +433,7 @@ def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
LOGGER.info(f"optimizer: {type(optimizer).__name__}(lr={lr}) with parameter groups "
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")
return optimizer

@ -1,4 +1,5 @@
import logging
from pathlib import Path
import torch
from omegaconf import OmegaConf
@ -6,6 +7,7 @@ from tqdm import tqdm
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
from ultralytics.yolo.utils import TQDM_BAR_FORMAT
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import de_parallel, select_device
@ -15,16 +17,17 @@ class BaseValidator:
Base validator class.
"""
def __init__(self, dataloader, pbar=None, logger=None, args=None):
def __init__(self, dataloader, save_dir=None, pbar=None, logger=None, args=None):
self.dataloader = dataloader
self.pbar = pbar
self.logger = logger or logging.getLogger()
self.args = args or OmegaConf.load(DEFAULT_CONFIG)
self.device = select_device(self.args.device, dataloader.batch_size)
self.save_dir = save_dir if save_dir is not None else \
increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
self.cuda = self.device.type != 'cpu'
self.batch_i = None
self.training = True
self.loss = None
def __call__(self, trainer=None, model=None):
"""
@ -35,20 +38,22 @@ class BaseValidator:
if self.training:
model = trainer.ema.ema or trainer.model
self.args.half &= self.device.type != 'cpu'
# NOTE: half() inference in evaluation will make training stuck,
# so I comment it out for now, I think we can reuse half mode after we add EMA.
model = model.half() if self.args.half else model.float()
loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
else: # TODO: handle this when detectMultiBackend is supported
assert model is not None, "Either trainer or model is needed for validation"
# model = DetectMultiBacked(model)
# TODO: implement init_model_attributes()
model.eval()
dt = Profile(), Profile(), Profile(), Profile()
self.loss = 0
n_batches = len(self.dataloader)
desc = self.get_desc()
bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
# NOTE: keeping this `not self.training` in tqdm will eliminate pbar after finishing segmantation evaluation during training,
# so I removed it, not sure if this will affect classification task cause I saw we use this arg in yolov5/classify/val.py.
# bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
bar = tqdm(self.dataloader, desc, n_batches, bar_format=TQDM_BAR_FORMAT)
self.init_metrics(de_parallel(model))
with torch.no_grad():
for batch_i, batch in enumerate(bar):
@ -59,20 +64,23 @@ class BaseValidator:
# inference
with dt[1]:
preds = model(batch["img"].float())
preds = model(batch["img"])
# TODO: remember to add native augmentation support when implementing model, like:
# preds, train_out = model(im, augment=augment)
# loss
with dt[2]:
if self.training:
self.loss += trainer.criterion(preds, batch)[0]
loss += trainer.criterion(preds, batch)[1]
# pre-process predictions
with dt[3]:
preds = self.postprocess(preds)
self.update_metrics(preds, batch)
if self.args.plots and batch_i < 3:
self.plot_val_samples(batch, batch_i)
self.plot_predictions(batch, preds, batch_i)
stats = self.get_stats()
self.check_stats(stats)
@ -81,7 +89,7 @@ class BaseValidator:
# print speeds
if not self.training:
t = tuple(x.t / len(self.dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
t = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image
# shape = (self.dataloader.batch_size, 3, imgsz, imgsz)
self.logger.info(
'Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image at shape ' % t)
@ -90,7 +98,8 @@ class BaseValidator:
model.float()
# TODO: implement save json
return stats
return stats | trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val") \
if self.training else stats
def preprocess(self, batch):
return batch
@ -105,7 +114,7 @@ class BaseValidator:
pass
def get_stats(self):
pass
return {}
def check_stats(self, stats):
pass
@ -115,3 +124,14 @@ class BaseValidator:
def get_desc(self):
pass
@property
def metric_keys(self):
return []
# TODO: may need to put these following functions into callback
def plot_val_samples(self, batch, ni):
pass
def plot_predictions(self, batch, preds, ni):
pass

@ -3,6 +3,7 @@ import logging.config
import os
import platform
import sys
import threading
from pathlib import Path
# Constants
@ -130,3 +131,13 @@ class TryExcept(contextlib.ContextDecorator):
if value:
print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))
return True
def threaded(func):
# Multi-threads a target function and returns thread. Usage: @threaded decorator
def wrapper(*args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread
return wrapper

@ -26,11 +26,11 @@ deterministic: True
local_rank: -1
single_cls: False # train multi-class data as single-class
image_weights: False # use weighted image selection for training
shuffle: True
rect: False # support rectangular training
cos_lr: False # Use cosine LR scheduler
overlap_mask: True # Segmentation masks overlap
mask_ratio: 4 # Segmentation mask downsample ratio
noval: False
# Val/Test settings ----------------------------------------------------------------------------------------------------
save_json: False
@ -43,7 +43,7 @@ plots: False
save_txt: False
# Hyperparameters ------------------------------------------------------------------------------------------------------
lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3)
lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
momentum: 0.937 # SGD momentum/Adam beta1
weight_decay: 0.0005 # optimizer weight decay 5e-4
@ -59,22 +59,23 @@ iou_t: 0.20 # IoU training threshold
anchor_t: 4.0 # anchor-multiple threshold
# anchors: 3 # anchors per output layer (0 to ignore)
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
hsv_v: 0.4 # image HSV-Value augmentation (fraction)
degrees: 0.0 # image rotation (+/- deg)
translate: 0.1 # image translation (+/- fraction)
scale: 0.5 # image scale (+/- gain)
shear: 0.0 # image shear (+/- deg)
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
flipud: 0.0 # image flip up-down (probability)
fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability)
label_smoothing: 0.0
nbs: 64 # nominal batch size
# anchors: 3
augment_hyp:
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
hsv_v: 0.4 # image HSV-Value augmentation (fraction)
degrees: 0.0 # image rotation (+/- deg)
translate: 0.1 # image translation (+/- fraction)
scale: 0.5 # image scale (+/- gain)
shear: 0.0 # image shear (+/- deg)
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
flipud: 0.0 # image flip up-down (probability)
fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability)
# Hydra configs --------------------------------------------------------------------------------------------------------
hydra:

@ -283,6 +283,50 @@ def smooth(y, f=0.05):
return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
# Precision-recall curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
py = np.stack(py, axis=1)
if 0 < len(names) < 21: # display per-class legend if < 21 classes
for i, y in enumerate(py.T):
ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
else:
ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
ax.set_title('Precision-Recall Curve')
fig.savefig(save_dir, dpi=250)
plt.close(fig)
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
# Metric-confidence curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
if 0 < len(names) < 21: # display per-class legend if < 21 classes
for i, y in enumerate(py):
ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
else:
ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
y = smooth(py.mean(0), 0.05)
ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
ax.set_title(f'{ylabel}-Confidence Curve')
fig.savefig(save_dir, dpi=250)
plt.close(fig)
def compute_ap(recall, precision):
""" Compute the average precision, given the recall and precision curves
# Arguments
@ -365,14 +409,11 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
f1 = 2 * p * r / (p + r + eps)
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
names = dict(enumerate(names)) # to dict
# TODO: plot
'''
if plot:
plot_pr_curve(px, py, ap, Path(save_dir) / f'{prefix}PR_curve.png', names)
plot_mc_curve(px, f1, Path(save_dir) / f'{prefix}F1_curve.png', names, ylabel='F1')
plot_mc_curve(px, p, Path(save_dir) / f'{prefix}P_curve.png', names, ylabel='Precision')
plot_mc_curve(px, r, Path(save_dir) / f'{prefix}R_curve.png', names, ylabel='Recall')
'''
i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
p, r, f1 = p[:, i], r[:, i], f1[:, i]

@ -1,12 +1,16 @@
import contextlib
import math
from pathlib import Path
from urllib.error import URLError
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from PIL import Image, ImageDraw, ImageFont
from ultralytics.yolo.utils import FONT, USER_CONFIG_DIR
from ultralytics.yolo.utils import FONT, USER_CONFIG_DIR, threaded
from .checks import check_font, check_requirements, is_ascii
from .files import increment_path
@ -179,3 +183,147 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False,
# cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
return crop
@threaded
def plot_images_and_masks(images, batch_idx, cls, bboxes, masks, paths, confs=None, fname='images.jpg', names=None):
# Plot image grid with labels
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
if isinstance(cls, torch.Tensor):
cls = cls.cpu().numpy()
if isinstance(bboxes, torch.Tensor):
bboxes = bboxes.cpu().numpy()
if isinstance(masks, torch.Tensor):
masks = masks.cpu().numpy().astype(int)
if isinstance(batch_idx, torch.Tensor):
batch_idx = batch_idx.cpu().numpy()
max_size = 1920 # max image size
max_subplots = 16 # max image subplots, i.e. 4x4
bs, _, h, w = images.shape # batch size, _, height, width
bs = min(bs, max_subplots) # limit plot images
ns = np.ceil(bs ** 0.5) # number of subplots (square)
if np.max(images[0]) <= 1:
images *= 255 # de-normalise (optional)
# Build Image
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
for i, im in enumerate(images):
if i == max_subplots: # if last batch has fewer images than we expect
break
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
im = im.transpose(1, 2, 0)
mosaic[y:y + h, x:x + w, :] = im
# Resize (optional)
scale = max_size / ns / max(h, w)
if scale < 1:
h = math.ceil(scale * h)
w = math.ceil(scale * w)
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
# Annotate
fs = int((h + w) * ns * 0.01) # font size
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
for i in range(i + 1):
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
if paths:
annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
if len(cls) > 0:
idx = batch_idx == i
boxes = xywh2xyxy(bboxes[idx]).T
classes = cls[idx].astype('int')
labels = confs is None # labels if no conf column
conf = None if labels else confs[idx] # check for confidence presence (label vs pred)
if boxes.shape[1]:
if boxes.max() <= 1.01: # if normalized with tolerance 0.01
boxes[[0, 2]] *= w # scale to pixels
boxes[[1, 3]] *= h
elif scale < 1: # absolute coords need scale if image scales
boxes *= scale
boxes[[0, 2]] += x
boxes[[1, 3]] += y
for j, box in enumerate(boxes.T.tolist()):
c = classes[j]
color = colors(c)
c = names[c] if names else c
if labels or conf[j] > 0.25: # 0.25 conf thresh
label = f'{c}' if labels else f'{c} {conf[j]:.1f}'
annotator.box_label(box, label, color=color)
# Plot masks
if len(masks):
if masks.max() > 1.0: # mean that masks are overlap
image_masks = masks[[i]] # (1, 640, 640)
nl = idx.sum()
index = np.arange(nl).reshape(nl, 1, 1) + 1
image_masks = np.repeat(image_masks, nl, axis=0)
image_masks = np.where(image_masks == index, 1.0, 0.0)
else:
image_masks = masks[idx]
im = np.asarray(annotator.im).copy()
for j, box in enumerate(boxes.T.tolist()):
if labels or conf[j] > 0.25: # 0.25 conf thresh
color = colors(classes[j])
mh, mw = image_masks[j].shape
if mh != h or mw != w:
mask = image_masks[j].astype(np.uint8)
mask = cv2.resize(mask, (w, h))
mask = mask.astype(bool)
else:
mask = image_masks[j].astype(bool)
with contextlib.suppress(Exception):
im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
annotator.fromarray(im)
annotator.im.save(fname) # save
def plot_results_with_masks(file="path/to/results.csv", dir="", best=True):
# Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
save_dir = Path(file).parent if file else Path(dir)
fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
ax = ax.ravel()
files = list(save_dir.glob("results*.csv"))
assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
for f in files:
try:
data = pd.read_csv(f)
index = np.argmax(0.9 * data.values[:, 8] + 0.1 * data.values[:, 7] + 0.9 * data.values[:, 12] +
0.1 * data.values[:, 11])
s = [x.strip() for x in data.columns]
x = data.values[:, 0]
for i, j in enumerate([1, 2, 3, 4, 5, 6, 9, 10, 13, 14, 15, 16, 7, 8, 11, 12]):
y = data.values[:, j]
# y[y == 0] = np.nan # don't show zero values
ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=2)
if best:
# best
ax[i].scatter(index, y[index], color="r", label=f"best:{index}", marker="*", linewidth=3)
ax[i].set_title(s[j] + f"\n{round(y[index], 5)}")
else:
# last
ax[i].scatter(x[-1], y[-1], color="r", label="last", marker="*", linewidth=3)
ax[i].set_title(s[j] + f"\n{round(y[-1], 5)}")
# if j in [8, 9, 10]: # share train and val loss y axes
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
except Exception as e:
print(f"Warning: Plotting error for {f}: {e}")
ax[1].legend()
fig.savefig(save_dir / "results.png", dpi=200)
plt.close()
def output_to_target(output, max_det=300):
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting
targets = []
for i, o in enumerate(output):
box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
j = torch.full((conf.shape[0], 1), i)
targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
targets = torch.cat(targets, 0).numpy()
return targets[:, 0], targets[:, 1], targets[:, 2:6], targets[:, 6]

@ -245,3 +245,19 @@ class ModelEMA:
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes
copy_attr(self.ema, model, include, exclude)
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
# Strip optimizer from 'f' to finalize training, optionally save as 's'
x = torch.load(f, map_location=torch.device('cpu'))
if x.get('ema'):
x['model'] = x['ema'] # replace model with ema
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
x[k] = None
x['epoch'] = -1
x['model'].half() # to FP16
for p in x['model'].parameters():
p.requires_grad = False
torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1E6 # filesize
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")

@ -9,6 +9,9 @@ from ultralytics.yolo.utils.modeling.tasks import ClassificationModel
class ClassificationTrainer(BaseTrainer):
def set_model_attributes(self):
self.model.names = self.data["names"]
def load_model(self, model_cfg, weights, data):
# TODO: why treat clf models as unique. We should have clf yamls?
if weights and not weights.__class__.__name__.startswith("yolo"): # torchvision
@ -18,7 +21,7 @@ class ClassificationTrainer(BaseTrainer):
ClassificationModel.reshape_outputs(model, data["nc"])
return model
def get_dataloader(self, dataset_path, batch_size=None, rank=0):
def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"):
return build_classification_dataloader(path=dataset_path,
imgsz=self.args.img_size,
batch_size=batch_size,

@ -23,3 +23,7 @@ class ClassificationValidator(BaseValidator):
acc = torch.stack((self.correct[:, 0], self.correct.max(1).values), dim=1) # (top1, top5) accuracy
top1, top5 = acc.mean(0).tolist()
return {"top1": top1, "top5": top5, "fitness": top5}
@property
def metric_keys(self):
return ["top1", "top5"]

@ -9,30 +9,18 @@ from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE
from ultralytics.yolo.utils.modeling.tasks import SegmentationModel
from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy
from ultralytics.yolo.utils.plotting import plot_images_and_masks, plot_results_with_masks
from ultralytics.yolo.utils.torch_utils import de_parallel
# BaseTrainer python usage
class SegmentationTrainer(BaseTrainer):
def get_dataloader(self, dataset_path, batch_size, rank=0):
def get_dataloader(self, dataset_path, batch_size, mode="train", rank=0):
# TODO: manage splits differently
# calculate stride - check if model is initialized
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
return build_dataloader(
img_path=dataset_path,
img_size=self.args.img_size,
batch_size=batch_size,
single_cls=self.args.single_cls,
cache=self.args.cache,
image_weights=self.args.image_weights,
stride=gs,
rect=self.args.rect,
rank=rank,
workers=self.args.workers,
shuffle=self.args.shuffle,
use_segments=True,
)[0]
return build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode)[0]
def preprocess_batch(self, batch):
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
@ -58,7 +46,10 @@ class SegmentationTrainer(BaseTrainer):
self.model.names = self.data["names"]
def get_validator(self):
return v8.segment.SegmentationValidator(self.test_loader, self.device, logger=self.console)
return v8.segment.SegmentationValidator(self.test_loader,
save_dir=self.save_dir,
logger=self.console,
args=self.args)
def criterion(self, preds, batch):
head = de_parallel(self.model).model[-1]
@ -218,6 +209,8 @@ class SegmentationTrainer(BaseTrainer):
else:
mask_gti = masks[tidxs[i]][j]
lseg += single_mask_loss(mask_gti, pmask[j], proto[bi], mxyxy[j], marea[j])
else:
lseg += (proto * 0).sum()
obji = BCEobj(pi[..., 4], tobj)
lobj += obji * balance[i] # obj loss
@ -234,15 +227,33 @@ class SegmentationTrainer(BaseTrainer):
loss = lbox + lobj + lcls + lseg
return loss * bs, torch.cat((lbox, lseg, lobj, lcls)).detach()
def label_loss_items(self, loss_items):
def label_loss_items(self, loss_items=None, prefix="train"):
# We should just use named tensors here in future
keys = ["lbox", "lseg", "lobj", "lcls"]
return dict(zip(keys, loss_items))
keys = [f"{prefix}/lbox", f"{prefix}/lseg", f"{prefix}/lobj", f"{prefix}/lcls"]
return dict(zip(keys, loss_items)) if loss_items is not None else keys
def progress_string(self):
return ('\n' + '%11s' * 7) % \
('Epoch', 'GPU_mem', 'box_loss', 'seg_loss', 'obj_loss', 'cls_loss', 'Size')
def plot_training_samples(self, batch, ni):
images = batch["img"]
masks = batch["masks"]
cls = batch["cls"].squeeze(-1)
bboxes = batch["bboxes"]
paths = batch["im_file"]
batch_idx = batch["batch_idx"]
plot_images_and_masks(images,
batch_idx,
cls,
bboxes,
masks,
paths,
fname=self.save_dir / f"train_batch{ni}.jpg")
def plot_metrics(self):
plot_results_with_masks(file=self.csv) # save results.png
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
def train(cfg):

@ -6,23 +6,24 @@ import torch.nn.functional as F
from ultralytics.yolo.engine.validator import BaseValidator
from ultralytics.yolo.utils import ops
from ultralytics.yolo.utils.checks import check_requirements
from ultralytics.yolo.utils.checks import check_file, check_requirements
from ultralytics.yolo.utils.files import yaml_load
from ultralytics.yolo.utils.metrics import (ConfusionMatrix, Metrics, ap_per_class_box_and_mask, box_iou,
fitness_segmentation, mask_iou)
from ultralytics.yolo.utils.plotting import output_to_target, plot_images_and_masks
from ultralytics.yolo.utils.torch_utils import de_parallel
class SegmentationValidator(BaseValidator):
def __init__(self, dataloader, pbar=None, logger=None, args=None):
super().__init__(dataloader, pbar, logger, args)
def __init__(self, dataloader, save_dir=None, pbar=None, logger=None, args=None):
super().__init__(dataloader, save_dir, pbar, logger, args)
if self.args.save_json:
check_requirements(['pycocotools'])
self.process = ops.process_mask_upsample # more accurate
else:
self.process = ops.process_mask # faster
self.data_dict = yaml_load(self.args.data) if self.args.data else None
self.data_dict = yaml_load(check_file(self.args.data)) if self.args.data else None
self.is_coco = False
self.class_map = None
self.targets = None
@ -62,6 +63,7 @@ class SegmentationValidator(BaseValidator):
self.loss = torch.zeros(4, device=self.device)
self.jdict = []
self.stats = []
self.plot_masks = []
def get_desc(self):
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P",
@ -80,11 +82,10 @@ class SegmentationValidator(BaseValidator):
def update_metrics(self, preds, batch):
# Metrics
plot_masks = [] # masks for plotting
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
labels = self.targets[self.targets[:, 0] == si, 1:]
nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
shape = batch["shape"][si]
shape = batch["ori_shape"][si]
# path = batch["shape"][si][0]
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
@ -130,7 +131,7 @@ class SegmentationValidator(BaseValidator):
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
if self.args.plots and self.batch_i < 3:
plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
# TODO: Save/log
'''
@ -143,26 +144,14 @@ class SegmentationValidator(BaseValidator):
# callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
'''
# TODO Plot images
'''
if self.args.plots and self.batch_i < 3:
if len(plot_masks):
plot_masks = torch.cat(plot_masks, dim=0)
plot_images_and_masks(im, targets, masks, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names)
plot_images_and_masks(im, output_to_target(preds, max_det=15), plot_masks, paths,
save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred
'''
def get_stats(self):
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
if len(stats) and stats[0].any():
# TODO: save_dir
results = ap_per_class_box_and_mask(*stats, plot=self.args.plots, save_dir='', names=self.names)
results = ap_per_class_box_and_mask(*stats, plot=self.args.plots, save_dir=self.save_dir, names=self.names)
self.metrics.update(results)
self.nt_per_class = np.bincount(stats[4].astype(int), minlength=self.nc) # number of targets per class
keys = ["mp_bbox", "mr_bbox", "map50_bbox", "map_bbox", "mp_mask", "mr_mask", "map50_mask", "map_mask"]
metrics = {"fitness": fitness_segmentation(np.array(self.metrics.mean_results()).reshape(1, -1))}
metrics |= zip(keys, self.metrics.mean_results())
metrics |= zip(self.metric_keys, self.metrics.mean_results())
return metrics
def print_results(self):
@ -177,9 +166,8 @@ class SegmentationValidator(BaseValidator):
for i, c in enumerate(self.metrics.ap_class_index):
self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
# plot TODO: save_dir
if self.args.plots:
self.confusion_matrix.plot(save_dir='', names=list(self.names.values()))
self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
def _process_batch(self, detections, labels, iouv, pred_masks=None, gt_masks=None, overlap=False, masks=False):
"""
@ -217,3 +205,41 @@ class SegmentationValidator(BaseValidator):
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
correct[matches[:, 1].astype(int), i] = True
return torch.tensor(correct, dtype=torch.bool, device=iouv.device)
@property
def metric_keys(self):
return [
"metrics/precision(B)",
"metrics/recall(B)",
"metrics/mAP_0.5(B)",
"metrics/mAP_0.5:0.95(B)", # metrics
"metrics/precision(M)",
"metrics/recall(M)",
"metrics/mAP_0.5(M)",
"metrics/mAP_0.5:0.95(M)",]
def plot_val_samples(self, batch, ni):
images = batch["img"]
masks = batch["masks"]
cls = batch["cls"].squeeze(-1)
bboxes = batch["bboxes"]
paths = batch["im_file"]
batch_idx = batch["batch_idx"]
plot_images_and_masks(images,
batch_idx,
cls,
bboxes,
masks,
paths,
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
names=self.names)
def plot_predictions(self, batch, preds, ni):
images = batch["img"]
paths = batch["im_file"]
if len(self.plot_masks):
plot_masks = torch.cat(self.plot_masks, dim=0)
batch_idx, cls, bboxes, conf = output_to_target(preds[0], max_det=15)
plot_images_and_masks(images, batch_idx, cls, bboxes, plot_masks, paths, conf,
self.save_dir / f'val_batch{ni}_pred.jpg', self.names) # pred
self.plot_masks.clear()

Loading…
Cancel
Save