|
|
@ -97,13 +97,9 @@ class BaseModel(nn.Module): |
|
|
|
Returns: |
|
|
|
Returns: |
|
|
|
(torch.Tensor): The output of the network. |
|
|
|
(torch.Tensor): The output of the network. |
|
|
|
""" |
|
|
|
""" |
|
|
|
y = [] |
|
|
|
if isinstance(x, dict): # for cases of training and validating while training. |
|
|
|
for m in self.model: |
|
|
|
return self.loss(x) |
|
|
|
if m.f != -1: |
|
|
|
return self.predict(x) |
|
|
|
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] |
|
|
|
|
|
|
|
x = m(x) |
|
|
|
|
|
|
|
y.append(x if m.i in self.save else None) |
|
|
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(self, x, profile=False, visualize=False, augment=False, embed=None): |
|
|
|
def predict(self, x, profile=False, visualize=False, augment=False, embed=None): |
|
|
|
""" |
|
|
|
""" |
|
|
|