Update style

own
Bobholamovic 3 years ago
parent dc1a407581
commit 27bae3505c
  1. 87
      tests/rs_models/test_cd_models.py

@ -76,12 +76,12 @@ class TestBITModel(TestCDModel):
def set_specs(self):
base_spec = dict(in_channels=3, num_classes=2)
self.specs = [
base_spec, dict(
**base_spec, backbone='resnet34'), dict(
**base_spec, n_stages=3), dict(
**base_spec, enc_depth=4, dec_head_dim=16), dict(
in_channels=4, num_classes=2), dict(
in_channels=3, num_classes=8)
base_spec,
dict(**base_spec, backbone='resnet34'),
dict(**base_spec, n_stages=3),
dict(**base_spec, enc_depth=4, dec_head_dim=16),
dict(in_channels=4, num_classes=2),
dict(in_channels=3, num_classes=8)
]
@ -91,10 +91,9 @@ class TestCDNetModel(TestCDModel):
def set_specs(self):
self.specs = [
dict(
in_channels=6, num_classes=2), dict(
in_channels=8, num_classes=2), dict(
in_channels=6, num_classes=8)
dict(in_channels=6, num_classes=2),
dict(in_channels=8, num_classes=2),
dict(in_channels=6, num_classes=8)
]
@ -103,9 +102,9 @@ class TestChangeStarModel(TestCDModel):
def set_specs(self):
self.specs = [
dict(num_classes=2), dict(num_classes=10), dict(
num_classes=2, mid_channels=128, num_convs=2), dict(
num_classes=2, _phase='eval', _stop_grad=True)
dict(num_classes=2), dict(num_classes=10),
dict(num_classes=2, mid_channels=128, num_convs=2),
dict(num_classes=2, _phase='eval', _stop_grad=True)
]
def set_targets(self):
@ -124,11 +123,11 @@ class TestDSAMNetModel(TestCDModel):
def set_specs(self):
base_spec = dict(in_channels=3, num_classes=2)
self.specs = [
base_spec, dict(
in_channels=8, num_classes=2), dict(
in_channels=3, num_classes=8), dict(
**base_spec, ca_ratio=4, sa_kernel=5), dict(
**base_spec, _phase='eval', _stop_grad=True)
base_spec,
dict(in_channels=8, num_classes=2),
dict(in_channels=3, num_classes=8),
dict(**base_spec, ca_ratio=4, sa_kernel=5),
dict(*base_spec, _phase='eval', _stop_grad=True)
]
def set_targets(self):
@ -145,9 +144,9 @@ class TestDSIFNModel(TestCDModel):
def set_specs(self):
self.specs = [
dict(num_classes=2), dict(num_classes=10), dict(
num_classes=2, use_dropout=True), dict(
num_classes=2, _phase='eval', _stop_grad=True)
dict(num_classes=2), dict(num_classes=10),
dict(num_classes=2, use_dropout=True),
dict(num_classes=2, _phase='eval', _stop_grad=True)
]
def set_targets(self):
@ -165,11 +164,10 @@ class TestFCEarlyFusionModel(TestCDModel):
def set_specs(self):
self.specs = [
dict(
in_channels=6, num_classes=2), dict(
in_channels=8, num_classes=2), dict(
in_channels=6, num_classes=8), dict(
in_channels=6, num_classes=2, use_dropout=True)
dict(in_channels=6, num_classes=2),
dict(in_channels=8, num_classes=2),
dict(in_channels=6, num_classes=8),
dict(in_channels=6, num_classes=2, use_dropout=True)
]
@ -178,11 +176,10 @@ class TestFCSiamConcModel(TestCDModel):
def set_specs(self):
self.specs = [
dict(
in_channels=3, num_classes=2), dict(
in_channels=8, num_classes=2), dict(
in_channels=3, num_classes=8), dict(
in_channels=3, num_classes=2, use_dropout=True)
dict(in_channels=3, num_classes=2),
dict(in_channels=8, num_classes=2),
dict(in_channels=3, num_classes=8),
dict(in_channels=3, num_classes=2, use_dropout=True)
]
@ -191,11 +188,10 @@ class TestFCSiamDiffModel(TestCDModel):
def set_specs(self):
self.specs = [
dict(
in_channels=3, num_classes=2), dict(
in_channels=8, num_classes=2), dict(
in_channels=3, num_classes=8), dict(
in_channels=3, num_classes=2, use_dropout=True)
dict(in_channels=3, num_classes=2),
dict(in_channels=8, num_classes=2),
dict(in_channels=3, num_classes=8),
dict(in_channels=3, num_classes=2, use_dropout=True)
]
@ -204,11 +200,10 @@ class TestSNUNetModel(TestCDModel):
def set_specs(self):
self.specs = [
dict(
in_channels=3, num_classes=2), dict(
in_channels=8, num_classes=2), dict(
in_channels=3, num_classes=8), dict(
in_channels=3, num_classes=2, width=64)
dict(in_channels=3, num_classes=2),
dict(in_channels=8, num_classes=2),
dict(in_channels=3, num_classes=8),
dict(in_channels=3, num_classes=2, width=64)
]
@ -218,9 +213,9 @@ class TestSTANetModel(TestCDModel):
def set_specs(self):
base_spec = dict(in_channels=3, num_classes=2)
self.specs = [
base_spec, dict(
in_channels=8, num_classes=2), dict(
in_channels=3, num_classes=8), dict(
**base_spec, att_type='PAM'), dict(
**base_spec, ds_factor=4)
base_spec,
dict(in_channels=8, num_classes=2),
dict(in_channels=3, num_classes=8),
dict(**base_spec, att_type='PAM'),
dict(**base_spec, ds_factor=4)
]

Loading…
Cancel
Save