diff --git a/ultralytics/tests/data/dataloader/hyp_test.yaml b/ultralytics/tests/data/dataloader/hyp_test.yaml new file mode 100644 index 000000000..a31eef724 --- /dev/null +++ b/ultralytics/tests/data/dataloader/hyp_test.yaml @@ -0,0 +1,29 @@ +lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3) +lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf) +momentum: 0.937 # SGD momentum/Adam beta1 +weight_decay: 0.0005 # optimizer weight decay 5e-4 +warmup_epochs: 3.0 # warmup epochs (fractions ok) +warmup_momentum: 0.8 # warmup initial momentum +warmup_bias_lr: 0.1 # warmup initial bias lr +box: 0.05 # box loss gain +cls: 0.5 # cls loss gain +cls_pw: 1.0 # cls BCELoss positive_weight +obj: 1.0 # obj loss gain (scale with pixels) +obj_pw: 1.0 # obj BCELoss positive_weight +iou_t: 0.20 # IoU training threshold +anchor_t: 4.0 # anchor-multiple threshold +# anchors: 3 # anchors per output layer (0 to ignore) +fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5) +hsv_h: 0.015 # image HSV-Hue augmentation (fraction) +hsv_s: 0.7 # image HSV-Saturation augmentation (fraction) +hsv_v: 0.4 # image HSV-Value augmentation (fraction) +degrees: 0.0 # image rotation (+/- deg) +translate: 0.1 # image translation (+/- fraction) +scale: 0.5 # image scale (+/- gain) +shear: 0.0 # image shear (+/- deg) +perspective: 0.0 # image perspective (+/- fraction), range 0-0.001 +flipud: 0.0 # image flip up-down (probability) +fliplr: 0.5 # image flip left-right (probability) +mosaic: 1.0 # image mosaic (probability) +mixup: 0.0 # image mixup (probability) +copy_paste: 0.0 # segment copy-paste (probability) diff --git a/ultralytics/tests/data/dataloader/yolodetection.py b/ultralytics/tests/data/dataloader/yolodetection.py new file mode 100644 index 000000000..7d37a84fd --- /dev/null +++ b/ultralytics/tests/data/dataloader/yolodetection.py @@ -0,0 +1,97 @@ +import cv2 +import numpy as np +from omegaconf import OmegaConf + +from ultralytics.yolo.data import build_dataloader + + +class Colors: + # Ultralytics color palette https://ultralytics.com/ + def __init__(self): + # hex = matplotlib.colors.TABLEAU_COLORS.values() + hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB', + '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7') + self.palette = [self.hex2rgb(f'#{c}') for c in hexs] + self.n = len(self.palette) + + def __call__(self, i, bgr=False): + c = self.palette[int(i) % self.n] + return (c[2], c[1], c[0]) if bgr else c + + @staticmethod + def hex2rgb(h): # rgb order (PIL) + return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4)) + + +colors = Colors() # create instance for 'from utils.plots import colors' + + +def plot_one_box(x, img, color=None, label=None, line_thickness=None): + import random + + # Plots one bounding box on image img + tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness + color = color or [random.randint(0, 255) for _ in range(3)] + c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) + cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) + if label: + tf = max(tl - 1, 1) # font thickness + t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] + c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 + cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled + cv2.putText( + img, + label, + (c1[0], c1[1] - 2), + 0, + tl / 3, + [225, 255, 255], + thickness=tf, + lineType=cv2.LINE_AA, + ) + + +with open("ultralytics/tests/data/dataloader/hyp_test.yaml") as f: + hyp = OmegaConf.load(f) + +dataloader, dataset = build_dataloader( + img_path="/d/dataset/COCO/coco128-seg/images", + img_size=640, + label_path=None, + cache=False, + hyp=hyp, + augment=False, + prefix="", + rect=False, + batch_size=4, + stride=32, + pad=0.5, + use_segments=True, + use_keypoints=False, +) + +for d in dataloader: + idx = 1 # show which image inside one batch + img = d["img"][idx].numpy() + img = np.ascontiguousarray(img.transpose(1, 2, 0)) + ih, iw = img.shape[:2] + # print(img.shape) + bidx = d["batch_idx"] + cls = d["cls"][bidx == idx].numpy() + bboxes = d["bboxes"][bidx == idx].numpy() + print(bboxes.shape) + bboxes[:, [0, 2]] *= iw + bboxes[:, [1, 3]] *= ih + nl = len(cls) + + for i, b in enumerate(bboxes): + x, y, w, h = b + x1 = x - w / 2 + x2 = x + w / 2 + y1 = y - h / 2 + y2 = y + h / 2 + c = int(cls[i][0]) + plot_one_box([int(x1), int(y1), int(x2), int(y2)], img, label=f"{c}", color=colors(c)) + cv2.imshow("p", img) + if cv2.waitKey(0) == ord("q"): + break diff --git a/ultralytics/tests/data/dataloader/yolopose.py b/ultralytics/tests/data/dataloader/yolopose.py new file mode 100644 index 000000000..e36ed1d1a --- /dev/null +++ b/ultralytics/tests/data/dataloader/yolopose.py @@ -0,0 +1,114 @@ +import cv2 +import numpy as np +import torch +from omegaconf import OmegaConf + +from ultralytics.yolo.data import build_dataloader + + +class Colors: + # Ultralytics color palette https://ultralytics.com/ + def __init__(self): + # hex = matplotlib.colors.TABLEAU_COLORS.values() + hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB', + '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7') + self.palette = [self.hex2rgb(f'#{c}') for c in hexs] + self.n = len(self.palette) + + def __call__(self, i, bgr=False): + c = self.palette[int(i) % self.n] + return (c[2], c[1], c[0]) if bgr else c + + @staticmethod + def hex2rgb(h): # rgb order (PIL) + return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4)) + + +colors = Colors() # create instance for 'from utils.plots import colors' + + +def plot_one_box(x, img, keypoints=None, color=None, label=None, line_thickness=None): + import random + + # Plots one bounding box on image img + tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness + color = color or [random.randint(0, 255) for _ in range(3)] + c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) + cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) + if label: + tf = max(tl - 1, 1) # font thickness + t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] + c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 + cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled + cv2.putText( + img, + label, + (c1[0], c1[1] - 2), + 0, + tl / 3, + [225, 255, 255], + thickness=tf, + lineType=cv2.LINE_AA, + ) + if keypoints is not None: + plot_keypoint(img, keypoints, color, tl) + + +def plot_keypoint(img, keypoints, color, tl): + num_l = len(keypoints) + # clors = [(255, 0, 0),(0, 255, 0),(0, 0, 255),(255, 255, 0),(0, 255, 255)] + # clors = [[random.randint(0, 255) for _ in range(3)] for _ in range(num_l)] + for i in range(num_l): + point_x = int(keypoints[i][0]) + point_y = int(keypoints[i][1]) + cv2.circle(img, (point_x, point_y), tl + 3, color, -1) + + +with open("ultralytics/tests/data/dataloader/hyp_test.yaml") as f: + hyp = OmegaConf.load(f) + +dataloader, dataset = build_dataloader( + img_path="/d/dataset/COCO/images/val2017", + img_size=640, + label_path=None, + cache=False, + hyp=hyp, + augment=False, + prefix="", + rect=False, + batch_size=4, + stride=32, + pad=0.5, + use_segments=False, + use_keypoints=True, +) + +for d in dataloader: + idx = 1 # show which image inside one batch + img = d["img"][idx].numpy() + img = np.ascontiguousarray(img.transpose(1, 2, 0)) + ih, iw = img.shape[:2] + # print(img.shape) + bidx = d["batch_idx"] + cls = d["cls"][bidx == idx].numpy() + bboxes = d["bboxes"][bidx == idx].numpy() + bboxes[:, [0, 2]] *= iw + bboxes[:, [1, 3]] *= ih + keypoints = d["keypoints"][bidx == idx] + keypoints[..., 0] *= iw + keypoints[..., 1] *= ih + # print(keypoints, keypoints.shape) + # print(d["im_file"]) + + for i, b in enumerate(bboxes): + x, y, w, h = b + x1 = x - w / 2 + x2 = x + w / 2 + y1 = y - h / 2 + y2 = y + h / 2 + c = int(cls[i][0]) + # print(x1, y1, x2, y2) + plot_one_box([int(x1), int(y1), int(x2), int(y2)], img, keypoints=keypoints[i], label=f"{c}", color=colors(c)) + cv2.imshow("p", img) + if cv2.waitKey(0) == ord("q"): + break diff --git a/ultralytics/tests/data/dataloader/yolosegment.py b/ultralytics/tests/data/dataloader/yolosegment.py new file mode 100644 index 000000000..ae99aa5ee --- /dev/null +++ b/ultralytics/tests/data/dataloader/yolosegment.py @@ -0,0 +1,112 @@ +import cv2 +import numpy as np +import torch +from omegaconf import OmegaConf + +from ultralytics.yolo.data import build_dataloader + + +class Colors: + # Ultralytics color palette https://ultralytics.com/ + def __init__(self): + # hex = matplotlib.colors.TABLEAU_COLORS.values() + hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB', + '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7') + self.palette = [self.hex2rgb(f'#{c}') for c in hexs] + self.n = len(self.palette) + + def __call__(self, i, bgr=False): + c = self.palette[int(i) % self.n] + return (c[2], c[1], c[0]) if bgr else c + + @staticmethod + def hex2rgb(h): # rgb order (PIL) + return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4)) + + +colors = Colors() # create instance for 'from utils.plots import colors' + + +def plot_one_box(x, img, color=None, label=None, line_thickness=None): + import random + + # Plots one bounding box on image img + tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness + color = color or [random.randint(0, 255) for _ in range(3)] + c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) + cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) + if label: + tf = max(tl - 1, 1) # font thickness + t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] + c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 + cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled + cv2.putText( + img, + label, + (c1[0], c1[1] - 2), + 0, + tl / 3, + [225, 255, 255], + thickness=tf, + lineType=cv2.LINE_AA, + ) + + +with open("ultralytics/tests/data/dataloader/hyp_test.yaml") as f: + hyp = OmegaConf.load(f) + +dataloader, dataset = build_dataloader( + img_path="/d/dataset/COCO/coco128-seg/images", + img_size=640, + label_path=None, + cache=False, + hyp=hyp, + augment=False, + prefix="", + rect=False, + batch_size=4, + stride=32, + pad=0.5, + use_segments=True, + use_keypoints=False, +) + +for d in dataloader: + idx = 1 # show which image inside one batch + img = d["img"][idx].numpy() + img = np.ascontiguousarray(img.transpose(1, 2, 0)) + ih, iw = img.shape[:2] + # print(img.shape) + bidx = d["batch_idx"] + cls = d["cls"][bidx == idx].numpy() + bboxes = d["bboxes"][bidx == idx].numpy() + masks = d["masks"][idx] + print(bboxes.shape) + bboxes[:, [0, 2]] *= iw + bboxes[:, [1, 3]] *= ih + nl = len(cls) + + index = torch.arange(nl).view(nl, 1, 1) + 1 + masks = masks.repeat(nl, 1, 1) + # print(masks.shape, index.shape) + masks = torch.where(masks == index, 1, 0) + masks = masks.numpy().astype(np.uint8) + print(masks.shape) + # keypoints = d["keypoints"] + + for i, b in enumerate(bboxes): + x, y, w, h = b + x1 = x - w / 2 + x2 = x + w / 2 + y1 = y - h / 2 + y2 = y + h / 2 + c = int(cls[i][0]) + # print(x1, y1, x2, y2) + plot_one_box([int(x1), int(y1), int(x2), int(y2)], img, label=f"{c}", color=colors(c)) + mask = masks[i] + mask = cv2.resize(mask, (iw, ih)) + mask = mask.astype(bool) + img[mask] = img[mask] * 0.5 + np.array(colors(c)) * 0.5 + cv2.imshow("p", img) + if cv2.waitKey(0) == ord("q"): + break diff --git a/ultralytics/yolo/data/augment.py b/ultralytics/yolo/data/augment.py index 6c936ad89..af7524084 100644 --- a/ultralytics/yolo/data/augment.py +++ b/ultralytics/yolo/data/augment.py @@ -127,7 +127,7 @@ class Mosaic(BaseMixTransform): self.border = border def get_indexes(self, dataset): - return [random.randint(0, len(dataset)) for _ in range(3)] + return [random.randint(0, len(dataset) - 1) for _ in range(3)] def _mix_transform(self, labels): mosaic_labels = [] @@ -200,7 +200,7 @@ class MixUp(BaseMixTransform): super().__init__(pre_transform=pre_transform, p=p) def get_indexes(self, dataset): - return random.randint(0, len(dataset)) + return random.randint(0, len(dataset) - 1) def _mix_transform(self, labels): im = labels["img"] @@ -366,7 +366,7 @@ class RandomPerspective: segments = instances.segments keypoints = instances.keypoints # update bboxes if there are segments. - if segments is not None: + if len(segments): bboxes, segments = self.apply_segments(segments, M) if keypoints is not None: @@ -379,7 +379,7 @@ class RandomPerspective: # make the bboxes have the same scale with new_bboxes i = self.box_candidates(box1=instances.bboxes.T, box2=new_instances.bboxes.T, - area_thr=0.01 if segments is not None else 0.10) + area_thr=0.01 if len(segments) else 0.10) labels["instances"] = new_instances[i] # clip labels["cls"] = cls[i] @@ -518,7 +518,7 @@ class CopyPaste: bboxes = labels["instances"].bboxes segments = labels["instances"].segments # n, 1000, 2 keypoints = labels["instances"].keypoints - if self.p and segments is not None: + if self.p and len(segments): n = len(segments) h, w, _ = im.shape # height, width, channels im_new = np.zeros(im.shape, np.uint8) @@ -593,10 +593,18 @@ class Albumentations: # TODO: technically this is not an augmentation, maybe we should put this to another files class Format: - def __init__(self, bbox_format="xywh", normalize=True, mask=False, mask_ratio=4, mask_overlap=True, batch_idx=True): + def __init__(self, + bbox_format="xywh", + normalize=True, + return_mask=False, + return_keypoint=False, + mask_ratio=4, + mask_overlap=True, + batch_idx=True): self.bbox_format = bbox_format self.normalize = normalize - self.mask = mask # set False when training detection only + self.return_mask = return_mask # set False when training detection only + self.return_keypoint = return_keypoint self.mask_ratio = mask_ratio self.mask_overlap = mask_overlap self.batch_idx = batch_idx # keep the batch indexes @@ -610,16 +618,20 @@ class Format: instances.denormalize(w, h) nl = len(instances) - if instances.segments is not None and self.mask: - masks, instances, cls = self._format_segments(instances, cls, w, h) - labels["masks"] = (torch.from_numpy(masks) if nl else torch.zeros( - 1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio)) + if self.return_mask: + if nl: + masks, instances, cls = self._format_segments(instances, cls, w, h) + masks = torch.from_numpy(masks) + else: + masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, + img.shape[1] // self.mask_ratio) + labels["masks"] = masks if self.normalize: instances.normalize(w, h) labels["img"] = self._format_img(img) labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl) labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4)) - if instances.keypoints is not None: + if self.return_keypoint: labels["keypoints"] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2)) # then we can use collate_fn if self.batch_idx: diff --git a/ultralytics/yolo/data/dataset.py b/ultralytics/yolo/data/dataset.py index b17fc46f9..e332aed12 100644 --- a/ultralytics/yolo/data/dataset.py +++ b/ultralytics/yolo/data/dataset.py @@ -132,7 +132,12 @@ class YOLODataset(BaseDataset): transforms = affine_transforms(self.img_size, hyp) else: transforms = Compose([LetterBox(new_shape=(self.img_size, self.img_size))]) - transforms.append(Format(bbox_format="xywh", normalize=True, mask=self.use_segments, batch_idx=True)) + transforms.append( + Format(bbox_format="xywh", + normalize=True, + return_mask=self.use_segments, + return_keypoint=self.use_keypoints, + batch_idx=True)) return transforms def update_labels_info(self, label): @@ -140,7 +145,7 @@ class YOLODataset(BaseDataset): # NOTE: cls is not with bboxes now, since other tasks like classification and semantic segmentation need a independent cls label # we can make it also support classification and semantic segmentation by add or remove some dict keys there. bboxes = label.pop("bboxes") - segments = label.pop("segments", None) + segments = label.pop("segments") keypoints = label.pop("keypoints", None) bbox_format = label.pop("bbox_format") normalized = label.pop("normalized") @@ -158,9 +163,9 @@ class YOLODataset(BaseDataset): value = values[i] if k == "img": value = torch.stack(value, 0) - if k in ["mask", "keypoint", "bboxes", "cls"]: + if k in ["masks", "keypoints", "bboxes", "cls"]: value = torch.cat(value, 0) - new_batch[k] = values[i] + new_batch[k] = value new_batch["batch_idx"] = list(new_batch["batch_idx"]) for i in range(len(new_batch["batch_idx"])): new_batch["batch_idx"][i] += i # add target image index for build_targets() diff --git a/ultralytics/yolo/data/utils.py b/ultralytics/yolo/data/utils.py index 002774246..63d596221 100644 --- a/ultralytics/yolo/data/utils.py +++ b/ultralytics/yolo/data/utils.py @@ -52,7 +52,7 @@ def verify_image_label(args): # Verify one image-label pair im_file, lb_file, prefix, keypoint = args # number (missing, found, empty, corrupt), message, segments, keypoints - nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", None, None + nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None try: # verify images im = Image.open(im_file) diff --git a/ultralytics/yolo/utils/instance.py b/ultralytics/yolo/utils/instance.py index 1481ce344..0c1b7ff54 100644 --- a/ultralytics/yolo/utils/instance.py +++ b/ultralytics/yolo/utils/instance.py @@ -162,7 +162,7 @@ class Bboxes: class Instances: - def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None: + def __init__(self, bboxes, segments=[], keypoints=None, bbox_format="xywh", normalized=True) -> None: """ Args: bboxes (ndarray): bboxes with shape [N, 4]. @@ -173,11 +173,13 @@ class Instances: self.keypoints = keypoints self.normalized = normalized - if isinstance(segments, list) and len(segments) > 0: + if len(segments) > 0: # list[np.array(1000, 2)] * num_samples segments = resample_segments(segments) # (N, 1000, 2) segments = np.stack(segments, axis=0) + else: + segments = np.zeros((0, 1000, 2), dtype=np.float32) self.segments = segments def convert_bbox(self, format): @@ -191,9 +193,8 @@ class Instances: self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h)) if bbox_only: return - if self.segments is not None: - self.segments[..., 0] *= scale_w - self.segments[..., 1] *= scale_h + self.segments[..., 0] *= scale_w + self.segments[..., 1] *= scale_h if self.keypoints is not None: self.keypoints[..., 0] *= scale_w self.keypoints[..., 1] *= scale_h @@ -202,9 +203,8 @@ class Instances: if not self.normalized: return self._bboxes.mul(scale=(w, h, w, h)) - if self.segments is not None: - self.segments[..., 0] *= w - self.segments[..., 1] *= h + self.segments[..., 0] *= w + self.segments[..., 1] *= h if self.keypoints is not None: self.keypoints[..., 0] *= w self.keypoints[..., 1] *= h @@ -214,9 +214,8 @@ class Instances: if self.normalized: return self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h)) - if self.segments is not None: - self.segments[..., 0] /= w - self.segments[..., 1] /= h + self.segments[..., 0] /= w + self.segments[..., 1] /= h if self.keypoints is not None: self.keypoints[..., 0] /= w self.keypoints[..., 1] /= h @@ -226,9 +225,8 @@ class Instances: # handle rect and mosaic situation assert not self.normalized, "you should add padding with absolute coordinates." self._bboxes.add(offset=(padw, padh, padw, padh)) - if self.segments is not None: - self.segments[..., 0] += padw - self.segments[..., 1] += padh + self.segments[..., 0] += padw + self.segments[..., 1] += padh if self.keypoints is not None: self.keypoints[..., 0] += padw self.keypoints[..., 1] += padh @@ -241,7 +239,7 @@ class Instances: Returns: Instances: Create a new :class:`Instances` by indexing. """ - segments = self.segments[index] if self.segments is not None else None + segments = self.segments[index] if len(self.segments) else self.segments keypoints = self.keypoints[index] if self.keypoints is not None else None bboxes = self.bboxes[index] bbox_format = self._bboxes.format @@ -256,16 +254,14 @@ class Instances: def flipud(self, h): # this function may not be very logical, just for clean code when using augment flipud self.bboxes[:, 1] = h - self.bboxes[:, 1] - if self.segments is not None: - self.segments[..., 1] = h - self.segments[..., 1] + self.segments[..., 1] = h - self.segments[..., 1] if self.keypoints is not None: self.keypoints[..., 1] = h - self.keypoints[..., 1] def fliplr(self, w): # this function may not be very logical, just for clean code when using augment fliplr self.bboxes[:, 0] = w - self.bboxes[:, 0] - if self.segments is not None: - self.segments[..., 0] = w - self.segments[..., 0] + self.segments[..., 0] = w - self.segments[..., 0] if self.keypoints is not None: self.keypoints[..., 0] = w - self.keypoints[..., 0] @@ -273,9 +269,8 @@ class Instances: self.convert_bbox(format="xyxy") self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w) self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h) - if self.segments is not None: - self.segments[..., 0] = self.segments[..., 0].clip(0, w) - self.segments[..., 1] = self.segments[..., 1].clip(0, h) + self.segments[..., 0] = self.segments[..., 0].clip(0, w) + self.segments[..., 1] = self.segments[..., 1].clip(0, h) if self.keypoints is not None: self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w) self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h) @@ -311,13 +306,12 @@ class Instances: if len(instances_list) == 1: return instances_list[0] - use_segment = instances_list[0].segments is not None use_keypoint = instances_list[0].keypoints is not None bbox_format = instances_list[0]._bboxes.format normalized = instances_list[0].normalized cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis) - cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis) if use_segment else None + cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis) cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized)