diff --git a/.gitignore b/.gitignore index 04ff57b..c317a91 100644 --- a/.gitignore +++ b/.gitignore @@ -40,7 +40,14 @@ dist/ # project dirs /datasets/coco /datasets/lvis +/datasets/pic +/datasets/ytvos /models +/demo_outputs +/example_inputs /debug /weights +/export eval.sh +train.sh +benchmark.sh \ No newline at end of file diff --git a/README.md b/README.md index 4cac1aa..b991a7d 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,34 @@ Name | box AP | download --- |:---:|:---: [FCOS_R_50_1x](configs/FCOS-Detection/R_50_1x.yaml) | 38.7 | [model](https://cloudstor.aarnet.edu.au/plus/s/glqFc13cCoEyHYy/download) +### COCO Instance Segmentation Baselines with [BlendMask](https://arxiv.org/abs/2001.00309) + +Model | Name |inference time (ms/im) | box AP | mask AP | download +--- |:---:|:---:|:---:|:---:|:---: +Mask R-CNN | [550_R_50_3x](configs/RCNN/550_R_50_FPN_3x.yaml) | 63 | 39.1 | 35.3 | +BlendMask | [550_R_50_3x](configs/BlendMask/550_R_50_3x.yaml) | 40 | 38.7 | 34.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/o0bpkmhMiuYgIcQ/download) +Mask R-CNN | [R_50_1x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml) | 90 | 38.6 | 35.2 | +BlendMask | [R_50_1x](configs/BlendMask/R_50_1x.yaml) | 83 | 39.9 | 35.8 | [model](https://cloudstor.aarnet.edu.au/plus/s/crpmeVCnQ3StvSz/download) +Mask R-CNN | [R_50_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml) | | 41.0 | 37.2 | +BlendMask | [R_50_3x](configs/BlendMask/R_50_3x.yaml) | | 42.7 | 37.8 | [model](https://cloudstor.aarnet.edu.au/plus/s/9u1cG2zXvEva5SM/download) +Mask R-CNN | [R_101_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml) | | 42.9 | 38.6 | +BlendMask | [R_101_3x](configs/BlendMask/R_101_3x.yaml) | | 44.8 | 39.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/mYm5VCXICoeLNHq/download) +BlendMask | [R_101_dcni3_5x](configs/BlendMask/R_101_dcni3_5x.yaml) | | 46.8 | 41.1 | [model](https://cloudstor.aarnet.edu.au/plus/s/TAZPxSDvPuhegKp/download) + +### COCO Panoptic Segmentation Baselines with BlendMask +Model | Name | PQ | PQTh | PQSt | download +--- |:---:|:---:|:---:|:---:|:---: +Panoptic FPN | [R_50_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml) | 41.5 | 48.3 | 31.2 | +BlendMask | [R_50_3x](configs/BlendMask/Panoptic/R_50_3x.yaml) | 42.5 | 49.5 | 32.0 | [model](https://cloudstor.aarnet.edu.au/plus/s/bG0IhYeMAvlTGTq/download) +Panoptic FPN | [R_101_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/panoptic_fpn_R_101_3x.yaml) | 43.0 | 49.7 | 32.9 | +BlendMask | [R_101_3x](configs/BlendMask/Panoptic/R_101_3x.yaml) | 44.3 | 51.6 | 33.2 | [model](https://cloudstor.aarnet.edu.au/plus/s/AEwbhyQ9F3lqvsz/download) +BlendMask | [R_101_dcni3_5x](configs/BlendMask/Panoptic/R_101_dcni3_5x.yaml) | 46.0 | 52.9 | 35.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/GyWDhsukAYYokZg/download) + +### Person in Context with BlendMask +Model | Name | box AP | mask AP | download +--- |:---:|:---:|:---:|:---: +BlendMask | [R_50_1x](configs/BlendMask/Person/R_50_1x.yaml) | 70.6 | 66.7 | [model](https://cloudstor.aarnet.edu.au/plus/s/d4f16WshXYbOuIo) + ## Installation First install Detectron2 following the official guide: [INSTALL.md](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md). Then build AdelaiDet with: diff --git a/adet/config/defaults.py b/adet/config/defaults.py index 3847360..c167e83 100644 --- a/adet/config/defaults.py +++ b/adet/config/defaults.py @@ -46,12 +46,11 @@ _C.MODEL.FCOS.NUM_SHARE_CONVS = 0 _C.MODEL.FCOS.CENTER_SAMPLE = True _C.MODEL.FCOS.POS_RADIUS = 1.5 _C.MODEL.FCOS.LOC_LOSS_TYPE = 'giou' - +_C.MODEL.FCOS.YIELD_PROPOSAL = False # ---------------------------------------------------------------------------- # # VoVNet backbone # ---------------------------------------------------------------------------- # - _C.MODEL.VOVNET = CN() _C.MODEL.VOVNET.CONV_BODY = "V-39-eSE" _C.MODEL.VOVNET.OUT_FEATURES = ["stage2", "stage3", "stage4", "stage5"] @@ -59,4 +58,33 @@ _C.MODEL.VOVNET.OUT_FEATURES = ["stage2", "stage3", "stage4", "stage5"] # Options: FrozenBN, GN, "SyncBN", "BN" _C.MODEL.VOVNET.NORM = "FrozenBN" _C.MODEL.VOVNET.OUT_CHANNELS = 256 -_C.MODEL.VOVNET.BACKBONE_OUT_CHANNELS = 256 \ No newline at end of file +_C.MODEL.VOVNET.BACKBONE_OUT_CHANNELS = 256 + +# ---------------------------------------------------------------------------- # +# BlendMask Options +# ---------------------------------------------------------------------------- # +_C.MODEL.BLENDMASK = CN() +_C.MODEL.BLENDMASK.ATTN_SIZE = 14 +_C.MODEL.BLENDMASK.TOP_INTERP = "bilinear" +_C.MODEL.BLENDMASK.BOTTOM_RESOLUTION = 56 +_C.MODEL.BLENDMASK.POOLER_TYPE = "ROIAlignV2" +_C.MODEL.BLENDMASK.POOLER_SAMPLING_RATIO = 1 +_C.MODEL.BLENDMASK.POOLER_SCALES = (0.25,) +_C.MODEL.BLENDMASK.INSTANCE_LOSS_WEIGHT = 1.0 +_C.MODEL.BLENDMASK.VISUALIZE = False + +# ---------------------------------------------------------------------------- # +# Basis Module Options +# ---------------------------------------------------------------------------- # +_C.MODEL.BASIS_MODULE = CN() +_C.MODEL.BASIS_MODULE.NAME = "ProtoNet" +_C.MODEL.BASIS_MODULE.NUM_BASES = 4 +_C.MODEL.BASIS_MODULE.LOSS_ON = False +_C.MODEL.BASIS_MODULE.ANN_SET = "coco" +_C.MODEL.BASIS_MODULE.CONVS_DIM = 128 +_C.MODEL.BASIS_MODULE.IN_FEATURES = ["p3", "p4", "p5"] +_C.MODEL.BASIS_MODULE.NORM = "SyncBN" +_C.MODEL.BASIS_MODULE.NUM_CONVS = 3 +_C.MODEL.BASIS_MODULE.COMMON_STRIDE = 8 +_C.MODEL.BASIS_MODULE.NUM_CLASSES = 80 +_C.MODEL.BASIS_MODULE.LOSS_WEIGHT = 0.3 \ No newline at end of file diff --git a/adet/data/__init__.py b/adet/data/__init__.py index e67e0ab..062d071 100644 --- a/adet/data/__init__.py +++ b/adet/data/__init__.py @@ -1,5 +1,5 @@ from . import builtin # ensure the builtin datasets are registered -# from .dataset_mapper import DatasetMapperWithBasis +from .dataset_mapper import DatasetMapperWithBasis -# __all__ = ["DatasetMapperWithBasis"] +__all__ = ["DatasetMapperWithBasis"] diff --git a/adet/data/builtin.py b/adet/data/builtin.py index 4c388ec..e262a32 100644 --- a/adet/data/builtin.py +++ b/adet/data/builtin.py @@ -1,8 +1,9 @@ import os from detectron2.data.datasets.register_coco import register_coco_instances +from detectron2.data.datasets.builtin_meta import _get_builtin_metadata -# register person in context dataset +# register plane reconstruction _PREDEFINED_SPLITS_PIC = { "pic_person_train": ("pic/image/train", "pic/annotations/train_person.json"), @@ -24,4 +25,5 @@ def register_all_coco(root="datasets"): os.path.join(root, image_root), ) -register_all_coco() + +register_all_coco() \ No newline at end of file diff --git a/adet/data/dataset_mapper.py b/adet/data/dataset_mapper.py new file mode 100644 index 0000000..153032c --- /dev/null +++ b/adet/data/dataset_mapper.py @@ -0,0 +1,141 @@ +import copy +import numpy as np +import torch +from fvcore.common.file_io import PathManager +from PIL import Image + +from detectron2.data.dataset_mapper import DatasetMapper +from detectron2.data.detection_utils import SizeMismatchError +from detectron2.data import detection_utils as utils +from detectron2.data import transforms as T + +""" +This file contains the default mapping that's applied to "dataset dicts". +""" + +__all__ = ["DatasetMapperWithBasis"] + + +class DatasetMapperWithBasis(DatasetMapper): + """ + This caller enables the default Detectron2 mapper to read an additional basis semantic label + """ + + def __init__(self, cfg, is_train=True): + super().__init__(cfg, is_train) + + # fmt: off + self.basis_loss_on = cfg.MODEL.BASIS_MODULE.LOSS_ON + self.ann_set = cfg.MODEL.BASIS_MODULE.ANN_SET + # fmt: on + + def __call__(self, dataset_dict): + """ + Args: + dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. + + Returns: + dict: a format that builtin models in detectron2 accept + """ + dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below + # USER: Write your own image loading if it's not from a file + try: + image = utils.read_image(dataset_dict["file_name"], format=self.img_format) + except Exception as e: + print(dataset_dict["file_name"]) + print(e) + raise e + try: + utils.check_image_size(dataset_dict, image) + except SizeMismatchError as e: + expected_wh = (dataset_dict["width"], dataset_dict["height"]) + image_wh = (image.shape[1], image.shape[0]) + if (image_wh[1], image_wh[0]) == expected_wh: + print("transposing image {}".format(dataset_dict["file_name"])) + image = image.transpose(1, 0, 2) + else: + raise e + + if "annotations" not in dataset_dict or len(dataset_dict["annotations"]) == 0: + image, transforms = T.apply_transform_gens( + ([self.crop_gen] if self.crop_gen else []) + self.tfm_gens, image + ) + else: + # Crop around an instance if there are instances in the image. + # USER: Remove if you don't use cropping + if self.crop_gen: + crop_tfm = utils.gen_crop_transform_with_instance( + self.crop_gen.get_crop_size(image.shape[:2]), + image.shape[:2], + np.random.choice(dataset_dict["annotations"]), + ) + image = crop_tfm.apply_image(image) + image, transforms = T.apply_transform_gens(self.tfm_gens, image) + if self.crop_gen: + transforms = crop_tfm + transforms + + image_shape = image.shape[:2] # h, w + + # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, + # but not efficient on large generic data structures due to the use of pickle & mp.Queue. + # Therefore it's important to use torch.Tensor. + dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32")) + # Can use uint8 if it turns out to be slow some day + + # USER: Remove if you don't use pre-computed proposals. + if self.load_proposals: + utils.transform_proposals( + dataset_dict, image_shape, transforms, self.min_box_side_len, self.proposal_topk + ) + + if not self.is_train: + dataset_dict.pop("annotations", None) + dataset_dict.pop("sem_seg_file_name", None) + dataset_dict.pop("pano_seg_file_name", None) + return dataset_dict + + if "annotations" in dataset_dict: + # USER: Modify this if you want to keep them for some reason. + for anno in dataset_dict["annotations"]: + if not self.mask_on: + anno.pop("segmentation", None) + if not self.keypoint_on: + anno.pop("keypoints", None) + + # USER: Implement additional transformations if you have other types of data + annos = [ + utils.transform_instance_annotations( + obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices + ) + for obj in dataset_dict.pop("annotations") + if obj.get("iscrowd", 0) == 0 + ] + instances = utils.annotations_to_instances( + annos, image_shape, mask_format=self.mask_format + ) + # Create a tight bounding box from masks, useful when image is cropped + if self.crop_gen and instances.has("gt_masks"): + instances.gt_boxes = instances.gt_masks.get_bounding_boxes() + dataset_dict["instances"] = utils.filter_empty_instances(instances) + + # USER: Remove if you don't do semantic/panoptic segmentation. + if "sem_seg_file_name" in dataset_dict: + with PathManager.open(dataset_dict.pop("sem_seg_file_name"), "rb") as f: + sem_seg_gt = Image.open(f) + sem_seg_gt = np.asarray(sem_seg_gt, dtype="uint8") + sem_seg_gt = transforms.apply_segmentation(sem_seg_gt) + sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long")) + dataset_dict["sem_seg"] = sem_seg_gt + + if self.basis_loss_on and self.is_train: + # load basis supervisions + if self.ann_set == "coco": + basis_sem_path = dataset_dict["file_name"].replace('train2017', 'thing_train2017').replace('image/train', 'thing_train') + else: + basis_sem_path = dataset_dict["file_name"].replace('coco', 'lvis').replace('train2017', 'thing_train').replace('jpg', 'npz') + basis_sem_path = basis_sem_path.replace('jpg', 'npz') + basis_sem_gt = np.load(basis_sem_path)["mask"] + basis_sem_gt = transforms.apply_segmentation(basis_sem_gt) + basis_sem_gt = torch.as_tensor(basis_sem_gt.astype("long")) + dataset_dict["basis_sem"] = basis_sem_gt + return dataset_dict diff --git a/adet/layers/conv_with_kaiming_uniform.py b/adet/layers/conv_with_kaiming_uniform.py index f4bd001..88cb682 100644 --- a/adet/layers/conv_with_kaiming_uniform.py +++ b/adet/layers/conv_with_kaiming_uniform.py @@ -37,7 +37,7 @@ def conv_with_kaiming_uniform( if norm is None: nn.init.constant_(conv.bias, 0) module = [conv,] - if norm is not None: + if norm is not None and len(norm) > 0: if norm == "GN": norm_module = nn.GroupNorm(32, out_channels) else: diff --git a/adet/modeling/__init__.py b/adet/modeling/__init__.py index d39eb4b..228bb44 100644 --- a/adet/modeling/__init__.py +++ b/adet/modeling/__init__.py @@ -1,4 +1,5 @@ from .fcos import FCOS +from .blendmask import BlendMask from .backbone import build_fcos_resnet_fpn_backbone from .one_stage_detector import OneStageDetector diff --git a/adet/modeling/blendmask/__init__.py b/adet/modeling/blendmask/__init__.py new file mode 100644 index 0000000..66e4125 --- /dev/null +++ b/adet/modeling/blendmask/__init__.py @@ -0,0 +1,2 @@ +from .basis_module import build_basis_module +from .blendmask import BlendMask diff --git a/adet/modeling/blendmask/basis_module.py b/adet/modeling/blendmask/basis_module.py new file mode 100644 index 0000000..2adf2b2 --- /dev/null +++ b/adet/modeling/blendmask/basis_module.py @@ -0,0 +1,104 @@ +from typing import Dict +from torch import nn +from torch.nn import functional as F + +from detectron2.utils.registry import Registry +from detectron2.layers import ShapeSpec + +from adet.layers import conv_with_kaiming_uniform + + +BASIS_MODULE_REGISTRY = Registry("BASIS_MODULE") +BASIS_MODULE_REGISTRY.__doc__ = """ +Registry for basis module, which produces global bases from feature maps. + +The registered object will be called with `obj(cfg, input_shape)`. +The call should return a `nn.Module` object. +""" + + +def build_basis_module(cfg, input_shape): + name = cfg.MODEL.BASIS_MODULE.NAME + return BASIS_MODULE_REGISTRY.get(name)(cfg, input_shape) + + +@BASIS_MODULE_REGISTRY.register() +class ProtoNet(nn.Module): + def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): + """ + TODO: support deconv and variable channel width + """ + # official protonet has a relu after each conv + super().__init__() + # fmt: off + mask_dim = cfg.MODEL.BASIS_MODULE.NUM_BASES + planes = cfg.MODEL.BASIS_MODULE.CONVS_DIM + self.in_features = cfg.MODEL.BASIS_MODULE.IN_FEATURES + self.loss_on = cfg.MODEL.BASIS_MODULE.LOSS_ON + norm = cfg.MODEL.BASIS_MODULE.NORM + num_convs = cfg.MODEL.BASIS_MODULE.NUM_CONVS + self.visualize = cfg.MODEL.BLENDMASK.VISUALIZE + # fmt: on + + feature_channels = {k: v.channels for k, v in input_shape.items()} + + conv_block = conv_with_kaiming_uniform(norm, True) # conv relu bn + self.refine = nn.ModuleList() + for in_feature in self.in_features: + self.refine.append(conv_block( + feature_channels[in_feature], planes, 3, 1)) + tower = [] + for i in range(num_convs): + tower.append( + conv_block(planes, planes, 3, 1)) + tower.append( + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)) + tower.append( + conv_block(planes, planes, 3, 1)) + tower.append( + nn.Conv2d(planes, mask_dim, 1)) + self.add_module('tower', nn.Sequential(*tower)) + + if self.loss_on: + # fmt: off + self.common_stride = cfg.MODEL.BASIS_MODULE.COMMON_STRIDE + num_classes = cfg.MODEL.BASIS_MODULE.NUM_CLASSES + 1 + self.sem_loss_weight = cfg.MODEL.BASIS_MODULE.LOSS_WEIGHT + # fmt: on + + inplanes = feature_channels[self.in_features[0]] + self.seg_head = nn.Sequential(nn.Conv2d(inplanes, planes, kernel_size=3, + stride=1, padding=1, bias=False), + nn.BatchNorm2d(planes), + nn.ReLU(), + nn.Conv2d(planes, planes, kernel_size=3, + stride=1, padding=1, bias=False), + nn.BatchNorm2d(planes), + nn.ReLU(), + nn.Conv2d(planes, num_classes, kernel_size=1, + stride=1)) + + def forward(self, features, targets=None): + for i, f in enumerate(self.in_features): + if i == 0: + x = self.refine[i](features[f]) + else: + x_p = self.refine[i](features[f]) + x_p = F.interpolate(x_p, x.size()[2:], mode="bilinear", align_corners=False) + # x_p = aligned_bilinear(x_p, x.size(3) // x_p.size(3)) + x = x + x_p + outputs = {"bases": [self.tower(x)]} + losses = {} + # auxiliary thing semantic loss + if self.training and self.loss_on: + sem_out = self.seg_head(features[self.in_features[0]]) + # resize target to reduce memory + gt_sem = targets.unsqueeze(1).float() + gt_sem = F.interpolate( + gt_sem, scale_factor=1 / self.common_stride) + seg_loss = F.cross_entropy( + sem_out, gt_sem.squeeze().long()) + losses['loss_basis_sem'] = seg_loss * self.sem_loss_weight + elif self.visualize and hasattr(self, "seg_head"): + outputs["seg_thing_out"] = self.seg_head(features[self.in_features[0]]) + return outputs, losses diff --git a/adet/modeling/blendmask/blender.py b/adet/modeling/blendmask/blender.py new file mode 100644 index 0000000..46a3f6c --- /dev/null +++ b/adet/modeling/blendmask/blender.py @@ -0,0 +1,110 @@ +import torch +from torch.nn import functional as F + +from detectron2.layers import cat +from detectron2.modeling.poolers import ROIPooler + + +def build_blender(cfg): + return Blender(cfg) + + +class Blender(object): + def __init__(self, cfg): + + # fmt: off + self.pooler_resolution = cfg.MODEL.BLENDMASK.BOTTOM_RESOLUTION + sampling_ratio = cfg.MODEL.BLENDMASK.POOLER_SAMPLING_RATIO + pooler_type = cfg.MODEL.BLENDMASK.POOLER_TYPE + pooler_scales = cfg.MODEL.BLENDMASK.POOLER_SCALES + self.attn_size = cfg.MODEL.BLENDMASK.ATTN_SIZE + self.top_interp = cfg.MODEL.BLENDMASK.TOP_INTERP + num_bases = cfg.MODEL.BASIS_MODULE.NUM_BASES + # fmt: on + + self.attn_len = num_bases * self.attn_size * self.attn_size + + self.pooler = ROIPooler( + output_size=self.pooler_resolution, + scales=pooler_scales, + sampling_ratio=sampling_ratio, + pooler_type=pooler_type, + canonical_level=2) + + def __call__(self, bases, proposals, gt_instances): + if gt_instances is not None: + # training + # reshape attns + extras = proposals["extras"] + attns = proposals["top_feats"] + pos_inds = extras["pos_inds"] + if pos_inds.numel() == 0: + return None, {"loss_mask": sum([x.sum() * 0 for x in attns]) + bases[0].sum() * 0} + + gt_inds = extras["gt_inds"] + attns = cat( + [ + # Reshape: (N, C, Hi, Wi) -> (N, Hi, Wi, C) -> (N*Hi*Wi, C) + x.permute(0, 2, 3, 1).reshape(-1, self.attn_len) + for x in attns + ], dim=0,) + attns = attns[pos_inds] + + rois = self.pooler(bases, [x.gt_boxes for x in gt_instances]) + rois = rois[gt_inds] + pred_mask_logits = self.merge_bases(rois, attns) + + # gen targets + gt_masks = [] + for instances_per_image in gt_instances: + if len(instances_per_image.gt_boxes.tensor) == 0: + continue + gt_mask_per_image = instances_per_image.gt_masks.crop_and_resize( + instances_per_image.gt_boxes.tensor, self.pooler_resolution + ).to(device=pred_mask_logits.device) + gt_masks.append(gt_mask_per_image) + gt_masks = cat(gt_masks, dim=0) + gt_masks = gt_masks[gt_inds] + N = gt_masks.size(0) + gt_masks = gt_masks.view(N, -1) + + gt_ctr = extras["gt_ctr"] + loss_denorm = extras["loss_denorm"] + mask_losses = F.binary_cross_entropy_with_logits( + pred_mask_logits, gt_masks.to(dtype=torch.float32), reduction="none") + mask_loss = ((mask_losses.mean(dim=-1) * gt_ctr).sum() + / loss_denorm) + return None, {"loss_mask": mask_loss} + else: + # no proposals + total_instances = sum([len(x) for x in proposals]) + if total_instances == 0: + # add empty pred_masks results + for box in proposals: + box.pred_masks = box.pred_classes.view( + -1, 1, self.pooler_resolution, self.pooler_resolution) + return proposals, {} + rois = self.pooler(bases, [x.pred_boxes for x in proposals]) + attns = cat([x.top_feat for x in proposals], dim=0) + pred_mask_logits = self.merge_bases(rois, attns).sigmoid() + pred_mask_logits = pred_mask_logits.view( + -1, 1, self.pooler_resolution, self.pooler_resolution) + start_ind = 0 + for box in proposals: + end_ind = start_ind + len(box) + box.pred_masks = pred_mask_logits[start_ind:end_ind] + start_ind = end_ind + return proposals, {} + + def merge_bases(self, rois, coeffs, location_to_inds=None): + # merge predictions + N = coeffs.size(0) + if location_to_inds is not None: + rois = rois[location_to_inds] + N, B, H, W = rois.size() + + coeffs = coeffs.view(N, -1, self.attn_size, self.attn_size) + coeffs = F.interpolate(coeffs, (H, W), + mode=self.top_interp).softmax(dim=1) + masks_preds = (rois * coeffs).sum(dim=1) + return masks_preds.view(N, -1) diff --git a/adet/modeling/blendmask/blendmask.py b/adet/modeling/blendmask/blendmask.py new file mode 100644 index 0000000..5ebe77c --- /dev/null +++ b/adet/modeling/blendmask/blendmask.py @@ -0,0 +1,154 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import torch +from torch import nn + +from detectron2.structures import ImageList +from detectron2.modeling.postprocessing import detector_postprocess, sem_seg_postprocess +from detectron2.modeling.proposal_generator import build_proposal_generator +from detectron2.modeling.backbone import build_backbone +from detectron2.modeling.meta_arch.panoptic_fpn import combine_semantic_and_instance_outputs +from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY +from detectron2.modeling.meta_arch.semantic_seg import build_sem_seg_head + +from .blender import build_blender +from .basis_module import build_basis_module + +__all__ = ["BlendMask"] + + +@META_ARCH_REGISTRY.register() +class BlendMask(nn.Module): + """ + Main class for BlendMask architectures (see https://arxiv.org/abd/1901.02446). + """ + + def __init__(self, cfg): + super().__init__() + + self.device = torch.device(cfg.MODEL.DEVICE) + self.instance_loss_weight = cfg.MODEL.BLENDMASK.INSTANCE_LOSS_WEIGHT + + self.backbone = build_backbone(cfg) + self.proposal_generator = build_proposal_generator(cfg, self.backbone.output_shape()) + self.blender = build_blender(cfg) + self.basis_module = build_basis_module(cfg, self.backbone.output_shape()) + + # options when combining instance & semantic outputs + self.combine_on = cfg.MODEL.PANOPTIC_FPN.COMBINE.ENABLED + if self.combine_on: + self.panoptic_module = build_sem_seg_head(cfg, self.backbone.output_shape()) + self.combine_overlap_threshold = cfg.MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH + self.combine_stuff_area_limit = cfg.MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT + self.combine_instances_confidence_threshold = ( + cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH) + + # build top module + in_channels = cfg.MODEL.FPN.OUT_CHANNELS + num_bases = cfg.MODEL.BASIS_MODULE.NUM_BASES + attn_size = cfg.MODEL.BLENDMASK.ATTN_SIZE + attn_len = num_bases * attn_size * attn_size + self.top_layer = nn.Conv2d( + in_channels, attn_len, + kernel_size=3, stride=1, padding=1) + torch.nn.init.normal_(self.top_layer.weight, std=0.01) + torch.nn.init.constant_(self.top_layer.bias, 0) + + pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1) + pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1) + self.normalizer = lambda x: (x - pixel_mean) / pixel_std + self.to(self.device) + + def forward(self, batched_inputs): + """ + Args: + batched_inputs: a list, batched outputs of :class:`DatasetMapper`. + Each item in the list contains the inputs for one image. + + For now, each item in the list is a dict that contains: + image: Tensor, image in (C, H, W) format. + instances: Instances + sem_seg: semantic segmentation ground truth. + Other information that's included in the original dicts, such as: + "height", "width" (int): the output resolution of the model, used in inference. + See :meth:`postprocess` for details. + + Returns: + list[dict]: each dict is the results for one image. The dict + contains the following keys: + "instances": see :meth:`GeneralizedRCNN.forward` for its format. + "sem_seg": see :meth:`SemanticSegmentor.forward` for its format. + "panoptic_seg": available when `PANOPTIC_FPN.COMBINE.ENABLED`. + See the return value of + :func:`combine_semantic_and_instance_outputs` for its format. + """ + images = [x["image"].to(self.device) for x in batched_inputs] + images = [self.normalizer(x) for x in images] + images = ImageList.from_tensors(images, self.backbone.size_divisibility) + features = self.backbone(images.tensor) + + if self.combine_on: + if "sem_seg" in batched_inputs[0]: + gt_sem = [x["sem_seg"].to(self.device) for x in batched_inputs] + gt_sem = ImageList.from_tensors( + gt_sem, self.backbone.size_divisibility, self.panoptic_module.ignore_value + ).tensor + else: + gt_sem = None + sem_seg_results, sem_seg_losses = self.panoptic_module(features, gt_sem) + + if "basis_sem" in batched_inputs[0]: + basis_sem = [x["basis_sem"].to(self.device) for x in batched_inputs] + basis_sem = ImageList.from_tensors( + basis_sem, self.backbone.size_divisibility, 0).tensor + else: + basis_sem = None + basis_out, basis_losses = self.basis_module(features, basis_sem) + + if "instances" in batched_inputs[0]: + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + else: + gt_instances = None + proposals, proposal_losses = self.proposal_generator( + images, features, gt_instances, self.top_layer) + detector_results, detector_losses = self.blender( + basis_out["bases"], proposals, gt_instances) + + if self.training: + losses = {} + losses.update(basis_losses) + losses.update({k: v * self.instance_loss_weight for k, v in detector_losses.items()}) + losses.update(proposal_losses) + if self.combine_on: + losses.update(sem_seg_losses) + return losses + + processed_results = [] + for i, (detector_result, input_per_image, image_size) in enumerate(zip( + detector_results, batched_inputs, images.image_sizes)): + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + detector_r = detector_postprocess(detector_result, height, width) + processed_result = {"instances": detector_r} + if self.combine_on: + sem_seg_r = sem_seg_postprocess( + sem_seg_results[i], image_size, height, width) + processed_result["sem_seg"] = sem_seg_r + if "seg_thing_out" in basis_out: + seg_thing_r = sem_seg_postprocess( + basis_out["seg_thing_out"], image_size, height, width) + processed_result["sem_thing_seg"] = seg_thing_r + if self.basis_module.visualize: + processed_result["bases"] = basis_out["bases"] + processed_results.append(processed_result) + + if self.combine_on: + panoptic_r = combine_semantic_and_instance_outputs( + detector_r, + sem_seg_r.argmax(dim=0), + self.combine_overlap_threshold, + self.combine_stuff_area_limit, + self.combine_instances_confidence_threshold) + processed_results[-1]["panoptic_seg"] = panoptic_r + return processed_results diff --git a/adet/modeling/fcos/fcos.py b/adet/modeling/fcos/fcos.py index 3b3da65..fb5895b 100644 --- a/adet/modeling/fcos/fcos.py +++ b/adet/modeling/fcos/fcos.py @@ -45,6 +45,7 @@ class FCOS(nn.Module): self.pre_nms_topk_train = cfg.MODEL.FCOS.PRE_NMS_TOPK_TRAIN self.pre_nms_topk_test = cfg.MODEL.FCOS.PRE_NMS_TOPK_TEST self.nms_thresh = cfg.MODEL.FCOS.NMS_TH + self.yield_proposal = cfg.MODEL.FCOS.YIELD_PROPOSAL self.post_nms_topk_train = cfg.MODEL.FCOS.POST_NMS_TOPK_TRAIN self.post_nms_topk_test = cfg.MODEL.FCOS.POST_NMS_TOPK_TEST self.thresh_with_ctr = cfg.MODEL.FCOS.THRESH_WITH_CTR @@ -60,7 +61,13 @@ class FCOS(nn.Module): self.sizes_of_interest = soi self.fcos_head = FCOSHead(cfg, [input_shape[f] for f in self.in_features]) - def forward(self, images, features, gt_instances): + def forward_head(self, features, top_module=None): + features = [features[f] for f in self.in_features] + pred_class_logits, pred_deltas, pred_centerness, top_feats, bbox_towers = self.fcos_head( + features, top_module, self.yield_proposal) + return pred_class_logits, pred_deltas, pred_centerness, top_feats, bbox_towers + + def forward(self, images, features, gt_instances=None, top_module=None): """ Arguments: images (list[Tensor] or ImageList): images to be processed @@ -75,7 +82,8 @@ class FCOS(nn.Module): """ features = [features[f] for f in self.in_features] locations = self.compute_locations(features) - logits_pred, reg_pred, ctrness_pred, bbox_towers = self.fcos_head(features) + logits_pred, reg_pred, ctrness_pred, top_feats, bbox_towers = self.fcos_head( + features, top_module, self.yield_proposal) if self.training: pre_nms_thresh = self.pre_nms_thresh_train @@ -108,12 +116,29 @@ class FCOS(nn.Module): gt_instances ) + results = {} + if self.yield_proposal: + results["features"] = { + f: b for f, b in zip(self.in_features, bbox_towers)} + if self.training: - losses, _ = outputs.losses() - return None, losses + losses, extras = outputs.losses() + + if top_module is not None: + results["extras"] = extras + results["top_feats"] = top_feats + if self.yield_proposal: + with torch.no_grad(): + results["proposals"] = outputs.predict_proposals(top_feats) else: - proposals = outputs.predict_proposals() - return proposals, {} + losses = {} + with torch.no_grad(): + proposals = outputs.predict_proposals(top_feats) + if self.yield_proposal: + results["proposals"] = proposals + else: + results = proposals + return results, losses def compute_locations(self, features): locations = [] @@ -173,9 +198,9 @@ class FCOSHead(nn.Module): conv_func = nn.Conv2d for i in range(num_convs): tower.append(conv_func( - in_channels, in_channels, - kernel_size=3, stride=1, - padding=1, bias=True + in_channels, in_channels, + kernel_size=3, stride=1, + padding=1, bias=True )) if norm == "GN": tower.append(nn.GroupNorm(32, in_channels)) @@ -192,7 +217,7 @@ class FCOSHead(nn.Module): in_channels, 4, kernel_size=3, stride=1, padding=1 ) - self.ctrness = nn.Conv2d( + self.centerness = nn.Conv2d( in_channels, 1, kernel_size=3, stride=1, padding=1 ) @@ -205,7 +230,7 @@ class FCOSHead(nn.Module): for modules in [ self.cls_tower, self.bbox_tower, self.share_tower, self.cls_logits, - self.bbox_pred, self.ctrness + self.bbox_pred, self.centerness ]: for l in modules.modules(): if isinstance(l, nn.Conv2d): @@ -217,22 +242,26 @@ class FCOSHead(nn.Module): bias_value = -math.log((1 - prior_prob) / prior_prob) torch.nn.init.constant_(self.cls_logits.bias, bias_value) - def forward(self, x): + def forward(self, x, top_module=None, yield_bbox_towers=False): logits = [] bbox_reg = [] ctrness = [] + top_feats = [] bbox_towers = [] for l, feature in enumerate(x): feature = self.share_tower(feature) cls_tower = self.cls_tower(feature) bbox_tower = self.bbox_tower(feature) + if yield_bbox_towers: + bbox_towers.append(bbox_tower) logits.append(self.cls_logits(cls_tower)) - ctrness.append(self.ctrness(bbox_tower)) + ctrness.append(self.centerness(bbox_tower)) reg = self.bbox_pred(bbox_tower) if self.scales is not None: reg = self.scales[l](reg) # Note that we use relu, as in the improved FCOS, instead of exp. bbox_reg.append(F.relu(reg)) - - return logits, bbox_reg, ctrness, bbox_towers + if top_module is not None: + top_feats.append(top_module(bbox_tower)) + return logits, bbox_reg, ctrness, top_feats, bbox_towers diff --git a/adet/modeling/fcos/fcos_outputs.py b/adet/modeling/fcos/fcos_outputs.py index e6fa285..4bcd5a5 100644 --- a/adet/modeling/fcos/fcos_outputs.py +++ b/adet/modeling/fcos/fcos_outputs.py @@ -34,7 +34,7 @@ Naming convention: reg_pred: the predicted (left, top, right, bottom), corresponding to reg_targets ctrness_pred: predicted centerness scores - + """ @@ -57,6 +57,7 @@ def fcos_losses( focal_loss_alpha, focal_loss_gamma, iou_loss, + gt_inds, ): num_classes = logits_pred.size(1) labels = labels.flatten() @@ -82,29 +83,40 @@ def fcos_losses( reg_pred = reg_pred[pos_inds] reg_targets = reg_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] + gt_inds = gt_inds[pos_inds] ctrness_targets = compute_ctrness_targets(reg_targets) ctrness_targets_sum = ctrness_targets.sum() - ctrness_norm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) - - reg_loss = iou_loss( - reg_pred, - reg_targets, - ctrness_targets - ) / ctrness_norm + loss_denorm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) - ctrness_loss = F.binary_cross_entropy_with_logits( - ctrness_pred, - ctrness_targets, - reduction="sum" - ) / num_pos_avg + if pos_inds.numel() > 0: + reg_loss = iou_loss( + reg_pred, + reg_targets, + ctrness_targets + ) / loss_denorm + + ctrness_loss = F.binary_cross_entropy_with_logits( + ctrness_pred, + ctrness_targets, + reduction="sum" + ) / num_pos_avg + else: + reg_loss = reg_pred.sum() * 0 + ctrness_loss = ctrness_pred.sum() * 0 losses = { "loss_fcos_cls": class_loss, "loss_fcos_loc": reg_loss, "loss_fcos_ctr": ctrness_loss } - return losses, {} + extras = { + "pos_inds": pos_inds, + "gt_inds": gt_inds, + "gt_ctr": ctrness_targets, + "loss_denorm": loss_denorm + } + return losses, extras class FCOSOutputs(object): @@ -236,8 +248,10 @@ class FCOSOutputs(object): def compute_targets_for_locations(self, locations, targets, size_ranges): labels = [] reg_targets = [] + target_inds = [] xs, ys = locations[:, 0], locations[:, 1] + num_targets = 0 for im_i in range(len(targets)): targets_per_im = targets[im_i] bboxes = targets_per_im.gt_boxes.tensor @@ -247,6 +261,7 @@ class FCOSOutputs(object): if bboxes.numel() == 0: labels.append(labels_per_im.new_zeros(locations.size(0)) + self.num_classes) reg_targets.append(locations.new_zeros((locations.size(0), 4))) + target_inds.append(labels_per_im.new_zeros(locations.size(0)) - 1) continue area = targets_per_im.gt_boxes.area() @@ -280,14 +295,19 @@ class FCOSOutputs(object): locations_to_min_area, locations_to_gt_inds = locations_to_gt_area.min(dim=1) reg_targets_per_im = reg_targets_per_im[range(len(locations)), locations_to_gt_inds] + target_inds_per_im = locations_to_gt_inds + num_targets labels_per_im = labels_per_im[locations_to_gt_inds] labels_per_im[locations_to_min_area == INF] = self.num_classes labels.append(labels_per_im) reg_targets.append(reg_targets_per_im) + target_inds.append(target_inds_per_im) - return {"labels": labels, "reg_targets": reg_targets} + return { + "labels": labels, + "reg_targets": reg_targets, + "target_inds": target_inds} def losses(self): """ @@ -298,7 +318,10 @@ class FCOSOutputs(object): """ training_targets = self._get_ground_truth() - labels, reg_targets = training_targets["labels"], training_targets["reg_targets"] + labels, reg_targets, gt_inds = ( + training_targets["labels"], + training_targets["reg_targets"], + training_targets["target_inds"]) # Collect all logits and regression predictions over feature maps # and images to arrive at the same shape as the labels and targets @@ -327,6 +350,12 @@ class FCOSOutputs(object): x.reshape(-1) for x in labels ], dim=0,) + gt_inds = cat( + [ + # Reshape: (N, 1, Hi, Wi) -> (N*Hi*Wi,) + x.reshape(-1) for x in gt_inds + ], dim=0,) + reg_targets = cat( [ # Reshape: (N, Hi, Wi, 4) -> (N*Hi*Wi, 4) @@ -341,25 +370,35 @@ class FCOSOutputs(object): ctrness_pred, self.focal_loss_alpha, self.focal_loss_gamma, - self.iou_loss + self.iou_loss, + gt_inds ) - def predict_proposals(self): + def predict_proposals(self, top_feats): sampled_boxes = [] - bundle = ( - self.locations, self.logits_pred, - self.reg_pred, self.ctrness_pred, - self.strides - ) + bundle = { + "l": self.locations, "o": self.logits_pred, + "r": self.reg_pred, "c": self.ctrness_pred, + "s": self.strides, + } + + if len(top_feats) > 0: + bundle["t"] = top_feats - for i, (l, o, r, c, s) in enumerate(zip(*bundle)): + for i, instance in enumerate(zip(*bundle.values())): + instance_dict = dict(zip(bundle.keys(), instance)) # recall that during training, we normalize regression targets with FPN's stride. # we denormalize them here. - r = r * s + l = instance_dict["l"] + o = instance_dict["o"] + r = instance_dict["r"] * instance_dict["s"] + c = instance_dict["c"] + t = instance_dict["t"] if "t" in bundle else None + sampled_boxes.append( self.forward_for_single_feature_map( - l, o, r, c, self.image_sizes + l, o, r, c, self.image_sizes, t ) ) @@ -370,8 +409,8 @@ class FCOSOutputs(object): def forward_for_single_feature_map( self, locations, box_cls, - reg_pred, ctrness, image_sizes - ): + reg_pred, ctrness, + image_sizes, top_feat=None): N, C, H, W = box_cls.shape # put in the same format as locations @@ -381,6 +420,9 @@ class FCOSOutputs(object): box_regression = box_regression.reshape(N, -1, 4) ctrness = ctrness.view(N, 1, H, W).permute(0, 2, 3, 1) ctrness = ctrness.reshape(N, -1).sigmoid() + if top_feat is not None: + top_feat = top_feat.view(N, -1, H, W).permute(0, 2, 3, 1) + top_feat = top_feat.reshape(N, H * W, -1) # if self.thresh_with_ctr is True, we multiply the classification # scores with centerness scores before applying the threshold. @@ -406,6 +448,9 @@ class FCOSOutputs(object): per_box_regression = box_regression[i] per_box_regression = per_box_regression[per_box_loc] per_locations = locations[per_box_loc] + if top_feat is not None: + per_top_feat = top_feat[i] + per_top_feat = per_top_feat[per_box_loc] per_pre_nms_top_n = pre_nms_top_n[i] @@ -415,6 +460,8 @@ class FCOSOutputs(object): per_class = per_class[top_k_indices] per_box_regression = per_box_regression[top_k_indices] per_locations = per_locations[top_k_indices] + if top_feat is not None: + per_top_feat = per_top_feat[top_k_indices] detections = torch.stack([ per_locations[:, 0] - per_box_regression[:, 0], @@ -428,7 +475,8 @@ class FCOSOutputs(object): boxlist.scores = torch.sqrt(per_box_cls) boxlist.pred_classes = per_class boxlist.locations = per_locations - + if top_feat is not None: + boxlist.top_feat = per_top_feat results.append(boxlist) return results diff --git a/configs/BlendMask/550_R_50_1x.yaml b/configs/BlendMask/550_R_50_1x.yaml new file mode 100644 index 0000000..4e609eb --- /dev/null +++ b/configs/BlendMask/550_R_50_1x.yaml @@ -0,0 +1,6 @@ +_BASE_: "Base-550.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 +OUTPUT_DIR: "output/blendmask/550_R_50_1x" diff --git a/configs/BlendMask/550_R_50_3x.yaml b/configs/BlendMask/550_R_50_3x.yaml new file mode 100644 index 0000000..ca5b575 --- /dev/null +++ b/configs/BlendMask/550_R_50_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "Base-550.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 +OUTPUT_DIR: "output/blendmask/550_R_50_3x" diff --git a/configs/BlendMask/550_R_50_dcni3_5x.yaml b/configs/BlendMask/550_R_50_dcni3_5x.yaml new file mode 100644 index 0000000..d2a9501 --- /dev/null +++ b/configs/BlendMask/550_R_50_dcni3_5x.yaml @@ -0,0 +1,18 @@ +_BASE_: "Base-550.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + DEFORM_ON_PER_STAGE: [False, True, True, True] + DEFORM_MODULATED: True + DEFORM_INTERVAL: 3 +INPUT: + MIN_SIZE_TRAIN: (440, 594) + MIN_SIZE_TRAIN_SAMPLING: "range" + MAX_SIZE_TRAIN: 990 + CROP: + ENABLED: True +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 +OUTPUT_DIR: "output/blendmask/550_R_50_dcni3_5x" diff --git a/configs/BlendMask/Base-550.yaml b/configs/BlendMask/Base-550.yaml new file mode 100644 index 0000000..3acc65d --- /dev/null +++ b/configs/BlendMask/Base-550.yaml @@ -0,0 +1,17 @@ +_BASE_: "Base-BlendMask.yaml" +MODEL: + FCOS: + TOP_LEVELS: 1 + IN_FEATURES: ["p3", "p4", "p5", "p6"] + FPN_STRIDES: [8, 16, 32, 64] + SIZES_OF_INTEREST: [64, 128, 256] + NUM_SHARE_CONVS: 3 + NUM_CLS_CONVS: 0 + NUM_BOX_CONVS: 0 + BASIS_MODULE: + NUM_CONVS: 2 +INPUT: + MIN_SIZE_TRAIN: (440, 462, 484, 506, 528, 550) + MAX_SIZE_TRAIN: 916 + MIN_SIZE_TEST: 550 + MAX_SIZE_TEST: 916 diff --git a/configs/BlendMask/Base-BlendMask.yaml b/configs/BlendMask/Base-BlendMask.yaml new file mode 100644 index 0000000..da455b9 --- /dev/null +++ b/configs/BlendMask/Base-BlendMask.yaml @@ -0,0 +1,29 @@ +MODEL: + META_ARCHITECTURE: "BlendMask" + MASK_ON: True + BACKBONE: + NAME: "build_fcos_resnet_fpn_backbone" + RESNETS: + OUT_FEATURES: ["res3", "res4", "res5"] + FPN: + IN_FEATURES: ["res3", "res4", "res5"] + PROPOSAL_GENERATOR: + NAME: "FCOS" + BASIS_MODULE: + LOSS_ON: True + PANOPTIC_FPN: + COMBINE: + ENABLED: False + FCOS: + THRESH_WITH_CTR: True + USE_SCALE: False +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.01 # Note that RetinaNet uses a different default learning rate + STEPS: (60000, 80000) + MAX_ITER: 90000 +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) diff --git a/configs/BlendMask/Panoptic/Base-Panoptic.yaml b/configs/BlendMask/Panoptic/Base-Panoptic.yaml new file mode 100644 index 0000000..7fd16ec --- /dev/null +++ b/configs/BlendMask/Panoptic/Base-Panoptic.yaml @@ -0,0 +1,16 @@ +_BASE_: "../Base-BlendMask.yaml" +MODEL: + RESNETS: + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + FPN: + IN_FEATURES: ["res2", "res3", "res4", "res5"] + SEM_SEG_HEAD: + LOSS_WEIGHT: 0.5 + PANOPTIC_FPN: + COMBINE: + ENABLED: True + INSTANCES_CONFIDENCE_THRESH: 0.2 + OVERLAP_THRESH: 0.4 +DATASETS: + TRAIN: ("coco_2017_train_panoptic_separated",) + TEST: ("coco_2017_val_panoptic_separated",) diff --git a/configs/BlendMask/Panoptic/R_101_3x.yaml b/configs/BlendMask/Panoptic/R_101_3x.yaml new file mode 100644 index 0000000..c5afd60 --- /dev/null +++ b/configs/BlendMask/Panoptic/R_101_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "Base-Panoptic.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 +OUTPUT_DIR: "output/panoptic/blendmask/R_101_3x" diff --git a/configs/BlendMask/Panoptic/R_101_dcni3_5x.yaml b/configs/BlendMask/Panoptic/R_101_dcni3_5x.yaml new file mode 100644 index 0000000..45bc89c --- /dev/null +++ b/configs/BlendMask/Panoptic/R_101_dcni3_5x.yaml @@ -0,0 +1,18 @@ +_BASE_: "Base-Panoptic.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 + DEFORM_ON_PER_STAGE: [False, True, True, True] + DEFORM_MODULATED: True + DEFORM_INTERVAL: 3 +SOLVER: + STEPS: (280000, 360000) + MAX_ITER: 400000 +INPUT: + MIN_SIZE_TRAIN: (640, 864) + MIN_SIZE_TRAIN_SAMPLING: "range" + MAX_SIZE_TRAIN: 1333 + CROP: + ENABLED: True +OUTPUT_DIR: "output/panoptic/blendmask/R_101_dcni3_5x" diff --git a/configs/BlendMask/Panoptic/R_50_1x.yaml b/configs/BlendMask/Panoptic/R_50_1x.yaml new file mode 100644 index 0000000..24f9e99 --- /dev/null +++ b/configs/BlendMask/Panoptic/R_50_1x.yaml @@ -0,0 +1,6 @@ +_BASE_: "Base-Panoptic.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 +OUTPUT_DIR: "output/panoptic/blendmask/R_50_1x" diff --git a/configs/BlendMask/Panoptic/R_50_3x.yaml b/configs/BlendMask/Panoptic/R_50_3x.yaml new file mode 100644 index 0000000..9de5502 --- /dev/null +++ b/configs/BlendMask/Panoptic/R_50_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "Base-Panoptic.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 +OUTPUT_DIR: "output/panoptic/blendmask/R_50_3x" diff --git a/configs/BlendMask/Panoptic/R_50_dcni3_5x.yaml b/configs/BlendMask/Panoptic/R_50_dcni3_5x.yaml new file mode 100644 index 0000000..b928f1d --- /dev/null +++ b/configs/BlendMask/Panoptic/R_50_dcni3_5x.yaml @@ -0,0 +1,18 @@ +_BASE_: "Base-Panoptic.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + DEFORM_ON_PER_STAGE: [False, True, True, True] + DEFORM_MODULATED: True + DEFORM_INTERVAL: 3 +SOLVER: + STEPS: (280000, 360000) + MAX_ITER: 400000 +INPUT: + MIN_SIZE_TRAIN: (640, 864) + MIN_SIZE_TRAIN_SAMPLING: "range" + MAX_SIZE_TRAIN: 1440 + CROP: + ENABLED: True +OUTPUT_DIR: "output/panoptic/blendmask/R_50_dcni3_5x" diff --git a/configs/BlendMask/Person/Base-Person.yaml b/configs/BlendMask/Person/Base-Person.yaml new file mode 100644 index 0000000..f7c8be0 --- /dev/null +++ b/configs/BlendMask/Person/Base-Person.yaml @@ -0,0 +1,9 @@ +_BASE_: "../Base-BlendMask.yaml" +MODEL: + BASIS_MODULE: + NUM_CLASSES: 1 + FCOS: + NUM_CLASSES: 1 +DATASETS: + TRAIN: ("pic_person_train",) + TEST: ("pic_person_val",) diff --git a/configs/BlendMask/Person/R_50_1x.yaml b/configs/BlendMask/Person/R_50_1x.yaml new file mode 100644 index 0000000..0b044c6 --- /dev/null +++ b/configs/BlendMask/Person/R_50_1x.yaml @@ -0,0 +1,6 @@ +_BASE_: "Base-Person.yaml" +MODEL: + WEIGHTS: "https://cloudstor.aarnet.edu.au/plus/s/9u1cG2zXvEva5SM/download#R_50_3x.pth" + RESNETS: + DEPTH: 50 +OUTPUT_DIR: "output/person/blendmask/R_50_1x" diff --git a/configs/BlendMask/R_101_3x.yaml b/configs/BlendMask/R_101_3x.yaml new file mode 100644 index 0000000..835e629 --- /dev/null +++ b/configs/BlendMask/R_101_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "Base-BlendMask.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 +OUTPUT_DIR: "output/blendmask/R_101_3x" diff --git a/configs/BlendMask/R_101_dcni3_5x.yaml b/configs/BlendMask/R_101_dcni3_5x.yaml new file mode 100644 index 0000000..b89b93f --- /dev/null +++ b/configs/BlendMask/R_101_dcni3_5x.yaml @@ -0,0 +1,20 @@ +_BASE_: "Base-BlendMask.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 + DEFORM_ON_PER_STAGE: [False, True, True, True] + DEFORM_MODULATED: True + DEFORM_INTERVAL: 3 +SOLVER: + STEPS: (280000, 360000) + MAX_ITER: 400000 +INPUT: + MIN_SIZE_TRAIN: (640, 864) + MIN_SIZE_TRAIN_SAMPLING: "range" + MAX_SIZE_TRAIN: 1440 + CROP: + ENABLED: True +TEST: + EVAL_PERIOD: 20000 +OUTPUT_DIR: "output/blendmask/R_101_dcni3_5x" diff --git a/configs/BlendMask/R_50_1x.yaml b/configs/BlendMask/R_50_1x.yaml new file mode 100644 index 0000000..646430a --- /dev/null +++ b/configs/BlendMask/R_50_1x.yaml @@ -0,0 +1,6 @@ +_BASE_: "Base-BlendMask.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 +OUTPUT_DIR: "output/blendmask/R_50_1x" diff --git a/configs/BlendMask/R_50_3x.yaml b/configs/BlendMask/R_50_3x.yaml new file mode 100644 index 0000000..f4acd8d --- /dev/null +++ b/configs/BlendMask/R_50_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "Base-BlendMask.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 +OUTPUT_DIR: "output/blendmask/R_50_3x" diff --git a/configs/RCNN/550_R_50_FPN_3x.yaml b/configs/RCNN/550_R_50_FPN_3x.yaml new file mode 100644 index 0000000..022ce44 --- /dev/null +++ b/configs/RCNN/550_R_50_FPN_3x.yaml @@ -0,0 +1,15 @@ +_BASE_: "Base-RCNN.yaml" +MODEL: + WEIGHTS: "output/mask_rcnn/550_R_50_3x/model_final.pth" + MASK_ON: True + RESNETS: + DEPTH: 50 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 +INPUT: + MIN_SIZE_TRAIN: (440, 462, 484, 506, 528, 550) + MAX_SIZE_TRAIN: 916 + MIN_SIZE_TEST: 550 + MAX_SIZE_TEST: 916 +OUTPUT_DIR: "output/mask_rcnn/550_R_50_3x" diff --git a/configs/RCNN/Base-RCNN.yaml b/configs/RCNN/Base-RCNN.yaml new file mode 100644 index 0000000..3e020f2 --- /dev/null +++ b/configs/RCNN/Base-RCNN.yaml @@ -0,0 +1,42 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + BACKBONE: + NAME: "build_resnet_fpn_backbone" + RESNETS: + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + FPN: + IN_FEATURES: ["res2", "res3", "res4", "res5"] + ANCHOR_GENERATOR: + SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map + ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) + RPN: + IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] + PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level + PRE_NMS_TOPK_TEST: 1000 # Per FPN level + # Detectron1 uses 2000 proposals per-batch, + # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) + # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. + POST_NMS_TOPK_TRAIN: 1000 + POST_NMS_TOPK_TEST: 1000 + ROI_HEADS: + NAME: "StandardROIHeads" + IN_FEATURES: ["p2", "p3", "p4", "p5"] + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_FC: 2 + POOLER_RESOLUTION: 7 + ROI_MASK_HEAD: + NAME: "MaskRCNNConvUpsampleHead" + NUM_CONV: 4 + POOLER_RESOLUTION: 14 +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 + STEPS: (60000, 80000) + MAX_ITER: 90000 +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +VERSION: 2 diff --git a/configs/RCNN/LVIS/R_50_1x.yaml b/configs/RCNN/LVIS/R_50_1x.yaml new file mode 100644 index 0000000..a1b9751 --- /dev/null +++ b/configs/RCNN/LVIS/R_50_1x.yaml @@ -0,0 +1,6 @@ +_BASE_: "Base-LVIS.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 +OUTPUT_DIR: "output/lvis/mask_rcnn/R_50_1x" diff --git a/configs/RCNN/R_101_3x.yaml b/configs/RCNN/R_101_3x.yaml new file mode 100644 index 0000000..dab490a --- /dev/null +++ b/configs/RCNN/R_101_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "Base-RCNN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + MASK_ON: True + RESNETS: + DEPTH: 101 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/datasets/README.md b/datasets/README.md new file mode 100644 index 0000000..b503855 --- /dev/null +++ b/datasets/README.md @@ -0,0 +1,33 @@ +## Expected dataset structure for AdelaiDet instance detection: + +``` +coco/ + thing_train2017/ + # thing class label maps for auxiliary semantic loss +lvis/ + thing_train/ + # semantic labels for LVIS +``` + +Run `python prepare_thing_sem_from_instance.py`, to extract semantic labels from instance annotations. + +Run `python prepare_thing_sem_from_lvis.py`, to extract semantic labels from LVIS annotations. + +## Expected dataset structure for Person In Context instance detection: + +``` +pic/ + thing_train/ + # thing class label maps for auxiliary semantic loss + annotations/ + train_person.json + val_person.json + image/ + train/ + ... + +``` + +First link the PIC_2.0 dataset to this folder with `ln -s \path\to\PIC_2.0 pic`. Then use the `python gen_coco_person.py` to generate train and validation annotation jsons. + +Run `python prepare_thing_sem_from_instance.py --dataset-name pic`, to extract semantic labels from instance annotations. diff --git a/datasets/gen_coco_person.py b/datasets/gen_coco_person.py new file mode 100755 index 0000000..ea40bbd --- /dev/null +++ b/datasets/gen_coco_person.py @@ -0,0 +1,101 @@ +import numpy as np +import cv2 +import os +import json +error_list = ['23382.png', '23441.png', '20714.png', '20727.png', '23300.png', '21200.png'] + +def mask2box(mask): + index = np.argwhere(mask == 1) + rows = index[:, 0] + clos = index[:, 1] + y1 = int(np.min(rows)) # y + x1 = int(np.min(clos)) # x + y2 = int(np.max(rows)) + x2 = int(np.max(clos)) + return (x1, y1, x2, y2) + +def gen_coco(phase): + result = { + "info": {"description": "PIC2.0 dataset."}, + "categories": [ + {"supercategory": "none", "id": 1, "name": "person"} + ] + } + out_json = phase +'_person.json' + store_segmentation = True + + images_info = [] + labels_info = [] + img_id = 0 + files = tuple(open("pic/list5/"+phase+'_id', 'r')) + files = (_.strip() for _ in files) + + for index, image_name in enumerate(files): + image_name = image_name+".png" + print(index, image_name) + if image_name in error_list: + continue + instance = cv2.imread(os.path.join('instance', phase, image_name), flags=cv2.IMREAD_GRAYSCALE) + semantic = cv2.imread(os.path.join('semantic', phase, image_name), flags=cv2.IMREAD_GRAYSCALE) + # print(instance.shape, semantic.shape) + h = instance.shape[0] + w = instance.shape[1] + images_info.append( + { + "file_name": image_name[:-4]+'.jpg', + "height": h, + "width": w, + "id": index + } + ) + instance_max_num = instance.max() + instance_ids = np.unique(instance) + for instance_id in instance_ids: + if instance_id == 0: + continue + instance_part = instance == instance_id + object_pos = instance_part.nonzero() + # category_id_ = int(semantic[object_pos[0][0], object_pos[1][0]]) + category_id = int(np.max(semantic[object_pos[0], object_pos[1]])) + # assert category_id_ == category_id, (category_id_, category_id) + if category_id != 1: + continue + area = int(instance_part.sum()) + x1, y1, x2, y2 = mask2box(instance_part) + w = x2 - x1 + 1 + h = y2 - y1 + 1 + segmentation = [] + if store_segmentation: + contours, hierarchy = cv2.findContours((instance_part * 255).astype(np.uint8), cv2.RETR_TREE, + cv2.CHAIN_APPROX_SIMPLE) + for contour in contours: + contour = contour.flatten().tolist() + if len(contour) > 4: + segmentation.append(contour) + if len(segmentation) == 0: + print('error') + continue + labels_info.append( + { + "segmentation": segmentation, # poly + "area": area, # segmentation area + "iscrowd": 0, + "image_id": index, + "bbox": [x1, y1, w, h], + "category_id": category_id, + "id": img_id + }, + ) + img_id += 1 + # break + result["images"] = images_info + result["annotations"] = labels_info + with open('pic/annotations/' + out_json, 'w') as f: + json.dump(result, f, indent=4) + +if __name__ == "__main__": + if not os.path.exists('pic/annotations/'): + os.mkdirs('pic/annotations/') + gen_coco("train") + gen_coco("val") + #gen_coco("test") diff --git a/datasets/prepare_thing_sem_from_instance.py b/datasets/prepare_thing_sem_from_instance.py new file mode 100644 index 0000000..ef5ced9 --- /dev/null +++ b/datasets/prepare_thing_sem_from_instance.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import time +import functools +import multiprocessing as mp +import numpy as np +import os +import argparse +from pycocotools.coco import COCO +from pycocotools import mask as maskUtils + +from detectron2.data.datasets.builtin_meta import _get_coco_instances_meta + + +def annToRLE(ann, img_size): + h, w = img_size + segm = ann['segmentation'] + if type(segm) == list: + # polygon -- a single object might consist of multiple parts + # we merge all parts into one mask rle code + rles = maskUtils.frPyObjects(segm, h, w) + rle = maskUtils.merge(rles) + elif type(segm['counts']) == list: + # uncompressed RLE + rle = maskUtils.frPyObjects(segm, h, w) + else: + # rle + rle = ann['segmentation'] + return rle + + +def annToMask(ann, img_size): + rle = annToRLE(ann, img_size) + m = maskUtils.decode(rle) + return m + + +def _process_instance_to_semantic(anns, output_semantic, img, categories): + img_size = (img["height"], img["width"]) + output = np.zeros(img_size, dtype=np.uint8) + for ann in anns: + mask = annToMask(ann, img_size) + output[mask == 1] = categories[ann["category_id"]] + 1 + # save as compressed npz + np.savez_compressed(output_semantic, mask=output) + # Image.fromarray(output).save(output_semantic) + + +def create_coco_semantic_from_instance(instance_json, sem_seg_root, categories): + """ + Create semantic segmentation annotations from panoptic segmentation + annotations, to be used by PanopticFPN. + + It maps all thing categories to contiguous ids starting from 1, and maps all unlabeled pixels to class 0 + + Args: + instance_json (str): path to the instance json file, in COCO's format. + sem_seg_root (str): a directory to output semantic annotation files + categories (dict): category metadata. Each dict needs to have: + "id": corresponds to the "category_id" in the json annotations + "isthing": 0 or 1 + """ + os.makedirs(sem_seg_root, exist_ok=True) + + coco_detection = COCO(instance_json) + + def iter_annotations(): + for img_id in coco_detection.getImgIds(): + anns_ids = coco_detection.getAnnIds(img_id) + anns = coco_detection.loadAnns(anns_ids) + img = coco_detection.loadImgs(int(img_id))[0] + output = os.path.join(sem_seg_root, img["file_name"].replace('jpg', 'npz')) + yield anns, output, img + + # single process + # print("Start writing to {} ...".format(sem_seg_root)) + # start = time.time() + # for anno, oup, img in iter_annotations(): + # _process_instance_to_semantic( + # anno, oup, img, categories) + # print("Finished. time: {:.2f}s".format(time.time() - start)) + # return + + pool = mp.Pool(processes=max(mp.cpu_count() // 2, 4)) + + print("Start writing to {} ...".format(sem_seg_root)) + start = time.time() + pool.starmap( + functools.partial( + _process_instance_to_semantic, + categories=categories), + iter_annotations(), + chunksize=100, + ) + print("Finished. time: {:.2f}s".format(time.time() - start)) + + +def get_parser(): + parser = argparse.ArgumentParser(description="Keep only model in ckpt") + parser.add_argument( + "--dataset-name", + default="coco", + help="dataset to generate", + ) + return parser + + +if __name__ == "__main__": + args = get_parser().parse_args() + dataset_dir = os.path.join(os.path.dirname(__file__), args.dataset_name) + if args.dataset_name == "coco": + thing_id_to_contiguous_id = _get_coco_instances_meta()["thing_dataset_id_to_contiguous_id"] + split_name = 'train2017' + annotation_name = "annotations/instances_{}.json" + else: + thing_id_to_contiguous_id = {1: 0} + split_name = 'train' + annotation_name = "annotations/{}_person.json" + for s in ["train2017"]: + create_coco_semantic_from_instance( + os.path.join(dataset_dir, "annotations/instances_{}.json".format(s)), + os.path.join(dataset_dir, "thing_{}".format(s)), + thing_id_to_contiguous_id + ) diff --git a/datasets/prepare_thing_sem_from_lvis.py b/datasets/prepare_thing_sem_from_lvis.py new file mode 100644 index 0000000..44165c9 --- /dev/null +++ b/datasets/prepare_thing_sem_from_lvis.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import time +import functools +import multiprocessing as mp +import numpy as np +import os +from lvis import LVIS +from pycocotools import mask as maskUtils + + +def annToRLE(ann, img_size): + h, w = img_size + segm = ann['segmentation'] + if type(segm) == list: + # polygon -- a single object might consist of multiple parts + # we merge all parts into one mask rle code + rles = maskUtils.frPyObjects(segm, h, w) + rle = maskUtils.merge(rles) + elif type(segm['counts']) == list: + # uncompressed RLE + rle = maskUtils.frPyObjects(segm, h, w) + else: + # rle + rle = ann['segmentation'] + return rle + + +def annToMask(ann, img_size): + rle = annToRLE(ann, img_size) + m = maskUtils.decode(rle) + return m + + +def _process_instance_to_semantic(anns, output_semantic, img): + img_size = (img["height"], img["width"]) + output = np.zeros(img_size, dtype=np.uint8) + for ann in anns: + mask = annToMask(ann, img_size) + output[mask == 1] = ann["category_id"] // 5 + # save as compressed npz + np.savez_compressed(output_semantic, mask=output) + # Image.fromarray(output).save(output_semantic) + + +def create_lvis_semantic_from_instance(instance_json, sem_seg_root): + """ + Create semantic segmentation annotations from panoptic segmentation + annotations, to be used by PanopticFPN. + + It maps all thing categories to contiguous ids starting from 1, and maps all unlabeled pixels to class 0 + + Args: + instance_json (str): path to the instance json file, in COCO's format. + sem_seg_root (str): a directory to output semantic annotation files + """ + os.makedirs(sem_seg_root, exist_ok=True) + + lvis_detection = LVIS(instance_json) + + def iter_annotations(): + for img_id in lvis_detection.get_img_ids(): + anns_ids = lvis_detection.get_ann_ids([img_id]) + anns = lvis_detection.load_anns(anns_ids) + img = lvis_detection.load_imgs([img_id])[0] + output = os.path.join(sem_seg_root, img["file_name"].replace('jpg', 'npz')) + yield anns, output, img + + # # single process + # print("Start writing to {} ...".format(sem_seg_root)) + # start = time.time() + # for anno, oup, img in iter_annotations(): + # _process_instance_to_semantic( + # anno, oup, img) + # print("Finished. time: {:.2f}s".format(time.time() - start)) + # return + + pool = mp.Pool(processes=max(mp.cpu_count() // 2, 4)) + + print("Start writing to {} ...".format(sem_seg_root)) + start = time.time() + pool.starmap( + functools.partial( + _process_instance_to_semantic), + iter_annotations(), + chunksize=100, + ) + print("Finished. time: {:.2f}s".format(time.time() - start)) + + +if __name__ == "__main__": + dataset_dir = os.path.join(os.path.dirname(__file__), "lvis") + for s in ["train"]: + create_lvis_semantic_from_instance( + os.path.join(dataset_dir, "lvis_v0.5_{}.json".format(s)), + os.path.join(dataset_dir, "thing_{}".format(s)), + ) diff --git a/tools/train_net.py b/tools/train_net.py index 0c91d3e..7e67871 100644 --- a/tools/train_net.py +++ b/tools/train_net.py @@ -37,7 +37,7 @@ from detectron2.evaluation import ( ) from detectron2.modeling import GeneralizedRCNNWithTTA -from detectron2.data.dataset_mapper import DatasetMapper +from adet.data.dataset_mapper import DatasetMapperWithBasis from adet.config import get_cfg from adet.checkpoint import AdetCheckpointer @@ -123,7 +123,7 @@ class Trainer(DefaultTrainer): It calls :func:`detectron2.data.build_detection_train_loader` with a customized DatasetMapper, which adds categorical labels as a semantic mask. """ - mapper = DatasetMapper(cfg, True) + mapper = DatasetMapperWithBasis(cfg, True) return build_detection_train_loader(cfg, mapper) @classmethod