|
|
|
@ -152,19 +152,34 @@ class PoseValidator(DetectionValidator): |
|
|
|
|
|
|
|
|
|
def _process_batch(self, detections, gt_bboxes, gt_cls, pred_kpts=None, gt_kpts=None): |
|
|
|
|
""" |
|
|
|
|
Return correct prediction matrix. |
|
|
|
|
Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
detections (torch.Tensor): Tensor of shape [N, 6] representing detections. |
|
|
|
|
Each detection is of the format: x1, y1, x2, y2, conf, class. |
|
|
|
|
labels (torch.Tensor): Tensor of shape [M, 5] representing labels. |
|
|
|
|
Each label is of the format: class, x1, y1, x2, y2. |
|
|
|
|
pred_kpts (torch.Tensor, optional): Tensor of shape [N, 51] representing predicted keypoints. |
|
|
|
|
51 corresponds to 17 keypoints each with 3 values. |
|
|
|
|
gt_kpts (torch.Tensor, optional): Tensor of shape [N, 51] representing ground truth keypoints. |
|
|
|
|
detections (torch.Tensor): Tensor with shape (N, 6) representing detection boxes and scores, where each |
|
|
|
|
detection is of the format (x1, y1, x2, y2, conf, class). |
|
|
|
|
gt_bboxes (torch.Tensor): Tensor with shape (M, 4) representing ground truth bounding boxes, where each |
|
|
|
|
box is of the format (x1, y1, x2, y2). |
|
|
|
|
gt_cls (torch.Tensor): Tensor with shape (M,) representing ground truth class indices. |
|
|
|
|
pred_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing predicted keypoints, where |
|
|
|
|
51 corresponds to 17 keypoints each having 3 values. |
|
|
|
|
gt_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing ground truth keypoints. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
torch.Tensor: Correct prediction matrix of shape [N, 10] for 10 IoU levels. |
|
|
|
|
torch.Tensor: A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels, |
|
|
|
|
where N is the number of detections. |
|
|
|
|
|
|
|
|
|
Example: |
|
|
|
|
```python |
|
|
|
|
detections = torch.rand(100, 6) # 100 predictions: (x1, y1, x2, y2, conf, class) |
|
|
|
|
gt_bboxes = torch.rand(50, 4) # 50 ground truth boxes: (x1, y1, x2, y2) |
|
|
|
|
gt_cls = torch.randint(0, 2, (50,)) # 50 ground truth class indices |
|
|
|
|
pred_kpts = torch.rand(100, 51) # 100 predicted keypoints |
|
|
|
|
gt_kpts = torch.rand(50, 51) # 50 ground truth keypoints |
|
|
|
|
correct_preds = _process_batch(detections, gt_bboxes, gt_cls, pred_kpts, gt_kpts) |
|
|
|
|
``` |
|
|
|
|
|
|
|
|
|
Note: |
|
|
|
|
`0.53` scale factor used in area computation is referenced from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384. |
|
|
|
|
""" |
|
|
|
|
if pred_kpts is not None and gt_kpts is not None: |
|
|
|
|
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384 |
|
|
|
|