diff --git a/.gitignore b/.gitignore index d1ea072..66f7c9b 100644 --- a/.gitignore +++ b/.gitignore @@ -128,3 +128,6 @@ dmypy.json # Pyre type checker .pyre/ + +# testdata +tutorials/train/change_detection/DataSet/ diff --git a/paddlers/tasks/__init__.py b/paddlers/tasks/__init__.py index 1543dac..a4e0fbb 100644 --- a/paddlers/tasks/__init__.py +++ b/paddlers/tasks/__init__.py @@ -14,4 +14,5 @@ from . import det from .segmenter import * +from .changedetector import * from .load_model import load_model diff --git a/paddlers/tasks/changedetector.py b/paddlers/tasks/changedetector.py index 557b0c9..733c742 100644 --- a/paddlers/tasks/changedetector.py +++ b/paddlers/tasks/changedetector.py @@ -29,7 +29,7 @@ from .base import BaseModel from .utils import seg_metrics as metrics from paddlers.utils.checkpoint import seg_pretrain_weights_dict from paddlers.transforms import Decode, Resize -from paddlers.models.ppcd import CDNet +from paddlers.models.ppcd import CDNet as _CDNet __all__ = ["CDNet"] @@ -59,7 +59,7 @@ class BaseChangeDetector(BaseModel): def build_net(self, **params): # TODO: add other model - net = CDNet(num_classes=self.num_classes, **params) + net = _CDNet(num_classes=self.num_classes, **params) return net def _fix_transforms_shape(self, image_shape): @@ -174,14 +174,7 @@ class BaseChangeDetector(BaseModel): paddleseg.models.MixedLoss( losses=losses, coef=list(coef)) ] - if self.model_name == 'FastSCNN': - loss_type *= 2 - loss_coef = [1.0, 0.4] - elif self.model_name == 'BiSeNetV2': - loss_type *= 5 - loss_coef = [1.0] * 5 - else: - loss_coef = [1.0] + loss_coef = [1.0] losses = {'types': loss_type, 'coef': loss_coef} return losses @@ -584,7 +577,7 @@ class BaseChangeDetector(BaseModel): return batch_restore_list def _postprocess(self, batch_pred, batch_origin_shape, transforms): - batch_restore_list = BaseSegmenter.get_transforms_shape_info( + batch_restore_list = BaseChangeDetector.get_transforms_shape_info( batch_origin_shape, transforms) if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer': return self._infer_postprocess( @@ -665,7 +658,7 @@ class CDNet(BaseChangeDetector): **params): params.update({'in_channels': in_channels}) super(CDNet, self).__init__( - model_name='UNet', + model_name='CDNet', num_classes=num_classes, use_mixed_loss=use_mixed_loss, **params) diff --git a/paddlers/transforms/__init__.py b/paddlers/transforms/__init__.py index 36d285d..0c10e7d 100644 --- a/paddlers/transforms/__init__.py +++ b/paddlers/transforms/__init__.py @@ -25,7 +25,7 @@ def arrange_transforms(model_type, transforms, mode='train'): else: transforms.apply_im_only = False arrange_transform = ArrangeSegmenter(mode) - elif model_type == 'changedetctor': + elif model_type == 'changedetector': if mode == 'eval': transforms.apply_im_only = True else: diff --git a/paddlers/transforms/operators.py b/paddlers/transforms/operators.py index f245589..ed2729c 100644 --- a/paddlers/transforms/operators.py +++ b/paddlers/transforms/operators.py @@ -35,8 +35,8 @@ __all__ = [ "RandomResizeByShort", "ResizeByLong", "RandomHorizontalFlip", "RandomVerticalFlip", "Normalize", "CenterCrop", "RandomCrop", "RandomScaleAspect", "RandomExpand", "Padding", "MixupImage", - "RandomDistort", "RandomBlur", "ArrangeSegmenter", "ArrangeClassifier", - "ArrangeDetector" + "RandomDistort", "RandomBlur", "ArrangeSegmenter", "ArrangeChangeDetector", + "ArrangeClassifier", "ArrangeDetector" ] interp_dict = { @@ -69,7 +69,11 @@ class Transform(object): pass def apply(self, sample): - sample['image'] = self.apply_im(sample['image']) + if 'image' in sample: + sample['image'] = self.apply_im(sample['image']) + else: # image_tx + sample['image'] = self.apply_im(sample['image_t1']) + sample['image2'] = self.apply_im(sample['image_t2']) if 'mask' in sample: sample['mask'] = self.apply_mask(sample['mask']) if 'gt_bbox' in sample: @@ -175,7 +179,7 @@ class Decode(Transform): return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR) else: - return cv2.imread(im_file, cv2.IMREAD_ANYDEPTH | + return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR) elif ext == '.npy': return np.load(img_path) @@ -218,7 +222,11 @@ class Decode(Transform): dict: Decoded sample. """ - sample['image'] = self.apply_im(sample['image']) + if 'image' in sample: + sample['image'] = self.apply_im(sample['image']) + else: # image_tx + sample['image'] = self.apply_im(sample['image_t1']) + sample['image2'] = self.apply_im(sample['image_t2']) if 'mask' in sample: sample['mask'] = self.apply_mask(sample['mask']) im_height, im_width, _ = sample['image'].shape @@ -323,6 +331,8 @@ class Resize(Transform): im_scale_x = target_w / im_w sample['image'] = self.apply_im(sample['image'], interp, target_size) + if 'image2' in sample: + sample['image2'] = self.apply_im(sample['image2'], interp, target_size) if 'mask' in sample: sample['mask'] = self.apply_mask(sample['mask'], target_size) @@ -523,6 +533,8 @@ class RandomHorizontalFlip(Transform): if random.random() < self.prob: im_h, im_w = sample['image'].shape[:2] sample['image'] = self.apply_im(sample['image']) + if 'image2' in sample: + sample['image2'] = self.apply_im(sample['image2']) if 'mask' in sample: sample['mask'] = self.apply_mask(sample['mask']) if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0: @@ -576,6 +588,8 @@ class RandomVerticalFlip(Transform): if random.random() < self.prob: im_h, im_w = sample['image'].shape[:2] sample['image'] = self.apply_im(sample['image']) + if 'image2' in sample: + sample['image2'] = self.apply_im(sample['image2']) if 'mask' in sample: sample['mask'] = self.apply_mask(sample['mask']) if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0: @@ -636,6 +650,8 @@ class Normalize(Transform): def apply(self, sample): sample['image'] = self.apply_im(sample['image']) + if 'image2' in sample: + sample['image2'] = self.apply_im(sample['image2']) return sample @@ -665,6 +681,8 @@ class CenterCrop(Transform): def apply(self, sample): sample['image'] = self.apply_im(sample['image']) + if 'image2' in sample: + sample['image2'] = self.apply_im(sample['image2']) if 'mask' in sample: sample['mask'] = self.apply_mask(sample['mask']) return sample @@ -819,6 +837,8 @@ class RandomCrop(Transform): crop_box, cropped_box, valid_ids = crop_info im_h, im_w = sample['image'].shape[:2] sample['image'] = self.apply_im(sample['image'], crop_box) + if 'image2' in sample: + sample['image2'] = self.apply_im(sample['image2'], crop_box) if 'gt_poly' in sample and len(sample['gt_poly']) > 0: crop_polys = self._crop_segm( sample['gt_poly'], @@ -1045,6 +1065,8 @@ class Padding(Transform): offsets = [w - im_w, h - im_h] sample['image'] = self.apply_im(sample['image'], offsets, (h, w)) + if 'image2' in sample: + sample['image2'] = self.apply_im(sample['image2'], offsets, (h, w)) if 'mask' in sample: sample['mask'] = self.apply_mask(sample['mask'], offsets, (h, w)) if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0: @@ -1239,22 +1261,33 @@ class RandomDistort(Transform): distortions = np.random.permutation(functions)[:self.count] for func in distortions: sample['image'] = func(sample['image']) + if 'image2' in sample: + sample['image2'] = func(sample['image2']) return sample sample['image'] = self.apply_brightness(sample['image']) + if 'image2' in sample: + sample['image2'] = self.apply_brightness(sample['image2']) mode = np.random.randint(0, 2) if mode: sample['image'] = self.apply_contrast(sample['image']) + if 'image2' in sample: + sample['image2'] = self.apply_contrast(sample['image2']) sample['image'] = self.apply_saturation(sample['image']) sample['image'] = self.apply_hue(sample['image']) + if 'image2' in sample: + sample['image2'] = self.apply_saturation(sample['image2']) + sample['image2'] = self.apply_hue(sample['image2']) if not mode: sample['image'] = self.apply_contrast(sample['image']) + if 'image2' in sample: + sample['image2'] = self.apply_contrast(sample['image2']) if self.shuffle_channel: if np.random.randint(0, 2): - sample['image'] = sample['image'][..., np.random.permutation( - 3)] - + sample['image'] = sample['image'][..., np.random.permutation(3)] + if 'image2' in sample: + sample['image2'] = sample['image2'][..., np.random.permutation(3)] return sample @@ -1289,7 +1322,8 @@ class RandomBlur(Transform): if radius > 9: radius = 9 sample['image'] = self.apply_im(sample['image'], radius) - + if 'image2' in sample: + sample['image2'] = self.apply_im(sample['image2'], radius) return sample @@ -1374,6 +1408,8 @@ class _Permute(Transform): def apply(self, sample): sample['image'] = permute(sample['image'], False) + if 'image2' in sample: + sample['image2'] = permute(sample['image2'], False) return sample @@ -1415,8 +1451,8 @@ class ArrangeChangeDetector(Transform): if 'mask' in sample: mask = sample['mask'] - image_t1 = permute(sample['image_t1'], False) - image_t2 = permute(sample['image_t2'], False) + image_t1 = permute(sample['image'], False) + image_t2 = permute(sample['image2'], False) if self.mode == 'train': mask = mask.astype('int64') return image_t1, image_t2, mask diff --git a/tutorials/train/change_detection/cdnet_build.py b/tutorials/train/change_detection/cdnet_build.py new file mode 100644 index 0000000..1c9127a --- /dev/null +++ b/tutorials/train/change_detection/cdnet_build.py @@ -0,0 +1,58 @@ +import sys + +sys.path.append("E:/dataFiles/github/PaddleRS") + +import paddlers as pdrs +from paddlers import transforms as T + +# 下载aistudio的数据到当前文件夹并解压、整理 +# https://aistudio.baidu.com/aistudio/datasetdetail/53795 + +# 定义训练和验证时的transforms +# API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/transforms/transforms.md +train_transforms = T.Compose([ + T.Resize(target_size=512), + T.RandomHorizontalFlip(), + T.Normalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), +]) + +eval_transforms = T.Compose([ + T.Resize(target_size=512), + T.Normalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), +]) + +# 定义训练和验证所用的数据集 +# API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/datasets.md +train_dataset = pdrs.datasets.CDDataset( + data_dir='E:/dataFiles/github/PaddleRS/tutorials/train/change_detection/DataSet', + file_list='tutorials/train/change_detection/DataSet/train.txt', + label_list='tutorials/train/change_detection/DataSet/labels.txt', + transforms=train_transforms, + num_workers=0, + shuffle=True) + +eval_dataset = pdrs.datasets.CDDataset( + data_dir='E:/dataFiles/github/PaddleRS/tutorials/train/change_detection/DataSet', + file_list='tutorials/train/change_detection/DataSet/val.txt', + label_list='tutorials/train/change_detection/DataSet/labels.txt', + transforms=eval_transforms, + num_workers=0, + shuffle=False) + +# 初始化模型,并进行训练 +# 可使用VisualDL查看训练指标,参考https://github.com/PaddlePaddle/paddlers/blob/develop/docs/visualdl.md +num_classes = len(train_dataset.labels) +model = pdrs.tasks.CDNet(num_classes=num_classes, in_channels=6) + +# API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/models/semantic_segmentation.md +# 各参数介绍与调整说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/parameters.md +model.train( + num_epochs=1, + train_dataset=train_dataset, + train_batch_size=4, + eval_dataset=eval_dataset, + learning_rate=0.01, + pretrain_weights=None, + save_dir='output/cdnet')