Merge branch 'develop' into update_ppseg

own
Bobholamovic 2 years ago
commit 9fd3b7b00e
  1. 6
      docs/apis/data.md
  2. 21
      docs/apis/infer.md
  3. 63
      paddlers/deploy/predictor.py
  4. 17
      paddlers/tasks/base.py
  5. 117
      paddlers/tasks/change_detector.py
  6. 4
      paddlers/tasks/classifier.py
  7. 4
      paddlers/tasks/object_detector.py
  8. 4
      paddlers/tasks/restorer.py
  9. 95
      paddlers/tasks/segmenter.py
  10. 437
      paddlers/tasks/utils/slider_predict.py
  11. 17
      paddlers/transforms/__init__.py
  12. 6
      paddlers/transforms/functions.py
  13. 14
      paddlers/transforms/operators.py
  14. 29
      tests/deploy/test_predictor.py
  15. 1
      tests/fast_tests.py
  16. 2
      tests/tasks/__init__.py
  17. 212
      tests/tasks/test_slider_predict.py

@ -134,10 +134,12 @@
|-------|----|--------|-----|
|`im_path`|`str`|输入图像路径。||
|`to_rgb`|`bool`|若为`True`,则执行BGR到RGB格式的转换。|`True`|
|`to_uint8`|`bool`|若为`True`,则将读取的像数据量化并转换为uint8类型。|`True`|
|`to_uint8`|`bool`|若为`True`,则将读取的像数据量化并转换为uint8类型。|`True`|
|`decode_bgr`|`bool`|若为`True`,则自动将非地学格式影像(如jpeg影像)解析为BGR格式。|`True`|
|`decode_sar`|`bool`|若为`True`,则自动将2通道的地学格式影像(如GeoTiff影像)作为SAR影像解析。|`True`|
|`decode_sar`|`bool`|若为`True`,则自动将通道的地学格式影像(如GeoTiff影像)作为SAR影像解析。|`True`|
|`read_geo_info`|`bool`|若为`True`,则从影像中读取地理信息。|`False`|
|`use_stretch`|`bool`|是否对影像亮度进行2%线性拉伸。仅当`to_uint8`为`True`时有效。|`False`|
|`read_raw`|`bool`|若为`True`,等价于指定`to_rgb`和`to_uint8`为`False`,且该参数的优先级高于上述参数。|`False`|
返回格式如下:

