You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

308 lines
11 KiB

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
import argparse
from collections import deque
from functools import reduce
import paddlers
import numpy as np
import cv2
try:
from osgeo import gdal
except:
import gdal
from tqdm import tqdm
from utils import time_it
IGN_CLS = 255
FMT = "im_{idx}{ext}"
class QuadTreeNode(object):
def __init__(self, i, j, h, w, level, cls_info=None):
super().__init__()
self.i = i
self.j = j
self.h = h
self.w = w
self.level = level
self.cls_info = cls_info
self.reset_children()
@property
def area(self):
return self.h * self.w
@property
def is_bg_node(self):
return self.cls_info is None
@property
def coords(self):
return (self.i, self.j, self.h, self.w)
def get_cls_cnt(self, cls):
if self.cls_info is None or cls >= len(self.cls_info):
return 0
return self.cls_info[cls]
def get_children(self):
for child in self.children:
if child is not None:
yield child
def reset_children(self):
self.children = [None, None, None, None]
def __repr__(self):
return f"{self.__class__.__name__}({self.i}, {self.j}, {self.h}, {self.w})"
class QuadTree(object):
def __init__(self, min_blk_size=256):
super().__init__()
self.min_blk_size = min_blk_size
self.h = None
self.w = None
self.root = None
def build_tree(self, mask_band, bg_cls=0):
cls_info_table = self.preprocess(mask_band, bg_cls)
n_rows = len(cls_info_table)
if n_rows == 0:
return None
n_cols = len(cls_info_table[0])
self.root = self._build_tree(cls_info_table, 0, n_rows - 1, 0,
n_cols - 1, 0)
return self.root
def preprocess(self, mask_ds, bg_cls):
h, w = mask_ds.RasterYSize, mask_ds.RasterXSize
s = self.min_blk_size
if s >= h or s >= w:
raise ValueError("`min_blk_size` must be smaller than image size.")
cls_info_table = []
for i in range(0, h, s):
cls_info_row = []
for j in range(0, w, s):
if i + s > h:
ch = h - i
else:
ch = s
if j + s > w:
cw = w - j
else:
cw = s
arr = mask_ds.ReadAsArray(j, i, cw, ch)
bins = np.bincount(arr.ravel())
if len(bins) > IGN_CLS:
bins = np.delete(bins, IGN_CLS)
if len(bins) > bg_cls and bins.sum() == bins[bg_cls]:
cls_info_row.append(None)
else:
cls_info_row.append(bins)
cls_info_table.append(cls_info_row)
return cls_info_table
def _build_tree(self, cls_info_table, i_st, i_ed, j_st, j_ed, level=0):
if i_ed < i_st or j_ed < j_st:
return None
i = i_st * self.min_blk_size
j = j_st * self.min_blk_size
h = (i_ed - i_st + 1) * self.min_blk_size
w = (j_ed - j_st + 1) * self.min_blk_size
if i_ed == i_st and j_ed == j_st:
return QuadTreeNode(i, j, h, w, level, cls_info_table[i_st][j_st])
i_mid = (i_ed + i_st) // 2
j_mid = (j_ed + j_st) // 2
root = QuadTreeNode(i, j, h, w, level)
root.children[0] = self._build_tree(cls_info_table, i_st, i_mid, j_st,
j_mid, level + 1)
root.children[1] = self._build_tree(cls_info_table, i_st, i_mid,
j_mid + 1, j_ed, level + 1)
root.children[2] = self._build_tree(cls_info_table, i_mid + 1, i_ed,
j_st, j_mid, level + 1)
root.children[3] = self._build_tree(cls_info_table, i_mid + 1, i_ed,
j_mid + 1, j_ed, level + 1)
bins_list = [
node.cls_info for node in root.get_children()
if node.cls_info is not None
]
if len(bins_list) > 0:
merged_bins = reduce(merge_bins, bins_list)
root.cls_info = merged_bins
else:
# Merge nodes
root.reset_children()
return root
def get_nodes(self, tar_cls=None, max_level=None, include_bg=True):
nodes = []
q = deque()
q.append(self.root)
while q:
node = q.popleft()
if max_level is None or node.level < max_level:
for child in node.get_children():
if not include_bg and child.is_bg_node:
continue
if tar_cls is not None and child.get_cls_cnt(tar_cls) == 0:
continue
nodes.append(child)
q.append(child)
return nodes
def visualize_regions(self, im_path, save_path='./vis_quadtree.png'):
im = paddlers.transforms.decode_image(im_path)
if im.ndim == 2:
im = np.stack([im] * 3, axis=2)
elif im.ndim == 3:
c = im.shape[2]
if c < 3:
raise ValueError(
"For multi-spectral images, the number of bands should not be less than 3."
)
else:
# Take first three bands as R, G, and B
im = im[..., :3]
else:
raise ValueError("Unrecognized data format.")
nodes = self.get_nodes(include_bg=True)
vis = np.ascontiguousarray(im)
for node in nodes:
i, j, h, w = node.coords
vis = cv2.rectangle(vis, (j, i), (j + w, i + h), (255, 0, 0), 2)
cv2.imwrite(save_path, vis[..., ::-1])
return save_path
def print_tree(self, node=None, level=0):
if node is None:
node = self.root
print(' ' * level + '-', node)
for child in node.get_children():
self.print_tree(child, level + 1)
def merge_bins(bins1, bins2):
if len(bins1) < len(bins2):
return merge_bins(bins2, bins1)
elif len(bins1) == len(bins2):
return bins1 + bins2
else:
return bins1 + np.concatenate(
[bins2, np.zeros(len(bins1) - len(bins2))])
@time_it
def extract_ms_patches(im_paths,
mask_path,
save_dir,
min_patch_size=256,
bg_class=0,
target_class=None,
max_level=None,
include_bg=False,
nonzero_ratio=None,
visualize=False):
def _save_patch(src_path, i, j, h, w, subdir=None):
src_path = osp.normpath(src_path)
src_name, src_ext = osp.splitext(osp.basename(src_path))
subdir = subdir if subdir is not None else src_name
dst_dir = osp.join(save_dir, subdir)
if not osp.exists(dst_dir):
os.makedirs(dst_dir)
dst_name = FMT.format(idx=idx, ext=src_ext)
dst_path = osp.join(dst_dir, dst_name)
gdal.Translate(dst_path, src_path, srcWin=(j, i, w, h))
return dst_path
if nonzero_ratio is not None:
print(
"`nonzero_ratio` is not None. More time will be consumed to filter out all-zero patches."
)
mask_ds = gdal.Open(mask_path)
quad_tree = QuadTree(min_blk_size=min_patch_size)
if mask_ds.RasterCount != 1:
raise ValueError("The mask image has more than 1 band.")
print("Start building quad tree...")
quad_tree.build_tree(mask_ds, bg_class)
if visualize:
print("Start drawing rectangles...")
save_path = quad_tree.visualize_regions(im_paths[0])
print(f"The visualization result is saved in {save_path} .")
print("Quad tree has been built. Now start collecting nodes...")
nodes = quad_tree.get_nodes(
tar_cls=target_class, max_level=max_level, include_bg=include_bg)
print("Nodes collected. Saving patches...")
for idx, node in enumerate(tqdm(nodes)):
i, j, h, w = node.coords
real_h = min(h, mask_ds.RasterYSize - i)
real_w = min(w, mask_ds.RasterXSize - j)
if real_h < h or real_w < w:
# Skip incomplete patches
continue
is_valid = True
if nonzero_ratio is not None:
for src_path in im_paths:
im_ds = gdal.Open(src_path)
arr = im_ds.ReadAsArray(j, i, real_w, real_h)
if np.count_nonzero(arr) / arr.size < nonzero_ratio:
is_valid = False
break
if is_valid:
for src_path in im_paths:
_save_patch(src_path, i, j, real_h, real_w)
_save_patch(mask_path, i, j, real_h, real_w, 'mask')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--im_paths", type=str, required=True, nargs='+', \
help="Path of images. Different images must have unique file names.")
parser.add_argument("--mask_path", type=str, required=True, \
help="Path of mask.")
parser.add_argument("--save_dir", type=str, default='output', \
help="Path to save the extracted patches.")
parser.add_argument("--min_patch_size", type=int, default=256, \
help="Minimum patch size (height and width).")
parser.add_argument("--bg_class", type=int, default=0, \
help="Index of the background category.")
parser.add_argument("--target_class", type=int, default=None, \
help="Index of the category of interest.")
parser.add_argument("--max_level", type=int, default=None, \
help="Maximum level of hierarchical patches.")
parser.add_argument("--include_bg", action='store_true', \
help="Include patches that contains only background pixels.")
parser.add_argument("--nonzero_ratio", type=float, default=None, \
help="Threshold for filtering out less informative patches.")
parser.add_argument("--visualize", action='store_true', \
help="Visualize the quadtree.")
args = parser.parse_args()
extract_ms_patches(args.im_paths, args.mask_path, args.save_dir,
args.min_patch_size, args.bg_class, args.target_class,
args.max_level, args.include_bg, args.nonzero_ratio,
args.visualize)