Refactor and fix bugs

own
Bobholamovic 2 years ago
parent 3939cbff16
commit 2f60abd899
  1. 212
      paddlers/tasks/utils/slider_predict.py

@ -19,6 +19,7 @@ from abc import ABCMeta, abstractmethod
from collections import Counter, defaultdict from collections import Counter, defaultdict
import numpy as np import numpy as np
from tqdm import tqdm
import paddlers.utils.logging as logging import paddlers.utils.logging as logging
@ -31,6 +32,7 @@ class Cache(metaclass=ABCMeta):
class SlowCache(Cache): class SlowCache(Cache):
def __init__(self): def __init__(self):
super(SlowCache, self).__init__()
self.cache = defaultdict(Counter) self.cache = defaultdict(Counter)
def push_pixel(self, i, j, l): def push_pixel(self, i, j, l):
@ -66,6 +68,7 @@ class SlowCache(Cache):
class ProbCache(Cache): class ProbCache(Cache):
def __init__(self, h, w, ch, cw, sh, sw, dtype=np.float32, order='c'): def __init__(self, h, w, ch, cw, sh, sw, dtype=np.float32, order='c'):
super(ProbCache, self).__init__()
self.cache = None self.cache = None
self.h = h self.h = h
self.w = w self.w = w
@ -116,20 +119,139 @@ class ProbCache(Cache):
self._alloc_memory(nc) self._alloc_memory(nc)
self.cache[i_st:i_st + h, j_st:j_st + w] += prob_map self.cache[i_st:i_st + h, j_st:j_st + w] += prob_map
def roll_cache(self): def roll_cache(self, shift):
if self.order == 'c': if self.order == 'c':
self.cache[:-self.sh] = self.cache[self.sh:] self.cache[:-shift] = self.cache[shift:]
self.cache[-self.sh:, :] = 0 self.cache[-shift:, :] = 0
elif self.order == 'f': elif self.order == 'f':
self.cache[:, :-self.sw] = self.cache[:, self.sw:] self.cache[:, :-shift] = self.cache[:, shift:]
self.cache[:, -self.sw:] = 0 self.cache[:, -shift:] = 0
def get_block(self, i_st, j_st, h, w): def get_block(self, i_st, j_st, h, w):
return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2) return np.argmax(self.cache[i_st:i_st + h, j_st:j_st + w], axis=2)
def slider_predict(predict_func, img_file, save_dir, block_size, overlap, class OverlapProcessor(metaclass=ABCMeta):
transforms, invalid_value, merge_strategy, batch_size): def __init__(self, h, w, ch, cw, sh, sw):
super(OverlapProcessor, self).__init__()
self.h = h
self.w = w
self.ch = ch
self.cw = cw
self.sh = sh
self.sw = sw
@abstractmethod
def process_pred(self, out, xoff, yoff):
pass
class KeepFirstProcessor(OverlapProcessor):
def __init__(self, h, w, ch, cw, sh, sw, ds, inval=255):
super(KeepFirstProcessor, self).__init__(h, w, ch, cw, sh, sw)
self.ds = ds
self.inval = inval
def process_pred(self, out, xoff, yoff):
pred = out['label_map']
pred = pred[:self.ch, :self.cw]
rd_block = self.ds.ReadAsArray(xoff, yoff, self.cw, self.ch)
mask = rd_block != self.inval
pred = np.where(mask, rd_block, pred)
return pred
class KeepLastProcessor(OverlapProcessor):
def process_pred(self, out, xoff, yoff):
pred = out['label_map']
pred = pred[:self.ch, :self.cw]
return pred
class AccumProcessor(OverlapProcessor):
def __init__(self,
h,
w,
ch,
cw,
sh,
sw,
dtype=np.float16,
assign_weight=True):
super(AccumProcessor, self).__init__(h, w, ch, cw, sh, sw)
self.cache = ProbCache(h, w, ch, cw, sh, sw, dtype=dtype, order='c')
self.prev_yoff = None
self.assign_weight = assign_weight
def process_pred(self, out, xoff, yoff):
if self.prev_yoff is not None and yoff != self.prev_yoff:
if yoff < self.prev_yoff:
raise RuntimeError
self.cache.roll_cache(yoff - self.prev_yoff)
pred = out['label_map']
pred = pred[:self.ch, :self.cw]
prob = out['score_map']
prob = prob[:self.ch, :self.cw]
if self.assign_weight:
prob = assign_border_weights(prob, border_ratio=0.25, inplace=True)
self.cache.update_block(0, xoff, self.ch, self.cw, prob)
pred = self.cache.get_block(0, xoff, self.ch, self.cw)
self.prev_yoff = yoff
return pred
def assign_border_weights(array, weight=0.5, border_ratio=0.25, inplace=True):
if not inplace:
array = array.copy()
h, w = array.shape[:2]
hm, wm = int(h * border_ratio), int(w * border_ratio)
array[:hm] *= weight
array[-hm:] *= weight
array[:, :wm] *= weight
array[:, -wm:] *= weight
return array
def read_block(ds,
xoff,
yoff,
xsize,
ysize,
tar_xsize=None,
tar_ysize=None,
pad_val=0):
if tar_xsize is None:
tar_xsize = xsize
if tar_ysize is None:
tar_ysize = ysize
# Read data from dataset
block = ds.ReadAsArray(xoff, yoff, xsize, ysize)
c, real_ysize, real_xsize = block.shape
assert real_ysize == ysize and real_xsize == xsize
# [c, h, w] -> [h, w, c]
block = block.transpose((1, 2, 0))
if (real_ysize, real_xsize) != (tar_ysize, tar_xsize):
if real_ysize >= tar_ysize or real_xsize >= tar_xsize:
raise ValueError
padded_block = np.full(
(tar_ysize, tar_xsize, c), fill_value=pad_val, dtype=block.dtype)
# Fill
padded_block[:real_ysize, :real_xsize] = block
return padded_block
else:
return block
def slider_predict(predict_func,
img_file,
save_dir,
block_size,
overlap,
transforms,
invalid_value,
merge_strategy,
batch_size,
show_progress=False):
""" """
Do inference using sliding windows. Do inference using sliding windows.
@ -153,6 +275,8 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
traversal order, respectively. 'accum' means determining the class traversal order, respectively. 'accum' means determining the class
of an overlapping pixel according to accumulated probabilities. of an overlapping pixel according to accumulated probabilities.
batch_size (int): Batch size used in inference. batch_size (int): Batch size used in inference.
show_progress (bool, optional): Whether to show prediction progress with a
progress bar. Defaults to True.
""" """
try: try:
@ -175,10 +299,6 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
raise ValueError( raise ValueError(
"`overlap` must be a tuple/list of length 2 or an integer.") "`overlap` must be a tuple/list of length 2 or an integer.")
if merge_strategy not in ('keep_first', 'keep_last', 'accum'):
raise ValueError("{} is not a supported stragegy for block merging.".
format(merge_strategy))
step = np.array( step = np.array(
block_size, dtype=np.int32) - np.array( block_size, dtype=np.int32) - np.array(
overlap, dtype=np.int32) overlap, dtype=np.int32)
@ -234,29 +354,50 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
# When there is no overlap or the whole image is used as input, # When there is no overlap or the whole image is used as input,
# use 'keep_last' strategy as it introduces least overheads # use 'keep_last' strategy as it introduces least overheads
merge_strategy = 'keep_last' merge_strategy = 'keep_last'
if merge_strategy == 'accum':
cache = ProbCache(height, width, *block_size[::-1], *step[::-1])
if merge_strategy == 'keep_first':
overlap_processor = KeepFirstProcessor(
height,
width,
*block_size[::-1],
*step[::-1],
band,
inval=invalid_value)
elif merge_strategy == 'keep_last':
overlap_processor = KeepLastProcessor(height, width, *block_size[::-1],
*step[::-1])
elif merge_strategy == 'accum':
overlap_processor = AccumProcessor(height, width, *block_size[::-1],
*step[::-1])
else:
raise ValueError("{} is not a supported stragegy for block merging.".
format(merge_strategy))
xsize, ysize = block_size
num_blocks = math.ceil(height / step[1]) * math.ceil(width / step[0])
cnt = 0
if show_progress:
pb = tqdm(total=num_blocks)
batch_data = [] batch_data = []
batch_offsets = [] batch_offsets = []
for yoff in range(0, height, step[1]): for yoff in range(0, height, step[1]):
for xoff in range(0, width, step[0]): for xoff in range(0, width, step[0]):
xsize, ysize = block_size
if xoff + xsize > width: if xoff + xsize > width:
xoff = width - xsize xoff = width - xsize
is_end_of_row = True
else:
is_end_of_row = False
if yoff + ysize > height: if yoff + ysize > height:
yoff = height - ysize yoff = height - ysize
is_end_of_col = True
else:
is_end_of_col = False
is_end_of_col = yoff + ysize >= height # Read
is_end_of_row = xoff + xsize >= width im = read_block(src_data, xoff, yoff, xsize, ysize)
# Read and fill
im = src_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose(
(1, 2, 0))
if isinstance(img_file, tuple): if isinstance(img_file, tuple):
im2 = src2_data.ReadAsArray(xoff, yoff, xsize, ysize).transpose( im2 = read_block(src2_data, xoff, yoff, xsize, ysize)
(1, 2, 0))
batch_data.append((im, im2)) batch_data.append((im, im2))
else: else:
batch_data.append(im) batch_data.append(im)
@ -276,24 +417,8 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
batch_out = predict_func(batch_data, transforms=transforms) batch_out = predict_func(batch_data, transforms=transforms)
for out, (xoff_, yoff_) in zip(batch_out, batch_offsets): for out, (xoff_, yoff_) in zip(batch_out, batch_offsets):
pred = out['label_map'].astype('uint8') # Get processed result
pred = pred[:ysize, :xsize] pred = overlap_processor.process_pred(out, xoff_, yoff_)
# Deal with overlapping pixels
if merge_strategy == 'keep_first':
rd_block = band.ReadAsArray(xoff_, yoff_, xsize, ysize)
mask = rd_block != invalid_value
pred = np.where(mask, rd_block, pred)
elif merge_strategy == 'keep_last':
pass
elif merge_strategy == 'accum':
prob = out['score_map']
prob = prob[:ysize, :xsize]
cache.update_block(0, xoff_, ysize, xsize, prob)
pred = cache.get_block(0, xoff_, ysize, xsize)
if xoff_ + xsize >= width:
cache.roll_cache()
# Write to file # Write to file
band.WriteArray(pred, xoff_, yoff_) band.WriteArray(pred, xoff_, yoff_)
@ -301,5 +426,12 @@ def slider_predict(predict_func, img_file, save_dir, block_size, overlap,
batch_data.clear() batch_data.clear()
batch_offsets.clear() batch_offsets.clear()
cnt += 1
if show_progress:
pb.update(1)
pb.set_description("{} out of {} blocks processed.".format(
cnt, num_blocks))
dst_data = None dst_data = None
logging.info("GeoTiff file saved in {}.".format(save_file)) logging.info("GeoTiff file saved in {}.".format(save_file))

Loading…
Cancel
Save