|
|
|
@ -1560,20 +1560,20 @@ class NMSModel(torch.nn.Module): |
|
|
|
|
|
|
|
|
|
preds = self.model(x) |
|
|
|
|
pred = preds[0] if isinstance(preds, tuple) else preds |
|
|
|
|
kwargs = dict(device=pred.device, dtype=pred.dtype) |
|
|
|
|
bs = pred.shape[0] |
|
|
|
|
pred = pred.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84) |
|
|
|
|
extra_shape = pred.shape[-1] - (4 + len(self.model.names)) # extras from Segment, OBB, Pose |
|
|
|
|
if self.args.dynamic and self.args.batch > 1: # batch size needs to always be same due to loop unroll |
|
|
|
|
pad = torch.zeros(torch.max(torch.tensor(self.args.batch - bs), torch.tensor(0)), *pred.shape[1:], **kwargs) |
|
|
|
|
pred = torch.cat((pred, pad)) |
|
|
|
|
boxes, scores, extras = pred.split([4, len(self.model.names), extra_shape], dim=2) |
|
|
|
|
scores, classes = scores.max(dim=-1) |
|
|
|
|
self.args.max_det = min(pred.shape[1], self.args.max_det) # in case num_anchors < max_det |
|
|
|
|
# (N, max_det, 4 coords + 1 class score + 1 class label + extra_shape). |
|
|
|
|
out = torch.zeros( |
|
|
|
|
boxes.shape[0], |
|
|
|
|
self.args.max_det, |
|
|
|
|
boxes.shape[-1] + 2 + extra_shape, |
|
|
|
|
device=boxes.device, |
|
|
|
|
dtype=boxes.dtype, |
|
|
|
|
) |
|
|
|
|
for i, (box, cls, score, extra) in enumerate(zip(boxes, classes, scores, extras)): |
|
|
|
|
out = torch.zeros(bs, self.args.max_det, boxes.shape[-1] + 2 + extra_shape, **kwargs) |
|
|
|
|
for i in range(bs): |
|
|
|
|
box, cls, score, extra = boxes[i], classes[i], scores[i], extras[i] |
|
|
|
|
mask = score > self.args.conf |
|
|
|
|
if self.is_tf: |
|
|
|
|
# TFLite GatherND error if mask is empty |
|
|
|
@ -1593,7 +1593,7 @@ class NMSModel(torch.nn.Module): |
|
|
|
|
if self.args.format == "tflite": # TFLite is already normalized |
|
|
|
|
nmsbox *= multiplier |
|
|
|
|
else: |
|
|
|
|
nmsbox = multiplier * nmsbox / torch.tensor(x.shape[2:], device=box.device, dtype=box.dtype).max() |
|
|
|
|
nmsbox = multiplier * nmsbox / torch.tensor(x.shape[2:], **kwargs).max() |
|
|
|
|
if not self.args.agnostic_nms: # class-specific NMS |
|
|
|
|
end = 2 if self.obb else 4 |
|
|
|
|
# fully explicit expansion otherwise reshape error |
|
|
|
@ -1624,4 +1624,4 @@ class NMSModel(torch.nn.Module): |
|
|
|
|
# Zero-pad to max_det size to avoid reshape error |
|
|
|
|
pad = (0, 0, 0, self.args.max_det - dets.shape[0]) |
|
|
|
|
out[i] = torch.nn.functional.pad(dets, pad) |
|
|
|
|
return (out, preds[1]) if self.model.task == "segment" else out |
|
|
|
|
return (out[:bs], preds[1]) if self.model.task == "segment" else out[:bs] |
|
|
|
|