You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

197 lines
7.4 KiB

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
import os.path as osp
from pathlib import Path
from typing import List, Tuple, Union, Optional
import webbrowser
import numpy as np
from folium import folium, Map, LayerControl
from folium.raster_layers import TileLayer, ImageOverlay
from paddlers.transforms.functions import to_uint8
try:
from osgeo import gdal, osr
except:
import gdal
import osr
CHINA_TILES = (
"GeoQ China Community",
"GeoQ China Street",
"AMAP China",
"TencentMap China",
"BaiduMaps China", )
def map_display(mask_path: str,
img_path: Optional[str]=None,
band_list: Union[List[int], Tuple[int, ...], None]=None,
save_path: Optional[str]=None,
tiles: str="GeoQ China Community") -> folium.Map:
"""
Show mask (and original image) on an online map.
Args:
mask_path (str): Path of predicted or ground-truth masks.
img_path (str|None, optional): Path of the original image. Defaults to None.
band_list (list[int]|tuple[int]|None, optional):
Bands to select from the original image for display (the band index starts from 1).
If None, use all bands. Defaults to None.
save_path (str, optional): Path of the .html file to save the visualization results.
In Jupyter Notebook environments,
leave `save_path` as None to display the result immediately in the notebook.
Defaults to None.
tiles (str): Map tileset to use. Chosen from the following list:
- "GeoQ China Community", "GeoQ China Street" (from http://www.geoq.cn/)
- "AMAP China" (from https://www.amap.com/)
- "TencentMap China" (from https://map.qq.com/)
- "BaiduMaps China" (from https://map.baidu.com/)
Defaults to "GeoQ China Community".
* All tilesets have been corrected through public algorithms from the Internet.
* Please read the relevant terms of use carefully:
- GeoQ [GISUNI] (http://geoq.cn/useragreement.html)
- AMap [AutoNavi] (https://wap.amap.com/doc/serviceitem.html)
- Tencent Map (https://ugc.map.qq.com/AppBox/Landlord/serveagreement.html)
- Baidu Map (https://map.baidu.com/zt/client/service/index.html)
Returns:
folium.Map: An example of folium map.
"""
if tiles not in CHINA_TILES:
raise ValueError("The `tiles` must in {}, not {}.".format(CHINA_TILES,
tiles))
fmap = Map(
tiles=tiles,
min_zoom=1,
max_zoom=24, )
if img_path is not None:
layer, _ = Raster(img_path, band_list).get_layer()
layer.add_to(fmap)
layer, center = Raster(mask_path).get_layer()
layer.add_to(fmap)
if center is not None:
fmap.location = center
fmap.fit_bounds(layer.bounds)
LayerControl().add_to(fmap)
if save_path:
fmap.save(save_path)
webbrowser.open(save_path)
return fmap
class OpenAsEPSG4326Error(Exception):
pass
class Raster:
def __init__(
self,
path: str,
band_list: Union[List[int], Tuple[int, ...], None]=None) -> None:
self.src_data = Converter.open_as_WGS84(path)
if self.src_data is None:
raise OpenAsEPSG4326Error("Faild to open {} in EPSG:4326.".format(
path))
self.name = Path(path).stem
self.set_bands(band_list)
self._get_info()
def get_layer(self) -> ImageOverlay:
layer = ImageOverlay(self._get_array(), self.wgs_range, name=self.name)
return layer, self.wgs_center
def set_bands(self,
band_list: Union[List[int], Tuple[int, ...], None]) -> None:
self.bands = self.src_data.RasterCount
if band_list is None:
if self.bands == 3:
band_list = [1, 2, 3]
else:
band_list = [1]
band_list_lens = len(band_list)
if band_list_lens not in (1, 3):
raise ValueError("The lenght of band_list must be 1 or 3, not {}.".
format(str(band_list_lens)))
if max(band_list) > self.bands or min(band_list) < 1:
raise ValueError("The range of band_list must within [1, {}].".
format(str(self.bands)))
self.band_list = band_list
def _get_info(self) -> None:
self.width = self.src_data.RasterXSize
self.height = self.src_data.RasterYSize
self.geotf = self.src_data.GetGeoTransform()
self.proj = self.src_data.GetProjection() # WGS84
self.wgs_range = self._get_WGS84_range()
self.wgs_center = self._get_WGS84_center()
def _get_WGS84_range(self) -> List[List[float]]:
converter = Converter(self.proj, self.geotf)
lat1, lon1 = converter.xy2latlon(self.height - 1, 0)
lat2, lon2 = converter.xy2latlon(0, self.width - 1)
return [[lon1, lat1], [lon2, lat2]]
def _get_WGS84_center(self) -> List[float]:
clat = (self.wgs_range[0][0] + self.wgs_range[1][0]) / 2
clon = (self.wgs_range[0][1] + self.wgs_range[1][1]) / 2
return [clat, clon]
def _get_array(self) -> np.ndarray:
band_array = []
for b in self.band_list:
band_i = self.src_data.GetRasterBand(b).ReadAsArray()
band_array.append(band_i)
ima = np.stack(band_array, axis=0)
if self.bands == 1:
# the type is complex means this is a SAR data
if isinstance(type(ima[0, 0]), complex):
ima = abs(ima)
ima = ima.squeeze()
else:
ima = ima.transpose((1, 2, 0))
ima = to_uint8(ima, True)
return ima
class Converter:
def __init__(self, proj: str, geotf: tuple) -> None:
# source data
self.source = osr.SpatialReference()
self.source.ImportFromWkt(proj)
self.geotf = geotf
# target data
self.target = osr.SpatialReference()
self.target.ImportFromEPSG(4326)
@classmethod
def open_as_WGS84(self, path: str) -> gdal.Dataset:
if not osp.exists(path):
raise FileNotFoundError("{} not found.".format(path))
result = gdal.Warp("", path, dstSRS="EPSG:4326", format="VRT")
return result
def xy2latlon(self, row: int, col: int) -> List[float]:
px = self.geotf[0] + col * self.geotf[1] + row * self.geotf[2]
py = self.geotf[3] + col * self.geotf[4] + row * self.geotf[5]
ct = osr.CoordinateTransformation(self.source, self.target)
coords = ct.TransformPoint(px, py)
return coords[:2]