[Feature] Support ConvNeXt (#7281)

* update

* update

* fix init_cfg

* update

* update

* update

* update

* update1

* final

* update

* update

* fix lint

* fix backbone config

* update cascade_mask_rcnn

* update and fix lint

* update

* fix DefaultOptimizerConstructor error

* update

* update

* update

* fix year

* update

* fix lint
pull/7635/merge
Haian Huang(深度眸) 3 years ago committed by GitHub
parent a828499d28
commit 1fd48f7318
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      README.md
  2. 1
      README_zh-CN.md
  3. 40
      configs/convnext/README.md
  4. 32
      configs/convnext/cascade_mask_rcnn_convnext-s_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco.py
  5. 149
      configs/convnext/cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco.py
  6. 90
      configs/convnext/mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco.py
  7. 93
      configs/convnext/metafile.yml
  8. 6
      mmdet/apis/train.py
  9. 1
      mmdet/core/__init__.py
  10. 9
      mmdet/core/optimizers/__init__.py
  11. 33
      mmdet/core/optimizers/builder.py
  12. 154
      mmdet/core/optimizers/layer_decay_optimizer_constructor.py
  13. 1
      model-index.yml
  14. 164
      tests/test_utils/test_layer_decay_optimizer_constructor.py

@ -231,6 +231,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
<li><a href="configs/pvt">PVTv2 (ArXiv'2021)</a></li>
<li><a href="configs/resnet_strikes_back">ResNet strikes back (ArXiv'2021)</a></li>
<li><a href="configs/efficientnet">EfficientNet (ArXiv'2021)</a></li>
<li><a href="configs/convnext">ConvNeXt (CVPR'2022)</a></li>
</ul>
</td>
<td>

@ -230,6 +230,7 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope
<li><a href="configs/pvt">PVTv2 (ArXiv'2021)</a></li>
<li><a href="configs/resnet_strikes_back">ResNet strikes back (ArXiv'2021)</a></li>
<li><a href="configs/efficientnet">EfficientNet (ArXiv'2021)</a></li>
<li><a href="configs/convnext">ConvNeXt (CVPR'2022)</a></li>
</ul>
</td>
<td>

@ -0,0 +1,40 @@
# ConvNeXt
> [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545)
## Abstract
The "Roaring 20s" of visual recognition began with the introduction of Vision Transformers (ViTs), which quickly superseded ConvNets as the state-of-the-art image classification model. A vanilla ViT, on the other hand, faces difficulties when applied to general computer vision tasks such as object detection and semantic segmentation. It is the hierarchical Transformers (e.g., Swin Transformers) that reintroduced several ConvNet priors, making Transformers practically viable as a generic vision backbone and demonstrating remarkable performance on a wide variety of vision tasks. However, the effectiveness of such hybrid approaches is still largely credited to the intrinsic superiority of Transformers, rather than the inherent inductive biases of convolutions. In this work, we reexamine the design spaces and test the limits of what a pure ConvNet can achieve. We gradually "modernize" a standard ResNet toward the design of a vision Transformer, and discover several key components that contribute to the performance difference along the way. The outcome of this exploration is a family of pure ConvNet models dubbed ConvNeXt. Constructed entirely from standard ConvNet modules, ConvNeXts compete favorably with Transformers in terms of accuracy and scalability, achieving 87.8% ImageNet top-1 accuracy and outperforming Swin Transformers on COCO detection and ADE20K segmentation, while maintaining the simplicity and efficiency of standard ConvNets.
<div align=center>
<img src="https://user-images.githubusercontent.com/8370623/148624004-e9581042-ea4d-4e10-b3bd-42c92b02053b.png" width="90%"/>
</div>
## Results and models
| Method | Backbone | Pretrain | Lr schd | Multi-scale crop | FP16 | Mem (GB) | box AP | mask AP | Config | Download |
| :----------------: | :--------: | :---------: | :-----: | :--------------: | :--: | :------: | :----: | :-----: | :-------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| Mask R-CNN | ConvNeXt-T | ImageNet-1K | 3x | yes | yes | 7.3 | 46.2 | 41.7 | [config](./mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/convnext/mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco/mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco_20220426_154953-050731f4.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/convnext/mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco/mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco_20220426_154953.log.json) |
| Cascade Mask R-CNN | ConvNeXt-T | ImageNet-1K | 3x | yes | yes | 9.0 | 50.3 | 43.6 | [config](./cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/convnext/cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco/cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco_20220509_204200-8f07c40b.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/convnext/cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco/cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco_20220509_204200.log.json) |
| Cascade Mask R-CNN | ConvNeXt-S | ImageNet-1K | 3x | yes | yes | 12.3 | 51.8 | 44.8 | [config](./cascade_mask_rcnn_convnext-s_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/convnext/cascade_mask_rcnn_convnext-s_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco/cascade_mask_rcnn_convnext-s_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco_20220510_201004-3d24f5a4.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/convnext/cascade_mask_rcnn_convnext-s_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco/cascade_mask_rcnn_convnext-s_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco_20220510_201004.log.json) |
**Note**:
- ConvNeXt backbone needs to install [MMClassification](https://github.com/open-mmlab/mmclassification) first, which has abundant backbones for downstream tasks.
```shell
pip install mmcls>=0.22.0
```
- The performance is unstable. `Cascade Mask R-CNN` may fluctuate about 0.2 mAP.
## Citation
```bibtex
@article{liu2022convnet,
title={A ConvNet for the 2020s},
author={Liu, Zhuang and Mao, Hanzi and Wu, Chao-Yuan and Feichtenhofer, Christoph and Darrell, Trevor and Xie, Saining},
journal={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2022}
}
```

@ -0,0 +1,32 @@
_base_ = './cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco.py' # noqa
# please install mmcls>=0.22.0
# import mmcls.models to trigger register_module in mmcls
custom_imports = dict(imports=['mmcls.models'], allow_failed_imports=False)
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-small_3rdparty_32xb128-noema_in1k_20220301-303e75e3.pth' # noqa
model = dict(
backbone=dict(
_delete_=True,
type='mmcls.ConvNeXt',
arch='small',
out_indices=[0, 1, 2, 3],
drop_path_rate=0.6,
layer_scale_init_value=1.0,
gap_before_final_norm=False,
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_file,
prefix='backbone.')))
optimizer = dict(
_delete_=True,
constructor='LearningRateDecayOptimizerConstructor',
type='AdamW',
lr=0.0002,
betas=(0.9, 0.999),
weight_decay=0.05,
paramwise_cfg={
'decay_rate': 0.7,
'decay_type': 'layer_wise',
'num_layers': 12
})

@ -0,0 +1,149 @@
_base_ = [
'../_base_/models/cascade_mask_rcnn_r50_fpn.py',
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
# please install mmcls>=0.22.0
# import mmcls.models to trigger register_module in mmcls
custom_imports = dict(imports=['mmcls.models'], allow_failed_imports=False)
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-tiny_3rdparty_32xb128-noema_in1k_20220301-795e9634.pth' # noqa
model = dict(
backbone=dict(
_delete_=True,
type='mmcls.ConvNeXt',
arch='tiny',
out_indices=[0, 1, 2, 3],
drop_path_rate=0.4,
layer_scale_init_value=1.0,
gap_before_final_norm=False,
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_file,
prefix='backbone.')),
neck=dict(in_channels=[96, 192, 384, 768]),
roi_head=dict(bbox_head=[
dict(
type='ConvFCBBoxHead',
num_shared_convs=4,
num_shared_fcs=1,
in_channels=256,
conv_out_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
reg_decoded_bbox=True,
norm_cfg=dict(type='SyncBN', requires_grad=True),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
dict(
type='ConvFCBBoxHead',
num_shared_convs=4,
num_shared_fcs=1,
in_channels=256,
conv_out_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.05, 0.05, 0.1, 0.1]),
reg_class_agnostic=False,
reg_decoded_bbox=True,
norm_cfg=dict(type='SyncBN', requires_grad=True),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
dict(
type='ConvFCBBoxHead',
num_shared_convs=4,
num_shared_fcs=1,
in_channels=256,
conv_out_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.033, 0.033, 0.067, 0.067]),
reg_class_agnostic=False,
reg_decoded_bbox=True,
norm_cfg=dict(type='SyncBN', requires_grad=True),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=10.0))
]))
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
# augmentation strategy originates from DETR / Sparse RCNN
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(
type='AutoAugment',
policies=[[
dict(
type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
multiscale_mode='value',
keep_ratio=True)
],
[
dict(
type='Resize',
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
multiscale_mode='value',
keep_ratio=True),
dict(
type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333),
(576, 1333), (608, 1333), (640, 1333),
(672, 1333), (704, 1333), (736, 1333),
(768, 1333), (800, 1333)],
multiscale_mode='value',
override=True,
keep_ratio=True)
]]),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
data = dict(train=dict(pipeline=train_pipeline), persistent_workers=True)
optimizer = dict(
_delete_=True,
constructor='LearningRateDecayOptimizerConstructor',
type='AdamW',
lr=0.0002,
betas=(0.9, 0.999),
weight_decay=0.05,
paramwise_cfg={
'decay_rate': 0.7,
'decay_type': 'layer_wise',
'num_layers': 6
})
lr_config = dict(warmup_iters=1000, step=[27, 33])
runner = dict(max_epochs=36)
# you need to set mode='dynamic' if you are using pytorch<=1.5.0
fp16 = dict(loss_scale=dict(init_scale=512))

@ -0,0 +1,90 @@
_base_ = [
'../_base_/models/mask_rcnn_r50_fpn.py',
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
# please install mmcls>=0.22.0
# import mmcls.models to trigger register_module in mmcls
custom_imports = dict(imports=['mmcls.models'], allow_failed_imports=False)
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-tiny_3rdparty_32xb128-noema_in1k_20220301-795e9634.pth' # noqa
model = dict(
backbone=dict(
_delete_=True,
type='mmcls.ConvNeXt',
arch='tiny',
out_indices=[0, 1, 2, 3],
drop_path_rate=0.4,
layer_scale_init_value=1.0,
gap_before_final_norm=False,
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_file,
prefix='backbone.')),
neck=dict(in_channels=[96, 192, 384, 768]))
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
# augmentation strategy originates from DETR / Sparse RCNN
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(
type='AutoAugment',
policies=[[
dict(
type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
multiscale_mode='value',
keep_ratio=True)
],
[
dict(
type='Resize',
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
multiscale_mode='value',
keep_ratio=True),
dict(
type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333),
(576, 1333), (608, 1333), (640, 1333),
(672, 1333), (704, 1333), (736, 1333),
(768, 1333), (800, 1333)],
multiscale_mode='value',
override=True,
keep_ratio=True)
]]),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
data = dict(train=dict(pipeline=train_pipeline), persistent_workers=True)
optimizer = dict(
_delete_=True,
constructor='LearningRateDecayOptimizerConstructor',
type='AdamW',
lr=0.0001,
betas=(0.9, 0.999),
weight_decay=0.05,
paramwise_cfg={
'decay_rate': 0.95,
'decay_type': 'layer_wise',
'num_layers': 6
})
lr_config = dict(warmup_iters=1000, step=[27, 33])
runner = dict(max_epochs=36)
# you need to set mode='dynamic' if you are using pytorch<=1.5.0
fp16 = dict(loss_scale=dict(init_scale=512))

@ -0,0 +1,93 @@
Models:
- Name: mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco
In Collection: Mask R-CNN
Config: configs/convnext/mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco
Metadata:
Training Memory (GB): 7.3
Epochs: 36
Training Data: COCO
Training Techniques:
- AdamW
- Mixed Precision Training
Training Resources: 8x A100 GPUs
Architecture:
- ConvNeXt
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 46.2
- Task: Instance Segmentation
Dataset: COCO
Metrics:
mask AP: 41.7
Weights: https://download.openmmlab.com/mmdetection/v2.0/convnext/mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco/mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco_20220426_154953-050731f4.pth
Paper:
URL: https://arxiv.org/abs/2201.03545
Title: 'A ConvNet for the 2020s'
README: configs/convnext/README.md
Code:
URL: https://github.com/open-mmlab/mmdetection/blob/v2.16.0/mmdet/models/backbones/swin.py#L465
Version: v2.16.0
- Name: cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco
In Collection: Cascade Mask R-CNN
Config: configs/convnext/cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco.py
Metadata:
Training Memory (GB): 9.0
Epochs: 36
Training Data: COCO
Training Techniques:
- AdamW
- Mixed Precision Training
Training Resources: 8x A100 GPUs
Architecture:
- ConvNeXt
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 50.3
- Task: Instance Segmentation
Dataset: COCO
Metrics:
mask AP: 43.6
Weights: https://download.openmmlab.com/mmdetection/v2.0/convnext/cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco/cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco_20220509_204200-8f07c40b.pth
Paper:
URL: https://arxiv.org/abs/2201.03545
Title: 'A ConvNet for the 2020s'
README: configs/convnext/README.md
Code:
URL: https://github.com/open-mmlab/mmdetection/blob/v2.16.0/mmdet/models/backbones/swin.py#L465
Version: v2.25.0
- Name: cascade_mask_rcnn_convnext-s_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco
In Collection: Cascade Mask R-CNN
Config: configs/convnext/cascade_mask_rcnn_convnext-s_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco.py
Metadata:
Training Memory (GB): 12.3
Epochs: 36
Training Data: COCO
Training Techniques:
- AdamW
- Mixed Precision Training
Training Resources: 8x A100 GPUs
Architecture:
- ConvNeXt
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 51.8
- Task: Instance Segmentation
Dataset: COCO
Metrics:
mask AP: 44.8
Weights: https://download.openmmlab.com/mmdetection/v2.0/convnext/cascade_mask_rcnn_convnext-s_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco/cascade_mask_rcnn_convnext-s_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco_20220510_201004-3d24f5a4.pth
Paper:
URL: https://arxiv.org/abs/2201.03545
Title: 'A ConvNet for the 2020s'
README: configs/convnext/README.md
Code:
URL: https://github.com/open-mmlab/mmdetection/blob/v2.16.0/mmdet/models/backbones/swin.py#L465
Version: v2.25.0

@ -6,10 +6,10 @@ import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner,
Fp16OptimizerHook, OptimizerHook, build_optimizer,
build_runner, get_dist_info)
Fp16OptimizerHook, OptimizerHook, build_runner,
get_dist_info)
from mmdet.core import DistEvalHook, EvalHook
from mmdet.core import DistEvalHook, EvalHook, build_optimizer
from mmdet.datasets import (build_dataloader, build_dataset,
replace_ImageToTensor)
from mmdet.utils import (build_ddp, build_dp, compat_cfg,

@ -5,5 +5,6 @@ from .data_structures import * # noqa: F401, F403
from .evaluation import * # noqa: F401, F403
from .hook import * # noqa: F401, F403
from .mask import * # noqa: F401, F403
from .optimizers import * # noqa: F401, F403
from .post_processing import * # noqa: F401, F403
from .utils import * # noqa: F401, F403

@ -0,0 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import OPTIMIZER_BUILDERS, build_optimizer
from .layer_decay_optimizer_constructor import \
LearningRateDecayOptimizerConstructor
__all__ = [
'LearningRateDecayOptimizerConstructor', 'OPTIMIZER_BUILDERS',
'build_optimizer'
]

@ -0,0 +1,33 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS
from mmcv.utils import Registry, build_from_cfg
OPTIMIZER_BUILDERS = Registry(
'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS)
def build_optimizer_constructor(cfg):
constructor_type = cfg.get('type')
if constructor_type in OPTIMIZER_BUILDERS:
return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
elif constructor_type in MMCV_OPTIMIZER_BUILDERS:
return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS)
else:
raise KeyError(f'{constructor_type} is not registered '
'in the optimizer builder registry.')
def build_optimizer(model, cfg):
optimizer_cfg = copy.deepcopy(cfg)
constructor_type = optimizer_cfg.pop('constructor',
'DefaultOptimizerConstructor')
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
optim_constructor = build_optimizer_constructor(
dict(
type=constructor_type,
optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg))
optimizer = optim_constructor(model)
return optimizer

@ -0,0 +1,154 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
from mmdet.utils import get_root_logger
from .builder import OPTIMIZER_BUILDERS
def get_layer_id_for_convnext(var_name, max_layer_id):
"""Get the layer id to set the different learning rates in ``layer_wise``
decay_type.
Args:
var_name (str): The key of the model.
max_layer_id (int): Maximum layer id.
Returns:
int: The id number corresponding to different learning rate in
``LearningRateDecayOptimizerConstructor``.
"""
if var_name in ('backbone.cls_token', 'backbone.mask_token',
'backbone.pos_embed'):
return 0
elif var_name.startswith('backbone.downsample_layers'):
stage_id = int(var_name.split('.')[2])
if stage_id == 0:
layer_id = 0
elif stage_id == 1:
layer_id = 2
elif stage_id == 2:
layer_id = 3
elif stage_id == 3:
layer_id = max_layer_id
return layer_id
elif var_name.startswith('backbone.stages'):
stage_id = int(var_name.split('.')[2])
block_id = int(var_name.split('.')[3])
if stage_id == 0:
layer_id = 1
elif stage_id == 1:
layer_id = 2
elif stage_id == 2:
layer_id = 3 + block_id // 3
elif stage_id == 3:
layer_id = max_layer_id
return layer_id
else:
return max_layer_id + 1
def get_stage_id_for_convnext(var_name, max_stage_id):
"""Get the stage id to set the different learning rates in ``stage_wise``
decay_type.
Args:
var_name (str): The key of the model.
max_stage_id (int): Maximum stage id.
Returns:
int: The id number corresponding to different learning rate in
``LearningRateDecayOptimizerConstructor``.
"""
if var_name in ('backbone.cls_token', 'backbone.mask_token',
'backbone.pos_embed'):
return 0
elif var_name.startswith('backbone.downsample_layers'):
return 0
elif var_name.startswith('backbone.stages'):
stage_id = int(var_name.split('.')[2])
return stage_id + 1
else:
return max_stage_id - 1
@OPTIMIZER_BUILDERS.register_module()
class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
# Different learning rates are set for different layers of backbone.
# Note: Currently, this optimizer constructor is built for ConvNeXt.
def add_params(self, params, module, **kwargs):
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg.
Args:
params (list[dict]): A list of param groups, it will be modified
in place.
module (nn.Module): The module to be added.
"""
logger = get_root_logger()
parameter_groups = {}
logger.info(f'self.paramwise_cfg is {self.paramwise_cfg}')
num_layers = self.paramwise_cfg.get('num_layers') + 2
decay_rate = self.paramwise_cfg.get('decay_rate')
decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise')
logger.info('Build LearningRateDecayOptimizerConstructor '
f'{decay_type} {decay_rate} - {num_layers}')
weight_decay = self.base_wd
for name, param in module.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith('.bias') or name in (
'pos_embed', 'cls_token'):
group_name = 'no_decay'
this_weight_decay = 0.
else:
group_name = 'decay'
this_weight_decay = weight_decay
if 'layer_wise' in decay_type:
if 'ConvNeXt' in module.backbone.__class__.__name__:
layer_id = get_layer_id_for_convnext(
name, self.paramwise_cfg.get('num_layers'))
logger.info(f'set param {name} as id {layer_id}')
else:
raise NotImplementedError()
elif decay_type == 'stage_wise':
if 'ConvNeXt' in module.backbone.__class__.__name__:
layer_id = get_stage_id_for_convnext(name, num_layers)
logger.info(f'set param {name} as id {layer_id}')
else:
raise NotImplementedError()
group_name = f'layer_{layer_id}_{group_name}'
if group_name not in parameter_groups:
scale = decay_rate**(num_layers - layer_id - 1)
parameter_groups[group_name] = {
'weight_decay': this_weight_decay,
'params': [],
'param_names': [],
'lr_scale': scale,
'group_name': group_name,
'lr': scale * self.base_lr,
}
parameter_groups[group_name]['params'].append(param)
parameter_groups[group_name]['param_names'].append(name)
rank, _ = get_dist_info()
if rank == 0:
to_display = {}
for key in parameter_groups:
to_display[key] = {
'param_names': parameter_groups[key]['param_names'],
'lr_scale': parameter_groups[key]['lr_scale'],
'lr': parameter_groups[key]['lr'],
'weight_decay': parameter_groups[key]['weight_decay'],
}
logger.info(f'Param groups = {json.dumps(to_display, indent=2)}')
params.extend(parameter_groups.values())

@ -7,6 +7,7 @@ Import:
- configs/centernet/metafile.yml
- configs/centripetalnet/metafile.yml
- configs/cornernet/metafile.yml
- configs/convnext/metafile.yml
- configs/dcn/metafile.yml
- configs/dcnv2/metafile.yml
- configs/deformable_detr/metafile.yml

@ -0,0 +1,164 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmdet.core.optimizers import LearningRateDecayOptimizerConstructor
base_lr = 1
decay_rate = 2
base_wd = 0.05
weight_decay = 0.05
expected_stage_wise_lr_wd_convnext = [{
'weight_decay': 0.0,
'lr_scale': 128
}, {
'weight_decay': 0.0,
'lr_scale': 1
}, {
'weight_decay': 0.05,
'lr_scale': 64
}, {
'weight_decay': 0.0,
'lr_scale': 64
}, {
'weight_decay': 0.05,
'lr_scale': 32
}, {
'weight_decay': 0.0,
'lr_scale': 32
}, {
'weight_decay': 0.05,
'lr_scale': 16
}, {
'weight_decay': 0.0,
'lr_scale': 16
}, {
'weight_decay': 0.05,
'lr_scale': 8
}, {
'weight_decay': 0.0,
'lr_scale': 8
}, {
'weight_decay': 0.05,
'lr_scale': 128
}, {
'weight_decay': 0.05,
'lr_scale': 1
}]
expected_layer_wise_lr_wd_convnext = [{
'weight_decay': 0.0,
'lr_scale': 128
}, {
'weight_decay': 0.0,
'lr_scale': 1
}, {
'weight_decay': 0.05,
'lr_scale': 64
}, {
'weight_decay': 0.0,
'lr_scale': 64
}, {
'weight_decay': 0.05,
'lr_scale': 32
}, {
'weight_decay': 0.0,
'lr_scale': 32
}, {
'weight_decay': 0.05,
'lr_scale': 16
}, {
'weight_decay': 0.0,
'lr_scale': 16
}, {
'weight_decay': 0.05,
'lr_scale': 2
}, {
'weight_decay': 0.0,
'lr_scale': 2
}, {
'weight_decay': 0.05,
'lr_scale': 128
}, {
'weight_decay': 0.05,
'lr_scale': 1
}]
class ToyConvNeXt(nn.Module):
def __init__(self):
super().__init__()
self.stages = nn.ModuleList()
for i in range(4):
stage = nn.Sequential(ConvModule(3, 4, kernel_size=1, bias=True))
self.stages.append(stage)
self.norm0 = nn.BatchNorm2d(2)
# add some variables to meet unit test coverate rate
self.cls_token = nn.Parameter(torch.ones(1))
self.mask_token = nn.Parameter(torch.ones(1))
self.pos_embed = nn.Parameter(torch.ones(1))
self.stem_norm = nn.Parameter(torch.ones(1))
self.downsample_norm0 = nn.BatchNorm2d(2)
self.downsample_norm1 = nn.BatchNorm2d(2)
self.downsample_norm2 = nn.BatchNorm2d(2)
self.lin = nn.Parameter(torch.ones(1))
self.lin.requires_grad = False
self.downsample_layers = nn.ModuleList()
for _ in range(4):
stage = nn.Sequential(nn.Conv2d(3, 4, kernel_size=1, bias=True))
self.downsample_layers.append(stage)
class ToyDetector(nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
self.head = nn.Conv2d(2, 2, kernel_size=1, groups=2)
class PseudoDataParallel(nn.Module):
def __init__(self, model):
super().__init__()
self.module = model
def check_optimizer_lr_wd(optimizer, gt_lr_wd):
assert isinstance(optimizer, torch.optim.AdamW)
assert optimizer.defaults['lr'] == base_lr
assert optimizer.defaults['weight_decay'] == base_wd
param_groups = optimizer.param_groups
print(param_groups)
assert len(param_groups) == len(gt_lr_wd)
for i, param_dict in enumerate(param_groups):
assert param_dict['weight_decay'] == gt_lr_wd[i]['weight_decay']
assert param_dict['lr_scale'] == gt_lr_wd[i]['lr_scale']
assert param_dict['lr_scale'] == param_dict['lr']
def test_learning_rate_decay_optimizer_constructor():
# Test lr wd for ConvNeXT
backbone = ToyConvNeXt()
model = PseudoDataParallel(ToyDetector(backbone))
optimizer_cfg = dict(
type='AdamW', lr=base_lr, betas=(0.9, 0.999), weight_decay=0.05)
# stagewise decay
stagewise_paramwise_cfg = dict(
decay_rate=decay_rate, decay_type='stage_wise', num_layers=6)
optim_constructor = LearningRateDecayOptimizerConstructor(
optimizer_cfg, stagewise_paramwise_cfg)
optimizer = optim_constructor(model)
check_optimizer_lr_wd(optimizer, expected_stage_wise_lr_wd_convnext)
# layerwise decay
layerwise_paramwise_cfg = dict(
decay_rate=decay_rate, decay_type='layer_wise', num_layers=6)
optim_constructor = LearningRateDecayOptimizerConstructor(
optimizer_cfg, layerwise_paramwise_cfg)
optimizer = optim_constructor(model)
check_optimizer_lr_wd(optimizer, expected_layer_wise_lr_wd_convnext)
Loading…
Cancel
Save