diff --git a/docs/apis/data.md b/docs/apis/data.md index 6be2e32..34afb1c 100644 --- a/docs/apis/data.md +++ b/docs/apis/data.md @@ -134,10 +134,12 @@ |-------|----|--------|-----| |`im_path`|`str`|输入图像路径。|| |`to_rgb`|`bool`|若为`True`,则执行BGR到RGB格式的转换。|`True`| -|`to_uint8`|`bool`|若为`True`,则将读取的图像数据量化并转换为uint8类型。|`True`| +|`to_uint8`|`bool`|若为`True`,则将读取的影像数据量化并转换为uint8类型。|`True`| |`decode_bgr`|`bool`|若为`True`,则自动将非地学格式影像(如jpeg影像)解析为BGR格式。|`True`| -|`decode_sar`|`bool`|若为`True`,则自动将2通道的地学格式影像(如GeoTiff影像)作为SAR影像解析。|`True`| +|`decode_sar`|`bool`|若为`True`,则自动将单通道的地学格式影像(如GeoTiff影像)作为SAR影像解析。|`True`| |`read_geo_info`|`bool`|若为`True`,则从影像中读取地理信息。|`False`| +|`use_stretch`|`bool`|是否对影像亮度进行2%线性拉伸。仅当`to_uint8`为`True`时有效。|`False`| +|`read_raw`|`bool`|若为`True`,等价于指定`to_rgb`和`to_uint8`为`False`,且该参数的优先级高于上述参数。|`False`| 返回格式如下: diff --git a/docs/apis/infer.md b/docs/apis/infer.md index 880b5cc..0217a20 100644 --- a/docs/apis/infer.md +++ b/docs/apis/infer.md @@ -155,7 +155,11 @@ def slider_predict(self, save_dir, block_size, overlap=36, - transforms=None): + transforms=None, + invalid_value=255, + merge_strategy='keep_last', + batch_size=1, + quiet=False): ``` 输入参数列表: @@ -164,11 +168,15 @@ def slider_predict(self, |-------|----|--------|-----| |`img_file`|`str`|输入影像路径。|| |`save_dir`|`str`|预测结果输出路径。|| -|`block_size`|`list[int]` \| `tuple[int]` \| `int`|滑窗的窗口大小(以列表或元组指定长、宽或以一个整数指定相同的长宽)。|| -|`overlap`|`list[int]` \| `tuple[int]` \| `int`|滑窗的滑动步长(以列表或元组指定长、宽或以一个整数指定相同的长宽)。|`36`| +|`block_size`|`list[int]` \| `tuple[int]` \| `int`|滑窗的窗口大小(以列表或元组指定宽度、高度或以一个整数指定相同的宽高)。|| +|`overlap`|`list[int]` \| `tuple[int]` \| `int`|滑窗的滑动步长(以列表或元组指定宽度、高度或以一个整数指定相同的宽高)。|`36`| |`transforms`|`paddlers.transforms.Compose` \| `None`|对输入数据应用的数据变换算子。若为`None`,则使用训练器在验证阶段使用的数据变换算子。|`None`| +|`invalid_value`|`int`|输出影像中用于标记无效像素的数值。|`255`| +|`merge_strategy`|`str`|合并滑窗重叠区域使用的策略。`'keep_first'`表示保留遍历顺序(从左至右,从上往下,列优先)最靠前的窗口的预测类别;`'keep_last'`表示保留遍历顺序最靠后的窗口的预测类别;`'accum'`表示通过将各窗口在重叠区域给出的预测概率累加,计算最终预测类别。需要注意的是,在对大尺寸影像进行`overlap`较大的密集推理时,使用`'accum'`策略可能导致较长的推理时间,但一般能够在窗口交界部分取得更好的表现。|`'keep_last'`| +|`batch_size`|`int`|预测时使用的mini-batch大小。|`1`| +|`quiet`|`bool`|若为`True`,不显示预测进度。|`False`| -变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准。 +变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准,存储滑窗推理结果的文件名也与第一时相影像文件相同。 ## 静态图推理API @@ -216,5 +224,10 @@ def predict(self, |`transforms`|`paddlers.transforms.Compose`\|`None`|对输入数据应用的数据变换算子。若为`None`,则使用从`model.yml`中读取的算子。|`None`| |`warmup_iters`|`int`|预热轮数,用于评估模型推理以及前后处理速度。若大于1,将预先重复执行`warmup_iters`次推理,而后才开始正式的预测及其速度评估。|`0`| |`repeats`|`int`|重复次数,用于评估模型推理以及前后处理速度。若大于1,将执行`repeats`次预测并取时间平均值。|`1`| +|`quiet`|`bool`|若为`True`,不打印计时信息。|`False`| `Predictor.predict()`的返回格式与相应的动态图推理API的返回格式完全相同,详情请参考[动态图推理API](#动态图推理api)。 + +### `Predictor.slider_predict()` + +实现滑窗推理功能。用法与`BaseSegmenter`和`BaseChangeDetector`的`slider_predict()`方法相同。 diff --git a/paddlers/deploy/predictor.py b/paddlers/deploy/predictor.py index 1b2c493..24bb337 100644 --- a/paddlers/deploy/predictor.py +++ b/paddlers/deploy/predictor.py @@ -14,6 +14,7 @@ import os.path as osp from operator import itemgetter +from functools import partial import numpy as np import paddle @@ -23,6 +24,7 @@ from paddle.inference import PrecisionType from paddlers.tasks import load_model from paddlers.utils import logging, Timer +from paddlers.tasks.utils.slider_predict import slider_predict class Predictor(object): @@ -271,22 +273,24 @@ class Predictor(object): topk=1, transforms=None, warmup_iters=0, - repeats=1): + repeats=1, + quiet=False): """ - Do prediction. + Do inference. Args: img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For scene classification, image restoration, object detection and semantic segmentation tasks, `img_file` should be either the path of the image to predict, a decoded image (a np.ndarray, which should be consistent with what you get from passing image path to - paddlers.transforms.decode_image()), or a list of image paths or decoded images. For change detection tasks, - img_file should be a tuple of image paths, a tuple of decoded images, or a list of tuples. + paddlers.transforms.decode_image(..., read_raw=True)), or a list of image paths or decoded images. For change + detection tasks, `img_file` should be a tuple of image paths, a tuple of decoded images, or a list of tuples. topk(int, optional): Top-k values to reserve in a classification result. Defaults to 1. transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms from `model.yml`. Defaults to None. warmup_iters (int, optional): Warm-up iterations before measuring the execution time. Defaults to 0. repeats (int, optional): Number of repetitions to evaluate model inference and data processing speed. If greater than 1, the reported time consumption is the average of all repeats. Defaults to 1. + quiet (bool, optional): If True, do not display the timing information. Defaults to False. """ if repeats < 1: @@ -313,12 +317,61 @@ class Predictor(object): self.timer.repeats = repeats self.timer.img_num = len(images) - self.timer.info(average=True) + if not quiet: + self.timer.info(average=True) if isinstance(img_file, (str, np.ndarray, tuple)): results = results[0] return results + def slider_predict(self, + img_file, + save_dir, + block_size, + overlap=36, + transforms=None, + invalid_value=255, + merge_strategy='keep_last', + batch_size=1, + quiet=False): + """ + Do inference using sliding windows. Only semantic segmentation and change detection models are supported in the + sliding-predicting mode. + + Args: + img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For semantic segmentation tasks, `img_file` + should be either the path of the image to predict, a decoded image (a np.ndarray, which should be + consistent with what you get from passing image path to paddlers.transforms.decode_image(..., read_raw=True)), + or a list of image paths or decoded images. For change detection tasks, `img_file` should be a tuple of + image paths, a tuple of decoded images, or a list of tuples. + save_dir (str): Directory that contains saved geotiff file. + block_size (list[int] | tuple[int] | int): Size of block. If `block_size` is a list or tuple, it should be in + (W, H) format. + overlap (list[int] | tuple[int] | int, optional): Overlap between two blocks. If `overlap` is a list or tuple, + it should be in (W, H) format. Defaults to 36. + transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms + from `model.yml`. Defaults to None. + invalid_value (int, optional): Value that marks invalid pixels in output image. Defaults to 255. + merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices are + {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' means keeping the values of the first and + the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel + according to accumulated probabilities. Defaults to 'keep_last'. + batch_size (int, optional): Batch size used in inference. Defaults to 1. + quiet (bool, optional): If True, disable the progress bar. Defaults to False. + """ + slider_predict( + partial( + self.predict, quiet=True), + img_file, + save_dir, + block_size, + overlap, + transforms, + invalid_value, + merge_strategy, + batch_size, + not quiet) + def batch_predict(self, image_list, **params): return self.predict(img_file=image_list, **params) diff --git a/paddlers/tasks/base.py b/paddlers/tasks/base.py index 34e684f..0c32bb5 100644 --- a/paddlers/tasks/base.py +++ b/paddlers/tasks/base.py @@ -86,7 +86,7 @@ class BaseModel(metaclass=ModelMeta): self.quant_config = None self.fixed_input_shape = None - def net_initialize(self, + def initialize_net(self, pretrain_weights=None, save_dir='.', resume_checkpoint=None, @@ -677,3 +677,18 @@ class BaseModel(metaclass=ModelMeta): raise ValueError( f"Incorrect arrange mode! Expected {mode} but got {arrange_obj.mode}." ) + + def run(self, net, inputs, mode): + raise NotImplementedError + + def train(self, *args, **kwargs): + raise NotImplementedError + + def evaluate(self, *args, **kwargs): + raise NotImplementedError + + def preprocess(self, images, transforms, to_tensor): + raise NotImplementedError + + def postprocess(self, *args, **kwargs): + raise NotImplementedError \ No newline at end of file diff --git a/paddlers/tasks/change_detector.py b/paddlers/tasks/change_detector.py index f822b90..6df35d8 100644 --- a/paddlers/tasks/change_detector.py +++ b/paddlers/tasks/change_detector.py @@ -35,6 +35,7 @@ from paddlers.utils.checkpoint import cd_pretrain_weights_dict from .base import BaseModel from .utils import seg_metrics as metrics from .utils.infer_nets import InferCDNet +from .utils.slider_predict import slider_predict __all__ = [ "CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT", @@ -315,7 +316,7 @@ class BaseChangeDetector(BaseModel): exit=True) pretrained_dir = osp.join(save_dir, 'pretrain') is_backbone_weights = pretrain_weights == 'IMAGENET' - self.net_initialize( + self.initialize_net( pretrain_weights=pretrain_weights, save_dir=pretrained_dir, resume_checkpoint=resume_checkpoint, @@ -581,96 +582,44 @@ class BaseChangeDetector(BaseModel): return prediction def slider_predict(self, - img_file, + img_files, save_dir, block_size, overlap=36, - transforms=None): + transforms=None, + invalid_value=255, + merge_strategy='keep_last', + batch_size=1, + quiet=False): """ - Do inference. + Do inference using sliding windows. Args: - img_file (tuple[str]): Tuple of image paths. + img_files (tuple[str]): Tuple of image paths. save_dir (str): Directory that contains saved geotiff file. - block_size (list[int] | tuple[int] | int, optional): Size of block. - overlap (list[int] | tuple[int] | int, optional): Overlap between two blocks. - Defaults to 36. - transforms (paddlers.transforms.Compose|None, optional): Transforms for inputs. - If None, the transforms for evaluation process will be used. Defaults to None. + block_size (list[int] | tuple[int] | int): + Size of block. If `block_size` is a list or tuple, it should be in + (W, H) format. + overlap (list[int] | tuple[int] | int, optional): + Overlap between two blocks. If `overlap` is a list or tuple, it should + be in (W, H) format. Defaults to 36. + transforms (paddlers.transforms.Compose|None, optional): Transforms for + inputs. If None, the transforms for evaluation process will be used. + Defaults to None. + invalid_value (int, optional): Value that marks invalid pixels in output + image. Defaults to 255. + merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices + are {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' + means keeping the values of the first and the last block in traversal + order, respectively. 'accum' means determining the class of an overlapping + pixel according to accumulated probabilities. Defaults to 'keep_last'. + batch_size (int, optional): Batch size used in inference. Defaults to 1. + quiet (bool, optional): If True, disable the progress bar. Defaults to False. """ - try: - from osgeo import gdal - except: - import gdal - - if not isinstance(img_file, tuple) or len(img_file) != 2: - raise ValueError("`img_file` must be a tuple of length 2.") - if isinstance(block_size, int): - block_size = (block_size, block_size) - elif isinstance(block_size, (tuple, list)) and len(block_size) == 2: - block_size = tuple(block_size) - else: - raise ValueError( - "`block_size` must be a tuple/list of length 2 or an integer.") - if isinstance(overlap, int): - overlap = (overlap, overlap) - elif isinstance(overlap, (tuple, list)) and len(overlap) == 2: - overlap = tuple(overlap) - else: - raise ValueError( - "`overlap` must be a tuple/list of length 2 or an integer.") - - src1_data = gdal.Open(img_file[0]) - src2_data = gdal.Open(img_file[1]) - width = src1_data.RasterXSize - height = src1_data.RasterYSize - bands = src1_data.RasterCount - - driver = gdal.GetDriverByName("GTiff") - file_name = osp.splitext(osp.normpath(img_file[0]).split(os.sep)[-1])[ - 0] + ".tif" - if not osp.exists(save_dir): - os.makedirs(save_dir) - save_file = osp.join(save_dir, file_name) - dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte) - dst_data.SetGeoTransform(src1_data.GetGeoTransform()) - dst_data.SetProjection(src1_data.GetProjection()) - band = dst_data.GetRasterBand(1) - band.WriteArray(255 * np.ones((height, width), dtype="uint8")) - - step = np.array(block_size) - np.array(overlap) - for yoff in range(0, height, step[1]): - for xoff in range(0, width, step[0]): - xsize, ysize = block_size - if xoff + xsize > width: - xsize = int(width - xoff) - if yoff + ysize > height: - ysize = int(height - yoff) - im1 = src1_data.ReadAsArray( - int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0)) - im2 = src2_data.ReadAsArray( - int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0)) - # Fill - h, w = im1.shape[:2] - im1_fill = np.zeros( - (block_size[1], block_size[0], bands), dtype=im1.dtype) - im2_fill = im1_fill.copy() - im1_fill[:h, :w, :] = im1 - im2_fill[:h, :w, :] = im2 - im_fill = (im1_fill, im2_fill) - # Predict - pred = self.predict(im_fill, - transforms)["label_map"].astype("uint8") - # Overlap - rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize) - mask = (rd_block == pred[:h, :w]) | (rd_block == 255) - temp = pred[:h, :w].copy() - temp[mask == False] = 0 - band.WriteArray(temp, int(xoff), int(yoff)) - dst_data.FlushCache() - dst_data = None - print("GeoTiff saved in {}.".format(save_file)) + slider_predict(self.predict, img_files, save_dir, block_size, overlap, + transforms, invalid_value, merge_strategy, batch_size, + not quiet) def preprocess(self, images, transforms, to_tensor=True): self._check_transforms(transforms, 'test') @@ -678,8 +627,8 @@ class BaseChangeDetector(BaseModel): batch_ori_shape = list() for im1, im2 in images: if isinstance(im1, str) or isinstance(im2, str): - im1 = decode_image(im1, to_rgb=False) - im2 = decode_image(im2, to_rgb=False) + im1 = decode_image(im1, read_raw=True) + im2 = decode_image(im2, read_raw=True) ori_shape = im1.shape[:2] # XXX: sample do not contain 'image_t1' and 'image_t2'. sample = {'image': im1, 'image2': im2} diff --git a/paddlers/tasks/classifier.py b/paddlers/tasks/classifier.py index 83c20fb..12b6640 100644 --- a/paddlers/tasks/classifier.py +++ b/paddlers/tasks/classifier.py @@ -288,7 +288,7 @@ class BaseClassifier(BaseModel): exit=True) pretrained_dir = osp.join(save_dir, 'pretrain') is_backbone_weights = False - self.net_initialize( + self.initialize_net( pretrain_weights=pretrain_weights, save_dir=pretrained_dir, resume_checkpoint=resume_checkpoint, @@ -497,7 +497,7 @@ class BaseClassifier(BaseModel): batch_ori_shape = list() for im in images: if isinstance(im, str): - im = decode_image(im, to_rgb=False) + im = decode_image(im, read_raw=True) ori_shape = im.shape[:2] sample = {'image': im} im = transforms(sample) diff --git a/paddlers/tasks/object_detector.py b/paddlers/tasks/object_detector.py index ca25213..f42d7cc 100644 --- a/paddlers/tasks/object_detector.py +++ b/paddlers/tasks/object_detector.py @@ -347,7 +347,7 @@ class BaseDetector(BaseModel): "Invalid pretrained weights. Please specify a .pdparams file.", exit=True) pretrained_dir = osp.join(save_dir, 'pretrain') - self.net_initialize( + self.initialize_net( pretrain_weights=pretrain_weights, save_dir=pretrained_dir, resume_checkpoint=resume_checkpoint, @@ -617,7 +617,7 @@ class BaseDetector(BaseModel): batch_samples = list() for im in images: if isinstance(im, str): - im = decode_image(im, to_rgb=False) + im = decode_image(im, read_raw=True) sample = {'image': im} sample = transforms(sample) batch_samples.append(sample) diff --git a/paddlers/tasks/restorer.py b/paddlers/tasks/restorer.py index d9ce6ad..ce7e931 100644 --- a/paddlers/tasks/restorer.py +++ b/paddlers/tasks/restorer.py @@ -283,7 +283,7 @@ class BaseRestorer(BaseModel): exit=True) pretrained_dir = osp.join(save_dir, 'pretrain') is_backbone_weights = pretrain_weights == 'IMAGENET' - self.net_initialize( + self.initialize_net( pretrain_weights=pretrain_weights, save_dir=pretrained_dir, resume_checkpoint=resume_checkpoint, @@ -481,7 +481,7 @@ class BaseRestorer(BaseModel): batch_tar_shape = list() for im in images: if isinstance(im, str): - im = decode_image(im, to_rgb=False) + im = decode_image(im, read_raw=True) ori_shape = im.shape[:2] sample = {'image': im} im = transforms(sample)[0] diff --git a/paddlers/tasks/segmenter.py b/paddlers/tasks/segmenter.py index 83cbffa..aa54f7c 100644 --- a/paddlers/tasks/segmenter.py +++ b/paddlers/tasks/segmenter.py @@ -34,6 +34,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict from .base import BaseModel from .utils import seg_metrics as metrics from .utils.infer_nets import InferSegNet +from .utils.slider_predict import slider_predict __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"] @@ -307,7 +308,7 @@ class BaseSegmenter(BaseModel): exit=True) pretrained_dir = osp.join(save_dir, 'pretrain') is_backbone_weights = pretrain_weights == 'IMAGENET' - self.net_initialize( + self.initialize_net( pretrain_weights=pretrain_weights, save_dir=pretrained_dir, resume_checkpoint=resume_checkpoint, @@ -557,86 +558,40 @@ class BaseSegmenter(BaseModel): save_dir, block_size, overlap=36, - transforms=None): + transforms=None, + invalid_value=255, + merge_strategy='keep_last', + batch_size=1, + quiet=False): """ - Do inference. + Do inference using sliding windows. Args: img_file (str): Image path. save_dir (str): Directory that contains saved geotiff file. block_size (list[int] | tuple[int] | int): - Size of block. + Size of block. If `block_size` is list or tuple, it should be in + (W, H) format. overlap (list[int] | tuple[int] | int, optional): - Overlap between two blocks. Defaults to 36. + Overlap between two blocks. If `overlap` is list or tuple, it should + be in (W, H) format. Defaults to 36. transforms (paddlers.transforms.Compose|None, optional): Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None. + invalid_value (int, optional): Value that marks invalid pixels in output + image. Defaults to 255. + merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices + are {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' + means keeping the values of the first and the last block in traversal + order, respectively. 'accum' means determining the class of an overlapping + pixel according to accumulated probabilities. Defaults to 'keep_last'. + batch_size (int, optional): Batch size used in inference. Defaults to 1. + quiet (bool, optional): If True, disable the progress bar. Defaults to False. """ - try: - from osgeo import gdal - except: - import gdal - - if isinstance(block_size, int): - block_size = (block_size, block_size) - elif isinstance(block_size, (tuple, list)) and len(block_size) == 2: - block_size = tuple(block_size) - else: - raise ValueError( - "`block_size` must be a tuple/list of length 2 or an integer.") - if isinstance(overlap, int): - overlap = (overlap, overlap) - elif isinstance(overlap, (tuple, list)) and len(overlap) == 2: - overlap = tuple(overlap) - else: - raise ValueError( - "`overlap` must be a tuple/list of length 2 or an integer.") - - src_data = gdal.Open(img_file) - width = src_data.RasterXSize - height = src_data.RasterYSize - bands = src_data.RasterCount - - driver = gdal.GetDriverByName("GTiff") - file_name = osp.splitext(osp.normpath(img_file).split(os.sep)[-1])[ - 0] + ".tif" - if not osp.exists(save_dir): - os.makedirs(save_dir) - save_file = osp.join(save_dir, file_name) - dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte) - dst_data.SetGeoTransform(src_data.GetGeoTransform()) - dst_data.SetProjection(src_data.GetProjection()) - band = dst_data.GetRasterBand(1) - band.WriteArray(255 * np.ones((height, width), dtype="uint8")) - - step = np.array(block_size) - np.array(overlap) - for yoff in range(0, height, step[1]): - for xoff in range(0, width, step[0]): - xsize, ysize = block_size - if xoff + xsize > width: - xsize = int(width - xoff) - if yoff + ysize > height: - ysize = int(height - yoff) - im = src_data.ReadAsArray(int(xoff), int(yoff), xsize, - ysize).transpose((1, 2, 0)) - # Fill - h, w = im.shape[:2] - im_fill = np.zeros( - (block_size[1], block_size[0], bands), dtype=im.dtype) - im_fill[:h, :w, :] = im - # Predict - pred = self.predict(im_fill, - transforms)["label_map"].astype("uint8") - # Overlap - rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize) - mask = (rd_block == pred[:h, :w]) | (rd_block == 255) - temp = pred[:h, :w].copy() - temp[mask == False] = 0 - band.WriteArray(temp, int(xoff), int(yoff)) - dst_data.FlushCache() - dst_data = None - print("GeoTiff saved in {}.".format(save_file)) + slider_predict(self.predict, img_file, save_dir, block_size, overlap, + transforms, invalid_value, merge_strategy, batch_size, + not quiet) def preprocess(self, images, transforms, to_tensor=True): self._check_transforms(transforms, 'test') @@ -644,7 +599,7 @@ class BaseSegmenter(BaseModel): batch_ori_shape = list() for im in images: if isinstance(im, str): - im = decode_image(im, to_rgb=False) + im = decode_image(im, read_raw=True) ori_shape = im.shape[:2] sample = {'image': im} im = transforms(sample)[0] diff --git a/paddlers/tasks/utils/slider_predict.py b/paddlers/tasks/utils/slider_predict.py new file mode 100644 index 0000000..620997f --- /dev/null +++ b/paddlers/tasks/utils/slider_predict.py @@ -0,0 +1,437 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import os.path as osp +import math +from abc import ABCMeta, abstractmethod +from collections import Counter, defaultdict + +import numpy as np +from tqdm import tqdm + +import paddlers.utils.logging as logging + + +class Cache(metaclass=ABCMeta): + @abstractmethod + def get_block(self, i_st, j_st, h, w): + pass + + +class SlowCache(Cache): + def __init__(self): + super(SlowCache, self).__init__() + self.cache = defaultdict(Counter) + + def push_pixel(self, i, j, l): + self.cache[(i, j)][l] += 1 + + def push_block(self, i_st, j_st, h, w, data): + for i in range(0, h): + for j in range(0, w): + self.push_pixel(i_st + i, j_st + j, data[i, j]) + + def pop_pixel(self, i, j): + self.cache.pop((i, j)) + + def pop_block(self, i_st, j_st, h, w): + for i in range(0, h): + for j in range(0, w): + self.pop_pixel(i_st + i, j_st + j) + + def get_pixel(self, i, j): + winners = self.cache[(i, j)].most_common(1) + winner = winners[0] + return winner[0] + + def get_block(self, i_st, j_st, h, w): + block = [] + for i in range(i_st, i_st + h): + row = [] + for j in range(j_st, j_st + w): + row.append(self.get_pixel(i, j)) + block.append(row) + return np.asarray(block) + + +class ProbCache(Cache): + def __init__(self, h, w, ch, cw, sh, sw, dtype=np.float32, order='c'): + super(ProbCache, self).__init__() + self.cache = None + self.h = h + self.w = w + self.ch = ch + self.cw = cw + self.sh = sh + self.sw = sw + if not issubclass(dtype, np.floating): + raise TypeError("`dtype` must be one of the floating types.") + self.dtype = dtype + order = order.lower() + if order not in ('c', 'f'): + raise ValueError("`order` other than 'c' and 'f' is not supported.") + self.order = order + + def _alloc_memory(self, nc): + if self.order == 'c': + # Colomn-first order (C-style) + # + # <-- cw --> + # |--------|---------------------|^ ^ + # | || | sh + # |--------|---------------------|| ch v + # | || + # |--------|---------------------|v + # <------------ w ---------------> + self.cache = np.zeros((self.ch, self.w, nc), dtype=self.dtype) + elif self.order == 'f': + # Row-first order (Fortran-style) + # + # <-- sw --> + # <---- cw ----> + # |--------|---|^ ^ + # | | || | + # | | || ch + # | | || | + # |--------|---|| h v + # | | || + # | | || + # | | || + # |--------|---|v + self.cache = np.zeros((self.h, self.cw, nc), dtype=self.dtype) + + def update_block(self, i_st, j_st, h, w, prob_map): + if self.cache is None: + nc = prob_map.shape[2] + # Lazy allocation of memory + self._alloc_memory(nc) + self.cache[i_st:i_st + h, j_st:j_st + w] += prob_map + + def roll_cache(self, shift): + if self.order == 'c': + self.cache[:-shift] = self.cache[shift:] + self.cache[-shift:, :] = 0 + elif self.order == 'f': + self.cache[:, :-shift] = self.cache[:, shift:] + self.cache[:, -shift:] = 0 + + def get_block(self, i_st, j_st, h, w): + return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2) + + +class OverlapProcessor(metaclass=ABCMeta): + def __init__(self, h, w, ch, cw, sh, sw): + super(OverlapProcessor, self).__init__() + self.h = h + self.w = w + self.ch = ch + self.cw = cw + self.sh = sh + self.sw = sw + + @abstractmethod + def process_pred(self, out, xoff, yoff): + pass + + +class KeepFirstProcessor(OverlapProcessor): + def __init__(self, h, w, ch, cw, sh, sw, ds, inval=255): + super(KeepFirstProcessor, self).__init__(h, w, ch, cw, sh, sw) + self.ds = ds + self.inval = inval + + def process_pred(self, out, xoff, yoff): + pred = out['label_map'] + pred = pred[:self.ch, :self.cw] + rd_block = self.ds.ReadAsArray(xoff, yoff, self.cw, self.ch) + mask = rd_block != self.inval + pred = np.where(mask, rd_block, pred) + return pred + + +class KeepLastProcessor(OverlapProcessor): + def process_pred(self, out, xoff, yoff): + pred = out['label_map'] + pred = pred[:self.ch, :self.cw] + return pred + + +class AccumProcessor(OverlapProcessor): + def __init__(self, + h, + w, + ch, + cw, + sh, + sw, + dtype=np.float16, + assign_weight=True): + super(AccumProcessor, self).__init__(h, w, ch, cw, sh, sw) + self.cache = ProbCache(h, w, ch, cw, sh, sw, dtype=dtype, order='c') + self.prev_yoff = None + self.assign_weight = assign_weight + + def process_pred(self, out, xoff, yoff): + if self.prev_yoff is not None and yoff != self.prev_yoff: + if yoff < self.prev_yoff: + raise RuntimeError + self.cache.roll_cache(yoff - self.prev_yoff) + pred = out['label_map'] + pred = pred[:self.ch, :self.cw] + prob = out['score_map'] + prob = prob[:self.ch, :self.cw] + if self.assign_weight: + prob = assign_border_weights(prob, border_ratio=0.25, inplace=True) + self.cache.update_block(0, xoff, self.ch, self.cw, prob) + pred = self.cache.get_block(0, xoff, self.ch, self.cw) + self.prev_yoff = yoff + return pred + + +def assign_border_weights(array, weight=0.5, border_ratio=0.25, inplace=True): + if not inplace: + array = array.copy() + h, w = array.shape[:2] + hm, wm = int(h * border_ratio), int(w * border_ratio) + array[:hm] *= weight + array[-hm:] *= weight + array[:, :wm] *= weight + array[:, -wm:] *= weight + return array + + +def read_block(ds, + xoff, + yoff, + xsize, + ysize, + tar_xsize=None, + tar_ysize=None, + pad_val=0): + if tar_xsize is None: + tar_xsize = xsize + if tar_ysize is None: + tar_ysize = ysize + # Read data from dataset + block = ds.ReadAsArray(xoff, yoff, xsize, ysize) + c, real_ysize, real_xsize = block.shape + assert real_ysize == ysize and real_xsize == xsize + # [c, h, w] -> [h, w, c] + block = block.transpose((1, 2, 0)) + if (real_ysize, real_xsize) != (tar_ysize, tar_xsize): + if real_ysize >= tar_ysize or real_xsize >= tar_xsize: + raise ValueError + padded_block = np.full( + (tar_ysize, tar_xsize, c), fill_value=pad_val, dtype=block.dtype) + # Fill + padded_block[:real_ysize, :real_xsize] = block + return padded_block + else: + return block + + +def slider_predict(predict_func, + img_file, + save_dir, + block_size, + overlap, + transforms, + invalid_value, + merge_strategy, + batch_size, + show_progress=False): + """ + Do inference using sliding windows. + + Args: + predict_func (callable): A callable object that makes the prediction. + img_file (str|tuple[str]): Image path(s). + save_dir (str): Directory that contains saved geotiff file. + block_size (list[int] | tuple[int] | int): + Size of block. If `block_size` is list or tuple, it should be in + (W, H) format. + overlap (list[int] | tuple[int] | int): + Overlap between two blocks. If `overlap` is list or tuple, it should + be in (W, H) format. + transforms (paddlers.transforms.Compose|None): Transforms for inputs. If + None, the transforms for evaluation process will be used. + invalid_value (int): Value that marks invalid pixels in output image. + Defaults to 255. + merge_strategy (str): Strategy to merge overlapping blocks. Choices are + {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' + means keeping the values of the first and the last block in + traversal order, respectively. 'accum' means determining the class + of an overlapping pixel according to accumulated probabilities. + batch_size (int): Batch size used in inference. + show_progress (bool, optional): Whether to show prediction progress with a + progress bar. Defaults to True. + """ + + try: + from osgeo import gdal + except: + import gdal + + if isinstance(block_size, int): + block_size = (block_size, block_size) + elif isinstance(block_size, (tuple, list)) and len(block_size) == 2: + block_size = tuple(block_size) + else: + raise ValueError( + "`block_size` must be a tuple/list of length 2 or an integer.") + if isinstance(overlap, int): + overlap = (overlap, overlap) + elif isinstance(overlap, (tuple, list)) and len(overlap) == 2: + overlap = tuple(overlap) + else: + raise ValueError( + "`overlap` must be a tuple/list of length 2 or an integer.") + + step = np.array( + block_size, dtype=np.int32) - np.array( + overlap, dtype=np.int32) + if step[0] == 0 or step[1] == 0: + raise ValueError("`block_size` and `overlap` should not be equal.") + + if isinstance(img_file, tuple): + if len(img_file) != 2: + raise ValueError("Tuple `img_file` must have the length of two.") + # Assume that two input images have the same size + src_data = gdal.Open(img_file[0]) + src2_data = gdal.Open(img_file[1]) + # Output name is the same as the name of the first image + file_name = osp.basename(osp.normpath(img_file[0])) + else: + src_data = gdal.Open(img_file) + file_name = osp.basename(osp.normpath(img_file)) + + # Get size of original raster + width = src_data.RasterXSize + height = src_data.RasterYSize + bands = src_data.RasterCount + + # XXX: GDAL read behavior conforms to paddlers.transforms.decode_image(read_raw=True) + # except for SAR images. + if bands == 1: + logging.warning( + f"Detected `bands=1`. Please note that currently `slider_predict()` does not properly handle SAR images." + ) + + if block_size[0] > width or block_size[1] > height: + raise ValueError("`block_size` should not be larger than image size.") + + driver = gdal.GetDriverByName("GTiff") + if not osp.exists(save_dir): + os.makedirs(save_dir) + # Replace extension name with '.tif' + file_name = osp.splitext(file_name)[0] + ".tif" + save_file = osp.join(save_dir, file_name) + dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte) + + # Set meta-information + dst_data.SetGeoTransform(src_data.GetGeoTransform()) + dst_data.SetProjection(src_data.GetProjection()) + + # Initialize raster with `invalid_value` + band = dst_data.GetRasterBand(1) + band.WriteArray( + np.full( + (height, width), fill_value=invalid_value, dtype="uint8")) + + if overlap == (0, 0) or block_size == (width, height): + # When there is no overlap or the whole image is used as input, + # use 'keep_last' strategy as it introduces least overheads + merge_strategy = 'keep_last' + + if merge_strategy == 'keep_first': + overlap_processor = KeepFirstProcessor( + height, + width, + *block_size[::-1], + *step[::-1], + band, + inval=invalid_value) + elif merge_strategy == 'keep_last': + overlap_processor = KeepLastProcessor(height, width, *block_size[::-1], + *step[::-1]) + elif merge_strategy == 'accum': + overlap_processor = AccumProcessor(height, width, *block_size[::-1], + *step[::-1]) + else: + raise ValueError("{} is not a supported stragegy for block merging.". + format(merge_strategy)) + + xsize, ysize = block_size + num_blocks = math.ceil(height / step[1]) * math.ceil(width / step[0]) + cnt = 0 + if show_progress: + pb = tqdm(total=num_blocks) + batch_data = [] + batch_offsets = [] + for yoff in range(0, height, step[1]): + for xoff in range(0, width, step[0]): + if xoff + xsize > width: + xoff = width - xsize + is_end_of_row = True + else: + is_end_of_row = False + if yoff + ysize > height: + yoff = height - ysize + is_end_of_col = True + else: + is_end_of_col = False + + # Read + im = read_block(src_data, xoff, yoff, xsize, ysize) + + if isinstance(img_file, tuple): + im2 = read_block(src2_data, xoff, yoff, xsize, ysize) + batch_data.append((im, im2)) + else: + batch_data.append(im) + + batch_offsets.append((xoff, yoff)) + + len_batch = len(batch_data) + + if is_end_of_row and is_end_of_col and len_batch < batch_size: + # Pad `batch_data` by repeating the last element + batch_data = batch_data + [batch_data[-1]] * (batch_size - + len_batch) + # While keeping `len(batch_offsets)` the number of valid elements in the batch + + if len(batch_data) == batch_size: + # Predict + batch_out = predict_func(batch_data, transforms=transforms) + + for out, (xoff_, yoff_) in zip(batch_out, batch_offsets): + # Get processed result + pred = overlap_processor.process_pred(out, xoff_, yoff_) + # Write to file + band.WriteArray(pred, xoff_, yoff_) + + dst_data.FlushCache() + batch_data.clear() + batch_offsets.clear() + + cnt += 1 + + if show_progress: + pb.update(1) + pb.set_description("{} out of {} blocks processed.".format( + cnt, num_blocks)) + + dst_data = None + logging.info("GeoTiff file saved in {}.".format(save_file)) diff --git a/paddlers/transforms/__init__.py b/paddlers/transforms/__init__.py index aafc4dd..ec470fb 100644 --- a/paddlers/transforms/__init__.py +++ b/paddlers/transforms/__init__.py @@ -25,7 +25,9 @@ def decode_image(im_path, to_uint8=True, decode_bgr=True, decode_sar=True, - read_geo_info=False): + read_geo_info=False, + use_stretch=False, + read_raw=False): """ Decode an image. @@ -37,11 +39,16 @@ def decode_image(im_path, uint8 type. Defaults to True. decode_bgr (bool, optional): If True, automatically interpret a non-geo image (e.g. jpeg images) as a BGR image. Defaults to True. - decode_sar (bool, optional): If True, automatically interpret a two-channel + decode_sar (bool, optional): If True, automatically interpret a single-channel geo image (e.g. geotiff images) as a SAR image, set this argument to True. Defaults to True. read_geo_info (bool, optional): If True, read geographical information from the image. Deafults to False. + use_stretch (bool, optional): Whether to apply 2% linear stretch. Valid only if + `to_uint8` is True. Defaults to False. + read_raw (bool, optional): If True, equivalent to setting `to_rgb` and `to_uint8` + to False. Setting `read_raw` takes precedence over setting `to_rgb` and + `to_uint8`. Defaults to False. Returns: np.ndarray|tuple: If `read_geo_info` is False, return the decoded image. @@ -53,12 +60,16 @@ def decode_image(im_path, # Do a presence check. osp.exists() assumes `im_path` is a path-like object. if not osp.exists(im_path): raise ValueError(f"{im_path} does not exist!") + if read_raw: + to_rgb = False + to_uint8 = False decoder = T.DecodeImg( to_rgb=to_rgb, to_uint8=to_uint8, decode_bgr=decode_bgr, decode_sar=decode_sar, - read_geo_info=read_geo_info) + read_geo_info=read_geo_info, + use_stretch=use_stretch) # Deepcopy to avoid inplace modification sample = {'image': copy.deepcopy(im_path)} sample = decoder(sample) diff --git a/paddlers/transforms/functions.py b/paddlers/transforms/functions.py index 5550e33..8518de5 100644 --- a/paddlers/transforms/functions.py +++ b/paddlers/transforms/functions.py @@ -382,13 +382,13 @@ def resize_rle(rle, im_h, im_w, im_scale_x, im_scale_y, interp): return rle -def to_uint8(im, is_linear=False): +def to_uint8(im, stretch=False): """ Convert raster data to uint8 type. Args: im (np.ndarray): Input raster image. - is_linear (bool, optional): Use 2% linear stretch or not. Default is False. + stretch (bool, optional): Use 2% linear stretch or not. Default is False. Returns: np.ndarray: Image data with unit8 type. @@ -430,7 +430,7 @@ def to_uint8(im, is_linear=False): dtype = im.dtype.name if dtype != "uint8": im = _sample_norm(im) - if is_linear: + if stretch: im = _two_percent_linear(im) return im diff --git a/paddlers/transforms/operators.py b/paddlers/transforms/operators.py index 66cffbb..205a4f1 100644 --- a/paddlers/transforms/operators.py +++ b/paddlers/transforms/operators.py @@ -203,11 +203,13 @@ class DecodeImg(Transform): uint8 type. Defaults to True. decode_bgr (bool, optional): If True, automatically interpret a non-geo image (e.g., jpeg images) as a BGR image. Defaults to True. - decode_sar (bool, optional): If True, automatically interpret a two-channel + decode_sar (bool, optional): If True, automatically interpret a single-channel geo image (e.g. geotiff images) as a SAR image, set this argument to True. Defaults to True. read_geo_info (bool, optional): If True, read geographical information from the image. Deafults to False. + use_stretch (bool, optional): Whether to apply 2% linear stretch. Valid only if + `to_uint8` is True. Defaults to False. """ def __init__(self, @@ -215,13 +217,15 @@ class DecodeImg(Transform): to_uint8=True, decode_bgr=True, decode_sar=True, - read_geo_info=False): + read_geo_info=False, + use_stretch=False): super(DecodeImg, self).__init__() self.to_rgb = to_rgb self.to_uint8 = to_uint8 self.decode_bgr = decode_bgr self.decode_sar = decode_sar - self.read_geo_info = False + self.read_geo_info = read_geo_info + self.use_stretch = use_stretch def read_img(self, img_path): img_format = imghdr.what(img_path) @@ -251,7 +255,7 @@ class DecodeImg(Transform): im_data = im_data.transpose((1, 2, 0)) if self.read_geo_info: geo_trans = dataset.GetGeoTransform() - geo_proj = dataset.GetGeoProjection() + geo_proj = dataset.GetProjection() elif img_format in ['jpeg', 'bmp', 'png', 'jpg']: if self.decode_bgr: im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH | @@ -288,7 +292,7 @@ class DecodeImg(Transform): image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if self.to_uint8: - image = to_uint8(image) + image = to_uint8(image, stretch=self.use_stretch) if self.read_geo_info: return image, geo_info_dict diff --git a/tests/deploy/test_predictor.py b/tests/deploy/test_predictor.py index cd67172..2038964 100644 --- a/tests/deploy/test_predictor.py +++ b/tests/deploy/test_predictor.py @@ -151,8 +151,8 @@ class TestCDPredictor(TestPredictor): # Single input (ndarrays) input_ = (decode_image( - t1_path, to_rgb=False), decode_image( - t2_path, to_rgb=False)) # Reuse the name `input_` + t1_path, read_raw=True), decode_image( + t2_path, read_raw=True)) # Reuse the name `input_` out_single_array_p = predictor.predict(input_, transforms=transforms) self.check_dict_equal(out_single_array_p, out_single_file_p) out_single_array_t = trainer.predict(input_, transforms=transforms) @@ -175,8 +175,9 @@ class TestCDPredictor(TestPredictor): # Multiple inputs (ndarrays) input_ = [(decode_image( - t1_path, to_rgb=False), decode_image( - t2_path, to_rgb=False))] * num_inputs # Reuse the name `input_` + t1_path, read_raw=True), decode_image( + t2_path, + read_raw=True))] * num_inputs # Reuse the name `input_` out_multi_array_p = predictor.predict(input_, transforms=transforms) self.assertEqual(len(out_multi_array_p), num_inputs) out_multi_array_t = trainer.predict(input_, transforms=transforms) @@ -217,7 +218,7 @@ class TestClasPredictor(TestPredictor): # Single input (ndarray) input_ = decode_image( - single_input, to_rgb=False) # Reuse the name `input_` + single_input, read_raw=True) # Reuse the name `input_` out_single_array_p = predictor.predict(input_, transforms=transforms) self.check_dict_equal(out_single_array_p, out_single_file_p) out_single_array_t = trainer.predict(input_, transforms=transforms) @@ -241,7 +242,8 @@ class TestClasPredictor(TestPredictor): # Multiple inputs (ndarrays) input_ = [decode_image( - single_input, to_rgb=False)] * num_inputs # Reuse the name `input_` + single_input, + read_raw=True)] * num_inputs # Reuse the name `input_` out_multi_array_p = predictor.predict(input_, transforms=transforms) self.assertEqual(len(out_multi_array_p), num_inputs) out_multi_array_t = trainer.predict(input_, transforms=transforms) @@ -282,7 +284,7 @@ class TestDetPredictor(TestPredictor): # Single input (ndarray) input_ = decode_image( - single_input, to_rgb=False) # Reuse the name `input_` + single_input, read_raw=True) # Reuse the name `input_` predictor.predict(input_, transforms=transforms) trainer.predict(input_, transforms=transforms) out_single_array_list_p = predictor.predict( @@ -301,7 +303,8 @@ class TestDetPredictor(TestPredictor): # Multiple inputs (ndarrays) input_ = [decode_image( - single_input, to_rgb=False)] * num_inputs # Reuse the name `input_` + single_input, + read_raw=True)] * num_inputs # Reuse the name `input_` out_multi_array_p = predictor.predict(input_, transforms=transforms) self.assertEqual(len(out_multi_array_p), num_inputs) out_multi_array_t = trainer.predict(input_, transforms=transforms) @@ -343,7 +346,7 @@ class TestResPredictor(TestPredictor): # Single input (ndarray) input_ = decode_image( - single_input, to_rgb=False) # Reuse the name `input_` + single_input, read_raw=True) # Reuse the name `input_` predictor.predict(input_, transforms=transforms) trainer.predict(input_, transforms=transforms) out_single_array_list_p = predictor.predict( @@ -362,7 +365,8 @@ class TestResPredictor(TestPredictor): # Multiple inputs (ndarrays) input_ = [decode_image( - single_input, to_rgb=False)] * num_inputs # Reuse the name `input_` + single_input, + read_raw=True)] * num_inputs # Reuse the name `input_` out_multi_array_p = predictor.predict(input_, transforms=transforms) self.assertEqual(len(out_multi_array_p), num_inputs) out_multi_array_t = trainer.predict(input_, transforms=transforms) @@ -400,7 +404,7 @@ class TestSegPredictor(TestPredictor): # Single input (ndarray) input_ = decode_image( - single_input, to_rgb=False) # Reuse the name `input_` + single_input, read_raw=True) # Reuse the name `input_` out_single_array_p = predictor.predict(input_, transforms=transforms) self.check_dict_equal(out_single_array_p, out_single_file_p) out_single_array_t = trainer.predict(input_, transforms=transforms) @@ -423,7 +427,8 @@ class TestSegPredictor(TestPredictor): # Multiple inputs (ndarrays) input_ = [decode_image( - single_input, to_rgb=False)] * num_inputs # Reuse the name `input_` + single_input, + read_raw=True)] * num_inputs # Reuse the name `input_` out_multi_array_p = predictor.predict(input_, transforms=transforms) self.assertEqual(len(out_multi_array_p), num_inputs) out_multi_array_t = trainer.predict(input_, transforms=transforms) diff --git a/tests/fast_tests.py b/tests/fast_tests.py index 8e7c26e..4214b5b 100644 --- a/tests/fast_tests.py +++ b/tests/fast_tests.py @@ -13,4 +13,5 @@ # limitations under the License. from rs_models import * +from tasks import * from transforms import * diff --git a/tests/tasks/__init__.py b/tests/tasks/__init__.py index 29c8b7d..5948c0c 100644 --- a/tests/tasks/__init__.py +++ b/tests/tasks/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .test_slider_predict import * diff --git a/tests/tasks/test_slider_predict.py b/tests/tasks/test_slider_predict.py new file mode 100644 index 0000000..5b4392d --- /dev/null +++ b/tests/tasks/test_slider_predict.py @@ -0,0 +1,212 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path as osp +import tempfile + +import paddlers as pdrs +import paddlers.transforms as T +from testing_utils import CommonTest + + +class _TestSliderPredictNamespace: + class TestSliderPredict(CommonTest): + def test_blocksize_and_overlap_whole(self): + # Original image size (256, 256) + with tempfile.TemporaryDirectory() as td: + # Whole-image inference using predict() + pred_whole = self.model.predict(self.image_path, + self.transforms) + pred_whole = pred_whole['label_map'] + + # Whole-image inference using slider_predict() + save_dir = osp.join(td, 'pred1') + self.model.slider_predict(self.image_path, save_dir, 256, 0, + self.transforms) + pred1 = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred1.shape, pred_whole.shape) + + # `block_size` == `overlap` + save_dir = osp.join(td, 'pred2') + with self.assertRaises(ValueError): + self.model.slider_predict(self.image_path, save_dir, 128, + 128, self.transforms) + + # `block_size` is a tuple + save_dir = osp.join(td, 'pred3') + self.model.slider_predict(self.image_path, save_dir, (128, 32), + 0, self.transforms) + pred3 = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred3.shape, pred_whole.shape) + + # `block_size` and `overlap` are both tuples + save_dir = osp.join(td, 'pred4') + self.model.slider_predict(self.image_path, save_dir, (128, 100), + (10, 5), self.transforms) + pred4 = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred4.shape, pred_whole.shape) + + # `block_size` larger than image size + save_dir = osp.join(td, 'pred5') + with self.assertRaises(ValueError): + self.model.slider_predict(self.image_path, save_dir, 512, 0, + self.transforms) + + def test_merge_strategy(self): + with tempfile.TemporaryDirectory() as td: + # Whole-image inference using predict() + pred_whole = self.model.predict(self.image_path, + self.transforms) + pred_whole = pred_whole['label_map'] + + # 'keep_first' + save_dir = osp.join(td, 'keep_first') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_first') + pred_keepfirst = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred_keepfirst.shape, pred_whole.shape) + + # 'keep_last' + save_dir = osp.join(td, 'keep_last') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_last') + pred_keeplast = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred_keeplast.shape, pred_whole.shape) + + # 'accum' + save_dir = osp.join(td, 'accum') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='accum') + pred_accum = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred_accum.shape, pred_whole.shape) + + def test_geo_info(self): + with tempfile.TemporaryDirectory() as td: + _, geo_info_in = T.decode_image( + self.ref_path, read_geo_info=True) + self.model.slider_predict(self.image_path, td, 128, 0, + self.transforms) + _, geo_info_out = T.decode_image( + osp.join(td, self.basename), read_geo_info=True) + self.assertEqual(geo_info_out['geo_trans'], + geo_info_in['geo_trans']) + self.assertEqual(geo_info_out['geo_proj'], + geo_info_in['geo_proj']) + + def test_batch_size(self): + with tempfile.TemporaryDirectory() as td: + # batch_size = 1 + save_dir = osp.join(td, 'bs1') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_first', + batch_size=1) + pred_bs1 = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + + # batch_size = 4 + save_dir = osp.join(td, 'bs4') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_first', + batch_size=4) + pred_bs4 = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred_bs4, pred_bs1) + + # batch_size = 8 + save_dir = osp.join(td, 'bs4') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_first', + batch_size=8) + pred_bs8 = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred_bs8, pred_bs1) + + +class TestSegSliderPredict(_TestSliderPredictNamespace.TestSliderPredict): + def setUp(self): + self.model = pdrs.tasks.seg.UNet(in_channels=10) + self.transforms = T.Compose([ + T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10), + T.ArrangeSegmenter('test') + ]) + self.image_path = "data/ssst/multispectral.tif" + self.ref_path = self.image_path + self.basename = osp.basename(self.ref_path) + + +class TestCDSliderPredict(_TestSliderPredictNamespace.TestSliderPredict): + def setUp(self): + self.model = pdrs.tasks.cd.BIT(in_channels=10) + self.transforms = T.Compose([ + T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10), + T.ArrangeChangeDetector('test') + ]) + self.image_path = ("data/ssmt/multispectral_t1.tif", + "data/ssmt/multispectral_t2.tif") + self.ref_path = self.image_path[0] + self.basename = osp.basename(self.ref_path)