diff --git a/tests/rs_models/test_cd_models.py b/tests/rs_models/test_cd_models.py index 2f7cc14..6deed6b 100644 --- a/tests/rs_models/test_cd_models.py +++ b/tests/rs_models/test_cd_models.py @@ -242,7 +242,12 @@ class TestFCCDNModel(TestCDModel): ] # yapf: disable def set_targets(self): - tar_c2 = [self.get_zeros_array(2), [self.get_zeros_array(1)] * 2] + b = self.DEFAULT_BATCH_SIZE + h = self.DEFAULT_HW[0] // 2 + w = self.DEFAULT_HW[1] // 2 + tar_c2 = [ + self.get_zeros_array(2), [self.get_zeros_array(1, b, h, w)] * 2 + ] self.targets = [ tar_c2, tar_c2, [self.get_zeros_array(8), tar_c2[1]], [self.get_zeros_array(2)]