commit
66894e3adb
11 changed files with 442 additions and 84 deletions
@ -0,0 +1,23 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import tempfile |
||||
|
||||
from testing_utils import run_script |
||||
|
||||
if __name__ == '__main__': |
||||
with tempfile.TemporaryDirectory() as td: |
||||
run_script( |
||||
f"python extract_ms_patches.py --im_paths ../tests/data/ssst/multispectral.tif --mask_path ../tests/data/ssst/multiclass_gt2.png --min_patch_size 32 --save_dir {td}", |
||||
wd="../tools") |
@ -0,0 +1,308 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import os |
||||
import os.path as osp |
||||
import argparse |
||||
from collections import deque |
||||
from functools import reduce |
||||
|
||||
import paddlers |
||||
import numpy as np |
||||
import cv2 |
||||
try: |
||||
from osgeo import gdal |
||||
except: |
||||
import gdal |
||||
from tqdm import tqdm |
||||
|
||||
from utils import time_it |
||||
|
||||
IGN_CLS = 255 |
||||
FMT = "im_{idx}{ext}" |
||||
|
||||
|
||||
class QuadTreeNode(object): |
||||
def __init__(self, i, j, h, w, level, cls_info=None): |
||||
super().__init__() |
||||
self.i = i |
||||
self.j = j |
||||
self.h = h |
||||
self.w = w |
||||
self.level = level |
||||
self.cls_info = cls_info |
||||
self.reset_children() |
||||
|
||||
@property |
||||
def area(self): |
||||
return self.h * self.w |
||||
|
||||
@property |
||||
def is_bg_node(self): |
||||
return self.cls_info is None |
||||
|
||||
@property |
||||
def coords(self): |
||||
return (self.i, self.j, self.h, self.w) |
||||
|
||||
def get_cls_cnt(self, cls): |
||||
if self.cls_info is None or cls >= len(self.cls_info): |
||||
return 0 |
||||
return self.cls_info[cls] |
||||
|
||||
def get_children(self): |
||||
for child in self.children: |
||||
if child is not None: |
||||
yield child |
||||
|
||||
def reset_children(self): |
||||
self.children = [None, None, None, None] |
||||
|
||||
def __repr__(self): |
||||
return f"{self.__class__.__name__}({self.i}, {self.j}, {self.h}, {self.w})" |
||||
|
||||
|
||||
class QuadTree(object): |
||||
def __init__(self, min_blk_size=256): |
||||
super().__init__() |
||||
self.min_blk_size = min_blk_size |
||||
self.h = None |
||||
self.w = None |
||||
self.root = None |
||||
|
||||
def build_tree(self, mask_band, bg_cls=0): |
||||
cls_info_table = self.preprocess(mask_band, bg_cls) |
||||
n_rows = len(cls_info_table) |
||||
if n_rows == 0: |
||||
return None |
||||
n_cols = len(cls_info_table[0]) |
||||
self.root = self._build_tree(cls_info_table, 0, n_rows - 1, 0, |
||||
n_cols - 1, 0) |
||||
return self.root |
||||
|
||||
def preprocess(self, mask_ds, bg_cls): |
||||
h, w = mask_ds.RasterYSize, mask_ds.RasterXSize |
||||
s = self.min_blk_size |
||||
if s >= h or s >= w: |
||||
raise ValueError("`min_blk_size` must be smaller than image size.") |
||||
cls_info_table = [] |
||||
for i in range(0, h, s): |
||||
cls_info_row = [] |
||||
for j in range(0, w, s): |
||||
if i + s > h: |
||||
ch = h - i |
||||
else: |
||||
ch = s |
||||
if j + s > w: |
||||
cw = w - j |
||||
else: |
||||
cw = s |
||||
arr = mask_ds.ReadAsArray(j, i, cw, ch) |
||||
bins = np.bincount(arr.ravel()) |
||||
if len(bins) > IGN_CLS: |
||||
bins = np.delete(bins, IGN_CLS) |
||||
if len(bins) > bg_cls and bins.sum() == bins[bg_cls]: |
||||
cls_info_row.append(None) |
||||
else: |
||||
cls_info_row.append(bins) |
||||
cls_info_table.append(cls_info_row) |
||||
return cls_info_table |
||||
|
||||
def _build_tree(self, cls_info_table, i_st, i_ed, j_st, j_ed, level=0): |
||||
if i_ed < i_st or j_ed < j_st: |
||||
return None |
||||
|
||||
i = i_st * self.min_blk_size |
||||
j = j_st * self.min_blk_size |
||||
h = (i_ed - i_st + 1) * self.min_blk_size |
||||
w = (j_ed - j_st + 1) * self.min_blk_size |
||||
|
||||
if i_ed == i_st and j_ed == j_st: |
||||
return QuadTreeNode(i, j, h, w, level, cls_info_table[i_st][j_st]) |
||||
|
||||
i_mid = (i_ed + i_st) // 2 |
||||
j_mid = (j_ed + j_st) // 2 |
||||
|
||||
root = QuadTreeNode(i, j, h, w, level) |
||||
|
||||
root.children[0] = self._build_tree(cls_info_table, i_st, i_mid, j_st, |
||||
j_mid, level + 1) |
||||
root.children[1] = self._build_tree(cls_info_table, i_st, i_mid, |
||||
j_mid + 1, j_ed, level + 1) |
||||
root.children[2] = self._build_tree(cls_info_table, i_mid + 1, i_ed, |
||||
j_st, j_mid, level + 1) |
||||
root.children[3] = self._build_tree(cls_info_table, i_mid + 1, i_ed, |
||||
j_mid + 1, j_ed, level + 1) |
||||
|
||||
bins_list = [ |
||||
node.cls_info for node in root.get_children() |
||||
if node.cls_info is not None |
||||
] |
||||
if len(bins_list) > 0: |
||||
merged_bins = reduce(merge_bins, bins_list) |
||||
root.cls_info = merged_bins |
||||
else: |
||||
# Merge nodes |
||||
root.reset_children() |
||||
|
||||
return root |
||||
|
||||
def get_nodes(self, tar_cls=None, max_level=None, include_bg=True): |
||||
nodes = [] |
||||
q = deque() |
||||
q.append(self.root) |
||||
while q: |
||||
node = q.popleft() |
||||
if max_level is None or node.level < max_level: |
||||
for child in node.get_children(): |
||||
if not include_bg and child.is_bg_node: |
||||
continue |
||||
if tar_cls is not None and child.get_cls_cnt(tar_cls) == 0: |
||||
continue |
||||
nodes.append(child) |
||||
q.append(child) |
||||
return nodes |
||||
|
||||
def visualize_regions(self, im_path, save_path='./vis_quadtree.png'): |
||||
im = paddlers.transforms.decode_image(im_path) |
||||
if im.ndim == 2: |
||||
im = np.stack([im] * 3, axis=2) |
||||
elif im.ndim == 3: |
||||
c = im.shape[2] |
||||
if c < 3: |
||||
raise ValueError( |
||||
"For multi-spectral images, the number of bands should not be less than 3." |
||||
) |
||||
else: |
||||
# Take first three bands as R, G, and B |
||||
im = im[..., :3] |
||||
else: |
||||
raise ValueError("Unrecognized data format.") |
||||
nodes = self.get_nodes(include_bg=True) |
||||
vis = np.ascontiguousarray(im) |
||||
for node in nodes: |
||||
i, j, h, w = node.coords |
||||
vis = cv2.rectangle(vis, (j, i), (j + w, i + h), (0, 0, 255), 2) |
||||
cv2.imwrite(save_path, vis) |
||||
return save_path |
||||
|
||||
def print_tree(self, node=None, level=0): |
||||
if node is None: |
||||
node = self.root |
||||
print(' ' * level + '-', node) |
||||
for child in node.get_children(): |
||||
self.print_tree(child, level + 1) |
||||
|
||||
|
||||
def merge_bins(bins1, bins2): |
||||
if len(bins1) < len(bins2): |
||||
return merge_bins(bins2, bins1) |
||||
elif len(bins1) == len(bins2): |
||||
return bins1 + bins2 |
||||
else: |
||||
return bins1 + np.concatenate( |
||||
[bins2, np.zeros(len(bins1) - len(bins2))]) |
||||
|
||||
|
||||
@time_it |
||||
def extract_ms_patches(im_paths, |
||||
mask_path, |
||||
save_dir, |
||||
min_patch_size=256, |
||||
bg_class=0, |
||||
target_class=None, |
||||
max_level=None, |
||||
include_bg=False, |
||||
nonzero_ratio=None, |
||||
visualize=False): |
||||
def _save_patch(src_path, i, j, h, w, subdir=None): |
||||
src_path = osp.normpath(src_path) |
||||
src_name, src_ext = osp.splitext(osp.basename(src_path)) |
||||
subdir = subdir if subdir is not None else src_name |
||||
dst_dir = osp.join(save_dir, subdir) |
||||
if not osp.exists(dst_dir): |
||||
os.makedirs(dst_dir) |
||||
dst_name = FMT.format(idx=idx, ext=src_ext) |
||||
dst_path = osp.join(dst_dir, dst_name) |
||||
gdal.Translate(dst_path, src_path, srcWin=(j, i, w, h)) |
||||
return dst_path |
||||
|
||||
if nonzero_ratio is not None: |
||||
print( |
||||
"`nonzero_ratio` is not None. More time will be consumed to filter out all-zero patches." |
||||
) |
||||
|
||||
mask_ds = gdal.Open(mask_path) |
||||
quad_tree = QuadTree(min_blk_size=min_patch_size) |
||||
if mask_ds.RasterCount != 1: |
||||
raise ValueError("The mask image has more than 1 band.") |
||||
print("Start building quad tree...") |
||||
quad_tree.build_tree(mask_ds, bg_class) |
||||
if visualize: |
||||
print("Start drawing rectangles...") |
||||
save_path = quad_tree.visualize_regions(im_paths[0]) |
||||
print(f"The visualization result is saved in {save_path} .") |
||||
print("Quad tree has been built. Now start collecting nodes...") |
||||
nodes = quad_tree.get_nodes( |
||||
tar_cls=target_class, max_level=max_level, include_bg=include_bg) |
||||
print("Nodes collected. Saving patches...") |
||||
for idx, node in enumerate(tqdm(nodes)): |
||||
i, j, h, w = node.coords |
||||
real_h = min(h, mask_ds.RasterYSize - i) |
||||
real_w = min(w, mask_ds.RasterXSize - j) |
||||
if real_h < h or real_w < w: |
||||
# Skip incomplete patches |
||||
continue |
||||
is_valid = True |
||||
if nonzero_ratio is not None: |
||||
for src_path in im_paths: |
||||
im_ds = gdal.Open(src_path) |
||||
arr = im_ds.ReadAsArray(j, i, real_w, real_h) |
||||
if np.count_nonzero(arr) / arr.size < nonzero_ratio: |
||||
is_valid = False |
||||
break |
||||
if is_valid: |
||||
for src_path in im_paths: |
||||
_save_patch(src_path, i, j, real_h, real_w) |
||||
_save_patch(mask_path, i, j, real_h, real_w, 'mask') |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
parser = argparse.ArgumentParser() |
||||
parser.add_argument("--im_paths", type=str, required=True, nargs='+', \ |
||||
help="Path of images. Different images must have unique file names.") |
||||
parser.add_argument("--mask_path", type=str, required=True, \ |
||||
help="Path of mask.") |
||||
parser.add_argument("--save_dir", type=str, default='output', \ |
||||
help="Path to save the extracted patches.") |
||||
parser.add_argument("--min_patch_size", type=int, default=256, \ |
||||
help="Minimum patch size (height and width).") |
||||
parser.add_argument("--bg_class", type=int, default=0, \ |
||||
help="Index of the background category.") |
||||
parser.add_argument("--target_class", type=int, default=None, \ |
||||
help="Index of the category of interest.") |
||||
parser.add_argument("--max_level", type=int, default=None, \ |
||||
help="Maximum level of hierarchical patches.") |
||||
parser.add_argument("--include_bg", action='store_true', \ |
||||
help="Include patches that contains only background pixels.") |
||||
parser.add_argument("--nonzero_ratio", type=float, default=None, \ |
||||
help="Threshold for filtering out less informative patches.") |
||||
parser.add_argument("--visualize", action='store_true', \ |
||||
help="Visualize the quadtree.") |
||||
args = parser.parse_args() |
||||
|
||||
extract_ms_patches(args.im_paths, args.mask_path, args.save_dir, |
||||
args.min_patch_size, args.bg_class, args.target_class, |
||||
args.max_level, args.include_bg, args.nonzero_ratio, |
||||
args.visualize) |
Loading…
Reference in new issue