Enhance slider_predict()

own
Bobholamovic 2 years ago
parent fc1613b4a4
commit f7a4ebc58d
  1. 15
      paddlers/tasks/base.py
  2. 141
      paddlers/tasks/change_detector.py
  3. 111
      paddlers/tasks/segmenter.py
  4. 52
      paddlers/tasks/utils/slider_predict.py
  5. 4
      paddlers/transforms/operators.py
  6. 1
      tests/fast_tests.py
  7. 2
      tests/tasks/__init__.py
  8. 274
      tests/tasks/test_slider_predict.py

@ -677,3 +677,18 @@ class BaseModel(metaclass=ModelMeta):
raise ValueError(
f"Incorrect arrange mode! Expected {mode} but got {arrange_obj.mode}."
)
def run(self, net, inputs, mode):
raise NotImplementedError
def train(self, *args, **kwargs):
raise NotImplementedError
def evaluate(self, *args, **kwargs):
raise NotImplementedError
def preprocess(self, images, transforms, to_tensor):
raise NotImplementedError
def postprocess(self, *args, **kwargs):
raise NotImplementedError

@ -35,6 +35,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict
from .base import BaseModel
from .utils import seg_metrics as metrics
from .utils.infer_nets import InferCDNet
from .utils.slider_predict import SlowCache as Cache
__all__ = [
"CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT",
@ -574,22 +575,35 @@ class BaseChangeDetector(BaseModel):
return prediction
def slider_predict(self,
img_file,
img_files,
save_dir,
block_size,
overlap=36,
transforms=None):
transforms=None,
invalid_value=255,
merge_strategy='keep_last'):
"""
Do inference.
Do inference using sliding windows.
Args:
img_file (tuple[str]): Tuple of image paths.
img_files (tuple[str]): Tuple of image paths.
save_dir (str): Directory that contains saved geotiff file.
block_size (list[int] | tuple[int] | int, optional): Size of block.
overlap (list[int] | tuple[int] | int, optional): Overlap between two blocks.
Defaults to 36.
transforms (paddlers.transforms.Compose|None, optional): Transforms for inputs.
If None, the transforms for evaluation process will be used. Defaults to None.
block_size (list[int] | tuple[int] | int):
Size of block. If `block_size` is a list or tuple, it should be in
(W, H) format.
overlap (list[int] | tuple[int] | int, optional):
Overlap between two blocks. If `overlap` is a list or tuple, it should
be in (W, H) format. Defaults to 36.
transforms (paddlers.transforms.Compose|None, optional): Transforms for
inputs. If None, the transforms for evaluation process will be used.
Defaults to None.
invalid_value (int, optional): Value that marks invalid pixels in output
image. Defaults to 255.
merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices
are {'keep_first', 'keep_last', 'vote'}. 'keep_first' and 'keep_last'
means keeping the values of the first and the last block in traversal
order, respectively. 'vote' means applying a simple voting strategy when
there are conflicts in the overlapping pixels. Defaults to 'keep_last'.
"""
try:
@ -597,8 +611,6 @@ class BaseChangeDetector(BaseModel):
except:
import gdal
if not isinstance(img_file, tuple) or len(img_file) != 2:
raise ValueError("`img_file` must be a tuple of length 2.")
if isinstance(block_size, int):
block_size = (block_size, block_size)
elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
@ -614,25 +626,54 @@ class BaseChangeDetector(BaseModel):
raise ValueError(
"`overlap` must be a tuple/list of length 2 or an integer.")
src1_data = gdal.Open(img_file[0])
src2_data = gdal.Open(img_file[1])
if merge_strategy not in ('keep_first', 'keep_last', 'vote'):
raise ValueError(
"{} is not a supported stragegy for block merging.".format(
merge_strategy))
if overlap == (0, 0):
# When there is no overlap, use 'keep_last' strategy as it introduces least overheads
merge_strategy = 'keep_last'
if merge_strategy == 'vote':
logging.warning(
"Currently, a naive Python-implemented cache is used for aggregating voting results. "
"For higher performance in inferring large images, please set `merge_strategy` to 'keep_first' or "
"'keep_last'.")
cache = Cache()
src1_data = gdal.Open(img_files[0])
src2_data = gdal.Open(img_files[1])
# Assume that two input images have the same size
width = src1_data.RasterXSize
height = src1_data.RasterYSize
bands = src1_data.RasterCount
driver = gdal.GetDriverByName("GTiff")
file_name = osp.splitext(osp.normpath(img_file[0]).split(os.sep)[-1])[
0] + ".tif"
# Output name is the same as the name of the first image
file_name = osp.basename(osp.normpath(img_files[0]))
# Replace extension name with '.tif'
file_name = osp.splitext(file_name)[0] + ".tif"
if not osp.exists(save_dir):
os.makedirs(save_dir)
save_file = osp.join(save_dir, file_name)
dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
# Set meta-information (consistent with the first image)
dst_data.SetGeoTransform(src1_data.GetGeoTransform())
dst_data.SetProjection(src1_data.GetProjection())
band = dst_data.GetRasterBand(1)
band.WriteArray(255 * np.ones((height, width), dtype="uint8"))
step = np.array(block_size) - np.array(overlap)
band = dst_data.GetRasterBand(1)
band.WriteArray(
np.full(
(height, width), fill_value=invalid_value, dtype="uint8"))
prev_yoff, prev_xoff = None, None
prev_h, prev_w = None, None
step = np.array(
block_size, dtype=np.int32) - np.array(
overlap, dtype=np.int32)
if step[0] == 0 or step[1] == 0:
raise ValueError("`block_size` and `overlap` should not be equal.")
for yoff in range(0, height, step[1]):
for xoff in range(0, width, step[0]):
xsize, ysize = block_size
@ -640,30 +681,64 @@ class BaseChangeDetector(BaseModel):
xsize = int(width - xoff)
if yoff + ysize > height:
ysize = int(height - yoff)
im1 = src1_data.ReadAsArray(
int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0))
im2 = src2_data.ReadAsArray(
int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0))
xoff = int(xoff)
yoff = int(yoff)
im1 = src1_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
(1, 2, 0))
im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
(1, 2, 0))
# Fill
h, w = im1.shape[:2]
im1_fill = np.zeros(
(block_size[1], block_size[0], bands), dtype=im1.dtype)
im2_fill = im1_fill.copy()
im1_fill[:h, :w, :] = im1
im2_fill = np.zeros(
(block_size[1], block_size[0], bands), dtype=im2.dtype)
im2_fill[:h, :w, :] = im2
im_fill = (im1_fill, im2_fill)
# Predict
pred = self.predict(im_fill,
transforms)["label_map"].astype("uint8")
# Overlap
rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize)
mask = (rd_block == pred[:h, :w]) | (rd_block == 255)
temp = pred[:h, :w].copy()
temp[mask == False] = 0
band.WriteArray(temp, int(xoff), int(yoff))
pred = self.predict((im1_fill, im2_fill), transforms)
pred = pred["label_map"].astype('uint8')
pred = pred[:h, :w]
# Deal with overlapping pixels
if merge_strategy == 'vote':
cache.push_block(yoff, xoff, h, w, pred)
pred = cache.get_block(yoff, xoff, h, w)
pred = pred.astype('uint8')
if prev_yoff is not None:
pop_h = yoff - prev_yoff
else:
pop_h = 0
if prev_xoff is not None:
if xoff < prev_xoff:
pop_w = prev_w
else:
pop_w = xoff - prev_xoff
else:
pop_w = 0
cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w)
elif merge_strategy == 'keep_first':
rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize)
mask = rd_block != invalid_value
pred = np.where(mask, rd_block, pred)
elif merge_strategy == 'keep_last':
pass
# Write to file
band.WriteArray(pred, xoff, yoff)
dst_data.FlushCache()
prev_xoff = xoff
prev_w = w
prev_yoff = yoff
prev_h = h
dst_data = None
print("GeoTiff saved in {}.".format(save_file))
logging.info("GeoTiff file saved in {}.".format(save_file))
def preprocess(self, images, transforms, to_tensor=True):
self._check_transforms(transforms, 'test')

