You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
275 lines
8.7 KiB
275 lines
8.7 KiB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
from __future__ import absolute_import |
|
from __future__ import print_function |
|
from __future__ import division |
|
|
|
import importlib |
|
import os |
|
import sys |
|
|
|
import yaml |
|
import collections |
|
|
|
try: |
|
collectionsAbc = collections.abc |
|
except AttributeError: |
|
collectionsAbc = collections |
|
|
|
from .config.schema import SchemaDict, SharedConfig, extract_schema |
|
from .config.yaml_helpers import serializable |
|
|
|
__all__ = [ |
|
'global_config', |
|
'load_config', |
|
'merge_config', |
|
'get_registered_modules', |
|
'create', |
|
'register', |
|
'serializable', |
|
'dump_value', |
|
] |
|
|
|
|
|
def dump_value(value): |
|
# XXX this is hackish, but collections.abc is not available in python 2 |
|
if hasattr(value, '__dict__') or isinstance(value, (dict, tuple, list)): |
|
value = yaml.dump(value, default_flow_style=True) |
|
value = value.replace('\n', '') |
|
value = value.replace('...', '') |
|
return "'{}'".format(value) |
|
else: |
|
# primitive types |
|
return str(value) |
|
|
|
|
|
class AttrDict(dict): |
|
"""Single level attribute dict, NOT recursive""" |
|
|
|
def __init__(self, **kwargs): |
|
super(AttrDict, self).__init__() |
|
super(AttrDict, self).update(kwargs) |
|
|
|
def __getattr__(self, key): |
|
if key in self: |
|
return self[key] |
|
raise AttributeError("object has no attribute '{}'".format(key)) |
|
|
|
|
|
global_config = AttrDict() |
|
|
|
BASE_KEY = '_BASE_' |
|
|
|
|
|
# parse and load _BASE_ recursively |
|
def _load_config_with_base(file_path): |
|
with open(file_path) as f: |
|
file_cfg = yaml.load(f, Loader=yaml.Loader) |
|
|
|
# NOTE: cfgs outside have higher priority than cfgs in _BASE_ |
|
if BASE_KEY in file_cfg: |
|
all_base_cfg = AttrDict() |
|
base_ymls = list(file_cfg[BASE_KEY]) |
|
for base_yml in base_ymls: |
|
if base_yml.startswith("~"): |
|
base_yml = os.path.expanduser(base_yml) |
|
if not base_yml.startswith('/'): |
|
base_yml = os.path.join(os.path.dirname(file_path), base_yml) |
|
|
|
with open(base_yml) as f: |
|
base_cfg = _load_config_with_base(base_yml) |
|
all_base_cfg = merge_config(base_cfg, all_base_cfg) |
|
|
|
del file_cfg[BASE_KEY] |
|
return merge_config(file_cfg, all_base_cfg) |
|
|
|
return file_cfg |
|
|
|
|
|
def load_config(file_path): |
|
""" |
|
Load config from file. |
|
|
|
Args: |
|
file_path (str): Path of the config file to be loaded. |
|
|
|
Returns: global config |
|
""" |
|
_, ext = os.path.splitext(file_path) |
|
assert ext in ['.yml', '.yaml'], "only support yaml files for now" |
|
|
|
# load config from file and merge into global config |
|
cfg = _load_config_with_base(file_path) |
|
cfg['filename'] = os.path.splitext(os.path.split(file_path)[-1])[0] |
|
merge_config(cfg) |
|
|
|
return global_config |
|
|
|
|
|
def dict_merge(dct, merge_dct): |
|
""" Recursive dict merge. Inspired by :meth:``dict.update()``, instead of |
|
updating only top-level keys, dict_merge recurses down into dicts nested |
|
to an arbitrary depth, updating keys. The ``merge_dct`` is merged into |
|
``dct``. |
|
|
|
Args: |
|
dct: dict onto which the merge is executed |
|
merge_dct: dct merged into dct |
|
|
|
Returns: dct |
|
""" |
|
for k, v in merge_dct.items(): |
|
if (k in dct and isinstance(dct[k], dict) and |
|
isinstance(merge_dct[k], collectionsAbc.Mapping)): |
|
dict_merge(dct[k], merge_dct[k]) |
|
else: |
|
dct[k] = merge_dct[k] |
|
return dct |
|
|
|
|
|
def merge_config(config, another_cfg=None): |
|
""" |
|
Merge config into global config or another_cfg. |
|
|
|
Args: |
|
config (dict): Config to be merged. |
|
|
|
Returns: global config |
|
""" |
|
global global_config |
|
dct = another_cfg or global_config |
|
return dict_merge(dct, config) |
|
|
|
|
|
def get_registered_modules(): |
|
return {k: v for k, v in global_config.items() if isinstance(v, SchemaDict)} |
|
|
|
|
|
def make_partial(cls): |
|
op_module = importlib.import_module(cls.__op__.__module__) |
|
op = getattr(op_module, cls.__op__.__name__) |
|
cls.__category__ = getattr(cls, '__category__', None) or 'op' |
|
|
|
def partial_apply(self, *args, **kwargs): |
|
kwargs_ = self.__dict__.copy() |
|
kwargs_.update(kwargs) |
|
return op(*args, **kwargs_) |
|
|
|
if getattr(cls, '__append_doc__', True): # XXX should default to True? |
|
if sys.version_info[0] > 2: |
|
cls.__doc__ = "Wrapper for `{}` OP".format(op.__name__) |
|
cls.__init__.__doc__ = op.__doc__ |
|
cls.__call__ = partial_apply |
|
cls.__call__.__doc__ = op.__doc__ |
|
else: |
|
# XXX work around for python 2 |
|
partial_apply.__doc__ = op.__doc__ |
|
cls.__call__ = partial_apply |
|
return cls |
|
|
|
|
|
def register(cls): |
|
""" |
|
Register a given module class. |
|
|
|
Args: |
|
cls (type): Module class to be registered. |
|
|
|
Returns: cls |
|
""" |
|
if cls.__name__ in global_config: |
|
raise ValueError("Module class already registered: {}".format( |
|
cls.__name__)) |
|
if hasattr(cls, '__op__'): |
|
cls = make_partial(cls) |
|
global_config[cls.__name__] = extract_schema(cls) |
|
return cls |
|
|
|
|
|
def create(cls_or_name, **kwargs): |
|
""" |
|
Create an instance of given module class. |
|
|
|
Args: |
|
cls_or_name (type or str): Class of which to create instance. |
|
|
|
Returns: instance of type `cls_or_name` |
|
""" |
|
assert type(cls_or_name) in [type, str |
|
], "should be a class or name of a class" |
|
name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__ |
|
assert name in global_config and \ |
|
isinstance(global_config[name], SchemaDict), \ |
|
"the module {} is not registered".format(name) |
|
config = global_config[name] |
|
cls = getattr(config.pymodule, name) |
|
cls_kwargs = {} |
|
cls_kwargs.update(global_config[name]) |
|
|
|
# parse `shared` annoation of registered modules |
|
if getattr(config, 'shared', None): |
|
for k in config.shared: |
|
target_key = config[k] |
|
shared_conf = config.schema[k].default |
|
assert isinstance(shared_conf, SharedConfig) |
|
if target_key is not None and not isinstance(target_key, |
|
SharedConfig): |
|
continue # value is given for the module |
|
elif shared_conf.key in global_config: |
|
# `key` is present in config |
|
cls_kwargs[k] = global_config[shared_conf.key] |
|
else: |
|
cls_kwargs[k] = shared_conf.default_value |
|
|
|
# parse `inject` annoation of registered modules |
|
if getattr(cls, 'from_config', None): |
|
cls_kwargs.update(cls.from_config(config, **kwargs)) |
|
|
|
if getattr(config, 'inject', None): |
|
for k in config.inject: |
|
target_key = config[k] |
|
# optional dependency |
|
if target_key is None: |
|
continue |
|
|
|
if isinstance(target_key, dict) or hasattr(target_key, '__dict__'): |
|
if 'name' not in target_key.keys(): |
|
continue |
|
inject_name = str(target_key['name']) |
|
if inject_name not in global_config: |
|
raise ValueError( |
|
"Missing injection name {} and check it's name in cfg file". |
|
format(k)) |
|
target = global_config[inject_name] |
|
for i, v in target_key.items(): |
|
if i == 'name': |
|
continue |
|
target[i] = v |
|
if isinstance(target, SchemaDict): |
|
cls_kwargs[k] = create(inject_name) |
|
elif isinstance(target_key, str): |
|
if target_key not in global_config: |
|
raise ValueError("Missing injection config:", target_key) |
|
target = global_config[target_key] |
|
if isinstance(target, SchemaDict): |
|
cls_kwargs[k] = create(target_key) |
|
elif hasattr(target, '__dict__'): # serialized object |
|
cls_kwargs[k] = target |
|
else: |
|
raise ValueError("Unsupported injection type:", target_key) |
|
# prevent modification of global config values of reference types |
|
# (e.g., list, dict) from within the created module instances |
|
#kwargs = copy.deepcopy(kwargs) |
|
return cls(**cls_kwargs)
|
|
|