diff --git a/ultralytics/yolo/data/base.py b/ultralytics/yolo/data/base.py index 984eb519f7..d6ceff79dd 100644 --- a/ultralytics/yolo/data/base.py +++ b/ultralytics/yolo/data/base.py @@ -65,7 +65,7 @@ class BaseDataset(Dataset): self.ims = [None] * self.ni self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files] if cache: - self.cache_images() + self.cache_images(cache) # transforms self.transforms = self.build_transforms(hyp=hyp) @@ -127,20 +127,20 @@ class BaseDataset(Dataset): return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized - def cache_images(self): + def cache_images(self, cache): # cache images to memory or disk gb = 0 # Gigabytes of cached images self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni - fcn = self.cache_images_to_disk if self.cache == "disk" else self.load_image + fcn = self.cache_images_to_disk if cache == "disk" else self.load_image results = ThreadPool(NUM_THREADS).imap(fcn, range(self.ni)) pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0) for i, x in pbar: - if self.cache == "disk": + if cache == "disk": gb += self.npy_files[i].stat().st_size else: # 'ram' self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i) gb += self.ims[i].nbytes - pbar.desc = f"{self.prefix}Caching images ({gb / 1E9:.1f}GB {self.cache})" + pbar.desc = f"{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})" pbar.close() def cache_images_to_disk(self, i):