[Enhancement]: Refactor SSD (#5291)
* add vgg neck * refactor ssd neck and vgg * refactor ssd head and neck * init l2 norm * update config * change ssdvgg backbone * revert to SSD_VGG * add unit test * fix ssd voc * avoid BC-breaking * add TODO * add convert script * avoid BC breaking * update readme * update download link * Fix docpull/5415/head
parent
d1ef85d9ff
commit
4058255b46
12 changed files with 428 additions and 115 deletions
@ -0,0 +1,128 @@ |
||||
import torch |
||||
import torch.nn as nn |
||||
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule |
||||
from mmcv.runner import BaseModule |
||||
|
||||
from ..builder import NECKS |
||||
|
||||
|
||||
@NECKS.register_module() |
||||
class SSDNeck(BaseModule): |
||||
"""Extra layers of SSD backbone to generate multi-scale feature maps. |
||||
|
||||
Args: |
||||
in_channels (Sequence[int]): Number of input channels per scale. |
||||
out_channels (Sequence[int]): Number of output channels per scale. |
||||
level_strides (Sequence[int]): Stride of 3x3 conv per level. |
||||
level_paddings (Sequence[int]): Padding size of 3x3 conv per level. |
||||
l2_norm_scale (float|None): L2 normalization layer init scale. |
||||
If None, not use L2 normalization on the first input feature. |
||||
last_kernel_size (int): Kernel size of the last conv layer. |
||||
Default: 3. |
||||
use_depthwise (bool): Whether to use DepthwiseSeparableConv. |
||||
Default: False. |
||||
conv_cfg (dict): Config dict for convolution layer. Default: None. |
||||
norm_cfg (dict): Dictionary to construct and config norm layer. |
||||
Default: None. |
||||
act_cfg (dict): Config dict for activation layer. |
||||
Default: dict(type='ReLU'). |
||||
init_cfg (dict or list[dict], optional): Initialization config dict. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
in_channels, |
||||
out_channels, |
||||
level_strides, |
||||
level_paddings, |
||||
l2_norm_scale=20., |
||||
last_kernel_size=3, |
||||
use_depthwise=False, |
||||
conv_cfg=None, |
||||
norm_cfg=None, |
||||
act_cfg=dict(type='ReLU'), |
||||
init_cfg=[ |
||||
dict( |
||||
type='Xavier', distribution='uniform', |
||||
layer='Conv2d'), |
||||
dict(type='Constant', val=1, layer='BatchNorm2d'), |
||||
]): |
||||
super(SSDNeck, self).__init__(init_cfg) |
||||
assert len(out_channels) > len(in_channels) |
||||
assert len(out_channels) - len(in_channels) == len(level_strides) |
||||
assert len(level_strides) == len(level_paddings) |
||||
assert in_channels == out_channels[:len(in_channels)] |
||||
|
||||
if l2_norm_scale: |
||||
self.l2_norm = L2Norm(in_channels[0], l2_norm_scale) |
||||
self.init_cfg += [ |
||||
dict( |
||||
type='Constant', |
||||
val=self.l2_norm.scale, |
||||
override=dict(name='l2_norm')) |
||||
] |
||||
|
||||
self.extra_layers = nn.ModuleList() |
||||
extra_layer_channels = out_channels[len(in_channels):] |
||||
second_conv = DepthwiseSeparableConvModule if \ |
||||
use_depthwise else ConvModule |
||||
|
||||
for i, (out_channel, stride, padding) in enumerate( |
||||
zip(extra_layer_channels, level_strides, level_paddings)): |
||||
kernel_size = last_kernel_size \ |
||||
if i == len(extra_layer_channels) - 1 else 3 |
||||
per_lvl_convs = nn.Sequential( |
||||
ConvModule( |
||||
out_channels[len(in_channels) - 1 + i], |
||||
out_channel // 2, |
||||
1, |
||||
conv_cfg=conv_cfg, |
||||
norm_cfg=norm_cfg, |
||||
act_cfg=act_cfg), |
||||
second_conv( |
||||
out_channel // 2, |
||||
out_channel, |
||||
kernel_size, |
||||
stride=stride, |
||||
padding=padding, |
||||
conv_cfg=conv_cfg, |
||||
norm_cfg=norm_cfg, |
||||
act_cfg=act_cfg)) |
||||
self.extra_layers.append(per_lvl_convs) |
||||
|
||||
def forward(self, inputs): |
||||
"""Forward function.""" |
||||
outs = [feat for feat in inputs] |
||||
if hasattr(self, 'l2_norm'): |
||||
outs[0] = self.l2_norm(outs[0]) |
||||
|
||||
feat = outs[-1] |
||||
for layer in self.extra_layers: |
||||
feat = layer(feat) |
||||
outs.append(feat) |
||||
return tuple(outs) |
||||
|
||||
|
||||
class L2Norm(nn.Module): |
||||
|
||||
def __init__(self, n_dims, scale=20., eps=1e-10): |
||||
"""L2 normalization layer. |
||||
|
||||
Args: |
||||
n_dims (int): Number of dimensions to be normalized |
||||
scale (float, optional): Defaults to 20.. |
||||
eps (float, optional): Used to avoid division by zero. |
||||
Defaults to 1e-10. |
||||
""" |
||||
super(L2Norm, self).__init__() |
||||
self.n_dims = n_dims |
||||
self.weight = nn.Parameter(torch.Tensor(self.n_dims)) |
||||
self.eps = eps |
||||
self.scale = scale |
||||
|
||||
def forward(self, x): |
||||
"""Forward function.""" |
||||
# normalization layer convert to FP32 in FP16 training |
||||
x_float = x.float() |
||||
norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps |
||||
return (self.weight[None, :, None, None].float().expand_as(x_float) * |
||||
x_float / norm).type_as(x) |
@ -0,0 +1,57 @@ |
||||
import argparse |
||||
import tempfile |
||||
from collections import OrderedDict |
||||
|
||||
import torch |
||||
from mmcv import Config |
||||
|
||||
|
||||
def parse_config(config_strings): |
||||
temp_file = tempfile.NamedTemporaryFile() |
||||
config_path = f'{temp_file.name}.py' |
||||
with open(config_path, 'w') as f: |
||||
f.write(config_strings) |
||||
|
||||
config = Config.fromfile(config_path) |
||||
# check whether it is SSD |
||||
if config.model.bbox_head.type != 'SSDHead': |
||||
raise AssertionError('This is not a SSD model.') |
||||
|
||||
|
||||
def convert(in_file, out_file): |
||||
checkpoint = torch.load(in_file) |
||||
in_state_dict = checkpoint.pop('state_dict') |
||||
out_state_dict = OrderedDict() |
||||
meta_info = checkpoint['meta'] |
||||
parse_config('#' + meta_info['config']) |
||||
for key, value in in_state_dict.items(): |
||||
if 'extra' in key: |
||||
layer_idx = int(key.split('.')[2]) |
||||
new_key = 'neck.extra_layers.{}.{}.conv.'.format( |
||||
layer_idx // 2, layer_idx % 2) + key.split('.')[-1] |
||||
elif 'l2_norm' in key: |
||||
new_key = 'neck.l2_norm.weight' |
||||
elif 'bbox_head' in key: |
||||
new_key = key[:21] + '.0' + key[21:] |
||||
else: |
||||
new_key = key |
||||
out_state_dict[new_key] = value |
||||
checkpoint['state_dict'] = out_state_dict |
||||
|
||||
if torch.__version__ >= '1.6': |
||||
torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False) |
||||
else: |
||||
torch.save(checkpoint, out_file) |
||||
|
||||
|
||||
def main(): |
||||
parser = argparse.ArgumentParser(description='Upgrade SSD version') |
||||
parser.add_argument('in_file', help='input checkpoint file') |
||||
parser.add_argument('out_file', help='output checkpoint file') |
||||
|
||||
args = parser.parse_args() |
||||
convert(args.in_file, args.out_file) |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
main() |
Loading…
Reference in new issue