[Clean] clean raster to transforms func

own
geoyee 3 years ago
parent 89197a668e
commit a7194612bf
  1. 1
      paddlers/datasets/__init__.py
  2. 135
      paddlers/transforms/functions.py
  3. 1
      paddlers/utils/__init__.py
  4. 95
      paddlers/utils/convert.py
  5. 7
      tools/mask2shp.py
  6. 7
      tools/spliter.py
  7. 19
      tools/utils/__init__.py
  8. 2
      tools/utils/raster.py

@ -16,4 +16,3 @@ from .voc import VOCDetection
from .seg_dataset import SegDataset from .seg_dataset import SegDataset
from .cd_dataset import CDDataset from .cd_dataset import CDDataset
from .clas_dataset import ClasDataset from .clas_dataset import ClasDataset
from .raster import Raster

@ -14,10 +14,11 @@
import cv2 import cv2
import numpy as np import numpy as np
import copy
import operator
import shapely.ops import shapely.ops
from shapely.geometry import Polygon, MultiPolygon, GeometryCollection from shapely.geometry import Polygon, MultiPolygon, GeometryCollection
import copy from functools import reduce
from sklearn.decomposition import PCA from sklearn.decomposition import PCA
@ -194,6 +195,122 @@ def resize_rle(rle, im_h, im_w, im_scale_x, im_scale_y, interp):
return rle return rle
def to_uint8(im):
""" Convert raster to uint8.
Args:
im (np.ndarray): The image.
Returns:
np.ndarray: Image on uint8.
"""
# 2% linear stretch
def _two_percentLinear(image, max_out=255, min_out=0):
def _gray_process(gray, maxout=max_out, minout=min_out):
# get the corresponding gray level at 98% histogram
high_value = np.percentile(gray, 98)
low_value = np.percentile(gray, 2)
truncated_gray = np.clip(gray, a_min=low_value, a_max=high_value)
processed_gray = ((truncated_gray - low_value) / (high_value - low_value)) * \
(maxout - minout)
return processed_gray
if len(image.shape) == 3:
processes = []
for b in range(image.shape[-1]):
processes.append(_gray_process(image[:, :, b]))
result = np.stack(processes, axis=2)
else: # if len(image.shape) == 2
result = _gray_process(image)
return np.uint8(result)
# simple image standardization
def _sample_norm(image, NUMS=65536):
stretches = []
if len(image.shape) == 3:
for b in range(image.shape[-1]):
stretched = _stretch(image[:, :, b], NUMS)
stretched /= float(NUMS)
stretches.append(stretched)
stretched_img = np.stack(stretches, axis=2)
else: # if len(image.shape) == 2
stretched_img = _stretch(image, NUMS)
return np.uint8(stretched_img * 255)
# histogram equalization
def _stretch(ima, NUMS):
hist = _histogram(ima, NUMS)
lut = []
for bt in range(0, len(hist), NUMS):
# step size
step = reduce(operator.add, hist[bt : bt + NUMS]) / (NUMS - 1)
# create balanced lookup table
n = 0
for i in range(NUMS):
lut.append(n / step)
n += hist[i + bt]
np.take(lut, ima, out=ima)
return ima
# calculate histogram
def _histogram(ima, NUMS):
bins = list(range(0, NUMS))
flat = ima.flat
n = np.searchsorted(np.sort(flat), bins)
n = np.concatenate([n, [len(flat)]])
hist = n[1:] - n[:-1]
return hist
dtype = im.dtype.name
dtypes = ["uint8", "uint16", "float32"]
if dtype not in dtypes:
raise ValueError(f"'dtype' must be uint8/uint16/float32, not {dtype}.")
if dtype == "uint8":
return im
else:
if dtype == "float32":
im = _sample_norm(im)
return _two_percentLinear(im)
def to_intensity(im):
""" calculate SAR data's intensity diagram.
Args:
im (np.ndarray): The SAR image.
Returns:
np.ndarray: Intensity diagram.
"""
if len(im.shape) != 2:
raise ValueError("im's shape must be 2.")
# the type is complex means this is a SAR data
if isinstance(type(im[0, 0]), complex):
im = abs(im)
return im
def select_bands(im, band_list=[1, 2, 3]):
""" Select bands.
Args:
im (np.ndarray): The image.
band_list (list, optional): Bands of selected (Start with 1). Defaults to [1, 2, 3].
Returns:
np.ndarray: The image after band selected.
"""
total_band = im.shape[-1]
result = []
for band in band_list:
band = int(band - 1)
if band < 0 or band >= total_band:
raise ValueError(
"The element in band_list must > 1 and <= {}.".format(str(total_band)))
result.append()
ima = np.stack(result, axis=0)
return ima
def matching(im1, im2): def matching(im1, im2):
""" Match two images, used change detection. (Just RGB) """ Match two images, used change detection. (Just RGB)
@ -214,8 +331,10 @@ def matching(im1, im2):
for m, n in mathces: for m, n in mathces:
if m.distance < 0.75 * n.distance: if m.distance < 0.75 * n.distance:
good_matches.append([m]) good_matches.append([m])
src_automatic_points = np.float32([kp1[m[0].queryIdx].pt for m in good_matches]).reshape(-1, 1, 2) src_automatic_points = np.float32([kp1[m[0].queryIdx].pt \
den_automatic_points = np.float32([kp2[m[0].trainIdx].pt for m in good_matches]).reshape(-1, 1, 2) for m in good_matches]).reshape(-1, 1, 2)
den_automatic_points = np.float32([kp2[m[0].trainIdx].pt \
for m in good_matches]).reshape(-1, 1, 2)
H, _ = cv2.findHomography(src_automatic_points, den_automatic_points, cv2.RANSAC, 5.0) H, _ = cv2.findHomography(src_automatic_points, den_automatic_points, cv2.RANSAC, 5.0)
im1_t = cv2.warpPerspective(im1, H, (im2.shape[1], im2.shape[0])) im1_t = cv2.warpPerspective(im1, H, (im2.shape[1], im2.shape[0]))
return im1_t, im2 return im1_t, im2
@ -231,7 +350,7 @@ def de_haze(im, gamma=False):
Returns: Returns:
np.ndarray: The image after defogged. np.ndarray: The image after defogged.
""" """
def guided_filter(I, p, r, eps): def _guided_filter(I, p, r, eps):
m_I = cv2.boxFilter(I, -1, (r, r)) m_I = cv2.boxFilter(I, -1, (r, r))
m_p = cv2.boxFilter(p, -1, (r, r)) m_p = cv2.boxFilter(p, -1, (r, r))
m_Ip = cv2.boxFilter(I * p, -1, (r, r)) m_Ip = cv2.boxFilter(I * p, -1, (r, r))
@ -244,11 +363,11 @@ def de_haze(im, gamma=False):
m_b = cv2.boxFilter(b, -1, (r, r)) m_b = cv2.boxFilter(b, -1, (r, r))
return m_a * I + m_b return m_a * I + m_b
def de_fog(im, r, w, maxatmo_mask, eps): def _de_fog(im, r, w, maxatmo_mask, eps):
# im is RGB and range[0, 1] # im is RGB and range[0, 1]
atmo_mask = np.min(im, 2) atmo_mask = np.min(im, 2)
dark_channel = cv2.erode(atmo_mask, np.ones((15, 15))) dark_channel = cv2.erode(atmo_mask, np.ones((15, 15)))
atmo_mask = guided_filter(atmo_mask, dark_channel, r, eps) atmo_mask = _guided_filter(atmo_mask, dark_channel, r, eps)
bins = 2000 bins = 2000
ht = np.histogram(atmo_mask, bins) ht = np.histogram(atmo_mask, bins)
d = np.cumsum(ht[0]) / float(atmo_mask.size) d = np.cumsum(ht[0]) / float(atmo_mask.size)
@ -262,7 +381,7 @@ def de_haze(im, gamma=False):
if np.max(im) > 1: if np.max(im) > 1:
im = im / 255. im = im / 255.
result = np.zeros(im.shape) result = np.zeros(im.shape)
mask_img, atmo_illum = de_fog(im, r=81, w=0.95, maxatmo_mask=0.80, eps=1e-8) mask_img, atmo_illum = _de_fog(im, r=81, w=0.95, maxatmo_mask=0.80, eps=1e-8)
for k in range(3): for k in range(3):
result[:, :, k] = (im[:, :, k] - mask_img) / (1 - mask_img / atmo_illum) result[:, :, k] = (im[:, :, k] - mask_img) / (1 - mask_img / atmo_illum)
result = np.clip(result, 0, 1) result = np.clip(result, 0, 1)

