diff --git a/MODEL_ZOO.md b/MODEL_ZOO.md new file mode 100644 index 0000000..66c542f --- /dev/null +++ b/MODEL_ZOO.md @@ -0,0 +1,44 @@ +# AdelaiDet Model Zoo and Baselines + +## Introduction +This file documents a collection of models trained with AdelaiDet in Nov, 2019. + +## Models + +The inference time is measured on one 1080Ti based on the most recent commit on Detectron2 ([ffff8ac](https://github.com/facebookresearch/detectron2/commit/ffff8acc35ea88ad1cb1806ab0f00b4c1c5dbfd9)). + +More models will be released soon. Stay tuned. + +### COCO Object Detecton Baselines with FCOS + +Name | box AP | download +--- |:---:|:---: +[FCOS_R_50_1x](configs/FCOS-Detection/R_50_1x.yaml) | 38.7 | [model](https://cloudstor.aarnet.edu.au/plus/s/glqFc13cCoEyHYy/download) + +### COCO Instance Segmentation Baselines with [BlendMask](https://arxiv.org/abs/2001.00309) + +Model | Name |inference time (ms/im) | box AP | mask AP | download +--- |:---:|:---:|:---:|:---:|:---: +Mask R-CNN | [550_R_50_3x](configs/RCNN/550_R_50_FPN_3x.yaml) | 63 | 39.1 | 35.3 | +BlendMask | [550_R_50_3x](configs/BlendMask/550_R_50_3x.yaml) | 36 | 38.7 | 34.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/R3Qintf7N8UCiIt/download) +Mask R-CNN | [R_50_1x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml) | 80 | 38.6 | 35.2 | +BlendMask | [R_50_1x](configs/BlendMask/R_50_1x.yaml) | 73 | 39.9 | 35.8 | [model](https://cloudstor.aarnet.edu.au/plus/s/zoxXPnr6Hw3OJgK/download) +Mask R-CNN | [R_50_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml) | 80 | 41.0 | 37.2 | +BlendMask | [R_50_3x](configs/BlendMask/R_50_3x.yaml) | 74 | 42.7 | 37.8 | [model](https://cloudstor.aarnet.edu.au/plus/s/ZnaInHFEKst6mvg/download) +Mask R-CNN | [R_101_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml) | 100 | 42.9 | 38.6 | +BlendMask | [R_101_3x](configs/BlendMask/R_101_3x.yaml) | 94 | 44.8 | 39.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/e4fXrliAcMtyEBy/download) +BlendMask | [R_101_dcni3_5x](configs/BlendMask/R_101_dcni3_5x.yaml) | 105 | 46.8 | 41.1 | [model](https://cloudstor.aarnet.edu.au/plus/s/vbnKnQtaGlw8TKv/download) + +### COCO Panoptic Segmentation Baselines with BlendMask +Model | Name | PQ | PQTh | PQSt | download +--- |:---:|:---:|:---:|:---:|:---: +Panoptic FPN | [R_50_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml) | 41.5 | 48.3 | 31.2 | +BlendMask | [R_50_3x](configs/BlendMask/Panoptic/R_50_3x.yaml) | 42.5 | 49.5 | 32.0 | [model](https://cloudstor.aarnet.edu.au/plus/s/oDgi0826JOJXCr5/download) +Panoptic FPN | [R_101_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/panoptic_fpn_R_101_3x.yaml) | 43.0 | 49.7 | 32.9 | +BlendMask | [R_101_3x](configs/BlendMask/Panoptic/R_101_3x.yaml) | 44.3 | 51.6 | 33.2 | [model](https://cloudstor.aarnet.edu.au/plus/s/u6gZwj06MWDEkYe/download) +BlendMask | [R_101_dcni3_5x](configs/BlendMask/Panoptic/R_101_dcni3_5x.yaml) | 46.0 | 52.9 | 35.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/Jwp41WEzDdrhWsN/download) + +### Person in Context with BlendMask +Model | Name | box AP | mask AP | download +--- |:---:|:---:|:---:|:---: +BlendMask | [R_50_1x](configs/BlendMask/Person/R_50_1x.yaml) | 70.6 | 66.7 | [model](https://cloudstor.aarnet.edu.au/plus/s/nvpcKTFA5fsagc0/download) \ No newline at end of file diff --git a/README.md b/README.md index b991a7d..294ea3c 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ To date, AdelaiDet implements the following algorithms: ## Models -More models will be released soon. Stay tuned. +All of our trained models are available in the [Model Zoo](MODEL_ZOO.md). ### COCO Object Detecton Baselines with FCOS @@ -29,28 +29,14 @@ Name | box AP | download Model | Name |inference time (ms/im) | box AP | mask AP | download --- |:---:|:---:|:---:|:---:|:---: Mask R-CNN | [550_R_50_3x](configs/RCNN/550_R_50_FPN_3x.yaml) | 63 | 39.1 | 35.3 | -BlendMask | [550_R_50_3x](configs/BlendMask/550_R_50_3x.yaml) | 40 | 38.7 | 34.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/o0bpkmhMiuYgIcQ/download) -Mask R-CNN | [R_50_1x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml) | 90 | 38.6 | 35.2 | -BlendMask | [R_50_1x](configs/BlendMask/R_50_1x.yaml) | 83 | 39.9 | 35.8 | [model](https://cloudstor.aarnet.edu.au/plus/s/crpmeVCnQ3StvSz/download) -Mask R-CNN | [R_50_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml) | | 41.0 | 37.2 | -BlendMask | [R_50_3x](configs/BlendMask/R_50_3x.yaml) | | 42.7 | 37.8 | [model](https://cloudstor.aarnet.edu.au/plus/s/9u1cG2zXvEva5SM/download) -Mask R-CNN | [R_101_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml) | | 42.9 | 38.6 | -BlendMask | [R_101_3x](configs/BlendMask/R_101_3x.yaml) | | 44.8 | 39.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/mYm5VCXICoeLNHq/download) -BlendMask | [R_101_dcni3_5x](configs/BlendMask/R_101_dcni3_5x.yaml) | | 46.8 | 41.1 | [model](https://cloudstor.aarnet.edu.au/plus/s/TAZPxSDvPuhegKp/download) - -### COCO Panoptic Segmentation Baselines with BlendMask -Model | Name | PQ | PQTh | PQSt | download ---- |:---:|:---:|:---:|:---:|:---: -Panoptic FPN | [R_50_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml) | 41.5 | 48.3 | 31.2 | -BlendMask | [R_50_3x](configs/BlendMask/Panoptic/R_50_3x.yaml) | 42.5 | 49.5 | 32.0 | [model](https://cloudstor.aarnet.edu.au/plus/s/bG0IhYeMAvlTGTq/download) -Panoptic FPN | [R_101_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/panoptic_fpn_R_101_3x.yaml) | 43.0 | 49.7 | 32.9 | -BlendMask | [R_101_3x](configs/BlendMask/Panoptic/R_101_3x.yaml) | 44.3 | 51.6 | 33.2 | [model](https://cloudstor.aarnet.edu.au/plus/s/AEwbhyQ9F3lqvsz/download) -BlendMask | [R_101_dcni3_5x](configs/BlendMask/Panoptic/R_101_dcni3_5x.yaml) | 46.0 | 52.9 | 35.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/GyWDhsukAYYokZg/download) - -### Person in Context with BlendMask -Model | Name | box AP | mask AP | download ---- |:---:|:---:|:---:|:---: -BlendMask | [R_50_1x](configs/BlendMask/Person/R_50_1x.yaml) | 70.6 | 66.7 | [model](https://cloudstor.aarnet.edu.au/plus/s/d4f16WshXYbOuIo) +BlendMask | [550_R_50_3x](configs/BlendMask/550_R_50_3x.yaml) | 36 | 38.7 | 34.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/R3Qintf7N8UCiIt/download) +Mask R-CNN | [R_50_1x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml) | 80 | 38.6 | 35.2 | +BlendMask | [R_50_1x](configs/BlendMask/R_50_1x.yaml) | 73 | 39.9 | 35.8 | [model](https://cloudstor.aarnet.edu.au/plus/s/zoxXPnr6Hw3OJgK/download) +Mask R-CNN | [R_50_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml) | 80 | 41.0 | 37.2 | +BlendMask | [R_50_3x](configs/BlendMask/R_50_3x.yaml) | 74 | 42.7 | 37.8 | [model](https://cloudstor.aarnet.edu.au/plus/s/ZnaInHFEKst6mvg/download) +Mask R-CNN | [R_101_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml) | 100 | 42.9 | 38.6 | +BlendMask | [R_101_3x](configs/BlendMask/R_101_3x.yaml) | 94 | 44.8 | 39.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/e4fXrliAcMtyEBy/download) +BlendMask | [R_101_dcni3_5x](configs/BlendMask/R_101_dcni3_5x.yaml) | 105 | 46.8 | 41.1 | [model](https://cloudstor.aarnet.edu.au/plus/s/vbnKnQtaGlw8TKv/download) ## Installation diff --git a/adet/config/defaults.py b/adet/config/defaults.py index c167e83..c12d76f 100644 --- a/adet/config/defaults.py +++ b/adet/config/defaults.py @@ -6,6 +6,8 @@ from detectron2.config import CfgNode as CN # Additional Configs # ---------------------------------------------------------------------------- # _C.MODEL.MOBILENET = False +_C.MODEL.BACKBONE.ANTI_ALIAS = False +_C.MODEL.RESNETS.DEFORM_INTERVAL = 1 # ---------------------------------------------------------------------------- # # FCOS Head diff --git a/adet/modeling/backbone/__init__.py b/adet/modeling/backbone/__init__.py index 3ba7bda..154585f 100644 --- a/adet/modeling/backbone/__init__.py +++ b/adet/modeling/backbone/__init__.py @@ -1,2 +1,3 @@ from .fpn import build_fcos_resnet_fpn_backbone from .vovnet import build_vovnet_fpn_backbone, build_vovnet_backbone +from .resnet_lpf import build_resnet_lpf_backbone diff --git a/adet/modeling/backbone/fpn.py b/adet/modeling/backbone/fpn.py index da1bf95..70adaa0 100644 --- a/adet/modeling/backbone/fpn.py +++ b/adet/modeling/backbone/fpn.py @@ -6,6 +6,8 @@ from detectron2.modeling.backbone import FPN, build_resnet_backbone from detectron2.layers import ShapeSpec from detectron2.modeling.backbone.build import BACKBONE_REGISTRY +from .resnet_lpf import build_resnet_lpf_backbone +from .resnet_interval import build_resnet_interval_backbone from .mobilenet import build_mnv2_backbone @@ -57,7 +59,11 @@ def build_fcos_resnet_fpn_backbone(cfg, input_shape: ShapeSpec): Returns: backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. """ - if cfg.MODEL.MOBILENET: + if cfg.MODEL.BACKBONE.ANTI_ALIAS: + bottom_up = build_resnet_lpf_backbone(cfg, input_shape) + elif cfg.MODEL.RESNETS.DEFORM_INTERVAL > 1: + bottom_up = build_resnet_interval_backbone(cfg, input_shape) + elif cfg.MODEL.MOBILENET: bottom_up = build_mnv2_backbone(cfg, input_shape) else: bottom_up = build_resnet_backbone(cfg, input_shape) diff --git a/adet/modeling/backbone/lpf.py b/adet/modeling/backbone/lpf.py new file mode 100644 index 0000000..f455fea --- /dev/null +++ b/adet/modeling/backbone/lpf.py @@ -0,0 +1,115 @@ +import torch +import torch.nn.parallel +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from IPython import embed + + +class Downsample(nn.Module): + def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0): + super(Downsample, self).__init__() + self.filt_size = filt_size + self.pad_off = pad_off + self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))] + self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes] + self.stride = stride + self.off = int((self.stride-1)/2.) + self.channels = channels + + # print('Filter size [%i]'%filt_size) + if(self.filt_size==1): + a = np.array([1.,]) + elif(self.filt_size==2): + a = np.array([1., 1.]) + elif(self.filt_size==3): + a = np.array([1., 2., 1.]) + elif(self.filt_size==4): + a = np.array([1., 3., 3., 1.]) + elif(self.filt_size==5): + a = np.array([1., 4., 6., 4., 1.]) + elif(self.filt_size==6): + a = np.array([1., 5., 10., 10., 5., 1.]) + elif(self.filt_size==7): + a = np.array([1., 6., 15., 20., 15., 6., 1.]) + + filt = torch.Tensor(a[:,None]*a[None,:]) + filt = filt/torch.sum(filt) + self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1))) + + self.pad = get_pad_layer(pad_type)(self.pad_sizes) + + def forward(self, inp): + if(self.filt_size==1): + if(self.pad_off==0): + return inp[:,:,::self.stride,::self.stride] + else: + return self.pad(inp)[:,:,::self.stride,::self.stride] + else: + return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) + +def get_pad_layer(pad_type): + if(pad_type in ['refl','reflect']): + PadLayer = nn.ReflectionPad2d + elif(pad_type in ['repl','replicate']): + PadLayer = nn.ReplicationPad2d + elif(pad_type=='zero'): + PadLayer = nn.ZeroPad2d + else: + print('Pad type [%s] not recognized'%pad_type) + return PadLayer + + +class Downsample1D(nn.Module): + def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0): + super(Downsample1D, self).__init__() + self.filt_size = filt_size + self.pad_off = pad_off + self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))] + self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] + self.stride = stride + self.off = int((self.stride - 1) / 2.) + self.channels = channels + + # print('Filter size [%i]' % filt_size) + if(self.filt_size == 1): + a = np.array([1., ]) + elif(self.filt_size == 2): + a = np.array([1., 1.]) + elif(self.filt_size == 3): + a = np.array([1., 2., 1.]) + elif(self.filt_size == 4): + a = np.array([1., 3., 3., 1.]) + elif(self.filt_size == 5): + a = np.array([1., 4., 6., 4., 1.]) + elif(self.filt_size == 6): + a = np.array([1., 5., 10., 10., 5., 1.]) + elif(self.filt_size == 7): + a = np.array([1., 6., 15., 20., 15., 6., 1.]) + + filt = torch.Tensor(a) + filt = filt / torch.sum(filt) + self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1))) + + self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes) + + def forward(self, inp): + if(self.filt_size == 1): + if(self.pad_off == 0): + return inp[:, :, ::self.stride] + else: + return self.pad(inp)[:, :, ::self.stride] + else: + return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) + + +def get_pad_layer_1d(pad_type): + if(pad_type in ['refl', 'reflect']): + PadLayer = nn.ReflectionPad1d + elif(pad_type in ['repl', 'replicate']): + PadLayer = nn.ReplicationPad1d + elif(pad_type == 'zero'): + PadLayer = nn.ZeroPad1d + else: + print('Pad type [%s] not recognized' % pad_type) + return PadLayer diff --git a/adet/modeling/backbone/resnet_interval.py b/adet/modeling/backbone/resnet_interval.py new file mode 100644 index 0000000..b91be6e --- /dev/null +++ b/adet/modeling/backbone/resnet_interval.py @@ -0,0 +1,116 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from detectron2.layers import FrozenBatchNorm2d +from detectron2.modeling.backbone import BACKBONE_REGISTRY +from detectron2.modeling.backbone.resnet import ( + BasicStem, + DeformBottleneckBlock, + BottleneckBlock, + ResNet, +) + + +def make_stage_intervals(block_class, num_blocks, first_stride, **kwargs): + """ + Create a resnet stage by creating many blocks. + Args: + block_class (class): a subclass of ResNetBlockBase + num_blocks (int): + first_stride (int): the stride of the first block. The other blocks will have stride=1. + A `stride` argument will be passed to the block constructor. + kwargs: other arguments passed to the block constructor. + + Returns: + list[nn.Module]: a list of block module. + """ + blocks = [] + conv_kwargs = {key: kwargs[key] for key in kwargs if "deform" not in key} + deform_kwargs = {key: kwargs[key] for key in kwargs if key != "deform_interval"} + deform_interval = kwargs.get("deform_interval", None) + for i in range(num_blocks): + if deform_interval and i % deform_interval == 0: + blocks.append(block_class(stride=first_stride if i == 0 else 1, **deform_kwargs)) + else: + blocks.append(BottleneckBlock(stride=first_stride if i == 0 else 1, **conv_kwargs)) + conv_kwargs["in_channels"] = conv_kwargs["out_channels"] + deform_kwargs["in_channels"] = deform_kwargs["out_channels"] + return blocks + + +@BACKBONE_REGISTRY.register() +def build_resnet_interval_backbone(cfg, input_shape): + """ + Create a ResNet instance from config. + + Returns: + ResNet: a :class:`ResNet` instance. + """ + # need registration of new blocks/stems? + norm = cfg.MODEL.RESNETS.NORM + stem = BasicStem( + in_channels=input_shape.channels, + out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, + norm=norm, + ) + freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT + + if freeze_at >= 1: + for p in stem.parameters(): + p.requires_grad = False + stem = FrozenBatchNorm2d.convert_frozen_batchnorm(stem) + + # fmt: off + out_features = cfg.MODEL.RESNETS.OUT_FEATURES + depth = cfg.MODEL.RESNETS.DEPTH + num_groups = cfg.MODEL.RESNETS.NUM_GROUPS + width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP + bottleneck_channels = num_groups * width_per_group + in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS + out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 + res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION + deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE + deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED + deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS + deform_interval = cfg.MODEL.RESNETS.DEFORM_INTERVAL + # fmt: on + assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation) + + num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth] + + stages = [] + + # Avoid creating variables without gradients + # It consumes extra memory and may cause allreduce to fail + out_stage_idx = [{"res2": 2, "res3": 3, "res4": 4, "res5": 5}[f] for f in out_features] + max_stage_idx = max(out_stage_idx) + for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)): + dilation = res5_dilation if stage_idx == 5 else 1 + first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2 + stage_kargs = { + "num_blocks": num_blocks_per_stage[idx], + "first_stride": first_stride, + "in_channels": in_channels, + "bottleneck_channels": bottleneck_channels, + "out_channels": out_channels, + "num_groups": num_groups, + "norm": norm, + "stride_in_1x1": stride_in_1x1, + "dilation": dilation, + } + if deform_on_per_stage[idx]: + stage_kargs["block_class"] = DeformBottleneckBlock + stage_kargs["deform_modulated"] = deform_modulated + stage_kargs["deform_num_groups"] = deform_num_groups + stage_kargs["deform_interval"] = deform_interval + else: + stage_kargs["block_class"] = BottleneckBlock + blocks = make_stage_intervals(**stage_kargs) + in_channels = out_channels + out_channels *= 2 + bottleneck_channels *= 2 + + if freeze_at >= stage_idx: + for block in blocks: + block.freeze() + stages.append(blocks) + return ResNet(stem, stages, out_features=out_features) diff --git a/adet/modeling/backbone/resnet_lpf.py b/adet/modeling/backbone/resnet_lpf.py new file mode 100644 index 0000000..867de6d --- /dev/null +++ b/adet/modeling/backbone/resnet_lpf.py @@ -0,0 +1,291 @@ +# This code is built from the PyTorch examples repository: https://github.com/pytorch/vision/tree/master/torchvision/models. +# Copyright (c) 2017 Torch Contributors. +# The Pytorch examples are available under the BSD 3-Clause License. +# +# ========================================================================================== +# +# Adobe’s modifications are Copyright 2019 Adobe. All rights reserved. +# Adobe’s modifications are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike +# 4.0 International Public License (CC-NC-SA-4.0). To view a copy of the license, visit +# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. +# +# ========================================================================================== +# +# BSD-3 License +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + +import torch.nn as nn + +from detectron2.layers.batch_norm import NaiveSyncBatchNorm +from detectron2.modeling.backbone.build import BACKBONE_REGISTRY +from detectron2.modeling.backbone import Backbone + +from .lpf import * + + +__all__ = ['ResNetLPF', 'build_resnet_lpf_backbone'] + + +def conv3x3(in_planes, out_planes, stride=1, groups=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, groups=groups, bias=False) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None, filter_size=1): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1: + raise ValueError('BasicBlock only supports groups=1') + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + if(stride == 1): + self.conv2 = conv3x3(planes, planes) + else: + self.conv2 = nn.Sequential(Downsample(filt_size=filter_size, stride=stride, channels=planes), + conv3x3(planes, planes),) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None, filter_size=1): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, planes) + self.bn1 = norm_layer(planes) + self.conv2 = conv3x3(planes, planes, groups) # stride moved + self.bn2 = norm_layer(planes) + if(stride == 1): + self.conv3 = conv1x1(planes, planes * self.expansion) + else: + self.conv3 = nn.Sequential(Downsample(filt_size=filter_size, stride=stride, channels=planes), + conv1x1(planes, planes * self.expansion)) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNetLPF(Backbone): + + def __init__(self, cfg, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, norm_layer=None, filter_size=1, + pool_only=True, return_idx=[0, 1, 2, 3]): + super().__init__() + self.return_idx = return_idx + if norm_layer is None: + norm_layer = nn.BatchNorm2d + planes = [int(width_per_group * groups * 2 ** i) for i in range(4)] + self.inplanes = planes[0] + + if(pool_only): + self.conv1 = nn.Conv2d(3, planes[0], kernel_size=7, stride=2, padding=3, bias=False) + else: + self.conv1 = nn.Conv2d(3, planes[0], kernel_size=7, stride=1, padding=3, bias=False) + self.bn1 = norm_layer(planes[0]) + self.relu = nn.ReLU(inplace=True) + + if(pool_only): + self.maxpool = nn.Sequential(*[nn.MaxPool2d(kernel_size=2, stride=1), + Downsample(filt_size=filter_size, stride=2, channels=planes[0])]) + else: + self.maxpool = nn.Sequential(*[Downsample(filt_size=filter_size, stride=2, channels=planes[0]), + nn.MaxPool2d(kernel_size=2, stride=1), + Downsample(filt_size=filter_size, stride=2, channels=planes[0])]) + + self.layer1 = self._make_layer( + block, planes[0], layers[0], groups=groups, norm_layer=norm_layer) + self.layer2 = self._make_layer( + block, planes[1], layers[1], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size) + self.layer3 = self._make_layer( + block, planes[2], layers[2], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size) + self.layer4 = self._make_layer( + block, planes[3], layers[3], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if(m.in_channels != m.out_channels or m.out_channels != m.groups or m.bias is not None): + # don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + else: + print('Not initializing') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + self._freeze_backbone(cfg.MODEL.BACKBONE.FREEZE_AT) + if False: + self._freeze_bn() + + def _freeze_backbone(self, freeze_at): + if freeze_at < 0: + return + for stage_index in range(freeze_at): + if stage_index == 0: + # stage 0 is the stem + for p in self.conv1.parameters(): + p.requires_grad = False + for p in self.bn1.parameters(): + p.requires_grad = False + else: + m = getattr(self, "layer" + str(stage_index)) + for p in m.parameters(): + p.requires_grad = False + + def _freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def _make_layer(self, block, planes, blocks, stride=1, groups=1, norm_layer=None, filter_size=1): + if norm_layer is None: + norm_layer = nn.BatchNorm2d + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + # downsample = nn.Sequential( + # conv1x1(self.inplanes, planes * block.expansion, stride, filter_size=filter_size), + # norm_layer(planes * block.expansion), + # ) + + downsample = [Downsample(filt_size=filter_size, stride=stride, + channels=self.inplanes), ] if(stride != 1) else [] + downsample += [conv1x1(self.inplanes, planes * block.expansion, 1), + norm_layer(planes * block.expansion)] + # print(downsample) + downsample = nn.Sequential(*downsample) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, + groups, norm_layer, filter_size=filter_size)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=groups, + norm_layer=norm_layer, filter_size=filter_size)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + outs = [] + outs.append(self.layer1(x)) # 1/4 + outs.append(self.layer2(outs[-1])) # 1/8 + outs.append(self.layer3(outs[-1])) # 1/16 + outs.append(self.layer4(outs[-1])) # 1/32 + return {"res{}".format(idx + 2): outs[idx] for idx in self.return_idx} + + +@BACKBONE_REGISTRY.register() +def build_resnet_lpf_backbone(cfg, input_shape): + """ + Create a ResNet instance from config. + + Returns: + ResNet: a :class:`ResNet` instance. + """ + depth = cfg.MODEL.RESNETS.DEPTH + out_features = cfg.MODEL.RESNETS.OUT_FEATURES + + num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth] + out_stage_idx = [{"res2": 0, "res3": 1, "res4": 2, "res5": 3}[f] for f in out_features] + out_feature_channels = {"res2": 256, "res3": 512, + "res4": 1024, "res5": 2048} + out_feature_strides = {"res2": 4, "res3": 8, "res4": 16, "res5": 32} + model = ResNetLPF(cfg, Bottleneck, num_blocks_per_stage, norm_layer=NaiveSyncBatchNorm, + filter_size=3, pool_only=True, return_idx=out_stage_idx) + model._out_features = out_features + model._out_feature_channels = out_feature_channels + model._out_feature_strides = out_feature_strides + return model diff --git a/adet/modeling/fcos/fcos.py b/adet/modeling/fcos/fcos.py index fb5895b..fbb0163 100644 --- a/adet/modeling/fcos/fcos.py +++ b/adet/modeling/fcos/fcos.py @@ -217,7 +217,7 @@ class FCOSHead(nn.Module): in_channels, 4, kernel_size=3, stride=1, padding=1 ) - self.centerness = nn.Conv2d( + self.ctrness = nn.Conv2d( in_channels, 1, kernel_size=3, stride=1, padding=1 ) @@ -230,7 +230,7 @@ class FCOSHead(nn.Module): for modules in [ self.cls_tower, self.bbox_tower, self.share_tower, self.cls_logits, - self.bbox_pred, self.centerness + self.bbox_pred, self.ctrness ]: for l in modules.modules(): if isinstance(l, nn.Conv2d): @@ -256,7 +256,7 @@ class FCOSHead(nn.Module): bbox_towers.append(bbox_tower) logits.append(self.cls_logits(cls_tower)) - ctrness.append(self.centerness(bbox_tower)) + ctrness.append(self.ctrness(bbox_tower)) reg = self.bbox_pred(bbox_tower) if self.scales is not None: reg = self.scales[l](reg) diff --git a/configs/BlendMask/Panoptic/Base-Panoptic.yaml b/configs/BlendMask/Panoptic/Base-Panoptic.yaml index 7fd16ec..ffd7ff0 100644 --- a/configs/BlendMask/Panoptic/Base-Panoptic.yaml +++ b/configs/BlendMask/Panoptic/Base-Panoptic.yaml @@ -9,7 +9,7 @@ MODEL: PANOPTIC_FPN: COMBINE: ENABLED: True - INSTANCES_CONFIDENCE_THRESH: 0.2 + INSTANCES_CONFIDENCE_THRESH: 0.45 OVERLAP_THRESH: 0.4 DATASETS: TRAIN: ("coco_2017_train_panoptic_separated",) diff --git a/tools/rename_blendmask.py b/tools/rename_blendmask.py new file mode 100644 index 0000000..694c502 --- /dev/null +++ b/tools/rename_blendmask.py @@ -0,0 +1,41 @@ +import argparse +from collections import OrderedDict + +import torch + + +def get_parser(): + parser = argparse.ArgumentParser(description="FCOS Detectron2 Converter") + parser.add_argument( + "--model", + default="weights/blendmask/person/R_50_1x.pth", + metavar="FILE", + help="path to model weights", + ) + parser.add_argument( + "--output", + default="weights/blendmask/person/R_50_1x.pth", + metavar="FILE", + help="path to model weights", + ) + return parser + + +def rename_resnet_param_names(ckpt_state_dict): + converted_state_dict = OrderedDict() + for key in ckpt_state_dict.keys(): + value = ckpt_state_dict[key] + key = key.replace("centerness", "ctrness") + + converted_state_dict[key] = value + return converted_state_dict + + +if __name__ == "__main__": + args = get_parser().parse_args() + ckpt = torch.load(args.model) + if "model" in ckpt: + model = rename_resnet_param_names(ckpt["model"]) + else: + model = rename_resnet_param_names(ckpt) + torch.save(model, args.output)