|
|
|
@ -23,12 +23,12 @@ from attach_tools import Attach |
|
|
|
|
attach = Attach.to(paddlers.tasks.change_detector) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_trainer(net_type, *args, **kwargs): |
|
|
|
|
def make_trainer(net_type, attach_trainer=True): |
|
|
|
|
def _init_func(self, |
|
|
|
|
num_classes=2, |
|
|
|
|
use_mixed_loss=False, |
|
|
|
|
losses=None, |
|
|
|
|
**params): |
|
|
|
|
**_params_): |
|
|
|
|
sig = inspect.signature(net_type.__init__) |
|
|
|
|
net_params = { |
|
|
|
|
k: p.default |
|
|
|
@ -36,7 +36,13 @@ def make_trainer(net_type, *args, **kwargs): |
|
|
|
|
} |
|
|
|
|
net_params.pop('self', None) |
|
|
|
|
net_params.pop('num_classes', None) |
|
|
|
|
net_params.update(params) |
|
|
|
|
# Special rule to parse arguments from `_params_`. |
|
|
|
|
# When using pdrs.tasks.load_model, `_params_`` is a dict with the key '_params_'. |
|
|
|
|
# This bypasses the dynamic modification/creation of function signature. |
|
|
|
|
if '_params_' not in _params_: |
|
|
|
|
net_params.update(_params_) |
|
|
|
|
else: |
|
|
|
|
net_params.update(_params_['_params_']) |
|
|
|
|
|
|
|
|
|
super(trainer_type, self).__init__( |
|
|
|
|
model_name=net_type.__name__, |
|
|
|
@ -52,7 +58,13 @@ def make_trainer(net_type, *args, **kwargs): |
|
|
|
|
|
|
|
|
|
trainer_type = type(trainer_name, (BaseChangeDetector, ), |
|
|
|
|
{'__init__': _init_func}) |
|
|
|
|
if attach_trainer: |
|
|
|
|
trainer_type = attach(trainer_type) |
|
|
|
|
return trainer_type |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_trainer_and_build(net_type, *args, **kwargs): |
|
|
|
|
trainer_type = make_trainer(net_type, attach_trainer=True) |
|
|
|
|
return trainer_type(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|