[Feat] Optimize `slide_predict()` (#3)

own
Lin Manhui 2 years ago committed by GitHub
parent 7a0f5405f6
commit 8beb15c3c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 6
      docs/apis/data.md
  2. 17
      paddlers/deploy/predictor.py
  3. 10
      paddlers/tasks/change_detector.py
  4. 2
      paddlers/tasks/classifier.py
  5. 2
      paddlers/tasks/object_detector.py
  6. 2
      paddlers/tasks/restorer.py
  7. 8
      paddlers/tasks/segmenter.py
  8. 52
      paddlers/tasks/utils/slider_predict.py
  9. 17
      paddlers/transforms/__init__.py
  10. 6
      paddlers/transforms/functions.py
  11. 10
      paddlers/transforms/operators.py
  12. 29
      tests/deploy/test_predictor.py
  13. 190
      tests/tasks/test_slider_predict.py

@ -134,10 +134,12 @@
|-------|----|--------|-----| |-------|----|--------|-----|
|`im_path`|`str`|输入图像路径。|| |`im_path`|`str`|输入图像路径。||
|`to_rgb`|`bool`|若为`True`,则执行BGR到RGB格式的转换。|`True`| |`to_rgb`|`bool`|若为`True`,则执行BGR到RGB格式的转换。|`True`|
|`to_uint8`|`bool`|若为`True`,则将读取的像数据量化并转换为uint8类型。|`True`| |`to_uint8`|`bool`|若为`True`,则将读取的像数据量化并转换为uint8类型。|`True`|
|`decode_bgr`|`bool`|若为`True`,则自动将非地学格式影像(如jpeg影像)解析为BGR格式。|`True`| |`decode_bgr`|`bool`|若为`True`,则自动将非地学格式影像(如jpeg影像)解析为BGR格式。|`True`|
|`decode_sar`|`bool`|若为`True`,则自动将2通道的地学格式影像(如GeoTiff影像)作为SAR影像解析。|`True`| |`decode_sar`|`bool`|若为`True`,则自动将通道的地学格式影像(如GeoTiff影像)作为SAR影像解析。|`True`|
|`read_geo_info`|`bool`|若为`True`,则从影像中读取地理信息。|`False`| |`read_geo_info`|`bool`|若为`True`,则从影像中读取地理信息。|`False`|
|`use_stretch`|`bool`|是否对影像亮度进行2%线性拉伸。仅当`to_uint8`为`True`时有效。|`False`|
|`read_raw`|`bool`|若为`True`,等价于指定`to_rgb`和`to_uint8`为`False`,且该参数的优先级高于上述参数。|`False`|
返回格式如下: 返回格式如下:

@ -282,8 +282,8 @@ class Predictor(object):
img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For scene classification, image restoration, img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For scene classification, image restoration,
object detection and semantic segmentation tasks, `img_file` should be either the path of the image to predict, object detection and semantic segmentation tasks, `img_file` should be either the path of the image to predict,
a decoded image (a np.ndarray, which should be consistent with what you get from passing image path to a decoded image (a np.ndarray, which should be consistent with what you get from passing image path to
paddlers.transforms.decode_image()), or a list of image paths or decoded images. For change detection tasks, paddlers.transforms.decode_image(..., read_raw=True)), or a list of image paths or decoded images. For change
`img_file` should be a tuple of image paths, a tuple of decoded images, or a list of tuples. detection tasks, `img_file` should be a tuple of image paths, a tuple of decoded images, or a list of tuples.
topk(int, optional): Top-k values to reserve in a classification result. Defaults to 1. topk(int, optional): Top-k values to reserve in a classification result. Defaults to 1.
transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms
from `model.yml`. Defaults to None. from `model.yml`. Defaults to None.
@ -332,7 +332,8 @@ class Predictor(object):
overlap=36, overlap=36,
transforms=None, transforms=None,
invalid_value=255, invalid_value=255,
merge_strategy='keep_last'): merge_strategy='keep_last',
batch_size=1):
""" """
Do inference using sliding windows. Only semantic segmentation and change detection models are supported in the Do inference using sliding windows. Only semantic segmentation and change detection models are supported in the
sliding-predicting mode. sliding-predicting mode.
@ -340,9 +341,9 @@ class Predictor(object):
Args: Args:
img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For semantic segmentation tasks, `img_file` img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For semantic segmentation tasks, `img_file`
should be either the path of the image to predict, a decoded image (a np.ndarray, which should be should be either the path of the image to predict, a decoded image (a np.ndarray, which should be
consistent with what you get from passing image path to paddlers.transforms.decode_image()), or a list of consistent with what you get from passing image path to paddlers.transforms.decode_image(..., read_raw=True)),
image paths or decoded images. For change detection tasks, `img_file` should be a tuple of image paths, a or a list of image paths or decoded images. For change detection tasks, `img_file` should be a tuple of
tuple of decoded images, or a list of tuples. image paths, a tuple of decoded images, or a list of tuples.
save_dir (str): Directory that contains saved geotiff file. save_dir (str): Directory that contains saved geotiff file.
block_size (list[int] | tuple[int] | int): Size of block. If `block_size` is a list or tuple, it should be in block_size (list[int] | tuple[int] | int): Size of block. If `block_size` is a list or tuple, it should be in
(W, H) format. (W, H) format.
@ -355,6 +356,7 @@ class Predictor(object):
{'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' means keeping the values of the first and {'keep_first', 'keep_last', 'accum'}. 'keep_first' and 'keep_last' means keeping the values of the first and
the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel the last block in traversal order, respectively. 'accum' means determining the class of an overlapping pixel
according to accumulated probabilities. Defaults to 'keep_last'. according to accumulated probabilities. Defaults to 'keep_last'.
batch_size (int, optional): Batch size used in inference. Defaults to 1.
""" """
slider_predict( slider_predict(
partial( partial(
@ -365,7 +367,8 @@ class Predictor(object):
overlap, overlap,
transforms, transforms,
invalid_value, invalid_value,
merge_strategy) merge_strategy,
batch_size)
def batch_predict(self, image_list, **params): def batch_predict(self, image_list, **params):
return self.predict(img_file=image_list, **params) return self.predict(img_file=image_list, **params)

@ -588,7 +588,8 @@ class BaseChangeDetector(BaseModel):
overlap=36, overlap=36,
transforms=None, transforms=None,
invalid_value=255, invalid_value=255,
merge_strategy='keep_last'): merge_strategy='keep_last',
batch_size=1):
""" """
Do inference using sliding windows. Do inference using sliding windows.
@ -611,10 +612,11 @@ class BaseChangeDetector(BaseModel):
means keeping the values of the first and the last block in traversal means keeping the values of the first and the last block in traversal
order, respectively. 'accum' means determining the class of an overlapping order, respectively. 'accum' means determining the class of an overlapping
pixel according to accumulated probabilities. Defaults to 'keep_last'. pixel according to accumulated probabilities. Defaults to 'keep_last'.
batch_size (int, optional): Batch size used in inference. Defaults to 1.
""" """
slider_predict(self.predict, img_files, save_dir, block_size, overlap, slider_predict(self.predict, img_files, save_dir, block_size, overlap,
transforms, invalid_value, merge_strategy) transforms, invalid_value, merge_strategy, batch_size)
def preprocess(self, images, transforms, to_tensor=True): def preprocess(self, images, transforms, to_tensor=True):
self._check_transforms(transforms, 'test') self._check_transforms(transforms, 'test')
@ -622,8 +624,8 @@ class BaseChangeDetector(BaseModel):
batch_ori_shape = list() batch_ori_shape = list()
for im1, im2 in images: for im1, im2 in images:
if isinstance(im1, str) or isinstance(im2, str): if isinstance(im1, str) or isinstance(im2, str):
im1 = decode_image(im1, to_rgb=False) im1 = decode_image(im1, read_raw=True)
im2 = decode_image(im2, to_rgb=False) im2 = decode_image(im2, read_raw=True)
ori_shape = im1.shape[:2] ori_shape = im1.shape[:2]
# XXX: sample do not contain 'image_t1' and 'image_t2'. # XXX: sample do not contain 'image_t1' and 'image_t2'.
sample = {'image': im1, 'image2': im2} sample = {'image': im1, 'image2': im2}

@ -497,7 +497,7 @@ class BaseClassifier(BaseModel):
batch_ori_shape = list() batch_ori_shape = list()
for im in images: for im in images:
if isinstance(im, str): if isinstance(im, str):
im = decode_image(im, to_rgb=False) im = decode_image(im, read_raw=True)
ori_shape = im.shape[:2] ori_shape = im.shape[:2]
sample = {'image': im} sample = {'image': im}
im = transforms(sample) im = transforms(sample)

@ -617,7 +617,7 @@ class BaseDetector(BaseModel):
batch_samples = list() batch_samples = list()
for im in images: for im in images:
if isinstance(im, str): if isinstance(im, str):
im = decode_image(im, to_rgb=False) im = decode_image(im, read_raw=True)
sample = {'image': im} sample = {'image': im}
sample = transforms(sample) sample = transforms(sample)
batch_samples.append(sample) batch_samples.append(sample)

@ -481,7 +481,7 @@ class BaseRestorer(BaseModel):
batch_tar_shape = list() batch_tar_shape = list()
for im in images: for im in images:
if isinstance(im, str): if isinstance(im, str):
im = decode_image(im, to_rgb=False) im = decode_image(im, read_raw=True)
ori_shape = im.shape[:2] ori_shape = im.shape[:2]
sample = {'image': im} sample = {'image': im}
im = transforms(sample)[0] im = transforms(sample)[0]

@ -560,7 +560,8 @@ class BaseSegmenter(BaseModel):
overlap=36, overlap=36,
transforms=None, transforms=None,
invalid_value=255, invalid_value=255,
merge_strategy='keep_last'): merge_strategy='keep_last',
batch_size=1):
""" """
Do inference using sliding windows. Do inference using sliding windows.
@ -583,10 +584,11 @@ class BaseSegmenter(BaseModel):
means keeping the values of the first and the last block in traversal means keeping the values of the first and the last block in traversal
order, respectively. 'accum' means determining the class of an overlapping order, respectively. 'accum' means determining the class of an overlapping
pixel according to accumulated probabilities. Defaults to 'keep_last'. pixel according to accumulated probabilities. Defaults to 'keep_last'.
batch_size (int, optional): Batch size used in inference. Defaults to 1.
""" """
slider_predict(self.predict, img_file, save_dir, block_size, overlap, slider_predict(self.predict, img_file, save_dir, block_size, overlap,
transforms, invalid_value, merge_strategy) transforms, invalid_value, merge_strategy, batch_size)
def preprocess(self, images, transforms, to_tensor=True): def preprocess(self, images, transforms, to_tensor=True):
self._check_transforms(transforms, 'test') self._check_transforms(transforms, 'test')
@ -594,7 +596,7 @@ class BaseSegmenter(BaseModel):
batch_ori_shape = list() batch_ori_shape = list()
for im in images: for im in images:
if isinstance(im, str): if isinstance(im, str):
im = decode_image(im, to_rgb=False) im = decode_image(im, read_raw=True)
ori_shape = im.shape[:2] ori_shape = im.shape[:2]
sample = {'image': im} sample = {'image': im}
im = transforms(sample)[0] im = transforms(sample)[0]

