You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
39 lines
1.3 KiB
39 lines
1.3 KiB
# Ultralytics YOLO 🚀, GPL-3.0 license |
|
|
|
import collections |
|
from copy import deepcopy |
|
|
|
from .augment import LetterBox |
|
|
|
|
|
class MixAndRectDataset: |
|
"""A wrapper of multiple images mixed dataset. |
|
|
|
Args: |
|
dataset (:obj:`BaseDataset`): The dataset to be mixed. |
|
transforms (Sequence[dict]): config dict to be composed. |
|
""" |
|
|
|
def __init__(self, dataset): |
|
self.dataset = dataset |
|
self.imgsz = dataset.imgsz |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, index): |
|
labels = deepcopy(self.dataset[index]) |
|
for transform in self.dataset.transforms.tolist(): |
|
# mosaic and mixup |
|
if hasattr(transform, 'get_indexes'): |
|
indexes = transform.get_indexes(self.dataset) |
|
if not isinstance(indexes, collections.abc.Sequence): |
|
indexes = [indexes] |
|
mix_labels = [deepcopy(self.dataset[index]) for index in indexes] |
|
labels['mix_labels'] = mix_labels |
|
if self.dataset.rect and isinstance(transform, LetterBox): |
|
transform.new_shape = self.dataset.batch_shapes[self.dataset.batch[index]] |
|
labels = transform(labels) |
|
if 'mix_labels' in labels: |
|
labels.pop('mix_labels') |
|
return labels
|
|
|