fix(tools): fix and update some (#110)

* refactor(tools): update savegeotiff and timer

* fix(geojson2mask): add convert tiff to EPSG:4326

* fix(geojson2mask): fix use tiff's EPSG

* fix(mask2geojson): update hole and fix srs

* fix(tools): update merge raster2vector

* feat(tools): update spliter to support HSI

* fix(raster2vector): fix save geojson without ignore index

* fix(tools): name fixed
own
Yizhou Chen 3 years ago committed by GitHub
parent 3af03678d4
commit ec9d58bb0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      .gitignore
  2. 4
      tools/coco2mask.py
  3. 17
      tools/geojson2mask.py
  4. 81
      tools/mask2geojson.py
  5. 93
      tools/mask2shp.py
  6. 29
      tools/matcher.py
  7. 4
      tools/oif.py
  8. 4
      tools/pca.py
  9. 102
      tools/raster2vector.py
  10. 52
      tools/spliter.py
  11. 3
      tools/utils/__init__.py
  12. 88
      tools/utils/raster.py
  13. 14
      tools/utils/timer.py
  14. 53
      tools/utils/vector.py

3
.gitignore vendored

@ -126,6 +126,9 @@ venv.bak/
.dmypy.json .dmypy.json
dmypy.json dmypy.json
# myvscode
.vscode
# Pyre type checker # Pyre type checker
.pyre/ .pyre/

@ -25,7 +25,7 @@ import glob
from tqdm import tqdm from tqdm import tqdm
from PIL import Image from PIL import Image
from utils import Timer from utils import timer
def _mkdir_p(path): def _mkdir_p(path):
@ -69,7 +69,7 @@ def _read_geojson(json_path):
return annotations, sizes return annotations, sizes
@Timer @timer
def convert_data(raw_folder, end_folder): def convert_data(raw_folder, end_folder):
print("-- Initializing --") print("-- Initializing --")
img_folder = osp.join(raw_folder, "images") img_folder = osp.join(raw_folder, "images")

@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import codecs import codecs
import cv2 import cv2
import numpy as np import numpy as np
import argparse import argparse
import geojson import geojson
from tqdm import tqdm from tqdm import tqdm
from utils import Raster, save_geotiff, Timer from utils import Raster, save_geotiff, vector_translate, timer
def _gt_convert(x_geo, y_geo, geotf): def _gt_convert(x_geo, y_geo, geotf):
@ -27,12 +28,16 @@ def _gt_convert(x_geo, y_geo, geotf):
return np.round(np.linalg.solve(a, b)).tolist() # 解一元二次方程 return np.round(np.linalg.solve(a, b)).tolist() # 解一元二次方程
@Timer @timer
# TODO: update for vector2raster
def convert_data(image_path, geojson_path): def convert_data(image_path, geojson_path):
raster = Raster(image_path) raster = Raster(image_path)
tmp_img = np.zeros((raster.height, raster.width), dtype=np.int32) tmp_img = np.zeros((raster.height, raster.width), dtype=np.int32)
geo_reader = codecs.open(geojson_path, "r", encoding="utf-8") # vector to EPSG from raster
temp_geojson_path = vector_translate(geojson_path, raster.proj)
geo_reader = codecs.open(temp_geojson_path, "r", encoding="utf-8")
feats = geojson.loads(geo_reader.read())["features"] # 所有图像块 feats = geojson.loads(geo_reader.read())["features"] # 所有图像块
geo_reader.close()
for feat in tqdm(feats): for feat in tqdm(feats):
geo = feat["geometry"] geo = feat["geometry"]
if geo["type"] == "Polygon": # 多边形 if geo["type"] == "Polygon": # 多边形
@ -40,7 +45,8 @@ def convert_data(image_path, geojson_path):
elif geo["type"] == "MultiPolygon": # 多面 elif geo["type"] == "MultiPolygon": # 多面
geo_points = geo["coordinates"][0][0] geo_points = geo["coordinates"][0][0]
else: else:
raise TypeError("Geometry type must be `Polygon` or `MultiPolygon`, not {}.".format(geo["type"])) raise TypeError("Geometry type must be `Polygon` or `MultiPolygon`, not {}.".format(
geo["type"]))
xy_points = np.array([ xy_points = np.array([
_gt_convert(point[0], point[1], raster.geot) _gt_convert(point[0], point[1], raster.geot)
for point in geo_points for point in geo_points
@ -49,13 +55,14 @@ def convert_data(image_path, geojson_path):
cv2.fillPoly(tmp_img, [xy_points], 1) # 多边形填充 cv2.fillPoly(tmp_img, [xy_points], 1) # 多边形填充
ext = "." + geojson_path.split(".")[-1] ext = "." + geojson_path.split(".")[-1]
save_geotiff(tmp_img, geojson_path.replace(ext, ".tif"), raster.proj, raster.geot) save_geotiff(tmp_img, geojson_path.replace(ext, ".tif"), raster.proj, raster.geot)
os.remove(temp_geojson_path)
parser = argparse.ArgumentParser(description="input parameters") parser = argparse.ArgumentParser(description="input parameters")
parser.add_argument("--image_path", type=str, required=True, \ parser.add_argument("--image_path", type=str, required=True, \
help="The path of original image.") help="The path of original image.")
parser.add_argument("--geojson_path", type=str, required=True, \ parser.add_argument("--geojson_path", type=str, required=True, \
help="The path of geojson.") help="The path of geojson. (coordinate of geojson is WGS84)")
if __name__ == "__main__": if __name__ == "__main__":

@ -1,81 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import codecs
import argparse
import cv2
import numpy as np
import geojson
from geojson import Polygon, Feature, FeatureCollection
from utils import Raster, Timer
def _gt_convert(x, y, geotf):
x_geo = geotf[0] + x * geotf[1] + y * geotf[2]
y_geo = geotf[3] + x * geotf[4] + y * geotf[5]
return x_geo, y_geo
@Timer
def convert_data(mask_path, save_path, epsilon=0):
raster = Raster(mask_path)
img = raster.getArray()
ext = save_path.split(".")[-1]
if ext != "json" and ext != "geojson":
raise ValueError("The ext of `save_path` must be `json` or `geojson`, not {}.".format(ext))
geo_writer = codecs.open(save_path, "w", encoding="utf-8")
clas = np.unique(img)
cv2_v = (cv2.__version__.split(".")[0] == "3")
feats = []
if not isinstance(epsilon, (int, float)):
epsilon = 0
for iclas in range(1, len(clas)):
tmp = np.zeros_like(img).astype("uint8")
tmp[img == iclas] = 1
# TODO: Detect internal and external contour
results = cv2.findContours(tmp, cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_TC89_KCOS)
contours = results[1] if cv2_v else results[0]
# hierarchys = results[2] if cv2_v else results[1]
if len(contours) == 0:
continue
for contour in contours:
contour = cv2.approxPolyDP(contour, epsilon, True)
polys = []
for point in contour:
x, y = point[0]
xg, yg = _gt_convert(x, y, raster.geot)
polys.append((xg, yg))
polys.append(polys[0])
feat = Feature(
geometry=Polygon([polys]), properties={"class": int(iclas)})
feats.append(feat)
gjs = FeatureCollection(feats)
geo_writer.write(geojson.dumps(gjs))
geo_writer.close()
parser = argparse.ArgumentParser(description="input parameters")
parser.add_argument("--mask_path", type=str, required=True, \
help="The path of mask tif.")
parser.add_argument("--save_path", type=str, required=True, \
help="The path to save the results, file suffix is `*.json/geojson`.")
parser.add_argument("--epsilon", type=float, default=0, \
help="The CV2 simplified parameters, `0` is the default.")
if __name__ == "__main__":
args = parser.parse_args()
convert_data(args.mask_path, args.save_path, args.epsilon)

@ -1,93 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
import argparse
import numpy as np
from PIL import Image
try:
from osgeo import gdal, ogr, osr
except ImportError:
import gdal
import ogr
import osr
from utils import Raster, Timer
def _mask2tif(mask_path, tmp_path, proj, geot):
mask = np.asarray(Image.open(mask_path))
if len(mask.shape) == 3:
mask = mask[:, :, 0]
row, columns = mask.shape[:2]
driver = gdal.GetDriverByName("GTiff")
dst_ds = driver.Create(tmp_path, columns, row, 1, gdal.GDT_UInt16)
dst_ds.SetGeoTransform(geot)
dst_ds.SetProjection(proj)
dst_ds.GetRasterBand(1).WriteArray(mask)
dst_ds.FlushCache()
return dst_ds
def _polygonize_raster(mask_path, shp_save_path, proj, geot, ignore_index):
tmp_path = shp_save_path.replace(".shp", ".tif")
ds = _mask2tif(mask_path, tmp_path, proj, geot)
srcband = ds.GetRasterBand(1)
maskband = srcband.GetMaskBand()
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
gdal.SetConfigOption("SHAPE_ENCODING", "UTF-8")
ogr.RegisterAll()
drv = ogr.GetDriverByName("ESRI Shapefile")
if osp.exists(shp_save_path):
os.remove(shp_save_path)
dst_ds = drv.CreateDataSource(shp_save_path)
prosrs = osr.SpatialReference(wkt=ds.GetProjection())
dst_layer = dst_ds.CreateLayer(
"Building boundary", geom_type=ogr.wkbPolygon, srs=prosrs)
dst_fieldname = "DN"
fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
dst_layer.CreateField(fd)
gdal.Polygonize(srcband, maskband, dst_layer, 0, [])
lyr = dst_ds.GetLayer()
lyr.SetAttributeFilter("DN = '{}'".format(str(ignore_index)))
for holes in lyr:
lyr.DeleteFeature(holes.GetFID())
dst_ds.Destroy()
ds = None
os.remove(tmp_path)
@Timer
def raster2shp(srcimg_path, mask_path, save_path, ignore_index=255):
src = Raster(srcimg_path)
_polygonize_raster(mask_path, save_path, src.proj, src.geot, ignore_index)
src = None
parser = argparse.ArgumentParser(description="input parameters")
parser.add_argument("--srcimg_path", type=str, required=True, \
help="The path of original data with geoinfos.")
parser.add_argument("--mask_path", type=str, required=True, \
help="The path of mask data.")
parser.add_argument("--save_path", type=str, default="output", \
help="The path to save the results shapefile, `output` is the default.")
parser.add_argument("--ignore_index", type=int, default=255, \
help="It will not be converted to the value of SHP, `255` is the default.")
if __name__ == "__main__":
args = parser.parse_args()
raster2shp(args.srcimg_path, args.mask_path, args.save_path,
args.ignore_index)

@ -16,12 +16,8 @@ import argparse
import numpy as np import numpy as np
import cv2 import cv2
try:
from osgeo import gdal
except ImportError:
import gdal
from utils import Raster, raster2uint8, Timer from utils import Raster, raster2uint8, save_geotiff, timer
class MatchError(Exception): class MatchError(Exception):
def __str__(self): def __str__(self):
@ -64,26 +60,7 @@ def _get_match_img(raster, bands):
return ima return ima
def _img2tif(ima, save_path, proj, geot, dtype): @timer
if len(ima.shape) == 3:
row, columns, bands = ima.shape
else:
row, columns = ima.shape
bands = 1
driver = gdal.GetDriverByName("GTiff")
dst_ds = driver.Create(save_path, columns, row, bands, dtype)
dst_ds.SetGeoTransform(geot)
dst_ds.SetProjection(proj)
if bands != 1:
for b in range(bands):
dst_ds.GetRasterBand(b + 1).WriteArray(ima[:, :, b])
else:
dst_ds.GetRasterBand(1).WriteArray(ima)
dst_ds.FlushCache()
return dst_ds
@Timer
def matching(im1_path, im2_path, im1_bands=[1, 2, 3], im2_bands=[1, 2, 3]): def matching(im1_path, im2_path, im1_bands=[1, 2, 3], im2_bands=[1, 2, 3]):
im1_ras = Raster(im1_path) im1_ras = Raster(im1_path)
im2_ras = Raster(im2_path) im2_ras = Raster(im2_path)
@ -96,7 +73,7 @@ def matching(im1_path, im2_path, im1_bands=[1, 2, 3], im2_bands=[1, 2, 3]):
im2_arr_t = cv2.warpPerspective(im2_ras.getArray(), H, im2_arr_t = cv2.warpPerspective(im2_ras.getArray(), H,
(im1_ras.width, im1_ras.height)) (im1_ras.width, im1_ras.height))
save_path = im2_ras.path.replace(("." + im2_ras.ext_type), "_M.tif") save_path = im2_ras.path.replace(("." + im2_ras.ext_type), "_M.tif")
_img2tif(im2_arr_t, save_path, im1_ras.proj, im1_ras.geot, im1_ras.datatype) save_geotiff(im2_arr_t, save_path, im1_ras.proj, im1_ras.geot, im1_ras.datatype)
parser = argparse.ArgumentParser(description="input parameters") parser = argparse.ArgumentParser(description="input parameters")

@ -19,7 +19,7 @@ from easydict import EasyDict as edict
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from utils import Raster, Timer from utils import Raster, timer
def _calcOIF(rgb, stds, rho): def _calcOIF(rgb, stds, rho):
r, g, b = rgb r, g, b = rgb
@ -32,7 +32,7 @@ def _calcOIF(rgb, stds, rho):
return (s1 + s2 + s3) / (abs(r12) + abs(r23) + abs(r31)) return (s1 + s2 + s3) / (abs(r12) + abs(r23) + abs(r31))
@Timer @timer
def oif(img_path, topk=5): def oif(img_path, topk=5):
raster = Raster(img_path) raster = Raster(img_path)
img = raster.getArray() img = raster.getArray()

@ -18,10 +18,10 @@ import numpy as np
import argparse import argparse
from sklearn.decomposition import PCA from sklearn.decomposition import PCA
from joblib import dump from joblib import dump
from utils import Raster, Timer, save_geotiff from utils import Raster, save_geotiff, timer
@Timer @timer
def pca_train(img_path, save_dir="output", dim=3): def pca_train(img_path, save_dir="output", dim=3):
raster = Raster(img_path) raster = Raster(img_path)
im = raster.getArray() im = raster.getArray()

@ -0,0 +1,102 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
import argparse
import numpy as np
from PIL import Image
try:
from osgeo import gdal, ogr, osr
except ImportError:
import gdal
import ogr
import osr
from utils import Raster, save_geotiff, timer
def _mask2tif(mask_path, tmp_path, proj, geot):
dst_ds = save_geotiff(
np.asarray(Image.open(mask_path)),
tmp_path, proj, geot, gdal.GDT_UInt16, False)
return dst_ds
def _polygonize_raster(mask_path, vec_save_path, proj, geot, ignore_index, ext):
if proj is None or geot is None:
tmp_path = None
ds = gdal.Open(mask_path)
else:
tmp_path = vec_save_path.replace("." + ext, ".tif")
ds = _mask2tif(mask_path, tmp_path, proj, geot)
srcband = ds.GetRasterBand(1)
maskband = srcband.GetMaskBand()
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
gdal.SetConfigOption("SHAPE_ENCODING", "UTF-8")
ogr.RegisterAll()
drv = ogr.GetDriverByName(
"ESRI Shapefile" if ext == "shp" else "GeoJSON"
)
if osp.exists(vec_save_path):
os.remove(vec_save_path)
dst_ds = drv.CreateDataSource(vec_save_path)
prosrs = osr.SpatialReference(wkt=ds.GetProjection())
dst_layer = dst_ds.CreateLayer(
"POLYGON", geom_type=ogr.wkbPolygon, srs=prosrs)
dst_fieldname = "CLAS"
fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
dst_layer.CreateField(fd)
gdal.Polygonize(srcband, maskband, dst_layer, 0, [])
# TODO: temporary: delete ignored values
dst_ds.Destroy()
ds = None
vec_ds = drv.Open(vec_save_path, 1)
lyr = vec_ds.GetLayer()
lyr.SetAttributeFilter("{} = '{}'".format(dst_fieldname, str(ignore_index)))
for holes in lyr:
lyr.DeleteFeature(holes.GetFID())
vec_ds.Destroy()
if tmp_path is not None:
os.remove(tmp_path)
@timer
def raster2vector(srcimg_path, mask_path, save_path, ignore_index=255):
vec_ext = save_path.split(".")[-1].lower()
if vec_ext not in ["json", "geojson", "shp"]:
raise ValueError("The ext of `save_path` must be `json/geojson` or `shp`, not {}.".format(vec_ext))
ras_ext = srcimg_path.split(".")[-1].lower()
if osp.exists(srcimg_path) and ras_ext in ["tif", "tiff", "geotiff", "img"]:
src = Raster(srcimg_path)
_polygonize_raster(mask_path, save_path, src.proj, src.geot, ignore_index, vec_ext)
src = None
else:
_polygonize_raster(mask_path, save_path, None, None, ignore_index, vec_ext)
parser = argparse.ArgumentParser(description="input parameters")
parser.add_argument("--mask_path", type=str, required=True, \
help="The path of mask data.")
parser.add_argument("--save_path", type=str, required=True, \
help="The path to save the results, file suffix is `*.json/geojson` or `*.shp`.")
parser.add_argument("--srcimg_path", type=str, default="", \
help="The path of original data with geoinfos, `` is the default.")
parser.add_argument("--ignore_index", type=int, default=255, \
help="It will not be converted to the value of SHP, `255` is the default.")
if __name__ == "__main__":
args = parser.parse_args()
raster2vector(args.srcimg_path, args.mask_path, args.save_path, args.ignore_index)

@ -16,42 +16,52 @@ import os
import os.path as osp import os.path as osp
import argparse import argparse
from math import ceil from math import ceil
from tqdm import tqdm
from PIL import Image from utils import Raster, save_geotiff, timer
from utils import Raster, Timer
def _calc_window_tf(geot, loc):
x, hr, r1, y, r2, vr = geot
nx, ny = loc
return (x + nx * hr, hr, r1, y + ny * vr, r2, vr)
@Timer
@timer
def split_data(image_path, mask_path, block_size, save_folder): def split_data(image_path, mask_path, block_size, save_folder):
if not osp.exists(save_folder): if not osp.exists(save_folder):
os.makedirs(save_folder) os.makedirs(save_folder)
os.makedirs(osp.join(save_folder, "images")) os.makedirs(osp.join(save_folder, "images"))
if mask_path is not None: if mask_path is not None:
os.makedirs(osp.join(save_folder, "masks")) os.makedirs(osp.join(save_folder, "masks"))
image_name = image_path.replace("\\", "/").split("/")[-1].split(".")[0] image_name, image_ext = image_path.replace("\\", "/").split("/")[-1].split(".")
image = Raster(image_path, to_uint8=True) image = Raster(image_path)
mask = Raster(mask_path) if mask_path is not None else None mask = Raster(mask_path) if mask_path is not None else None
if image.width != mask.width or image.height != mask.height: if mask is not None and (image.width != mask.width or image.height != mask.height):
raise ValueError("image's shape must equal mask's shape.") raise ValueError("image's shape must equal mask's shape.")
rows = ceil(image.height / block_size) rows = ceil(image.height / block_size)
cols = ceil(image.width / block_size) cols = ceil(image.width / block_size)
total_number = int(rows * cols) total_number = int(rows * cols)
for r in range(rows):
for c in range(cols): with tqdm(total=total_number) as pbar:
loc_start = (c * block_size, r * block_size) for r in range(rows):
image_title = Image.fromarray(image.getArray( for c in range(cols):
loc_start, (block_size, block_size))).convert("RGB") loc_start = (c * block_size, r * block_size)
image_save_path = osp.join(save_folder, "images", ( image_title = image.getArray(loc_start, (block_size, block_size))
image_name + "_" + str(r) + "_" + str(c) + ".jpg")) image_save_path = osp.join(save_folder, "images", (
image_title.save(image_save_path, "JPEG") image_name + "_" + str(r) + "_" + str(c) + "." + image_ext))
if mask is not None: window_geotf = _calc_window_tf(image.geot, loc_start)
mask_title = Image.fromarray(mask.getArray( save_geotiff(
loc_start, (block_size, block_size))).convert("L") image_title, image_save_path, image.proj, window_geotf
mask_save_path = osp.join(save_folder, "masks", ( )
image_name + "_" + str(r) + "_" + str(c) + ".png")) if mask is not None:
mask_title.save(mask_save_path, "PNG") mask_title = mask.getArray(loc_start, (block_size, block_size))
print("-- {:d}/{:d} --".format(int(r * cols + c + 1), total_number)) mask_save_path = osp.join(save_folder, "masks", (
image_name + "_" + str(r) + "_" + str(c) + "." + image_ext))
save_geotiff(
mask_title, mask_save_path, image.proj, window_geotf
)
pbar.update(1)
parser = argparse.ArgumentParser(description="input parameters") parser = argparse.ArgumentParser(description="input parameters")

@ -17,4 +17,5 @@ import os.path as osp
sys.path.insert(0, osp.abspath("..")) # add workspace sys.path.insert(0, osp.abspath("..")) # add workspace
from .raster import Raster, raster2uint8, save_geotiff from .raster import Raster, raster2uint8, save_geotiff
from .timer import Timer from .vector import vector_translate
from .timer import timer

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import os.path as osp import os.path as osp
from typing import List, Tuple, Union from typing import List, Tuple, Union, Optional
import numpy as np import numpy as np
@ -49,36 +49,45 @@ def _get_type(type_name: str) -> int:
class Raster: class Raster:
def __init__(self, def __init__(self,
path: str, path: Optional[str],
gdal_obj: Optional[gdal.Dataset]=None,
band_list: Union[List[int], Tuple[int], None]=None, band_list: Union[List[int], Tuple[int], None]=None,
to_uint8: bool=False) -> None: to_uint8: bool=False) -> None:
""" Class of read raster. """ Class of read raster.
Args: Args:
path (str): The path of raster. path (Optional[str]): The path of raster.
gdal_obj (Optional[Any], optional): The object of GDAL. Defaults to None.
band_list (Union[List[int], Tuple[int], None], optional): band_list (Union[List[int], Tuple[int], None], optional):
band list (start with 1) or None (all of bands). Defaults to None. band list (start with 1) or None (all of bands). Defaults to None.
to_uint8 (bool, optional): to_uint8 (bool, optional):
Convert uint8 or return raw data. Defaults to False. Convert uint8 or return raw data. Defaults to False.
""" """
super(Raster, self).__init__() super(Raster, self).__init__()
if osp.exists(path): if path is not None:
self.path = path if osp.exists(path):
self.ext_type = path.split(".")[-1] self.path = path
if self.ext_type.lower() in ["npy", "npz"]: self.ext_type = path.split(".")[-1]
self._src_data = None if self.ext_type.lower() in ["npy", "npz"]:
self._src_data = None
else:
try:
# raster format support in GDAL:
# https://www.osgeo.cn/gdal/drivers/raster/index.html
self._src_data = gdal.Open(path)
except:
raise TypeError(
"Unsupported data format: `{}`".format(self.ext_type))
else: else:
try: raise ValueError("The path {0} not exists.".format(path))
# raster format support in GDAL:
# https://www.osgeo.cn/gdal/drivers/raster/index.html
self._src_data = gdal.Open(path)
except:
raise TypeError("Unsupported data format: `{}`".format(
self.ext_type))
self.to_uint8 = to_uint8
self.setBands(band_list)
self._getInfo()
else: else:
raise ValueError("The path {0} not exists.".format(path)) if gdal_obj is not None:
self._src_data = gdal_obj
else:
raise ValueError("At least one of `path` and `gdal_obj` is not None.")
self.to_uint8 = to_uint8
self._getInfo()
self.setBands(band_list)
self._getType()
def setBands(self, band_list: Union[List[int], Tuple[int], None]) -> None: def setBands(self, band_list: Union[List[int], Tuple[int], None]) -> None:
""" Set band of data. """ Set band of data.
@ -86,7 +95,6 @@ class Raster:
band_list (Union[List[int], Tuple[int], None]): band_list (Union[List[int], Tuple[int], None]):
band list (start with 1) or None (all of bands). band list (start with 1) or None (all of bands).
""" """
self.bands = self._src_data.RasterCount
if band_list is not None: if band_list is not None:
if len(band_list) > self.bands: if len(band_list) > self.bands:
raise ValueError( raise ValueError(
@ -99,8 +107,8 @@ class Raster:
def getArray( def getArray(
self, self,
start_loc: Union[List[int], Tuple[int], None]=None, start_loc: Union[List[int], Tuple[int, int], None]=None,
block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray: block_size: Union[List[int], Tuple[int, int]]=[512, 512]) -> np.ndarray:
""" Get ndarray data """ Get ndarray data
Args: Args:
start_loc (Union[List[int], Tuple[int], None], optional): start_loc (Union[List[int], Tuple[int], None], optional):
@ -123,13 +131,12 @@ class Raster:
if self._src_data is not None: if self._src_data is not None:
self.width = self._src_data.RasterXSize self.width = self._src_data.RasterXSize
self.height = self._src_data.RasterYSize self.height = self._src_data.RasterYSize
self.bands = self._src_data.RasterCount
self.geot = self._src_data.GetGeoTransform() self.geot = self._src_data.GetGeoTransform()
self.proj = self._src_data.GetProjection() self.proj = self._src_data.GetProjection()
d_name = self._getBlock([0, 0], [1, 1]).dtype.name
else: else:
d_img = self._getNumpy() d_img = self._getNumpy()
d_shape = d_img.shape d_shape = d_img.shape
d_name = d_img.dtype.name
if len(d_shape) == 3: if len(d_shape) == 3:
self.height, self.width, self.bands = d_shape self.height, self.width, self.bands = d_shape
else: else:
@ -137,6 +144,9 @@ class Raster:
self.bands = 1 self.bands = 1
self.geot = None self.geot = None
self.proj = None self.proj = None
def _getType(self) -> None:
d_name = self.getArray([0, 0], [1, 1]).dtype.name
self.datatype = _get_type(d_name) self.datatype = _get_type(d_name)
def _getNumpy(self): def _getNumpy(self):
@ -151,7 +161,9 @@ class Raster:
def _getArray( def _getArray(
self, self,
window: Union[None, List[int], Tuple[int]]=None) -> np.ndarray: window: Union[None, List[int], Tuple[int, int, int, int]]=None) -> np.ndarray:
if self._src_data is None:
raise ValueError("The raster is None.")
if window is not None: if window is not None:
xoff, yoff, xsize, ysize = window xoff, yoff, xsize, ysize = window
if self.band_list is None: if self.band_list is None:
@ -183,8 +195,8 @@ class Raster:
def _getBlock( def _getBlock(
self, self,
start_loc: Union[List[int], Tuple[int]], start_loc: Union[List[int], Tuple[int, int]],
block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray: block_size: Union[List[int], Tuple[int, int]]=[512, 512]) -> np.ndarray:
if len(start_loc) != 2 or len(block_size) != 2: if len(start_loc) != 2 or len(block_size) != 2:
raise ValueError("The length start_loc/block_size must be 2.") raise ValueError("The length start_loc/block_size must be 2.")
xoff, yoff = start_loc xoff, yoff = start_loc
@ -208,9 +220,21 @@ class Raster:
return tmp return tmp
def save_geotiff(image: np.ndarray, save_path: str, proj: str, geotf: Tuple) -> None: def save_geotiff(image: np.ndarray,
height, width, channel = image.shape save_path: str,
data_type = _get_type(image.dtype.name) proj: str,
geotf: Tuple,
use_type: Optional[int]=None,
clear_ds: bool=True) -> None:
if len(image.shape) == 2:
height, width = image.shape
channel = 1
else:
height, width, channel = image.shape
if use_type is not None:
data_type = use_type
else:
data_type = _get_type(image.dtype.name)
driver = gdal.GetDriverByName("GTiff") driver = gdal.GetDriverByName("GTiff")
dst_ds = driver.Create(save_path, width, height, channel, data_type) dst_ds = driver.Create(save_path, width, height, channel, data_type)
dst_ds.SetGeoTransform(geotf) dst_ds.SetGeoTransform(geotf)
@ -224,4 +248,6 @@ def save_geotiff(image: np.ndarray, save_path: str, proj: str, geotf: Tuple) ->
band = dst_ds.GetRasterBand(1) band = dst_ds.GetRasterBand(1)
band.WriteArray(image) band.WriteArray(image)
dst_ds.FlushCache() dst_ds.FlushCache()
dst_ds = None if clear_ds:
dst_ds = None
return dst_ds

@ -13,14 +13,14 @@
# limitations under the License. # limitations under the License.
import time import time
from functools import wraps
class Timer(object): def timer(func):
def __init__(self, func): @wraps(func)
self.func = func def wrapper(*args,**kwargs):
def __call__(self, *args, **kwds):
start_time = time.time() start_time = time.time()
func_t = self.func(*args, **kwds) result = func(*args,**kwargs)
print("Total time: {0}.".format(time.time() - start_time)) print("Total time: {0}.".format(time.time() - start_time))
return func_t return result
return wrapper

@ -0,0 +1,53 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# reference: https://zhuanlan.zhihu.com/p/378918221
try:
from osgeo import gdal, ogr, osr
except:
import gdal
import ogr
import osr
def vector_translate(geojson_path: str,
wo_wkt: str,
g_type: str="POLYGON",
dim: str="XY") -> str:
ogr.RegisterAll()
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
data = ogr.Open(geojson_path)
layer = data.GetLayer()
spatial = layer.GetSpatialRef()
layerName = layer.GetName()
data.Destroy()
dstSRS = osr.SpatialReference()
dstSRS.ImportFromWkt(wo_wkt)
ext = "." + geojson_path.split(".")[-1]
save_path = geojson_path.replace(ext, ("_tmp" + ext))
options = gdal.VectorTranslateOptions(
srcSRS=spatial,
dstSRS=dstSRS,
reproject=True,
layerName=layerName,
geometryType=g_type,
dim=dim
)
gdal.VectorTranslate(
save_path,
srcDS=geojson_path,
options=options
)
return save_path
Loading…
Cancel
Save