OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
33 lines
1.2 KiB
33 lines
1.2 KiB
# Copyright (c) OpenMMLab. All rights reserved. |
|
import copy |
|
|
|
from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS |
|
from mmcv.utils import Registry, build_from_cfg |
|
|
|
OPTIMIZER_BUILDERS = Registry( |
|
'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS) |
|
|
|
|
|
def build_optimizer_constructor(cfg): |
|
constructor_type = cfg.get('type') |
|
if constructor_type in OPTIMIZER_BUILDERS: |
|
return build_from_cfg(cfg, OPTIMIZER_BUILDERS) |
|
elif constructor_type in MMCV_OPTIMIZER_BUILDERS: |
|
return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS) |
|
else: |
|
raise KeyError(f'{constructor_type} is not registered ' |
|
'in the optimizer builder registry.') |
|
|
|
|
|
def build_optimizer(model, cfg): |
|
optimizer_cfg = copy.deepcopy(cfg) |
|
constructor_type = optimizer_cfg.pop('constructor', |
|
'DefaultOptimizerConstructor') |
|
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) |
|
optim_constructor = build_optimizer_constructor( |
|
dict( |
|
type=constructor_type, |
|
optimizer_cfg=optimizer_cfg, |
|
paramwise_cfg=paramwise_cfg)) |
|
optimizer = optim_constructor(model) |
|
return optimizer
|
|
|