[Feat] Optimize `slide_predict()` (#3)

own
Lin Manhui 2 years ago committed by GitHub
parent 7a0f5405f6
commit 8beb15c3c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 6
      docs/apis/data.md
  2. 17
      paddlers/deploy/predictor.py
  3. 10
      paddlers/tasks/change_detector.py
  4. 2
      paddlers/tasks/classifier.py
  5. 2
      paddlers/tasks/object_detector.py
  6. 2
      paddlers/tasks/restorer.py
  7. 8
      paddlers/tasks/segmenter.py
  8. 52
      paddlers/tasks/utils/slider_predict.py
  9. 17
      paddlers/transforms/__init__.py
  10. 6
      paddlers/transforms/functions.py
  11. 10
      paddlers/transforms/operators.py
  12. 29
      tests/deploy/test_predictor.py
  13. 190
      tests/tasks/test_slider_predict.py

@ -134,10 +134,12 @@
|-------|----|--------|-----|
|`im_path`|`str`|输入图像路径。||
|`to_rgb`|`bool`|若为`True`,则执行BGR到RGB格式的转换。|`True`|
|`to_uint8`|`bool`|若为`True`,则将读取的像数据量化并转换为uint8类型。|`True`|
|`to_uint8`|`bool`|若为`True`,则将读取的像数据量化并转换为uint8类型。|`True`|
|`decode_bgr`|`bool`|若为`True`,则自动将非地学格式影像(如jpeg影像)解析为BGR格式。|`True`|
|`decode_sar`|`bool`|若为`True`,则自动将2通道的地学格式影像(如GeoTiff影像)作为SAR影像解析。|`True`|
|`decode_sar`|`bool`|若为`True`,则自动将通道的地学格式影像(如GeoTiff影像)作为SAR影像解析。|`True`|
|`read_geo_info`|`bool`|若为`True`,则从影像中读取地理信息。|`False`|
|`use_stretch`|`bool`|是否对影像亮度进行2%线性拉伸。仅当`to_uint8`为`True`时有效。|`False`|
|`read_raw`|`bool`|若为`True`,等价于指定`to_rgb`和`to_uint8`为`False`,且该参数的优先级高于上述参数。|`False`|
返回格式如下:

@ -282,8 +282,8 @@ class Predictor(object):
img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For scene classification, image restoration,
object detection and semantic segmentation tasks, `img_file` should be either the path of the image to predict,
a decoded image (a np.ndarray, which should be consistent with what you get from passing image path to
paddlers.transforms.decode_image()), or a list of image paths or decoded images. For change detection tasks,
`img_file` should be a tuple of image paths, a tuple of decoded images, or a list of tuples.
paddlers.transforms.decode_image(..., read_raw=True)), or a list of image paths or decoded images. For change
detection tasks, `img_file` should be a tuple of image paths, a tuple of decoded images, or a list of tuples.
topk(int, optional): Top-k values to reserve in a classification result. Defaults to 1.
transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms
from `model.yml`. Defaults to None.
@ -332,7 +332,8 @@ class Predictor(object):
overlap=36,
transforms=None,
invalid_value=255,
merge_strategy='keep_last'):
merge_strategy='keep_last',
batch_size=1):
"""
Do inference using sliding windows. Only semantic segmentation and change detection models are supported in the
sliding-predicting mode.
@ -340,9 +341,9 @@ class Predictor(object):
Args:
img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For semantic segmentation tasks, `img_file`
should be either the path of the image to predict, a decoded image (a np.ndarray, which should be
consistent with what you get from passing image path to paddlers.transforms.decode_image()), or a list of
image paths or decoded images. For change detection tasks, `img_file` should be a tuple of image paths, a
tuple of decoded images, or a list of tuples.
consistent with what you get from passing image path to paddlers.transforms.decode_image(..., read_raw=True)),
or a list of image paths or decoded images. For change detection tasks, `img_file` should be a tuple of
image paths, a tuple of decoded images, or a list of tuples.
save_dir (str): Directory that contains saved geotiff file.
block_size (list[int] | tuple[int] | int): Size of block. If `block_size` is a list or tuple, it should be in
(W, H) format.
@ -355,6 +356,7 @@ class Predictor(object):
{'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' means keeping the values of the first and
the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel
according to accumulated probabilities. Defaults to 'keep_last'.
batch_size (int, optional): Batch size used in inference. Defaults to 1.
"""
slider_predict(
partial(
@ -365,7 +367,8 @@ class Predictor(object):
overlap,
transforms,
invalid_value,
merge_strategy)
merge_strategy,
batch_size)
def batch_predict(self, image_list, **params):
return self.predict(img_file=image_list, **params)

