Merge branch 'enhance_slide' of https://github.com/Bobholamovic/PaddleRS into enhance_slide

own
Bobholamovic 2 years ago
commit c9457a0d99
  1. 4
      docs/apis/infer.md
  2. 137
      paddlers/tasks/change_detector.py
  3. 137
      paddlers/tasks/segmenter.py
  4. 254
      paddlers/tasks/utils/slider_predict.py
  5. 50
      tests/tasks/test_slider_predict.py

@ -170,9 +170,9 @@ def slider_predict(self,
|`overlap`|`list[int]` \| `tuple[int]` \| `int`|滑窗的滑动步长(以列表或元组指定宽度、高度或以一个整数指定相同的宽高)。|`36`| |`overlap`|`list[int]` \| `tuple[int]` \| `int`|滑窗的滑动步长(以列表或元组指定宽度、高度或以一个整数指定相同的宽高)。|`36`|
|`transforms`|`paddlers.transforms.Compose` \| `None`|对输入数据应用的数据变换算子。若为`None`,则使用训练器在验证阶段使用的数据变换算子。|`None`| |`transforms`|`paddlers.transforms.Compose` \| `None`|对输入数据应用的数据变换算子。若为`None`,则使用训练器在验证阶段使用的数据变换算子。|`None`|
|`invalid_value`|`int`|输出影像中用于标记无效像素的数值。|`255`| |`invalid_value`|`int`|输出影像中用于标记无效像素的数值。|`255`|
|`merge_strategy`|`str`|合并滑窗重叠区域使用的策略。`'keep_first'`表示保留遍历顺序(从左至右,从上往下,列优先)最靠前的窗口的预测值;`'keep_last'`表示保留遍历顺序最靠后的窗口的预测值;`'vote'`表示使用投票策略,即对于每个像素,最终预测值为所有覆盖该像素的滑窗给出的预测值中出现频率最高者。需要注意的是,在对大尺寸影像进行`overlap`较大的密集推理时,使用`'vote'`策略可能导致较长的推理时间,但给出的预测结果在窗口的接缝处相比其它两种策略将更加平滑。|`'keep_last'`| |`merge_strategy`|`str`|合并滑窗重叠区域使用的策略。`'keep_first'`表示保留遍历顺序(从左至右,从上往下,列优先)最靠前的窗口的预测类别;`'keep_last'`表示保留遍历顺序最靠后的窗口的预测类别;`'vote'`表示使用投票策略,即对于每个像素,最终预测类别为所有覆盖该像素的滑窗给出的预测类别中出现频率最高者;`'accum'`表示通过将各窗口在重叠区域给出的预测概率累加,计算最终预测类别。需要注意的是,在对大尺寸影像进行`overlap`较大的密集推理时,使用`'vote'`策略可能导致较长的推理时间。|`'keep_last'`|
变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准。 变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准,存储滑窗推理结果的文件名也与第一时相影像文件相同
## 静态图推理API ## 静态图推理API

@ -35,7 +35,7 @@ from paddlers.utils.checkpoint import cd_pretrain_weights_dict
from .base import BaseModel from .base import BaseModel
from .utils import seg_metrics as metrics from .utils import seg_metrics as metrics
from .utils.infer_nets import InferCDNet from .utils.infer_nets import InferCDNet
from .utils.slider_predict import SlowCache as Cache from .utils.slider_predict import slider_predict
__all__ = [ __all__ = [
"CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT", "CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT",
@ -613,139 +613,8 @@ class BaseChangeDetector(BaseModel):
there are conflicts in the overlapping pixels. Defaults to 'keep_last'. there are conflicts in the overlapping pixels. Defaults to 'keep_last'.
""" """
try: slider_predict(self, img_files, save_dir, block_size, overlap,
from osgeo import gdal transforms, invalid_value, merge_strategy)
except:
import gdal
if isinstance(block_size, int):
block_size = (block_size, block_size)
elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
block_size = tuple(block_size)
else:
raise ValueError(
"`block_size` must be a tuple/list of length 2 or an integer.")
if isinstance(overlap, int):
overlap = (overlap, overlap)
elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
overlap = tuple(overlap)
else:
raise ValueError(
"`overlap` must be a tuple/list of length 2 or an integer.")
if merge_strategy not in ('keep_first', 'keep_last', 'vote'):
raise ValueError(
"{} is not a supported stragegy for block merging.".format(
merge_strategy))
if overlap == (0, 0):
# When there is no overlap, use 'keep_last' strategy as it introduces least overheads
merge_strategy = 'keep_last'
if merge_strategy == 'vote':
logging.warning(
"Currently, a naive Python-implemented cache is used for aggregating voting results. "
"For higher performance in inferring large images, please set `merge_strategy` to 'keep_first' or "
"'keep_last'.")
cache = Cache()
src1_data = gdal.Open(img_files[0])
src2_data = gdal.Open(img_files[1])
# Assume that two input images have the same size
width = src1_data.RasterXSize
height = src1_data.RasterYSize
bands = src1_data.RasterCount
driver = gdal.GetDriverByName("GTiff")
# Output name is the same as the name of the first image
file_name = osp.basename(osp.normpath(img_files[0]))
# Replace extension name with '.tif'
file_name = osp.splitext(file_name)[0] + ".tif"
if not osp.exists(save_dir):
os.makedirs(save_dir)
save_file = osp.join(save_dir, file_name)
dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
# Set meta-information (consistent with the first image)
dst_data.SetGeoTransform(src1_data.GetGeoTransform())
dst_data.SetProjection(src1_data.GetProjection())
band = dst_data.GetRasterBand(1)
band.WriteArray(
np.full(
(height, width), fill_value=invalid_value, dtype="uint8"))
prev_yoff, prev_xoff = None, None
prev_h, prev_w = None, None
step = np.array(
block_size, dtype=np.int32) - np.array(
overlap, dtype=np.int32)
if step[0] == 0 or step[1] == 0:
raise ValueError("`block_size` and `overlap` should not be equal.")
for yoff in range(0, height, step[1]):
for xoff in range(0, width, step[0]):
xsize, ysize = block_size
if xoff + xsize > width:
xsize = int(width - xoff)
if yoff + ysize > height:
ysize = int(height - yoff)
xoff = int(xoff)
yoff = int(yoff)
im1 = src1_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
(1, 2, 0))
im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
(1, 2, 0))
# Fill
h, w = im1.shape[:2]
im1_fill = np.zeros(
(block_size[1], block_size[0], bands), dtype=im1.dtype)
im1_fill[:h, :w, :] = im1
im2_fill = np.zeros(
(block_size[1], block_size[0], bands), dtype=im2.dtype)
im2_fill[:h, :w, :] = im2
# Predict
pred = self.predict((im1_fill, im2_fill), transforms)
pred = pred["label_map"].astype('uint8')
pred = pred[:h, :w]
# Deal with overlapping pixels
if merge_strategy == 'vote':
cache.push_block(yoff, xoff, h, w, pred)
pred = cache.get_block(yoff, xoff, h, w)
pred = pred.astype('uint8')
if prev_yoff is not None:
pop_h = yoff - prev_yoff
else:
pop_h = 0
if prev_xoff is not None:
if xoff < prev_xoff:
pop_w = prev_w
else:
pop_w = xoff - prev_xoff
else:
pop_w = 0
cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w)
elif merge_strategy == 'keep_first':
rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize)
mask = rd_block != invalid_value
pred = np.where(mask, rd_block, pred)
elif merge_strategy == 'keep_last':
pass
# Write to file
band.WriteArray(pred, xoff, yoff)
dst_data.FlushCache()
prev_xoff = xoff
prev_w = w
prev_yoff = yoff
prev_h = h
dst_data = None
logging.info("GeoTiff file saved in {}.".format(save_file))
def preprocess(self, images, transforms, to_tensor=True): def preprocess(self, images, transforms, to_tensor=True):
self._check_transforms(transforms, 'test') self._check_transforms(transforms, 'test')

