OBB: Fix distorted plotting (#9899)

pull/9950/head
Laughing 10 months ago committed by GitHub
parent e040ce0618
commit 417c429ec4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 9
      ultralytics/data/augment.py
  2. 12
      ultralytics/utils/plotting.py

@ -975,17 +975,22 @@ class Format:
1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio
)
labels["masks"] = masks
if self.normalize:
instances.normalize(w, h)
labels["img"] = self._format_img(img)
labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
if self.return_keypoint:
labels["keypoints"] = torch.from_numpy(instances.keypoints)
if self.normalize:
labels["keypoints"][..., 0] /= w
labels["keypoints"][..., 1] /= h
if self.return_obb:
labels["bboxes"] = (
xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(instances.segments) else torch.zeros((0, 5))
)
# NOTE: need to normalize obb in xywhr format for width-height consistency
if self.normalize:
labels["bboxes"][:, [0, 2]] /= w
labels["bboxes"][:, [1, 3]] /= h
# Then we can use collate_fn
if self.batch_idx:
labels["batch_idx"] = torch.zeros(nl)

@ -838,16 +838,16 @@ def plot_images(
if len(bboxes):
boxes = bboxes[idx]
conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
is_obb = boxes.shape[-1] == 5 # xywhr
boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
if len(boxes):
if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
boxes[..., 0::2] *= w # scale to pixels
boxes[..., 1::2] *= h
boxes[..., [0, 2]] *= w # scale to pixels
boxes[..., [1, 3]] *= h
elif scale < 1: # absolute coords need scale if image scales
boxes[..., :4] *= scale
boxes[..., 0::2] += x
boxes[..., 1::2] += y
boxes[..., 0] += x
boxes[..., 1] += y
is_obb = boxes.shape[-1] == 5 # xywhr
boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
for j, box in enumerate(boxes.astype(np.int64).tolist()):
c = classes[j]
color = colors(c)

Loading…
Cancel
Save