@ -155,7 +155,11 @@ def slider_predict(self,
save_dir,
block_size,
overlap=36,
transforms=None):
transforms=None,
invalid_value=255,
merge_strategy='keep_last',
batch_size=1,
quiet=False):
```
输入参数列表:
@ -164,11 +168,15 @@ def slider_predict(self,
|-------|----|--------|-----|
|`img_file`|`str`|输入影像路径。||
|`save_dir`|`str`|预测结果输出路径。||
|`block_size`|`list[int]` \| `tuple[int]` \| `int`|滑窗的窗口大小(以列表或元组指定长、宽或以一个整数指定相同的宽)。||
|`overlap`|`list[int]` \| `tuple[int]` \| `int`|滑窗的滑动步长(以列表或元组指定长、宽或以一个整数指定相同的宽)。|`36`|
|`block_size`|`list[int]` \| `tuple[int]` \| `int`|滑窗的窗口大小(以列表或元组指定宽度、高度或以一个整数指定相同的宽)。||
|`overlap`|`list[int]` \| `tuple[int]` \| `int`|滑窗的滑动步长(以列表或元组指定宽度、高度或以一个整数指定相同的宽)。|`36`|
|`transforms`|`paddlers.transforms.Compose` \| `None`|对输入数据应用的数据变换算子。若为`None`,则使用训练器在验证阶段使用的数据变换算子。|`None`|
|`invalid_value`|`int`|输出影像中用于标记无效像素的数值。|`255`|
|`merge_strategy`|`str`|合并滑窗重叠区域使用的策略。`'keep_first'`表示保留遍历顺序(从左至右,从上往下,列优先)最靠前的窗口的预测类别;`'keep_last'`表示保留遍历顺序最靠后的窗口的预测类别;`'accum'`表示通过将各窗口在重叠区域给出的预测概率累加,计算最终预测类别。需要注意的是,在对大尺寸影像进行`overlap`较大的密集推理时,使用`'accum'`策略可能导致较长的推理时间,但一般能够在窗口交界部分取得更好的表现。|`'keep_last'`|
|`batch_size`|`int`|预测时使用的mini-batch大小。|`1`|
|`quiet`|`bool`|若为`True`,不显示预测进度。|`False`|
变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准。
变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准,存储滑窗推理结果的文件名也与第一时相影像文件相同
## 静态图推理API
@ -216,5 +224,10 @@ def predict(self,
|`transforms`|`paddlers.transforms.Compose`\|`None`|对输入数据应用的数据变换算子。若为`None`,则使用从`model.yml`中读取的算子。|`None`|
|`warmup_iters`|`int`|预热轮数,用于评估模型推理以及前后处理速度。若大于1,将预先重复执行`warmup_iters`次推理,而后才开始正式的预测及其速度评估。|`0`|
|`repeats`|`int`|重复次数,用于评估模型推理以及前后处理速度。若大于1,将执行`repeats`次预测并取时间平均值。|`1`|
|`quiet`|`bool`|若为`True`,不打印计时信息。|`False`|
`Predictor.predict()`的返回格式与相应的动态图推理API的返回格式完全相同,详情请参考[动态图推理API](#动态图推理api)。
### `Predictor.slider_predict()`
实现滑窗推理功能。用法与`BaseSegmenter`和`BaseChangeDetector`的`slider_predict()`方法相同。

@ -14,6 +14,7 @@
import os.path as osp
from operator import itemgetter
from functools import partial
import numpy as np
import paddle
@ -23,6 +24,7 @@ from paddle.inference import PrecisionType
from paddlers.tasks import load_model
from paddlers.utils import logging, Timer
from paddlers.tasks.utils.slider_predict import slider_predict
class Predictor(object):
@ -271,22 +273,24 @@ class Predictor(object):
topk=1,
transforms=None,
warmup_iters=0,
repeats=1):
repeats=1,
quiet=False):
"""
Do prediction.
Do inference.
Args:
img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For scene classification, image restoration,
object detection and semantic segmentation tasks, `img_file` should be either the path of the image to predict,
a decoded image (a np.ndarray, which should be consistent with what you get from passing image path to
paddlers.transforms.decode_image()), or a list of image paths or decoded images. For change detection tasks,
img_file should be a tuple of image paths, a tuple of decoded images, or a list of tuples.
paddlers.transforms.decode_image(..., read_raw=True)), or a list of image paths or decoded images. For change
detection tasks, `img_file` should be a tuple of image paths, a tuple of decoded images, or a list of tuples.
topk(int, optional): Top-k values to reserve in a classification result. Defaults to 1.
transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms
from `model.yml`. Defaults to None.
warmup_iters (int, optional): Warm-up iterations before measuring the execution time. Defaults to 0.
repeats (int, optional): Number of repetitions to evaluate model inference and data processing speed. If greater than
1, the reported time consumption is the average of all repeats. Defaults to 1.
quiet (bool, optional): If True, do not display the timing information. Defaults to False.
"""
if repeats < 1:
@ -313,12 +317,61 @@ class Predictor(object):
self.timer.repeats = repeats
self.timer.img_num = len(images)
self.timer.info(average=True)
if not quiet:
self.timer.info(average=True)
if isinstance(img_file, (str, np.ndarray, tuple)):
results = results[0]
return results
def slider_predict(self,
img_file,
save_dir,
block_size,
overlap=36,
transforms=None,
invalid_value=255,
merge_strategy='keep_last',
batch_size=1,
quiet=False):
"""
Do inference using sliding windows. Only semantic segmentation and change detection models are supported in the
sliding-predicting mode.
Args:
img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For semantic segmentation tasks, `img_file`
should be either the path of the image to predict, a decoded image (a np.ndarray, which should be
consistent with what you get from passing image path to paddlers.transforms.decode_image(..., read_raw=True)),
or a list of image paths or decoded images. For change detection tasks, `img_file` should be a tuple of
image paths, a tuple of decoded images, or a list of tuples.
save_dir (str): Directory that contains saved geotiff file.
block_size (list[int] | tuple[int] | int): Size of block. If `block_size` is a list or tuple, it should be in
(W, H) format.
overlap (list[int] | tuple[int] | int, optional): Overlap between two blocks. If `overlap` is a list or tuple,
it should be in (W, H) format. Defaults to 36.
transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms
from `model.yml`. Defaults to None.
invalid_value (int, optional): Value that marks invalid pixels in output image. Defaults to 255.
merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices are
{'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' means keeping the values of the first and
the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel
according to accumulated probabilities. Defaults to 'keep_last'.
batch_size (int, optional): Batch size used in inference. Defaults to 1.
quiet (bool, optional): If True, disable the progress bar. Defaults to False.
"""
slider_predict(
partial(
self.predict, quiet=True),
img_file,
save_dir,
block_size,
overlap,
transforms,
invalid_value,
merge_strategy,
batch_size,
not quiet)
def batch_predict(self, image_list, **params):
return self.predict(img_file=image_list, **params)

@ -86,7 +86,7 @@ class BaseModel(metaclass=ModelMeta):
self.quant_config = None
self.fixed_input_shape = None
def net_initialize(self,
def initialize_net(self,
pretrain_weights=None,
save_dir='.',
resume_checkpoint=None,
@ -677,3 +677,18 @@ class BaseModel(metaclass=ModelMeta):
raise ValueError(
f"Incorrect arrange mode! Expected {mode} but got {arrange_obj.mode}."
)
def run(self, net, inputs, mode):
raise NotImplementedError
def train(self, *args, **kwargs):
raise NotImplementedError
def evaluate(self, *args, **kwargs):
raise NotImplementedError
def preprocess(self, images, transforms, to_tensor):
raise NotImplementedError
def postprocess(self, *args, **kwargs):
raise NotImplementedError

@ -35,6 +35,7 @@ from paddlers.utils.checkpoint import cd_pretrain_weights_dict
from .base import BaseModel
from .utils import seg_metrics as metrics
from .utils.infer_nets import InferCDNet
from .utils.slider_predict import slider_predict
__all__ = [
"CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT",
@ -315,7 +316,7 @@ class BaseChangeDetector(BaseModel):
exit=True)
pretrained_dir = osp.join(save_dir, 'pretrain')
is_backbone_weights = pretrain_weights == 'IMAGENET'
self.net_initialize(
self.initialize_net(
pretrain_weights=pretrain_weights,
save_dir=pretrained_dir,
resume_checkpoint=resume_checkpoint,
@ -581,96 +582,44 @@ class BaseChangeDetector(BaseModel):
return prediction
def slider_predict(self,
img_file,
img_files,
save_dir,
block_size,
overlap=36,
transforms=None):
transforms=None,
invalid_value=255,
merge_strategy='keep_last',
batch_size=1,
quiet=False):
"""
Do inference.
Do inference using sliding windows.
Args:
img_file (tuple[str]): Tuple of image paths.
img_files (tuple[str]): Tuple of image paths.
save_dir (str): Directory that contains saved geotiff file.
block_size (list[int] | tuple[int] | int, optional): Size of block.
overlap (list[int] | tuple[int] | int, optional): Overlap between two blocks.
Defaults to 36.
transforms (paddlers.transforms.Compose|None, optional): Transforms for inputs.
If None, the transforms for evaluation process will be used. Defaults to None.
block_size (list[int] | tuple[int] | int):
Size of block. If `block_size` is a list or tuple, it should be in
(W, H) format.
overlap (list[int] | tuple[int] | int, optional):
Overlap between two blocks. If `overlap` is a list or tuple, it should
be in (W, H) format. Defaults to 36.
transforms (paddlers.transforms.Compose|None, optional): Transforms for
inputs. If None, the transforms for evaluation process will be used.
Defaults to None.
invalid_value (int, optional): Value that marks invalid pixels in output
image. Defaults to 255.
merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices
are {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last'
means keeping the values of the first and the last block in traversal
order, respectively. 'accum' means determining the class of an overlapping
pixel according to accumulated probabilities. Defaults to 'keep_last'.
batch_size (int, optional): Batch size used in inference. Defaults to 1.
quiet (bool, optional): If True, disable the progress bar. Defaults to False.
"""
try:
from osgeo import gdal
except:
import gdal
if not isinstance(img_file, tuple) or len(img_file) != 2:
raise ValueError("`img_file` must be a tuple of length 2.")
if isinstance(block_size, int):
block_size = (block_size, block_size)
elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
block_size = tuple(block_size)
else:
raise ValueError(
"`block_size` must be a tuple/list of length 2 or an integer.")
if isinstance(overlap, int):
overlap = (overlap, overlap)
elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
overlap = tuple(overlap)
else:
raise ValueError(
"`overlap` must be a tuple/list of length 2 or an integer.")
src1_data = gdal.Open(img_file[0])
src2_data = gdal.Open(img_file[1])
width = src1_data.RasterXSize
height = src1_data.RasterYSize
bands = src1_data.RasterCount
driver = gdal.GetDriverByName("GTiff")
file_name = osp.splitext(osp.normpath(img_file[0]).split(os.sep)[-1])[
0] + ".tif"
if not osp.exists(save_dir):
os.makedirs(save_dir)
save_file = osp.join(save_dir, file_name)
dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
dst_data.SetGeoTransform(src1_data.GetGeoTransform())
dst_data.SetProjection(src1_data.GetProjection())
band = dst_data.GetRasterBand(1)
band.WriteArray(255 * np.ones((height, width), dtype="uint8"))
step = np.array(block_size) - np.array(overlap)
for yoff in range(0, height, step[1]):
for xoff in range(0, width, step[0]):
xsize, ysize = block_size
if xoff + xsize > width:
xsize = int(width - xoff)
if yoff + ysize > height:
ysize = int(height - yoff)
im1 = src1_data.ReadAsArray(
int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0))
im2 = src2_data.ReadAsArray(
int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0))
# Fill
h, w = im1.shape[:2]
im1_fill = np.zeros(
(block_size[1], block_size[0], bands), dtype=im1.dtype)
im2_fill = im1_fill.copy()
im1_fill[:h, :w, :] = im1
im2_fill[:h, :w, :] = im2
im_fill = (im1_fill, im2_fill)
# Predict
pred = self.predict(im_fill,
transforms)["label_map"].astype("uint8")
# Overlap
rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize)
mask = (rd_block == pred[:h, :w]) | (rd_block == 255)
temp = pred[:h, :w].copy()
temp[mask == False] = 0
band.WriteArray(temp, int(xoff), int(yoff))
dst_data.FlushCache()
dst_data = None
print("GeoTiff saved in {}.".format(save_file))
slider_predict(self.predict, img_files, save_dir, block_size, overlap,
transforms, invalid_value, merge_strategy, batch_size,
not quiet)
def preprocess(self, images, transforms, to_tensor=True):
self._check_transforms(transforms, 'test')
@ -678,8 +627,8 @@ class BaseChangeDetector(BaseModel):
batch_ori_shape = list()
for im1, im2 in images:
if isinstance(im1, str) or isinstance(im2, str):
im1 = decode_image(im1, to_rgb=False)
im2 = decode_image(im2, to_rgb=False)
im1 = decode_image(im1, read_raw=True)
im2 = decode_image(im2, read_raw=True)
ori_shape = im1.shape[:2]
# XXX: sample do not contain 'image_t1' and 'image_t2'.
sample = {'image': im1, 'image2': im2}

@ -286,7 +286,7 @@ class BaseClassifier(BaseModel):
exit=True)
pretrained_dir = osp.join(save_dir, 'pretrain')
is_backbone_weights = False
self.net_initialize(
self.initialize_net(
pretrain_weights=pretrain_weights,
save_dir=pretrained_dir,
resume_checkpoint=resume_checkpoint,
@ -495,7 +495,7 @@ class BaseClassifier(BaseModel):
batch_ori_shape = list()
for im in images:
if isinstance(im, str):
im = decode_image(im, to_rgb=False)
im = decode_image(im, read_raw=True)
ori_shape = im.shape[:2]
sample = {'image': im}
im = transforms(sample)

@ -347,7 +347,7 @@ class BaseDetector(BaseModel):
"Invalid pretrained weights. Please specify a .pdparams file.",
exit=True)
pretrained_dir = osp.join(save_dir, 'pretrain')
self.net_initialize(
self.initialize_net(
pretrain_weights=pretrain_weights,
save_dir=pretrained_dir,
resume_checkpoint=resume_checkpoint,
@ -617,7 +617,7 @@ class BaseDetector(BaseModel):
batch_samples = list()
for im in images:
if isinstance(im, str):
im = decode_image(im, to_rgb=False)
im = decode_image(im, read_raw=True)
sample = {'image': im}
sample = transforms(sample)
batch_samples.append(sample)

@ -283,7 +283,7 @@ class BaseRestorer(BaseModel):
exit=True)
pretrained_dir = osp.join(save_dir, 'pretrain')
is_backbone_weights = pretrain_weights == 'IMAGENET'
self.net_initialize(
self.initialize_net(
pretrain_weights=pretrain_weights,
save_dir=pretrained_dir,
resume_checkpoint=resume_checkpoint,
@ -481,7 +481,7 @@ class BaseRestorer(BaseModel):
batch_tar_shape = list()
for im in images:
if isinstance(im, str):
im = decode_image(im, to_rgb=False)
im = decode_image(im, read_raw=True)
ori_shape = im.shape[:2]
sample = {'image': im}
im = transforms(sample)[0]

@ -34,6 +34,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict
from .base import BaseModel
from .utils import seg_metrics as metrics
from .utils.infer_nets import InferSegNet
from .utils.slider_predict import slider_predict
__all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
@ -300,7 +301,7 @@ class BaseSegmenter(BaseModel):
exit=True)
pretrained_dir = osp.join(save_dir, 'pretrain')
is_backbone_weights = pretrain_weights == 'IMAGENET'
self.net_initialize(
self.initialize_net(
pretrain_weights=pretrain_weights,
save_dir=pretrained_dir,
resume_checkpoint=resume_checkpoint,
@ -550,86 +551,40 @@ class BaseSegmenter(BaseModel):
save_dir,
block_size,
overlap=36,
transforms=None):
transforms=None,
invalid_value=255,
merge_strategy='keep_last',
batch_size=1,
quiet=False):
"""
Do inference.
Do inference using sliding windows.
Args:
img_file (str): Image path.
save_dir (str): Directory that contains saved geotiff file.
block_size (list[int] | tuple[int] | int):
Size of block.
Size of block. If `block_size` is list or tuple, it should be in
(W, H) format.
overlap (list[int] | tuple[int] | int, optional):
Overlap between two blocks. Defaults to 36.
Overlap between two blocks. If `overlap` is list or tuple, it should
be in (W, H) format. Defaults to 36.
transforms (paddlers.transforms.Compose|None, optional): Transforms for
inputs. If None, the transforms for evaluation process will be used.
Defaults to None.
invalid_value (int, optional): Value that marks invalid pixels in output
image. Defaults to 255.
merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices
are {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last'
means keeping the values of the first and the last block in traversal
order, respectively. 'accum' means determining the class of an overlapping
pixel according to accumulated probabilities. Defaults to 'keep_last'.
batch_size (int, optional): Batch size used in inference. Defaults to 1.
quiet (bool, optional): If True, disable the progress bar. Defaults to False.
"""
try:
from osgeo import gdal
except:
import gdal
if isinstance(block_size, int):
block_size = (block_size, block_size)
elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
block_size = tuple(block_size)
else:
raise ValueError(
"`block_size` must be a tuple/list of length 2 or an integer.")
if isinstance(overlap, int):
overlap = (overlap, overlap)
elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
overlap = tuple(overlap)
else:
raise ValueError(
"`overlap` must be a tuple/list of length 2 or an integer.")
src_data = gdal.Open(img_file)
width = src_data.RasterXSize
height = src_data.RasterYSize
bands = src_data.RasterCount
driver = gdal.GetDriverByName("GTiff")
file_name = osp.splitext(osp.normpath(img_file).split(os.sep)[-1])[
0] + ".tif"
if not osp.exists(save_dir):
os.makedirs(save_dir)
save_file = osp.join(save_dir, file_name)
dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
dst_data.SetGeoTransform(src_data.GetGeoTransform())
dst_data.SetProjection(src_data.GetProjection())
band = dst_data.GetRasterBand(1)
band.WriteArray(255 * np.ones((height, width), dtype="uint8"))
step = np.array(block_size) - np.array(overlap)
for yoff in range(0, height, step[1]):
for xoff in range(0, width, step[0]):
xsize, ysize = block_size
if xoff + xsize > width:
xsize = int(width - xoff)
if yoff + ysize > height:
ysize = int(height - yoff)
im = src_data.ReadAsArray(int(xoff), int(yoff), xsize,
ysize).transpose((1, 2, 0))
# Fill
h, w = im.shape[:2]
im_fill = np.zeros(
(block_size[1], block_size[0], bands), dtype=im.dtype)
im_fill[:h, :w, :] = im
# Predict
pred = self.predict(im_fill,
transforms)["label_map"].astype("uint8")
# Overlap
rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize)
mask = (rd_block == pred[:h, :w]) | (rd_block == 255)
temp = pred[:h, :w].copy()
temp[mask == False] = 0
band.WriteArray(temp, int(xoff), int(yoff))
dst_data.FlushCache()
dst_data = None
print("GeoTiff saved in {}.".format(save_file))
slider_predict(self.predict, img_file, save_dir, block_size, overlap,
transforms, invalid_value, merge_strategy, batch_size,
not quiet)
def preprocess(self, images, transforms, to_tensor=True):
self._check_transforms(transforms, 'test')
@ -637,7 +592,7 @@ class BaseSegmenter(BaseModel):
batch_ori_shape = list()
for im in images:
if isinstance(im, str):
im = decode_image(im, to_rgb=False)
im = decode_image(im, read_raw=True)
ori_shape = im.shape[:2]
sample = {'image': im}
im = transforms(sample)[0]

@ -0,0 +1,437 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
import math
from abc import ABCMeta, abstractmethod
from collections import Counter, defaultdict
import numpy as np
from tqdm import tqdm
import paddlers.utils.logging as logging
class Cache(metaclass=ABCMeta):
@abstractmethod
def get_block(self, i_st, j_st, h, w):
pass
class SlowCache(Cache):
def __init__(self):
super(SlowCache, self).__init__()
self.cache = defaultdict(Counter)
def push_pixel(self, i, j, l):
self.cache[(i, j)][l] += 1
def push_block(self, i_st, j_st, h, w, data):
for i in range(0, h):
for j in range(0, w):
self.push_pixel(i_st + i, j_st + j, data[i, j])
def pop_pixel(self, i, j):
self.cache.pop((i, j))
def pop_block(self, i_st, j_st, h, w):
for i in range(0, h):
for j in range(0, w):
self.pop_pixel(i_st + i, j_st + j)
def get_pixel(self, i, j):
winners = self.cache[(i, j)].most_common(1)
winner = winners[0]
return winner[0]
def get_block(self, i_st, j_st, h, w):
block = []
for i in range(i_st, i_st + h):
row = []
for j in range(j_st, j_st + w):
row.append(self.get_pixel(i, j))
block.append(row)
return np.asarray(block)
class ProbCache(Cache):
def __init__(self, h, w, ch, cw, sh, sw, dtype=np.float32, order='c'):
super(ProbCache, self).__init__()
self.cache = None
self.h = h
self.w = w
self.ch = ch
self.cw = cw
self.sh = sh
self.sw = sw
if not issubclass(dtype, np.floating):
raise TypeError("`dtype` must be one of the floating types.")
self.dtype = dtype
order = order.lower()
if order not in ('c', 'f'):
raise ValueError("`order` other than 'c' and 'f' is not supported.")
self.order = order
def _alloc_memory(self, nc):
if self.order == 'c':
# Colomn-first order (C-style)
#
# <-- cw -->
# |--------|---------------------|^ ^
# | || | sh
# |--------|---------------------|| ch v
# | ||
# |--------|---------------------|v
# <------------ w --------------->
self.cache = np.zeros((self.ch, self.w, nc), dtype=self.dtype)
elif self.order == 'f':
# Row-first order (Fortran-style)
#
# <-- sw -->
# <---- cw ---->
# |--------|---|^ ^
# | | || |
# | | || ch
# | | || |
# |--------|---|| h v
# | | ||
# | | ||
# | | ||
# |--------|---|v
self.cache = np.zeros((self.h, self.cw, nc), dtype=self.dtype)
def update_block(self, i_st, j_st, h, w, prob_map):
if self.cache is None:
nc = prob_map.shape[2]
# Lazy allocation of memory
self._alloc_memory(nc)
self.cache[i_st:i_st + h, j_st:j_st + w] += prob_map
def roll_cache(self, shift):
if self.order == 'c':
self.cache[:-shift] = self.cache[shift:]
self.cache[-shift:, :] = 0
elif self.order == 'f':
self.cache[:, :-shift] = self.cache[:, shift:]
self.cache[:, -shift:] = 0
def get_block(self, i_st, j_st, h, w):
return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2)
class OverlapProcessor(metaclass=ABCMeta):
def __init__(self, h, w, ch, cw, sh, sw):
super(OverlapProcessor, self).__init__()
self.h = h
self.w = w
self.ch = ch
self.cw = cw
self.sh = sh
self.sw = sw
@abstractmethod
def process_pred(self, out, xoff, yoff):
pass
class KeepFirstProcessor(OverlapProcessor):
def __init__(self, h, w, ch, cw, sh, sw, ds, inval=255):
super(KeepFirstProcessor, self).__init__(h, w, ch, cw, sh, sw)
self.ds = ds
self.inval = inval
def process_pred(self, out, xoff, yoff):
pred = out['label_map']
pred = pred[:self.ch, :self.cw]
rd_block = self.ds.ReadAsArray(xoff, yoff, self.cw, self.ch)
mask = rd_block != self.inval
pred = np.where(mask, rd_block, pred)
return pred
class KeepLastProcessor(OverlapProcessor):
def process_pred(self, out, xoff, yoff):
pred = out['label_map']
pred = pred[:self.ch, :self.cw]
return pred
class AccumProcessor(OverlapProcessor):
def __init__(self,
h,
w,
ch,
cw,
sh,
sw,
dtype=np.float16,
assign_weight=True):
super(AccumProcessor, self).__init__(h, w, ch, cw, sh, sw)
self.cache = ProbCache(h, w, ch, cw, sh, sw, dtype=dtype, order='c')
self.prev_yoff = None
self.assign_weight = assign_weight
def process_pred(self, out, xoff, yoff):
if self.prev_yoff is not None and yoff != self.prev_yoff:
if yoff < self.prev_yoff:
raise RuntimeError
self.cache.roll_cache(yoff - self.prev_yoff)
pred = out['label_map']
pred = pred[:self.ch, :self.cw]
prob = out['score_map']
prob = prob[:self.ch, :self.cw]
if self.assign_weight:
prob = assign_border_weights(prob, border_ratio=0.25, inplace=True)
self.cache.update_block(0, xoff, self.ch, self.cw, prob)
pred = self.cache.get_block(0, xoff, self.ch, self.cw)
self.prev_yoff = yoff
return pred
def assign_border_weights(array, weight=0.5, border_ratio=0.25, inplace=True):
if not inplace:
array = array.copy()
h, w = array.shape[:2]
hm, wm = int(h * border_ratio), int(w * border_ratio)
array[:hm] *= weight
array[-hm:] *= weight
array[:, :wm] *= weight
array[:, -wm:] *= weight
return array
def read_block(ds,
xoff,
yoff,
xsize,
ysize,
tar_xsize=None,
tar_ysize=None,
pad_val=0):
if tar_xsize is None:
tar_xsize = xsize
if tar_ysize is None:
tar_ysize = ysize
# Read data from dataset
block = ds.ReadAsArray(xoff, yoff, xsize, ysize)
c, real_ysize, real_xsize = block.shape
assert real_ysize == ysize and real_xsize == xsize
# [c, h, w] -> [h, w, c]
block = block.transpose((1, 2, 0))
if (real_ysize, real_xsize) != (tar_ysize, tar_xsize):
if real_ysize >= tar_ysize or real_xsize >= tar_xsize:
raise ValueError
padded_block = np.full(
(tar_ysize, tar_xsize, c), fill_value=pad_val, dtype=block.dtype)
# Fill
padded_block[:real_ysize, :real_xsize] = block
return padded_block
else:
return block
def slider_predict(predict_func,
img_file,
save_dir,
block_size,
overlap,
transforms,
invalid_value,
merge_strategy,
batch_size,
show_progress=False):
"""
Do inference using sliding windows.
Args:
predict_func (callable): A callable object that makes the prediction.
img_file (str|tuple[str]): Image path(s).
save_dir (str): Directory that contains saved geotiff file.
block_size (list[int] | tuple[int] | int):
Size of block. If `block_size` is list or tuple, it should be in
(W, H) format.
overlap (list[int] | tuple[int] | int):
Overlap between two blocks. If `overlap` is list or tuple, it should
be in (W, H) format.
transforms (paddlers.transforms.Compose|None): Transforms for inputs. If
None, the transforms for evaluation process will be used.
invalid_value (int): Value that marks invalid pixels in output image.
Defaults to 255.
merge_strategy (str): Strategy to merge overlapping blocks. Choices are
{'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last'
means keeping the values of the first and the last block in
traversal order, respectively. 'accum' means determining the class
of an overlapping pixel according to accumulated probabilities.
batch_size (int): Batch size used in inference.
show_progress (bool, optional): Whether to show prediction progress with a
progress bar. Defaults to True.
"""
try:
from osgeo import gdal
except:
import gdal
if isinstance(block_size, int):
block_size = (block_size, block_size)
elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
block_size = tuple(block_size)
else:
raise ValueError(
"`block_size` must be a tuple/list of length 2 or an integer.")
if isinstance(overlap, int):
overlap = (overlap, overlap)
elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
overlap = tuple(overlap)
else:
raise ValueError(
"`overlap` must be a tuple/list of length 2 or an integer.")
step = np.array(
block_size, dtype=np.int32) - np.array(
overlap, dtype=np.int32)
if step[0] == 0 or step[1] == 0:
raise ValueError("`block_size` and `overlap` should not be equal.")
if isinstance(img_file, tuple):
if len(img_file) != 2:
raise ValueError("Tuple `img_file` must have the length of two.")
# Assume that two input images have the same size
src_data = gdal.Open(img_file[0])
src2_data = gdal.Open(img_file[1])
# Output name is the same as the name of the first image
file_name = osp.basename(osp.normpath(img_file[0]))
else:
src_data = gdal.Open(img_file)
file_name = osp.basename(osp.normpath(img_file))
# Get size of original raster
width = src_data.RasterXSize
height = src_data.RasterYSize
bands = src_data.RasterCount
# XXX: GDAL read behavior conforms to paddlers.transforms.decode_image(read_raw=True)
# except for SAR images.
if bands == 1:
logging.warning(
f"Detected `bands=1`. Please note that currently `slider_predict()` does not properly handle SAR images."
)
if block_size[0] > width or block_size[1] > height:
raise ValueError("`block_size` should not be larger than image size.")
driver = gdal.GetDriverByName("GTiff")
if not osp.exists(save_dir):
os.makedirs(save_dir)
# Replace extension name with '.tif'
file_name = osp.splitext(file_name)[0] + ".tif"
save_file = osp.join(save_dir, file_name)
dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
# Set meta-information
dst_data.SetGeoTransform(src_data.GetGeoTransform())
dst_data.SetProjection(src_data.GetProjection())
# Initialize raster with `invalid_value`
band = dst_data.GetRasterBand(1)
band.WriteArray(
np.full(
(height, width), fill_value=invalid_value, dtype="uint8"))
if overlap == (0, 0) or block_size == (width, height):
# When there is no overlap or the whole image is used as input,
# use 'keep_last' strategy as it introduces least overheads
merge_strategy = 'keep_last'
if merge_strategy == 'keep_first':
overlap_processor = KeepFirstProcessor(
height,
width,
*block_size[::-1],
*step[::-1],
band,
inval=invalid_value)
elif merge_strategy == 'keep_last':
overlap_processor = KeepLastProcessor(height, width, *block_size[::-1],
*step[::-1])
elif merge_strategy == 'accum':
overlap_processor = AccumProcessor(height, width, *block_size[::-1],
*step[::-1])
else:
raise ValueError("{} is not a supported stragegy for block merging.".
format(merge_strategy))
xsize, ysize = block_size
num_blocks = math.ceil(height / step[1]) * math.ceil(width / step[0])
cnt = 0
if show_progress:
pb = tqdm(total=num_blocks)
batch_data = []
batch_offsets = []
for yoff in range(0, height, step[1]):
for xoff in range(0, width, step[0]):
if xoff + xsize > width:
xoff = width - xsize
is_end_of_row = True
else:
is_end_of_row = False
if yoff + ysize > height:
yoff = height - ysize
is_end_of_col = True
else:
is_end_of_col = False
# Read
im = read_block(src_data, xoff, yoff, xsize, ysize)
if isinstance(img_file, tuple):
im2 = read_block(src2_data, xoff, yoff, xsize, ysize)
batch_data.append((im, im2))
else:
batch_data.append(im)
batch_offsets.append((xoff, yoff))
len_batch = len(batch_data)
if is_end_of_row and is_end_of_col and len_batch < batch_size:
# Pad `batch_data` by repeating the last element
batch_data = batch_data + [batch_data[-1]] * (batch_size -
len_batch)
# While keeping `len(batch_offsets)` the number of valid elements in the batch
if len(batch_data) == batch_size:
# Predict
batch_out = predict_func(batch_data, transforms=transforms)
for out, (xoff_, yoff_) in zip(batch_out, batch_offsets):
# Get processed result
pred = overlap_processor.process_pred(out, xoff_, yoff_)
# Write to file
band.WriteArray(pred, xoff_, yoff_)
dst_data.FlushCache()
batch_data.clear()
batch_offsets.clear()
cnt += 1
if show_progress:
pb.update(1)
pb.set_description("{} out of {} blocks processed.".format(
cnt, num_blocks))
dst_data = None
logging.info("GeoTiff file saved in {}.".format(save_file))

@ -25,7 +25,9 @@ def decode_image(im_path,
to_uint8=True,
decode_bgr=True,
decode_sar=True,
read_geo_info=False):
read_geo_info=False,
use_stretch=False,
read_raw=False):
"""
Decode an image.
@ -37,11 +39,16 @@ def decode_image(im_path,
uint8 type. Defaults to True.
decode_bgr (bool, optional): If True, automatically interpret a non-geo
image (e.g. jpeg images) as a BGR image. Defaults to True.
decode_sar (bool, optional): If True, automatically interpret a two-channel
decode_sar (bool, optional): If True, automatically interpret a single-channel
geo image (e.g. geotiff images) as a SAR image, set this argument to
True. Defaults to True.
read_geo_info (bool, optional): If True, read geographical information from
the image. Deafults to False.
use_stretch (bool, optional): Whether to apply 2% linear stretch. Valid only if
`to_uint8` is True. Defaults to False.
read_raw (bool, optional): If True, equivalent to setting `to_rgb` and `to_uint8`
to False. Setting `read_raw` takes precedence over setting `to_rgb` and
`to_uint8`. Defaults to False.
Returns:
np.ndarray|tuple: If `read_geo_info` is False, return the decoded image.
@ -53,12 +60,16 @@ def decode_image(im_path,
# Do a presence check. osp.exists() assumes `im_path` is a path-like object.
if not osp.exists(im_path):
raise ValueError(f"{im_path} does not exist!")
if read_raw:
to_rgb = False
to_uint8 = False
decoder = T.DecodeImg(
to_rgb=to_rgb,
to_uint8=to_uint8,
decode_bgr=decode_bgr,
decode_sar=decode_sar,
read_geo_info=read_geo_info)
read_geo_info=read_geo_info,
use_stretch=use_stretch)
# Deepcopy to avoid inplace modification
sample = {'image': copy.deepcopy(im_path)}
sample = decoder(sample)

@ -382,13 +382,13 @@ def resize_rle(rle, im_h, im_w, im_scale_x, im_scale_y, interp):
return rle
def to_uint8(im, is_linear=False):
def to_uint8(im, stretch=False):
"""
Convert raster data to uint8 type.
Args:
im (np.ndarray): Input raster image.
is_linear (bool, optional): Use 2% linear stretch or not. Default is False.
stretch (bool, optional): Use 2% linear stretch or not. Default is False.
Returns:
np.ndarray: Image data with unit8 type.
@ -430,7 +430,7 @@ def to_uint8(im, is_linear=False):
dtype = im.dtype.name
if dtype != "uint8":
im = _sample_norm(im)
if is_linear:
if stretch:
im = _two_percent_linear(im)
return im

@ -203,11 +203,13 @@ class DecodeImg(Transform):
uint8 type. Defaults to True.
decode_bgr (bool, optional): If True, automatically interpret a non-geo image
(e.g., jpeg images) as a BGR image. Defaults to True.
decode_sar (bool, optional): If True, automatically interpret a two-channel
decode_sar (bool, optional): If True, automatically interpret a single-channel
geo image (e.g. geotiff images) as a SAR image, set this argument to
True. Defaults to True.
read_geo_info (bool, optional): If True, read geographical information from
the image. Deafults to False.
use_stretch (bool, optional): Whether to apply 2% linear stretch. Valid only if
`to_uint8` is True. Defaults to False.
"""
def __init__(self,
@ -215,13 +217,15 @@ class DecodeImg(Transform):
to_uint8=True,
decode_bgr=True,
decode_sar=True,
read_geo_info=False):
read_geo_info=False,
use_stretch=False):
super(DecodeImg, self).__init__()
self.to_rgb = to_rgb
self.to_uint8 = to_uint8
self.decode_bgr = decode_bgr
self.decode_sar = decode_sar
self.read_geo_info = False
self.read_geo_info = read_geo_info
self.use_stretch = use_stretch
def read_img(self, img_path):
img_format = imghdr.what(img_path)
@ -251,7 +255,7 @@ class DecodeImg(Transform):
im_data = im_data.transpose((1, 2, 0))
if self.read_geo_info:
geo_trans = dataset.GetGeoTransform()
geo_proj = dataset.GetGeoProjection()
geo_proj = dataset.GetProjection()
elif img_format in ['jpeg', 'bmp', 'png', 'jpg']:
if self.decode_bgr:
im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
@ -288,7 +292,7 @@ class DecodeImg(Transform):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.to_uint8:
image = to_uint8(image)
image = to_uint8(image, stretch=self.use_stretch)
if self.read_geo_info:
return image, geo_info_dict

@ -151,8 +151,8 @@ class TestCDPredictor(TestPredictor):
# Single input (ndarrays)
input_ = (decode_image(
t1_path, to_rgb=False), decode_image(
t2_path, to_rgb=False)) # Reuse the name `input_`
t1_path, read_raw=True), decode_image(
t2_path, read_raw=True)) # Reuse the name `input_`
out_single_array_p = predictor.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_file_p)
out_single_array_t = trainer.predict(input_, transforms=transforms)
@ -175,8 +175,9 @@ class TestCDPredictor(TestPredictor):
# Multiple inputs (ndarrays)
input_ = [(decode_image(
t1_path, to_rgb=False), decode_image(
t2_path, to_rgb=False))] * num_inputs # Reuse the name `input_`
t1_path, read_raw=True), decode_image(
t2_path,
read_raw=True))] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms)
@ -217,7 +218,7 @@ class TestClasPredictor(TestPredictor):
# Single input (ndarray)
input_ = decode_image(
single_input, to_rgb=False) # Reuse the name `input_`
single_input, read_raw=True) # Reuse the name `input_`
out_single_array_p = predictor.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_file_p)
out_single_array_t = trainer.predict(input_, transforms=transforms)
@ -241,7 +242,8 @@ class TestClasPredictor(TestPredictor):
# Multiple inputs (ndarrays)
input_ = [decode_image(
single_input, to_rgb=False)] * num_inputs # Reuse the name `input_`
single_input,
read_raw=True)] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms)
@ -282,7 +284,7 @@ class TestDetPredictor(TestPredictor):
# Single input (ndarray)
input_ = decode_image(
single_input, to_rgb=False) # Reuse the name `input_`
single_input, read_raw=True) # Reuse the name `input_`
predictor.predict(input_, transforms=transforms)
trainer.predict(input_, transforms=transforms)
out_single_array_list_p = predictor.predict(
@ -301,7 +303,8 @@ class TestDetPredictor(TestPredictor):
# Multiple inputs (ndarrays)
input_ = [decode_image(
single_input, to_rgb=False)] * num_inputs # Reuse the name `input_`
single_input,
read_raw=True)] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms)
@ -343,7 +346,7 @@ class TestResPredictor(TestPredictor):
# Single input (ndarray)
input_ = decode_image(
single_input, to_rgb=False) # Reuse the name `input_`
single_input, read_raw=True) # Reuse the name `input_`
predictor.predict(input_, transforms=transforms)
trainer.predict(input_, transforms=transforms)
out_single_array_list_p = predictor.predict(
@ -362,7 +365,8 @@ class TestResPredictor(TestPredictor):
# Multiple inputs (ndarrays)
input_ = [decode_image(
single_input, to_rgb=False)] * num_inputs # Reuse the name `input_`
single_input,
read_raw=True)] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms)
@ -400,7 +404,7 @@ class TestSegPredictor(TestPredictor):
# Single input (ndarray)
input_ = decode_image(
single_input, to_rgb=False) # Reuse the name `input_`
single_input, read_raw=True) # Reuse the name `input_`
out_single_array_p = predictor.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_file_p)
out_single_array_t = trainer.predict(input_, transforms=transforms)
@ -423,7 +427,8 @@ class TestSegPredictor(TestPredictor):
# Multiple inputs (ndarrays)
input_ = [decode_image(
single_input, to_rgb=False)] * num_inputs # Reuse the name `input_`
single_input,
read_raw=True)] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms)

@ -13,4 +13,5 @@
# limitations under the License.
from rs_models import *
from tasks import *
from transforms import *

@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .test_slider_predict import *

@ -0,0 +1,212 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
import tempfile
import paddlers as pdrs
import paddlers.transforms as T
from testing_utils import CommonTest
class _TestSliderPredictNamespace:
class TestSliderPredict(CommonTest):
def test_blocksize_and_overlap_whole(self):
# Original image size (256, 256)
with tempfile.TemporaryDirectory() as td:
# Whole-image inference using predict()
pred_whole = self.model.predict(self.image_path,
self.transforms)
pred_whole = pred_whole['label_map']
# Whole-image inference using slider_predict()
save_dir = osp.join(td, 'pred1')
self.model.slider_predict(self.image_path, save_dir, 256, 0,
self.transforms)
pred1 = T.decode_image(
osp.join(save_dir, self.basename),
read_raw=True,
decode_sar=False)
self.check_output_equal(pred1.shape, pred_whole.shape)
# `block_size` == `overlap`
save_dir = osp.join(td, 'pred2')
with self.assertRaises(ValueError):
self.model.slider_predict(self.image_path, save_dir, 128,
128, self.transforms)
# `block_size` is a tuple
save_dir = osp.join(td, 'pred3')
self.model.slider_predict(self.image_path, save_dir, (128, 32),
0, self.transforms)
pred3 = T.decode_image(
osp.join(save_dir, self.basename),
read_raw=True,
decode_sar=False)
self.check_output_equal(pred3.shape, pred_whole.shape)
# `block_size` and `overlap` are both tuples
save_dir = osp.join(td, 'pred4')
self.model.slider_predict(self.image_path, save_dir, (128, 100),
(10, 5), self.transforms)
pred4 = T.decode_image(
osp.join(save_dir, self.basename),
read_raw=True,
decode_sar=False)
self.check_output_equal(pred4.shape, pred_whole.shape)
# `block_size` larger than image size
save_dir = osp.join(td, 'pred5')
with self.assertRaises(ValueError):
self.model.slider_predict(self.image_path, save_dir, 512, 0,
self.transforms)
def test_merge_strategy(self):
with tempfile.TemporaryDirectory() as td:
# Whole-image inference using predict()
pred_whole = self.model.predict(self.image_path,
self.transforms)
pred_whole = pred_whole['label_map']
# 'keep_first'
save_dir = osp.join(td, 'keep_first')
self.model.slider_predict(
self.image_path,
save_dir,
128,
64,
self.transforms,
merge_strategy='keep_first')
pred_keepfirst = T.decode_image(
osp.join(save_dir, self.basename),
read_raw=True,
decode_sar=False)
self.check_output_equal(pred_keepfirst.shape, pred_whole.shape)
# 'keep_last'
save_dir = osp.join(td, 'keep_last')
self.model.slider_predict(
self.image_path,
save_dir,
128,
64,
self.transforms,
merge_strategy='keep_last')
pred_keeplast = T.decode_image(
osp.join(save_dir, self.basename),
read_raw=True,
decode_sar=False)
self.check_output_equal(pred_keeplast.shape, pred_whole.shape)
# 'accum'
save_dir = osp.join(td, 'accum')
self.model.slider_predict(
self.image_path,
save_dir,
128,
64,
self.transforms,
merge_strategy='accum')
pred_accum = T.decode_image(
osp.join(save_dir, self.basename),
read_raw=True,
decode_sar=False)
self.check_output_equal(pred_accum.shape, pred_whole.shape)
def test_geo_info(self):
with tempfile.TemporaryDirectory() as td:
_, geo_info_in = T.decode_image(
self.ref_path, read_geo_info=True)
self.model.slider_predict(self.image_path, td, 128, 0,
self.transforms)
_, geo_info_out = T.decode_image(
osp.join(td, self.basename), read_geo_info=True)
self.assertEqual(geo_info_out['geo_trans'],
geo_info_in['geo_trans'])
self.assertEqual(geo_info_out['geo_proj'],
geo_info_in['geo_proj'])
def test_batch_size(self):
with tempfile.TemporaryDirectory() as td:
# batch_size = 1
save_dir = osp.join(td, 'bs1')
self.model.slider_predict(
self.image_path,
save_dir,
128,
64,
self.transforms,
merge_strategy='keep_first',
batch_size=1)
pred_bs1 = T.decode_image(
osp.join(save_dir, self.basename),
read_raw=True,
decode_sar=False)
# batch_size = 4
save_dir = osp.join(td, 'bs4')
self.model.slider_predict(
self.image_path,
save_dir,
128,
64,
self.transforms,
merge_strategy='keep_first',
batch_size=4)
pred_bs4 = T.decode_image(
osp.join(save_dir, self.basename),
read_raw=True,
decode_sar=False)
self.check_output_equal(pred_bs4, pred_bs1)
# batch_size = 8
save_dir = osp.join(td, 'bs4')
self.model.slider_predict(
self.image_path,
save_dir,
128,
64,
self.transforms,
merge_strategy='keep_first',
batch_size=8)
pred_bs8 = T.decode_image(
osp.join(save_dir, self.basename),
read_raw=True,
decode_sar=False)
self.check_output_equal(pred_bs8, pred_bs1)
class TestSegSliderPredict(_TestSliderPredictNamespace.TestSliderPredict):
def setUp(self):
self.model = pdrs.tasks.seg.UNet(in_channels=10)
self.transforms = T.Compose([
T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10),
T.ArrangeSegmenter('test')
])
self.image_path = "data/ssst/multispectral.tif"
self.ref_path = self.image_path
self.basename = osp.basename(self.ref_path)
class TestCDSliderPredict(_TestSliderPredictNamespace.TestSliderPredict):
def setUp(self):
self.model = pdrs.tasks.cd.BIT(in_channels=10)
self.transforms = T.Compose([
T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10),
T.ArrangeChangeDetector('test')
])
self.image_path = ("data/ssmt/multispectral_t1.tif",
"data/ssmt/multispectral_t2.tif")
self.ref_path = self.image_path[0]
self.basename = osp.basename(self.ref_path)
Loading…
Cancel
Save