|
|
|
@ -12,6 +12,8 @@ |
|
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
|
# limitations under the License. |
|
|
|
|
|
|
|
|
|
import inspect |
|
|
|
|
|
|
|
|
|
import paddle |
|
|
|
|
import paddlers |
|
|
|
|
from paddlers.tasks.change_detector import BaseChangeDetector |
|
|
|
@ -27,12 +29,21 @@ def make_trainer(net_type, *args, **kwargs): |
|
|
|
|
use_mixed_loss=False, |
|
|
|
|
losses=None, |
|
|
|
|
**params): |
|
|
|
|
super().__init__( |
|
|
|
|
sig = inspect.signature(net_type.__init__) |
|
|
|
|
net_params = { |
|
|
|
|
k: p.default |
|
|
|
|
for k, p in sig.parameters.items() if not p.default is p.empty |
|
|
|
|
} |
|
|
|
|
net_params.pop('self', None) |
|
|
|
|
net_params.pop('num_classes', None) |
|
|
|
|
net_params.update(params) |
|
|
|
|
|
|
|
|
|
super(trainer_type, self).__init__( |
|
|
|
|
model_name=net_type.__name__, |
|
|
|
|
num_classes=num_classes, |
|
|
|
|
use_mixed_loss=use_mixed_loss, |
|
|
|
|
losses=losses, |
|
|
|
|
**params) |
|
|
|
|
**net_params) |
|
|
|
|
|
|
|
|
|
if not issubclass(net_type, paddle.nn.Layer): |
|
|
|
|
raise TypeError("net must be a subclass of paddle.nn.Layer") |
|
|
|
|