Update tasks.py

mct-2.1.1
Laughing-q 8 months ago
parent 6e607a1362
commit 9560c868af
  1. 10
      ultralytics/nn/tasks.py

@ -97,13 +97,9 @@ class BaseModel(nn.Module):
Returns:
(torch.Tensor): The output of the network.
"""
y = []
for m in self.model:
if m.f != -1:
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]
x = m(x)
y.append(x if m.i in self.save else None)
return x
if isinstance(x, dict): # for cases of training and validating while training.
return self.loss(x)
return self.predict(x)
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
"""

Loading…
Cancel
Save