|
|
@ -2,18 +2,18 @@ from copy import copy |
|
|
|
|
|
|
|
|
|
|
|
import hydra |
|
|
|
import hydra |
|
|
|
import torch |
|
|
|
import torch |
|
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
import torch.nn.functional as F |
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
from ultralytics.nn.tasks import SegmentationModel |
|
|
|
from ultralytics.nn.tasks import SegmentationModel |
|
|
|
from ultralytics.yolo import v8 |
|
|
|
from ultralytics.yolo import v8 |
|
|
|
from ultralytics.yolo.utils import DEFAULT_CONFIG |
|
|
|
from ultralytics.yolo.utils import DEFAULT_CONFIG |
|
|
|
from ultralytics.yolo.utils.loss import BboxLoss |
|
|
|
from ultralytics.yolo.utils.ops import crop_mask, xyxy2xywh |
|
|
|
from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh |
|
|
|
|
|
|
|
from ultralytics.yolo.utils.plotting import plot_images, plot_results |
|
|
|
from ultralytics.yolo.utils.plotting import plot_images, plot_results |
|
|
|
from ultralytics.yolo.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors |
|
|
|
from ultralytics.yolo.utils.tal import make_anchors |
|
|
|
from ultralytics.yolo.utils.torch_utils import de_parallel |
|
|
|
from ultralytics.yolo.utils.torch_utils import de_parallel |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from ..detect.train import Loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# BaseTrainer python usage |
|
|
|
# BaseTrainer python usage |
|
|
|
class SegmentationTrainer(v8.detect.DetectionTrainer): |
|
|
|
class SegmentationTrainer(v8.detect.DetectionTrainer): |
|
|
@ -55,51 +55,12 @@ class SegmentationTrainer(v8.detect.DetectionTrainer): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Criterion class for computing training losses |
|
|
|
# Criterion class for computing training losses |
|
|
|
class SegLoss: |
|
|
|
class SegLoss(Loss): |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, model, overlap=True): # model must be de-paralleled |
|
|
|
def __init__(self, model, overlap=True): # model must be de-paralleled |
|
|
|
|
|
|
|
super().__init__(model) |
|
|
|
device = next(model.parameters()).device # get model device |
|
|
|
self.nm = model.model[-1].nm # number of masks |
|
|
|
h = model.args # hyperparameters |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
m = model.model[-1] # Detect() module |
|
|
|
|
|
|
|
self.bce = nn.BCEWithLogitsLoss(reduction='none') |
|
|
|
|
|
|
|
self.hyp = h |
|
|
|
|
|
|
|
self.stride = m.stride # model strides |
|
|
|
|
|
|
|
self.nc = m.nc # number of classes |
|
|
|
|
|
|
|
self.no = m.no |
|
|
|
|
|
|
|
self.nm = m.nm # number of masks |
|
|
|
|
|
|
|
self.reg_max = m.reg_max |
|
|
|
|
|
|
|
self.overlap = overlap |
|
|
|
self.overlap = overlap |
|
|
|
self.device = device |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.use_dfl = m.reg_max > 1 |
|
|
|
|
|
|
|
self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0) |
|
|
|
|
|
|
|
self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device) |
|
|
|
|
|
|
|
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess(self, targets, batch_size, scale_tensor): |
|
|
|
|
|
|
|
if targets.shape[0] == 0: |
|
|
|
|
|
|
|
out = torch.zeros(batch_size, 0, 5, device=self.device) |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
i = targets[:, 0] # image index |
|
|
|
|
|
|
|
_, counts = i.unique(return_counts=True) |
|
|
|
|
|
|
|
out = torch.zeros(batch_size, counts.max(), 5, device=self.device) |
|
|
|
|
|
|
|
for j in range(batch_size): |
|
|
|
|
|
|
|
matches = i == j |
|
|
|
|
|
|
|
n = matches.sum() |
|
|
|
|
|
|
|
if n: |
|
|
|
|
|
|
|
out[j, :n] = targets[matches, 1:] |
|
|
|
|
|
|
|
out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor)) |
|
|
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def bbox_decode(self, anchor_points, pred_dist): |
|
|
|
|
|
|
|
if self.use_dfl: |
|
|
|
|
|
|
|
b, a, c = pred_dist.shape # batch, anchors, channels |
|
|
|
|
|
|
|
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype)) |
|
|
|
|
|
|
|
# pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype)) |
|
|
|
|
|
|
|
# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2) |
|
|
|
|
|
|
|
return dist2bbox(pred_dist, anchor_points, xywh=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, preds, batch): |
|
|
|
def __call__(self, preds, batch): |
|
|
|
loss = torch.zeros(4, device=self.device) # box, cls, dfl |
|
|
|
loss = torch.zeros(4, device=self.device) # box, cls, dfl |
|
|
@ -163,10 +124,10 @@ class SegLoss: |
|
|
|
# else: |
|
|
|
# else: |
|
|
|
# loss[1] += proto.sum() * 0 |
|
|
|
# loss[1] += proto.sum() * 0 |
|
|
|
|
|
|
|
|
|
|
|
loss[0] *= 7.5 # box gain |
|
|
|
loss[0] *= self.hyp.box # box gain |
|
|
|
loss[1] *= 7.5 / batch_size # seg gain |
|
|
|
loss[1] *= self.hyp.box / batch_size # seg gain |
|
|
|
loss[2] *= 0.5 # cls gain |
|
|
|
loss[2] *= self.hyp.cls # cls gain |
|
|
|
loss[3] *= 1.5 # dfl gain |
|
|
|
loss[3] *= self.hyp.dfl # dfl gain |
|
|
|
|
|
|
|
|
|
|
|
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) |
|
|
|
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) |
|
|
|
|
|
|
|
|
|
|
|