|
|
|
@ -19,6 +19,7 @@ from abc import ABCMeta, abstractmethod |
|
|
|
|
from collections import Counter, defaultdict |
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
|
import paddlers.utils.logging as logging |
|
|
|
|
|
|
|
|
@ -31,6 +32,7 @@ class Cache(metaclass=ABCMeta): |
|
|
|
|
|
|
|
|
|
class SlowCache(Cache): |
|
|
|
|
def __init__(self): |
|
|
|
|
super(SlowCache, self).__init__() |
|
|
|
|
self.cache = defaultdict(Counter) |
|
|
|
|
|
|
|
|
|
def push_pixel(self, i, j, l): |
|
|
|
@ -66,6 +68,7 @@ class SlowCache(Cache): |
|
|
|
|
|
|
|
|
|
class ProbCache(Cache): |
|
|
|
|
def __init__(self, h, w, ch, cw, sh, sw, dtype=np.float32, order='c'): |
|
|
|
|
super(ProbCache, self).__init__() |
|
|
|
|
self.cache = None |
|
|
|
|
self.h = h |
|
|
|
|
self.w = w |
|
|
|
@ -116,20 +119,139 @@ class ProbCache(Cache): |
|
|
|
|
self._alloc_memory(nc) |
|
|
|
|
self.cache[i_st:i_st + h, j_st:j_st + w] += prob_map |
|
|
|
|
|
|
|
|
|
def roll_cache(self): |
|
|
|
|
def roll_cache(self, shift): |
|
|
|
|
if self.order == 'c': |
|
|
|
|
self.cache[:-self.sh] = self.cache[self.sh:] |
|
|
|
|
self.cache[-self.sh:, :] = 0 |
|
|
|
|
self.cache[:-shift] = self.cache[shift:] |
|
|
|
|
self.cache[-shift:, :] = 0 |
|
|
|
|
elif self.order == 'f': |
|
|
|
|
self.cache[:, :-self.sw] = self.cache[:, self.sw:] |
|
|
|
|
self.cache[:, -self.sw:] = 0 |
|
|
|
|
self.cache[:, :-shift] = self.cache[:, shift:] |
|
|
|
|
self.cache[:, -shift:] = 0 |
|
|
|
|
|
|
|
|
|
def get_block(self, i_st, j_st, h, w): |
|
|
|
|
return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def slider_predict(predict_func, img_file, save_dir, block_size, overlap, |
|
|
|
|
transforms, invalid_value, merge_strategy, batch_size): |
|
|
|
|
class OverlapProcessor(metaclass=ABCMeta): |
|
|
|
|
def __init__(self, h, w, ch, cw, sh, sw): |
|
|
|
|
super(OverlapProcessor, self).__init__() |
|
|
|
|
self.h = h |
|
|
|
|
self.w = w |
|
|
|
|
self.ch = ch |
|
|
|
|
self.cw = cw |
|
|
|
|
self.sh = sh |
|
|
|
|
self.sw = sw |
|
|
|
|
|
|
|
|
|
@abstractmethod |
|
|
|
|
def process_pred(self, out, xoff, yoff): |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class KeepFirstProcessor(OverlapProcessor): |
|
|
|
|
def __init__(self, h, w, ch, cw, sh, sw, ds, inval=255): |
|
|
|
|
super(KeepFirstProcessor, self).__init__(h, w, ch, cw, sh, sw) |
|
|
|
|
self.ds = ds |
|
|
|
|
self.inval = inval |
|
|
|
|
|
|
|
|
|
def process_pred(self, out, xoff, yoff): |
|
|
|
|
pred = out['label_map'] |
|
|
|
|
pred = pred[:self.ch, :self.cw] |
|
|
|
|
rd_block = self.ds.ReadAsArray(xoff, yoff, self.cw, self.ch) |
|
|
|
|
mask = rd_block != self.inval |
|
|
|
|
pred = np.where(mask, rd_block, pred) |
|
|
|
|
return pred |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class KeepLastProcessor(OverlapProcessor): |
|
|
|
|
def process_pred(self, out, xoff, yoff): |
|
|
|
|
pred = out['label_map'] |
|
|
|
|
pred = pred[:self.ch, :self.cw] |
|
|
|
|
return pred |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AccumProcessor(OverlapProcessor): |
|
|
|
|
def __init__(self, |
|
|
|
|
h, |
|
|
|
|
w, |
|
|
|
|
ch, |
|
|
|
|
cw, |
|
|
|
|
sh, |
|
|
|
|
sw, |
|
|
|
|
dtype=np.float16, |
|
|
|
|
assign_weight=True): |
|
|
|
|
super(AccumProcessor, self).__init__(h, w, ch, cw, sh, sw) |
|
|
|
|
self.cache = ProbCache(h, w, ch, cw, sh, sw, dtype=dtype, order='c') |
|
|
|
|
self.prev_yoff = None |
|
|
|
|
self.assign_weight = assign_weight |
|
|
|
|
|
|
|
|
|
def process_pred(self, out, xoff, yoff): |
|
|
|
|
if self.prev_yoff is not None and yoff != self.prev_yoff: |
|
|
|
|
if yoff < self.prev_yoff: |
|
|
|
|
raise RuntimeError |
|
|
|
|
self.cache.roll_cache(yoff - self.prev_yoff) |
|
|
|
|
pred = out['label_map'] |
|
|
|
|
pred = pred[:self.ch, :self.cw] |
|
|
|
|
prob = out['score_map'] |
|
|
|
|
prob = prob[:self.ch, :self.cw] |
|
|
|
|
if self.assign_weight: |
|
|
|
|
prob = assign_border_weights(prob, border_ratio=0.25, inplace=True) |
|
|
|
|
self.cache.update_block(0, xoff, self.ch, self.cw, prob) |
|
|
|
|
pred = self.cache.get_block(0, xoff, self.ch, self.cw) |
|
|
|
|
self.prev_yoff = yoff |
|
|
|
|
return pred |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def assign_border_weights(array, weight=0.5, border_ratio=0.25, inplace=True): |
|
|
|
|
if not inplace: |
|
|
|
|
array = array.copy() |
|
|
|
|
h, w = array.shape[:2] |
|
|
|
|
hm, wm = int(h * border_ratio), int(w * border_ratio) |
|
|
|
|
array[:hm] *= weight |
|
|
|
|
array[-hm:] *= weight |
|
|
|
|
array[:, :wm] *= weight |
|
|
|
|
array[:, -wm:] *= weight |
|
|
|
|
return array |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def read_block(ds, |
|
|
|
|
xoff, |
|
|
|
|
yoff, |
|
|
|
|
xsize, |
|
|
|
|
ysize, |
|
|
|
|
tar_xsize=None, |
|
|
|
|
tar_ysize=None, |
|
|
|
|
pad_val=0): |
|
|
|
|
if tar_xsize is None: |
|
|
|
|
tar_xsize = xsize |
|
|
|
|
if tar_ysize is None: |
|
|
|
|
tar_ysize = ysize |
|
|
|
|
# Read data from dataset |
|
|
|
|
block = ds.ReadAsArray(xoff, yoff, xsize, ysize) |
|
|
|
|
c, real_ysize, real_xsize = block.shape |
|
|
|
|
assert real_ysize == ysize and real_xsize == xsize |
|
|
|
|
# [c, h, w] -> [h, w, c] |
|
|
|
|
block = block.transpose((1, 2, 0)) |
|
|
|
|
if (real_ysize, real_xsize) != (tar_ysize, tar_xsize): |
|
|
|
|
if real_ysize >= tar_ysize or real_xsize >= tar_xsize: |
|
|
|
|
raise ValueError |
|
|
|
|
padded_block = np.full( |
|
|
|
|
(tar_ysize, tar_xsize, c), fill_value=pad_val, dtype=block.dtype) |
|
|
|
|
# Fill |
|
|
|
|
padded_block[:real_ysize, :real_xsize] = block |
|
|
|
|
return padded_block |
|
|
|
|
else: |
|
|
|
|
return block |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def slider_predict(predict_func, |
|
|
|
|
img_file, |
|
|
|
|
save_dir, |
|
|
|
|
block_size, |
|
|
|
|
overlap, |
|
|
|
|
transforms, |
|
|
|
|
invalid_value, |
|
|
|
|
merge_strategy, |
|
|
|
|
batch_size, |
|
|
|
|
show_progress=False): |
|
|
|
|
""" |
|
|
|
|
Do inference using sliding windows. |
|
|
|
|
|
|
|
|
@ -153,6 +275,8 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, |
|
|
|
|
traversal order, respectively. 'accum' means determining the class |
|
|
|
|
of an overlapping pixel according to accumulated probabilities. |
|
|
|
|
batch_size (int): Batch size used in inference. |
|
|
|
|
show_progress (bool, optional): Whether to show prediction progress with a |
|
|
|
|
progress bar. Defaults to True. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
@ -175,10 +299,6 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, |
|
|
|
|
raise ValueError( |
|
|
|
|
"`overlap` must be a tuple/list of length 2 or an integer.") |
|
|
|
|
|
|
|
|
|
if merge_strategy not in ('keep_first', 'keep_last', 'accum'): |
|
|
|
|
raise ValueError("{} is not a supported stragegy for block merging.". |
|
|
|
|
format(merge_strategy)) |
|
|
|
|
|
|
|
|
|
step = np.array( |
|
|
|
|
block_size, dtype=np.int32) - np.array( |
|
|
|
|
overlap, dtype=np.int32) |
|
|
|
@ -234,29 +354,50 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, |
|
|
|
|
# When there is no overlap or the whole image is used as input, |
|
|
|
|
# use 'keep_last' strategy as it introduces least overheads |
|
|
|
|
merge_strategy = 'keep_last' |
|
|
|
|
if merge_strategy == 'accum': |
|
|
|
|
cache = ProbCache(height, width, *block_size[::-1], *step[::-1]) |
|
|
|
|
|
|
|
|
|
if merge_strategy == 'keep_first': |
|
|
|
|
overlap_processor = KeepFirstProcessor( |
|
|
|
|
height, |
|
|
|
|
width, |
|
|
|
|
*block_size[::-1], |
|
|
|
|
*step[::-1], |
|
|
|
|
band, |
|
|
|
|
inval=invalid_value) |
|
|
|
|
elif merge_strategy == 'keep_last': |
|
|
|
|
overlap_processor = KeepLastProcessor(height, width, *block_size[::-1], |
|
|
|
|
*step[::-1]) |
|
|
|
|
elif merge_strategy == 'accum': |
|
|
|
|
overlap_processor = AccumProcessor(height, width, *block_size[::-1], |
|
|
|
|
*step[::-1]) |
|
|
|
|
else: |
|
|
|
|
raise ValueError("{} is not a supported stragegy for block merging.". |
|
|
|
|
format(merge_strategy)) |
|
|
|
|
|
|
|
|
|
xsize, ysize = block_size |
|
|
|
|
num_blocks = math.ceil(height / step[1]) * math.ceil(width / step[0]) |
|
|
|
|
cnt = 0 |
|
|
|
|
if show_progress: |
|
|
|
|
pb = tqdm(total=num_blocks) |
|
|
|
|
batch_data = [] |
|
|
|
|
batch_offsets = [] |
|
|
|
|
for yoff in range(0, height, step[1]): |
|
|
|
|
for xoff in range(0, width, step[0]): |
|
|
|
|
xsize, ysize = block_size |
|
|
|
|
if xoff + xsize > width: |
|
|
|
|
xoff = width - xsize |
|
|
|
|
is_end_of_row = True |
|
|
|
|
else: |
|
|
|
|
is_end_of_row = False |
|
|
|
|
if yoff + ysize > height: |
|
|
|
|
yoff = height - ysize |
|
|
|
|
is_end_of_col = True |
|
|
|
|
else: |
|
|
|
|
is_end_of_col = False |
|
|
|
|
|
|
|
|
|
is_end_of_col = yoff + ysize >= height |
|
|
|
|
is_end_of_row = xoff + xsize >= width |
|
|
|
|
|
|
|
|
|
# Read and fill |
|
|
|
|
im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( |
|
|
|
|
(1, 2, 0)) |
|
|
|
|
# Read |
|
|
|
|
im = read_block(src_data, xoff, yoff, xsize, ysize) |
|
|
|
|
|
|
|
|
|
if isinstance(img_file, tuple): |
|
|
|
|
im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( |
|
|
|
|
(1, 2, 0)) |
|
|
|
|
im2 = read_block(src2_data, xoff, yoff, xsize, ysize) |
|
|
|
|
batch_data.append((im, im2)) |
|
|
|
|
else: |
|
|
|
|
batch_data.append(im) |
|
|
|
@ -276,24 +417,8 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, |
|
|
|
|
batch_out = predict_func(batch_data, transforms=transforms) |
|
|
|
|
|
|
|
|
|
for out, (xoff_, yoff_) in zip(batch_out, batch_offsets): |
|
|
|
|
pred = out['label_map'].astype('uint8') |
|
|
|
|
pred = pred[:ysize, :xsize] |
|
|
|
|
|
|
|
|
|
# Deal with overlapping pixels |
|
|
|
|
if merge_strategy == 'keep_first': |
|
|
|
|
rd_block = band.ReadAsArray(xoff_, yoff_, xsize, ysize) |
|
|
|
|
mask = rd_block != invalid_value |
|
|
|
|
pred = np.where(mask, rd_block, pred) |
|
|
|
|
elif merge_strategy == 'keep_last': |
|
|
|
|
pass |
|
|
|
|
elif merge_strategy == 'accum': |
|
|
|
|
prob = out['score_map'] |
|
|
|
|
prob = prob[:ysize, :xsize] |
|
|
|
|
cache.update_block(0, xoff_, ysize, xsize, prob) |
|
|
|
|
pred = cache.get_block(0, xoff_, ysize, xsize) |
|
|
|
|
if xoff_ + xsize >= width: |
|
|
|
|
cache.roll_cache() |
|
|
|
|
|
|
|
|
|
# Get processed result |
|
|
|
|
pred = overlap_processor.process_pred(out, xoff_, yoff_) |
|
|
|
|
# Write to file |
|
|
|
|
band.WriteArray(pred, xoff_, yoff_) |
|
|
|
|
|
|
|
|
@ -301,5 +426,12 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap, |
|
|
|
|
batch_data.clear() |
|
|
|
|
batch_offsets.clear() |
|
|
|
|
|
|
|
|
|
cnt += 1 |
|
|
|
|
|
|
|
|
|
if show_progress: |
|
|
|
|
pb.update(1) |
|
|
|
|
pb.set_description("{} out of {} blocks processed.".format( |
|
|
|
|
cnt, num_blocks)) |
|
|
|
|
|
|
|
|
|
dst_data = None |
|
|
|
|
logging.info("GeoTiff file saved in {}.".format(save_file)) |
|
|
|
|