diff --git a/tools/extract_ms_patches.py b/tools/extract_ms_patches.py index 2d3e28d..d726ceb 100644 --- a/tools/extract_ms_patches.py +++ b/tools/extract_ms_patches.py @@ -20,6 +20,7 @@ from functools import reduce import paddlers import numpy as np +import cv2 try: from osgeo import gdal except: @@ -111,7 +112,7 @@ class QuadTree(object): bins = np.bincount(arr.ravel()) if len(bins) > IGN_CLS: bins = np.delete(bins, IGN_CLS) - if bins.sum() == bins[bg_cls]: + if len(bins) > bg_cls and bins.sum() == bins[bg_cls]: cls_info_row.append(None) else: cls_info_row.append(bins) @@ -173,6 +174,29 @@ class QuadTree(object): q.append(child) return nodes + def visualize_regions(self, im_path, save_path='./vis_quadtree.png'): + im = paddlers.transforms.decode_image(im_path) + if im.ndim == 2: + im = np.stack([im] * 3, axis=2) + elif im.ndim == 3: + c = im.shape[2] + if c < 3: + raise ValueError( + "For multi-spectral images, the number of bands should not be less than 3." + ) + else: + # Take first three bands as R, G, and B + im = im[..., :3] + else: + raise ValueError("Unrecognized data format.") + nodes = self.get_nodes(include_bg=True) + vis = np.ascontiguousarray(im) + for node in nodes: + i, j, h, w = node.coords + vis = cv2.rectangle(vis, (j, i), (j + w, i + h), (0, 0, 255), 2) + cv2.imwrite(save_path, vis) + return save_path + def print_tree(self, node=None, level=0): if node is None: node = self.root @@ -200,7 +224,8 @@ def extract_ms_patches(im_paths, target_class=None, max_level=None, include_bg=False, - nonzero_ratio=None): + nonzero_ratio=None, + visualize=False): def _save_patch(src_path, i, j, h, w, subdir=None): src_path = osp.normpath(src_path) src_name, src_ext = osp.splitext(osp.basename(src_path)) @@ -224,26 +249,33 @@ def extract_ms_patches(im_paths, raise ValueError("The mask image has more than 1 band.") print("Start building quad tree...") quad_tree.build_tree(mask_ds, bg_class) + if visualize: + print("Start drawing rectangles...") + save_path = quad_tree.visualize_regions(im_paths[0]) + print(f"The visualization result is saved in {save_path} .") print("Quad tree has been built. Now start collecting nodes...") nodes = quad_tree.get_nodes( tar_cls=target_class, max_level=max_level, include_bg=include_bg) print("Nodes collected. Saving patches...") for idx, node in enumerate(tqdm(nodes)): i, j, h, w = node.coords - h = min(h, mask_ds.RasterYSize - i) - w = min(w, mask_ds.RasterXSize - j) + real_h = min(h, mask_ds.RasterYSize - i) + real_w = min(w, mask_ds.RasterXSize - j) + if real_h < h or real_w < w: + # Skip incomplete patches + continue is_valid = True if nonzero_ratio is not None: for src_path in im_paths: im_ds = gdal.Open(src_path) - arr = im_ds.ReadAsArray(j, i, w, h) + arr = im_ds.ReadAsArray(j, i, real_w, real_h) if np.count_nonzero(arr) / arr.size < nonzero_ratio: is_valid = False break if is_valid: for src_path in im_paths: - _save_patch(src_path, i, j, h, w) - _save_patch(mask_path, i, j, h, w, 'mask') + _save_patch(src_path, i, j, real_h, real_w) + _save_patch(mask_path, i, j, real_h, real_w, 'mask') if __name__ == '__main__': @@ -266,8 +298,11 @@ if __name__ == '__main__': help="Include patches that contains only background pixels.") parser.add_argument("--nonzero_ratio", type=float, default=None, \ help="Threshold for filtering out less informative patches.") + parser.add_argument("--visualize", action='store_true', \ + help="Visualize the quadtree.") args = parser.parse_args() extract_ms_patches(args.im_paths, args.mask_path, args.save_dir, args.min_patch_size, args.bg_class, args.target_class, - args.max_level, args.include_bg, args.nonzero_ratio) + args.max_level, args.include_bg, args.nonzero_ratio, + args.visualize)