From 0c35d0e108851a362b76269ba3c7751c8e558b27 Mon Sep 17 00:00:00 2001 From: geoyee Date: Thu, 3 Mar 2022 18:26:28 +0800 Subject: [PATCH] [Feature] Add raster class --- paddlers/datasets/__init__.py | 1 + paddlers/datasets/raster.py | 140 ++++++++++++++++++++++++++++++++++ paddlers/utils/__init__.py | 1 + paddlers/utils/convert.py | 95 +++++++++++++++++++++++ 4 files changed, 237 insertions(+) create mode 100644 paddlers/datasets/raster.py create mode 100644 paddlers/utils/convert.py diff --git a/paddlers/datasets/__init__.py b/paddlers/datasets/__init__.py index 4e9e35e..4e31bee 100644 --- a/paddlers/datasets/__init__.py +++ b/paddlers/datasets/__init__.py @@ -1,2 +1,3 @@ from .voc import VOCDetection from .seg_dataset import SegDataset +from .raster import Raster \ No newline at end of file diff --git a/paddlers/datasets/raster.py b/paddlers/datasets/raster.py new file mode 100644 index 0000000..c3d9b03 --- /dev/null +++ b/paddlers/datasets/raster.py @@ -0,0 +1,140 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path as osp +import numpy as np +from typing import List, Tuple, Union +from paddlers.utils import raster2uint8 + +try: + from osgeo import gdal +except: + import gdal + + +class Raster: + def __init__(self, + path: str, + band_list: Union[List[int], Tuple[int], None]=None, + is_sar: bool=False, # TODO: Remove this param + is_src: bool=False) -> None: + """ Class of read raster. + + Args: + path (str): The path of raster. + band_list (Union[List[int], Tuple[int], None], optional): + band list (start with 1) or None (all of bands). Defaults to None. + is_sar (bool, optional): The raster is SAR or not. Defaults to False. + is_src (bool, optional): + Return raw data or not (convert uint8/float32). Defaults to False. + """ + super(Raster, self).__init__() + if osp.exists(path): + self.path = path + self.__src_data = gdal.Open(path) + self.__getInfo() + self.is_sar = is_sar + self.is_src = is_src + self.setBands(band_list) + else: + raise ValueError("The path {0} not exists.".format(path)) + + def setBands(self, + band_list: Union[List[int], Tuple[int], None]) -> None: + """ Set band of data. + + Args: + band_list (Union[List[int], Tuple[int], None]): + band list (start with 1) or None (all of bands). + """ + if band_list is not None: + if len(band_list) > self.bands: + raise ValueError("The lenght of band_list must be less than {0}.".format(str(self.bands))) + if max(band_list) > self.bands or min(band_list) < 1: + raise ValueError("The range of band_list must within [1, {0}].".format(str(self.bands))) + self.band_list = band_list + + def getArray(self, + start_loc: Union[List[int], Tuple[int], None]=None, + block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray: + """ Get ndarray data + + Args: + start_loc (Union[List[int], Tuple[int], None], optional): + Coordinates of the upper left corner of the block, if None means return full image. + block_size (Union[List[int], Tuple[int]], optional): + Block size. Defaults to [512, 512]. + + Returns: + np.ndarray: data's ndarray. + """ + if start_loc is None: + return self.__getAarray() + else: + return self.__getBlock(start_loc, block_size) + + def __getInfo(self) -> None: + self.bands = self.__src_data.RasterCount + self.width = self.__src_data.RasterXSize + self.height = self.__src_data.RasterYSize + + def __getAarray(self, window: Union[None, List[int], Tuple[int]]=None) -> np.ndarray: + if window is not None: + xoff, yoff, xsize, ysize = window + if self.band_list is None: + if window is None: + ima = self.__src_data.ReadAsArray() + else: + ima = self.__src_data.ReadAsArray(xoff, yoff, xsize, ysize) + else: + band_array = [] + for b in self.band_list: + if window is None: + band_i = self.__src_data.GetRasterBand(b).ReadAsArray() + else: + band_i = self.__src_data.GetRasterBand(b).ReadAsArray(xoff, yoff, xsize, ysize) + band_array.append(band_i) + ima = np.stack(band_array, axis=0) + if self.bands == 1: + if self.is_sar: + ima = abs(ima) + else: + ima = ima.transpose((1, 2, 0)) + if self.is_src is False: + ima = raster2uint8(ima) + return ima + + def __getBlock(self, + start_loc: Union[List[int], Tuple[int]], + block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray: + if len(start_loc) != 2 or len(block_size) != 2: + raise ValueError("The length start_loc/block_size must be 2.") + xoff, yoff = start_loc + xsize, ysize = block_size + if (xoff < 0 or xoff > self.width) or (yoff < 0 or yoff > self.height): + raise ValueError( + "start_loc must be within [0-{0}, 0-{1}].".format(str(self.width), str(self.height))) + if xoff + xsize > self.width: + xsize = self.width - xoff + if yoff + ysize > self.height: + ysize = self.height - yoff + ima = self.__getAarray([int(xoff), int(yoff), int(xsize), int(ysize)]) + h, w = ima.shape[:2] if len(ima.shape) == 3 else ima.shape + if self.bands != 1: + tmp = np.zeros((block_size[0], block_size[1], self.bands), dtype=ima.dtype) + tmp[:h, :w, :] = ima + else: + tmp = np.zeros((block_size[0], block_size[1]), dtype=ima.dtype) + tmp[:h, :w] = ima + return tmp \ No newline at end of file diff --git a/paddlers/utils/__init__.py b/paddlers/utils/__init__.py index 842e533..427f6af 100644 --- a/paddlers/utils/__init__.py +++ b/paddlers/utils/__init__.py @@ -22,3 +22,4 @@ from .env import get_environ_info, get_num_workers, init_parallel_env from .download import download_and_decompress, decompress from .stats import SmoothedValue, TrainingStats from .shm import _get_shared_memory_size_in_M +from .convert import raster2uint8 diff --git a/paddlers/utils/convert.py b/paddlers/utils/convert.py new file mode 100644 index 0000000..4b161e1 --- /dev/null +++ b/paddlers/utils/convert.py @@ -0,0 +1,95 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import operator +from functools import reduce + + +def raster2uint8(image: np.ndarray) -> np.ndarray: + """ Convert raster to uint8. + Args: + image (np.ndarray): image. + Returns: + np.ndarray: image on uint8. + """ + dtype = image.dtype.name + dtypes = ["uint8", "uint16", "float32"] + if dtype not in dtypes: + raise ValueError(f"'dtype' must be uint8/uint16/float32, not {dtype}.") + if dtype == "uint8": + return image + else: + if dtype == "float32": + image = _sample_norm(image) + return _two_percentLinear(image) + + +# 2% linear stretch +def _two_percentLinear(image: np.ndarray, max_out: int=255, min_out: int=0) -> np.ndarray: + def _gray_process(gray, maxout=max_out, minout=min_out): + # Get the corresponding gray level at 98% histogram + high_value = np.percentile(gray, 98) + low_value = np.percentile(gray, 2) + truncated_gray = np.clip(gray, a_min=low_value, a_max=high_value) + processed_gray = ((truncated_gray - low_value) / (high_value - low_value)) * (maxout - minout) + return processed_gray + if len(image.shape) == 3: + processes = [] + for b in range(image.shape[-1]): + processes.append(_gray_process(image[:, :, b])) + result = np.stack(processes, axis=2) + else: # if len(image.shape) == 2 + result = _gray_process(image) + return np.uint8(result) + + +# Simple image standardization +def _sample_norm(image: np.ndarray, NUMS: int=65536) -> np.ndarray: + stretches = [] + if len(image.shape) == 3: + for b in range(image.shape[-1]): + stretched = _stretch(image[:, :, b], NUMS) + stretched /= float(NUMS) + stretches.append(stretched) + stretched_img = np.stack(stretches, axis=2) + else: # if len(image.shape) == 2 + stretched_img = _stretch(image, NUMS) + return np.uint8(stretched_img * 255) + + +# Histogram equalization +def _stretch(ima: np.ndarray, NUMS: int) -> np.ndarray: + hist = _histogram(ima, NUMS) + lut = [] + for bt in range(0, len(hist), NUMS): + # Step size + step = reduce(operator.add, hist[bt : bt + NUMS]) / (NUMS - 1) + # Create balanced lookup table + n = 0 + for i in range(NUMS): + lut.append(n / step) + n += hist[i + bt] + np.take(lut, ima, out=ima) + return ima + + +# Calculate histogram +def _histogram(ima: np.ndarray, NUMS: int) -> np.ndarray: + bins = list(range(0, NUMS)) + flat = ima.flat + n = np.searchsorted(np.sort(flat), bins) + n = np.concatenate([n, [len(flat)]]) + hist = n[1:] - n[:-1] + return hist \ No newline at end of file