|
|
|
@ -91,6 +91,7 @@ class RTDETRValidator(DetectionValidator): |
|
|
|
|
"""Apply Non-maximum suppression to prediction outputs.""" |
|
|
|
|
bs, _, nd = preds[0].shape |
|
|
|
|
bboxes, scores = preds[0].split((4, nd - 4), dim=-1) |
|
|
|
|
bboxes *= self.args.imgsz |
|
|
|
|
outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs |
|
|
|
|
for i, bbox in enumerate(bboxes): # (300, 4) |
|
|
|
|
bbox = ops.xywh2xyxy(bbox) |
|
|
|
@ -126,8 +127,8 @@ class RTDETRValidator(DetectionValidator): |
|
|
|
|
if self.args.single_cls: |
|
|
|
|
pred[:, 5] = 0 |
|
|
|
|
predn = pred.clone() |
|
|
|
|
predn[..., [0, 2]] *= shape[1] # native-space pred |
|
|
|
|
predn[..., [1, 3]] *= shape[0] # native-space pred |
|
|
|
|
predn[..., [0, 2]] *= shape[1] / self.args.imgsz # native-space pred |
|
|
|
|
predn[..., [1, 3]] *= shape[0] / self.args.imgsz # native-space pred |
|
|
|
|
|
|
|
|
|
# Evaluate |
|
|
|
|
if nl: |
|
|
|
|