@ -588,7 +588,8 @@ class BaseChangeDetector(BaseModel):
overlap=36,
transforms=None,
invalid_value=255,
merge_strategy='keep_last'):
merge_strategy='keep_last',
batch_size=1):
"""
Do inference using sliding windows.
@ -611,10 +612,11 @@ class BaseChangeDetector(BaseModel):
means keeping the values of the first and the last block in traversal
order, respectively. 'accum' means determining the class of an overlapping
pixel according to accumulated probabilities. Defaults to 'keep_last'.
batch_size (int, optional): Batch size used in inference. Defaults to 1.
"""
slider_predict(self.predict, img_files, save_dir, block_size, overlap,
transforms, invalid_value, merge_strategy)
transforms, invalid_value, merge_strategy, batch_size)
def preprocess(self, images, transforms, to_tensor=True):
self._check_transforms(transforms, 'test')
@ -622,8 +624,8 @@ class BaseChangeDetector(BaseModel):
batch_ori_shape = list()
for im1, im2 in images:
if isinstance(im1, str) or isinstance(im2, str):
im1 = decode_image(im1, to_rgb=False)
im2 = decode_image(im2, to_rgb=False)
im1 = decode_image(im1, read_raw=True)
im2 = decode_image(im2, read_raw=True)
ori_shape = im1.shape[:2]
# XXX: sample do not contain 'image_t1' and 'image_t2'.
sample = {'image': im1, 'image2': im2}

@ -497,7 +497,7 @@ class BaseClassifier(BaseModel):
batch_ori_shape = list()
for im in images:
if isinstance(im, str):
im = decode_image(im, to_rgb=False)
im = decode_image(im, read_raw=True)
ori_shape = im.shape[:2]
sample = {'image': im}
im = transforms(sample)

@ -617,7 +617,7 @@ class BaseDetector(BaseModel):
batch_samples = list()
for im in images:
if isinstance(im, str):
im = decode_image(im, to_rgb=False)
im = decode_image(im, read_raw=True)
sample = {'image': im}
sample = transforms(sample)
batch_samples.append(sample)

@ -481,7 +481,7 @@ class BaseRestorer(BaseModel):
batch_tar_shape = list()
for im in images:
if isinstance(im, str):
im = decode_image(im, to_rgb=False)
im = decode_image(im, read_raw=True)
ori_shape = im.shape[:2]
sample = {'image': im}
im = transforms(sample)[0]

@ -560,7 +560,8 @@ class BaseSegmenter(BaseModel):
overlap=36,
transforms=None,
invalid_value=255,
merge_strategy='keep_last'):
merge_strategy='keep_last',
batch_size=1):
"""
Do inference using sliding windows.
@ -583,10 +584,11 @@ class BaseSegmenter(BaseModel):
means keeping the values of the first and the last block in traversal
order, respectively. 'accum' means determining the class of an overlapping
pixel according to accumulated probabilities. Defaults to 'keep_last'.
batch_size (int, optional): Batch size used in inference. Defaults to 1.
"""
slider_predict(self.predict, img_file, save_dir, block_size, overlap,
transforms, invalid_value, merge_strategy)
transforms, invalid_value, merge_strategy, batch_size)
def preprocess(self, images, transforms, to_tensor=True):
self._check_transforms(transforms, 'test')
@ -594,7 +596,7 @@ class BaseSegmenter(BaseModel):
batch_ori_shape = list()
for im in images:
if isinstance(im, str):
im = decode_image(im, to_rgb=False)
im = decode_image(im, read_raw=True)
ori_shape = im.shape[:2]
sample = {'image': im}
im = transforms(sample)[0]

