From f7a4ebc58db9d0e1aec4756630f2b68584d70539 Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Fri, 2 Sep 2022 20:02:57 +0800 Subject: [PATCH] Enhance slider_predict() --- paddlers/tasks/base.py | 15 ++ paddlers/tasks/change_detector.py | 141 ++++++++++--- paddlers/tasks/segmenter.py | 111 ++++++++-- paddlers/tasks/utils/slider_predict.py | 52 +++++ paddlers/transforms/operators.py | 4 +- tests/fast_tests.py | 1 + tests/tasks/__init__.py | 2 + tests/tasks/test_slider_predict.py | 274 +++++++++++++++++++++++++ 8 files changed, 545 insertions(+), 55 deletions(-) create mode 100644 paddlers/tasks/utils/slider_predict.py create mode 100644 tests/tasks/test_slider_predict.py diff --git a/paddlers/tasks/base.py b/paddlers/tasks/base.py index 34e684f..5950691 100644 --- a/paddlers/tasks/base.py +++ b/paddlers/tasks/base.py @@ -677,3 +677,18 @@ class BaseModel(metaclass=ModelMeta): raise ValueError( f"Incorrect arrange mode! Expected {mode} but got {arrange_obj.mode}." ) + + def run(self, net, inputs, mode): + raise NotImplementedError + + def train(self, *args, **kwargs): + raise NotImplementedError + + def evaluate(self, *args, **kwargs): + raise NotImplementedError + + def preprocess(self, images, transforms, to_tensor): + raise NotImplementedError + + def postprocess(self, *args, **kwargs): + raise NotImplementedError \ No newline at end of file diff --git a/paddlers/tasks/change_detector.py b/paddlers/tasks/change_detector.py index 7a45172..30babb5 100644 --- a/paddlers/tasks/change_detector.py +++ b/paddlers/tasks/change_detector.py @@ -35,6 +35,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict from .base import BaseModel from .utils import seg_metrics as metrics from .utils.infer_nets import InferCDNet +from .utils.slider_predict import SlowCache as Cache __all__ = [ "CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT", @@ -574,22 +575,35 @@ class BaseChangeDetector(BaseModel): return prediction def slider_predict(self, - img_file, + img_files, save_dir, block_size, overlap=36, - transforms=None): + transforms=None, + invalid_value=255, + merge_strategy='keep_last'): """ - Do inference. + Do inference using sliding windows. Args: - img_file (tuple[str]): Tuple of image paths. + img_files (tuple[str]): Tuple of image paths. save_dir (str): Directory that contains saved geotiff file. - block_size (list[int] | tuple[int] | int, optional): Size of block. - overlap (list[int] | tuple[int] | int, optional): Overlap between two blocks. - Defaults to 36. - transforms (paddlers.transforms.Compose|None, optional): Transforms for inputs. - If None, the transforms for evaluation process will be used. Defaults to None. + block_size (list[int] | tuple[int] | int): + Size of block. If `block_size` is a list or tuple, it should be in + (W, H) format. + overlap (list[int] | tuple[int] | int, optional): + Overlap between two blocks. If `overlap` is a list or tuple, it should + be in (W, H) format. Defaults to 36. + transforms (paddlers.transforms.Compose|None, optional): Transforms for + inputs. If None, the transforms for evaluation process will be used. + Defaults to None. + invalid_value (int, optional): Value that marks invalid pixels in output + image. Defaults to 255. + merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices + are {'keep_first', 'keep_last', 'vote'}. 'keep_first' and 'keep_last' + means keeping the values of the first and the last block in traversal + order, respectively. 'vote' means applying a simple voting strategy when + there are conflicts in the overlapping pixels. Defaults to 'keep_last'. """ try: @@ -597,8 +611,6 @@ class BaseChangeDetector(BaseModel): except: import gdal - if not isinstance(img_file, tuple) or len(img_file) != 2: - raise ValueError("`img_file` must be a tuple of length 2.") if isinstance(block_size, int): block_size = (block_size, block_size) elif isinstance(block_size, (tuple, list)) and len(block_size) == 2: @@ -614,25 +626,54 @@ class BaseChangeDetector(BaseModel): raise ValueError( "`overlap` must be a tuple/list of length 2 or an integer.") - src1_data = gdal.Open(img_file[0]) - src2_data = gdal.Open(img_file[1]) + if merge_strategy not in ('keep_first', 'keep_last', 'vote'): + raise ValueError( + "{} is not a supported stragegy for block merging.".format( + merge_strategy)) + if overlap == (0, 0): + # When there is no overlap, use 'keep_last' strategy as it introduces least overheads + merge_strategy = 'keep_last' + if merge_strategy == 'vote': + logging.warning( + "Currently, a naive Python-implemented cache is used for aggregating voting results. " + "For higher performance in inferring large images, please set `merge_strategy` to 'keep_first' or " + "'keep_last'.") + cache = Cache() + + src1_data = gdal.Open(img_files[0]) + src2_data = gdal.Open(img_files[1]) + + # Assume that two input images have the same size width = src1_data.RasterXSize height = src1_data.RasterYSize bands = src1_data.RasterCount driver = gdal.GetDriverByName("GTiff") - file_name = osp.splitext(osp.normpath(img_file[0]).split(os.sep)[-1])[ - 0] + ".tif" + # Output name is the same as the name of the first image + file_name = osp.basename(osp.normpath(img_files[0])) + # Replace extension name with '.tif' + file_name = osp.splitext(file_name)[0] + ".tif" if not osp.exists(save_dir): os.makedirs(save_dir) save_file = osp.join(save_dir, file_name) dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte) + + # Set meta-information (consistent with the first image) dst_data.SetGeoTransform(src1_data.GetGeoTransform()) dst_data.SetProjection(src1_data.GetProjection()) - band = dst_data.GetRasterBand(1) - band.WriteArray(255 * np.ones((height, width), dtype="uint8")) - step = np.array(block_size) - np.array(overlap) + band = dst_data.GetRasterBand(1) + band.WriteArray( + np.full( + (height, width), fill_value=invalid_value, dtype="uint8")) + + prev_yoff, prev_xoff = None, None + prev_h, prev_w = None, None + step = np.array( + block_size, dtype=np.int32) - np.array( + overlap, dtype=np.int32) + if step[0] == 0 or step[1] == 0: + raise ValueError("`block_size` and `overlap` should not be equal.") for yoff in range(0, height, step[1]): for xoff in range(0, width, step[0]): xsize, ysize = block_size @@ -640,30 +681,64 @@ class BaseChangeDetector(BaseModel): xsize = int(width - xoff) if yoff + ysize > height: ysize = int(height - yoff) - im1 = src1_data.ReadAsArray( - int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0)) - im2 = src2_data.ReadAsArray( - int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0)) + + xoff = int(xoff) + yoff = int(yoff) + im1 = src1_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( + (1, 2, 0)) + im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( + (1, 2, 0)) # Fill h, w = im1.shape[:2] im1_fill = np.zeros( (block_size[1], block_size[0], bands), dtype=im1.dtype) - im2_fill = im1_fill.copy() im1_fill[:h, :w, :] = im1 + + im2_fill = np.zeros( + (block_size[1], block_size[0], bands), dtype=im2.dtype) im2_fill[:h, :w, :] = im2 - im_fill = (im1_fill, im2_fill) + # Predict - pred = self.predict(im_fill, - transforms)["label_map"].astype("uint8") - # Overlap - rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize) - mask = (rd_block == pred[:h, :w]) | (rd_block == 255) - temp = pred[:h, :w].copy() - temp[mask == False] = 0 - band.WriteArray(temp, int(xoff), int(yoff)) + pred = self.predict((im1_fill, im2_fill), transforms) + pred = pred["label_map"].astype('uint8') + pred = pred[:h, :w] + + # Deal with overlapping pixels + if merge_strategy == 'vote': + cache.push_block(yoff, xoff, h, w, pred) + pred = cache.get_block(yoff, xoff, h, w) + pred = pred.astype('uint8') + if prev_yoff is not None: + pop_h = yoff - prev_yoff + else: + pop_h = 0 + if prev_xoff is not None: + if xoff < prev_xoff: + pop_w = prev_w + else: + pop_w = xoff - prev_xoff + else: + pop_w = 0 + cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w) + elif merge_strategy == 'keep_first': + rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize) + mask = rd_block != invalid_value + pred = np.where(mask, rd_block, pred) + elif merge_strategy == 'keep_last': + pass + + # Write to file + band.WriteArray(pred, xoff, yoff) dst_data.FlushCache() + + prev_xoff = xoff + prev_w = w + + prev_yoff = yoff + prev_h = h + dst_data = None - print("GeoTiff saved in {}.".format(save_file)) + logging.info("GeoTiff file saved in {}.".format(save_file)) def preprocess(self, images, transforms, to_tensor=True): self._check_transforms(transforms, 'test') diff --git a/paddlers/tasks/segmenter.py b/paddlers/tasks/segmenter.py index b9c586f..1037041 100644 --- a/paddlers/tasks/segmenter.py +++ b/paddlers/tasks/segmenter.py @@ -34,6 +34,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict from .base import BaseModel from .utils import seg_metrics as metrics from .utils.infer_nets import InferSegNet +from .utils.slider_predict import SlowCache as Cache __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"] @@ -550,20 +551,31 @@ class BaseSegmenter(BaseModel): save_dir, block_size, overlap=36, - transforms=None): + transforms=None, + invalid_value=255, + merge_strategy='keep_last'): """ - Do inference. + Do inference using sliding windows. Args: img_file (str): Image path. save_dir (str): Directory that contains saved geotiff file. block_size (list[int] | tuple[int] | int): - Size of block. + Size of block. If `block_size` is list or tuple, it should be in + (W, H) format. overlap (list[int] | tuple[int] | int, optional): - Overlap between two blocks. Defaults to 36. + Overlap between two blocks. If `overlap` is list or tuple, it should + be in (W, H) format. Defaults to 36. transforms (paddlers.transforms.Compose|None, optional): Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None. + invalid_value (int, optional): Value that marks invalid pixels in output + image. Defaults to 255. + merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices + are {'keep_first', 'keep_last', 'vote'}. 'keep_first' and 'keep_last' + means keeping the values of the first and the last block in traversal + order, respectively. 'vote' means applying a simple voting strategy when + there are conflicts in the overlapping pixels. Defaults to 'keep_last'. """ try: @@ -586,24 +598,50 @@ class BaseSegmenter(BaseModel): raise ValueError( "`overlap` must be a tuple/list of length 2 or an integer.") + if merge_strategy not in ('keep_first', 'keep_last', 'vote'): + raise ValueError( + "{} is not a supported stragegy for block merging.".format( + merge_strategy)) + if overlap == (0, 0): + # When there is no overlap, use 'keep_last' strategy as it introduces least overheads + merge_strategy = 'keep_last' + if merge_strategy == 'vote': + logging.warning( + "Currently, a naive Python-implemented cache is used for aggregating voting results. " + "For higher performance in inferring large images, please set `merge_strategy` to 'keep_first' or " + "'keep_last'.") + cache = Cache() + src_data = gdal.Open(img_file) width = src_data.RasterXSize height = src_data.RasterYSize bands = src_data.RasterCount driver = gdal.GetDriverByName("GTiff") - file_name = osp.splitext(osp.normpath(img_file).split(os.sep)[-1])[ - 0] + ".tif" + file_name = osp.basename(osp.normpath(img_file)) + # Replace extension name with '.tif' + file_name = osp.splitext(file_name)[0] + ".tif" if not osp.exists(save_dir): os.makedirs(save_dir) save_file = osp.join(save_dir, file_name) dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte) + + # Set meta-information dst_data.SetGeoTransform(src_data.GetGeoTransform()) dst_data.SetProjection(src_data.GetProjection()) - band = dst_data.GetRasterBand(1) - band.WriteArray(255 * np.ones((height, width), dtype="uint8")) - step = np.array(block_size) - np.array(overlap) + band = dst_data.GetRasterBand(1) + band.WriteArray( + np.full( + (height, width), fill_value=invalid_value, dtype="uint8")) + + prev_yoff, prev_xoff = None, None + prev_h, prev_w = None, None + step = np.array( + block_size, dtype=np.int32) - np.array( + overlap, dtype=np.int32) + if step[0] == 0 or step[1] == 0: + raise ValueError("`block_size` and `overlap` should not be equal.") for yoff in range(0, height, step[1]): for xoff in range(0, width, step[0]): xsize, ysize = block_size @@ -611,25 +649,58 @@ class BaseSegmenter(BaseModel): xsize = int(width - xoff) if yoff + ysize > height: ysize = int(height - yoff) - im = src_data.ReadAsArray(int(xoff), int(yoff), xsize, - ysize).transpose((1, 2, 0)) + + xoff = int(xoff) + yoff = int(yoff) + im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( + (1, 2, 0)) # Fill h, w = im.shape[:2] im_fill = np.zeros( (block_size[1], block_size[0], bands), dtype=im.dtype) im_fill[:h, :w, :] = im + # Predict - pred = self.predict(im_fill, - transforms)["label_map"].astype("uint8") - # Overlap - rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize) - mask = (rd_block == pred[:h, :w]) | (rd_block == 255) - temp = pred[:h, :w].copy() - temp[mask == False] = 0 - band.WriteArray(temp, int(xoff), int(yoff)) + pred = self.predict(im_fill, transforms) + pred = pred["label_map"].astype('uint8') + pred = pred[:h, :w] + + # Deal with overlapping pixels + if merge_strategy == 'vote': + cache.push_block(yoff, xoff, h, w, pred) + pred = cache.get_block(yoff, xoff, h, w) + pred = pred.astype('uint8') + if prev_yoff is not None: + pop_h = yoff - prev_yoff + else: + pop_h = 0 + if prev_xoff is not None: + if xoff < prev_xoff: + pop_w = prev_w + else: + pop_w = xoff - prev_xoff + else: + pop_w = 0 + cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w) + elif merge_strategy == 'keep_first': + rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize) + mask = rd_block != invalid_value + pred = np.where(mask, rd_block, pred) + elif merge_strategy == 'keep_last': + pass + + # Write to file + band.WriteArray(pred, xoff, yoff) dst_data.FlushCache() + + prev_xoff = xoff + prev_w = w + + prev_yoff = yoff + prev_h = h + dst_data = None - print("GeoTiff saved in {}.".format(save_file)) + logging.info("GeoTiff file saved in {}.".format(save_file)) def preprocess(self, images, transforms, to_tensor=True): self._check_transforms(transforms, 'test') diff --git a/paddlers/tasks/utils/slider_predict.py b/paddlers/tasks/utils/slider_predict.py new file mode 100644 index 0000000..254882b --- /dev/null +++ b/paddlers/tasks/utils/slider_predict.py @@ -0,0 +1,52 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import Counter, defaultdict + +import numpy as np + + +class SlowCache(object): + def __init__(self): + self.cache = defaultdict(Counter) + + def push_pixel(self, i, j, l): + self.cache[(i, j)][l] += 1 + + def push_block(self, i_st, j_st, h, w, data): + for i in range(0, h): + for j in range(0, w): + self.push_pixel(i_st + i, j_st + j, data[i, j]) + + def pop_pixel(self, i, j): + self.cache.pop((i, j)) + + def pop_block(self, i_st, j_st, h, w): + for i in range(0, h): + for j in range(0, w): + self.pop_pixel(i_st + i, j_st + j) + + def get_pixel(self, i, j): + winners = self.cache[(i, j)].most_common(1) + winner = winners[0] + return winner[0] + + def get_block(self, i_st, j_st, h, w): + block = [] + for i in range(i_st, i_st + h): + row = [] + for j in range(j_st, j_st + w): + row.append(self.get_pixel(i, j)) + block.append(row) + return np.asarray(block) diff --git a/paddlers/transforms/operators.py b/paddlers/transforms/operators.py index dd21c7a..6110bdb 100644 --- a/paddlers/transforms/operators.py +++ b/paddlers/transforms/operators.py @@ -197,7 +197,7 @@ class DecodeImg(Transform): self.to_uint8 = to_uint8 self.decode_bgr = decode_bgr self.decode_sar = decode_sar - self.read_geo_info = False + self.read_geo_info = read_geo_info def read_img(self, img_path): img_format = imghdr.what(img_path) @@ -227,7 +227,7 @@ class DecodeImg(Transform): im_data = im_data.transpose((1, 2, 0)) if self.read_geo_info: geo_trans = dataset.GetGeoTransform() - geo_proj = dataset.GetGeoProjection() + geo_proj = dataset.GetProjection() elif img_format in ['jpeg', 'bmp', 'png', 'jpg']: if self.decode_bgr: im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH | diff --git a/tests/fast_tests.py b/tests/fast_tests.py index 8e7c26e..4214b5b 100644 --- a/tests/fast_tests.py +++ b/tests/fast_tests.py @@ -13,4 +13,5 @@ # limitations under the License. from rs_models import * +from tasks import * from transforms import * diff --git a/tests/tasks/__init__.py b/tests/tasks/__init__.py index 29c8b7d..5948c0c 100644 --- a/tests/tasks/__init__.py +++ b/tests/tasks/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .test_slider_predict import * diff --git a/tests/tasks/test_slider_predict.py b/tests/tasks/test_slider_predict.py new file mode 100644 index 0000000..bc3bd5a --- /dev/null +++ b/tests/tasks/test_slider_predict.py @@ -0,0 +1,274 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path as osp +import tempfile + +import paddlers as pdrs +import paddlers.transforms as T +from testing_utils import CommonTest + + +class TestSegSliderPredict(CommonTest): + def setUp(self): + self.model = pdrs.tasks.seg.UNet(in_channels=10) + self.transforms = T.Compose([ + T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10), + T.ArrangeSegmenter('test') + ]) + self.image_path = "data/ssst/multispectral.tif" + self.basename = osp.basename(self.image_path) + + def test_blocksize_and_overlap_whole(self): + # Original image size (256, 256) + with tempfile.TemporaryDirectory() as td: + # Whole-image inference using predict() + pred_whole = self.model.predict(self.image_path, self.transforms) + pred_whole = pred_whole['label_map'] + + # Whole-image inference using slider_predict() + save_dir = osp.join(td, 'pred1') + self.model.slider_predict(self.image_path, save_dir, 256, 0, + self.transforms) + pred1 = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred1.shape, pred_whole.shape) + + # `block_size` == `overlap` + save_dir = osp.join(td, 'pred2') + with self.assertRaises(ValueError): + self.model.slider_predict(self.image_path, save_dir, 128, 128, + self.transforms) + + # `block_size` is a tuple + save_dir = osp.join(td, 'pred3') + self.model.slider_predict(self.image_path, save_dir, (128, 32), 0, + self.transforms) + pred3 = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred3.shape, pred_whole.shape) + + # `block_size` and `overlap` are both tuples + save_dir = osp.join(td, 'pred4') + self.model.slider_predict(self.image_path, save_dir, (128, 100), + (10, 5), self.transforms) + pred4 = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred4.shape, pred_whole.shape) + + # `block_size` larger than image size + save_dir = osp.join(td, 'pred5') + self.model.slider_predict(self.image_path, save_dir, 512, 0, + self.transforms) + pred5 = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred5.shape, pred_whole.shape) + + def test_merge_strategy(self): + with tempfile.TemporaryDirectory() as td: + # Whole-image inference using predict() + pred_whole = self.model.predict(self.image_path, self.transforms) + pred_whole = pred_whole['label_map'] + + # 'keep_first' + save_dir = osp.join(td, 'keep_first') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_first') + pred_keepfirst = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred_keepfirst.shape, pred_whole.shape) + + # 'keep_last' + save_dir = osp.join(td, 'keep_last') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_last') + pred_keeplast = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred_keeplast.shape, pred_whole.shape) + + # 'vote' + save_dir = osp.join(td, 'vote') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='vote') + pred_vote = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred_vote.shape, pred_whole.shape) + + def test_geo_info(self): + with tempfile.TemporaryDirectory() as td: + _, geo_info_in = T.decode_image(self.image_path, read_geo_info=True) + self.model.slider_predict(self.image_path, td, 128, 0, + self.transforms) + _, geo_info_out = T.decode_image( + osp.join(td, self.basename), read_geo_info=True) + self.assertEqual(geo_info_out['geo_trans'], + geo_info_in['geo_trans']) + self.assertEqual(geo_info_out['geo_proj'], geo_info_in['geo_proj']) + + +class TestCDSliderPredict(CommonTest): + def setUp(self): + self.model = pdrs.tasks.cd.BIT(in_channels=10) + self.transforms = T.Compose([ + T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10), + T.ArrangeChangeDetector('test') + ]) + self.image_paths = ("data/ssmt/multispectral_t1.tif", + "data/ssmt/multispectral_t2.tif") + self.basename = osp.basename(self.image_paths[0]) + + def test_blocksize_and_overlap_whole(self): + # Original image size (256, 256) + with tempfile.TemporaryDirectory() as td: + # Whole-image inference using predict() + pred_whole = self.model.predict(self.image_paths, self.transforms) + pred_whole = pred_whole['label_map'] + + # Whole-image inference using slider_predict() + save_dir = osp.join(td, 'pred1') + self.model.slider_predict(self.image_paths, save_dir, 256, 0, + self.transforms) + pred1 = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred1.shape, pred_whole.shape) + + # `block_size` == `overlap` + save_dir = osp.join(td, 'pred2') + with self.assertRaises(ValueError): + self.model.slider_predict(self.image_paths, save_dir, 128, 128, + self.transforms) + + # `block_size` is a tuple + save_dir = osp.join(td, 'pred3') + self.model.slider_predict(self.image_paths, save_dir, (128, 32), 0, + self.transforms) + pred3 = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred3.shape, pred_whole.shape) + + # `block_size` and `overlap` are both tuples + save_dir = osp.join(td, 'pred4') + self.model.slider_predict(self.image_paths, save_dir, (128, 100), + (10, 5), self.transforms) + pred4 = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred4.shape, pred_whole.shape) + + # `block_size` larger than image size + save_dir = osp.join(td, 'pred5') + self.model.slider_predict(self.image_paths, save_dir, 512, 0, + self.transforms) + pred5 = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred5.shape, pred_whole.shape) + + def test_merge_strategy(self): + with tempfile.TemporaryDirectory() as td: + # Whole-image inference using predict() + pred_whole = self.model.predict(self.image_paths, self.transforms) + pred_whole = pred_whole['label_map'] + + # 'keep_first' + save_dir = osp.join(td, 'keep_first') + self.model.slider_predict( + self.image_paths, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_first') + pred_keepfirst = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred_keepfirst.shape, pred_whole.shape) + + # 'keep_last' + save_dir = osp.join(td, 'keep_last') + self.model.slider_predict( + self.image_paths, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_last') + pred_keeplast = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred_keeplast.shape, pred_whole.shape) + + # 'vote' + save_dir = osp.join(td, 'vote') + self.model.slider_predict( + self.image_paths, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='vote') + pred_vote = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred_vote.shape, pred_whole.shape) + + def test_geo_info(self): + with tempfile.TemporaryDirectory() as td: + _, geo_info_in = T.decode_image( + self.image_paths[0], read_geo_info=True) + self.model.slider_predict(self.image_paths, td, 128, 0, + self.transforms) + _, geo_info_out = T.decode_image( + osp.join(td, self.basename), read_geo_info=True) + self.assertEqual(geo_info_out['geo_trans'], + geo_info_in['geo_trans']) + self.assertEqual(geo_info_out['geo_proj'], geo_info_in['geo_proj'])