[Feature] Support ConvNeXt (#7281)
* update * update * fix init_cfg * update * update * update * update * update1 * final * update * update * fix lint * fix backbone config * update cascade_mask_rcnn * update and fix lint * update * fix DefaultOptimizerConstructor error * update * update * update * fix year * update * fix lintpull/7635/merge
parent
a828499d28
commit
1fd48f7318
14 changed files with 771 additions and 3 deletions
@ -0,0 +1,40 @@ |
||||
# ConvNeXt |
||||
|
||||
> [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) |
||||
|
||||
## Abstract |
||||
|
||||
The "Roaring 20s" of visual recognition began with the introduction of Vision Transformers (ViTs), which quickly superseded ConvNets as the state-of-the-art image classification model. A vanilla ViT, on the other hand, faces difficulties when applied to general computer vision tasks such as object detection and semantic segmentation. It is the hierarchical Transformers (e.g., Swin Transformers) that reintroduced several ConvNet priors, making Transformers practically viable as a generic vision backbone and demonstrating remarkable performance on a wide variety of vision tasks. However, the effectiveness of such hybrid approaches is still largely credited to the intrinsic superiority of Transformers, rather than the inherent inductive biases of convolutions. In this work, we reexamine the design spaces and test the limits of what a pure ConvNet can achieve. We gradually "modernize" a standard ResNet toward the design of a vision Transformer, and discover several key components that contribute to the performance difference along the way. The outcome of this exploration is a family of pure ConvNet models dubbed ConvNeXt. Constructed entirely from standard ConvNet modules, ConvNeXts compete favorably with Transformers in terms of accuracy and scalability, achieving 87.8% ImageNet top-1 accuracy and outperforming Swin Transformers on COCO detection and ADE20K segmentation, while maintaining the simplicity and efficiency of standard ConvNets. |
||||
|
||||
<div align=center> |
||||
<img src="https://user-images.githubusercontent.com/8370623/148624004-e9581042-ea4d-4e10-b3bd-42c92b02053b.png" width="90%"/> |
||||
</div> |
||||
|
||||
## Results and models |
||||
|
||||
| Method | Backbone | Pretrain | Lr schd | Multi-scale crop | FP16 | Mem (GB) | box AP | mask AP | Config | Download | |
||||
| :----------------: | :--------: | :---------: | :-----: | :--------------: | :--: | :------: | :----: | :-----: | :-------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | |
||||
| Mask R-CNN | ConvNeXt-T | ImageNet-1K | 3x | yes | yes | 7.3 | 46.2 | 41.7 | [config](./mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/convnext/mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco/mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco_20220426_154953-050731f4.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/convnext/mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco/mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco_20220426_154953.log.json) | |
||||
| Cascade Mask R-CNN | ConvNeXt-T | ImageNet-1K | 3x | yes | yes | 9.0 | 50.3 | 43.6 | [config](./cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/convnext/cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco/cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco_20220509_204200-8f07c40b.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/convnext/cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco/cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco_20220509_204200.log.json) | |
||||
| Cascade Mask R-CNN | ConvNeXt-S | ImageNet-1K | 3x | yes | yes | 12.3 | 51.8 | 44.8 | [config](./cascade_mask_rcnn_convnext-s_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/convnext/cascade_mask_rcnn_convnext-s_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco/cascade_mask_rcnn_convnext-s_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco_20220510_201004-3d24f5a4.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/convnext/cascade_mask_rcnn_convnext-s_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco/cascade_mask_rcnn_convnext-s_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco_20220510_201004.log.json) | |
||||
|
||||
**Note**: |
||||
|
||||
- ConvNeXt backbone needs to install [MMClassification](https://github.com/open-mmlab/mmclassification) first, which has abundant backbones for downstream tasks. |
||||
|
||||
```shell |
||||
pip install mmcls>=0.22.0 |
||||
``` |
||||
|
||||
- The performance is unstable. `Cascade Mask R-CNN` may fluctuate about 0.2 mAP. |
||||
|
||||
## Citation |
||||
|
||||
```bibtex |
||||
@article{liu2022convnet, |
||||
title={A ConvNet for the 2020s}, |
||||
author={Liu, Zhuang and Mao, Hanzi and Wu, Chao-Yuan and Feichtenhofer, Christoph and Darrell, Trevor and Xie, Saining}, |
||||
journal={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, |
||||
year={2022} |
||||
} |
||||
``` |
@ -0,0 +1,32 @@ |
||||
_base_ = './cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco.py' # noqa |
||||
|
||||
# please install mmcls>=0.22.0 |
||||
# import mmcls.models to trigger register_module in mmcls |
||||
custom_imports = dict(imports=['mmcls.models'], allow_failed_imports=False) |
||||
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-small_3rdparty_32xb128-noema_in1k_20220301-303e75e3.pth' # noqa |
||||
|
||||
model = dict( |
||||
backbone=dict( |
||||
_delete_=True, |
||||
type='mmcls.ConvNeXt', |
||||
arch='small', |
||||
out_indices=[0, 1, 2, 3], |
||||
drop_path_rate=0.6, |
||||
layer_scale_init_value=1.0, |
||||
gap_before_final_norm=False, |
||||
init_cfg=dict( |
||||
type='Pretrained', checkpoint=checkpoint_file, |
||||
prefix='backbone.'))) |
||||
|
||||
optimizer = dict( |
||||
_delete_=True, |
||||
constructor='LearningRateDecayOptimizerConstructor', |
||||
type='AdamW', |
||||
lr=0.0002, |
||||
betas=(0.9, 0.999), |
||||
weight_decay=0.05, |
||||
paramwise_cfg={ |
||||
'decay_rate': 0.7, |
||||
'decay_type': 'layer_wise', |
||||
'num_layers': 12 |
||||
}) |
@ -0,0 +1,149 @@ |
||||
_base_ = [ |
||||
'../_base_/models/cascade_mask_rcnn_r50_fpn.py', |
||||
'../_base_/datasets/coco_instance.py', |
||||
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' |
||||
] |
||||
|
||||
# please install mmcls>=0.22.0 |
||||
# import mmcls.models to trigger register_module in mmcls |
||||
custom_imports = dict(imports=['mmcls.models'], allow_failed_imports=False) |
||||
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-tiny_3rdparty_32xb128-noema_in1k_20220301-795e9634.pth' # noqa |
||||
|
||||
model = dict( |
||||
backbone=dict( |
||||
_delete_=True, |
||||
type='mmcls.ConvNeXt', |
||||
arch='tiny', |
||||
out_indices=[0, 1, 2, 3], |
||||
drop_path_rate=0.4, |
||||
layer_scale_init_value=1.0, |
||||
gap_before_final_norm=False, |
||||
init_cfg=dict( |
||||
type='Pretrained', checkpoint=checkpoint_file, |
||||
prefix='backbone.')), |
||||
neck=dict(in_channels=[96, 192, 384, 768]), |
||||
roi_head=dict(bbox_head=[ |
||||
dict( |
||||
type='ConvFCBBoxHead', |
||||
num_shared_convs=4, |
||||
num_shared_fcs=1, |
||||
in_channels=256, |
||||
conv_out_channels=256, |
||||
fc_out_channels=1024, |
||||
roi_feat_size=7, |
||||
num_classes=80, |
||||
bbox_coder=dict( |
||||
type='DeltaXYWHBBoxCoder', |
||||
target_means=[0., 0., 0., 0.], |
||||
target_stds=[0.1, 0.1, 0.2, 0.2]), |
||||
reg_class_agnostic=False, |
||||
reg_decoded_bbox=True, |
||||
norm_cfg=dict(type='SyncBN', requires_grad=True), |
||||
loss_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), |
||||
loss_bbox=dict(type='GIoULoss', loss_weight=10.0)), |
||||
dict( |
||||
type='ConvFCBBoxHead', |
||||
num_shared_convs=4, |
||||
num_shared_fcs=1, |
||||
in_channels=256, |
||||
conv_out_channels=256, |
||||
fc_out_channels=1024, |
||||
roi_feat_size=7, |
||||
num_classes=80, |
||||
bbox_coder=dict( |
||||
type='DeltaXYWHBBoxCoder', |
||||
target_means=[0., 0., 0., 0.], |
||||
target_stds=[0.05, 0.05, 0.1, 0.1]), |
||||
reg_class_agnostic=False, |
||||
reg_decoded_bbox=True, |
||||
norm_cfg=dict(type='SyncBN', requires_grad=True), |
||||
loss_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), |
||||
loss_bbox=dict(type='GIoULoss', loss_weight=10.0)), |
||||
dict( |
||||
type='ConvFCBBoxHead', |
||||
num_shared_convs=4, |
||||
num_shared_fcs=1, |
||||
in_channels=256, |
||||
conv_out_channels=256, |
||||
fc_out_channels=1024, |
||||
roi_feat_size=7, |
||||
num_classes=80, |
||||
bbox_coder=dict( |
||||
type='DeltaXYWHBBoxCoder', |
||||
target_means=[0., 0., 0., 0.], |
||||
target_stds=[0.033, 0.033, 0.067, 0.067]), |
||||
reg_class_agnostic=False, |
||||
reg_decoded_bbox=True, |
||||
norm_cfg=dict(type='SyncBN', requires_grad=True), |
||||
loss_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), |
||||
loss_bbox=dict(type='GIoULoss', loss_weight=10.0)) |
||||
])) |
||||
|
||||
img_norm_cfg = dict( |
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) |
||||
|
||||
# augmentation strategy originates from DETR / Sparse RCNN |
||||
train_pipeline = [ |
||||
dict(type='LoadImageFromFile'), |
||||
dict(type='LoadAnnotations', with_bbox=True, with_mask=True), |
||||
dict(type='RandomFlip', flip_ratio=0.5), |
||||
dict( |
||||
type='AutoAugment', |
||||
policies=[[ |
||||
dict( |
||||
type='Resize', |
||||
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), |
||||
(608, 1333), (640, 1333), (672, 1333), (704, 1333), |
||||
(736, 1333), (768, 1333), (800, 1333)], |
||||
multiscale_mode='value', |
||||
keep_ratio=True) |
||||
], |
||||
[ |
||||
dict( |
||||
type='Resize', |
||||
img_scale=[(400, 1333), (500, 1333), (600, 1333)], |
||||
multiscale_mode='value', |
||||
keep_ratio=True), |
||||
dict( |
||||
type='RandomCrop', |
||||
crop_type='absolute_range', |
||||
crop_size=(384, 600), |
||||
allow_negative_crop=True), |
||||
dict( |
||||
type='Resize', |
||||
img_scale=[(480, 1333), (512, 1333), (544, 1333), |
||||
(576, 1333), (608, 1333), (640, 1333), |
||||
(672, 1333), (704, 1333), (736, 1333), |
||||
(768, 1333), (800, 1333)], |
||||
multiscale_mode='value', |
||||
override=True, |
||||
keep_ratio=True) |
||||
]]), |
||||
dict(type='Normalize', **img_norm_cfg), |
||||
dict(type='Pad', size_divisor=32), |
||||
dict(type='DefaultFormatBundle'), |
||||
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), |
||||
] |
||||
data = dict(train=dict(pipeline=train_pipeline), persistent_workers=True) |
||||
|
||||
optimizer = dict( |
||||
_delete_=True, |
||||
constructor='LearningRateDecayOptimizerConstructor', |
||||
type='AdamW', |
||||
lr=0.0002, |
||||
betas=(0.9, 0.999), |
||||
weight_decay=0.05, |
||||
paramwise_cfg={ |
||||
'decay_rate': 0.7, |
||||
'decay_type': 'layer_wise', |
||||
'num_layers': 6 |
||||
}) |
||||
|
||||
lr_config = dict(warmup_iters=1000, step=[27, 33]) |
||||
runner = dict(max_epochs=36) |
||||
|
||||
# you need to set mode='dynamic' if you are using pytorch<=1.5.0 |
||||
fp16 = dict(loss_scale=dict(init_scale=512)) |
@ -0,0 +1,90 @@ |
||||
_base_ = [ |
||||
'../_base_/models/mask_rcnn_r50_fpn.py', |
||||
'../_base_/datasets/coco_instance.py', |
||||
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' |
||||
] |
||||
|
||||
# please install mmcls>=0.22.0 |
||||
# import mmcls.models to trigger register_module in mmcls |
||||
custom_imports = dict(imports=['mmcls.models'], allow_failed_imports=False) |
||||
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-tiny_3rdparty_32xb128-noema_in1k_20220301-795e9634.pth' # noqa |
||||
|
||||
model = dict( |
||||
backbone=dict( |
||||
_delete_=True, |
||||
type='mmcls.ConvNeXt', |
||||
arch='tiny', |
||||
out_indices=[0, 1, 2, 3], |
||||
drop_path_rate=0.4, |
||||
layer_scale_init_value=1.0, |
||||
gap_before_final_norm=False, |
||||
init_cfg=dict( |
||||
type='Pretrained', checkpoint=checkpoint_file, |
||||
prefix='backbone.')), |
||||
neck=dict(in_channels=[96, 192, 384, 768])) |
||||
|
||||
img_norm_cfg = dict( |
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) |
||||
|
||||
# augmentation strategy originates from DETR / Sparse RCNN |
||||
train_pipeline = [ |
||||
dict(type='LoadImageFromFile'), |
||||
dict(type='LoadAnnotations', with_bbox=True, with_mask=True), |
||||
dict(type='RandomFlip', flip_ratio=0.5), |
||||
dict( |
||||
type='AutoAugment', |
||||
policies=[[ |
||||
dict( |
||||
type='Resize', |
||||
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), |
||||
(608, 1333), (640, 1333), (672, 1333), (704, 1333), |
||||
(736, 1333), (768, 1333), (800, 1333)], |
||||
multiscale_mode='value', |
||||
keep_ratio=True) |
||||
], |
||||
[ |
||||
dict( |
||||
type='Resize', |
||||
img_scale=[(400, 1333), (500, 1333), (600, 1333)], |
||||
multiscale_mode='value', |
||||
keep_ratio=True), |
||||
dict( |
||||
type='RandomCrop', |
||||
crop_type='absolute_range', |
||||
crop_size=(384, 600), |
||||
allow_negative_crop=True), |
||||
dict( |
||||
type='Resize', |
||||
img_scale=[(480, 1333), (512, 1333), (544, 1333), |
||||
(576, 1333), (608, 1333), (640, 1333), |
||||
(672, 1333), (704, 1333), (736, 1333), |
||||
(768, 1333), (800, 1333)], |
||||
multiscale_mode='value', |
||||
override=True, |
||||
keep_ratio=True) |
||||
]]), |
||||
dict(type='Normalize', **img_norm_cfg), |
||||
dict(type='Pad', size_divisor=32), |
||||
dict(type='DefaultFormatBundle'), |
||||
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), |
||||
] |
||||
data = dict(train=dict(pipeline=train_pipeline), persistent_workers=True) |
||||
|
||||
optimizer = dict( |
||||
_delete_=True, |
||||
constructor='LearningRateDecayOptimizerConstructor', |
||||
type='AdamW', |
||||
lr=0.0001, |
||||
betas=(0.9, 0.999), |
||||
weight_decay=0.05, |
||||
paramwise_cfg={ |
||||
'decay_rate': 0.95, |
||||
'decay_type': 'layer_wise', |
||||
'num_layers': 6 |
||||
}) |
||||
|
||||
lr_config = dict(warmup_iters=1000, step=[27, 33]) |
||||
runner = dict(max_epochs=36) |
||||
|
||||
# you need to set mode='dynamic' if you are using pytorch<=1.5.0 |
||||
fp16 = dict(loss_scale=dict(init_scale=512)) |
@ -0,0 +1,93 @@ |
||||
Models: |
||||
- Name: mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco |
||||
In Collection: Mask R-CNN |
||||
Config: configs/convnext/mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco |
||||
Metadata: |
||||
Training Memory (GB): 7.3 |
||||
Epochs: 36 |
||||
Training Data: COCO |
||||
Training Techniques: |
||||
- AdamW |
||||
- Mixed Precision Training |
||||
Training Resources: 8x A100 GPUs |
||||
Architecture: |
||||
- ConvNeXt |
||||
Results: |
||||
- Task: Object Detection |
||||
Dataset: COCO |
||||
Metrics: |
||||
box AP: 46.2 |
||||
- Task: Instance Segmentation |
||||
Dataset: COCO |
||||
Metrics: |
||||
mask AP: 41.7 |
||||
Weights: https://download.openmmlab.com/mmdetection/v2.0/convnext/mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco/mask_rcnn_convnext-t_p4_w7_fpn_fp16_ms-crop_3x_coco_20220426_154953-050731f4.pth |
||||
Paper: |
||||
URL: https://arxiv.org/abs/2201.03545 |
||||
Title: 'A ConvNet for the 2020s' |
||||
README: configs/convnext/README.md |
||||
Code: |
||||
URL: https://github.com/open-mmlab/mmdetection/blob/v2.16.0/mmdet/models/backbones/swin.py#L465 |
||||
Version: v2.16.0 |
||||
|
||||
- Name: cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco |
||||
In Collection: Cascade Mask R-CNN |
||||
Config: configs/convnext/cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco.py |
||||
Metadata: |
||||
Training Memory (GB): 9.0 |
||||
Epochs: 36 |
||||
Training Data: COCO |
||||
Training Techniques: |
||||
- AdamW |
||||
- Mixed Precision Training |
||||
Training Resources: 8x A100 GPUs |
||||
Architecture: |
||||
- ConvNeXt |
||||
Results: |
||||
- Task: Object Detection |
||||
Dataset: COCO |
||||
Metrics: |
||||
box AP: 50.3 |
||||
- Task: Instance Segmentation |
||||
Dataset: COCO |
||||
Metrics: |
||||
mask AP: 43.6 |
||||
Weights: https://download.openmmlab.com/mmdetection/v2.0/convnext/cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco/cascade_mask_rcnn_convnext-t_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco_20220509_204200-8f07c40b.pth |
||||
Paper: |
||||
URL: https://arxiv.org/abs/2201.03545 |
||||
Title: 'A ConvNet for the 2020s' |
||||
README: configs/convnext/README.md |
||||
Code: |
||||
URL: https://github.com/open-mmlab/mmdetection/blob/v2.16.0/mmdet/models/backbones/swin.py#L465 |
||||
Version: v2.25.0 |
||||
|
||||
- Name: cascade_mask_rcnn_convnext-s_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco |
||||
In Collection: Cascade Mask R-CNN |
||||
Config: configs/convnext/cascade_mask_rcnn_convnext-s_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco.py |
||||
Metadata: |
||||
Training Memory (GB): 12.3 |
||||
Epochs: 36 |
||||
Training Data: COCO |
||||
Training Techniques: |
||||
- AdamW |
||||
- Mixed Precision Training |
||||
Training Resources: 8x A100 GPUs |
||||
Architecture: |
||||
- ConvNeXt |
||||
Results: |
||||
- Task: Object Detection |
||||
Dataset: COCO |
||||
Metrics: |
||||
box AP: 51.8 |
||||
- Task: Instance Segmentation |
||||
Dataset: COCO |
||||
Metrics: |
||||
mask AP: 44.8 |
||||
Weights: https://download.openmmlab.com/mmdetection/v2.0/convnext/cascade_mask_rcnn_convnext-s_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco/cascade_mask_rcnn_convnext-s_p4_w7_fpn_giou_4conv1f_fp16_ms-crop_3x_coco_20220510_201004-3d24f5a4.pth |
||||
Paper: |
||||
URL: https://arxiv.org/abs/2201.03545 |
||||
Title: 'A ConvNet for the 2020s' |
||||
README: configs/convnext/README.md |
||||
Code: |
||||
URL: https://github.com/open-mmlab/mmdetection/blob/v2.16.0/mmdet/models/backbones/swin.py#L465 |
||||
Version: v2.25.0 |
@ -0,0 +1,9 @@ |
||||
# Copyright (c) OpenMMLab. All rights reserved. |
||||
from .builder import OPTIMIZER_BUILDERS, build_optimizer |
||||
from .layer_decay_optimizer_constructor import \ |
||||
LearningRateDecayOptimizerConstructor |
||||
|
||||
__all__ = [ |
||||
'LearningRateDecayOptimizerConstructor', 'OPTIMIZER_BUILDERS', |
||||
'build_optimizer' |
||||
] |
@ -0,0 +1,33 @@ |
||||
# Copyright (c) OpenMMLab. All rights reserved. |
||||
import copy |
||||
|
||||
from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS |
||||
from mmcv.utils import Registry, build_from_cfg |
||||
|
||||
OPTIMIZER_BUILDERS = Registry( |
||||
'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS) |
||||
|
||||
|
||||
def build_optimizer_constructor(cfg): |
||||
constructor_type = cfg.get('type') |
||||
if constructor_type in OPTIMIZER_BUILDERS: |
||||
return build_from_cfg(cfg, OPTIMIZER_BUILDERS) |
||||
elif constructor_type in MMCV_OPTIMIZER_BUILDERS: |
||||
return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS) |
||||
else: |
||||
raise KeyError(f'{constructor_type} is not registered ' |
||||
'in the optimizer builder registry.') |
||||
|
||||
|
||||
def build_optimizer(model, cfg): |
||||
optimizer_cfg = copy.deepcopy(cfg) |
||||
constructor_type = optimizer_cfg.pop('constructor', |
||||
'DefaultOptimizerConstructor') |
||||
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) |
||||
optim_constructor = build_optimizer_constructor( |
||||
dict( |
||||
type=constructor_type, |
||||
optimizer_cfg=optimizer_cfg, |
||||
paramwise_cfg=paramwise_cfg)) |
||||
optimizer = optim_constructor(model) |
||||
return optimizer |
@ -0,0 +1,154 @@ |
||||
# Copyright (c) OpenMMLab. All rights reserved. |
||||
import json |
||||
|
||||
from mmcv.runner import DefaultOptimizerConstructor, get_dist_info |
||||
|
||||
from mmdet.utils import get_root_logger |
||||
from .builder import OPTIMIZER_BUILDERS |
||||
|
||||
|
||||
def get_layer_id_for_convnext(var_name, max_layer_id): |
||||
"""Get the layer id to set the different learning rates in ``layer_wise`` |
||||
decay_type. |
||||
|
||||
Args: |
||||
var_name (str): The key of the model. |
||||
max_layer_id (int): Maximum layer id. |
||||
|
||||
Returns: |
||||
int: The id number corresponding to different learning rate in |
||||
``LearningRateDecayOptimizerConstructor``. |
||||
""" |
||||
|
||||
if var_name in ('backbone.cls_token', 'backbone.mask_token', |
||||
'backbone.pos_embed'): |
||||
return 0 |
||||
elif var_name.startswith('backbone.downsample_layers'): |
||||
stage_id = int(var_name.split('.')[2]) |
||||
if stage_id == 0: |
||||
layer_id = 0 |
||||
elif stage_id == 1: |
||||
layer_id = 2 |
||||
elif stage_id == 2: |
||||
layer_id = 3 |
||||
elif stage_id == 3: |
||||
layer_id = max_layer_id |
||||
return layer_id |
||||
elif var_name.startswith('backbone.stages'): |
||||
stage_id = int(var_name.split('.')[2]) |
||||
block_id = int(var_name.split('.')[3]) |
||||
if stage_id == 0: |
||||
layer_id = 1 |
||||
elif stage_id == 1: |
||||
layer_id = 2 |
||||
elif stage_id == 2: |
||||
layer_id = 3 + block_id // 3 |
||||
elif stage_id == 3: |
||||
layer_id = max_layer_id |
||||
return layer_id |
||||
else: |
||||
return max_layer_id + 1 |
||||
|
||||
|
||||
def get_stage_id_for_convnext(var_name, max_stage_id): |
||||
"""Get the stage id to set the different learning rates in ``stage_wise`` |
||||
decay_type. |
||||
|
||||
Args: |
||||
var_name (str): The key of the model. |
||||
max_stage_id (int): Maximum stage id. |
||||
|
||||
Returns: |
||||
int: The id number corresponding to different learning rate in |
||||
``LearningRateDecayOptimizerConstructor``. |
||||
""" |
||||
|
||||
if var_name in ('backbone.cls_token', 'backbone.mask_token', |
||||
'backbone.pos_embed'): |
||||
return 0 |
||||
elif var_name.startswith('backbone.downsample_layers'): |
||||
return 0 |
||||
elif var_name.startswith('backbone.stages'): |
||||
stage_id = int(var_name.split('.')[2]) |
||||
return stage_id + 1 |
||||
else: |
||||
return max_stage_id - 1 |
||||
|
||||
|
||||
@OPTIMIZER_BUILDERS.register_module() |
||||
class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor): |
||||
# Different learning rates are set for different layers of backbone. |
||||
# Note: Currently, this optimizer constructor is built for ConvNeXt. |
||||
|
||||
def add_params(self, params, module, **kwargs): |
||||
"""Add all parameters of module to the params list. |
||||
|
||||
The parameters of the given module will be added to the list of param |
||||
groups, with specific rules defined by paramwise_cfg. |
||||
|
||||
Args: |
||||
params (list[dict]): A list of param groups, it will be modified |
||||
in place. |
||||
module (nn.Module): The module to be added. |
||||
""" |
||||
logger = get_root_logger() |
||||
|
||||
parameter_groups = {} |
||||
logger.info(f'self.paramwise_cfg is {self.paramwise_cfg}') |
||||
num_layers = self.paramwise_cfg.get('num_layers') + 2 |
||||
decay_rate = self.paramwise_cfg.get('decay_rate') |
||||
decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise') |
||||
logger.info('Build LearningRateDecayOptimizerConstructor ' |
||||
f'{decay_type} {decay_rate} - {num_layers}') |
||||
weight_decay = self.base_wd |
||||
for name, param in module.named_parameters(): |
||||
if not param.requires_grad: |
||||
continue # frozen weights |
||||
if len(param.shape) == 1 or name.endswith('.bias') or name in ( |
||||
'pos_embed', 'cls_token'): |
||||
group_name = 'no_decay' |
||||
this_weight_decay = 0. |
||||
else: |
||||
group_name = 'decay' |
||||
this_weight_decay = weight_decay |
||||
if 'layer_wise' in decay_type: |
||||
if 'ConvNeXt' in module.backbone.__class__.__name__: |
||||
layer_id = get_layer_id_for_convnext( |
||||
name, self.paramwise_cfg.get('num_layers')) |
||||
logger.info(f'set param {name} as id {layer_id}') |
||||
else: |
||||
raise NotImplementedError() |
||||
elif decay_type == 'stage_wise': |
||||
if 'ConvNeXt' in module.backbone.__class__.__name__: |
||||
layer_id = get_stage_id_for_convnext(name, num_layers) |
||||
logger.info(f'set param {name} as id {layer_id}') |
||||
else: |
||||
raise NotImplementedError() |
||||
group_name = f'layer_{layer_id}_{group_name}' |
||||
|
||||
if group_name not in parameter_groups: |
||||
scale = decay_rate**(num_layers - layer_id - 1) |
||||
|
||||
parameter_groups[group_name] = { |
||||
'weight_decay': this_weight_decay, |
||||
'params': [], |
||||
'param_names': [], |
||||
'lr_scale': scale, |
||||
'group_name': group_name, |
||||
'lr': scale * self.base_lr, |
||||
} |
||||
|
||||
parameter_groups[group_name]['params'].append(param) |
||||
parameter_groups[group_name]['param_names'].append(name) |
||||
rank, _ = get_dist_info() |
||||
if rank == 0: |
||||
to_display = {} |
||||
for key in parameter_groups: |
||||
to_display[key] = { |
||||
'param_names': parameter_groups[key]['param_names'], |
||||
'lr_scale': parameter_groups[key]['lr_scale'], |
||||
'lr': parameter_groups[key]['lr'], |
||||
'weight_decay': parameter_groups[key]['weight_decay'], |
||||
} |
||||
logger.info(f'Param groups = {json.dumps(to_display, indent=2)}') |
||||
params.extend(parameter_groups.values()) |
@ -0,0 +1,164 @@ |
||||
# Copyright (c) OpenMMLab. All rights reserved. |
||||
import torch |
||||
import torch.nn as nn |
||||
from mmcv.cnn import ConvModule |
||||
|
||||
from mmdet.core.optimizers import LearningRateDecayOptimizerConstructor |
||||
|
||||
base_lr = 1 |
||||
decay_rate = 2 |
||||
base_wd = 0.05 |
||||
weight_decay = 0.05 |
||||
|
||||
expected_stage_wise_lr_wd_convnext = [{ |
||||
'weight_decay': 0.0, |
||||
'lr_scale': 128 |
||||
}, { |
||||
'weight_decay': 0.0, |
||||
'lr_scale': 1 |
||||
}, { |
||||
'weight_decay': 0.05, |
||||
'lr_scale': 64 |
||||
}, { |
||||
'weight_decay': 0.0, |
||||
'lr_scale': 64 |
||||
}, { |
||||
'weight_decay': 0.05, |
||||
'lr_scale': 32 |
||||
}, { |
||||
'weight_decay': 0.0, |
||||
'lr_scale': 32 |
||||
}, { |
||||
'weight_decay': 0.05, |
||||
'lr_scale': 16 |
||||
}, { |
||||
'weight_decay': 0.0, |
||||
'lr_scale': 16 |
||||
}, { |
||||
'weight_decay': 0.05, |
||||
'lr_scale': 8 |
||||
}, { |
||||
'weight_decay': 0.0, |
||||
'lr_scale': 8 |
||||
}, { |
||||
'weight_decay': 0.05, |
||||
'lr_scale': 128 |
||||
}, { |
||||
'weight_decay': 0.05, |
||||
'lr_scale': 1 |
||||
}] |
||||
|
||||
expected_layer_wise_lr_wd_convnext = [{ |
||||
'weight_decay': 0.0, |
||||
'lr_scale': 128 |
||||
}, { |
||||
'weight_decay': 0.0, |
||||
'lr_scale': 1 |
||||
}, { |
||||
'weight_decay': 0.05, |
||||
'lr_scale': 64 |
||||
}, { |
||||
'weight_decay': 0.0, |
||||
'lr_scale': 64 |
||||
}, { |
||||
'weight_decay': 0.05, |
||||
'lr_scale': 32 |
||||
}, { |
||||
'weight_decay': 0.0, |
||||
'lr_scale': 32 |
||||
}, { |
||||
'weight_decay': 0.05, |
||||
'lr_scale': 16 |
||||
}, { |
||||
'weight_decay': 0.0, |
||||
'lr_scale': 16 |
||||
}, { |
||||
'weight_decay': 0.05, |
||||
'lr_scale': 2 |
||||
}, { |
||||
'weight_decay': 0.0, |
||||
'lr_scale': 2 |
||||
}, { |
||||
'weight_decay': 0.05, |
||||
'lr_scale': 128 |
||||
}, { |
||||
'weight_decay': 0.05, |
||||
'lr_scale': 1 |
||||
}] |
||||
|
||||
|
||||
class ToyConvNeXt(nn.Module): |
||||
|
||||
def __init__(self): |
||||
super().__init__() |
||||
self.stages = nn.ModuleList() |
||||
for i in range(4): |
||||
stage = nn.Sequential(ConvModule(3, 4, kernel_size=1, bias=True)) |
||||
self.stages.append(stage) |
||||
self.norm0 = nn.BatchNorm2d(2) |
||||
|
||||
# add some variables to meet unit test coverate rate |
||||
self.cls_token = nn.Parameter(torch.ones(1)) |
||||
self.mask_token = nn.Parameter(torch.ones(1)) |
||||
self.pos_embed = nn.Parameter(torch.ones(1)) |
||||
self.stem_norm = nn.Parameter(torch.ones(1)) |
||||
self.downsample_norm0 = nn.BatchNorm2d(2) |
||||
self.downsample_norm1 = nn.BatchNorm2d(2) |
||||
self.downsample_norm2 = nn.BatchNorm2d(2) |
||||
self.lin = nn.Parameter(torch.ones(1)) |
||||
self.lin.requires_grad = False |
||||
self.downsample_layers = nn.ModuleList() |
||||
for _ in range(4): |
||||
stage = nn.Sequential(nn.Conv2d(3, 4, kernel_size=1, bias=True)) |
||||
self.downsample_layers.append(stage) |
||||
|
||||
|
||||
class ToyDetector(nn.Module): |
||||
|
||||
def __init__(self, backbone): |
||||
super().__init__() |
||||
self.backbone = backbone |
||||
self.head = nn.Conv2d(2, 2, kernel_size=1, groups=2) |
||||
|
||||
|
||||
class PseudoDataParallel(nn.Module): |
||||
|
||||
def __init__(self, model): |
||||
super().__init__() |
||||
self.module = model |
||||
|
||||
|
||||
def check_optimizer_lr_wd(optimizer, gt_lr_wd): |
||||
assert isinstance(optimizer, torch.optim.AdamW) |
||||
assert optimizer.defaults['lr'] == base_lr |
||||
assert optimizer.defaults['weight_decay'] == base_wd |
||||
param_groups = optimizer.param_groups |
||||
print(param_groups) |
||||
assert len(param_groups) == len(gt_lr_wd) |
||||
for i, param_dict in enumerate(param_groups): |
||||
assert param_dict['weight_decay'] == gt_lr_wd[i]['weight_decay'] |
||||
assert param_dict['lr_scale'] == gt_lr_wd[i]['lr_scale'] |
||||
assert param_dict['lr_scale'] == param_dict['lr'] |
||||
|
||||
|
||||
def test_learning_rate_decay_optimizer_constructor(): |
||||
|
||||
# Test lr wd for ConvNeXT |
||||
backbone = ToyConvNeXt() |
||||
model = PseudoDataParallel(ToyDetector(backbone)) |
||||
optimizer_cfg = dict( |
||||
type='AdamW', lr=base_lr, betas=(0.9, 0.999), weight_decay=0.05) |
||||
# stagewise decay |
||||
stagewise_paramwise_cfg = dict( |
||||
decay_rate=decay_rate, decay_type='stage_wise', num_layers=6) |
||||
optim_constructor = LearningRateDecayOptimizerConstructor( |
||||
optimizer_cfg, stagewise_paramwise_cfg) |
||||
optimizer = optim_constructor(model) |
||||
check_optimizer_lr_wd(optimizer, expected_stage_wise_lr_wd_convnext) |
||||
# layerwise decay |
||||
layerwise_paramwise_cfg = dict( |
||||
decay_rate=decay_rate, decay_type='layer_wise', num_layers=6) |
||||
optim_constructor = LearningRateDecayOptimizerConstructor( |
||||
optimizer_cfg, layerwise_paramwise_cfg) |
||||
optimizer = optim_constructor(model) |
||||
check_optimizer_lr_wd(optimizer, expected_layer_wise_lr_wd_convnext) |
Loading…
Reference in new issue