|
|
|
@ -20,6 +20,7 @@ from functools import reduce |
|
|
|
|
|
|
|
|
|
import paddlers |
|
|
|
|
import numpy as np |
|
|
|
|
import cv2 |
|
|
|
|
try: |
|
|
|
|
from osgeo import gdal |
|
|
|
|
except: |
|
|
|
@ -111,7 +112,7 @@ class QuadTree(object): |
|
|
|
|
bins = np.bincount(arr.ravel()) |
|
|
|
|
if len(bins) > IGN_CLS: |
|
|
|
|
bins = np.delete(bins, IGN_CLS) |
|
|
|
|
if bins.sum() == bins[bg_cls]: |
|
|
|
|
if len(bins) > bg_cls and bins.sum() == bins[bg_cls]: |
|
|
|
|
cls_info_row.append(None) |
|
|
|
|
else: |
|
|
|
|
cls_info_row.append(bins) |
|
|
|
@ -173,6 +174,29 @@ class QuadTree(object): |
|
|
|
|
q.append(child) |
|
|
|
|
return nodes |
|
|
|
|
|
|
|
|
|
def visualize_regions(self, im_path, save_path='./vis_quadtree.png'): |
|
|
|
|
im = paddlers.transforms.decode_image(im_path) |
|
|
|
|
if im.ndim == 2: |
|
|
|
|
im = np.stack([im] * 3, axis=2) |
|
|
|
|
elif im.ndim == 3: |
|
|
|
|
c = im.shape[2] |
|
|
|
|
if c < 3: |
|
|
|
|
raise ValueError( |
|
|
|
|
"For multi-spectral images, the number of bands should not be less than 3." |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
# Take first three bands as R, G, and B |
|
|
|
|
im = im[..., :3] |
|
|
|
|
else: |
|
|
|
|
raise ValueError("Unrecognized data format.") |
|
|
|
|
nodes = self.get_nodes(include_bg=True) |
|
|
|
|
vis = np.ascontiguousarray(im) |
|
|
|
|
for node in nodes: |
|
|
|
|
i, j, h, w = node.coords |
|
|
|
|
vis = cv2.rectangle(vis, (j, i), (j + w, i + h), (0, 0, 255), 2) |
|
|
|
|
cv2.imwrite(save_path, vis) |
|
|
|
|
return save_path |
|
|
|
|
|
|
|
|
|
def print_tree(self, node=None, level=0): |
|
|
|
|
if node is None: |
|
|
|
|
node = self.root |
|
|
|
@ -200,7 +224,8 @@ def extract_ms_patches(im_paths, |
|
|
|
|
target_class=None, |
|
|
|
|
max_level=None, |
|
|
|
|
include_bg=False, |
|
|
|
|
nonzero_ratio=None): |
|
|
|
|
nonzero_ratio=None, |
|
|
|
|
visualize=False): |
|
|
|
|
def _save_patch(src_path, i, j, h, w, subdir=None): |
|
|
|
|
src_path = osp.normpath(src_path) |
|
|
|
|
src_name, src_ext = osp.splitext(osp.basename(src_path)) |
|
|
|
@ -224,26 +249,33 @@ def extract_ms_patches(im_paths, |
|
|
|
|
raise ValueError("The mask image has more than 1 band.") |
|
|
|
|
print("Start building quad tree...") |
|
|
|
|
quad_tree.build_tree(mask_ds, bg_class) |
|
|
|
|
if visualize: |
|
|
|
|
print("Start drawing rectangles...") |
|
|
|
|
save_path = quad_tree.visualize_regions(im_paths[0]) |
|
|
|
|
print(f"The visualization result is saved in {save_path} .") |
|
|
|
|
print("Quad tree has been built. Now start collecting nodes...") |
|
|
|
|
nodes = quad_tree.get_nodes( |
|
|
|
|
tar_cls=target_class, max_level=max_level, include_bg=include_bg) |
|
|
|
|
print("Nodes collected. Saving patches...") |
|
|
|
|
for idx, node in enumerate(tqdm(nodes)): |
|
|
|
|
i, j, h, w = node.coords |
|
|
|
|
h = min(h, mask_ds.RasterYSize - i) |
|
|
|
|
w = min(w, mask_ds.RasterXSize - j) |
|
|
|
|
real_h = min(h, mask_ds.RasterYSize - i) |
|
|
|
|
real_w = min(w, mask_ds.RasterXSize - j) |
|
|
|
|
if real_h < h or real_w < w: |
|
|
|
|
# Skip incomplete patches |
|
|
|
|
continue |
|
|
|
|
is_valid = True |
|
|
|
|
if nonzero_ratio is not None: |
|
|
|
|
for src_path in im_paths: |
|
|
|
|
im_ds = gdal.Open(src_path) |
|
|
|
|
arr = im_ds.ReadAsArray(j, i, w, h) |
|
|
|
|
arr = im_ds.ReadAsArray(j, i, real_w, real_h) |
|
|
|
|
if np.count_nonzero(arr) / arr.size < nonzero_ratio: |
|
|
|
|
is_valid = False |
|
|
|
|
break |
|
|
|
|
if is_valid: |
|
|
|
|
for src_path in im_paths: |
|
|
|
|
_save_patch(src_path, i, j, h, w) |
|
|
|
|
_save_patch(mask_path, i, j, h, w, 'mask') |
|
|
|
|
_save_patch(src_path, i, j, real_h, real_w) |
|
|
|
|
_save_patch(mask_path, i, j, real_h, real_w, 'mask') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
@ -266,8 +298,11 @@ if __name__ == '__main__': |
|
|
|
|
help="Include patches that contains only background pixels.") |
|
|
|
|
parser.add_argument("--nonzero_ratio", type=float, default=None, \ |
|
|
|
|
help="Threshold for filtering out less informative patches.") |
|
|
|
|
parser.add_argument("--visualize", action='store_true', \ |
|
|
|
|
help="Visualize the quadtree.") |
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
extract_ms_patches(args.im_paths, args.mask_path, args.save_dir, |
|
|
|
|
args.min_patch_size, args.bg_class, args.target_class, |
|
|
|
|
args.max_level, args.include_bg, args.nonzero_ratio) |
|
|
|
|
args.max_level, args.include_bg, args.nonzero_ratio, |
|
|
|
|
args.visualize) |
|
|
|
|