OpenMMLab Detection Toolbox and Benchmark https://mmdetection.readthedocs.io/
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

196 lines
6.3 KiB

import numpy as np
import torch
def bbox_flip(bboxes, img_shape, direction='horizontal'):
"""Flip bboxes horizontally or vertically.
Args:
bboxes (Tensor): Shape (..., 4*k)
img_shape (tuple): Image shape.
direction (str): Flip direction, options are "horizontal", "vertical",
"diagonal". Default: "horizontal"
Returns:
Tensor: Flipped bboxes.
"""
assert bboxes.shape[-1] % 4 == 0
assert direction in ['horizontal', 'vertical', 'diagonal']
flipped = bboxes.clone()
if direction == 'horizontal':
flipped[..., 0::4] = img_shape[1] - bboxes[..., 2::4]
flipped[..., 2::4] = img_shape[1] - bboxes[..., 0::4]
elif direction == 'vertical':
flipped[..., 1::4] = img_shape[0] - bboxes[..., 3::4]
flipped[..., 3::4] = img_shape[0] - bboxes[..., 1::4]
else:
flipped[..., 0::4] = img_shape[1] - bboxes[..., 2::4]
flipped[..., 1::4] = img_shape[0] - bboxes[..., 3::4]
flipped[..., 2::4] = img_shape[1] - bboxes[..., 0::4]
flipped[..., 3::4] = img_shape[0] - bboxes[..., 1::4]
return flipped
def bbox_mapping(bboxes,
img_shape,
scale_factor,
flip,
flip_direction='horizontal'):
"""Map bboxes from the original image scale to testing scale."""
new_bboxes = bboxes * bboxes.new_tensor(scale_factor)
if flip:
new_bboxes = bbox_flip(new_bboxes, img_shape, flip_direction)
return new_bboxes
def bbox_mapping_back(bboxes,
img_shape,
scale_factor,
flip,
flip_direction='horizontal'):
"""Map bboxes from testing scale to original image scale."""
new_bboxes = bbox_flip(bboxes, img_shape,
flip_direction) if flip else bboxes
new_bboxes = new_bboxes.view(-1, 4) / new_bboxes.new_tensor(scale_factor)
return new_bboxes.view(bboxes.shape)
def bbox2roi(bbox_list):
"""Convert a list of bboxes to roi format.
Args:
bbox_list (list[Tensor]): a list of bboxes corresponding to a batch
of images.
Returns:
Tensor: shape (n, 5), [batch_ind, x1, y1, x2, y2]
"""
rois_list = []
for img_id, bboxes in enumerate(bbox_list):
if bboxes.size(0) > 0:
img_inds = bboxes.new_full((bboxes.size(0), 1), img_id)
rois = torch.cat([img_inds, bboxes[:, :4]], dim=-1)
else:
rois = bboxes.new_zeros((0, 5))
rois_list.append(rois)
rois = torch.cat(rois_list, 0)
return rois
def roi2bbox(rois):
"""Convert rois to bounding box format.
Args:
rois (torch.Tensor): RoIs with the shape (n, 5) where the first
column indicates batch id of each RoI.
Returns:
list[torch.Tensor]: Converted boxes of corresponding rois.
"""
bbox_list = []
img_ids = torch.unique(rois[:, 0].cpu(), sorted=True)
for img_id in img_ids:
inds = (rois[:, 0] == img_id.item())
bbox = rois[inds, 1:]
bbox_list.append(bbox)
return bbox_list
def bbox2result(bboxes, labels, num_classes):
"""Convert detection results to a list of numpy arrays.
Args:
bboxes (torch.Tensor | np.ndarray): shape (n, 5)
labels (torch.Tensor | np.ndarray): shape (n, )
num_classes (int): class number, including background class
Returns:
list(ndarray): bbox results of each class
"""
if bboxes.shape[0] == 0:
return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)]
else:
if isinstance(bboxes, torch.Tensor):
bboxes = bboxes.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()
return [bboxes[labels == i, :] for i in range(num_classes)]
def distance2bbox(points, distance, max_shape=None):
"""Decode distance prediction to bounding box.
Args:
points (Tensor): Shape (n, 2), [x, y].
distance (Tensor): Distance from the given point to 4
boundaries (left, top, right, bottom).
max_shape (tuple): Shape of the image.
Returns:
Tensor: Decoded bboxes.
"""
x1 = points[:, 0] - distance[:, 0]
y1 = points[:, 1] - distance[:, 1]
x2 = points[:, 0] + distance[:, 2]
y2 = points[:, 1] + distance[:, 3]
if max_shape is not None:
x1 = x1.clamp(min=0, max=max_shape[1])
y1 = y1.clamp(min=0, max=max_shape[0])
x2 = x2.clamp(min=0, max=max_shape[1])
y2 = y2.clamp(min=0, max=max_shape[0])
return torch.stack([x1, y1, x2, y2], -1)
def bbox2distance(points, bbox, max_dis=None, eps=0.1):
"""Decode bounding box based on distances.
Args:
points (Tensor): Shape (n, 2), [x, y].
bbox (Tensor): Shape (n, 4), "xyxy" format
max_dis (float): Upper bound of the distance.
eps (float): a small value to ensure target < max_dis, instead <=
Returns:
Tensor: Decoded distances.
"""
left = points[:, 0] - bbox[:, 0]
top = points[:, 1] - bbox[:, 1]
right = bbox[:, 2] - points[:, 0]
bottom = bbox[:, 3] - points[:, 1]
if max_dis is not None:
left = left.clamp(min=0, max=max_dis - eps)
top = top.clamp(min=0, max=max_dis - eps)
right = right.clamp(min=0, max=max_dis - eps)
bottom = bottom.clamp(min=0, max=max_dis - eps)
return torch.stack([left, top, right, bottom], -1)
def bbox_rescale(bboxes, scale_factor=1.0):
"""Rescale bounding box w.r.t. scale_factor.
Args:
bboxes (Tensor): Shape (n, 4) for bboxes or (n, 5) for rois
scale_factor (float): rescale factor
Returns:
Tensor: Rescaled bboxes.
"""
if bboxes.size(1) == 5:
bboxes_ = bboxes[:, 1:]
inds_ = bboxes[:, 0]
else:
bboxes_ = bboxes
cx = (bboxes_[:, 0] + bboxes_[:, 2]) * 0.5
cy = (bboxes_[:, 1] + bboxes_[:, 3]) * 0.5
w = bboxes_[:, 2] - bboxes_[:, 0]
h = bboxes_[:, 3] - bboxes_[:, 1]
w = w * scale_factor
h = h * scale_factor
x1 = cx - 0.5 * w
x2 = cx + 0.5 * w
y1 = cy - 0.5 * h
y2 = cy + 0.5 * h
if bboxes.size(1) == 5:
rescaled_bboxes = torch.stack([inds_, x1, y1, x2, y2], dim=-1)
else:
rescaled_bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
return rescaled_bboxes