Fix seg test

own
Bobholamovic 2 years ago
parent 5dc93677ba
commit 107e0083bb
  1. 25
      tests/rs_models/test_seg_models.py

@ -52,16 +52,21 @@ class TestFarSegModel(TestSegModel):
MODEL_CLASS = paddlers.rs_models.seg.FarSeg
def set_specs(self):
base_spec = dict(in_channels=3, num_classes=2)
self.specs = [
dict(), dict(
in_channels=6, num_classes=10), dict(
backbone='resnet18', backbone_pretrained=False), dict(
fpn_out_channels=128,
fsr_out_channels=64,
decoder_out_channels=32), dict(scale_aware_proj=False)
]
base_spec,
dict(in_channels=6, num_classes=10),
dict(**base_spec,
backbone='resnet18',
backbone_pretrained=False),
dict(**base_spec,
fpn_out_channels=128,
fsr_out_channels=64,
decoder_out_channels=32),
dict(**base_spec, scale_aware_proj=False)
] # yapf: disable
def set_targets(self):
self.targets = [[self.get_zeros_array(16)], [self.get_zeros_array(10)],
[self.get_zeros_array(16)], [self.get_zeros_array(16)],
[self.get_zeros_array(16)]]
self.targets = [[self.get_zeros_array(2)], [self.get_zeros_array(10)],
[self.get_zeros_array(2)], [self.get_zeros_array(2)],
[self.get_zeros_array(2)]]

Loading…
Cancel
Save