From f7a4ebc58db9d0e1aec4756630f2b68584d70539 Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Fri, 2 Sep 2022 20:02:57 +0800 Subject: [PATCH 01/10] Enhance slider_predict() --- paddlers/tasks/base.py | 15 ++ paddlers/tasks/change_detector.py | 141 ++++++++++--- paddlers/tasks/segmenter.py | 111 ++++++++-- paddlers/tasks/utils/slider_predict.py | 52 +++++ paddlers/transforms/operators.py | 4 +- tests/fast_tests.py | 1 + tests/tasks/__init__.py | 2 + tests/tasks/test_slider_predict.py | 274 +++++++++++++++++++++++++ 8 files changed, 545 insertions(+), 55 deletions(-) create mode 100644 paddlers/tasks/utils/slider_predict.py create mode 100644 tests/tasks/test_slider_predict.py diff --git a/paddlers/tasks/base.py b/paddlers/tasks/base.py index 34e684f..5950691 100644 --- a/paddlers/tasks/base.py +++ b/paddlers/tasks/base.py @@ -677,3 +677,18 @@ class BaseModel(metaclass=ModelMeta): raise ValueError( f"Incorrect arrange mode! Expected {mode} but got {arrange_obj.mode}." ) + + def run(self, net, inputs, mode): + raise NotImplementedError + + def train(self, *args, **kwargs): + raise NotImplementedError + + def evaluate(self, *args, **kwargs): + raise NotImplementedError + + def preprocess(self, images, transforms, to_tensor): + raise NotImplementedError + + def postprocess(self, *args, **kwargs): + raise NotImplementedError \ No newline at end of file diff --git a/paddlers/tasks/change_detector.py b/paddlers/tasks/change_detector.py index 7a45172..30babb5 100644 --- a/paddlers/tasks/change_detector.py +++ b/paddlers/tasks/change_detector.py @@ -35,6 +35,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict from .base import BaseModel from .utils import seg_metrics as metrics from .utils.infer_nets import InferCDNet +from .utils.slider_predict import SlowCache as Cache __all__ = [ "CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT", @@ -574,22 +575,35 @@ class BaseChangeDetector(BaseModel): return prediction def slider_predict(self, - img_file, + img_files, save_dir, block_size, overlap=36, - transforms=None): + transforms=None, + invalid_value=255, + merge_strategy='keep_last'): """ - Do inference. + Do inference using sliding windows. Args: - img_file (tuple[str]): Tuple of image paths. + img_files (tuple[str]): Tuple of image paths. save_dir (str): Directory that contains saved geotiff file. - block_size (list[int] | tuple[int] | int, optional): Size of block. - overlap (list[int] | tuple[int] | int, optional): Overlap between two blocks. - Defaults to 36. - transforms (paddlers.transforms.Compose|None, optional): Transforms for inputs. - If None, the transforms for evaluation process will be used. Defaults to None. + block_size (list[int] | tuple[int] | int): + Size of block. If `block_size` is a list or tuple, it should be in + (W, H) format. + overlap (list[int] | tuple[int] | int, optional): + Overlap between two blocks. If `overlap` is a list or tuple, it should + be in (W, H) format. Defaults to 36. + transforms (paddlers.transforms.Compose|None, optional): Transforms for + inputs. If None, the transforms for evaluation process will be used. + Defaults to None. + invalid_value (int, optional): Value that marks invalid pixels in output + image. Defaults to 255. + merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices + are {'keep_first', 'keep_last', 'vote'}. 'keep_first' and 'keep_last' + means keeping the values of the first and the last block in traversal + order, respectively. 'vote' means applying a simple voting strategy when + there are conflicts in the overlapping pixels. Defaults to 'keep_last'. """ try: @@ -597,8 +611,6 @@ class BaseChangeDetector(BaseModel): except: import gdal - if not isinstance(img_file, tuple) or len(img_file) != 2: - raise ValueError("`img_file` must be a tuple of length 2.") if isinstance(block_size, int): block_size = (block_size, block_size) elif isinstance(block_size, (tuple, list)) and len(block_size) == 2: @@ -614,25 +626,54 @@ class BaseChangeDetector(BaseModel): raise ValueError( "`overlap` must be a tuple/list of length 2 or an integer.") - src1_data = gdal.Open(img_file[0]) - src2_data = gdal.Open(img_file[1]) + if merge_strategy not in ('keep_first', 'keep_last', 'vote'): + raise ValueError( + "{} is not a supported stragegy for block merging.".format( + merge_strategy)) + if overlap == (0, 0): + # When there is no overlap, use 'keep_last' strategy as it introduces least overheads + merge_strategy = 'keep_last' + if merge_strategy == 'vote': + logging.warning( + "Currently, a naive Python-implemented cache is used for aggregating voting results. " + "For higher performance in inferring large images, please set `merge_strategy` to 'keep_first' or " + "'keep_last'.") + cache = Cache() + + src1_data = gdal.Open(img_files[0]) + src2_data = gdal.Open(img_files[1]) + + # Assume that two input images have the same size width = src1_data.RasterXSize height = src1_data.RasterYSize bands = src1_data.RasterCount driver = gdal.GetDriverByName("GTiff") - file_name = osp.splitext(osp.normpath(img_file[0]).split(os.sep)[-1])[ - 0] + ".tif" + # Output name is the same as the name of the first image + file_name = osp.basename(osp.normpath(img_files[0])) + # Replace extension name with '.tif' + file_name = osp.splitext(file_name)[0] + ".tif" if not osp.exists(save_dir): os.makedirs(save_dir) save_file = osp.join(save_dir, file_name) dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte) + + # Set meta-information (consistent with the first image) dst_data.SetGeoTransform(src1_data.GetGeoTransform()) dst_data.SetProjection(src1_data.GetProjection()) - band = dst_data.GetRasterBand(1) - band.WriteArray(255 * np.ones((height, width), dtype="uint8")) - step = np.array(block_size) - np.array(overlap) + band = dst_data.GetRasterBand(1) + band.WriteArray( + np.full( + (height, width), fill_value=invalid_value, dtype="uint8")) + + prev_yoff, prev_xoff = None, None + prev_h, prev_w = None, None + step = np.array( + block_size, dtype=np.int32) - np.array( + overlap, dtype=np.int32) + if step[0] == 0 or step[1] == 0: + raise ValueError("`block_size` and `overlap` should not be equal.") for yoff in range(0, height, step[1]): for xoff in range(0, width, step[0]): xsize, ysize = block_size @@ -640,30 +681,64 @@ class BaseChangeDetector(BaseModel): xsize = int(width - xoff) if yoff + ysize > height: ysize = int(height - yoff) - im1 = src1_data.ReadAsArray( - int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0)) - im2 = src2_data.ReadAsArray( - int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0)) + + xoff = int(xoff) + yoff = int(yoff) + im1 = src1_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( + (1, 2, 0)) + im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( + (1, 2, 0)) # Fill h, w = im1.shape[:2] im1_fill = np.zeros( (block_size[1], block_size[0], bands), dtype=im1.dtype) - im2_fill = im1_fill.copy() im1_fill[:h, :w, :] = im1 + + im2_fill = np.zeros( + (block_size[1], block_size[0], bands), dtype=im2.dtype) im2_fill[:h, :w, :] = im2 - im_fill = (im1_fill, im2_fill) + # Predict - pred = self.predict(im_fill, - transforms)["label_map"].astype("uint8") - # Overlap - rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize) - mask = (rd_block == pred[:h, :w]) | (rd_block == 255) - temp = pred[:h, :w].copy() - temp[mask == False] = 0 - band.WriteArray(temp, int(xoff), int(yoff)) + pred = self.predict((im1_fill, im2_fill), transforms) + pred = pred["label_map"].astype('uint8') + pred = pred[:h, :w] + + # Deal with overlapping pixels + if merge_strategy == 'vote': + cache.push_block(yoff, xoff, h, w, pred) + pred = cache.get_block(yoff, xoff, h, w) + pred = pred.astype('uint8') + if prev_yoff is not None: + pop_h = yoff - prev_yoff + else: + pop_h = 0 + if prev_xoff is not None: + if xoff < prev_xoff: + pop_w = prev_w + else: + pop_w = xoff - prev_xoff + else: + pop_w = 0 + cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w) + elif merge_strategy == 'keep_first': + rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize) + mask = rd_block != invalid_value + pred = np.where(mask, rd_block, pred) + elif merge_strategy == 'keep_last': + pass + + # Write to file + band.WriteArray(pred, xoff, yoff) dst_data.FlushCache() + + prev_xoff = xoff + prev_w = w + + prev_yoff = yoff + prev_h = h + dst_data = None - print("GeoTiff saved in {}.".format(save_file)) + logging.info("GeoTiff file saved in {}.".format(save_file)) def preprocess(self, images, transforms, to_tensor=True): self._check_transforms(transforms, 'test') diff --git a/paddlers/tasks/segmenter.py b/paddlers/tasks/segmenter.py index b9c586f..1037041 100644 --- a/paddlers/tasks/segmenter.py +++ b/paddlers/tasks/segmenter.py @@ -34,6 +34,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict from .base import BaseModel from .utils import seg_metrics as metrics from .utils.infer_nets import InferSegNet +from .utils.slider_predict import SlowCache as Cache __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"] @@ -550,20 +551,31 @@ class BaseSegmenter(BaseModel): save_dir, block_size, overlap=36, - transforms=None): + transforms=None, + invalid_value=255, + merge_strategy='keep_last'): """ - Do inference. + Do inference using sliding windows. Args: img_file (str): Image path. save_dir (str): Directory that contains saved geotiff file. block_size (list[int] | tuple[int] | int): - Size of block. + Size of block. If `block_size` is list or tuple, it should be in + (W, H) format. overlap (list[int] | tuple[int] | int, optional): - Overlap between two blocks. Defaults to 36. + Overlap between two blocks. If `overlap` is list or tuple, it should + be in (W, H) format. Defaults to 36. transforms (paddlers.transforms.Compose|None, optional): Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None. + invalid_value (int, optional): Value that marks invalid pixels in output + image. Defaults to 255. + merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices + are {'keep_first', 'keep_last', 'vote'}. 'keep_first' and 'keep_last' + means keeping the values of the first and the last block in traversal + order, respectively. 'vote' means applying a simple voting strategy when + there are conflicts in the overlapping pixels. Defaults to 'keep_last'. """ try: @@ -586,24 +598,50 @@ class BaseSegmenter(BaseModel): raise ValueError( "`overlap` must be a tuple/list of length 2 or an integer.") + if merge_strategy not in ('keep_first', 'keep_last', 'vote'): + raise ValueError( + "{} is not a supported stragegy for block merging.".format( + merge_strategy)) + if overlap == (0, 0): + # When there is no overlap, use 'keep_last' strategy as it introduces least overheads + merge_strategy = 'keep_last' + if merge_strategy == 'vote': + logging.warning( + "Currently, a naive Python-implemented cache is used for aggregating voting results. " + "For higher performance in inferring large images, please set `merge_strategy` to 'keep_first' or " + "'keep_last'.") + cache = Cache() + src_data = gdal.Open(img_file) width = src_data.RasterXSize height = src_data.RasterYSize bands = src_data.RasterCount driver = gdal.GetDriverByName("GTiff") - file_name = osp.splitext(osp.normpath(img_file).split(os.sep)[-1])[ - 0] + ".tif" + file_name = osp.basename(osp.normpath(img_file)) + # Replace extension name with '.tif' + file_name = osp.splitext(file_name)[0] + ".tif" if not osp.exists(save_dir): os.makedirs(save_dir) save_file = osp.join(save_dir, file_name) dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte) + + # Set meta-information dst_data.SetGeoTransform(src_data.GetGeoTransform()) dst_data.SetProjection(src_data.GetProjection()) - band = dst_data.GetRasterBand(1) - band.WriteArray(255 * np.ones((height, width), dtype="uint8")) - step = np.array(block_size) - np.array(overlap) + band = dst_data.GetRasterBand(1) + band.WriteArray( + np.full( + (height, width), fill_value=invalid_value, dtype="uint8")) + + prev_yoff, prev_xoff = None, None + prev_h, prev_w = None, None + step = np.array( + block_size, dtype=np.int32) - np.array( + overlap, dtype=np.int32) + if step[0] == 0 or step[1] == 0: + raise ValueError("`block_size` and `overlap` should not be equal.") for yoff in range(0, height, step[1]): for xoff in range(0, width, step[0]): xsize, ysize = block_size @@ -611,25 +649,58 @@ class BaseSegmenter(BaseModel): xsize = int(width - xoff) if yoff + ysize > height: ysize = int(height - yoff) - im = src_data.ReadAsArray(int(xoff), int(yoff), xsize, - ysize).transpose((1, 2, 0)) + + xoff = int(xoff) + yoff = int(yoff) + im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( + (1, 2, 0)) # Fill h, w = im.shape[:2] im_fill = np.zeros( (block_size[1], block_size[0], bands), dtype=im.dtype) im_fill[:h, :w, :] = im + # Predict - pred = self.predict(im_fill, - transforms)["label_map"].astype("uint8") - # Overlap - rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize) - mask = (rd_block == pred[:h, :w]) | (rd_block == 255) - temp = pred[:h, :w].copy() - temp[mask == False] = 0 - band.WriteArray(temp, int(xoff), int(yoff)) + pred = self.predict(im_fill, transforms) + pred = pred["label_map"].astype('uint8') + pred = pred[:h, :w] + + # Deal with overlapping pixels + if merge_strategy == 'vote': + cache.push_block(yoff, xoff, h, w, pred) + pred = cache.get_block(yoff, xoff, h, w) + pred = pred.astype('uint8') + if prev_yoff is not None: + pop_h = yoff - prev_yoff + else: + pop_h = 0 + if prev_xoff is not None: + if xoff < prev_xoff: + pop_w = prev_w + else: + pop_w = xoff - prev_xoff + else: + pop_w = 0 + cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w) + elif merge_strategy == 'keep_first': + rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize) + mask = rd_block != invalid_value + pred = np.where(mask, rd_block, pred) + elif merge_strategy == 'keep_last': + pass + + # Write to file + band.WriteArray(pred, xoff, yoff) dst_data.FlushCache() + + prev_xoff = xoff + prev_w = w + + prev_yoff = yoff + prev_h = h + dst_data = None - print("GeoTiff saved in {}.".format(save_file)) + logging.info("GeoTiff file saved in {}.".format(save_file)) def preprocess(self, images, transforms, to_tensor=True): self._check_transforms(transforms, 'test') diff --git a/paddlers/tasks/utils/slider_predict.py b/paddlers/tasks/utils/slider_predict.py new file mode 100644 index 0000000..254882b --- /dev/null +++ b/paddlers/tasks/utils/slider_predict.py @@ -0,0 +1,52 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import Counter, defaultdict + +import numpy as np + + +class SlowCache(object): + def __init__(self): + self.cache = defaultdict(Counter) + + def push_pixel(self, i, j, l): + self.cache[(i, j)][l] += 1 + + def push_block(self, i_st, j_st, h, w, data): + for i in range(0, h): + for j in range(0, w): + self.push_pixel(i_st + i, j_st + j, data[i, j]) + + def pop_pixel(self, i, j): + self.cache.pop((i, j)) + + def pop_block(self, i_st, j_st, h, w): + for i in range(0, h): + for j in range(0, w): + self.pop_pixel(i_st + i, j_st + j) + + def get_pixel(self, i, j): + winners = self.cache[(i, j)].most_common(1) + winner = winners[0] + return winner[0] + + def get_block(self, i_st, j_st, h, w): + block = [] + for i in range(i_st, i_st + h): + row = [] + for j in range(j_st, j_st + w): + row.append(self.get_pixel(i, j)) + block.append(row) + return np.asarray(block) diff --git a/paddlers/transforms/operators.py b/paddlers/transforms/operators.py index dd21c7a..6110bdb 100644 --- a/paddlers/transforms/operators.py +++ b/paddlers/transforms/operators.py @@ -197,7 +197,7 @@ class DecodeImg(Transform): self.to_uint8 = to_uint8 self.decode_bgr = decode_bgr self.decode_sar = decode_sar - self.read_geo_info = False + self.read_geo_info = read_geo_info def read_img(self, img_path): img_format = imghdr.what(img_path) @@ -227,7 +227,7 @@ class DecodeImg(Transform): im_data = im_data.transpose((1, 2, 0)) if self.read_geo_info: geo_trans = dataset.GetGeoTransform() - geo_proj = dataset.GetGeoProjection() + geo_proj = dataset.GetProjection() elif img_format in ['jpeg', 'bmp', 'png', 'jpg']: if self.decode_bgr: im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH | diff --git a/tests/fast_tests.py b/tests/fast_tests.py index 8e7c26e..4214b5b 100644 --- a/tests/fast_tests.py +++ b/tests/fast_tests.py @@ -13,4 +13,5 @@ # limitations under the License. from rs_models import * +from tasks import * from transforms import * diff --git a/tests/tasks/__init__.py b/tests/tasks/__init__.py index 29c8b7d..5948c0c 100644 --- a/tests/tasks/__init__.py +++ b/tests/tasks/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .test_slider_predict import * diff --git a/tests/tasks/test_slider_predict.py b/tests/tasks/test_slider_predict.py new file mode 100644 index 0000000..bc3bd5a --- /dev/null +++ b/tests/tasks/test_slider_predict.py @@ -0,0 +1,274 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path as osp +import tempfile + +import paddlers as pdrs +import paddlers.transforms as T +from testing_utils import CommonTest + + +class TestSegSliderPredict(CommonTest): + def setUp(self): + self.model = pdrs.tasks.seg.UNet(in_channels=10) + self.transforms = T.Compose([ + T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10), + T.ArrangeSegmenter('test') + ]) + self.image_path = "data/ssst/multispectral.tif" + self.basename = osp.basename(self.image_path) + + def test_blocksize_and_overlap_whole(self): + # Original image size (256, 256) + with tempfile.TemporaryDirectory() as td: + # Whole-image inference using predict() + pred_whole = self.model.predict(self.image_path, self.transforms) + pred_whole = pred_whole['label_map'] + + # Whole-image inference using slider_predict() + save_dir = osp.join(td, 'pred1') + self.model.slider_predict(self.image_path, save_dir, 256, 0, + self.transforms) + pred1 = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred1.shape, pred_whole.shape) + + # `block_size` == `overlap` + save_dir = osp.join(td, 'pred2') + with self.assertRaises(ValueError): + self.model.slider_predict(self.image_path, save_dir, 128, 128, + self.transforms) + + # `block_size` is a tuple + save_dir = osp.join(td, 'pred3') + self.model.slider_predict(self.image_path, save_dir, (128, 32), 0, + self.transforms) + pred3 = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred3.shape, pred_whole.shape) + + # `block_size` and `overlap` are both tuples + save_dir = osp.join(td, 'pred4') + self.model.slider_predict(self.image_path, save_dir, (128, 100), + (10, 5), self.transforms) + pred4 = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred4.shape, pred_whole.shape) + + # `block_size` larger than image size + save_dir = osp.join(td, 'pred5') + self.model.slider_predict(self.image_path, save_dir, 512, 0, + self.transforms) + pred5 = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred5.shape, pred_whole.shape) + + def test_merge_strategy(self): + with tempfile.TemporaryDirectory() as td: + # Whole-image inference using predict() + pred_whole = self.model.predict(self.image_path, self.transforms) + pred_whole = pred_whole['label_map'] + + # 'keep_first' + save_dir = osp.join(td, 'keep_first') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_first') + pred_keepfirst = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred_keepfirst.shape, pred_whole.shape) + + # 'keep_last' + save_dir = osp.join(td, 'keep_last') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_last') + pred_keeplast = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred_keeplast.shape, pred_whole.shape) + + # 'vote' + save_dir = osp.join(td, 'vote') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='vote') + pred_vote = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred_vote.shape, pred_whole.shape) + + def test_geo_info(self): + with tempfile.TemporaryDirectory() as td: + _, geo_info_in = T.decode_image(self.image_path, read_geo_info=True) + self.model.slider_predict(self.image_path, td, 128, 0, + self.transforms) + _, geo_info_out = T.decode_image( + osp.join(td, self.basename), read_geo_info=True) + self.assertEqual(geo_info_out['geo_trans'], + geo_info_in['geo_trans']) + self.assertEqual(geo_info_out['geo_proj'], geo_info_in['geo_proj']) + + +class TestCDSliderPredict(CommonTest): + def setUp(self): + self.model = pdrs.tasks.cd.BIT(in_channels=10) + self.transforms = T.Compose([ + T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10), + T.ArrangeChangeDetector('test') + ]) + self.image_paths = ("data/ssmt/multispectral_t1.tif", + "data/ssmt/multispectral_t2.tif") + self.basename = osp.basename(self.image_paths[0]) + + def test_blocksize_and_overlap_whole(self): + # Original image size (256, 256) + with tempfile.TemporaryDirectory() as td: + # Whole-image inference using predict() + pred_whole = self.model.predict(self.image_paths, self.transforms) + pred_whole = pred_whole['label_map'] + + # Whole-image inference using slider_predict() + save_dir = osp.join(td, 'pred1') + self.model.slider_predict(self.image_paths, save_dir, 256, 0, + self.transforms) + pred1 = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred1.shape, pred_whole.shape) + + # `block_size` == `overlap` + save_dir = osp.join(td, 'pred2') + with self.assertRaises(ValueError): + self.model.slider_predict(self.image_paths, save_dir, 128, 128, + self.transforms) + + # `block_size` is a tuple + save_dir = osp.join(td, 'pred3') + self.model.slider_predict(self.image_paths, save_dir, (128, 32), 0, + self.transforms) + pred3 = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred3.shape, pred_whole.shape) + + # `block_size` and `overlap` are both tuples + save_dir = osp.join(td, 'pred4') + self.model.slider_predict(self.image_paths, save_dir, (128, 100), + (10, 5), self.transforms) + pred4 = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred4.shape, pred_whole.shape) + + # `block_size` larger than image size + save_dir = osp.join(td, 'pred5') + self.model.slider_predict(self.image_paths, save_dir, 512, 0, + self.transforms) + pred5 = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred5.shape, pred_whole.shape) + + def test_merge_strategy(self): + with tempfile.TemporaryDirectory() as td: + # Whole-image inference using predict() + pred_whole = self.model.predict(self.image_paths, self.transforms) + pred_whole = pred_whole['label_map'] + + # 'keep_first' + save_dir = osp.join(td, 'keep_first') + self.model.slider_predict( + self.image_paths, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_first') + pred_keepfirst = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred_keepfirst.shape, pred_whole.shape) + + # 'keep_last' + save_dir = osp.join(td, 'keep_last') + self.model.slider_predict( + self.image_paths, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_last') + pred_keeplast = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred_keeplast.shape, pred_whole.shape) + + # 'vote' + save_dir = osp.join(td, 'vote') + self.model.slider_predict( + self.image_paths, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='vote') + pred_vote = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred_vote.shape, pred_whole.shape) + + def test_geo_info(self): + with tempfile.TemporaryDirectory() as td: + _, geo_info_in = T.decode_image( + self.image_paths[0], read_geo_info=True) + self.model.slider_predict(self.image_paths, td, 128, 0, + self.transforms) + _, geo_info_out = T.decode_image( + osp.join(td, self.basename), read_geo_info=True) + self.assertEqual(geo_info_out['geo_trans'], + geo_info_in['geo_trans']) + self.assertEqual(geo_info_out['geo_proj'], geo_info_in['geo_proj']) From 9c3c14af77ef61c133c1b4bc6b73bb0689c83a55 Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Fri, 2 Sep 2022 20:37:14 +0800 Subject: [PATCH 02/10] Update docs --- docs/apis/infer.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/apis/infer.md b/docs/apis/infer.md index 2a5b62d..7c493eb 100644 --- a/docs/apis/infer.md +++ b/docs/apis/infer.md @@ -134,7 +134,9 @@ def slider_predict(self, save_dir, block_size, overlap=36, - transforms=None): + transforms=None, + invalid_value=255, + merge_strategy='keep_last'): ``` 输入参数列表: @@ -143,9 +145,11 @@ def slider_predict(self, |-------|----|--------|-----| |`img_file`|`str`|输入影像路径。|| |`save_dir`|`str`|预测结果输出路径。|| -|`block_size`|`list[int]` \| `tuple[int]` \| `int`|滑窗的窗口大小(以列表或元组指定长、宽或以一个整数指定相同的长宽)。|| -|`overlap`|`list[int]` \| `tuple[int]` \| `int`|滑窗的滑动步长(以列表或元组指定长、宽或以一个整数指定相同的长宽)。|`36`| +|`block_size`|`list[int]` \| `tuple[int]` \| `int`|滑窗的窗口大小(以列表或元组指定宽度、高度或以一个整数指定相同的宽高)。|| +|`overlap`|`list[int]` \| `tuple[int]` \| `int`|滑窗的滑动步长(以列表或元组指定宽度、高度或以一个整数指定相同的宽高)。|`36`| |`transforms`|`paddlers.transforms.Compose` \| `None`|对输入数据应用的数据变换算子。若为`None`,则使用训练器在验证阶段使用的数据变换算子。|`None`| +|`invalid_value`|`int`|输出影像中用于标记无效像素的数值。|`255`| +|`merge_strategy`|`str`|合并滑窗重叠区域使用的策略。`'keep_first'`表示保留遍历顺序(从左至右,从上往下,列优先)最靠前的窗口的预测值;`'keep_last'`表示保留遍历顺序最靠后的窗口的预测值;`'vote'`表示使用投票策略,即对于每个像素,最终预测值为所有覆盖该像素的滑窗给出的预测值中出现频率最高者。需要注意的是,在对大尺寸影像进行`overlap`较大的密集推理时,使用`'vote'`策略可能导致较长的推理时间,但给出的预测结果在窗口的接缝处相比其它两种策略将更加平滑。|`'keep_last'`| 变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准。 From bf54499f5a6ab0eb28e7b01d9b9c58f5d507b520 Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Sun, 4 Sep 2022 03:13:09 +0800 Subject: [PATCH 03/10] Add accum strategy --- paddlers/tasks/change_detector.py | 137 +------------ paddlers/tasks/segmenter.py | 137 +------------ paddlers/tasks/utils/slider_predict.py | 254 ++++++++++++++++++++++++- tests/tasks/test_slider_predict.py | 50 +++-- 4 files changed, 301 insertions(+), 277 deletions(-) diff --git a/paddlers/tasks/change_detector.py b/paddlers/tasks/change_detector.py index 30babb5..790173b 100644 --- a/paddlers/tasks/change_detector.py +++ b/paddlers/tasks/change_detector.py @@ -35,7 +35,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict from .base import BaseModel from .utils import seg_metrics as metrics from .utils.infer_nets import InferCDNet -from .utils.slider_predict import SlowCache as Cache +from .utils.slider_predict import slider_predict __all__ = [ "CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT", @@ -606,139 +606,8 @@ class BaseChangeDetector(BaseModel): there are conflicts in the overlapping pixels. Defaults to 'keep_last'. """ - try: - from osgeo import gdal - except: - import gdal - - if isinstance(block_size, int): - block_size = (block_size, block_size) - elif isinstance(block_size, (tuple, list)) and len(block_size) == 2: - block_size = tuple(block_size) - else: - raise ValueError( - "`block_size` must be a tuple/list of length 2 or an integer.") - if isinstance(overlap, int): - overlap = (overlap, overlap) - elif isinstance(overlap, (tuple, list)) and len(overlap) == 2: - overlap = tuple(overlap) - else: - raise ValueError( - "`overlap` must be a tuple/list of length 2 or an integer.") - - if merge_strategy not in ('keep_first', 'keep_last', 'vote'): - raise ValueError( - "{} is not a supported stragegy for block merging.".format( - merge_strategy)) - if overlap == (0, 0): - # When there is no overlap, use 'keep_last' strategy as it introduces least overheads - merge_strategy = 'keep_last' - if merge_strategy == 'vote': - logging.warning( - "Currently, a naive Python-implemented cache is used for aggregating voting results. " - "For higher performance in inferring large images, please set `merge_strategy` to 'keep_first' or " - "'keep_last'.") - cache = Cache() - - src1_data = gdal.Open(img_files[0]) - src2_data = gdal.Open(img_files[1]) - - # Assume that two input images have the same size - width = src1_data.RasterXSize - height = src1_data.RasterYSize - bands = src1_data.RasterCount - - driver = gdal.GetDriverByName("GTiff") - # Output name is the same as the name of the first image - file_name = osp.basename(osp.normpath(img_files[0])) - # Replace extension name with '.tif' - file_name = osp.splitext(file_name)[0] + ".tif" - if not osp.exists(save_dir): - os.makedirs(save_dir) - save_file = osp.join(save_dir, file_name) - dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte) - - # Set meta-information (consistent with the first image) - dst_data.SetGeoTransform(src1_data.GetGeoTransform()) - dst_data.SetProjection(src1_data.GetProjection()) - - band = dst_data.GetRasterBand(1) - band.WriteArray( - np.full( - (height, width), fill_value=invalid_value, dtype="uint8")) - - prev_yoff, prev_xoff = None, None - prev_h, prev_w = None, None - step = np.array( - block_size, dtype=np.int32) - np.array( - overlap, dtype=np.int32) - if step[0] == 0 or step[1] == 0: - raise ValueError("`block_size` and `overlap` should not be equal.") - for yoff in range(0, height, step[1]): - for xoff in range(0, width, step[0]): - xsize, ysize = block_size - if xoff + xsize > width: - xsize = int(width - xoff) - if yoff + ysize > height: - ysize = int(height - yoff) - - xoff = int(xoff) - yoff = int(yoff) - im1 = src1_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( - (1, 2, 0)) - im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( - (1, 2, 0)) - # Fill - h, w = im1.shape[:2] - im1_fill = np.zeros( - (block_size[1], block_size[0], bands), dtype=im1.dtype) - im1_fill[:h, :w, :] = im1 - - im2_fill = np.zeros( - (block_size[1], block_size[0], bands), dtype=im2.dtype) - im2_fill[:h, :w, :] = im2 - - # Predict - pred = self.predict((im1_fill, im2_fill), transforms) - pred = pred["label_map"].astype('uint8') - pred = pred[:h, :w] - - # Deal with overlapping pixels - if merge_strategy == 'vote': - cache.push_block(yoff, xoff, h, w, pred) - pred = cache.get_block(yoff, xoff, h, w) - pred = pred.astype('uint8') - if prev_yoff is not None: - pop_h = yoff - prev_yoff - else: - pop_h = 0 - if prev_xoff is not None: - if xoff < prev_xoff: - pop_w = prev_w - else: - pop_w = xoff - prev_xoff - else: - pop_w = 0 - cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w) - elif merge_strategy == 'keep_first': - rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize) - mask = rd_block != invalid_value - pred = np.where(mask, rd_block, pred) - elif merge_strategy == 'keep_last': - pass - - # Write to file - band.WriteArray(pred, xoff, yoff) - dst_data.FlushCache() - - prev_xoff = xoff - prev_w = w - - prev_yoff = yoff - prev_h = h - - dst_data = None - logging.info("GeoTiff file saved in {}.".format(save_file)) + slider_predict(self, img_files, save_dir, block_size, overlap, + transforms, invalid_value, merge_strategy) def preprocess(self, images, transforms, to_tensor=True): self._check_transforms(transforms, 'test') diff --git a/paddlers/tasks/segmenter.py b/paddlers/tasks/segmenter.py index 1037041..8b332aa 100644 --- a/paddlers/tasks/segmenter.py +++ b/paddlers/tasks/segmenter.py @@ -34,7 +34,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict from .base import BaseModel from .utils import seg_metrics as metrics from .utils.infer_nets import InferSegNet -from .utils.slider_predict import SlowCache as Cache +from .utils.slider_predict import slider_predict __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"] @@ -572,135 +572,16 @@ class BaseSegmenter(BaseModel): invalid_value (int, optional): Value that marks invalid pixels in output image. Defaults to 255. merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices - are {'keep_first', 'keep_last', 'vote'}. 'keep_first' and 'keep_last' - means keeping the values of the first and the last block in traversal - order, respectively. 'vote' means applying a simple voting strategy when - there are conflicts in the overlapping pixels. Defaults to 'keep_last'. + are {'keep_first', 'keep_last', 'vote', 'accum'}. 'keep_first' and + 'keep_last' means keeping the values of the first and the last block in + traversal order, respectively. 'vote' means applying a simple voting + strategy when there are conflicts in the overlapping pixels. 'accum' + means determining the class of an overlapping pixel according to + accumulated probabilities. Defaults to 'keep_last'. """ - try: - from osgeo import gdal - except: - import gdal - - if isinstance(block_size, int): - block_size = (block_size, block_size) - elif isinstance(block_size, (tuple, list)) and len(block_size) == 2: - block_size = tuple(block_size) - else: - raise ValueError( - "`block_size` must be a tuple/list of length 2 or an integer.") - if isinstance(overlap, int): - overlap = (overlap, overlap) - elif isinstance(overlap, (tuple, list)) and len(overlap) == 2: - overlap = tuple(overlap) - else: - raise ValueError( - "`overlap` must be a tuple/list of length 2 or an integer.") - - if merge_strategy not in ('keep_first', 'keep_last', 'vote'): - raise ValueError( - "{} is not a supported stragegy for block merging.".format( - merge_strategy)) - if overlap == (0, 0): - # When there is no overlap, use 'keep_last' strategy as it introduces least overheads - merge_strategy = 'keep_last' - if merge_strategy == 'vote': - logging.warning( - "Currently, a naive Python-implemented cache is used for aggregating voting results. " - "For higher performance in inferring large images, please set `merge_strategy` to 'keep_first' or " - "'keep_last'.") - cache = Cache() - - src_data = gdal.Open(img_file) - width = src_data.RasterXSize - height = src_data.RasterYSize - bands = src_data.RasterCount - - driver = gdal.GetDriverByName("GTiff") - file_name = osp.basename(osp.normpath(img_file)) - # Replace extension name with '.tif' - file_name = osp.splitext(file_name)[0] + ".tif" - if not osp.exists(save_dir): - os.makedirs(save_dir) - save_file = osp.join(save_dir, file_name) - dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte) - - # Set meta-information - dst_data.SetGeoTransform(src_data.GetGeoTransform()) - dst_data.SetProjection(src_data.GetProjection()) - - band = dst_data.GetRasterBand(1) - band.WriteArray( - np.full( - (height, width), fill_value=invalid_value, dtype="uint8")) - - prev_yoff, prev_xoff = None, None - prev_h, prev_w = None, None - step = np.array( - block_size, dtype=np.int32) - np.array( - overlap, dtype=np.int32) - if step[0] == 0 or step[1] == 0: - raise ValueError("`block_size` and `overlap` should not be equal.") - for yoff in range(0, height, step[1]): - for xoff in range(0, width, step[0]): - xsize, ysize = block_size - if xoff + xsize > width: - xsize = int(width - xoff) - if yoff + ysize > height: - ysize = int(height - yoff) - - xoff = int(xoff) - yoff = int(yoff) - im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( - (1, 2, 0)) - # Fill - h, w = im.shape[:2] - im_fill = np.zeros( - (block_size[1], block_size[0], bands), dtype=im.dtype) - im_fill[:h, :w, :] = im - - # Predict - pred = self.predict(im_fill, transforms) - pred = pred["label_map"].astype('uint8') - pred = pred[:h, :w] - - # Deal with overlapping pixels - if merge_strategy == 'vote': - cache.push_block(yoff, xoff, h, w, pred) - pred = cache.get_block(yoff, xoff, h, w) - pred = pred.astype('uint8') - if prev_yoff is not None: - pop_h = yoff - prev_yoff - else: - pop_h = 0 - if prev_xoff is not None: - if xoff < prev_xoff: - pop_w = prev_w - else: - pop_w = xoff - prev_xoff - else: - pop_w = 0 - cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w) - elif merge_strategy == 'keep_first': - rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize) - mask = rd_block != invalid_value - pred = np.where(mask, rd_block, pred) - elif merge_strategy == 'keep_last': - pass - - # Write to file - band.WriteArray(pred, xoff, yoff) - dst_data.FlushCache() - - prev_xoff = xoff - prev_w = w - - prev_yoff = yoff - prev_h = h - - dst_data = None - logging.info("GeoTiff file saved in {}.".format(save_file)) + slider_predict(self, img_file, save_dir, block_size, overlap, + transforms, invalid_value, merge_strategy) def preprocess(self, images, transforms, to_tensor=True): self._check_transforms(transforms, 'test') diff --git a/paddlers/tasks/utils/slider_predict.py b/paddlers/tasks/utils/slider_predict.py index 254882b..a9ecdf6 100644 --- a/paddlers/tasks/utils/slider_predict.py +++ b/paddlers/tasks/utils/slider_predict.py @@ -12,12 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import os.path as osp +from abc import ABCMeta, abstractmethod from collections import Counter, defaultdict import numpy as np +import paddlers.utils.logging as logging -class SlowCache(object): + +class Cache(metaclass=ABCMeta): + @abstractmethod + def get_block(self, i_st, j_st, h, w): + pass + + +class SlowCache(Cache): def __init__(self): self.cache = defaultdict(Counter) @@ -50,3 +61,244 @@ class SlowCache(object): row.append(self.get_pixel(i, j)) block.append(row) return np.asarray(block) + + +class ProbCache(Cache): + def __init__(self, h, w, ch, cw, sh, sw, dtype=np.float32, order='c'): + self.cache = None + self.h = h + self.w = w + self.ch = ch + self.cw = cw + self.sh = sh + self.sw = sw + if not issubclass(dtype, np.floating): + raise TypeError("`dtype` must be one of the floating types.") + self.dtype = dtype + order = order.lower() + if order not in ('c', 'f'): + raise ValueError("`order` other than 'c' and 'f' is not supported.") + self.order = order + + def _alloc_memory(self, nc): + if self.order == 'c': + # Colomn-first order (C-style) + # + # <-- cw --> + # |--------|---------------------|^ ^ + # | || | sh + # |--------|---------------------|| ch v + # | || + # |--------|---------------------|v + # <------------ w ---------------> + self.cache = np.zeros((self.ch, self.w, nc), dtype=self.dtype) + elif self.order == 'f': + # Row-first order (Fortran-style) + # + # <-- sw --> + # <---- cw ----> + # |--------|---|^ ^ + # | | || | + # | | || ch + # | | || | + # |--------|---|| h v + # | | || + # | | || + # | | || + # |--------|---|v + self.cache = np.zeros((self.h, self.cw, nc), dtype=self.dtype) + + def update_block(self, i_st, j_st, h, w, prob_map): + if self.cache is None: + nc = prob_map.shape[2] + # Lazy allocation of memory + self._alloc_memory(nc) + self.cache[i_st:i_st + h, j_st:j_st + w] += prob_map + + def roll_cache(self): + if self.order == 'c': + self.cache = np.roll(self.cache, -self.sh, axis=0) + self.cache[self.sh:self.ch, :] = 0 + elif self.order == 'f': + self.cache = np.roll(self.cache, -self.sw, axis=1) + self.cache[:, self.sw:self.cw] = 0 + + def get_block(self, i_st, j_st, h, w): + return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2) + + +def slider_predict(predictor, img_file, save_dir, block_size, overlap, + transforms, invalid_value, merge_strategy): + """ + Do inference using sliding windows. + + Args: + predictor (object): Object that implements `predict()` method. + img_file (str|tuple[str]): Image path(s). + save_dir (str): Directory that contains saved geotiff file. + block_size (list[int] | tuple[int] | int): + Size of block. If `block_size` is list or tuple, it should be in + (W, H) format. + overlap (list[int] | tuple[int] | int): + Overlap between two blocks. If `overlap` is list or tuple, it should + be in (W, H) format. + transforms (paddlers.transforms.Compose|None): Transforms for inputs. If + None, the transforms for evaluation process will be used. + invalid_value (int): Value that marks invalid pixels in output image. + Defaults to 255. + merge_strategy (str): Strategy to merge overlapping blocks. Choices are + {'keep_first', 'keep_last', 'vote', 'accum'}. 'keep_first' and + 'keep_last' means keeping the values of the first and the last block in + traversal order, respectively. 'vote' means applying a simple voting + strategy when there are conflicts in the overlapping pixels. 'accum' + means determining the class of an overlapping pixel according to + accumulated probabilities. + """ + + try: + from osgeo import gdal + except: + import gdal + + if isinstance(block_size, int): + block_size = (block_size, block_size) + elif isinstance(block_size, (tuple, list)) and len(block_size) == 2: + block_size = tuple(block_size) + else: + raise ValueError( + "`block_size` must be a tuple/list of length 2 or an integer.") + if isinstance(overlap, int): + overlap = (overlap, overlap) + elif isinstance(overlap, (tuple, list)) and len(overlap) == 2: + overlap = tuple(overlap) + else: + raise ValueError( + "`overlap` must be a tuple/list of length 2 or an integer.") + + if merge_strategy not in ('keep_first', 'keep_last', 'vote', 'accum'): + raise ValueError("{} is not a supported stragegy for block merging.". + format(merge_strategy)) + + step = np.array( + block_size, dtype=np.int32) - np.array( + overlap, dtype=np.int32) + if step[0] == 0 or step[1] == 0: + raise ValueError("`block_size` and `overlap` should not be equal.") + + if isinstance(img_file, tuple): + if len(img_file) != 2: + raise ValueError("Tuple `img_file` must have the length of two.") + # Assume that two input images have the same size + src_data = gdal.Open(img_file[0]) + src2_data = gdal.Open(img_file[1]) + # Output name is the same as the name of the first image + file_name = osp.basename(osp.normpath(img_file[0])) + else: + src_data = gdal.Open(img_file) + file_name = osp.basename(osp.normpath(img_file)) + + # Get size of original raster + width = src_data.RasterXSize + height = src_data.RasterYSize + bands = src_data.RasterCount + + if block_size[0] > width or block_size[1] > height: + raise ValueError("`block_size` should not be larger than image size.") + + driver = gdal.GetDriverByName("GTiff") + if not osp.exists(save_dir): + os.makedirs(save_dir) + # Replace extension name with '.tif' + file_name = osp.splitext(file_name)[0] + ".tif" + save_file = osp.join(save_dir, file_name) + dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte) + + # Set meta-information + dst_data.SetGeoTransform(src_data.GetGeoTransform()) + dst_data.SetProjection(src_data.GetProjection()) + + # Initialize raster with `invalid_value` + band = dst_data.GetRasterBand(1) + band.WriteArray( + np.full( + (height, width), fill_value=invalid_value, dtype="uint8")) + + if overlap == (0, 0) or block_size == (width, height): + # When there is no overlap or the whole image is used as input, + # use 'keep_last' strategy as it introduces least overheads + merge_strategy = 'keep_last' + if merge_strategy == 'vote': + logging.warning( + "Currently, a naive Python-implemented cache is used for aggregating voting results. " + "For higher performance in inferring large images, please set `merge_strategy` to 'keep_first', " + "'keep_last', or 'accum'.") + cache = SlowCache() + elif merge_strategy == 'accum': + cache = ProbCache(height, width, *block_size, *step) + + prev_yoff, prev_xoff = None, None + + for yoff in range(0, height, step[1]): + for xoff in range(0, width, step[0]): + xsize, ysize = block_size + if xoff + xsize > width: + xoff = width - xsize + if yoff + ysize > height: + yoff = height - ysize + + # Read and fill + im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( + (1, 2, 0)) + + if isinstance(img_file, tuple): + im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( + (1, 2, 0)) + # Predict + out = predictor.predict((im, im2), transforms) + else: + # Predict + out = predictor.predict(im, transforms) + + pred = out['label_map'].astype('uint8') + pred = pred[:ysize, :xsize] + + # Deal with overlapping pixels + if merge_strategy == 'vote': + cache.push_block(yoff, xoff, ysize, xsize, pred) + pred = cache.get_block(yoff, xoff, ysize, xsize) + pred = pred.astype('uint8') + if prev_yoff is not None: + pop_h = yoff - prev_yoff + else: + pop_h = 0 + if prev_xoff is not None: + if xoff < prev_xoff: + pop_w = xsize + else: + pop_w = xoff - prev_xoff + else: + pop_w = 0 + cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w) + elif merge_strategy == 'keep_first': + rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize) + mask = rd_block != invalid_value + pred = np.where(mask, rd_block, pred) + elif merge_strategy == 'keep_last': + pass + elif merge_strategy == 'accum': + prob = out['score_map'] + prob = prob[:ysize, :xsize] + cache.update_block(0, yoff, ysize, xsize, prob) + pred = cache.get_block(0, yoff, ysize, xsize) + if xoff + step[0] >= width: + cache.roll_cache() + + # Write to file + band.WriteArray(pred, xoff, yoff) + dst_data.FlushCache() + + prev_xoff = xoff + prev_yoff = yoff + + dst_data = None + logging.info("GeoTiff file saved in {}.".format(save_file)) diff --git a/tests/tasks/test_slider_predict.py b/tests/tasks/test_slider_predict.py index bc3bd5a..fce8550 100644 --- a/tests/tasks/test_slider_predict.py +++ b/tests/tasks/test_slider_predict.py @@ -75,13 +75,9 @@ class TestSegSliderPredict(CommonTest): # `block_size` larger than image size save_dir = osp.join(td, 'pred5') - self.model.slider_predict(self.image_path, save_dir, 512, 0, - self.transforms) - pred5 = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred5.shape, pred_whole.shape) + with self.assertRaises(ValueError): + self.model.slider_predict(self.image_path, save_dir, 512, 0, + self.transforms) def test_merge_strategy(self): with tempfile.TemporaryDirectory() as td: @@ -134,6 +130,21 @@ class TestSegSliderPredict(CommonTest): decode_sar=False) self.check_output_equal(pred_vote.shape, pred_whole.shape) + # 'accum' + save_dir = osp.join(td, 'accum') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='vote') + pred_accum = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred_accum.shape, pred_whole.shape) + def test_geo_info(self): with tempfile.TemporaryDirectory() as td: _, geo_info_in = T.decode_image(self.image_path, read_geo_info=True) @@ -202,13 +213,9 @@ class TestCDSliderPredict(CommonTest): # `block_size` larger than image size save_dir = osp.join(td, 'pred5') - self.model.slider_predict(self.image_paths, save_dir, 512, 0, - self.transforms) - pred5 = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred5.shape, pred_whole.shape) + with self.assertRaises(ValueError): + self.model.slider_predict(self.image_paths, save_dir, 512, 0, + self.transforms) def test_merge_strategy(self): with tempfile.TemporaryDirectory() as td: @@ -261,6 +268,21 @@ class TestCDSliderPredict(CommonTest): decode_sar=False) self.check_output_equal(pred_vote.shape, pred_whole.shape) + # 'accum' + save_dir = osp.join(td, 'accum') + self.model.slider_predict( + self.image_paths, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='vote') + pred_accum = T.decode_image( + osp.join(save_dir, self.basename), + to_uint8=False, + decode_sar=False) + self.check_output_equal(pred_accum.shape, pred_whole.shape) + def test_geo_info(self): with tempfile.TemporaryDirectory() as td: _, geo_info_in = T.decode_image( From 5a75c7d1e138bb42db7ba94a4c18e86f214d5b61 Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Sun, 4 Sep 2022 03:22:53 +0800 Subject: [PATCH 04/10] Update docs about accum strategy --- docs/apis/infer.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/apis/infer.md b/docs/apis/infer.md index 7c493eb..b06ee8f 100644 --- a/docs/apis/infer.md +++ b/docs/apis/infer.md @@ -149,9 +149,9 @@ def slider_predict(self, |`overlap`|`list[int]` \| `tuple[int]` \| `int`|滑窗的滑动步长(以列表或元组指定宽度、高度或以一个整数指定相同的宽高)。|`36`| |`transforms`|`paddlers.transforms.Compose` \| `None`|对输入数据应用的数据变换算子。若为`None`,则使用训练器在验证阶段使用的数据变换算子。|`None`| |`invalid_value`|`int`|输出影像中用于标记无效像素的数值。|`255`| -|`merge_strategy`|`str`|合并滑窗重叠区域使用的策略。`'keep_first'`表示保留遍历顺序(从左至右,从上往下,列优先)最靠前的窗口的预测值;`'keep_last'`表示保留遍历顺序最靠后的窗口的预测值;`'vote'`表示使用投票策略,即对于每个像素,最终预测值为所有覆盖该像素的滑窗给出的预测值中出现频率最高者。需要注意的是,在对大尺寸影像进行`overlap`较大的密集推理时,使用`'vote'`策略可能导致较长的推理时间,但给出的预测结果在窗口的接缝处相比其它两种策略将更加平滑。|`'keep_last'`| +|`merge_strategy`|`str`|合并滑窗重叠区域使用的策略。`'keep_first'`表示保留遍历顺序(从左至右,从上往下,列优先)最靠前的窗口的预测类别;`'keep_last'`表示保留遍历顺序最靠后的窗口的预测类别;`'vote'`表示使用投票策略,即对于每个像素,最终预测类别为所有覆盖该像素的滑窗给出的预测类别中出现频率最高者;`'accum'`表示通过将各窗口在重叠区域给出的预测概率累加,计算最终预测类别。需要注意的是,在对大尺寸影像进行`overlap`较大的密集推理时,使用`'vote'`策略可能导致较长的推理时间。|`'keep_last'`| -变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准。 +变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准,存储滑窗推理结果的文件名也与第一时相影像文件相同。 ## 静态图推理API From 0334a262c5d04d97c2834b71234dc81c0c1d51ad Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Mon, 5 Sep 2022 19:36:02 +0800 Subject: [PATCH 05/10] net_initialize->initialize_net --- paddlers/tasks/base.py | 2 +- paddlers/tasks/change_detector.py | 10 +++++----- paddlers/tasks/classifier.py | 2 +- paddlers/tasks/object_detector.py | 2 +- paddlers/tasks/restorer.py | 2 +- paddlers/tasks/segmenter.py | 14 ++++++-------- 6 files changed, 15 insertions(+), 17 deletions(-) diff --git a/paddlers/tasks/base.py b/paddlers/tasks/base.py index 5950691..0c32bb5 100644 --- a/paddlers/tasks/base.py +++ b/paddlers/tasks/base.py @@ -86,7 +86,7 @@ class BaseModel(metaclass=ModelMeta): self.quant_config = None self.fixed_input_shape = None - def net_initialize(self, + def initialize_net(self, pretrain_weights=None, save_dir='.', resume_checkpoint=None, diff --git a/paddlers/tasks/change_detector.py b/paddlers/tasks/change_detector.py index 6de48c7..afccdc4 100644 --- a/paddlers/tasks/change_detector.py +++ b/paddlers/tasks/change_detector.py @@ -316,7 +316,7 @@ class BaseChangeDetector(BaseModel): exit=True) pretrained_dir = osp.join(save_dir, 'pretrain') is_backbone_weights = pretrain_weights == 'IMAGENET' - self.net_initialize( + self.initialize_net( pretrain_weights=pretrain_weights, save_dir=pretrained_dir, resume_checkpoint=resume_checkpoint, @@ -607,13 +607,13 @@ class BaseChangeDetector(BaseModel): invalid_value (int, optional): Value that marks invalid pixels in output image. Defaults to 255. merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices - are {'keep_first', 'keep_last', 'vote'}. 'keep_first' and 'keep_last' + are {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' means keeping the values of the first and the last block in traversal - order, respectively. 'vote' means applying a simple voting strategy when - there are conflicts in the overlapping pixels. Defaults to 'keep_last'. + order, respectively. 'accum' means determining the class of an overlapping + pixel according to accumulated probabilities. Defaults to 'keep_last'. """ - slider_predict(self, img_files, save_dir, block_size, overlap, + slider_predict(self.predict, img_files, save_dir, block_size, overlap, transforms, invalid_value, merge_strategy) def preprocess(self, images, transforms, to_tensor=True): diff --git a/paddlers/tasks/classifier.py b/paddlers/tasks/classifier.py index 83c20fb..33ba5f3 100644 --- a/paddlers/tasks/classifier.py +++ b/paddlers/tasks/classifier.py @@ -288,7 +288,7 @@ class BaseClassifier(BaseModel): exit=True) pretrained_dir = osp.join(save_dir, 'pretrain') is_backbone_weights = False - self.net_initialize( + self.initialize_net( pretrain_weights=pretrain_weights, save_dir=pretrained_dir, resume_checkpoint=resume_checkpoint, diff --git a/paddlers/tasks/object_detector.py b/paddlers/tasks/object_detector.py index ca25213..313a893 100644 --- a/paddlers/tasks/object_detector.py +++ b/paddlers/tasks/object_detector.py @@ -347,7 +347,7 @@ class BaseDetector(BaseModel): "Invalid pretrained weights. Please specify a .pdparams file.", exit=True) pretrained_dir = osp.join(save_dir, 'pretrain') - self.net_initialize( + self.initialize_net( pretrain_weights=pretrain_weights, save_dir=pretrained_dir, resume_checkpoint=resume_checkpoint, diff --git a/paddlers/tasks/restorer.py b/paddlers/tasks/restorer.py index d9ce6ad..61691f0 100644 --- a/paddlers/tasks/restorer.py +++ b/paddlers/tasks/restorer.py @@ -283,7 +283,7 @@ class BaseRestorer(BaseModel): exit=True) pretrained_dir = osp.join(save_dir, 'pretrain') is_backbone_weights = pretrain_weights == 'IMAGENET' - self.net_initialize( + self.initialize_net( pretrain_weights=pretrain_weights, save_dir=pretrained_dir, resume_checkpoint=resume_checkpoint, diff --git a/paddlers/tasks/segmenter.py b/paddlers/tasks/segmenter.py index 8026b23..a319fc5 100644 --- a/paddlers/tasks/segmenter.py +++ b/paddlers/tasks/segmenter.py @@ -308,7 +308,7 @@ class BaseSegmenter(BaseModel): exit=True) pretrained_dir = osp.join(save_dir, 'pretrain') is_backbone_weights = pretrain_weights == 'IMAGENET' - self.net_initialize( + self.initialize_net( pretrain_weights=pretrain_weights, save_dir=pretrained_dir, resume_checkpoint=resume_checkpoint, @@ -579,15 +579,13 @@ class BaseSegmenter(BaseModel): invalid_value (int, optional): Value that marks invalid pixels in output image. Defaults to 255. merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices - are {'keep_first', 'keep_last', 'vote', 'accum'}. 'keep_first' and - 'keep_last' means keeping the values of the first and the last block in - traversal order, respectively. 'vote' means applying a simple voting - strategy when there are conflicts in the overlapping pixels. 'accum' - means determining the class of an overlapping pixel according to - accumulated probabilities. Defaults to 'keep_last'. + are {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' + means keeping the values of the first and the last block in traversal + order, respectively. 'accum' means determining the class of an overlapping + pixel according to accumulated probabilities. Defaults to 'keep_last'. """ - slider_predict(self, img_file, save_dir, block_size, overlap, + slider_predict(self.predict, img_file, save_dir, block_size, overlap, transforms, invalid_value, merge_strategy) def preprocess(self, images, transforms, to_tensor=True): From 7a0f5405f669c217a798d8d511d7538300e38248 Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Mon, 5 Sep 2022 19:36:23 +0800 Subject: [PATCH 06/10] Remove vote mode and fix bugs --- docs/apis/infer.md | 7 ++- paddlers/deploy/predictor.py | 55 ++++++++++++++++++++-- paddlers/tasks/utils/slider_predict.py | 63 +++++++------------------- tests/tasks/test_slider_predict.py | 34 +------------- 4 files changed, 76 insertions(+), 83 deletions(-) diff --git a/docs/apis/infer.md b/docs/apis/infer.md index c17c98c..d3c272c 100644 --- a/docs/apis/infer.md +++ b/docs/apis/infer.md @@ -170,7 +170,7 @@ def slider_predict(self, |`overlap`|`list[int]` \| `tuple[int]` \| `int`|滑窗的滑动步长(以列表或元组指定宽度、高度或以一个整数指定相同的宽高)。|`36`| |`transforms`|`paddlers.transforms.Compose` \| `None`|对输入数据应用的数据变换算子。若为`None`,则使用训练器在验证阶段使用的数据变换算子。|`None`| |`invalid_value`|`int`|输出影像中用于标记无效像素的数值。|`255`| -|`merge_strategy`|`str`|合并滑窗重叠区域使用的策略。`'keep_first'`表示保留遍历顺序(从左至右,从上往下,列优先)最靠前的窗口的预测类别;`'keep_last'`表示保留遍历顺序最靠后的窗口的预测类别;`'vote'`表示使用投票策略,即对于每个像素,最终预测类别为所有覆盖该像素的滑窗给出的预测类别中出现频率最高者;`'accum'`表示通过将各窗口在重叠区域给出的预测概率累加,计算最终预测类别。需要注意的是,在对大尺寸影像进行`overlap`较大的密集推理时,使用`'vote'`策略可能导致较长的推理时间。|`'keep_last'`| +|`merge_strategy`|`str`|合并滑窗重叠区域使用的策略。`'keep_first'`表示保留遍历顺序(从左至右,从上往下,列优先)最靠前的窗口的预测类别;`'keep_last'`表示保留遍历顺序最靠后的窗口的预测类别;`'accum'`表示通过将各窗口在重叠区域给出的预测概率累加,计算最终预测类别。需要注意的是,在对大尺寸影像进行`overlap`较大的密集推理时,使用`'accum'`策略可能导致较长的推理时间,但一般能够在窗口交界部分取得更好的表现。|`'keep_last'`| 变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准,存储滑窗推理结果的文件名也与第一时相影像文件相同。 @@ -220,5 +220,10 @@ def predict(self, |`transforms`|`paddlers.transforms.Compose`\|`None`|对输入数据应用的数据变换算子。若为`None`,则使用从`model.yml`中读取的算子。|`None`| |`warmup_iters`|`int`|预热轮数,用于评估模型推理以及前后处理速度。若大于1,将预先重复执行`warmup_iters`次推理,而后才开始正式的预测及其速度评估。|`0`| |`repeats`|`int`|重复次数,用于评估模型推理以及前后处理速度。若大于1,将执行`repeats`次预测并取时间平均值。|`1`| +|`quiet`|`bool`|若为`True`,不打印计时信息。|`False`| `Predictor.predict()`的返回格式与相应的动态图推理API的返回格式完全相同,详情请参考[动态图推理API](#动态图推理api)。 + +### `Predictor.slider_predict()` + +实现滑窗推理功能。用法与`BaseSegmenter`和`BaseChangeDetector`的`slider_predict()`方法相同。 diff --git a/paddlers/deploy/predictor.py b/paddlers/deploy/predictor.py index 1b2c493..157abca 100644 --- a/paddlers/deploy/predictor.py +++ b/paddlers/deploy/predictor.py @@ -14,6 +14,7 @@ import os.path as osp from operator import itemgetter +from functools import partial import numpy as np import paddle @@ -23,6 +24,7 @@ from paddle.inference import PrecisionType from paddlers.tasks import load_model from paddlers.utils import logging, Timer +from paddlers.tasks.utils.slider_predict import slider_predict class Predictor(object): @@ -271,22 +273,24 @@ class Predictor(object): topk=1, transforms=None, warmup_iters=0, - repeats=1): + repeats=1, + quiet=False): """ - Do prediction. + Do inference. Args: img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For scene classification, image restoration, object detection and semantic segmentation tasks, `img_file` should be either the path of the image to predict, a decoded image (a np.ndarray, which should be consistent with what you get from passing image path to paddlers.transforms.decode_image()), or a list of image paths or decoded images. For change detection tasks, - img_file should be a tuple of image paths, a tuple of decoded images, or a list of tuples. + `img_file` should be a tuple of image paths, a tuple of decoded images, or a list of tuples. topk(int, optional): Top-k values to reserve in a classification result. Defaults to 1. transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms from `model.yml`. Defaults to None. warmup_iters (int, optional): Warm-up iterations before measuring the execution time. Defaults to 0. repeats (int, optional): Number of repetitions to evaluate model inference and data processing speed. If greater than 1, the reported time consumption is the average of all repeats. Defaults to 1. + quiet (bool, optional): If True, do not display the timing information. Defaults to False. """ if repeats < 1: @@ -313,12 +317,55 @@ class Predictor(object): self.timer.repeats = repeats self.timer.img_num = len(images) - self.timer.info(average=True) + if not quiet: + self.timer.info(average=True) if isinstance(img_file, (str, np.ndarray, tuple)): results = results[0] return results + def slider_predict(self, + img_file, + save_dir, + block_size, + overlap=36, + transforms=None, + invalid_value=255, + merge_strategy='keep_last'): + """ + Do inference using sliding windows. Only semantic segmentation and change detection models are supported in the + sliding-predicting mode. + + Args: + img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For semantic segmentation tasks, `img_file` + should be either the path of the image to predict, a decoded image (a np.ndarray, which should be + consistent with what you get from passing image path to paddlers.transforms.decode_image()), or a list of + image paths or decoded images. For change detection tasks, `img_file` should be a tuple of image paths, a + tuple of decoded images, or a list of tuples. + save_dir (str): Directory that contains saved geotiff file. + block_size (list[int] | tuple[int] | int): Size of block. If `block_size` is a list or tuple, it should be in + (W, H) format. + overlap (list[int] | tuple[int] | int, optional): Overlap between two blocks. If `overlap` is a list or tuple, + it should be in (W, H) format. Defaults to 36. + transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms + from `model.yml`. Defaults to None. + invalid_value (int, optional): Value that marks invalid pixels in output image. Defaults to 255. + merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices are + {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' means keeping the values of the first and + the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel + according to accumulated probabilities. Defaults to 'keep_last'. + """ + slider_predict( + partial( + self.predict, quiet=True), + img_file, + save_dir, + block_size, + overlap, + transforms, + invalid_value, + merge_strategy) + def batch_predict(self, image_list, **params): return self.predict(img_file=image_list, **params) diff --git a/paddlers/tasks/utils/slider_predict.py b/paddlers/tasks/utils/slider_predict.py index a9ecdf6..f32f7f0 100644 --- a/paddlers/tasks/utils/slider_predict.py +++ b/paddlers/tasks/utils/slider_predict.py @@ -118,22 +118,22 @@ class ProbCache(Cache): def roll_cache(self): if self.order == 'c': self.cache = np.roll(self.cache, -self.sh, axis=0) - self.cache[self.sh:self.ch, :] = 0 + self.cache[-self.sh:, :] = 0 elif self.order == 'f': self.cache = np.roll(self.cache, -self.sw, axis=1) - self.cache[:, self.sw:self.cw] = 0 + self.cache[:, -self.sw:] = 0 def get_block(self, i_st, j_st, h, w): return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2) -def slider_predict(predictor, img_file, save_dir, block_size, overlap, +def slider_predict(predict_func, img_file, save_dir, block_size, overlap, transforms, invalid_value, merge_strategy): """ Do inference using sliding windows. Args: - predictor (object): Object that implements `predict()` method. + predict_func (callable): A callable object that makes the prediction. img_file (str|tuple[str]): Image path(s). save_dir (str): Directory that contains saved geotiff file. block_size (list[int] | tuple[int] | int): @@ -147,12 +147,10 @@ def slider_predict(predictor, img_file, save_dir, block_size, overlap, invalid_value (int): Value that marks invalid pixels in output image. Defaults to 255. merge_strategy (str): Strategy to merge overlapping blocks. Choices are - {'keep_first', 'keep_last', 'vote', 'accum'}. 'keep_first' and - 'keep_last' means keeping the values of the first and the last block in - traversal order, respectively. 'vote' means applying a simple voting - strategy when there are conflicts in the overlapping pixels. 'accum' - means determining the class of an overlapping pixel according to - accumulated probabilities. + {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' + means keeping the values of the first and the last block in + traversal order, respectively. 'accum' means determining the class + of an overlapping pixel according to accumulated probabilities. """ try: @@ -175,7 +173,7 @@ def slider_predict(predictor, img_file, save_dir, block_size, overlap, raise ValueError( "`overlap` must be a tuple/list of length 2 or an integer.") - if merge_strategy not in ('keep_first', 'keep_last', 'vote', 'accum'): + if merge_strategy not in ('keep_first', 'keep_last', 'accum'): raise ValueError("{} is not a supported stragegy for block merging.". format(merge_strategy)) @@ -227,16 +225,8 @@ def slider_predict(predictor, img_file, save_dir, block_size, overlap, # When there is no overlap or the whole image is used as input, # use 'keep_last' strategy as it introduces least overheads merge_strategy = 'keep_last' - if merge_strategy == 'vote': - logging.warning( - "Currently, a naive Python-implemented cache is used for aggregating voting results. " - "For higher performance in inferring large images, please set `merge_strategy` to 'keep_first', " - "'keep_last', or 'accum'.") - cache = SlowCache() - elif merge_strategy == 'accum': - cache = ProbCache(height, width, *block_size, *step) - - prev_yoff, prev_xoff = None, None + if merge_strategy == 'accum': + cache = ProbCache(height, width, *block_size[::-1], *step[::-1]) for yoff in range(0, height, step[1]): for xoff in range(0, width, step[0]): @@ -254,32 +244,16 @@ def slider_predict(predictor, img_file, save_dir, block_size, overlap, im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( (1, 2, 0)) # Predict - out = predictor.predict((im, im2), transforms) + out = predict_func((im, im2), transforms=transforms) else: # Predict - out = predictor.predict(im, transforms) + out = predict_func(im, transforms=transforms) pred = out['label_map'].astype('uint8') pred = pred[:ysize, :xsize] # Deal with overlapping pixels - if merge_strategy == 'vote': - cache.push_block(yoff, xoff, ysize, xsize, pred) - pred = cache.get_block(yoff, xoff, ysize, xsize) - pred = pred.astype('uint8') - if prev_yoff is not None: - pop_h = yoff - prev_yoff - else: - pop_h = 0 - if prev_xoff is not None: - if xoff < prev_xoff: - pop_w = xsize - else: - pop_w = xoff - prev_xoff - else: - pop_w = 0 - cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w) - elif merge_strategy == 'keep_first': + if merge_strategy == 'keep_first': rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize) mask = rd_block != invalid_value pred = np.where(mask, rd_block, pred) @@ -288,17 +262,14 @@ def slider_predict(predictor, img_file, save_dir, block_size, overlap, elif merge_strategy == 'accum': prob = out['score_map'] prob = prob[:ysize, :xsize] - cache.update_block(0, yoff, ysize, xsize, prob) - pred = cache.get_block(0, yoff, ysize, xsize) - if xoff + step[0] >= width: + cache.update_block(0, xoff, ysize, xsize, prob) + pred = cache.get_block(0, xoff, ysize, xsize) + if xoff + xsize >= width: cache.roll_cache() # Write to file band.WriteArray(pred, xoff, yoff) dst_data.FlushCache() - prev_xoff = xoff - prev_yoff = yoff - dst_data = None logging.info("GeoTiff file saved in {}.".format(save_file)) diff --git a/tests/tasks/test_slider_predict.py b/tests/tasks/test_slider_predict.py index fce8550..d5eff91 100644 --- a/tests/tasks/test_slider_predict.py +++ b/tests/tasks/test_slider_predict.py @@ -115,21 +115,6 @@ class TestSegSliderPredict(CommonTest): decode_sar=False) self.check_output_equal(pred_keeplast.shape, pred_whole.shape) - # 'vote' - save_dir = osp.join(td, 'vote') - self.model.slider_predict( - self.image_path, - save_dir, - 128, - 64, - self.transforms, - merge_strategy='vote') - pred_vote = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred_vote.shape, pred_whole.shape) - # 'accum' save_dir = osp.join(td, 'accum') self.model.slider_predict( @@ -138,7 +123,7 @@ class TestSegSliderPredict(CommonTest): 128, 64, self.transforms, - merge_strategy='vote') + merge_strategy='accum') pred_accum = T.decode_image( osp.join(save_dir, self.basename), to_uint8=False, @@ -253,21 +238,6 @@ class TestCDSliderPredict(CommonTest): decode_sar=False) self.check_output_equal(pred_keeplast.shape, pred_whole.shape) - # 'vote' - save_dir = osp.join(td, 'vote') - self.model.slider_predict( - self.image_paths, - save_dir, - 128, - 64, - self.transforms, - merge_strategy='vote') - pred_vote = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred_vote.shape, pred_whole.shape) - # 'accum' save_dir = osp.join(td, 'accum') self.model.slider_predict( @@ -276,7 +246,7 @@ class TestCDSliderPredict(CommonTest): 128, 64, self.transforms, - merge_strategy='vote') + merge_strategy='accum') pred_accum = T.decode_image( osp.join(save_dir, self.basename), to_uint8=False, From 8beb15c3c833848d0031fd371703daa27093b8a1 Mon Sep 17 00:00:00 2001 From: Lin Manhui Date: Wed, 7 Sep 2022 10:45:19 +0800 Subject: [PATCH 07/10] [Feat] Optimize `slide_predict()` (#3) --- docs/apis/data.md | 6 +- paddlers/deploy/predictor.py | 17 +- paddlers/tasks/change_detector.py | 10 +- paddlers/tasks/classifier.py | 2 +- paddlers/tasks/object_detector.py | 2 +- paddlers/tasks/restorer.py | 2 +- paddlers/tasks/segmenter.py | 8 +- paddlers/tasks/utils/slider_predict.py | 86 ++++-- paddlers/transforms/__init__.py | 17 +- paddlers/transforms/functions.py | 6 +- paddlers/transforms/operators.py | 10 +- tests/deploy/test_predictor.py | 29 +- tests/tasks/test_slider_predict.py | 404 +++++++++++-------------- 13 files changed, 302 insertions(+), 297 deletions(-) diff --git a/docs/apis/data.md b/docs/apis/data.md index 6be2e32..34afb1c 100644 --- a/docs/apis/data.md +++ b/docs/apis/data.md @@ -134,10 +134,12 @@ |-------|----|--------|-----| |`im_path`|`str`|输入图像路径。|| |`to_rgb`|`bool`|若为`True`,则执行BGR到RGB格式的转换。|`True`| -|`to_uint8`|`bool`|若为`True`,则将读取的图像数据量化并转换为uint8类型。|`True`| +|`to_uint8`|`bool`|若为`True`,则将读取的影像数据量化并转换为uint8类型。|`True`| |`decode_bgr`|`bool`|若为`True`,则自动将非地学格式影像(如jpeg影像)解析为BGR格式。|`True`| -|`decode_sar`|`bool`|若为`True`,则自动将2通道的地学格式影像(如GeoTiff影像)作为SAR影像解析。|`True`| +|`decode_sar`|`bool`|若为`True`,则自动将单通道的地学格式影像(如GeoTiff影像)作为SAR影像解析。|`True`| |`read_geo_info`|`bool`|若为`True`,则从影像中读取地理信息。|`False`| +|`use_stretch`|`bool`|是否对影像亮度进行2%线性拉伸。仅当`to_uint8`为`True`时有效。|`False`| +|`read_raw`|`bool`|若为`True`,等价于指定`to_rgb`和`to_uint8`为`False`,且该参数的优先级高于上述参数。|`False`| 返回格式如下: diff --git a/paddlers/deploy/predictor.py b/paddlers/deploy/predictor.py index 157abca..ab474d2 100644 --- a/paddlers/deploy/predictor.py +++ b/paddlers/deploy/predictor.py @@ -282,8 +282,8 @@ class Predictor(object): img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For scene classification, image restoration, object detection and semantic segmentation tasks, `img_file` should be either the path of the image to predict, a decoded image (a np.ndarray, which should be consistent with what you get from passing image path to - paddlers.transforms.decode_image()), or a list of image paths or decoded images. For change detection tasks, - `img_file` should be a tuple of image paths, a tuple of decoded images, or a list of tuples. + paddlers.transforms.decode_image(..., read_raw=True)), or a list of image paths or decoded images. For change + detection tasks, `img_file` should be a tuple of image paths, a tuple of decoded images, or a list of tuples. topk(int, optional): Top-k values to reserve in a classification result. Defaults to 1. transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms from `model.yml`. Defaults to None. @@ -332,7 +332,8 @@ class Predictor(object): overlap=36, transforms=None, invalid_value=255, - merge_strategy='keep_last'): + merge_strategy='keep_last', + batch_size=1): """ Do inference using sliding windows. Only semantic segmentation and change detection models are supported in the sliding-predicting mode. @@ -340,9 +341,9 @@ class Predictor(object): Args: img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For semantic segmentation tasks, `img_file` should be either the path of the image to predict, a decoded image (a np.ndarray, which should be - consistent with what you get from passing image path to paddlers.transforms.decode_image()), or a list of - image paths or decoded images. For change detection tasks, `img_file` should be a tuple of image paths, a - tuple of decoded images, or a list of tuples. + consistent with what you get from passing image path to paddlers.transforms.decode_image(..., read_raw=True)), + or a list of image paths or decoded images. For change detection tasks, `img_file` should be a tuple of + image paths, a tuple of decoded images, or a list of tuples. save_dir (str): Directory that contains saved geotiff file. block_size (list[int] | tuple[int] | int): Size of block. If `block_size` is a list or tuple, it should be in (W, H) format. @@ -355,6 +356,7 @@ class Predictor(object): {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' means keeping the values of the first and the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel according to accumulated probabilities. Defaults to 'keep_last'. + batch_size (int, optional): Batch size used in inference. Defaults to 1. """ slider_predict( partial( @@ -365,7 +367,8 @@ class Predictor(object): overlap, transforms, invalid_value, - merge_strategy) + merge_strategy, + batch_size) def batch_predict(self, image_list, **params): return self.predict(img_file=image_list, **params) diff --git a/paddlers/tasks/change_detector.py b/paddlers/tasks/change_detector.py index afccdc4..347e996 100644 --- a/paddlers/tasks/change_detector.py +++ b/paddlers/tasks/change_detector.py @@ -588,7 +588,8 @@ class BaseChangeDetector(BaseModel): overlap=36, transforms=None, invalid_value=255, - merge_strategy='keep_last'): + merge_strategy='keep_last', + batch_size=1): """ Do inference using sliding windows. @@ -611,10 +612,11 @@ class BaseChangeDetector(BaseModel): means keeping the values of the first and the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel according to accumulated probabilities. Defaults to 'keep_last'. + batch_size (int, optional): Batch size used in inference. Defaults to 1. """ slider_predict(self.predict, img_files, save_dir, block_size, overlap, - transforms, invalid_value, merge_strategy) + transforms, invalid_value, merge_strategy, batch_size) def preprocess(self, images, transforms, to_tensor=True): self._check_transforms(transforms, 'test') @@ -622,8 +624,8 @@ class BaseChangeDetector(BaseModel): batch_ori_shape = list() for im1, im2 in images: if isinstance(im1, str) or isinstance(im2, str): - im1 = decode_image(im1, to_rgb=False) - im2 = decode_image(im2, to_rgb=False) + im1 = decode_image(im1, read_raw=True) + im2 = decode_image(im2, read_raw=True) ori_shape = im1.shape[:2] # XXX: sample do not contain 'image_t1' and 'image_t2'. sample = {'image': im1, 'image2': im2} diff --git a/paddlers/tasks/classifier.py b/paddlers/tasks/classifier.py index 33ba5f3..12b6640 100644 --- a/paddlers/tasks/classifier.py +++ b/paddlers/tasks/classifier.py @@ -497,7 +497,7 @@ class BaseClassifier(BaseModel): batch_ori_shape = list() for im in images: if isinstance(im, str): - im = decode_image(im, to_rgb=False) + im = decode_image(im, read_raw=True) ori_shape = im.shape[:2] sample = {'image': im} im = transforms(sample) diff --git a/paddlers/tasks/object_detector.py b/paddlers/tasks/object_detector.py index 313a893..f42d7cc 100644 --- a/paddlers/tasks/object_detector.py +++ b/paddlers/tasks/object_detector.py @@ -617,7 +617,7 @@ class BaseDetector(BaseModel): batch_samples = list() for im in images: if isinstance(im, str): - im = decode_image(im, to_rgb=False) + im = decode_image(im, read_raw=True) sample = {'image': im} sample = transforms(sample) batch_samples.append(sample) diff --git a/paddlers/tasks/restorer.py b/paddlers/tasks/restorer.py index 61691f0..ce7e931 100644 --- a/paddlers/tasks/restorer.py +++ b/paddlers/tasks/restorer.py @@ -481,7 +481,7 @@ class BaseRestorer(BaseModel): batch_tar_shape = list() for im in images: if isinstance(im, str): - im = decode_image(im, to_rgb=False) + im = decode_image(im, read_raw=True) ori_shape = im.shape[:2] sample = {'image': im} im = transforms(sample)[0] diff --git a/paddlers/tasks/segmenter.py b/paddlers/tasks/segmenter.py index a319fc5..29f2aa4 100644 --- a/paddlers/tasks/segmenter.py +++ b/paddlers/tasks/segmenter.py @@ -560,7 +560,8 @@ class BaseSegmenter(BaseModel): overlap=36, transforms=None, invalid_value=255, - merge_strategy='keep_last'): + merge_strategy='keep_last', + batch_size=1): """ Do inference using sliding windows. @@ -583,10 +584,11 @@ class BaseSegmenter(BaseModel): means keeping the values of the first and the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel according to accumulated probabilities. Defaults to 'keep_last'. + batch_size (int, optional): Batch size used in inference. Defaults to 1. """ slider_predict(self.predict, img_file, save_dir, block_size, overlap, - transforms, invalid_value, merge_strategy) + transforms, invalid_value, merge_strategy, batch_size) def preprocess(self, images, transforms, to_tensor=True): self._check_transforms(transforms, 'test') @@ -594,7 +596,7 @@ class BaseSegmenter(BaseModel): batch_ori_shape = list() for im in images: if isinstance(im, str): - im = decode_image(im, to_rgb=False) + im = decode_image(im, read_raw=True) ori_shape = im.shape[:2] sample = {'image': im} im = transforms(sample)[0] diff --git a/paddlers/tasks/utils/slider_predict.py b/paddlers/tasks/utils/slider_predict.py index f32f7f0..1672eb0 100644 --- a/paddlers/tasks/utils/slider_predict.py +++ b/paddlers/tasks/utils/slider_predict.py @@ -14,6 +14,7 @@ import os import os.path as osp +import math from abc import ABCMeta, abstractmethod from collections import Counter, defaultdict @@ -117,10 +118,10 @@ class ProbCache(Cache): def roll_cache(self): if self.order == 'c': - self.cache = np.roll(self.cache, -self.sh, axis=0) + self.cache[:-self.sh] = self.cache[self.sh:] self.cache[-self.sh:, :] = 0 elif self.order == 'f': - self.cache = np.roll(self.cache, -self.sw, axis=1) + self.cache[:, :-self.sw] = self.cache[:, self.sw:] self.cache[:, -self.sw:] = 0 def get_block(self, i_st, j_st, h, w): @@ -128,7 +129,7 @@ class ProbCache(Cache): def slider_predict(predict_func, img_file, save_dir, block_size, overlap, - transforms, invalid_value, merge_strategy): + transforms, invalid_value, merge_strategy, batch_size): """ Do inference using sliding windows. @@ -151,6 +152,7 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, means keeping the values of the first and the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel according to accumulated probabilities. + batch_size (int): Batch size used in inference. """ try: @@ -200,6 +202,13 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, height = src_data.RasterYSize bands = src_data.RasterCount + # XXX: GDAL read behavior conforms to paddlers.transforms.decode_image(read_raw=True) + # except for SAR images. + if bands == 1: + logging.warning( + f"Detected `bands=1`. Please note that currently `slider_predict()` does not properly handle SAR images." + ) + if block_size[0] > width or block_size[1] > height: raise ValueError("`block_size` should not be larger than image size.") @@ -228,6 +237,8 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, if merge_strategy == 'accum': cache = ProbCache(height, width, *block_size[::-1], *step[::-1]) + batch_data = [] + batch_offsets = [] for yoff in range(0, height, step[1]): for xoff in range(0, width, step[0]): xsize, ysize = block_size @@ -236,6 +247,9 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, if yoff + ysize > height: yoff = height - ysize + is_end_of_col = yoff + ysize >= height + is_end_of_row = xoff + xsize >= width + # Read and fill im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( (1, 2, 0)) @@ -243,33 +257,49 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, if isinstance(img_file, tuple): im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( (1, 2, 0)) - # Predict - out = predict_func((im, im2), transforms=transforms) + batch_data.append((im, im2)) else: + batch_data.append(im) + + batch_offsets.append((xoff, yoff)) + + len_batch = len(batch_data) + + if is_end_of_row and is_end_of_col and len_batch < batch_size: + # Pad `batch_data` by repeating the last element + batch_data = batch_data + [batch_data[-1]] * (batch_size - + len_batch) + # While keeping `len(batch_offsets)` the number of valid elements in the batch + + if len(batch_data) == batch_size: # Predict - out = predict_func(im, transforms=transforms) - - pred = out['label_map'].astype('uint8') - pred = pred[:ysize, :xsize] - - # Deal with overlapping pixels - if merge_strategy == 'keep_first': - rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize) - mask = rd_block != invalid_value - pred = np.where(mask, rd_block, pred) - elif merge_strategy == 'keep_last': - pass - elif merge_strategy == 'accum': - prob = out['score_map'] - prob = prob[:ysize, :xsize] - cache.update_block(0, xoff, ysize, xsize, prob) - pred = cache.get_block(0, xoff, ysize, xsize) - if xoff + xsize >= width: - cache.roll_cache() - - # Write to file - band.WriteArray(pred, xoff, yoff) - dst_data.FlushCache() + batch_out = predict_func(batch_data, transforms=transforms) + + for out, (xoff_, yoff_) in zip(batch_out, batch_offsets): + pred = out['label_map'].astype('uint8') + pred = pred[:ysize, :xsize] + + # Deal with overlapping pixels + if merge_strategy == 'keep_first': + rd_block = band.ReadAsArray(xoff_, yoff_, xsize, ysize) + mask = rd_block != invalid_value + pred = np.where(mask, rd_block, pred) + elif merge_strategy == 'keep_last': + pass + elif merge_strategy == 'accum': + prob = out['score_map'] + prob = prob[:ysize, :xsize] + cache.update_block(0, xoff_, ysize, xsize, prob) + pred = cache.get_block(0, xoff_, ysize, xsize) + if xoff_ + xsize >= width: + cache.roll_cache() + + # Write to file + band.WriteArray(pred, xoff_, yoff_) + + dst_data.FlushCache() + batch_data.clear() + batch_offsets.clear() dst_data = None logging.info("GeoTiff file saved in {}.".format(save_file)) diff --git a/paddlers/transforms/__init__.py b/paddlers/transforms/__init__.py index aafc4dd..ec470fb 100644 --- a/paddlers/transforms/__init__.py +++ b/paddlers/transforms/__init__.py @@ -25,7 +25,9 @@ def decode_image(im_path, to_uint8=True, decode_bgr=True, decode_sar=True, - read_geo_info=False): + read_geo_info=False, + use_stretch=False, + read_raw=False): """ Decode an image. @@ -37,11 +39,16 @@ def decode_image(im_path, uint8 type. Defaults to True. decode_bgr (bool, optional): If True, automatically interpret a non-geo image (e.g. jpeg images) as a BGR image. Defaults to True. - decode_sar (bool, optional): If True, automatically interpret a two-channel + decode_sar (bool, optional): If True, automatically interpret a single-channel geo image (e.g. geotiff images) as a SAR image, set this argument to True. Defaults to True. read_geo_info (bool, optional): If True, read geographical information from the image. Deafults to False. + use_stretch (bool, optional): Whether to apply 2% linear stretch. Valid only if + `to_uint8` is True. Defaults to False. + read_raw (bool, optional): If True, equivalent to setting `to_rgb` and `to_uint8` + to False. Setting `read_raw` takes precedence over setting `to_rgb` and + `to_uint8`. Defaults to False. Returns: np.ndarray|tuple: If `read_geo_info` is False, return the decoded image. @@ -53,12 +60,16 @@ def decode_image(im_path, # Do a presence check. osp.exists() assumes `im_path` is a path-like object. if not osp.exists(im_path): raise ValueError(f"{im_path} does not exist!") + if read_raw: + to_rgb = False + to_uint8 = False decoder = T.DecodeImg( to_rgb=to_rgb, to_uint8=to_uint8, decode_bgr=decode_bgr, decode_sar=decode_sar, - read_geo_info=read_geo_info) + read_geo_info=read_geo_info, + use_stretch=use_stretch) # Deepcopy to avoid inplace modification sample = {'image': copy.deepcopy(im_path)} sample = decoder(sample) diff --git a/paddlers/transforms/functions.py b/paddlers/transforms/functions.py index 5550e33..8518de5 100644 --- a/paddlers/transforms/functions.py +++ b/paddlers/transforms/functions.py @@ -382,13 +382,13 @@ def resize_rle(rle, im_h, im_w, im_scale_x, im_scale_y, interp): return rle -def to_uint8(im, is_linear=False): +def to_uint8(im, stretch=False): """ Convert raster data to uint8 type. Args: im (np.ndarray): Input raster image. - is_linear (bool, optional): Use 2% linear stretch or not. Default is False. + stretch (bool, optional): Use 2% linear stretch or not. Default is False. Returns: np.ndarray: Image data with unit8 type. @@ -430,7 +430,7 @@ def to_uint8(im, is_linear=False): dtype = im.dtype.name if dtype != "uint8": im = _sample_norm(im) - if is_linear: + if stretch: im = _two_percent_linear(im) return im diff --git a/paddlers/transforms/operators.py b/paddlers/transforms/operators.py index 6110bdb..47e6cc4 100644 --- a/paddlers/transforms/operators.py +++ b/paddlers/transforms/operators.py @@ -179,11 +179,13 @@ class DecodeImg(Transform): uint8 type. Defaults to True. decode_bgr (bool, optional): If True, automatically interpret a non-geo image (e.g., jpeg images) as a BGR image. Defaults to True. - decode_sar (bool, optional): If True, automatically interpret a two-channel + decode_sar (bool, optional): If True, automatically interpret a single-channel geo image (e.g. geotiff images) as a SAR image, set this argument to True. Defaults to True. read_geo_info (bool, optional): If True, read geographical information from the image. Deafults to False. + use_stretch (bool, optional): Whether to apply 2% linear stretch. Valid only if + `to_uint8` is True. Defaults to False. """ def __init__(self, @@ -191,13 +193,15 @@ class DecodeImg(Transform): to_uint8=True, decode_bgr=True, decode_sar=True, - read_geo_info=False): + read_geo_info=False, + use_stretch=False): super(DecodeImg, self).__init__() self.to_rgb = to_rgb self.to_uint8 = to_uint8 self.decode_bgr = decode_bgr self.decode_sar = decode_sar self.read_geo_info = read_geo_info + self.use_stretch = use_stretch def read_img(self, img_path): img_format = imghdr.what(img_path) @@ -264,7 +268,7 @@ class DecodeImg(Transform): image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if self.to_uint8: - image = to_uint8(image) + image = to_uint8(image, stretch=self.use_stretch) if self.read_geo_info: return image, geo_info_dict diff --git a/tests/deploy/test_predictor.py b/tests/deploy/test_predictor.py index 8675e65..081b34a 100644 --- a/tests/deploy/test_predictor.py +++ b/tests/deploy/test_predictor.py @@ -145,8 +145,8 @@ class TestCDPredictor(TestPredictor): # Single input (ndarrays) input_ = (decode_image( - t1_path, to_rgb=False), decode_image( - t2_path, to_rgb=False)) # Reuse the name `input_` + t1_path, read_raw=True), decode_image( + t2_path, read_raw=True)) # Reuse the name `input_` out_single_array_p = predictor.predict(input_, transforms=transforms) self.check_dict_equal(out_single_array_p, out_single_file_p) out_single_array_t = trainer.predict(input_, transforms=transforms) @@ -169,8 +169,9 @@ class TestCDPredictor(TestPredictor): # Multiple inputs (ndarrays) input_ = [(decode_image( - t1_path, to_rgb=False), decode_image( - t2_path, to_rgb=False))] * num_inputs # Reuse the name `input_` + t1_path, read_raw=True), decode_image( + t2_path, + read_raw=True))] * num_inputs # Reuse the name `input_` out_multi_array_p = predictor.predict(input_, transforms=transforms) self.assertEqual(len(out_multi_array_p), num_inputs) out_multi_array_t = trainer.predict(input_, transforms=transforms) @@ -211,7 +212,7 @@ class TestClasPredictor(TestPredictor): # Single input (ndarray) input_ = decode_image( - single_input, to_rgb=False) # Reuse the name `input_` + single_input, read_raw=True) # Reuse the name `input_` out_single_array_p = predictor.predict(input_, transforms=transforms) self.check_dict_equal(out_single_array_p, out_single_file_p) out_single_array_t = trainer.predict(input_, transforms=transforms) @@ -235,7 +236,8 @@ class TestClasPredictor(TestPredictor): # Multiple inputs (ndarrays) input_ = [decode_image( - single_input, to_rgb=False)] * num_inputs # Reuse the name `input_` + single_input, + read_raw=True)] * num_inputs # Reuse the name `input_` out_multi_array_p = predictor.predict(input_, transforms=transforms) self.assertEqual(len(out_multi_array_p), num_inputs) out_multi_array_t = trainer.predict(input_, transforms=transforms) @@ -276,7 +278,7 @@ class TestDetPredictor(TestPredictor): # Single input (ndarray) input_ = decode_image( - single_input, to_rgb=False) # Reuse the name `input_` + single_input, read_raw=True) # Reuse the name `input_` predictor.predict(input_, transforms=transforms) trainer.predict(input_, transforms=transforms) out_single_array_list_p = predictor.predict( @@ -295,7 +297,8 @@ class TestDetPredictor(TestPredictor): # Multiple inputs (ndarrays) input_ = [decode_image( - single_input, to_rgb=False)] * num_inputs # Reuse the name `input_` + single_input, + read_raw=True)] * num_inputs # Reuse the name `input_` out_multi_array_p = predictor.predict(input_, transforms=transforms) self.assertEqual(len(out_multi_array_p), num_inputs) out_multi_array_t = trainer.predict(input_, transforms=transforms) @@ -329,7 +332,7 @@ class TestResPredictor(TestPredictor): # Single input (ndarray) input_ = decode_image( - single_input, to_rgb=False) # Reuse the name `input_` + single_input, read_raw=True) # Reuse the name `input_` predictor.predict(input_, transforms=transforms) trainer.predict(input_, transforms=transforms) out_single_array_list_p = predictor.predict( @@ -348,7 +351,8 @@ class TestResPredictor(TestPredictor): # Multiple inputs (ndarrays) input_ = [decode_image( - single_input, to_rgb=False)] * num_inputs # Reuse the name `input_` + single_input, + read_raw=True)] * num_inputs # Reuse the name `input_` out_multi_array_p = predictor.predict(input_, transforms=transforms) self.assertEqual(len(out_multi_array_p), num_inputs) out_multi_array_t = trainer.predict(input_, transforms=transforms) @@ -386,7 +390,7 @@ class TestSegPredictor(TestPredictor): # Single input (ndarray) input_ = decode_image( - single_input, to_rgb=False) # Reuse the name `input_` + single_input, read_raw=True) # Reuse the name `input_` out_single_array_p = predictor.predict(input_, transforms=transforms) self.check_dict_equal(out_single_array_p, out_single_file_p) out_single_array_t = trainer.predict(input_, transforms=transforms) @@ -409,7 +413,8 @@ class TestSegPredictor(TestPredictor): # Multiple inputs (ndarrays) input_ = [decode_image( - single_input, to_rgb=False)] * num_inputs # Reuse the name `input_` + single_input, + read_raw=True)] * num_inputs # Reuse the name `input_` out_multi_array_p = predictor.predict(input_, transforms=transforms) self.assertEqual(len(out_multi_array_p), num_inputs) out_multi_array_t = trainer.predict(input_, transforms=transforms) diff --git a/tests/tasks/test_slider_predict.py b/tests/tasks/test_slider_predict.py index d5eff91..5b4392d 100644 --- a/tests/tasks/test_slider_predict.py +++ b/tests/tasks/test_slider_predict.py @@ -20,7 +20,174 @@ import paddlers.transforms as T from testing_utils import CommonTest -class TestSegSliderPredict(CommonTest): +class _TestSliderPredictNamespace: + class TestSliderPredict(CommonTest): + def test_blocksize_and_overlap_whole(self): + # Original image size (256, 256) + with tempfile.TemporaryDirectory() as td: + # Whole-image inference using predict() + pred_whole = self.model.predict(self.image_path, + self.transforms) + pred_whole = pred_whole['label_map'] + + # Whole-image inference using slider_predict() + save_dir = osp.join(td, 'pred1') + self.model.slider_predict(self.image_path, save_dir, 256, 0, + self.transforms) + pred1 = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred1.shape, pred_whole.shape) + + # `block_size` == `overlap` + save_dir = osp.join(td, 'pred2') + with self.assertRaises(ValueError): + self.model.slider_predict(self.image_path, save_dir, 128, + 128, self.transforms) + + # `block_size` is a tuple + save_dir = osp.join(td, 'pred3') + self.model.slider_predict(self.image_path, save_dir, (128, 32), + 0, self.transforms) + pred3 = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred3.shape, pred_whole.shape) + + # `block_size` and `overlap` are both tuples + save_dir = osp.join(td, 'pred4') + self.model.slider_predict(self.image_path, save_dir, (128, 100), + (10, 5), self.transforms) + pred4 = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred4.shape, pred_whole.shape) + + # `block_size` larger than image size + save_dir = osp.join(td, 'pred5') + with self.assertRaises(ValueError): + self.model.slider_predict(self.image_path, save_dir, 512, 0, + self.transforms) + + def test_merge_strategy(self): + with tempfile.TemporaryDirectory() as td: + # Whole-image inference using predict() + pred_whole = self.model.predict(self.image_path, + self.transforms) + pred_whole = pred_whole['label_map'] + + # 'keep_first' + save_dir = osp.join(td, 'keep_first') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_first') + pred_keepfirst = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred_keepfirst.shape, pred_whole.shape) + + # 'keep_last' + save_dir = osp.join(td, 'keep_last') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_last') + pred_keeplast = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred_keeplast.shape, pred_whole.shape) + + # 'accum' + save_dir = osp.join(td, 'accum') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='accum') + pred_accum = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred_accum.shape, pred_whole.shape) + + def test_geo_info(self): + with tempfile.TemporaryDirectory() as td: + _, geo_info_in = T.decode_image( + self.ref_path, read_geo_info=True) + self.model.slider_predict(self.image_path, td, 128, 0, + self.transforms) + _, geo_info_out = T.decode_image( + osp.join(td, self.basename), read_geo_info=True) + self.assertEqual(geo_info_out['geo_trans'], + geo_info_in['geo_trans']) + self.assertEqual(geo_info_out['geo_proj'], + geo_info_in['geo_proj']) + + def test_batch_size(self): + with tempfile.TemporaryDirectory() as td: + # batch_size = 1 + save_dir = osp.join(td, 'bs1') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_first', + batch_size=1) + pred_bs1 = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + + # batch_size = 4 + save_dir = osp.join(td, 'bs4') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_first', + batch_size=4) + pred_bs4 = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred_bs4, pred_bs1) + + # batch_size = 8 + save_dir = osp.join(td, 'bs4') + self.model.slider_predict( + self.image_path, + save_dir, + 128, + 64, + self.transforms, + merge_strategy='keep_first', + batch_size=8) + pred_bs8 = T.decode_image( + osp.join(save_dir, self.basename), + read_raw=True, + decode_sar=False) + self.check_output_equal(pred_bs8, pred_bs1) + + +class TestSegSliderPredict(_TestSliderPredictNamespace.TestSliderPredict): def setUp(self): self.model = pdrs.tasks.seg.UNet(in_channels=10) self.transforms = T.Compose([ @@ -28,239 +195,18 @@ class TestSegSliderPredict(CommonTest): T.ArrangeSegmenter('test') ]) self.image_path = "data/ssst/multispectral.tif" - self.basename = osp.basename(self.image_path) - - def test_blocksize_and_overlap_whole(self): - # Original image size (256, 256) - with tempfile.TemporaryDirectory() as td: - # Whole-image inference using predict() - pred_whole = self.model.predict(self.image_path, self.transforms) - pred_whole = pred_whole['label_map'] - - # Whole-image inference using slider_predict() - save_dir = osp.join(td, 'pred1') - self.model.slider_predict(self.image_path, save_dir, 256, 0, - self.transforms) - pred1 = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred1.shape, pred_whole.shape) - - # `block_size` == `overlap` - save_dir = osp.join(td, 'pred2') - with self.assertRaises(ValueError): - self.model.slider_predict(self.image_path, save_dir, 128, 128, - self.transforms) - - # `block_size` is a tuple - save_dir = osp.join(td, 'pred3') - self.model.slider_predict(self.image_path, save_dir, (128, 32), 0, - self.transforms) - pred3 = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred3.shape, pred_whole.shape) - - # `block_size` and `overlap` are both tuples - save_dir = osp.join(td, 'pred4') - self.model.slider_predict(self.image_path, save_dir, (128, 100), - (10, 5), self.transforms) - pred4 = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred4.shape, pred_whole.shape) - - # `block_size` larger than image size - save_dir = osp.join(td, 'pred5') - with self.assertRaises(ValueError): - self.model.slider_predict(self.image_path, save_dir, 512, 0, - self.transforms) - - def test_merge_strategy(self): - with tempfile.TemporaryDirectory() as td: - # Whole-image inference using predict() - pred_whole = self.model.predict(self.image_path, self.transforms) - pred_whole = pred_whole['label_map'] - - # 'keep_first' - save_dir = osp.join(td, 'keep_first') - self.model.slider_predict( - self.image_path, - save_dir, - 128, - 64, - self.transforms, - merge_strategy='keep_first') - pred_keepfirst = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred_keepfirst.shape, pred_whole.shape) - - # 'keep_last' - save_dir = osp.join(td, 'keep_last') - self.model.slider_predict( - self.image_path, - save_dir, - 128, - 64, - self.transforms, - merge_strategy='keep_last') - pred_keeplast = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred_keeplast.shape, pred_whole.shape) - - # 'accum' - save_dir = osp.join(td, 'accum') - self.model.slider_predict( - self.image_path, - save_dir, - 128, - 64, - self.transforms, - merge_strategy='accum') - pred_accum = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred_accum.shape, pred_whole.shape) + self.ref_path = self.image_path + self.basename = osp.basename(self.ref_path) - def test_geo_info(self): - with tempfile.TemporaryDirectory() as td: - _, geo_info_in = T.decode_image(self.image_path, read_geo_info=True) - self.model.slider_predict(self.image_path, td, 128, 0, - self.transforms) - _, geo_info_out = T.decode_image( - osp.join(td, self.basename), read_geo_info=True) - self.assertEqual(geo_info_out['geo_trans'], - geo_info_in['geo_trans']) - self.assertEqual(geo_info_out['geo_proj'], geo_info_in['geo_proj']) - -class TestCDSliderPredict(CommonTest): +class TestCDSliderPredict(_TestSliderPredictNamespace.TestSliderPredict): def setUp(self): self.model = pdrs.tasks.cd.BIT(in_channels=10) self.transforms = T.Compose([ T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10), T.ArrangeChangeDetector('test') ]) - self.image_paths = ("data/ssmt/multispectral_t1.tif", - "data/ssmt/multispectral_t2.tif") - self.basename = osp.basename(self.image_paths[0]) - - def test_blocksize_and_overlap_whole(self): - # Original image size (256, 256) - with tempfile.TemporaryDirectory() as td: - # Whole-image inference using predict() - pred_whole = self.model.predict(self.image_paths, self.transforms) - pred_whole = pred_whole['label_map'] - - # Whole-image inference using slider_predict() - save_dir = osp.join(td, 'pred1') - self.model.slider_predict(self.image_paths, save_dir, 256, 0, - self.transforms) - pred1 = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred1.shape, pred_whole.shape) - - # `block_size` == `overlap` - save_dir = osp.join(td, 'pred2') - with self.assertRaises(ValueError): - self.model.slider_predict(self.image_paths, save_dir, 128, 128, - self.transforms) - - # `block_size` is a tuple - save_dir = osp.join(td, 'pred3') - self.model.slider_predict(self.image_paths, save_dir, (128, 32), 0, - self.transforms) - pred3 = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred3.shape, pred_whole.shape) - - # `block_size` and `overlap` are both tuples - save_dir = osp.join(td, 'pred4') - self.model.slider_predict(self.image_paths, save_dir, (128, 100), - (10, 5), self.transforms) - pred4 = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred4.shape, pred_whole.shape) - - # `block_size` larger than image size - save_dir = osp.join(td, 'pred5') - with self.assertRaises(ValueError): - self.model.slider_predict(self.image_paths, save_dir, 512, 0, - self.transforms) - - def test_merge_strategy(self): - with tempfile.TemporaryDirectory() as td: - # Whole-image inference using predict() - pred_whole = self.model.predict(self.image_paths, self.transforms) - pred_whole = pred_whole['label_map'] - - # 'keep_first' - save_dir = osp.join(td, 'keep_first') - self.model.slider_predict( - self.image_paths, - save_dir, - 128, - 64, - self.transforms, - merge_strategy='keep_first') - pred_keepfirst = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred_keepfirst.shape, pred_whole.shape) - - # 'keep_last' - save_dir = osp.join(td, 'keep_last') - self.model.slider_predict( - self.image_paths, - save_dir, - 128, - 64, - self.transforms, - merge_strategy='keep_last') - pred_keeplast = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred_keeplast.shape, pred_whole.shape) - - # 'accum' - save_dir = osp.join(td, 'accum') - self.model.slider_predict( - self.image_paths, - save_dir, - 128, - 64, - self.transforms, - merge_strategy='accum') - pred_accum = T.decode_image( - osp.join(save_dir, self.basename), - to_uint8=False, - decode_sar=False) - self.check_output_equal(pred_accum.shape, pred_whole.shape) - - def test_geo_info(self): - with tempfile.TemporaryDirectory() as td: - _, geo_info_in = T.decode_image( - self.image_paths[0], read_geo_info=True) - self.model.slider_predict(self.image_paths, td, 128, 0, - self.transforms) - _, geo_info_out = T.decode_image( - osp.join(td, self.basename), read_geo_info=True) - self.assertEqual(geo_info_out['geo_trans'], - geo_info_in['geo_trans']) - self.assertEqual(geo_info_out['geo_proj'], geo_info_in['geo_proj']) + self.image_path = ("data/ssmt/multispectral_t1.tif", + "data/ssmt/multispectral_t2.tif") + self.ref_path = self.image_path[0] + self.basename = osp.basename(self.ref_path) From 3939cbff16e77e66c7ffe44aa06b889dc4708fe6 Mon Sep 17 00:00:00 2001 From: Lin Manhui Date: Thu, 8 Sep 2022 10:19:30 +0800 Subject: [PATCH 08/10] Update infer.md --- docs/apis/infer.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/apis/infer.md b/docs/apis/infer.md index d3c272c..4431ed2 100644 --- a/docs/apis/infer.md +++ b/docs/apis/infer.md @@ -157,7 +157,8 @@ def slider_predict(self, overlap=36, transforms=None, invalid_value=255, - merge_strategy='keep_last'): + merge_strategy='keep_last', + batch_size=1): ``` 输入参数列表: @@ -171,6 +172,7 @@ def slider_predict(self, |`transforms`|`paddlers.transforms.Compose` \| `None`|对输入数据应用的数据变换算子。若为`None`,则使用训练器在验证阶段使用的数据变换算子。|`None`| |`invalid_value`|`int`|输出影像中用于标记无效像素的数值。|`255`| |`merge_strategy`|`str`|合并滑窗重叠区域使用的策略。`'keep_first'`表示保留遍历顺序(从左至右,从上往下,列优先)最靠前的窗口的预测类别;`'keep_last'`表示保留遍历顺序最靠后的窗口的预测类别;`'accum'`表示通过将各窗口在重叠区域给出的预测概率累加,计算最终预测类别。需要注意的是,在对大尺寸影像进行`overlap`较大的密集推理时,使用`'accum'`策略可能导致较长的推理时间,但一般能够在窗口交界部分取得更好的表现。|`'keep_last'`| +|`batch_size`|`int`|预测时使用的mini-batch大小。|`1`| 变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准,存储滑窗推理结果的文件名也与第一时相影像文件相同。 From 2f60abd89995dc0a803aeb9b579c108f682ff03b Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Thu, 8 Sep 2022 14:59:17 +0800 Subject: [PATCH 09/10] Refactor and fix bugs --- paddlers/tasks/utils/slider_predict.py | 212 ++++++++++++++++++++----- 1 file changed, 172 insertions(+), 40 deletions(-) diff --git a/paddlers/tasks/utils/slider_predict.py b/paddlers/tasks/utils/slider_predict.py index 1672eb0..620997f 100644 --- a/paddlers/tasks/utils/slider_predict.py +++ b/paddlers/tasks/utils/slider_predict.py @@ -19,6 +19,7 @@ from abc import ABCMeta, abstractmethod from collections import Counter, defaultdict import numpy as np +from tqdm import tqdm import paddlers.utils.logging as logging @@ -31,6 +32,7 @@ class Cache(metaclass=ABCMeta): class SlowCache(Cache): def __init__(self): + super(SlowCache, self).__init__() self.cache = defaultdict(Counter) def push_pixel(self, i, j, l): @@ -66,6 +68,7 @@ class SlowCache(Cache): class ProbCache(Cache): def __init__(self, h, w, ch, cw, sh, sw, dtype=np.float32, order='c'): + super(ProbCache, self).__init__() self.cache = None self.h = h self.w = w @@ -116,20 +119,139 @@ class ProbCache(Cache): self._alloc_memory(nc) self.cache[i_st:i_st + h, j_st:j_st + w] += prob_map - def roll_cache(self): + def roll_cache(self, shift): if self.order == 'c': - self.cache[:-self.sh] = self.cache[self.sh:] - self.cache[-self.sh:, :] = 0 + self.cache[:-shift] = self.cache[shift:] + self.cache[-shift:, :] = 0 elif self.order == 'f': - self.cache[:, :-self.sw] = self.cache[:, self.sw:] - self.cache[:, -self.sw:] = 0 + self.cache[:, :-shift] = self.cache[:, shift:] + self.cache[:, -shift:] = 0 def get_block(self, i_st, j_st, h, w): return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2) -def slider_predict(predict_func, img_file, save_dir, block_size, overlap, - transforms, invalid_value, merge_strategy, batch_size): +class OverlapProcessor(metaclass=ABCMeta): + def __init__(self, h, w, ch, cw, sh, sw): + super(OverlapProcessor, self).__init__() + self.h = h + self.w = w + self.ch = ch + self.cw = cw + self.sh = sh + self.sw = sw + + @abstractmethod + def process_pred(self, out, xoff, yoff): + pass + + +class KeepFirstProcessor(OverlapProcessor): + def __init__(self, h, w, ch, cw, sh, sw, ds, inval=255): + super(KeepFirstProcessor, self).__init__(h, w, ch, cw, sh, sw) + self.ds = ds + self.inval = inval + + def process_pred(self, out, xoff, yoff): + pred = out['label_map'] + pred = pred[:self.ch, :self.cw] + rd_block = self.ds.ReadAsArray(xoff, yoff, self.cw, self.ch) + mask = rd_block != self.inval + pred = np.where(mask, rd_block, pred) + return pred + + +class KeepLastProcessor(OverlapProcessor): + def process_pred(self, out, xoff, yoff): + pred = out['label_map'] + pred = pred[:self.ch, :self.cw] + return pred + + +class AccumProcessor(OverlapProcessor): + def __init__(self, + h, + w, + ch, + cw, + sh, + sw, + dtype=np.float16, + assign_weight=True): + super(AccumProcessor, self).__init__(h, w, ch, cw, sh, sw) + self.cache = ProbCache(h, w, ch, cw, sh, sw, dtype=dtype, order='c') + self.prev_yoff = None + self.assign_weight = assign_weight + + def process_pred(self, out, xoff, yoff): + if self.prev_yoff is not None and yoff != self.prev_yoff: + if yoff < self.prev_yoff: + raise RuntimeError + self.cache.roll_cache(yoff - self.prev_yoff) + pred = out['label_map'] + pred = pred[:self.ch, :self.cw] + prob = out['score_map'] + prob = prob[:self.ch, :self.cw] + if self.assign_weight: + prob = assign_border_weights(prob, border_ratio=0.25, inplace=True) + self.cache.update_block(0, xoff, self.ch, self.cw, prob) + pred = self.cache.get_block(0, xoff, self.ch, self.cw) + self.prev_yoff = yoff + return pred + + +def assign_border_weights(array, weight=0.5, border_ratio=0.25, inplace=True): + if not inplace: + array = array.copy() + h, w = array.shape[:2] + hm, wm = int(h * border_ratio), int(w * border_ratio) + array[:hm] *= weight + array[-hm:] *= weight + array[:, :wm] *= weight + array[:, -wm:] *= weight + return array + + +def read_block(ds, + xoff, + yoff, + xsize, + ysize, + tar_xsize=None, + tar_ysize=None, + pad_val=0): + if tar_xsize is None: + tar_xsize = xsize + if tar_ysize is None: + tar_ysize = ysize + # Read data from dataset + block = ds.ReadAsArray(xoff, yoff, xsize, ysize) + c, real_ysize, real_xsize = block.shape + assert real_ysize == ysize and real_xsize == xsize + # [c, h, w] -> [h, w, c] + block = block.transpose((1, 2, 0)) + if (real_ysize, real_xsize) != (tar_ysize, tar_xsize): + if real_ysize >= tar_ysize or real_xsize >= tar_xsize: + raise ValueError + padded_block = np.full( + (tar_ysize, tar_xsize, c), fill_value=pad_val, dtype=block.dtype) + # Fill + padded_block[:real_ysize, :real_xsize] = block + return padded_block + else: + return block + + +def slider_predict(predict_func, + img_file, + save_dir, + block_size, + overlap, + transforms, + invalid_value, + merge_strategy, + batch_size, + show_progress=False): """ Do inference using sliding windows. @@ -153,6 +275,8 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, traversal order, respectively. 'accum' means determining the class of an overlapping pixel according to accumulated probabilities. batch_size (int): Batch size used in inference. + show_progress (bool, optional): Whether to show prediction progress with a + progress bar. Defaults to True. """ try: @@ -175,10 +299,6 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, raise ValueError( "`overlap` must be a tuple/list of length 2 or an integer.") - if merge_strategy not in ('keep_first', 'keep_last', 'accum'): - raise ValueError("{} is not a supported stragegy for block merging.". - format(merge_strategy)) - step = np.array( block_size, dtype=np.int32) - np.array( overlap, dtype=np.int32) @@ -234,29 +354,50 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, # When there is no overlap or the whole image is used as input, # use 'keep_last' strategy as it introduces least overheads merge_strategy = 'keep_last' - if merge_strategy == 'accum': - cache = ProbCache(height, width, *block_size[::-1], *step[::-1]) + if merge_strategy == 'keep_first': + overlap_processor = KeepFirstProcessor( + height, + width, + *block_size[::-1], + *step[::-1], + band, + inval=invalid_value) + elif merge_strategy == 'keep_last': + overlap_processor = KeepLastProcessor(height, width, *block_size[::-1], + *step[::-1]) + elif merge_strategy == 'accum': + overlap_processor = AccumProcessor(height, width, *block_size[::-1], + *step[::-1]) + else: + raise ValueError("{} is not a supported stragegy for block merging.". + format(merge_strategy)) + + xsize, ysize = block_size + num_blocks = math.ceil(height / step[1]) * math.ceil(width / step[0]) + cnt = 0 + if show_progress: + pb = tqdm(total=num_blocks) batch_data = [] batch_offsets = [] for yoff in range(0, height, step[1]): for xoff in range(0, width, step[0]): - xsize, ysize = block_size if xoff + xsize > width: xoff = width - xsize + is_end_of_row = True + else: + is_end_of_row = False if yoff + ysize > height: yoff = height - ysize + is_end_of_col = True + else: + is_end_of_col = False - is_end_of_col = yoff + ysize >= height - is_end_of_row = xoff + xsize >= width - - # Read and fill - im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( - (1, 2, 0)) + # Read + im = read_block(src_data, xoff, yoff, xsize, ysize) if isinstance(img_file, tuple): - im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( - (1, 2, 0)) + im2 = read_block(src2_data, xoff, yoff, xsize, ysize) batch_data.append((im, im2)) else: batch_data.append(im) @@ -276,24 +417,8 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, batch_out = predict_func(batch_data, transforms=transforms) for out, (xoff_, yoff_) in zip(batch_out, batch_offsets): - pred = out['label_map'].astype('uint8') - pred = pred[:ysize, :xsize] - - # Deal with overlapping pixels - if merge_strategy == 'keep_first': - rd_block = band.ReadAsArray(xoff_, yoff_, xsize, ysize) - mask = rd_block != invalid_value - pred = np.where(mask, rd_block, pred) - elif merge_strategy == 'keep_last': - pass - elif merge_strategy == 'accum': - prob = out['score_map'] - prob = prob[:ysize, :xsize] - cache.update_block(0, xoff_, ysize, xsize, prob) - pred = cache.get_block(0, xoff_, ysize, xsize) - if xoff_ + xsize >= width: - cache.roll_cache() - + # Get processed result + pred = overlap_processor.process_pred(out, xoff_, yoff_) # Write to file band.WriteArray(pred, xoff_, yoff_) @@ -301,5 +426,12 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, batch_data.clear() batch_offsets.clear() + cnt += 1 + + if show_progress: + pb.update(1) + pb.set_description("{} out of {} blocks processed.".format( + cnt, num_blocks)) + dst_data = None logging.info("GeoTiff file saved in {}.".format(save_file)) From db0cc05bfeb9dc8b0756d34d594154b5723e646f Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Thu, 8 Sep 2022 14:59:34 +0800 Subject: [PATCH 10/10] Add verbose mode --- docs/apis/infer.md | 4 +++- paddlers/deploy/predictor.py | 7 +++++-- paddlers/tasks/change_detector.py | 7 +++++-- paddlers/tasks/segmenter.py | 7 +++++-- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/docs/apis/infer.md b/docs/apis/infer.md index 4431ed2..0217a20 100644 --- a/docs/apis/infer.md +++ b/docs/apis/infer.md @@ -158,7 +158,8 @@ def slider_predict(self, transforms=None, invalid_value=255, merge_strategy='keep_last', - batch_size=1): + batch_size=1, + quiet=False): ``` 输入参数列表: @@ -173,6 +174,7 @@ def slider_predict(self, |`invalid_value`|`int`|输出影像中用于标记无效像素的数值。|`255`| |`merge_strategy`|`str`|合并滑窗重叠区域使用的策略。`'keep_first'`表示保留遍历顺序(从左至右,从上往下,列优先)最靠前的窗口的预测类别;`'keep_last'`表示保留遍历顺序最靠后的窗口的预测类别;`'accum'`表示通过将各窗口在重叠区域给出的预测概率累加,计算最终预测类别。需要注意的是,在对大尺寸影像进行`overlap`较大的密集推理时,使用`'accum'`策略可能导致较长的推理时间,但一般能够在窗口交界部分取得更好的表现。|`'keep_last'`| |`batch_size`|`int`|预测时使用的mini-batch大小。|`1`| +|`quiet`|`bool`|若为`True`,不显示预测进度。|`False`| 变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准,存储滑窗推理结果的文件名也与第一时相影像文件相同。 diff --git a/paddlers/deploy/predictor.py b/paddlers/deploy/predictor.py index ab474d2..24bb337 100644 --- a/paddlers/deploy/predictor.py +++ b/paddlers/deploy/predictor.py @@ -333,7 +333,8 @@ class Predictor(object): transforms=None, invalid_value=255, merge_strategy='keep_last', - batch_size=1): + batch_size=1, + quiet=False): """ Do inference using sliding windows. Only semantic segmentation and change detection models are supported in the sliding-predicting mode. @@ -357,6 +358,7 @@ class Predictor(object): the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel according to accumulated probabilities. Defaults to 'keep_last'. batch_size (int, optional): Batch size used in inference. Defaults to 1. + quiet (bool, optional): If True, disable the progress bar. Defaults to False. """ slider_predict( partial( @@ -368,7 +370,8 @@ class Predictor(object): transforms, invalid_value, merge_strategy, - batch_size) + batch_size, + not quiet) def batch_predict(self, image_list, **params): return self.predict(img_file=image_list, **params) diff --git a/paddlers/tasks/change_detector.py b/paddlers/tasks/change_detector.py index 347e996..6df35d8 100644 --- a/paddlers/tasks/change_detector.py +++ b/paddlers/tasks/change_detector.py @@ -589,7 +589,8 @@ class BaseChangeDetector(BaseModel): transforms=None, invalid_value=255, merge_strategy='keep_last', - batch_size=1): + batch_size=1, + quiet=False): """ Do inference using sliding windows. @@ -613,10 +614,12 @@ class BaseChangeDetector(BaseModel): order, respectively. 'accum' means determining the class of an overlapping pixel according to accumulated probabilities. Defaults to 'keep_last'. batch_size (int, optional): Batch size used in inference. Defaults to 1. + quiet (bool, optional): If True, disable the progress bar. Defaults to False. """ slider_predict(self.predict, img_files, save_dir, block_size, overlap, - transforms, invalid_value, merge_strategy, batch_size) + transforms, invalid_value, merge_strategy, batch_size, + not quiet) def preprocess(self, images, transforms, to_tensor=True): self._check_transforms(transforms, 'test') diff --git a/paddlers/tasks/segmenter.py b/paddlers/tasks/segmenter.py index 29f2aa4..aa54f7c 100644 --- a/paddlers/tasks/segmenter.py +++ b/paddlers/tasks/segmenter.py @@ -561,7 +561,8 @@ class BaseSegmenter(BaseModel): transforms=None, invalid_value=255, merge_strategy='keep_last', - batch_size=1): + batch_size=1, + quiet=False): """ Do inference using sliding windows. @@ -585,10 +586,12 @@ class BaseSegmenter(BaseModel): order, respectively. 'accum' means determining the class of an overlapping pixel according to accumulated probabilities. Defaults to 'keep_last'. batch_size (int, optional): Batch size used in inference. Defaults to 1. + quiet (bool, optional): If True, disable the progress bar. Defaults to False. """ slider_predict(self.predict, img_file, save_dir, block_size, overlap, - transforms, invalid_value, merge_strategy, batch_size) + transforms, invalid_value, merge_strategy, batch_size, + not quiet) def preprocess(self, images, transforms, to_tensor=True): self._check_transforms(transforms, 'test')