You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
129 lines
4.2 KiB
129 lines
4.2 KiB
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
import inspect |
|
import traceback |
|
|
|
|
|
class Registry(object): |
|
""" |
|
The registry that provides name -> object mapping, to support third-party users' custom modules. |
|
To create a registry (inside ppgan): |
|
.. code-block:: python |
|
BACKBONE_REGISTRY = Registry('BACKBONE') |
|
To register an object: |
|
.. code-block:: python |
|
@BACKBONE_REGISTRY.register() |
|
class MyBackbone(): |
|
... |
|
Or: |
|
.. code-block:: python |
|
BACKBONE_REGISTRY.register(MyBackbone) |
|
""" |
|
|
|
def __init__(self, name): |
|
""" |
|
Args: |
|
name (str): the name of this registry |
|
""" |
|
self._name = name |
|
|
|
self._obj_map = {} |
|
|
|
def _do_register(self, name, obj): |
|
assert ( |
|
name not in self._obj_map |
|
), "An object named '{}' was already registered in '{}' registry!".format( |
|
name, self._name) |
|
self._obj_map[name] = obj |
|
|
|
def register(self, obj=None, name=None): |
|
""" |
|
Register the given object under the the name `obj.__name__`. |
|
Can be used as either a decorator or not. See docstring of this class for usage. |
|
""" |
|
if obj is None: |
|
# used as a decorator |
|
def deco(func_or_class, name=name): |
|
if name is None: |
|
name = func_or_class.__name__ |
|
self._do_register(name, func_or_class) |
|
return func_or_class |
|
|
|
return deco |
|
|
|
# used as a function call |
|
if name is None: |
|
name = obj.__name__ |
|
self._do_register(name, obj) |
|
|
|
def get(self, name): |
|
ret = self._obj_map.get(name) |
|
if ret is None: |
|
raise KeyError("No object named '{}' found in '{}' registry!". |
|
format(name, self._name)) |
|
|
|
return ret |
|
|
|
|
|
# code was based on mmcv |
|
# Copyright (c) Copyright (c) OpenMMLab. |
|
def build_from_config(cfg, registry, default_args=None): |
|
"""Build a class from config dict. |
|
|
|
Args: |
|
cfg (dict): Config dict. It should at least contain the key "name". |
|
registry (ppgan.utils.Registry): The registry to search the name from. |
|
default_args (dict, optional): Default initialization arguments. |
|
|
|
Returns: |
|
class: The constructed class. |
|
""" |
|
if not isinstance(cfg, dict): |
|
raise TypeError(f'cfg must be a dict, but got {type(cfg)}') |
|
if 'name' not in cfg: |
|
if default_args is None or 'name' not in default_args: |
|
raise KeyError( |
|
'`cfg` or `default_args` must contain the key "name", ' |
|
f'but got {cfg}\n{default_args}') |
|
if not isinstance(registry, Registry): |
|
raise TypeError('registry must be an ppgan.utils.Registry object, ' |
|
f'but got {type(registry)}') |
|
if not (isinstance(default_args, dict) or default_args is None): |
|
raise TypeError('default_args must be a dict or None, ' |
|
f'but got {type(default_args)}') |
|
|
|
args = cfg.copy() |
|
|
|
if default_args is not None: |
|
for name, value in default_args.items(): |
|
args.setdefault(name, value) |
|
|
|
cls_name = args.pop('name') |
|
if isinstance(cls_name, str): |
|
obj_cls = registry.get(cls_name) |
|
elif inspect.isclass(cls_name): |
|
obj_cls = obj_cls |
|
else: |
|
raise TypeError( |
|
f'name must be a str or valid name, but got {type(cls_name)}') |
|
|
|
try: |
|
instance = obj_cls(**args) |
|
except Exception as e: |
|
stack_info = traceback.format_exc() |
|
print("Fail to initial class [{}] with error: " |
|
"{} and stack:\n{}".format(cls_name, e, str(stack_info))) |
|
raise e |
|
return instance
|
|
|