[Feature] pass change detection run test

own
geoyee 3 years ago
parent 39796e2de4
commit 26ec44f1d9
  1. 3
      .gitignore
  2. 1
      paddlers/tasks/__init__.py
  3. 17
      paddlers/tasks/changedetector.py
  4. 2
      paddlers/transforms/__init__.py
  5. 58
      paddlers/transforms/operators.py
  6. 58
      tutorials/train/change_detection/cdnet_build.py

3
.gitignore vendored

@ -128,3 +128,6 @@ dmypy.json
# Pyre type checker
.pyre/
# testdata
tutorials/train/change_detection/DataSet/

@ -14,4 +14,5 @@
from . import det
from .segmenter import *
from .changedetector import *
from .load_model import load_model

@ -29,7 +29,7 @@ from .base import BaseModel
from .utils import seg_metrics as metrics
from paddlers.utils.checkpoint import seg_pretrain_weights_dict
from paddlers.transforms import Decode, Resize
from paddlers.models.ppcd import CDNet
from paddlers.models.ppcd import CDNet as _CDNet
__all__ = ["CDNet"]
@ -59,7 +59,7 @@ class BaseChangeDetector(BaseModel):
def build_net(self, **params):
# TODO: add other model
net = CDNet(num_classes=self.num_classes, **params)
net = _CDNet(num_classes=self.num_classes, **params)
return net
def _fix_transforms_shape(self, image_shape):
@ -174,14 +174,7 @@ class BaseChangeDetector(BaseModel):
paddleseg.models.MixedLoss(
losses=losses, coef=list(coef))
]
if self.model_name == 'FastSCNN':
loss_type *= 2
loss_coef = [1.0, 0.4]
elif self.model_name == 'BiSeNetV2':
loss_type *= 5
loss_coef = [1.0] * 5
else:
loss_coef = [1.0]
loss_coef = [1.0]
losses = {'types': loss_type, 'coef': loss_coef}
return losses
@ -584,7 +577,7 @@ class BaseChangeDetector(BaseModel):
return batch_restore_list
def _postprocess(self, batch_pred, batch_origin_shape, transforms):
batch_restore_list = BaseSegmenter.get_transforms_shape_info(
batch_restore_list = BaseChangeDetector.get_transforms_shape_info(
batch_origin_shape, transforms)
if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
return self._infer_postprocess(
@ -665,7 +658,7 @@ class CDNet(BaseChangeDetector):
**params):
params.update({'in_channels': in_channels})
super(CDNet, self).__init__(
model_name='UNet',
model_name='CDNet',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)

@ -25,7 +25,7 @@ def arrange_transforms(model_type, transforms, mode='train'):
else:
transforms.apply_im_only = False
arrange_transform = ArrangeSegmenter(mode)
elif model_type == 'changedetctor':
elif model_type == 'changedetector':
if mode == 'eval':
transforms.apply_im_only = True
else:

@ -35,8 +35,8 @@ __all__ = [
"RandomResizeByShort", "ResizeByLong", "RandomHorizontalFlip",
"RandomVerticalFlip", "Normalize", "CenterCrop", "RandomCrop",
"RandomScaleAspect", "RandomExpand", "Padding", "MixupImage",
"RandomDistort", "RandomBlur", "ArrangeSegmenter", "ArrangeClassifier",
"ArrangeDetector"
"RandomDistort", "RandomBlur", "ArrangeSegmenter", "ArrangeChangeDetector",
"ArrangeClassifier", "ArrangeDetector"
]
interp_dict = {
@ -69,7 +69,11 @@ class Transform(object):
pass
def apply(self, sample):
sample['image'] = self.apply_im(sample['image'])
if 'image' in sample:
sample['image'] = self.apply_im(sample['image'])
else: # image_tx
sample['image'] = self.apply_im(sample['image_t1'])
sample['image2'] = self.apply_im(sample['image_t2'])
if 'mask' in sample:
sample['mask'] = self.apply_mask(sample['mask'])
if 'gt_bbox' in sample:
@ -175,7 +179,7 @@ class Decode(Transform):
return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR)
else:
return cv2.imread(im_file, cv2.IMREAD_ANYDEPTH |
return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
cv2.IMREAD_ANYCOLOR)
elif ext == '.npy':
return np.load(img_path)
@ -218,7 +222,11 @@ class Decode(Transform):
dict: Decoded sample.
"""
sample['image'] = self.apply_im(sample['image'])
if 'image' in sample:
sample['image'] = self.apply_im(sample['image'])
else: # image_tx
sample['image'] = self.apply_im(sample['image_t1'])
sample['image2'] = self.apply_im(sample['image_t2'])
if 'mask' in sample:
sample['mask'] = self.apply_mask(sample['mask'])
im_height, im_width, _ = sample['image'].shape
@ -323,6 +331,8 @@ class Resize(Transform):
im_scale_x = target_w / im_w
sample['image'] = self.apply_im(sample['image'], interp, target_size)
if 'image2' in sample:
sample['image2'] = self.apply_im(sample['image2'], interp, target_size)
if 'mask' in sample:
sample['mask'] = self.apply_mask(sample['mask'], target_size)
@ -523,6 +533,8 @@ class RandomHorizontalFlip(Transform):
if random.random() < self.prob:
im_h, im_w = sample['image'].shape[:2]
sample['image'] = self.apply_im(sample['image'])
if 'image2' in sample:
sample['image2'] = self.apply_im(sample['image2'])
if 'mask' in sample:
sample['mask'] = self.apply_mask(sample['mask'])
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
@ -576,6 +588,8 @@ class RandomVerticalFlip(Transform):
if random.random() < self.prob:
im_h, im_w = sample['image'].shape[:2]
sample['image'] = self.apply_im(sample['image'])
if 'image2' in sample:
sample['image2'] = self.apply_im(sample['image2'])
if 'mask' in sample:
sample['mask'] = self.apply_mask(sample['mask'])
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
@ -636,6 +650,8 @@ class Normalize(Transform):
def apply(self, sample):
sample['image'] = self.apply_im(sample['image'])
if 'image2' in sample:
sample['image2'] = self.apply_im(sample['image2'])
return sample
@ -665,6 +681,8 @@ class CenterCrop(Transform):
def apply(self, sample):
sample['image'] = self.apply_im(sample['image'])
if 'image2' in sample:
sample['image2'] = self.apply_im(sample['image2'])
if 'mask' in sample:
sample['mask'] = self.apply_mask(sample['mask'])
return sample
@ -819,6 +837,8 @@ class RandomCrop(Transform):
crop_box, cropped_box, valid_ids = crop_info
im_h, im_w = sample['image'].shape[:2]
sample['image'] = self.apply_im(sample['image'], crop_box)
if 'image2' in sample:
sample['image2'] = self.apply_im(sample['image2'], crop_box)
if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
crop_polys = self._crop_segm(
sample['gt_poly'],
@ -1045,6 +1065,8 @@ class Padding(Transform):
offsets = [w - im_w, h - im_h]
sample['image'] = self.apply_im(sample['image'], offsets, (h, w))
if 'image2' in sample:
sample['image2'] = self.apply_im(sample['image2'], offsets, (h, w))
if 'mask' in sample:
sample['mask'] = self.apply_mask(sample['mask'], offsets, (h, w))
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
@ -1239,22 +1261,33 @@ class RandomDistort(Transform):
distortions = np.random.permutation(functions)[:self.count]
for func in distortions:
sample['image'] = func(sample['image'])
if 'image2' in sample:
sample['image2'] = func(sample['image2'])
return sample
sample['image'] = self.apply_brightness(sample['image'])
if 'image2' in sample:
sample['image2'] = self.apply_brightness(sample['image2'])
mode = np.random.randint(0, 2)
if mode:
sample['image'] = self.apply_contrast(sample['image'])
if 'image2' in sample:
sample['image2'] = self.apply_contrast(sample['image2'])
sample['image'] = self.apply_saturation(sample['image'])
sample['image'] = self.apply_hue(sample['image'])
if 'image2' in sample:
sample['image2'] = self.apply_saturation(sample['image2'])
sample['image2'] = self.apply_hue(sample['image2'])
if not mode:
sample['image'] = self.apply_contrast(sample['image'])
if 'image2' in sample:
sample['image2'] = self.apply_contrast(sample['image2'])
if self.shuffle_channel:
if np.random.randint(0, 2):
sample['image'] = sample['image'][..., np.random.permutation(
3)]
sample['image'] = sample['image'][..., np.random.permutation(3)]
if 'image2' in sample:
sample['image2'] = sample['image2'][..., np.random.permutation(3)]
return sample
@ -1289,7 +1322,8 @@ class RandomBlur(Transform):
if radius > 9:
radius = 9
sample['image'] = self.apply_im(sample['image'], radius)
if 'image2' in sample:
sample['image2'] = self.apply_im(sample['image2'], radius)
return sample
@ -1374,6 +1408,8 @@ class _Permute(Transform):
def apply(self, sample):
sample['image'] = permute(sample['image'], False)
if 'image2' in sample:
sample['image2'] = permute(sample['image2'], False)
return sample
@ -1415,8 +1451,8 @@ class ArrangeChangeDetector(Transform):
if 'mask' in sample:
mask = sample['mask']
image_t1 = permute(sample['image_t1'], False)
image_t2 = permute(sample['image_t2'], False)
image_t1 = permute(sample['image'], False)
image_t2 = permute(sample['image2'], False)
if self.mode == 'train':
mask = mask.astype('int64')
return image_t1, image_t2, mask

@ -0,0 +1,58 @@
import sys
sys.path.append("E:/dataFiles/github/PaddleRS")
import paddlers as pdrs
from paddlers import transforms as T
# 下载aistudio的数据到当前文件夹并解压、整理
# https://aistudio.baidu.com/aistudio/datasetdetail/53795
# 定义训练和验证时的transforms
# API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/transforms/transforms.md
train_transforms = T.Compose([
T.Resize(target_size=512),
T.RandomHorizontalFlip(),
T.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
eval_transforms = T.Compose([
T.Resize(target_size=512),
T.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
# 定义训练和验证所用的数据集
# API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/datasets.md
train_dataset = pdrs.datasets.CDDataset(
data_dir='E:/dataFiles/github/PaddleRS/tutorials/train/change_detection/DataSet',
file_list='tutorials/train/change_detection/DataSet/train.txt',
label_list='tutorials/train/change_detection/DataSet/labels.txt',
transforms=train_transforms,
num_workers=0,
shuffle=True)
eval_dataset = pdrs.datasets.CDDataset(
data_dir='E:/dataFiles/github/PaddleRS/tutorials/train/change_detection/DataSet',
file_list='tutorials/train/change_detection/DataSet/val.txt',
label_list='tutorials/train/change_detection/DataSet/labels.txt',
transforms=eval_transforms,
num_workers=0,
shuffle=False)
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标,参考https://github.com/PaddlePaddle/paddlers/blob/develop/docs/visualdl.md
num_classes = len(train_dataset.labels)
model = pdrs.tasks.CDNet(num_classes=num_classes, in_channels=6)
# API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/models/semantic_segmentation.md
# 各参数介绍与调整说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/parameters.md
model.train(
num_epochs=1,
train_dataset=train_dataset,
train_batch_size=4,
eval_dataset=eval_dataset,
learning_rate=0.01,
pretrain_weights=None,
save_dir='output/cdnet')
Loading…
Cancel
Save