@ -14,6 +14,7 @@
import os import os
import os.path as osp import os.path as osp
import math
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import Counter, defaultdict from collections import Counter, defaultdict
@ -117,10 +118,10 @@ class ProbCache(Cache):
def roll_cache(self): def roll_cache(self):
if self.order == 'c': if self.order == 'c':
self.cache = np.roll(self.cache, -self.sh, axis=0) self.cache[:-self.sh] = self.cache[self.sh:]
self.cache[-self.sh:, :] = 0 self.cache[-self.sh:, :] = 0
elif self.order == 'f': elif self.order == 'f':
self.cache = np.roll(self.cache, -self.sw, axis=1) self.cache[:, :-self.sw] = self.cache[:, self.sw:]
self.cache[:, -self.sw:] = 0 self.cache[:, -self.sw:] = 0
def get_block(self, i_st, j_st, h, w): def get_block(self, i_st, j_st, h, w):
@ -128,7 +129,7 @@ class ProbCache(Cache):
def slider_predict(predict_func, img_file, save_dir, block_size, overlap, def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
transforms, invalid_value, merge_strategy): transforms, invalid_value, merge_strategy, batch_size):
""" """
Do inference using sliding windows. Do inference using sliding windows.
@ -151,6 +152,7 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
means keeping the values of the first and the last block in means keeping the values of the first and the last block in
traversal order, respectively. 'accum' means determining the class traversal order, respectively. 'accum' means determining the class
of an overlapping pixel according to accumulated probabilities. of an overlapping pixel according to accumulated probabilities.
batch_size (int): Batch size used in inference.
""" """
try: try:
@ -200,6 +202,13 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
height = src_data.RasterYSize height = src_data.RasterYSize
bands = src_data.RasterCount bands = src_data.RasterCount
# XXX: GDAL read behavior conforms to paddlers.transforms.decode_image(read_raw=True)
# except for SAR images.
if bands == 1:
logging.warning(
f"Detected `bands=1`. Please note that currently `slider_predict()` does not properly handle SAR images."
)
if block_size[0] > width or block_size[1] > height: if block_size[0] > width or block_size[1] > height:
raise ValueError("`block_size` should not be larger than image size.") raise ValueError("`block_size` should not be larger than image size.")
@ -228,6 +237,8 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
if merge_strategy == 'accum': if merge_strategy == 'accum':
cache = ProbCache(height, width, *block_size[::-1], *step[::-1]) cache = ProbCache(height, width, *block_size[::-1], *step[::-1])
batch_data = []
batch_offsets = []
for yoff in range(0, height, step[1]): for yoff in range(0, height, step[1]):
for xoff in range(0, width, step[0]): for xoff in range(0, width, step[0]):
xsize, ysize = block_size xsize, ysize = block_size
@ -236,6 +247,9 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
if yoff + ysize > height: if yoff + ysize > height:
yoff = height - ysize yoff = height - ysize
is_end_of_col = yoff + ysize >= height
is_end_of_row = xoff + xsize >= width
# Read and fill # Read and fill
im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
(1, 2, 0)) (1, 2, 0))
@ -243,18 +257,31 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
if isinstance(img_file, tuple): if isinstance(img_file, tuple):
im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
(1, 2, 0)) (1, 2, 0))
# Predict batch_data.append((im, im2))
out = predict_func((im, im2), transforms=transforms)
else: else:
batch_data.append(im)
batch_offsets.append((xoff, yoff))
len_batch = len(batch_data)
if is_end_of_row and is_end_of_col and len_batch < batch_size:
# Pad `batch_data` by repeating the last element
batch_data = batch_data + [batch_data[-1]] * (batch_size -
len_batch)
# While keeping `len(batch_offsets)` the number of valid elements in the batch
if len(batch_data) == batch_size:
# Predict # Predict
out = predict_func(im, transforms=transforms) batch_out = predict_func(batch_data, transforms=transforms)
for out, (xoff_, yoff_) in zip(batch_out, batch_offsets):
pred = out['label_map'].astype('uint8') pred = out['label_map'].astype('uint8')
pred = pred[:ysize, :xsize] pred = pred[:ysize, :xsize]
# Deal with overlapping pixels # Deal with overlapping pixels
if merge_strategy == 'keep_first': if merge_strategy == 'keep_first':
rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize) rd_block = band.ReadAsArray(xoff_, yoff_, xsize, ysize)
mask = rd_block != invalid_value mask = rd_block != invalid_value
pred = np.where(mask, rd_block, pred) pred = np.where(mask, rd_block, pred)
elif merge_strategy == 'keep_last': elif merge_strategy == 'keep_last':
@ -262,14 +289,17 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
elif merge_strategy == 'accum': elif merge_strategy == 'accum':
prob = out['score_map'] prob = out['score_map']
prob = prob[:ysize, :xsize] prob = prob[:ysize, :xsize]
cache.update_block(0, xoff, ysize, xsize, prob) cache.update_block(0, xoff_, ysize, xsize, prob)
pred = cache.get_block(0, xoff, ysize, xsize) pred = cache.get_block(0, xoff_, ysize, xsize)
if xoff + xsize >= width: if xoff_ + xsize >= width:
cache.roll_cache() cache.roll_cache()
# Write to file # Write to file
band.WriteArray(pred, xoff, yoff) band.WriteArray(pred, xoff_, yoff_)
dst_data.FlushCache() dst_data.FlushCache()
batch_data.clear()
batch_offsets.clear()
dst_data = None dst_data = None
logging.info("GeoTiff file saved in {}.".format(save_file)) logging.info("GeoTiff file saved in {}.".format(save_file))