@ -34,7 +34,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict
from .base import BaseModel from .base import BaseModel
from .utils import seg_metrics as metrics from .utils import seg_metrics as metrics
from .utils.infer_nets import InferSegNet from .utils.infer_nets import InferSegNet
from .utils.slider_predict import SlowCache as Cache from .utils.slider_predict import slider_predict
__all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"] __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
@ -579,135 +579,16 @@ class BaseSegmenter(BaseModel):
invalid_value (int, optional): Value that marks invalid pixels in output invalid_value (int, optional): Value that marks invalid pixels in output
image. Defaults to 255. image. Defaults to 255.
merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices
are {'keep_first', 'keep_last', 'vote'}. 'keep_first' and 'keep_last' are {'keep_first', 'keep_last', 'vote', 'accum'}. 'keep_first' and
means keeping the values of the first and the last block in traversal 'keep_last' means keeping the values of the first and the last block in
order, respectively. 'vote' means applying a simple voting strategy when traversal order, respectively. 'vote' means applying a simple voting
there are conflicts in the overlapping pixels. Defaults to 'keep_last'. strategy when there are conflicts in the overlapping pixels. 'accum'
means determining the class of an overlapping pixel according to
accumulated probabilities. Defaults to 'keep_last'.
""" """
try: slider_predict(self, img_file, save_dir, block_size, overlap,
from osgeo import gdal transforms, invalid_value, merge_strategy)
except:
import gdal
if isinstance(block_size, int):
block_size = (block_size, block_size)
elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
block_size = tuple(block_size)
else:
raise ValueError(
"`block_size` must be a tuple/list of length 2 or an integer.")
if isinstance(overlap, int):
overlap = (overlap, overlap)
elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
overlap = tuple(overlap)
else:
raise ValueError(
"`overlap` must be a tuple/list of length 2 or an integer.")
if merge_strategy not in ('keep_first', 'keep_last', 'vote'):
raise ValueError(
"{} is not a supported stragegy for block merging.".format(
merge_strategy))
if overlap == (0, 0):
# When there is no overlap, use 'keep_last' strategy as it introduces least overheads
merge_strategy = 'keep_last'
if merge_strategy == 'vote':
logging.warning(
"Currently, a naive Python-implemented cache is used for aggregating voting results. "
"For higher performance in inferring large images, please set `merge_strategy` to 'keep_first' or "
"'keep_last'.")
cache = Cache()
src_data = gdal.Open(img_file)
width = src_data.RasterXSize
height = src_data.RasterYSize
bands = src_data.RasterCount
driver = gdal.GetDriverByName("GTiff")
file_name = osp.basename(osp.normpath(img_file))
# Replace extension name with '.tif'
file_name = osp.splitext(file_name)[0] + ".tif"
if not osp.exists(save_dir):
os.makedirs(save_dir)
save_file = osp.join(save_dir, file_name)
dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
# Set meta-information
dst_data.SetGeoTransform(src_data.GetGeoTransform())
dst_data.SetProjection(src_data.GetProjection())
band = dst_data.GetRasterBand(1)
band.WriteArray(
np.full(
(height, width), fill_value=invalid_value, dtype="uint8"))
prev_yoff, prev_xoff = None, None
prev_h, prev_w = None, None
step = np.array(
block_size, dtype=np.int32) - np.array(
overlap, dtype=np.int32)
if step[0] == 0 or step[1] == 0:
raise ValueError("`block_size` and `overlap` should not be equal.")
for yoff in range(0, height, step[1]):
for xoff in range(0, width, step[0]):
xsize, ysize = block_size
if xoff + xsize > width:
xsize = int(width - xoff)
if yoff + ysize > height:
ysize = int(height - yoff)
xoff = int(xoff)
yoff = int(yoff)
im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
(1, 2, 0))
# Fill
h, w = im.shape[:2]
im_fill = np.zeros(
(block_size[1], block_size[0], bands), dtype=im.dtype)
im_fill[:h, :w, :] = im
# Predict
pred = self.predict(im_fill, transforms)
pred = pred["label_map"].astype('uint8')
pred = pred[:h, :w]
# Deal with overlapping pixels
if merge_strategy == 'vote':
cache.push_block(yoff, xoff, h, w, pred)
pred = cache.get_block(yoff, xoff, h, w)
pred = pred.astype('uint8')
if prev_yoff is not None:
pop_h = yoff - prev_yoff
else:
pop_h = 0
if prev_xoff is not None:
if xoff < prev_xoff:
pop_w = prev_w
else:
pop_w = xoff - prev_xoff
else:
pop_w = 0
cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w)
elif merge_strategy == 'keep_first':
rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize)
mask = rd_block != invalid_value
pred = np.where(mask, rd_block, pred)
elif merge_strategy == 'keep_last':
pass
# Write to file
band.WriteArray(pred, xoff, yoff)
dst_data.FlushCache()
prev_xoff = xoff
prev_w = w
prev_yoff = yoff
prev_h = h
dst_data = None
logging.info("GeoTiff file saved in {}.".format(save_file))
def preprocess(self, images, transforms, to_tensor=True): def preprocess(self, images, transforms, to_tensor=True):
self._check_transforms(transforms, 'test') self._check_transforms(transforms, 'test')

