Fix `xyxyxyxy2xywhr` for Numpy inputs (#13273)

pull/13281/head
Glenn Jocher 6 months ago committed by GitHub
parent 7593b4a301
commit dd13707bf8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 43
      ultralytics/utils/ops.py

@ -518,59 +518,58 @@ def ltwh2xywh(x):
return y return y
def xyxyxyxy2xywhr(corners): def xyxyxyxy2xywhr(x):
""" """
Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. Rotation values are Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. Rotation values are
expected in degrees from 0 to 90. expected in degrees from 0 to 90.
Args: Args:
corners (numpy.ndarray | torch.Tensor): Input corners of shape (n, 8). x (numpy.ndarray | torch.Tensor): Input box corners [xy1, xy2, xy3, xy4] of shape (n, 8).
Returns: Returns:
(numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5). (numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5).
""" """
is_torch = isinstance(corners, torch.Tensor) is_torch = isinstance(x, torch.Tensor)
points = corners.cpu().numpy() if is_torch else corners points = x.cpu().numpy() if is_torch else x
points = points.reshape(len(corners), -1, 2) points = points.reshape(len(x), -1, 2)
rboxes = [] rboxes = []
for pts in points: for pts in points:
# NOTE: Use cv2.minAreaRect to get accurate xywhr, # NOTE: Use cv2.minAreaRect to get accurate xywhr,
# especially some objects are cut off by augmentations in dataloader. # especially some objects are cut off by augmentations in dataloader.
(x, y), (w, h), angle = cv2.minAreaRect(pts) (cx, cy), (w, h), angle = cv2.minAreaRect(pts)
rboxes.append([x, y, w, h, angle / 180 * np.pi]) rboxes.append([cx, cy, w, h, angle / 180 * np.pi])
return ( return torch.tensor(rboxes, device=x.device, dtype=x.dtype) if is_torch else np.asarray(rboxes)
torch.tensor(rboxes, device=corners.device, dtype=corners.dtype)
if is_torch
else np.asarray(rboxes, dtype=points.dtype)
) # rboxes
def xywhr2xyxyxyxy(rboxes): def xywhr2xyxyxyxy(x):
""" """
Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should
be in degrees from 0 to 90. be in degrees from 0 to 90.
Args: Args:
rboxes (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5). x (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5).
Returns: Returns:
(numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2). (numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2).
""" """
is_numpy = isinstance(rboxes, np.ndarray) cos, sin, cat, stack = (
cos, sin = (np.cos, np.sin) if is_numpy else (torch.cos, torch.sin) (torch.cos, torch.sin, torch.cat, torch.stack)
if isinstance(x, torch.Tensor)
else (np.cos, np.sin, np.concatenate, np.stack)
)
ctr = rboxes[..., :2] ctr = x[..., :2]
w, h, angle = (rboxes[..., i : i + 1] for i in range(2, 5)) w, h, angle = (x[..., i : i + 1] for i in range(2, 5))
cos_value, sin_value = cos(angle), sin(angle) cos_value, sin_value = cos(angle), sin(angle)
vec1 = [w / 2 * cos_value, w / 2 * sin_value] vec1 = [w / 2 * cos_value, w / 2 * sin_value]
vec2 = [-h / 2 * sin_value, h / 2 * cos_value] vec2 = [-h / 2 * sin_value, h / 2 * cos_value]
vec1 = np.concatenate(vec1, axis=-1) if is_numpy else torch.cat(vec1, dim=-1) vec1 = cat(vec1, -1)
vec2 = np.concatenate(vec2, axis=-1) if is_numpy else torch.cat(vec2, dim=-1) vec2 = cat(vec2, -1)
pt1 = ctr + vec1 + vec2 pt1 = ctr + vec1 + vec2
pt2 = ctr + vec1 - vec2 pt2 = ctr + vec1 - vec2
pt3 = ctr - vec1 - vec2 pt3 = ctr - vec1 - vec2
pt4 = ctr - vec1 + vec2 pt4 = ctr - vec1 + vec2
return np.stack([pt1, pt2, pt3, pt4], axis=-2) if is_numpy else torch.stack([pt1, pt2, pt3, pt4], dim=-2) return stack([pt1, pt2, pt3, pt4], -2)
def ltwh2xyxy(x): def ltwh2xyxy(x):
@ -785,7 +784,7 @@ def regularize_rboxes(rboxes):
Regularize rotated boxes in range [0, pi/2]. Regularize rotated boxes in range [0, pi/2].
Args: Args:
rboxes (torch.Tensor): (N, 5), xywhr. rboxes (torch.Tensor): Input boxes of shape(N, 5) in xywhr format.
Returns: Returns:
(torch.Tensor): The regularized boxes. (torch.Tensor): The regularized boxes.

Loading…
Cancel
Save