diff --git a/MODEL_ZOO.md b/MODEL_ZOO.md
new file mode 100644
index 0000000..66c542f
--- /dev/null
+++ b/MODEL_ZOO.md
@@ -0,0 +1,44 @@
+# AdelaiDet Model Zoo and Baselines
+
+## Introduction
+This file documents a collection of models trained with AdelaiDet in Nov, 2019.
+
+## Models
+
+The inference time is measured on one 1080Ti based on the most recent commit on Detectron2 ([ffff8ac](https://github.com/facebookresearch/detectron2/commit/ffff8acc35ea88ad1cb1806ab0f00b4c1c5dbfd9)).
+
+More models will be released soon. Stay tuned.
+
+### COCO Object Detecton Baselines with FCOS
+
+Name | box AP | download
+--- |:---:|:---:
+[FCOS_R_50_1x](configs/FCOS-Detection/R_50_1x.yaml) | 38.7 | [model](https://cloudstor.aarnet.edu.au/plus/s/glqFc13cCoEyHYy/download)
+
+### COCO Instance Segmentation Baselines with [BlendMask](https://arxiv.org/abs/2001.00309)
+
+Model | Name |inference time (ms/im) | box AP | mask AP | download
+--- |:---:|:---:|:---:|:---:|:---:
+Mask R-CNN | [550_R_50_3x](configs/RCNN/550_R_50_FPN_3x.yaml) | 63 | 39.1 | 35.3 |
+BlendMask | [550_R_50_3x](configs/BlendMask/550_R_50_3x.yaml) | 36 | 38.7 | 34.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/R3Qintf7N8UCiIt/download)
+Mask R-CNN | [R_50_1x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml) | 80 | 38.6 | 35.2 |
+BlendMask | [R_50_1x](configs/BlendMask/R_50_1x.yaml) | 73 | 39.9 | 35.8 | [model](https://cloudstor.aarnet.edu.au/plus/s/zoxXPnr6Hw3OJgK/download)
+Mask R-CNN | [R_50_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml) | 80 | 41.0 | 37.2 |
+BlendMask | [R_50_3x](configs/BlendMask/R_50_3x.yaml) | 74 | 42.7 | 37.8 | [model](https://cloudstor.aarnet.edu.au/plus/s/ZnaInHFEKst6mvg/download)
+Mask R-CNN | [R_101_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml) | 100 | 42.9 | 38.6 |
+BlendMask | [R_101_3x](configs/BlendMask/R_101_3x.yaml) | 94 | 44.8 | 39.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/e4fXrliAcMtyEBy/download)
+BlendMask | [R_101_dcni3_5x](configs/BlendMask/R_101_dcni3_5x.yaml) | 105 | 46.8 | 41.1 | [model](https://cloudstor.aarnet.edu.au/plus/s/vbnKnQtaGlw8TKv/download)
+
+### COCO Panoptic Segmentation Baselines with BlendMask
+Model | Name | PQ | PQTh | PQSt | download
+--- |:---:|:---:|:---:|:---:|:---:
+Panoptic FPN | [R_50_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml) | 41.5 | 48.3 | 31.2 |
+BlendMask | [R_50_3x](configs/BlendMask/Panoptic/R_50_3x.yaml) | 42.5 | 49.5 | 32.0 | [model](https://cloudstor.aarnet.edu.au/plus/s/oDgi0826JOJXCr5/download)
+Panoptic FPN | [R_101_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/panoptic_fpn_R_101_3x.yaml) | 43.0 | 49.7 | 32.9 |
+BlendMask | [R_101_3x](configs/BlendMask/Panoptic/R_101_3x.yaml) | 44.3 | 51.6 | 33.2 | [model](https://cloudstor.aarnet.edu.au/plus/s/u6gZwj06MWDEkYe/download)
+BlendMask | [R_101_dcni3_5x](configs/BlendMask/Panoptic/R_101_dcni3_5x.yaml) | 46.0 | 52.9 | 35.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/Jwp41WEzDdrhWsN/download)
+
+### Person in Context with BlendMask
+Model | Name | box AP | mask AP | download
+--- |:---:|:---:|:---:|:---:
+BlendMask | [R_50_1x](configs/BlendMask/Person/R_50_1x.yaml) | 70.6 | 66.7 | [model](https://cloudstor.aarnet.edu.au/plus/s/nvpcKTFA5fsagc0/download)
\ No newline at end of file
diff --git a/README.md b/README.md
index b991a7d..294ea3c 100644
--- a/README.md
+++ b/README.md
@@ -16,7 +16,7 @@ To date, AdelaiDet implements the following algorithms:
## Models
-More models will be released soon. Stay tuned.
+All of our trained models are available in the [Model Zoo](MODEL_ZOO.md).
### COCO Object Detecton Baselines with FCOS
@@ -29,28 +29,14 @@ Name | box AP | download
Model | Name |inference time (ms/im) | box AP | mask AP | download
--- |:---:|:---:|:---:|:---:|:---:
Mask R-CNN | [550_R_50_3x](configs/RCNN/550_R_50_FPN_3x.yaml) | 63 | 39.1 | 35.3 |
-BlendMask | [550_R_50_3x](configs/BlendMask/550_R_50_3x.yaml) | 40 | 38.7 | 34.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/o0bpkmhMiuYgIcQ/download)
-Mask R-CNN | [R_50_1x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml) | 90 | 38.6 | 35.2 |
-BlendMask | [R_50_1x](configs/BlendMask/R_50_1x.yaml) | 83 | 39.9 | 35.8 | [model](https://cloudstor.aarnet.edu.au/plus/s/crpmeVCnQ3StvSz/download)
-Mask R-CNN | [R_50_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml) | | 41.0 | 37.2 |
-BlendMask | [R_50_3x](configs/BlendMask/R_50_3x.yaml) | | 42.7 | 37.8 | [model](https://cloudstor.aarnet.edu.au/plus/s/9u1cG2zXvEva5SM/download)
-Mask R-CNN | [R_101_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml) | | 42.9 | 38.6 |
-BlendMask | [R_101_3x](configs/BlendMask/R_101_3x.yaml) | | 44.8 | 39.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/mYm5VCXICoeLNHq/download)
-BlendMask | [R_101_dcni3_5x](configs/BlendMask/R_101_dcni3_5x.yaml) | | 46.8 | 41.1 | [model](https://cloudstor.aarnet.edu.au/plus/s/TAZPxSDvPuhegKp/download)
-
-### COCO Panoptic Segmentation Baselines with BlendMask
-Model | Name | PQ | PQTh | PQSt | download
---- |:---:|:---:|:---:|:---:|:---:
-Panoptic FPN | [R_50_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml) | 41.5 | 48.3 | 31.2 |
-BlendMask | [R_50_3x](configs/BlendMask/Panoptic/R_50_3x.yaml) | 42.5 | 49.5 | 32.0 | [model](https://cloudstor.aarnet.edu.au/plus/s/bG0IhYeMAvlTGTq/download)
-Panoptic FPN | [R_101_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/panoptic_fpn_R_101_3x.yaml) | 43.0 | 49.7 | 32.9 |
-BlendMask | [R_101_3x](configs/BlendMask/Panoptic/R_101_3x.yaml) | 44.3 | 51.6 | 33.2 | [model](https://cloudstor.aarnet.edu.au/plus/s/AEwbhyQ9F3lqvsz/download)
-BlendMask | [R_101_dcni3_5x](configs/BlendMask/Panoptic/R_101_dcni3_5x.yaml) | 46.0 | 52.9 | 35.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/GyWDhsukAYYokZg/download)
-
-### Person in Context with BlendMask
-Model | Name | box AP | mask AP | download
---- |:---:|:---:|:---:|:---:
-BlendMask | [R_50_1x](configs/BlendMask/Person/R_50_1x.yaml) | 70.6 | 66.7 | [model](https://cloudstor.aarnet.edu.au/plus/s/d4f16WshXYbOuIo)
+BlendMask | [550_R_50_3x](configs/BlendMask/550_R_50_3x.yaml) | 36 | 38.7 | 34.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/R3Qintf7N8UCiIt/download)
+Mask R-CNN | [R_50_1x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml) | 80 | 38.6 | 35.2 |
+BlendMask | [R_50_1x](configs/BlendMask/R_50_1x.yaml) | 73 | 39.9 | 35.8 | [model](https://cloudstor.aarnet.edu.au/plus/s/zoxXPnr6Hw3OJgK/download)
+Mask R-CNN | [R_50_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml) | 80 | 41.0 | 37.2 |
+BlendMask | [R_50_3x](configs/BlendMask/R_50_3x.yaml) | 74 | 42.7 | 37.8 | [model](https://cloudstor.aarnet.edu.au/plus/s/ZnaInHFEKst6mvg/download)
+Mask R-CNN | [R_101_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml) | 100 | 42.9 | 38.6 |
+BlendMask | [R_101_3x](configs/BlendMask/R_101_3x.yaml) | 94 | 44.8 | 39.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/e4fXrliAcMtyEBy/download)
+BlendMask | [R_101_dcni3_5x](configs/BlendMask/R_101_dcni3_5x.yaml) | 105 | 46.8 | 41.1 | [model](https://cloudstor.aarnet.edu.au/plus/s/vbnKnQtaGlw8TKv/download)
## Installation
diff --git a/adet/config/defaults.py b/adet/config/defaults.py
index c167e83..c12d76f 100644
--- a/adet/config/defaults.py
+++ b/adet/config/defaults.py
@@ -6,6 +6,8 @@ from detectron2.config import CfgNode as CN
# Additional Configs
# ---------------------------------------------------------------------------- #
_C.MODEL.MOBILENET = False
+_C.MODEL.BACKBONE.ANTI_ALIAS = False
+_C.MODEL.RESNETS.DEFORM_INTERVAL = 1
# ---------------------------------------------------------------------------- #
# FCOS Head
diff --git a/adet/modeling/backbone/__init__.py b/adet/modeling/backbone/__init__.py
index 3ba7bda..154585f 100644
--- a/adet/modeling/backbone/__init__.py
+++ b/adet/modeling/backbone/__init__.py
@@ -1,2 +1,3 @@
from .fpn import build_fcos_resnet_fpn_backbone
from .vovnet import build_vovnet_fpn_backbone, build_vovnet_backbone
+from .resnet_lpf import build_resnet_lpf_backbone
diff --git a/adet/modeling/backbone/fpn.py b/adet/modeling/backbone/fpn.py
index da1bf95..70adaa0 100644
--- a/adet/modeling/backbone/fpn.py
+++ b/adet/modeling/backbone/fpn.py
@@ -6,6 +6,8 @@ from detectron2.modeling.backbone import FPN, build_resnet_backbone
from detectron2.layers import ShapeSpec
from detectron2.modeling.backbone.build import BACKBONE_REGISTRY
+from .resnet_lpf import build_resnet_lpf_backbone
+from .resnet_interval import build_resnet_interval_backbone
from .mobilenet import build_mnv2_backbone
@@ -57,7 +59,11 @@ def build_fcos_resnet_fpn_backbone(cfg, input_shape: ShapeSpec):
Returns:
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
"""
- if cfg.MODEL.MOBILENET:
+ if cfg.MODEL.BACKBONE.ANTI_ALIAS:
+ bottom_up = build_resnet_lpf_backbone(cfg, input_shape)
+ elif cfg.MODEL.RESNETS.DEFORM_INTERVAL > 1:
+ bottom_up = build_resnet_interval_backbone(cfg, input_shape)
+ elif cfg.MODEL.MOBILENET:
bottom_up = build_mnv2_backbone(cfg, input_shape)
else:
bottom_up = build_resnet_backbone(cfg, input_shape)
diff --git a/adet/modeling/backbone/lpf.py b/adet/modeling/backbone/lpf.py
new file mode 100644
index 0000000..f455fea
--- /dev/null
+++ b/adet/modeling/backbone/lpf.py
@@ -0,0 +1,115 @@
+import torch
+import torch.nn.parallel
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+from IPython import embed
+
+
+class Downsample(nn.Module):
+ def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0):
+ super(Downsample, self).__init__()
+ self.filt_size = filt_size
+ self.pad_off = pad_off
+ self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))]
+ self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes]
+ self.stride = stride
+ self.off = int((self.stride-1)/2.)
+ self.channels = channels
+
+ # print('Filter size [%i]'%filt_size)
+ if(self.filt_size==1):
+ a = np.array([1.,])
+ elif(self.filt_size==2):
+ a = np.array([1., 1.])
+ elif(self.filt_size==3):
+ a = np.array([1., 2., 1.])
+ elif(self.filt_size==4):
+ a = np.array([1., 3., 3., 1.])
+ elif(self.filt_size==5):
+ a = np.array([1., 4., 6., 4., 1.])
+ elif(self.filt_size==6):
+ a = np.array([1., 5., 10., 10., 5., 1.])
+ elif(self.filt_size==7):
+ a = np.array([1., 6., 15., 20., 15., 6., 1.])
+
+ filt = torch.Tensor(a[:,None]*a[None,:])
+ filt = filt/torch.sum(filt)
+ self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1)))
+
+ self.pad = get_pad_layer(pad_type)(self.pad_sizes)
+
+ def forward(self, inp):
+ if(self.filt_size==1):
+ if(self.pad_off==0):
+ return inp[:,:,::self.stride,::self.stride]
+ else:
+ return self.pad(inp)[:,:,::self.stride,::self.stride]
+ else:
+ return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
+
+def get_pad_layer(pad_type):
+ if(pad_type in ['refl','reflect']):
+ PadLayer = nn.ReflectionPad2d
+ elif(pad_type in ['repl','replicate']):
+ PadLayer = nn.ReplicationPad2d
+ elif(pad_type=='zero'):
+ PadLayer = nn.ZeroPad2d
+ else:
+ print('Pad type [%s] not recognized'%pad_type)
+ return PadLayer
+
+
+class Downsample1D(nn.Module):
+ def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0):
+ super(Downsample1D, self).__init__()
+ self.filt_size = filt_size
+ self.pad_off = pad_off
+ self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
+ self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
+ self.stride = stride
+ self.off = int((self.stride - 1) / 2.)
+ self.channels = channels
+
+ # print('Filter size [%i]' % filt_size)
+ if(self.filt_size == 1):
+ a = np.array([1., ])
+ elif(self.filt_size == 2):
+ a = np.array([1., 1.])
+ elif(self.filt_size == 3):
+ a = np.array([1., 2., 1.])
+ elif(self.filt_size == 4):
+ a = np.array([1., 3., 3., 1.])
+ elif(self.filt_size == 5):
+ a = np.array([1., 4., 6., 4., 1.])
+ elif(self.filt_size == 6):
+ a = np.array([1., 5., 10., 10., 5., 1.])
+ elif(self.filt_size == 7):
+ a = np.array([1., 6., 15., 20., 15., 6., 1.])
+
+ filt = torch.Tensor(a)
+ filt = filt / torch.sum(filt)
+ self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1)))
+
+ self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes)
+
+ def forward(self, inp):
+ if(self.filt_size == 1):
+ if(self.pad_off == 0):
+ return inp[:, :, ::self.stride]
+ else:
+ return self.pad(inp)[:, :, ::self.stride]
+ else:
+ return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
+
+
+def get_pad_layer_1d(pad_type):
+ if(pad_type in ['refl', 'reflect']):
+ PadLayer = nn.ReflectionPad1d
+ elif(pad_type in ['repl', 'replicate']):
+ PadLayer = nn.ReplicationPad1d
+ elif(pad_type == 'zero'):
+ PadLayer = nn.ZeroPad1d
+ else:
+ print('Pad type [%s] not recognized' % pad_type)
+ return PadLayer
diff --git a/adet/modeling/backbone/resnet_interval.py b/adet/modeling/backbone/resnet_interval.py
new file mode 100644
index 0000000..b91be6e
--- /dev/null
+++ b/adet/modeling/backbone/resnet_interval.py
@@ -0,0 +1,116 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+from detectron2.layers import FrozenBatchNorm2d
+from detectron2.modeling.backbone import BACKBONE_REGISTRY
+from detectron2.modeling.backbone.resnet import (
+ BasicStem,
+ DeformBottleneckBlock,
+ BottleneckBlock,
+ ResNet,
+)
+
+
+def make_stage_intervals(block_class, num_blocks, first_stride, **kwargs):
+ """
+ Create a resnet stage by creating many blocks.
+ Args:
+ block_class (class): a subclass of ResNetBlockBase
+ num_blocks (int):
+ first_stride (int): the stride of the first block. The other blocks will have stride=1.
+ A `stride` argument will be passed to the block constructor.
+ kwargs: other arguments passed to the block constructor.
+
+ Returns:
+ list[nn.Module]: a list of block module.
+ """
+ blocks = []
+ conv_kwargs = {key: kwargs[key] for key in kwargs if "deform" not in key}
+ deform_kwargs = {key: kwargs[key] for key in kwargs if key != "deform_interval"}
+ deform_interval = kwargs.get("deform_interval", None)
+ for i in range(num_blocks):
+ if deform_interval and i % deform_interval == 0:
+ blocks.append(block_class(stride=first_stride if i == 0 else 1, **deform_kwargs))
+ else:
+ blocks.append(BottleneckBlock(stride=first_stride if i == 0 else 1, **conv_kwargs))
+ conv_kwargs["in_channels"] = conv_kwargs["out_channels"]
+ deform_kwargs["in_channels"] = deform_kwargs["out_channels"]
+ return blocks
+
+
+@BACKBONE_REGISTRY.register()
+def build_resnet_interval_backbone(cfg, input_shape):
+ """
+ Create a ResNet instance from config.
+
+ Returns:
+ ResNet: a :class:`ResNet` instance.
+ """
+ # need registration of new blocks/stems?
+ norm = cfg.MODEL.RESNETS.NORM
+ stem = BasicStem(
+ in_channels=input_shape.channels,
+ out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
+ norm=norm,
+ )
+ freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT
+
+ if freeze_at >= 1:
+ for p in stem.parameters():
+ p.requires_grad = False
+ stem = FrozenBatchNorm2d.convert_frozen_batchnorm(stem)
+
+ # fmt: off
+ out_features = cfg.MODEL.RESNETS.OUT_FEATURES
+ depth = cfg.MODEL.RESNETS.DEPTH
+ num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
+ width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
+ bottleneck_channels = num_groups * width_per_group
+ in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
+ out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
+ stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
+ res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION
+ deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE
+ deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED
+ deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS
+ deform_interval = cfg.MODEL.RESNETS.DEFORM_INTERVAL
+ # fmt: on
+ assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation)
+
+ num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth]
+
+ stages = []
+
+ # Avoid creating variables without gradients
+ # It consumes extra memory and may cause allreduce to fail
+ out_stage_idx = [{"res2": 2, "res3": 3, "res4": 4, "res5": 5}[f] for f in out_features]
+ max_stage_idx = max(out_stage_idx)
+ for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)):
+ dilation = res5_dilation if stage_idx == 5 else 1
+ first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2
+ stage_kargs = {
+ "num_blocks": num_blocks_per_stage[idx],
+ "first_stride": first_stride,
+ "in_channels": in_channels,
+ "bottleneck_channels": bottleneck_channels,
+ "out_channels": out_channels,
+ "num_groups": num_groups,
+ "norm": norm,
+ "stride_in_1x1": stride_in_1x1,
+ "dilation": dilation,
+ }
+ if deform_on_per_stage[idx]:
+ stage_kargs["block_class"] = DeformBottleneckBlock
+ stage_kargs["deform_modulated"] = deform_modulated
+ stage_kargs["deform_num_groups"] = deform_num_groups
+ stage_kargs["deform_interval"] = deform_interval
+ else:
+ stage_kargs["block_class"] = BottleneckBlock
+ blocks = make_stage_intervals(**stage_kargs)
+ in_channels = out_channels
+ out_channels *= 2
+ bottleneck_channels *= 2
+
+ if freeze_at >= stage_idx:
+ for block in blocks:
+ block.freeze()
+ stages.append(blocks)
+ return ResNet(stem, stages, out_features=out_features)
diff --git a/adet/modeling/backbone/resnet_lpf.py b/adet/modeling/backbone/resnet_lpf.py
new file mode 100644
index 0000000..867de6d
--- /dev/null
+++ b/adet/modeling/backbone/resnet_lpf.py
@@ -0,0 +1,291 @@
+# This code is built from the PyTorch examples repository: https://github.com/pytorch/vision/tree/master/torchvision/models.
+# Copyright (c) 2017 Torch Contributors.
+# The Pytorch examples are available under the BSD 3-Clause License.
+#
+# ==========================================================================================
+#
+# Adobe’s modifications are Copyright 2019 Adobe. All rights reserved.
+# Adobe’s modifications are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
+# 4.0 International Public License (CC-NC-SA-4.0). To view a copy of the license, visit
+# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.
+#
+# ==========================================================================================
+#
+# BSD-3 License
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+
+import torch.nn as nn
+
+from detectron2.layers.batch_norm import NaiveSyncBatchNorm
+from detectron2.modeling.backbone.build import BACKBONE_REGISTRY
+from detectron2.modeling.backbone import Backbone
+
+from .lpf import *
+
+
+__all__ = ['ResNetLPF', 'build_resnet_lpf_backbone']
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, groups=groups, bias=False)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None, filter_size=1):
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1:
+ raise ValueError('BasicBlock only supports groups=1')
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ if(stride == 1):
+ self.conv2 = conv3x3(planes, planes)
+ else:
+ self.conv2 = nn.Sequential(Downsample(filt_size=filter_size, stride=stride, channels=planes),
+ conv3x3(planes, planes),)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None, filter_size=1):
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, planes)
+ self.bn1 = norm_layer(planes)
+ self.conv2 = conv3x3(planes, planes, groups) # stride moved
+ self.bn2 = norm_layer(planes)
+ if(stride == 1):
+ self.conv3 = conv1x1(planes, planes * self.expansion)
+ else:
+ self.conv3 = nn.Sequential(Downsample(filt_size=filter_size, stride=stride, channels=planes),
+ conv1x1(planes, planes * self.expansion))
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNetLPF(Backbone):
+
+ def __init__(self, cfg, block, layers, num_classes=1000, zero_init_residual=False,
+ groups=1, width_per_group=64, norm_layer=None, filter_size=1,
+ pool_only=True, return_idx=[0, 1, 2, 3]):
+ super().__init__()
+ self.return_idx = return_idx
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ planes = [int(width_per_group * groups * 2 ** i) for i in range(4)]
+ self.inplanes = planes[0]
+
+ if(pool_only):
+ self.conv1 = nn.Conv2d(3, planes[0], kernel_size=7, stride=2, padding=3, bias=False)
+ else:
+ self.conv1 = nn.Conv2d(3, planes[0], kernel_size=7, stride=1, padding=3, bias=False)
+ self.bn1 = norm_layer(planes[0])
+ self.relu = nn.ReLU(inplace=True)
+
+ if(pool_only):
+ self.maxpool = nn.Sequential(*[nn.MaxPool2d(kernel_size=2, stride=1),
+ Downsample(filt_size=filter_size, stride=2, channels=planes[0])])
+ else:
+ self.maxpool = nn.Sequential(*[Downsample(filt_size=filter_size, stride=2, channels=planes[0]),
+ nn.MaxPool2d(kernel_size=2, stride=1),
+ Downsample(filt_size=filter_size, stride=2, channels=planes[0])])
+
+ self.layer1 = self._make_layer(
+ block, planes[0], layers[0], groups=groups, norm_layer=norm_layer)
+ self.layer2 = self._make_layer(
+ block, planes[1], layers[1], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size)
+ self.layer3 = self._make_layer(
+ block, planes[2], layers[2], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size)
+ self.layer4 = self._make_layer(
+ block, planes[3], layers[3], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ if(m.in_channels != m.out_channels or m.out_channels != m.groups or m.bias is not None):
+ # don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ else:
+ print('Not initializing')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0)
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+ self._freeze_backbone(cfg.MODEL.BACKBONE.FREEZE_AT)
+ if False:
+ self._freeze_bn()
+
+ def _freeze_backbone(self, freeze_at):
+ if freeze_at < 0:
+ return
+ for stage_index in range(freeze_at):
+ if stage_index == 0:
+ # stage 0 is the stem
+ for p in self.conv1.parameters():
+ p.requires_grad = False
+ for p in self.bn1.parameters():
+ p.requires_grad = False
+ else:
+ m = getattr(self, "layer" + str(stage_index))
+ for p in m.parameters():
+ p.requires_grad = False
+
+ def _freeze_bn(self):
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+
+ def _make_layer(self, block, planes, blocks, stride=1, groups=1, norm_layer=None, filter_size=1):
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ # downsample = nn.Sequential(
+ # conv1x1(self.inplanes, planes * block.expansion, stride, filter_size=filter_size),
+ # norm_layer(planes * block.expansion),
+ # )
+
+ downsample = [Downsample(filt_size=filter_size, stride=stride,
+ channels=self.inplanes), ] if(stride != 1) else []
+ downsample += [conv1x1(self.inplanes, planes * block.expansion, 1),
+ norm_layer(planes * block.expansion)]
+ # print(downsample)
+ downsample = nn.Sequential(*downsample)
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample,
+ groups, norm_layer, filter_size=filter_size))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=groups,
+ norm_layer=norm_layer, filter_size=filter_size))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ outs = []
+ outs.append(self.layer1(x)) # 1/4
+ outs.append(self.layer2(outs[-1])) # 1/8
+ outs.append(self.layer3(outs[-1])) # 1/16
+ outs.append(self.layer4(outs[-1])) # 1/32
+ return {"res{}".format(idx + 2): outs[idx] for idx in self.return_idx}
+
+
+@BACKBONE_REGISTRY.register()
+def build_resnet_lpf_backbone(cfg, input_shape):
+ """
+ Create a ResNet instance from config.
+
+ Returns:
+ ResNet: a :class:`ResNet` instance.
+ """
+ depth = cfg.MODEL.RESNETS.DEPTH
+ out_features = cfg.MODEL.RESNETS.OUT_FEATURES
+
+ num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth]
+ out_stage_idx = [{"res2": 0, "res3": 1, "res4": 2, "res5": 3}[f] for f in out_features]
+ out_feature_channels = {"res2": 256, "res3": 512,
+ "res4": 1024, "res5": 2048}
+ out_feature_strides = {"res2": 4, "res3": 8, "res4": 16, "res5": 32}
+ model = ResNetLPF(cfg, Bottleneck, num_blocks_per_stage, norm_layer=NaiveSyncBatchNorm,
+ filter_size=3, pool_only=True, return_idx=out_stage_idx)
+ model._out_features = out_features
+ model._out_feature_channels = out_feature_channels
+ model._out_feature_strides = out_feature_strides
+ return model
diff --git a/adet/modeling/fcos/fcos.py b/adet/modeling/fcos/fcos.py
index fb5895b..fbb0163 100644
--- a/adet/modeling/fcos/fcos.py
+++ b/adet/modeling/fcos/fcos.py
@@ -217,7 +217,7 @@ class FCOSHead(nn.Module):
in_channels, 4, kernel_size=3,
stride=1, padding=1
)
- self.centerness = nn.Conv2d(
+ self.ctrness = nn.Conv2d(
in_channels, 1, kernel_size=3,
stride=1, padding=1
)
@@ -230,7 +230,7 @@ class FCOSHead(nn.Module):
for modules in [
self.cls_tower, self.bbox_tower,
self.share_tower, self.cls_logits,
- self.bbox_pred, self.centerness
+ self.bbox_pred, self.ctrness
]:
for l in modules.modules():
if isinstance(l, nn.Conv2d):
@@ -256,7 +256,7 @@ class FCOSHead(nn.Module):
bbox_towers.append(bbox_tower)
logits.append(self.cls_logits(cls_tower))
- ctrness.append(self.centerness(bbox_tower))
+ ctrness.append(self.ctrness(bbox_tower))
reg = self.bbox_pred(bbox_tower)
if self.scales is not None:
reg = self.scales[l](reg)
diff --git a/configs/BlendMask/Panoptic/Base-Panoptic.yaml b/configs/BlendMask/Panoptic/Base-Panoptic.yaml
index 7fd16ec..ffd7ff0 100644
--- a/configs/BlendMask/Panoptic/Base-Panoptic.yaml
+++ b/configs/BlendMask/Panoptic/Base-Panoptic.yaml
@@ -9,7 +9,7 @@ MODEL:
PANOPTIC_FPN:
COMBINE:
ENABLED: True
- INSTANCES_CONFIDENCE_THRESH: 0.2
+ INSTANCES_CONFIDENCE_THRESH: 0.45
OVERLAP_THRESH: 0.4
DATASETS:
TRAIN: ("coco_2017_train_panoptic_separated",)
diff --git a/tools/rename_blendmask.py b/tools/rename_blendmask.py
new file mode 100644
index 0000000..694c502
--- /dev/null
+++ b/tools/rename_blendmask.py
@@ -0,0 +1,41 @@
+import argparse
+from collections import OrderedDict
+
+import torch
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(description="FCOS Detectron2 Converter")
+ parser.add_argument(
+ "--model",
+ default="weights/blendmask/person/R_50_1x.pth",
+ metavar="FILE",
+ help="path to model weights",
+ )
+ parser.add_argument(
+ "--output",
+ default="weights/blendmask/person/R_50_1x.pth",
+ metavar="FILE",
+ help="path to model weights",
+ )
+ return parser
+
+
+def rename_resnet_param_names(ckpt_state_dict):
+ converted_state_dict = OrderedDict()
+ for key in ckpt_state_dict.keys():
+ value = ckpt_state_dict[key]
+ key = key.replace("centerness", "ctrness")
+
+ converted_state_dict[key] = value
+ return converted_state_dict
+
+
+if __name__ == "__main__":
+ args = get_parser().parse_args()
+ ckpt = torch.load(args.model)
+ if "model" in ckpt:
+ model = rename_resnet_param_names(ckpt["model"])
+ else:
+ model = rename_resnet_param_names(ckpt)
+ torch.save(model, args.output)