parent
38a599c94f
commit
281ee240d9
11 changed files with 630 additions and 28 deletions
@ -0,0 +1,44 @@ |
||||
# AdelaiDet Model Zoo and Baselines |
||||
|
||||
## Introduction |
||||
This file documents a collection of models trained with AdelaiDet in Nov, 2019. |
||||
|
||||
## Models |
||||
|
||||
The inference time is measured on one 1080Ti based on the most recent commit on Detectron2 ([ffff8ac](https://github.com/facebookresearch/detectron2/commit/ffff8acc35ea88ad1cb1806ab0f00b4c1c5dbfd9)). |
||||
|
||||
More models will be released soon. Stay tuned. |
||||
|
||||
### COCO Object Detecton Baselines with FCOS |
||||
|
||||
Name | box AP | download |
||||
--- |:---:|:---: |
||||
[FCOS_R_50_1x](configs/FCOS-Detection/R_50_1x.yaml) | 38.7 | [model](https://cloudstor.aarnet.edu.au/plus/s/glqFc13cCoEyHYy/download) |
||||
|
||||
### COCO Instance Segmentation Baselines with [BlendMask](https://arxiv.org/abs/2001.00309) |
||||
|
||||
Model | Name |inference time (ms/im) | box AP | mask AP | download |
||||
--- |:---:|:---:|:---:|:---:|:---: |
||||
Mask R-CNN | [550_R_50_3x](configs/RCNN/550_R_50_FPN_3x.yaml) | 63 | 39.1 | 35.3 | |
||||
BlendMask | [550_R_50_3x](configs/BlendMask/550_R_50_3x.yaml) | 36 | 38.7 | 34.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/R3Qintf7N8UCiIt/download) |
||||
Mask R-CNN | [R_50_1x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml) | 80 | 38.6 | 35.2 | |
||||
BlendMask | [R_50_1x](configs/BlendMask/R_50_1x.yaml) | 73 | 39.9 | 35.8 | [model](https://cloudstor.aarnet.edu.au/plus/s/zoxXPnr6Hw3OJgK/download) |
||||
Mask R-CNN | [R_50_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml) | 80 | 41.0 | 37.2 | |
||||
BlendMask | [R_50_3x](configs/BlendMask/R_50_3x.yaml) | 74 | 42.7 | 37.8 | [model](https://cloudstor.aarnet.edu.au/plus/s/ZnaInHFEKst6mvg/download) |
||||
Mask R-CNN | [R_101_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml) | 100 | 42.9 | 38.6 | |
||||
BlendMask | [R_101_3x](configs/BlendMask/R_101_3x.yaml) | 94 | 44.8 | 39.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/e4fXrliAcMtyEBy/download) |
||||
BlendMask | [R_101_dcni3_5x](configs/BlendMask/R_101_dcni3_5x.yaml) | 105 | 46.8 | 41.1 | [model](https://cloudstor.aarnet.edu.au/plus/s/vbnKnQtaGlw8TKv/download) |
||||
|
||||
### COCO Panoptic Segmentation Baselines with BlendMask |
||||
Model | Name | PQ | PQ<sup>Th</sup> | PQ<sup>St</sup> | download |
||||
--- |:---:|:---:|:---:|:---:|:---: |
||||
Panoptic FPN | [R_50_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml) | 41.5 | 48.3 | 31.2 | |
||||
BlendMask | [R_50_3x](configs/BlendMask/Panoptic/R_50_3x.yaml) | 42.5 | 49.5 | 32.0 | [model](https://cloudstor.aarnet.edu.au/plus/s/oDgi0826JOJXCr5/download) |
||||
Panoptic FPN | [R_101_3x](https://github.com/facebookresearch/detectron2/blob/master/configs/COCO-InstanceSegmentation/panoptic_fpn_R_101_3x.yaml) | 43.0 | 49.7 | 32.9 | |
||||
BlendMask | [R_101_3x](configs/BlendMask/Panoptic/R_101_3x.yaml) | 44.3 | 51.6 | 33.2 | [model](https://cloudstor.aarnet.edu.au/plus/s/u6gZwj06MWDEkYe/download) |
||||
BlendMask | [R_101_dcni3_5x](configs/BlendMask/Panoptic/R_101_dcni3_5x.yaml) | 46.0 | 52.9 | 35.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/Jwp41WEzDdrhWsN/download) |
||||
|
||||
### Person in Context with BlendMask |
||||
Model | Name | box AP | mask AP | download |
||||
--- |:---:|:---:|:---:|:---: |
||||
BlendMask | [R_50_1x](configs/BlendMask/Person/R_50_1x.yaml) | 70.6 | 66.7 | [model](https://cloudstor.aarnet.edu.au/plus/s/nvpcKTFA5fsagc0/download) |
@ -1,2 +1,3 @@ |
||||
from .fpn import build_fcos_resnet_fpn_backbone |
||||
from .vovnet import build_vovnet_fpn_backbone, build_vovnet_backbone |
||||
from .resnet_lpf import build_resnet_lpf_backbone |
||||
|
@ -0,0 +1,115 @@ |
||||
import torch |
||||
import torch.nn.parallel |
||||
import numpy as np |
||||
import torch.nn as nn |
||||
import torch.nn.functional as F |
||||
from IPython import embed |
||||
|
||||
|
||||
class Downsample(nn.Module): |
||||
def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0): |
||||
super(Downsample, self).__init__() |
||||
self.filt_size = filt_size |
||||
self.pad_off = pad_off |
||||
self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))] |
||||
self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes] |
||||
self.stride = stride |
||||
self.off = int((self.stride-1)/2.) |
||||
self.channels = channels |
||||
|
||||
# print('Filter size [%i]'%filt_size) |
||||
if(self.filt_size==1): |
||||
a = np.array([1.,]) |
||||
elif(self.filt_size==2): |
||||
a = np.array([1., 1.]) |
||||
elif(self.filt_size==3): |
||||
a = np.array([1., 2., 1.]) |
||||
elif(self.filt_size==4): |
||||
a = np.array([1., 3., 3., 1.]) |
||||
elif(self.filt_size==5): |
||||
a = np.array([1., 4., 6., 4., 1.]) |
||||
elif(self.filt_size==6): |
||||
a = np.array([1., 5., 10., 10., 5., 1.]) |
||||
elif(self.filt_size==7): |
||||
a = np.array([1., 6., 15., 20., 15., 6., 1.]) |
||||
|
||||
filt = torch.Tensor(a[:,None]*a[None,:]) |
||||
filt = filt/torch.sum(filt) |
||||
self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1))) |
||||
|
||||
self.pad = get_pad_layer(pad_type)(self.pad_sizes) |
||||
|
||||
def forward(self, inp): |
||||
if(self.filt_size==1): |
||||
if(self.pad_off==0): |
||||
return inp[:,:,::self.stride,::self.stride] |
||||
else: |
||||
return self.pad(inp)[:,:,::self.stride,::self.stride] |
||||
else: |
||||
return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) |
||||
|
||||
def get_pad_layer(pad_type): |
||||
if(pad_type in ['refl','reflect']): |
||||
PadLayer = nn.ReflectionPad2d |
||||
elif(pad_type in ['repl','replicate']): |
||||
PadLayer = nn.ReplicationPad2d |
||||
elif(pad_type=='zero'): |
||||
PadLayer = nn.ZeroPad2d |
||||
else: |
||||
print('Pad type [%s] not recognized'%pad_type) |
||||
return PadLayer |
||||
|
||||
|
||||
class Downsample1D(nn.Module): |
||||
def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0): |
||||
super(Downsample1D, self).__init__() |
||||
self.filt_size = filt_size |
||||
self.pad_off = pad_off |
||||
self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))] |
||||
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] |
||||
self.stride = stride |
||||
self.off = int((self.stride - 1) / 2.) |
||||
self.channels = channels |
||||
|
||||
# print('Filter size [%i]' % filt_size) |
||||
if(self.filt_size == 1): |
||||
a = np.array([1., ]) |
||||
elif(self.filt_size == 2): |
||||
a = np.array([1., 1.]) |
||||
elif(self.filt_size == 3): |
||||
a = np.array([1., 2., 1.]) |
||||
elif(self.filt_size == 4): |
||||
a = np.array([1., 3., 3., 1.]) |
||||
elif(self.filt_size == 5): |
||||
a = np.array([1., 4., 6., 4., 1.]) |
||||
elif(self.filt_size == 6): |
||||
a = np.array([1., 5., 10., 10., 5., 1.]) |
||||
elif(self.filt_size == 7): |
||||
a = np.array([1., 6., 15., 20., 15., 6., 1.]) |
||||
|
||||
filt = torch.Tensor(a) |
||||
filt = filt / torch.sum(filt) |
||||
self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1))) |
||||
|
||||
self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes) |
||||
|
||||
def forward(self, inp): |
||||
if(self.filt_size == 1): |
||||
if(self.pad_off == 0): |
||||
return inp[:, :, ::self.stride] |
||||
else: |
||||
return self.pad(inp)[:, :, ::self.stride] |
||||
else: |
||||
return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) |
||||
|
||||
|
||||
def get_pad_layer_1d(pad_type): |
||||
if(pad_type in ['refl', 'reflect']): |
||||
PadLayer = nn.ReflectionPad1d |
||||
elif(pad_type in ['repl', 'replicate']): |
||||
PadLayer = nn.ReplicationPad1d |
||||
elif(pad_type == 'zero'): |
||||
PadLayer = nn.ZeroPad1d |
||||
else: |
||||
print('Pad type [%s] not recognized' % pad_type) |
||||
return PadLayer |
@ -0,0 +1,116 @@ |
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved |
||||
from detectron2.layers import FrozenBatchNorm2d |
||||
from detectron2.modeling.backbone import BACKBONE_REGISTRY |
||||
from detectron2.modeling.backbone.resnet import ( |
||||
BasicStem, |
||||
DeformBottleneckBlock, |
||||
BottleneckBlock, |
||||
ResNet, |
||||
) |
||||
|
||||
|
||||
def make_stage_intervals(block_class, num_blocks, first_stride, **kwargs): |
||||
""" |
||||
Create a resnet stage by creating many blocks. |
||||
Args: |
||||
block_class (class): a subclass of ResNetBlockBase |
||||
num_blocks (int): |
||||
first_stride (int): the stride of the first block. The other blocks will have stride=1. |
||||
A `stride` argument will be passed to the block constructor. |
||||
kwargs: other arguments passed to the block constructor. |
||||
|
||||
Returns: |
||||
list[nn.Module]: a list of block module. |
||||
""" |
||||
blocks = [] |
||||
conv_kwargs = {key: kwargs[key] for key in kwargs if "deform" not in key} |
||||
deform_kwargs = {key: kwargs[key] for key in kwargs if key != "deform_interval"} |
||||
deform_interval = kwargs.get("deform_interval", None) |
||||
for i in range(num_blocks): |
||||
if deform_interval and i % deform_interval == 0: |
||||
blocks.append(block_class(stride=first_stride if i == 0 else 1, **deform_kwargs)) |
||||
else: |
||||
blocks.append(BottleneckBlock(stride=first_stride if i == 0 else 1, **conv_kwargs)) |
||||
conv_kwargs["in_channels"] = conv_kwargs["out_channels"] |
||||
deform_kwargs["in_channels"] = deform_kwargs["out_channels"] |
||||
return blocks |
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register() |
||||
def build_resnet_interval_backbone(cfg, input_shape): |
||||
""" |
||||
Create a ResNet instance from config. |
||||
|
||||
Returns: |
||||
ResNet: a :class:`ResNet` instance. |
||||
""" |
||||
# need registration of new blocks/stems? |
||||
norm = cfg.MODEL.RESNETS.NORM |
||||
stem = BasicStem( |
||||
in_channels=input_shape.channels, |
||||
out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, |
||||
norm=norm, |
||||
) |
||||
freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT |
||||
|
||||
if freeze_at >= 1: |
||||
for p in stem.parameters(): |
||||
p.requires_grad = False |
||||
stem = FrozenBatchNorm2d.convert_frozen_batchnorm(stem) |
||||
|
||||
# fmt: off |
||||
out_features = cfg.MODEL.RESNETS.OUT_FEATURES |
||||
depth = cfg.MODEL.RESNETS.DEPTH |
||||
num_groups = cfg.MODEL.RESNETS.NUM_GROUPS |
||||
width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP |
||||
bottleneck_channels = num_groups * width_per_group |
||||
in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS |
||||
out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS |
||||
stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 |
||||
res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION |
||||
deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE |
||||
deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED |
||||
deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS |
||||
deform_interval = cfg.MODEL.RESNETS.DEFORM_INTERVAL |
||||
# fmt: on |
||||
assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation) |
||||
|
||||
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth] |
||||
|
||||
stages = [] |
||||
|
||||
# Avoid creating variables without gradients |
||||
# It consumes extra memory and may cause allreduce to fail |
||||
out_stage_idx = [{"res2": 2, "res3": 3, "res4": 4, "res5": 5}[f] for f in out_features] |
||||
max_stage_idx = max(out_stage_idx) |
||||
for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)): |
||||
dilation = res5_dilation if stage_idx == 5 else 1 |
||||
first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2 |
||||
stage_kargs = { |
||||
"num_blocks": num_blocks_per_stage[idx], |
||||
"first_stride": first_stride, |
||||
"in_channels": in_channels, |
||||
"bottleneck_channels": bottleneck_channels, |
||||
"out_channels": out_channels, |
||||
"num_groups": num_groups, |
||||
"norm": norm, |
||||
"stride_in_1x1": stride_in_1x1, |
||||
"dilation": dilation, |
||||
} |
||||
if deform_on_per_stage[idx]: |
||||
stage_kargs["block_class"] = DeformBottleneckBlock |
||||
stage_kargs["deform_modulated"] = deform_modulated |
||||
stage_kargs["deform_num_groups"] = deform_num_groups |
||||
stage_kargs["deform_interval"] = deform_interval |
||||
else: |
||||
stage_kargs["block_class"] = BottleneckBlock |
||||
blocks = make_stage_intervals(**stage_kargs) |
||||
in_channels = out_channels |
||||
out_channels *= 2 |
||||
bottleneck_channels *= 2 |
||||
|
||||
if freeze_at >= stage_idx: |
||||
for block in blocks: |
||||
block.freeze() |
||||
stages.append(blocks) |
||||
return ResNet(stem, stages, out_features=out_features) |
@ -0,0 +1,291 @@ |
||||
# This code is built from the PyTorch examples repository: https://github.com/pytorch/vision/tree/master/torchvision/models. |
||||
# Copyright (c) 2017 Torch Contributors. |
||||
# The Pytorch examples are available under the BSD 3-Clause License. |
||||
# |
||||
# ========================================================================================== |
||||
# |
||||
# Adobe’s modifications are Copyright 2019 Adobe. All rights reserved. |
||||
# Adobe’s modifications are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike |
||||
# 4.0 International Public License (CC-NC-SA-4.0). To view a copy of the license, visit |
||||
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. |
||||
# |
||||
# ========================================================================================== |
||||
# |
||||
# BSD-3 License |
||||
# |
||||
# Redistribution and use in source and binary forms, with or without |
||||
# modification, are permitted provided that the following conditions are met: |
||||
# |
||||
# * Redistributions of source code must retain the above copyright notice, this |
||||
# list of conditions and the following disclaimer. |
||||
# |
||||
# * Redistributions in binary form must reproduce the above copyright notice, |
||||
# this list of conditions and the following disclaimer in the documentation |
||||
# and/or other materials provided with the distribution. |
||||
# |
||||
# * Neither the name of the copyright holder nor the names of its |
||||
# contributors may be used to endorse or promote products derived from |
||||
# this software without specific prior written permission. |
||||
# |
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE |
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL |
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR |
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER |
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, |
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
||||
|
||||
import torch.nn as nn |
||||
|
||||
from detectron2.layers.batch_norm import NaiveSyncBatchNorm |
||||
from detectron2.modeling.backbone.build import BACKBONE_REGISTRY |
||||
from detectron2.modeling.backbone import Backbone |
||||
|
||||
from .lpf import * |
||||
|
||||
|
||||
__all__ = ['ResNetLPF', 'build_resnet_lpf_backbone'] |
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1): |
||||
"""3x3 convolution with padding""" |
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, |
||||
padding=1, groups=groups, bias=False) |
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1): |
||||
"""1x1 convolution""" |
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) |
||||
|
||||
|
||||
class BasicBlock(nn.Module): |
||||
expansion = 1 |
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None, filter_size=1): |
||||
super(BasicBlock, self).__init__() |
||||
if norm_layer is None: |
||||
norm_layer = nn.BatchNorm2d |
||||
if groups != 1: |
||||
raise ValueError('BasicBlock only supports groups=1') |
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1 |
||||
self.conv1 = conv3x3(inplanes, planes) |
||||
self.bn1 = norm_layer(planes) |
||||
self.relu = nn.ReLU(inplace=True) |
||||
if(stride == 1): |
||||
self.conv2 = conv3x3(planes, planes) |
||||
else: |
||||
self.conv2 = nn.Sequential(Downsample(filt_size=filter_size, stride=stride, channels=planes), |
||||
conv3x3(planes, planes),) |
||||
self.bn2 = norm_layer(planes) |
||||
self.downsample = downsample |
||||
self.stride = stride |
||||
|
||||
def forward(self, x): |
||||
identity = x |
||||
|
||||
out = self.conv1(x) |
||||
out = self.bn1(out) |
||||
out = self.relu(out) |
||||
|
||||
out = self.conv2(out) |
||||
out = self.bn2(out) |
||||
|
||||
if self.downsample is not None: |
||||
identity = self.downsample(x) |
||||
|
||||
out += identity |
||||
out = self.relu(out) |
||||
|
||||
return out |
||||
|
||||
|
||||
class Bottleneck(nn.Module): |
||||
expansion = 4 |
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None, filter_size=1): |
||||
super(Bottleneck, self).__init__() |
||||
if norm_layer is None: |
||||
norm_layer = nn.BatchNorm2d |
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1 |
||||
self.conv1 = conv1x1(inplanes, planes) |
||||
self.bn1 = norm_layer(planes) |
||||
self.conv2 = conv3x3(planes, planes, groups) # stride moved |
||||
self.bn2 = norm_layer(planes) |
||||
if(stride == 1): |
||||
self.conv3 = conv1x1(planes, planes * self.expansion) |
||||
else: |
||||
self.conv3 = nn.Sequential(Downsample(filt_size=filter_size, stride=stride, channels=planes), |
||||
conv1x1(planes, planes * self.expansion)) |
||||
self.bn3 = norm_layer(planes * self.expansion) |
||||
self.relu = nn.ReLU(inplace=True) |
||||
self.downsample = downsample |
||||
self.stride = stride |
||||
|
||||
def forward(self, x): |
||||
identity = x |
||||
|
||||
out = self.conv1(x) |
||||
out = self.bn1(out) |
||||
out = self.relu(out) |
||||
|
||||
out = self.conv2(out) |
||||
out = self.bn2(out) |
||||
out = self.relu(out) |
||||
|
||||
out = self.conv3(out) |
||||
out = self.bn3(out) |
||||
|
||||
if self.downsample is not None: |
||||
identity = self.downsample(x) |
||||
|
||||
out += identity |
||||
out = self.relu(out) |
||||
|
||||
return out |
||||
|
||||
|
||||
class ResNetLPF(Backbone): |
||||
|
||||
def __init__(self, cfg, block, layers, num_classes=1000, zero_init_residual=False, |
||||
groups=1, width_per_group=64, norm_layer=None, filter_size=1, |
||||
pool_only=True, return_idx=[0, 1, 2, 3]): |
||||
super().__init__() |
||||
self.return_idx = return_idx |
||||
if norm_layer is None: |
||||
norm_layer = nn.BatchNorm2d |
||||
planes = [int(width_per_group * groups * 2 ** i) for i in range(4)] |
||||
self.inplanes = planes[0] |
||||
|
||||
if(pool_only): |
||||
self.conv1 = nn.Conv2d(3, planes[0], kernel_size=7, stride=2, padding=3, bias=False) |
||||
else: |
||||
self.conv1 = nn.Conv2d(3, planes[0], kernel_size=7, stride=1, padding=3, bias=False) |
||||
self.bn1 = norm_layer(planes[0]) |
||||
self.relu = nn.ReLU(inplace=True) |
||||
|
||||
if(pool_only): |
||||
self.maxpool = nn.Sequential(*[nn.MaxPool2d(kernel_size=2, stride=1), |
||||
Downsample(filt_size=filter_size, stride=2, channels=planes[0])]) |
||||
else: |
||||
self.maxpool = nn.Sequential(*[Downsample(filt_size=filter_size, stride=2, channels=planes[0]), |
||||
nn.MaxPool2d(kernel_size=2, stride=1), |
||||
Downsample(filt_size=filter_size, stride=2, channels=planes[0])]) |
||||
|
||||
self.layer1 = self._make_layer( |
||||
block, planes[0], layers[0], groups=groups, norm_layer=norm_layer) |
||||
self.layer2 = self._make_layer( |
||||
block, planes[1], layers[1], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size) |
||||
self.layer3 = self._make_layer( |
||||
block, planes[2], layers[2], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size) |
||||
self.layer4 = self._make_layer( |
||||
block, planes[3], layers[3], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size) |
||||
|
||||
for m in self.modules(): |
||||
if isinstance(m, nn.Conv2d): |
||||
if(m.in_channels != m.out_channels or m.out_channels != m.groups or m.bias is not None): |
||||
# don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics |
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
||||
else: |
||||
print('Not initializing') |
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): |
||||
nn.init.constant_(m.weight, 1) |
||||
nn.init.constant_(m.bias, 0) |
||||
|
||||
# Zero-initialize the last BN in each residual branch, |
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity. |
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 |
||||
if zero_init_residual: |
||||
for m in self.modules(): |
||||
if isinstance(m, Bottleneck): |
||||
nn.init.constant_(m.bn3.weight, 0) |
||||
elif isinstance(m, BasicBlock): |
||||
nn.init.constant_(m.bn2.weight, 0) |
||||
self._freeze_backbone(cfg.MODEL.BACKBONE.FREEZE_AT) |
||||
if False: |
||||
self._freeze_bn() |
||||
|
||||
def _freeze_backbone(self, freeze_at): |
||||
if freeze_at < 0: |
||||
return |
||||
for stage_index in range(freeze_at): |
||||
if stage_index == 0: |
||||
# stage 0 is the stem |
||||
for p in self.conv1.parameters(): |
||||
p.requires_grad = False |
||||
for p in self.bn1.parameters(): |
||||
p.requires_grad = False |
||||
else: |
||||
m = getattr(self, "layer" + str(stage_index)) |
||||
for p in m.parameters(): |
||||
p.requires_grad = False |
||||
|
||||
def _freeze_bn(self): |
||||
for m in self.modules(): |
||||
if isinstance(m, nn.BatchNorm2d): |
||||
m.eval() |
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, groups=1, norm_layer=None, filter_size=1): |
||||
if norm_layer is None: |
||||
norm_layer = nn.BatchNorm2d |
||||
downsample = None |
||||
if stride != 1 or self.inplanes != planes * block.expansion: |
||||
# downsample = nn.Sequential( |
||||
# conv1x1(self.inplanes, planes * block.expansion, stride, filter_size=filter_size), |
||||
# norm_layer(planes * block.expansion), |
||||
# ) |
||||
|
||||
downsample = [Downsample(filt_size=filter_size, stride=stride, |
||||
channels=self.inplanes), ] if(stride != 1) else [] |
||||
downsample += [conv1x1(self.inplanes, planes * block.expansion, 1), |
||||
norm_layer(planes * block.expansion)] |
||||
# print(downsample) |
||||
downsample = nn.Sequential(*downsample) |
||||
|
||||
layers = [] |
||||
layers.append(block(self.inplanes, planes, stride, downsample, |
||||
groups, norm_layer, filter_size=filter_size)) |
||||
self.inplanes = planes * block.expansion |
||||
for _ in range(1, blocks): |
||||
layers.append(block(self.inplanes, planes, groups=groups, |
||||
norm_layer=norm_layer, filter_size=filter_size)) |
||||
|
||||
return nn.Sequential(*layers) |
||||
|
||||
def forward(self, x): |
||||
x = self.conv1(x) |
||||
x = self.bn1(x) |
||||
x = self.relu(x) |
||||
x = self.maxpool(x) |
||||
|
||||
outs = [] |
||||
outs.append(self.layer1(x)) # 1/4 |
||||
outs.append(self.layer2(outs[-1])) # 1/8 |
||||
outs.append(self.layer3(outs[-1])) # 1/16 |
||||
outs.append(self.layer4(outs[-1])) # 1/32 |
||||
return {"res{}".format(idx + 2): outs[idx] for idx in self.return_idx} |
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register() |
||||
def build_resnet_lpf_backbone(cfg, input_shape): |
||||
""" |
||||
Create a ResNet instance from config. |
||||
|
||||
Returns: |
||||
ResNet: a :class:`ResNet` instance. |
||||
""" |
||||
depth = cfg.MODEL.RESNETS.DEPTH |
||||
out_features = cfg.MODEL.RESNETS.OUT_FEATURES |
||||
|
||||
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth] |
||||
out_stage_idx = [{"res2": 0, "res3": 1, "res4": 2, "res5": 3}[f] for f in out_features] |
||||
out_feature_channels = {"res2": 256, "res3": 512, |
||||
"res4": 1024, "res5": 2048} |
||||
out_feature_strides = {"res2": 4, "res3": 8, "res4": 16, "res5": 32} |
||||
model = ResNetLPF(cfg, Bottleneck, num_blocks_per_stage, norm_layer=NaiveSyncBatchNorm, |
||||
filter_size=3, pool_only=True, return_idx=out_stage_idx) |
||||
model._out_features = out_features |
||||
model._out_feature_channels = out_feature_channels |
||||
model._out_feature_strides = out_feature_strides |
||||
return model |
@ -0,0 +1,41 @@ |
||||
import argparse |
||||
from collections import OrderedDict |
||||
|
||||
import torch |
||||
|
||||
|
||||
def get_parser(): |
||||
parser = argparse.ArgumentParser(description="FCOS Detectron2 Converter") |
||||
parser.add_argument( |
||||
"--model", |
||||
default="weights/blendmask/person/R_50_1x.pth", |
||||
metavar="FILE", |
||||
help="path to model weights", |
||||
) |
||||
parser.add_argument( |
||||
"--output", |
||||
default="weights/blendmask/person/R_50_1x.pth", |
||||
metavar="FILE", |
||||
help="path to model weights", |
||||
) |
||||
return parser |
||||
|
||||
|
||||
def rename_resnet_param_names(ckpt_state_dict): |
||||
converted_state_dict = OrderedDict() |
||||
for key in ckpt_state_dict.keys(): |
||||
value = ckpt_state_dict[key] |
||||
key = key.replace("centerness", "ctrness") |
||||
|
||||
converted_state_dict[key] = value |
||||
return converted_state_dict |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
args = get_parser().parse_args() |
||||
ckpt = torch.load(args.model) |
||||
if "model" in ckpt: |
||||
model = rename_resnet_param_names(ckpt["model"]) |
||||
else: |
||||
model = rename_resnet_param_names(ckpt) |
||||
torch.save(model, args.output) |
Loading…
Reference in new issue