@ -14,6 +14,7 @@
import os
import os.path as osp
import math
from abc import ABCMeta, abstractmethod
from collections import Counter, defaultdict
@ -117,10 +118,10 @@ class ProbCache(Cache):
def roll_cache(self):
if self.order == 'c':
self.cache = np.roll(self.cache, -self.sh, axis=0)
self.cache[:-self.sh] = self.cache[self.sh:]
self.cache[-self.sh:, :] = 0
elif self.order == 'f':
self.cache = np.roll(self.cache, -self.sw, axis=1)
self.cache[:, :-self.sw] = self.cache[:, self.sw:]
self.cache[:, -self.sw:] = 0
def get_block(self, i_st, j_st, h, w):
@ -128,7 +129,7 @@ class ProbCache(Cache):
def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
transforms, invalid_value, merge_strategy):
transforms, invalid_value, merge_strategy, batch_size):
"""
Do inference using sliding windows.
@ -151,6 +152,7 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
means keeping the values of the first and the last block in
traversal order, respectively. 'accum' means determining the class
of an overlapping pixel according to accumulated probabilities.
batch_size (int): Batch size used in inference.
"""
try:
@ -200,6 +202,13 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
height = src_data.RasterYSize
bands = src_data.RasterCount
# XXX: GDAL read behavior conforms to paddlers.transforms.decode_image(read_raw=True)
# except for SAR images.
if bands == 1:
logging.warning(
f"Detected `bands=1`. Please note that currently `slider_predict()` does not properly handle SAR images."
)
if block_size[0] > width or block_size[1] > height:
raise ValueError("`block_size` should not be larger than image size.")
@ -228,6 +237,8 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
if merge_strategy == 'accum':
cache = ProbCache(height, width, *block_size[::-1], *step[::-1])
batch_data = []
batch_offsets = []
for yoff in range(0, height, step[1]):
for xoff in range(0, width, step[0]):
xsize, ysize = block_size
@ -236,6 +247,9 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
if yoff + ysize > height:
yoff = height - ysize
is_end_of_col = yoff + ysize >= height
is_end_of_row = xoff + xsize >= width
# Read and fill
im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
(1, 2, 0))
@ -243,18 +257,31 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
if isinstance(img_file, tuple):
im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
(1, 2, 0))
# Predict
out = predict_func((im, im2), transforms=transforms)
batch_data.append((im, im2))
else:
batch_data.append(im)
batch_offsets.append((xoff, yoff))
len_batch = len(batch_data)
if is_end_of_row and is_end_of_col and len_batch < batch_size:
# Pad `batch_data` by repeating the last element
batch_data = batch_data + [batch_data[-1]] * (batch_size -
len_batch)
# While keeping `len(batch_offsets)` the number of valid elements in the batch
if len(batch_data) == batch_size:
# Predict
out = predict_func(im, transforms=transforms)
batch_out = predict_func(batch_data, transforms=transforms)
for out, (xoff_, yoff_) in zip(batch_out, batch_offsets):
pred = out['label_map'].astype('uint8')
pred = pred[:ysize, :xsize]
# Deal with overlapping pixels
if merge_strategy == 'keep_first':
rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize)
rd_block = band.ReadAsArray(xoff_, yoff_, xsize, ysize)
mask = rd_block != invalid_value
pred = np.where(mask, rd_block, pred)
elif merge_strategy == 'keep_last':
@ -262,14 +289,17 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
elif merge_strategy == 'accum':
prob = out['score_map']
prob = prob[:ysize, :xsize]
cache.update_block(0, xoff, ysize, xsize, prob)
pred = cache.get_block(0, xoff, ysize, xsize)
if xoff + xsize >= width:
cache.update_block(0, xoff_, ysize, xsize, prob)
pred = cache.get_block(0, xoff_, ysize, xsize)
if xoff_ + xsize >= width:
cache.roll_cache()
# Write to file
band.WriteArray(pred, xoff, yoff)
band.WriteArray(pred, xoff_, yoff_)
dst_data.FlushCache()
batch_data.clear()
batch_offsets.clear()
dst_data = None
logging.info("GeoTiff file saved in {}.".format(save_file))

