From 2f60abd89995dc0a803aeb9b579c108f682ff03b Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Thu, 8 Sep 2022 14:59:17 +0800 Subject: [PATCH] Refactor and fix bugs --- paddlers/tasks/utils/slider_predict.py | 212 ++++++++++++++++++++----- 1 file changed, 172 insertions(+), 40 deletions(-) diff --git a/paddlers/tasks/utils/slider_predict.py b/paddlers/tasks/utils/slider_predict.py index 1672eb0..620997f 100644 --- a/paddlers/tasks/utils/slider_predict.py +++ b/paddlers/tasks/utils/slider_predict.py @@ -19,6 +19,7 @@ from abc import ABCMeta, abstractmethod from collections import Counter, defaultdict import numpy as np +from tqdm import tqdm import paddlers.utils.logging as logging @@ -31,6 +32,7 @@ class Cache(metaclass=ABCMeta): class SlowCache(Cache): def __init__(self): + super(SlowCache, self).__init__() self.cache = defaultdict(Counter) def push_pixel(self, i, j, l): @@ -66,6 +68,7 @@ class SlowCache(Cache): class ProbCache(Cache): def __init__(self, h, w, ch, cw, sh, sw, dtype=np.float32, order='c'): + super(ProbCache, self).__init__() self.cache = None self.h = h self.w = w @@ -116,20 +119,139 @@ class ProbCache(Cache): self._alloc_memory(nc) self.cache[i_st:i_st + h, j_st:j_st + w] += prob_map - def roll_cache(self): + def roll_cache(self, shift): if self.order == 'c': - self.cache[:-self.sh] = self.cache[self.sh:] - self.cache[-self.sh:, :] = 0 + self.cache[:-shift] = self.cache[shift:] + self.cache[-shift:, :] = 0 elif self.order == 'f': - self.cache[:, :-self.sw] = self.cache[:, self.sw:] - self.cache[:, -self.sw:] = 0 + self.cache[:, :-shift] = self.cache[:, shift:] + self.cache[:, -shift:] = 0 def get_block(self, i_st, j_st, h, w): return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2) -def slider_predict(predict_func, img_file, save_dir, block_size, overlap, - transforms, invalid_value, merge_strategy, batch_size): +class OverlapProcessor(metaclass=ABCMeta): + def __init__(self, h, w, ch, cw, sh, sw): + super(OverlapProcessor, self).__init__() + self.h = h + self.w = w + self.ch = ch + self.cw = cw + self.sh = sh + self.sw = sw + + @abstractmethod + def process_pred(self, out, xoff, yoff): + pass + + +class KeepFirstProcessor(OverlapProcessor): + def __init__(self, h, w, ch, cw, sh, sw, ds, inval=255): + super(KeepFirstProcessor, self).__init__(h, w, ch, cw, sh, sw) + self.ds = ds + self.inval = inval + + def process_pred(self, out, xoff, yoff): + pred = out['label_map'] + pred = pred[:self.ch, :self.cw] + rd_block = self.ds.ReadAsArray(xoff, yoff, self.cw, self.ch) + mask = rd_block != self.inval + pred = np.where(mask, rd_block, pred) + return pred + + +class KeepLastProcessor(OverlapProcessor): + def process_pred(self, out, xoff, yoff): + pred = out['label_map'] + pred = pred[:self.ch, :self.cw] + return pred + + +class AccumProcessor(OverlapProcessor): + def __init__(self, + h, + w, + ch, + cw, + sh, + sw, + dtype=np.float16, + assign_weight=True): + super(AccumProcessor, self).__init__(h, w, ch, cw, sh, sw) + self.cache = ProbCache(h, w, ch, cw, sh, sw, dtype=dtype, order='c') + self.prev_yoff = None + self.assign_weight = assign_weight + + def process_pred(self, out, xoff, yoff): + if self.prev_yoff is not None and yoff != self.prev_yoff: + if yoff < self.prev_yoff: + raise RuntimeError + self.cache.roll_cache(yoff - self.prev_yoff) + pred = out['label_map'] + pred = pred[:self.ch, :self.cw] + prob = out['score_map'] + prob = prob[:self.ch, :self.cw] + if self.assign_weight: + prob = assign_border_weights(prob, border_ratio=0.25, inplace=True) + self.cache.update_block(0, xoff, self.ch, self.cw, prob) + pred = self.cache.get_block(0, xoff, self.ch, self.cw) + self.prev_yoff = yoff + return pred + + +def assign_border_weights(array, weight=0.5, border_ratio=0.25, inplace=True): + if not inplace: + array = array.copy() + h, w = array.shape[:2] + hm, wm = int(h * border_ratio), int(w * border_ratio) + array[:hm] *= weight + array[-hm:] *= weight + array[:, :wm] *= weight + array[:, -wm:] *= weight + return array + + +def read_block(ds, + xoff, + yoff, + xsize, + ysize, + tar_xsize=None, + tar_ysize=None, + pad_val=0): + if tar_xsize is None: + tar_xsize = xsize + if tar_ysize is None: + tar_ysize = ysize + # Read data from dataset + block = ds.ReadAsArray(xoff, yoff, xsize, ysize) + c, real_ysize, real_xsize = block.shape + assert real_ysize == ysize and real_xsize == xsize + # [c, h, w] -> [h, w, c] + block = block.transpose((1, 2, 0)) + if (real_ysize, real_xsize) != (tar_ysize, tar_xsize): + if real_ysize >= tar_ysize or real_xsize >= tar_xsize: + raise ValueError + padded_block = np.full( + (tar_ysize, tar_xsize, c), fill_value=pad_val, dtype=block.dtype) + # Fill + padded_block[:real_ysize, :real_xsize] = block + return padded_block + else: + return block + + +def slider_predict(predict_func, + img_file, + save_dir, + block_size, + overlap, + transforms, + invalid_value, + merge_strategy, + batch_size, + show_progress=False): """ Do inference using sliding windows. @@ -153,6 +275,8 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, traversal order, respectively. 'accum' means determining the class of an overlapping pixel according to accumulated probabilities. batch_size (int): Batch size used in inference. + show_progress (bool, optional): Whether to show prediction progress with a + progress bar. Defaults to True. """ try: @@ -175,10 +299,6 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, raise ValueError( "`overlap` must be a tuple/list of length 2 or an integer.") - if merge_strategy not in ('keep_first', 'keep_last', 'accum'): - raise ValueError("{} is not a supported stragegy for block merging.". - format(merge_strategy)) - step = np.array( block_size, dtype=np.int32) - np.array( overlap, dtype=np.int32) @@ -234,29 +354,50 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, # When there is no overlap or the whole image is used as input, # use 'keep_last' strategy as it introduces least overheads merge_strategy = 'keep_last' - if merge_strategy == 'accum': - cache = ProbCache(height, width, *block_size[::-1], *step[::-1]) + if merge_strategy == 'keep_first': + overlap_processor = KeepFirstProcessor( + height, + width, + *block_size[::-1], + *step[::-1], + band, + inval=invalid_value) + elif merge_strategy == 'keep_last': + overlap_processor = KeepLastProcessor(height, width, *block_size[::-1], + *step[::-1]) + elif merge_strategy == 'accum': + overlap_processor = AccumProcessor(height, width, *block_size[::-1], + *step[::-1]) + else: + raise ValueError("{} is not a supported stragegy for block merging.". + format(merge_strategy)) + + xsize, ysize = block_size + num_blocks = math.ceil(height / step[1]) * math.ceil(width / step[0]) + cnt = 0 + if show_progress: + pb = tqdm(total=num_blocks) batch_data = [] batch_offsets = [] for yoff in range(0, height, step[1]): for xoff in range(0, width, step[0]): - xsize, ysize = block_size if xoff + xsize > width: xoff = width - xsize + is_end_of_row = True + else: + is_end_of_row = False if yoff + ysize > height: yoff = height - ysize + is_end_of_col = True + else: + is_end_of_col = False - is_end_of_col = yoff + ysize >= height - is_end_of_row = xoff + xsize >= width - - # Read and fill - im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( - (1, 2, 0)) + # Read + im = read_block(src_data, xoff, yoff, xsize, ysize) if isinstance(img_file, tuple): - im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( - (1, 2, 0)) + im2 = read_block(src2_data, xoff, yoff, xsize, ysize) batch_data.append((im, im2)) else: batch_data.append(im) @@ -276,24 +417,8 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, batch_out = predict_func(batch_data, transforms=transforms) for out, (xoff_, yoff_) in zip(batch_out, batch_offsets): - pred = out['label_map'].astype('uint8') - pred = pred[:ysize, :xsize] - - # Deal with overlapping pixels - if merge_strategy == 'keep_first': - rd_block = band.ReadAsArray(xoff_, yoff_, xsize, ysize) - mask = rd_block != invalid_value - pred = np.where(mask, rd_block, pred) - elif merge_strategy == 'keep_last': - pass - elif merge_strategy == 'accum': - prob = out['score_map'] - prob = prob[:ysize, :xsize] - cache.update_block(0, xoff_, ysize, xsize, prob) - pred = cache.get_block(0, xoff_, ysize, xsize) - if xoff_ + xsize >= width: - cache.roll_cache() - + # Get processed result + pred = overlap_processor.process_pred(out, xoff_, yoff_) # Write to file band.WriteArray(pred, xoff_, yoff_) @@ -301,5 +426,12 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, batch_data.clear() batch_offsets.clear() + cnt += 1 + + if show_progress: + pb.update(1) + pb.set_description("{} out of {} blocks processed.".format( + cnt, num_blocks)) + dst_data = None logging.info("GeoTiff file saved in {}.".format(save_file))