Refactor and fix bugs

own
Bobholamovic 2 years ago
parent 3939cbff16
commit 2f60abd899
  1. 212
      paddlers/tasks/utils/slider_predict.py

@ -19,6 +19,7 @@ from abc import ABCMeta, abstractmethod
from collections import Counter, defaultdict
import numpy as np
from tqdm import tqdm
import paddlers.utils.logging as logging
@ -31,6 +32,7 @@ class Cache(metaclass=ABCMeta):
class SlowCache(Cache):
def __init__(self):
super(SlowCache, self).__init__()
self.cache = defaultdict(Counter)
def push_pixel(self, i, j, l):
@ -66,6 +68,7 @@ class SlowCache(Cache):
class ProbCache(Cache):
def __init__(self, h, w, ch, cw, sh, sw, dtype=np.float32, order='c'):
super(ProbCache, self).__init__()
self.cache = None
self.h = h
self.w = w
@ -116,20 +119,139 @@ class ProbCache(Cache):
self._alloc_memory(nc)
self.cache[i_st:i_st + h, j_st:j_st + w] += prob_map
def roll_cache(self):
def roll_cache(self, shift):
if self.order == 'c':
self.cache[:-self.sh] = self.cache[self.sh:]
self.cache[-self.sh:, :] = 0
self.cache[:-shift] = self.cache[shift:]
self.cache[-shift:, :] = 0
elif self.order == 'f':
self.cache[:, :-self.sw] = self.cache[:, self.sw:]
self.cache[:, -self.sw:] = 0
self.cache[:, :-shift] = self.cache[:, shift:]
self.cache[:, -shift:] = 0
def get_block(self, i_st, j_st, h, w):
return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2)
def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
transforms, invalid_value, merge_strategy, batch_size):
class OverlapProcessor(metaclass=ABCMeta):
def __init__(self, h, w, ch, cw, sh, sw):
super(OverlapProcessor, self).__init__()
self.h = h
self.w = w
self.ch = ch
self.cw = cw
self.sh = sh
self.sw = sw
@abstractmethod
def process_pred(self, out, xoff, yoff):
pass
class KeepFirstProcessor(OverlapProcessor):
def __init__(self, h, w, ch, cw, sh, sw, ds, inval=255):
super(KeepFirstProcessor, self).__init__(h, w, ch, cw, sh, sw)
self.ds = ds
self.inval = inval
def process_pred(self, out, xoff, yoff):
pred = out['label_map']
pred = pred[:self.ch, :self.cw]
rd_block = self.ds.ReadAsArray(xoff, yoff, self.cw, self.ch)
mask = rd_block != self.inval
pred = np.where(mask, rd_block, pred)
return pred
class KeepLastProcessor(OverlapProcessor):
def process_pred(self, out, xoff, yoff):
pred = out['label_map']
pred = pred[:self.ch, :self.cw]
return pred
class AccumProcessor(OverlapProcessor):
def __init__(self,
h,
w,
ch,
cw,
sh,
sw,
dtype=np.float16,
assign_weight=True):
super(AccumProcessor, self).__init__(h, w, ch, cw, sh, sw)
self.cache = ProbCache(h, w, ch, cw, sh, sw, dtype=dtype, order='c')
self.prev_yoff = None
self.assign_weight = assign_weight
def process_pred(self, out, xoff, yoff):
if self.prev_yoff is not None and yoff != self.prev_yoff:
if yoff < self.prev_yoff:
raise RuntimeError
self.cache.roll_cache(yoff - self.prev_yoff)
pred = out['label_map']
pred = pred[:self.ch, :self.cw]
prob = out['score_map']
prob = prob[:self.ch, :self.cw]
if self.assign_weight:
prob = assign_border_weights(prob, border_ratio=0.25, inplace=True)
self.cache.update_block(0, xoff, self.ch, self.cw, prob)
pred = self.cache.get_block(0, xoff, self.ch, self.cw)
self.prev_yoff = yoff
return pred
def assign_border_weights(array, weight=0.5, border_ratio=0.25, inplace=True):
if not inplace:
array = array.copy()
h, w = array.shape[:2]
hm, wm = int(h * border_ratio), int(w * border_ratio)
array[:hm] *= weight
array[-hm:] *= weight
array[:, :wm] *= weight
array[:, -wm:] *= weight
return array
def read_block(ds,
xoff,
yoff,
xsize,
ysize,
tar_xsize=None,
tar_ysize=None,
pad_val=0):
if tar_xsize is None:
tar_xsize = xsize
if tar_ysize is None:
tar_ysize = ysize
# Read data from dataset
block = ds.ReadAsArray(xoff, yoff, xsize, ysize)
c, real_ysize, real_xsize = block.shape
assert real_ysize == ysize and real_xsize == xsize
# [c, h, w] -> [h, w, c]
block = block.transpose((1, 2, 0))
if (real_ysize, real_xsize) != (tar_ysize, tar_xsize):
if real_ysize >= tar_ysize or real_xsize >= tar_xsize:
raise ValueError
padded_block = np.full(
(tar_ysize, tar_xsize, c), fill_value=pad_val, dtype=block.dtype)
# Fill
padded_block[:real_ysize, :real_xsize] = block
return padded_block
else:
return block
def slider_predict(predict_func,
img_file,
save_dir,
block_size,
overlap,
transforms,
invalid_value,
merge_strategy,
batch_size,
show_progress=False):
"""
Do inference using sliding windows.
@ -153,6 +275,8 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
traversal order, respectively. 'accum' means determining the class
of an overlapping pixel according to accumulated probabilities.
batch_size (int): Batch size used in inference.
show_progress (bool, optional): Whether to show prediction progress with a
progress bar. Defaults to True.
"""
try:
@ -175,10 +299,6 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
raise ValueError(
"`overlap` must be a tuple/list of length 2 or an integer.")
if merge_strategy not in ('keep_first', 'keep_last', 'accum'):
raise ValueError("{} is not a supported stragegy for block merging.".
format(merge_strategy))
step = np.array(
block_size, dtype=np.int32) - np.array(
overlap, dtype=np.int32)
@ -234,29 +354,50 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
# When there is no overlap or the whole image is used as input,
# use 'keep_last' strategy as it introduces least overheads
merge_strategy = 'keep_last'
if merge_strategy == 'accum':
cache = ProbCache(height, width, *block_size[::-1], *step[::-1])
if merge_strategy == 'keep_first':
overlap_processor = KeepFirstProcessor(
height,
width,
*block_size[::-1],
*step[::-1],
band,
inval=invalid_value)
elif merge_strategy == 'keep_last':
overlap_processor = KeepLastProcessor(height, width, *block_size[::-1],
*step[::-1])
elif merge_strategy == 'accum':
overlap_processor = AccumProcessor(height, width, *block_size[::-1],
*step[::-1])
else:
raise ValueError("{} is not a supported stragegy for block merging.".
format(merge_strategy))
xsize, ysize = block_size
num_blocks = math.ceil(height / step[1]) * math.ceil(width / step[0])
cnt = 0
if show_progress:
pb = tqdm(total=num_blocks)
batch_data = []
batch_offsets = []
for yoff in range(0, height, step[1]):
for xoff in range(0, width, step[0]):
xsize, ysize = block_size
if xoff + xsize > width:
xoff = width - xsize
is_end_of_row = True
else:
is_end_of_row = False
if yoff + ysize > height:
yoff = height - ysize
is_end_of_col = True
else:
is_end_of_col = False
is_end_of_col = yoff + ysize >= height
is_end_of_row = xoff + xsize >= width
# Read and fill
im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
(1, 2, 0))
# Read
im = read_block(src_data, xoff, yoff, xsize, ysize)
if isinstance(img_file, tuple):
im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
(1, 2, 0))
im2 = read_block(src2_data, xoff, yoff, xsize, ysize)
batch_data.append((im, im2))
else:
batch_data.append(im)
@ -276,24 +417,8 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
batch_out = predict_func(batch_data, transforms=transforms)
for out, (xoff_, yoff_) in zip(batch_out, batch_offsets):
pred = out['label_map'].astype('uint8')
pred = pred[:ysize, :xsize]
# Deal with overlapping pixels
if merge_strategy == 'keep_first':
rd_block = band.ReadAsArray(xoff_, yoff_, xsize, ysize)
mask = rd_block != invalid_value
pred = np.where(mask, rd_block, pred)
elif merge_strategy == 'keep_last':
pass
elif merge_strategy == 'accum':
prob = out['score_map']
prob = prob[:ysize, :xsize]
cache.update_block(0, xoff_, ysize, xsize, prob)
pred = cache.get_block(0, xoff_, ysize, xsize)
if xoff_ + xsize >= width:
cache.roll_cache()
# Get processed result
pred = overlap_processor.process_pred(out, xoff_, yoff_)
# Write to file
band.WriteArray(pred, xoff_, yoff_)
@ -301,5 +426,12 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
batch_data.clear()
batch_offsets.clear()
cnt += 1
if show_progress:
pb.update(1)
pb.set_description("{} out of {} blocks processed.".format(
cnt, num_blocks))
dst_data = None
logging.info("GeoTiff file saved in {}.".format(save_file))

Loading…
Cancel
Save