Fix seg test

own
Bobholamovic 2 years ago
parent 5dc93677ba
commit 107e0083bb
  1. 21
      tests/rs_models/test_seg_models.py

@ -52,16 +52,21 @@ class TestFarSegModel(TestSegModel):
MODEL_CLASS = paddlers.rs_models.seg.FarSeg MODEL_CLASS = paddlers.rs_models.seg.FarSeg
def set_specs(self): def set_specs(self):
base_spec = dict(in_channels=3, num_classes=2)
self.specs = [ self.specs = [
dict(), dict( base_spec,
in_channels=6, num_classes=10), dict( dict(in_channels=6, num_classes=10),
backbone='resnet18', backbone_pretrained=False), dict( dict(**base_spec,
backbone='resnet18',
backbone_pretrained=False),
dict(**base_spec,
fpn_out_channels=128, fpn_out_channels=128,
fsr_out_channels=64, fsr_out_channels=64,
decoder_out_channels=32), dict(scale_aware_proj=False) decoder_out_channels=32),
] dict(**base_spec, scale_aware_proj=False)
] # yapf: disable
def set_targets(self): def set_targets(self):
self.targets = [[self.get_zeros_array(16)], [self.get_zeros_array(10)], self.targets = [[self.get_zeros_array(2)], [self.get_zeros_array(10)],
[self.get_zeros_array(16)], [self.get_zeros_array(16)], [self.get_zeros_array(2)], [self.get_zeros_array(2)],
[self.get_zeros_array(16)]] [self.get_zeros_array(2)]]

Loading…
Cancel
Save