@ -12,12 +12,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import os.path as osp
from abc import ABCMeta, abstractmethod
from collections import Counter, defaultdict from collections import Counter, defaultdict
import numpy as np import numpy as np
import paddlers.utils.logging as logging
class SlowCache(object):
class Cache(metaclass=ABCMeta):
@abstractmethod
def get_block(self, i_st, j_st, h, w):
pass
class SlowCache(Cache):
def __init__(self): def __init__(self):
self.cache = defaultdict(Counter) self.cache = defaultdict(Counter)
@ -50,3 +61,244 @@ class SlowCache(object):
row.append(self.get_pixel(i, j)) row.append(self.get_pixel(i, j))
block.append(row) block.append(row)
return np.asarray(block) return np.asarray(block)
class ProbCache(Cache):
def __init__(self, h, w, ch, cw, sh, sw, dtype=np.float32, order='c'):
self.cache = None
self.h = h
self.w = w
self.ch = ch
self.cw = cw
self.sh = sh
self.sw = sw
if not issubclass(dtype, np.floating):
raise TypeError("`dtype` must be one of the floating types.")
self.dtype = dtype
order = order.lower()
if order not in ('c', 'f'):
raise ValueError("`order` other than 'c' and 'f' is not supported.")
self.order = order
def _alloc_memory(self, nc):
if self.order == 'c':
# Colomn-first order (C-style)
#
# <-- cw -->
# |--------|---------------------|^ ^
# | || | sh
# |--------|---------------------|| ch v
# | ||
# |--------|---------------------|v
# <------------ w --------------->
self.cache = np.zeros((self.ch, self.w, nc), dtype=self.dtype)
elif self.order == 'f':
# Row-first order (Fortran-style)
#
# <-- sw -->
# <---- cw ---->
# |--------|---|^ ^
# | | || |
# | | || ch
# | | || |
# |--------|---|| h v
# | | ||
# | | ||
# | | ||
# |--------|---|v
self.cache = np.zeros((self.h, self.cw, nc), dtype=self.dtype)
def update_block(self, i_st, j_st, h, w, prob_map):
if self.cache is None:
nc = prob_map.shape[2]
# Lazy allocation of memory
self._alloc_memory(nc)
self.cache[i_st:i_st + h, j_st:j_st + w] += prob_map
def roll_cache(self):
if self.order == 'c':
self.cache = np.roll(self.cache, -self.sh, axis=0)
self.cache[self.sh:self.ch, :] = 0
elif self.order == 'f':
self.cache = np.roll(self.cache, -self.sw, axis=1)
self.cache[:, self.sw:self.cw] = 0
def get_block(self, i_st, j_st, h, w):
return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2)
def slider_predict(predictor, img_file, save_dir, block_size, overlap,
transforms, invalid_value, merge_strategy):
"""
Do inference using sliding windows.
Args:
predictor (object): Object that implements `predict()` method.
img_file (str|tuple[str]): Image path(s).
save_dir (str): Directory that contains saved geotiff file.
block_size (list[int] | tuple[int] | int):
Size of block. If `block_size` is list or tuple, it should be in
(W, H) format.
overlap (list[int] | tuple[int] | int):
Overlap between two blocks. If `overlap` is list or tuple, it should
be in (W, H) format.
transforms (paddlers.transforms.Compose|None): Transforms for inputs. If
None, the transforms for evaluation process will be used.
invalid_value (int): Value that marks invalid pixels in output image.
Defaults to 255.
merge_strategy (str): Strategy to merge overlapping blocks. Choices are
{'keep_first', 'keep_last', 'vote', 'accum'}. 'keep_first' and
'keep_last' means keeping the values of the first and the last block in
traversal order, respectively. 'vote' means applying a simple voting
strategy when there are conflicts in the overlapping pixels. 'accum'
means determining the class of an overlapping pixel according to
accumulated probabilities.
"""
try:
from osgeo import gdal
except:
import gdal
if isinstance(block_size, int):
block_size = (block_size, block_size)
elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
block_size = tuple(block_size)
else:
raise ValueError(
"`block_size` must be a tuple/list of length 2 or an integer.")
if isinstance(overlap, int):
overlap = (overlap, overlap)
elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
overlap = tuple(overlap)
else:
raise ValueError(
"`overlap` must be a tuple/list of length 2 or an integer.")
if merge_strategy not in ('keep_first', 'keep_last', 'vote', 'accum'):
raise ValueError("{} is not a supported stragegy for block merging.".
format(merge_strategy))
step = np.array(
block_size, dtype=np.int32) - np.array(
overlap, dtype=np.int32)
if step[0] == 0 or step[1] == 0:
raise ValueError("`block_size` and `overlap` should not be equal.")
if isinstance(img_file, tuple):
if len(img_file) != 2:
raise ValueError("Tuple `img_file` must have the length of two.")
# Assume that two input images have the same size
src_data = gdal.Open(img_file[0])
src2_data = gdal.Open(img_file[1])
# Output name is the same as the name of the first image
file_name = osp.basename(osp.normpath(img_file[0]))
else:
src_data = gdal.Open(img_file)
file_name = osp.basename(osp.normpath(img_file))
# Get size of original raster
width = src_data.RasterXSize
height = src_data.RasterYSize
bands = src_data.RasterCount
if block_size[0] > width or block_size[1] > height:
raise ValueError("`block_size` should not be larger than image size.")
driver = gdal.GetDriverByName("GTiff")
if not osp.exists(save_dir):
os.makedirs(save_dir)
# Replace extension name with '.tif'
file_name = osp.splitext(file_name)[0] + ".tif"
save_file = osp.join(save_dir, file_name)
dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
# Set meta-information
dst_data.SetGeoTransform(src_data.GetGeoTransform())
dst_data.SetProjection(src_data.GetProjection())
# Initialize raster with `invalid_value`
band = dst_data.GetRasterBand(1)
band.WriteArray(
np.full(
(height, width), fill_value=invalid_value, dtype="uint8"))
if overlap == (0, 0) or block_size == (width, height):
# When there is no overlap or the whole image is used as input,
# use 'keep_last' strategy as it introduces least overheads
merge_strategy = 'keep_last'
if merge_strategy == 'vote':
logging.warning(
"Currently, a naive Python-implemented cache is used for aggregating voting results. "
"For higher performance in inferring large images, please set `merge_strategy` to 'keep_first', "
"'keep_last', or 'accum'.")
cache = SlowCache()
elif merge_strategy == 'accum':
cache = ProbCache(height, width, *block_size, *step)
prev_yoff, prev_xoff = None, None
for yoff in range(0, height, step[1]):
for xoff in range(0, width, step[0]):
xsize, ysize = block_size
if xoff + xsize > width:
xoff = width - xsize
if yoff + ysize > height:
yoff = height - ysize
# Read and fill
im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
(1, 2, 0))
if isinstance(img_file, tuple):
im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
(1, 2, 0))
# Predict
out = predictor.predict((im, im2), transforms)
else:
# Predict
out = predictor.predict(im, transforms)
pred = out['label_map'].astype('uint8')
pred = pred[:ysize, :xsize]
# Deal with overlapping pixels
if merge_strategy == 'vote':
cache.push_block(yoff, xoff, ysize, xsize, pred)
pred = cache.get_block(yoff, xoff, ysize, xsize)
pred = pred.astype('uint8')
if prev_yoff is not None:
pop_h = yoff - prev_yoff
else:
pop_h = 0
if prev_xoff is not None:
if xoff < prev_xoff:
pop_w = xsize
else:
pop_w = xoff - prev_xoff
else:
pop_w = 0
cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w)
elif merge_strategy == 'keep_first':
rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize)
mask = rd_block != invalid_value
pred = np.where(mask, rd_block, pred)
elif merge_strategy == 'keep_last':
pass
elif merge_strategy == 'accum':
prob = out['score_map']
prob = prob[:ysize, :xsize]
cache.update_block(0, yoff, ysize, xsize, prob)
pred = cache.get_block(0, yoff, ysize, xsize)
if xoff + step[0] >= width:
cache.roll_cache()
# Write to file
band.WriteArray(pred, xoff, yoff)
dst_data.FlushCache()
prev_xoff = xoff
prev_yoff = yoff
dst_data = None
logging.info("GeoTiff file saved in {}.".format(save_file))

