diff --git a/models/common.py b/models/common.py index 3141900..63b8dd9 100644 --- a/models/common.py +++ b/models/common.py @@ -153,7 +153,7 @@ class PostSeg(nn.Module): self.shape = shape x = [i.view(b, self.no, -1) for i in res] y = torch.cat(x, 2) - boxes, scores = y[:, :, ...], y[:, b_reg_num:, ...].sigmoid() + boxes, scores = y[:, :b_reg_num, ...], y[:, b_reg_num:, ...].sigmoid() boxes = boxes.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2) boxes = boxes.softmax(-1) @ torch.arange(self.reg_max).to(boxes) boxes0, boxes1 = -boxes[:, :2, ...], boxes[:, 2:, ...]