Add prepare_isaid.py

own
Bobholamovic 2 years ago
parent d530252c69
commit 62f34b3b68
  1. 76
      tools/prepare_dataset/common.py
  2. 136
      tools/prepare_dataset/prepare_isaid.py

@ -3,11 +3,13 @@ import random
import copy
import os
import os.path as osp
import shutil
from glob import glob
from itertools import count
from functools import partial
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from skimage.io import imread, imsave
from tqdm import tqdm
@ -57,20 +59,54 @@ def add_crop_options(parser):
return parser
def crop_and_save(path, out_subdir, crop_size, stride):
def crop_and_save(path,
out_subdir,
crop_size,
stride,
keep_last=False,
pad=True,
pad_val=0):
name, ext = osp.splitext(osp.basename(path))
out_subsubdir = osp.join(out_subdir, name)
if not osp.exists(out_subsubdir):
os.makedirs(out_subsubdir)
img = imread(path)
w, h = img.shape[:2]
h, w = img.shape[:2]
if h < crop_size or w < crop_size:
if not pad:
raise ValueError(
f"`crop_size` must be smaller than image size. `crop_size` is {crop_size}, but got image size {h}x{w}."
)
padded_img = np.full(
shape=(max(h, crop_size), max(w, crop_size)) + img.shape[2:],
fill_value=pad_val,
dtype=img.dtype)
padded_img[:h, :w] = img
h, w = padded_img.shape[:2]
img = padded_img
counter = count()
for i in range(0, h - crop_size + 1, stride):
for j in range(0, w - crop_size + 1, stride):
for i in range(0, h, stride):
i_st = i
i_ed = i_st + crop_size
if i_ed > h:
if keep_last:
i_st = h - crop_size
i_ed = h
else:
continue
for j in range(0, w, stride):
j_st = j
j_ed = j_st + crop_size
if j_ed > w:
if keep_last:
j_st = w - crop_size
j_ed = w
else:
continue
imsave(
osp.join(out_subsubdir, '{}_{}{}'.format(name,
next(counter), ext)),
img[i:i + crop_size, j:j + crop_size],
img[i_st:i_ed, j_st:j_ed],
check_contrast=False)
@ -81,7 +117,8 @@ def crop_patches(crop_size,
subsets=('train', 'val', 'test'),
subdirs=('A', 'B', 'label'),
glob_pattern='*',
max_workers=0):
max_workers=0,
keep_last=False):
"""
Crop patches from images in specific directories.
@ -102,6 +139,9 @@ def crop_patches(crop_size,
Defaults to '*', which matches arbitrary file.
max_workers (int, optional): Number of worker threads to perform the cropping
operation. Deafults to 0.
keep_last (bool, optional): If True, keep the last patch in each row and each
column. The left and upper border of the last patch will be shifted to
ensure that size of the patch be `crop_size`. Defaults to False.
"""
if max_workers < 0:
@ -110,6 +150,8 @@ def crop_patches(crop_size,
if subsets is None:
subsets = ('', )
print("Cropping patches...")
if max_workers == 0:
for subset in subsets:
for subdir in subdirs:
@ -122,7 +164,8 @@ def crop_patches(crop_size,
p,
out_subdir=out_subdir,
crop_size=crop_size,
stride=stride)
stride=stride,
keep_last=keep_last)
else:
# Concurrently crop image patches
with ThreadPoolExecutor(max_workers=max_workers) as executor:
@ -232,6 +275,25 @@ def link_dataset(src, dst):
os.symlink(src, osp.join(dst, name), target_is_directory=True)
def copy_dataset(src, dst):
"""
Make a copy a dataset.
Args:
src (str): Path of the original dataset.
dst (str): Path to copy to.
"""
if osp.exists(dst) and not osp.isdir(dst):
raise ValueError(f"{dst} exists and is not a directory.")
elif not osp.exists(dst):
os.makedirs(dst)
src = osp.realpath(src)
name = osp.basename(osp.normpath(src))
shutil.copytree(src, osp.join(dst, name))
def random_split(samples,
ratios=(0.7, 0.2, 0.1),
inplace=True,

@ -0,0 +1,136 @@
#!/usr/bin/env python
import os.path as osp
from glob import glob
from PIL import Image
from tqdm import tqdm
from common import (get_default_parser, add_crop_options, crop_patches,
create_file_list, copy_dataset, create_label_list,
get_path_tuples)
# According to the official doc(https://github.com/CAPTAIN-WHU/iSAID_Devkit),
# the files should be organized as follows:
#
# iSAID
# ├── test
# │   └── images
# │   ├── P0006.png
# │   └── ...
# │   └── P0009.png
# ├── train
# │   └── images
# │   ├── P0002_instance_color_RGB.png
# │   ├── P0002_instance_id_RGB.png
# │   ├── P0002.png
# │   ├── ...
# │   ├── P0010_instance_color_RGB.png
# │   ├── P0010_instance_id_RGB.png
# │   └── P0010.png
# └── val
# └── images
# ├── P0003_instance_color_RGB.png
# ├── P0003_instance_id_RGB.png
# ├── P0003.png
# ├── ...
# ├── P0004_instance_color_RGB.png
# ├── P0004_instance_id_RGB.png
# └── P0004.png
CLASSES = ('background', 'ship', 'storage_tank', 'baseball_diamond',
'tennis_court', 'basketball_court', 'ground_track_field', 'bridge',
'large_vehicle', 'small_vehicle', 'helicopter', 'swimming_pool',
'roundabout', 'soccer_ball_field', 'plane', 'harbor')
# Refer to https://github.com/Z-Zheng/FarSeg/blob/master/data/isaid.py
COLOR_MAP = [[0, 0, 0], [0, 0, 63], [0, 191, 127], [0, 63, 0], [0, 63, 127],
[0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127],
[0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 63, 63], [0, 127, 191],
[0, 127, 255], [0, 100, 155]]
SUBSETS = ('train', 'val')
SUBDIR = 'images'
FILE_LIST_PATTERN = "{subset}.txt"
LABEL_LIST_NAME = "labels.txt"
URL = ""
def flatten(nested_list):
flattened_list = []
for ele in nested_list:
if isinstance(ele, list):
flattened_list.extend(flatten(ele))
else:
flattened_list.append(ele)
return flattened_list
def rgb2mask(rgb):
palette = flatten(COLOR_MAP)
# Pad with zero
palette = palette + [0] * (256 * 3 - len(palette))
ref = Image.new(mode='P', size=(1, 1))
ref.putpalette(palette)
mask = rgb.quantize(palette=ref, dither=0)
return mask
if __name__ == '__main__':
parser = get_default_parser()
parser.add_argument(
'--crop_size', type=int, help="Size of cropped patches.", default=800)
parser.add_argument(
'--crop_stride',
type=int,
help="Stride of sliding windows when cropping patches. `crop_size` will be used only if `crop_size` is not None.",
default=600)
args = parser.parse_args()
out_dir = osp.join(args.out_dataset_dir,
osp.basename(osp.normpath(args.in_dataset_dir)))
assert args.crop_size is not None
# According to https://github.com/CAPTAIN-WHU/iSAID_Devkit/blob/master/preprocess/split.py
# Set keep_last=True
crop_patches(
args.crop_size,
args.crop_stride,
data_dir=args.in_dataset_dir,
out_dir=out_dir,
subsets=SUBSETS,
subdirs=(SUBDIR, ),
glob_pattern='*.png',
max_workers=8,
keep_last=True)
for subset in SUBSETS:
path_tuples = []
print(f"Processing {subset} labels...")
for im_subdir in tqdm(glob(osp.join(out_dir, subset, SUBDIR, "*/"))):
im_name = osp.basename(im_subdir[:-1]) # Strip trailing '/'
if '_' in im_name:
# Do not process labels
continue
mask_subdir = osp.join(out_dir, subset, SUBDIR,
im_name + '_instance_color_RGB')
for mask_path in glob(osp.join(mask_subdir, '*.png')):
# Convert RGB files to mask files (pseudo color)
rgb = Image.open(mask_path).convert('RGB')
mask = rgb2mask(rgb)
# Write to the original location
mask.save(mask_path)
path_tuples.extend(
get_path_tuples(
im_subdir,
mask_subdir,
glob_pattern='*.png',
data_dir=args.out_dataset_dir))
path_tuples.sort()
file_list = osp.join(
args.out_dataset_dir, FILE_LIST_PATTERN.format(subset=subset))
create_file_list(file_list, path_tuples)
print(f"Write file list to {file_list}.")
label_list = osp.join(args.out_dataset_dir, LABEL_LIST_NAME)
create_label_list(label_list, CLASSES)
print(f"Write label list to {label_list}.")
Loading…
Cancel
Save