diff --git a/paddlers/datasets/__init__.py b/paddlers/datasets/__init__.py index 4a8afd0..f611c54 100644 --- a/paddlers/datasets/__init__.py +++ b/paddlers/datasets/__init__.py @@ -15,5 +15,4 @@ from .voc import VOCDetection from .seg_dataset import SegDataset from .cd_dataset import CDDataset -from .clas_dataset import ClasDataset -from .raster import Raster \ No newline at end of file +from .clas_dataset import ClasDataset \ No newline at end of file diff --git a/paddlers/transforms/functions.py b/paddlers/transforms/functions.py index 2077421..080753a 100644 --- a/paddlers/transforms/functions.py +++ b/paddlers/transforms/functions.py @@ -14,10 +14,11 @@ import cv2 import numpy as np - +import copy +import operator import shapely.ops from shapely.geometry import Polygon, MultiPolygon, GeometryCollection -import copy +from functools import reduce from sklearn.decomposition import PCA @@ -194,6 +195,122 @@ def resize_rle(rle, im_h, im_w, im_scale_x, im_scale_y, interp): return rle +def to_uint8(im): + """ Convert raster to uint8. + + Args: + im (np.ndarray): The image. + + Returns: + np.ndarray: Image on uint8. + """ + # 2% linear stretch + def _two_percentLinear(image, max_out=255, min_out=0): + def _gray_process(gray, maxout=max_out, minout=min_out): + # get the corresponding gray level at 98% histogram + high_value = np.percentile(gray, 98) + low_value = np.percentile(gray, 2) + truncated_gray = np.clip(gray, a_min=low_value, a_max=high_value) + processed_gray = ((truncated_gray - low_value) / (high_value - low_value)) * \ + (maxout - minout) + return processed_gray + if len(image.shape) == 3: + processes = [] + for b in range(image.shape[-1]): + processes.append(_gray_process(image[:, :, b])) + result = np.stack(processes, axis=2) + else: # if len(image.shape) == 2 + result = _gray_process(image) + return np.uint8(result) + + # simple image standardization + def _sample_norm(image, NUMS=65536): + stretches = [] + if len(image.shape) == 3: + for b in range(image.shape[-1]): + stretched = _stretch(image[:, :, b], NUMS) + stretched /= float(NUMS) + stretches.append(stretched) + stretched_img = np.stack(stretches, axis=2) + else: # if len(image.shape) == 2 + stretched_img = _stretch(image, NUMS) + return np.uint8(stretched_img * 255) + + # histogram equalization + def _stretch(ima, NUMS): + hist = _histogram(ima, NUMS) + lut = [] + for bt in range(0, len(hist), NUMS): + # step size + step = reduce(operator.add, hist[bt : bt + NUMS]) / (NUMS - 1) + # create balanced lookup table + n = 0 + for i in range(NUMS): + lut.append(n / step) + n += hist[i + bt] + np.take(lut, ima, out=ima) + return ima + + # calculate histogram + def _histogram(ima, NUMS): + bins = list(range(0, NUMS)) + flat = ima.flat + n = np.searchsorted(np.sort(flat), bins) + n = np.concatenate([n, [len(flat)]]) + hist = n[1:] - n[:-1] + return hist + + dtype = im.dtype.name + dtypes = ["uint8", "uint16", "float32"] + if dtype not in dtypes: + raise ValueError(f"'dtype' must be uint8/uint16/float32, not {dtype}.") + if dtype == "uint8": + return im + else: + if dtype == "float32": + im = _sample_norm(im) + return _two_percentLinear(im) + + +def to_intensity(im): + """ calculate SAR data's intensity diagram. + + Args: + im (np.ndarray): The SAR image. + + Returns: + np.ndarray: Intensity diagram. + """ + if len(im.shape) != 2: + raise ValueError("im's shape must be 2.") + # the type is complex means this is a SAR data + if isinstance(type(im[0, 0]), complex): + im = abs(im) + return im + + +def select_bands(im, band_list=[1, 2, 3]): + """ Select bands. + + Args: + im (np.ndarray): The image. + band_list (list, optional): Bands of selected (Start with 1). Defaults to [1, 2, 3]. + + Returns: + np.ndarray: The image after band selected. + """ + total_band = im.shape[-1] + result = [] + for band in band_list: + band = int(band - 1) + if band < 0 or band >= total_band: + raise ValueError( + "The element in band_list must > 1 and <= {}.".format(str(total_band))) + result.append() + ima = np.stack(result, axis=0) + return ima + + def matching(im1, im2): """ Match two images, used change detection. (Just RGB) @@ -214,8 +331,10 @@ def matching(im1, im2): for m, n in mathces: if m.distance < 0.75 * n.distance: good_matches.append([m]) - src_automatic_points = np.float32([kp1[m[0].queryIdx].pt for m in good_matches]).reshape(-1, 1, 2) - den_automatic_points = np.float32([kp2[m[0].trainIdx].pt for m in good_matches]).reshape(-1, 1, 2) + src_automatic_points = np.float32([kp1[m[0].queryIdx].pt \ + for m in good_matches]).reshape(-1, 1, 2) + den_automatic_points = np.float32([kp2[m[0].trainIdx].pt \ + for m in good_matches]).reshape(-1, 1, 2) H, _ = cv2.findHomography(src_automatic_points, den_automatic_points, cv2.RANSAC, 5.0) im1_t = cv2.warpPerspective(im1, H, (im2.shape[1], im2.shape[0])) return im1_t, im2 @@ -231,7 +350,7 @@ def de_haze(im, gamma=False): Returns: np.ndarray: The image after defogged. """ - def guided_filter(I, p, r, eps): + def _guided_filter(I, p, r, eps): m_I = cv2.boxFilter(I, -1, (r, r)) m_p = cv2.boxFilter(p, -1, (r, r)) m_Ip = cv2.boxFilter(I * p, -1, (r, r)) @@ -244,11 +363,11 @@ def de_haze(im, gamma=False): m_b = cv2.boxFilter(b, -1, (r, r)) return m_a * I + m_b - def de_fog(im, r, w, maxatmo_mask, eps): + def _de_fog(im, r, w, maxatmo_mask, eps): # im is RGB and range[0, 1] atmo_mask = np.min(im, 2) dark_channel = cv2.erode(atmo_mask, np.ones((15, 15))) - atmo_mask = guided_filter(atmo_mask, dark_channel, r, eps) + atmo_mask = _guided_filter(atmo_mask, dark_channel, r, eps) bins = 2000 ht = np.histogram(atmo_mask, bins) d = np.cumsum(ht[0]) / float(atmo_mask.size) @@ -262,7 +381,7 @@ def de_haze(im, gamma=False): if np.max(im) > 1: im = im / 255. result = np.zeros(im.shape) - mask_img, atmo_illum = de_fog(im, r=81, w=0.95, maxatmo_mask=0.80, eps=1e-8) + mask_img, atmo_illum = _de_fog(im, r=81, w=0.95, maxatmo_mask=0.80, eps=1e-8) for k in range(3): result[:, :, k] = (im[:, :, k] - mask_img) / (1 - mask_img / atmo_illum) result = np.clip(result, 0, 1) diff --git a/paddlers/utils/__init__.py b/paddlers/utils/__init__.py index 427f6af..832793d 100644 --- a/paddlers/utils/__init__.py +++ b/paddlers/utils/__init__.py @@ -21,5 +21,4 @@ from .checkpoint import get_pretrain_weights, load_pretrain_weights, load_checkp from .env import get_environ_info, get_num_workers, init_parallel_env from .download import download_and_decompress, decompress from .stats import SmoothedValue, TrainingStats -from .shm import _get_shared_memory_size_in_M -from .convert import raster2uint8 +from .shm import _get_shared_memory_size_in_M \ No newline at end of file diff --git a/paddlers/utils/convert.py b/paddlers/utils/convert.py deleted file mode 100644 index 2964abf..0000000 --- a/paddlers/utils/convert.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import operator -from functools import reduce - - -def raster2uint8(image: np.ndarray) -> np.ndarray: - """ Convert raster to uint8. - Args: - image (np.ndarray): image. - Returns: - np.ndarray: image on uint8. - """ - dtype = image.dtype.name - dtypes = ["uint8", "uint16", "float32"] - if dtype not in dtypes: - raise ValueError(f"'dtype' must be uint8/uint16/float32, not {dtype}.") - if dtype == "uint8": - return image - else: - if dtype == "float32": - image = _sample_norm(image) - return _two_percentLinear(image) - - -# 2% linear stretch -def _two_percentLinear(image: np.ndarray, max_out: int=255, min_out: int=0) -> np.ndarray: - def _gray_process(gray, maxout=max_out, minout=min_out): - # get the corresponding gray level at 98% histogram - high_value = np.percentile(gray, 98) - low_value = np.percentile(gray, 2) - truncated_gray = np.clip(gray, a_min=low_value, a_max=high_value) - processed_gray = ((truncated_gray - low_value) / (high_value - low_value)) * (maxout - minout) - return processed_gray - if len(image.shape) == 3: - processes = [] - for b in range(image.shape[-1]): - processes.append(_gray_process(image[:, :, b])) - result = np.stack(processes, axis=2) - else: # if len(image.shape) == 2 - result = _gray_process(image) - return np.uint8(result) - - -# simple image standardization -def _sample_norm(image: np.ndarray, NUMS: int=65536) -> np.ndarray: - stretches = [] - if len(image.shape) == 3: - for b in range(image.shape[-1]): - stretched = _stretch(image[:, :, b], NUMS) - stretched /= float(NUMS) - stretches.append(stretched) - stretched_img = np.stack(stretches, axis=2) - else: # if len(image.shape) == 2 - stretched_img = _stretch(image, NUMS) - return np.uint8(stretched_img * 255) - - -# histogram equalization -def _stretch(ima: np.ndarray, NUMS: int) -> np.ndarray: - hist = _histogram(ima, NUMS) - lut = [] - for bt in range(0, len(hist), NUMS): - # step size - step = reduce(operator.add, hist[bt : bt + NUMS]) / (NUMS - 1) - # create balanced lookup table - n = 0 - for i in range(NUMS): - lut.append(n / step) - n += hist[i + bt] - np.take(lut, ima, out=ima) - return ima - - -# calculate histogram -def _histogram(ima: np.ndarray, NUMS: int) -> np.ndarray: - bins = list(range(0, NUMS)) - flat = ima.flat - n = np.searchsorted(np.sort(flat), bins) - n = np.concatenate([n, [len(flat)]]) - hist = n[1:] - n[:-1] - return hist \ No newline at end of file diff --git a/tools/mask2shp.py b/tools/mask2shp.py index 6aec8a3..e597f9c 100644 --- a/tools/mask2shp.py +++ b/tools/mask2shp.py @@ -12,15 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys -import os.path as osp -sys.path.insert(0, osp.abspath("..")) # add workspace - import os +import os.path as osp import numpy as np import argparse from PIL import Image -from paddlers.datasets.raster import Raster +from utils import Raster try: from osgeo import gdal, ogr, osr diff --git a/tools/spliter.py b/tools/spliter.py index 71d8546..32d7961 100644 --- a/tools/spliter.py +++ b/tools/spliter.py @@ -12,15 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys -import os.path as osp -sys.path.insert(0, osp.abspath("..")) # add workspace - import os +import os.path as osp import argparse from math import ceil from PIL import Image -from paddlers.datasets.raster import Raster +from utils import Raster def split_data(image_path, block_size, save_folder): diff --git a/tools/utils/__init__.py b/tools/utils/__init__.py new file mode 100644 index 0000000..00d1af8 --- /dev/null +++ b/tools/utils/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os.path as osp +sys.path.insert(0, osp.abspath("..")) # add workspace + +from .raster import Raster \ No newline at end of file diff --git a/paddlers/datasets/raster.py b/tools/utils/raster.py similarity index 96% rename from paddlers/datasets/raster.py rename to tools/utils/raster.py index 61655cc..45ab5a8 100644 --- a/paddlers/datasets/raster.py +++ b/tools/utils/raster.py @@ -15,7 +15,7 @@ import os.path as osp import numpy as np from typing import List, Tuple, Union -from paddlers.utils import raster2uint8 +from paddlers.transforms.functions import to_uint8 as raster2uint8 try: from osgeo import gdal