diff --git a/.gitignore b/.gitignore index 43b91b0..f615de4 100644 --- a/.gitignore +++ b/.gitignore @@ -126,6 +126,9 @@ venv.bak/ .dmypy.json dmypy.json +# myvscode +.vscode + # Pyre type checker .pyre/ diff --git a/tools/coco2mask.py b/tools/coco2mask.py index 36529b1..a380dd0 100644 --- a/tools/coco2mask.py +++ b/tools/coco2mask.py @@ -25,7 +25,7 @@ import glob from tqdm import tqdm from PIL import Image -from utils import Timer +from utils import timer def _mkdir_p(path): @@ -69,7 +69,7 @@ def _read_geojson(json_path): return annotations, sizes -@Timer +@timer def convert_data(raw_folder, end_folder): print("-- Initializing --") img_folder = osp.join(raw_folder, "images") diff --git a/tools/geojson2mask.py b/tools/geojson2mask.py index 8ffcc0a..f14b328 100644 --- a/tools/geojson2mask.py +++ b/tools/geojson2mask.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import codecs import cv2 import numpy as np import argparse import geojson from tqdm import tqdm -from utils import Raster, save_geotiff, Timer +from utils import Raster, save_geotiff, vector_translate, timer def _gt_convert(x_geo, y_geo, geotf): @@ -27,12 +28,16 @@ def _gt_convert(x_geo, y_geo, geotf): return np.round(np.linalg.solve(a, b)).tolist() # 解一元二次方程 -@Timer +@timer +# TODO: update for vector2raster def convert_data(image_path, geojson_path): raster = Raster(image_path) tmp_img = np.zeros((raster.height, raster.width), dtype=np.int32) - geo_reader = codecs.open(geojson_path, "r", encoding="utf-8") + # vector to EPSG from raster + temp_geojson_path = vector_translate(geojson_path, raster.proj) + geo_reader = codecs.open(temp_geojson_path, "r", encoding="utf-8") feats = geojson.loads(geo_reader.read())["features"] # 所有图像块 + geo_reader.close() for feat in tqdm(feats): geo = feat["geometry"] if geo["type"] == "Polygon": # 多边形 @@ -40,7 +45,8 @@ def convert_data(image_path, geojson_path): elif geo["type"] == "MultiPolygon": # 多面 geo_points = geo["coordinates"][0][0] else: - raise TypeError("Geometry type must be `Polygon` or `MultiPolygon`, not {}.".format(geo["type"])) + raise TypeError("Geometry type must be `Polygon` or `MultiPolygon`, not {}.".format( + geo["type"])) xy_points = np.array([ _gt_convert(point[0], point[1], raster.geot) for point in geo_points @@ -49,13 +55,14 @@ def convert_data(image_path, geojson_path): cv2.fillPoly(tmp_img, [xy_points], 1) # 多边形填充 ext = "." + geojson_path.split(".")[-1] save_geotiff(tmp_img, geojson_path.replace(ext, ".tif"), raster.proj, raster.geot) + os.remove(temp_geojson_path) parser = argparse.ArgumentParser(description="input parameters") parser.add_argument("--image_path", type=str, required=True, \ help="The path of original image.") parser.add_argument("--geojson_path", type=str, required=True, \ - help="The path of geojson.") + help="The path of geojson. (coordinate of geojson is WGS84)") if __name__ == "__main__": diff --git a/tools/mask2geojson.py b/tools/mask2geojson.py deleted file mode 100644 index f73f331..0000000 --- a/tools/mask2geojson.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import codecs -import argparse - -import cv2 -import numpy as np -import geojson -from geojson import Polygon, Feature, FeatureCollection - -from utils import Raster, Timer - - -def _gt_convert(x, y, geotf): - x_geo = geotf[0] + x * geotf[1] + y * geotf[2] - y_geo = geotf[3] + x * geotf[4] + y * geotf[5] - return x_geo, y_geo - - -@Timer -def convert_data(mask_path, save_path, epsilon=0): - raster = Raster(mask_path) - img = raster.getArray() - ext = save_path.split(".")[-1] - if ext != "json" and ext != "geojson": - raise ValueError("The ext of `save_path` must be `json` or `geojson`, not {}.".format(ext)) - geo_writer = codecs.open(save_path, "w", encoding="utf-8") - clas = np.unique(img) - cv2_v = (cv2.__version__.split(".")[0] == "3") - feats = [] - if not isinstance(epsilon, (int, float)): - epsilon = 0 - for iclas in range(1, len(clas)): - tmp = np.zeros_like(img).astype("uint8") - tmp[img == iclas] = 1 - # TODO: Detect internal and external contour - results = cv2.findContours(tmp, cv2.RETR_EXTERNAL, - cv2.CHAIN_APPROX_TC89_KCOS) - contours = results[1] if cv2_v else results[0] - # hierarchys = results[2] if cv2_v else results[1] - if len(contours) == 0: - continue - for contour in contours: - contour = cv2.approxPolyDP(contour, epsilon, True) - polys = [] - for point in contour: - x, y = point[0] - xg, yg = _gt_convert(x, y, raster.geot) - polys.append((xg, yg)) - polys.append(polys[0]) - feat = Feature( - geometry=Polygon([polys]), properties={"class": int(iclas)}) - feats.append(feat) - gjs = FeatureCollection(feats) - geo_writer.write(geojson.dumps(gjs)) - geo_writer.close() - - -parser = argparse.ArgumentParser(description="input parameters") -parser.add_argument("--mask_path", type=str, required=True, \ - help="The path of mask tif.") -parser.add_argument("--save_path", type=str, required=True, \ - help="The path to save the results, file suffix is `*.json/geojson`.") -parser.add_argument("--epsilon", type=float, default=0, \ - help="The CV2 simplified parameters, `0` is the default.") - -if __name__ == "__main__": - args = parser.parse_args() - convert_data(args.mask_path, args.save_path, args.epsilon) diff --git a/tools/mask2shp.py b/tools/mask2shp.py deleted file mode 100644 index 926e297..0000000 --- a/tools/mask2shp.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import os.path as osp -import argparse - -import numpy as np -from PIL import Image -try: - from osgeo import gdal, ogr, osr -except ImportError: - import gdal - import ogr - import osr - -from utils import Raster, Timer - - -def _mask2tif(mask_path, tmp_path, proj, geot): - mask = np.asarray(Image.open(mask_path)) - if len(mask.shape) == 3: - mask = mask[:, :, 0] - row, columns = mask.shape[:2] - driver = gdal.GetDriverByName("GTiff") - dst_ds = driver.Create(tmp_path, columns, row, 1, gdal.GDT_UInt16) - dst_ds.SetGeoTransform(geot) - dst_ds.SetProjection(proj) - dst_ds.GetRasterBand(1).WriteArray(mask) - dst_ds.FlushCache() - return dst_ds - - -def _polygonize_raster(mask_path, shp_save_path, proj, geot, ignore_index): - tmp_path = shp_save_path.replace(".shp", ".tif") - ds = _mask2tif(mask_path, tmp_path, proj, geot) - srcband = ds.GetRasterBand(1) - maskband = srcband.GetMaskBand() - gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES") - gdal.SetConfigOption("SHAPE_ENCODING", "UTF-8") - ogr.RegisterAll() - drv = ogr.GetDriverByName("ESRI Shapefile") - if osp.exists(shp_save_path): - os.remove(shp_save_path) - dst_ds = drv.CreateDataSource(shp_save_path) - prosrs = osr.SpatialReference(wkt=ds.GetProjection()) - dst_layer = dst_ds.CreateLayer( - "Building boundary", geom_type=ogr.wkbPolygon, srs=prosrs) - dst_fieldname = "DN" - fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger) - dst_layer.CreateField(fd) - gdal.Polygonize(srcband, maskband, dst_layer, 0, []) - lyr = dst_ds.GetLayer() - lyr.SetAttributeFilter("DN = '{}'".format(str(ignore_index))) - for holes in lyr: - lyr.DeleteFeature(holes.GetFID()) - dst_ds.Destroy() - ds = None - os.remove(tmp_path) - - -@Timer -def raster2shp(srcimg_path, mask_path, save_path, ignore_index=255): - src = Raster(srcimg_path) - _polygonize_raster(mask_path, save_path, src.proj, src.geot, ignore_index) - src = None - - -parser = argparse.ArgumentParser(description="input parameters") -parser.add_argument("--srcimg_path", type=str, required=True, \ - help="The path of original data with geoinfos.") -parser.add_argument("--mask_path", type=str, required=True, \ - help="The path of mask data.") -parser.add_argument("--save_path", type=str, default="output", \ - help="The path to save the results shapefile, `output` is the default.") -parser.add_argument("--ignore_index", type=int, default=255, \ - help="It will not be converted to the value of SHP, `255` is the default.") - -if __name__ == "__main__": - args = parser.parse_args() - raster2shp(args.srcimg_path, args.mask_path, args.save_path, - args.ignore_index) diff --git a/tools/matcher.py b/tools/matcher.py index 72b2574..3a04b68 100644 --- a/tools/matcher.py +++ b/tools/matcher.py @@ -16,12 +16,8 @@ import argparse import numpy as np import cv2 -try: - from osgeo import gdal -except ImportError: - import gdal -from utils import Raster, raster2uint8, Timer +from utils import Raster, raster2uint8, save_geotiff, timer class MatchError(Exception): def __str__(self): @@ -64,26 +60,7 @@ def _get_match_img(raster, bands): return ima -def _img2tif(ima, save_path, proj, geot, dtype): - if len(ima.shape) == 3: - row, columns, bands = ima.shape - else: - row, columns = ima.shape - bands = 1 - driver = gdal.GetDriverByName("GTiff") - dst_ds = driver.Create(save_path, columns, row, bands, dtype) - dst_ds.SetGeoTransform(geot) - dst_ds.SetProjection(proj) - if bands != 1: - for b in range(bands): - dst_ds.GetRasterBand(b + 1).WriteArray(ima[:, :, b]) - else: - dst_ds.GetRasterBand(1).WriteArray(ima) - dst_ds.FlushCache() - return dst_ds - - -@Timer +@timer def matching(im1_path, im2_path, im1_bands=[1, 2, 3], im2_bands=[1, 2, 3]): im1_ras = Raster(im1_path) im2_ras = Raster(im2_path) @@ -96,7 +73,7 @@ def matching(im1_path, im2_path, im1_bands=[1, 2, 3], im2_bands=[1, 2, 3]): im2_arr_t = cv2.warpPerspective(im2_ras.getArray(), H, (im1_ras.width, im1_ras.height)) save_path = im2_ras.path.replace(("." + im2_ras.ext_type), "_M.tif") - _img2tif(im2_arr_t, save_path, im1_ras.proj, im1_ras.geot, im1_ras.datatype) + save_geotiff(im2_arr_t, save_path, im1_ras.proj, im1_ras.geot, im1_ras.datatype) parser = argparse.ArgumentParser(description="input parameters") diff --git a/tools/oif.py b/tools/oif.py index 97e418a..959be50 100644 --- a/tools/oif.py +++ b/tools/oif.py @@ -19,7 +19,7 @@ from easydict import EasyDict as edict import numpy as np import pandas as pd -from utils import Raster, Timer +from utils import Raster, timer def _calcOIF(rgb, stds, rho): r, g, b = rgb @@ -32,7 +32,7 @@ def _calcOIF(rgb, stds, rho): return (s1 + s2 + s3) / (abs(r12) + abs(r23) + abs(r31)) -@Timer +@timer def oif(img_path, topk=5): raster = Raster(img_path) img = raster.getArray() diff --git a/tools/pca.py b/tools/pca.py index a3c3063..2bda64d 100644 --- a/tools/pca.py +++ b/tools/pca.py @@ -18,10 +18,10 @@ import numpy as np import argparse from sklearn.decomposition import PCA from joblib import dump -from utils import Raster, Timer, save_geotiff +from utils import Raster, save_geotiff, timer -@Timer +@timer def pca_train(img_path, save_dir="output", dim=3): raster = Raster(img_path) im = raster.getArray() diff --git a/tools/raster2vector.py b/tools/raster2vector.py new file mode 100644 index 0000000..ec21531 --- /dev/null +++ b/tools/raster2vector.py @@ -0,0 +1,102 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import os.path as osp +import argparse + +import numpy as np +from PIL import Image +try: + from osgeo import gdal, ogr, osr +except ImportError: + import gdal + import ogr + import osr + +from utils import Raster, save_geotiff, timer + + +def _mask2tif(mask_path, tmp_path, proj, geot): + dst_ds = save_geotiff( + np.asarray(Image.open(mask_path)), + tmp_path, proj, geot, gdal.GDT_UInt16, False) + return dst_ds + + +def _polygonize_raster(mask_path, vec_save_path, proj, geot, ignore_index, ext): + if proj is None or geot is None: + tmp_path = None + ds = gdal.Open(mask_path) + else: + tmp_path = vec_save_path.replace("." + ext, ".tif") + ds = _mask2tif(mask_path, tmp_path, proj, geot) + srcband = ds.GetRasterBand(1) + maskband = srcband.GetMaskBand() + gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES") + gdal.SetConfigOption("SHAPE_ENCODING", "UTF-8") + ogr.RegisterAll() + drv = ogr.GetDriverByName( + "ESRI Shapefile" if ext == "shp" else "GeoJSON" + ) + if osp.exists(vec_save_path): + os.remove(vec_save_path) + dst_ds = drv.CreateDataSource(vec_save_path) + prosrs = osr.SpatialReference(wkt=ds.GetProjection()) + dst_layer = dst_ds.CreateLayer( + "POLYGON", geom_type=ogr.wkbPolygon, srs=prosrs) + dst_fieldname = "CLAS" + fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger) + dst_layer.CreateField(fd) + gdal.Polygonize(srcband, maskband, dst_layer, 0, []) + # TODO: temporary: delete ignored values + dst_ds.Destroy() + ds = None + vec_ds = drv.Open(vec_save_path, 1) + lyr = vec_ds.GetLayer() + lyr.SetAttributeFilter("{} = '{}'".format(dst_fieldname, str(ignore_index))) + for holes in lyr: + lyr.DeleteFeature(holes.GetFID()) + vec_ds.Destroy() + if tmp_path is not None: + os.remove(tmp_path) + + +@timer +def raster2vector(srcimg_path, mask_path, save_path, ignore_index=255): + vec_ext = save_path.split(".")[-1].lower() + if vec_ext not in ["json", "geojson", "shp"]: + raise ValueError("The ext of `save_path` must be `json/geojson` or `shp`, not {}.".format(vec_ext)) + ras_ext = srcimg_path.split(".")[-1].lower() + if osp.exists(srcimg_path) and ras_ext in ["tif", "tiff", "geotiff", "img"]: + src = Raster(srcimg_path) + _polygonize_raster(mask_path, save_path, src.proj, src.geot, ignore_index, vec_ext) + src = None + else: + _polygonize_raster(mask_path, save_path, None, None, ignore_index, vec_ext) + + +parser = argparse.ArgumentParser(description="input parameters") +parser.add_argument("--mask_path", type=str, required=True, \ + help="The path of mask data.") +parser.add_argument("--save_path", type=str, required=True, \ + help="The path to save the results, file suffix is `*.json/geojson` or `*.shp`.") +parser.add_argument("--srcimg_path", type=str, default="", \ + help="The path of original data with geoinfos, `` is the default.") +parser.add_argument("--ignore_index", type=int, default=255, \ + help="It will not be converted to the value of SHP, `255` is the default.") + +if __name__ == "__main__": + args = parser.parse_args() + raster2vector(args.srcimg_path, args.mask_path, args.save_path, args.ignore_index) diff --git a/tools/spliter.py b/tools/spliter.py index 3a32599..60fad71 100644 --- a/tools/spliter.py +++ b/tools/spliter.py @@ -16,42 +16,52 @@ import os import os.path as osp import argparse from math import ceil +from tqdm import tqdm -from PIL import Image +from utils import Raster, save_geotiff, timer -from utils import Raster, Timer +def _calc_window_tf(geot, loc): + x, hr, r1, y, r2, vr = geot + nx, ny = loc + return (x + nx * hr, hr, r1, y + ny * vr, r2, vr) -@Timer + +@timer def split_data(image_path, mask_path, block_size, save_folder): if not osp.exists(save_folder): os.makedirs(save_folder) os.makedirs(osp.join(save_folder, "images")) if mask_path is not None: os.makedirs(osp.join(save_folder, "masks")) - image_name = image_path.replace("\\", "/").split("/")[-1].split(".")[0] - image = Raster(image_path, to_uint8=True) + image_name, image_ext = image_path.replace("\\", "/").split("/")[-1].split(".") + image = Raster(image_path) mask = Raster(mask_path) if mask_path is not None else None - if image.width != mask.width or image.height != mask.height: + if mask is not None and (image.width != mask.width or image.height != mask.height): raise ValueError("image's shape must equal mask's shape.") rows = ceil(image.height / block_size) cols = ceil(image.width / block_size) total_number = int(rows * cols) - for r in range(rows): - for c in range(cols): - loc_start = (c * block_size, r * block_size) - image_title = Image.fromarray(image.getArray( - loc_start, (block_size, block_size))).convert("RGB") - image_save_path = osp.join(save_folder, "images", ( - image_name + "_" + str(r) + "_" + str(c) + ".jpg")) - image_title.save(image_save_path, "JPEG") - if mask is not None: - mask_title = Image.fromarray(mask.getArray( - loc_start, (block_size, block_size))).convert("L") - mask_save_path = osp.join(save_folder, "masks", ( - image_name + "_" + str(r) + "_" + str(c) + ".png")) - mask_title.save(mask_save_path, "PNG") - print("-- {:d}/{:d} --".format(int(r * cols + c + 1), total_number)) + + with tqdm(total=total_number) as pbar: + for r in range(rows): + for c in range(cols): + loc_start = (c * block_size, r * block_size) + image_title = image.getArray(loc_start, (block_size, block_size)) + image_save_path = osp.join(save_folder, "images", ( + image_name + "_" + str(r) + "_" + str(c) + "." + image_ext)) + window_geotf = _calc_window_tf(image.geot, loc_start) + save_geotiff( + image_title, image_save_path, image.proj, window_geotf + ) + if mask is not None: + mask_title = mask.getArray(loc_start, (block_size, block_size)) + mask_save_path = osp.join(save_folder, "masks", ( + image_name + "_" + str(r) + "_" + str(c) + "." + image_ext)) + save_geotiff( + mask_title, mask_save_path, image.proj, window_geotf + ) + pbar.update(1) parser = argparse.ArgumentParser(description="input parameters") diff --git a/tools/utils/__init__.py b/tools/utils/__init__.py index 208b605..fd17a5e 100644 --- a/tools/utils/__init__.py +++ b/tools/utils/__init__.py @@ -17,4 +17,5 @@ import os.path as osp sys.path.insert(0, osp.abspath("..")) # add workspace from .raster import Raster, raster2uint8, save_geotiff -from .timer import Timer +from .vector import vector_translate +from .timer import timer diff --git a/tools/utils/raster.py b/tools/utils/raster.py index e6b7c93..868fd3f 100644 --- a/tools/utils/raster.py +++ b/tools/utils/raster.py @@ -13,7 +13,7 @@ # limitations under the License. import os.path as osp -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Optional import numpy as np @@ -49,36 +49,45 @@ def _get_type(type_name: str) -> int: class Raster: def __init__(self, - path: str, + path: Optional[str], + gdal_obj: Optional[gdal.Dataset]=None, band_list: Union[List[int], Tuple[int], None]=None, to_uint8: bool=False) -> None: """ Class of read raster. Args: - path (str): The path of raster. + path (Optional[str]): The path of raster. + gdal_obj (Optional[Any], optional): The object of GDAL. Defaults to None. band_list (Union[List[int], Tuple[int], None], optional): band list (start with 1) or None (all of bands). Defaults to None. to_uint8 (bool, optional): Convert uint8 or return raw data. Defaults to False. """ super(Raster, self).__init__() - if osp.exists(path): - self.path = path - self.ext_type = path.split(".")[-1] - if self.ext_type.lower() in ["npy", "npz"]: - self._src_data = None + if path is not None: + if osp.exists(path): + self.path = path + self.ext_type = path.split(".")[-1] + if self.ext_type.lower() in ["npy", "npz"]: + self._src_data = None + else: + try: + # raster format support in GDAL: + # https://www.osgeo.cn/gdal/drivers/raster/index.html + self._src_data = gdal.Open(path) + except: + raise TypeError( + "Unsupported data format: `{}`".format(self.ext_type)) else: - try: - # raster format support in GDAL: - # https://www.osgeo.cn/gdal/drivers/raster/index.html - self._src_data = gdal.Open(path) - except: - raise TypeError("Unsupported data format: `{}`".format( - self.ext_type)) - self.to_uint8 = to_uint8 - self.setBands(band_list) - self._getInfo() + raise ValueError("The path {0} not exists.".format(path)) else: - raise ValueError("The path {0} not exists.".format(path)) + if gdal_obj is not None: + self._src_data = gdal_obj + else: + raise ValueError("At least one of `path` and `gdal_obj` is not None.") + self.to_uint8 = to_uint8 + self._getInfo() + self.setBands(band_list) + self._getType() def setBands(self, band_list: Union[List[int], Tuple[int], None]) -> None: """ Set band of data. @@ -86,7 +95,6 @@ class Raster: band_list (Union[List[int], Tuple[int], None]): band list (start with 1) or None (all of bands). """ - self.bands = self._src_data.RasterCount if band_list is not None: if len(band_list) > self.bands: raise ValueError( @@ -99,8 +107,8 @@ class Raster: def getArray( self, - start_loc: Union[List[int], Tuple[int], None]=None, - block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray: + start_loc: Union[List[int], Tuple[int, int], None]=None, + block_size: Union[List[int], Tuple[int, int]]=[512, 512]) -> np.ndarray: """ Get ndarray data Args: start_loc (Union[List[int], Tuple[int], None], optional): @@ -123,13 +131,12 @@ class Raster: if self._src_data is not None: self.width = self._src_data.RasterXSize self.height = self._src_data.RasterYSize + self.bands = self._src_data.RasterCount self.geot = self._src_data.GetGeoTransform() self.proj = self._src_data.GetProjection() - d_name = self._getBlock([0, 0], [1, 1]).dtype.name else: d_img = self._getNumpy() d_shape = d_img.shape - d_name = d_img.dtype.name if len(d_shape) == 3: self.height, self.width, self.bands = d_shape else: @@ -137,6 +144,9 @@ class Raster: self.bands = 1 self.geot = None self.proj = None + + def _getType(self) -> None: + d_name = self.getArray([0, 0], [1, 1]).dtype.name self.datatype = _get_type(d_name) def _getNumpy(self): @@ -151,7 +161,9 @@ class Raster: def _getArray( self, - window: Union[None, List[int], Tuple[int]]=None) -> np.ndarray: + window: Union[None, List[int], Tuple[int, int, int, int]]=None) -> np.ndarray: + if self._src_data is None: + raise ValueError("The raster is None.") if window is not None: xoff, yoff, xsize, ysize = window if self.band_list is None: @@ -183,8 +195,8 @@ class Raster: def _getBlock( self, - start_loc: Union[List[int], Tuple[int]], - block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray: + start_loc: Union[List[int], Tuple[int, int]], + block_size: Union[List[int], Tuple[int, int]]=[512, 512]) -> np.ndarray: if len(start_loc) != 2 or len(block_size) != 2: raise ValueError("The length start_loc/block_size must be 2.") xoff, yoff = start_loc @@ -208,9 +220,21 @@ class Raster: return tmp -def save_geotiff(image: np.ndarray, save_path: str, proj: str, geotf: Tuple) -> None: - height, width, channel = image.shape - data_type = _get_type(image.dtype.name) +def save_geotiff(image: np.ndarray, + save_path: str, + proj: str, + geotf: Tuple, + use_type: Optional[int]=None, + clear_ds: bool=True) -> None: + if len(image.shape) == 2: + height, width = image.shape + channel = 1 + else: + height, width, channel = image.shape + if use_type is not None: + data_type = use_type + else: + data_type = _get_type(image.dtype.name) driver = gdal.GetDriverByName("GTiff") dst_ds = driver.Create(save_path, width, height, channel, data_type) dst_ds.SetGeoTransform(geotf) @@ -224,4 +248,6 @@ def save_geotiff(image: np.ndarray, save_path: str, proj: str, geotf: Tuple) -> band = dst_ds.GetRasterBand(1) band.WriteArray(image) dst_ds.FlushCache() - dst_ds = None + if clear_ds: + dst_ds = None + return dst_ds diff --git a/tools/utils/timer.py b/tools/utils/timer.py index bbec344..568d589 100644 --- a/tools/utils/timer.py +++ b/tools/utils/timer.py @@ -13,14 +13,14 @@ # limitations under the License. import time +from functools import wraps -class Timer(object): - def __init__(self, func): - self.func = func - - def __call__(self, *args, **kwds): +def timer(func): + @wraps(func) + def wrapper(*args,**kwargs): start_time = time.time() - func_t = self.func(*args, **kwds) + result = func(*args,**kwargs) print("Total time: {0}.".format(time.time() - start_time)) - return func_t + return result + return wrapper diff --git a/tools/utils/vector.py b/tools/utils/vector.py new file mode 100644 index 0000000..f6bb298 --- /dev/null +++ b/tools/utils/vector.py @@ -0,0 +1,53 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# reference: https://zhuanlan.zhihu.com/p/378918221 + +try: + from osgeo import gdal, ogr, osr +except: + import gdal + import ogr + import osr + + +def vector_translate(geojson_path: str, + wo_wkt: str, + g_type: str="POLYGON", + dim: str="XY") -> str: + ogr.RegisterAll() + gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES") + data = ogr.Open(geojson_path) + layer = data.GetLayer() + spatial = layer.GetSpatialRef() + layerName = layer.GetName() + data.Destroy() + dstSRS = osr.SpatialReference() + dstSRS.ImportFromWkt(wo_wkt) + ext = "." + geojson_path.split(".")[-1] + save_path = geojson_path.replace(ext, ("_tmp" + ext)) + options = gdal.VectorTranslateOptions( + srcSRS=spatial, + dstSRS=dstSRS, + reproject=True, + layerName=layerName, + geometryType=g_type, + dim=dim + ) + gdal.VectorTranslate( + save_path, + srcDS=geojson_path, + options=options + ) + return save_path \ No newline at end of file