[Feature] Add Mask2Former to mmdet (#6938)
update doc update doc format deepcopy pixel_decoder cfg move mask_pseudo_sampler cfg to config file move part of postprocess from head to detector fix bug in postprocessing move class setting from head to config file remove if else move mask2bbox to mask/util update docstring update docstring in result2json fix bug update class_weight add maskformer_fusion_head add maskformer fusion head update add cfg for filter_low_score update maskformer update class_weight update config update unit test rename param update comments in config rename variable, rm arg, update unit tests update mask2bbox add unit test for mask2bbox replace unsqueeze(1) and squeeze(1) add unit test for maskformer_fusion_head update docstrings update docstring delete \ remove modification to ce loss update docstring update docstring update docstring of ce loss update unit test update docstring update docstring update docstring rename rename add msdeformattn pixel decoder maskformer refactor add strides in config remove redundant code remove redundant code update unit test update config updatepull/7435/head^2
parent
fc8fb168c5
commit
14f0e9585c
13 changed files with 1212 additions and 9 deletions
@ -0,0 +1,253 @@ |
||||
_base_ = [ |
||||
'../_base_/datasets/coco_panoptic.py', '../_base_/default_runtime.py' |
||||
] |
||||
num_things_classes = 80 |
||||
num_stuff_classes = 53 |
||||
num_classes = num_things_classes + num_stuff_classes |
||||
model = dict( |
||||
type='Mask2Former', |
||||
backbone=dict( |
||||
type='ResNet', |
||||
depth=50, |
||||
num_stages=4, |
||||
out_indices=(0, 1, 2, 3), |
||||
frozen_stages=-1, |
||||
norm_cfg=dict(type='BN', requires_grad=False), |
||||
norm_eval=True, |
||||
style='pytorch', |
||||
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), |
||||
panoptic_head=dict( |
||||
type='Mask2FormerHead', |
||||
in_channels=[256, 512, 1024, 2048], # pass to pixel_decoder inside |
||||
strides=[4, 8, 16, 32], |
||||
feat_channels=256, |
||||
out_channels=256, |
||||
num_things_classes=num_things_classes, |
||||
num_stuff_classes=num_stuff_classes, |
||||
num_queries=100, |
||||
num_transformer_feat_level=3, |
||||
pixel_decoder=dict( |
||||
type='MSDeformAttnPixelDecoder', |
||||
num_outs=3, |
||||
norm_cfg=dict(type='GN', num_groups=32), |
||||
act_cfg=dict(type='ReLU'), |
||||
encoder=dict( |
||||
type='DetrTransformerEncoder', |
||||
num_layers=6, |
||||
transformerlayers=dict( |
||||
type='BaseTransformerLayer', |
||||
attn_cfgs=dict( |
||||
type='MultiScaleDeformableAttention', |
||||
embed_dims=256, |
||||
num_heads=8, |
||||
num_levels=3, |
||||
num_points=4, |
||||
im2col_step=64, |
||||
dropout=0.0, |
||||
batch_first=False, |
||||
norm_cfg=None, |
||||
init_cfg=None), |
||||
ffn_cfgs=dict( |
||||
type='FFN', |
||||
embed_dims=256, |
||||
feedforward_channels=1024, |
||||
num_fcs=2, |
||||
ffn_drop=0.0, |
||||
act_cfg=dict(type='ReLU', inplace=True)), |
||||
operation_order=('self_attn', 'norm', 'ffn', 'norm')), |
||||
init_cfg=None), |
||||
positional_encoding=dict( |
||||
type='SinePositionalEncoding', num_feats=128, normalize=True), |
||||
init_cfg=None), |
||||
enforce_decoder_input_project=False, |
||||
positional_encoding=dict( |
||||
type='SinePositionalEncoding', num_feats=128, normalize=True), |
||||
transformer_decoder=dict( |
||||
type='DetrTransformerDecoder', |
||||
return_intermediate=True, |
||||
num_layers=9, |
||||
transformerlayers=dict( |
||||
type='DetrTransformerDecoderLayer', |
||||
attn_cfgs=dict( |
||||
type='MultiheadAttention', |
||||
embed_dims=256, |
||||
num_heads=8, |
||||
attn_drop=0.0, |
||||
proj_drop=0.0, |
||||
dropout_layer=None, |
||||
batch_first=False), |
||||
ffn_cfgs=dict( |
||||
embed_dims=256, |
||||
feedforward_channels=2048, |
||||
num_fcs=2, |
||||
act_cfg=dict(type='ReLU', inplace=True), |
||||
ffn_drop=0.0, |
||||
dropout_layer=None, |
||||
add_identity=True), |
||||
feedforward_channels=2048, |
||||
operation_order=('cross_attn', 'norm', 'self_attn', 'norm', |
||||
'ffn', 'norm')), |
||||
init_cfg=None), |
||||
loss_cls=dict( |
||||
type='CrossEntropyLoss', |
||||
use_sigmoid=False, |
||||
loss_weight=2.0, |
||||
reduction='mean', |
||||
class_weight=[1.0] * num_classes + [0.1]), |
||||
loss_mask=dict( |
||||
type='CrossEntropyLoss', |
||||
use_sigmoid=True, |
||||
reduction='mean', |
||||
loss_weight=5.0), |
||||
loss_dice=dict( |
||||
type='DiceLoss', |
||||
use_sigmoid=True, |
||||
activate=True, |
||||
reduction='mean', |
||||
naive_dice=True, |
||||
eps=1.0, |
||||
loss_weight=5.0)), |
||||
panoptic_fusion_head=dict( |
||||
type='MaskFormerFusionHead', |
||||
num_things_classes=num_things_classes, |
||||
num_stuff_classes=num_stuff_classes, |
||||
loss_panoptic=None, |
||||
init_cfg=None), |
||||
train_cfg=dict( |
||||
num_points=12544, |
||||
oversample_ratio=3.0, |
||||
importance_sample_ratio=0.75, |
||||
assigner=dict( |
||||
type='MaskHungarianAssigner', |
||||
cls_cost=dict(type='ClassificationCost', weight=2.0), |
||||
mask_cost=dict( |
||||
type='CrossEntropyLossCost', weight=5.0, use_sigmoid=True), |
||||
dice_cost=dict( |
||||
type='DiceCost', weight=5.0, pred_act=True, eps=1.0)), |
||||
sampler=dict(type='MaskPseudoSampler')), |
||||
test_cfg=dict( |
||||
panoptic_on=True, |
||||
# For now, the dataset does not support |
||||
# evaluating semantic segmentation metric. |
||||
semantic_on=False, |
||||
instance_on=True, |
||||
# max_per_image is for instance segmentation. |
||||
max_per_image=100, |
||||
iou_thr=0.8, |
||||
# In Mask2Former's panoptic postprocessing, |
||||
# it will filter mask area where score is less than 0.5 . |
||||
filter_low_score=True), |
||||
init_cfg=None) |
||||
|
||||
# dataset settings |
||||
image_size = (1024, 1024) |
||||
img_norm_cfg = dict( |
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) |
||||
train_pipeline = [ |
||||
dict(type='LoadImageFromFile', to_float32=True), |
||||
dict( |
||||
type='LoadPanopticAnnotations', |
||||
with_bbox=True, |
||||
with_mask=True, |
||||
with_seg=True), |
||||
dict(type='RandomFlip', flip_ratio=0.5), |
||||
# large scale jittering |
||||
dict( |
||||
type='Resize', |
||||
img_scale=image_size, |
||||
ratio_range=(0.1, 2.0), |
||||
multiscale_mode='range', |
||||
keep_ratio=True), |
||||
dict( |
||||
type='RandomCrop', |
||||
crop_size=image_size, |
||||
crop_type='absolute', |
||||
recompute_bbox=True, |
||||
allow_negative_crop=True), |
||||
dict(type='Normalize', **img_norm_cfg), |
||||
dict(type='Pad', size=image_size), |
||||
dict(type='DefaultFormatBundle', img_to_float=True), |
||||
dict( |
||||
type='Collect', |
||||
keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']), |
||||
] |
||||
test_pipeline = [ |
||||
dict(type='LoadImageFromFile'), |
||||
dict( |
||||
type='MultiScaleFlipAug', |
||||
img_scale=(1333, 800), |
||||
flip=False, |
||||
transforms=[ |
||||
dict(type='Resize', keep_ratio=True), |
||||
dict(type='RandomFlip'), |
||||
dict(type='Normalize', **img_norm_cfg), |
||||
dict(type='Pad', size_divisor=32), |
||||
dict(type='ImageToTensor', keys=['img']), |
||||
dict(type='Collect', keys=['img']), |
||||
]) |
||||
] |
||||
data_root = 'data/coco/' |
||||
data = dict( |
||||
samples_per_gpu=2, |
||||
workers_per_gpu=2, |
||||
train=dict(pipeline=train_pipeline), |
||||
val=dict( |
||||
pipeline=test_pipeline, |
||||
ins_ann_file=data_root + 'annotations/instances_val2017.json', |
||||
), |
||||
test=dict( |
||||
pipeline=test_pipeline, |
||||
ins_ann_file=data_root + 'annotations/instances_val2017.json', |
||||
)) |
||||
|
||||
embed_multi = dict(lr_mult=1.0, decay_mult=0.0) |
||||
# optimizer |
||||
optimizer = dict( |
||||
type='AdamW', |
||||
lr=0.0001, |
||||
weight_decay=0.05, |
||||
eps=1e-8, |
||||
betas=(0.9, 0.999), |
||||
paramwise_cfg=dict( |
||||
custom_keys={ |
||||
'backbone': dict(lr_mult=0.1, decay_mult=1.0), |
||||
'query_embed': embed_multi, |
||||
'query_feat': embed_multi, |
||||
'level_embed': embed_multi, |
||||
}, |
||||
norm_decay_mult=0.0)) |
||||
optimizer_config = dict(grad_clip=dict(max_norm=0.01, norm_type=2)) |
||||
|
||||
# learning policy |
||||
lr_config = dict( |
||||
policy='step', |
||||
gamma=0.1, |
||||
by_epoch=False, |
||||
step=[327778, 355092], |
||||
warmup='linear', |
||||
warmup_by_epoch=False, |
||||
warmup_ratio=1.0, # no warmup |
||||
warmup_iters=10) |
||||
|
||||
max_iters = 368750 |
||||
runner = dict(type='IterBasedRunner', max_iters=max_iters) |
||||
|
||||
log_config = dict( |
||||
interval=50, |
||||
hooks=[ |
||||
dict(type='TextLoggerHook', by_epoch=False), |
||||
dict(type='TensorboardLoggerHook', by_epoch=False) |
||||
]) |
||||
interval = 200000 |
||||
workflow = [('train', interval)] |
||||
checkpoint_config = dict( |
||||
by_epoch=False, interval=interval, save_last=True, max_keep_ckpts=3) |
||||
|
||||
# Before 365001th iteration, we do evaluation every 200000 iterations. |
||||
# After 365000th iteration, we do evaluation every 368750 iterations, |
||||
# which means do evaluation at the end of training. |
||||
# In all, we do evaluation at the 200000th iteration and the |
||||
# last iteratoin. |
||||
dynamic_intervals = [(max_iters // interval * interval + 1, max_iters)] |
||||
evaluation = dict( |
||||
interval=interval, dynamic_intervals=dynamic_intervals, metric='PQ') |
@ -0,0 +1,62 @@ |
||||
_base_ = ['./mask2former_r50_lsj_8x2_50e_coco.py'] |
||||
pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' # noqa |
||||
|
||||
depths = [2, 2, 6, 2] |
||||
model = dict( |
||||
type='Mask2Former', |
||||
backbone=dict( |
||||
_delete_=True, |
||||
type='SwinTransformer', |
||||
embed_dims=96, |
||||
depths=depths, |
||||
num_heads=[3, 6, 12, 24], |
||||
window_size=7, |
||||
mlp_ratio=4, |
||||
qkv_bias=True, |
||||
qk_scale=None, |
||||
drop_rate=0., |
||||
attn_drop_rate=0., |
||||
drop_path_rate=0.3, |
||||
patch_norm=True, |
||||
out_indices=(0, 1, 2, 3), |
||||
with_cp=False, |
||||
convert_weights=True, |
||||
frozen_stages=-1, |
||||
init_cfg=dict(type='Pretrained', checkpoint=pretrained)), |
||||
panoptic_head=dict( |
||||
type='Mask2FormerHead', in_channels=[96, 192, 384, 768]), |
||||
init_cfg=None) |
||||
|
||||
# set all layers in backbone to lr_mult=0.1 |
||||
# set all norm layers, position_embeding, |
||||
# query_embeding, level_embeding to decay_multi=0.0 |
||||
backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0) |
||||
backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0) |
||||
embed_multi = dict(lr_mult=1.0, decay_mult=0.0) |
||||
custom_keys = { |
||||
'backbone': dict(lr_mult=0.1, decay_mult=1.0), |
||||
'backbone.patch_embed.norm': backbone_norm_multi, |
||||
'backbone.norm': backbone_norm_multi, |
||||
'absolute_pos_embed': backbone_embed_multi, |
||||
'relative_position_bias_table': backbone_embed_multi, |
||||
'query_embed': embed_multi, |
||||
'query_feat': embed_multi, |
||||
'level_embed': embed_multi |
||||
} |
||||
custom_keys.update({ |
||||
f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi |
||||
for stage_id, num_blocks in enumerate(depths) |
||||
for block_id in range(num_blocks) |
||||
}) |
||||
custom_keys.update({ |
||||
f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi |
||||
for stage_id in range(len(depths) - 1) |
||||
}) |
||||
# optimizer |
||||
optimizer = dict( |
||||
type='AdamW', |
||||
lr=0.0001, |
||||
weight_decay=0.05, |
||||
eps=1e-8, |
||||
betas=(0.9, 0.999), |
||||
paramwise_cfg=dict(custom_keys=custom_keys, norm_decay_mult=0.0)) |
@ -1,9 +1,9 @@ |
||||
# Copyright (c) OpenMMLab. All rights reserved. |
||||
from .builder import build_match_cost |
||||
from .match_cost import (BBoxL1Cost, ClassificationCost, DiceCost, |
||||
FocalLossCost, IoUCost) |
||||
from .match_cost import (BBoxL1Cost, ClassificationCost, CrossEntropyLossCost, |
||||
DiceCost, FocalLossCost, IoUCost) |
||||
|
||||
__all__ = [ |
||||
'build_match_cost', 'ClassificationCost', 'BBoxL1Cost', 'IoUCost', |
||||
'FocalLossCost', 'DiceCost' |
||||
'FocalLossCost', 'DiceCost', 'CrossEntropyLossCost' |
||||
] |
||||
|
@ -0,0 +1,430 @@ |
||||
# Copyright (c) OpenMMLab. All rights reserved. |
||||
import copy |
||||
|
||||
import torch |
||||
import torch.nn as nn |
||||
import torch.nn.functional as F |
||||
from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init |
||||
from mmcv.cnn.bricks.transformer import (build_positional_encoding, |
||||
build_transformer_layer_sequence) |
||||
from mmcv.ops import point_sample |
||||
from mmcv.runner import ModuleList |
||||
|
||||
from mmdet.core import build_assigner, build_sampler, reduce_mean |
||||
from mmdet.models.utils import get_uncertain_point_coords_with_randomness |
||||
from ..builder import HEADS, build_loss |
||||
from .anchor_free_head import AnchorFreeHead |
||||
from .maskformer_head import MaskFormerHead |
||||
|
||||
|
||||
@HEADS.register_module() |
||||
class Mask2FormerHead(MaskFormerHead): |
||||
"""Implements the Mask2Former head. |
||||
|
||||
See `Masked-attention Mask Transformer for Universal Image |
||||
Segmentation <https://arxiv.org/pdf/2112.01527>`_ for details. |
||||
|
||||
Args: |
||||
in_channels (list[int]): Number of channels in the input feature map. |
||||
feat_channels (int): Number of channels for features. |
||||
out_channels (int): Number of channels for output. |
||||
num_things_classes (int): Number of things. |
||||
num_stuff_classes (int): Number of stuff. |
||||
num_queries (int): Number of query in Transformer decoder. |
||||
pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel |
||||
decoder. Defaults to None. |
||||
enforce_decoder_input_project (bool, optional): Whether to add |
||||
a layer to change the embed_dim of tranformer encoder in |
||||
pixel decoder to the embed_dim of transformer decoder. |
||||
Defaults to False. |
||||
transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for |
||||
transformer decoder. Defaults to None. |
||||
positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for |
||||
transformer decoder position encoding. Defaults to None. |
||||
loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification |
||||
loss. Defaults to None. |
||||
loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss. |
||||
Defaults to None. |
||||
loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss. |
||||
Defaults to None. |
||||
train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of |
||||
Mask2Former head. |
||||
test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of |
||||
Mask2Former head. |
||||
init_cfg (dict or list[dict], optional): Initialization config dict. |
||||
Defaults to None. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
in_channels, |
||||
feat_channels, |
||||
out_channels, |
||||
num_things_classes=80, |
||||
num_stuff_classes=53, |
||||
num_queries=100, |
||||
num_transformer_feat_level=3, |
||||
pixel_decoder=None, |
||||
enforce_decoder_input_project=False, |
||||
transformer_decoder=None, |
||||
positional_encoding=None, |
||||
loss_cls=None, |
||||
loss_mask=None, |
||||
loss_dice=None, |
||||
train_cfg=None, |
||||
test_cfg=None, |
||||
init_cfg=None, |
||||
**kwargs): |
||||
super(AnchorFreeHead, self).__init__(init_cfg) |
||||
self.num_things_classes = num_things_classes |
||||
self.num_stuff_classes = num_stuff_classes |
||||
self.num_classes = self.num_things_classes + self.num_stuff_classes |
||||
self.num_queries = num_queries |
||||
self.num_transformer_feat_level = num_transformer_feat_level |
||||
self.num_heads = transformer_decoder.transformerlayers.\ |
||||
attn_cfgs.num_heads |
||||
self.num_transformer_decoder_layers = transformer_decoder.num_layers |
||||
assert pixel_decoder.encoder.transformerlayers.\ |
||||
attn_cfgs.num_levels == num_transformer_feat_level |
||||
pixel_decoder_ = copy.deepcopy(pixel_decoder) |
||||
pixel_decoder_.update( |
||||
in_channels=in_channels, |
||||
feat_channels=feat_channels, |
||||
out_channels=out_channels) |
||||
self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1] |
||||
self.transformer_decoder = build_transformer_layer_sequence( |
||||
transformer_decoder) |
||||
self.decoder_embed_dims = self.transformer_decoder.embed_dims |
||||
|
||||
self.decoder_input_projs = ModuleList() |
||||
# from low resolution to high resolution |
||||
for _ in range(num_transformer_feat_level): |
||||
if (self.decoder_embed_dims != feat_channels |
||||
or enforce_decoder_input_project): |
||||
self.decoder_input_projs.append( |
||||
Conv2d( |
||||
feat_channels, self.decoder_embed_dims, kernel_size=1)) |
||||
else: |
||||
self.decoder_input_projs.append(nn.Identity()) |
||||
self.decoder_positional_encoding = build_positional_encoding( |
||||
positional_encoding) |
||||
self.query_embed = nn.Embedding(self.num_queries, feat_channels) |
||||
self.query_feat = nn.Embedding(self.num_queries, feat_channels) |
||||
# from low resolution to high resolution |
||||
self.level_embed = nn.Embedding(self.num_transformer_feat_level, |
||||
feat_channels) |
||||
|
||||
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) |
||||
self.mask_embed = nn.Sequential( |
||||
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), |
||||
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), |
||||
nn.Linear(feat_channels, out_channels)) |
||||
|
||||
self.test_cfg = test_cfg |
||||
self.train_cfg = train_cfg |
||||
if train_cfg: |
||||
self.assigner = build_assigner(self.train_cfg.assigner) |
||||
self.sampler = build_sampler(self.train_cfg.sampler, context=self) |
||||
self.num_points = self.train_cfg.get('num_points', 12544) |
||||
self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) |
||||
self.importance_sample_ratio = self.train_cfg.get( |
||||
'importance_sample_ratio', 0.75) |
||||
|
||||
self.class_weight = loss_cls.class_weight |
||||
self.loss_cls = build_loss(loss_cls) |
||||
self.loss_mask = build_loss(loss_mask) |
||||
self.loss_dice = build_loss(loss_dice) |
||||
|
||||
def init_weights(self): |
||||
for m in self.decoder_input_projs: |
||||
if isinstance(m, Conv2d): |
||||
caffe2_xavier_init(m, bias=0) |
||||
|
||||
self.pixel_decoder.init_weights() |
||||
|
||||
for p in self.transformer_decoder.parameters(): |
||||
if p.dim() > 1: |
||||
nn.init.xavier_normal_(p) |
||||
|
||||
def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, |
||||
img_metas): |
||||
"""Compute classification and mask targets for one image. |
||||
|
||||
Args: |
||||
cls_score (Tensor): Mask score logits from a single decoder layer |
||||
for one image. Shape (num_queries, cls_out_channels). |
||||
mask_pred (Tensor): Mask logits for a single decoder layer for one |
||||
image. Shape (num_queries, h, w). |
||||
gt_labels (Tensor): Ground truth class indices for one image with |
||||
shape (num_gts, ). |
||||
gt_masks (Tensor): Ground truth mask for each image, each with |
||||
shape (num_gts, h, w). |
||||
img_metas (dict): Image informtation. |
||||
|
||||
Returns: |
||||
tuple[Tensor]: A tuple containing the following for one image. |
||||
|
||||
- labels (Tensor): Labels of each image. \ |
||||
shape (num_queries, ). |
||||
- label_weights (Tensor): Label weights of each image. \ |
||||
shape (num_queries, ). |
||||
- mask_targets (Tensor): Mask targets of each image. \ |
||||
shape (num_queries, h, w). |
||||
- mask_weights (Tensor): Mask weights of each image. \ |
||||
shape (num_queries, ). |
||||
- pos_inds (Tensor): Sampled positive indices for each \ |
||||
image. |
||||
- neg_inds (Tensor): Sampled negative indices for each \ |
||||
image. |
||||
""" |
||||
# sample points |
||||
num_queries = cls_score.shape[0] |
||||
num_gts = gt_labels.shape[0] |
||||
|
||||
point_coords = torch.rand((1, self.num_points, 2), |
||||
device=cls_score.device) |
||||
# shape (num_queries, num_points) |
||||
mask_points_pred = point_sample( |
||||
mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, |
||||
1)).squeeze(1) |
||||
# shape (num_gts, num_points) |
||||
gt_points_masks = point_sample( |
||||
gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, |
||||
1)).squeeze(1) |
||||
|
||||
# assign and sample |
||||
assign_result = self.assigner.assign(cls_score, mask_points_pred, |
||||
gt_labels, gt_points_masks, |
||||
img_metas) |
||||
sampling_result = self.sampler.sample(assign_result, mask_pred, |
||||
gt_masks) |
||||
pos_inds = sampling_result.pos_inds |
||||
neg_inds = sampling_result.neg_inds |
||||
|
||||
# label target |
||||
labels = gt_labels.new_full((self.num_queries, ), |
||||
self.num_classes, |
||||
dtype=torch.long) |
||||
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] |
||||
label_weights = gt_labels.new_ones((self.num_queries, )) |
||||
|
||||
# mask target |
||||
mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] |
||||
mask_weights = mask_pred.new_zeros((self.num_queries, )) |
||||
mask_weights[pos_inds] = 1.0 |
||||
|
||||
return (labels, label_weights, mask_targets, mask_weights, pos_inds, |
||||
neg_inds) |
||||
|
||||
def loss_single(self, cls_scores, mask_preds, gt_labels_list, |
||||
gt_masks_list, img_metas): |
||||
"""Loss function for outputs from a single decoder layer. |
||||
|
||||
Args: |
||||
cls_scores (Tensor): Mask score logits from a single decoder layer |
||||
for all images. Shape (batch_size, num_queries, |
||||
cls_out_channels). Note `cls_out_channels` should includes |
||||
background. |
||||
mask_preds (Tensor): Mask logits for a pixel decoder for all |
||||
images. Shape (batch_size, num_queries, h, w). |
||||
gt_labels_list (list[Tensor]): Ground truth class indices for each |
||||
image, each with shape (num_gts, ). |
||||
gt_masks_list (list[Tensor]): Ground truth mask for each image, |
||||
each with shape (num_gts, h, w). |
||||
img_metas (list[dict]): List of image meta information. |
||||
|
||||
Returns: |
||||
tuple[Tensor]: Loss components for outputs from a single \ |
||||
decoder layer. |
||||
""" |
||||
num_imgs = cls_scores.size(0) |
||||
cls_scores_list = [cls_scores[i] for i in range(num_imgs)] |
||||
mask_preds_list = [mask_preds[i] for i in range(num_imgs)] |
||||
(labels_list, label_weights_list, mask_targets_list, mask_weights_list, |
||||
num_total_pos, |
||||
num_total_neg) = self.get_targets(cls_scores_list, mask_preds_list, |
||||
gt_labels_list, gt_masks_list, |
||||
img_metas) |
||||
# shape (batch_size, num_queries) |
||||
labels = torch.stack(labels_list, dim=0) |
||||
# shape (batch_size, num_queries) |
||||
label_weights = torch.stack(label_weights_list, dim=0) |
||||
# shape (num_total_gts, h, w) |
||||
mask_targets = torch.cat(mask_targets_list, dim=0) |
||||
# shape (batch_size, num_queries) |
||||
mask_weights = torch.stack(mask_weights_list, dim=0) |
||||
|
||||
# classfication loss |
||||
# shape (batch_size * num_queries, ) |
||||
cls_scores = cls_scores.flatten(0, 1) |
||||
labels = labels.flatten(0, 1) |
||||
label_weights = label_weights.flatten(0, 1) |
||||
|
||||
class_weight = cls_scores.new_tensor(self.class_weight) |
||||
loss_cls = self.loss_cls( |
||||
cls_scores, |
||||
labels, |
||||
label_weights, |
||||
avg_factor=class_weight[labels].sum()) |
||||
|
||||
num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos])) |
||||
num_total_masks = max(num_total_masks, 1) |
||||
|
||||
# extract positive ones |
||||
# shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) |
||||
mask_preds = mask_preds[mask_weights > 0] |
||||
|
||||
if mask_targets.shape[0] == 0: |
||||
# zero match |
||||
loss_dice = mask_preds.sum() |
||||
loss_mask = mask_preds.sum() |
||||
return loss_cls, loss_mask, loss_dice |
||||
|
||||
with torch.no_grad(): |
||||
points_coords = get_uncertain_point_coords_with_randomness( |
||||
mask_preds.unsqueeze(1), None, self.num_points, |
||||
self.oversample_ratio, self.importance_sample_ratio) |
||||
# shape (num_total_gts, h, w) -> (num_total_gts, num_points) |
||||
mask_point_targets = point_sample( |
||||
mask_targets.unsqueeze(1).float(), points_coords).squeeze(1) |
||||
# shape (num_queries, h, w) -> (num_queries, num_points) |
||||
mask_point_preds = point_sample( |
||||
mask_preds.unsqueeze(1), points_coords).squeeze(1) |
||||
|
||||
# dice loss |
||||
loss_dice = self.loss_dice( |
||||
mask_point_preds, mask_point_targets, avg_factor=num_total_masks) |
||||
|
||||
# mask loss |
||||
# shape (num_queries, num_points) -> (num_queries * num_points, ) |
||||
mask_point_preds = mask_point_preds.reshape(-1) |
||||
# shape (num_total_gts, num_points) -> (num_total_gts * num_points, ) |
||||
mask_point_targets = mask_point_targets.reshape(-1) |
||||
loss_mask = self.loss_mask( |
||||
mask_point_preds, |
||||
mask_point_targets, |
||||
avg_factor=num_total_masks * self.num_points) |
||||
|
||||
return loss_cls, loss_mask, loss_dice |
||||
|
||||
def forward_head(self, decoder_out, mask_feature, attn_mask_target_size): |
||||
"""Forward for head part which is called after every decoder layer. |
||||
|
||||
Args: |
||||
decoder_out (Tensor): in shape (num_queries, batch_size, c). |
||||
mask_feature (Tensor): in shape (batch_size, c, h, w). |
||||
attn_mask_target_size (tuple[int, int]): target attention |
||||
mask size. |
||||
|
||||
Returns: |
||||
tuple: A tuple contain three elements. |
||||
|
||||
- cls_pred (Tensor): Classification scores in shape \ |
||||
(batch_size, num_queries, cls_out_channels). \ |
||||
Note `cls_out_channels` should includes background. |
||||
- mask_pred (Tensor): Mask scores in shape \ |
||||
(batch_size, num_queries,h, w). |
||||
- attn_mask (Tensor): Attention mask in shape \ |
||||
(batch_size * num_heads, num_queries, h, w). |
||||
""" |
||||
decoder_out = self.transformer_decoder.post_norm(decoder_out) |
||||
decoder_out = decoder_out.transpose(0, 1) |
||||
# shape (num_queries, batch_size, c) |
||||
cls_pred = self.cls_embed(decoder_out) |
||||
# shape (num_queries, batch_size, c) |
||||
mask_embed = self.mask_embed(decoder_out) |
||||
# shape (num_queries, batch_size, h, w) |
||||
mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature) |
||||
attn_mask = F.interpolate( |
||||
mask_pred, |
||||
attn_mask_target_size, |
||||
mode='bilinear', |
||||
align_corners=False) |
||||
# shape (num_queries, batch_size, h, w) -> |
||||
# (batch_size * num_head, num_queries, h, w) |
||||
attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat( |
||||
(1, self.num_heads, 1, 1)).flatten(0, 1) |
||||
attn_mask = attn_mask.sigmoid() < 0.5 |
||||
attn_mask = attn_mask.detach() |
||||
|
||||
return cls_pred, mask_pred, attn_mask |
||||
|
||||
def forward(self, feats, img_metas): |
||||
"""Forward function. |
||||
|
||||
Args: |
||||
feats (list[Tensor]): Multi scale Features from the |
||||
upstream network, each is a 4D-tensor. |
||||
img_metas (list[dict]): List of image information. |
||||
|
||||
Returns: |
||||
tuple: A tuple contains two elements. |
||||
|
||||
- cls_pred_list (list[Tensor)]: Classification logits \ |
||||
for each decoder layer. Each is a 3D-tensor with shape \ |
||||
(batch_size, num_queries, cls_out_channels). \ |
||||
Note `cls_out_channels` should includes background. |
||||
- mask_pred_list (list[Tensor]): Mask logits for each \ |
||||
decoder layer. Each with shape (batch_size, num_queries, \ |
||||
h, w). |
||||
""" |
||||
batch_size = len(img_metas) |
||||
mask_features, multi_scale_memorys = self.pixel_decoder(feats) |
||||
# multi_scale_memorys (from low resolution to high resolution) |
||||
decoder_inputs = [] |
||||
decoder_positional_encodings = [] |
||||
for i in range(self.num_transformer_feat_level): |
||||
decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) |
||||
# shape (batch_size, c, h, w) -> (h*w, batch_size, c) |
||||
decoder_input = decoder_input.flatten(2).permute(2, 0, 1) |
||||
level_embed = self.level_embed.weight[i].view(1, 1, -1) |
||||
decoder_input = decoder_input + level_embed |
||||
# shape (batch_size, c, h, w) -> (h*w, batch_size, c) |
||||
mask = decoder_input.new_zeros( |
||||
(batch_size, ) + multi_scale_memorys[i].shape[-2:], |
||||
dtype=torch.bool) |
||||
decoder_positional_encoding = self.decoder_positional_encoding( |
||||
mask) |
||||
decoder_positional_encoding = decoder_positional_encoding.flatten( |
||||
2).permute(2, 0, 1) |
||||
decoder_inputs.append(decoder_input) |
||||
decoder_positional_encodings.append(decoder_positional_encoding) |
||||
# shape (num_queries, c) -> (num_queries, batch_size, c) |
||||
query_feat = self.query_feat.weight.unsqueeze(1).repeat( |
||||
(1, batch_size, 1)) |
||||
query_embed = self.query_embed.weight.unsqueeze(1).repeat( |
||||
(1, batch_size, 1)) |
||||
|
||||
cls_pred_list = [] |
||||
mask_pred_list = [] |
||||
cls_pred, mask_pred, attn_mask = self.forward_head( |
||||
query_feat, mask_features, multi_scale_memorys[0].shape[-2:]) |
||||
cls_pred_list.append(cls_pred) |
||||
mask_pred_list.append(mask_pred) |
||||
|
||||
for i in range(self.num_transformer_decoder_layers): |
||||
level_idx = i % self.num_transformer_feat_level |
||||
# if a mask is all True(all background), then set it all False. |
||||
attn_mask[torch.where( |
||||
attn_mask.sum(-1) == attn_mask.shape[-1])] = False |
||||
|
||||
# cross_attn + self_attn |
||||
layer = self.transformer_decoder.layers[i] |
||||
attn_masks = [attn_mask, None] |
||||
query_feat = layer( |
||||
query=query_feat, |
||||
key=decoder_inputs[level_idx], |
||||
value=decoder_inputs[level_idx], |
||||
query_pos=query_embed, |
||||
key_pos=decoder_positional_encodings[level_idx], |
||||
attn_masks=attn_masks, |
||||
query_key_padding_mask=None, |
||||
# here we do not apply masking on padded region |
||||
key_padding_mask=None) |
||||
cls_pred, mask_pred, attn_mask = self.forward_head( |
||||
query_feat, mask_features, multi_scale_memorys[ |
||||
(i + 1) % self.num_transformer_feat_level].shape[-2:]) |
||||
|
||||
cls_pred_list.append(cls_pred) |
||||
mask_pred_list.append(mask_pred) |
||||
|
||||
return cls_pred_list, mask_pred_list |
@ -0,0 +1,27 @@ |
||||
# Copyright (c) OpenMMLab. All rights reserved. |
||||
from ..builder import DETECTORS |
||||
from .maskformer import MaskFormer |
||||
|
||||
|
||||
@DETECTORS.register_module() |
||||
class Mask2Former(MaskFormer): |
||||
r"""Implementation of `Masked-attention Mask |
||||
Transformer for Universal Image Segmentation |
||||
<https://arxiv.org/pdf/2112.01527>`_.""" |
||||
|
||||
def __init__(self, |
||||
backbone, |
||||
neck=None, |
||||
panoptic_head=None, |
||||
panoptic_fusion_head=None, |
||||
train_cfg=None, |
||||
test_cfg=None, |
||||
init_cfg=None): |
||||
super().__init__( |
||||
backbone, |
||||
neck=neck, |
||||
panoptic_head=panoptic_head, |
||||
panoptic_fusion_head=panoptic_fusion_head, |
||||
train_cfg=train_cfg, |
||||
test_cfg=test_cfg, |
||||
init_cfg=init_cfg) |
@ -0,0 +1,216 @@ |
||||
import numpy as np |
||||
import torch |
||||
from mmcv import ConfigDict |
||||
|
||||
from mmdet.core.mask import BitmapMasks |
||||
from mmdet.models.dense_heads import Mask2FormerHead |
||||
|
||||
|
||||
def test_mask2former_head_loss(): |
||||
"""Tests head loss when truth is empty and non-empty.""" |
||||
base_channels = 64 |
||||
img_metas = [{ |
||||
'batch_input_shape': (128, 160), |
||||
'img_shape': (126, 160, 3), |
||||
'ori_shape': (63, 80, 3) |
||||
}, { |
||||
'batch_input_shape': (128, 160), |
||||
'img_shape': (120, 160, 3), |
||||
'ori_shape': (60, 80, 3) |
||||
}] |
||||
feats = [ |
||||
torch.rand((2, 64 * 2**i, 4 * 2**(3 - i), 5 * 2**(3 - i))) |
||||
for i in range(4) |
||||
] |
||||
num_things_classes = 80 |
||||
num_stuff_classes = 53 |
||||
num_classes = num_things_classes + num_stuff_classes |
||||
config = ConfigDict( |
||||
dict( |
||||
type='Mask2FormerHead', |
||||
in_channels=[base_channels * 2**i for i in range(4)], |
||||
feat_channels=base_channels, |
||||
out_channels=base_channels, |
||||
num_things_classes=num_things_classes, |
||||
num_stuff_classes=num_stuff_classes, |
||||
num_queries=100, |
||||
num_transformer_feat_level=3, |
||||
pixel_decoder=dict( |
||||
type='MSDeformAttnPixelDecoder', |
||||
num_outs=3, |
||||
norm_cfg=dict(type='GN', num_groups=32), |
||||
act_cfg=dict(type='ReLU'), |
||||
encoder=dict( |
||||
type='DetrTransformerEncoder', |
||||
num_layers=6, |
||||
transformerlayers=dict( |
||||
type='BaseTransformerLayer', |
||||
attn_cfgs=dict( |
||||
type='MultiScaleDeformableAttention', |
||||
embed_dims=base_channels, |
||||
num_heads=8, |
||||
num_levels=3, |
||||
num_points=4, |
||||
im2col_step=64, |
||||
dropout=0.0, |
||||
batch_first=False, |
||||
norm_cfg=None, |
||||
init_cfg=None), |
||||
ffn_cfgs=dict( |
||||
type='FFN', |
||||
embed_dims=base_channels, |
||||
feedforward_channels=base_channels * 4, |
||||
num_fcs=2, |
||||
ffn_drop=0.0, |
||||
act_cfg=dict(type='ReLU', inplace=True)), |
||||
feedforward_channels=base_channels * 4, |
||||
ffn_dropout=0.0, |
||||
operation_order=('self_attn', 'norm', 'ffn', 'norm')), |
||||
init_cfg=None), |
||||
positional_encoding=dict( |
||||
type='SinePositionalEncoding', |
||||
num_feats=base_channels // 2, |
||||
normalize=True), |
||||
init_cfg=None), |
||||
enforce_decoder_input_project=False, |
||||
positional_encoding=dict( |
||||
type='SinePositionalEncoding', |
||||
num_feats=base_channels // 2, |
||||
normalize=True), |
||||
transformer_decoder=dict( |
||||
type='DetrTransformerDecoder', |
||||
return_intermediate=True, |
||||
num_layers=9, |
||||
transformerlayers=dict( |
||||
type='DetrTransformerDecoderLayer', |
||||
attn_cfgs=dict( |
||||
type='MultiheadAttention', |
||||
embed_dims=base_channels, |
||||
num_heads=8, |
||||
attn_drop=0.0, |
||||
proj_drop=0.0, |
||||
dropout_layer=None, |
||||
batch_first=False), |
||||
ffn_cfgs=dict( |
||||
embed_dims=base_channels, |
||||
feedforward_channels=base_channels * 8, |
||||
num_fcs=2, |
||||
act_cfg=dict(type='ReLU', inplace=True), |
||||
ffn_drop=0.0, |
||||
dropout_layer=None, |
||||
add_identity=True), |
||||
# the following parameter was not used, |
||||
# just make current api happy |
||||
feedforward_channels=base_channels * 8, |
||||
operation_order=('cross_attn', 'norm', 'self_attn', 'norm', |
||||
'ffn', 'norm')), |
||||
init_cfg=None), |
||||
loss_cls=dict( |
||||
type='CrossEntropyLoss', |
||||
use_sigmoid=False, |
||||
loss_weight=2.0, |
||||
reduction='mean', |
||||
class_weight=[1.0] * num_classes + [0.1]), |
||||
loss_mask=dict( |
||||
type='CrossEntropyLoss', |
||||
use_sigmoid=True, |
||||
reduction='mean', |
||||
loss_weight=5.0), |
||||
loss_dice=dict( |
||||
type='DiceLoss', |
||||
use_sigmoid=True, |
||||
activate=True, |
||||
reduction='mean', |
||||
naive_dice=True, |
||||
eps=1.0, |
||||
loss_weight=5.0), |
||||
train_cfg=dict( |
||||
num_points=256, |
||||
oversample_ratio=3.0, |
||||
importance_sample_ratio=0.75, |
||||
assigner=dict( |
||||
type='MaskHungarianAssigner', |
||||
cls_cost=dict(type='ClassificationCost', weight=2.0), |
||||
mask_cost=dict( |
||||
type='CrossEntropyLossCost', |
||||
weight=5.0, |
||||
use_sigmoid=True), |
||||
dice_cost=dict( |
||||
type='DiceCost', weight=5.0, pred_act=True, eps=1.0)), |
||||
sampler=dict(type='MaskPseudoSampler')), |
||||
test_cfg=dict( |
||||
panoptic_on=True, |
||||
semantic_on=False, |
||||
instance_on=True, |
||||
max_dets_per_image=100, |
||||
object_mask_thr=0.8, |
||||
iou_thr=0.8))) |
||||
self = Mask2FormerHead(**config) |
||||
self.init_weights() |
||||
all_cls_scores, all_mask_preds = self.forward(feats, img_metas) |
||||
# Test that empty ground truth encourages the network to predict background |
||||
gt_labels_list = [torch.LongTensor([]), torch.LongTensor([])] |
||||
gt_masks_list = [ |
||||
torch.zeros((0, 128, 160)).long(), |
||||
torch.zeros((0, 128, 160)).long() |
||||
] |
||||
|
||||
empty_gt_losses = self.loss(all_cls_scores, all_mask_preds, gt_labels_list, |
||||
gt_masks_list, img_metas) |
||||
# When there is no truth, the cls loss should be nonzero but there should |
||||
# be no mask loss. |
||||
for key, loss in empty_gt_losses.items(): |
||||
if 'cls' in key: |
||||
assert loss.item() > 0, 'cls loss should be non-zero' |
||||
elif 'mask' in key: |
||||
assert loss.item( |
||||
) == 0, 'there should be no mask loss when there are no true mask' |
||||
elif 'dice' in key: |
||||
assert loss.item( |
||||
) == 0, 'there should be no dice loss when there are no true mask' |
||||
|
||||
# when truth is non-empty then both cls, mask, dice loss should be nonzero |
||||
# random inputs |
||||
gt_labels_list = [ |
||||
torch.tensor([10, 100]).long(), |
||||
torch.tensor([100, 10]).long() |
||||
] |
||||
mask1 = torch.zeros((2, 128, 160)).long() |
||||
mask1[0, :50] = 1 |
||||
mask1[1, 50:] = 1 |
||||
mask2 = torch.zeros((2, 128, 160)).long() |
||||
mask2[0, :, :50] = 1 |
||||
mask2[1, :, 50:] = 1 |
||||
gt_masks_list = [mask1, mask2] |
||||
two_gt_losses = self.loss(all_cls_scores, all_mask_preds, gt_labels_list, |
||||
gt_masks_list, img_metas) |
||||
for loss in two_gt_losses.values(): |
||||
assert loss.item() > 0, 'all loss should be non-zero' |
||||
|
||||
# test forward_train |
||||
gt_bboxes = None |
||||
gt_labels = [ |
||||
torch.tensor([10]).long(), |
||||
torch.tensor([10]).long(), |
||||
] |
||||
thing_mask1 = np.zeros((1, 128, 160), dtype=np.int32) |
||||
thing_mask1[0, :50] = 1 |
||||
thing_mask2 = np.zeros((1, 128, 160), dtype=np.int32) |
||||
thing_mask2[0, :, 50:] = 1 |
||||
gt_masks = [ |
||||
BitmapMasks(thing_mask1, 128, 160), |
||||
BitmapMasks(thing_mask2, 128, 160), |
||||
] |
||||
stuff_mask1 = torch.zeros((1, 128, 160)).long() |
||||
stuff_mask1[0, :50] = 10 |
||||
stuff_mask1[0, 50:] = 100 |
||||
stuff_mask2 = torch.zeros((1, 128, 160)).long() |
||||
stuff_mask2[0, :, 50:] = 10 |
||||
stuff_mask2[0, :, :50] = 100 |
||||
gt_semantic_seg = [stuff_mask1, stuff_mask2] |
||||
|
||||
self.forward_train(feats, img_metas, gt_bboxes, gt_labels, gt_masks, |
||||
gt_semantic_seg) |
||||
|
||||
# test inference mode |
||||
self.simple_test(feats, img_metas) |
Loading…
Reference in new issue