Update head.py

mct-2.1.1
Laughing-q 7 months ago
parent 00a4ef7e30
commit e0b3b1bf0f
  1. 11
      ultralytics/nn/modules/head.py

@ -48,8 +48,6 @@ class Detect(nn.Module):
self.one2one_cv2 = copy.deepcopy(self.cv2)
self.one2one_cv3 = copy.deepcopy(self.cv3)
self.relu = nn.ReLU()
def forward(self, x):
"""Concatenates and returns predicted bounding boxes and class probabilities."""
if self.end2end:
@ -120,11 +118,10 @@ class Detect(nn.Module):
dbox = self.decode_bboxes(
self.dfl(box) * self.strides, self.anchors.unsqueeze(0) * self.strides, xywh=False
)
# NOTE: the relu could be removed in the future.
y1 = self.relu(dbox[:, 0, :])
x1 = self.relu(dbox[:, 1, :])
y2 = self.relu(dbox[:, 2, :])
x2 = self.relu(dbox[:, 3, :])
y1 = dbox[:, 0, :]
x1 = dbox[:, 1, :]
y2 = dbox[:, 2, :]
x2 = dbox[:, 3, :]
y_bb = torch.stack((x1, y1, x2, y2), 1).transpose(1, 2)
return y_bb, cls.sigmoid().permute(0, 2, 1)
else:

Loading…
Cancel
Save