@ -34,6 +34,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict
from .base import BaseModel
from .utils import seg_metrics as metrics
from .utils.infer_nets import InferSegNet
from .utils.slider_predict import SlowCache as Cache
__all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
@ -550,20 +551,31 @@ class BaseSegmenter(BaseModel):
save_dir,
block_size,
overlap=36,
transforms=None):
transforms=None,
invalid_value=255,
merge_strategy='keep_last'):
"""
Do inference.
Do inference using sliding windows.
Args:
img_file (str): Image path.
save_dir (str): Directory that contains saved geotiff file.
block_size (list[int] | tuple[int] | int):
Size of block.
Size of block. If `block_size` is list or tuple, it should be in
(W, H) format.
overlap (list[int] | tuple[int] | int, optional):
Overlap between two blocks. Defaults to 36.
Overlap between two blocks. If `overlap` is list or tuple, it should
be in (W, H) format. Defaults to 36.
transforms (paddlers.transforms.Compose|None, optional): Transforms for
inputs. If None, the transforms for evaluation process will be used.
Defaults to None.
invalid_value (int, optional): Value that marks invalid pixels in output
image. Defaults to 255.
merge_strategy (str, optional): Strategy to merge overlapping blocks. Choices
are {'keep_first', 'keep_last', 'vote'}. 'keep_first' and 'keep_last'
means keeping the values of the first and the last block in traversal
order, respectively. 'vote' means applying a simple voting strategy when
there are conflicts in the overlapping pixels. Defaults to 'keep_last'.
"""
try:
@ -586,24 +598,50 @@ class BaseSegmenter(BaseModel):
raise ValueError(
"`overlap` must be a tuple/list of length 2 or an integer.")
if merge_strategy not in ('keep_first', 'keep_last', 'vote'):
raise ValueError(
"{} is not a supported stragegy for block merging.".format(
merge_strategy))
if overlap == (0, 0):
# When there is no overlap, use 'keep_last' strategy as it introduces least overheads
merge_strategy = 'keep_last'
if merge_strategy == 'vote':
logging.warning(
"Currently, a naive Python-implemented cache is used for aggregating voting results. "
"For higher performance in inferring large images, please set `merge_strategy` to 'keep_first' or "
"'keep_last'.")
cache = Cache()
src_data = gdal.Open(img_file)
width = src_data.RasterXSize
height = src_data.RasterYSize
bands = src_data.RasterCount
driver = gdal.GetDriverByName("GTiff")
file_name = osp.splitext(osp.normpath(img_file).split(os.sep)[-1])[
0] + ".tif"
file_name = osp.basename(osp.normpath(img_file))
# Replace extension name with '.tif'
file_name = osp.splitext(file_name)[0] + ".tif"
if not osp.exists(save_dir):
os.makedirs(save_dir)
save_file = osp.join(save_dir, file_name)
dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
# Set meta-information
dst_data.SetGeoTransform(src_data.GetGeoTransform())
dst_data.SetProjection(src_data.GetProjection())
band = dst_data.GetRasterBand(1)
band.WriteArray(255 * np.ones((height, width), dtype="uint8"))
step = np.array(block_size) - np.array(overlap)
band = dst_data.GetRasterBand(1)
band.WriteArray(
np.full(
(height, width), fill_value=invalid_value, dtype="uint8"))
prev_yoff, prev_xoff = None, None
prev_h, prev_w = None, None
step = np.array(
block_size, dtype=np.int32) - np.array(
overlap, dtype=np.int32)
if step[0] == 0 or step[1] == 0:
raise ValueError("`block_size` and `overlap` should not be equal.")
for yoff in range(0, height, step[1]):
for xoff in range(0, width, step[0]):
xsize, ysize = block_size
@ -611,25 +649,58 @@ class BaseSegmenter(BaseModel):
xsize = int(width - xoff)
if yoff + ysize > height:
ysize = int(height - yoff)
im = src_data.ReadAsArray(int(xoff), int(yoff), xsize,
ysize).transpose((1, 2, 0))
xoff = int(xoff)
yoff = int(yoff)
im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
(1, 2, 0))
# Fill
h, w = im.shape[:2]
im_fill = np.zeros(
(block_size[1], block_size[0], bands), dtype=im.dtype)
im_fill[:h, :w, :] = im
# Predict
pred = self.predict(im_fill,
transforms)["label_map"].astype("uint8")
# Overlap
rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize)
mask = (rd_block == pred[:h, :w]) | (rd_block == 255)
temp = pred[:h, :w].copy()
temp[mask == False] = 0
band.WriteArray(temp, int(xoff), int(yoff))
pred = self.predict(im_fill, transforms)
pred = pred["label_map"].astype('uint8')
pred = pred[:h, :w]
# Deal with overlapping pixels
if merge_strategy == 'vote':
cache.push_block(yoff, xoff, h, w, pred)
pred = cache.get_block(yoff, xoff, h, w)
pred = pred.astype('uint8')
if prev_yoff is not None:
pop_h = yoff - prev_yoff
else:
pop_h = 0
if prev_xoff is not None:
if xoff < prev_xoff:
pop_w = prev_w
else:
pop_w = xoff - prev_xoff
else:
pop_w = 0
cache.pop_block(prev_yoff, prev_xoff, pop_h, pop_w)
elif merge_strategy == 'keep_first':
rd_block = band.ReadAsArray(xoff, yoff, xsize, ysize)
mask = rd_block != invalid_value
pred = np.where(mask, rd_block, pred)
elif merge_strategy == 'keep_last':
pass
# Write to file
band.WriteArray(pred, xoff, yoff)
dst_data.FlushCache()
prev_xoff = xoff
prev_w = w
prev_yoff = yoff
prev_h = h
dst_data = None
print("GeoTiff saved in {}.".format(save_file))
logging.info("GeoTiff file saved in {}.".format(save_file))
def preprocess(self, images, transforms, to_tensor=True):
self._check_transforms(transforms, 'test')