@ -25,7 +25,9 @@ def decode_image(im_path,
to_uint8=True,
decode_bgr=True,
decode_sar=True,
read_geo_info=False):
read_geo_info=False,
use_stretch=False,
read_raw=False):
"""
Decode an image.
@ -37,11 +39,16 @@ def decode_image(im_path,
uint8 type. Defaults to True.
decode_bgr (bool, optional): If True, automatically interpret a non-geo
image (e.g. jpeg images) as a BGR image. Defaults to True.
decode_sar (bool, optional): If True, automatically interpret a two-channel
decode_sar (bool, optional): If True, automatically interpret a single-channel
geo image (e.g. geotiff images) as a SAR image, set this argument to
True. Defaults to True.
read_geo_info (bool, optional): If True, read geographical information from
the image. Deafults to False.
use_stretch (bool, optional): Whether to apply 2% linear stretch. Valid only if
`to_uint8` is True. Defaults to False.
read_raw (bool, optional): If True, equivalent to setting `to_rgb` and `to_uint8`
to False. Setting `read_raw` takes precedence over setting `to_rgb` and
`to_uint8`. Defaults to False.
Returns:
np.ndarray|tuple: If `read_geo_info` is False, return the decoded image.
@ -53,12 +60,16 @@ def decode_image(im_path,
# Do a presence check. osp.exists() assumes `im_path` is a path-like object.
if not osp.exists(im_path):
raise ValueError(f"{im_path} does not exist!")
if read_raw:
to_rgb = False
to_uint8 = False
decoder = T.DecodeImg(
to_rgb=to_rgb,
to_uint8=to_uint8,
decode_bgr=decode_bgr,
decode_sar=decode_sar,
read_geo_info=read_geo_info)
read_geo_info=read_geo_info,
use_stretch=use_stretch)
# Deepcopy to avoid inplace modification
sample = {'image': copy.deepcopy(im_path)}
sample = decoder(sample)

@ -382,13 +382,13 @@ def resize_rle(rle, im_h, im_w, im_scale_x, im_scale_y, interp):
return rle
def to_uint8(im, is_linear=False):
def to_uint8(im, stretch=False):
"""
Convert raster data to uint8 type.
Args:
im (np.ndarray): Input raster image.
is_linear (bool, optional): Use 2% linear stretch or not. Default is False.
stretch (bool, optional): Use 2% linear stretch or not. Default is False.
Returns:
np.ndarray: Image data with unit8 type.
@ -430,7 +430,7 @@ def to_uint8(im, is_linear=False):
dtype = im.dtype.name
if dtype != "uint8":
im = _sample_norm(im)
if is_linear:
if stretch:
im = _two_percent_linear(im)
return im

