You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

129 lines
4.2 KiB

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import traceback
class Registry(object):
"""
The registry that provides name -> object mapping, to support third-party users' custom modules.
To create a registry (inside ppgan):
.. code-block:: python
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
.. code-block:: python
@BACKBONE_REGISTRY.register()
class MyBackbone():
...
Or:
.. code-block:: python
BACKBONE_REGISTRY.register(MyBackbone)
"""
def __init__(self, name):
"""
Args:
name (str): the name of this registry
"""
self._name = name
self._obj_map = {}
def _do_register(self, name, obj):
assert (
name not in self._obj_map
), "An object named '{}' was already registered in '{}' registry!".format(
name, self._name)
self._obj_map[name] = obj
def register(self, obj=None, name=None):
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not. See docstring of this class for usage.
"""
if obj is None:
# used as a decorator
def deco(func_or_class, name=name):
if name is None:
name = func_or_class.__name__
self._do_register(name, func_or_class)
return func_or_class
return deco
# used as a function call
if name is None:
name = obj.__name__
self._do_register(name, obj)
def get(self, name):
ret = self._obj_map.get(name)
if ret is None:
raise KeyError("No object named '{}' found in '{}' registry!".
format(name, self._name))
return ret
# code was based on mmcv
# Copyright (c) Copyright (c) OpenMMLab.
def build_from_config(cfg, registry, default_args=None):
"""Build a class from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "name".
registry (ppgan.utils.Registry): The registry to search the name from.
default_args (dict, optional): Default initialization arguments.
Returns:
class: The constructed class.
"""
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'name' not in cfg:
if default_args is None or 'name' not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "name", '
f'but got {cfg}\n{default_args}')
if not isinstance(registry, Registry):
raise TypeError('registry must be an ppgan.utils.Registry object, '
f'but got {type(registry)}')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
cls_name = args.pop('name')
if isinstance(cls_name, str):
obj_cls = registry.get(cls_name)
elif inspect.isclass(cls_name):
obj_cls = obj_cls
else:
raise TypeError(
f'name must be a str or valid name, but got {type(cls_name)}')
try:
instance = obj_cls(**args)
except Exception as e:
stack_info = traceback.format_exc()
print("Fail to initial class [{}] with error: "
"{} and stack:\n{}".format(cls_name, e, str(stack_info)))
raise e
return instance