fix(tools): fix and update some (#110)

* refactor(tools): update savegeotiff and timer

* fix(geojson2mask): add convert tiff to EPSG:4326

* fix(geojson2mask): fix use tiff's EPSG

* fix(mask2geojson): update hole and fix srs

* fix(tools): update merge raster2vector

* feat(tools): update spliter to support HSI

* fix(raster2vector): fix save geojson without ignore index

* fix(tools): name fixed
own
Yizhou Chen 2 years ago committed by GitHub
parent 3af03678d4
commit ec9d58bb0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      .gitignore
  2. 4
      tools/coco2mask.py
  3. 17
      tools/geojson2mask.py
  4. 81
      tools/mask2geojson.py
  5. 93
      tools/mask2shp.py
  6. 29
      tools/matcher.py
  7. 4
      tools/oif.py
  8. 4
      tools/pca.py
  9. 102
      tools/raster2vector.py
  10. 40
      tools/spliter.py
  11. 3
      tools/utils/__init__.py
  12. 60
      tools/utils/raster.py
  13. 14
      tools/utils/timer.py
  14. 53
      tools/utils/vector.py

3
.gitignore vendored

@ -126,6 +126,9 @@ venv.bak/
.dmypy.json
dmypy.json
# myvscode
.vscode
# Pyre type checker
.pyre/

@ -25,7 +25,7 @@ import glob
from tqdm import tqdm
from PIL import Image
from utils import Timer
from utils import timer
def _mkdir_p(path):
@ -69,7 +69,7 @@ def _read_geojson(json_path):
return annotations, sizes
@Timer
@timer
def convert_data(raw_folder, end_folder):
print("-- Initializing --")
img_folder = osp.join(raw_folder, "images")

@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import codecs
import cv2
import numpy as np
import argparse
import geojson
from tqdm import tqdm
from utils import Raster, save_geotiff, Timer
from utils import Raster, save_geotiff, vector_translate, timer
def _gt_convert(x_geo, y_geo, geotf):
@ -27,12 +28,16 @@ def _gt_convert(x_geo, y_geo, geotf):
return np.round(np.linalg.solve(a, b)).tolist() # 解一元二次方程
@Timer
@timer
# TODO: update for vector2raster
def convert_data(image_path, geojson_path):
raster = Raster(image_path)
tmp_img = np.zeros((raster.height, raster.width), dtype=np.int32)
geo_reader = codecs.open(geojson_path, "r", encoding="utf-8")
# vector to EPSG from raster
temp_geojson_path = vector_translate(geojson_path, raster.proj)
geo_reader = codecs.open(temp_geojson_path, "r", encoding="utf-8")
feats = geojson.loads(geo_reader.read())["features"] # 所有图像块
geo_reader.close()
for feat in tqdm(feats):
geo = feat["geometry"]
if geo["type"] == "Polygon": # 多边形
@ -40,7 +45,8 @@ def convert_data(image_path, geojson_path):
elif geo["type"] == "MultiPolygon": # 多面
geo_points = geo["coordinates"][0][0]
else:
raise TypeError("Geometry type must be `Polygon` or `MultiPolygon`, not {}.".format(geo["type"]))
raise TypeError("Geometry type must be `Polygon` or `MultiPolygon`, not {}.".format(
geo["type"]))
xy_points = np.array([
_gt_convert(point[0], point[1], raster.geot)
for point in geo_points
@ -49,13 +55,14 @@ def convert_data(image_path, geojson_path):
cv2.fillPoly(tmp_img, [xy_points], 1) # 多边形填充
ext = "." + geojson_path.split(".")[-1]
save_geotiff(tmp_img, geojson_path.replace(ext, ".tif"), raster.proj, raster.geot)
os.remove(temp_geojson_path)
parser = argparse.ArgumentParser(description="input parameters")
parser.add_argument("--image_path", type=str, required=True, \
help="The path of original image.")
parser.add_argument("--geojson_path", type=str, required=True, \
help="The path of geojson.")
help="The path of geojson. (coordinate of geojson is WGS84)")
if __name__ == "__main__":

@ -1,81 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import codecs
import argparse
import cv2
import numpy as np
import geojson
from geojson import Polygon, Feature, FeatureCollection
from utils import Raster, Timer
def _gt_convert(x, y, geotf):
x_geo = geotf[0] + x * geotf[1] + y * geotf[2]
y_geo = geotf[3] + x * geotf[4] + y * geotf[5]
return x_geo, y_geo
@Timer
def convert_data(mask_path, save_path, epsilon=0):
raster = Raster(mask_path)
img = raster.getArray()
ext = save_path.split(".")[-1]
if ext != "json" and ext != "geojson":
raise ValueError("The ext of `save_path` must be `json` or `geojson`, not {}.".format(ext))
geo_writer = codecs.open(save_path, "w", encoding="utf-8")
clas = np.unique(img)
cv2_v = (cv2.__version__.split(".")[0] == "3")
feats = []
if not isinstance(epsilon, (int, float)):
epsilon = 0
for iclas in range(1, len(clas)):
tmp = np.zeros_like(img).astype("uint8")
tmp[img == iclas] = 1
# TODO: Detect internal and external contour
results = cv2.findContours(tmp, cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_TC89_KCOS)
contours = results[1] if cv2_v else results[0]
# hierarchys = results[2] if cv2_v else results[1]
if len(contours) == 0:
continue
for contour in contours:
contour = cv2.approxPolyDP(contour, epsilon, True)
polys = []
for point in contour:
x, y = point[0]
xg, yg = _gt_convert(x, y, raster.geot)
polys.append((xg, yg))
polys.append(polys[0])
feat = Feature(
geometry=Polygon([polys]), properties={"class": int(iclas)})
feats.append(feat)
gjs = FeatureCollection(feats)
geo_writer.write(geojson.dumps(gjs))
geo_writer.close()
parser = argparse.ArgumentParser(description="input parameters")
parser.add_argument("--mask_path", type=str, required=True, \
help="The path of mask tif.")
parser.add_argument("--save_path", type=str, required=True, \
help="The path to save the results, file suffix is `*.json/geojson`.")
parser.add_argument("--epsilon", type=float, default=0, \
help="The CV2 simplified parameters, `0` is the default.")
if __name__ == "__main__":
args = parser.parse_args()
convert_data(args.mask_path, args.save_path, args.epsilon)

@ -1,93 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
import argparse
import numpy as np
from PIL import Image
try:
from osgeo import gdal, ogr, osr
except ImportError:
import gdal
import ogr
import osr
from utils import Raster, Timer
def _mask2tif(mask_path, tmp_path, proj, geot):
mask = np.asarray(Image.open(mask_path))
if len(mask.shape) == 3:
mask = mask[:, :, 0]
row, columns = mask.shape[:2]
driver = gdal.GetDriverByName("GTiff")
dst_ds = driver.Create(tmp_path, columns, row, 1, gdal.GDT_UInt16)
dst_ds.SetGeoTransform(geot)
dst_ds.SetProjection(proj)
dst_ds.GetRasterBand(1).WriteArray(mask)
dst_ds.FlushCache()
return dst_ds
def _polygonize_raster(mask_path, shp_save_path, proj, geot, ignore_index):
tmp_path = shp_save_path.replace(".shp", ".tif")
ds = _mask2tif(mask_path, tmp_path, proj, geot)
srcband = ds.GetRasterBand(1)
maskband = srcband.GetMaskBand()
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
gdal.SetConfigOption("SHAPE_ENCODING", "UTF-8")
ogr.RegisterAll()
drv = ogr.GetDriverByName("ESRI Shapefile")
if osp.exists(shp_save_path):
os.remove(shp_save_path)
dst_ds = drv.CreateDataSource(shp_save_path)
prosrs = osr.SpatialReference(wkt=ds.GetProjection())
dst_layer = dst_ds.CreateLayer(
"Building boundary", geom_type=ogr.wkbPolygon, srs=prosrs)
dst_fieldname = "DN"
fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
dst_layer.CreateField(fd)
gdal.Polygonize(srcband, maskband, dst_layer, 0, [])
lyr = dst_ds.GetLayer()
lyr.SetAttributeFilter("DN = '{}'".format(str(ignore_index)))
for holes in lyr:
lyr.DeleteFeature(holes.GetFID())
dst_ds.Destroy()
ds = None
os.remove(tmp_path)
@Timer
def raster2shp(srcimg_path, mask_path, save_path, ignore_index=255):
src = Raster(srcimg_path)
_polygonize_raster(mask_path, save_path, src.proj, src.geot, ignore_index)
src = None
parser = argparse.ArgumentParser(description="input parameters")
parser.add_argument("--srcimg_path", type=str, required=True, \
help="The path of original data with geoinfos.")
parser.add_argument("--mask_path", type=str, required=True, \
help="The path of mask data.")
parser.add_argument("--save_path", type=str, default="output", \
help="The path to save the results shapefile, `output` is the default.")
parser.add_argument("--ignore_index", type=int, default=255, \
help="It will not be converted to the value of SHP, `255` is the default.")
if __name__ == "__main__":
args = parser.parse_args()
raster2shp(args.srcimg_path, args.mask_path, args.save_path,
args.ignore_index)

@ -16,12 +16,8 @@ import argparse
import numpy as np
import cv2
try:
from osgeo import gdal
except ImportError:
import gdal
from utils import Raster, raster2uint8, Timer
from utils import Raster, raster2uint8, save_geotiff, timer
class MatchError(Exception):
def __str__(self):
@ -64,26 +60,7 @@ def _get_match_img(raster, bands):
return ima
def _img2tif(ima, save_path, proj, geot, dtype):
if len(ima.shape) == 3:
row, columns, bands = ima.shape
else:
row, columns = ima.shape
bands = 1
driver = gdal.GetDriverByName("GTiff")
dst_ds = driver.Create(save_path, columns, row, bands, dtype)
dst_ds.SetGeoTransform(geot)
dst_ds.SetProjection(proj)
if bands != 1:
for b in range(bands):
dst_ds.GetRasterBand(b + 1).WriteArray(ima[:, :, b])
else:
dst_ds.GetRasterBand(1).WriteArray(ima)
dst_ds.FlushCache()
return dst_ds
@Timer
@timer
def matching(im1_path, im2_path, im1_bands=[1, 2, 3], im2_bands=[1, 2, 3]):
im1_ras = Raster(im1_path)
im2_ras = Raster(im2_path)
@ -96,7 +73,7 @@ def matching(im1_path, im2_path, im1_bands=[1, 2, 3], im2_bands=[1, 2, 3]):
im2_arr_t = cv2.warpPerspective(im2_ras.getArray(), H,
(im1_ras.width, im1_ras.height))
save_path = im2_ras.path.replace(("." + im2_ras.ext_type), "_M.tif")
_img2tif(im2_arr_t, save_path, im1_ras.proj, im1_ras.geot, im1_ras.datatype)
save_geotiff(im2_arr_t, save_path, im1_ras.proj, im1_ras.geot, im1_ras.datatype)
parser = argparse.ArgumentParser(description="input parameters")

@ -19,7 +19,7 @@ from easydict import EasyDict as edict
import numpy as np
import pandas as pd
from utils import Raster, Timer
from utils import Raster, timer
def _calcOIF(rgb, stds, rho):
r, g, b = rgb
@ -32,7 +32,7 @@ def _calcOIF(rgb, stds, rho):
return (s1 + s2 + s3) / (abs(r12) + abs(r23) + abs(r31))
@Timer
@timer
def oif(img_path, topk=5):
raster = Raster(img_path)
img = raster.getArray()

@ -18,10 +18,10 @@ import numpy as np
import argparse
from sklearn.decomposition import PCA
from joblib import dump
from utils import Raster, Timer, save_geotiff
from utils import Raster, save_geotiff, timer
@Timer
@timer
def pca_train(img_path, save_dir="output", dim=3):
raster = Raster(img_path)
im = raster.getArray()

@ -0,0 +1,102 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
import argparse
import numpy as np
from PIL import Image
try:
from osgeo import gdal, ogr, osr
except ImportError:
import gdal
import ogr
import osr
from utils import Raster, save_geotiff, timer
def _mask2tif(mask_path, tmp_path, proj, geot):
dst_ds = save_geotiff(
np.asarray(Image.open(mask_path)),
tmp_path, proj, geot, gdal.GDT_UInt16, False)
return dst_ds
def _polygonize_raster(mask_path, vec_save_path, proj, geot, ignore_index, ext):
if proj is None or geot is None:
tmp_path = None
ds = gdal.Open(mask_path)
else:
tmp_path = vec_save_path.replace("." + ext, ".tif")
ds = _mask2tif(mask_path, tmp_path, proj, geot)
srcband = ds.GetRasterBand(1)
maskband = srcband.GetMaskBand()
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
gdal.SetConfigOption("SHAPE_ENCODING", "UTF-8")
ogr.RegisterAll()
drv = ogr.GetDriverByName(
"ESRI Shapefile" if ext == "shp" else "GeoJSON"
)
if osp.exists(vec_save_path):
os.remove(vec_save_path)
dst_ds = drv.CreateDataSource(vec_save_path)
prosrs = osr.SpatialReference(wkt=ds.GetProjection())
dst_layer = dst_ds.CreateLayer(
"POLYGON", geom_type=ogr.wkbPolygon, srs=prosrs)
dst_fieldname = "CLAS"
fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
dst_layer.CreateField(fd)
gdal.Polygonize(srcband, maskband, dst_layer, 0, [])
# TODO: temporary: delete ignored values
dst_ds.Destroy()
ds = None
vec_ds = drv.Open(vec_save_path, 1)
lyr = vec_ds.GetLayer()
lyr.SetAttributeFilter("{} = '{}'".format(dst_fieldname, str(ignore_index)))
for holes in lyr:
lyr.DeleteFeature(holes.GetFID())
vec_ds.Destroy()
if tmp_path is not None:
os.remove(tmp_path)
@timer
def raster2vector(srcimg_path, mask_path, save_path, ignore_index=255):
vec_ext = save_path.split(".")[-1].lower()
if vec_ext not in ["json", "geojson", "shp"]:
raise ValueError("The ext of `save_path` must be `json/geojson` or `shp`, not {}.".format(vec_ext))
ras_ext = srcimg_path.split(".")[-1].lower()
if osp.exists(srcimg_path) and ras_ext in ["tif", "tiff", "geotiff", "img"]:
src = Raster(srcimg_path)
_polygonize_raster(mask_path, save_path, src.proj, src.geot, ignore_index, vec_ext)
src = None
else:
_polygonize_raster(mask_path, save_path, None, None, ignore_index, vec_ext)
parser = argparse.ArgumentParser(description="input parameters")
parser.add_argument("--mask_path", type=str, required=True, \
help="The path of mask data.")
parser.add_argument("--save_path", type=str, required=True, \
help="The path to save the results, file suffix is `*.json/geojson` or `*.shp`.")
parser.add_argument("--srcimg_path", type=str, default="", \
help="The path of original data with geoinfos, `` is the default.")
parser.add_argument("--ignore_index", type=int, default=255, \
help="It will not be converted to the value of SHP, `255` is the default.")
if __name__ == "__main__":
args = parser.parse_args()
raster2vector(args.srcimg_path, args.mask_path, args.save_path, args.ignore_index)

@ -16,42 +16,52 @@ import os
import os.path as osp
import argparse
from math import ceil
from tqdm import tqdm
from PIL import Image
from utils import Raster, save_geotiff, timer
from utils import Raster, Timer
def _calc_window_tf(geot, loc):
x, hr, r1, y, r2, vr = geot
nx, ny = loc
return (x + nx * hr, hr, r1, y + ny * vr, r2, vr)
@Timer
@timer
def split_data(image_path, mask_path, block_size, save_folder):
if not osp.exists(save_folder):
os.makedirs(save_folder)
os.makedirs(osp.join(save_folder, "images"))
if mask_path is not None:
os.makedirs(osp.join(save_folder, "masks"))
image_name = image_path.replace("\\", "/").split("/")[-1].split(".")[0]
image = Raster(image_path, to_uint8=True)
image_name, image_ext = image_path.replace("\\", "/").split("/")[-1].split(".")
image = Raster(image_path)
mask = Raster(mask_path) if mask_path is not None else None
if image.width != mask.width or image.height != mask.height:
if mask is not None and (image.width != mask.width or image.height != mask.height):
raise ValueError("image's shape must equal mask's shape.")
rows = ceil(image.height / block_size)
cols = ceil(image.width / block_size)
total_number = int(rows * cols)
with tqdm(total=total_number) as pbar:
for r in range(rows):
for c in range(cols):
loc_start = (c * block_size, r * block_size)
image_title = Image.fromarray(image.getArray(
loc_start, (block_size, block_size))).convert("RGB")
image_title = image.getArray(loc_start, (block_size, block_size))
image_save_path = osp.join(save_folder, "images", (
image_name + "_" + str(r) + "_" + str(c) + ".jpg"))
image_title.save(image_save_path, "JPEG")
image_name + "_" + str(r) + "_" + str(c) + "." + image_ext))
window_geotf = _calc_window_tf(image.geot, loc_start)
save_geotiff(
image_title, image_save_path, image.proj, window_geotf
)
if mask is not None:
mask_title = Image.fromarray(mask.getArray(
loc_start, (block_size, block_size))).convert("L")
mask_title = mask.getArray(loc_start, (block_size, block_size))
mask_save_path = osp.join(save_folder, "masks", (
image_name + "_" + str(r) + "_" + str(c) + ".png"))
mask_title.save(mask_save_path, "PNG")
print("-- {:d}/{:d} --".format(int(r * cols + c + 1), total_number))
image_name + "_" + str(r) + "_" + str(c) + "." + image_ext))
save_geotiff(
mask_title, mask_save_path, image.proj, window_geotf
)
pbar.update(1)
parser = argparse.ArgumentParser(description="input parameters")

