diff --git a/paddlers/transforms/functions.py b/paddlers/transforms/functions.py index c0d86fc..7f2c384 100644 --- a/paddlers/transforms/functions.py +++ b/paddlers/transforms/functions.py @@ -15,10 +15,8 @@ import cv2 import numpy as np import copy -import operator import shapely.ops from shapely.geometry import Polygon, MultiPolygon, GeometryCollection -from functools import reduce from sklearn.decomposition import PCA from sklearn.linear_model import LinearRegression from skimage import exposure @@ -383,18 +381,19 @@ def resize_rle(rle, im_h, im_w, im_scale_x, im_scale_y, interp): return rle -def to_uint8(im): +def to_uint8(im, is_linear=False): """ Convert raster to uint8. Args: im (np.ndarray): The image. + is_linear (bool, optional): Use 2% linear stretch or not. Default is False. Returns: np.ndarray: Image on uint8. """ # 2% linear stretch - def _two_percentLinear(image, max_out=255, min_out=0): + def _two_percent_linear(image, max_out=255, min_out=0): def _gray_process(gray, maxout=max_out, minout=min_out): # get the corresponding gray level at 98% histogram high_value = np.percentile(gray, 98) @@ -402,7 +401,7 @@ def to_uint8(im): truncated_gray = np.clip(gray, a_min=low_value, a_max=high_value) processed_gray = ((truncated_gray - low_value) / (high_value - low_value)) * \ (maxout - minout) - return processed_gray + return np.uint8(processed_gray) if len(image.shape) == 3: processes = [] @@ -414,52 +413,28 @@ def to_uint8(im): return np.uint8(result) # simple image standardization - def _sample_norm(image, NUMS=65536): + def _sample_norm(image): stretches = [] if len(image.shape) == 3: for b in range(image.shape[-1]): - stretched = _stretch(image[:, :, b], NUMS) - stretched /= float(NUMS) + stretched = exposure.equalize_hist(image[:, :, b]) + stretched /= float(np.max(stretched)) stretches.append(stretched) stretched_img = np.stack(stretches, axis=2) else: # if len(image.shape) == 2 - stretched_img = _stretch(image, NUMS) + stretched_img = exposure.equalize_hist(image) return np.uint8(stretched_img * 255) - # histogram equalization - def _stretch(ima, NUMS): - hist = _histogram(ima, NUMS) - lut = [] - for bt in range(0, len(hist), NUMS): - # step size - step = reduce(operator.add, hist[bt:bt + NUMS]) / (NUMS - 1) - # create balanced lookup table - n = 0 - for i in range(NUMS): - lut.append(n / step) - n += hist[i + bt] - np.take(lut, ima, out=ima) - return ima - - # calculate histogram - def _histogram(ima, NUMS): - bins = list(range(0, NUMS)) - flat = ima.flat - n = np.searchsorted(np.sort(flat), bins) - n = np.concatenate([n, [len(flat)]]) - hist = n[1:] - n[:-1] - return hist - dtype = im.dtype.name - dtypes = ["uint8", "uint16", "float32"] + dtypes = ["uint8", "uint16", "uint32", "float32"] if dtype not in dtypes: - raise ValueError(f"'dtype' must be uint8/uint16/float32, not {dtype}.") - if dtype == "uint8": - return im - else: - if dtype == "float32": - im = _sample_norm(im) - return _two_percentLinear(im) + raise ValueError( + f"'dtype' must be uint8/uint16/uint32/float32, not {dtype}.") + if dtype != "uint8": + im = _sample_norm(im) + if is_linear: + im = _two_percent_linear(im) + return im def to_intensity(im):