@ -75,13 +75,9 @@ class TestSegSliderPredict(CommonTest):
# `block_size` larger than image size # `block_size` larger than image size
save_dir = osp.join(td, 'pred5') save_dir = osp.join(td, 'pred5')
self.model.slider_predict(self.image_path, save_dir, 512, 0, with self.assertRaises(ValueError):
self.transforms) self.model.slider_predict(self.image_path, save_dir, 512, 0,
pred5 = T.decode_image( self.transforms)
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred5.shape, pred_whole.shape)
def test_merge_strategy(self): def test_merge_strategy(self):
with tempfile.TemporaryDirectory() as td: with tempfile.TemporaryDirectory() as td:
@ -134,6 +130,21 @@ class TestSegSliderPredict(CommonTest):
decode_sar=False) decode_sar=False)
self.check_output_equal(pred_vote.shape, pred_whole.shape) self.check_output_equal(pred_vote.shape, pred_whole.shape)
# 'accum'
save_dir = osp.join(td, 'accum')
self.model.slider_predict(
self.image_path,
save_dir,
128,
64,
self.transforms,
merge_strategy='vote')
pred_accum = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred_accum.shape, pred_whole.shape)
def test_geo_info(self): def test_geo_info(self):
with tempfile.TemporaryDirectory() as td: with tempfile.TemporaryDirectory() as td:
_, geo_info_in = T.decode_image(self.image_path, read_geo_info=True) _, geo_info_in = T.decode_image(self.image_path, read_geo_info=True)
@ -202,13 +213,9 @@ class TestCDSliderPredict(CommonTest):
# `block_size` larger than image size # `block_size` larger than image size
save_dir = osp.join(td, 'pred5') save_dir = osp.join(td, 'pred5')
self.model.slider_predict(self.image_paths, save_dir, 512, 0, with self.assertRaises(ValueError):
self.transforms) self.model.slider_predict(self.image_paths, save_dir, 512, 0,
pred5 = T.decode_image( self.transforms)
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred5.shape, pred_whole.shape)
def test_merge_strategy(self): def test_merge_strategy(self):
with tempfile.TemporaryDirectory() as td: with tempfile.TemporaryDirectory() as td:
@ -261,6 +268,21 @@ class TestCDSliderPredict(CommonTest):
decode_sar=False) decode_sar=False)
self.check_output_equal(pred_vote.shape, pred_whole.shape) self.check_output_equal(pred_vote.shape, pred_whole.shape)
# 'accum'
save_dir = osp.join(td, 'accum')
self.model.slider_predict(
self.image_paths,
save_dir,
128,
64,
self.transforms,
merge_strategy='vote')
pred_accum = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred_accum.shape, pred_whole.shape)
def test_geo_info(self): def test_geo_info(self):
with tempfile.TemporaryDirectory() as td: with tempfile.TemporaryDirectory() as td:
_, geo_info_in = T.decode_image( _, geo_info_in = T.decode_image(

Loading…
Cancel
Save