@ -25,7 +25,9 @@ def decode_image(im_path,
to_uint8=True, to_uint8=True,
decode_bgr=True, decode_bgr=True,
decode_sar=True, decode_sar=True,
read_geo_info=False): read_geo_info=False,
use_stretch=False,
read_raw=False):
""" """
Decode an image. Decode an image.
@ -37,11 +39,16 @@ def decode_image(im_path,
uint8 type. Defaults to True. uint8 type. Defaults to True.
decode_bgr (bool, optional): If True, automatically interpret a non-geo decode_bgr (bool, optional): If True, automatically interpret a non-geo
image (e.g. jpeg images) as a BGR image. Defaults to True. image (e.g. jpeg images) as a BGR image. Defaults to True.
decode_sar (bool, optional): If True, automatically interpret a two-channel decode_sar (bool, optional): If True, automatically interpret a single-channel
geo image (e.g. geotiff images) as a SAR image, set this argument to geo image (e.g. geotiff images) as a SAR image, set this argument to
True. Defaults to True. True. Defaults to True.
read_geo_info (bool, optional): If True, read geographical information from read_geo_info (bool, optional): If True, read geographical information from
the image. Deafults to False. the image. Deafults to False.
use_stretch (bool, optional): Whether to apply 2% linear stretch. Valid only if
`to_uint8` is True. Defaults to False.
read_raw (bool, optional): If True, equivalent to setting `to_rgb` and `to_uint8`
to False. Setting `read_raw` takes precedence over setting `to_rgb` and
`to_uint8`. Defaults to False.
Returns: Returns:
np.ndarray|tuple: If `read_geo_info` is False, return the decoded image. np.ndarray|tuple: If `read_geo_info` is False, return the decoded image.
@ -53,12 +60,16 @@ def decode_image(im_path,
# Do a presence check. osp.exists() assumes `im_path` is a path-like object. # Do a presence check. osp.exists() assumes `im_path` is a path-like object.
if not osp.exists(im_path): if not osp.exists(im_path):
raise ValueError(f"{im_path} does not exist!") raise ValueError(f"{im_path} does not exist!")
if read_raw:
to_rgb = False
to_uint8 = False
decoder = T.DecodeImg( decoder = T.DecodeImg(
to_rgb=to_rgb, to_rgb=to_rgb,
to_uint8=to_uint8, to_uint8=to_uint8,
decode_bgr=decode_bgr, decode_bgr=decode_bgr,
decode_sar=decode_sar, decode_sar=decode_sar,
read_geo_info=read_geo_info) read_geo_info=read_geo_info,
use_stretch=use_stretch)
# Deepcopy to avoid inplace modification # Deepcopy to avoid inplace modification
sample = {'image': copy.deepcopy(im_path)} sample = {'image': copy.deepcopy(im_path)}
sample = decoder(sample) sample = decoder(sample)

@ -382,13 +382,13 @@ def resize_rle(rle, im_h, im_w, im_scale_x, im_scale_y, interp):
return rle return rle
def to_uint8(im, is_linear=False): def to_uint8(im, stretch=False):
""" """
Convert raster data to uint8 type. Convert raster data to uint8 type.
Args: Args:
im (np.ndarray): Input raster image. im (np.ndarray): Input raster image.
is_linear (bool, optional): Use 2% linear stretch or not. Default is False. stretch (bool, optional): Use 2% linear stretch or not. Default is False.
Returns: Returns:
np.ndarray: Image data with unit8 type. np.ndarray: Image data with unit8 type.
@ -430,7 +430,7 @@ def to_uint8(im, is_linear=False):
dtype = im.dtype.name dtype = im.dtype.name
if dtype != "uint8": if dtype != "uint8":
im = _sample_norm(im) im = _sample_norm(im)
if is_linear: if stretch:
im = _two_percent_linear(im) im = _two_percent_linear(im)
return im return im

@ -179,11 +179,13 @@ class DecodeImg(Transform):
uint8 type. Defaults to True. uint8 type. Defaults to True.
decode_bgr (bool, optional): If True, automatically interpret a non-geo image decode_bgr (bool, optional): If True, automatically interpret a non-geo image
(e.g., jpeg images) as a BGR image. Defaults to True. (e.g., jpeg images) as a BGR image. Defaults to True.
decode_sar (bool, optional): If True, automatically interpret a two-channel decode_sar (bool, optional): If True, automatically interpret a single-channel
geo image (e.g. geotiff images) as a SAR image, set this argument to geo image (e.g. geotiff images) as a SAR image, set this argument to
True. Defaults to True. True. Defaults to True.
read_geo_info (bool, optional): If True, read geographical information from read_geo_info (bool, optional): If True, read geographical information from
the image. Deafults to False. the image. Deafults to False.
use_stretch (bool, optional): Whether to apply 2% linear stretch. Valid only if
`to_uint8` is True. Defaults to False.
""" """
def __init__(self, def __init__(self,
@ -191,13 +193,15 @@ class DecodeImg(Transform):
to_uint8=True, to_uint8=True,
decode_bgr=True, decode_bgr=True,
decode_sar=True, decode_sar=True,
read_geo_info=False): read_geo_info=False,
use_stretch=False):
super(DecodeImg, self).__init__() super(DecodeImg, self).__init__()
self.to_rgb = to_rgb self.to_rgb = to_rgb
self.to_uint8 = to_uint8 self.to_uint8 = to_uint8
self.decode_bgr = decode_bgr self.decode_bgr = decode_bgr
self.decode_sar = decode_sar self.decode_sar = decode_sar
self.read_geo_info = read_geo_info self.read_geo_info = read_geo_info
self.use_stretch = use_stretch
def read_img(self, img_path): def read_img(self, img_path):
img_format = imghdr.what(img_path) img_format = imghdr.what(img_path)
@ -264,7 +268,7 @@ class DecodeImg(Transform):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.to_uint8: if self.to_uint8:
image = to_uint8(image) image = to_uint8(image, stretch=self.use_stretch)
if self.read_geo_info: if self.read_geo_info:
return image, geo_info_dict return image, geo_info_dict

@ -145,8 +145,8 @@ class TestCDPredictor(TestPredictor):
# Single input (ndarrays) # Single input (ndarrays)
input_ = (decode_image( input_ = (decode_image(
t1_path, to_rgb=False), decode_image( t1_path, read_raw=True), decode_image(
t2_path, to_rgb=False)) # Reuse the name `input_` t2_path, read_raw=True)) # Reuse the name `input_`
out_single_array_p = predictor.predict(input_, transforms=transforms) out_single_array_p = predictor.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_file_p) self.check_dict_equal(out_single_array_p, out_single_file_p)
out_single_array_t = trainer.predict(input_, transforms=transforms) out_single_array_t = trainer.predict(input_, transforms=transforms)
@ -169,8 +169,9 @@ class TestCDPredictor(TestPredictor):
# Multiple inputs (ndarrays) # Multiple inputs (ndarrays)
input_ = [(decode_image( input_ = [(decode_image(
t1_path, to_rgb=False), decode_image( t1_path, read_raw=True), decode_image(
t2_path, to_rgb=False))] * num_inputs # Reuse the name `input_` t2_path,
read_raw=True))] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms) out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs) self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms) out_multi_array_t = trainer.predict(input_, transforms=transforms)
@ -211,7 +212,7 @@ class TestClasPredictor(TestPredictor):
# Single input (ndarray) # Single input (ndarray)
input_ = decode_image( input_ = decode_image(
single_input, to_rgb=False) # Reuse the name `input_` single_input, read_raw=True) # Reuse the name `input_`
out_single_array_p = predictor.predict(input_, transforms=transforms) out_single_array_p = predictor.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_file_p) self.check_dict_equal(out_single_array_p, out_single_file_p)
out_single_array_t = trainer.predict(input_, transforms=transforms) out_single_array_t = trainer.predict(input_, transforms=transforms)
@ -235,7 +236,8 @@ class TestClasPredictor(TestPredictor):
# Multiple inputs (ndarrays) # Multiple inputs (ndarrays)
input_ = [decode_image( input_ = [decode_image(
single_input, to_rgb=False)] * num_inputs # Reuse the name `input_` single_input,
read_raw=True)] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms) out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs) self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms) out_multi_array_t = trainer.predict(input_, transforms=transforms)
@ -276,7 +278,7 @@ class TestDetPredictor(TestPredictor):
# Single input (ndarray) # Single input (ndarray)
input_ = decode_image( input_ = decode_image(
single_input, to_rgb=False) # Reuse the name `input_` single_input, read_raw=True) # Reuse the name `input_`
predictor.predict(input_, transforms=transforms) predictor.predict(input_, transforms=transforms)
trainer.predict(input_, transforms=transforms) trainer.predict(input_, transforms=transforms)
out_single_array_list_p = predictor.predict( out_single_array_list_p = predictor.predict(
@ -295,7 +297,8 @@ class TestDetPredictor(TestPredictor):
# Multiple inputs (ndarrays) # Multiple inputs (ndarrays)
input_ = [decode_image( input_ = [decode_image(
single_input, to_rgb=False)] * num_inputs # Reuse the name `input_` single_input,
read_raw=True)] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms) out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs) self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms) out_multi_array_t = trainer.predict(input_, transforms=transforms)
@ -329,7 +332,7 @@ class TestResPredictor(TestPredictor):
# Single input (ndarray) # Single input (ndarray)
input_ = decode_image( input_ = decode_image(
single_input, to_rgb=False) # Reuse the name `input_` single_input, read_raw=True) # Reuse the name `input_`
predictor.predict(input_, transforms=transforms) predictor.predict(input_, transforms=transforms)
trainer.predict(input_, transforms=transforms) trainer.predict(input_, transforms=transforms)
out_single_array_list_p = predictor.predict( out_single_array_list_p = predictor.predict(
@ -348,7 +351,8 @@ class TestResPredictor(TestPredictor):
# Multiple inputs (ndarrays) # Multiple inputs (ndarrays)
input_ = [decode_image( input_ = [decode_image(
single_input, to_rgb=False)] * num_inputs # Reuse the name `input_` single_input,
read_raw=True)] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms) out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs) self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms) out_multi_array_t = trainer.predict(input_, transforms=transforms)
@ -386,7 +390,7 @@ class TestSegPredictor(TestPredictor):
# Single input (ndarray) # Single input (ndarray)
input_ = decode_image( input_ = decode_image(
single_input, to_rgb=False) # Reuse the name `input_` single_input, read_raw=True) # Reuse the name `input_`
out_single_array_p = predictor.predict(input_, transforms=transforms) out_single_array_p = predictor.predict(input_, transforms=transforms)
self.check_dict_equal(out_single_array_p, out_single_file_p) self.check_dict_equal(out_single_array_p, out_single_file_p)
out_single_array_t = trainer.predict(input_, transforms=transforms) out_single_array_t = trainer.predict(input_, transforms=transforms)
@ -409,7 +413,8 @@ class TestSegPredictor(TestPredictor):
# Multiple inputs (ndarrays) # Multiple inputs (ndarrays)
input_ = [decode_image( input_ = [decode_image(
single_input, to_rgb=False)] * num_inputs # Reuse the name `input_` single_input,
read_raw=True)] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms) out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs) self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms) out_multi_array_t = trainer.predict(input_, transforms=transforms)

