diff --git a/ultralytics/data/augment.py b/ultralytics/data/augment.py index bb6590e58c..9cf1e12731 100644 --- a/ultralytics/data/augment.py +++ b/ultralytics/data/augment.py @@ -975,17 +975,22 @@ class Format: 1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio ) labels["masks"] = masks - if self.normalize: - instances.normalize(w, h) labels["img"] = self._format_img(img) labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl) labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4)) if self.return_keypoint: labels["keypoints"] = torch.from_numpy(instances.keypoints) + if self.normalize: + labels["keypoints"][..., 0] /= w + labels["keypoints"][..., 1] /= h if self.return_obb: labels["bboxes"] = ( xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(instances.segments) else torch.zeros((0, 5)) ) + # NOTE: need to normalize obb in xywhr format for width-height consistency + if self.normalize: + labels["bboxes"][:, [0, 2]] /= w + labels["bboxes"][:, [1, 3]] /= h # Then we can use collate_fn if self.batch_idx: labels["batch_idx"] = torch.zeros(nl) diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py index 303228ffc5..946425b348 100644 --- a/ultralytics/utils/plotting.py +++ b/ultralytics/utils/plotting.py @@ -838,16 +838,16 @@ def plot_images( if len(bboxes): boxes = bboxes[idx] conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred) - is_obb = boxes.shape[-1] == 5 # xywhr - boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes) if len(boxes): if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1 - boxes[..., 0::2] *= w # scale to pixels - boxes[..., 1::2] *= h + boxes[..., [0, 2]] *= w # scale to pixels + boxes[..., [1, 3]] *= h elif scale < 1: # absolute coords need scale if image scales boxes[..., :4] *= scale - boxes[..., 0::2] += x - boxes[..., 1::2] += y + boxes[..., 0] += x + boxes[..., 1] += y + is_obb = boxes.shape[-1] == 5 # xywhr + boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes) for j, box in enumerate(boxes.astype(np.int64).tolist()): c = classes[j] color = colors(c)