[Feat] Add display (#41)
parent
820e5e66d7
commit
f403abcadd
5 changed files with 228 additions and 18 deletions
Binary file not shown.
Before Width: | Height: | Size: 171 KiB |
@ -0,0 +1,201 @@ |
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
|
||||||
|
import warnings |
||||||
|
|
||||||
|
warnings.filterwarnings("ignore", category=DeprecationWarning) |
||||||
|
|
||||||
|
import os.path as osp |
||||||
|
from pathlib import Path |
||||||
|
from typing import List, Tuple, Union, Optional |
||||||
|
import webbrowser |
||||||
|
import numpy as np |
||||||
|
from folium import folium, Map, LayerControl |
||||||
|
from folium.raster_layers import TileLayer, ImageOverlay |
||||||
|
from paddlers.transforms.functions import to_uint8 |
||||||
|
|
||||||
|
try: |
||||||
|
from osgeo import gdal, osr |
||||||
|
except: |
||||||
|
import gdal |
||||||
|
import osr |
||||||
|
|
||||||
|
CHINATILES = ( |
||||||
|
"GeoQ China Community", |
||||||
|
"GeoQ China Street", |
||||||
|
"AMAP China", |
||||||
|
"TencentMap China", |
||||||
|
"BaiduMaps China", ) |
||||||
|
|
||||||
|
|
||||||
|
def map_display(mask_path: str, |
||||||
|
img_path: Optional[str]=None, |
||||||
|
band_list: Union[List[int], Tuple[int, ...], None]=None, |
||||||
|
save_path: Optional[str]=None, |
||||||
|
tiles: str="GeoQ China Community") -> folium.Map: |
||||||
|
""" |
||||||
|
Show mask (and original image) on an online map. |
||||||
|
|
||||||
|
Args: |
||||||
|
mask_path (str): Path of predicted or ground-truth masks. |
||||||
|
img_path (str|None, optional): Path of the original image. Defaults to None. |
||||||
|
band_list (list[int]|tuple[int]|None, optional): |
||||||
|
Bands to select from the original image for display (the band index starts from 1). |
||||||
|
If None, use all bands. Defaults to None. |
||||||
|
save_path (str, optional): Path of the .html file to save the visualization results. |
||||||
|
In Jupyter Notebook environments, |
||||||
|
leave `save_path` as None to display the result immediately in the notebook. |
||||||
|
Defaults to None. |
||||||
|
tiles (str): Map tileset to use. Due to boundary question, |
||||||
|
the map tileset provided by folium cannot be selected. |
||||||
|
* Only can choose from this list: |
||||||
|
- "GeoQ China Community", "GeoQ China Street" (from http://www.geoq.cn/) |
||||||
|
- "AMAP China" (from https://www.amap.com/) |
||||||
|
- "TencentMap China" (from https://map.qq.com/) |
||||||
|
- "BaiduMaps China" (from https://map.baidu.com/) { |
||||||
|
It is worked in some cases, but it is not recommended. |
||||||
|
} |
||||||
|
|
||||||
|
Defaults to "GeoQ China Community". |
||||||
|
|
||||||
|
These tileset have been corrected through the public algorithm on the Internet. |
||||||
|
* Please read the relevant terms of use carefully: |
||||||
|
- GeoQ [GISUNI] (http://geoq.cn/useragreement.html) |
||||||
|
- AMap [AutoNavi] (https://wap.amap.com/doc/serviceitem.html) |
||||||
|
- Tencent Map (https://ugc.map.qq.com/AppBox/Landlord/serveagreement.html) |
||||||
|
- Baidu Map (https://map.baidu.com/zt/client/service/index.html) |
||||||
|
|
||||||
|
Returns: |
||||||
|
folium.Map: An example of folium map. |
||||||
|
""" |
||||||
|
|
||||||
|
if tiles not in CHINATILES: |
||||||
|
raise ValueError("The `tiles` must in {}, not {}.".format(CHINATILES, |
||||||
|
tiles)) |
||||||
|
fmap = Map( |
||||||
|
tiles=tiles, |
||||||
|
min_zoom=1, |
||||||
|
max_zoom=24, ) |
||||||
|
if img_path is not None: |
||||||
|
layer, _ = Raster(img_path, band_list).get_layer() |
||||||
|
layer.add_to(fmap) |
||||||
|
layer, center = Raster(mask_path).get_layer() |
||||||
|
layer.add_to(fmap) |
||||||
|
if center is not None: |
||||||
|
fmap.location = center |
||||||
|
fmap.fit_bounds(layer.bounds) |
||||||
|
LayerControl().add_to(fmap) |
||||||
|
if save_path: |
||||||
|
fmap.save(save_path) |
||||||
|
webbrowser.open(save_path) |
||||||
|
return fmap |
||||||
|
|
||||||
|
|
||||||
|
class OpenAsEPSG4326Error(Exception): |
||||||
|
pass |
||||||
|
|
||||||
|
|
||||||
|
class Raster: |
||||||
|
def __init__( |
||||||
|
self, |
||||||
|
path: str, |
||||||
|
band_list: Union[List[int], Tuple[int, ...], None]=None) -> None: |
||||||
|
self.src_data = Converter.open_as_WGS84(path) |
||||||
|
if self.src_data is None: |
||||||
|
raise OpenAsEPSG4326Error("Faild to open {} in EPSG:4326.".format( |
||||||
|
path)) |
||||||
|
self.name = Path(path).stem |
||||||
|
self.set_bands(band_list) |
||||||
|
self._get_info() |
||||||
|
|
||||||
|
def get_layer(self) -> ImageOverlay: |
||||||
|
layer = ImageOverlay(self._get_array(), self.wgs_range, name=self.name) |
||||||
|
return layer, self.wgs_center |
||||||
|
|
||||||
|
def set_bands(self, |
||||||
|
band_list: Union[List[int], Tuple[int, ...], None]) -> None: |
||||||
|
self.bands = self.src_data.RasterCount |
||||||
|
if band_list is None: |
||||||
|
if self.bands == 3: |
||||||
|
band_list = [1, 2, 3] |
||||||
|
else: |
||||||
|
band_list = [1] |
||||||
|
band_list_lens = len(band_list) |
||||||
|
if band_list_lens not in (1, 3): |
||||||
|
raise ValueError("The lenght of band_list must be 1 or 3, not {}.". |
||||||
|
format(str(band_list_lens))) |
||||||
|
if max(band_list) > self.bands or min(band_list) < 1: |
||||||
|
raise ValueError("The range of band_list must within [1, {}].". |
||||||
|
format(str(self.bands))) |
||||||
|
self.band_list = band_list |
||||||
|
|
||||||
|
def _get_info(self) -> None: |
||||||
|
self.width = self.src_data.RasterXSize |
||||||
|
self.height = self.src_data.RasterYSize |
||||||
|
self.geotf = self.src_data.GetGeoTransform() |
||||||
|
self.proj = self.src_data.GetProjection() # WGS84 |
||||||
|
self.wgs_range = self._get_WGS84_range() |
||||||
|
self.wgs_center = self._get_WGS84_center() |
||||||
|
|
||||||
|
def _get_WGS84_range(self) -> List[List[float]]: |
||||||
|
converter = Converter(self.proj, self.geotf) |
||||||
|
lat1, lon1 = converter.xy2latlon(self.height - 1, 0) |
||||||
|
lat2, lon2 = converter.xy2latlon(0, self.width - 1) |
||||||
|
return [[lon1, lat1], [lon2, lat2]] |
||||||
|
|
||||||
|
def _get_WGS84_center(self) -> List[float]: |
||||||
|
clat = (self.wgs_range[0][0] + self.wgs_range[1][0]) / 2 |
||||||
|
clon = (self.wgs_range[0][1] + self.wgs_range[1][1]) / 2 |
||||||
|
return [clat, clon] |
||||||
|
|
||||||
|
def _get_array(self) -> np.ndarray: |
||||||
|
band_array = [] |
||||||
|
for b in self.band_list: |
||||||
|
band_i = self.src_data.GetRasterBand(b).ReadAsArray() |
||||||
|
band_array.append(band_i) |
||||||
|
ima = np.stack(band_array, axis=0) |
||||||
|
if self.bands == 1: |
||||||
|
# the type is complex means this is a SAR data |
||||||
|
if isinstance(type(ima[0, 0]), complex): |
||||||
|
ima = abs(ima) |
||||||
|
ima = ima.squeeze() |
||||||
|
else: |
||||||
|
ima = ima.transpose((1, 2, 0)) |
||||||
|
ima = to_uint8(ima, True) |
||||||
|
return ima |
||||||
|
|
||||||
|
|
||||||
|
class Converter: |
||||||
|
def __init__(self, proj: str, geotf: tuple) -> None: |
||||||
|
# source data |
||||||
|
self.source = osr.SpatialReference() |
||||||
|
self.source.ImportFromWkt(proj) |
||||||
|
self.geotf = geotf |
||||||
|
# target data |
||||||
|
self.target = osr.SpatialReference() |
||||||
|
self.target.ImportFromEPSG(4326) |
||||||
|
|
||||||
|
@classmethod |
||||||
|
def open_as_WGS84(self, path: str) -> gdal.Dataset: |
||||||
|
if not osp.exists(path): |
||||||
|
raise FileNotFoundError("{} not found.".format(path)) |
||||||
|
result = gdal.Warp("", path, dstSRS="EPSG:4326", format="VRT") |
||||||
|
return result |
||||||
|
|
||||||
|
def xy2latlon(self, row: int, col: int) -> List[float]: |
||||||
|
px = self.geotf[0] + col * self.geotf[1] + row * self.geotf[2] |
||||||
|
py = self.geotf[3] + col * self.geotf[4] + row * self.geotf[5] |
||||||
|
ct = osr.CoordinateTransformation(self.source, self.target) |
||||||
|
coords = ct.TransformPoint(px, py) |
||||||
|
return coords[:2] |
@ -1,23 +1,25 @@ |
|||||||
paddleslim >= 2.2.1,<2.3.3 |
chardet |
||||||
visualdl >= 2.1.1 |
colorama |
||||||
opencv-contrib-python == 4.3.0.38 |
|
||||||
numba == 0.53.1 |
|
||||||
scikit-learn == 0.23.2 |
|
||||||
scikit-image >= 0.14.0 |
|
||||||
pandas |
|
||||||
scipy |
|
||||||
cython |
cython |
||||||
pycocotools |
easydict |
||||||
shapely |
filelock |
||||||
|
foliume |
||||||
|
# GDAL >= 3.1.3 |
||||||
|
geojson |
||||||
lap |
lap |
||||||
motmetrics |
motmetrics |
||||||
chardet |
|
||||||
openpyxl |
|
||||||
easydict |
|
||||||
munch |
munch |
||||||
natsort |
natsort |
||||||
geojson |
numba == 0.53.1 |
||||||
colorama |
opencv-contrib-python == 4.3.0.38 |
||||||
filelock |
openpyxl |
||||||
# # Self installation |
# paddlepaddle >= 2.2.0 |
||||||
# GDAL >= 3.1.3 |
# paddlepaddle-gpu >= 2.2.0 |
||||||
|
paddleslim >= 2.2.1,<2.3.3 |
||||||
|
pandas |
||||||
|
pycocotools |
||||||
|
scikit-learn == 0.23.2 |
||||||
|
scikit-image >= 0.14.0 |
||||||
|
scipy |
||||||
|
shapely |
||||||
|
visualdl >= 2.1.1 |
||||||
|
Loading…
Reference in new issue