@ -20,21 +20,14 @@ import paddlers.transforms as T
from testing_utils import CommonTest from testing_utils import CommonTest
class TestSegSliderPredict(CommonTest): class _TestSliderPredictNamespace:
def setUp(self): class TestSliderPredict(CommonTest):
self.model = pdrs.tasks.seg.UNet(in_channels=10)
self.transforms = T.Compose([
T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10),
T.ArrangeSegmenter('test')
])
self.image_path = "data/ssst/multispectral.tif"
self.basename = osp.basename(self.image_path)
def test_blocksize_and_overlap_whole(self): def test_blocksize_and_overlap_whole(self):
# Original image size (256, 256) # Original image size (256, 256)
with tempfile.TemporaryDirectory() as td: with tempfile.TemporaryDirectory() as td:
# Whole-image inference using predict() # Whole-image inference using predict()
pred_whole = self.model.predict(self.image_path, self.transforms) pred_whole = self.model.predict(self.image_path,
self.transforms)
pred_whole = pred_whole['label_map'] pred_whole = pred_whole['label_map']
# Whole-image inference using slider_predict() # Whole-image inference using slider_predict()
@ -43,23 +36,23 @@ class TestSegSliderPredict(CommonTest):
self.transforms) self.transforms)
pred1 = T.decode_image( pred1 = T.decode_image(
osp.join(save_dir, self.basename), osp.join(save_dir, self.basename),
to_uint8=False, read_raw=True,
decode_sar=False) decode_sar=False)
self.check_output_equal(pred1.shape, pred_whole.shape) self.check_output_equal(pred1.shape, pred_whole.shape)
# `block_size` == `overlap` # `block_size` == `overlap`
save_dir = osp.join(td, 'pred2') save_dir = osp.join(td, 'pred2')
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.model.slider_predict(self.image_path, save_dir, 128, 128, self.model.slider_predict(self.image_path, save_dir, 128,
self.transforms) 128, self.transforms)
# `block_size` is a tuple # `block_size` is a tuple
save_dir = osp.join(td, 'pred3') save_dir = osp.join(td, 'pred3')
self.model.slider_predict(self.image_path, save_dir, (128, 32), 0, self.model.slider_predict(self.image_path, save_dir, (128, 32),
self.transforms) 0, self.transforms)
pred3 = T.decode_image( pred3 = T.decode_image(
osp.join(save_dir, self.basename), osp.join(save_dir, self.basename),
to_uint8=False, read_raw=True,
decode_sar=False) decode_sar=False)
self.check_output_equal(pred3.shape, pred_whole.shape) self.check_output_equal(pred3.shape, pred_whole.shape)
@ -69,7 +62,7 @@ class TestSegSliderPredict(CommonTest):
(10, 5), self.transforms) (10, 5), self.transforms)
pred4 = T.decode_image( pred4 = T.decode_image(
osp.join(save_dir, self.basename), osp.join(save_dir, self.basename),
to_uint8=False, read_raw=True,
decode_sar=False) decode_sar=False)
self.check_output_equal(pred4.shape, pred_whole.shape) self.check_output_equal(pred4.shape, pred_whole.shape)
@ -82,7 +75,8 @@ class TestSegSliderPredict(CommonTest):
def test_merge_strategy(self): def test_merge_strategy(self):
with tempfile.TemporaryDirectory() as td: with tempfile.TemporaryDirectory() as td:
# Whole-image inference using predict() # Whole-image inference using predict()
pred_whole = self.model.predict(self.image_path, self.transforms) pred_whole = self.model.predict(self.image_path,
self.transforms)
pred_whole = pred_whole['label_map'] pred_whole = pred_whole['label_map']
# 'keep_first' # 'keep_first'
@ -96,7 +90,7 @@ class TestSegSliderPredict(CommonTest):
merge_strategy='keep_first') merge_strategy='keep_first')
pred_keepfirst = T.decode_image( pred_keepfirst = T.decode_image(
osp.join(save_dir, self.basename), osp.join(save_dir, self.basename),
to_uint8=False, read_raw=True,
decode_sar=False) decode_sar=False)
self.check_output_equal(pred_keepfirst.shape, pred_whole.shape) self.check_output_equal(pred_keepfirst.shape, pred_whole.shape)
@ -111,7 +105,7 @@ class TestSegSliderPredict(CommonTest):
merge_strategy='keep_last') merge_strategy='keep_last')
pred_keeplast = T.decode_image( pred_keeplast = T.decode_image(
osp.join(save_dir, self.basename), osp.join(save_dir, self.basename),
to_uint8=False, read_raw=True,
decode_sar=False) decode_sar=False)
self.check_output_equal(pred_keeplast.shape, pred_whole.shape) self.check_output_equal(pred_keeplast.shape, pred_whole.shape)
@ -126,141 +120,93 @@ class TestSegSliderPredict(CommonTest):
merge_strategy='accum') merge_strategy='accum')
pred_accum = T.decode_image( pred_accum = T.decode_image(
osp.join(save_dir, self.basename), osp.join(save_dir, self.basename),
to_uint8=False, read_raw=True,
decode_sar=False) decode_sar=False)
self.check_output_equal(pred_accum.shape, pred_whole.shape) self.check_output_equal(pred_accum.shape, pred_whole.shape)
def test_geo_info(self): def test_geo_info(self):
with tempfile.TemporaryDirectory() as td: with tempfile.TemporaryDirectory() as td:
_, geo_info_in = T.decode_image(self.image_path, read_geo_info=True) _, geo_info_in = T.decode_image(
self.ref_path, read_geo_info=True)
self.model.slider_predict(self.image_path, td, 128, 0, self.model.slider_predict(self.image_path, td, 128, 0,
self.transforms) self.transforms)
_, geo_info_out = T.decode_image( _, geo_info_out = T.decode_image(
osp.join(td, self.basename), read_geo_info=True) osp.join(td, self.basename), read_geo_info=True)
self.assertEqual(geo_info_out['geo_trans'], self.assertEqual(geo_info_out['geo_trans'],
geo_info_in['geo_trans']) geo_info_in['geo_trans'])
self.assertEqual(geo_info_out['geo_proj'], geo_info_in['geo_proj']) self.assertEqual(geo_info_out['geo_proj'],
geo_info_in['geo_proj'])
class TestCDSliderPredict(CommonTest): def test_batch_size(self):
def setUp(self):
self.model = pdrs.tasks.cd.BIT(in_channels=10)
self.transforms = T.Compose([
T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10),
T.ArrangeChangeDetector('test')
])
self.image_paths = ("data/ssmt/multispectral_t1.tif",
"data/ssmt/multispectral_t2.tif")
self.basename = osp.basename(self.image_paths[0])
def test_blocksize_and_overlap_whole(self):
# Original image size (256, 256)
with tempfile.TemporaryDirectory() as td: with tempfile.TemporaryDirectory() as td:
# Whole-image inference using predict() # batch_size = 1
pred_whole = self.model.predict(self.image_paths, self.transforms) save_dir = osp.join(td, 'bs1')
pred_whole = pred_whole['label_map']
# Whole-image inference using slider_predict()
save_dir = osp.join(td, 'pred1')
self.model.slider_predict(self.image_paths, save_dir, 256, 0,
self.transforms)
pred1 = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred1.shape, pred_whole.shape)
# `block_size` == `overlap`
save_dir = osp.join(td, 'pred2')
with self.assertRaises(ValueError):
self.model.slider_predict(self.image_paths, save_dir, 128, 128,
self.transforms)
# `block_size` is a tuple
save_dir = osp.join(td, 'pred3')
self.model.slider_predict(self.image_paths, save_dir, (128, 32), 0,
self.transforms)
pred3 = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred3.shape, pred_whole.shape)
# `block_size` and `overlap` are both tuples
save_dir = osp.join(td, 'pred4')
self.model.slider_predict(self.image_paths, save_dir, (128, 100),
(10, 5), self.transforms)
pred4 = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred4.shape, pred_whole.shape)
# `block_size` larger than image size
save_dir = osp.join(td, 'pred5')
with self.assertRaises(ValueError):
self.model.slider_predict(self.image_paths, save_dir, 512, 0,
self.transforms)
def test_merge_strategy(self):
with tempfile.TemporaryDirectory() as td:
# Whole-image inference using predict()
pred_whole = self.model.predict(self.image_paths, self.transforms)
pred_whole = pred_whole['label_map']
# 'keep_first'
save_dir = osp.join(td, 'keep_first')
self.model.slider_predict( self.model.slider_predict(
self.image_paths, self.image_path,
save_dir, save_dir,
128, 128,
64, 64,
self.transforms, self.transforms,
merge_strategy='keep_first') merge_strategy='keep_first',
pred_keepfirst = T.decode_image( batch_size=1)
pred_bs1 = T.decode_image(
osp.join(save_dir, self.basename), osp.join(save_dir, self.basename),
to_uint8=False, read_raw=True,
decode_sar=False) decode_sar=False)
self.check_output_equal(pred_keepfirst.shape, pred_whole.shape)
# 'keep_last' # batch_size = 4
save_dir = osp.join(td, 'keep_last') save_dir = osp.join(td, 'bs4')
self.model.slider_predict( self.model.slider_predict(
self.image_paths, self.image_path,
save_dir, save_dir,
128, 128,
64, 64,
self.transforms, self.transforms,
merge_strategy='keep_last') merge_strategy='keep_first',
pred_keeplast = T.decode_image( batch_size=4)
pred_bs4 = T.decode_image(
osp.join(save_dir, self.basename), osp.join(save_dir, self.basename),
to_uint8=False, read_raw=True,
decode_sar=False) decode_sar=False)
self.check_output_equal(pred_keeplast.shape, pred_whole.shape) self.check_output_equal(pred_bs4, pred_bs1)
# 'accum' # batch_size = 8
save_dir = osp.join(td, 'accum') save_dir = osp.join(td, 'bs4')
self.model.slider_predict( self.model.slider_predict(
self.image_paths, self.image_path,
save_dir, save_dir,
128, 128,
64, 64,
self.transforms, self.transforms,
merge_strategy='accum') merge_strategy='keep_first',
pred_accum = T.decode_image( batch_size=8)
pred_bs8 = T.decode_image(
osp.join(save_dir, self.basename), osp.join(save_dir, self.basename),
to_uint8=False, read_raw=True,
decode_sar=False) decode_sar=False)
self.check_output_equal(pred_accum.shape, pred_whole.shape) self.check_output_equal(pred_bs8, pred_bs1)
def test_geo_info(self):
with tempfile.TemporaryDirectory() as td: class TestSegSliderPredict(_TestSliderPredictNamespace.TestSliderPredict):
_, geo_info_in = T.decode_image( def setUp(self):
self.image_paths[0], read_geo_info=True) self.model = pdrs.tasks.seg.UNet(in_channels=10)
self.model.slider_predict(self.image_paths, td, 128, 0, self.transforms = T.Compose([
self.transforms) T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10),
_, geo_info_out = T.decode_image( T.ArrangeSegmenter('test')
osp.join(td, self.basename), read_geo_info=True) ])
self.assertEqual(geo_info_out['geo_trans'], self.image_path = "data/ssst/multispectral.tif"
geo_info_in['geo_trans']) self.ref_path = self.image_path
self.assertEqual(geo_info_out['geo_proj'], geo_info_in['geo_proj']) self.basename = osp.basename(self.ref_path)
class TestCDSliderPredict(_TestSliderPredictNamespace.TestSliderPredict):
def setUp(self):
self.model = pdrs.tasks.cd.BIT(in_channels=10)
self.transforms = T.Compose([
T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10),
T.ArrangeChangeDetector('test')
])
self.image_path = ("data/ssmt/multispectral_t1.tif",
"data/ssmt/multispectral_t2.tif")
self.ref_path = self.image_path[0]
self.basename = osp.basename(self.ref_path)

Loading…
Cancel
Save