From 8beb15c3c833848d0031fd371703daa27093b8a1 Mon Sep 17 00:00:00 2001 From: Lin Manhui Date: Wed, 7 Sep 2022 10:45:19 +0800 Subject: [PATCH] [Feat] Optimize `slide_predict()` (#3) --- docs/apis/data.md | 6 +- paddlers/deploy/predictor.py | 17 +- paddlers/tasks/change_detector.py | 10 +- paddlers/tasks/classifier.py | 2 +- paddlers/tasks/object_detector.py | 2 +- paddlers/tasks/restorer.py | 2 +- paddlers/tasks/segmenter.py | 8 +- paddlers/tasks/utils/slider_predict.py | 86 ++++-- paddlers/transforms/__init__.py | 17 +- paddlers/transforms/functions.py | 6 +- paddlers/transforms/operators.py | 10 +- tests/deploy/test_predictor.py | 29 +- tests/tasks/test_slider_predict.py | 404 +++++++++++-------------- 13 files changed, 302 insertions(+), 297 deletions(-) diff --git a/docs/apis/data.md b/docs/apis/data.md index 6be2e32..34afb1c 100644 --- a/docs/apis/data.md +++ b/docs/apis/data.md @@ -134,10 +134,12 @@ |-------|----|--------|-----| |`im_path`|`str`|输入图像路径。|| |`to_rgb`|`bool`|若为`True`,则执行BGR到RGB格式的转换。|`True`| -|`to_uint8`|`bool`|若为`True`,则将读取的图像数据量化并转换为uint8类型。|`True`| +|`to_uint8`|`bool`|若为`True`,则将读取的影像数据量化并转换为uint8类型。|`True`| |`decode_bgr`|`bool`|若为`True`,则自动将非地学格式影像(如jpeg影像)解析为BGR格式。|`True`| -|`decode_sar`|`bool`|若为`True`,则自动将2通道的地学格式影像(如GeoTiff影像)作为SAR影像解析。|`True`| +|`decode_sar`|`bool`|若为`True`,则自动将单通道的地学格式影像(如GeoTiff影像)作为SAR影像解析。|`True`| |`read_geo_info`|`bool`|若为`True`,则从影像中读取地理信息。|`False`| +|`use_stretch`|`bool`|是否对影像亮度进行2%线性拉伸。仅当`to_uint8`为`True`时有效。|`False`| +|`read_raw`|`bool`|若为`True`,等价于指定`to_rgb`和`to_uint8`为`False`,且该参数的优先级高于上述参数。|`False`| 返回格式如下: diff --git a/paddlers/deploy/predictor.py b/paddlers/deploy/predictor.py index 157abca..ab474d2 100644 --- a/paddlers/deploy/predictor.py +++ b/paddlers/deploy/predictor.py @@ -282,8 +282,8 @@ class Predictor(object): img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For scene classification, image restoration, object detection and semantic segmentation tasks, `img_file` should be either the path of the image to predict, a decoded image (a np.ndarray, which should be consistent with what you get from passing image path to - paddlers.transforms.decode_image()), or a list of image paths or decoded images. For change detection tasks, - `img_file` should be a tuple of image paths, a tuple of decoded images, or a list of tuples. + paddlers.transforms.decode_image(..., read_raw=True)), or a list of image paths or decoded images. For change + detection tasks, `img_file` should be a tuple of image paths, a tuple of decoded images, or a list of tuples. topk(int, optional): Top-k values to reserve in a classification result. Defaults to 1. transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms from `model.yml`. Defaults to None. @@ -332,7 +332,8 @@ class Predictor(object): overlap=36, transforms=None, invalid_value=255, - merge_strategy='keep_last'): + merge_strategy='keep_last', + batch_size=1): """ Do inference using sliding windows. Only semantic segmentation and change detection models are supported in the sliding-predicting mode. @@ -340,9 +341,9 @@ class Predictor(object): Args: img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For semantic segmentation tasks, `img_file` should be either the path of the image to predict, a decoded image (a np.ndarray, which should be - consistent with what you get from passing image path to paddlers.transforms.decode_image()), or a list of - image paths or decoded images. For change detection tasks, `img_file` should be a tuple of image paths, a - tuple of decoded images, or a list of tuples. + consistent with what you get from passing image path to paddlers.transforms.decode_image(..., read_raw=True)), + or a list of image paths or decoded images. For change detection tasks, `img_file` should be a tuple of + image paths, a tuple of decoded images, or a list of tuples. save_dir (str): Directory that contains saved geotiff file. block_size (list[int] | tuple[int] | int): Size of block. If `block_size` is a list or tuple, it should be in (W, H) format. @@ -355,6 +356,7 @@ class Predictor(object): {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' means keeping the values of the first and the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel according to accumulated probabilities. Defaults to 'keep_last'. + batch_size (int, optional): Batch size used in inference. Defaults to 1. """ slider_predict( partial( @@ -365,7 +367,8 @@ class Predictor(object): overlap, transforms, invalid_value, - merge_strategy) + merge_strategy, + batch_size) def batch_predict(self, image_list, **params): return self.predict(img_file=image_list, **params) diff --git a/paddlers/tasks/change_detector.py b/paddlers/tasks/change_detector.py index afccdc4..347e996 100644 --- a/paddlers/tasks/change_detector.py +++ b/paddlers/tasks/change_detector.py @@ -588,7 +588,8 @@ class BaseChangeDetector(BaseModel): overlap=36, transforms=None, invalid_value=255, - merge_strategy='keep_last'): + merge_strategy='keep_last', + batch_size=1): """ Do inference using sliding windows. @@ -611,10 +612,11 @@ class BaseChangeDetector(BaseModel): means keeping the values of the first and the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel according to accumulated probabilities. Defaults to 'keep_last'. + batch_size (int, optional): Batch size used in inference. Defaults to 1. """ slider_predict(self.predict, img_files, save_dir, block_size, overlap, - transforms, invalid_value, merge_strategy) + transforms, invalid_value, merge_strategy, batch_size) def preprocess(self, images, transforms, to_tensor=True): self._check_transforms(transforms, 'test') @@ -622,8 +624,8 @@ class BaseChangeDetector(BaseModel): batch_ori_shape = list() for im1, im2 in images: if isinstance(im1, str) or isinstance(im2, str): - im1 = decode_image(im1, to_rgb=False) - im2 = decode_image(im2, to_rgb=False) + im1 = decode_image(im1, read_raw=True) + im2 = decode_image(im2, read_raw=True) ori_shape = im1.shape[:2] # XXX: sample do not contain 'image_t1' and 'image_t2'. sample = {'image': im1, 'image2': im2} diff --git a/paddlers/tasks/classifier.py b/paddlers/tasks/classifier.py index 33ba5f3..12b6640 100644 --- a/paddlers/tasks/classifier.py +++ b/paddlers/tasks/classifier.py @@ -497,7 +497,7 @@ class BaseClassifier(BaseModel): batch_ori_shape = list() for im in images: if isinstance(im, str): - im = decode_image(im, to_rgb=False) + im = decode_image(im, read_raw=True) ori_shape = im.shape[:2] sample = {'image': im} im = transforms(sample) diff --git a/paddlers/tasks/object_detector.py b/paddlers/tasks/object_detector.py index 313a893..f42d7cc 100644 --- a/paddlers/tasks/object_detector.py +++ b/paddlers/tasks/object_detector.py @@ -617,7 +617,7 @@ class BaseDetector(BaseModel): batch_samples = list() for im in images: if isinstance(im, str): - im = decode_image(im, to_rgb=False) + im = decode_image(im, read_raw=True) sample = {'image': im} sample = transforms(sample) batch_samples.append(sample) diff --git a/paddlers/tasks/restorer.py b/paddlers/tasks/restorer.py index 61691f0..ce7e931 100644 --- a/paddlers/tasks/restorer.py +++ b/paddlers/tasks/restorer.py @@ -481,7 +481,7 @@ class BaseRestorer(BaseModel): batch_tar_shape = list() for im in images: if isinstance(im, str): - im = decode_image(im, to_rgb=False) + im = decode_image(im, read_raw=True) ori_shape = im.shape[:2] sample = {'image': im} im = transforms(sample)[0] diff --git a/paddlers/tasks/segmenter.py b/paddlers/tasks/segmenter.py index a319fc5..29f2aa4 100644 --- a/paddlers/tasks/segmenter.py +++ b/paddlers/tasks/segmenter.py @@ -560,7 +560,8 @@ class BaseSegmenter(BaseModel): overlap=36, transforms=None, invalid_value=255, - merge_strategy='keep_last'): + merge_strategy='keep_last', + batch_size=1): """ Do inference using sliding windows. @@ -583,10 +584,11 @@ class BaseSegmenter(BaseModel): means keeping the values of the first and the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel according to accumulated probabilities. Defaults to 'keep_last'. + batch_size (int, optional): Batch size used in inference. Defaults to 1. """ slider_predict(self.predict, img_file, save_dir, block_size, overlap, - transforms, invalid_value, merge_strategy) + transforms, invalid_value, merge_strategy, batch_size) def preprocess(self, images, transforms, to_tensor=True): self._check_transforms(transforms, 'test') @@ -594,7 +596,7 @@ class BaseSegmenter(BaseModel): batch_ori_shape = list() for im in images: if isinstance(im, str): - im = decode_image(im, to_rgb=False) + im = decode_image(im, read_raw=True) ori_shape = im.shape[:2] sample = {'image': im} im = transforms(sample)[0] diff --git a/paddlers/tasks/utils/slider_predict.py b/paddlers/tasks/utils/slider_predict.py index f32f7f0..1672eb0 100644 --- a/paddlers/tasks/utils/slider_predict.py +++ b/paddlers/tasks/utils/slider_predict.py @@ -14,6 +14,7 @@ import os import os.path as osp +import math from abc import ABCMeta, abstractmethod from collections import Counter, defaultdict @@ -117,10 +118,10 @@ class ProbCache(Cache): def roll_cache(self): if self.order == 'c': - self.cache = np.roll(self.cache, -self.sh, axis=0) + self.cache[:-self.sh] = self.cache[self.sh:] self.cache[-self.sh:, :] = 0 elif self.order == 'f': - self.cache = np.roll(self.cache, -self.sw, axis=1) + self.cache[:, :-self.sw] = self.cache[:, self.sw:] self.cache[:, -self.sw:] = 0 def get_block(self, i_st, j_st, h, w): @@ -128,7 +129,7 @@ class ProbCache(Cache): def slider_predict(predict_func, img_file, save_dir, block_size, overlap, - transforms, invalid_value, merge_strategy): + transforms, invalid_value, merge_strategy, batch_size): """ Do inference using sliding windows. @@ -151,6 +152,7 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, means keeping the values of the first and the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel according to accumulated probabilities. + batch_size (int): Batch size used in inference. """ try: @@ -200,6 +202,13 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, height = src_data.RasterYSize bands = src_data.RasterCount + # XXX: GDAL read behavior conforms to paddlers.transforms.decode_image(read_raw=True) + # except for SAR images. + if bands == 1: + logging.warning( + f"Detected `bands=1`. Please note that currently `slider_predict()` does not properly handle SAR images." + ) + if block_size[0] > width or block_size[1] > height: raise ValueError("`block_size` should not be larger than image size.") @@ -228,6 +237,8 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, if merge_strategy == 'accum': cache = ProbCache(height, width, *block_size[::-1], *step[::-1]) + batch_data = [] + batch_offsets = [] for yoff in range(0, height, step[1]): for xoff in range(0, width, step[0]): xsize, ysize = block_size @@ -236,6 +247,9 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, if yoff + ysize > height: yoff = height - ysize + is_end_of_col = yoff + ysize >= height + is_end_of_row = xoff + xsize >= width + # Read and fill im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( (1, 2, 0)) @@ -243,33 +257,49 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, if isinstance(img_file, tuple): im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( (1, 2, 0)) - # Predict - out = predict_func((im, im2), transforms=transforms) + batch_data.append((im, im2)) else: + batch_data.append(im) + + batch_offsets.append((xoff, yoff)) + + len_batch = len(batch_data) + + if is_end_of_row and is_end_of_col and len_batch < batch_size: + # Pad `batch_data` by repeating the last element + batch_data = batch_data + [batch_data[-1]] * (batch_size - + len_batch) + # While keeping `len(batch_offsets)` the number of valid elements in the batch + + if len(batch_data) == batch_size: # Predict - out = predict_func(im, transforms=transforms) - - pred = out['label_map'].astype('uint8') - pred = pred[:ysize, :xsize] - - # Deal with overlapping pixels - if merge_strategy == 'keep_first': - rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize) - mask = rd_block != invalid_value - pred = np.where(mask, rd_block, pred) - elif merge_strategy == 'keep_last': - pass - elif merge_strategy == 'accum': - prob = out['score_map'] - prob = prob[:ysize, :xsize] - cache.update_block(0, xoff, ysize, xsize, prob) - pred = cache.get_block(0, xoff, ysize, xsize) - if xoff + xsize >= width: - cache.roll_cache() - - # Write to file - band.WriteArray(pred, xoff, yoff) - dst_data.FlushCache() + batch_out = predict_func(batch_data, transforms=transforms) + + for out, (xoff_, yoff_) in zip(batch_out, batch_offsets): + pred = out['label_map'].astype('uint8') + pred = pred[:ysize, :xsize] + + # Deal with overlapping pixels + if merge_strategy == 'keep_first': + rd_block = band.ReadAsArray(xoff_, yoff_, xsize, ysize) + mask = rd_block != invalid_value + pred = np.where(mask, rd_block, pred) + elif merge_strategy == 'keep_last': + pass + elif merge_strategy == 'accum': + prob = out['score_map'] + prob = prob[:ysize, :xsize] + cache.update_block(0, xoff_, ysize, xsize, prob) + pred = cache.get_block(0, xoff_, ysize, xsize) + if xoff_ + xsize >= width: + cache.roll_cache() + + # Write to file + band.WriteArray(pred, xoff_, yoff_) + + dst_data.FlushCache() + batch_data.clear() + batch_offsets.clear() dst_data = None logging.info("GeoTiff file saved in {}.".format(save_file)) diff --git a/paddlers/transforms/__init__.py b/paddlers/transforms/__init__.py index aafc4dd..ec470fb 100644 --- a/paddlers/transforms/__init__.py +++ b/paddlers/transforms/__init__.py @@ -25,7 +25,9 @@ def decode_image(im_path, to_uint8=True, decode_bgr=True, decode_sar=True, - read_geo_info=False): + read_geo_info=False, + use_stretch=False, + read_raw=False): """ Decode an image. @@ -37,11 +39,16 @@ def decode_image(im_path, uint8 type. Defaults to True. decode_bgr (bool, optional): If True, automatically interpret a non-geo image (e.g. jpeg images) as a BGR image. Defaults to True. - decode_sar (bool, optional): If True, automatically interpret a two-channel + decode_sar (bool, optional): If True, automatically interpret a single-channel geo image (e.g. geotiff images) as a SAR image, set this argument to True. Defaults to True. read_geo_info (bool, optional): If True, read geographical information from the image. Deafults to False. + use_stretch (bool, optional): Whether to apply 2% linear stretch. Valid only if + `to_uint8` is True. Defaults to False. + read_raw (bool, optional): If True, equivalent to setting `to_rgb` and `to_uint8` + to False. Setting `read_raw` takes precedence over setting `to_rgb` and + `to_uint8`. Defaults to False. Returns: np.ndarray|tuple: If `read_geo_info` is False, return the decoded image. @@ -53,12 +60,16 @@ def decode_image(im_path, # Do a presence check. osp.exists() assumes `im_path` is a path-like object. if not osp.exists(im_path): raise ValueError(f"{im_path} does not exist!") + if read_raw: + to_rgb = False + to_uint8 = False decoder = T.DecodeImg( to_rgb=to_rgb, to_uint8=to_uint8, decode_bgr=decode_bgr, decode_sar=decode_sar, - read_geo_info=read_geo_info) + read_geo_info=read_geo_info, + use_stretch=use_stretch) # Deepcopy to avoid inplace modification sample = {'image': copy.deepcopy(im_path)} sample = decoder(sample) diff --git a/paddlers/transforms/functions.py b/paddlers/transforms/functions.py index 5550e33..8518de5 100644 --- a/paddlers/transforms/functions.py +++ b/paddlers/transforms/functions.py @@ -382,13 +382,13 @@ def resize_rle(rle, im_h, im_w, im_scale_x, im_scale_y, interp): return rle -def to_uint8(im, is_linear=False): +def to_uint8(im, stretch=False): """ Convert raster data to uint8 type. Args: im (np.ndarray): Input raster image. - is_linear (bool, optional): Use 2% linear stretch or not. Default is False. + stretch (bool, optional): Use 2% linear stretch or not. Default is False. Returns: np.ndarray: Image data with unit8 type. @@ -430,7 +430,7 @@ def to_uint8(im, is_linear=False): dtype = im.dtype.name if dtype != "uint8": im = _sample_norm(im) - if is_linear: + if stretch: im = _two_percent_linear(im) return im diff --git a/paddlers/transforms/operators.py b/paddlers/transforms/operators.py index 6110bdb..47e6cc4 100644 --- a/paddlers/transforms/operators.py +++ b/paddlers/transforms/operators.py @@ -179,11 +179,13 @@ class DecodeImg(Transform): uint8 type. Defaults to True. decode_bgr (bool, optional): If True, automatically interpret a non-geo image (e.g., jpeg images) as a BGR image. Defaults to True. - decode_sar (bool, optional): If True, automatically interpret a two-channel + decode_sar (bool, optional): If True, automatically interpret a single-channel geo image (e.g. geotiff images) as a SAR image, set this argument to True. Defaults to True. read_geo_info (bool, optional): If True, read geographical information from the image. Deafults to False. + use_stretch (bool, optional): Whether to apply 2% linear stretch. Valid only if + `to_uint8` is True. Defaults to False. """ def __init__(self, @@ -191,13 +193,15 @@ class DecodeImg(Transform): to_uint8=True, decode_bgr=True, decode_sar=True, - read_geo_info=False): + read_geo_info=False, + use_stretch=False): super(DecodeImg, self).__init__() self.to_rgb = to_rgb self.to_uint8 = to_uint8 self.decode_bgr = decode_bgr self.decode_sar = decode_sar self.read_geo_info = read_geo_info + self.use_stretch = use_stretch def read_img(self, img_path): img_format = imghdr.what(img_path) @@ -264,7 +268,7 @@ class DecodeImg(Transform): image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if self.to_uint8: - image = to_uint8(image) + image = to_uint8(image, stretch=self.use_stretch) if self.read_geo_info: return image, geo_info_dict diff --git a/tests/deploy/test_predictor.py b/tests/deploy/test_predictor.py index 8675e65..081b34a 100644 --- a/tests/deploy/test_predictor.py +++ b/tests/deploy/test_predictor.py @@ -145,8 +145,8 @@ class TestCDPredictor(TestPredictor): # Single input (ndarrays) input_ = (decode_image( - t1_path, to_rgb=False), decode_image( - t2_path, to_rgb=False)) # Reuse the name `input_` + t1_path, read_raw=True), decode_image( + t2_path, read_raw=True)) # Reuse the name `input_` out_single_array_p = predictor.predict(input_, transforms=transforms) self.check_dict_equal(out_single_array_p, out_single_file_p) out_single_array_t = trainer.predict(input_, transforms=transforms) @@ -169,8 +169,9 @@ class TestCDPredictor(TestPredictor): # Multiple inputs (ndarrays) input_ = [(decode_image( - t1_path, to_rgb=False), decode_image( - t2_path, to_rgb=False))] * num_inputs # Reuse the name `input_` + t1_path, read_raw=True), decode_image( + t2_path, + read_raw=True))] * num_inputs # Reuse the name `input_` out_multi_array_p = predictor.predict(input_, transforms=transforms) self.assertEqual(len(out_multi_array_p), num_inputs) out_multi_array_t = trainer.predict(input_, transforms=transforms) @@ -211,7 +212,7 @@ class TestClasPredictor(TestPredictor): # Single input (ndarray) input_ = decode_image( - single_input, to_rgb=False) # Reuse the name `input_` + single_input, read_raw=True) # Reuse the name `input_` out_single_array_p = predictor.predict(input_, transforms=transforms) self.check_dict_equal(out_single_array_p, out_single_file_p) out_single_array_t = trainer.predict(input_, transforms=transforms) @@ -235,7 +236,8 @@ class TestClasPredictor(TestPredictor): # Multiple inputs (ndarrays) input_ = [decode_image( - single_input, to_rgb=False)] * num_inputs # Reuse the name `input_` + single_input, + read_raw=True)] * num_inputs # Reuse the name `input_` out_multi_array_p = predictor.predict(input_, transforms=transforms) self.assertEqual(len(out_multi_array_p), num_inputs) out_multi_array_t = trainer.predict(input_, transforms=transforms) @@ -276,7 +278,7 @@ class TestDetPredictor(TestPredictor): # Single input (ndarray) input_ = decode_image( - single_input, to_rgb=False) # Reuse the name `input_` + single_input, read_raw=True) # Reuse the name `input_` predictor.predict(input_, transforms=transforms) trainer.predict(input_, transforms=transforms) out_single_array_list_p = predictor.predict( @@ -295,7 +297,8 @@ class TestDetPredictor(TestPredictor): # Multiple inputs (ndarrays) input_ = [decode_image( - single_input, to_rgb=False)] * num_inputs # Reuse the name `input_` + single_input, + read_raw=True)] * num_inputs # Reuse the name `input_` out_multi_array_p = predictor.predict(input_, transforms=transforms) self.assertEqual(len(out_multi_array_p), num_inputs) out_multi_array_t = trainer.predict(input_, transforms=transforms) @@ -329,7 +332,7 @@ class TestResPredictor(TestPredictor): # Single input (ndarray) input_ = decode_image( - single_input, to_rgb=False) # Reuse the name `input_` + single_input, read_raw=True) # Reuse the name `input_` predictor.predict(input_, transforms=transforms) trainer.predict(input_, transforms=transforms) out_single_array_list_p = predictor.predict( @@ -348,7 +351,8 @@ class TestResPredictor(TestPredictor): # Multiple inputs (ndarrays) input_ = [decode_image( - single_input, to_rgb=False)] * num_inputs # Reuse the name `input_` + single_input, + read_raw=True)] * num_inputs # Reuse the name `input_` out_multi_array_p = predictor.predict(input_, transforms=transforms) self.assertEqual(len(out_multi_array_p), num_inputs) out_multi_array_t = trainer.predict(input_, transforms=transforms) @@ -386,7 +390,7 @@ class TestSegPredictor(TestPredictor): # Single input (ndarray) input_ = decode_image( - single_input, to_rgb=False) # Reuse the name `input_` + single_input, read_raw=True) # Reuse the name `input_` out_single_array_p = predictor.predict(input_, transforms=transforms) self.check_dict_equal(out_single_array_p, out_single_file_p) out_single_array_t = trainer.predict(input_, transforms=transforms) @@ -409,7 +413,8 @@ class TestSegPredictor(TestPredictor): # Multiple inputs (ndarrays) input_ = [decode_image( - single_input, to_rgb=False)] * num_inputs # Reuse the name `input_` + single_input, + read_raw=True)] * num_inputs # Reuse the name `input_` out_multi_array_p = predictor.predict(input_, transforms=transforms) self.assertEqual(len(out_multi_array_p), num_inputs) out_multi_array_t = trainer.predict(input_, transforms=transforms) diff --git a/tests/tasks/test_slider_predict.py b/tests/tasks/test_slider_predict.py index d5eff91..5b4392d 100644 --- a/tests/tasks/test_slider_predict.py +++ b/tests/tasks/test_slider_predict.py @@ -20,7 +20,174 @@ import paddlers.transforms as T from testing_utils import CommonTest -class TestSegSliderPredict(CommonTest): +class _TestSliderPredictNamespace: + class TestSliderPredict(CommonTest): + def test_blocksize_and_overlap_whole(self): + # Original image size (256, 256) + with tempfile.TemporaryDirectory() as td: + # Whole-image inference using predict() + pred_whole = self.model.predict(self.image_path, + self.transforms) + pred_whole = pred_whole['label_map'] + + # Whole-image inference using slider_predict() + save_dir = osp.join(td, 'pred1') + self.model.slider_predict(self.image_path, save_dir, 256, 0, + self.transforms) + pred1 = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred1.shape, pred_whole.shape) + + # `block_size` == `overlap` + save_dir = osp.join(td, 'pred2') + with self.assertRaises(ValueError): + self.model.slider_predict(self.image_path, save_dir, 128, + 128, self.transforms) + + # `block_size` is a tuple + save_dir = osp.join(td, 'pred3') + self.model.slider_predict(self.image_path, save_dir, (128, 32), + 0, self.transforms) + pred3 = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred3.shape, pred_whole.shape) + + # `block_size` and `overlap` are both tuples + save_dir = osp.join(td, 'pred4') + self.model.slider_predict(self.image_path, save_dir, (128, 100), + (10, 5), self.transforms) + pred4 = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred4.shape, pred_whole.shape) + + # `block_size` larger than image size + save_dir = osp.join(td, 'pred5') + with self.assertRaises(ValueError): + self.model.slider_predict(self.image_path, save_dir, 512, 0, + self.transforms) + + def test_merge_strategy(self): + with tempfile.TemporaryDirectory() as td: + # Whole-image inference using predict() + pred_whole = self.model.predict(self.image_path, + self.transforms) + pred_whole = pred_whole['label_map'] + + # 'keep_first' + save_dir = osp.join(td, 'keep_first') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_first') + pred_keepfirst = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred_keepfirst.shape, pred_whole.shape) + + # 'keep_last' + save_dir = osp.join(td, 'keep_last') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_last') + pred_keeplast = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred_keeplast.shape, pred_whole.shape) + + # 'accum' + save_dir = osp.join(td, 'accum') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='accum') + pred_accum = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred_accum.shape, pred_whole.shape) + + def test_geo_info(self): + with tempfile.TemporaryDirectory() as td: + _, geo_info_in = T.decode_image( + self.ref_path, read_geo_info=True) + self.model.slider_predict(self.image_path, td, 128, 0, + self.transforms) + _, geo_info_out = T.decode_image( + osp.join(td, self.basename), read_geo_info=True) + self.assertEqual(geo_info_out['geo_trans'], + geo_info_in['geo_trans']) + self.assertEqual(geo_info_out['geo_proj'], + geo_info_in['geo_proj']) + + def test_batch_size(self): + with tempfile.TemporaryDirectory() as td: + # batch_size = 1 + save_dir = osp.join(td, 'bs1') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_first', + batch_size=1) + pred_bs1 = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + + # batch_size = 4 + save_dir = osp.join(td, 'bs4') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_first', + batch_size=4) + pred_bs4 = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred_bs4, pred_bs1) + + # batch_size = 8 + save_dir = osp.join(td, 'bs4') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_first', + batch_size=8) + pred_bs8 = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred_bs8, pred_bs1) + + +class TestSegSliderPredict(_TestSliderPredictNamespace.TestSliderPredict): def setUp(self): self.model = pdrs.tasks.seg.UNet(in_channels=10) self.transforms = T.Compose([ @@ -28,239 +195,18 @@ class TestSegSliderPredict(CommonTest): T.ArrangeSegmenter('test') ]) self.image_path = "data/ssst/multispectral.tif" - self.basename = osp.basename(self.image_path) - - def test_blocksize_and_overlap_whole(self): - # Original image size (256, 256) - with tempfile.TemporaryDirectory() as td: - # Whole-image inference using predict() - pred_whole = self.model.predict(self.image_path, self.transforms) - pred_whole = pred_whole['label_map'] - - # Whole-image inference using slider_predict() - save_dir = osp.join(td, 'pred1') - self.model.slider_predict(self.image_path, save_dir, 256, 0, - self.transforms) - pred1 = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred1.shape, pred_whole.shape) - - # `block_size` == `overlap` - save_dir = osp.join(td, 'pred2') - with self.assertRaises(ValueError): - self.model.slider_predict(self.image_path, save_dir, 128, 128, - self.transforms) - - # `block_size` is a tuple - save_dir = osp.join(td, 'pred3') - self.model.slider_predict(self.image_path, save_dir, (128, 32), 0, - self.transforms) - pred3 = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred3.shape, pred_whole.shape) - - # `block_size` and `overlap` are both tuples - save_dir = osp.join(td, 'pred4') - self.model.slider_predict(self.image_path, save_dir, (128, 100), - (10, 5), self.transforms) - pred4 = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred4.shape, pred_whole.shape) - - # `block_size` larger than image size - save_dir = osp.join(td, 'pred5') - with self.assertRaises(ValueError): - self.model.slider_predict(self.image_path, save_dir, 512, 0, - self.transforms) - - def test_merge_strategy(self): - with tempfile.TemporaryDirectory() as td: - # Whole-image inference using predict() - pred_whole = self.model.predict(self.image_path, self.transforms) - pred_whole = pred_whole['label_map'] - - # 'keep_first' - save_dir = osp.join(td, 'keep_first') - self.model.slider_predict( - self.image_path, - save_dir, - 128, - 64, - self.transforms, - merge_strategy='keep_first') - pred_keepfirst = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred_keepfirst.shape, pred_whole.shape) - - # 'keep_last' - save_dir = osp.join(td, 'keep_last') - self.model.slider_predict( - self.image_path, - save_dir, - 128, - 64, - self.transforms, - merge_strategy='keep_last') - pred_keeplast = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred_keeplast.shape, pred_whole.shape) - - # 'accum' - save_dir = osp.join(td, 'accum') - self.model.slider_predict( - self.image_path, - save_dir, - 128, - 64, - self.transforms, - merge_strategy='accum') - pred_accum = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred_accum.shape, pred_whole.shape) + self.ref_path = self.image_path + self.basename = osp.basename(self.ref_path) - def test_geo_info(self): - with tempfile.TemporaryDirectory() as td: - _, geo_info_in = T.decode_image(self.image_path, read_geo_info=True) - self.model.slider_predict(self.image_path, td, 128, 0, - self.transforms) - _, geo_info_out = T.decode_image( - osp.join(td, self.basename), read_geo_info=True) - self.assertEqual(geo_info_out['geo_trans'], - geo_info_in['geo_trans']) - self.assertEqual(geo_info_out['geo_proj'], geo_info_in['geo_proj']) - -class TestCDSliderPredict(CommonTest): +class TestCDSliderPredict(_TestSliderPredictNamespace.TestSliderPredict): def setUp(self): self.model = pdrs.tasks.cd.BIT(in_channels=10) self.transforms = T.Compose([ T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10), T.ArrangeChangeDetector('test') ]) - self.image_paths = ("data/ssmt/multispectral_t1.tif", - "data/ssmt/multispectral_t2.tif") - self.basename = osp.basename(self.image_paths[0]) - - def test_blocksize_and_overlap_whole(self): - # Original image size (256, 256) - with tempfile.TemporaryDirectory() as td: - # Whole-image inference using predict() - pred_whole = self.model.predict(self.image_paths, self.transforms) - pred_whole = pred_whole['label_map'] - - # Whole-image inference using slider_predict() - save_dir = osp.join(td, 'pred1') - self.model.slider_predict(self.image_paths, save_dir, 256, 0, - self.transforms) - pred1 = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred1.shape, pred_whole.shape) - - # `block_size` == `overlap` - save_dir = osp.join(td, 'pred2') - with self.assertRaises(ValueError): - self.model.slider_predict(self.image_paths, save_dir, 128, 128, - self.transforms) - - # `block_size` is a tuple - save_dir = osp.join(td, 'pred3') - self.model.slider_predict(self.image_paths, save_dir, (128, 32), 0, - self.transforms) - pred3 = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred3.shape, pred_whole.shape) - - # `block_size` and `overlap` are both tuples - save_dir = osp.join(td, 'pred4') - self.model.slider_predict(self.image_paths, save_dir, (128, 100), - (10, 5), self.transforms) - pred4 = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred4.shape, pred_whole.shape) - - # `block_size` larger than image size - save_dir = osp.join(td, 'pred5') - with self.assertRaises(ValueError): - self.model.slider_predict(self.image_paths, save_dir, 512, 0, - self.transforms) - - def test_merge_strategy(self): - with tempfile.TemporaryDirectory() as td: - # Whole-image inference using predict() - pred_whole = self.model.predict(self.image_paths, self.transforms) - pred_whole = pred_whole['label_map'] - - # 'keep_first' - save_dir = osp.join(td, 'keep_first') - self.model.slider_predict( - self.image_paths, - save_dir, - 128, - 64, - self.transforms, - merge_strategy='keep_first') - pred_keepfirst = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred_keepfirst.shape, pred_whole.shape) - - # 'keep_last' - save_dir = osp.join(td, 'keep_last') - self.model.slider_predict( - self.image_paths, - save_dir, - 128, - 64, - self.transforms, - merge_strategy='keep_last') - pred_keeplast = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred_keeplast.shape, pred_whole.shape) - - # 'accum' - save_dir = osp.join(td, 'accum') - self.model.slider_predict( - self.image_paths, - save_dir, - 128, - 64, - self.transforms, - merge_strategy='accum') - pred_accum = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred_accum.shape, pred_whole.shape) - - def test_geo_info(self): - with tempfile.TemporaryDirectory() as td: - _, geo_info_in = T.decode_image( - self.image_paths[0], read_geo_info=True) - self.model.slider_predict(self.image_paths, td, 128, 0, - self.transforms) - _, geo_info_out = T.decode_image( - osp.join(td, self.basename), read_geo_info=True) - self.assertEqual(geo_info_out['geo_trans'], - geo_info_in['geo_trans']) - self.assertEqual(geo_info_out['geo_proj'], geo_info_in['geo_proj']) + self.image_path = ("data/ssmt/multispectral_t1.tif", + "data/ssmt/multispectral_t2.tif") + self.ref_path = self.image_path[0] + self.basename = osp.basename(self.ref_path)