@ -17,4 +17,5 @@ import os.path as osp
sys.path.insert(0, osp.abspath("..")) # add workspace
from .raster import Raster, raster2uint8, save_geotiff
from .timer import Timer
from .vector import vector_translate
from .timer import timer

@ -13,7 +13,7 @@
# limitations under the License.
import os.path as osp
from typing import List, Tuple, Union
from typing import List, Tuple, Union, Optional
import numpy as np
@ -49,18 +49,21 @@ def _get_type(type_name: str) -> int:
class Raster:
def __init__(self,
path: str,
path: Optional[str],
gdal_obj: Optional[gdal.Dataset]=None,
band_list: Union[List[int], Tuple[int], None]=None,
to_uint8: bool=False) -> None:
""" Class of read raster.
Args:
path (str): The path of raster.
path (Optional[str]): The path of raster.
gdal_obj (Optional[Any], optional): The object of GDAL. Defaults to None.
band_list (Union[List[int], Tuple[int], None], optional):
band list (start with 1) or None (all of bands). Defaults to None.
to_uint8 (bool, optional):
Convert uint8 or return raw data. Defaults to False.
"""
super(Raster, self).__init__()
if path is not None:
if osp.exists(path):
self.path = path
self.ext_type = path.split(".")[-1]
@ -72,13 +75,19 @@ class Raster:
# https://www.osgeo.cn/gdal/drivers/raster/index.html
self._src_data = gdal.Open(path)
except:
raise TypeError("Unsupported data format: `{}`".format(
self.ext_type))
self.to_uint8 = to_uint8
self.setBands(band_list)
self._getInfo()
raise TypeError(
"Unsupported data format: `{}`".format(self.ext_type))
else:
raise ValueError("The path {0} not exists.".format(path))
else:
if gdal_obj is not None:
self._src_data = gdal_obj
else:
raise ValueError("At least one of `path` and `gdal_obj` is not None.")
self.to_uint8 = to_uint8
self._getInfo()
self.setBands(band_list)
self._getType()
def setBands(self, band_list: Union[List[int], Tuple[int], None]) -> None:
""" Set band of data.
@ -86,7 +95,6 @@ class Raster:
band_list (Union[List[int], Tuple[int], None]):
band list (start with 1) or None (all of bands).
"""
self.bands = self._src_data.RasterCount
if band_list is not None:
if len(band_list) > self.bands:
raise ValueError(
@ -99,8 +107,8 @@ class Raster:
def getArray(
self,
start_loc: Union[List[int], Tuple[int], None]=None,
block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray:
start_loc: Union[List[int], Tuple[int, int], None]=None,
block_size: Union[List[int], Tuple[int, int]]=[512, 512]) -> np.ndarray:
""" Get ndarray data
Args:
start_loc (Union[List[int], Tuple[int], None], optional):
@ -123,13 +131,12 @@ class Raster:
if self._src_data is not None:
self.width = self._src_data.RasterXSize
self.height = self._src_data.RasterYSize
self.bands = self._src_data.RasterCount
self.geot = self._src_data.GetGeoTransform()
self.proj = self._src_data.GetProjection()
d_name = self._getBlock([0, 0], [1, 1]).dtype.name
else:
d_img = self._getNumpy()
d_shape = d_img.shape
d_name = d_img.dtype.name
if len(d_shape) == 3:
self.height, self.width, self.bands = d_shape
else:
@ -137,6 +144,9 @@ class Raster:
self.bands = 1
self.geot = None
self.proj = None
def _getType(self) -> None:
d_name = self.getArray([0, 0], [1, 1]).dtype.name
self.datatype = _get_type(d_name)
def _getNumpy(self):
@ -151,7 +161,9 @@ class Raster:
def _getArray(
self,
window: Union[None, List[int], Tuple[int]]=None) -> np.ndarray:
window: Union[None, List[int], Tuple[int, int, int, int]]=None) -> np.ndarray:
if self._src_data is None:
raise ValueError("The raster is None.")
if window is not None:
xoff, yoff, xsize, ysize = window
if self.band_list is None:
@ -183,8 +195,8 @@ class Raster:
def _getBlock(
self,
start_loc: Union[List[int], Tuple[int]],
block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray:
start_loc: Union[List[int], Tuple[int, int]],
block_size: Union[List[int], Tuple[int, int]]=[512, 512]) -> np.ndarray:
if len(start_loc) != 2 or len(block_size) != 2:
raise ValueError("The length start_loc/block_size must be 2.")
xoff, yoff = start_loc
@ -208,8 +220,20 @@ class Raster:
return tmp
def save_geotiff(image: np.ndarray, save_path: str, proj: str, geotf: Tuple) -> None:
def save_geotiff(image: np.ndarray,
save_path: str,
proj: str,
geotf: Tuple,
use_type: Optional[int]=None,
clear_ds: bool=True) -> None:
if len(image.shape) == 2:
height, width = image.shape
channel = 1
else:
height, width, channel = image.shape
if use_type is not None:
data_type = use_type
else:
data_type = _get_type(image.dtype.name)
driver = gdal.GetDriverByName("GTiff")
dst_ds = driver.Create(save_path, width, height, channel, data_type)
@ -224,4 +248,6 @@ def save_geotiff(image: np.ndarray, save_path: str, proj: str, geotf: Tuple) ->
band = dst_ds.GetRasterBand(1)
band.WriteArray(image)
dst_ds.FlushCache()
if clear_ds:
dst_ds = None
return dst_ds

@ -13,14 +13,14 @@
# limitations under the License.
import time
from functools import wraps
class Timer(object):
def __init__(self, func):
self.func = func
def __call__(self, *args, **kwds):
def timer(func):
@wraps(func)
def wrapper(*args,**kwargs):
start_time = time.time()
func_t = self.func(*args, **kwds)
result = func(*args,**kwargs)
print("Total time: {0}.".format(time.time() - start_time))
return func_t
return result
return wrapper

@ -0,0 +1,53 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# reference: https://zhuanlan.zhihu.com/p/378918221
try:
from osgeo import gdal, ogr, osr
except:
import gdal
import ogr
import osr
def vector_translate(geojson_path: str,
wo_wkt: str,
g_type: str="POLYGON",
dim: str="XY") -> str:
ogr.RegisterAll()
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
data = ogr.Open(geojson_path)
layer = data.GetLayer()
spatial = layer.GetSpatialRef()
layerName = layer.GetName()
data.Destroy()
dstSRS = osr.SpatialReference()
dstSRS.ImportFromWkt(wo_wkt)
ext = "." + geojson_path.split(".")[-1]
save_path = geojson_path.replace(ext, ("_tmp" + ext))
options = gdal.VectorTranslateOptions(
srcSRS=spatial,
dstSRS=dstSRS,
reproject=True,
layerName=layerName,
geometryType=g_type,
dim=dim
)
gdal.VectorTranslate(
save_path,
srcDS=geojson_path,
options=options
)
return save_path
Loading…
Cancel
Save