|
|
|
@ -80,7 +80,7 @@ class BaseDataset(Dataset): |
|
|
|
|
# Cache stuff |
|
|
|
|
if cache == 'ram' and not self.check_cache_ram(): |
|
|
|
|
cache = False |
|
|
|
|
self.ims = [None] * self.ni |
|
|
|
|
self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni |
|
|
|
|
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files] |
|
|
|
|
if cache: |
|
|
|
|
self.cache_images(cache) |
|
|
|
@ -88,6 +88,10 @@ class BaseDataset(Dataset): |
|
|
|
|
# Transforms |
|
|
|
|
self.transforms = self.build_transforms(hyp=hyp) |
|
|
|
|
|
|
|
|
|
# Buffer thread for mosaic images |
|
|
|
|
self.buffer = [] # buffer size = batch size |
|
|
|
|
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0 |
|
|
|
|
|
|
|
|
|
def get_img_files(self, img_path): |
|
|
|
|
"""Read image files.""" |
|
|
|
|
try: |
|
|
|
@ -147,13 +151,22 @@ class BaseDataset(Dataset): |
|
|
|
|
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA |
|
|
|
|
im = cv2.resize(im, (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz)), |
|
|
|
|
interpolation=interp) |
|
|
|
|
return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized |
|
|
|
|
return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized |
|
|
|
|
|
|
|
|
|
# Add to buffer if training with augmentations |
|
|
|
|
if self.augment: |
|
|
|
|
self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized |
|
|
|
|
self.buffer.append(i) |
|
|
|
|
if len(self.buffer) >= self.max_buffer_length: |
|
|
|
|
j = self.buffer.pop(0) |
|
|
|
|
self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None |
|
|
|
|
|
|
|
|
|
return im, (h0, w0), im.shape[:2] |
|
|
|
|
|
|
|
|
|
return self.ims[i], self.im_hw0[i], self.im_hw[i] |
|
|
|
|
|
|
|
|
|
def cache_images(self, cache): |
|
|
|
|
"""Cache images to memory or disk.""" |
|
|
|
|
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes |
|
|
|
|
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni |
|
|
|
|
fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image |
|
|
|
|
with ThreadPool(NUM_THREADS) as pool: |
|
|
|
|
results = pool.imap(fcn, range(self.ni)) |
|
|
|
@ -218,9 +231,9 @@ class BaseDataset(Dataset): |
|
|
|
|
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
|
"""Returns transformed label information for given index.""" |
|
|
|
|
return self.transforms(self.get_label_info(index)) |
|
|
|
|
return self.transforms(self.get_image_and_label(index)) |
|
|
|
|
|
|
|
|
|
def get_label_info(self, index): |
|
|
|
|
def get_image_and_label(self, index): |
|
|
|
|
"""Get and return label information from the dataset.""" |
|
|
|
|
label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948 |
|
|
|
|
label.pop('shape', None) # shape is for rect, remove it |
|
|
|
@ -229,8 +242,7 @@ class BaseDataset(Dataset): |
|
|
|
|
label['resized_shape'][1] / label['ori_shape'][1]) # for evaluation |
|
|
|
|
if self.rect: |
|
|
|
|
label['rect_shape'] = self.batch_shapes[self.batch[index]] |
|
|
|
|
label = self.update_labels_info(label) |
|
|
|
|
return label |
|
|
|
|
return self.update_labels_info(label) |
|
|
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
|
"""Returns the length of the labels list for the dataset.""" |
|
|
|
|