|
|
|
@ -13,6 +13,7 @@ |
|
|
|
|
# limitations under the License. |
|
|
|
|
|
|
|
|
|
import math |
|
|
|
|
import os |
|
|
|
|
import os.path as osp |
|
|
|
|
from collections import OrderedDict |
|
|
|
|
from operator import attrgetter |
|
|
|
@ -545,6 +546,88 @@ class BaseChangeDetector(BaseModel): |
|
|
|
|
} |
|
|
|
|
return prediction |
|
|
|
|
|
|
|
|
|
def slider_predict(self, img_file, save_dir, block_size, overlap=36, transforms=None): |
|
|
|
|
""" |
|
|
|
|
Do inference. |
|
|
|
|
Args: |
|
|
|
|
Args: |
|
|
|
|
img_file(List[str]): |
|
|
|
|
List of image paths. |
|
|
|
|
save_dir(str): |
|
|
|
|
Directory that contains saved geotiff file. |
|
|
|
|
block_size(List[int] or Tuple[int], int): |
|
|
|
|
The size of block. |
|
|
|
|
overlap(List[int] or Tuple[int], int): |
|
|
|
|
The overlap between two blocks. Defaults to 36. |
|
|
|
|
transforms(paddlers.transforms.Compose or None, optional): |
|
|
|
|
Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None. |
|
|
|
|
""" |
|
|
|
|
try: |
|
|
|
|
from osgeo import gdal |
|
|
|
|
except: |
|
|
|
|
import gdal |
|
|
|
|
|
|
|
|
|
if len(img_file) != 2: |
|
|
|
|
raise ValueError("`img_file` must be a list of length 2.") |
|
|
|
|
if isinstance(block_size, int): |
|
|
|
|
block_size = (block_size, block_size) |
|
|
|
|
elif isinstance(block_size, (tuple, list)) and len(block_size) == 2: |
|
|
|
|
block_size = tuple(block_size) |
|
|
|
|
else: |
|
|
|
|
raise ValueError("`block_size` must be a tuple/list of length 2 or an integer.") |
|
|
|
|
if isinstance(overlap, int): |
|
|
|
|
overlap = (overlap, overlap) |
|
|
|
|
elif isinstance(overlap, (tuple, list)) and len(overlap) == 2: |
|
|
|
|
overlap = tuple(overlap) |
|
|
|
|
else: |
|
|
|
|
raise ValueError("`overlap` must be a tuple/list of length 2 or an integer.") |
|
|
|
|
|
|
|
|
|
src1_data = gdal.Open(img_file[0]) |
|
|
|
|
src2_data = gdal.Open(img_file[1]) |
|
|
|
|
width = src1_data.RasterXSize |
|
|
|
|
height = src1_data.RasterYSize |
|
|
|
|
bands = src1_data.RasterCount |
|
|
|
|
|
|
|
|
|
driver = gdal.GetDriverByName("GTiff") |
|
|
|
|
file_name = osp.splitext(osp.normpath(img_file[0]).split(os.sep)[-1])[0] + ".tif" |
|
|
|
|
if not osp.exists(save_dir): |
|
|
|
|
os.makedirs(save_dir) |
|
|
|
|
save_file = osp.join(save_dir, file_name) |
|
|
|
|
dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte) |
|
|
|
|
dst_data.SetGeoTransform(src1_data.GetGeoTransform()) |
|
|
|
|
dst_data.SetProjection(src1_data.GetProjection()) |
|
|
|
|
band = dst_data.GetRasterBand(1) |
|
|
|
|
band.WriteArray(255 * np.ones((height, width), dtype="uint8")) |
|
|
|
|
|
|
|
|
|
step = np.array(block_size) - np.array(overlap) |
|
|
|
|
for yoff in range(0, height, step[1]): |
|
|
|
|
for xoff in range(0, width, step[0]): |
|
|
|
|
xsize, ysize = block_size |
|
|
|
|
if xoff + xsize > width: |
|
|
|
|
xsize = int(width - xoff) |
|
|
|
|
if yoff + ysize > height: |
|
|
|
|
ysize = int(height - yoff) |
|
|
|
|
im1 = src1_data.ReadAsArray(int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0)) |
|
|
|
|
im2 = src2_data.ReadAsArray(int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0)) |
|
|
|
|
# fill |
|
|
|
|
h, w = im1.shape[:2] |
|
|
|
|
im1_fill = np.zeros((block_size[1], block_size[0], bands), dtype=im1.dtype) |
|
|
|
|
im2_fill = im1_fill.copy() |
|
|
|
|
im1_fill[:h, :w, :] = im1 |
|
|
|
|
im2_fill[:h, :w, :] = im2 |
|
|
|
|
im_fill = (im1_fill, im2_fill) |
|
|
|
|
# predict |
|
|
|
|
pred = self.predict(im_fill, transforms)["label_map"].astype("uint8") |
|
|
|
|
# overlap |
|
|
|
|
rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize) |
|
|
|
|
mask = (rd_block == pred[:h, :w]) | (rd_block == 255) |
|
|
|
|
temp = pred[:h, :w].copy() |
|
|
|
|
temp[mask == False] = 0 |
|
|
|
|
band.WriteArray(temp, int(xoff), int(yoff)) |
|
|
|
|
dst_data.FlushCache() |
|
|
|
|
dst_data = None |
|
|
|
|
print("GeoTiff saved in {}.".format(save_file)) |
|
|
|
|
|
|
|
|
|
def _preprocess(self, images, transforms, to_tensor=True): |
|
|
|
|
arrange_transforms( |
|
|
|
|
model_type=self.model_type, transforms=transforms, mode='test') |
|
|
|
|