@ -179,11 +179,13 @@ class DecodeImg(Transform):
uint8 type. Defaults to True.
decode_bgr (bool, optional): If True, automatically interpret a non-geo image
(e.g., jpeg images) as a BGR image. Defaults to True.
decode_sar (bool, optional): If True, automatically interpret a two-channel
decode_sar (bool, optional): If True, automatically interpret a single-channel
geo image (e.g. geotiff images) as a SAR image, set this argument to
True. Defaults to True.
read_geo_info (bool, optional): If True, read geographical information from
the image. Deafults to False.
use_stretch (bool, optional): Whether to apply 2% linear stretch. Valid only if
`to_uint8` is True. Defaults to False.
"""
def __init__(self,
@ -191,13 +193,15 @@ class DecodeImg(Transform):
to_uint8=True,
decode_bgr=True,
decode_sar=True,
read_geo_info=False):
read_geo_info=False,
use_stretch=False):
super(DecodeImg, self).__init__()
self.to_rgb = to_rgb
self.to_uint8 = to_uint8
self.decode_bgr = decode_bgr
self.decode_sar = decode_sar
self.read_geo_info = read_geo_info
self.use_stretch = use_stretch
def read_img(self, img_path):
img_format = imghdr.what(img_path)
@ -264,7 +268,7 @@ class DecodeImg(Transform):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.to_uint8:
image = to_uint8(image)
image = to_uint8(image, stretch=self.use_stretch)
if self.read_geo_info:
return image, geo_info_dict

@ -145,8 +145,8 @@ class TestCDPredictor(TestPredictor):
# Single input (ndarrays)
input_ = (decode_image(
t1_path, to_rgb=False), decode_image(
t2_path, to_rgb=False)) # Reuse the name `input_`
t1_path, read_raw=True), decode_image(
t2_path, read_raw=True)) # Reuse the name `input_`
out_single_array_p = predictor.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_file_p)
out_single_array_t = trainer.predict(input_, transforms=transforms)
@ -169,8 +169,9 @@ class TestCDPredictor(TestPredictor):
# Multiple inputs (ndarrays)
input_ = [(decode_image(
t1_path, to_rgb=False), decode_image(
t2_path, to_rgb=False))] * num_inputs # Reuse the name `input_`
t1_path, read_raw=True), decode_image(
t2_path,
read_raw=True))] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms)
@ -211,7 +212,7 @@ class TestClasPredictor(TestPredictor):
# Single input (ndarray)
input_ = decode_image(
single_input, to_rgb=False) # Reuse the name `input_`
single_input, read_raw=True) # Reuse the name `input_`
out_single_array_p = predictor.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_file_p)
out_single_array_t = trainer.predict(input_, transforms=transforms)
@ -235,7 +236,8 @@ class TestClasPredictor(TestPredictor):
# Multiple inputs (ndarrays)
input_ = [decode_image(
single_input, to_rgb=False)] * num_inputs # Reuse the name `input_`
single_input,
read_raw=True)] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms)
@ -276,7 +278,7 @@ class TestDetPredictor(TestPredictor):
# Single input (ndarray)
input_ = decode_image(
single_input, to_rgb=False) # Reuse the name `input_`
single_input, read_raw=True) # Reuse the name `input_`
predictor.predict(input_, transforms=transforms)
trainer.predict(input_, transforms=transforms)
out_single_array_list_p = predictor.predict(
@ -295,7 +297,8 @@ class TestDetPredictor(TestPredictor):
# Multiple inputs (ndarrays)
input_ = [decode_image(
single_input, to_rgb=False)] * num_inputs # Reuse the name `input_`
single_input,
read_raw=True)] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms)
@ -329,7 +332,7 @@ class TestResPredictor(TestPredictor):
# Single input (ndarray)
input_ = decode_image(
single_input, to_rgb=False) # Reuse the name `input_`
single_input, read_raw=True) # Reuse the name `input_`
predictor.predict(input_, transforms=transforms)
trainer.predict(input_, transforms=transforms)
out_single_array_list_p = predictor.predict(
@ -348,7 +351,8 @@ class TestResPredictor(TestPredictor):
# Multiple inputs (ndarrays)
input_ = [decode_image(
single_input, to_rgb=False)] * num_inputs # Reuse the name `input_`
single_input,
read_raw=True)] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms)
@ -386,7 +390,7 @@ class TestSegPredictor(TestPredictor):
# Single input (ndarray)
input_ = decode_image(
single_input, to_rgb=False) # Reuse the name `input_`
single_input, read_raw=True) # Reuse the name `input_`
out_single_array_p = predictor.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_file_p)
out_single_array_t = trainer.predict(input_, transforms=transforms)
@ -409,7 +413,8 @@ class TestSegPredictor(TestPredictor):
# Multiple inputs (ndarrays)
input_ = [decode_image(
single_input, to_rgb=False)] * num_inputs # Reuse the name `input_`
single_input,
read_raw=True)] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms)

@ -20,21 +20,14 @@ import paddlers.transforms as T
from testing_utils import CommonTest
class TestSegSliderPredict(CommonTest):
def setUp(self):
self.model = pdrs.tasks.seg.UNet(in_channels=10)
self.transforms = T.Compose([
T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10),
T.ArrangeSegmenter('test')
])
self.image_path = "data/ssst/multispectral.tif"
self.basename = osp.basename(self.image_path)
class _TestSliderPredictNamespace:
class TestSliderPredict(CommonTest):
def test_blocksize_and_overlap_whole(self):
# Original image size (256, 256)
with tempfile.TemporaryDirectory() as td:
# Whole-image inference using predict()
pred_whole = self.model.predict(self.image_path, self.transforms)
pred_whole = self.model.predict(self.image_path,
self.transforms)
pred_whole = pred_whole['label_map']
# Whole-image inference using slider_predict()
@ -43,23 +36,23 @@ class TestSegSliderPredict(CommonTest):
self.transforms)
pred1 = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
read_raw=True,
decode_sar=False)
self.check_output_equal(pred1.shape, pred_whole.shape)
# `block_size` == `overlap`
save_dir = osp.join(td, 'pred2')
with self.assertRaises(ValueError):
self.model.slider_predict(self.image_path, save_dir, 128, 128,
self.transforms)
self.model.slider_predict(self.image_path, save_dir, 128,
128, self.transforms)
# `block_size` is a tuple
save_dir = osp.join(td, 'pred3')
self.model.slider_predict(self.image_path, save_dir, (128, 32), 0,
self.transforms)
self.model.slider_predict(self.image_path, save_dir, (128, 32),
0, self.transforms)
pred3 = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
read_raw=True,
decode_sar=False)
self.check_output_equal(pred3.shape, pred_whole.shape)
@ -69,7 +62,7 @@ class TestSegSliderPredict(CommonTest):
(10, 5), self.transforms)
pred4 = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
read_raw=True,
decode_sar=False)
self.check_output_equal(pred4.shape, pred_whole.shape)
@ -82,7 +75,8 @@ class TestSegSliderPredict(CommonTest):
def test_merge_strategy(self):
with tempfile.TemporaryDirectory() as td:
# Whole-image inference using predict()
pred_whole = self.model.predict(self.image_path, self.transforms)
pred_whole = self.model.predict(self.image_path,
self.transforms)
pred_whole = pred_whole['label_map']
# 'keep_first'
@ -96,7 +90,7 @@ class TestSegSliderPredict(CommonTest):
merge_strategy='keep_first')
pred_keepfirst = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
read_raw=True,
decode_sar=False)
self.check_output_equal(pred_keepfirst.shape, pred_whole.shape)
@ -111,7 +105,7 @@ class TestSegSliderPredict(CommonTest):
merge_strategy='keep_last')
pred_keeplast = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
read_raw=True,
decode_sar=False)
self.check_output_equal(pred_keeplast.shape, pred_whole.shape)
@ -126,141 +120,93 @@ class TestSegSliderPredict(CommonTest):
merge_strategy='accum')
pred_accum = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
read_raw=True,
decode_sar=False)
self.check_output_equal(pred_accum.shape, pred_whole.shape)
def test_geo_info(self):
with tempfile.TemporaryDirectory() as td:
_, geo_info_in = T.decode_image(self.image_path, read_geo_info=True)
_, geo_info_in = T.decode_image(
self.ref_path, read_geo_info=True)
self.model.slider_predict(self.image_path, td, 128, 0,
self.transforms)
_, geo_info_out = T.decode_image(
osp.join(td, self.basename), read_geo_info=True)
self.assertEqual(geo_info_out['geo_trans'],
geo_info_in['geo_trans'])
self.assertEqual(geo_info_out['geo_proj'], geo_info_in['geo_proj'])
self.assertEqual(geo_info_out['geo_proj'],
geo_info_in['geo_proj'])
class TestCDSliderPredict(CommonTest):
def setUp(self):
self.model = pdrs.tasks.cd.BIT(in_channels=10)
self.transforms = T.Compose([
T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10),
T.ArrangeChangeDetector('test')
])
self.image_paths = ("data/ssmt/multispectral_t1.tif",
"data/ssmt/multispectral_t2.tif")
self.basename = osp.basename(self.image_paths[0])
def test_blocksize_and_overlap_whole(self):
# Original image size (256, 256)
def test_batch_size(self):
with tempfile.TemporaryDirectory() as td:
# Whole-image inference using predict()
pred_whole = self.model.predict(self.image_paths, self.transforms)
pred_whole = pred_whole['label_map']
# Whole-image inference using slider_predict()
save_dir = osp.join(td, 'pred1')
self.model.slider_predict(self.image_paths, save_dir, 256, 0,
self.transforms)
pred1 = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred1.shape, pred_whole.shape)
# `block_size` == `overlap`
save_dir = osp.join(td, 'pred2')
with self.assertRaises(ValueError):
self.model.slider_predict(self.image_paths, save_dir, 128, 128,
self.transforms)
# `block_size` is a tuple
save_dir = osp.join(td, 'pred3')
self.model.slider_predict(self.image_paths, save_dir, (128, 32), 0,
self.transforms)
pred3 = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred3.shape, pred_whole.shape)
# `block_size` and `overlap` are both tuples
save_dir = osp.join(td, 'pred4')
self.model.slider_predict(self.image_paths, save_dir, (128, 100),
(10, 5), self.transforms)
pred4 = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred4.shape, pred_whole.shape)
# `block_size` larger than image size
save_dir = osp.join(td, 'pred5')
with self.assertRaises(ValueError):
self.model.slider_predict(self.image_paths, save_dir, 512, 0,
self.transforms)
def test_merge_strategy(self):
with tempfile.TemporaryDirectory() as td:
# Whole-image inference using predict()
pred_whole = self.model.predict(self.image_paths, self.transforms)
pred_whole = pred_whole['label_map']
# 'keep_first'
save_dir = osp.join(td, 'keep_first')
# batch_size = 1
save_dir = osp.join(td, 'bs1')
self.model.slider_predict(
self.image_paths,
self.image_path,
save_dir,
128,
64,
self.transforms,
merge_strategy='keep_first')
pred_keepfirst = T.decode_image(
merge_strategy='keep_first',
batch_size=1)
pred_bs1 = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
read_raw=True,
decode_sar=False)
self.check_output_equal(pred_keepfirst.shape, pred_whole.shape)
# 'keep_last'
save_dir = osp.join(td, 'keep_last')
# batch_size = 4
save_dir = osp.join(td, 'bs4')
self.model.slider_predict(
self.image_paths,
self.image_path,
save_dir,
128,
64,
self.transforms,
merge_strategy='keep_last')
pred_keeplast = T.decode_image(
merge_strategy='keep_first',
batch_size=4)
pred_bs4 = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
read_raw=True,
decode_sar=False)
self.check_output_equal(pred_keeplast.shape, pred_whole.shape)
self.check_output_equal(pred_bs4, pred_bs1)
# 'accum'
save_dir = osp.join(td, 'accum')
# batch_size = 8
save_dir = osp.join(td, 'bs4')
self.model.slider_predict(
self.image_paths,
self.image_path,
save_dir,
128,
64,
self.transforms,
merge_strategy='accum')
pred_accum = T.decode_image(
merge_strategy='keep_first',
batch_size=8)
pred_bs8 = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
read_raw=True,
decode_sar=False)
self.check_output_equal(pred_accum.shape, pred_whole.shape)
self.check_output_equal(pred_bs8, pred_bs1)
def test_geo_info(self):
with tempfile.TemporaryDirectory() as td:
_, geo_info_in = T.decode_image(
self.image_paths[0], read_geo_info=True)
self.model.slider_predict(self.image_paths, td, 128, 0,
self.transforms)
_, geo_info_out = T.decode_image(
osp.join(td, self.basename), read_geo_info=True)
self.assertEqual(geo_info_out['geo_trans'],
geo_info_in['geo_trans'])
self.assertEqual(geo_info_out['geo_proj'], geo_info_in['geo_proj'])
class TestSegSliderPredict(_TestSliderPredictNamespace.TestSliderPredict):
def setUp(self):
self.model = pdrs.tasks.seg.UNet(in_channels=10)
self.transforms = T.Compose([
T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10),
T.ArrangeSegmenter('test')
])
self.image_path = "data/ssst/multispectral.tif"
self.ref_path = self.image_path
self.basename = osp.basename(self.ref_path)
class TestCDSliderPredict(_TestSliderPredictNamespace.TestSliderPredict):
def setUp(self):
self.model = pdrs.tasks.cd.BIT(in_channels=10)
self.transforms = T.Compose([
T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10),
T.ArrangeChangeDetector('test')
])
self.image_path = ("data/ssmt/multispectral_t1.tif",
"data/ssmt/multispectral_t2.tif")
self.ref_path = self.image_path[0]
self.basename = osp.basename(self.ref_path)

Loading…
Cancel
Save