# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import print_function from __future__ import division import importlib import os import sys import yaml import collections try: collectionsAbc = collections.abc except AttributeError: collectionsAbc = collections from .config.schema import SchemaDict, SharedConfig, extract_schema from .config.yaml_helpers import serializable __all__ = [ 'global_config', 'load_config', 'merge_config', 'get_registered_modules', 'create', 'register', 'serializable', 'dump_value', ] def dump_value(value): # XXX this is hackish, but collections.abc is not available in python 2 if hasattr(value, '__dict__') or isinstance(value, (dict, tuple, list)): value = yaml.dump(value, default_flow_style=True) value = value.replace('\n', '') value = value.replace('...', '') return "'{}'".format(value) else: # primitive types return str(value) class AttrDict(dict): """Single level attribute dict, NOT recursive""" def __init__(self, **kwargs): super(AttrDict, self).__init__() super(AttrDict, self).update(kwargs) def __getattr__(self, key): if key in self: return self[key] raise AttributeError("object has no attribute '{}'".format(key)) global_config = AttrDict() BASE_KEY = '_BASE_' # parse and load _BASE_ recursively def _load_config_with_base(file_path): with open(file_path) as f: file_cfg = yaml.load(f, Loader=yaml.Loader) # NOTE: cfgs outside have higher priority than cfgs in _BASE_ if BASE_KEY in file_cfg: all_base_cfg = AttrDict() base_ymls = list(file_cfg[BASE_KEY]) for base_yml in base_ymls: if base_yml.startswith("~"): base_yml = os.path.expanduser(base_yml) if not base_yml.startswith('/'): base_yml = os.path.join(os.path.dirname(file_path), base_yml) with open(base_yml) as f: base_cfg = _load_config_with_base(base_yml) all_base_cfg = merge_config(base_cfg, all_base_cfg) del file_cfg[BASE_KEY] return merge_config(file_cfg, all_base_cfg) return file_cfg def load_config(file_path): """ Load config from file. Args: file_path (str): Path of the config file to be loaded. Returns: global config """ _, ext = os.path.splitext(file_path) assert ext in ['.yml', '.yaml'], "only support yaml files for now" # load config from file and merge into global config cfg = _load_config_with_base(file_path) cfg['filename'] = os.path.splitext(os.path.split(file_path)[-1])[0] merge_config(cfg) return global_config def dict_merge(dct, merge_dct): """ Recursive dict merge. Inspired by :meth:``dict.update()``, instead of updating only top-level keys, dict_merge recurses down into dicts nested to an arbitrary depth, updating keys. The ``merge_dct`` is merged into ``dct``. Args: dct: dict onto which the merge is executed merge_dct: dct merged into dct Returns: dct """ for k, v in merge_dct.items(): if (k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], collectionsAbc.Mapping)): dict_merge(dct[k], merge_dct[k]) else: dct[k] = merge_dct[k] return dct def merge_config(config, another_cfg=None): """ Merge config into global config or another_cfg. Args: config (dict): Config to be merged. Returns: global config """ global global_config dct = another_cfg or global_config return dict_merge(dct, config) def get_registered_modules(): return {k: v for k, v in global_config.items() if isinstance(v, SchemaDict)} def make_partial(cls): op_module = importlib.import_module(cls.__op__.__module__) op = getattr(op_module, cls.__op__.__name__) cls.__category__ = getattr(cls, '__category__', None) or 'op' def partial_apply(self, *args, **kwargs): kwargs_ = self.__dict__.copy() kwargs_.update(kwargs) return op(*args, **kwargs_) if getattr(cls, '__append_doc__', True): # XXX should default to True? if sys.version_info[0] > 2: cls.__doc__ = "Wrapper for `{}` OP".format(op.__name__) cls.__init__.__doc__ = op.__doc__ cls.__call__ = partial_apply cls.__call__.__doc__ = op.__doc__ else: # XXX work around for python 2 partial_apply.__doc__ = op.__doc__ cls.__call__ = partial_apply return cls def register(cls): """ Register a given module class. Args: cls (type): Module class to be registered. Returns: cls """ if cls.__name__ in global_config: raise ValueError("Module class already registered: {}".format( cls.__name__)) if hasattr(cls, '__op__'): cls = make_partial(cls) global_config[cls.__name__] = extract_schema(cls) return cls def create(cls_or_name, **kwargs): """ Create an instance of given module class. Args: cls_or_name (type or str): Class of which to create instance. Returns: instance of type `cls_or_name` """ assert type(cls_or_name) in [type, str ], "should be a class or name of a class" name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__ if name in global_config: if isinstance(global_config[name], SchemaDict): pass elif hasattr(global_config[name], "__dict__"): # support instance return directly return global_config[name] else: raise ValueError("The module {} is not registered".format(name)) else: raise ValueError("The module {} is not registered".format(name)) config = global_config[name] cls = getattr(config.pymodule, name) cls_kwargs = {} cls_kwargs.update(global_config[name]) # parse `shared` annoation of registered modules if getattr(config, 'shared', None): for k in config.shared: target_key = config[k] shared_conf = config.schema[k].default assert isinstance(shared_conf, SharedConfig) if target_key is not None and not isinstance(target_key, SharedConfig): continue # value is given for the module elif shared_conf.key in global_config: # `key` is present in config cls_kwargs[k] = global_config[shared_conf.key] else: cls_kwargs[k] = shared_conf.default_value # parse `inject` annoation of registered modules if getattr(cls, 'from_config', None): cls_kwargs.update(cls.from_config(config, **kwargs)) if getattr(config, 'inject', None): for k in config.inject: target_key = config[k] # optional dependency if target_key is None: continue if isinstance(target_key, dict) or hasattr(target_key, '__dict__'): if 'name' not in target_key.keys(): continue inject_name = str(target_key['name']) if inject_name not in global_config: raise ValueError( "Missing injection name {} and check it's name in cfg file". format(k)) target = global_config[inject_name] for i, v in target_key.items(): if i == 'name': continue target[i] = v if isinstance(target, SchemaDict): cls_kwargs[k] = create(inject_name) elif isinstance(target_key, str): if target_key not in global_config: raise ValueError("Missing injection config:", target_key) target = global_config[target_key] if isinstance(target, SchemaDict): cls_kwargs[k] = create(target_key) elif hasattr(target, '__dict__'): # serialized object cls_kwargs[k] = target else: raise ValueError("Unsupported injection type:", target_key) # prevent modification of global config values of reference types # (e.g., list, dict) from within the created module instances #kwargs = copy.deepcopy(kwargs) return cls(**cls_kwargs)