You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

129 lines
3.2 KiB

from functools import wraps
from inspect import isfunction, isgeneratorfunction, getmembers
from collections.abc import Sequence
from abc import ABC
import paddle
import paddle.nn as nn
__all__ = ['GANAdapter', 'OptimizerAdapter']
class _AttrDesc:
def __init__(self, key):
self.key = key
def __get__(self, instance, owner):
return tuple(getattr(ele, self.key) for ele in instance)
def __set__(self, instance, value):
for ele in instance:
setattr(ele, self.key, value)
def _func_deco(cls, func_name):
@wraps(getattr(cls.__ducktype__, func_name))
def _wrapper(self, *args, **kwargs):
return tuple(getattr(ele, func_name)(*args, **kwargs) for ele in self)
return _wrapper
def _generator_deco(cls, func_name):
@wraps(getattr(cls.__ducktype__, func_name))
def _wrapper(self, *args, **kwargs):
for ele in self:
yield from getattr(ele, func_name)(*args, **kwargs)
return _wrapper
class Adapter(Sequence, ABC):
__ducktype__ = object
__ava__ = ()
def __init__(self, *args):
if not all(map(self._check, args)):
raise TypeError("Please check the input type.")
self._seq = tuple(args)
def __getitem__(self, key):
return self._seq[key]
def __len__(self):
return len(self._seq)
def __repr__(self):
return repr(self._seq)
@classmethod
def _check(cls, obj):
for attr in cls.__ava__:
try:
getattr(obj, attr)
# TODO: Check function signature
except AttributeError:
return False
return True
def make_adapter(cls):
members = dict(getmembers(cls.__ducktype__))
for k in cls.__ava__:
if hasattr(cls, k):
continue
if k in members:
v = members[k]
if isgeneratorfunction(v):
setattr(cls, k, _generator_deco(cls, k))
elif isfunction(v):
setattr(cls, k, _func_deco(cls, k))
else:
setattr(cls, k, _AttrDesc(k))
return cls
class GANAdapter(nn.Layer):
__ducktype__ = nn.Layer
__ava__ = ('state_dict', 'set_state_dict', 'train', 'eval')
def __init__(self, generators, discriminators):
super(GANAdapter, self).__init__()
self.generators = nn.LayerList(generators)
self.discriminators = nn.LayerList(discriminators)
self._m = [*generators, *discriminators]
def __len__(self):
return len(self._m)
def __getitem__(self, key):
return self._m[key]
def __contains__(self, m):
return m in self._m
def __repr__(self):
return repr(self._m)
@property
def generator(self):
return self.generators[0]
@property
def discriminator(self):
return self.discriminators[0]
Adapter.register(GANAdapter)
@make_adapter
class OptimizerAdapter(Adapter):
__ducktype__ = paddle.optimizer.Optimizer
__ava__ = ('state_dict', 'set_state_dict', 'clear_grad', 'step', 'get_lr')
# Special dispatching rule
def set_state_dict(self, state_dicts):
for optim, state_dict in zip(self, state_dicts):
optim.set_state_dict(state_dict)