You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
133 lines
5.4 KiB
133 lines
5.4 KiB
2 years ago
|
from typing import List, Dict, Set, Optional, Callable, Any
|
||
|
import torch
|
||
|
import copy
|
||
|
|
||
|
from detectron2.solver.build import reduce_param_groups
|
||
|
|
||
|
|
||
|
def lr_factor_func(para_name: str, is_resnet50, dec: float, debug=False) -> float:
|
||
|
if dec == 0:
|
||
|
dec = 1.
|
||
|
|
||
|
N = 5 if is_resnet50 else 11
|
||
|
if '.stem.' in para_name:
|
||
|
layer_id = 0
|
||
|
elif '.res' in para_name:
|
||
|
ls = para_name.split('.res')[1].split('.')
|
||
|
if ls[0].isnumeric() and ls[1].isnumeric():
|
||
|
stage_id, block_id = int(ls[0]), int(ls[1])
|
||
|
if stage_id == 2: # res2
|
||
|
layer_id = 1
|
||
|
elif stage_id == 3: # res3
|
||
|
layer_id = 2
|
||
|
elif stage_id == 4: # res4
|
||
|
layer_id = 3 + block_id // 3 # 3, 4 or 4, 5
|
||
|
else: # res5
|
||
|
layer_id = N
|
||
|
else:
|
||
|
assert para_name.startswith('roi_heads.res5.norm.')
|
||
|
layer_id = N + 1 # roi_heads.res5.norm.weight and roi_heads.res5.norm.bias of C4
|
||
|
else:
|
||
|
layer_id = N + 1
|
||
|
|
||
|
exp = N + 1 - layer_id
|
||
|
return f'{dec:g} ** {exp}' if debug else dec ** exp
|
||
|
|
||
|
|
||
|
# [modification] see: https://github.com/facebookresearch/detectron2/blob/v0.6/detectron2/solver/build.py#L134
|
||
|
# add the `lr_factor_func` to implement lr decay
|
||
|
def get_default_optimizer_params(
|
||
|
model: torch.nn.Module,
|
||
|
base_lr: Optional[float] = None,
|
||
|
weight_decay: Optional[float] = None,
|
||
|
weight_decay_norm: Optional[float] = None,
|
||
|
bias_lr_factor: Optional[float] = 1.0,
|
||
|
weight_decay_bias: Optional[float] = None,
|
||
|
lr_factor_func: Optional[Callable] = None,
|
||
|
overrides: Optional[Dict[str, Dict[str, float]]] = None,
|
||
|
) -> List[Dict[str, Any]]:
|
||
|
"""
|
||
|
Get default param list for optimizer, with support for a few types of
|
||
|
overrides. If no overrides needed, this is equivalent to `model.parameters()`.
|
||
|
|
||
|
Args:
|
||
|
base_lr: lr for every group by default. Can be omitted to use the one in optimizer.
|
||
|
weight_decay: weight decay for every group by default. Can be omitted to use the one
|
||
|
in optimizer.
|
||
|
weight_decay_norm: override weight decay for params in normalization layers
|
||
|
bias_lr_factor: multiplier of lr for bias parameters.
|
||
|
weight_decay_bias: override weight decay for bias parameters.
|
||
|
lr_factor_func: function to calculate lr decay rate by mapping the parameter names to
|
||
|
corresponding lr decay rate. Note that setting this option requires
|
||
|
also setting ``base_lr``.
|
||
|
overrides: if not `None`, provides values for optimizer hyperparameters
|
||
|
(LR, weight decay) for module parameters with a given name; e.g.
|
||
|
``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and
|
||
|
weight decay values for all module parameters named `embedding`.
|
||
|
|
||
|
For common detection models, ``weight_decay_norm`` is the only option
|
||
|
needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings
|
||
|
from Detectron1 that are not found useful.
|
||
|
|
||
|
Example:
|
||
|
::
|
||
|
torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0),
|
||
|
lr=0.01, weight_decay=1e-4, momentum=0.9)
|
||
|
"""
|
||
|
if overrides is None:
|
||
|
overrides = {}
|
||
|
defaults = {}
|
||
|
if base_lr is not None:
|
||
|
defaults["lr"] = base_lr
|
||
|
if weight_decay is not None:
|
||
|
defaults["weight_decay"] = weight_decay
|
||
|
bias_overrides = {}
|
||
|
if bias_lr_factor is not None and bias_lr_factor != 1.0:
|
||
|
# NOTE: unlike Detectron v1, we now by default make bias hyperparameters
|
||
|
# exactly the same as regular weights.
|
||
|
if base_lr is None:
|
||
|
raise ValueError("bias_lr_factor requires base_lr")
|
||
|
bias_overrides["lr"] = base_lr * bias_lr_factor
|
||
|
if weight_decay_bias is not None:
|
||
|
bias_overrides["weight_decay"] = weight_decay_bias
|
||
|
if len(bias_overrides):
|
||
|
if "bias" in overrides:
|
||
|
raise ValueError("Conflicting overrides for 'bias'")
|
||
|
overrides["bias"] = bias_overrides
|
||
|
if lr_factor_func is not None:
|
||
|
if base_lr is None:
|
||
|
raise ValueError("lr_factor_func requires base_lr")
|
||
|
norm_module_types = (
|
||
|
torch.nn.BatchNorm1d,
|
||
|
torch.nn.BatchNorm2d,
|
||
|
torch.nn.BatchNorm3d,
|
||
|
torch.nn.SyncBatchNorm,
|
||
|
# NaiveSyncBatchNorm inherits from BatchNorm2d
|
||
|
torch.nn.GroupNorm,
|
||
|
torch.nn.InstanceNorm1d,
|
||
|
torch.nn.InstanceNorm2d,
|
||
|
torch.nn.InstanceNorm3d,
|
||
|
torch.nn.LayerNorm,
|
||
|
torch.nn.LocalResponseNorm,
|
||
|
)
|
||
|
params: List[Dict[str, Any]] = []
|
||
|
memo: Set[torch.nn.parameter.Parameter] = set()
|
||
|
for module_name, module in model.named_modules():
|
||
|
for module_param_name, value in module.named_parameters(recurse=False):
|
||
|
if not value.requires_grad:
|
||
|
continue
|
||
|
# Avoid duplicating parameters
|
||
|
if value in memo:
|
||
|
continue
|
||
|
memo.add(value)
|
||
|
|
||
|
hyperparams = copy.copy(defaults)
|
||
|
if isinstance(module, norm_module_types) and weight_decay_norm is not None:
|
||
|
hyperparams["weight_decay"] = weight_decay_norm
|
||
|
if lr_factor_func is not None:
|
||
|
hyperparams["lr"] *= lr_factor_func(f"{module_name}.{module_param_name}")
|
||
|
|
||
|
hyperparams.update(overrides.get(module_param_name, {}))
|
||
|
params.append({"params": [value], **hyperparams})
|
||
|
return reduce_param_groups(params)
|