OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
196 lines
6.3 KiB
196 lines
6.3 KiB
import numpy as np |
|
import torch |
|
|
|
|
|
def bbox_flip(bboxes, img_shape, direction='horizontal'): |
|
"""Flip bboxes horizontally or vertically. |
|
|
|
Args: |
|
bboxes (Tensor): Shape (..., 4*k) |
|
img_shape (tuple): Image shape. |
|
direction (str): Flip direction, options are "horizontal", "vertical", |
|
"diagonal". Default: "horizontal" |
|
|
|
Returns: |
|
Tensor: Flipped bboxes. |
|
""" |
|
assert bboxes.shape[-1] % 4 == 0 |
|
assert direction in ['horizontal', 'vertical', 'diagonal'] |
|
flipped = bboxes.clone() |
|
if direction == 'horizontal': |
|
flipped[..., 0::4] = img_shape[1] - bboxes[..., 2::4] |
|
flipped[..., 2::4] = img_shape[1] - bboxes[..., 0::4] |
|
elif direction == 'vertical': |
|
flipped[..., 1::4] = img_shape[0] - bboxes[..., 3::4] |
|
flipped[..., 3::4] = img_shape[0] - bboxes[..., 1::4] |
|
else: |
|
flipped[..., 0::4] = img_shape[1] - bboxes[..., 2::4] |
|
flipped[..., 1::4] = img_shape[0] - bboxes[..., 3::4] |
|
flipped[..., 2::4] = img_shape[1] - bboxes[..., 0::4] |
|
flipped[..., 3::4] = img_shape[0] - bboxes[..., 1::4] |
|
return flipped |
|
|
|
|
|
def bbox_mapping(bboxes, |
|
img_shape, |
|
scale_factor, |
|
flip, |
|
flip_direction='horizontal'): |
|
"""Map bboxes from the original image scale to testing scale.""" |
|
new_bboxes = bboxes * bboxes.new_tensor(scale_factor) |
|
if flip: |
|
new_bboxes = bbox_flip(new_bboxes, img_shape, flip_direction) |
|
return new_bboxes |
|
|
|
|
|
def bbox_mapping_back(bboxes, |
|
img_shape, |
|
scale_factor, |
|
flip, |
|
flip_direction='horizontal'): |
|
"""Map bboxes from testing scale to original image scale.""" |
|
new_bboxes = bbox_flip(bboxes, img_shape, |
|
flip_direction) if flip else bboxes |
|
new_bboxes = new_bboxes.view(-1, 4) / new_bboxes.new_tensor(scale_factor) |
|
return new_bboxes.view(bboxes.shape) |
|
|
|
|
|
def bbox2roi(bbox_list): |
|
"""Convert a list of bboxes to roi format. |
|
|
|
Args: |
|
bbox_list (list[Tensor]): a list of bboxes corresponding to a batch |
|
of images. |
|
|
|
Returns: |
|
Tensor: shape (n, 5), [batch_ind, x1, y1, x2, y2] |
|
""" |
|
rois_list = [] |
|
for img_id, bboxes in enumerate(bbox_list): |
|
if bboxes.size(0) > 0: |
|
img_inds = bboxes.new_full((bboxes.size(0), 1), img_id) |
|
rois = torch.cat([img_inds, bboxes[:, :4]], dim=-1) |
|
else: |
|
rois = bboxes.new_zeros((0, 5)) |
|
rois_list.append(rois) |
|
rois = torch.cat(rois_list, 0) |
|
return rois |
|
|
|
|
|
def roi2bbox(rois): |
|
"""Convert rois to bounding box format. |
|
|
|
Args: |
|
rois (torch.Tensor): RoIs with the shape (n, 5) where the first |
|
column indicates batch id of each RoI. |
|
|
|
Returns: |
|
list[torch.Tensor]: Converted boxes of corresponding rois. |
|
""" |
|
bbox_list = [] |
|
img_ids = torch.unique(rois[:, 0].cpu(), sorted=True) |
|
for img_id in img_ids: |
|
inds = (rois[:, 0] == img_id.item()) |
|
bbox = rois[inds, 1:] |
|
bbox_list.append(bbox) |
|
return bbox_list |
|
|
|
|
|
def bbox2result(bboxes, labels, num_classes): |
|
"""Convert detection results to a list of numpy arrays. |
|
|
|
Args: |
|
bboxes (torch.Tensor | np.ndarray): shape (n, 5) |
|
labels (torch.Tensor | np.ndarray): shape (n, ) |
|
num_classes (int): class number, including background class |
|
|
|
Returns: |
|
list(ndarray): bbox results of each class |
|
""" |
|
if bboxes.shape[0] == 0: |
|
return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)] |
|
else: |
|
if isinstance(bboxes, torch.Tensor): |
|
bboxes = bboxes.detach().cpu().numpy() |
|
labels = labels.detach().cpu().numpy() |
|
return [bboxes[labels == i, :] for i in range(num_classes)] |
|
|
|
|
|
def distance2bbox(points, distance, max_shape=None): |
|
"""Decode distance prediction to bounding box. |
|
|
|
Args: |
|
points (Tensor): Shape (n, 2), [x, y]. |
|
distance (Tensor): Distance from the given point to 4 |
|
boundaries (left, top, right, bottom). |
|
max_shape (tuple): Shape of the image. |
|
|
|
Returns: |
|
Tensor: Decoded bboxes. |
|
""" |
|
x1 = points[:, 0] - distance[:, 0] |
|
y1 = points[:, 1] - distance[:, 1] |
|
x2 = points[:, 0] + distance[:, 2] |
|
y2 = points[:, 1] + distance[:, 3] |
|
if max_shape is not None: |
|
x1 = x1.clamp(min=0, max=max_shape[1]) |
|
y1 = y1.clamp(min=0, max=max_shape[0]) |
|
x2 = x2.clamp(min=0, max=max_shape[1]) |
|
y2 = y2.clamp(min=0, max=max_shape[0]) |
|
return torch.stack([x1, y1, x2, y2], -1) |
|
|
|
|
|
def bbox2distance(points, bbox, max_dis=None, eps=0.1): |
|
"""Decode bounding box based on distances. |
|
|
|
Args: |
|
points (Tensor): Shape (n, 2), [x, y]. |
|
bbox (Tensor): Shape (n, 4), "xyxy" format |
|
max_dis (float): Upper bound of the distance. |
|
eps (float): a small value to ensure target < max_dis, instead <= |
|
|
|
Returns: |
|
Tensor: Decoded distances. |
|
""" |
|
left = points[:, 0] - bbox[:, 0] |
|
top = points[:, 1] - bbox[:, 1] |
|
right = bbox[:, 2] - points[:, 0] |
|
bottom = bbox[:, 3] - points[:, 1] |
|
if max_dis is not None: |
|
left = left.clamp(min=0, max=max_dis - eps) |
|
top = top.clamp(min=0, max=max_dis - eps) |
|
right = right.clamp(min=0, max=max_dis - eps) |
|
bottom = bottom.clamp(min=0, max=max_dis - eps) |
|
return torch.stack([left, top, right, bottom], -1) |
|
|
|
|
|
def bbox_rescale(bboxes, scale_factor=1.0): |
|
"""Rescale bounding box w.r.t. scale_factor. |
|
|
|
Args: |
|
bboxes (Tensor): Shape (n, 4) for bboxes or (n, 5) for rois |
|
scale_factor (float): rescale factor |
|
|
|
Returns: |
|
Tensor: Rescaled bboxes. |
|
""" |
|
if bboxes.size(1) == 5: |
|
bboxes_ = bboxes[:, 1:] |
|
inds_ = bboxes[:, 0] |
|
else: |
|
bboxes_ = bboxes |
|
cx = (bboxes_[:, 0] + bboxes_[:, 2]) * 0.5 |
|
cy = (bboxes_[:, 1] + bboxes_[:, 3]) * 0.5 |
|
w = bboxes_[:, 2] - bboxes_[:, 0] |
|
h = bboxes_[:, 3] - bboxes_[:, 1] |
|
w = w * scale_factor |
|
h = h * scale_factor |
|
x1 = cx - 0.5 * w |
|
x2 = cx + 0.5 * w |
|
y1 = cy - 0.5 * h |
|
y2 = cy + 0.5 * h |
|
if bboxes.size(1) == 5: |
|
rescaled_bboxes = torch.stack([inds_, x1, y1, x2, y2], dim=-1) |
|
else: |
|
rescaled_bboxes = torch.stack([x1, y1, x2, y2], dim=-1) |
|
return rescaled_bboxes
|
|
|