|
|
@ -104,7 +104,7 @@ class PostDetect(nn.Module): |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
def forward(self, x): |
|
|
|
shape = x[0].shape |
|
|
|
shape = x[0].shape |
|
|
|
b, res = shape[0], [] |
|
|
|
b, res, b_reg_num = shape[0], [], self.reg_max * 4 |
|
|
|
for i in range(self.nl): |
|
|
|
for i in range(self.nl): |
|
|
|
res.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)) |
|
|
|
res.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)) |
|
|
|
if self.dynamic or self.shape != shape: |
|
|
|
if self.dynamic or self.shape != shape: |
|
|
@ -113,15 +113,14 @@ class PostDetect(nn.Module): |
|
|
|
self.shape = shape |
|
|
|
self.shape = shape |
|
|
|
x = [i.view(b, self.no, -1) for i in res] |
|
|
|
x = [i.view(b, self.no, -1) for i in res] |
|
|
|
y = torch.cat(x, 2) |
|
|
|
y = torch.cat(x, 2) |
|
|
|
box, cls = y[:, :self.reg_max * 4, ...], y[:, self.reg_max * 4:, |
|
|
|
boxes, scores = y[:, :b_reg_num, ...], y[:, b_reg_num:, ...].sigmoid() |
|
|
|
...].sigmoid() |
|
|
|
boxes = boxes.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2) |
|
|
|
box = box.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2).contiguous() |
|
|
|
boxes = boxes.softmax(-1) @ torch.arange(self.reg_max).to(boxes) |
|
|
|
box = box.softmax(-1) @ torch.arange(self.reg_max).to(box) |
|
|
|
boxes0, boxes1 = -boxes[:, :2, ...], boxes[:, 2:, ...] |
|
|
|
box0, box1 = -box[:, :2, ...], box[:, 2:, ...] |
|
|
|
boxes = self.anchors.repeat(b, 2, 1) + torch.cat([boxes0, boxes1], 1) |
|
|
|
box = self.anchors.repeat(b, 2, 1) + torch.cat([box0, box1], 1) |
|
|
|
boxes = boxes * self.strides |
|
|
|
box = box * self.strides |
|
|
|
|
|
|
|
|
|
|
|
return TRT_NMS.apply(boxes.transpose(1, 2), scores.transpose(1, 2), |
|
|
|
return TRT_NMS.apply(box.transpose(1, 2), cls.transpose(1, 2), |
|
|
|
|
|
|
|
self.iou_thres, self.conf_thres, self.topk) |
|
|
|
self.iou_thres, self.conf_thres, self.topk) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -139,30 +138,29 @@ class PostSeg(nn.Module): |
|
|
|
mc = torch.cat( |
|
|
|
mc = torch.cat( |
|
|
|
[self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], |
|
|
|
[self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], |
|
|
|
2) # mask coefficients |
|
|
|
2) # mask coefficients |
|
|
|
box, score, cls = self.forward_det(x) |
|
|
|
boxes, scores, labels = self.forward_det(x) |
|
|
|
out = torch.cat([box, score, cls, mc.transpose(1, 2)], 2) |
|
|
|
out = torch.cat([boxes, scores, labels.float(), mc.transpose(1, 2)], 2) |
|
|
|
return out, p.flatten(2) |
|
|
|
return out, p.flatten(2) |
|
|
|
|
|
|
|
|
|
|
|
def forward_det(self, x): |
|
|
|
def forward_det(self, x): |
|
|
|
shape = x[0].shape |
|
|
|
shape = x[0].shape |
|
|
|
b, res = shape[0], [] |
|
|
|
b, res, b_reg_num = shape[0], [], self.reg_max * 4 |
|
|
|
for i in range(self.nl): |
|
|
|
for i in range(self.nl): |
|
|
|
res.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)) |
|
|
|
res.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)) |
|
|
|
if self.dynamic or self.shape != shape: |
|
|
|
if self.dynamic or self.shape != shape: |
|
|
|
self.anchors, self.strides = (x.transpose( |
|
|
|
self.anchors, self.strides = \ |
|
|
|
0, 1) for x in make_anchors(x, self.stride, 0.5)) |
|
|
|
(x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) |
|
|
|
self.shape = shape |
|
|
|
self.shape = shape |
|
|
|
x = [i.view(b, self.no, -1) for i in res] |
|
|
|
x = [i.view(b, self.no, -1) for i in res] |
|
|
|
y = torch.cat(x, 2) |
|
|
|
y = torch.cat(x, 2) |
|
|
|
box, cls = y[:, :self.reg_max * 4, ...], y[:, self.reg_max * 4:, |
|
|
|
boxes, scores = y[:, :, ...], y[:, b_reg_num:, ...].sigmoid() |
|
|
|
...].sigmoid() |
|
|
|
boxes = boxes.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2) |
|
|
|
box = box.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2).contiguous() |
|
|
|
boxes = boxes.softmax(-1) @ torch.arange(self.reg_max).to(boxes) |
|
|
|
box = box.softmax(-1) @ torch.arange(self.reg_max).to(box) |
|
|
|
boxes0, boxes1 = -boxes[:, :2, ...], boxes[:, 2:, ...] |
|
|
|
box0, box1 = -box[:, :2, ...], box[:, 2:, ...] |
|
|
|
boxes = self.anchors.repeat(b, 2, 1) + torch.cat([boxes0, boxes1], 1) |
|
|
|
box = self.anchors.repeat(b, 2, 1) + torch.cat([box0, box1], 1) |
|
|
|
boxes = boxes * self.strides |
|
|
|
box = box * self.strides |
|
|
|
scores, labels = scores.transpose(1, 2).max(dim=-1, keepdim=True) |
|
|
|
score, cls = cls.transpose(1, 2).max(dim=-1, keepdim=True) |
|
|
|
return boxes.transpose(1, 2), scores, labels |
|
|
|
return box.transpose(1, 2), score, cls |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def optim(module: nn.Module): |
|
|
|
def optim(module: nn.Module): |
|
|
|