@ -0,0 +1,52 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import Counter, defaultdict
import numpy as np
class SlowCache(object):
def __init__(self):
self.cache = defaultdict(Counter)
def push_pixel(self, i, j, l):
self.cache[(i, j)][l] += 1
def push_block(self, i_st, j_st, h, w, data):
for i in range(0, h):
for j in range(0, w):
self.push_pixel(i_st + i, j_st + j, data[i, j])
def pop_pixel(self, i, j):
self.cache.pop((i, j))
def pop_block(self, i_st, j_st, h, w):
for i in range(0, h):
for j in range(0, w):
self.pop_pixel(i_st + i, j_st + j)
def get_pixel(self, i, j):
winners = self.cache[(i, j)].most_common(1)
winner = winners[0]
return winner[0]
def get_block(self, i_st, j_st, h, w):
block = []
for i in range(i_st, i_st + h):
row = []
for j in range(j_st, j_st + w):
row.append(self.get_pixel(i, j))
block.append(row)
return np.asarray(block)

@ -197,7 +197,7 @@ class DecodeImg(Transform):
self.to_uint8 = to_uint8
self.decode_bgr = decode_bgr
self.decode_sar = decode_sar
self.read_geo_info = False
self.read_geo_info = read_geo_info
def read_img(self, img_path):
img_format = imghdr.what(img_path)
@ -227,7 +227,7 @@ class DecodeImg(Transform):
im_data = im_data.transpose((1, 2, 0))
if self.read_geo_info:
geo_trans = dataset.GetGeoTransform()
geo_proj = dataset.GetGeoProjection()
geo_proj = dataset.GetProjection()
elif img_format in ['jpeg', 'bmp', 'png', 'jpg']:
if self.decode_bgr:
im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |

