You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
132 lines
3.3 KiB
132 lines
3.3 KiB
from functools import wraps |
|
from inspect import isfunction, isgeneratorfunction, getmembers |
|
from collections.abc import Sequence |
|
from abc import ABC |
|
|
|
import paddle |
|
import paddle.nn as nn |
|
|
|
__all__ = ['GANAdapter', 'OptimizerAdapter'] |
|
|
|
|
|
class _AttrDesc: |
|
def __init__(self, key): |
|
self.key = key |
|
|
|
def __get__(self, instance, owner): |
|
return tuple(getattr(ele, self.key) for ele in instance) |
|
|
|
def __set__(self, instance, value): |
|
for ele in instance: |
|
setattr(ele, self.key, value) |
|
|
|
|
|
def _func_deco(cls, func_name): |
|
@wraps(getattr(cls.__ducktype__, func_name)) |
|
def _wrapper(self, *args, **kwargs): |
|
return tuple(getattr(ele, func_name)(*args, **kwargs) for ele in self) |
|
|
|
return _wrapper |
|
|
|
|
|
def _generator_deco(cls, func_name): |
|
@wraps(getattr(cls.__ducktype__, func_name)) |
|
def _wrapper(self, *args, **kwargs): |
|
for ele in self: |
|
yield from getattr(ele, func_name)(*args, **kwargs) |
|
|
|
return _wrapper |
|
|
|
|
|
class Adapter(Sequence, ABC): |
|
__ducktype__ = object |
|
__ava__ = () |
|
|
|
def __init__(self, *args): |
|
if not all(map(self._check, args)): |
|
raise TypeError("Please check the input type.") |
|
self._seq = tuple(args) |
|
|
|
def __getitem__(self, key): |
|
return self._seq[key] |
|
|
|
def __len__(self): |
|
return len(self._seq) |
|
|
|
def __repr__(self): |
|
return repr(self._seq) |
|
|
|
@classmethod |
|
def _check(cls, obj): |
|
for attr in cls.__ava__: |
|
try: |
|
getattr(obj, attr) |
|
# TODO: Check function signature |
|
except AttributeError: |
|
return False |
|
return True |
|
|
|
|
|
def make_adapter(cls): |
|
members = dict(getmembers(cls.__ducktype__)) |
|
for k in cls.__ava__: |
|
if hasattr(cls, k): |
|
continue |
|
if k in members: |
|
v = members[k] |
|
if isgeneratorfunction(v): |
|
setattr(cls, k, _generator_deco(cls, k)) |
|
elif isfunction(v): |
|
setattr(cls, k, _func_deco(cls, k)) |
|
else: |
|
setattr(cls, k, _AttrDesc(k)) |
|
return cls |
|
|
|
|
|
class GANAdapter(nn.Layer): |
|
__ducktype__ = nn.Layer |
|
__ava__ = ('state_dict', 'set_state_dict', 'train', 'eval') |
|
|
|
def __init__(self, generators, discriminators): |
|
super(GANAdapter, self).__init__() |
|
self.generators = nn.LayerList(generators) |
|
self.discriminators = nn.LayerList(discriminators) |
|
self._m = [*generators, *discriminators] |
|
|
|
def __len__(self): |
|
return len(self._m) |
|
|
|
def __getitem__(self, key): |
|
return self._m[key] |
|
|
|
def __contains__(self, m): |
|
return m in self._m |
|
|
|
def __repr__(self): |
|
return repr(self._m) |
|
|
|
@property |
|
def generator(self): |
|
return self.generators[0] |
|
|
|
@property |
|
def discriminator(self): |
|
return self.discriminators[0] |
|
|
|
|
|
Adapter.register(GANAdapter) |
|
|
|
|
|
@make_adapter |
|
class OptimizerAdapter(Adapter): |
|
__ducktype__ = paddle.optimizer.Optimizer |
|
__ava__ = ('state_dict', 'set_state_dict', 'clear_grad', 'step', 'get_lr') |
|
|
|
def set_state_dict(self, state_dicts): |
|
# Special dispatching rule |
|
for optim, state_dict in zip(self, state_dicts): |
|
optim.set_state_dict(state_dict) |
|
|
|
def get_lr(self): |
|
# Return the lr of the first optimizer |
|
return self[0].get_lr()
|
|
|