From bf54499f5a6ab0eb28e7b01d9b9c58f5d507b520 Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Sun, 4 Sep 2022 03:13:09 +0800 Subject: [PATCH] Add accum strategy --- paddlers/tasks/change_detector.py | 137 +------------ paddlers/tasks/segmenter.py | 137 +------------ paddlers/tasks/utils/slider_predict.py | 254 ++++++++++++++++++++++++- tests/tasks/test_slider_predict.py | 50 +++-- 4 files changed, 301 insertions(+), 277 deletions(-) diff --git a/paddlers/tasks/change_detector.py b/paddlers/tasks/change_detector.py index 30babb5..790173b 100644 --- a/paddlers/tasks/change_detector.py +++ b/paddlers/tasks/change_detector.py @@ -35,7 +35,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict from .base import BaseModel from .utils import seg_metrics as metrics from .utils.infer_nets import InferCDNet -from .utils.slider_predict import SlowCache as Cache +from .utils.slider_predict import slider_predict __all__ = [ "CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT", @@ -606,139 +606,8 @@ class BaseChangeDetector(BaseModel): there are conflicts in the overlapping pixels. Defaults to 'keep_last'. """ - try: - from osgeo import gdal - except: - import gdal - - if isinstance(block_size, int): - block_size = (block_size, block_size) - elif isinstance(block_size, (tuple, list)) and len(block_size) == 2: - block_size = tuple(block_size) - else: - raise ValueError( - "`block_size` must be a tuple/list of length 2 or an integer.") - if isinstance(overlap, int): - overlap = (overlap, overlap) - elif isinstance(overlap, (tuple, list)) and len(overlap) == 2: - overlap = tuple(overlap) - else: - raise ValueError( - "`overlap` must be a tuple/list of length 2 or an integer.") - - if merge_strategy not in ('keep_first', 'keep_last', 'vote'): - raise ValueError( - "{} is not a supported stragegy for block merging.".format( - merge_strategy)) - if overlap == (0, 0): - # When there is no overlap, use 'keep_last' strategy as it introduces least overheads - merge_strategy = 'keep_last' - if merge_strategy == 'vote': - logging.warning( - "Currently, a naive Python-implemented cache is used for aggregating voting results. " - "For higher performance in inferring large images, please set `merge_strategy` to 'keep_first' or " - "'keep_last'.") - cache = Cache() - - src1_data = gdal.Open(img_files[0]) - src2_data = gdal.Open(img_files[1]) - - # Assume that two input images have the same size - width = src1_data.RasterXSize - height = src1_data.RasterYSize - bands = src1_data.RasterCount - - driver = gdal.GetDriverByName("GTiff") - # Output name is the same as the name of the first image - file_name = osp.basename(osp.normpath(img_files[0])) - # Replace extension name with '.tif' - file_name = osp.splitext(file_name)[0] + ".tif" - if not osp.exists(save_dir): - os.makedirs(save_dir) - save_file = osp.join(save_dir, file_name) - dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte) - - # Set meta-information (consistent with the first image) - dst_data.SetGeoTransform(src1_data.GetGeoTransform()) - dst_data.SetProjection(src1_data.GetProjection()) - - band = dst_data.GetRasterBand(1) - band.WriteArray( - np.full( - (height, width), fill_value=invalid_value, dtype="uint8")) - - prev_yoff, prev_xoff = None, None - prev_h, prev_w = None, None - step = np.array( - block_size, dtype=np.int32) - np.array( - overlap, dtype=np.int32) - if step[0] == 0 or step[1] == 0: - raise ValueError("`block_size` and `overlap` should not be equal.") - for yoff in range(0, height, step[1]): - for xoff in range(0, width, step[0]): - xsize, ysize = block_size - if xoff + xsize > width: - xsize = int(width - xoff) - if yoff + ysize > height: - ysize = int(height - yoff) - - xoff = int(xoff) - yoff = int(yoff) - im1 = src1_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( - (1, 2, 0)) - im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( - (1, 2, 0)) - # Fill - h, w = im1.shape[:2] - im1_fill = np.zeros( - (block_size[1], block_size[0], bands), dtype=im1.dtype) - im1_fill[:h, :w, :] = im1 - - im2_fill = np.zeros( - (block_size[1], block_size[0], bands), dtype=im2.dtype) - im2_fill[:h, :w, :] = im2 - - # Predict - pred = self.predict((im1_fill, im2_fill), transforms) - pred = pred["label_map"].astype('uint8') - pred = pred[:h, :w] - - # Deal with overlapping pixels - if merge_strategy == 'vote': - cache.push_block(yoff, xoff, h, w, pred) - pred = cache.get_block(yoff, xoff, h, w) - pred = pred.astype('uint8') - if prev_yoff is not None: - pop_h = yoff - prev_yoff - else: - pop_h = 0 - if prev_xoff is not None: - if xoff < prev_xoff: - pop_w = prev_w - else: - pop_w = xoff - prev_xoff - else: - pop_w = 0 - cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w) - elif merge_strategy == 'keep_first': - rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize) - mask = rd_block != invalid_value - pred = np.where(mask, rd_block, pred) - elif merge_strategy == 'keep_last': - pass - - # Write to file - band.WriteArray(pred, xoff, yoff) - dst_data.FlushCache() - - prev_xoff = xoff - prev_w = w - - prev_yoff = yoff - prev_h = h - - dst_data = None - logging.info("GeoTiff file saved in {}.".format(save_file)) + slider_predict(self, img_files, save_dir, block_size, overlap, + transforms, invalid_value, merge_strategy) def preprocess(self, images, transforms, to_tensor=True): self._check_transforms(transforms, 'test') diff --git a/paddlers/tasks/segmenter.py b/paddlers/tasks/segmenter.py index 1037041..8b332aa 100644 --- a/paddlers/tasks/segmenter.py +++ b/paddlers/tasks/segmenter.py @@ -34,7 +34,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict from .base import BaseModel from .utils import seg_metrics as metrics from .utils.infer_nets import InferSegNet -from .utils.slider_predict import SlowCache as Cache +from .utils.slider_predict import slider_predict __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"] @@ -572,135 +572,16 @@ class BaseSegmenter(BaseModel): invalid_value (int, optional): Value that marks invalid pixels in output image. Defaults to 255. merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices - are {'keep_first', 'keep_last', 'vote'}. 'keep_first' and 'keep_last' - means keeping the values of the first and the last block in traversal - order, respectively. 'vote' means applying a simple voting strategy when - there are conflicts in the overlapping pixels. Defaults to 'keep_last'. + are {'keep_first', 'keep_last', 'vote', 'accum'}. 'keep_first' and + 'keep_last' means keeping the values of the first and the last block in + traversal order, respectively. 'vote' means applying a simple voting + strategy when there are conflicts in the overlapping pixels. 'accum' + means determining the class of an overlapping pixel according to + accumulated probabilities. Defaults to 'keep_last'. """ - try: - from osgeo import gdal - except: - import gdal - - if isinstance(block_size, int): - block_size = (block_size, block_size) - elif isinstance(block_size, (tuple, list)) and len(block_size) == 2: - block_size = tuple(block_size) - else: - raise ValueError( - "`block_size` must be a tuple/list of length 2 or an integer.") - if isinstance(overlap, int): - overlap = (overlap, overlap) - elif isinstance(overlap, (tuple, list)) and len(overlap) == 2: - overlap = tuple(overlap) - else: - raise ValueError( - "`overlap` must be a tuple/list of length 2 or an integer.") - - if merge_strategy not in ('keep_first', 'keep_last', 'vote'): - raise ValueError( - "{} is not a supported stragegy for block merging.".format( - merge_strategy)) - if overlap == (0, 0): - # When there is no overlap, use 'keep_last' strategy as it introduces least overheads - merge_strategy = 'keep_last' - if merge_strategy == 'vote': - logging.warning( - "Currently, a naive Python-implemented cache is used for aggregating voting results. " - "For higher performance in inferring large images, please set `merge_strategy` to 'keep_first' or " - "'keep_last'.") - cache = Cache() - - src_data = gdal.Open(img_file) - width = src_data.RasterXSize - height = src_data.RasterYSize - bands = src_data.RasterCount - - driver = gdal.GetDriverByName("GTiff") - file_name = osp.basename(osp.normpath(img_file)) - # Replace extension name with '.tif' - file_name = osp.splitext(file_name)[0] + ".tif" - if not osp.exists(save_dir): - os.makedirs(save_dir) - save_file = osp.join(save_dir, file_name) - dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte) - - # Set meta-information - dst_data.SetGeoTransform(src_data.GetGeoTransform()) - dst_data.SetProjection(src_data.GetProjection()) - - band = dst_data.GetRasterBand(1) - band.WriteArray( - np.full( - (height, width), fill_value=invalid_value, dtype="uint8")) - - prev_yoff, prev_xoff = None, None - prev_h, prev_w = None, None - step = np.array( - block_size, dtype=np.int32) - np.array( - overlap, dtype=np.int32) - if step[0] == 0 or step[1] == 0: - raise ValueError("`block_size` and `overlap` should not be equal.") - for yoff in range(0, height, step[1]): - for xoff in range(0, width, step[0]): - xsize, ysize = block_size - if xoff + xsize > width: - xsize = int(width - xoff) - if yoff + ysize > height: - ysize = int(height - yoff) - - xoff = int(xoff) - yoff = int(yoff) - im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( - (1, 2, 0)) - # Fill - h, w = im.shape[:2] - im_fill = np.zeros( - (block_size[1], block_size[0], bands), dtype=im.dtype) - im_fill[:h, :w, :] = im - - # Predict - pred = self.predict(im_fill, transforms) - pred = pred["label_map"].astype('uint8') - pred = pred[:h, :w] - - # Deal with overlapping pixels - if merge_strategy == 'vote': - cache.push_block(yoff, xoff, h, w, pred) - pred = cache.get_block(yoff, xoff, h, w) - pred = pred.astype('uint8') - if prev_yoff is not None: - pop_h = yoff - prev_yoff - else: - pop_h = 0 - if prev_xoff is not None: - if xoff < prev_xoff: - pop_w = prev_w - else: - pop_w = xoff - prev_xoff - else: - pop_w = 0 - cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w) - elif merge_strategy == 'keep_first': - rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize) - mask = rd_block != invalid_value - pred = np.where(mask, rd_block, pred) - elif merge_strategy == 'keep_last': - pass - - # Write to file - band.WriteArray(pred, xoff, yoff) - dst_data.FlushCache() - - prev_xoff = xoff - prev_w = w - - prev_yoff = yoff - prev_h = h - - dst_data = None - logging.info("GeoTiff file saved in {}.".format(save_file)) + slider_predict(self, img_file, save_dir, block_size, overlap, + transforms, invalid_value, merge_strategy) def preprocess(self, images, transforms, to_tensor=True): self._check_transforms(transforms, 'test') diff --git a/paddlers/tasks/utils/slider_predict.py b/paddlers/tasks/utils/slider_predict.py index 254882b..a9ecdf6 100644 --- a/paddlers/tasks/utils/slider_predict.py +++ b/paddlers/tasks/utils/slider_predict.py @@ -12,12 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import os.path as osp +from abc import ABCMeta, abstractmethod from collections import Counter, defaultdict import numpy as np +import paddlers.utils.logging as logging -class SlowCache(object): + +class Cache(metaclass=ABCMeta): + @abstractmethod + def get_block(self, i_st, j_st, h, w): + pass + + +class SlowCache(Cache): def __init__(self): self.cache = defaultdict(Counter) @@ -50,3 +61,244 @@ class SlowCache(object): row.append(self.get_pixel(i, j)) block.append(row) return np.asarray(block) + + +class ProbCache(Cache): + def __init__(self, h, w, ch, cw, sh, sw, dtype=np.float32, order='c'): + self.cache = None + self.h = h + self.w = w + self.ch = ch + self.cw = cw + self.sh = sh + self.sw = sw + if not issubclass(dtype, np.floating): + raise TypeError("`dtype` must be one of the floating types.") + self.dtype = dtype + order = order.lower() + if order not in ('c', 'f'): + raise ValueError("`order` other than 'c' and 'f' is not supported.") + self.order = order + + def _alloc_memory(self, nc): + if self.order == 'c': + # Colomn-first order (C-style) + # + # <-- cw --> + # |--------|---------------------|^ ^ + # | || | sh + # |--------|---------------------|| ch v + # | || + # |--------|---------------------|v + # <------------ w ---------------> + self.cache = np.zeros((self.ch, self.w, nc), dtype=self.dtype) + elif self.order == 'f': + # Row-first order (Fortran-style) + # + # <-- sw --> + # <---- cw ----> + # |--------|---|^ ^ + # | | || | + # | | || ch + # | | || | + # |--------|---|| h v + # | | || + # | | || + # | | || + # |--------|---|v + self.cache = np.zeros((self.h, self.cw, nc), dtype=self.dtype) + + def update_block(self, i_st, j_st, h, w, prob_map): + if self.cache is None: + nc = prob_map.shape[2] + # Lazy allocation of memory + self._alloc_memory(nc) + self.cache[i_st:i_st + h, j_st:j_st + w] += prob_map + + def roll_cache(self): + if self.order == 'c': + self.cache = np.roll(self.cache, -self.sh, axis=0) + self.cache[self.sh:self.ch, :] = 0 + elif self.order == 'f': + self.cache = np.roll(self.cache, -self.sw, axis=1) + self.cache[:, self.sw:self.cw] = 0 + + def get_block(self, i_st, j_st, h, w): + return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2) + + +def slider_predict(predictor, img_file, save_dir, block_size, overlap, + transforms, invalid_value, merge_strategy): + """ + Do inference using sliding windows. + + Args: + predictor (object): Object that implements `predict()` method. + img_file (str|tuple[str]): Image path(s). + save_dir (str): Directory that contains saved geotiff file. + block_size (list[int] | tuple[int] | int): + Size of block. If `block_size` is list or tuple, it should be in + (W, H) format. + overlap (list[int] | tuple[int] | int): + Overlap between two blocks. If `overlap` is list or tuple, it should + be in (W, H) format. + transforms (paddlers.transforms.Compose|None): Transforms for inputs. If + None, the transforms for evaluation process will be used. + invalid_value (int): Value that marks invalid pixels in output image. + Defaults to 255. + merge_strategy (str): Strategy to merge overlapping blocks. Choices are + {'keep_first', 'keep_last', 'vote', 'accum'}. 'keep_first' and + 'keep_last' means keeping the values of the first and the last block in + traversal order, respectively. 'vote' means applying a simple voting + strategy when there are conflicts in the overlapping pixels. 'accum' + means determining the class of an overlapping pixel according to + accumulated probabilities. + """ + + try: + from osgeo import gdal + except: + import gdal + + if isinstance(block_size, int): + block_size = (block_size, block_size) + elif isinstance(block_size, (tuple, list)) and len(block_size) == 2: + block_size = tuple(block_size) + else: + raise ValueError( + "`block_size` must be a tuple/list of length 2 or an integer.") + if isinstance(overlap, int): + overlap = (overlap, overlap) + elif isinstance(overlap, (tuple, list)) and len(overlap) == 2: + overlap = tuple(overlap) + else: + raise ValueError( + "`overlap` must be a tuple/list of length 2 or an integer.") + + if merge_strategy not in ('keep_first', 'keep_last', 'vote', 'accum'): + raise ValueError("{} is not a supported stragegy for block merging.". + format(merge_strategy)) + + step = np.array( + block_size, dtype=np.int32) - np.array( + overlap, dtype=np.int32) + if step[0] == 0 or step[1] == 0: + raise ValueError("`block_size` and `overlap` should not be equal.") + + if isinstance(img_file, tuple): + if len(img_file) != 2: + raise ValueError("Tuple `img_file` must have the length of two.") + # Assume that two input images have the same size + src_data = gdal.Open(img_file[0]) + src2_data = gdal.Open(img_file[1]) + # Output name is the same as the name of the first image + file_name = osp.basename(osp.normpath(img_file[0])) + else: + src_data = gdal.Open(img_file) + file_name = osp.basename(osp.normpath(img_file)) + + # Get size of original raster + width = src_data.RasterXSize + height = src_data.RasterYSize + bands = src_data.RasterCount + + if block_size[0] > width or block_size[1] > height: + raise ValueError("`block_size` should not be larger than image size.") + + driver = gdal.GetDriverByName("GTiff") + if not osp.exists(save_dir): + os.makedirs(save_dir) + # Replace extension name with '.tif' + file_name = osp.splitext(file_name)[0] + ".tif" + save_file = osp.join(save_dir, file_name) + dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte) + + # Set meta-information + dst_data.SetGeoTransform(src_data.GetGeoTransform()) + dst_data.SetProjection(src_data.GetProjection()) + + # Initialize raster with `invalid_value` + band = dst_data.GetRasterBand(1) + band.WriteArray( + np.full( + (height, width), fill_value=invalid_value, dtype="uint8")) + + if overlap == (0, 0) or block_size == (width, height): + # When there is no overlap or the whole image is used as input, + # use 'keep_last' strategy as it introduces least overheads + merge_strategy = 'keep_last' + if merge_strategy == 'vote': + logging.warning( + "Currently, a naive Python-implemented cache is used for aggregating voting results. " + "For higher performance in inferring large images, please set `merge_strategy` to 'keep_first', " + "'keep_last', or 'accum'.") + cache = SlowCache() + elif merge_strategy == 'accum': + cache = ProbCache(height, width, *block_size, *step) + + prev_yoff, prev_xoff = None, None + + for yoff in range(0, height, step[1]): + for xoff in range(0, width, step[0]): + xsize, ysize = block_size + if xoff + xsize > width: + xoff = width - xsize + if yoff + ysize > height: + yoff = height - ysize + + # Read and fill + im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( + (1, 2, 0)) + + if isinstance(img_file, tuple): + im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( + (1, 2, 0)) + # Predict + out = predictor.predict((im, im2), transforms) + else: + # Predict + out = predictor.predict(im, transforms) + + pred = out['label_map'].astype('uint8') + pred = pred[:ysize, :xsize] + + # Deal with overlapping pixels + if merge_strategy == 'vote': + cache.push_block(yoff, xoff, ysize, xsize, pred) + pred = cache.get_block(yoff, xoff, ysize, xsize) + pred = pred.astype('uint8') + if prev_yoff is not None: + pop_h = yoff - prev_yoff + else: + pop_h = 0 + if prev_xoff is not None: + if xoff < prev_xoff: + pop_w = xsize + else: + pop_w = xoff - prev_xoff + else: + pop_w = 0 + cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w) + elif merge_strategy == 'keep_first': + rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize) + mask = rd_block != invalid_value + pred = np.where(mask, rd_block, pred) + elif merge_strategy == 'keep_last': + pass + elif merge_strategy == 'accum': + prob = out['score_map'] + prob = prob[:ysize, :xsize] + cache.update_block(0, yoff, ysize, xsize, prob) + pred = cache.get_block(0, yoff, ysize, xsize) + if xoff + step[0] >= width: + cache.roll_cache() + + # Write to file + band.WriteArray(pred, xoff, yoff) + dst_data.FlushCache() + + prev_xoff = xoff + prev_yoff = yoff + + dst_data = None + logging.info("GeoTiff file saved in {}.".format(save_file)) diff --git a/tests/tasks/test_slider_predict.py b/tests/tasks/test_slider_predict.py index bc3bd5a..fce8550 100644 --- a/tests/tasks/test_slider_predict.py +++ b/tests/tasks/test_slider_predict.py @@ -75,13 +75,9 @@ class TestSegSliderPredict(CommonTest): # `block_size` larger than image size save_dir = osp.join(td, 'pred5') - self.model.slider_predict(self.image_path, save_dir, 512, 0, - self.transforms) - pred5 = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred5.shape, pred_whole.shape) + with self.assertRaises(ValueError): + self.model.slider_predict(self.image_path, save_dir, 512, 0, + self.transforms) def test_merge_strategy(self): with tempfile.TemporaryDirectory() as td: @@ -134,6 +130,21 @@ class TestSegSliderPredict(CommonTest): decode_sar=False) self.check_output_equal(pred_vote.shape, pred_whole.shape) + # 'accum' + save_dir = osp.join(td, 'accum') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='vote') + pred_accum = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred_accum.shape, pred_whole.shape) + def test_geo_info(self): with tempfile.TemporaryDirectory() as td: _, geo_info_in = T.decode_image(self.image_path, read_geo_info=True) @@ -202,13 +213,9 @@ class TestCDSliderPredict(CommonTest): # `block_size` larger than image size save_dir = osp.join(td, 'pred5') - self.model.slider_predict(self.image_paths, save_dir, 512, 0, - self.transforms) - pred5 = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred5.shape, pred_whole.shape) + with self.assertRaises(ValueError): + self.model.slider_predict(self.image_paths, save_dir, 512, 0, + self.transforms) def test_merge_strategy(self): with tempfile.TemporaryDirectory() as td: @@ -261,6 +268,21 @@ class TestCDSliderPredict(CommonTest): decode_sar=False) self.check_output_equal(pred_vote.shape, pred_whole.shape) + # 'accum' + save_dir = osp.join(td, 'accum') + self.model.slider_predict( + self.image_paths, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='vote') + pred_accum = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred_accum.shape, pred_whole.shape) + def test_geo_info(self): with tempfile.TemporaryDirectory() as td: _, geo_info_in = T.decode_image(