From 0e3d44fc8d45d226937f16f3125b2607865f2af4 Mon Sep 17 00:00:00 2001 From: Laughing-q <1185102784@qq.com> Date: Tue, 2 Apr 2024 15:38:53 +0800 Subject: [PATCH] attempt to update cache --- ultralytics/data/base.py | 50 +++++++++++----------------------------- 1 file changed, 13 insertions(+), 37 deletions(-) diff --git a/ultralytics/data/base.py b/ultralytics/data/base.py index 7aa3928a7..7cf9195a0 100644 --- a/ultralytics/data/base.py +++ b/ultralytics/data/base.py @@ -81,20 +81,19 @@ class BaseDataset(Dataset): if self.rect: assert self.batch_size is not None self.set_rectangle() - if isinstance(cache, str): - cache = cache.lower() - # Buffer thread for mosaic images self.buffer = [] # buffer size = batch size self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0 # Cache images + if isinstance(cache, str): + cache = cache.lower() if cache == "ram" and not self.check_cache_ram(): cache = False self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files] - if cache: - self.cache_images(cache) + self.cache_disk = cache == "disk" # cache images on hard drive as uncompressed *.npy files + self.cache_ram = cache and not self.cache_disk # cache images into RAM # Transforms self.transforms = self.build_transforms(hyp=hyp) @@ -149,16 +148,15 @@ class BaseDataset(Dataset): def load_image(self, i, rect_mode=True): """Loads 1 image from dataset index 'i', returns (im, resized hw).""" im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] - if im is None: # not cached in RAM - if fn.exists(): # load npy - try: - im = np.load(fn) - except Exception as e: - LOGGER.warning(f"{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}") - Path(fn).unlink(missing_ok=True) - im = cv2.imread(f) # BGR - else: # read image - im = cv2.imread(f) # BGR + if im is None: + if self.cache_ram: + im = self.ims[i] = cv2.imread(f) + elif self.cache_disk: + if not fn.exists(): + np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False) + im = np.load(fn) + else: + im = cv2.imread(f) if im is None: raise FileNotFoundError(f"Image Not Found {f}") @@ -183,28 +181,6 @@ class BaseDataset(Dataset): return self.ims[i], self.im_hw0[i], self.im_hw[i] - def cache_images(self, cache): - """Cache images to memory or disk.""" - b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes - fcn = self.cache_images_to_disk if cache == "disk" else self.load_image - with ThreadPool(NUM_THREADS) as pool: - results = pool.imap(fcn, range(self.ni)) - pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0) - for i, x in pbar: - if cache == "disk": - b += self.npy_files[i].stat().st_size - else: # 'ram' - self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i) - b += self.ims[i].nbytes - pbar.desc = f"{self.prefix}Caching images ({b / gb:.1f}GB {cache})" - pbar.close() - - def cache_images_to_disk(self, i): - """Saves an image as an *.npy file for faster loading.""" - f = self.npy_files[i] - if not f.exists(): - np.save(f.as_posix(), cv2.imread(self.im_files[i]), allow_pickle=False) - def check_cache_ram(self, safety_margin=0.5): """Check image caching requirements vs available memory.""" b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes