[Enhancement]: Refactor SSD (#5291)

* add vgg neck

* refactor ssd neck and vgg

* refactor ssd head and neck

* init l2 norm

* update config

* change ssdvgg backbone

* revert to SSD_VGG

* add unit test

* fix ssd voc

* avoid BC-breaking

* add TODO

* add convert script

* avoid BC breaking

* update readme

* update download link

* Fix doc
pull/5415/head
RangiLyu 3 years ago committed by GitHub
parent d1ef85d9ff
commit 4058255b46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 10
      configs/_base_/models/ssd300.py
  2. 1
      configs/pascal_voc/ssd512_voc0712.py
  3. 18
      configs/ssd/README.md
  4. 12
      configs/ssd/metafile.yml
  5. 6
      configs/ssd/ssd512_coco.py
  6. 13
      docs/compatibility.md
  7. 107
      mmdet/models/backbones/ssd_vgg.py
  8. 118
      mmdet/models/dense_heads/ssd_head.py
  9. 4
      mmdet/models/necks/__init__.py
  10. 128
      mmdet/models/necks/ssd_neck.py
  11. 69
      tests/test_models/test_necks.py
  12. 57
      tools/model_converters/upgrade_ssd_version.py

@ -5,14 +5,18 @@ model = dict(
pretrained='open-mmlab://vgg16_caffe',
backbone=dict(
type='SSDVGG',
input_size=input_size,
depth=16,
with_last_pool=False,
ceil_mode=True,
out_indices=(3, 4),
out_feature_indices=(22, 34),
out_feature_indices=(22, 34)),
neck=dict(
type='SSDNeck',
in_channels=(512, 1024),
out_channels=(512, 1024, 512, 256, 256, 256),
level_strides=(2, 2, 1, 1),
level_paddings=(1, 1, 0, 0),
l2_norm_scale=20),
neck=None,
bbox_head=dict(
type='SSDHead',
in_channels=(512, 1024, 512, 256, 256, 256),

@ -1,7 +1,6 @@
_base_ = 'ssd300_voc0712.py'
input_size = 512
model = dict(
backbone=dict(input_size=input_size),
bbox_head=dict(
in_channels=(512, 1024, 512, 256, 256, 256, 256),
anchor_generator=dict(

@ -17,5 +17,19 @@
| Backbone | Size | Style | Lr schd | Mem (GB) | Inf time (fps) | box AP | Config | Download |
| :------: | :---: | :---: | :-----: | :------: | :------------: | :----: | :------: | :--------: |
| VGG16 | 300 | caffe | 120e | 10.2 | 43.7 | 25.6 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd/ssd300_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd300_coco/ssd300_coco_20200307-a92d2092.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd300_coco/ssd300_coco_20200307_174216.log.json) |
| VGG16 | 512 | caffe | 120e | 9.3 | 30.7 | 29.4 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd/ssd512_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd512_coco/ssd512_coco_20200308-038c5591.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd512_coco/ssd512_coco_20200308_134447.log.json) |
| VGG16 | 300 | caffe | 120e | 9.9 | 43.7 | 25.9 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd/ssd300_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd300_coco/ssd300_coco_20210604_193052-b61137df.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd300_coco/ssd300_coco_20210604_193052.log.json) |
| VGG16 | 512 | caffe | 120e | 19.4 | 30.7 | 29.8 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd/ssd512_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd512_coco/ssd512_coco_20210604_111835-d3eba047.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd512_coco/ssd512_coco_20210604_111835.log.json) |
## Notice
In v2.14.0, [PR5291](https://github.com/open-mmlab/mmdetection/pull/5291) refactored SSD neck and head for more
flexible usage. If users want to use the SSD checkpoint trained in the older versions, we provide a scripts
`tools/model_converters/upgrade_ssd_version.py` to convert the model weights.
```bash
python tools/model_converters/upgrade_ssd_version.py ${OLD_MODEL_PATH} ${NEW_MODEL_PATH}
```
- OLD_MODEL_PATH: the path to load the old version SSD model.
- NEW_MODEL_PATH: the path to save the converted model weights.

@ -16,24 +16,24 @@ Models:
In Collection: SSD
Config: configs/ssd/ssd300_coco.py
Metadata:
Training Memory (GB): 10.2
Training Memory (GB): 9.9
inference time (s/im): 0.02288
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 25.6
Weights: https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd300_coco/ssd300_coco_20200307-a92d2092.pth
box AP: 25.9
Weights: https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd300_coco/ssd300_coco_20210604_193052-b61137df.pth
- Name: ssd512_coco
In Collection: SSD
Config: configs/ssd/ssd512_coco.py
Metadata:
Training Memory (GB): 9.3
Training Memory (GB): 19.4
inference time (s/im): 0.03257
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 29.4
Weights: https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd512_coco/ssd512_coco_20200308-038c5591.pth
box AP: 29.8
Weights: https://download.openmmlab.com/mmdetection/v2.0/ssd/ssd512_coco/ssd512_coco_20210604_111835-d3eba047.pth

@ -1,7 +1,11 @@
_base_ = 'ssd300_coco.py'
input_size = 512
model = dict(
backbone=dict(input_size=input_size),
neck=dict(
out_channels=(512, 1024, 512, 256, 256, 256, 256),
level_strides=(2, 2, 2, 2, 1),
level_paddings=(1, 1, 1, 1, 1),
last_kernel_size=4),
bbox_head=dict(
in_channels=(512, 1024, 512, 256, 256, 256, 256),
anchor_generator=dict(

@ -1,5 +1,18 @@
# Compatibility of MMDetection 2.x
## MMDetection 2.14.0
### SSD compatibility
In v2.14.0, to make SSD more flexible to use, [PR5291](https://github.com/open-mmlab/mmdetection/pull/5291) refactored its backbone, neck and head. The users can use the script `tools/model_converters/upgrade_ssd_version.py` to convert their models.
```bash
python tools/model_converters/upgrade_ssd_version.py ${OLD_MODEL_PATH} ${NEW_MODEL_PATH}
```
- OLD_MODEL_PATH: the path to load the old version SSD model.
- NEW_MODEL_PATH: the path to save the converted model weights.
## MMDetection 2.12.0
MMDetection is going through big refactoring for more general and convenient usages during the releases from v2.12.0 to v2.15.0 (maybe longer).

@ -1,12 +1,11 @@
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import VGG
from mmcv.runner import BaseModule, Sequential
from mmcv.runner import BaseModule
from ..builder import BACKBONES
from ..necks import ssd_neck
@BACKBONES.register_module()
@ -14,12 +13,20 @@ class SSDVGG(VGG, BaseModule):
"""VGG Backbone network for single-shot-detection.
Args:
input_size (int): width and height of input, from {300, 512}.
depth (int): Depth of vgg, from {11, 13, 16, 19}.
with_last_pool (bool): Whether to add a pooling layer at the last
of the model
ceil_mode (bool): When True, will use `ceil` instead of `floor`
to compute the output shape.
out_indices (Sequence[int]): Output from which stages.
out_feature_indices (Sequence[int]): Output from which feature map.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
input_size (int, optional): Deprecated argumment.
Width and height of input, from {300, 512}.
l2_norm_scale (float, optional) : Deprecated argumment.
L2 normalization layer init scale.
Example:
>>> self = SSDVGG(input_size=300, depth=11)
@ -40,23 +47,21 @@ class SSDVGG(VGG, BaseModule):
}
def __init__(self,
input_size,
depth,
with_last_pool=False,
ceil_mode=True,
out_indices=(3, 4),
out_feature_indices=(22, 34),
l2_norm_scale=20.,
pretrained=None,
init_cfg=None):
init_cfg=None,
input_size=None,
l2_norm_scale=None):
# TODO: in_channels for mmcv.VGG
super(SSDVGG, self).__init__(
depth,
with_last_pool=with_last_pool,
ceil_mode=ceil_mode,
out_indices=out_indices)
assert input_size in (300, 512)
self.input_size = input_size
self.features.add_module(
str(len(self.features)),
@ -72,18 +77,17 @@ class SSDVGG(VGG, BaseModule):
str(len(self.features)), nn.ReLU(inplace=True))
self.out_feature_indices = out_feature_indices
self.inplanes = 1024
self.extra = self._make_extra_layers(self.extra_setting[input_size])
self.l2_norm = L2Norm(
self.features[out_feature_indices[0] - 1].out_channels,
l2_norm_scale)
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = [dict(type='Pretrained', checkpoint=pretrained)]
if input_size is not None:
warnings.warn('DeprecationWarning: input_size is deprecated')
if l2_norm_scale is not None:
warnings.warn('DeprecationWarning: l2_norm_scale in VGG is '
'deprecated, it has been moved to SSDNeck.')
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
@ -94,18 +98,6 @@ class SSDVGG(VGG, BaseModule):
else:
raise TypeError('pretrained must be a str or None')
if init_cfg is None:
self.init_cfg += [
dict(
type='Xavier',
distribution='uniform',
override=dict(name='extra')),
dict(
type='Constant',
val=self.l2_norm.scale,
override=dict(name='l2_norm'))
]
def init_weights(self, pretrained=None):
super(VGG, self).init_weights()
@ -116,64 +108,17 @@ class SSDVGG(VGG, BaseModule):
x = layer(x)
if i in self.out_feature_indices:
outs.append(x)
for i, layer in enumerate(self.extra):
x = F.relu(layer(x), inplace=True)
if i % 2 == 1:
outs.append(x)
outs[0] = self.l2_norm(outs[0])
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
def _make_extra_layers(self, outplanes):
layers = []
kernel_sizes = (1, 3)
num_layers = 0
outplane = None
for i in range(len(outplanes)):
if self.inplanes == 'S':
self.inplanes = outplane
continue
k = kernel_sizes[num_layers % 2]
if outplanes[i] == 'S':
outplane = outplanes[i + 1]
conv = nn.Conv2d(
self.inplanes, outplane, k, stride=2, padding=1)
else:
outplane = outplanes[i]
conv = nn.Conv2d(
self.inplanes, outplane, k, stride=1, padding=0)
layers.append(conv)
self.inplanes = outplanes[i]
num_layers += 1
if self.input_size == 512:
layers.append(nn.Conv2d(self.inplanes, 256, 4, padding=1))
return Sequential(*layers)
class L2Norm(ssd_neck.L2Norm):
class L2Norm(nn.Module):
def __init__(self, n_dims, scale=20., eps=1e-10):
"""L2 normalization layer.
Args:
n_dims (int): Number of dimensions to be normalized
scale (float, optional): Defaults to 20..
eps (float, optional): Used to avoid division by zero.
Defaults to 1e-10.
"""
super(L2Norm, self).__init__()
self.n_dims = n_dims
self.weight = nn.Parameter(torch.Tensor(self.n_dims))
self.eps = eps
self.scale = scale
def forward(self, x):
"""Forward function."""
# normalization layer convert to FP32 in FP16 training
x_float = x.float()
norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps
return (self.weight[None, :, None, None].float().expand_as(x_float) *
x_float / norm).type_as(x)
def __init__(self, **kwargs):
super(L2Norm, self).__init__(**kwargs)
warnings.warn('DeprecationWarning: L2Norm in ssd_vgg.py '
'is deprecated, please use L2Norm in '
'mmdet/models/necks/ssd_neck.py instead')

@ -1,7 +1,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import ModuleList, force_fp32
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmcv.runner import force_fp32
from mmdet.core import (build_anchor_generator, build_assigner,
build_bbox_coder, build_sampler, multi_apply)
@ -19,6 +20,18 @@ class SSDHead(AnchorHead):
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
stacked_convs (int): Number of conv layers in cls and reg tower.
Default: 0.
feat_channels (int): Number of hidden channels when stacked_convs
> 0. Default: 256.
use_depthwise (bool): Whether to use DepthwiseSeparableConv.
Default: False.
conv_cfg (dict): Dictionary to construct and config conv layer.
Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: None.
act_cfg (dict): Dictionary to construct and config activation layer.
Default: None.
anchor_generator (dict): Config dict for anchor generator
bbox_coder (dict): Config of bounding box coder.
reg_decoded_bbox (bool): If true, the regression loss would be
@ -34,6 +47,12 @@ class SSDHead(AnchorHead):
def __init__(self,
num_classes=80,
in_channels=(512, 1024, 512, 256, 256, 256),
stacked_convs=0,
feat_channels=256,
use_depthwise=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=None,
anchor_generator=dict(
type='SSDAnchorGenerator',
scale_major=False,
@ -58,27 +77,18 @@ class SSDHead(AnchorHead):
super(AnchorHead, self).__init__(init_cfg)
self.num_classes = num_classes
self.in_channels = in_channels
self.stacked_convs = stacked_convs
self.feat_channels = feat_channels
self.use_depthwise = use_depthwise
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.cls_out_channels = num_classes + 1 # add background class
self.anchor_generator = build_anchor_generator(anchor_generator)
num_anchors = self.anchor_generator.num_base_anchors
self.num_anchors = self.anchor_generator.num_base_anchors
reg_convs = []
cls_convs = []
for i in range(len(in_channels)):
reg_convs.append(
nn.Conv2d(
in_channels[i],
num_anchors[i] * 4,
kernel_size=3,
padding=1))
cls_convs.append(
nn.Conv2d(
in_channels[i],
num_anchors[i] * (num_classes + 1),
kernel_size=3,
padding=1))
self.reg_convs = ModuleList(reg_convs)
self.cls_convs = ModuleList(cls_convs)
self._init_layers()
self.bbox_coder = build_bbox_coder(bbox_coder)
self.reg_decoded_bbox = reg_decoded_bbox
@ -95,6 +105,76 @@ class SSDHead(AnchorHead):
self.sampler = build_sampler(sampler_cfg, context=self)
self.fp16_enabled = False
def _init_layers(self):
"""Initialize layers of the head."""
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
# TODO: Use registry to choose ConvModule type
conv = DepthwiseSeparableConvModule \
if self.use_depthwise else ConvModule
for channel, num_anchors in zip(self.in_channels, self.num_anchors):
cls_layers = []
reg_layers = []
in_channel = channel
# build stacked conv tower, not used in default ssd
for i in range(self.stacked_convs):
cls_layers.append(
conv(
in_channel,
self.feat_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
reg_layers.append(
conv(
in_channel,
self.feat_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
in_channel = self.feat_channels
# SSD-Lite head
if self.use_depthwise:
cls_layers.append(
ConvModule(
in_channel,
in_channel,
3,
padding=1,
groups=in_channel,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
reg_layers.append(
ConvModule(
in_channel,
in_channel,
3,
padding=1,
groups=in_channel,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
cls_layers.append(
nn.Conv2d(
in_channel,
num_anchors * self.cls_out_channels,
kernel_size=1 if self.use_depthwise else 3,
padding=0 if self.use_depthwise else 1))
reg_layers.append(
nn.Conv2d(
in_channel,
num_anchors * 4,
kernel_size=1 if self.use_depthwise else 3,
padding=0 if self.use_depthwise else 1))
self.cls_convs.append(nn.Sequential(*cls_layers))
self.reg_convs.append(nn.Sequential(*reg_layers))
def forward(self, feats):
"""Forward features from the upstream network.

@ -10,9 +10,11 @@ from .nas_fpn import NASFPN
from .nasfcos_fpn import NASFCOS_FPN
from .pafpn import PAFPN
from .rfp import RFP
from .ssd_neck import SSDNeck
from .yolo_neck import YOLOV3Neck
__all__ = [
'FPN', 'BFP', 'ChannelMapper', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN',
'NASFCOS_FPN', 'RFP', 'YOLOV3Neck', 'FPG', 'DilatedEncoder', 'CTResNetNeck'
'NASFCOS_FPN', 'RFP', 'YOLOV3Neck', 'FPG', 'DilatedEncoder',
'CTResNetNeck', 'SSDNeck'
]

@ -0,0 +1,128 @@
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmcv.runner import BaseModule
from ..builder import NECKS
@NECKS.register_module()
class SSDNeck(BaseModule):
"""Extra layers of SSD backbone to generate multi-scale feature maps.
Args:
in_channels (Sequence[int]): Number of input channels per scale.
out_channels (Sequence[int]): Number of output channels per scale.
level_strides (Sequence[int]): Stride of 3x3 conv per level.
level_paddings (Sequence[int]): Padding size of 3x3 conv per level.
l2_norm_scale (float|None): L2 normalization layer init scale.
If None, not use L2 normalization on the first input feature.
last_kernel_size (int): Kernel size of the last conv layer.
Default: 3.
use_depthwise (bool): Whether to use DepthwiseSeparableConv.
Default: False.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: None.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
in_channels,
out_channels,
level_strides,
level_paddings,
l2_norm_scale=20.,
last_kernel_size=3,
use_depthwise=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
init_cfg=[
dict(
type='Xavier', distribution='uniform',
layer='Conv2d'),
dict(type='Constant', val=1, layer='BatchNorm2d'),
]):
super(SSDNeck, self).__init__(init_cfg)
assert len(out_channels) > len(in_channels)
assert len(out_channels) - len(in_channels) == len(level_strides)
assert len(level_strides) == len(level_paddings)
assert in_channels == out_channels[:len(in_channels)]
if l2_norm_scale:
self.l2_norm = L2Norm(in_channels[0], l2_norm_scale)
self.init_cfg += [
dict(
type='Constant',
val=self.l2_norm.scale,
override=dict(name='l2_norm'))
]
self.extra_layers = nn.ModuleList()
extra_layer_channels = out_channels[len(in_channels):]
second_conv = DepthwiseSeparableConvModule if \
use_depthwise else ConvModule
for i, (out_channel, stride, padding) in enumerate(
zip(extra_layer_channels, level_strides, level_paddings)):
kernel_size = last_kernel_size \
if i == len(extra_layer_channels) - 1 else 3
per_lvl_convs = nn.Sequential(
ConvModule(
out_channels[len(in_channels) - 1 + i],
out_channel // 2,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
second_conv(
out_channel // 2,
out_channel,
kernel_size,
stride=stride,
padding=padding,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.extra_layers.append(per_lvl_convs)
def forward(self, inputs):
"""Forward function."""
outs = [feat for feat in inputs]
if hasattr(self, 'l2_norm'):
outs[0] = self.l2_norm(outs[0])
feat = outs[-1]
for layer in self.extra_layers:
feat = layer(feat)
outs.append(feat)
return tuple(outs)
class L2Norm(nn.Module):
def __init__(self, n_dims, scale=20., eps=1e-10):
"""L2 normalization layer.
Args:
n_dims (int): Number of dimensions to be normalized
scale (float, optional): Defaults to 20..
eps (float, optional): Used to avoid division by zero.
Defaults to 1e-10.
"""
super(L2Norm, self).__init__()
self.n_dims = n_dims
self.weight = nn.Parameter(torch.Tensor(self.n_dims))
self.eps = eps
self.scale = scale
def forward(self, x):
"""Forward function."""
# normalization layer convert to FP32 in FP16 training
x_float = x.float()
norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps
return (self.weight[None, :, None, None].float().expand_as(x_float) *
x_float / norm).type_as(x)

@ -3,7 +3,7 @@ import torch
from torch.nn.modules.batchnorm import _BatchNorm
from mmdet.models.necks import (FPN, ChannelMapper, CTResNetNeck,
DilatedEncoder, YOLOV3Neck)
DilatedEncoder, SSDNeck, YOLOV3Neck)
def test_fpn():
@ -338,3 +338,70 @@ def test_yolov3_neck():
for i in range(len(outs)):
assert outs[i].shape == \
(1, out_channels[i], feat_sizes[i], feat_sizes[i])
def test_ssd_neck():
# level_strides/level_paddings must be same length
with pytest.raises(AssertionError):
SSDNeck(
in_channels=[8, 16],
out_channels=[8, 16, 32],
level_strides=[2],
level_paddings=[2, 1])
# length of out_channels must larger than in_channels
with pytest.raises(AssertionError):
SSDNeck(
in_channels=[8, 16],
out_channels=[8],
level_strides=[2],
level_paddings=[2])
# len(out_channels) - len(in_channels) must equal to len(level_strides)
with pytest.raises(AssertionError):
SSDNeck(
in_channels=[8, 16],
out_channels=[4, 16, 64],
level_strides=[2, 2],
level_paddings=[2, 2])
# in_channels must be same with out_channels[:len(in_channels)]
with pytest.raises(AssertionError):
SSDNeck(
in_channels=[8, 16],
out_channels=[4, 16, 64],
level_strides=[2],
level_paddings=[2])
ssd_neck = SSDNeck(
in_channels=[4],
out_channels=[4, 8, 16],
level_strides=[2, 1],
level_paddings=[1, 0])
feats = (torch.rand(1, 4, 16, 16), )
outs = ssd_neck(feats)
assert outs[0].shape == (1, 4, 16, 16)
assert outs[1].shape == (1, 8, 8, 8)
assert outs[2].shape == (1, 16, 6, 6)
# test SSD-Lite Neck
ssd_neck = SSDNeck(
in_channels=[4, 8],
out_channels=[4, 8, 16],
level_strides=[1],
level_paddings=[1],
l2_norm_scale=None,
use_depthwise=True,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6'))
assert not hasattr(ssd_neck, 'l2_norm')
from mmcv.cnn.bricks import DepthwiseSeparableConvModule
assert isinstance(ssd_neck.extra_layers[0][-1],
DepthwiseSeparableConvModule)
feats = (torch.rand(1, 4, 8, 8), torch.rand(1, 8, 8, 8))
outs = ssd_neck(feats)
assert outs[0].shape == (1, 4, 8, 8)
assert outs[1].shape == (1, 8, 8, 8)
assert outs[2].shape == (1, 16, 8, 8)

@ -0,0 +1,57 @@
import argparse
import tempfile
from collections import OrderedDict
import torch
from mmcv import Config
def parse_config(config_strings):
temp_file = tempfile.NamedTemporaryFile()
config_path = f'{temp_file.name}.py'
with open(config_path, 'w') as f:
f.write(config_strings)
config = Config.fromfile(config_path)
# check whether it is SSD
if config.model.bbox_head.type != 'SSDHead':
raise AssertionError('This is not a SSD model.')
def convert(in_file, out_file):
checkpoint = torch.load(in_file)
in_state_dict = checkpoint.pop('state_dict')
out_state_dict = OrderedDict()
meta_info = checkpoint['meta']
parse_config('#' + meta_info['config'])
for key, value in in_state_dict.items():
if 'extra' in key:
layer_idx = int(key.split('.')[2])
new_key = 'neck.extra_layers.{}.{}.conv.'.format(
layer_idx // 2, layer_idx % 2) + key.split('.')[-1]
elif 'l2_norm' in key:
new_key = 'neck.l2_norm.weight'
elif 'bbox_head' in key:
new_key = key[:21] + '.0' + key[21:]
else:
new_key = key
out_state_dict[new_key] = value
checkpoint['state_dict'] = out_state_dict
if torch.__version__ >= '1.6':
torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False)
else:
torch.save(checkpoint, out_file)
def main():
parser = argparse.ArgumentParser(description='Upgrade SSD version')
parser.add_argument('in_file', help='input checkpoint file')
parser.add_argument('out_file', help='output checkpoint file')
args = parser.parse_args()
convert(args.in_file, args.out_file)
if __name__ == '__main__':
main()
Loading…
Cancel
Save