diff --git a/docs/intro/data_prep.md b/docs/intro/data_prep.md index 6932e6b..07a5250 100644 --- a/docs/intro/data_prep.md +++ b/docs/intro/data_prep.md @@ -6,4 +6,6 @@ |-----|-----------|----------|----------| | 变化检测 | LEVIR-CD | https://justchenhao.github.io/LEVIR/ | [prepare_levircd.py](https://github.com/PaddlePaddle/PaddleRS/blob/develop/tools/prepare_dataset/prepare_levircd.py) | | 变化检测 | Season-varying | https://paperswithcode.com/dataset/cdd-dataset-season-varying | [prepare_svcd.py](https://github.com/PaddlePaddle/PaddleRS/blob/develop/tools/prepare_dataset/prepare_svcd.py) | +| 场景分类 | UC Merced | http://weegee.vision.ucmerced.edu/datasets/landuse.html | [prepare_ucmerced.py](https://github.com/PaddlePaddle/PaddleRS/blob/develop/tools/prepare_dataset/prepare_ucmerced.py) | | 目标检测 | RSOD | https://github.com/RSIA-LIESMARS-WHU/RSOD-Dataset- | [prepare_rsod](https://github.com/PaddlePaddle/PaddleRS/blob/develop/tools/prepare_dataset/prepare_rsod.py) | +| 图像分割 | iSAID | https://captain-whu.github.io/iSAID/ | [prepare_isaid](https://github.com/PaddlePaddle/PaddleRS/blob/develop/tools/prepare_dataset/prepare_isaid.py) | diff --git a/paddlers/tasks/utils/slider_predict.py b/paddlers/tasks/utils/slider_predict.py index 620997f..185bb83 100644 --- a/paddlers/tasks/utils/slider_predict.py +++ b/paddlers/tasks/utils/slider_predict.py @@ -299,11 +299,12 @@ def slider_predict(predict_func, raise ValueError( "`overlap` must be a tuple/list of length 2 or an integer.") + if block_size[0] <= overlap[0] or block_size[1] <= overlap[1]: + raise ValueError("`block_size` must be larger than `overlap`.") + step = np.array( block_size, dtype=np.int32) - np.array( overlap, dtype=np.int32) - if step[0] == 0 or step[1] == 0: - raise ValueError("`block_size` and `overlap` should not be equal.") if isinstance(img_file, tuple): if len(img_file) != 2: diff --git a/tools/prepare_dataset/common.py b/tools/prepare_dataset/common.py index 3848915..75f9c13 100644 --- a/tools/prepare_dataset/common.py +++ b/tools/prepare_dataset/common.py @@ -3,11 +3,13 @@ import random import copy import os import os.path as osp +import shutil from glob import glob from itertools import count from functools import partial from concurrent.futures import ThreadPoolExecutor +import numpy as np from skimage.io import imread, imsave from tqdm import tqdm @@ -57,20 +59,54 @@ def add_crop_options(parser): return parser -def crop_and_save(path, out_subdir, crop_size, stride): +def crop_and_save(path, + out_subdir, + crop_size, + stride, + keep_last=False, + pad=True, + pad_val=0): name, ext = osp.splitext(osp.basename(path)) out_subsubdir = osp.join(out_subdir, name) if not osp.exists(out_subsubdir): os.makedirs(out_subsubdir) img = imread(path) - w, h = img.shape[:2] + h, w = img.shape[:2] + if h < crop_size or w < crop_size: + if not pad: + raise ValueError( + f"`crop_size` must be smaller than image size. `crop_size` is {crop_size}, but got image size {h}x{w}." + ) + padded_img = np.full( + shape=(max(h, crop_size), max(w, crop_size)) + img.shape[2:], + fill_value=pad_val, + dtype=img.dtype) + padded_img[:h, :w] = img + h, w = padded_img.shape[:2] + img = padded_img counter = count() - for i in range(0, h - crop_size + 1, stride): - for j in range(0, w - crop_size + 1, stride): + for i in range(0, h, stride): + i_st = i + i_ed = i_st + crop_size + if i_ed > h: + if keep_last: + i_st = h - crop_size + i_ed = h + else: + continue + for j in range(0, w, stride): + j_st = j + j_ed = j_st + crop_size + if j_ed > w: + if keep_last: + j_st = w - crop_size + j_ed = w + else: + continue imsave( osp.join(out_subsubdir, '{}_{}{}'.format(name, next(counter), ext)), - img[i:i + crop_size, j:j + crop_size], + img[i_st:i_ed, j_st:j_ed], check_contrast=False) @@ -81,7 +117,8 @@ def crop_patches(crop_size, subsets=('train', 'val', 'test'), subdirs=('A', 'B', 'label'), glob_pattern='*', - max_workers=0): + max_workers=0, + keep_last=False): """ Crop patches from images in specific directories. @@ -102,6 +139,9 @@ def crop_patches(crop_size, Defaults to '*', which matches arbitrary file. max_workers (int, optional): Number of worker threads to perform the cropping operation. Deafults to 0. + keep_last (bool, optional): If True, keep the last patch in each row and each + column. The left and upper border of the last patch will be shifted to + ensure that size of the patch be `crop_size`. Defaults to False. """ if max_workers < 0: @@ -110,6 +150,8 @@ def crop_patches(crop_size, if subsets is None: subsets = ('', ) + print("Cropping patches...") + if max_workers == 0: for subset in subsets: for subdir in subdirs: @@ -122,7 +164,8 @@ def crop_patches(crop_size, p, out_subdir=out_subdir, crop_size=crop_size, - stride=stride) + stride=stride, + keep_last=keep_last) else: # Concurrently crop image patches with ThreadPoolExecutor(max_workers=max_workers) as executor: @@ -232,6 +275,25 @@ def link_dataset(src, dst): os.symlink(src, osp.join(dst, name), target_is_directory=True) +def copy_dataset(src, dst): + """ + Make a copy a dataset. + + Args: + src (str): Path of the original dataset. + dst (str): Path to copy to. + """ + + if osp.exists(dst) and not osp.isdir(dst): + raise ValueError(f"{dst} exists and is not a directory.") + elif not osp.exists(dst): + os.makedirs(dst) + + src = osp.realpath(src) + name = osp.basename(osp.normpath(src)) + shutil.copytree(src, osp.join(dst, name)) + + def random_split(samples, ratios=(0.7, 0.2, 0.1), inplace=True, diff --git a/tools/prepare_dataset/prepare_isaid.py b/tools/prepare_dataset/prepare_isaid.py index e69de29..037ed6c 100644 --- a/tools/prepare_dataset/prepare_isaid.py +++ b/tools/prepare_dataset/prepare_isaid.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python + +import os.path as osp +from glob import glob + +from PIL import Image +from tqdm import tqdm + +from common import (get_default_parser, add_crop_options, crop_patches, + create_file_list, copy_dataset, create_label_list, + get_path_tuples) + +# According to the official doc(https://github.com/CAPTAIN-WHU/iSAID_Devkit), +# the files should be organized as follows: +# +# iSAID +# ├── test +# │   └── images +# │   ├── P0006.png +# │   └── ... +# │   └── P0009.png +# ├── train +# │   └── images +# │   ├── P0002_instance_color_RGB.png +# │   ├── P0002_instance_id_RGB.png +# │   ├── P0002.png +# │   ├── ... +# │   ├── P0010_instance_color_RGB.png +# │   ├── P0010_instance_id_RGB.png +# │   └── P0010.png +# └── val +# └── images +# ├── P0003_instance_color_RGB.png +# ├── P0003_instance_id_RGB.png +# ├── P0003.png +# ├── ... +# ├── P0004_instance_color_RGB.png +# ├── P0004_instance_id_RGB.png +# └── P0004.png + +CLASSES = ('background', 'ship', 'storage_tank', 'baseball_diamond', + 'tennis_court', 'basketball_court', 'ground_track_field', 'bridge', + 'large_vehicle', 'small_vehicle', 'helicopter', 'swimming_pool', + 'roundabout', 'soccer_ball_field', 'plane', 'harbor') +# Refer to https://github.com/Z-Zheng/FarSeg/blob/master/data/isaid.py +COLOR_MAP = [[0, 0, 0], [0, 0, 63], [0, 191, 127], [0, 63, 0], [0, 63, 127], + [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127], + [0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 63, 63], [0, 127, 191], + [0, 127, 255], [0, 100, 155]] +SUBSETS = ('train', 'val') +SUBDIR = 'images' +FILE_LIST_PATTERN = "{subset}.txt" +LABEL_LIST_NAME = "labels.txt" +URL = "" + + +def flatten(nested_list): + flattened_list = [] + for ele in nested_list: + if isinstance(ele, list): + flattened_list.extend(flatten(ele)) + else: + flattened_list.append(ele) + return flattened_list + + +def rgb2mask(rgb): + palette = flatten(COLOR_MAP) + # Pad with zero + palette = palette + [0] * (256 * 3 - len(palette)) + ref = Image.new(mode='P', size=(1, 1)) + ref.putpalette(palette) + mask = rgb.quantize(palette=ref, dither=0) + return mask + + +if __name__ == '__main__': + parser = get_default_parser() + parser.add_argument( + '--crop_size', type=int, help="Size of cropped patches.", default=800) + parser.add_argument( + '--crop_stride', + type=int, + help="Stride of sliding windows when cropping patches. `crop_size` will be used only if `crop_size` is not None.", + default=600) + args = parser.parse_args() + + out_dir = osp.join(args.out_dataset_dir, + osp.basename(osp.normpath(args.in_dataset_dir))) + + assert args.crop_size is not None + # According to https://github.com/CAPTAIN-WHU/iSAID_Devkit/blob/master/preprocess/split.py + # Set keep_last=True + crop_patches( + args.crop_size, + args.crop_stride, + data_dir=args.in_dataset_dir, + out_dir=out_dir, + subsets=SUBSETS, + subdirs=(SUBDIR, ), + glob_pattern='*.png', + max_workers=8, + keep_last=True) + + for subset in SUBSETS: + path_tuples = [] + print(f"Processing {subset} labels...") + for im_subdir in tqdm(glob(osp.join(out_dir, subset, SUBDIR, "*/"))): + im_name = osp.basename(im_subdir[:-1]) # Strip trailing '/' + if '_' in im_name: + # Do not process labels + continue + mask_subdir = osp.join(out_dir, subset, SUBDIR, + im_name + '_instance_color_RGB') + for mask_path in glob(osp.join(mask_subdir, '*.png')): + # Convert RGB files to mask files (pseudo color) + rgb = Image.open(mask_path).convert('RGB') + mask = rgb2mask(rgb) + # Write to the original location + mask.save(mask_path) + path_tuples.extend( + get_path_tuples( + im_subdir, + mask_subdir, + glob_pattern='*.png', + data_dir=args.out_dataset_dir)) + path_tuples.sort() + + file_list = osp.join( + args.out_dataset_dir, FILE_LIST_PATTERN.format(subset=subset)) + create_file_list(file_list, path_tuples) + print(f"Write file list to {file_list}.") + + label_list = osp.join(args.out_dataset_dir, LABEL_LIST_NAME) + create_label_list(label_list, CLASSES) + print(f"Write label list to {label_list}.") diff --git a/tools/prepare_dataset/prepare_ucmerced.py b/tools/prepare_dataset/prepare_ucmerced.py index e69de29..3286371 100644 --- a/tools/prepare_dataset/prepare_ucmerced.py +++ b/tools/prepare_dataset/prepare_ucmerced.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python + +import random +import os.path as osp +from glob import iglob +from functools import reduce, partial + +from common import (get_default_parser, create_file_list, link_dataset, + random_split, create_label_list) + +CLASSES = ('agricultural', 'airplane', 'baseballdiamond', 'beach', 'buildings', + 'chaparral', 'denseresidential', 'forest', 'freeway', 'golfcourse', + 'harbor', 'intersection', 'mediumresidential', 'mobilehomepark', + 'overpass', 'parkinglot', 'river', 'runway', 'sparseresidential', + 'storagetanks', 'tenniscourt') +SUBSETS = ('train', 'val', 'test') +SUBDIRS = tuple(osp.join('Images', cls) for cls in CLASSES) +FILE_LIST_PATTERN = "{subset}.txt" +LABEL_LIST_NAME = "labels.txt" +URL = "" + +if __name__ == '__main__': + parser = get_default_parser() + parser.add_argument('--seed', type=int, default=None, help="Random seed.") + parser.add_argument( + '--ratios', + type=float, + nargs='+', + default=(0.7, 0.2, 0.1), + help="Ratios of each subset (train/val or train/val/test).") + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + + if len(args.ratios) not in (2, 3): + raise ValueError("Wrong number of ratios!") + + out_dir = osp.join(args.out_dataset_dir, + osp.basename(osp.normpath(args.in_dataset_dir))) + + link_dataset(args.in_dataset_dir, args.out_dataset_dir) + + splits_list = [] + for idx, (cls, subdir) in enumerate(zip(CLASSES, SUBDIRS)): + pairs = [] + for p in iglob(osp.join(out_dir, subdir, '*.tif')): + pair = (osp.relpath(p, args.out_dataset_dir), str(idx)) + pairs.append(pair) + splits = random_split(pairs, ratios=args.ratios) + splits_list.append(splits) + splits = map(partial(reduce, list.__add__), zip(*splits_list)) + + for subset, split in zip(SUBSETS, splits): + file_list = osp.join( + args.out_dataset_dir, FILE_LIST_PATTERN.format(subset=subset)) + create_file_list(file_list, split) + print(f"Write file list to {file_list}.") + + label_list = osp.join(args.out_dataset_dir, LABEL_LIST_NAME) + create_label_list(label_list, CLASSES) + print(f"Write label list to {label_list}.")