[Feat] Add display (#41)
parent
820e5e66d7
commit
f403abcadd
5 changed files with 228 additions and 18 deletions
Binary file not shown.
Before Width: | Height: | Size: 171 KiB |
@ -0,0 +1,201 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import warnings |
||||
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning) |
||||
|
||||
import os.path as osp |
||||
from pathlib import Path |
||||
from typing import List, Tuple, Union, Optional |
||||
import webbrowser |
||||
import numpy as np |
||||
from folium import folium, Map, LayerControl |
||||
from folium.raster_layers import TileLayer, ImageOverlay |
||||
from paddlers.transforms.functions import to_uint8 |
||||
|
||||
try: |
||||
from osgeo import gdal, osr |
||||
except: |
||||
import gdal |
||||
import osr |
||||
|
||||
CHINATILES = ( |
||||
"GeoQ China Community", |
||||
"GeoQ China Street", |
||||
"AMAP China", |
||||
"TencentMap China", |
||||
"BaiduMaps China", ) |
||||
|
||||
|
||||
def map_display(mask_path: str, |
||||
img_path: Optional[str]=None, |
||||
band_list: Union[List[int], Tuple[int, ...], None]=None, |
||||
save_path: Optional[str]=None, |
||||
tiles: str="GeoQ China Community") -> folium.Map: |
||||
""" |
||||
Show mask (and original image) on an online map. |
||||
|
||||
Args: |
||||
mask_path (str): Path of predicted or ground-truth masks. |
||||
img_path (str|None, optional): Path of the original image. Defaults to None. |
||||
band_list (list[int]|tuple[int]|None, optional): |
||||
Bands to select from the original image for display (the band index starts from 1). |
||||
If None, use all bands. Defaults to None. |
||||
save_path (str, optional): Path of the .html file to save the visualization results. |
||||
In Jupyter Notebook environments, |
||||
leave `save_path` as None to display the result immediately in the notebook. |
||||
Defaults to None. |
||||
tiles (str): Map tileset to use. Due to boundary question, |
||||
the map tileset provided by folium cannot be selected. |
||||
* Only can choose from this list: |
||||
- "GeoQ China Community", "GeoQ China Street" (from http://www.geoq.cn/) |
||||
- "AMAP China" (from https://www.amap.com/) |
||||
- "TencentMap China" (from https://map.qq.com/) |
||||
- "BaiduMaps China" (from https://map.baidu.com/) { |
||||
It is worked in some cases, but it is not recommended. |
||||
} |
||||
|
||||
Defaults to "GeoQ China Community". |
||||
|
||||
These tileset have been corrected through the public algorithm on the Internet. |
||||
* Please read the relevant terms of use carefully: |
||||
- GeoQ [GISUNI] (http://geoq.cn/useragreement.html) |
||||
- AMap [AutoNavi] (https://wap.amap.com/doc/serviceitem.html) |
||||
- Tencent Map (https://ugc.map.qq.com/AppBox/Landlord/serveagreement.html) |
||||
- Baidu Map (https://map.baidu.com/zt/client/service/index.html) |
||||
|
||||
Returns: |
||||
folium.Map: An example of folium map. |
||||
""" |
||||
|
||||
if tiles not in CHINATILES: |
||||
raise ValueError("The `tiles` must in {}, not {}.".format(CHINATILES, |
||||
tiles)) |
||||
fmap = Map( |
||||
tiles=tiles, |
||||
min_zoom=1, |
||||
max_zoom=24, ) |
||||
if img_path is not None: |
||||
layer, _ = Raster(img_path, band_list).get_layer() |
||||
layer.add_to(fmap) |
||||
layer, center = Raster(mask_path).get_layer() |
||||
layer.add_to(fmap) |
||||
if center is not None: |
||||
fmap.location = center |
||||
fmap.fit_bounds(layer.bounds) |
||||
LayerControl().add_to(fmap) |
||||
if save_path: |
||||
fmap.save(save_path) |
||||
webbrowser.open(save_path) |
||||
return fmap |
||||
|
||||
|
||||
class OpenAsEPSG4326Error(Exception): |
||||
pass |
||||
|
||||
|
||||
class Raster: |
||||
def __init__( |
||||
self, |
||||
path: str, |
||||
band_list: Union[List[int], Tuple[int, ...], None]=None) -> None: |
||||
self.src_data = Converter.open_as_WGS84(path) |
||||
if self.src_data is None: |
||||
raise OpenAsEPSG4326Error("Faild to open {} in EPSG:4326.".format( |
||||
path)) |
||||
self.name = Path(path).stem |
||||
self.set_bands(band_list) |
||||
self._get_info() |
||||
|
||||
def get_layer(self) -> ImageOverlay: |
||||
layer = ImageOverlay(self._get_array(), self.wgs_range, name=self.name) |
||||
return layer, self.wgs_center |
||||
|
||||
def set_bands(self, |
||||
band_list: Union[List[int], Tuple[int, ...], None]) -> None: |
||||
self.bands = self.src_data.RasterCount |
||||
if band_list is None: |
||||
if self.bands == 3: |
||||
band_list = [1, 2, 3] |
||||
else: |
||||
band_list = [1] |
||||
band_list_lens = len(band_list) |
||||
if band_list_lens not in (1, 3): |
||||
raise ValueError("The lenght of band_list must be 1 or 3, not {}.". |
||||
format(str(band_list_lens))) |
||||
if max(band_list) > self.bands or min(band_list) < 1: |
||||
raise ValueError("The range of band_list must within [1, {}].". |
||||
format(str(self.bands))) |
||||
self.band_list = band_list |
||||
|
||||
def _get_info(self) -> None: |
||||
self.width = self.src_data.RasterXSize |
||||
self.height = self.src_data.RasterYSize |
||||
self.geotf = self.src_data.GetGeoTransform() |
||||
self.proj = self.src_data.GetProjection() # WGS84 |
||||
self.wgs_range = self._get_WGS84_range() |
||||
self.wgs_center = self._get_WGS84_center() |
||||
|
||||
def _get_WGS84_range(self) -> List[List[float]]: |
||||
converter = Converter(self.proj, self.geotf) |
||||
lat1, lon1 = converter.xy2latlon(self.height - 1, 0) |
||||
lat2, lon2 = converter.xy2latlon(0, self.width - 1) |
||||
return [[lon1, lat1], [lon2, lat2]] |
||||
|
||||
def _get_WGS84_center(self) -> List[float]: |
||||
clat = (self.wgs_range[0][0] + self.wgs_range[1][0]) / 2 |
||||
clon = (self.wgs_range[0][1] + self.wgs_range[1][1]) / 2 |
||||
return [clat, clon] |
||||
|
||||
def _get_array(self) -> np.ndarray: |
||||
band_array = [] |
||||
for b in self.band_list: |
||||
band_i = self.src_data.GetRasterBand(b).ReadAsArray() |
||||
band_array.append(band_i) |
||||
ima = np.stack(band_array, axis=0) |
||||
if self.bands == 1: |
||||
# the type is complex means this is a SAR data |
||||
if isinstance(type(ima[0, 0]), complex): |
||||
ima = abs(ima) |
||||
ima = ima.squeeze() |
||||
else: |
||||
ima = ima.transpose((1, 2, 0)) |
||||
ima = to_uint8(ima, True) |
||||
return ima |
||||
|
||||
|
||||
class Converter: |
||||
def __init__(self, proj: str, geotf: tuple) -> None: |
||||
# source data |
||||
self.source = osr.SpatialReference() |
||||
self.source.ImportFromWkt(proj) |
||||
self.geotf = geotf |
||||
# target data |
||||
self.target = osr.SpatialReference() |
||||
self.target.ImportFromEPSG(4326) |
||||
|
||||
@classmethod |
||||
def open_as_WGS84(self, path: str) -> gdal.Dataset: |
||||
if not osp.exists(path): |
||||
raise FileNotFoundError("{} not found.".format(path)) |
||||
result = gdal.Warp("", path, dstSRS="EPSG:4326", format="VRT") |
||||
return result |
||||
|
||||
def xy2latlon(self, row: int, col: int) -> List[float]: |
||||
px = self.geotf[0] + col * self.geotf[1] + row * self.geotf[2] |
||||
py = self.geotf[3] + col * self.geotf[4] + row * self.geotf[5] |
||||
ct = osr.CoordinateTransformation(self.source, self.target) |
||||
coords = ct.TransformPoint(px, py) |
||||
return coords[:2] |
@ -1,23 +1,25 @@ |
||||
paddleslim >= 2.2.1,<2.3.3 |
||||
visualdl >= 2.1.1 |
||||
opencv-contrib-python == 4.3.0.38 |
||||
numba == 0.53.1 |
||||
scikit-learn == 0.23.2 |
||||
scikit-image >= 0.14.0 |
||||
pandas |
||||
scipy |
||||
chardet |
||||
colorama |
||||
cython |
||||
pycocotools |
||||
shapely |
||||
easydict |
||||
filelock |
||||
foliume |
||||
# GDAL >= 3.1.3 |
||||
geojson |
||||
lap |
||||
motmetrics |
||||
chardet |
||||
openpyxl |
||||
easydict |
||||
munch |
||||
natsort |
||||
geojson |
||||
colorama |
||||
filelock |
||||
# # Self installation |
||||
# GDAL >= 3.1.3 |
||||
numba == 0.53.1 |
||||
opencv-contrib-python == 4.3.0.38 |
||||
openpyxl |
||||
# paddlepaddle >= 2.2.0 |
||||
# paddlepaddle-gpu >= 2.2.0 |
||||
paddleslim >= 2.2.1,<2.3.3 |
||||
pandas |
||||
pycocotools |
||||
scikit-learn == 0.23.2 |
||||
scikit-image >= 0.14.0 |
||||
scipy |
||||
shapely |
||||
visualdl >= 2.1.1 |
||||
|
Loading…
Reference in new issue