[Feature] Add Mask2Former to mmdet (#6938)

update doc

update doc format

deepcopy pixel_decoder cfg

move mask_pseudo_sampler cfg to config file

move part of postprocess from head to detector

fix bug in postprocessing

move class setting from head to config file

remove if else

move mask2bbox to mask/util

update docstring

update docstring in result2json

fix bug

update class_weight

add maskformer_fusion_head

add maskformer fusion head

update

add cfg for filter_low_score

update maskformer

update class_weight

update config

update unit test

rename param

update comments in config

rename variable, rm arg, update unit tests

update mask2bbox

add unit test for mask2bbox

replace unsqueeze(1) and squeeze(1)

add unit test for maskformer_fusion_head

update docstrings

update docstring

delete \

remove modification to ce loss

update docstring

update docstring

update docstring of ce loss

update unit test

update docstring

update docstring

update docstring

rename

rename

add msdeformattn pixel decoder

maskformer refactor

add strides in config

remove redundant code

remove redundant code

update unit test

update config

update
pull/7435/head^2
Cedric Luo 3 years ago committed by GitHub
parent fc8fb168c5
commit 14f0e9585c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 253
      configs/mask2former/mask2former_r50_lsj_8x2_50e_coco.py
  2. 62
      configs/mask2former/mask2former_swin-t-p4-w7-224_lsj_8x2_50e_coco.py
  3. 6
      mmdet/core/bbox/match_costs/__init__.py
  4. 65
      mmdet/core/bbox/match_costs/match_cost.py
  5. 3
      mmdet/datasets/coco.py
  6. 17
      mmdet/datasets/coco_panoptic.py
  7. 4
      mmdet/models/dense_heads/__init__.py
  8. 430
      mmdet/models/dense_heads/mask2former_head.py
  9. 3
      mmdet/models/detectors/__init__.py
  10. 27
      mmdet/models/detectors/mask2former.py
  11. 216
      tests/test_models/test_dense_heads/test_mask2former_head.py
  12. 111
      tests/test_models/test_forward.py
  13. 24
      tests/test_utils/test_assigner.py

@ -0,0 +1,253 @@
_base_ = [
'../_base_/datasets/coco_panoptic.py', '../_base_/default_runtime.py'
]
num_things_classes = 80
num_stuff_classes = 53
num_classes = num_things_classes + num_stuff_classes
model = dict(
type='Mask2Former',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
panoptic_head=dict(
type='Mask2FormerHead',
in_channels=[256, 512, 1024, 2048], # pass to pixel_decoder inside
strides=[4, 8, 16, 32],
feat_channels=256,
out_channels=256,
num_things_classes=num_things_classes,
num_stuff_classes=num_stuff_classes,
num_queries=100,
num_transformer_feat_level=3,
pixel_decoder=dict(
type='MSDeformAttnPixelDecoder',
num_outs=3,
norm_cfg=dict(type='GN', num_groups=32),
act_cfg=dict(type='ReLU'),
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiScaleDeformableAttention',
embed_dims=256,
num_heads=8,
num_levels=3,
num_points=4,
im2col_step=64,
dropout=0.0,
batch_first=False,
norm_cfg=None,
init_cfg=None),
ffn_cfgs=dict(
type='FFN',
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.0,
act_cfg=dict(type='ReLU', inplace=True)),
operation_order=('self_attn', 'norm', 'ffn', 'norm')),
init_cfg=None),
positional_encoding=dict(
type='SinePositionalEncoding', num_feats=128, normalize=True),
init_cfg=None),
enforce_decoder_input_project=False,
positional_encoding=dict(
type='SinePositionalEncoding', num_feats=128, normalize=True),
transformer_decoder=dict(
type='DetrTransformerDecoder',
return_intermediate=True,
num_layers=9,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
attn_drop=0.0,
proj_drop=0.0,
dropout_layer=None,
batch_first=False),
ffn_cfgs=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True),
ffn_drop=0.0,
dropout_layer=None,
add_identity=True),
feedforward_channels=2048,
operation_order=('cross_attn', 'norm', 'self_attn', 'norm',
'ffn', 'norm')),
init_cfg=None),
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=2.0,
reduction='mean',
class_weight=[1.0] * num_classes + [0.1]),
loss_mask=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=5.0),
loss_dice=dict(
type='DiceLoss',
use_sigmoid=True,
activate=True,
reduction='mean',
naive_dice=True,
eps=1.0,
loss_weight=5.0)),
panoptic_fusion_head=dict(
type='MaskFormerFusionHead',
num_things_classes=num_things_classes,
num_stuff_classes=num_stuff_classes,
loss_panoptic=None,
init_cfg=None),
train_cfg=dict(
num_points=12544,
oversample_ratio=3.0,
importance_sample_ratio=0.75,
assigner=dict(
type='MaskHungarianAssigner',
cls_cost=dict(type='ClassificationCost', weight=2.0),
mask_cost=dict(
type='CrossEntropyLossCost', weight=5.0, use_sigmoid=True),
dice_cost=dict(
type='DiceCost', weight=5.0, pred_act=True, eps=1.0)),
sampler=dict(type='MaskPseudoSampler')),
test_cfg=dict(
panoptic_on=True,
# For now, the dataset does not support
# evaluating semantic segmentation metric.
semantic_on=False,
instance_on=True,
# max_per_image is for instance segmentation.
max_per_image=100,
iou_thr=0.8,
# In Mask2Former's panoptic postprocessing,
# it will filter mask area where score is less than 0.5 .
filter_low_score=True),
init_cfg=None)
# dataset settings
image_size = (1024, 1024)
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile', to_float32=True),
dict(
type='LoadPanopticAnnotations',
with_bbox=True,
with_mask=True,
with_seg=True),
dict(type='RandomFlip', flip_ratio=0.5),
# large scale jittering
dict(
type='Resize',
img_scale=image_size,
ratio_range=(0.1, 2.0),
multiscale_mode='range',
keep_ratio=True),
dict(
type='RandomCrop',
crop_size=image_size,
crop_type='absolute',
recompute_bbox=True,
allow_negative_crop=True),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=image_size),
dict(type='DefaultFormatBundle', img_to_float=True),
dict(
type='Collect',
keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data_root = 'data/coco/'
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(pipeline=train_pipeline),
val=dict(
pipeline=test_pipeline,
ins_ann_file=data_root + 'annotations/instances_val2017.json',
),
test=dict(
pipeline=test_pipeline,
ins_ann_file=data_root + 'annotations/instances_val2017.json',
))
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
# optimizer
optimizer = dict(
type='AdamW',
lr=0.0001,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
paramwise_cfg=dict(
custom_keys={
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
'query_embed': embed_multi,
'query_feat': embed_multi,
'level_embed': embed_multi,
},
norm_decay_mult=0.0))
optimizer_config = dict(grad_clip=dict(max_norm=0.01, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
gamma=0.1,
by_epoch=False,
step=[327778, 355092],
warmup='linear',
warmup_by_epoch=False,
warmup_ratio=1.0, # no warmup
warmup_iters=10)
max_iters = 368750
runner = dict(type='IterBasedRunner', max_iters=max_iters)
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook', by_epoch=False),
dict(type='TensorboardLoggerHook', by_epoch=False)
])
interval = 200000
workflow = [('train', interval)]
checkpoint_config = dict(
by_epoch=False, interval=interval, save_last=True, max_keep_ckpts=3)
# Before 365001th iteration, we do evaluation every 200000 iterations.
# After 365000th iteration, we do evaluation every 368750 iterations,
# which means do evaluation at the end of training.
# In all, we do evaluation at the 200000th iteration and the
# last iteratoin.
dynamic_intervals = [(max_iters // interval * interval + 1, max_iters)]
evaluation = dict(
interval=interval, dynamic_intervals=dynamic_intervals, metric='PQ')

@ -0,0 +1,62 @@
_base_ = ['./mask2former_r50_lsj_8x2_50e_coco.py']
pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' # noqa
depths = [2, 2, 6, 2]
model = dict(
type='Mask2Former',
backbone=dict(
_delete_=True,
type='SwinTransformer',
embed_dims=96,
depths=depths,
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.3,
patch_norm=True,
out_indices=(0, 1, 2, 3),
with_cp=False,
convert_weights=True,
frozen_stages=-1,
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
panoptic_head=dict(
type='Mask2FormerHead', in_channels=[96, 192, 384, 768]),
init_cfg=None)
# set all layers in backbone to lr_mult=0.1
# set all norm layers, position_embeding,
# query_embeding, level_embeding to decay_multi=0.0
backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0)
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
custom_keys = {
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
'backbone.patch_embed.norm': backbone_norm_multi,
'backbone.norm': backbone_norm_multi,
'absolute_pos_embed': backbone_embed_multi,
'relative_position_bias_table': backbone_embed_multi,
'query_embed': embed_multi,
'query_feat': embed_multi,
'level_embed': embed_multi
}
custom_keys.update({
f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
for stage_id, num_blocks in enumerate(depths)
for block_id in range(num_blocks)
})
custom_keys.update({
f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
for stage_id in range(len(depths) - 1)
})
# optimizer
optimizer = dict(
type='AdamW',
lr=0.0001,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
paramwise_cfg=dict(custom_keys=custom_keys, norm_decay_mult=0.0))

@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import build_match_cost
from .match_cost import (BBoxL1Cost, ClassificationCost, DiceCost,
FocalLossCost, IoUCost)
from .match_cost import (BBoxL1Cost, ClassificationCost, CrossEntropyLossCost,
DiceCost, FocalLossCost, IoUCost)
__all__ = [
'build_match_cost', 'ClassificationCost', 'BBoxL1Cost', 'IoUCost',
'FocalLossCost', 'DiceCost'
'FocalLossCost', 'DiceCost', 'CrossEntropyLossCost'
]

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from mmdet.core.bbox.iou_calculators import bbox_overlaps
from mmdet.core.bbox.transforms import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh
@ -281,3 +282,67 @@ class DiceCost:
mask_preds = mask_preds.sigmoid()
dice_cost = self.binary_mask_dice_loss(mask_preds, gt_masks)
return dice_cost * self.weight
@MATCH_COST.register_module()
class CrossEntropyLossCost:
"""CrossEntropyLossCost.
Args:
weight (int | float, optional): loss weight. Defaults to 1.
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Defaults to True.
Examples:
>>> from mmdet.core.bbox.match_costs import CrossEntropyLossCost
>>> import torch
>>> bce = CrossEntropyLossCost(use_sigmoid=True)
>>> cls_pred = torch.tensor([[7.6, 1.2], [-1.3, 10]])
>>> gt_labels = torch.tensor([[1, 1], [1, 0]])
>>> print(bce(cls_pred, gt_labels))
"""
def __init__(self, weight=1., use_sigmoid=True):
assert use_sigmoid, 'use_sigmoid = False is not supported yet.'
self.weight = weight
self.use_sigmoid = use_sigmoid
def _binary_cross_entropy(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): The prediction with shape (num_query, 1, *) or
(num_query, *).
gt_labels (Tensor): The learning label of prediction with
shape (num_gt, *).
Returns:
Tensor: Cross entropy cost matrix in shape (num_query, num_gt).
"""
cls_pred = cls_pred.flatten(1).float()
gt_labels = gt_labels.flatten(1).float()
n = cls_pred.shape[1]
pos = F.binary_cross_entropy_with_logits(
cls_pred, torch.ones_like(cls_pred), reduction='none')
neg = F.binary_cross_entropy_with_logits(
cls_pred, torch.zeros_like(cls_pred), reduction='none')
cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \
torch.einsum('nc,mc->nm', neg, 1 - gt_labels)
cls_cost = cls_cost / n
return cls_cost
def __call__(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classification logits.
gt_labels (Tensor): Labels.
Returns:
Tensor: Cross entropy cost matrix with weight in
shape (num_query, num_gt).
"""
if self.use_sigmoid:
cls_cost = self._binary_cross_entropy(cls_pred, gt_labels)
else:
raise NotImplementedError
return cls_cost * self.weight

@ -405,9 +405,6 @@ class CocoDataset(CustomDataset):
'bbox', 'segm', 'proposal', 'proposal_fast'.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
jsonfile_prefix (str | None): The prefix of json files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None.
classwise (bool): Whether to evaluating the AP for each class.
proposal_nums (Sequence[int]): Proposal number used for evaluating
recalls, such as recall@100, recall@1000.

@ -457,8 +457,20 @@ class CocoPanopticDataset(CocoDataset):
different data types. This method will automatically recognize
the type, and dump them to json files.
.. code-block:: none
[
{
'pan_results': np.array, # shape (h, w)
# ins_results which includes bboxes and RLE encoded masks
# is optional.
'ins_results': (list[np.array], list[list[str]])
},
...
]
Args:
results (dict): Testing results of the dataset.
results (list[dict]): Testing results of the dataset.
outfile_prefix (str): The filename prefix of the json files. If the
prefix is "somepath/xxx", the json files will be named
"somepath/xxx.panoptic.json", "somepath/xxx.bbox.json",
@ -597,6 +609,7 @@ class CocoPanopticDataset(CocoDataset):
if 'PQ' in metrics:
eval_pan_results = self.evaluate_pan_json(
result_files, outfile_prefix, logger, classwise, nproc=nproc)
eval_results.update(eval_pan_results)
metrics.remove('PQ')
@ -611,11 +624,13 @@ class CocoPanopticDataset(CocoDataset):
'shuold not be None'
coco_gt = COCO(self.ins_ann_file)
panoptic_cat_ids = self.cat_ids
self.cat_ids = coco_gt.get_cat_ids(cat_names=self.THING_CLASSES)
eval_ins_results = self.evaluate_det_segm(results, result_files,
coco_gt, metrics, logger,
classwise, **kwargs)
self.cat_ids = panoptic_cat_ids
eval_results.update(eval_ins_results)
if tmp_dir is not None:

@ -20,6 +20,7 @@ from .gfl_head import GFLHead
from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
from .lad_head import LADHead
from .ld_head import LDHead
from .mask2former_head import Mask2FormerHead
from .maskformer_head import MaskFormerHead
from .nasfcos_head import NASFCOSHead
from .paa_head import PAAHead
@ -50,5 +51,6 @@ __all__ = [
'CascadeRPNHead', 'EmbeddingRPNHead', 'LDHead', 'CascadeRPNHead',
'AutoAssignHead', 'DETRHead', 'YOLOFHead', 'DeformableDETRHead',
'SOLOHead', 'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead',
'DecoupledSOLOLightHead', 'LADHead', 'TOODHead', 'MaskFormerHead'
'DecoupledSOLOLightHead', 'LADHead', 'TOODHead', 'MaskFormerHead',
'Mask2FormerHead'
]

@ -0,0 +1,430 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init
from mmcv.cnn.bricks.transformer import (build_positional_encoding,
build_transformer_layer_sequence)
from mmcv.ops import point_sample
from mmcv.runner import ModuleList
from mmdet.core import build_assigner, build_sampler, reduce_mean
from mmdet.models.utils import get_uncertain_point_coords_with_randomness
from ..builder import HEADS, build_loss
from .anchor_free_head import AnchorFreeHead
from .maskformer_head import MaskFormerHead
@HEADS.register_module()
class Mask2FormerHead(MaskFormerHead):
"""Implements the Mask2Former head.
See `Masked-attention Mask Transformer for Universal Image
Segmentation <https://arxiv.org/pdf/2112.01527>`_ for details.
Args:
in_channels (list[int]): Number of channels in the input feature map.
feat_channels (int): Number of channels for features.
out_channels (int): Number of channels for output.
num_things_classes (int): Number of things.
num_stuff_classes (int): Number of stuff.
num_queries (int): Number of query in Transformer decoder.
pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel
decoder. Defaults to None.
enforce_decoder_input_project (bool, optional): Whether to add
a layer to change the embed_dim of tranformer encoder in
pixel decoder to the embed_dim of transformer decoder.
Defaults to False.
transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for
transformer decoder. Defaults to None.
positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for
transformer decoder position encoding. Defaults to None.
loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification
loss. Defaults to None.
loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss.
Defaults to None.
loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss.
Defaults to None.
train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of
Mask2Former head.
test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of
Mask2Former head.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
in_channels,
feat_channels,
out_channels,
num_things_classes=80,
num_stuff_classes=53,
num_queries=100,
num_transformer_feat_level=3,
pixel_decoder=None,
enforce_decoder_input_project=False,
transformer_decoder=None,
positional_encoding=None,
loss_cls=None,
loss_mask=None,
loss_dice=None,
train_cfg=None,
test_cfg=None,
init_cfg=None,
**kwargs):
super(AnchorFreeHead, self).__init__(init_cfg)
self.num_things_classes = num_things_classes
self.num_stuff_classes = num_stuff_classes
self.num_classes = self.num_things_classes + self.num_stuff_classes
self.num_queries = num_queries
self.num_transformer_feat_level = num_transformer_feat_level
self.num_heads = transformer_decoder.transformerlayers.\
attn_cfgs.num_heads
self.num_transformer_decoder_layers = transformer_decoder.num_layers
assert pixel_decoder.encoder.transformerlayers.\
attn_cfgs.num_levels == num_transformer_feat_level
pixel_decoder_ = copy.deepcopy(pixel_decoder)
pixel_decoder_.update(
in_channels=in_channels,
feat_channels=feat_channels,
out_channels=out_channels)
self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1]
self.transformer_decoder = build_transformer_layer_sequence(
transformer_decoder)
self.decoder_embed_dims = self.transformer_decoder.embed_dims
self.decoder_input_projs = ModuleList()
# from low resolution to high resolution
for _ in range(num_transformer_feat_level):
if (self.decoder_embed_dims != feat_channels
or enforce_decoder_input_project):
self.decoder_input_projs.append(
Conv2d(
feat_channels, self.decoder_embed_dims, kernel_size=1))
else:
self.decoder_input_projs.append(nn.Identity())
self.decoder_positional_encoding = build_positional_encoding(
positional_encoding)
self.query_embed = nn.Embedding(self.num_queries, feat_channels)
self.query_feat = nn.Embedding(self.num_queries, feat_channels)
# from low resolution to high resolution
self.level_embed = nn.Embedding(self.num_transformer_feat_level,
feat_channels)
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
self.mask_embed = nn.Sequential(
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
nn.Linear(feat_channels, out_channels))
self.test_cfg = test_cfg
self.train_cfg = train_cfg
if train_cfg:
self.assigner = build_assigner(self.train_cfg.assigner)
self.sampler = build_sampler(self.train_cfg.sampler, context=self)
self.num_points = self.train_cfg.get('num_points', 12544)
self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0)
self.importance_sample_ratio = self.train_cfg.get(
'importance_sample_ratio', 0.75)
self.class_weight = loss_cls.class_weight
self.loss_cls = build_loss(loss_cls)
self.loss_mask = build_loss(loss_mask)
self.loss_dice = build_loss(loss_dice)
def init_weights(self):
for m in self.decoder_input_projs:
if isinstance(m, Conv2d):
caffe2_xavier_init(m, bias=0)
self.pixel_decoder.init_weights()
for p in self.transformer_decoder.parameters():
if p.dim() > 1:
nn.init.xavier_normal_(p)
def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks,
img_metas):
"""Compute classification and mask targets for one image.
Args:
cls_score (Tensor): Mask score logits from a single decoder layer
for one image. Shape (num_queries, cls_out_channels).
mask_pred (Tensor): Mask logits for a single decoder layer for one
image. Shape (num_queries, h, w).
gt_labels (Tensor): Ground truth class indices for one image with
shape (num_gts, ).
gt_masks (Tensor): Ground truth mask for each image, each with
shape (num_gts, h, w).
img_metas (dict): Image informtation.
Returns:
tuple[Tensor]: A tuple containing the following for one image.
- labels (Tensor): Labels of each image. \
shape (num_queries, ).
- label_weights (Tensor): Label weights of each image. \
shape (num_queries, ).
- mask_targets (Tensor): Mask targets of each image. \
shape (num_queries, h, w).
- mask_weights (Tensor): Mask weights of each image. \
shape (num_queries, ).
- pos_inds (Tensor): Sampled positive indices for each \
image.
- neg_inds (Tensor): Sampled negative indices for each \
image.
"""
# sample points
num_queries = cls_score.shape[0]
num_gts = gt_labels.shape[0]
point_coords = torch.rand((1, self.num_points, 2),
device=cls_score.device)
# shape (num_queries, num_points)
mask_points_pred = point_sample(
mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1,
1)).squeeze(1)
# shape (num_gts, num_points)
gt_points_masks = point_sample(
gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1,
1)).squeeze(1)
# assign and sample
assign_result = self.assigner.assign(cls_score, mask_points_pred,
gt_labels, gt_points_masks,
img_metas)
sampling_result = self.sampler.sample(assign_result, mask_pred,
gt_masks)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
# label target
labels = gt_labels.new_full((self.num_queries, ),
self.num_classes,
dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
label_weights = gt_labels.new_ones((self.num_queries, ))
# mask target
mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds]
mask_weights = mask_pred.new_zeros((self.num_queries, ))
mask_weights[pos_inds] = 1.0
return (labels, label_weights, mask_targets, mask_weights, pos_inds,
neg_inds)
def loss_single(self, cls_scores, mask_preds, gt_labels_list,
gt_masks_list, img_metas):
"""Loss function for outputs from a single decoder layer.
Args:
cls_scores (Tensor): Mask score logits from a single decoder layer
for all images. Shape (batch_size, num_queries,
cls_out_channels). Note `cls_out_channels` should includes
background.
mask_preds (Tensor): Mask logits for a pixel decoder for all
images. Shape (batch_size, num_queries, h, w).
gt_labels_list (list[Tensor]): Ground truth class indices for each
image, each with shape (num_gts, ).
gt_masks_list (list[Tensor]): Ground truth mask for each image,
each with shape (num_gts, h, w).
img_metas (list[dict]): List of image meta information.
Returns:
tuple[Tensor]: Loss components for outputs from a single \
decoder layer.
"""
num_imgs = cls_scores.size(0)
cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
mask_preds_list = [mask_preds[i] for i in range(num_imgs)]
(labels_list, label_weights_list, mask_targets_list, mask_weights_list,
num_total_pos,
num_total_neg) = self.get_targets(cls_scores_list, mask_preds_list,
gt_labels_list, gt_masks_list,
img_metas)
# shape (batch_size, num_queries)
labels = torch.stack(labels_list, dim=0)
# shape (batch_size, num_queries)
label_weights = torch.stack(label_weights_list, dim=0)
# shape (num_total_gts, h, w)
mask_targets = torch.cat(mask_targets_list, dim=0)
# shape (batch_size, num_queries)
mask_weights = torch.stack(mask_weights_list, dim=0)
# classfication loss
# shape (batch_size * num_queries, )
cls_scores = cls_scores.flatten(0, 1)
labels = labels.flatten(0, 1)
label_weights = label_weights.flatten(0, 1)
class_weight = cls_scores.new_tensor(self.class_weight)
loss_cls = self.loss_cls(
cls_scores,
labels,
label_weights,
avg_factor=class_weight[labels].sum())
num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos]))
num_total_masks = max(num_total_masks, 1)
# extract positive ones
# shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
mask_preds = mask_preds[mask_weights > 0]
if mask_targets.shape[0] == 0:
# zero match
loss_dice = mask_preds.sum()
loss_mask = mask_preds.sum()
return loss_cls, loss_mask, loss_dice
with torch.no_grad():
points_coords = get_uncertain_point_coords_with_randomness(
mask_preds.unsqueeze(1), None, self.num_points,
self.oversample_ratio, self.importance_sample_ratio)
# shape (num_total_gts, h, w) -> (num_total_gts, num_points)
mask_point_targets = point_sample(
mask_targets.unsqueeze(1).float(), points_coords).squeeze(1)
# shape (num_queries, h, w) -> (num_queries, num_points)
mask_point_preds = point_sample(
mask_preds.unsqueeze(1), points_coords).squeeze(1)
# dice loss
loss_dice = self.loss_dice(
mask_point_preds, mask_point_targets, avg_factor=num_total_masks)
# mask loss
# shape (num_queries, num_points) -> (num_queries * num_points, )
mask_point_preds = mask_point_preds.reshape(-1)
# shape (num_total_gts, num_points) -> (num_total_gts * num_points, )
mask_point_targets = mask_point_targets.reshape(-1)
loss_mask = self.loss_mask(
mask_point_preds,
mask_point_targets,
avg_factor=num_total_masks * self.num_points)
return loss_cls, loss_mask, loss_dice
def forward_head(self, decoder_out, mask_feature, attn_mask_target_size):
"""Forward for head part which is called after every decoder layer.
Args:
decoder_out (Tensor): in shape (num_queries, batch_size, c).
mask_feature (Tensor): in shape (batch_size, c, h, w).
attn_mask_target_size (tuple[int, int]): target attention
mask size.
Returns:
tuple: A tuple contain three elements.
- cls_pred (Tensor): Classification scores in shape \
(batch_size, num_queries, cls_out_channels). \
Note `cls_out_channels` should includes background.
- mask_pred (Tensor): Mask scores in shape \
(batch_size, num_queries,h, w).
- attn_mask (Tensor): Attention mask in shape \
(batch_size * num_heads, num_queries, h, w).
"""
decoder_out = self.transformer_decoder.post_norm(decoder_out)
decoder_out = decoder_out.transpose(0, 1)
# shape (num_queries, batch_size, c)
cls_pred = self.cls_embed(decoder_out)
# shape (num_queries, batch_size, c)
mask_embed = self.mask_embed(decoder_out)
# shape (num_queries, batch_size, h, w)
mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature)
attn_mask = F.interpolate(
mask_pred,
attn_mask_target_size,
mode='bilinear',
align_corners=False)
# shape (num_queries, batch_size, h, w) ->
# (batch_size * num_head, num_queries, h, w)
attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat(
(1, self.num_heads, 1, 1)).flatten(0, 1)
attn_mask = attn_mask.sigmoid() < 0.5
attn_mask = attn_mask.detach()
return cls_pred, mask_pred, attn_mask
def forward(self, feats, img_metas):
"""Forward function.
Args:
feats (list[Tensor]): Multi scale Features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
Returns:
tuple: A tuple contains two elements.
- cls_pred_list (list[Tensor)]: Classification logits \
for each decoder layer. Each is a 3D-tensor with shape \
(batch_size, num_queries, cls_out_channels). \
Note `cls_out_channels` should includes background.
- mask_pred_list (list[Tensor]): Mask logits for each \
decoder layer. Each with shape (batch_size, num_queries, \
h, w).
"""
batch_size = len(img_metas)
mask_features, multi_scale_memorys = self.pixel_decoder(feats)
# multi_scale_memorys (from low resolution to high resolution)
decoder_inputs = []
decoder_positional_encodings = []
for i in range(self.num_transformer_feat_level):
decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i])
# shape (batch_size, c, h, w) -> (h*w, batch_size, c)
decoder_input = decoder_input.flatten(2).permute(2, 0, 1)
level_embed = self.level_embed.weight[i].view(1, 1, -1)
decoder_input = decoder_input + level_embed
# shape (batch_size, c, h, w) -> (h*w, batch_size, c)
mask = decoder_input.new_zeros(
(batch_size, ) + multi_scale_memorys[i].shape[-2:],
dtype=torch.bool)
decoder_positional_encoding = self.decoder_positional_encoding(
mask)
decoder_positional_encoding = decoder_positional_encoding.flatten(
2).permute(2, 0, 1)
decoder_inputs.append(decoder_input)
decoder_positional_encodings.append(decoder_positional_encoding)
# shape (num_queries, c) -> (num_queries, batch_size, c)
query_feat = self.query_feat.weight.unsqueeze(1).repeat(
(1, batch_size, 1))
query_embed = self.query_embed.weight.unsqueeze(1).repeat(
(1, batch_size, 1))
cls_pred_list = []
mask_pred_list = []
cls_pred, mask_pred, attn_mask = self.forward_head(
query_feat, mask_features, multi_scale_memorys[0].shape[-2:])
cls_pred_list.append(cls_pred)
mask_pred_list.append(mask_pred)
for i in range(self.num_transformer_decoder_layers):
level_idx = i % self.num_transformer_feat_level
# if a mask is all True(all background), then set it all False.
attn_mask[torch.where(
attn_mask.sum(-1) == attn_mask.shape[-1])] = False
# cross_attn + self_attn
layer = self.transformer_decoder.layers[i]
attn_masks = [attn_mask, None]
query_feat = layer(
query=query_feat,
key=decoder_inputs[level_idx],
value=decoder_inputs[level_idx],
query_pos=query_embed,
key_pos=decoder_positional_encodings[level_idx],
attn_masks=attn_masks,
query_key_padding_mask=None,
# here we do not apply masking on padded region
key_padding_mask=None)
cls_pred, mask_pred, attn_mask = self.forward_head(
query_feat, mask_features, multi_scale_memorys[
(i + 1) % self.num_transformer_feat_level].shape[-2:])
cls_pred_list.append(cls_pred)
mask_pred_list.append(mask_pred)
return cls_pred_list, mask_pred_list

@ -17,6 +17,7 @@ from .grid_rcnn import GridRCNN
from .htc import HybridTaskCascade
from .kd_one_stage import KnowledgeDistillationSingleStageDetector
from .lad import LAD
from .mask2former import Mask2Former
from .mask_rcnn import MaskRCNN
from .mask_scoring_rcnn import MaskScoringRCNN
from .maskformer import MaskFormer
@ -51,5 +52,5 @@ __all__ = [
'VFNet', 'DETR', 'TridentFasterRCNN', 'SparseRCNN', 'SCNet', 'SOLO',
'DeformableDETR', 'AutoAssign', 'YOLOF', 'CenterNet', 'YOLOX',
'TwoStagePanopticSegmentor', 'PanopticFPN', 'QueryInst', 'LAD', 'TOOD',
'MaskFormer'
'MaskFormer', 'Mask2Former'
]

@ -0,0 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
from ..builder import DETECTORS
from .maskformer import MaskFormer
@DETECTORS.register_module()
class Mask2Former(MaskFormer):
r"""Implementation of `Masked-attention Mask
Transformer for Universal Image Segmentation
<https://arxiv.org/pdf/2112.01527>`_."""
def __init__(self,
backbone,
neck=None,
panoptic_head=None,
panoptic_fusion_head=None,
train_cfg=None,
test_cfg=None,
init_cfg=None):
super().__init__(
backbone,
neck=neck,
panoptic_head=panoptic_head,
panoptic_fusion_head=panoptic_fusion_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg)

@ -0,0 +1,216 @@
import numpy as np
import torch
from mmcv import ConfigDict
from mmdet.core.mask import BitmapMasks
from mmdet.models.dense_heads import Mask2FormerHead
def test_mask2former_head_loss():
"""Tests head loss when truth is empty and non-empty."""
base_channels = 64
img_metas = [{
'batch_input_shape': (128, 160),
'img_shape': (126, 160, 3),
'ori_shape': (63, 80, 3)
}, {
'batch_input_shape': (128, 160),
'img_shape': (120, 160, 3),
'ori_shape': (60, 80, 3)
}]
feats = [
torch.rand((2, 64 * 2**i, 4 * 2**(3 - i), 5 * 2**(3 - i)))
for i in range(4)
]
num_things_classes = 80
num_stuff_classes = 53
num_classes = num_things_classes + num_stuff_classes
config = ConfigDict(
dict(
type='Mask2FormerHead',
in_channels=[base_channels * 2**i for i in range(4)],
feat_channels=base_channels,
out_channels=base_channels,
num_things_classes=num_things_classes,
num_stuff_classes=num_stuff_classes,
num_queries=100,
num_transformer_feat_level=3,
pixel_decoder=dict(
type='MSDeformAttnPixelDecoder',
num_outs=3,
norm_cfg=dict(type='GN', num_groups=32),
act_cfg=dict(type='ReLU'),
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiScaleDeformableAttention',
embed_dims=base_channels,
num_heads=8,
num_levels=3,
num_points=4,
im2col_step=64,
dropout=0.0,
batch_first=False,
norm_cfg=None,
init_cfg=None),
ffn_cfgs=dict(
type='FFN',
embed_dims=base_channels,
feedforward_channels=base_channels * 4,
num_fcs=2,
ffn_drop=0.0,
act_cfg=dict(type='ReLU', inplace=True)),
feedforward_channels=base_channels * 4,
ffn_dropout=0.0,
operation_order=('self_attn', 'norm', 'ffn', 'norm')),
init_cfg=None),
positional_encoding=dict(
type='SinePositionalEncoding',
num_feats=base_channels // 2,
normalize=True),
init_cfg=None),
enforce_decoder_input_project=False,
positional_encoding=dict(
type='SinePositionalEncoding',
num_feats=base_channels // 2,
normalize=True),
transformer_decoder=dict(
type='DetrTransformerDecoder',
return_intermediate=True,
num_layers=9,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=base_channels,
num_heads=8,
attn_drop=0.0,
proj_drop=0.0,
dropout_layer=None,
batch_first=False),
ffn_cfgs=dict(
embed_dims=base_channels,
feedforward_channels=base_channels * 8,
num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True),
ffn_drop=0.0,
dropout_layer=None,
add_identity=True),
# the following parameter was not used,
# just make current api happy
feedforward_channels=base_channels * 8,
operation_order=('cross_attn', 'norm', 'self_attn', 'norm',
'ffn', 'norm')),
init_cfg=None),
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=2.0,
reduction='mean',
class_weight=[1.0] * num_classes + [0.1]),
loss_mask=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=5.0),
loss_dice=dict(
type='DiceLoss',
use_sigmoid=True,
activate=True,
reduction='mean',
naive_dice=True,
eps=1.0,
loss_weight=5.0),
train_cfg=dict(
num_points=256,
oversample_ratio=3.0,
importance_sample_ratio=0.75,
assigner=dict(
type='MaskHungarianAssigner',
cls_cost=dict(type='ClassificationCost', weight=2.0),
mask_cost=dict(
type='CrossEntropyLossCost',
weight=5.0,
use_sigmoid=True),
dice_cost=dict(
type='DiceCost', weight=5.0, pred_act=True, eps=1.0)),
sampler=dict(type='MaskPseudoSampler')),
test_cfg=dict(
panoptic_on=True,
semantic_on=False,
instance_on=True,
max_dets_per_image=100,
object_mask_thr=0.8,
iou_thr=0.8)))
self = Mask2FormerHead(**config)
self.init_weights()
all_cls_scores, all_mask_preds = self.forward(feats, img_metas)
# Test that empty ground truth encourages the network to predict background
gt_labels_list = [torch.LongTensor([]), torch.LongTensor([])]
gt_masks_list = [
torch.zeros((0, 128, 160)).long(),
torch.zeros((0, 128, 160)).long()
]
empty_gt_losses = self.loss(all_cls_scores, all_mask_preds, gt_labels_list,
gt_masks_list, img_metas)
# When there is no truth, the cls loss should be nonzero but there should
# be no mask loss.
for key, loss in empty_gt_losses.items():
if 'cls' in key:
assert loss.item() > 0, 'cls loss should be non-zero'
elif 'mask' in key:
assert loss.item(
) == 0, 'there should be no mask loss when there are no true mask'
elif 'dice' in key:
assert loss.item(
) == 0, 'there should be no dice loss when there are no true mask'
# when truth is non-empty then both cls, mask, dice loss should be nonzero
# random inputs
gt_labels_list = [
torch.tensor([10, 100]).long(),
torch.tensor([100, 10]).long()
]
mask1 = torch.zeros((2, 128, 160)).long()
mask1[0, :50] = 1
mask1[1, 50:] = 1
mask2 = torch.zeros((2, 128, 160)).long()
mask2[0, :, :50] = 1
mask2[1, :, 50:] = 1
gt_masks_list = [mask1, mask2]
two_gt_losses = self.loss(all_cls_scores, all_mask_preds, gt_labels_list,
gt_masks_list, img_metas)
for loss in two_gt_losses.values():
assert loss.item() > 0, 'all loss should be non-zero'
# test forward_train
gt_bboxes = None
gt_labels = [
torch.tensor([10]).long(),
torch.tensor([10]).long(),
]
thing_mask1 = np.zeros((1, 128, 160), dtype=np.int32)
thing_mask1[0, :50] = 1
thing_mask2 = np.zeros((1, 128, 160), dtype=np.int32)
thing_mask2[0, :, 50:] = 1
gt_masks = [
BitmapMasks(thing_mask1, 128, 160),
BitmapMasks(thing_mask2, 128, 160),
]
stuff_mask1 = torch.zeros((1, 128, 160)).long()
stuff_mask1[0, :50] = 10
stuff_mask1[0, 50:] = 100
stuff_mask2 = torch.zeros((1, 128, 160)).long()
stuff_mask2[0, :, 50:] = 10
stuff_mask2[0, :, :50] = 100
gt_semantic_seg = [stuff_mask1, stuff_mask2]
self.forward_train(feats, img_metas, gt_bboxes, gt_labels, gt_masks,
gt_semantic_seg)
# test inference mode
self.simple_test(feats, img_metas)

@ -811,3 +811,114 @@ def test_maskformer_forward():
rescale=True,
return_loss=False)
batch_results.append(result)
def test_mask2former_forward():
model_cfg = _get_detector_cfg(
'mask2former/mask2former_r50_lsj_8x2_50e_coco.py')
base_channels = 32
model_cfg.backbone.depth = 18
model_cfg.backbone.init_cfg = None
model_cfg.backbone.base_channels = base_channels
model_cfg.panoptic_head.in_channels = [
base_channels * 2**i for i in range(4)
]
model_cfg.panoptic_head.feat_channels = base_channels
model_cfg.panoptic_head.out_channels = base_channels
model_cfg.panoptic_head.pixel_decoder.encoder.\
transformerlayers.attn_cfgs.embed_dims = base_channels
model_cfg.panoptic_head.pixel_decoder.encoder.\
transformerlayers.ffn_cfgs.embed_dims = base_channels
model_cfg.panoptic_head.pixel_decoder.encoder.\
transformerlayers.ffn_cfgs.feedforward_channels = base_channels * 4
model_cfg.panoptic_head.pixel_decoder.\
positional_encoding.num_feats = base_channels // 2
model_cfg.panoptic_head.positional_encoding.\
num_feats = base_channels // 2
model_cfg.panoptic_head.transformer_decoder.\
transformerlayers.attn_cfgs.embed_dims = base_channels
model_cfg.panoptic_head.transformer_decoder.\
transformerlayers.ffn_cfgs.embed_dims = base_channels
model_cfg.panoptic_head.transformer_decoder.\
transformerlayers.ffn_cfgs.feedforward_channels = base_channels * 8
model_cfg.panoptic_head.transformer_decoder.\
transformerlayers.feedforward_channels = base_channels * 8
from mmdet.core import BitmapMasks
from mmdet.models import build_detector
detector = build_detector(model_cfg)
# Test forward train with non-empty truth batch
detector.train()
img_metas = [
{
'batch_input_shape': (128, 160),
'img_shape': (126, 160, 3),
'ori_shape': (63, 80, 3),
'pad_shape': (128, 160, 3)
},
]
img = torch.rand((1, 3, 128, 160))
gt_bboxes = None
gt_labels = [
torch.tensor([10]).long(),
]
thing_mask1 = np.zeros((1, 128, 160), dtype=np.int32)
thing_mask1[0, :50] = 1
gt_masks = [
BitmapMasks(thing_mask1, 128, 160),
]
stuff_mask1 = torch.zeros((1, 128, 160)).long()
stuff_mask1[0, :50] = 10
stuff_mask1[0, 50:] = 100
gt_semantic_seg = [
stuff_mask1,
]
losses = detector.forward(
img=img,
img_metas=img_metas,
gt_bboxes=gt_bboxes,
gt_labels=gt_labels,
gt_masks=gt_masks,
gt_semantic_seg=gt_semantic_seg,
return_loss=True)
assert isinstance(losses, dict)
loss, _ = detector._parse_losses(losses)
assert float(loss.item()) > 0
# Test forward train with an empty truth batch
gt_bboxes = [
torch.empty((0, 4)).float(),
]
gt_labels = [
torch.empty((0, )).long(),
]
mask = np.zeros((0, 128, 160), dtype=np.uint8)
gt_masks = [
BitmapMasks(mask, 128, 160),
]
gt_semantic_seg = [
torch.randint(0, 133, (0, 128, 160)),
]
losses = detector.forward(
img,
img_metas,
gt_bboxes=gt_bboxes,
gt_labels=gt_labels,
gt_masks=gt_masks,
gt_semantic_seg=gt_semantic_seg,
return_loss=True)
assert isinstance(losses, dict)
loss, _ = detector._parse_losses(losses)
assert float(loss.item()) > 0
# Test forward test
detector.eval()
with torch.no_grad():
img_list = [g[None, :] for g in img]
batch_results = []
for one_img, one_meta in zip(img_list, img_metas):
result = detector.forward([one_img], [[one_meta]],
rescale=True,
return_loss=False)
batch_results.append(result)

@ -606,3 +606,27 @@ def test_mask_hungarian_match_assigner():
assert torch.all(assign_result.gt_inds > -1)
assert (assign_result.gt_inds > 0).sum() == gt_labels.size(0)
assert (assign_result.labels > -1).sum() == gt_labels.size(0)
# test with mask bce mode
assigner_cfg = dict(
cls_cost=dict(type='ClassificationCost', weight=0.0),
mask_cost=dict(
type='CrossEntropyLossCost', weight=1.0, use_sigmoid=True),
dice_cost=dict(type='DiceCost', weight=0.0, pred_act=True, eps=1.0))
self = MaskHungarianAssigner(**assigner_cfg)
assign_result = self.assign(cls_pred, mask_pred, gt_labels, gt_masks,
img_meta)
assert torch.all(assign_result.gt_inds > -1)
assert (assign_result.gt_inds > 0).sum() == gt_labels.size(0)
assert (assign_result.labels > -1).sum() == gt_labels.size(0)
# test with mask ce mode
assigner_cfg = dict(
cls_cost=dict(type='ClassificationCost', weight=0.0),
mask_cost=dict(
type='CrossEntropyLossCost', weight=1.0, use_sigmoid=False),
dice_cost=dict(type='DiceCost', weight=0.0, pred_act=True, eps=1.0))
self = MaskHungarianAssigner(**assigner_cfg)
with pytest.raises(NotImplementedError):
assign_result = self.assign(cls_pred, mask_pred, gt_labels, gt_masks,
img_meta)

Loading…
Cancel
Save