Add RTDETR Trainer (#2745)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>pull/3212/head^2^2
parent
03bce07848
commit
a0ba8ef5f0
23 changed files with 982 additions and 307 deletions
@ -0,0 +1,46 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect |
||||
|
||||
# Parameters |
||||
nc: 80 # number of classes |
||||
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' |
||||
# [depth, width, max_channels] |
||||
n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs |
||||
s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs |
||||
m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs |
||||
l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs |
||||
x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs |
||||
|
||||
# YOLOv8.0n backbone |
||||
backbone: |
||||
# [from, repeats, module, args] |
||||
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 |
||||
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 |
||||
- [-1, 3, C2f, [128, True]] |
||||
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 |
||||
- [-1, 6, C2f, [256, True]] |
||||
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 |
||||
- [-1, 6, C2f, [512, True]] |
||||
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 |
||||
- [-1, 3, C2f, [1024, True]] |
||||
- [-1, 1, SPPF, [1024, 5]] # 9 |
||||
|
||||
# YOLOv8.0n head |
||||
head: |
||||
- [-1, 1, nn.Upsample, [None, 2, 'nearest']] |
||||
- [[-1, 6], 1, Concat, [1]] # cat backbone P4 |
||||
- [-1, 3, C2f, [512]] # 12 |
||||
|
||||
- [-1, 1, nn.Upsample, [None, 2, 'nearest']] |
||||
- [[-1, 4], 1, Concat, [1]] # cat backbone P3 |
||||
- [-1, 3, C2f, [256]] # 15 (P3/8-small) |
||||
|
||||
- [-1, 1, Conv, [256, 3, 2]] |
||||
- [[-1, 12], 1, Concat, [1]] # cat head P4 |
||||
- [-1, 3, C2f, [512]] # 18 (P4/16-medium) |
||||
|
||||
- [-1, 1, Conv, [512, 3, 2]] |
||||
- [[-1, 9], 1, Concat, [1]] # cat head P5 |
||||
- [-1, 3, C2f, [1024]] # 21 (P5/32-large) |
||||
|
||||
- [[15, 18, 21], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) |
@ -0,0 +1,78 @@ |
||||
from copy import copy |
||||
|
||||
import torch |
||||
|
||||
from ultralytics.nn.tasks import RTDETRDetectionModel |
||||
from ultralytics.yolo.utils import DEFAULT_CFG, RANK, colorstr |
||||
from ultralytics.yolo.v8.detect import DetectionTrainer |
||||
|
||||
from .val import RTDETRDataset, RTDETRValidator |
||||
|
||||
|
||||
class RTDETRTrainer(DetectionTrainer): |
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True): |
||||
"""Return a YOLO detection model.""" |
||||
model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) |
||||
if weights: |
||||
model.load(weights) |
||||
return model |
||||
|
||||
def build_dataset(self, img_path, mode='val', batch=None): |
||||
"""Build RTDETR Dataset |
||||
|
||||
Args: |
||||
img_path (str): Path to the folder containing images. |
||||
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. |
||||
batch (int, optional): Size of batches, this is for `rect`. Defaults to None. |
||||
""" |
||||
return RTDETRDataset( |
||||
img_path=img_path, |
||||
imgsz=self.args.imgsz, |
||||
batch_size=batch, |
||||
augment=mode == 'train', # no augmentation |
||||
hyp=self.args, |
||||
rect=False, # no rect |
||||
cache=self.args.cache or None, |
||||
prefix=colorstr(f'{mode}: '), |
||||
data=self.data) |
||||
|
||||
def get_validator(self): |
||||
"""Returns a DetectionValidator for RTDETR model validation.""" |
||||
self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss' |
||||
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) |
||||
|
||||
def preprocess_batch(self, batch): |
||||
"""Preprocesses a batch of images by scaling and converting to float.""" |
||||
batch = super().preprocess_batch(batch) |
||||
bs = len(batch['img']) |
||||
batch_idx = batch['batch_idx'] |
||||
gt_bbox, gt_class = [], [] |
||||
for i in range(bs): |
||||
gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device)) |
||||
gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long)) |
||||
return batch |
||||
|
||||
|
||||
def train(cfg=DEFAULT_CFG, use_python=False): |
||||
"""Train and optimize RTDETR model given training data and device.""" |
||||
model = 'rtdetr-l.yaml' |
||||
data = cfg.data or 'coco128.yaml' # or yolo.ClassificationDataset("mnist") |
||||
device = cfg.device if cfg.device is not None else '' |
||||
|
||||
# NOTE: F.grid_sample which is in rt-detr does not support deterministic=True |
||||
# NOTE: amp training causes nan outputs and end with error while doing bipartite graph matching |
||||
args = dict(model=model, |
||||
data=data, |
||||
device=device, |
||||
imgsz=640, |
||||
exist_ok=True, |
||||
batch=4, |
||||
deterministic=False, |
||||
amp=False) |
||||
trainer = RTDETRTrainer(overrides=args) |
||||
trainer.train() |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
train() |
@ -0,0 +1,291 @@ |
||||
import torch |
||||
import torch.nn as nn |
||||
import torch.nn.functional as F |
||||
|
||||
from ultralytics.vit.utils.ops import HungarianMatcher |
||||
from ultralytics.yolo.utils.loss import FocalLoss, VarifocalLoss |
||||
from ultralytics.yolo.utils.metrics import bbox_iou |
||||
|
||||
|
||||
class DETRLoss(nn.Module): |
||||
|
||||
def __init__(self, |
||||
nc=80, |
||||
loss_gain=None, |
||||
aux_loss=True, |
||||
use_fl=True, |
||||
use_vfl=False, |
||||
use_uni_match=False, |
||||
uni_match_ind=0): |
||||
""" |
||||
Args: |
||||
nc (int): The number of classes. |
||||
loss_gain (dict): The coefficient of loss. |
||||
aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used. |
||||
use_focal_loss (bool): Use focal loss or not. |
||||
use_vfl (bool): Use VarifocalLoss or not. |
||||
use_uni_match (bool): Whether to use a fixed layer to assign labels for auxiliary branch. |
||||
uni_match_ind (int): The fixed indices of a layer. |
||||
""" |
||||
super().__init__() |
||||
|
||||
if loss_gain is None: |
||||
loss_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'no_object': 0.1, 'mask': 1, 'dice': 1} |
||||
self.nc = nc |
||||
self.matcher = HungarianMatcher(cost_gain={'class': 2, 'bbox': 5, 'giou': 2}) |
||||
self.loss_gain = loss_gain |
||||
self.aux_loss = aux_loss |
||||
self.fl = FocalLoss() if use_fl else None |
||||
self.vfl = VarifocalLoss() if use_vfl else None |
||||
|
||||
self.use_uni_match = use_uni_match |
||||
self.uni_match_ind = uni_match_ind |
||||
self.device = None |
||||
|
||||
def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=''): |
||||
# logits: [b, query, num_classes], gt_class: list[[n, 1]] |
||||
name_class = f'loss_class{postfix}' |
||||
bs, nq = pred_scores.shape[:2] |
||||
# one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes) |
||||
one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device) |
||||
one_hot.scatter_(2, targets.unsqueeze(-1), 1) |
||||
one_hot = one_hot[..., :-1] |
||||
gt_scores = gt_scores.view(bs, nq, 1) * one_hot |
||||
|
||||
if self.fl: |
||||
if num_gts and self.vfl: |
||||
loss_cls = self.vfl(pred_scores, gt_scores, one_hot) |
||||
else: |
||||
loss_cls = self.fl(pred_scores, one_hot.float()) |
||||
loss_cls /= max(num_gts, 1) / nq |
||||
else: |
||||
loss_cls = nn.BCEWithLogitsLoss(reduction='none')(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss |
||||
|
||||
return {name_class: loss_cls.squeeze() * self.loss_gain['class']} |
||||
|
||||
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=''): |
||||
# boxes: [b, query, 4], gt_bbox: list[[n, 4]] |
||||
name_bbox = f'loss_bbox{postfix}' |
||||
name_giou = f'loss_giou{postfix}' |
||||
|
||||
loss = {} |
||||
if len(gt_bboxes) == 0: |
||||
loss[name_bbox] = torch.tensor(0., device=self.device) |
||||
loss[name_giou] = torch.tensor(0., device=self.device) |
||||
return loss |
||||
|
||||
loss[name_bbox] = self.loss_gain['bbox'] * F.l1_loss(pred_bboxes, gt_bboxes, reduction='sum') / len(gt_bboxes) |
||||
loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True) |
||||
loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes) |
||||
loss[name_giou] = self.loss_gain['giou'] * loss[name_giou] |
||||
loss = {k: v.squeeze() for k, v in loss.items()} |
||||
return loss |
||||
|
||||
def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''): |
||||
# masks: [b, query, h, w], gt_mask: list[[n, H, W]] |
||||
name_mask = f'loss_mask{postfix}' |
||||
name_dice = f'loss_dice{postfix}' |
||||
|
||||
loss = {} |
||||
if sum(len(a) for a in gt_mask) == 0: |
||||
loss[name_mask] = torch.tensor(0., device=self.device) |
||||
loss[name_dice] = torch.tensor(0., device=self.device) |
||||
return loss |
||||
|
||||
num_gts = len(gt_mask) |
||||
src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices) |
||||
src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0] |
||||
# TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now. |
||||
loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks, |
||||
torch.tensor([num_gts], dtype=torch.float32)) |
||||
loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts) |
||||
return loss |
||||
|
||||
def _dice_loss(self, inputs, targets, num_gts): |
||||
inputs = F.sigmoid(inputs) |
||||
inputs = inputs.flatten(1) |
||||
targets = targets.flatten(1) |
||||
numerator = 2 * (inputs * targets).sum(1) |
||||
denominator = inputs.sum(-1) + targets.sum(-1) |
||||
loss = 1 - (numerator + 1) / (denominator + 1) |
||||
return loss.sum() / num_gts |
||||
|
||||
def _get_loss_aux(self, |
||||
pred_bboxes, |
||||
pred_scores, |
||||
gt_bboxes, |
||||
gt_cls, |
||||
gt_groups, |
||||
match_indices=None, |
||||
postfix='', |
||||
masks=None, |
||||
gt_mask=None): |
||||
"""Get auxiliary losses""" |
||||
# NOTE: loss class, bbox, giou, mask, dice |
||||
loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device) |
||||
if match_indices is None and self.use_uni_match: |
||||
match_indices = self.matcher(pred_bboxes[self.uni_match_ind], |
||||
pred_scores[self.uni_match_ind], |
||||
gt_bboxes, |
||||
gt_cls, |
||||
gt_groups, |
||||
masks=masks[self.uni_match_ind] if masks is not None else None, |
||||
gt_mask=gt_mask) |
||||
for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)): |
||||
aux_masks = masks[i] if masks is not None else None |
||||
loss_ = self._get_loss(aux_bboxes, |
||||
aux_scores, |
||||
gt_bboxes, |
||||
gt_cls, |
||||
gt_groups, |
||||
masks=aux_masks, |
||||
gt_mask=gt_mask, |
||||
postfix=postfix, |
||||
match_indices=match_indices) |
||||
loss[0] += loss_[f'loss_class{postfix}'] |
||||
loss[1] += loss_[f'loss_bbox{postfix}'] |
||||
loss[2] += loss_[f'loss_giou{postfix}'] |
||||
# if masks is not None and gt_mask is not None: |
||||
# loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix) |
||||
# loss[3] += loss_[f'loss_mask{postfix}'] |
||||
# loss[4] += loss_[f'loss_dice{postfix}'] |
||||
|
||||
loss = { |
||||
f'loss_class_aux{postfix}': loss[0], |
||||
f'loss_bbox_aux{postfix}': loss[1], |
||||
f'loss_giou_aux{postfix}': loss[2]} |
||||
# if masks is not None and gt_mask is not None: |
||||
# loss[f'loss_mask_aux{postfix}'] = loss[3] |
||||
# loss[f'loss_dice_aux{postfix}'] = loss[4] |
||||
return loss |
||||
|
||||
def _get_index(self, match_indices): |
||||
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)]) |
||||
src_idx = torch.cat([src for (src, _) in match_indices]) |
||||
dst_idx = torch.cat([dst for (_, dst) in match_indices]) |
||||
return (batch_idx, src_idx), dst_idx |
||||
|
||||
def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices): |
||||
pred_assigned = torch.cat([ |
||||
t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device) |
||||
for t, (I, _) in zip(pred_bboxes, match_indices)]) |
||||
gt_assigned = torch.cat([ |
||||
t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device) |
||||
for t, (_, J) in zip(gt_bboxes, match_indices)]) |
||||
return pred_assigned, gt_assigned |
||||
|
||||
def _get_loss(self, |
||||
pred_bboxes, |
||||
pred_scores, |
||||
gt_bboxes, |
||||
gt_cls, |
||||
gt_groups, |
||||
masks=None, |
||||
gt_mask=None, |
||||
postfix='', |
||||
match_indices=None): |
||||
"""Get losses""" |
||||
if match_indices is None: |
||||
match_indices = self.matcher(pred_bboxes, |
||||
pred_scores, |
||||
gt_bboxes, |
||||
gt_cls, |
||||
gt_groups, |
||||
masks=masks, |
||||
gt_mask=gt_mask) |
||||
|
||||
idx, gt_idx = self._get_index(match_indices) |
||||
pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx] |
||||
|
||||
bs, nq = pred_scores.shape[:2] |
||||
targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype) |
||||
targets[idx] = gt_cls[gt_idx] |
||||
|
||||
gt_scores = torch.zeros([bs, nq], device=pred_scores.device) |
||||
if len(gt_bboxes): |
||||
gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1) |
||||
|
||||
loss = {} |
||||
loss.update(self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix)) |
||||
loss.update(self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix)) |
||||
# if masks is not None and gt_mask is not None: |
||||
# loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix)) |
||||
return loss |
||||
|
||||
def forward(self, pred_bboxes, pred_scores, batch, postfix='', **kwargs): |
||||
""" |
||||
Args: |
||||
pred_bboxes (torch.Tensor): [l, b, query, 4] |
||||
pred_scores (torch.Tensor): [l, b, query, num_classes] |
||||
batch (dict): A dict includes: |
||||
gt_cls (torch.Tensor) with shape [num_gts, ], |
||||
gt_bboxes (torch.Tensor): [num_gts, 4], |
||||
gt_groups (List(int)): a list of batch size length includes the number of gts of each image. |
||||
postfix (str): postfix of loss name. |
||||
""" |
||||
self.device = pred_bboxes.device |
||||
match_indices = kwargs.get('match_indices', None) |
||||
gt_cls, gt_bboxes, gt_groups = batch['cls'], batch['bboxes'], batch['gt_groups'] |
||||
|
||||
total_loss = self._get_loss(pred_bboxes[-1], |
||||
pred_scores[-1], |
||||
gt_bboxes, |
||||
gt_cls, |
||||
gt_groups, |
||||
postfix=postfix, |
||||
match_indices=match_indices) |
||||
|
||||
if self.aux_loss: |
||||
total_loss.update( |
||||
self._get_loss_aux(pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, |
||||
postfix)) |
||||
|
||||
return total_loss |
||||
|
||||
|
||||
class RTDETRDetectionLoss(DETRLoss): |
||||
|
||||
def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None): |
||||
pred_bboxes, pred_scores = preds |
||||
total_loss = super().forward(pred_bboxes, pred_scores, batch) |
||||
|
||||
if dn_meta is not None: |
||||
dn_pos_idx, dn_num_group = dn_meta['dn_pos_idx'], dn_meta['dn_num_group'] |
||||
assert len(batch['gt_groups']) == len(dn_pos_idx) |
||||
|
||||
# denoising match indices |
||||
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch['gt_groups']) |
||||
|
||||
# compute denoising training loss |
||||
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix='_dn', match_indices=match_indices) |
||||
total_loss.update(dn_loss) |
||||
else: |
||||
total_loss.update({f'{k}_dn': torch.tensor(0., device=self.device) for k in total_loss.keys()}) |
||||
|
||||
return total_loss |
||||
|
||||
@staticmethod |
||||
def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups): |
||||
"""Get the match indices for denoising. |
||||
|
||||
Args: |
||||
dn_pos_idx (List[torch.Tensor]): A list includes positive indices of denoising. |
||||
dn_num_group (int): The number of groups of denoising. |
||||
gt_groups (List(int)): a list of batch size length includes the number of gts of each image. |
||||
|
||||
Returns: |
||||
dn_match_indices (List(tuple)): Matched indices. |
||||
|
||||
""" |
||||
dn_match_indices = [] |
||||
idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) |
||||
for i, num_gt in enumerate(gt_groups): |
||||
if num_gt > 0: |
||||
gt_idx = torch.arange(end=num_gt, dtype=torch.int32) + idx_groups[i] |
||||
gt_idx = gt_idx.repeat(dn_num_group) |
||||
assert len(dn_pos_idx[i]) == len(gt_idx), 'Expected the same length, ' |
||||
f'but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.' |
||||
dn_match_indices.append((dn_pos_idx[i], gt_idx)) |
||||
else: |
||||
dn_match_indices.append((torch.zeros([0], dtype=torch.int32), torch.zeros([0], dtype=torch.int32))) |
||||
return dn_match_indices |
@ -0,0 +1,230 @@ |
||||
# TODO: license |
||||
|
||||
import torch |
||||
import torch.nn as nn |
||||
import torch.nn.functional as F |
||||
from scipy.optimize import linear_sum_assignment |
||||
|
||||
from ultralytics.yolo.utils.metrics import bbox_iou |
||||
from ultralytics.yolo.utils.ops import xywh2xyxy, xyxy2xywh |
||||
|
||||
|
||||
class HungarianMatcher(nn.Module): |
||||
|
||||
def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0): |
||||
""" |
||||
Args: |
||||
matcher_coeff (dict): The coefficient of hungarian matcher cost. |
||||
""" |
||||
super().__init__() |
||||
if cost_gain is None: |
||||
cost_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'mask': 1, 'dice': 1} |
||||
self.cost_gain = cost_gain |
||||
self.use_fl = use_fl |
||||
self.with_mask = with_mask |
||||
self.num_sample_points = num_sample_points |
||||
self.alpha = alpha |
||||
self.gamma = gamma |
||||
|
||||
def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None): |
||||
""" |
||||
Args: |
||||
pred_bboxes (Tensor): [b, query, 4] |
||||
pred_scores (Tensor): [b, query, num_classes] |
||||
gt_cls (torch.Tensor) with shape [num_gts, ] |
||||
gt_bboxes (torch.Tensor): [num_gts, 4] |
||||
gt_groups (List(int)): a list of batch size length includes the number of gts of each image. |
||||
masks (Tensor|None): [b, query, h, w] |
||||
gt_mask (List(Tensor)): list[[n, H, W]] |
||||
|
||||
Returns: |
||||
A list of size batch_size, containing tuples of (index_i, index_j) where: |
||||
- index_i is the indices of the selected predictions (in order) |
||||
- index_j is the indices of the corresponding selected targets (in order) |
||||
For each batch element, it holds: |
||||
len(index_i) = len(index_j) = min(num_queries, num_target_boxes) |
||||
""" |
||||
bs, nq, nc = pred_scores.shape |
||||
|
||||
if sum(gt_groups) == 0: |
||||
return [(torch.tensor([], dtype=torch.int32), torch.tensor([], dtype=torch.int32)) for _ in range(bs)] |
||||
|
||||
# We flatten to compute the cost matrices in a batch |
||||
# [batch_size * num_queries, num_classes] |
||||
pred_scores = pred_scores.detach().view(-1, nc) |
||||
pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1) |
||||
# [batch_size * num_queries, 4] |
||||
pred_bboxes = pred_bboxes.detach().view(-1, 4) |
||||
|
||||
# Compute the classification cost |
||||
pred_scores = pred_scores[:, gt_cls] |
||||
if self.use_fl: |
||||
neg_cost_class = (1 - self.alpha) * (pred_scores ** self.gamma) * (-(1 - pred_scores + 1e-8).log()) |
||||
pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log()) |
||||
cost_class = pos_cost_class - neg_cost_class |
||||
else: |
||||
cost_class = -pred_scores |
||||
|
||||
# Compute the L1 cost between boxes |
||||
cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt) |
||||
|
||||
# Compute the GIoU cost between boxes, (bs*num_queries, num_gt) |
||||
cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1) |
||||
|
||||
# Final cost matrix |
||||
C = self.cost_gain['class'] * cost_class + \ |
||||
self.cost_gain['bbox'] * cost_bbox + \ |
||||
self.cost_gain['giou'] * cost_giou |
||||
# Compute the mask cost and dice cost |
||||
if self.with_mask: |
||||
C += self._cost_mask(bs, gt_groups, masks, gt_mask) |
||||
|
||||
C = C.view(bs, nq, -1).cpu() |
||||
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))] |
||||
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) |
||||
# (idx for queries, idx for gt) |
||||
return [(torch.tensor(i, dtype=torch.int32), torch.tensor(j, dtype=torch.int32) + gt_groups[k]) |
||||
for k, (i, j) in enumerate(indices)] |
||||
|
||||
def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None): |
||||
assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`' |
||||
# all masks share the same set of points for efficient matching |
||||
sample_points = torch.rand([bs, 1, self.num_sample_points, 2]) |
||||
sample_points = 2.0 * sample_points - 1.0 |
||||
|
||||
out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2) |
||||
out_mask = out_mask.flatten(0, 1) |
||||
|
||||
tgt_mask = torch.cat(gt_mask).unsqueeze(1) |
||||
sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0]) |
||||
tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2]) |
||||
|
||||
with torch.cuda.amp.autocast(False): |
||||
# binary cross entropy cost |
||||
pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none') |
||||
neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none') |
||||
cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T) |
||||
cost_mask /= self.num_sample_points |
||||
|
||||
# dice cost |
||||
out_mask = F.sigmoid(out_mask) |
||||
numerator = 2 * torch.matmul(out_mask, tgt_mask.T) |
||||
denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0) |
||||
cost_dice = 1 - (numerator + 1) / (denominator + 1) |
||||
|
||||
C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice |
||||
return C |
||||
|
||||
|
||||
def get_cdn_group(batch, |
||||
num_classes, |
||||
num_queries, |
||||
class_embed, |
||||
num_dn=100, |
||||
cls_noise_ratio=0.5, |
||||
box_noise_scale=1.0, |
||||
training=False): |
||||
"""Get contrastive denoising training group |
||||
|
||||
Args: |
||||
batch (dict): A dict includes: |
||||
gt_cls (torch.Tensor) with shape [num_gts, ], |
||||
gt_bboxes (torch.Tensor): [num_gts, 4], |
||||
gt_groups (List(int)): a list of batch size length includes the number of gts of each image. |
||||
num_classes (int): Number of classes. |
||||
num_queries (int): Number of queries. |
||||
class_embed (torch.Tensor): Embedding weights to map cls to embedding space. |
||||
num_dn (int): Number of denoising. |
||||
cls_noise_ratio (float): Noise ratio for class. |
||||
box_noise_scale (float): Noise scale for bbox. |
||||
training (bool): If it's training or not. |
||||
|
||||
Returns: |
||||
|
||||
""" |
||||
if (not training) or num_dn <= 0: |
||||
return None, None, None, None |
||||
gt_groups = batch['gt_groups'] |
||||
total_num = sum(gt_groups) |
||||
max_nums = max(gt_groups) |
||||
if max_nums == 0: |
||||
return None, None, None, None |
||||
|
||||
num_group = num_dn // max_nums |
||||
num_group = 1 if num_group == 0 else num_group |
||||
# pad gt to max_num of a batch |
||||
bs = len(gt_groups) |
||||
gt_cls = batch['cls'] # (bs*num, ) |
||||
gt_bbox = batch['bboxes'] # bs*num, 4 |
||||
b_idx = batch['batch_idx'] |
||||
|
||||
# each group has positive and negative queries. |
||||
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, ) |
||||
dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4 |
||||
dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, ) |
||||
|
||||
# positive and negative mask |
||||
# (bs*num*num_group, ), the second total_num*num_group part as negative samples |
||||
neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num |
||||
|
||||
if cls_noise_ratio > 0: |
||||
# half of bbox prob |
||||
mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5) |
||||
idx = torch.nonzero(mask).squeeze(-1) |
||||
# randomly put a new one here |
||||
new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device) |
||||
dn_cls[idx] = new_label |
||||
|
||||
if box_noise_scale > 0: |
||||
known_bbox = xywh2xyxy(dn_bbox) |
||||
|
||||
diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4 |
||||
|
||||
rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0 |
||||
rand_part = torch.rand_like(dn_bbox) |
||||
rand_part[neg_idx] += 1.0 |
||||
rand_part *= rand_sign |
||||
known_bbox += rand_part * diff |
||||
known_bbox.clip_(min=0.0, max=1.0) |
||||
dn_bbox = xyxy2xywh(known_bbox) |
||||
dn_bbox = inverse_sigmoid(dn_bbox) |
||||
|
||||
# total denoising queries |
||||
num_dn = int(max_nums * 2 * num_group) |
||||
# class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)]) |
||||
dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256 |
||||
padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device) |
||||
padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device) |
||||
|
||||
map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups]) |
||||
pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0) |
||||
|
||||
map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)]) |
||||
padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed |
||||
padding_bbox[(dn_b_idx, map_indices)] = dn_bbox |
||||
|
||||
tgt_size = num_dn + num_queries |
||||
attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool) |
||||
# match query cannot see the reconstruct |
||||
attn_mask[num_dn:, :num_dn] = True |
||||
# reconstruct cannot see each other |
||||
for i in range(num_group): |
||||
if i == 0: |
||||
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True |
||||
if i == num_group - 1: |
||||
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * i * 2] = True |
||||
else: |
||||
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True |
||||
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * 2 * i] = True |
||||
dn_meta = { |
||||
'dn_pos_idx': [p.reshape(-1) for p in pos_idx.cpu().split([n for n in gt_groups], dim=1)], |
||||
'dn_num_group': num_group, |
||||
'dn_num_split': [num_dn, num_queries]} |
||||
|
||||
return padding_cls.to(class_embed.device), padding_bbox.to(class_embed.device), attn_mask.to( |
||||
class_embed.device), dn_meta |
||||
|
||||
|
||||
def inverse_sigmoid(x, eps=1e-6): |
||||
x = x.clip(min=0., max=1.) |
||||
return torch.log(x / (1 - x + eps) + eps) |
Loading…
Reference in new issue