OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
47 lines
1.5 KiB
47 lines
1.5 KiB
# Copyright (c) OpenMMLab. All rights reserved. |
|
import torch.nn as nn |
|
from mmcv.utils import Registry, build_from_cfg |
|
|
|
TRANSFORMER = Registry('Transformer') |
|
LINEAR_LAYERS = Registry('linear layers') |
|
|
|
|
|
def build_transformer(cfg, default_args=None): |
|
"""Builder for Transformer.""" |
|
return build_from_cfg(cfg, TRANSFORMER, default_args) |
|
|
|
|
|
LINEAR_LAYERS.register_module('Linear', module=nn.Linear) |
|
|
|
|
|
def build_linear_layer(cfg, *args, **kwargs): |
|
"""Build linear layer. |
|
Args: |
|
cfg (None or dict): The linear layer config, which should contain: |
|
- type (str): Layer type. |
|
- layer args: Args needed to instantiate an linear layer. |
|
args (argument list): Arguments passed to the `__init__` |
|
method of the corresponding linear layer. |
|
kwargs (keyword arguments): Keyword arguments passed to the `__init__` |
|
method of the corresponding linear layer. |
|
Returns: |
|
nn.Module: Created linear layer. |
|
""" |
|
if cfg is None: |
|
cfg_ = dict(type='Linear') |
|
else: |
|
if not isinstance(cfg, dict): |
|
raise TypeError('cfg must be a dict') |
|
if 'type' not in cfg: |
|
raise KeyError('the cfg dict must contain the key "type"') |
|
cfg_ = cfg.copy() |
|
|
|
layer_type = cfg_.pop('type') |
|
if layer_type not in LINEAR_LAYERS: |
|
raise KeyError(f'Unrecognized linear type {layer_type}') |
|
else: |
|
linear_layer = LINEAR_LAYERS.get(layer_type) |
|
|
|
layer = linear_layer(*args, **kwargs, **cfg_) |
|
|
|
return layer
|
|
|