diff --git a/tests/rs_models/test_seg_models.py b/tests/rs_models/test_seg_models.py index 813cfd9..b4d6ab5 100644 --- a/tests/rs_models/test_seg_models.py +++ b/tests/rs_models/test_seg_models.py @@ -52,16 +52,21 @@ class TestFarSegModel(TestSegModel): MODEL_CLASS = paddlers.rs_models.seg.FarSeg def set_specs(self): + base_spec = dict(in_channels=3, num_classes=2) self.specs = [ - dict(), dict( - in_channels=6, num_classes=10), dict( - backbone='resnet18', backbone_pretrained=False), dict( - fpn_out_channels=128, - fsr_out_channels=64, - decoder_out_channels=32), dict(scale_aware_proj=False) - ] + base_spec, + dict(in_channels=6, num_classes=10), + dict(**base_spec, + backbone='resnet18', + backbone_pretrained=False), + dict(**base_spec, + fpn_out_channels=128, + fsr_out_channels=64, + decoder_out_channels=32), + dict(**base_spec, scale_aware_proj=False) + ] # yapf: disable def set_targets(self): - self.targets = [[self.get_zeros_array(16)], [self.get_zeros_array(10)], - [self.get_zeros_array(16)], [self.get_zeros_array(16)], - [self.get_zeros_array(16)]] + self.targets = [[self.get_zeros_array(2)], [self.get_zeros_array(10)], + [self.get_zeros_array(2)], [self.get_zeros_array(2)], + [self.get_zeros_array(2)]]