[Feature] Add raster class

own
geoyee 3 years ago
parent 1bf3b139d0
commit 0c35d0e108
  1. 1
      paddlers/datasets/__init__.py
  2. 140
      paddlers/datasets/raster.py
  3. 1
      paddlers/utils/__init__.py
  4. 95
      paddlers/utils/convert.py

@ -1,2 +1,3 @@
from .voc import VOCDetection
from .seg_dataset import SegDataset
from .raster import Raster

@ -0,0 +1,140 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
import numpy as np
from typing import List, Tuple, Union
from paddlers.utils import raster2uint8
try:
from osgeo import gdal
except:
import gdal
class Raster:
def __init__(self,
path: str,
band_list: Union[List[int], Tuple[int], None]=None,
is_sar: bool=False, # TODO: Remove this param
is_src: bool=False) -> None:
""" Class of read raster.
Args:
path (str): The path of raster.
band_list (Union[List[int], Tuple[int], None], optional):
band list (start with 1) or None (all of bands). Defaults to None.
is_sar (bool, optional): The raster is SAR or not. Defaults to False.
is_src (bool, optional):
Return raw data or not (convert uint8/float32). Defaults to False.
"""
super(Raster, self).__init__()
if osp.exists(path):
self.path = path
self.__src_data = gdal.Open(path)
self.__getInfo()
self.is_sar = is_sar
self.is_src = is_src
self.setBands(band_list)
else:
raise ValueError("The path {0} not exists.".format(path))
def setBands(self,
band_list: Union[List[int], Tuple[int], None]) -> None:
""" Set band of data.
Args:
band_list (Union[List[int], Tuple[int], None]):
band list (start with 1) or None (all of bands).
"""
if band_list is not None:
if len(band_list) > self.bands:
raise ValueError("The lenght of band_list must be less than {0}.".format(str(self.bands)))
if max(band_list) > self.bands or min(band_list) < 1:
raise ValueError("The range of band_list must within [1, {0}].".format(str(self.bands)))
self.band_list = band_list
def getArray(self,
start_loc: Union[List[int], Tuple[int], None]=None,
block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray:
""" Get ndarray data
Args:
start_loc (Union[List[int], Tuple[int], None], optional):
Coordinates of the upper left corner of the block, if None means return full image.
block_size (Union[List[int], Tuple[int]], optional):
Block size. Defaults to [512, 512].
Returns:
np.ndarray: data's ndarray.
"""
if start_loc is None:
return self.__getAarray()
else:
return self.__getBlock(start_loc, block_size)
def __getInfo(self) -> None:
self.bands = self.__src_data.RasterCount
self.width = self.__src_data.RasterXSize
self.height = self.__src_data.RasterYSize
def __getAarray(self, window: Union[None, List[int], Tuple[int]]=None) -> np.ndarray:
if window is not None:
xoff, yoff, xsize, ysize = window
if self.band_list is None:
if window is None:
ima = self.__src_data.ReadAsArray()
else:
ima = self.__src_data.ReadAsArray(xoff, yoff, xsize, ysize)
else:
band_array = []
for b in self.band_list:
if window is None:
band_i = self.__src_data.GetRasterBand(b).ReadAsArray()
else:
band_i = self.__src_data.GetRasterBand(b).ReadAsArray(xoff, yoff, xsize, ysize)
band_array.append(band_i)
ima = np.stack(band_array, axis=0)
if self.bands == 1:
if self.is_sar:
ima = abs(ima)
else:
ima = ima.transpose((1, 2, 0))
if self.is_src is False:
ima = raster2uint8(ima)
return ima
def __getBlock(self,
start_loc: Union[List[int], Tuple[int]],
block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray:
if len(start_loc) != 2 or len(block_size) != 2:
raise ValueError("The length start_loc/block_size must be 2.")
xoff, yoff = start_loc
xsize, ysize = block_size
if (xoff < 0 or xoff > self.width) or (yoff < 0 or yoff > self.height):
raise ValueError(
"start_loc must be within [0-{0}, 0-{1}].".format(str(self.width), str(self.height)))
if xoff + xsize > self.width:
xsize = self.width - xoff
if yoff + ysize > self.height:
ysize = self.height - yoff
ima = self.__getAarray([int(xoff), int(yoff), int(xsize), int(ysize)])
h, w = ima.shape[:2] if len(ima.shape) == 3 else ima.shape
if self.bands != 1:
tmp = np.zeros((block_size[0], block_size[1], self.bands), dtype=ima.dtype)
tmp[:h, :w, :] = ima
else:
tmp = np.zeros((block_size[0], block_size[1]), dtype=ima.dtype)
tmp[:h, :w] = ima
return tmp

@ -22,3 +22,4 @@ from .env import get_environ_info, get_num_workers, init_parallel_env
from .download import download_and_decompress, decompress
from .stats import SmoothedValue, TrainingStats
from .shm import _get_shared_memory_size_in_M
from .convert import raster2uint8

@ -0,0 +1,95 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import operator
from functools import reduce
def raster2uint8(image: np.ndarray) -> np.ndarray:
""" Convert raster to uint8.
Args:
image (np.ndarray): image.
Returns:
np.ndarray: image on uint8.
"""
dtype = image.dtype.name
dtypes = ["uint8", "uint16", "float32"]
if dtype not in dtypes:
raise ValueError(f"'dtype' must be uint8/uint16/float32, not {dtype}.")
if dtype == "uint8":
return image
else:
if dtype == "float32":
image = _sample_norm(image)
return _two_percentLinear(image)
# 2% linear stretch
def _two_percentLinear(image: np.ndarray, max_out: int=255, min_out: int=0) -> np.ndarray:
def _gray_process(gray, maxout=max_out, minout=min_out):
# Get the corresponding gray level at 98% histogram
high_value = np.percentile(gray, 98)
low_value = np.percentile(gray, 2)
truncated_gray = np.clip(gray, a_min=low_value, a_max=high_value)
processed_gray = ((truncated_gray - low_value) / (high_value - low_value)) * (maxout - minout)
return processed_gray
if len(image.shape) == 3:
processes = []
for b in range(image.shape[-1]):
processes.append(_gray_process(image[:, :, b]))
result = np.stack(processes, axis=2)
else: # if len(image.shape) == 2
result = _gray_process(image)
return np.uint8(result)
# Simple image standardization
def _sample_norm(image: np.ndarray, NUMS: int=65536) -> np.ndarray:
stretches = []
if len(image.shape) == 3:
for b in range(image.shape[-1]):
stretched = _stretch(image[:, :, b], NUMS)
stretched /= float(NUMS)
stretches.append(stretched)
stretched_img = np.stack(stretches, axis=2)
else: # if len(image.shape) == 2
stretched_img = _stretch(image, NUMS)
return np.uint8(stretched_img * 255)
# Histogram equalization
def _stretch(ima: np.ndarray, NUMS: int) -> np.ndarray:
hist = _histogram(ima, NUMS)
lut = []
for bt in range(0, len(hist), NUMS):
# Step size
step = reduce(operator.add, hist[bt : bt + NUMS]) / (NUMS - 1)
# Create balanced lookup table
n = 0
for i in range(NUMS):
lut.append(n / step)
n += hist[i + bt]
np.take(lut, ima, out=ima)
return ima
# Calculate histogram
def _histogram(ima: np.ndarray, NUMS: int) -> np.ndarray:
bins = list(range(0, NUMS))
flat = ima.flat
n = np.searchsorted(np.sort(flat), bins)
n = np.concatenate([n, [len(flat)]])
hist = n[1:] - n[:-1]
return hist
Loading…
Cancel
Save