From 05e1eeee6eb7a31ef5e1299f32f09b067e2f7e85 Mon Sep 17 00:00:00 2001 From: Yizhou Chen Date: Wed, 13 Jul 2022 11:32:45 +0800 Subject: [PATCH] [Feature] Add slider in cd (#92) * [Feature] Add cd slider * [Fix] Tuple instead of list * [Fix] Spell repair * [Fix] Spell repair --- paddlers/tasks/change_detector.py | 83 +++++++++++++++++++++++++++++++ paddlers/tasks/segmenter.py | 6 +-- 2 files changed, 86 insertions(+), 3 deletions(-) diff --git a/paddlers/tasks/change_detector.py b/paddlers/tasks/change_detector.py index 7fbc425..9035127 100644 --- a/paddlers/tasks/change_detector.py +++ b/paddlers/tasks/change_detector.py @@ -13,6 +13,7 @@ # limitations under the License. import math +import os import os.path as osp from collections import OrderedDict from operator import attrgetter @@ -545,6 +546,88 @@ class BaseChangeDetector(BaseModel): } return prediction + def slider_predict(self, img_file, save_dir, block_size, overlap=36, transforms=None): + """ + Do inference. + Args: + Args: + img_file(List[str]): + List of image paths. + save_dir(str): + Directory that contains saved geotiff file. + block_size(List[int] or Tuple[int], int): + The size of block. + overlap(List[int] or Tuple[int], int): + The overlap between two blocks. Defaults to 36. + transforms(paddlers.transforms.Compose or None, optional): + Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None. + """ + try: + from osgeo import gdal + except: + import gdal + + if len(img_file) != 2: + raise ValueError("`img_file` must be a list of length 2.") + if isinstance(block_size, int): + block_size = (block_size, block_size) + elif isinstance(block_size, (tuple, list)) and len(block_size) == 2: + block_size = tuple(block_size) + else: + raise ValueError("`block_size` must be a tuple/list of length 2 or an integer.") + if isinstance(overlap, int): + overlap = (overlap, overlap) + elif isinstance(overlap, (tuple, list)) and len(overlap) == 2: + overlap = tuple(overlap) + else: + raise ValueError("`overlap` must be a tuple/list of length 2 or an integer.") + + src1_data = gdal.Open(img_file[0]) + src2_data = gdal.Open(img_file[1]) + width = src1_data.RasterXSize + height = src1_data.RasterYSize + bands = src1_data.RasterCount + + driver = gdal.GetDriverByName("GTiff") + file_name = osp.splitext(osp.normpath(img_file[0]).split(os.sep)[-1])[0] + ".tif" + if not osp.exists(save_dir): + os.makedirs(save_dir) + save_file = osp.join(save_dir, file_name) + dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte) + dst_data.SetGeoTransform(src1_data.GetGeoTransform()) + dst_data.SetProjection(src1_data.GetProjection()) + band = dst_data.GetRasterBand(1) + band.WriteArray(255 * np.ones((height, width), dtype="uint8")) + + step = np.array(block_size) - np.array(overlap) + for yoff in range(0, height, step[1]): + for xoff in range(0, width, step[0]): + xsize, ysize = block_size + if xoff + xsize > width: + xsize = int(width - xoff) + if yoff + ysize > height: + ysize = int(height - yoff) + im1 = src1_data.ReadAsArray(int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0)) + im2 = src2_data.ReadAsArray(int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0)) + # fill + h, w = im1.shape[:2] + im1_fill = np.zeros((block_size[1], block_size[0], bands), dtype=im1.dtype) + im2_fill = im1_fill.copy() + im1_fill[:h, :w, :] = im1 + im2_fill[:h, :w, :] = im2 + im_fill = (im1_fill, im2_fill) + # predict + pred = self.predict(im_fill, transforms)["label_map"].astype("uint8") + # overlap + rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize) + mask = (rd_block == pred[:h, :w]) | (rd_block == 255) + temp = pred[:h, :w].copy() + temp[mask == False] = 0 + band.WriteArray(temp, int(xoff), int(yoff)) + dst_data.FlushCache() + dst_data = None + print("GeoTiff saved in {}.".format(save_file)) + def _preprocess(self, images, transforms, to_tensor=True): arrange_transforms( model_type=self.model_type, transforms=transforms, mode='test') diff --git a/paddlers/tasks/segmenter.py b/paddlers/tasks/segmenter.py index 5462094..32b7dd0 100644 --- a/paddlers/tasks/segmenter.py +++ b/paddlers/tasks/segmenter.py @@ -527,7 +527,7 @@ class BaseSegmenter(BaseModel): img_file(str): Image path. save_dir(str): - Folder of geotiff saved. + Directory that contains saved geotiff file. block_size(List[int] or Tuple[int], int): The size of block. overlap(List[int] or Tuple[int], int): @@ -545,13 +545,13 @@ class BaseSegmenter(BaseModel): elif isinstance(block_size, (tuple, list)) and len(block_size) == 2: block_size = tuple(block_size) else: - raise ValueError("`block_size` must be a tuple/list of length 2 or a integer.") + raise ValueError("`block_size` must be a tuple/list of length 2 or an integer.") if isinstance(overlap, int): overlap = (overlap, overlap) elif isinstance(overlap, (tuple, list)) and len(overlap) == 2: overlap = tuple(overlap) else: - raise ValueError("`overlap` must be a tuple/list of length 2 or a integer.") + raise ValueError("`overlap` must be a tuple/list of length 2 or an integer.") src_data = gdal.Open(img_file) width = src_data.RasterXSize