@ -13,4 +13,5 @@
# limitations under the License.
from rs_models import *
from tasks import *
from transforms import *

@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .test_slider_predict import *

@ -0,0 +1,274 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
import tempfile
import paddlers as pdrs
import paddlers.transforms as T
from testing_utils import CommonTest
class TestSegSliderPredict(CommonTest):
def setUp(self):
self.model = pdrs.tasks.seg.UNet(in_channels=10)
self.transforms = T.Compose([
T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10),
T.ArrangeSegmenter('test')
])
self.image_path = "data/ssst/multispectral.tif"
self.basename = osp.basename(self.image_path)
def test_blocksize_and_overlap_whole(self):
# Original image size (256, 256)
with tempfile.TemporaryDirectory() as td:
# Whole-image inference using predict()
pred_whole = self.model.predict(self.image_path, self.transforms)
pred_whole = pred_whole['label_map']
# Whole-image inference using slider_predict()
save_dir = osp.join(td, 'pred1')
self.model.slider_predict(self.image_path, save_dir, 256, 0,
self.transforms)
pred1 = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred1.shape, pred_whole.shape)
# `block_size` == `overlap`
save_dir = osp.join(td, 'pred2')
with self.assertRaises(ValueError):
self.model.slider_predict(self.image_path, save_dir, 128, 128,
self.transforms)
# `block_size` is a tuple
save_dir = osp.join(td, 'pred3')
self.model.slider_predict(self.image_path, save_dir, (128, 32), 0,
self.transforms)
pred3 = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred3.shape, pred_whole.shape)
# `block_size` and `overlap` are both tuples
save_dir = osp.join(td, 'pred4')
self.model.slider_predict(self.image_path, save_dir, (128, 100),
(10, 5), self.transforms)
pred4 = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred4.shape, pred_whole.shape)
# `block_size` larger than image size
save_dir = osp.join(td, 'pred5')
self.model.slider_predict(self.image_path, save_dir, 512, 0,
self.transforms)
pred5 = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred5.shape, pred_whole.shape)
def test_merge_strategy(self):
with tempfile.TemporaryDirectory() as td:
# Whole-image inference using predict()
pred_whole = self.model.predict(self.image_path, self.transforms)
pred_whole = pred_whole['label_map']
# 'keep_first'
save_dir = osp.join(td, 'keep_first')
self.model.slider_predict(
self.image_path,
save_dir,
128,
64,
self.transforms,
merge_strategy='keep_first')
pred_keepfirst = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred_keepfirst.shape, pred_whole.shape)
# 'keep_last'
save_dir = osp.join(td, 'keep_last')
self.model.slider_predict(
self.image_path,
save_dir,
128,
64,
self.transforms,
merge_strategy='keep_last')
pred_keeplast = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred_keeplast.shape, pred_whole.shape)
# 'vote'
save_dir = osp.join(td, 'vote')
self.model.slider_predict(
self.image_path,
save_dir,
128,
64,
self.transforms,
merge_strategy='vote')
pred_vote = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred_vote.shape, pred_whole.shape)
def test_geo_info(self):
with tempfile.TemporaryDirectory() as td:
_, geo_info_in = T.decode_image(self.image_path, read_geo_info=True)
self.model.slider_predict(self.image_path, td, 128, 0,
self.transforms)
_, geo_info_out = T.decode_image(
osp.join(td, self.basename), read_geo_info=True)
self.assertEqual(geo_info_out['geo_trans'],
geo_info_in['geo_trans'])
self.assertEqual(geo_info_out['geo_proj'], geo_info_in['geo_proj'])
class TestCDSliderPredict(CommonTest):
def setUp(self):
self.model = pdrs.tasks.cd.BIT(in_channels=10)
self.transforms = T.Compose([
T.DecodeImg(), T.Normalize([0.5] * 10, [0.5] * 10),
T.ArrangeChangeDetector('test')
])
self.image_paths = ("data/ssmt/multispectral_t1.tif",
"data/ssmt/multispectral_t2.tif")
self.basename = osp.basename(self.image_paths[0])
def test_blocksize_and_overlap_whole(self):
# Original image size (256, 256)
with tempfile.TemporaryDirectory() as td:
# Whole-image inference using predict()
pred_whole = self.model.predict(self.image_paths, self.transforms)
pred_whole = pred_whole['label_map']
# Whole-image inference using slider_predict()
save_dir = osp.join(td, 'pred1')
self.model.slider_predict(self.image_paths, save_dir, 256, 0,
self.transforms)
pred1 = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred1.shape, pred_whole.shape)
# `block_size` == `overlap`
save_dir = osp.join(td, 'pred2')
with self.assertRaises(ValueError):
self.model.slider_predict(self.image_paths, save_dir, 128, 128,
self.transforms)
# `block_size` is a tuple
save_dir = osp.join(td, 'pred3')
self.model.slider_predict(self.image_paths, save_dir, (128, 32), 0,
self.transforms)
pred3 = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred3.shape, pred_whole.shape)
# `block_size` and `overlap` are both tuples
save_dir = osp.join(td, 'pred4')
self.model.slider_predict(self.image_paths, save_dir, (128, 100),
(10, 5), self.transforms)
pred4 = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred4.shape, pred_whole.shape)
# `block_size` larger than image size
save_dir = osp.join(td, 'pred5')
self.model.slider_predict(self.image_paths, save_dir, 512, 0,
self.transforms)
pred5 = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred5.shape, pred_whole.shape)
def test_merge_strategy(self):
with tempfile.TemporaryDirectory() as td:
# Whole-image inference using predict()
pred_whole = self.model.predict(self.image_paths, self.transforms)
pred_whole = pred_whole['label_map']
# 'keep_first'
save_dir = osp.join(td, 'keep_first')
self.model.slider_predict(
self.image_paths,
save_dir,
128,
64,
self.transforms,
merge_strategy='keep_first')
pred_keepfirst = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred_keepfirst.shape, pred_whole.shape)
# 'keep_last'
save_dir = osp.join(td, 'keep_last')
self.model.slider_predict(
self.image_paths,
save_dir,
128,
64,
self.transforms,
merge_strategy='keep_last')
pred_keeplast = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred_keeplast.shape, pred_whole.shape)
# 'vote'
save_dir = osp.join(td, 'vote')
self.model.slider_predict(
self.image_paths,
save_dir,
128,
64,
self.transforms,
merge_strategy='vote')
pred_vote = T.decode_image(
osp.join(save_dir, self.basename),
to_uint8=False,
decode_sar=False)
self.check_output_equal(pred_vote.shape, pred_whole.shape)
def test_geo_info(self):
with tempfile.TemporaryDirectory() as td:
_, geo_info_in = T.decode_image(
self.image_paths[0], read_geo_info=True)
self.model.slider_predict(self.image_paths, td, 128, 0,
self.transforms)
_, geo_info_out = T.decode_image(
osp.join(td, self.basename), read_geo_info=True)
self.assertEqual(geo_info_out['geo_trans'],
geo_info_in['geo_trans'])
self.assertEqual(geo_info_out['geo_proj'], geo_info_in['geo_proj'])
Loading…
Cancel
Save