@ -16,14 +16,13 @@ import os.path as osp
from typing import List , Tuple , Union , Optional
import numpy as np
from paddlers . transforms . functions import to_uint8 as raster2uint8
try :
from osgeo import gdal
except :
import gdal
from paddlers . transforms . functions import to_uint8 as raster2uint8
def _get_type ( type_name : str ) - > int :
if type_name in [ " bool " , " uint8 " ] :
@ -53,7 +52,9 @@ class Raster:
gdal_obj : Optional [ gdal . Dataset ] = None ,
band_list : Union [ List [ int ] , Tuple [ int ] , None ] = None ,
to_uint8 : bool = False ) - > None :
""" Class of read raster.
"""
Class of read raster .
Args :
path ( Optional [ str ] ) : The path of raster .
gdal_obj ( Optional [ Any ] , optional ) : The object of GDAL . Defaults to None .
@ -75,22 +76,25 @@ class Raster:
# https://www.osgeo.cn/gdal/drivers/raster/index.html
self . _src_data = gdal . Open ( path )
except :
raise TypeError (
" Unsupported data format: ` {} ` " . format ( self . ext_type ) )
raise TypeError ( " Unsupported data format: ` {} ` " . format (
self . ext_type ) )
else :
raise ValueError ( " The path {0} not exists. " . format ( path ) )
else :
if gdal_obj is not None :
self . _src_data = gdal_obj
else :
raise ValueError ( " At least one of `path` and `gdal_obj` is not None. " )
raise ValueError (
" At least one of `path` and `gdal_obj` is not None. " )
self . to_uint8 = to_uint8
self . _getInfo ( )
self . setBands ( band_list )
self . _getType ( )
def setBands ( self , band_list : Union [ List [ int ] , Tuple [ int ] , None ] ) - > None :
""" Set band of data.
"""
Set band of data .
Args :
band_list ( Union [ List [ int ] , Tuple [ int ] , None ] ) :
band list ( start with 1 ) or None ( all of bands ) .
@ -105,16 +109,19 @@ class Raster:
format ( str ( self . bands ) ) )
self . band_list = band_list
def getArray (
self ,
start_loc : Union [ List [ int ] , Tuple [ int , int ] , None ] = None ,
block_size : Union [ List [ int ] , Tuple [ int , int ] ] = [ 512 , 512 ] ) - > np . ndarray :
""" Get ndarray data
def getArray ( self ,
start_loc : Union [ List [ int ] , Tuple [ int , int ] , None ] = None ,
block_size : Union [ List [ int ] , Tuple [ int , int ] ] = [ 512 , 512 ]
) - > np . ndarray :
"""
Get ndarray data
Args :
start_loc ( Union [ List [ int ] , Tuple [ int ] , None ] , optional ) :
Coordinates of the upper left corner of the block , if None means return full image .
block_size ( Union [ List [ int ] , Tuple [ int ] ] , optional ) :
Block size . Defaults to [ 512 , 512 ] .
Returns :
np . ndarray : data ' s ndarray.
"""
@ -144,7 +151,7 @@ class Raster:
self . bands = 1
self . geot = None
self . proj = None
def _getType ( self ) - > None :
d_name = self . getArray ( [ 0 , 0 ] , [ 1 , 1 ] ) . dtype . name
self . datatype = _get_type ( d_name )
@ -159,9 +166,9 @@ class Raster:
ima = np . stack ( band_array , axis = 0 )
return ima
def _getArray (
self ,
window : Union [ None , List [ int ] , Tuple [ int , int , int , int ] ] = None ) - > np . ndarray :
def _getArray ( self ,
window : Union [ None , List [ int ] , Tuple [ int , int , int , int ] ] = None
) - > np . ndarray :
if self . _src_data is None :
raise ValueError ( " The raster is None. " )
if window is not None :
@ -193,10 +200,10 @@ class Raster:
ima = raster2uint8 ( ima )
return ima
def _getBlock (
self ,
start_loc : Union [ List [ int ] , Tuple [ int , int ] ] ,
block_size : Union [ List [ int ] , Tuple [ int , int ] ] = [ 512 , 512 ] ) - > np . ndarray :
def _getBlock ( self ,
start_loc : Union [ List [ int ] , Tuple [ int , int ] ] ,
block_size : Union [ List [ int ] , Tuple [ int , int ] ] = [ 512 , 512 ]
) - > np . ndarray :
if len ( start_loc ) != 2 or len ( block_size ) != 2 :
raise ValueError ( " The length start_loc/block_size must be 2. " )
xoff , yoff = start_loc
@ -220,9 +227,9 @@ class Raster:
return tmp
def save_geotiff ( image : np . ndarray ,
save_path : str ,
proj : str ,
def save_geotiff ( image : np . ndarray ,
save_path : str ,
proj : str ,
geotf : Tuple ,
use_type : Optional [ int ] = None ,
clear_ds : bool = True ) - > None :