diff --git a/paddlers/transforms/functions.py b/paddlers/transforms/functions.py index 887f1ed..11200ba 100644 --- a/paddlers/transforms/functions.py +++ b/paddlers/transforms/functions.py @@ -18,9 +18,9 @@ import copy import numpy as np import shapely.ops from shapely.geometry import Polygon, MultiPolygon, GeometryCollection -from sklearn.decomposition import PCA from sklearn.linear_model import LinearRegression from skimage import exposure +from joblib import load def normalize(im, mean, std, min_value=[0, 0, 0], max_value=[255, 255, 255]): @@ -427,10 +427,6 @@ def to_uint8(im, is_linear=False): return np.uint8(stretched_img * 255) dtype = im.dtype.name - dtypes = ["uint8", "uint16", "uint32", "float32"] - if dtype not in dtypes: - raise ValueError( - f"'dtype' must be uint8/uint16/uint32/float32, not {dtype}.") if dtype != "uint8": im = _sample_norm(im) if is_linear: @@ -533,26 +529,6 @@ def de_haze(im, gamma=False): return (result * 255).astype("uint8") -def pca(im, dim=3, whiten=True): - """ Dimensionality reduction of PCA. - - Args: - im (np.ndarray): The image. - dim (int, optional): Reserved dimensions. Defaults to 3. - whiten (bool, optional): PCA whiten or not. Defaults to True. - - Returns: - np.ndarray: The image after PCA. - """ - H, W, C = im.shape - n_im = np.reshape(im, (-1, C)) - pca = PCA(n_components=dim, whiten=whiten) - im_pca = pca.fit_transform(n_im) - result = np.reshape(im_pca, (H, W, dim)) - result = np.clip(result, 0, 1) - return (result * 255).astype("uint8") - - def match_histograms(im, ref): """ Match the cumulative histogram of one image to another. @@ -615,3 +591,22 @@ def match_by_regression(im, ref, pif_loc=None): matched = _linear_regress(im, ref, pif_loc).astype(im.dtype) return matched + + +def inv_pca(im, joblib_path): + """ + Restore PCA result. + + Args: + im (np.ndarray): The input image after PCA. + joblib_path (str): Path of *.joblib about PCA. + + Returns: + np.ndarray: The raw input image. + """ + pca = load(joblib_path) + H, W, C = im.shape + n_im = np.reshape(im, (-1, C)) + r_im = pca.inverse_transform(n_im) + r_im = np.reshape(r_im, (H, W, -1)) + return r_im diff --git a/paddlers/transforms/operators.py b/paddlers/transforms/operators.py index 1a34b99..fad74a4 100644 --- a/paddlers/transforms/operators.py +++ b/paddlers/transforms/operators.py @@ -27,11 +27,12 @@ import numpy as np import cv2 import imghdr from PIL import Image +from joblib import load import paddlers from .functions import normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly, \ horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly, vertical_flip_rle, crop_poly, \ - crop_rle, expand_poly, expand_rle, resize_poly, resize_rle, de_haze, pca, select_bands, \ + crop_rle, expand_poly, expand_rle, resize_poly, resize_rle, de_haze, select_bands, \ to_intensity, to_uint8, img_flip, img_simple_rotate __all__ = [ @@ -242,7 +243,7 @@ class Compose(Transform): ValueError: Invalid length of transforms. """ - def __init__(self, transforms): + def __init__(self, transforms, to_uint8=True): super(Compose, self).__init__() if not isinstance(transforms, list): raise TypeError( @@ -253,7 +254,7 @@ class Compose(Transform): 'Length of transforms must not be less than 1, but received is {}' .format(len(transforms))) self.transforms = transforms - self.decode_image = ImgDecoder() + self.decode_image = ImgDecoder(to_uint8=to_uint8) self.arrange_outputs = None self.apply_im_only = False @@ -1552,18 +1553,22 @@ class DimReducing(Transform): Use PCA to reduce input image(s) dimension. Args: - dim (int, optional): Reserved dimensions. Defaults to 3. - whiten (bool, optional): PCA whiten or not. Defaults to True. + joblib_path (str): Path of *.joblib about PCA. """ - def __init__(self, dim=3, whiten=True): + def __init__(self, joblib_path): super(DimReducing, self).__init__() - self.dim = dim - self.whiten = whiten + ext = joblib_path.split(".")[-1] + if ext != "joblib": + raise ValueError("`joblib_path` must be *.joblib, not *.{}.".format(ext)) + self.pca = load(joblib_path) def apply_im(self, image): - image = pca(image, self.dim, self.whiten) - return image + H, W, C = image.shape + n_im = np.reshape(image, (-1, C)) + im_pca = self.pca.transform(n_im) + result = np.reshape(im_pca, (H, W, -1)) + return result def apply(self, sample): sample['image'] = self.apply_im(sample['image']) diff --git a/tools/geojson2mask.py b/tools/geojson2mask.py index 63fb2da..8ffcc0a 100644 --- a/tools/geojson2mask.py +++ b/tools/geojson2mask.py @@ -18,7 +18,7 @@ import numpy as np import argparse import geojson from tqdm import tqdm -from utils import Raster, save_mask_geotiff, Timer +from utils import Raster, save_geotiff, Timer def _gt_convert(x_geo, y_geo, geotf): @@ -48,7 +48,7 @@ def convert_data(image_path, geojson_path): # TODO: Label category cv2.fillPoly(tmp_img, [xy_points], 1) # 多边形填充 ext = "." + geojson_path.split(".")[-1] - save_mask_geotiff(tmp_img, geojson_path.replace(ext, ".tif"), raster.proj, raster.geot) + save_geotiff(tmp_img, geojson_path.replace(ext, ".tif"), raster.proj, raster.geot) parser = argparse.ArgumentParser(description="input parameters") diff --git a/tools/pca.py b/tools/pca.py new file mode 100644 index 0000000..a3c3063 --- /dev/null +++ b/tools/pca.py @@ -0,0 +1,53 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import os.path as osp +import numpy as np +import argparse +from sklearn.decomposition import PCA +from joblib import dump +from utils import Raster, Timer, save_geotiff + + +@Timer +def pca_train(img_path, save_dir="output", dim=3): + raster = Raster(img_path) + im = raster.getArray() + n_im = np.reshape(im, (-1, raster.bands)) + pca = PCA(n_components=dim, whiten=True) + pca_model = pca.fit(n_im) + if not osp.exists(save_dir): + os.makedirs(save_dir) + name = osp.splitext(osp.normpath(img_path).split(os.sep)[-1])[0] + model_save_path = osp.join(save_dir, (name + "_pca.joblib")) + image_save_path = osp.join(save_dir, (name + "_pca.tif")) + dump(pca_model, model_save_path) # save model + output = pca_model.transform(n_im).reshape((raster.height, raster.width, -1)) + save_geotiff(output, image_save_path, raster.proj, raster.geot) # save tiff + print("The Image and model of PCA saved in {}.".format(save_dir)) + + +parser = argparse.ArgumentParser(description="input parameters") +parser.add_argument("--im_path", type=str, required=True, \ + help="The path of HSIs image.") +parser.add_argument("--save_dir", type=str, default="output", \ + help="The params(*.joblib) saved folder, `output` is the default.") +parser.add_argument("--dim", type=int, default=3, \ + help="The dimension after reduced, `3` is the default.") + + +if __name__ == "__main__": + args = parser.parse_args() + pca_train(args.im_path, args.save_dir, args.dim) diff --git a/tools/utils/__init__.py b/tools/utils/__init__.py index e414e00..208b605 100644 --- a/tools/utils/__init__.py +++ b/tools/utils/__init__.py @@ -1,20 +1,20 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -import os.path as osp -sys.path.insert(0, osp.abspath("..")) # add workspace - -from .raster import Raster, save_mask_geotiff, raster2uint8 -from .timer import Timer +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os.path as osp +sys.path.insert(0, osp.abspath("..")) # add workspace + +from .raster import Raster, raster2uint8, save_geotiff +from .timer import Timer diff --git a/tools/utils/raster.py b/tools/utils/raster.py index 82bfffb..e6b7c93 100644 --- a/tools/utils/raster.py +++ b/tools/utils/raster.py @@ -1,207 +1,227 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os.path as osp -from typing import List, Tuple, Union - -import numpy as np - -from paddlers.transforms.functions import to_uint8 as raster2uint8 - -try: - from osgeo import gdal -except: - import gdal - - -class Raster: - def __init__(self, - path: str, - band_list: Union[List[int], Tuple[int], None]=None, - to_uint8: bool=False) -> None: - """ Class of read raster. - - Args: - path (str): The path of raster. - band_list (Union[List[int], Tuple[int], None], optional): - band list (start with 1) or None (all of bands). Defaults to None. - to_uint8 (bool, optional): - Convert uint8 or return raw data. Defaults to False. - """ - super(Raster, self).__init__() - if osp.exists(path): - self.path = path - self.ext_type = path.split(".")[-1] - if self.ext_type.lower() in ["npy", "npz"]: - self._src_data = None - else: - try: - # raster format support in GDAL: - # https://www.osgeo.cn/gdal/drivers/raster/index.html - self._src_data = gdal.Open(path) - except: - raise TypeError("Unsupported data format: `{}`".format( - self.ext_type)) - self.to_uint8 = to_uint8 - self.setBands(band_list) - self._getInfo() - else: - raise ValueError("The path {0} not exists.".format(path)) - - def setBands(self, band_list: Union[List[int], Tuple[int], None]) -> None: - """ Set band of data. - - Args: - band_list (Union[List[int], Tuple[int], None]): - band list (start with 1) or None (all of bands). - """ - self.bands = self._src_data.RasterCount - if band_list is not None: - if len(band_list) > self.bands: - raise ValueError( - "The lenght of band_list must be less than {0}.".format( - str(self.bands))) - if max(band_list) > self.bands or min(band_list) < 1: - raise ValueError("The range of band_list must within [1, {0}].". - format(str(self.bands))) - self.band_list = band_list - - def getArray( - self, - start_loc: Union[List[int], Tuple[int], None]=None, - block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray: - """ Get ndarray data - - Args: - start_loc (Union[List[int], Tuple[int], None], optional): - Coordinates of the upper left corner of the block, if None means return full image. - block_size (Union[List[int], Tuple[int]], optional): - Block size. Defaults to [512, 512]. - - Returns: - np.ndarray: data's ndarray. - """ - if self._src_data is not None: - if start_loc is None: - return self._getArray() - else: - return self._getBlock(start_loc, block_size) - else: - print("Numpy doesn't support blocking temporarily.") - return self._getNumpy() - - def _getInfo(self) -> None: - if self._src_data is not None: - self.width = self._src_data.RasterXSize - self.height = self._src_data.RasterYSize - self.geot = self._src_data.GetGeoTransform() - self.proj = self._src_data.GetProjection() - d_name = self._getBlock([0, 0], [1, 1]).dtype.name - else: - d_img = self._getNumpy() - d_shape = d_img.shape - d_name = d_img.dtype.name - if len(d_shape) == 3: - self.height, self.width, self.bands = d_shape - else: - self.height, self.width = d_shape - self.bands = 1 - self.geot = None - self.proj = None - if "int8" in d_name: - self.datatype = gdal.GDT_Byte - elif "int16" in d_name: - self.datatype = gdal.GDT_UInt16 - else: - self.datatype = gdal.GDT_Float32 - - def _getNumpy(self): - ima = np.load(self.path) - if self.band_list is not None: - band_array = [] - for b in self.band_list: - band_i = ima[:, :, b - 1] - band_array.append(band_i) - ima = np.stack(band_array, axis=0) - return ima - - def _getArray( - self, - window: Union[None, List[int], Tuple[int]]=None) -> np.ndarray: - if window is not None: - xoff, yoff, xsize, ysize = window - if self.band_list is None: - if window is None: - ima = self._src_data.ReadAsArray() - else: - ima = self._src_data.ReadAsArray(xoff, yoff, xsize, ysize) - else: - band_array = [] - for b in self.band_list: - if window is None: - band_i = self._src_data.GetRasterBand(b).ReadAsArray() - else: - band_i = self._src_data.GetRasterBand(b).ReadAsArray( - xoff, yoff, xsize, ysize) - band_array.append(band_i) - ima = np.stack(band_array, axis=0) - if self.bands == 1: - if len(ima.shape) == 3: - ima = ima.squeeze(0) - # the type is complex means this is a SAR data - if isinstance(type(ima[0, 0]), complex): - ima = abs(ima) - else: - ima = ima.transpose((1, 2, 0)) - if self.to_uint8 is True: - ima = raster2uint8(ima) - return ima - - def _getBlock( - self, - start_loc: Union[List[int], Tuple[int]], - block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray: - if len(start_loc) != 2 or len(block_size) != 2: - raise ValueError("The length start_loc/block_size must be 2.") - xoff, yoff = start_loc - xsize, ysize = block_size - if (xoff < 0 or xoff > self.width) or (yoff < 0 or yoff > self.height): - raise ValueError("start_loc must be within [0-{0}, 0-{1}].".format( - str(self.width), str(self.height))) - if xoff + xsize > self.width: - xsize = self.width - xoff - if yoff + ysize > self.height: - ysize = self.height - yoff - ima = self._getArray([int(xoff), int(yoff), int(xsize), int(ysize)]) - h, w = ima.shape[:2] if len(ima.shape) == 3 else ima.shape - if self.bands != 1: - tmp = np.zeros( - (block_size[0], block_size[1], self.bands), dtype=ima.dtype) - tmp[:h, :w, :] = ima - else: - tmp = np.zeros((block_size[0], block_size[1]), dtype=ima.dtype) - tmp[:h, :w] = ima - return tmp - - -def save_mask_geotiff(mask: np.ndarray, save_path: str, proj: str, geotf: Tuple) -> None: - height, width = mask.shape - driver = gdal.GetDriverByName("GTiff") - dst_ds = driver.Create(save_path, width, height, 1, gdal.GDT_UInt16) - dst_ds.SetGeoTransform(geotf) - dst_ds.SetProjection(proj) - band = dst_ds.GetRasterBand(1) - band.WriteArray(mask) - dst_ds.FlushCache() - dst_ds = None +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path as osp +from typing import List, Tuple, Union + +import numpy as np + +from paddlers.transforms.functions import to_uint8 as raster2uint8 + +try: + from osgeo import gdal +except: + import gdal + + +def _get_type(type_name: str) -> int: + if type_name in ["bool", "uint8"]: + gdal_type = gdal.GDT_Byte + elif type_name in ["int8", "int16"]: + gdal_type = gdal.GDT_Int16 + elif type_name == "uint16": + gdal_type = gdal.GDT_UInt16 + elif type_name == "int32": + gdal_type = gdal.GDT_Int32 + elif type_name == "uint32": + gdal_type = gdal.GDT_UInt32 + elif type_name in ["int64", "uint64", "float16", "float32"]: + gdal_type = gdal.GDT_Float32 + elif type_name == "float64": + gdal_type = gdal.GDT_Float64 + elif type_name == "complex64": + gdal_type = gdal.GDT_CFloat64 + else: + raise TypeError("Non-suported data type `{}`.".format(type_name)) + return gdal_type + + +class Raster: + def __init__(self, + path: str, + band_list: Union[List[int], Tuple[int], None]=None, + to_uint8: bool=False) -> None: + """ Class of read raster. + Args: + path (str): The path of raster. + band_list (Union[List[int], Tuple[int], None], optional): + band list (start with 1) or None (all of bands). Defaults to None. + to_uint8 (bool, optional): + Convert uint8 or return raw data. Defaults to False. + """ + super(Raster, self).__init__() + if osp.exists(path): + self.path = path + self.ext_type = path.split(".")[-1] + if self.ext_type.lower() in ["npy", "npz"]: + self._src_data = None + else: + try: + # raster format support in GDAL: + # https://www.osgeo.cn/gdal/drivers/raster/index.html + self._src_data = gdal.Open(path) + except: + raise TypeError("Unsupported data format: `{}`".format( + self.ext_type)) + self.to_uint8 = to_uint8 + self.setBands(band_list) + self._getInfo() + else: + raise ValueError("The path {0} not exists.".format(path)) + + def setBands(self, band_list: Union[List[int], Tuple[int], None]) -> None: + """ Set band of data. + Args: + band_list (Union[List[int], Tuple[int], None]): + band list (start with 1) or None (all of bands). + """ + self.bands = self._src_data.RasterCount + if band_list is not None: + if len(band_list) > self.bands: + raise ValueError( + "The lenght of band_list must be less than {0}.".format( + str(self.bands))) + if max(band_list) > self.bands or min(band_list) < 1: + raise ValueError("The range of band_list must within [1, {0}].". + format(str(self.bands))) + self.band_list = band_list + + def getArray( + self, + start_loc: Union[List[int], Tuple[int], None]=None, + block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray: + """ Get ndarray data + Args: + start_loc (Union[List[int], Tuple[int], None], optional): + Coordinates of the upper left corner of the block, if None means return full image. + block_size (Union[List[int], Tuple[int]], optional): + Block size. Defaults to [512, 512]. + Returns: + np.ndarray: data's ndarray. + """ + if self._src_data is not None: + if start_loc is None: + return self._getArray() + else: + return self._getBlock(start_loc, block_size) + else: + print("Numpy doesn't support blocking temporarily.") + return self._getNumpy() + + def _getInfo(self) -> None: + if self._src_data is not None: + self.width = self._src_data.RasterXSize + self.height = self._src_data.RasterYSize + self.geot = self._src_data.GetGeoTransform() + self.proj = self._src_data.GetProjection() + d_name = self._getBlock([0, 0], [1, 1]).dtype.name + else: + d_img = self._getNumpy() + d_shape = d_img.shape + d_name = d_img.dtype.name + if len(d_shape) == 3: + self.height, self.width, self.bands = d_shape + else: + self.height, self.width = d_shape + self.bands = 1 + self.geot = None + self.proj = None + self.datatype = _get_type(d_name) + + def _getNumpy(self): + ima = np.load(self.path) + if self.band_list is not None: + band_array = [] + for b in self.band_list: + band_i = ima[:, :, b - 1] + band_array.append(band_i) + ima = np.stack(band_array, axis=0) + return ima + + def _getArray( + self, + window: Union[None, List[int], Tuple[int]]=None) -> np.ndarray: + if window is not None: + xoff, yoff, xsize, ysize = window + if self.band_list is None: + if window is None: + ima = self._src_data.ReadAsArray() + else: + ima = self._src_data.ReadAsArray(xoff, yoff, xsize, ysize) + else: + band_array = [] + for b in self.band_list: + if window is None: + band_i = self._src_data.GetRasterBand(b).ReadAsArray() + else: + band_i = self._src_data.GetRasterBand(b).ReadAsArray( + xoff, yoff, xsize, ysize) + band_array.append(band_i) + ima = np.stack(band_array, axis=0) + if self.bands == 1: + if len(ima.shape) == 3: + ima = ima.squeeze(0) + # the type is complex means this is a SAR data + if isinstance(type(ima[0, 0]), complex): + ima = abs(ima) + else: + ima = ima.transpose((1, 2, 0)) + if self.to_uint8 is True: + ima = raster2uint8(ima) + return ima + + def _getBlock( + self, + start_loc: Union[List[int], Tuple[int]], + block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray: + if len(start_loc) != 2 or len(block_size) != 2: + raise ValueError("The length start_loc/block_size must be 2.") + xoff, yoff = start_loc + xsize, ysize = block_size + if (xoff < 0 or xoff > self.width) or (yoff < 0 or yoff > self.height): + raise ValueError("start_loc must be within [0-{0}, 0-{1}].".format( + str(self.width), str(self.height))) + if xoff + xsize > self.width: + xsize = self.width - xoff + if yoff + ysize > self.height: + ysize = self.height - yoff + ima = self._getArray([int(xoff), int(yoff), int(xsize), int(ysize)]) + h, w = ima.shape[:2] if len(ima.shape) == 3 else ima.shape + if self.bands != 1: + tmp = np.zeros( + (block_size[0], block_size[1], self.bands), dtype=ima.dtype) + tmp[:h, :w, :] = ima + else: + tmp = np.zeros((block_size[0], block_size[1]), dtype=ima.dtype) + tmp[:h, :w] = ima + return tmp + + +def save_geotiff(image: np.ndarray, save_path: str, proj: str, geotf: Tuple) -> None: + height, width, channel = image.shape + data_type = _get_type(image.dtype.name) + driver = gdal.GetDriverByName("GTiff") + dst_ds = driver.Create(save_path, width, height, channel, data_type) + dst_ds.SetGeoTransform(geotf) + dst_ds.SetProjection(proj) + if channel > 1: + for i in range(channel): + band = dst_ds.GetRasterBand(i + 1) + band.WriteArray(image[:, :, i]) + dst_ds.FlushCache() + else: + band = dst_ds.GetRasterBand(1) + band.WriteArray(image) + dst_ds.FlushCache() + dst_ds = None