@ -22,4 +22,3 @@ from .env import get_environ_info, get_num_workers, init_parallel_env
from .download import download_and_decompress, decompress from .download import download_and_decompress, decompress
from .stats import SmoothedValue, TrainingStats from .stats import SmoothedValue, TrainingStats
from .shm import _get_shared_memory_size_in_M from .shm import _get_shared_memory_size_in_M
from .convert import raster2uint8

@ -1,95 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import operator
from functools import reduce
def raster2uint8(image: np.ndarray) -> np.ndarray:
""" Convert raster to uint8.
Args:
image (np.ndarray): image.
Returns:
np.ndarray: image on uint8.
"""
dtype = image.dtype.name
dtypes = ["uint8", "uint16", "float32"]
if dtype not in dtypes:
raise ValueError(f"'dtype' must be uint8/uint16/float32, not {dtype}.")
if dtype == "uint8":
return image
else:
if dtype == "float32":
image = _sample_norm(image)
return _two_percentLinear(image)
# 2% linear stretch
def _two_percentLinear(image: np.ndarray, max_out: int=255, min_out: int=0) -> np.ndarray:
def _gray_process(gray, maxout=max_out, minout=min_out):
# get the corresponding gray level at 98% histogram
high_value = np.percentile(gray, 98)
low_value = np.percentile(gray, 2)
truncated_gray = np.clip(gray, a_min=low_value, a_max=high_value)
processed_gray = ((truncated_gray - low_value) / (high_value - low_value)) * (maxout - minout)
return processed_gray
if len(image.shape) == 3:
processes = []
for b in range(image.shape[-1]):
processes.append(_gray_process(image[:, :, b]))
result = np.stack(processes, axis=2)
else: # if len(image.shape) == 2
result = _gray_process(image)
return np.uint8(result)
# simple image standardization
def _sample_norm(image: np.ndarray, NUMS: int=65536) -> np.ndarray:
stretches = []
if len(image.shape) == 3:
for b in range(image.shape[-1]):
stretched = _stretch(image[:, :, b], NUMS)
stretched /= float(NUMS)
stretches.append(stretched)
stretched_img = np.stack(stretches, axis=2)
else: # if len(image.shape) == 2
stretched_img = _stretch(image, NUMS)
return np.uint8(stretched_img * 255)
# histogram equalization
def _stretch(ima: np.ndarray, NUMS: int) -> np.ndarray:
hist = _histogram(ima, NUMS)
lut = []
for bt in range(0, len(hist), NUMS):
# step size
step = reduce(operator.add, hist[bt : bt + NUMS]) / (NUMS - 1)
# create balanced lookup table
n = 0
for i in range(NUMS):
lut.append(n / step)
n += hist[i + bt]
np.take(lut, ima, out=ima)
return ima
# calculate histogram
def _histogram(ima: np.ndarray, NUMS: int) -> np.ndarray:
bins = list(range(0, NUMS))
flat = ima.flat
n = np.searchsorted(np.sort(flat), bins)
n = np.concatenate([n, [len(flat)]])
hist = n[1:] - n[:-1]
return hist

@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
import os.path as osp
sys.path.insert(0, osp.abspath("..")) # add workspace
import os import os
import os.path as osp
import numpy as np import numpy as np
import argparse import argparse
from PIL import Image from PIL import Image
from paddlers.datasets.raster import Raster from utils import Raster
try: try:
from osgeo import gdal, ogr, osr from osgeo import gdal, ogr, osr

@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
import os.path as osp
sys.path.insert(0, osp.abspath("..")) # add workspace
import os import os
import os.path as osp
import argparse import argparse
from math import ceil from math import ceil
from PIL import Image from PIL import Image
from paddlers.datasets.raster import Raster from utils import Raster
def split_data(image_path, block_size, save_folder): def split_data(image_path, block_size, save_folder):

@ -0,0 +1,19 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os.path as osp
sys.path.insert(0, osp.abspath("..")) # add workspace
from .raster import Raster

@ -15,7 +15,7 @@
import os.path as osp import os.path as osp
import numpy as np import numpy as np
from typing import List, Tuple, Union from typing import List, Tuple, Union
from paddlers.utils import raster2uint8 from paddlers.transforms.functions import to_uint8 as raster2uint8
try: try:
from osgeo import gdal from osgeo import gdal
Loading…
Cancel
Save