diff --git a/docs/apis/data.md b/docs/apis/data.md index 8f8bc4d..ab8a523 100644 --- a/docs/apis/data.md +++ b/docs/apis/data.md @@ -84,6 +84,9 @@ - file list中的每一行应该包含2个以空格分隔的项,依次表示输入影像相对`data_dir`的路径以及[Pascal VOC格式](http://host.robots.ox.ac.uk/pascal/VOC/)标注文件相对`data_dir`的路径。 +### 图像复原数据集`ResDataset` + + ### 图像分割数据集`SegDataset` `SegDataset`定义在:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/datasets/seg_dataset.py @@ -143,6 +146,7 @@ |`'aux_masks'`|图像分割/变化检测任务中的辅助标签路径或数据。| |`'gt_bbox'`|目标检测任务中的检测框标注数据。| |`'gt_poly'`|目标检测任务中的多边形标注数据。| +|`'target'`|图像复原中的目标影像路径或数据。| ### 组合数据变换算子 diff --git a/docs/apis/infer.md b/docs/apis/infer.md index c7f3d1c..2a5b62d 100644 --- a/docs/apis/infer.md +++ b/docs/apis/infer.md @@ -26,7 +26,7 @@ def predict(self, img_file, transforms=None): 若`img_file`是一个元组,则返回对象为包含下列键值对的字典: ``` -{"label map": 输出类别标签(以[h, w]格式排布),"score_map": 模型输出的各类别概率(以[h, w, c]格式排布)} +{"label_map": 输出类别标签(以[h, w]格式排布),"score_map": 模型输出的各类别概率(以[h, w, c]格式排布)} ``` 若`img_file`是一个列表,则返回对象为与`img_file`等长的列表,其中的每一项为一个字典(键值对如上所示),顺序对应`img_file`中的每个元素。 @@ -51,7 +51,7 @@ def predict(self, img_file, transforms=None): 若`img_file`是一个字符串或NumPy数组,则返回对象为包含下列键值对的字典: ``` -{"label map": 输出类别标签, +{"label_map": 输出类别标签, "scores_map": 输出类别概率, "label_names_map": 输出类别名称} ``` @@ -87,6 +87,10 @@ def predict(self, img_file, transforms=None): 若`img_file`是一个列表,则返回对象为与`img_file`等长的列表,其中的每一项为一个由字典(键值对如上所示)构成的列表,顺序对应`img_file`中的每个元素。 +#### `BaseRestorer.predict()` + + + #### `BaseSegmenter.predict()` 接口形式: @@ -107,7 +111,7 @@ def predict(self, img_file, transforms=None): 若`img_file`是一个字符串或NumPy数组,则返回对象为包含下列键值对的字典: ``` -{"label map": 输出类别标签(以[h, w]格式排布),"score_map": 模型输出的各类别概率(以[h, w, c]格式排布)} +{"label_map": 输出类别标签(以[h, w]格式排布),"score_map": 模型输出的各类别概率(以[h, w, c]格式排布)} ``` 若`img_file`是一个列表,则返回对象为与`img_file`等长的列表,其中的每一项为一个字典(键值对如上所示),顺序对应`img_file`中的每个元素。 diff --git a/docs/apis/train.md b/docs/apis/train.md index 944b0b3..5ac48f3 100644 --- a/docs/apis/train.md +++ b/docs/apis/train.md @@ -18,11 +18,15 @@ - `use_mixed_loss`参将在未来被弃用,因此不建议使用。 - 不同的子类支持与模型相关的输入参数,详情请参考[模型定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/rs_models/clas)和[训练器定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py)。 -### 初始化`Baseetector`子类对象 +### 初始化`BaseDetector`子类对象 - 一般支持设置`num_classes`和`backbone`参数,分别表示模型输出类别数以及所用的骨干网络类型。相比其它任务,目标检测任务的训练器支持设置的初始化参数较多,囊括网络结构、损失函数、后处理策略等方面。 - 不同的子类支持与模型相关的输入参数,详情请参考[模型定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/rs_models/det)和[训练器定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/object_detector.py)。 +### 初始化`BaseRestorer`子类对象 + + + ### 初始化`BaseSegmenter`子类对象 - 一般支持设置`in_channels`、`num_classes`以及`use_mixed_loss`参数,分别表示输入通道数、输出类别数以及是否使用预置的混合损失。部分模型如`FarSeg`暂不支持对`in_channels`参数的设置。 @@ -170,6 +174,9 @@ def train(self, |`use_vdl`|`bool`|是否启用VisualDL日志。|`True`| |`resume_checkpoint`|`str` \| `None`|检查点路径。PaddleRS支持从检查点(包含先前训练过程中存储的模型权重和优化器权重)继续训练,但需注意`resume_checkpoint`与`pretrain_weights`不得同时设置为`None`以外的值。|`None`| +### `BaseRestorer.train()` + + ### `BaseSegmenter.train()` 接口形式: @@ -311,6 +318,9 @@ def evaluate(self, "mask": 预测得到的掩模图信息} ``` +### `BaseRestorer.evaluate()` + + ### `BaseSegmenter.evaluate()` 接口形式: diff --git a/docs/dev/dev_guide.md b/docs/dev/dev_guide.md index 55f3980..9f678af 100644 --- a/docs/dev/dev_guide.md +++ b/docs/dev/dev_guide.md @@ -22,7 +22,7 @@ 在子目录中新建文件,以`{模型名称小写}.py`命名。在文件中编写完整的模型定义。 -新模型必须是`paddle.nn.Layer`的子类。对于图像分割、目标检测和场景分类任务,分别需要遵循[PaddleSeg](https://github.com/PaddlePaddle/PaddleSeg)、[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection)和[PaddleClas](https://github.com/PaddlePaddle/PaddleClas)套件中制定的相关规范。**对于变化检测、场景分类和图像分割任务,模型构造时必须传入`num_classes`参数以指定输出的类别数目**。对于变化检测任务,模型定义需遵循的规范与分割模型类似,但有以下几点不同: +新模型必须是`paddle.nn.Layer`的子类。对于图像分割、目标检测、场景分类和图像复原任务,分别需要遵循[PaddleSeg](https://github.com/PaddlePaddle/PaddleSeg)、[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection)、[PaddleClas](https://github.com/PaddlePaddle/PaddleClas)和[PaddleGAN](https://github.com/PaddlePaddle/PaddleGAN)套件中制定的相关规范。**对于变化检测、场景分类和图像分割任务,模型构造时必须传入`num_classes`参数以指定输出的类别数目。对于图像复原任务,模型构造时必须传入`rs_factor`参数以指定超分辨率缩放倍数(对于非超分辨率模型,将此参数设置为`None`)。**对于变化检测任务,模型定义需遵循的规范与分割模型类似,但有以下几点不同: - `forward()`方法接受3个输入参数,分别是`self`、`t1`和`t2`,其中`t1`和`t2`分别表示前、后两个时相的输入影像。 - 对于多任务变化检测模型(例如模型同时输出变化检测结果与两个时相的建筑物提取结果),需要指定类的`USE_MULTITASK_DECODER`属性为`True`,同时在`OUT_TYPES`属性中设置模型前向输出的列表中每一个元素对应的标签类型。可参考`ChangeStar`模型的定义。 @@ -64,7 +64,7 @@ Args: 2. 在`paddlers/tasks`目录中找到任务对应的训练器定义文件(例如变化检测任务对应`paddlers/tasks/change_detector.py`)。 3. 在文件尾部追加新的训练器定义。训练器需要继承自相关的基类(例如`BaseChangeDetector`),重写`__init__()`方法,并根据需要重写其他方法。对训练器`__init__()`方法编写的要求如下: - - 对于变化检测、场景分类、目标检测、图像分割任务,`__init__()`方法的第1个输入参数是`num_classes`,表示模型输出类别数;对于变化检测、场景分类、图像分割任务,第2个输入参数是`use_mixed_loss`,表示用户是否使用默认定义的混合损失。 + - 对于变化检测、场景分类、目标检测、图像分割任务,`__init__()`方法的第1个输入参数是`num_classes`,表示模型输出类别数。对于变化检测、场景分类、图像分割任务,第2个输入参数是`use_mixed_loss`,表示用户是否使用默认定义的混合损失;第3个输入参数是`losses`,表示训练时使用的损失函数。对于图像复原任务,第1个参数是`losses`,含义同上;第2个参数是`rs_factor`,表示超分辨率缩放倍数。 - `__init__()`的所有输入参数都必须有默认值,且在**取默认值的情况下,模型接收3通道RGB输入**。 - 在`__init__()`中需要更新`params`字典,该字典中的键值对将被用作模型构造时的输入参数。 @@ -78,7 +78,7 @@ Args: ### 2.2 新增数据预处理/数据增强算子 -在`paddlers/transforms/operators.py`中定义新算子,所有算子均继承自`paddlers.transforms.Transform`类。算子的`apply()`方法接收一个字典`sample`作为输入,取出其中存储的相关对象,处理后对字典进行in-place修改,最后返回修改后的字典。在定义算子时,只有极少数的情况需要重写`apply()`方法。大多数情况下,只需要重写`apply_im()`、`apply_mask()`、`apply_bbox()`和`apply_segm()`方法就分别可以实现对输入图像、分割标签、目标框以及目标多边形的处理。 +在`paddlers/transforms/operators.py`中定义新算子,所有算子均继承自`paddlers.transforms.Transform`类。算子的`apply()`方法接收一个字典`sample`作为输入,取出其中存储的相关对象,处理后对字典进行in-place修改,最后返回修改后的字典。在定义算子时,只有极少数的情况需要重写`apply()`方法。大多数情况下,只需要重写`apply_im()`、`apply_mask()`、`apply_bbox()`和`apply_segm()`方法就分别可以实现对图像、分割标签、目标框以及目标多边形的处理。 如果处理逻辑较为复杂,建议先封装为函数,添加到`paddlers/transforms/functions.py`中,然后在算子的`apply*()`方法中调用函数。 diff --git a/docs/intro/model_zoo.md b/docs/intro/model_zoo.md index 0fbafc7..9281df8 100644 --- a/docs/intro/model_zoo.md +++ b/docs/intro/model_zoo.md @@ -10,19 +10,20 @@ PaddleRS目前已支持的全部模型如下(标注\*的为遥感专用模型 |--------|---------|------| | 变化检测 | \*BIT | 是 | | 变化检测 | \*CDNet | 是 | +| 变化检测 | \*ChangeFormer | 是 | +| 变化检测 | \*ChangeStar | 否 | | 变化检测 | \*DSAMNet | 是 | | 变化检测 | \*DSIFN | 否 | -| 变化检测 | \*SNUNet | 是 | -| 变化检测 | \*STANet | 是 | | 变化检测 | \*FC-EF | 是 | | 变化检测 | \*FC-Siam-conc | 是 | | 变化检测 | \*FC-Siam-diff | 是 | -| 变化检测 | \*ChangeStar | 否 | -| 变化检测 | \*ChangeFormer | 是 | +| 变化检测 | \*FCCDN | 是 | +| 变化检测 | \*SNUNet | 是 | +| 变化检测 | \*STANet | 是 | +| 场景分类 | CondenseNetV2 | 是 | | 场景分类 | HRNet | 是 | | 场景分类 | MobileNetV3 | 是 | | 场景分类 | ResNet50-vd | 是 | -| 场景分类 | CondenseNetV2 | 是 | | 图像复原 | DRN | 否 | | 图像复原 | ESRGAN | 否 | | 图像复原 | LESRCNN | 否 | @@ -32,5 +33,5 @@ PaddleRS目前已支持的全部模型如下(标注\*的为遥感专用模型 | 目标检测 | PP-YOLOv2 | 是 | | 目标检测 | YOLOv3 | 是 | | 图像分割 | DeepLab V3+ | 是 | -| 图像分割 | UNet | 是 | | 图像分割 | \*FarSeg | 否 | +| 图像分割 | UNet | 是 | diff --git a/docs/intro/transforms.md b/docs/intro/transforms.md index c7234de..15bb1a2 100644 --- a/docs/intro/transforms.md +++ b/docs/intro/transforms.md @@ -6,26 +6,26 @@ PaddleRS对不同遥感任务需要的数据预处理/数据增强(合称为 | 数据变换算子名 | 用途 | 任务 | ... | | -------------------- | ------------------------------------------------- | -------- | ---- | -| Resize | 调整输入影像大小。 | 所有任务 | ... | -| RandomResize | 随机调整输入影像大小。 | 所有任务 | ... | -| ResizeByShort | 调整输入影像大小,保持纵横比不变(根据短边计算缩放系数)。 | 所有任务 | ... | -| RandomResizeByShort | 随机调整输入影像大小,保持纵横比不变(根据短边计算缩放系数)。 | 所有任务 | ... | -| ResizeByLong | 调整输入影像大小,保持纵横比不变(根据长边计算缩放系数)。 | 所有任务 | ... | -| RandomHorizontalFlip | 随机水平翻转输入影像。 | 所有任务 | ... | -| RandomVerticalFlip | 随机垂直翻转输入影像。 | 所有任务 | ... | -| Normalize | 对输入影像应用标准化。 | 所有任务 | ... | | CenterCrop | 对输入影像进行中心裁剪。 | 所有任务 | ... | -| RandomCrop | 对输入影像进行随机中心裁剪。 | 所有任务 | ... | -| RandomScaleAspect | 裁剪输入影像并重新缩放到原始尺寸。 | 所有任务 | ... | -| RandomExpand | 根据随机偏移扩展输入影像。 | 所有任务 | ... | -| Pad | 将输入影像填充到指定的大小。 | 所有任务 | ... | +| Dehaze | 对输入图像进行去雾。 | 所有任务 | ... | | MixupImage | 将两幅影像(及对应的目标检测标注)混合在一起作为新的样本。 | 目标检测 | ... | -| RandomDistort | 对输入施加随机色彩变换。 | 所有任务 | ... | +| Normalize | 对输入影像应用标准化。 | 所有任务 | ... | +| Pad | 将输入影像填充到指定的大小。 | 所有任务 | ... | | RandomBlur | 对输入施加随机模糊。 | 所有任务 | ... | -| Dehaze | 对输入图像进行去雾。 | 所有任务 | ... | +| RandomCrop | 对输入影像进行随机中心裁剪。 | 所有任务 | ... | +| RandomDistort | 对输入施加随机色彩变换。 | 所有任务 | ... | +| RandomExpand | 根据随机偏移扩展输入影像。 | 所有任务 | ... | +| RandomHorizontalFlip | 随机水平翻转输入影像。 | 所有任务 | ... | +| RandomResize | 随机调整输入影像大小。 | 所有任务 | ... | +| RandomResizeByShort | 随机调整输入影像大小,保持纵横比不变(根据短边计算缩放系数)。 | 所有任务 | ... | +| RandomScaleAspect | 裁剪输入影像并重新缩放到原始尺寸。 | 所有任务 | ... | +| RandomSwap | 随机交换两个时相的输入影像。 | 变化检测 | ... | +| RandomVerticalFlip | 随机竖直翻转输入影像。 | 所有任务 | ... | | ReduceDim | 对输入图像进行波段降维。 | 所有任务 | ... | +| Resize | 调整输入影像大小。 | 所有任务 | ... | +| ResizeByLong | 调整输入影像大小,保持纵横比不变(根据长边计算缩放系数)。 | 所有任务 | ... | +| ResizeByShort | 调整输入影像大小,保持纵横比不变(根据短边计算缩放系数)。 | 所有任务 | ... | | SelectBand | 对输入影像进行波段选择。 | 所有任务 | ... | -| RandomSwap | 随机交换两个时相的输入影像。 | 变化检测 | ... | | ... | ... | ... | ... | ## 组合算子 diff --git a/examples/rs_research/README.md b/examples/rs_research/README.md index f00b42c..73b52cf 100644 --- a/examples/rs_research/README.md +++ b/examples/rs_research/README.md @@ -187,7 +187,7 @@ python train_cd.py #### 4.2.1 配置文件编写 -本案例提供一个基于[YAML][https://yaml.org/]的轻量级配置系统,使用者可以通过修改yaml文件达到调整超参数、更换模型、更换数据集等目的,或通过编写yaml文件增加新的配置。 +本案例提供一个基于[YAML](https://yaml.org/)的轻量级配置系统,使用者可以通过修改yaml文件达到调整超参数、更换模型、更换数据集等目的,或通过编写yaml文件增加新的配置。 关于本案例中配置文件的编写规则,请参考[此项目](https://aistudio.baidu.com/aistudio/projectdetail/4203534)。 diff --git a/examples/rs_research/config_utils.py b/examples/rs_research/config_utils.py index 1effc1a..10e7129 100644 --- a/examples/rs_research/config_utils.py +++ b/examples/rs_research/config_utils.py @@ -132,7 +132,7 @@ def parse_args(*args, **kwargs): conflict_handler='resolve', parents=[cfg_parser]) # Global settings parser.add_argument('cmd', choices=['train', 'eval']) - parser.add_argument('task', choices=['cd', 'clas', 'det', 'seg']) + parser.add_argument('task', choices=['cd', 'clas', 'det', 'res', 'seg']) # Data parser.add_argument('--datasets', type=dict, default={}) diff --git a/paddlers/datasets/__init__.py b/paddlers/datasets/__init__.py index 1fecd96..1fab4c4 100644 --- a/paddlers/datasets/__init__.py +++ b/paddlers/datasets/__init__.py @@ -17,4 +17,4 @@ from .coco import COCODetDataset from .seg_dataset import SegDataset from .cd_dataset import CDDataset from .clas_dataset import ClasDataset -from .sr_dataset import SRdataset, ComposeTrans +from .res_dataset import ResDataset diff --git a/paddlers/datasets/cd_dataset.py b/paddlers/datasets/cd_dataset.py index 2a2d85a..048b23c 100644 --- a/paddlers/datasets/cd_dataset.py +++ b/paddlers/datasets/cd_dataset.py @@ -95,23 +95,23 @@ class CDDataset(BaseDataset): full_path_label))): continue if not osp.exists(full_path_im_t1): - raise IOError('Image file {} does not exist!'.format( + raise IOError("Image file {} does not exist!".format( full_path_im_t1)) if not osp.exists(full_path_im_t2): - raise IOError('Image file {} does not exist!'.format( + raise IOError("Image file {} does not exist!".format( full_path_im_t2)) if not osp.exists(full_path_label): - raise IOError('Label file {} does not exist!'.format( + raise IOError("Label file {} does not exist!".format( full_path_label)) if with_seg_labels: full_path_seg_label_t1 = osp.join(data_dir, items[3]) full_path_seg_label_t2 = osp.join(data_dir, items[4]) if not osp.exists(full_path_seg_label_t1): - raise IOError('Label file {} does not exist!'.format( + raise IOError("Label file {} does not exist!".format( full_path_seg_label_t1)) if not osp.exists(full_path_seg_label_t2): - raise IOError('Label file {} does not exist!'.format( + raise IOError("Label file {} does not exist!".format( full_path_seg_label_t2)) item_dict = dict( diff --git a/paddlers/datasets/res_dataset.py b/paddlers/datasets/res_dataset.py new file mode 100644 index 0000000..aaab8b2 --- /dev/null +++ b/paddlers/datasets/res_dataset.py @@ -0,0 +1,83 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path as osp +import copy + +from .base import BaseDataset +from paddlers.utils import logging, get_encoding, norm_path, is_pic + + +class ResDataset(BaseDataset): + """ + Dataset for image restoration tasks. + + Args: + data_dir (str): Root directory of the dataset. + file_list (str): Path of the file that contains relative paths of source and target image files. + transforms (paddlers.transforms.Compose): Data preprocessing and data augmentation operators to apply. + num_workers (int|str, optional): Number of processes used for data loading. If `num_workers` is 'auto', + the number of workers will be automatically determined according to the number of CPU cores: If + there are more than 16 cores,8 workers will be used. Otherwise, the number of workers will be half + the number of CPU cores. Defaults: 'auto'. + shuffle (bool, optional): Whether to shuffle the samples. Defaults to False. + sr_factor (int|None, optional): Scaling factor of image super-resolution task. None for other image + restoration tasks. Defaults to None. + """ + + def __init__(self, + data_dir, + file_list, + transforms, + num_workers='auto', + shuffle=False, + sr_factor=None): + super(ResDataset, self).__init__(data_dir, None, transforms, + num_workers, shuffle) + self.batch_transforms = None + self.file_list = list() + + with open(file_list, encoding=get_encoding(file_list)) as f: + for line in f: + items = line.strip().split() + if len(items) > 2: + raise ValueError( + "A space is defined as the delimiter to separate the source and target image path, " \ + "so the space cannot be in the source image or target image path, but the line[{}] of " \ + " file_list[{}] has a space in the two paths.".format(line, file_list)) + items[0] = norm_path(items[0]) + items[1] = norm_path(items[1]) + full_path_im = osp.join(data_dir, items[0]) + full_path_tar = osp.join(data_dir, items[1]) + if not is_pic(full_path_im) or not is_pic(full_path_tar): + continue + if not osp.exists(full_path_im): + raise IOError("Source image file {} does not exist!".format( + full_path_im)) + if not osp.exists(full_path_tar): + raise IOError("Target image file {} does not exist!".format( + full_path_tar)) + sample = { + 'image': full_path_im, + 'target': full_path_tar, + } + if sr_factor is not None: + sample['sr_factor'] = sr_factor + self.file_list.append(sample) + self.num_samples = len(self.file_list) + logging.info("{} samples in file {}".format( + len(self.file_list), file_list)) + + def __len__(self): + return len(self.file_list) diff --git a/paddlers/datasets/seg_dataset.py b/paddlers/datasets/seg_dataset.py index 0bfab96..656777c 100644 --- a/paddlers/datasets/seg_dataset.py +++ b/paddlers/datasets/seg_dataset.py @@ -44,7 +44,7 @@ class SegDataset(BaseDataset): shuffle=False): super(SegDataset, self).__init__(data_dir, label_list, transforms, num_workers, shuffle) - # TODO batch padding + # TODO: batch padding self.batch_transforms = None self.file_list = list() self.labels = list() diff --git a/paddlers/datasets/sr_dataset.py b/paddlers/datasets/sr_dataset.py deleted file mode 100644 index 17748bf..0000000 --- a/paddlers/datasets/sr_dataset.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# 超分辨率数据集定义 -class SRdataset(object): - def __init__(self, - mode, - gt_floder, - lq_floder, - transforms, - scale, - num_workers=4, - batch_size=8): - if mode == 'train': - preprocess = [] - preprocess.append({ - 'name': 'LoadImageFromFile', - 'key': 'lq' - }) # 加载方式 - preprocess.append({'name': 'LoadImageFromFile', 'key': 'gt'}) - preprocess.append(transforms) # 变换方式 - self.dataset = { - 'name': 'SRDataset', - 'gt_folder': gt_floder, - 'lq_folder': lq_floder, - 'num_workers': num_workers, - 'batch_size': batch_size, - 'scale': scale, - 'preprocess': preprocess - } - - if mode == "test": - preprocess = [] - preprocess.append({'name': 'LoadImageFromFile', 'key': 'lq'}) - preprocess.append({'name': 'LoadImageFromFile', 'key': 'gt'}) - preprocess.append(transforms) - self.dataset = { - 'name': 'SRDataset', - 'gt_folder': gt_floder, - 'lq_folder': lq_floder, - 'scale': scale, - 'preprocess': preprocess - } - - def __call__(self): - return self.dataset - - -# 对定义的transforms处理方式组合,返回字典 -class ComposeTrans(object): - def __init__(self, input_keys, output_keys, pipelines): - if not isinstance(pipelines, list): - raise TypeError( - 'Type of transforms is invalid. Must be List, but received is {}' - .format(type(pipelines))) - if len(pipelines) < 1: - raise ValueError( - 'Length of transforms must not be less than 1, but received is {}' - .format(len(pipelines))) - self.transforms = pipelines - self.output_length = len(output_keys) # 当output_keys的长度为3时,是DRN训练 - self.input_keys = input_keys - self.output_keys = output_keys - - def __call__(self): - pipeline = [] - for op in self.transforms: - if op['name'] == 'SRPairedRandomCrop': - op['keys'] = ['image'] * 2 - else: - op['keys'] = ['image'] * self.output_length - pipeline.append(op) - if self.output_length == 2: - transform_dict = { - 'name': 'Transforms', - 'input_keys': self.input_keys, - 'pipeline': pipeline - } - else: - transform_dict = { - 'name': 'Transforms', - 'input_keys': self.input_keys, - 'output_keys': self.output_keys, - 'pipeline': pipeline - } - - return transform_dict diff --git a/paddlers/deploy/predictor.py b/paddlers/deploy/predictor.py index 7579120..1b2c493 100644 --- a/paddlers/deploy/predictor.py +++ b/paddlers/deploy/predictor.py @@ -105,7 +105,7 @@ class Predictor(object): logging.warning( "Semantic segmentation models do not support TensorRT acceleration, " "TensorRT is forcibly disabled.") - elif 'RCNN' in self._model.__class__.__name__: + elif self._model.model_type == 'detector' and 'RCNN' in self._model.__class__.__name__: logging.warning( "RCNN models do not support TensorRT acceleration, " "TensorRT is forcibly disabled.") @@ -163,13 +163,23 @@ class Predictor(object): 'image2': preprocessed_samples[1], 'ori_shape': preprocessed_samples[2] } + elif self._model.model_type == 'restorer': + preprocessed_samples = { + 'image': preprocessed_samples[0], + 'tar_shape': preprocessed_samples[1] + } else: logging.error( "Invalid model type {}".format(self._model.model_type), exit=True) return preprocessed_samples - def postprocess(self, net_outputs, topk=1, ori_shape=None, transforms=None): + def postprocess(self, + net_outputs, + topk=1, + ori_shape=None, + tar_shape=None, + transforms=None): if self._model.model_type == 'classifier': true_topk = min(self._model.num_classes, topk) if self._model.postprocess is None: @@ -201,6 +211,12 @@ class Predictor(object): for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs) } preds = self._model.postprocess(net_outputs) + elif self._model.model_type == 'restorer': + res_maps = self._model.postprocess( + net_outputs[0], + batch_tar_shape=tar_shape, + transforms=transforms.transforms) + preds = [{'res_map': res_map} for res_map in res_maps] else: logging.error( "Invalid model type {}.".format(self._model.model_type), @@ -244,6 +260,7 @@ class Predictor(object): net_outputs, topk, ori_shape=preprocessed_input.get('ori_shape', None), + tar_shape=preprocessed_input.get('tar_shape', None), transforms=transforms) self.timer.postprocess_time_s.end(iter_num=len(images)) diff --git a/paddlers/models/__init__.py b/paddlers/models/__init__.py index 952821f..bf2f708 100644 --- a/paddlers/models/__init__.py +++ b/paddlers/models/__init__.py @@ -16,3 +16,4 @@ from . import ppcls, ppdet, ppseg, ppgan import paddlers.models.ppseg.models.losses as seg_losses import paddlers.models.ppdet.modeling.losses as det_losses import paddlers.models.ppcls.loss as clas_losses +import paddlers.models.ppgan.models.criterions as res_losses diff --git a/paddlers/models/ppdet/metrics/json_results.py b/paddlers/models/ppdet/metrics/json_results.py old mode 100755 new mode 100644 diff --git a/paddlers/models/ppdet/modeling/architectures/centernet.py b/paddlers/models/ppdet/modeling/architectures/centernet.py old mode 100755 new mode 100644 diff --git a/paddlers/models/ppdet/modeling/architectures/fairmot.py b/paddlers/models/ppdet/modeling/architectures/fairmot.py old mode 100755 new mode 100644 diff --git a/paddlers/models/ppdet/modeling/backbones/darknet.py b/paddlers/models/ppdet/modeling/backbones/darknet.py old mode 100755 new mode 100644 diff --git a/paddlers/models/ppdet/modeling/backbones/dla.py b/paddlers/models/ppdet/modeling/backbones/dla.py old mode 100755 new mode 100644 diff --git a/paddlers/models/ppdet/modeling/backbones/resnet.py b/paddlers/models/ppdet/modeling/backbones/resnet.py old mode 100755 new mode 100644 diff --git a/paddlers/models/ppdet/modeling/backbones/vgg.py b/paddlers/models/ppdet/modeling/backbones/vgg.py old mode 100755 new mode 100644 diff --git a/paddlers/models/ppdet/modeling/heads/centernet_head.py b/paddlers/models/ppdet/modeling/heads/centernet_head.py old mode 100755 new mode 100644 diff --git a/paddlers/models/ppdet/modeling/losses/fairmot_loss.py b/paddlers/models/ppdet/modeling/losses/fairmot_loss.py old mode 100755 new mode 100644 diff --git a/paddlers/models/ppdet/modeling/necks/centernet_fpn.py b/paddlers/models/ppdet/modeling/necks/centernet_fpn.py old mode 100755 new mode 100644 diff --git a/paddlers/models/ppdet/modeling/reid/fairmot_embedding_head.py b/paddlers/models/ppdet/modeling/reid/fairmot_embedding_head.py old mode 100755 new mode 100644 diff --git a/paddlers/models/ppseg/models/losses/focal_loss.py b/paddlers/models/ppseg/models/losses/focal_loss.py old mode 100755 new mode 100644 diff --git a/paddlers/models/ppseg/models/losses/kl_loss.py b/paddlers/models/ppseg/models/losses/kl_loss.py old mode 100755 new mode 100644 diff --git a/paddlers/rs_models/cd/__init__.py b/paddlers/rs_models/cd/__init__.py index c3d75b5..274b2e9 100644 --- a/paddlers/rs_models/cd/__init__.py +++ b/paddlers/rs_models/cd/__init__.py @@ -23,3 +23,5 @@ from .fc_ef import FCEarlyFusion from .fc_siam_conc import FCSiamConc from .fc_siam_diff import FCSiamDiff from .changeformer import ChangeFormer +from .fccdn import FCCDN +from .losses import fccdn_ssl_loss diff --git a/paddlers/rs_models/cd/fccdn.py b/paddlers/rs_models/cd/fccdn.py new file mode 100644 index 0000000..17d1673 --- /dev/null +++ b/paddlers/rs_models/cd/fccdn.py @@ -0,0 +1,478 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from .layers import BasicConv, MaxPool2x2, Conv1x1, Conv3x3 + +bn_mom = 1 - 0.0003 + + +class NLBlock(nn.Layer): + def __init__(self, in_channels): + super(NLBlock, self).__init__() + self.conv_v = BasicConv( + in_ch=in_channels, + out_ch=in_channels, + kernel_size=3, + norm=nn.BatchNorm2D( + in_channels, momentum=0.9)) + self.W = BasicConv( + in_ch=in_channels, + out_ch=in_channels, + kernel_size=3, + norm=nn.BatchNorm2D( + in_channels, momentum=0.9), + act=nn.ReLU()) + + def forward(self, x): + batch_size, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3] + value = self.conv_v(x) + value = value.reshape([batch_size, c, value.shape[2] * value.shape[3]]) + value = value.transpose([0, 2, 1]) # B * (H*W) * value_channels + key = x.reshape([batch_size, c, h * w]) # B * key_channels * (H*W) + query = x.reshape([batch_size, c, h * w]) + query = query.transpose([0, 2, 1]) + + sim_map = paddle.matmul(query, key) # B * (H*W) * (H*W) + sim_map = (c**-.5) * sim_map # B * (H*W) * (H*W) + sim_map = nn.functional.softmax(sim_map, axis=-1) # B * (H*W) * (H*W) + + context = paddle.matmul(sim_map, value) + context = context.transpose([0, 2, 1]) + context = context.reshape([batch_size, c, *x.shape[2:]]) + context = self.W(context) + + return context + + +class NLFPN(nn.Layer): + """ Non-local feature parymid network""" + + def __init__(self, in_dim, reduction=True): + super(NLFPN, self).__init__() + if reduction: + self.reduction = BasicConv( + in_ch=in_dim, + out_ch=in_dim // 4, + kernel_size=1, + norm=nn.BatchNorm2D( + in_dim // 4, momentum=bn_mom), + act=nn.ReLU()) + self.re_reduction = BasicConv( + in_ch=in_dim // 4, + out_ch=in_dim, + kernel_size=1, + norm=nn.BatchNorm2D( + in_dim, momentum=bn_mom), + act=nn.ReLU()) + in_dim = in_dim // 4 + else: + self.reduction = None + self.re_reduction = None + self.conv_e1 = BasicConv( + in_dim, + in_dim, + kernel_size=3, + norm=nn.BatchNorm2D( + in_dim, momentum=bn_mom), + act=nn.ReLU()) + self.conv_e2 = BasicConv( + in_dim, + in_dim * 2, + kernel_size=3, + norm=nn.BatchNorm2D( + in_dim * 2, momentum=bn_mom), + act=nn.ReLU()) + self.conv_e3 = BasicConv( + in_dim * 2, + in_dim * 4, + kernel_size=3, + norm=nn.BatchNorm2D( + in_dim * 4, momentum=bn_mom), + act=nn.ReLU()) + self.conv_d1 = BasicConv( + in_dim, + in_dim, + kernel_size=3, + norm=nn.BatchNorm2D( + in_dim, momentum=bn_mom), + act=nn.ReLU()) + self.conv_d2 = BasicConv( + in_dim * 2, + in_dim, + kernel_size=3, + norm=nn.BatchNorm2D( + in_dim, momentum=bn_mom), + act=nn.ReLU()) + self.conv_d3 = BasicConv( + in_dim * 4, + in_dim * 2, + kernel_size=3, + norm=nn.BatchNorm2D( + in_dim * 2, momentum=bn_mom), + act=nn.ReLU()) + self.nl3 = NLBlock(in_dim * 2) + self.nl2 = NLBlock(in_dim) + self.nl1 = NLBlock(in_dim) + + self.downsample_x2 = nn.MaxPool2D(stride=2, kernel_size=2) + self.upsample_x2 = nn.UpsamplingBilinear2D(scale_factor=2) + + def forward(self, x): + if self.reduction is not None: + x = self.reduction(x) + e1 = self.conv_e1(x) # C,H,W + e2 = self.conv_e2(self.downsample_x2(e1)) # 2C,H/2,W/2 + e3 = self.conv_e3(self.downsample_x2(e2)) # 4C,H/4,W/4 + + d3 = self.conv_d3(e3) # 2C,H/4,W/4 + nl = self.nl3(d3) + d3 = self.upsample_x2(paddle.multiply(d3, nl)) ##2C,H/2,W/2 + d2 = self.conv_d2(e2 + d3) # C,H/2,W/2 + nl = self.nl2(d2) + d2 = self.upsample_x2(paddle.multiply(d2, nl)) # C,H,W + d1 = self.conv_d1(e1 + d2) + nl = self.nl1(d1) + d1 = paddle.multiply(d1, nl) # C,H,W + if self.re_reduction is not None: + d1 = self.re_reduction(d1) + + return d1 + + +class Cat(nn.Layer): + def __init__(self, in_chn_high, in_chn_low, out_chn, upsample=False): + super(Cat, self).__init__() + self.do_upsample = upsample + self.upsample = nn.Upsample(scale_factor=2, mode="nearest") + self.conv2d = BasicConv( + in_chn_high + in_chn_low, + out_chn, + kernel_size=1, + norm=nn.BatchNorm2D( + out_chn, momentum=bn_mom), + act=nn.ReLU()) + + def forward(self, x, y): + if self.do_upsample: + x = self.upsample(x) + + x = paddle.concat((x, y), 1) + + return self.conv2d(x) + + +class DoubleConv(nn.Layer): + def __init__(self, in_chn, out_chn, stride=1, dilation=1): + super(DoubleConv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2D( + in_chn, + out_chn, + kernel_size=3, + stride=stride, + dilation=dilation, + padding=dilation), + nn.BatchNorm2D( + out_chn, momentum=bn_mom), + nn.ReLU(), + nn.Conv2D( + out_chn, out_chn, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2D( + out_chn, momentum=bn_mom), + nn.ReLU()) + + def forward(self, x): + x = self.conv(x) + return x + + +class SEModule(nn.Layer): + def __init__(self, channels, reduction_channels): + super(SEModule, self).__init__() + self.fc1 = nn.Conv2D( + channels, + reduction_channels, + kernel_size=1, + padding=0, + bias_attr=True) + self.ReLU = nn.ReLU() + self.fc2 = nn.Conv2D( + reduction_channels, + channels, + kernel_size=1, + padding=0, + bias_attr=True) + + def forward(self, x): + x_se = x.reshape( + [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]).mean(-1).reshape( + [x.shape[0], x.shape[1], 1, 1]) + + x_se = self.fc1(x_se) + x_se = self.ReLU(x_se) + x_se = self.fc2(x_se) + return x * F.sigmoid(x_se) + + +class BasicBlock(nn.Layer): + expansion = 1 + + def __init__(self, + inplanes, + planes, + downsample=None, + use_se=False, + stride=1, + dilation=1): + super(BasicBlock, self).__init__() + first_planes = planes + outplanes = planes * self.expansion + + self.conv1 = DoubleConv(inplanes, first_planes) + self.conv2 = DoubleConv( + first_planes, outplanes, stride=stride, dilation=dilation) + self.se = SEModule(outplanes, planes // 4) if use_se else None + self.downsample = MaxPool2x2() if downsample else None + self.ReLU = nn.ReLU() + + def forward(self, x): + out = self.conv1(x) + residual = out + out = self.conv2(out) + + if self.se is not None: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(residual) + + out = out + residual + out = self.ReLU(out) + return out + + +class DenseCatAdd(nn.Layer): + def __init__(self, in_chn, out_chn): + super(DenseCatAdd, self).__init__() + self.conv1 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) + self.conv2 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) + self.conv3 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) + self.conv_out = BasicConv( + in_chn, + out_chn, + kernel_size=1, + norm=nn.BatchNorm2D( + out_chn, momentum=bn_mom), + act=nn.ReLU()) + + def forward(self, x, y): + x1 = self.conv1(x) + x2 = self.conv2(x1) + x3 = self.conv3(x2 + x1) + + y1 = self.conv1(y) + y2 = self.conv2(y1) + y3 = self.conv3(y2 + y1) + + return self.conv_out(x1 + x2 + x3 + y1 + y2 + y3) + + +class DenseCatDiff(nn.Layer): + def __init__(self, in_chn, out_chn): + super(DenseCatDiff, self).__init__() + self.conv1 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) + self.conv2 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) + self.conv3 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) + self.conv_out = BasicConv( + in_ch=in_chn, + out_ch=out_chn, + kernel_size=1, + norm=nn.BatchNorm2D( + out_chn, momentum=bn_mom), + act=nn.ReLU()) + + def forward(self, x, y): + x1 = self.conv1(x) + x2 = self.conv2(x1) + x3 = self.conv3(x2 + x1) + + y1 = self.conv1(y) + y2 = self.conv2(y1) + y3 = self.conv3(y2 + y1) + out = self.conv_out(paddle.abs(x1 + x2 + x3 - y1 - y2 - y3)) + return out + + +class DFModule(nn.Layer): + """Dense connection-based feature fusion module""" + + def __init__(self, dim_in, dim_out, reduction=True): + super(DFModule, self).__init__() + if reduction: + self.reduction = Conv1x1( + dim_in, + dim_in // 2, + norm=nn.BatchNorm2D( + dim_in // 2, momentum=bn_mom), + act=nn.ReLU()) + dim_in = dim_in // 2 + else: + self.reduction = None + self.cat1 = DenseCatAdd(dim_in, dim_out) + self.cat2 = DenseCatDiff(dim_in, dim_out) + self.conv1 = Conv3x3( + dim_out, + dim_out, + norm=nn.BatchNorm2D( + dim_out, momentum=bn_mom), + act=nn.ReLU()) + + def forward(self, x1, x2): + if self.reduction is not None: + x1 = self.reduction(x1) + x2 = self.reduction(x2) + x_add = self.cat1(x1, x2) + x_diff = self.cat2(x1, x2) + y = self.conv1(x_diff) + x_add + return y + + +class FCCDN(nn.Layer): + """ + The FCCDN implementation based on PaddlePaddle. + + The original article refers to + Pan Chen, et al., "FCCDN: Feature Constraint Network for VHR Image Change Detection" + (https://arxiv.org/pdf/2105.10860.pdf). + + Args: + in_channels (int): Number of input channels. Default: 3. + num_classes (int): Number of target classes. Default: 2. + os (int): Number of output stride. Default: 16. + use_se (bool): Whether to use SEModule. Default: True. + """ + + def __init__(self, in_channels=3, num_classes=2, os=16, use_se=True): + super(FCCDN, self).__init__() + if os >= 16: + dilation_list = [1, 1, 1, 1] + stride_list = [2, 2, 2, 2] + pool_list = [True, True, True, True] + elif os == 8: + dilation_list = [2, 1, 1, 1] + stride_list = [1, 2, 2, 2] + pool_list = [False, True, True, True] + else: + dilation_list = [2, 2, 1, 1] + stride_list = [1, 1, 2, 2] + pool_list = [False, False, True, True] + se_list = [use_se, use_se, use_se, use_se] + channel_list = [256, 128, 64, 32] + # Encoder + self.block1 = BasicBlock(in_channels, channel_list[3], pool_list[3], + se_list[3], stride_list[3], dilation_list[3]) + self.block2 = BasicBlock(channel_list[3], channel_list[2], pool_list[2], + se_list[2], stride_list[2], dilation_list[2]) + self.block3 = BasicBlock(channel_list[2], channel_list[1], pool_list[1], + se_list[1], stride_list[1], dilation_list[1]) + self.block4 = BasicBlock(channel_list[1], channel_list[0], pool_list[0], + se_list[0], stride_list[0], dilation_list[0]) + + # Center + self.center = NLFPN(channel_list[0], True) + + # Decoder + self.decoder3 = Cat(channel_list[0], + channel_list[1], + channel_list[1], + upsample=pool_list[0]) + self.decoder2 = Cat(channel_list[1], + channel_list[2], + channel_list[2], + upsample=pool_list[1]) + self.decoder1 = Cat(channel_list[2], + channel_list[3], + channel_list[3], + upsample=pool_list[2]) + + self.df1 = DFModule(channel_list[3], channel_list[3], True) + self.df2 = DFModule(channel_list[2], channel_list[2], True) + self.df3 = DFModule(channel_list[1], channel_list[1], True) + self.df4 = DFModule(channel_list[0], channel_list[0], True) + + self.catc3 = Cat(channel_list[0], + channel_list[1], + channel_list[1], + upsample=pool_list[0]) + self.catc2 = Cat(channel_list[1], + channel_list[2], + channel_list[2], + upsample=pool_list[1]) + self.catc1 = Cat(channel_list[2], + channel_list[3], + channel_list[3], + upsample=pool_list[2]) + + self.upsample_x2 = nn.Sequential( + nn.Conv2D( + channel_list[3], 8, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2D( + 8, momentum=bn_mom), + nn.ReLU(), + nn.UpsamplingBilinear2D(scale_factor=2)) + + self.conv_out = nn.Conv2D( + 8, num_classes, kernel_size=3, stride=1, padding=1) + self.conv_out_class = nn.Conv2D( + channel_list[3], 1, kernel_size=1, stride=1, padding=0) + + def forward(self, t1, t2): + e1_1 = self.block1(t1) + e2_1 = self.block2(e1_1) + e3_1 = self.block3(e2_1) + y1 = self.block4(e3_1) + + e1_2 = self.block1(t2) + e2_2 = self.block2(e1_2) + e3_2 = self.block3(e2_2) + y2 = self.block4(e3_2) + + y1 = self.center(y1) + y2 = self.center(y2) + c = self.df4(y1, y2) + + y1 = self.decoder3(y1, e3_1) + y2 = self.decoder3(y2, e3_2) + c = self.catc3(c, self.df3(y1, y2)) + + y1 = self.decoder2(y1, e2_1) + y2 = self.decoder2(y2, e2_2) + c = self.catc2(c, self.df2(y1, y2)) + + y1 = self.decoder1(y1, e1_1) + y2 = self.decoder1(y2, e1_2) + + c = self.catc1(c, self.df1(y1, y2)) + y = self.conv_out(self.upsample_x2(c)) + + if self.training: + y1 = self.conv_out_class(y1) + y2 = self.conv_out_class(y2) + return [y, [y1, y2]] + else: + return [y] diff --git a/paddlers/rs_models/res/generators/builder.py b/paddlers/rs_models/cd/losses/__init__.py similarity index 56% rename from paddlers/rs_models/res/generators/builder.py rename to paddlers/rs_models/cd/losses/__init__.py index 8e4b884..49465ff 100644 --- a/paddlers/rs_models/res/generators/builder.py +++ b/paddlers/rs_models/cd/losses/__init__.py @@ -1,10 +1,10 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,15 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - -from ....models.ppgan.utils.registry import Registry - -GENERATORS = Registry("GENERATOR") - - -def build_generator(cfg): - cfg_copy = copy.deepcopy(cfg) - name = cfg_copy.pop('name') - generator = GENERATORS.get(name)(**cfg_copy) - return generator +from .fccdn_loss import fccdn_ssl_loss diff --git a/paddlers/rs_models/cd/losses/fccdn_loss.py b/paddlers/rs_models/cd/losses/fccdn_loss.py new file mode 100644 index 0000000..49d2b4c --- /dev/null +++ b/paddlers/rs_models/cd/losses/fccdn_loss.py @@ -0,0 +1,170 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class DiceLoss(nn.Layer): + def __init__(self, batch=True): + super(DiceLoss, self).__init__() + self.batch = batch + + def soft_dice_coeff(self, y_pred, y_true): + smooth = 0.00001 + if self.batch: + i = paddle.sum(y_true) + j = paddle.sum(y_pred) + intersection = paddle.sum(y_true * y_pred) + else: + i = y_true.sum(1).sum(1).sum(1) + j = y_pred.sum(1).sum(1).sum(1) + intersection = (y_true * y_pred).sum(1).sum(1).sum(1) + score = (2. * intersection + smooth) / (i + j + smooth) + return score.mean() + + def soft_dice_loss(self, y_pred, y_true): + loss = 1 - self.soft_dice_coeff(y_pred, y_true) + return loss + + def forward(self, y_pred, y_true): + return self.soft_dice_loss(y_pred.astype(paddle.float32), y_true) + + +class MultiClassDiceLoss(nn.Layer): + def __init__( + self, + weight, + batch=True, + ignore_index=-1, + do_softmax=False, + **kwargs, ): + super(MultiClassDiceLoss, self).__init__() + self.ignore_index = ignore_index + self.weight = weight + self.do_softmax = do_softmax + self.binary_diceloss = DiceLoss(batch) + + def forward(self, y_pred, y_true): + if self.do_softmax: + y_pred = paddle.nn.functional.softmax(y_pred, axis=1) + y_true = F.one_hot(y_true.long(), y_pred.shape[1]).permute(0, 3, 1, 2) + total_loss = 0.0 + tmp_i = 0.0 + for i in range(y_pred.shape[1]): + if i != self.ignore_index: + diceloss = self.binary_diceloss(y_pred[:, i, :, :], + y_true[:, i, :, :]) + total_loss += paddle.multiply(diceloss, self.weight[i]) + tmp_i += 1.0 + return total_loss / tmp_i + + +class DiceBCELoss(nn.Layer): + """Binary change detection task loss""" + + def __init__(self): + super(DiceBCELoss, self).__init__() + self.bce_loss = nn.BCELoss() + self.binnary_dice = DiceLoss() + + def forward(self, scores, labels, do_sigmoid=True): + if len(scores.shape) > 3: + scores = scores.squeeze(1) + if len(labels.shape) > 3: + labels = labels.squeeze(1) + if do_sigmoid: + scores = paddle.nn.functional.sigmoid(scores.clone()) + diceloss = self.binnary_dice(scores, labels) + bceloss = self.bce_loss(scores, labels) + return diceloss + bceloss + + +class McDiceBCELoss(nn.Layer): + """Multi-class change detection task loss""" + + def __init__(self, weight, do_sigmoid=True): + super(McDiceBCELoss, self).__init__() + self.ce_loss = nn.CrossEntropyLoss(weight) + self.dice = MultiClassDiceLoss(weight, do_sigmoid) + + def forward(self, scores, labels): + if len(scores.shape) < 4: + scores = scores.unsqueeze(1) + if len(labels.shape) < 4: + labels = labels.unsqueeze(1) + diceloss = self.dice(scores, labels) + bceloss = self.ce_loss(scores, labels) + return diceloss + bceloss + + +def fccdn_ssl_loss(logits_list, labels): + """ + Self-supervised learning loss for change detection. + + The original article refers to + Pan Chen, et al., "FCCDN: Feature Constraint Network for VHR Image Change Detection" + (https://arxiv.org/pdf/2105.10860.pdf). + + Args: + logits_list (list[paddle.Tensor]): Single-channel segmentation logit maps for each of the two temporal phases. + labels (paddle.Tensor): Binary change labels. + """ + + # Create loss + criterion_ssl = DiceBCELoss() + + # Get downsampled change map + h, w = logits_list[0].shape[-2], logits_list[0].shape[-1] + labels_downsample = F.interpolate(x=labels.unsqueeze(1), size=[h, w]) + labels_type = str(labels_downsample.dtype) + assert "int" in labels_type or "bool" in labels_type,\ + f"Expected dtype of labels to be int or bool, but got {labels_type}" + + # Seg map + out1 = paddle.nn.functional.sigmoid(logits_list[0]).clone() + out2 = paddle.nn.functional.sigmoid(logits_list[1]).clone() + out3 = out1.clone() + out4 = out2.clone() + + out1 = paddle.where(labels_downsample == 1, paddle.zeros_like(out1), out1) + out2 = paddle.where(labels_downsample == 1, paddle.zeros_like(out2), out2) + out3 = paddle.where(labels_downsample != 1, paddle.zeros_like(out3), out3) + out4 = paddle.where(labels_downsample != 1, paddle.zeros_like(out4), out4) + + pred_seg_pre_tmp1 = paddle.where(out1 <= 0.5, + paddle.zeros_like(out1), + paddle.ones_like(out1)) + pred_seg_post_tmp1 = paddle.where(out2 <= 0.5, + paddle.zeros_like(out2), + paddle.ones_like(out2)) + + pred_seg_pre_tmp2 = paddle.where(out3 <= 0.5, + paddle.zeros_like(out3), + paddle.ones_like(out3)) + pred_seg_post_tmp2 = paddle.where(out4 <= 0.5, + paddle.zeros_like(out4), + paddle.ones_like(out4)) + + # Seg loss + labels_downsample = labels_downsample.astype(paddle.float32) + loss_aux = 0.2 * criterion_ssl(out1, pred_seg_post_tmp1, False) + loss_aux += 0.2 * criterion_ssl(out2, pred_seg_pre_tmp1, False) + loss_aux += 0.2 * criterion_ssl( + out3, labels_downsample - pred_seg_post_tmp2, False) + loss_aux += 0.2 * criterion_ssl(out4, labels_downsample - pred_seg_pre_tmp2, + False) + + return loss_aux diff --git a/paddlers/rs_models/res/__init__.py b/paddlers/rs_models/res/__init__.py index 4dec1be..583dd10 100644 --- a/paddlers/rs_models/res/__init__.py +++ b/paddlers/rs_models/res/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .rcan_model import RCANModel +from .generators import * diff --git a/paddlers/rs_models/res/generators/param_init.py b/paddlers/rs_models/res/generators/param_init.py new file mode 100644 index 0000000..003c3c2 --- /dev/null +++ b/paddlers/rs_models/res/generators/param_init.py @@ -0,0 +1,27 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + +from paddlers.models.ppgan.modules.init import reset_parameters + + +def init_sr_weight(net): + def reset_func(m): + if hasattr(m, 'weight') and ( + not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D))): + reset_parameters(m) + + net.apply(reset_func) diff --git a/paddlers/rs_models/res/generators/rcan.py b/paddlers/rs_models/res/generators/rcan.py index 17f9ee8..6b32621 100644 --- a/paddlers/rs_models/res/generators/rcan.py +++ b/paddlers/rs_models/res/generators/rcan.py @@ -1,10 +1,25 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Based on https://github.com/kongdebug/RCAN-Paddle + import math import paddle import paddle.nn as nn -from .builder import GENERATORS +from .param_init import init_sr_weight def default_conv(in_channels, out_channels, kernel_size, bias=True): @@ -63,8 +78,10 @@ class RCAB(nn.Layer): bias=True, bn=False, act=nn.ReLU(), - res_scale=1): + res_scale=1, + use_init_weight=False): super(RCAB, self).__init__() + modules_body = [] for i in range(2): modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) @@ -74,6 +91,9 @@ class RCAB(nn.Layer): self.body = nn.Sequential(*modules_body) self.res_scale = res_scale + if use_init_weight: + init_sr_weight(self) + def forward(self, x): res = self.body(x) res += x @@ -128,21 +148,19 @@ class Upsampler(nn.Sequential): super(Upsampler, self).__init__(*m) -@GENERATORS.register() class RCAN(nn.Layer): - def __init__( - self, - scale, - n_resgroups, - n_resblocks, - n_feats=64, - n_colors=3, - rgb_range=255, - kernel_size=3, - reduction=16, - conv=default_conv, ): + def __init__(self, + sr_factor=4, + n_resgroups=10, + n_resblocks=20, + n_feats=64, + n_colors=3, + rgb_range=255, + kernel_size=3, + reduction=16, + conv=default_conv): super(RCAN, self).__init__() - self.scale = scale + self.scale = sr_factor act = nn.ReLU() n_resgroups = n_resgroups @@ -150,7 +168,6 @@ class RCAN(nn.Layer): n_feats = n_feats kernel_size = kernel_size reduction = reduction - scale = scale act = nn.ReLU() rgb_mean = (0.4488, 0.4371, 0.4040) @@ -171,7 +188,7 @@ class RCAN(nn.Layer): # Define tail module modules_tail = [ Upsampler( - conv, scale, n_feats, act=False), + conv, self.scale, n_feats, act=False), conv(n_feats, n_colors, kernel_size) ] diff --git a/paddlers/rs_models/res/rcan_model.py b/paddlers/rs_models/res/rcan_model.py deleted file mode 100644 index 691fb12..0000000 --- a/paddlers/rs_models/res/rcan_model.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle -import paddle.nn as nn - -from .generators.builder import build_generator -from ...models.ppgan.models.criterions.builder import build_criterion -from ...models.ppgan.models.base_model import BaseModel -from ...models.ppgan.models.builder import MODELS -from ...models.ppgan.utils.visual import tensor2img -from ...models.ppgan.modules.init import reset_parameters - - -@MODELS.register() -class RCANModel(BaseModel): - """ - Base SR model for single image super-resolution. - """ - - def __init__(self, generator, pixel_criterion=None, use_init_weight=False): - """ - Args: - generator (dict): config of generator. - pixel_criterion (dict): config of pixel criterion. - """ - super(RCANModel, self).__init__() - - self.nets['generator'] = build_generator(generator) - self.error_last = 1e8 - self.batch = 0 - if pixel_criterion: - self.pixel_criterion = build_criterion(pixel_criterion) - if use_init_weight: - init_sr_weight(self.nets['generator']) - - def setup_input(self, input): - self.lq = paddle.to_tensor(input['lq']) - self.visual_items['lq'] = self.lq - if 'gt' in input: - self.gt = paddle.to_tensor(input['gt']) - self.visual_items['gt'] = self.gt - self.image_paths = input['lq_path'] - - def forward(self): - pass - - def train_iter(self, optims=None): - optims['optim'].clear_grad() - - self.output = self.nets['generator'](self.lq) - self.visual_items['output'] = self.output - # pixel loss - loss_pixel = self.pixel_criterion(self.output, self.gt) - self.losses['loss_pixel'] = loss_pixel - - skip_threshold = 1e6 - - if loss_pixel.item() < skip_threshold * self.error_last: - loss_pixel.backward() - optims['optim'].step() - else: - print('Skip this batch {}! (Loss: {})'.format(self.batch + 1, - loss_pixel.item())) - self.batch += 1 - - if self.batch % 1000 == 0: - self.error_last = loss_pixel.item() / 1000 - print("update error_last:{}".format(self.error_last)) - - def test_iter(self, metrics=None): - self.nets['generator'].eval() - with paddle.no_grad(): - self.output = self.nets['generator'](self.lq) - self.visual_items['output'] = self.output - self.nets['generator'].train() - - out_img = [] - gt_img = [] - for out_tensor, gt_tensor in zip(self.output, self.gt): - out_img.append(tensor2img(out_tensor, (0., 255.))) - gt_img.append(tensor2img(gt_tensor, (0., 255.))) - - if metrics is not None: - for metric in metrics.values(): - metric.update(out_img, gt_img) - - -def init_sr_weight(net): - def reset_func(m): - if hasattr(m, 'weight') and ( - not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D))): - reset_parameters(m) - - net.apply(reset_func) diff --git a/paddlers/tasks/__init__.py b/paddlers/tasks/__init__.py index 5c1f428..ffa023f 100644 --- a/paddlers/tasks/__init__.py +++ b/paddlers/tasks/__init__.py @@ -16,7 +16,7 @@ import paddlers.tasks.object_detector as detector import paddlers.tasks.segmenter as segmenter import paddlers.tasks.change_detector as change_detector import paddlers.tasks.classifier as classifier -import paddlers.tasks.image_restorer as restorer +import paddlers.tasks.restorer as restorer from .load_model import load_model # Shorter aliases diff --git a/paddlers/tasks/base.py b/paddlers/tasks/base.py index 5250058..34e684f 100644 --- a/paddlers/tasks/base.py +++ b/paddlers/tasks/base.py @@ -30,12 +30,11 @@ from paddleslim import L1NormFilterPruner, FPGMFilterPruner import paddlers import paddlers.utils.logging as logging -from paddlers.utils import (seconds_to_hms, get_single_card_bs, dict2str, - get_pretrain_weights, load_pretrain_weights, - load_checkpoint, SmoothedValue, TrainingStats, - _get_shared_memory_size_in_M, EarlyStop) +from paddlers.utils import ( + seconds_to_hms, get_single_card_bs, dict2str, get_pretrain_weights, + load_pretrain_weights, load_checkpoint, SmoothedValue, TrainingStats, + _get_shared_memory_size_in_M, EarlyStop, to_data_parallel, scheduler_step) from .slim.prune import _pruner_eval_fn, _pruner_template_input, sensitive_prune -from .utils.infer_nets import InferNet, InferCDNet class ModelMeta(type): @@ -268,7 +267,7 @@ class BaseModel(metaclass=ModelMeta): 'The volume of dataset({}) must be larger than batch size({}).' .format(dataset.num_samples, batch_size)) batch_size_each_card = get_single_card_bs(batch_size=batch_size) - # TODO detection eval阶段需做判断 + batch_sampler = DistributedBatchSampler( dataset, batch_size=batch_size_each_card, @@ -308,7 +307,7 @@ class BaseModel(metaclass=ModelMeta): use_vdl=True): self._check_transforms(train_dataset.transforms, 'train') - if "RCNN" in self.__class__.__name__ and train_dataset.pos_num < len( + if self.model_type == 'detector' and 'RCNN' in self.__class__.__name__ and train_dataset.pos_num < len( train_dataset.file_list): nranks = 1 else: @@ -321,10 +320,10 @@ class BaseModel(metaclass=ModelMeta): if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized( ): paddle.distributed.init_parallel_env() - ddp_net = paddle.DataParallel( + ddp_net = to_data_parallel( self.net, find_unused_parameters=find_unused_parameters) else: - ddp_net = paddle.DataParallel( + ddp_net = to_data_parallel( self.net, find_unused_parameters=find_unused_parameters) if use_vdl: @@ -365,24 +364,14 @@ class BaseModel(metaclass=ModelMeta): for step, data in enumerate(self.train_data_loader()): if nranks > 1: - outputs = self.run(ddp_net, data, mode='train') + outputs = self.train_step(step, data, ddp_net) else: - outputs = self.run(self.net, data, mode='train') - loss = outputs['loss'] - loss.backward() - self.optimizer.step() - self.optimizer.clear_grad() - lr = self.optimizer.get_lr() - if isinstance(self.optimizer._learning_rate, - paddle.optimizer.lr.LRScheduler): - # If ReduceOnPlateau is used as the scheduler, use the loss value as the metric. - if isinstance(self.optimizer._learning_rate, - paddle.optimizer.lr.ReduceOnPlateau): - self.optimizer._learning_rate.step(loss.item()) - else: - self.optimizer._learning_rate.step() + outputs = self.train_step(step, data, self.net) + + scheduler_step(self.optimizer, outputs['loss']) train_avg_metrics.update(outputs) + lr = self.optimizer.get_lr() outputs['lr'] = lr if ema is not None: ema.update(self.net) @@ -622,14 +611,7 @@ class BaseModel(metaclass=ModelMeta): return pipeline_info def _build_inference_net(self): - if self.model_type in ('classifier', 'detector'): - infer_net = self.net - elif self.model_type == 'change_detector': - infer_net = InferCDNet(self.net) - else: - infer_net = InferNet(self.net, self.model_type) - infer_net.eval() - return infer_net + raise NotImplementedError def _export_inference_model(self, save_dir, image_shape=None): self.test_inputs = self._get_test_inputs(image_shape) @@ -674,6 +656,16 @@ class BaseModel(metaclass=ModelMeta): logging.info("The inference model for deployment is saved in {}.". format(save_dir)) + def train_step(self, step, data, net): + outputs = self.run(net, data, mode='train') + + loss = outputs['loss'] + loss.backward() + self.optimizer.step() + self.optimizer.clear_grad() + + return outputs + def _check_transforms(self, transforms, mode): # NOTE: Check transforms and transforms.arrange and give user-friendly error messages. if not isinstance(transforms, paddlers.transforms.Compose): diff --git a/paddlers/tasks/change_detector.py b/paddlers/tasks/change_detector.py index 60ff25e..7a45172 100644 --- a/paddlers/tasks/change_detector.py +++ b/paddlers/tasks/change_detector.py @@ -30,14 +30,15 @@ import paddlers.rs_models.cd as cmcd import paddlers.utils.logging as logging from paddlers.models import seg_losses from paddlers.transforms import Resize, decode_image -from paddlers.utils import get_single_card_bs, DisablePrint +from paddlers.utils import get_single_card_bs from paddlers.utils.checkpoint import seg_pretrain_weights_dict from .base import BaseModel from .utils import seg_metrics as metrics +from .utils.infer_nets import InferCDNet __all__ = [ "CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT", - "SNUNet", "DSIFN", "DSAMNet", "ChangeStar", "ChangeFormer" + "SNUNet", "DSIFN", "DSAMNet", "ChangeStar", "ChangeFormer", "FCCDN" ] @@ -69,6 +70,11 @@ class BaseChangeDetector(BaseModel): **params) return net + def _build_inference_net(self): + infer_net = InferCDNet(self.net) + infer_net.eval() + return infer_net + def _fix_transforms_shape(self, image_shape): if hasattr(self, 'test_transforms'): if self.test_transforms is not None: @@ -399,7 +405,8 @@ class BaseChangeDetector(BaseModel): Defaults to False. Returns: - collections.OrderedDict with key-value pairs: + If `return_details` is False, return collections.OrderedDict with + key-value pairs: For binary change detection (number of classes == 2), the key-value pairs are like: {"iou": `intersection over union for the change class`, @@ -527,12 +534,12 @@ class BaseChangeDetector(BaseModel): Returns: If `img_file` is a tuple of string or np.array, the result is a dict with - key-value pairs: - {"label map": `label map`, "score_map": `score map`}. + the following key-value pairs: + label_map (np.ndarray): Predicted label map (HW). + score_map (np.ndarray): Prediction score map (HWC). + If `img_file` is a list, the result is a list composed of dicts with the - corresponding fields: - label_map (np.ndarray): the predicted label map (HW) - score_map (np.ndarray): the prediction score map (HWC) + above keys. """ if transforms is None and not hasattr(self, 'test_transforms'): @@ -787,11 +794,11 @@ class BaseChangeDetector(BaseModel): elif item[0] == 'padding': x, y = item[2] if isinstance(label_map, np.ndarray): - label_map = label_map[..., y:y + h, x:x + w] - score_map = score_map[..., y:y + h, x:x + w] + label_map = label_map[y:y + h, x:x + w] + score_map = score_map[y:y + h, x:x + w] else: - label_map = label_map[:, :, y:y + h, x:x + w] - score_map = score_map[:, :, y:y + h, x:x + w] + label_map = label_map[:, y:y + h, x:x + w, :] + score_map = score_map[:, y:y + h, x:x + w, :] else: pass label_map = label_map.squeeze() @@ -1053,7 +1060,7 @@ class ChangeStar(BaseChangeDetector): if self.use_mixed_loss is False: return { # XXX: make sure the shallow copy works correctly here. - 'types': [seglosses.CrossEntropyLoss()] * 4, + 'types': [seg_losses.CrossEntropyLoss()] * 4, 'coef': [1.0] * 4 } else: @@ -1082,3 +1089,31 @@ class ChangeFormer(BaseChangeDetector): use_mixed_loss=use_mixed_loss, losses=losses, **params) + + +class FCCDN(BaseChangeDetector): + def __init__(self, + in_channels=3, + num_classes=2, + use_mixed_loss=False, + losses=None, + **params): + params.update({'in_channels': in_channels}) + super(FCCDN, self).__init__( + model_name='FCCDN', + num_classes=num_classes, + use_mixed_loss=use_mixed_loss, + losses=losses, + **params) + + def default_loss(self): + if self.use_mixed_loss is False: + return { + 'types': + [seg_losses.CrossEntropyLoss(), cmcd.losses.fccdn_ssl_loss], + 'coef': [1.0, 1.0] + } + else: + raise ValueError( + f"Currently `use_mixed_loss` must be set to False for {self.__class__}" + ) diff --git a/paddlers/tasks/classifier.py b/paddlers/tasks/classifier.py index 79a5099..23ab154 100644 --- a/paddlers/tasks/classifier.py +++ b/paddlers/tasks/classifier.py @@ -83,6 +83,11 @@ class BaseClassifier(BaseModel): self.in_channels = 3 return net + def _build_inference_net(self): + infer_net = self.net + infer_net.eval() + return infer_net + def _fix_transforms_shape(self, image_shape): if hasattr(self, 'test_transforms'): if self.test_transforms is not None: @@ -373,7 +378,8 @@ class BaseClassifier(BaseModel): Defaults to False. Returns: - collections.OrderedDict with key-value pairs: + If `return_details` is False, return collections.OrderedDict with + key-value pairs: {"top1": `acc of top1`, "top5": `acc of top5`}. """ @@ -389,38 +395,37 @@ class BaseClassifier(BaseModel): ): paddle.distributed.init_parallel_env() - batch_size_each_card = get_single_card_bs(batch_size) - if batch_size_each_card > 1: - batch_size_each_card = 1 - batch_size = batch_size_each_card * paddlers.env_info['num'] + if batch_size > 1: logging.warning( - "Classifier only supports batch_size=1 for each gpu/cpu card " \ - "during evaluation, so batch_size " \ - "is forcibly set to {}.".format(batch_size)) - self.eval_data_loader = self.build_data_loader( - eval_dataset, batch_size=batch_size, mode='eval') - - logging.info( - "Start to evaluate(total_samples={}, total_steps={})...".format( - eval_dataset.num_samples, - math.ceil(eval_dataset.num_samples * 1.0 / batch_size))) - - top1s = [] - top5s = [] - with paddle.no_grad(): - for step, data in enumerate(self.eval_data_loader): - data.append(eval_dataset.transforms.transforms) - outputs = self.run(self.net, data, 'eval') - top1s.append(outputs["top1"]) - top5s.append(outputs["top5"]) - - top1 = np.mean(top1s) - top5 = np.mean(top5s) - eval_metrics = OrderedDict(zip(['top1', 'top5'], [top1, top5])) - if return_details: - # TODO: add details - return eval_metrics, None - return eval_metrics + "Classifier only supports single card evaluation with batch_size=1 " + "during evaluation, so batch_size is forcibly set to 1.") + batch_size = 1 + + if nranks < 2 or local_rank == 0: + self.eval_data_loader = self.build_data_loader( + eval_dataset, batch_size=batch_size, mode='eval') + logging.info( + "Start to evaluate(total_samples={}, total_steps={})...".format( + eval_dataset.num_samples, eval_dataset.num_samples)) + + top1s = [] + top5s = [] + with paddle.no_grad(): + for step, data in enumerate(self.eval_data_loader): + data.append(eval_dataset.transforms.transforms) + outputs = self.run(self.net, data, 'eval') + top1s.append(outputs["top1"]) + top5s.append(outputs["top5"]) + + top1 = np.mean(top1s) + top5 = np.mean(top5s) + eval_metrics = OrderedDict(zip(['top1', 'top5'], [top1, top5])) + + if return_details: + # TODO: Add details + return eval_metrics, None + + return eval_metrics def predict(self, img_file, transforms=None): """ @@ -435,16 +440,14 @@ class BaseClassifier(BaseModel): Defaults to None. Returns: - If `img_file` is a string or np.array, the result is a dict with key-value - pairs: - {"label map": `class_ids_map`, - "scores_map": `scores_map`, - "label_names_map": `label_names_map`}. + If `img_file` is a string or np.array, the result is a dict with the + following key-value pairs: + class_ids_map (np.ndarray): IDs of predicted classes. + scores_map (np.ndarray): Scores of predicted classes. + label_names_map (np.ndarray): Names of predicted classes. + If `img_file` is a list, the result is a list composed of dicts with the - corresponding fields: - class_ids_map (np.ndarray): class_ids - scores_map (np.ndarray): scores - label_names_map (np.ndarray): label_names + above keys. """ if transforms is None and not hasattr(self, 'test_transforms'): @@ -555,6 +558,26 @@ class BaseClassifier(BaseModel): raise TypeError( "`transforms.arrange` must be an ArrangeClassifier object.") + def build_data_loader(self, dataset, batch_size, mode='train'): + if dataset.num_samples < batch_size: + raise ValueError( + 'The volume of dataset({}) must be larger than batch size({}).' + .format(dataset.num_samples, batch_size)) + + if mode != 'train': + return paddle.io.DataLoader( + dataset, + batch_size=batch_size, + shuffle=dataset.shuffle, + drop_last=False, + collate_fn=dataset.batch_transforms, + num_workers=dataset.num_workers, + return_list=True, + use_shared_memory=False) + else: + return super(BaseClassifier, self).build_data_loader( + dataset, batch_size, mode) + class ResNet50_vd(BaseClassifier): def __init__(self, diff --git a/paddlers/tasks/image_restorer.py b/paddlers/tasks/image_restorer.py deleted file mode 100644 index ec41dd3..0000000 --- a/paddlers/tasks/image_restorer.py +++ /dev/null @@ -1,786 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import time -import datetime - -import paddle -from paddle.distributed import ParallelEnv - -from ..models.ppgan.datasets.builder import build_dataloader -from ..models.ppgan.models.builder import build_model -from ..models.ppgan.utils.visual import tensor2img, save_image -from ..models.ppgan.utils.filesystem import makedirs, save, load -from ..models.ppgan.utils.timer import TimeAverager -from ..models.ppgan.utils.profiler import add_profiler_step -from ..models.ppgan.utils.logger import setup_logger - - -# 定义AttrDict类实现动态属性 -class AttrDict(dict): - def __getattr__(self, key): - try: - return self[key] - except KeyError: - raise AttributeError(key) - - def __setattr__(self, key, value): - if key in self.__dict__: - self.__dict__[key] = value - else: - self[key] = value - - -# 创建AttrDict类 -def create_attr_dict(config_dict): - from ast import literal_eval - for key, value in config_dict.items(): - if type(value) is dict: - config_dict[key] = value = AttrDict(value) - if isinstance(value, str): - try: - value = literal_eval(value) - except BaseException: - pass - if isinstance(value, AttrDict): - create_attr_dict(config_dict[key]) - else: - config_dict[key] = value - - -# 数据加载类 -class IterLoader: - def __init__(self, dataloader): - self._dataloader = dataloader - self.iter_loader = iter(self._dataloader) - self._epoch = 1 - - @property - def epoch(self): - return self._epoch - - def __next__(self): - try: - data = next(self.iter_loader) - except StopIteration: - self._epoch += 1 - self.iter_loader = iter(self._dataloader) - data = next(self.iter_loader) - - return data - - def __len__(self): - return len(self._dataloader) - - -# 基础训练类 -class Restorer: - """ - # trainer calling logic: - # - # build_model || model(BaseModel) - # | || - # build_dataloader || dataloader - # | || - # model.setup_lr_schedulers || lr_scheduler - # | || - # model.setup_optimizers || optimizers - # | || - # train loop (model.setup_input + model.train_iter) || train loop - # | || - # print log (model.get_current_losses) || - # | || - # save checkpoint (model.nets) \/ - """ - - def __init__(self, cfg, logger): - # base config - # self.logger = logging.getLogger(__name__) - self.logger = logger - self.cfg = cfg - self.output_dir = cfg.output_dir - self.max_eval_steps = cfg.model.get('max_eval_steps', None) - - self.local_rank = ParallelEnv().local_rank - self.world_size = ParallelEnv().nranks - self.log_interval = cfg.log_config.interval - self.visual_interval = cfg.log_config.visiual_interval - self.weight_interval = cfg.snapshot_config.interval - - self.start_epoch = 1 - self.current_epoch = 1 - self.current_iter = 1 - self.inner_iter = 1 - self.batch_id = 0 - self.global_steps = 0 - - # build model - self.model = build_model(cfg.model) - # multiple gpus prepare - if ParallelEnv().nranks > 1: - self.distributed_data_parallel() - - # build metrics - self.metrics = None - self.is_save_img = True - validate_cfg = cfg.get('validate', None) - if validate_cfg and 'metrics' in validate_cfg: - self.metrics = self.model.setup_metrics(validate_cfg['metrics']) - if validate_cfg and 'save_img' in validate_cfg: - self.is_save_img = validate_cfg['save_img'] - - self.enable_visualdl = cfg.get('enable_visualdl', False) - if self.enable_visualdl: - import visualdl - self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir) - - # evaluate only - if not cfg.is_train: - return - - # build train dataloader - self.train_dataloader = build_dataloader(cfg.dataset.train) - self.iters_per_epoch = len(self.train_dataloader) - - # build lr scheduler - # TODO: has a better way? - if 'lr_scheduler' in cfg and 'iters_per_epoch' in cfg.lr_scheduler: - cfg.lr_scheduler.iters_per_epoch = self.iters_per_epoch - self.lr_schedulers = self.model.setup_lr_schedulers(cfg.lr_scheduler) - - # build optimizers - self.optimizers = self.model.setup_optimizers(self.lr_schedulers, - cfg.optimizer) - - self.epochs = cfg.get('epochs', None) - if self.epochs: - self.total_iters = self.epochs * self.iters_per_epoch - self.by_epoch = True - else: - self.by_epoch = False - self.total_iters = cfg.total_iters - - if self.by_epoch: - self.weight_interval *= self.iters_per_epoch - - self.validate_interval = -1 - if cfg.get('validate', None) is not None: - self.validate_interval = cfg.validate.get('interval', -1) - - self.time_count = {} - self.best_metric = {} - self.model.set_total_iter(self.total_iters) - self.profiler_options = cfg.profiler_options - - def distributed_data_parallel(self): - paddle.distributed.init_parallel_env() - find_unused_parameters = self.cfg.get('find_unused_parameters', False) - for net_name, net in self.model.nets.items(): - self.model.nets[net_name] = paddle.DataParallel( - net, find_unused_parameters=find_unused_parameters) - - def learning_rate_scheduler_step(self): - if isinstance(self.model.lr_scheduler, dict): - for lr_scheduler in self.model.lr_scheduler.values(): - lr_scheduler.step() - elif isinstance(self.model.lr_scheduler, - paddle.optimizer.lr.LRScheduler): - self.model.lr_scheduler.step() - else: - raise ValueError( - 'lr schedulter must be a dict or an instance of LRScheduler') - - def train(self): - reader_cost_averager = TimeAverager() - batch_cost_averager = TimeAverager() - - iter_loader = IterLoader(self.train_dataloader) - - # set model.is_train = True - self.model.setup_train_mode(is_train=True) - while self.current_iter < (self.total_iters + 1): - self.current_epoch = iter_loader.epoch - self.inner_iter = self.current_iter % self.iters_per_epoch - - add_profiler_step(self.profiler_options) - - start_time = step_start_time = time.time() - data = next(iter_loader) - reader_cost_averager.record(time.time() - step_start_time) - # unpack data from dataset and apply preprocessing - # data input should be dict - self.model.setup_input(data) - self.model.train_iter(self.optimizers) - - batch_cost_averager.record( - time.time() - step_start_time, - num_samples=self.cfg['dataset']['train'].get('batch_size', 1)) - - step_start_time = time.time() - - if self.current_iter % self.log_interval == 0: - self.data_time = reader_cost_averager.get_average() - self.step_time = batch_cost_averager.get_average() - self.ips = batch_cost_averager.get_ips_average() - self.print_log() - - reader_cost_averager.reset() - batch_cost_averager.reset() - - if self.current_iter % self.visual_interval == 0 and self.local_rank == 0: - self.visual('visual_train') - - self.learning_rate_scheduler_step() - - if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0: - self.test() - - if self.current_iter % self.weight_interval == 0: - self.save(self.current_iter, 'weight', keep=-1) - self.save(self.current_iter) - - self.current_iter += 1 - - def test(self): - if not hasattr(self, 'test_dataloader'): - self.test_dataloader = build_dataloader( - self.cfg.dataset.test, is_train=False) - iter_loader = IterLoader(self.test_dataloader) - if self.max_eval_steps is None: - self.max_eval_steps = len(self.test_dataloader) - - if self.metrics: - for metric in self.metrics.values(): - metric.reset() - - # set model.is_train = False - self.model.setup_train_mode(is_train=False) - - for i in range(self.max_eval_steps): - if self.max_eval_steps < self.log_interval or i % self.log_interval == 0: - self.logger.info('Test iter: [%d/%d]' % ( - i * self.world_size, self.max_eval_steps * self.world_size)) - - data = next(iter_loader) - self.model.setup_input(data) - self.model.test_iter(metrics=self.metrics) - - if self.is_save_img: - visual_results = {} - current_paths = self.model.get_image_paths() - current_visuals = self.model.get_current_visuals() - - if len(current_visuals) > 0 and list(current_visuals.values())[ - 0].shape == 4: - num_samples = list(current_visuals.values())[0].shape[0] - else: - num_samples = 1 - - for j in range(num_samples): - if j < len(current_paths): - short_path = os.path.basename(current_paths[j]) - basename = os.path.splitext(short_path)[0] - else: - basename = '{:04d}_{:04d}'.format(i, j) - for k, img_tensor in current_visuals.items(): - name = '%s_%s' % (basename, k) - if len(img_tensor.shape) == 4: - visual_results.update({name: img_tensor[j]}) - else: - visual_results.update({name: img_tensor}) - - self.visual( - 'visual_test', - visual_results=visual_results, - step=self.batch_id, - is_save_image=True) - - if self.metrics: - for metric_name, metric in self.metrics.items(): - self.logger.info("Metric {}: {:.4f}".format( - metric_name, metric.accumulate())) - - def print_log(self): - losses = self.model.get_current_losses() - - message = '' - if self.by_epoch: - message += 'Epoch: %d/%d, iter: %d/%d ' % ( - self.current_epoch, self.epochs, self.inner_iter, - self.iters_per_epoch) - else: - message += 'Iter: %d/%d ' % (self.current_iter, self.total_iters) - - message += f'lr: {self.current_learning_rate:.3e} ' - - for k, v in losses.items(): - message += '%s: %.3f ' % (k, v) - if self.enable_visualdl: - self.vdl_logger.add_scalar(k, v, step=self.global_steps) - - if hasattr(self, 'step_time'): - message += 'batch_cost: %.5f sec ' % self.step_time - - if hasattr(self, 'data_time'): - message += 'reader_cost: %.5f sec ' % self.data_time - - if hasattr(self, 'ips'): - message += 'ips: %.5f images/s ' % self.ips - - if hasattr(self, 'step_time'): - eta = self.step_time * (self.total_iters - self.current_iter) - eta = eta if eta > 0 else 0 - - eta_str = str(datetime.timedelta(seconds=int(eta))) - message += f'eta: {eta_str}' - - # print the message - self.logger.info(message) - - @property - def current_learning_rate(self): - for optimizer in self.model.optimizers.values(): - return optimizer.get_lr() - - def visual(self, - results_dir, - visual_results=None, - step=None, - is_save_image=False): - """ - visual the images, use visualdl or directly write to the directory - Parameters: - results_dir (str) -- directory name which contains saved images - visual_results (dict) -- the results images dict - step (int) -- global steps, used in visualdl - is_save_image (bool) -- weather write to the directory or visualdl - """ - self.model.compute_visuals() - - if visual_results is None: - visual_results = self.model.get_current_visuals() - - min_max = self.cfg.get('min_max', None) - if min_max is None: - min_max = (-1., 1.) - - image_num = self.cfg.get('image_num', None) - if (image_num is None) or (not self.enable_visualdl): - image_num = 1 - for label, image in visual_results.items(): - image_numpy = tensor2img(image, min_max, image_num) - if (not is_save_image) and self.enable_visualdl: - self.vdl_logger.add_image( - results_dir + '/' + label, - image_numpy, - step=step if step else self.global_steps, - dataformats="HWC" if image_num == 1 else "NCHW") - else: - if self.cfg.is_train: - if self.by_epoch: - msg = 'epoch%.3d_' % self.current_epoch - else: - msg = 'iter%.3d_' % self.current_iter - else: - msg = '' - makedirs(os.path.join(self.output_dir, results_dir)) - img_path = os.path.join(self.output_dir, results_dir, - msg + '%s.png' % (label)) - save_image(image_numpy, img_path) - - def save(self, epoch, name='checkpoint', keep=1): - if self.local_rank != 0: - return - - assert name in ['checkpoint', 'weight'] - - state_dicts = {} - if self.by_epoch: - save_filename = 'epoch_%s_%s.pdparams' % ( - epoch // self.iters_per_epoch, name) - else: - save_filename = 'iter_%s_%s.pdparams' % (epoch, name) - - os.makedirs(self.output_dir, exist_ok=True) - save_path = os.path.join(self.output_dir, save_filename) - for net_name, net in self.model.nets.items(): - state_dicts[net_name] = net.state_dict() - - if name == 'weight': - save(state_dicts, save_path) - return - - state_dicts['epoch'] = epoch - - for opt_name, opt in self.model.optimizers.items(): - state_dicts[opt_name] = opt.state_dict() - - save(state_dicts, save_path) - - if keep > 0: - try: - if self.by_epoch: - checkpoint_name_to_be_removed = os.path.join( - self.output_dir, 'epoch_%s_%s.pdparams' % ( - (epoch - keep * self.weight_interval) // - self.iters_per_epoch, name)) - else: - checkpoint_name_to_be_removed = os.path.join( - self.output_dir, 'iter_%s_%s.pdparams' % - (epoch - keep * self.weight_interval, name)) - - if os.path.exists(checkpoint_name_to_be_removed): - os.remove(checkpoint_name_to_be_removed) - - except Exception as e: - self.logger.info('remove old checkpoints error: {}'.format(e)) - - def resume(self, checkpoint_path): - state_dicts = load(checkpoint_path) - if state_dicts.get('epoch', None) is not None: - self.start_epoch = state_dicts['epoch'] + 1 - self.global_steps = self.iters_per_epoch * state_dicts['epoch'] - - self.current_iter = state_dicts['epoch'] + 1 - - for net_name, net in self.model.nets.items(): - net.set_state_dict(state_dicts[net_name]) - - for opt_name, opt in self.model.optimizers.items(): - opt.set_state_dict(state_dicts[opt_name]) - - def load(self, weight_path): - state_dicts = load(weight_path) - - for net_name, net in self.model.nets.items(): - if net_name in state_dicts: - net.set_state_dict(state_dicts[net_name]) - self.logger.info('Loaded pretrained weight for net {}'.format( - net_name)) - else: - self.logger.warning( - 'Can not find state dict of net {}. Skip load pretrained weight for net {}' - .format(net_name, net_name)) - - def close(self): - """ - when finish the training need close file handler or other. - """ - if self.enable_visualdl: - self.vdl_logger.close() - - -# 基础超分模型训练类 -class BasicSRNet: - def __init__(self): - self.model = {} - self.optimizer = {} - self.lr_scheduler = {} - self.min_max = '' - - def train( - self, - total_iters, - train_dataset, - test_dataset, - output_dir, - validate, - snapshot, - log, - lr_rate, - evaluate_weights='', - resume='', - pretrain_weights='', - periods=[100000], - restart_weights=[1], ): - self.lr_scheduler['learning_rate'] = lr_rate - - if self.lr_scheduler['name'] == 'CosineAnnealingRestartLR': - self.lr_scheduler['periods'] = periods - self.lr_scheduler['restart_weights'] = restart_weights - - validate = { - 'interval': validate, - 'save_img': False, - 'metrics': { - 'psnr': { - 'name': 'PSNR', - 'crop_border': 4, - 'test_y_channel': True - }, - 'ssim': { - 'name': 'SSIM', - 'crop_border': 4, - 'test_y_channel': True - } - } - } - log_config = {'interval': log, 'visiual_interval': 500} - snapshot_config = {'interval': snapshot} - - cfg = { - 'total_iters': total_iters, - 'output_dir': output_dir, - 'min_max': self.min_max, - 'model': self.model, - 'dataset': { - 'train': train_dataset, - 'test': test_dataset - }, - 'lr_scheduler': self.lr_scheduler, - 'optimizer': self.optimizer, - 'validate': validate, - 'log_config': log_config, - 'snapshot_config': snapshot_config - } - - cfg = AttrDict(cfg) - create_attr_dict(cfg) - - cfg.is_train = True - cfg.profiler_options = None - cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime()) - - if cfg.model.name == 'BaseSRModel': - floderModelName = cfg.model.generator.name - else: - floderModelName = cfg.model.name - cfg.output_dir = os.path.join(cfg.output_dir, - floderModelName + cfg.timestamp) - - logger_cfg = setup_logger(cfg.output_dir) - logger_cfg.info('Configs: {}'.format(cfg)) - - if paddle.is_compiled_with_cuda(): - paddle.set_device('gpu') - else: - paddle.set_device('cpu') - - # build trainer - trainer = Restorer(cfg, logger_cfg) - - # continue train or evaluate, checkpoint need contain epoch and optimizer info - if len(resume) > 0: - trainer.resume(resume) - # evaluate or finute, only load generator weights - elif len(pretrain_weights) > 0: - trainer.load(pretrain_weights) - if len(evaluate_weights) > 0: - trainer.load(evaluate_weights) - trainer.test() - return - # training, when keyboard interrupt save weights - try: - trainer.train() - except KeyboardInterrupt as e: - trainer.save(trainer.current_epoch) - - trainer.close() - - -# DRN模型训练 -class DRNet(BasicSRNet): - def __init__(self, - n_blocks=30, - n_feats=16, - n_colors=3, - rgb_range=255, - negval=0.2): - super(DRNet, self).__init__() - self.min_max = '(0., 255.)' - self.generator = { - 'name': 'DRNGenerator', - 'scale': (2, 4), - 'n_blocks': n_blocks, - 'n_feats': n_feats, - 'n_colors': n_colors, - 'rgb_range': rgb_range, - 'negval': negval - } - self.pixel_criterion = {'name': 'L1Loss'} - self.model = { - 'name': 'DRN', - 'generator': self.generator, - 'pixel_criterion': self.pixel_criterion - } - self.optimizer = { - 'optimG': { - 'name': 'Adam', - 'net_names': ['generator'], - 'weight_decay': 0.0, - 'beta1': 0.9, - 'beta2': 0.999 - }, - 'optimD': { - 'name': 'Adam', - 'net_names': ['dual_model_0', 'dual_model_1'], - 'weight_decay': 0.0, - 'beta1': 0.9, - 'beta2': 0.999 - } - } - self.lr_scheduler = { - 'name': 'CosineAnnealingRestartLR', - 'eta_min': 1e-07 - } - - -# 轻量化超分模型LESRCNN训练 -class LESRCNNet(BasicSRNet): - def __init__(self, scale=4, multi_scale=False, group=1): - super(LESRCNNet, self).__init__() - self.min_max = '(0., 1.)' - self.generator = { - 'name': 'LESRCNNGenerator', - 'scale': scale, - 'multi_scale': False, - 'group': 1 - } - self.pixel_criterion = {'name': 'L1Loss'} - self.model = { - 'name': 'BaseSRModel', - 'generator': self.generator, - 'pixel_criterion': self.pixel_criterion - } - self.optimizer = { - 'name': 'Adam', - 'net_names': ['generator'], - 'beta1': 0.9, - 'beta2': 0.99 - } - self.lr_scheduler = { - 'name': 'CosineAnnealingRestartLR', - 'eta_min': 1e-07 - } - - -# ESRGAN模型训练 -# 若loss_type='gan' 使用感知损失、对抗损失和像素损失 -# 若loss_type = 'pixel' 只使用像素损失 -class ESRGANet(BasicSRNet): - def __init__(self, loss_type='gan', in_nc=3, out_nc=3, nf=64, nb=23): - super(ESRGANet, self).__init__() - self.min_max = '(0., 1.)' - self.generator = { - 'name': 'RRDBNet', - 'in_nc': in_nc, - 'out_nc': out_nc, - 'nf': nf, - 'nb': nb - } - - if loss_type == 'gan': - # 定义损失函数 - self.pixel_criterion = {'name': 'L1Loss', 'loss_weight': 0.01} - self.discriminator = { - 'name': 'VGGDiscriminator128', - 'in_channels': 3, - 'num_feat': 64 - } - self.perceptual_criterion = { - 'name': 'PerceptualLoss', - 'layer_weights': { - '34': 1.0 - }, - 'perceptual_weight': 1.0, - 'style_weight': 0.0, - 'norm_img': False - } - self.gan_criterion = { - 'name': 'GANLoss', - 'gan_mode': 'vanilla', - 'loss_weight': 0.005 - } - # 定义模型 - self.model = { - 'name': 'ESRGAN', - 'generator': self.generator, - 'discriminator': self.discriminator, - 'pixel_criterion': self.pixel_criterion, - 'perceptual_criterion': self.perceptual_criterion, - 'gan_criterion': self.gan_criterion - } - self.optimizer = { - 'optimG': { - 'name': 'Adam', - 'net_names': ['generator'], - 'weight_decay': 0.0, - 'beta1': 0.9, - 'beta2': 0.99 - }, - 'optimD': { - 'name': 'Adam', - 'net_names': ['discriminator'], - 'weight_decay': 0.0, - 'beta1': 0.9, - 'beta2': 0.99 - } - } - self.lr_scheduler = { - 'name': 'MultiStepDecay', - 'milestones': [50000, 100000, 200000, 300000], - 'gamma': 0.5 - } - else: - self.pixel_criterion = {'name': 'L1Loss'} - self.model = { - 'name': 'BaseSRModel', - 'generator': self.generator, - 'pixel_criterion': self.pixel_criterion - } - self.optimizer = { - 'name': 'Adam', - 'net_names': ['generator'], - 'beta1': 0.9, - 'beta2': 0.99 - } - self.lr_scheduler = { - 'name': 'CosineAnnealingRestartLR', - 'eta_min': 1e-07 - } - - -# RCAN模型训练 -class RCANet(BasicSRNet): - def __init__( - self, - scale=2, - n_resgroups=10, - n_resblocks=20, ): - super(RCANet, self).__init__() - self.min_max = '(0., 255.)' - self.generator = { - 'name': 'RCAN', - 'scale': scale, - 'n_resgroups': n_resgroups, - 'n_resblocks': n_resblocks - } - self.pixel_criterion = {'name': 'L1Loss'} - self.model = { - 'name': 'RCANModel', - 'generator': self.generator, - 'pixel_criterion': self.pixel_criterion - } - self.optimizer = { - 'name': 'Adam', - 'net_names': ['generator'], - 'beta1': 0.9, - 'beta2': 0.99 - } - self.lr_scheduler = { - 'name': 'MultiStepDecay', - 'milestones': [250000, 500000, 750000, 1000000], - 'gamma': 0.5 - } diff --git a/paddlers/tasks/object_detector.py b/paddlers/tasks/object_detector.py index 8a0acaa..6531481 100644 --- a/paddlers/tasks/object_detector.py +++ b/paddlers/tasks/object_detector.py @@ -61,6 +61,11 @@ class BaseDetector(BaseModel): net = ppdet.modeling.__dict__[self.model_name](**params) return net + def _build_inference_net(self): + infer_net = self.net + infer_net.eval() + return infer_net + def _fix_transforms_shape(self, image_shape): raise NotImplementedError("_fix_transforms_shape: not implemented!") @@ -250,32 +255,18 @@ class BaseDetector(BaseModel): """ args = self._pre_train(locals()) + args.pop('self') return self._real_train(**args) def _pre_train(self, in_args): return in_args - def _real_train(self, - num_epochs, - train_dataset, - train_batch_size=64, - eval_dataset=None, - optimizer=None, - save_interval_epochs=1, - log_interval_steps=10, - save_dir='output', - pretrain_weights='IMAGENET', - learning_rate=.001, - warmup_steps=0, - warmup_start_lr=0.0, - lr_decay_epochs=(216, 243), - lr_decay_gamma=0.1, - metric=None, - use_ema=False, - early_stop=False, - early_stop_patience=5, - use_vdl=True, - resume_checkpoint=None): + def _real_train( + self, num_epochs, train_dataset, train_batch_size, eval_dataset, + optimizer, save_interval_epochs, log_interval_steps, save_dir, + pretrain_weights, learning_rate, warmup_steps, warmup_start_lr, + lr_decay_epochs, lr_decay_gamma, metric, use_ema, early_stop, + early_stop_patience, use_vdl, resume_checkpoint): if self.status == 'Infer': logging.error( @@ -485,7 +476,7 @@ class BaseDetector(BaseModel): Defaults to False. Returns: - collections.OrderedDict with key-value pairs: + If `return_details` is False, return collections.OrderedDict with key-value pairs: {"bbox_mmap":`mean average precision (0.50, 11point)`}. """ @@ -584,21 +575,17 @@ class BaseDetector(BaseModel): Returns: If `img_file` is a string or np.array, the result is a list of dict with - key-value pairs: - {"category_id": `category_id`, - "category": `category`, - "bbox": `[x, y, w, h]`, - "score": `score`, - "mask": `mask`}. - If `img_file` is a list, the result is a list composed of list of dicts - with the corresponding fields: - category_id(int): the predicted category ID. 0 represents the first + the following key-value pairs: + category_id (int): Predicted category ID. 0 represents the first category in the dataset, and so on. - category(str): category name - bbox(list): bounding box in [x, y, w, h] format - score(str): confidence - mask(dict): Only for instance segmentation task. Mask of the object in - RLE format + category (str): Category name. + bbox (list): Bounding box in [x, y, w, h] format. + score (str): Confidence. + mask (dict): Only for instance segmentation task. Mask of the object in + RLE format. + + If `img_file` is a list, the result is a list composed of list of dicts + with the above keys. """ if transforms is None and not hasattr(self, 'test_transforms'): @@ -926,6 +913,26 @@ class PicoDet(BaseDetector): in_args['optimizer'] = optimizer return in_args + def build_data_loader(self, dataset, batch_size, mode='train'): + if dataset.num_samples < batch_size: + raise ValueError( + 'The volume of dataset({}) must be larger than batch size({}).' + .format(dataset.num_samples, batch_size)) + + if mode != 'train': + return paddle.io.DataLoader( + dataset, + batch_size=batch_size, + shuffle=dataset.shuffle, + drop_last=False, + collate_fn=dataset.batch_transforms, + num_workers=dataset.num_workers, + return_list=True, + use_shared_memory=False) + else: + return super(BaseDetector, self).build_data_loader(dataset, + batch_size, mode) + class YOLOv3(BaseDetector): def __init__(self, diff --git a/paddlers/tasks/restorer.py b/paddlers/tasks/restorer.py new file mode 100644 index 0000000..fe17f82 --- /dev/null +++ b/paddlers/tasks/restorer.py @@ -0,0 +1,936 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import os.path as osp +from collections import OrderedDict + +import numpy as np +import cv2 +import paddle +import paddle.nn.functional as F +from paddle.static import InputSpec + +import paddlers +import paddlers.models.ppgan as ppgan +import paddlers.rs_models.res as cmres +import paddlers.models.ppgan.metrics as metrics +import paddlers.utils.logging as logging +from paddlers.models import res_losses +from paddlers.transforms import Resize, decode_image +from paddlers.transforms.functions import calc_hr_shape +from paddlers.utils import get_single_card_bs +from .base import BaseModel +from .utils.res_adapters import GANAdapter, OptimizerAdapter +from .utils.infer_nets import InferResNet + +__all__ = ["DRN", "LESRCNN", "ESRGAN"] + + +class BaseRestorer(BaseModel): + MIN_MAX = (0., 1.) + TEST_OUT_KEY = None + + def __init__(self, model_name, losses=None, sr_factor=None, **params): + self.init_params = locals() + if 'with_net' in self.init_params: + del self.init_params['with_net'] + super(BaseRestorer, self).__init__('restorer') + self.model_name = model_name + self.losses = losses + self.sr_factor = sr_factor + if params.get('with_net', True): + params.pop('with_net', None) + self.net = self.build_net(**params) + self.find_unused_parameters = True + + def build_net(self, **params): + # Currently, only use models from cmres. + if not hasattr(cmres, self.model_name): + raise ValueError("ERROR: There is no model named {}.".format( + model_name)) + net = dict(**cmres.__dict__)[self.model_name](**params) + return net + + def _build_inference_net(self): + # For GAN models, only the generator will be used for inference. + if isinstance(self.net, GANAdapter): + infer_net = InferResNet( + self.net.generator, out_key=self.TEST_OUT_KEY) + else: + infer_net = InferResNet(self.net, out_key=self.TEST_OUT_KEY) + infer_net.eval() + return infer_net + + def _fix_transforms_shape(self, image_shape): + if hasattr(self, 'test_transforms'): + if self.test_transforms is not None: + has_resize_op = False + resize_op_idx = -1 + normalize_op_idx = len(self.test_transforms.transforms) + for idx, op in enumerate(self.test_transforms.transforms): + name = op.__class__.__name__ + if name == 'Normalize': + normalize_op_idx = idx + if 'Resize' in name: + has_resize_op = True + resize_op_idx = idx + + if not has_resize_op: + self.test_transforms.transforms.insert( + normalize_op_idx, Resize(target_size=image_shape)) + else: + self.test_transforms.transforms[resize_op_idx] = Resize( + target_size=image_shape) + + def _get_test_inputs(self, image_shape): + if image_shape is not None: + if len(image_shape) == 2: + image_shape = [1, 3] + image_shape + self._fix_transforms_shape(image_shape[-2:]) + else: + image_shape = [None, 3, -1, -1] + self.fixed_input_shape = image_shape + input_spec = [ + InputSpec( + shape=image_shape, name='image', dtype='float32') + ] + return input_spec + + def run(self, net, inputs, mode): + outputs = OrderedDict() + + if mode == 'test': + tar_shape = inputs[1] + if self.status == 'Infer': + net_out = net(inputs[0]) + res_map_list = self.postprocess( + net_out, tar_shape, transforms=inputs[2]) + else: + if isinstance(net, GANAdapter): + net_out = net.generator(inputs[0]) + else: + net_out = net(inputs[0]) + if self.TEST_OUT_KEY is not None: + net_out = net_out[self.TEST_OUT_KEY] + pred = self.postprocess( + net_out, tar_shape, transforms=inputs[2]) + res_map_list = [] + for res_map in pred: + res_map = self._tensor_to_images(res_map) + res_map_list.append(res_map) + outputs['res_map'] = res_map_list + + if mode == 'eval': + if isinstance(net, GANAdapter): + net_out = net.generator(inputs[0]) + else: + net_out = net(inputs[0]) + if self.TEST_OUT_KEY is not None: + net_out = net_out[self.TEST_OUT_KEY] + tar = inputs[1] + tar_shape = [tar.shape[-2:]] + pred = self.postprocess( + net_out, tar_shape, transforms=inputs[2])[0] # NCHW + pred = self._tensor_to_images(pred) + outputs['pred'] = pred + tar = self._tensor_to_images(tar) + outputs['tar'] = tar + + if mode == 'train': + # This is used by non-GAN models. + # For GAN models, self.run_gan() should be used. + net_out = net(inputs[0]) + loss = self.losses(net_out, inputs[1]) + outputs['loss'] = loss + return outputs + + def run_gan(self, net, inputs, mode, gan_mode): + raise NotImplementedError + + def default_loss(self): + return res_losses.L1Loss() + + def default_optimizer(self, + parameters, + learning_rate, + num_epochs, + num_steps_each_epoch, + lr_decay_power=0.9): + decay_step = num_epochs * num_steps_each_epoch + lr_scheduler = paddle.optimizer.lr.PolynomialDecay( + learning_rate, decay_step, end_lr=0, power=lr_decay_power) + optimizer = paddle.optimizer.Momentum( + learning_rate=lr_scheduler, + parameters=parameters, + momentum=0.9, + weight_decay=4e-5) + return optimizer + + def train(self, + num_epochs, + train_dataset, + train_batch_size=2, + eval_dataset=None, + optimizer=None, + save_interval_epochs=1, + log_interval_steps=2, + save_dir='output', + pretrain_weights=None, + learning_rate=0.01, + lr_decay_power=0.9, + early_stop=False, + early_stop_patience=5, + use_vdl=True, + resume_checkpoint=None): + """ + Train the model. + + Args: + num_epochs (int): Number of epochs. + train_dataset (paddlers.datasets.ResDataset): Training dataset. + train_batch_size (int, optional): Total batch size among all cards used in + training. Defaults to 2. + eval_dataset (paddlers.datasets.ResDataset|None, optional): Evaluation dataset. + If None, the model will not be evaluated during training process. + Defaults to None. + optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in + training. If None, a default optimizer will be used. Defaults to None. + save_interval_epochs (int, optional): Epoch interval for saving the model. + Defaults to 1. + log_interval_steps (int, optional): Step interval for printing training + information. Defaults to 2. + save_dir (str, optional): Directory to save the model. Defaults to 'output'. + pretrain_weights (str|None, optional): None or name/path of pretrained + weights. If None, no pretrained weights will be loaded. + Defaults to None. + learning_rate (float, optional): Learning rate for training. Defaults to .01. + lr_decay_power (float, optional): Learning decay power. Defaults to .9. + early_stop (bool, optional): Whether to adopt early stop strategy. Defaults + to False. + early_stop_patience (int, optional): Early stop patience. Defaults to 5. + use_vdl (bool, optional): Whether to use VisualDL to monitor the training + process. Defaults to True. + resume_checkpoint (str|None, optional): Path of the checkpoint to resume + training from. If None, no training checkpoint will be resumed. At most + Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously. + Defaults to None. + """ + + if self.status == 'Infer': + logging.error( + "Exported inference model does not support training.", + exit=True) + if pretrain_weights is not None and resume_checkpoint is not None: + logging.error( + "pretrain_weights and resume_checkpoint cannot be set simultaneously.", + exit=True) + + if self.losses is None: + self.losses = self.default_loss() + + if optimizer is None: + num_steps_each_epoch = train_dataset.num_samples // train_batch_size + if isinstance(self.net, GANAdapter): + parameters = {'params_g': [], 'params_d': []} + for net_g in self.net.generators: + parameters['params_g'].append(net_g.parameters()) + for net_d in self.net.discriminators: + parameters['params_d'].append(net_d.parameters()) + else: + parameters = self.net.parameters() + self.optimizer = self.default_optimizer( + parameters, learning_rate, num_epochs, num_steps_each_epoch, + lr_decay_power) + else: + self.optimizer = optimizer + + if pretrain_weights is not None and not osp.exists(pretrain_weights): + logging.warning("Path of pretrain_weights('{}') does not exist!". + format(pretrain_weights)) + elif pretrain_weights is not None and osp.exists(pretrain_weights): + if osp.splitext(pretrain_weights)[-1] != '.pdparams': + logging.error( + "Invalid pretrain weights. Please specify a '.pdparams' file.", + exit=True) + pretrained_dir = osp.join(save_dir, 'pretrain') + is_backbone_weights = pretrain_weights == 'IMAGENET' + self.net_initialize( + pretrain_weights=pretrain_weights, + save_dir=pretrained_dir, + resume_checkpoint=resume_checkpoint, + is_backbone_weights=is_backbone_weights) + + self.train_loop( + num_epochs=num_epochs, + train_dataset=train_dataset, + train_batch_size=train_batch_size, + eval_dataset=eval_dataset, + save_interval_epochs=save_interval_epochs, + log_interval_steps=log_interval_steps, + save_dir=save_dir, + early_stop=early_stop, + early_stop_patience=early_stop_patience, + use_vdl=use_vdl) + + def quant_aware_train(self, + num_epochs, + train_dataset, + train_batch_size=2, + eval_dataset=None, + optimizer=None, + save_interval_epochs=1, + log_interval_steps=2, + save_dir='output', + learning_rate=0.0001, + lr_decay_power=0.9, + early_stop=False, + early_stop_patience=5, + use_vdl=True, + resume_checkpoint=None, + quant_config=None): + """ + Quantization-aware training. + + Args: + num_epochs (int): Number of epochs. + train_dataset (paddlers.datasets.ResDataset): Training dataset. + train_batch_size (int, optional): Total batch size among all cards used in + training. Defaults to 2. + eval_dataset (paddlers.datasets.ResDataset|None, optional): Evaluation dataset. + If None, the model will not be evaluated during training process. + Defaults to None. + optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in + training. If None, a default optimizer will be used. Defaults to None. + save_interval_epochs (int, optional): Epoch interval for saving the model. + Defaults to 1. + log_interval_steps (int, optional): Step interval for printing training + information. Defaults to 2. + save_dir (str, optional): Directory to save the model. Defaults to 'output'. + learning_rate (float, optional): Learning rate for training. + Defaults to .0001. + lr_decay_power (float, optional): Learning decay power. Defaults to .9. + early_stop (bool, optional): Whether to adopt early stop strategy. + Defaults to False. + early_stop_patience (int, optional): Early stop patience. Defaults to 5. + use_vdl (bool, optional): Whether to use VisualDL to monitor the training + process. Defaults to True. + quant_config (dict|None, optional): Quantization configuration. If None, + a default rule of thumb configuration will be used. Defaults to None. + resume_checkpoint (str|None, optional): Path of the checkpoint to resume + quantization-aware training from. If None, no training checkpoint will + be resumed. Defaults to None. + """ + + self._prepare_qat(quant_config) + self.train( + num_epochs=num_epochs, + train_dataset=train_dataset, + train_batch_size=train_batch_size, + eval_dataset=eval_dataset, + optimizer=optimizer, + save_interval_epochs=save_interval_epochs, + log_interval_steps=log_interval_steps, + save_dir=save_dir, + pretrain_weights=None, + learning_rate=learning_rate, + lr_decay_power=lr_decay_power, + early_stop=early_stop, + early_stop_patience=early_stop_patience, + use_vdl=use_vdl, + resume_checkpoint=resume_checkpoint) + + def evaluate(self, eval_dataset, batch_size=1, return_details=False): + """ + Evaluate the model. + + Args: + eval_dataset (paddlers.datasets.ResDataset): Evaluation dataset. + batch_size (int, optional): Total batch size among all cards used for + evaluation. Defaults to 1. + return_details (bool, optional): Whether to return evaluation details. + Defaults to False. + + Returns: + If `return_details` is False, return collections.OrderedDict with + key-value pairs: + {"psnr": `peak signal-to-noise ratio`, + "ssim": `structural similarity`}. + + """ + + self._check_transforms(eval_dataset.transforms, 'eval') + + self.net.eval() + nranks = paddle.distributed.get_world_size() + local_rank = paddle.distributed.get_rank() + if nranks > 1: + # Initialize parallel environment if not done. + if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized( + ): + paddle.distributed.init_parallel_env() + + # TODO: Distributed evaluation + if batch_size > 1: + logging.warning( + "Restorer only supports single card evaluation with batch_size=1 " + "during evaluation, so batch_size is forcibly set to 1.") + batch_size = 1 + + if nranks < 2 or local_rank == 0: + self.eval_data_loader = self.build_data_loader( + eval_dataset, batch_size=batch_size, mode='eval') + # XXX: Hard-code crop_border and test_y_channel + psnr = metrics.PSNR(crop_border=4, test_y_channel=True) + ssim = metrics.SSIM(crop_border=4, test_y_channel=True) + logging.info( + "Start to evaluate(total_samples={}, total_steps={})...".format( + eval_dataset.num_samples, eval_dataset.num_samples)) + with paddle.no_grad(): + for step, data in enumerate(self.eval_data_loader): + data.append(eval_dataset.transforms.transforms) + outputs = self.run(self.net, data, 'eval') + psnr.update(outputs['pred'], outputs['tar']) + ssim.update(outputs['pred'], outputs['tar']) + + # DO NOT use psnr.accumulate() here, otherwise the program hangs in multi-card training. + assert len(psnr.results) > 0 + assert len(ssim.results) > 0 + eval_metrics = OrderedDict( + zip(['psnr', 'ssim'], + [np.mean(psnr.results), np.mean(ssim.results)])) + + if return_details: + # TODO: Add details + return eval_metrics, None + + return eval_metrics + + def predict(self, img_file, transforms=None): + """ + Do inference. + + Args: + img_file (list[np.ndarray|str] | str | np.ndarray): Image path or decoded + image data, which also could constitute a list, meaning all images to be + predicted as a mini-batch. + transforms (paddlers.transforms.Compose|None, optional): Transforms for + inputs. If None, the transforms for evaluation process will be used. + Defaults to None. + + Returns: + If `img_file` is a tuple of string or np.array, the result is a dict with + the following key-value pairs: + res_map (np.ndarray): Restored image (HWC). + + If `img_file` is a list, the result is a list composed of dicts with the + above keys. + """ + + if transforms is None and not hasattr(self, 'test_transforms'): + raise ValueError("transforms need to be defined, now is None.") + if transforms is None: + transforms = self.test_transforms + if isinstance(img_file, (str, np.ndarray)): + images = [img_file] + else: + images = img_file + batch_im, batch_tar_shape = self.preprocess(images, transforms, + self.model_type) + self.net.eval() + data = (batch_im, batch_tar_shape, transforms.transforms) + outputs = self.run(self.net, data, 'test') + res_map_list = outputs['res_map'] + if isinstance(img_file, list): + prediction = [{'res_map': m} for m in res_map_list] + else: + prediction = {'res_map': res_map_list[0]} + return prediction + + def preprocess(self, images, transforms, to_tensor=True): + self._check_transforms(transforms, 'test') + batch_im = list() + batch_tar_shape = list() + for im in images: + if isinstance(im, str): + im = decode_image(im, to_rgb=False) + ori_shape = im.shape[:2] + sample = {'image': im} + im = transforms(sample)[0] + batch_im.append(im) + batch_tar_shape.append(self._get_target_shape(ori_shape)) + if to_tensor: + batch_im = paddle.to_tensor(batch_im) + else: + batch_im = np.asarray(batch_im) + + return batch_im, batch_tar_shape + + def _get_target_shape(self, ori_shape): + if self.sr_factor is None: + return ori_shape + else: + return calc_hr_shape(ori_shape, self.sr_factor) + + @staticmethod + def get_transforms_shape_info(batch_tar_shape, transforms): + batch_restore_list = list() + for tar_shape in batch_tar_shape: + restore_list = list() + h, w = tar_shape[0], tar_shape[1] + for op in transforms: + if op.__class__.__name__ == 'Resize': + restore_list.append(('resize', (h, w))) + h, w = op.target_size + elif op.__class__.__name__ == 'ResizeByShort': + restore_list.append(('resize', (h, w))) + im_short_size = min(h, w) + im_long_size = max(h, w) + scale = float(op.short_size) / float(im_short_size) + if 0 < op.max_size < np.round(scale * im_long_size): + scale = float(op.max_size) / float(im_long_size) + h = int(round(h * scale)) + w = int(round(w * scale)) + elif op.__class__.__name__ == 'ResizeByLong': + restore_list.append(('resize', (h, w))) + im_long_size = max(h, w) + scale = float(op.long_size) / float(im_long_size) + h = int(round(h * scale)) + w = int(round(w * scale)) + elif op.__class__.__name__ == 'Pad': + if op.target_size: + target_h, target_w = op.target_size + else: + target_h = int( + (np.ceil(h / op.size_divisor) * op.size_divisor)) + target_w = int( + (np.ceil(w / op.size_divisor) * op.size_divisor)) + + if op.pad_mode == -1: + offsets = op.offsets + elif op.pad_mode == 0: + offsets = [0, 0] + elif op.pad_mode == 1: + offsets = [(target_h - h) // 2, (target_w - w) // 2] + else: + offsets = [target_h - h, target_w - w] + restore_list.append(('padding', (h, w), offsets)) + h, w = target_h, target_w + + batch_restore_list.append(restore_list) + return batch_restore_list + + def postprocess(self, batch_pred, batch_tar_shape, transforms): + batch_restore_list = BaseRestorer.get_transforms_shape_info( + batch_tar_shape, transforms) + if self.status == 'Infer': + return self._infer_postprocess( + batch_res_map=batch_pred, batch_restore_list=batch_restore_list) + results = [] + if batch_pred.dtype == paddle.float32: + mode = 'bilinear' + else: + mode = 'nearest' + for pred, restore_list in zip(batch_pred, batch_restore_list): + pred = paddle.unsqueeze(pred, axis=0) + for item in restore_list[::-1]: + h, w = item[1][0], item[1][1] + if item[0] == 'resize': + pred = F.interpolate( + pred, (h, w), mode=mode, data_format='NCHW') + elif item[0] == 'padding': + x, y = item[2] + pred = pred[:, :, y:y + h, x:x + w] + else: + pass + results.append(pred) + return results + + def _infer_postprocess(self, batch_res_map, batch_restore_list): + res_maps = [] + for res_map, restore_list in zip(batch_res_map, batch_restore_list): + if not isinstance(res_map, np.ndarray): + res_map = paddle.unsqueeze(res_map, axis=0) + for item in restore_list[::-1]: + h, w = item[1][0], item[1][1] + if item[0] == 'resize': + if isinstance(res_map, np.ndarray): + res_map = cv2.resize( + res_map, (w, h), interpolation=cv2.INTER_LINEAR) + else: + res_map = F.interpolate( + res_map, (h, w), + mode='bilinear', + data_format='NHWC') + elif item[0] == 'padding': + x, y = item[2] + if isinstance(res_map, np.ndarray): + res_map = res_map[y:y + h, x:x + w] + else: + res_map = res_map[:, y:y + h, x:x + w, :] + else: + pass + res_map = res_map.squeeze() + if not isinstance(res_map, np.ndarray): + res_map = res_map.numpy() + res_map = self._normalize(res_map) + res_maps.append(res_map.squeeze()) + return res_maps + + def _check_transforms(self, transforms, mode): + super()._check_transforms(transforms, mode) + if not isinstance(transforms.arrange, + paddlers.transforms.ArrangeRestorer): + raise TypeError( + "`transforms.arrange` must be an ArrangeRestorer object.") + + def build_data_loader(self, dataset, batch_size, mode='train'): + if dataset.num_samples < batch_size: + raise ValueError( + 'The volume of dataset({}) must be larger than batch size({}).' + .format(dataset.num_samples, batch_size)) + + if mode != 'train': + return paddle.io.DataLoader( + dataset, + batch_size=batch_size, + shuffle=dataset.shuffle, + drop_last=False, + collate_fn=dataset.batch_transforms, + num_workers=dataset.num_workers, + return_list=True, + use_shared_memory=False) + else: + return super(BaseRestorer, self).build_data_loader(dataset, + batch_size, mode) + + def set_losses(self, losses): + self.losses = losses + + def _tensor_to_images(self, + tensor, + transpose=True, + squeeze=True, + quantize=True): + if transpose: + tensor = paddle.transpose(tensor, perm=[0, 2, 3, 1]) # NHWC + if squeeze: + tensor = tensor.squeeze() + images = tensor.numpy().astype('float32') + images = self._normalize( + images, copy=True, clip=True, quantize=quantize) + return images + + def _normalize(self, im, copy=False, clip=True, quantize=True): + if copy: + im = im.copy() + if clip: + im = np.clip(im, self.MIN_MAX[0], self.MIN_MAX[1]) + im -= im.min() + im /= im.max() + 1e-32 + if quantize: + im *= 255 + im = im.astype('uint8') + return im + + +class DRN(BaseRestorer): + TEST_OUT_KEY = -1 + + def __init__(self, + losses=None, + sr_factor=4, + scales=(2, 4), + n_blocks=30, + n_feats=16, + n_colors=3, + rgb_range=1.0, + negval=0.2, + lq_loss_weight=0.1, + dual_loss_weight=0.1, + **params): + if sr_factor != max(scales): + raise ValueError(f"`sr_factor` must be equal to `max(scales)`.") + params.update({ + 'scale': scales, + 'n_blocks': n_blocks, + 'n_feats': n_feats, + 'n_colors': n_colors, + 'rgb_range': rgb_range, + 'negval': negval + }) + self.lq_loss_weight = lq_loss_weight + self.dual_loss_weight = dual_loss_weight + self.scales = scales + super(DRN, self).__init__( + model_name='DRN', losses=losses, sr_factor=sr_factor, **params) + + def build_net(self, **params): + from ppgan.modules.init import init_weights + generators = [ppgan.models.generators.DRNGenerator(**params)] + init_weights(generators[-1]) + for scale in params['scale']: + dual_model = ppgan.models.generators.drn.DownBlock( + params['negval'], params['n_feats'], params['n_colors'], 2) + generators.append(dual_model) + init_weights(generators[-1]) + return GANAdapter(generators, []) + + def default_optimizer(self, parameters, *args, **kwargs): + optims_g = [ + super(DRN, self).default_optimizer(params_g, *args, **kwargs) + for params_g in parameters['params_g'] + ] + return OptimizerAdapter(*optims_g) + + def run_gan(self, net, inputs, mode, gan_mode='forward_primary'): + if mode != 'train': + raise ValueError("`mode` is not 'train'.") + outputs = OrderedDict() + if gan_mode == 'forward_primary': + sr = net.generator(inputs[0]) + lr = [inputs[0]] + lr.extend([ + F.interpolate( + inputs[0], scale_factor=s, mode='bicubic') + for s in self.scales[:-1] + ]) + loss = self.losses(sr[-1], inputs[1]) + for i in range(1, len(sr)): + if self.lq_loss_weight > 0: + loss += self.losses(sr[i - 1 - len(sr)], + lr[i - len(sr)]) * self.lq_loss_weight + outputs['loss_prim'] = loss + outputs['sr'] = sr + outputs['lr'] = lr + elif gan_mode == 'forward_dual': + sr, lr = inputs[0], inputs[1] + sr2lr = [] + n_scales = len(self.scales) + for i in range(n_scales): + sr2lr_i = net.generators[1 + i](sr[i - n_scales]) + sr2lr.append(sr2lr_i) + loss = self.losses(sr2lr[0], lr[0]) + for i in range(1, n_scales): + if self.dual_loss_weight > 0.0: + loss += self.losses(sr2lr[i], lr[i]) * self.dual_loss_weight + outputs['loss_dual'] = loss + else: + raise ValueError("Invalid `gan_mode`!") + return outputs + + def train_step(self, step, data, net): + outputs = self.run_gan( + net, data, mode='train', gan_mode='forward_primary') + outputs.update( + self.run_gan( + net, (outputs['sr'], outputs['lr']), + mode='train', + gan_mode='forward_dual')) + self.optimizer.clear_grad() + (outputs['loss_prim'] + outputs['loss_dual']).backward() + self.optimizer.step() + return { + 'loss': outputs['loss_prim'] + outputs['loss_dual'], + 'loss_prim': outputs['loss_prim'], + 'loss_dual': outputs['loss_dual'] + } + + +class LESRCNN(BaseRestorer): + def __init__(self, + losses=None, + sr_factor=4, + multi_scale=False, + group=1, + **params): + params.update({ + 'scale': sr_factor, + 'multi_scale': multi_scale, + 'group': group + }) + super(LESRCNN, self).__init__( + model_name='LESRCNN', losses=losses, sr_factor=sr_factor, **params) + + def build_net(self, **params): + net = ppgan.models.generators.LESRCNNGenerator(**params) + return net + + +class ESRGAN(BaseRestorer): + def __init__(self, + losses=None, + sr_factor=4, + use_gan=True, + in_channels=3, + out_channels=3, + nf=64, + nb=23, + **params): + if sr_factor != 4: + raise ValueError("`sr_factor` must be 4.") + params.update({ + 'in_nc': in_channels, + 'out_nc': out_channels, + 'nf': nf, + 'nb': nb + }) + self.use_gan = use_gan + super(ESRGAN, self).__init__( + model_name='ESRGAN', losses=losses, sr_factor=sr_factor, **params) + + def build_net(self, **params): + from ppgan.modules.init import init_weights + generator = ppgan.models.generators.RRDBNet(**params) + init_weights(generator) + if self.use_gan: + discriminator = ppgan.models.discriminators.VGGDiscriminator128( + in_channels=params['out_nc'], num_feat=64) + net = GANAdapter( + generators=[generator], discriminators=[discriminator]) + else: + net = generator + return net + + def default_loss(self): + if self.use_gan: + return { + 'pixel': res_losses.L1Loss(loss_weight=0.01), + 'perceptual': res_losses.PerceptualLoss( + layer_weights={'34': 1.0}, + perceptual_weight=1.0, + style_weight=0.0, + norm_img=False), + 'gan': res_losses.GANLoss( + gan_mode='vanilla', loss_weight=0.005) + } + else: + return res_losses.L1Loss() + + def default_optimizer(self, parameters, *args, **kwargs): + if self.use_gan: + optim_g = super(ESRGAN, self).default_optimizer( + parameters['params_g'][0], *args, **kwargs) + optim_d = super(ESRGAN, self).default_optimizer( + parameters['params_d'][0], *args, **kwargs) + return OptimizerAdapter(optim_g, optim_d) + else: + return super(ESRGAN, self).default_optimizer(parameters, *args, + **kwargs) + + def run_gan(self, net, inputs, mode, gan_mode='forward_g'): + if mode != 'train': + raise ValueError("`mode` is not 'train'.") + outputs = OrderedDict() + if gan_mode == 'forward_g': + loss_g = 0 + g_pred = net.generator(inputs[0]) + loss_pix = self.losses['pixel'](g_pred, inputs[1]) + loss_perc, loss_sty = self.losses['perceptual'](g_pred, inputs[1]) + loss_g += loss_pix + if loss_perc is not None: + loss_g += loss_perc + if loss_sty is not None: + loss_g += loss_sty + self._set_requires_grad(net.discriminator, False) + real_d_pred = net.discriminator(inputs[1]).detach() + fake_g_pred = net.discriminator(g_pred) + loss_g_real = self.losses['gan']( + real_d_pred - paddle.mean(fake_g_pred), False, + is_disc=False) * 0.5 + loss_g_fake = self.losses['gan']( + fake_g_pred - paddle.mean(real_d_pred), True, + is_disc=False) * 0.5 + loss_g_gan = loss_g_real + loss_g_fake + outputs['g_pred'] = g_pred.detach() + outputs['loss_g_pps'] = loss_g + outputs['loss_g_gan'] = loss_g_gan + elif gan_mode == 'forward_d': + self._set_requires_grad(net.discriminator, True) + # Real + fake_d_pred = net.discriminator(inputs[0]).detach() + real_d_pred = net.discriminator(inputs[1]) + loss_d_real = self.losses['gan']( + real_d_pred - paddle.mean(fake_d_pred), True, + is_disc=True) * 0.5 + # Fake + fake_d_pred = net.discriminator(inputs[0].detach()) + loss_d_fake = self.losses['gan']( + fake_d_pred - paddle.mean(real_d_pred.detach()), + False, + is_disc=True) * 0.5 + outputs['loss_d'] = loss_d_real + loss_d_fake + else: + raise ValueError("Invalid `gan_mode`!") + return outputs + + def train_step(self, step, data, net): + if self.use_gan: + optim_g, optim_d = self.optimizer + + outputs = self.run_gan( + net, data, mode='train', gan_mode='forward_g') + optim_g.clear_grad() + (outputs['loss_g_pps'] + outputs['loss_g_gan']).backward() + optim_g.step() + + outputs.update( + self.run_gan( + net, (outputs['g_pred'], data[1]), + mode='train', + gan_mode='forward_d')) + optim_d.clear_grad() + outputs['loss_d'].backward() + optim_d.step() + + outputs['loss'] = outputs['loss_g_pps'] + outputs[ + 'loss_g_gan'] + outputs['loss_d'] + + return { + 'loss': outputs['loss'], + 'loss_g_pps': outputs['loss_g_pps'], + 'loss_g_gan': outputs['loss_g_gan'], + 'loss_d': outputs['loss_d'] + } + else: + return super(ESRGAN, self).train_step(step, data, net) + + def _set_requires_grad(self, net, requires_grad): + for p in net.parameters(): + p.trainable = requires_grad + + +class RCAN(BaseRestorer): + def __init__(self, + losses=None, + sr_factor=4, + n_resgroups=10, + n_resblocks=20, + n_feats=64, + n_colors=3, + rgb_range=1.0, + kernel_size=3, + reduction=16, + **params): + params.update({ + 'n_resgroups': n_resgroups, + 'n_resblocks': n_resblocks, + 'n_feats': n_feats, + 'n_colors': n_colors, + 'rgb_range': rgb_range, + 'kernel_size': kernel_size, + 'reduction': reduction + }) + super(RCAN, self).__init__( + model_name='RCAN', losses=losses, sr_factor=sr_factor, **params) diff --git a/paddlers/tasks/segmenter.py b/paddlers/tasks/segmenter.py index 9800a3c..b9c586f 100644 --- a/paddlers/tasks/segmenter.py +++ b/paddlers/tasks/segmenter.py @@ -33,6 +33,7 @@ from paddlers.utils import get_single_card_bs, DisablePrint from paddlers.utils.checkpoint import seg_pretrain_weights_dict from .base import BaseModel from .utils import seg_metrics as metrics +from .utils.infer_nets import InferSegNet __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"] @@ -64,11 +65,16 @@ class BaseSegmenter(BaseModel): def build_net(self, **params): # TODO: when using paddle.utils.unique_name.guard, - # DeepLabv3p and HRNet will raise a error + # DeepLabv3p and HRNet will raise an error. net = dict(ppseg.models.__dict__, **cmseg.__dict__)[self.model_name]( num_classes=self.num_classes, **params) return net + def _build_inference_net(self): + infer_net = InferSegNet(self.net) + infer_net.eval() + return infer_net + def _fix_transforms_shape(self, image_shape): if hasattr(self, 'test_transforms'): if self.test_transforms is not None: @@ -472,7 +478,6 @@ class BaseSegmenter(BaseModel): conf_mat_all.append(conf_mat) class_iou, miou = ppseg.utils.metrics.mean_iou( intersect_area_all, pred_area_all, label_area_all) - # TODO 确认是按oacc还是macc class_acc, oacc = ppseg.utils.metrics.accuracy(intersect_area_all, pred_area_all) kappa = ppseg.utils.metrics.kappa(intersect_area_all, pred_area_all, @@ -504,13 +509,13 @@ class BaseSegmenter(BaseModel): Defaults to None. Returns: - If `img_file` is a string or np.array, the result is a dict with key-value - pairs: - {"label map": `label map`, "score_map": `score map`}. + If `img_file` is a tuple of string or np.array, the result is a dict with + the following key-value pairs: + label_map (np.ndarray): Predicted label map (HW). + score_map (np.ndarray): Prediction score map (HWC). + If `img_file` is a list, the result is a list composed of dicts with the - corresponding fields: - label_map (np.ndarray): the predicted label map (HW) - score_map (np.ndarray): the prediction score map (HWC) + above keys. """ if transforms is None and not hasattr(self, 'test_transforms'): @@ -750,11 +755,11 @@ class BaseSegmenter(BaseModel): elif item[0] == 'padding': x, y = item[2] if isinstance(label_map, np.ndarray): - label_map = label_map[..., y:y + h, x:x + w] - score_map = score_map[..., y:y + h, x:x + w] + label_map = label_map[y:y + h, x:x + w] + score_map = score_map[y:y + h, x:x + w] else: - label_map = label_map[:, :, y:y + h, x:x + w] - score_map = score_map[:, :, y:y + h, x:x + w] + label_map = label_map[:, y:y + h, x:x + w, :] + score_map = score_map[:, y:y + h, x:x + w, :] else: pass label_map = label_map.squeeze() diff --git a/paddlers/tasks/utils/infer_nets.py b/paddlers/tasks/utils/infer_nets.py index e35731c..f20b94b 100644 --- a/paddlers/tasks/utils/infer_nets.py +++ b/paddlers/tasks/utils/infer_nets.py @@ -15,30 +15,36 @@ import paddle -class PostProcessor(paddle.nn.Layer): - def __init__(self, model_type): - super(PostProcessor, self).__init__() - self.model_type = model_type - +class SegPostProcessor(paddle.nn.Layer): def forward(self, net_outputs): # label_map [NHW], score_map [NHWC] logit = net_outputs[0] outputs = paddle.argmax(logit, axis=1, keepdim=False, dtype='int32'), \ paddle.transpose(paddle.nn.functional.softmax(logit, axis=1), perm=[0, 2, 3, 1]) + return outputs + + +class ResPostProcessor(paddle.nn.Layer): + def __init__(self, out_key=None): + super(ResPostProcessor, self).__init__() + self.out_key = out_key + def forward(self, net_outputs): + if self.out_key is not None: + net_outputs = net_outputs[self.out_key] + outputs = paddle.transpose(net_outputs, perm=[0, 2, 3, 1]) return outputs -class InferNet(paddle.nn.Layer): - def __init__(self, net, model_type): - super(InferNet, self).__init__() +class InferSegNet(paddle.nn.Layer): + def __init__(self, net): + super(InferSegNet, self).__init__() self.net = net - self.postprocessor = PostProcessor(model_type) + self.postprocessor = SegPostProcessor() def forward(self, x): net_outputs = self.net(x) outputs = self.postprocessor(net_outputs) - return outputs @@ -46,10 +52,21 @@ class InferCDNet(paddle.nn.Layer): def __init__(self, net): super(InferCDNet, self).__init__() self.net = net - self.postprocessor = PostProcessor('change_detector') + self.postprocessor = SegPostProcessor() def forward(self, x1, x2): net_outputs = self.net(x1, x2) outputs = self.postprocessor(net_outputs) + return outputs + + +class InferResNet(paddle.nn.Layer): + def __init__(self, net, out_key=None): + super(InferResNet, self).__init__() + self.net = net + self.postprocessor = ResPostProcessor(out_key=out_key) + def forward(self, x): + net_outputs = self.net(x) + outputs = self.postprocessor(net_outputs) return outputs diff --git a/paddlers/tasks/utils/res_adapters.py b/paddlers/tasks/utils/res_adapters.py new file mode 100644 index 0000000..eba8106 --- /dev/null +++ b/paddlers/tasks/utils/res_adapters.py @@ -0,0 +1,132 @@ +from functools import wraps +from inspect import isfunction, isgeneratorfunction, getmembers +from collections.abc import Sequence +from abc import ABC + +import paddle +import paddle.nn as nn + +__all__ = ['GANAdapter', 'OptimizerAdapter'] + + +class _AttrDesc: + def __init__(self, key): + self.key = key + + def __get__(self, instance, owner): + return tuple(getattr(ele, self.key) for ele in instance) + + def __set__(self, instance, value): + for ele in instance: + setattr(ele, self.key, value) + + +def _func_deco(cls, func_name): + @wraps(getattr(cls.__ducktype__, func_name)) + def _wrapper(self, *args, **kwargs): + return tuple(getattr(ele, func_name)(*args, **kwargs) for ele in self) + + return _wrapper + + +def _generator_deco(cls, func_name): + @wraps(getattr(cls.__ducktype__, func_name)) + def _wrapper(self, *args, **kwargs): + for ele in self: + yield from getattr(ele, func_name)(*args, **kwargs) + + return _wrapper + + +class Adapter(Sequence, ABC): + __ducktype__ = object + __ava__ = () + + def __init__(self, *args): + if not all(map(self._check, args)): + raise TypeError("Please check the input type.") + self._seq = tuple(args) + + def __getitem__(self, key): + return self._seq[key] + + def __len__(self): + return len(self._seq) + + def __repr__(self): + return repr(self._seq) + + @classmethod + def _check(cls, obj): + for attr in cls.__ava__: + try: + getattr(obj, attr) + # TODO: Check function signature + except AttributeError: + return False + return True + + +def make_adapter(cls): + members = dict(getmembers(cls.__ducktype__)) + for k in cls.__ava__: + if hasattr(cls, k): + continue + if k in members: + v = members[k] + if isgeneratorfunction(v): + setattr(cls, k, _generator_deco(cls, k)) + elif isfunction(v): + setattr(cls, k, _func_deco(cls, k)) + else: + setattr(cls, k, _AttrDesc(k)) + return cls + + +class GANAdapter(nn.Layer): + __ducktype__ = nn.Layer + __ava__ = ('state_dict', 'set_state_dict', 'train', 'eval') + + def __init__(self, generators, discriminators): + super(GANAdapter, self).__init__() + self.generators = nn.LayerList(generators) + self.discriminators = nn.LayerList(discriminators) + self._m = [*generators, *discriminators] + + def __len__(self): + return len(self._m) + + def __getitem__(self, key): + return self._m[key] + + def __contains__(self, m): + return m in self._m + + def __repr__(self): + return repr(self._m) + + @property + def generator(self): + return self.generators[0] + + @property + def discriminator(self): + return self.discriminators[0] + + +Adapter.register(GANAdapter) + + +@make_adapter +class OptimizerAdapter(Adapter): + __ducktype__ = paddle.optimizer.Optimizer + __ava__ = ('state_dict', 'set_state_dict', 'clear_grad', 'step', 'get_lr') + + def set_state_dict(self, state_dicts): + # Special dispatching rule + for optim, state_dict in zip(self, state_dicts): + optim.set_state_dict(state_dict) + + def get_lr(self): + # Return the lr of the first optimizer + return self[0].get_lr() diff --git a/paddlers/transforms/functions.py b/paddlers/transforms/functions.py index 12c3e9a..5550e33 100644 --- a/paddlers/transforms/functions.py +++ b/paddlers/transforms/functions.py @@ -638,3 +638,7 @@ def decode_seg_mask(mask_path): mask = np.asarray(Image.open(mask_path)) mask = mask.astype('int64') return mask + + +def calc_hr_shape(lr_shape, sr_factor): + return tuple(int(s * sr_factor) for s in lr_shape) diff --git a/paddlers/transforms/operators.py b/paddlers/transforms/operators.py index c7603b1..dd21c7a 100644 --- a/paddlers/transforms/operators.py +++ b/paddlers/transforms/operators.py @@ -35,7 +35,7 @@ from .functions import ( horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly, vertical_flip_rle, crop_poly, crop_rle, expand_poly, expand_rle, resize_poly, resize_rle, dehaze, select_bands, to_intensity, to_uint8, - img_flip, img_simple_rotate, decode_seg_mask) + img_flip, img_simple_rotate, decode_seg_mask, calc_hr_shape) __all__ = [ "Compose", "DecodeImg", "Resize", "RandomResize", "ResizeByShort", @@ -44,7 +44,7 @@ __all__ = [ "RandomScaleAspect", "RandomExpand", "Pad", "MixupImage", "RandomDistort", "RandomBlur", "RandomSwap", "Dehaze", "ReduceDim", "SelectBand", "ArrangeSegmenter", "ArrangeChangeDetector", "ArrangeClassifier", - "ArrangeDetector", "RandomFlipOrRotate", "ReloadMask" + "ArrangeDetector", "ArrangeRestorer", "RandomFlipOrRotate", "ReloadMask" ] interp_dict = { @@ -154,6 +154,8 @@ class Transform(object): if 'aux_masks' in sample: sample['aux_masks'] = list( map(self.apply_mask, sample['aux_masks'])) + if 'target' in sample: + sample['target'] = self.apply_im(sample['target']) return sample @@ -336,6 +338,14 @@ class DecodeImg(Transform): map(self.apply_mask, sample['aux_masks'])) # TODO: check the shape of auxiliary masks + if 'target' in sample: + if self.read_geo_info: + target, geo_info_dict = self.apply_im(sample['target']) + sample['target'] = target + sample['geo_info_dict_tar'] = geo_info_dict + else: + sample['target'] = self.apply_im(sample['target']) + sample['im_shape'] = np.array( sample['image'].shape[:2], dtype=np.float32) sample['scale_factor'] = np.array([1., 1.], dtype=np.float32) @@ -457,6 +467,17 @@ class Resize(Transform): if 'gt_poly' in sample and len(sample['gt_poly']) > 0: sample['gt_poly'] = self.apply_segm( sample['gt_poly'], [im_h, im_w], [im_scale_x, im_scale_y]) + if 'target' in sample: + if 'sr_factor' in sample: + # For SR tasks + sample['target'] = self.apply_im( + sample['target'], interp, + calc_hr_shape(target_size, sample['sr_factor'])) + else: + # For non-SR tasks + sample['target'] = self.apply_im(sample['target'], interp, + target_size) + sample['im_shape'] = np.asarray( sample['image'].shape[:2], dtype=np.float32) if 'scale_factor' in sample: @@ -730,6 +751,9 @@ class RandomFlipOrRotate(Transform): if 'gt_poly' in sample and len(sample['gt_poly']) > 0: sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id, True) + if 'target' in sample: + sample['target'] = self.apply_im(sample['target'], mode_id, + True) elif p_m < self.probs[1]: mode_p = random.random() mode_id = self.judge_probs_range(mode_p, self.probsr) @@ -750,6 +774,9 @@ class RandomFlipOrRotate(Transform): if 'gt_poly' in sample and len(sample['gt_poly']) > 0: sample['gt_poly'] = self.apply_segm(sample['gt_poly'], mode_id, False) + if 'target' in sample: + sample['target'] = self.apply_im(sample['target'], mode_id, + False) return sample @@ -809,6 +836,8 @@ class RandomHorizontalFlip(Transform): if 'gt_poly' in sample and len(sample['gt_poly']) > 0: sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h, im_w) + if 'target' in sample: + sample['target'] = self.apply_im(sample['target']) return sample @@ -867,6 +896,8 @@ class RandomVerticalFlip(Transform): if 'gt_poly' in sample and len(sample['gt_poly']) > 0: sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_h, im_w) + if 'target' in sample: + sample['target'] = self.apply_im(sample['target']) return sample @@ -883,16 +914,19 @@ class Normalize(Transform): std (list[float] | tuple[float], optional): Standard deviation of input image(s). Defaults to [0.229, 0.224, 0.225]. min_val (list[float] | tuple[float], optional): Minimum value of input - image(s). Defaults to [0, 0, 0, ]. - max_val (list[float] | tuple[float], optional): Max value of input image(s). - Defaults to [255., 255., 255.]. + image(s). If None, use 0 for all channels. Defaults to None. + max_val (list[float] | tuple[float], optional): Maximum value of input + image(s). If None, use 255. for all channels. Defaults to None. + apply_to_tar (bool, optional): Whether to apply transformation to the target + image. Defaults to True. """ def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], min_val=None, - max_val=None): + max_val=None, + apply_to_tar=True): super(Normalize, self).__init__() channel = len(mean) if min_val is None: @@ -914,6 +948,7 @@ class Normalize(Transform): self.std = std self.min_val = min_val self.max_val = max_val + self.apply_to_tar = apply_to_tar def apply_im(self, image): image = image.astype(np.float32) @@ -927,6 +962,8 @@ class Normalize(Transform): sample['image'] = self.apply_im(sample['image']) if 'image2' in sample: sample['image2'] = self.apply_im(sample['image2']) + if 'target' in sample and self.apply_to_tar: + sample['target'] = self.apply_im(sample['target']) return sample @@ -964,6 +1001,8 @@ class CenterCrop(Transform): if 'aux_masks' in sample: sample['aux_masks'] = list( map(self.apply_mask, sample['aux_masks'])) + if 'target' in sample: + sample['target'] = self.apply_im(sample['target']) return sample @@ -1165,6 +1204,14 @@ class RandomCrop(Transform): self.apply_mask, crop=crop_box), sample['aux_masks'])) + if 'target' in sample: + if 'sr_factor' in sample: + sample['target'] = self.apply_im( + sample['target'], + calc_hr_shape(crop_box, sample['sr_factor'])) + else: + sample['target'] = self.apply_im(sample['image'], crop_box) + if self.crop_size is not None: sample = Resize(self.crop_size)(sample) @@ -1266,6 +1313,7 @@ class Pad(Transform): pad_mode (int, optional): Pad mode. Currently only four modes are supported: [-1, 0, 1, 2]. if -1, use specified offsets. If 0, only pad to right and bottom If 1, pad according to center. If 2, only pad left and top. Defaults to 0. + offsets (list[int]|None, optional): Padding offsets. Defaults to None. im_padding_value (list[float] | tuple[float]): RGB value of padded area. Defaults to (127.5, 127.5, 127.5). label_padding_value (int, optional): Filling value for the mask. @@ -1332,6 +1380,17 @@ class Pad(Transform): expand_rle(segm, x, y, height, width, h, w)) return expanded_segms + def _get_offsets(self, im_h, im_w, h, w): + if self.pad_mode == -1: + offsets = self.offsets + elif self.pad_mode == 0: + offsets = [0, 0] + elif self.pad_mode == 1: + offsets = [(w - im_w) // 2, (h - im_h) // 2] + else: + offsets = [w - im_w, h - im_h] + return offsets + def apply(self, sample): im_h, im_w = sample['image'].shape[:2] if self.target_size: @@ -1349,14 +1408,7 @@ class Pad(Transform): if h == im_h and w == im_w: return sample - if self.pad_mode == -1: - offsets = self.offsets - elif self.pad_mode == 0: - offsets = [0, 0] - elif self.pad_mode == 1: - offsets = [(w - im_w) // 2, (h - im_h) // 2] - else: - offsets = [w - im_w, h - im_h] + offsets = self._get_offsets(im_h, im_w, h, w) sample['image'] = self.apply_im(sample['image'], offsets, (h, w)) if 'image2' in sample: @@ -1373,6 +1425,16 @@ class Pad(Transform): if 'gt_poly' in sample and len(sample['gt_poly']) > 0: sample['gt_poly'] = self.apply_segm( sample['gt_poly'], offsets, im_size=[im_h, im_w], size=[h, w]) + if 'target' in sample: + if 'sr_factor' in sample: + hr_shape = calc_hr_shape((h, w), sample['sr_factor']) + hr_offsets = self._get_offsets(*sample['target'].shape[:2], + *hr_shape) + sample['target'] = self.apply_im(sample['target'], hr_offsets, + hr_shape) + else: + sample['target'] = self.apply_im(sample['target'], offsets, + (h, w)) return sample @@ -1688,15 +1750,18 @@ class ReduceDim(Transform): Args: joblib_path (str): Path of *.joblib file of PCA. + apply_to_tar (bool, optional): Whether to apply transformation to the target + image. Defaults to True. """ - def __init__(self, joblib_path): + def __init__(self, joblib_path, apply_to_tar=True): super(ReduceDim, self).__init__() ext = joblib_path.split(".")[-1] if ext != "joblib": raise ValueError("`joblib_path` must be *.joblib, not *.{}.".format( ext)) self.pca = load(joblib_path) + self.apply_to_tar = apply_to_tar def apply_im(self, image): H, W, C = image.shape @@ -1709,6 +1774,8 @@ class ReduceDim(Transform): sample['image'] = self.apply_im(sample['image']) if 'image2' in sample: sample['image2'] = self.apply_im(sample['image2']) + if 'target' in sample and self.apply_to_tar: + sample['target'] = self.apply_im(sample['target']) return sample @@ -1719,11 +1786,14 @@ class SelectBand(Transform): Args: band_list (list, optional): Bands to select (band index starts from 1). Defaults to [1, 2, 3]. + apply_to_tar (bool, optional): Whether to apply transformation to the target + image. Defaults to True. """ - def __init__(self, band_list=[1, 2, 3]): + def __init__(self, band_list=[1, 2, 3], apply_to_tar=True): super(SelectBand, self).__init__() self.band_list = band_list + self.apply_to_tar = apply_to_tar def apply_im(self, image): image = select_bands(image, self.band_list) @@ -1733,6 +1803,8 @@ class SelectBand(Transform): sample['image'] = self.apply_im(sample['image']) if 'image2' in sample: sample['image2'] = self.apply_im(sample['image2']) + if 'target' in sample and self.apply_to_tar: + sample['target'] = self.apply_im(sample['target']) return sample @@ -1820,6 +1892,8 @@ class _Permute(Transform): sample['image'] = permute(sample['image'], False) if 'image2' in sample: sample['image2'] = permute(sample['image2'], False) + if 'target' in sample: + sample['target'] = permute(sample['target'], False) return sample @@ -1915,3 +1989,16 @@ class ArrangeDetector(Arrange): if self.mode == 'eval' and 'gt_poly' in sample: del sample['gt_poly'] return sample + + +class ArrangeRestorer(Arrange): + def apply(self, sample): + if 'target' in sample: + target = permute(sample['target'], False) + image = permute(sample['image'], False) + if self.mode == 'train': + return image, target + if self.mode == 'eval': + return image, target + if self.mode == 'test': + return image, diff --git a/paddlers/utils/__init__.py b/paddlers/utils/__init__.py index 950ea73..8be069c 100644 --- a/paddlers/utils/__init__.py +++ b/paddlers/utils/__init__.py @@ -16,7 +16,7 @@ from . import logging from . import utils from .utils import (seconds_to_hms, get_encoding, get_single_card_bs, dict2str, EarlyStop, norm_path, is_pic, MyEncoder, DisablePrint, - Timer) + Timer, to_data_parallel, scheduler_step) from .checkpoint import get_pretrain_weights, load_pretrain_weights, load_checkpoint from .env import get_environ_info, get_num_workers, init_parallel_env from .download import download_and_decompress, decompress diff --git a/paddlers/utils/utils.py b/paddlers/utils/utils.py index 692a1c6..90e32aa 100644 --- a/paddlers/utils/utils.py +++ b/paddlers/utils/utils.py @@ -20,11 +20,12 @@ import math import imghdr import chardet import json +import platform import numpy as np +import paddle from . import logging -import platform import paddlers @@ -237,3 +238,30 @@ class Timer(Times): self.postprocess_time_s.reset() self.img_num = 0 self.repeats = 0 + + +def to_data_parallel(layers, *args, **kwargs): + from paddlers.tasks.utils.res_adapters import GANAdapter + if isinstance(layers, GANAdapter): + layers = GANAdapter( + [to_data_parallel(g, *args, **kwargs) for g in layers.generators], [ + to_data_parallel(d, *args, **kwargs) + for d in layers.discriminators + ]) + else: + layers = paddle.DataParallel(layers, *args, **kwargs) + return layers + + +def scheduler_step(optimizer, loss=None): + from paddlers.tasks.utils.res_adapters import OptimizerAdapter + if not isinstance(optimizer, OptimizerAdapter): + optimizer = [optimizer] + for optim in optimizer: + if isinstance(optim._learning_rate, paddle.optimizer.lr.LRScheduler): + # If ReduceOnPlateau is used as the scheduler, use the loss value as the metric. + if isinstance(optim._learning_rate, + paddle.optimizer.lr.ReduceOnPlateau): + optim._learning_rate.step(loss.item()) + else: + optim._learning_rate.step() diff --git a/test_tipc/README.md b/test_tipc/README.md index 8f72eb8..70fd203 100644 --- a/test_tipc/README.md +++ b/test_tipc/README.md @@ -23,11 +23,29 @@ | 任务类别 | 模型名称 | 基础
训练预测 | 更多
训练方式 | 更多
部署方式 | Slim
训练部署 | 更多
训练环境 | | :--- | :--- | :----: | :--------: | :----: | :----: | :----: | | 变化检测 | BIT | 支持 | - | - | - | +| 变化检测 | CDNet | 支持 | - | - | - | +| 变化检测 | DSAMNet | 支持 | - | - | - | +| 变化检测 | DSIFN | 支持 | - | - | - | +| 变化检测 | SNUNet | 支持 | - | - | - | +| 变化检测 | STANet | 支持 | - | - | - | +| 变化检测 | FC-EF | 支持 | - | - | - | +| 变化检测 | FC-Siam-conc | 支持 | - | - | - | +| 变化检测 | FC-Siam-diff | 支持 | - | - | - | +| 变化检测 | ChangeFormer | 支持 | - | - | - | | 场景分类 | HRNet | 支持 | - | - | - | +| 场景分类 | MobileNetV3 | 支持 | - | - | - | +| 场景分类 | ResNet50-vd | 支持 | - | - | - | +| 图像复原 | DRN | 支持 | - | - | - | +| 图像复原 | EARGAN | 支持 | - | - | - | +| 图像复原 | LESRCNN | 支持 | - | - | - | +| 目标检测 | Faster R-CNN | 支持 | - | - | - | | 目标检测 | PP-YOLO | 支持 | - | - | - | +| 目标检测 | PP-YOLO Tiny | 支持 | - | - | - | +| 目标检测 | PP-YOLOv2 | 支持 | - | - | - | +| 目标检测 | YOLOv3 | 支持 | - | - | - | +| 图像分割 | DeepLab V3+ | 支持 | - | - | - | | 图像分割 | UNet | 支持 | - | - | - | - ## 3 测试工具简介 ### 3.1 目录介绍 diff --git a/test_tipc/common_func.sh b/test_tipc/common_func.sh index 0690d87..d6b4bd4 100644 --- a/test_tipc/common_func.sh +++ b/test_tipc/common_func.sh @@ -86,12 +86,14 @@ function download_and_unzip_dataset() { rm -rf "${ds_path}" fi - wget -nc -P "${ds_dir}" "${url}" --no-check-certificate + wget -O "${ds_dir}/${zip_name}" "${url}" --no-check-certificate # The extracted file/directory must have the same name as the zip file. - cd "${ds_dir}" && unzip "${zip_name}" \ - && mv "${zip_name%.*}" ${ds_name} && cd - \ - && echo "Successfully downloaded ${zip_name} from ${url}. File saved in ${ds_path}. " + cd "${ds_dir}" && unzip "${zip_name}" + if [ "${zip_name%.*}" != "${ds_name}" ]; then + mv "${zip_name%.*}" "${ds_name}" + fi + cd - } function parse_extra_args() { diff --git a/test_tipc/config_utils.py b/test_tipc/config_utils.py index 9f1b6fc..6e677b4 100644 --- a/test_tipc/config_utils.py +++ b/test_tipc/config_utils.py @@ -118,7 +118,7 @@ def parse_args(*args, **kwargs): conflict_handler='resolve', parents=[cfg_parser]) # Global settings parser.add_argument('cmd', choices=['train', 'eval']) - parser.add_argument('task', choices=['cd', 'clas', 'det', 'seg']) + parser.add_argument('task', choices=['cd', 'clas', 'det', 'res', 'seg']) # Data parser.add_argument('--datasets', type=dict, default={}) diff --git a/test_tipc/configs/cd/_base_/airchange.yaml b/test_tipc/configs/cd/_base_/airchange.yaml index 38ec406..f41f05a 100644 --- a/test_tipc/configs/cd/_base_/airchange.yaml +++ b/test_tipc/configs/cd/_base_/airchange.yaml @@ -60,7 +60,7 @@ download_path: ./test_tipc/data/ num_epochs: 5 train_batch_size: 4 -save_interval_epochs: 3 +save_interval_epochs: 5 log_interval_steps: 50 save_dir: ./test_tipc/output/cd/ learning_rate: 0.01 diff --git a/test_tipc/configs/cd/bit/bit_airchange.yaml b/test_tipc/configs/cd/bit/bit_airchange.yaml index efd6fbb..27e0bb4 100644 --- a/test_tipc/configs/cd/bit/bit_airchange.yaml +++ b/test_tipc/configs/cd/bit/bit_airchange.yaml @@ -1,4 +1,4 @@ -# Basic configurations of BIT with AirChange dataset +# Configurations of BIT with AirChange dataset _base_: ../_base_/airchange.yaml diff --git a/test_tipc/configs/cd/bit/bit_levircd.yaml b/test_tipc/configs/cd/bit/bit_levircd.yaml index 8008901..d9a5dd9 100644 --- a/test_tipc/configs/cd/bit/bit_levircd.yaml +++ b/test_tipc/configs/cd/bit/bit_levircd.yaml @@ -1,4 +1,4 @@ -# Basic configurations of BIT with LEVIR-CD dataset +# Configurations of BIT with LEVIR-CD dataset _base_: ../_base_/levircd.yaml diff --git a/test_tipc/configs/cd/bit/train_infer_python.txt b/test_tipc/configs/cd/bit/train_infer_python.txt index 33ee2f3..3cd2de1 100644 --- a/test_tipc/configs/cd/bit/train_infer_python.txt +++ b/test_tipc/configs/cd/bit/train_infer_python.txt @@ -6,7 +6,7 @@ use_gpu:null|null --precision:null --num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=10 --save_dir:adaptive ---train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4 +--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=8 --model_path:null --config:lite_train_lite_infer=./test_tipc/configs/cd/bit/bit_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/bit/bit_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/bit/bit_levircd.yaml train_model_name:best_model diff --git a/test_tipc/configs/cd/cdnet/cdnet_airchange.yaml b/test_tipc/configs/cd/cdnet/cdnet_airchange.yaml new file mode 100644 index 0000000..28d3f7a --- /dev/null +++ b/test_tipc/configs/cd/cdnet/cdnet_airchange.yaml @@ -0,0 +1,8 @@ +# Configurations of CDNet with AirChange dataset + +_base_: ../_base_/airchange.yaml + +save_dir: ./test_tipc/output/cd/cdnet/ + +model: !Node + type: CDNet \ No newline at end of file diff --git a/test_tipc/configs/cd/cdnet/cdnet_levircd.yaml b/test_tipc/configs/cd/cdnet/cdnet_levircd.yaml new file mode 100644 index 0000000..586e4e3 --- /dev/null +++ b/test_tipc/configs/cd/cdnet/cdnet_levircd.yaml @@ -0,0 +1,8 @@ +# Configurations of cdnet with LEVIR-CD dataset + +_base_: ../_base_/levircd.yaml + +save_dir: ./test_tipc/output/cd/cdnet/ + +model: !Node + type: CDNet \ No newline at end of file diff --git a/test_tipc/configs/cd/cdnet/train_infer_python.txt b/test_tipc/configs/cd/cdnet/train_infer_python.txt new file mode 100644 index 0000000..00ff523 --- /dev/null +++ b/test_tipc/configs/cd/cdnet/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:cd:cdnet +python:python +gpu_list:0|0,1 +use_gpu:null|null +--precision:null +--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=10 +--save_dir:adaptive +--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=8 +--model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/cd/cdnet/cdnet_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/cdnet/cdnet_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/cdnet/cdnet_levircd.yaml +train_model_name:best_model +null:null +## +trainer:norm +norm_train:test_tipc/run_task.py train cd +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir:adaptive +--model_dir:adaptive +--fixed_input_shape:[-1,3,256,256] +norm_export:deploy/export/export_model.py +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:null +infer_export:null +infer_quant:False +inference:test_tipc/infer.py +--device:cpu|gpu +--enable_mkldnn:True +--cpu_threads:6 +--batch_size:1 +--use_trt:False +--precision:fp32 +--model_dir:null +--config:null +--save_log_path:null +--benchmark:True +--model_name:cdnet +null:null \ No newline at end of file diff --git a/test_tipc/configs/cd/changeformer/changeformer_airchange.yaml b/test_tipc/configs/cd/changeformer/changeformer_airchange.yaml new file mode 100644 index 0000000..15a37ea --- /dev/null +++ b/test_tipc/configs/cd/changeformer/changeformer_airchange.yaml @@ -0,0 +1,8 @@ +# Configurations of ChangeFormer with AirChange dataset + +_base_: ../_base_/airchange.yaml + +save_dir: ./test_tipc/output/cd/changeformer/ + +model: !Node + type: ChangeFormer \ No newline at end of file diff --git a/test_tipc/configs/cd/changeformer/changeformer_levircd.yaml b/test_tipc/configs/cd/changeformer/changeformer_levircd.yaml new file mode 100644 index 0000000..931a3e8 --- /dev/null +++ b/test_tipc/configs/cd/changeformer/changeformer_levircd.yaml @@ -0,0 +1,8 @@ +# Configurations of ChangeFormer with LEVIR-CD dataset + +_base_: ../_base_/levircd.yaml + +save_dir: ./test_tipc/output/cd/changeformer/ + +model: !Node + type: ChangeFormer \ No newline at end of file diff --git a/test_tipc/configs/cd/changeformer/train_infer_python.txt b/test_tipc/configs/cd/changeformer/train_infer_python.txt index 9ac2cdc..47fe600 100644 --- a/test_tipc/configs/cd/changeformer/train_infer_python.txt +++ b/test_tipc/configs/cd/changeformer/train_infer_python.txt @@ -6,14 +6,14 @@ use_gpu:null|null --precision:null --num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=10 --save_dir:adaptive ---train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4 +--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=8 --model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/cd/changeformer/changeformer_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/changeformer/changeformer_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/changeformer/changeformer_levircd.yaml train_model_name:best_model -train_infer_file_list:./test_tipc/data/airchange/:./test_tipc/data/airchange/eval.txt null:null ## trainer:norm -norm_train:test_tipc/run_task.py train cd --config ./test_tipc/configs/cd/changeformer/changeformer.yaml +norm_train:test_tipc/run_task.py train cd pact_train:null fpgm_train:null distill_train:null @@ -27,7 +27,7 @@ null:null ===========================export_params=========================== --save_dir:adaptive --model_dir:adaptive ---fixed_input_shape:[1,3,256,256] +--fixed_input_shape:[-1,3,256,256] norm_export:deploy/export/export_model.py quant_export:null fpgm_export:null @@ -46,7 +46,7 @@ inference:test_tipc/infer.py --use_trt:False --precision:fp32 --model_dir:null ---file_list:null:null +--config:null --save_log_path:null --benchmark:True --model_name:changeformer diff --git a/test_tipc/configs/cd/dsamnet/dsamnet_airchange.yaml b/test_tipc/configs/cd/dsamnet/dsamnet_airchange.yaml new file mode 100644 index 0000000..1ede33f --- /dev/null +++ b/test_tipc/configs/cd/dsamnet/dsamnet_airchange.yaml @@ -0,0 +1,8 @@ +# Configurations of DSAMNet with AirChange dataset + +_base_: ../_base_/airchange.yaml + +save_dir: ./test_tipc/output/cd/dsamnet/ + +model: !Node + type: DSAMNet \ No newline at end of file diff --git a/test_tipc/configs/cd/dsamnet/dsamnet_levircd.yaml b/test_tipc/configs/cd/dsamnet/dsamnet_levircd.yaml new file mode 100644 index 0000000..0fa9900 --- /dev/null +++ b/test_tipc/configs/cd/dsamnet/dsamnet_levircd.yaml @@ -0,0 +1,8 @@ +# Configurations of DSAMNet with LEVIR-CD dataset + +_base_: ../_base_/levircd.yaml + +save_dir: ./test_tipc/output/cd/dsamnet/ + +model: !Node + type: DSAMNet \ No newline at end of file diff --git a/test_tipc/configs/cd/dsamnet/train_infer_python.txt b/test_tipc/configs/cd/dsamnet/train_infer_python.txt new file mode 100644 index 0000000..bce8cab --- /dev/null +++ b/test_tipc/configs/cd/dsamnet/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:cd:dsamnet +python:python +gpu_list:0|0,1 +use_gpu:null|null +--precision:null +--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=10 +--save_dir:adaptive +--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=8 +--model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/cd/dsamnet/dsamnet_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/dsamnet/dsamnet_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/dsamnet/dsamnet_levircd.yaml +train_model_name:best_model +null:null +## +trainer:norm +norm_train:test_tipc/run_task.py train cd +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir:adaptive +--model_dir:adaptive +--fixed_input_shape:[-1,3,256,256] +norm_export:deploy/export/export_model.py +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:null +infer_export:null +infer_quant:False +inference:test_tipc/infer.py +--device:cpu|gpu +--enable_mkldnn:True +--cpu_threads:6 +--batch_size:1 +--use_trt:False +--precision:fp32 +--model_dir:null +--config:null +--save_log_path:null +--benchmark:True +--model_name:dsamnet +null:null \ No newline at end of file diff --git a/test_tipc/configs/cd/dsifn/dsifn_airchange.yaml b/test_tipc/configs/cd/dsifn/dsifn_airchange.yaml new file mode 100644 index 0000000..7fc661a --- /dev/null +++ b/test_tipc/configs/cd/dsifn/dsifn_airchange.yaml @@ -0,0 +1,8 @@ +# Configurations of DSIFN with AirChange dataset + +_base_: ../_base_/airchange.yaml + +save_dir: ./test_tipc/output/cd/dsifn/ + +model: !Node + type: DSIFN \ No newline at end of file diff --git a/test_tipc/configs/cd/dsifn/dsifn_levircd.yaml b/test_tipc/configs/cd/dsifn/dsifn_levircd.yaml new file mode 100644 index 0000000..c4454a1 --- /dev/null +++ b/test_tipc/configs/cd/dsifn/dsifn_levircd.yaml @@ -0,0 +1,8 @@ +# Configurations of DSIFN with LEVIR-CD dataset + +_base_: ../_base_/levircd.yaml + +save_dir: ./test_tipc/output/cd/dsifn/ + +model: !Node + type: DSIFN \ No newline at end of file diff --git a/test_tipc/configs/cd/dsifn/train_infer_python.txt b/test_tipc/configs/cd/dsifn/train_infer_python.txt new file mode 100644 index 0000000..e491797 --- /dev/null +++ b/test_tipc/configs/cd/dsifn/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:cd:dsifn +python:python +gpu_list:0|0,1 +use_gpu:null|null +--precision:null +--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=10 +--save_dir:adaptive +--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=8 +--model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/cd/dsifn/dsifn_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/dsifn/dsifn_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/dsifn/dsifn_levircd.yaml +train_model_name:best_model +null:null +## +trainer:norm +norm_train:test_tipc/run_task.py train cd +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir:adaptive +--model_dir:adaptive +--fixed_input_shape:[-1,3,256,256] +norm_export:deploy/export/export_model.py +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:null +infer_export:null +infer_quant:False +inference:test_tipc/infer.py +--device:cpu|gpu +--enable_mkldnn:True +--cpu_threads:6 +--batch_size:1 +--use_trt:False +--precision:fp32 +--model_dir:null +--config:null +--save_log_path:null +--benchmark:True +--model_name:dsifn +null:null \ No newline at end of file diff --git a/test_tipc/configs/cd/fc_ef/fc_ef_airchange.yaml b/test_tipc/configs/cd/fc_ef/fc_ef_airchange.yaml new file mode 100644 index 0000000..fc47737 --- /dev/null +++ b/test_tipc/configs/cd/fc_ef/fc_ef_airchange.yaml @@ -0,0 +1,8 @@ +# Configurations of FC-EF with AirChange dataset + +_base_: ../_base_/airchange.yaml + +save_dir: ./test_tipc/output/cd/fc_ef/ + +model: !Node + type: FCEarlyFusion \ No newline at end of file diff --git a/test_tipc/configs/cd/fc_ef/fc_ef_levircd.yaml b/test_tipc/configs/cd/fc_ef/fc_ef_levircd.yaml new file mode 100644 index 0000000..758d4a0 --- /dev/null +++ b/test_tipc/configs/cd/fc_ef/fc_ef_levircd.yaml @@ -0,0 +1,8 @@ +# Configurations of FC-EF with LEVIR-CD dataset + +_base_: ../_base_/levircd.yaml + +save_dir: ./test_tipc/output/cd/fc_ef/ + +model: !Node + type: FCEarlyFusion \ No newline at end of file diff --git a/test_tipc/configs/cd/fc_ef/train_infer_python.txt b/test_tipc/configs/cd/fc_ef/train_infer_python.txt new file mode 100644 index 0000000..73da148 --- /dev/null +++ b/test_tipc/configs/cd/fc_ef/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:cd:fc_ef +python:python +gpu_list:0|0,1 +use_gpu:null|null +--precision:null +--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=20 +--save_dir:adaptive +--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=8 +--model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/cd/fc_ef/fc_ef_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/fc_ef/fc_ef_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/fc_ef/fc_ef_levircd.yaml +train_model_name:best_model +null:null +## +trainer:norm +norm_train:test_tipc/run_task.py train cd +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir:adaptive +--model_dir:adaptive +--fixed_input_shape:[-1,3,256,256] +norm_export:deploy/export/export_model.py +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:null +infer_export:null +infer_quant:False +inference:test_tipc/infer.py +--device:cpu|gpu +--enable_mkldnn:True +--cpu_threads:6 +--batch_size:1 +--use_trt:False +--precision:fp32 +--model_dir:null +--config:null +--save_log_path:null +--benchmark:True +--model_name:fc_ef +null:null \ No newline at end of file diff --git a/test_tipc/configs/cd/fc_siam_conc/fc_siam_conc_airchange.yaml b/test_tipc/configs/cd/fc_siam_conc/fc_siam_conc_airchange.yaml new file mode 100644 index 0000000..f4a8111 --- /dev/null +++ b/test_tipc/configs/cd/fc_siam_conc/fc_siam_conc_airchange.yaml @@ -0,0 +1,8 @@ +# Configurations of FC-Siam-conc with AirChange dataset + +_base_: ../_base_/airchange.yaml + +save_dir: ./test_tipc/output/cd/fc_siam_conc/ + +model: !Node + type: FCSiamConc \ No newline at end of file diff --git a/test_tipc/configs/cd/fc_siam_conc/fc_siam_conc_levircd.yaml b/test_tipc/configs/cd/fc_siam_conc/fc_siam_conc_levircd.yaml new file mode 100644 index 0000000..1d49a5d --- /dev/null +++ b/test_tipc/configs/cd/fc_siam_conc/fc_siam_conc_levircd.yaml @@ -0,0 +1,8 @@ +# Configurations of FC-Siam-conc with LEVIR-CD dataset + +_base_: ../_base_/levircd.yaml + +save_dir: ./test_tipc/output/cd/fc_siam_conc/ + +model: !Node + type: FCSiamConc \ No newline at end of file diff --git a/test_tipc/configs/cd/fc_siam_conc/train_infer_python.txt b/test_tipc/configs/cd/fc_siam_conc/train_infer_python.txt new file mode 100644 index 0000000..db1ade5 --- /dev/null +++ b/test_tipc/configs/cd/fc_siam_conc/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:cd:fc_siam_conc +python:python +gpu_list:0|0,1 +use_gpu:null|null +--precision:null +--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=20 +--save_dir:adaptive +--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=8 +--model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/cd/fc_siam_conc/fc_siam_conc_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/fc_siam_conc/fc_siam_conc_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/fc_siam_conc/fc_siam_conc_levircd.yaml +train_model_name:best_model +null:null +## +trainer:norm +norm_train:test_tipc/run_task.py train cd +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir:adaptive +--model_dir:adaptive +--fixed_input_shape:[-1,3,256,256] +norm_export:deploy/export/export_model.py +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:null +infer_export:null +infer_quant:False +inference:test_tipc/infer.py +--device:cpu|gpu +--enable_mkldnn:True +--cpu_threads:6 +--batch_size:1 +--use_trt:False +--precision:fp32 +--model_dir:null +--config:null +--save_log_path:null +--benchmark:True +--model_name:fc_siam_conc +null:null \ No newline at end of file diff --git a/test_tipc/configs/cd/fc_siam_diff/fc_siam_diff_airchange.yaml b/test_tipc/configs/cd/fc_siam_diff/fc_siam_diff_airchange.yaml new file mode 100644 index 0000000..3453d82 --- /dev/null +++ b/test_tipc/configs/cd/fc_siam_diff/fc_siam_diff_airchange.yaml @@ -0,0 +1,8 @@ +# Configurations of FC-Siam-diff with AirChange dataset + +_base_: ../_base_/airchange.yaml + +save_dir: ./test_tipc/output/cd/fc_siam_diff/ + +model: !Node + type: FCSiamDiff \ No newline at end of file diff --git a/test_tipc/configs/cd/fc_siam_diff/fc_siam_diff_levircd.yaml b/test_tipc/configs/cd/fc_siam_diff/fc_siam_diff_levircd.yaml new file mode 100644 index 0000000..2588cb9 --- /dev/null +++ b/test_tipc/configs/cd/fc_siam_diff/fc_siam_diff_levircd.yaml @@ -0,0 +1,8 @@ +# Configurations of FC-Siam-diff with LEVIR-CD dataset + +_base_: ../_base_/levircd.yaml + +save_dir: ./test_tipc/output/cd/fc_siam_diff/ + +model: !Node + type: FCSiamDiff \ No newline at end of file diff --git a/test_tipc/configs/cd/fc_siam_diff/train_infer_python.txt b/test_tipc/configs/cd/fc_siam_diff/train_infer_python.txt new file mode 100644 index 0000000..245e4ed --- /dev/null +++ b/test_tipc/configs/cd/fc_siam_diff/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:cd:fc_siam_diff +python:python +gpu_list:0|0,1 +use_gpu:null|null +--precision:null +--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=20 +--save_dir:adaptive +--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=8 +--model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/cd/fc_siam_diff/fc_siam_diff_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/fc_siam_diff/fc_siam_diff_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/fc_siam_diff/fc_siam_diff_levircd.yaml +train_model_name:best_model +null:null +## +trainer:norm +norm_train:test_tipc/run_task.py train cd +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir:adaptive +--model_dir:adaptive +--fixed_input_shape:[-1,3,256,256] +norm_export:deploy/export/export_model.py +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:null +infer_export:null +infer_quant:False +inference:test_tipc/infer.py +--device:cpu|gpu +--enable_mkldnn:True +--cpu_threads:6 +--batch_size:1 +--use_trt:False +--precision:fp32 +--model_dir:null +--config:null +--save_log_path:null +--benchmark:True +--model_name:fc_siam_diff +null:null \ No newline at end of file diff --git a/test_tipc/configs/cd/fccdn/fccdn_airchange.yaml b/test_tipc/configs/cd/fccdn/fccdn_airchange.yaml new file mode 100644 index 0000000..12fc83e --- /dev/null +++ b/test_tipc/configs/cd/fccdn/fccdn_airchange.yaml @@ -0,0 +1,13 @@ +# Configurations of FCCDN with AirChange dataset + +_base_: ../_base_/airchange.yaml + +save_dir: ./test_tipc/output/cd/fccdn/ + +model: !Node + type: FCCDN + +learning_rate: 0.07 +lr_decay_power: 0.6 +log_interval_steps: 100 +save_interval_epochs: 3 \ No newline at end of file diff --git a/test_tipc/configs/cd/fccdn/fccdn_levircd.yaml b/test_tipc/configs/cd/fccdn/fccdn_levircd.yaml new file mode 100644 index 0000000..02586cb --- /dev/null +++ b/test_tipc/configs/cd/fccdn/fccdn_levircd.yaml @@ -0,0 +1,8 @@ +# Configurations of FCCDN with LEVIR-CD dataset + +_base_: ../_base_/levircd.yaml + +save_dir: ./test_tipc/output/cd/fccdn/ + +model: !Node + type: FCCDN \ No newline at end of file diff --git a/test_tipc/configs/cd/fccdn/train_infer_python.txt b/test_tipc/configs/cd/fccdn/train_infer_python.txt new file mode 100644 index 0000000..b18ae87 --- /dev/null +++ b/test_tipc/configs/cd/fccdn/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:cd:fccdn +python:python +gpu_list:0|0,1 +use_gpu:null|null +--precision:null +--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=10 +--save_dir:adaptive +--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=8 +--model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/cd/fccdn/fccdn_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/fccdn/fccdn_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/fccdn/fccdn_levircd.yaml +train_model_name:best_model +null:null +## +trainer:norm +norm_train:test_tipc/run_task.py train cd +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir:adaptive +--model_dir:adaptive +--fixed_input_shape:[-1,3,256,256] +norm_export:deploy/export/export_model.py +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:null +infer_export:null +infer_quant:False +inference:test_tipc/infer.py +--device:cpu|gpu +--enable_mkldnn:True +--cpu_threads:6 +--batch_size:1 +--use_trt:False +--precision:fp32 +--model_dir:null +--config:null +--save_log_path:null +--benchmark:True +--model_name:fccdn +null:null \ No newline at end of file diff --git a/test_tipc/configs/cd/snunet/snunet_airchange.yaml b/test_tipc/configs/cd/snunet/snunet_airchange.yaml new file mode 100644 index 0000000..eee3b1d --- /dev/null +++ b/test_tipc/configs/cd/snunet/snunet_airchange.yaml @@ -0,0 +1,8 @@ +# Configurations of SNUNet with AirChange dataset + +_base_: ../_base_/airchange.yaml + +save_dir: ./test_tipc/output/cd/snunet/ + +model: !Node + type: SNUNet \ No newline at end of file diff --git a/test_tipc/configs/cd/snunet/snunet_levircd.yaml b/test_tipc/configs/cd/snunet/snunet_levircd.yaml new file mode 100644 index 0000000..7af3bcb --- /dev/null +++ b/test_tipc/configs/cd/snunet/snunet_levircd.yaml @@ -0,0 +1,8 @@ +# Configurations of SNUNet with LEVIR-CD dataset + +_base_: ../_base_/levircd.yaml + +save_dir: ./test_tipc/output/cd/snunet/ + +model: !Node + type: SNUNet \ No newline at end of file diff --git a/test_tipc/configs/cd/snunet/train_infer_python.txt b/test_tipc/configs/cd/snunet/train_infer_python.txt new file mode 100644 index 0000000..264ffd9 --- /dev/null +++ b/test_tipc/configs/cd/snunet/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:cd:snunet +python:python +gpu_list:0|0,1 +use_gpu:null|null +--precision:null +--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=10 +--save_dir:adaptive +--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=8 +--model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/cd/snunet/snunet_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/snunet/snunet_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/snunet/snunet_levircd.yaml +train_model_name:best_model +null:null +## +trainer:norm +norm_train:test_tipc/run_task.py train cd +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir:adaptive +--model_dir:adaptive +--fixed_input_shape:[-1,3,256,256] +norm_export:deploy/export/export_model.py +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:null +infer_export:null +infer_quant:False +inference:test_tipc/infer.py +--device:cpu|gpu +--enable_mkldnn:True +--cpu_threads:6 +--batch_size:1 +--use_trt:False +--precision:fp32 +--model_dir:null +--config:null +--save_log_path:null +--benchmark:True +--model_name:snunet +null:null \ No newline at end of file diff --git a/test_tipc/configs/cd/stanet/stanet_airchange.yaml b/test_tipc/configs/cd/stanet/stanet_airchange.yaml new file mode 100644 index 0000000..7c7c05a --- /dev/null +++ b/test_tipc/configs/cd/stanet/stanet_airchange.yaml @@ -0,0 +1,8 @@ +# Configurations of STANet with AirChange dataset + +_base_: ../_base_/airchange.yaml + +save_dir: ./test_tipc/output/cd/stanet/ + +model: !Node + type: STANet \ No newline at end of file diff --git a/test_tipc/configs/cd/stanet/stanet_levircd.yaml b/test_tipc/configs/cd/stanet/stanet_levircd.yaml new file mode 100644 index 0000000..b439ff1 --- /dev/null +++ b/test_tipc/configs/cd/stanet/stanet_levircd.yaml @@ -0,0 +1,8 @@ +# Configurations of STANet with LEVIR-CD dataset + +_base_: ../_base_/levircd.yaml + +save_dir: ./test_tipc/output/cd/stanet/ + +model: !Node + type: STANet \ No newline at end of file diff --git a/test_tipc/configs/cd/stanet/train_infer_python.txt b/test_tipc/configs/cd/stanet/train_infer_python.txt new file mode 100644 index 0000000..0bff7df --- /dev/null +++ b/test_tipc/configs/cd/stanet/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:cd:stanet +python:python +gpu_list:0|0,1 +use_gpu:null|null +--precision:null +--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=10 +--save_dir:adaptive +--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=8 +--model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/cd/stanet/stanet_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/stanet/stanet_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/stanet/stanet_levircd.yaml +train_model_name:best_model +null:null +## +trainer:norm +norm_train:test_tipc/run_task.py train cd +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir:adaptive +--model_dir:adaptive +--fixed_input_shape:[-1,3,256,256] +norm_export:deploy/export/export_model.py +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:null +infer_export:null +infer_quant:False +inference:test_tipc/infer.py +--device:cpu|gpu +--enable_mkldnn:True +--cpu_threads:6 +--batch_size:1 +--use_trt:False +--precision:fp32 +--model_dir:null +--config:null +--save_log_path:null +--benchmark:True +--model_name:stanet +null:null \ No newline at end of file diff --git a/test_tipc/configs/clas/_base_/ucmerced.yaml b/test_tipc/configs/clas/_base_/ucmerced.yaml index bff0b8f..1b3b79d 100644 --- a/test_tipc/configs/clas/_base_/ucmerced.yaml +++ b/test_tipc/configs/clas/_base_/ucmerced.yaml @@ -62,7 +62,7 @@ download_path: ./test_tipc/data/ num_epochs: 2 train_batch_size: 16 -save_interval_epochs: 5 +save_interval_epochs: 10 log_interval_steps: 50 save_dir: ./test_tipc/output/clas/ learning_rate: 0.01 diff --git a/test_tipc/configs/clas/hrnet/hrnet.yaml b/test_tipc/configs/clas/hrnet/hrnet.yaml index f402c26..4c9879f 100644 --- a/test_tipc/configs/clas/hrnet/hrnet.yaml +++ b/test_tipc/configs/clas/hrnet/hrnet.yaml @@ -6,5 +6,5 @@ save_dir: ./test_tipc/output/clas/hrnet/ model: !Node type: HRNet_W18_C - args: - num_classes: 21 \ No newline at end of file + args: + num_classes: 21 \ No newline at end of file diff --git a/test_tipc/configs/clas/hrnet/hrnet_ucmerced.yaml b/test_tipc/configs/clas/hrnet/hrnet_ucmerced.yaml new file mode 100644 index 0000000..3a09756 --- /dev/null +++ b/test_tipc/configs/clas/hrnet/hrnet_ucmerced.yaml @@ -0,0 +1,10 @@ +# Configurations of HRNet with UCMerced dataset + +_base_: ../_base_/ucmerced.yaml + +save_dir: ./test_tipc/output/clas/hrnet/ + +model: !Node + type: HRNet_W18_C + args: + num_classes: 21 \ No newline at end of file diff --git a/test_tipc/configs/clas/hrnet/train_infer_python.txt b/test_tipc/configs/clas/hrnet/train_infer_python.txt index 23f3820..1116c77 100644 --- a/test_tipc/configs/clas/hrnet/train_infer_python.txt +++ b/test_tipc/configs/clas/hrnet/train_infer_python.txt @@ -8,12 +8,12 @@ use_gpu:null|null --save_dir:adaptive --train_batch_size:lite_train_lite_infer=16|lite_train_whole_infer=16|whole_train_whole_infer=16 --model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/clas/hrnet/hrnet_ucmerced.yaml|lite_train_whole_infer=./test_tipc/configs/clas/hrnet/hrnet_ucmerced.yaml|whole_train_whole_infer=./test_tipc/configs/clas/hrnet/hrnet_ucmerced.yaml train_model_name:best_model -train_infer_file_list:./test_tipc/data/ucmerced/:./test_tipc/data/ucmerced/val.txt null:null ## trainer:norm -norm_train:test_tipc/run_task.py train clas --config ./test_tipc/configs/clas/hrnet/hrnet.yaml +norm_train:test_tipc/run_task.py train clas pact_train:null fpgm_train:null distill_train:null @@ -46,7 +46,7 @@ inference:test_tipc/infer.py --use_trt:False --precision:fp32 --model_dir:null ---file_list:null:null +--config:null --save_log_path:null --benchmark:True --model_name:hrnet diff --git a/test_tipc/configs/clas/mobilenetv3/mobilenetv3_ucmerced.yaml b/test_tipc/configs/clas/mobilenetv3/mobilenetv3_ucmerced.yaml new file mode 100644 index 0000000..becdd5f --- /dev/null +++ b/test_tipc/configs/clas/mobilenetv3/mobilenetv3_ucmerced.yaml @@ -0,0 +1,10 @@ +# Configurations of MobileNetV3 with UCMerced dataset + +_base_: ../_base_/ucmerced.yaml + +save_dir: ./test_tipc/output/clas/mobilenetv3/ + +model: !Node + type: MobileNetV3_small_x1_0 + args: + num_classes: 21 \ No newline at end of file diff --git a/test_tipc/configs/clas/mobilenetv3/train_infer_python.txt b/test_tipc/configs/clas/mobilenetv3/train_infer_python.txt new file mode 100644 index 0000000..50406f6 --- /dev/null +++ b/test_tipc/configs/clas/mobilenetv3/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:clas:mobilenetv3 +python:python +gpu_list:0|0,1 +use_gpu:null|null +--precision:null +--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=10 +--save_dir:adaptive +--train_batch_size:lite_train_lite_infer=16|lite_train_whole_infer=16|whole_train_whole_infer=16 +--model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/clas/mobilenetv3/mobilenetv3_ucmerced.yaml|lite_train_whole_infer=./test_tipc/configs/clas/mobilenetv3/mobilenetv3_ucmerced.yaml|whole_train_whole_infer=./test_tipc/configs/clas/mobilenetv3/mobilenetv3_ucmerced.yaml +train_model_name:best_model +null:null +## +trainer:norm +norm_train:test_tipc/run_task.py train clas +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir:adaptive +--model_dir:adaptive +--fixed_input_shape:[-1,3,256,256] +norm_export:deploy/export/export_model.py +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:null +infer_export:null +infer_quant:False +inference:test_tipc/infer.py +--device:cpu|gpu +--enable_mkldnn:True +--cpu_threads:6 +--batch_size:1 +--use_trt:False +--precision:fp32 +--model_dir:null +--config:null +--save_log_path:null +--benchmark:True +--model_name:mobilenetv3 +null:null \ No newline at end of file diff --git a/test_tipc/configs/clas/resnet50_vd/resnet50_vd_ucmerced.yaml b/test_tipc/configs/clas/resnet50_vd/resnet50_vd_ucmerced.yaml new file mode 100644 index 0000000..4978cfc --- /dev/null +++ b/test_tipc/configs/clas/resnet50_vd/resnet50_vd_ucmerced.yaml @@ -0,0 +1,10 @@ +# Configurations of ResNet50-vd with UCMerced dataset + +_base_: ../_base_/ucmerced.yaml + +save_dir: ./test_tipc/output/clas/resnet50_vd/ + +model: !Node + type: ResNet50_vd + args: + num_classes: 21 \ No newline at end of file diff --git a/test_tipc/configs/clas/resnet50_vd/train_infer_python.txt b/test_tipc/configs/clas/resnet50_vd/train_infer_python.txt new file mode 100644 index 0000000..2295361 --- /dev/null +++ b/test_tipc/configs/clas/resnet50_vd/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:clas:resnet50_vd +python:python +gpu_list:0|0,1 +use_gpu:null|null +--precision:null +--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=10 +--save_dir:adaptive +--train_batch_size:lite_train_lite_infer=16|lite_train_whole_infer=16|whole_train_whole_infer=16 +--model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/clas/resnet50_vd/resnet50_vd_ucmerced.yaml|lite_train_whole_infer=./test_tipc/configs/clas/resnet50_vd/resnet50_vd_ucmerced.yaml|whole_train_whole_infer=./test_tipc/configs/clas/resnet50_vd/resnet50_vd_ucmerced.yaml +train_model_name:best_model +null:null +## +trainer:norm +norm_train:test_tipc/run_task.py train clas +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir:adaptive +--model_dir:adaptive +--fixed_input_shape:[-1,3,256,256] +norm_export:deploy/export/export_model.py +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:null +infer_export:null +infer_quant:False +inference:test_tipc/infer.py +--device:cpu|gpu +--enable_mkldnn:True +--cpu_threads:6 +--batch_size:1 +--use_trt:False +--precision:fp32 +--model_dir:null +--config:null +--save_log_path:null +--benchmark:True +--model_name:resnet50_vd +null:null \ No newline at end of file diff --git a/test_tipc/configs/det/_base_/rsod.yaml b/test_tipc/configs/det/_base_/rsod.yaml new file mode 100644 index 0000000..a4fcf69 --- /dev/null +++ b/test_tipc/configs/det/_base_/rsod.yaml @@ -0,0 +1,72 @@ +# Basic configurations of RSOD dataset + +datasets: + train: !Node + type: VOCDetDataset + args: + data_dir: ./test_tipc/data/rsod/ + file_list: ./test_tipc/data/rsod/train.txt + label_list: ./test_tipc/data/rsod/labels.txt + shuffle: True + eval: !Node + type: VOCDetDataset + args: + data_dir: ./test_tipc/data/rsod/ + file_list: ./test_tipc/data/rsod/val.txt + label_list: ./test_tipc/data/rsod/labels.txt + shuffle: False +transforms: + train: + - !Node + type: DecodeImg + - !Node + type: RandomDistort + - !Node + type: RandomExpand + - !Node + type: RandomCrop + - !Node + type: RandomHorizontalFlip + - !Node + type: BatchRandomResize + args: + target_sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608] + interp: RANDOM + - !Node + type: Normalize + args: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + - !Node + type: ArrangeDetector + args: ['train'] + eval: + - !Node + type: DecodeImg + - !Node + type: Resize + args: + target_size: 608 + interp: CUBIC + - !Node + type: Normalize + args: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + - !Node + type: ArrangeDetector + args: ['eval'] +download_on: False + +num_epochs: 10 +train_batch_size: 4 +save_interval_epochs: 10 +log_interval_steps: 4 +save_dir: ./test_tipc/output/det/ +learning_rate: 0.0001 +use_vdl: False +resume_checkpoint: '' +train: + pretrain_weights: COCO + warmup_steps: 0 + warmup_start_lr: 0.0 \ No newline at end of file diff --git a/test_tipc/configs/det/_base_/sarship.yaml b/test_tipc/configs/det/_base_/sarship.yaml index c7c6afe..ba38220 100644 --- a/test_tipc/configs/det/_base_/sarship.yaml +++ b/test_tipc/configs/det/_base_/sarship.yaml @@ -62,10 +62,10 @@ download_path: ./test_tipc/data/ num_epochs: 10 train_batch_size: 4 -save_interval_epochs: 5 +save_interval_epochs: 10 log_interval_steps: 4 save_dir: ./test_tipc/output/det/ -learning_rate: 0.0005 +learning_rate: 0.0001 use_vdl: False resume_checkpoint: '' train: diff --git a/test_tipc/configs/det/faster_rcnn/faster_rcnn_rsod.yaml b/test_tipc/configs/det/faster_rcnn/faster_rcnn_rsod.yaml new file mode 100644 index 0000000..c9c3aa4 --- /dev/null +++ b/test_tipc/configs/det/faster_rcnn/faster_rcnn_rsod.yaml @@ -0,0 +1,10 @@ +# Configurations of Faster R-CNN with RSOD dataset + +_base_: ../_base_/rsod.yaml + +save_dir: ./test_tipc/output/det/faster_rcnn/ + +model: !Node + type: FasterRCNN + args: + num_classes: 4 \ No newline at end of file diff --git a/test_tipc/configs/det/faster_rcnn/faster_rcnn_sarship.yaml b/test_tipc/configs/det/faster_rcnn/faster_rcnn_sarship.yaml new file mode 100644 index 0000000..b958be3 --- /dev/null +++ b/test_tipc/configs/det/faster_rcnn/faster_rcnn_sarship.yaml @@ -0,0 +1,10 @@ +# Configurations of Faster R-CNN with SARShip dataset + +_base_: ../_base_/sarship.yaml + +save_dir: ./test_tipc/output/det/faster_rcnn/ + +model: !Node + type: FasterRCNN + args: + num_classes: 1 \ No newline at end of file diff --git a/test_tipc/configs/det/faster_rcnn/train_infer_python.txt b/test_tipc/configs/det/faster_rcnn/train_infer_python.txt new file mode 100644 index 0000000..679d81e --- /dev/null +++ b/test_tipc/configs/det/faster_rcnn/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:det:faster_rcnn +python:python +gpu_list:0|0,1 +use_gpu:null|null +--precision:null +--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=10 +--save_dir:adaptive +--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4 +--model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/det/faster_rcnn/faster_rcnn_sarship.yaml|lite_train_whole_infer=./test_tipc/configs/det/faster_rcnn/faster_rcnn_sarship.yaml|whole_train_whole_infer=./test_tipc/configs/det/faster_rcnn/faster_rcnn_rsod.yaml +train_model_name:best_model +null:null +## +trainer:norm +norm_train:test_tipc/run_task.py train det +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir:adaptive +--model_dir:adaptive +--fixed_input_shape:[-1,3,608,608] +norm_export:deploy/export/export_model.py +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:null +infer_export:null +infer_quant:False +inference:test_tipc/infer.py +--device:cpu|gpu +--enable_mkldnn:True +--cpu_threads:6 +--batch_size:1 +--use_trt:False +--precision:fp32 +--model_dir:null +--config:null +--save_log_path:null +--benchmark:True +--model_name:faster_rcnn +null:null \ No newline at end of file diff --git a/test_tipc/configs/det/ppyolo/ppyolo_rsod.yaml b/test_tipc/configs/det/ppyolo/ppyolo_rsod.yaml new file mode 100644 index 0000000..32c7fca --- /dev/null +++ b/test_tipc/configs/det/ppyolo/ppyolo_rsod.yaml @@ -0,0 +1,10 @@ +# Configurations of PP-YOLO with RSOD dataset + +_base_: ../_base_/rsod.yaml + +save_dir: ./test_tipc/output/det/ppyolo/ + +model: !Node + type: PPYOLO + args: + num_classes: 4 \ No newline at end of file diff --git a/test_tipc/configs/det/ppyolo/ppyolo.yaml b/test_tipc/configs/det/ppyolo/ppyolo_sarship.yaml similarity index 56% rename from test_tipc/configs/det/ppyolo/ppyolo.yaml rename to test_tipc/configs/det/ppyolo/ppyolo_sarship.yaml index f36919c..a3fbf58 100644 --- a/test_tipc/configs/det/ppyolo/ppyolo.yaml +++ b/test_tipc/configs/det/ppyolo/ppyolo_sarship.yaml @@ -1,4 +1,4 @@ -# Basic configurations of PP-YOLO +# Configurations of PP-YOLO with SARShip dataset _base_: ../_base_/sarship.yaml @@ -6,5 +6,5 @@ save_dir: ./test_tipc/output/det/ppyolo/ model: !Node type: PPYOLO - args: - num_classes: 1 \ No newline at end of file + args: + num_classes: 1 \ No newline at end of file diff --git a/test_tipc/configs/det/ppyolo/train_infer_python.txt b/test_tipc/configs/det/ppyolo/train_infer_python.txt index 43a47fa..eadaaf4 100644 --- a/test_tipc/configs/det/ppyolo/train_infer_python.txt +++ b/test_tipc/configs/det/ppyolo/train_infer_python.txt @@ -4,16 +4,16 @@ python:python gpu_list:0|0,1 use_gpu:null|null --precision:null ---num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=10 +--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=20 --save_dir:adaptive --train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4 --model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/det/ppyolo/ppyolo_sarship.yaml|lite_train_whole_infer=./test_tipc/configs/det/ppyolo/ppyolo_sarship.yaml|whole_train_whole_infer=./test_tipc/configs/det/ppyolo/ppyolo_rsod.yaml train_model_name:best_model -train_infer_file_list:./test_tipc/data/sarship/:./test_tipc/data/sarship/eval.txt null:null ## trainer:norm -norm_train:test_tipc/run_task.py train det --config ./test_tipc/configs/det/ppyolo/ppyolo.yaml +norm_train:test_tipc/run_task.py train det pact_train:null fpgm_train:null distill_train:null @@ -46,7 +46,7 @@ inference:test_tipc/infer.py --use_trt:False --precision:fp32 --model_dir:null ---file_list:null:null +--config:null --save_log_path:null --benchmark:True --model_name:ppyolo diff --git a/test_tipc/configs/det/ppyolo_tiny/ppyolo_tiny_rsod.yaml b/test_tipc/configs/det/ppyolo_tiny/ppyolo_tiny_rsod.yaml new file mode 100644 index 0000000..cdd20a4 --- /dev/null +++ b/test_tipc/configs/det/ppyolo_tiny/ppyolo_tiny_rsod.yaml @@ -0,0 +1,10 @@ +# Configurations of PP-YOLO Tiny with RSOD dataset + +_base_: ../_base_/rsod.yaml + +save_dir: ./test_tipc/output/det/ppyolo_tiny/ + +model: !Node + type: PPYOLOTiny + args: + num_classes: 4 \ No newline at end of file diff --git a/test_tipc/configs/det/ppyolo_tiny/ppyolo_tiny_sarship.yaml b/test_tipc/configs/det/ppyolo_tiny/ppyolo_tiny_sarship.yaml new file mode 100644 index 0000000..24f67a6 --- /dev/null +++ b/test_tipc/configs/det/ppyolo_tiny/ppyolo_tiny_sarship.yaml @@ -0,0 +1,10 @@ +# Configurations of PP-YOLO Tiny with SARShip dataset + +_base_: ../_base_/sarship.yaml + +save_dir: ./test_tipc/output/det/ppyolo_tiny/ + +model: !Node + type: PPYOLOTiny + args: + num_classes: 1 \ No newline at end of file diff --git a/test_tipc/configs/det/ppyolo_tiny/train_infer_python.txt b/test_tipc/configs/det/ppyolo_tiny/train_infer_python.txt new file mode 100644 index 0000000..106610f --- /dev/null +++ b/test_tipc/configs/det/ppyolo_tiny/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:det:ppyolo_tiny +python:python +gpu_list:0|0,1 +use_gpu:null|null +--precision:null +--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=20 +--save_dir:adaptive +--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4 +--model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/det/ppyolo_tiny/ppyolo_tiny_sarship.yaml|lite_train_whole_infer=./test_tipc/configs/det/ppyolo_tiny/ppyolo_tiny_sarship.yaml|whole_train_whole_infer=./test_tipc/configs/det/ppyolo_tiny/ppyolo_tiny_rsod.yaml +train_model_name:best_model +null:null +## +trainer:norm +norm_train:test_tipc/run_task.py train det +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir:adaptive +--model_dir:adaptive +--fixed_input_shape:[-1,3,608,608] +norm_export:deploy/export/export_model.py +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:null +infer_export:null +infer_quant:False +inference:test_tipc/infer.py +--device:cpu|gpu +--enable_mkldnn:True +--cpu_threads:6 +--batch_size:1 +--use_trt:False +--precision:fp32 +--model_dir:null +--config:null +--save_log_path:null +--benchmark:True +--model_name:ppyolo_tiny +null:null \ No newline at end of file diff --git a/test_tipc/configs/det/ppyolov2/ppyolov2_rsod.yaml b/test_tipc/configs/det/ppyolov2/ppyolov2_rsod.yaml new file mode 100644 index 0000000..ec5c423 --- /dev/null +++ b/test_tipc/configs/det/ppyolov2/ppyolov2_rsod.yaml @@ -0,0 +1,10 @@ +# Configurations of PP-YOLOv2 with RSOD dataset + +_base_: ../_base_/rsod.yaml + +save_dir: ./test_tipc/output/det/ppyolov2/ + +model: !Node + type: PPYOLOv2 + args: + num_classes: 4 \ No newline at end of file diff --git a/test_tipc/configs/det/ppyolov2/ppyolov2_sarship.yaml b/test_tipc/configs/det/ppyolov2/ppyolov2_sarship.yaml new file mode 100644 index 0000000..45a9d36 --- /dev/null +++ b/test_tipc/configs/det/ppyolov2/ppyolov2_sarship.yaml @@ -0,0 +1,10 @@ +# Configurations of PP-YOLOv2 with SARShip dataset + +_base_: ../_base_/sarship.yaml + +save_dir: ./test_tipc/output/det/ppyolov2/ + +model: !Node + type: PPYOLOv2 + args: + num_classes: 1 \ No newline at end of file diff --git a/test_tipc/configs/det/ppyolov2/train_infer_python.txt b/test_tipc/configs/det/ppyolov2/train_infer_python.txt new file mode 100644 index 0000000..3825157 --- /dev/null +++ b/test_tipc/configs/det/ppyolov2/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:det:ppyolov2 +python:python +gpu_list:0|0,1 +use_gpu:null|null +--precision:null +--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=20 +--save_dir:adaptive +--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4 +--model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/det/ppyolov2/ppyolov2_sarship.yaml|lite_train_whole_infer=./test_tipc/configs/det/ppyolov2/ppyolov2_sarship.yaml|whole_train_whole_infer=./test_tipc/configs/det/ppyolov2/ppyolov2_rsod.yaml +train_model_name:best_model +null:null +## +trainer:norm +norm_train:test_tipc/run_task.py train det +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir:adaptive +--model_dir:adaptive +--fixed_input_shape:[-1,3,608,608] +norm_export:deploy/export/export_model.py +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:null +infer_export:null +infer_quant:False +inference:test_tipc/infer.py +--device:cpu|gpu +--enable_mkldnn:True +--cpu_threads:6 +--batch_size:1 +--use_trt:False +--precision:fp32 +--model_dir:null +--config:null +--save_log_path:null +--benchmark:True +--model_name:ppyolov2 +null:null \ No newline at end of file diff --git a/test_tipc/configs/det/yolov3/train_infer_python.txt b/test_tipc/configs/det/yolov3/train_infer_python.txt new file mode 100644 index 0000000..b60be01 --- /dev/null +++ b/test_tipc/configs/det/yolov3/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:det:yolov3 +python:python +gpu_list:0|0,1 +use_gpu:null|null +--precision:null +--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=10 +--save_dir:adaptive +--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4 +--model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/det/yolov3/yolov3_sarship.yaml|lite_train_whole_infer=./test_tipc/configs/det/yolov3/yolov3_sarship.yaml|whole_train_whole_infer=./test_tipc/configs/det/yolov3/yolov3_rsod.yaml +train_model_name:best_model +null:null +## +trainer:norm +norm_train:test_tipc/run_task.py train det +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir:adaptive +--model_dir:adaptive +--fixed_input_shape:[-1,3,608,608] +norm_export:deploy/export/export_model.py +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:null +infer_export:null +infer_quant:False +inference:test_tipc/infer.py +--device:cpu|gpu +--enable_mkldnn:True +--cpu_threads:6 +--batch_size:1 +--use_trt:False +--precision:fp32 +--model_dir:null +--config:null +--save_log_path:null +--benchmark:True +--model_name:yolov3 +null:null \ No newline at end of file diff --git a/test_tipc/configs/det/yolov3/yolov3_rsod.yaml b/test_tipc/configs/det/yolov3/yolov3_rsod.yaml new file mode 100644 index 0000000..ce1d4df --- /dev/null +++ b/test_tipc/configs/det/yolov3/yolov3_rsod.yaml @@ -0,0 +1,10 @@ +# Configurations of YOLOv3 with RSOD dataset + +_base_: ../_base_/rsod.yaml + +save_dir: ./test_tipc/output/det/yolov3/ + +model: !Node + type: YOLOv3 + args: + num_classes: 4 \ No newline at end of file diff --git a/test_tipc/configs/det/yolov3/yolov3_sarship.yaml b/test_tipc/configs/det/yolov3/yolov3_sarship.yaml new file mode 100644 index 0000000..3e2659d --- /dev/null +++ b/test_tipc/configs/det/yolov3/yolov3_sarship.yaml @@ -0,0 +1,10 @@ +# Configurations of YOLOv3 with SARShip dataset + +_base_: ../_base_/sarship.yaml + +save_dir: ./test_tipc/output/det/yolov3/ + +model: !Node + type: YOLOv3 + args: + num_classes: 1 \ No newline at end of file diff --git a/test_tipc/configs/res/_base_/rssr.yaml b/test_tipc/configs/res/_base_/rssr.yaml new file mode 100644 index 0000000..c2d5265 --- /dev/null +++ b/test_tipc/configs/res/_base_/rssr.yaml @@ -0,0 +1,72 @@ +# Basic configurations of RSSR dataset + +datasets: + train: !Node + type: ResDataset + args: + data_dir: ./test_tipc/data/rssr/ + file_list: ./test_tipc/data/rssr/train.txt + num_workers: 0 + shuffle: True + sr_factor: 4 + eval: !Node + type: ResDataset + args: + data_dir: ./test_tipc/data/rssr/ + file_list: ./test_tipc/data/rssr/val.txt + num_workers: 0 + shuffle: False + sr_factor: 4 +transforms: + train: + - !Node + type: DecodeImg + - !Node + type: RandomCrop + args: + crop_size: 32 + - !Node + type: RandomHorizontalFlip + args: + prob: 0.5 + - !Node + type: RandomVerticalFlip + args: + prob: 0.5 + - !Node + type: Normalize + args: + mean: [0.0, 0.0, 0.0] + std: [1.0, 1.0, 1.0] + - !Node + type: ArrangeRestorer + args: ['train'] + eval: + - !Node + type: DecodeImg + - !Node + type: Resize + args: + target_size: 256 + - !Node + type: Normalize + args: + mean: [0.0, 0.0, 0.0] + std: [1.0, 1.0, 1.0] + - !Node + type: ArrangeRestorer + args: ['eval'] +download_on: False +download_url: https://paddlers.bj.bcebos.com/datasets/rssr.zip +download_path: ./test_tipc/data/ + +num_epochs: 10 +train_batch_size: 4 +save_interval_epochs: 10 +log_interval_steps: 10 +save_dir: ./test_tipc/output/res/ +learning_rate: 0.0005 +early_stop: False +early_stop_patience: 5 +use_vdl: False +resume_checkpoint: '' \ No newline at end of file diff --git a/test_tipc/configs/res/drn/drn_rssr.yaml b/test_tipc/configs/res/drn/drn_rssr.yaml new file mode 100644 index 0000000..52625cf --- /dev/null +++ b/test_tipc/configs/res/drn/drn_rssr.yaml @@ -0,0 +1,8 @@ +# Configurations of DRN with RSSR dataset + +_base_: ../_base_/rssr.yaml + +save_dir: ./test_tipc/output/res/drn/ + +model: !Node + type: DRN \ No newline at end of file diff --git a/test_tipc/configs/res/drn/train_infer_python.txt b/test_tipc/configs/res/drn/train_infer_python.txt new file mode 100644 index 0000000..c3ba4b0 --- /dev/null +++ b/test_tipc/configs/res/drn/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:res:drn +python:python +gpu_list:0|0,1 +use_gpu:null|null +--precision:null +--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=20 +--save_dir:adaptive +--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4 +--model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/res/drn/drn_rssr.yaml|lite_train_whole_infer=./test_tipc/configs/res/drn/drn_rssr.yaml|whole_train_whole_infer=./test_tipc/configs/res/drn/drn_rssr.yaml +train_model_name:best_model +null:null +## +trainer:norm +norm_train:test_tipc/run_task.py train res +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir:adaptive +--model_dir:adaptive +--fixed_input_shape:[-1,3,256,256] +norm_export:deploy/export/export_model.py +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:null +infer_export:null +infer_quant:False +inference:test_tipc/infer.py +--device:cpu|gpu +--enable_mkldnn:True +--cpu_threads:6 +--batch_size:1 +--use_trt:False +--precision:fp32 +--model_dir:null +--config:null +--save_log_path:null +--benchmark:True +--model_name:drn +null:null \ No newline at end of file diff --git a/test_tipc/configs/res/esrgan/esrgan_rssr.yaml b/test_tipc/configs/res/esrgan/esrgan_rssr.yaml new file mode 100644 index 0000000..9dbb2f5 --- /dev/null +++ b/test_tipc/configs/res/esrgan/esrgan_rssr.yaml @@ -0,0 +1,8 @@ +# Configurations of ESRGAN with RSSR dataset + +_base_: ../_base_/rssr.yaml + +save_dir: ./test_tipc/output/res/esrgan/ + +model: !Node + type: ESRGAN \ No newline at end of file diff --git a/test_tipc/configs/res/esrgan/train_infer_python.txt b/test_tipc/configs/res/esrgan/train_infer_python.txt new file mode 100644 index 0000000..9aaab9b --- /dev/null +++ b/test_tipc/configs/res/esrgan/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:res:esrgan +python:python +gpu_list:0|0,1 +use_gpu:null|null +--precision:null +--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=20 +--save_dir:adaptive +--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4 +--model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/res/esrgan/esrgan_rssr.yaml|lite_train_whole_infer=./test_tipc/configs/res/esrgan/esrgan_rssr.yaml|whole_train_whole_infer=./test_tipc/configs/res/esrgan/esrgan_rssr.yaml +train_model_name:best_model +null:null +## +trainer:norm +norm_train:test_tipc/run_task.py train res +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir:adaptive +--model_dir:adaptive +--fixed_input_shape:[-1,3,256,256] +norm_export:deploy/export/export_model.py +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:null +infer_export:null +infer_quant:False +inference:test_tipc/infer.py +--device:cpu|gpu +--enable_mkldnn:True +--cpu_threads:6 +--batch_size:1 +--use_trt:False +--precision:fp32 +--model_dir:null +--config:null +--save_log_path:null +--benchmark:True +--model_name:esrgan +null:null \ No newline at end of file diff --git a/test_tipc/configs/res/lesrcnn/lesrcnn_rssr.yaml b/test_tipc/configs/res/lesrcnn/lesrcnn_rssr.yaml new file mode 100644 index 0000000..6b4c193 --- /dev/null +++ b/test_tipc/configs/res/lesrcnn/lesrcnn_rssr.yaml @@ -0,0 +1,8 @@ +# Configurations of LESRCNN with RSSR dataset + +_base_: ../_base_/rssr.yaml + +save_dir: ./test_tipc/output/res/lesrcnn/ + +model: !Node + type: LESRCNN \ No newline at end of file diff --git a/test_tipc/configs/res/lesrcnn/train_infer_python.txt b/test_tipc/configs/res/lesrcnn/train_infer_python.txt new file mode 100644 index 0000000..97fac6f --- /dev/null +++ b/test_tipc/configs/res/lesrcnn/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:res:lesrcnn +python:python +gpu_list:0|0,1 +use_gpu:null|null +--precision:null +--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=20 +--save_dir:adaptive +--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4 +--model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/res/lesrcnn/lesrcnn_rssr.yaml|lite_train_whole_infer=./test_tipc/configs/res/lesrcnn/lesrcnn_rssr.yaml|whole_train_whole_infer=./test_tipc/configs/res/lesrcnn/lesrcnn_rssr.yaml +train_model_name:best_model +null:null +## +trainer:norm +norm_train:test_tipc/run_task.py train res +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir:adaptive +--model_dir:adaptive +--fixed_input_shape:[-1,3,256,256] +norm_export:deploy/export/export_model.py +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:null +infer_export:null +infer_quant:False +inference:test_tipc/infer.py +--device:cpu|gpu +--enable_mkldnn:True +--cpu_threads:6 +--batch_size:1 +--use_trt:False +--precision:fp32 +--model_dir:null +--config:null +--save_log_path:null +--benchmark:True +--model_name:lesrcnn +null:null \ No newline at end of file diff --git a/test_tipc/configs/seg/_base_/rsseg.yaml b/test_tipc/configs/seg/_base_/rsseg.yaml index 2f1d588..de5b469 100644 --- a/test_tipc/configs/seg/_base_/rsseg.yaml +++ b/test_tipc/configs/seg/_base_/rsseg.yaml @@ -58,7 +58,7 @@ download_path: ./test_tipc/data/ num_epochs: 10 train_batch_size: 4 -save_interval_epochs: 5 +save_interval_epochs: 10 log_interval_steps: 4 save_dir: ./test_tipc/output/seg/ learning_rate: 0.001 diff --git a/test_tipc/configs/seg/deeplabv3p/deeplabv3p_rsseg.yaml b/test_tipc/configs/seg/deeplabv3p/deeplabv3p_rsseg.yaml new file mode 100644 index 0000000..c7e1248 --- /dev/null +++ b/test_tipc/configs/seg/deeplabv3p/deeplabv3p_rsseg.yaml @@ -0,0 +1,11 @@ +# Configurations of DeepLab V3+ with RSSeg dataset + +_base_: ../_base_/rsseg.yaml + +save_dir: ./test_tipc/output/seg/deeplabv3p/ + +model: !Node + type: DeepLabV3P + args: + in_channels: 10 + num_classes: 5 \ No newline at end of file diff --git a/test_tipc/configs/seg/deeplabv3p/train_infer_python.txt b/test_tipc/configs/seg/deeplabv3p/train_infer_python.txt new file mode 100644 index 0000000..de7cac6 --- /dev/null +++ b/test_tipc/configs/seg/deeplabv3p/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:seg:deeplabv3p +python:python +gpu_list:0|0,1 +use_gpu:null|null +--precision:null +--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=30 +--save_dir:adaptive +--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4 +--model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/seg/deeplabv3p/deeplabv3p_rsseg.yaml|lite_train_whole_infer=./test_tipc/configs/seg/deeplabv3p/deeplabv3p_rsseg.yaml|whole_train_whole_infer=./test_tipc/configs/seg/deeplabv3p/deeplabv3p_rsseg.yaml +train_model_name:best_model +null:null +## +trainer:norm +norm_train:test_tipc/run_task.py train seg +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir:adaptive +--model_dir:adaptive +--fixed_input_shape:[-1,10,512,512] +norm_export:deploy/export/export_model.py +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:null +infer_export:null +infer_quant:False +inference:test_tipc/infer.py +--device:cpu|gpu +--enable_mkldnn:True +--cpu_threads:6 +--batch_size:1 +--use_trt:False +--precision:fp32 +--model_dir:null +--config:null +--save_log_path:null +--benchmark:True +--model_name:deeplabv3p +null:null \ No newline at end of file diff --git a/test_tipc/configs/seg/unet/train_infer_python.txt b/test_tipc/configs/seg/unet/train_infer_python.txt index 1a548e1..8abf325 100644 --- a/test_tipc/configs/seg/unet/train_infer_python.txt +++ b/test_tipc/configs/seg/unet/train_infer_python.txt @@ -4,16 +4,16 @@ python:python gpu_list:0|0,1 use_gpu:null|null --precision:null ---num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=10 +--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=20 --save_dir:adaptive --train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4 --model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/seg/unet/unet_rsseg.yaml|lite_train_whole_infer=./test_tipc/configs/seg/unet/unet_rsseg.yaml|whole_train_whole_infer=./test_tipc/configs/seg/unet/unet_rsseg.yaml train_model_name:best_model -train_infer_file_list:./test_tipc/data/rsseg/:./test_tipc/data/rsseg/val.txt null:null ## trainer:norm -norm_train:test_tipc/run_task.py train seg --config ./test_tipc/configs/seg/unet/unet.yaml +norm_train:test_tipc/run_task.py train seg pact_train:null fpgm_train:null distill_train:null @@ -46,7 +46,7 @@ inference:test_tipc/infer.py --use_trt:False --precision:fp32 --model_dir:null ---file_list:null:null +--config:null --save_log_path:null --benchmark:True --model_name:unet diff --git a/test_tipc/configs/seg/unet/unet.yaml b/test_tipc/configs/seg/unet/unet.yaml deleted file mode 100644 index 045347c..0000000 --- a/test_tipc/configs/seg/unet/unet.yaml +++ /dev/null @@ -1,11 +0,0 @@ -# Basic configurations of UNet - -_base_: ../_base_/rsseg.yaml - -save_dir: ./test_tipc/output/seg/unet/ - -model: !Node - type: UNet - args: - in_channels: 10 - num_classes: 5 diff --git a/test_tipc/configs/seg/unet/unet_rsseg.yaml b/test_tipc/configs/seg/unet/unet_rsseg.yaml new file mode 100644 index 0000000..18211b5 --- /dev/null +++ b/test_tipc/configs/seg/unet/unet_rsseg.yaml @@ -0,0 +1,11 @@ +# Configurations of UNet with RSSeg dataset + +_base_: ../_base_/rsseg.yaml + +save_dir: ./test_tipc/output/seg/unet/ + +model: !Node + type: UNet + args: + in_channels: 10 + num_classes: 5 \ No newline at end of file diff --git a/test_tipc/docs/test_train_inference_python.md b/test_tipc/docs/test_train_inference_python.md index 5100f81..72a321b 100644 --- a/test_tipc/docs/test_train_inference_python.md +++ b/test_tipc/docs/test_train_inference_python.md @@ -6,22 +6,62 @@ Linux GPU/CPU 基础训练推理测试的主程序为`test_train_inference_pytho - 训练相关: -| 任务类别 | 模型名称 | 单机单卡 | 单机多卡 | -| :----: | :----: | :----: | :----: | -| 变化检测 | BIT | 正常训练 | 正常训练 | -| 场景分类 | HRNet | 正常训练 | 正常训练 | -| 目标检测 | PP-YOLO | 正常训练 | 正常训练 | -| 图像分割 | UNet | 正常训练 | 正常训练 | +| 任务类别 | 模型名称 | 单机单卡 | 单机多卡 | 参考预测精度 | +| :----: | :----: | :----: | :----: | :----: | +| 变化检测 | BIT | 正常训练 | 正常训练 | IoU=71.02% | +| 变化检测 | CDNet | 正常训练 | 正常训练 | IoU=56.02% | +| 变化检测 | ChangeFormer | 正常训练 | 正常训练 | IoU=61.65% | +| 变化检测 | DSAMNet | 正常训练 | 正常训练 | IoU=69.76% | +| 变化检测 | DSIFN | 正常训练 | 正常训练 | IoU=72.88% | +| 变化检测 | SNUNet | 正常训练 | 正常训练 | IoU=68.46% | +| 变化检测 | STANet | 正常训练 | 正常训练 | IoU=65.11% | +| 变化检测 | FC-EF | 正常训练 | 正常训练 | IoU=64.22% | +| 变化检测 | FC-Siam-conc | 正常训练 | 正常训练 | IoU=65.79% | +| 变化检测 | FC-Siam-diff | 正常训练 | 正常训练 | IoU=61.23% | +| 变化检测 | FCCDN | 正常训练 | 正常训练 | IoU=24.42% | +| 场景分类 | HRNet | 正常训练 | 正常训练 | Acc(top1)=99.37% | +| 场景分类 | MobileNetV3 | 正常训练 | 正常训练 | Acc(top1)=99.58% | +| 场景分类 | ResNet50-vd | 正常训练 | 正常训练 | Acc(top1)=99.26% | +| 图像复原 | DRN | 正常训练 | 正常训练 | PSNR=24.23 | +| 图像复原 | ESRGAN | 正常训练 | 正常训练 | PSNR=21.30 | +| 图像复原 | LESRCNN | 正常训练 | 正常训练 | PSNR=23.18 | +| 目标检测 | Faster R-CNN | 正常训练 | 正常训练 | mAP=46.99% | +| 目标检测 | PP-YOLO | 正常训练 | 正常训练 | mAP=56.02% | +| 目标检测 | PP-YOLO Tiny | 正常训练 | 正常训练 | mAP=44.27% | +| 目标检测 | PP-YOLOv2 | 正常训练 | 正常训练 | mAP=59.37% | +| 目标检测 | YOLOv3 | 正常训练 | 正常训练 | mAP=47.33% | +| 图像分割 | DeepLab V3+ | 正常训练 | 正常训练 | mIoU=56.05% | +| 图像分割 | UNet | 正常训练 | 正常训练 | mIoU=55.50% | + +*注:参考预测精度为whole_train_whole_infer模式下单卡训练汇报的精度数据。* - 推理相关: | 任务类别 | 模型名称 | device_CPU | device_GPU | batchsize | | :----: | :----: | :----: | :----: | :----: | -| 变化检测 | BIT | 支持 | 支持 | 1 | -| 场景分类 | HRNet | 支持 | 支持 | 1 | -| 目标检测 | YOLO | 支持 | 支持 | 1 | -| 图像分割 | UNet | 支持 | 支持 | 1 | - +| 变化检测 | BIT | 支持 | 支持 | 1 | +| 变化检测 | CDNet | 支持 | 支持 | 1 | +| 变化检测 | ChangeFormer | 支持 | 支持 | 1 | +| 变化检测 | DSAMNet | 支持 | 支持 | 1 | +| 变化检测 | DSIFN | 支持 | 支持 | 1 | +| 变化检测 | SNUNet | 支持 | 支持 | 1 | +| 变化检测 | STANet | 支持 | 支持 | 1 | +| 变化检测 | FC-EF | 支持 | 支持 | 1 | +| 变化检测 | FC-Siam-conc | 支持 | 支持 | 1 | +| 变化检测 | FC-Siam-diff | 支持 | 支持 | 1 | +| 场景分类 | HRNet | 支持 | 支持 | 1 | +| 场景分类 | MobileNetV3 | 支持 | 支持 | 1 | +| 场景分类 | ResNet50-vd | 支持 | 支持 | 1 | +| 图像复原 | DRN | 支持 | 支持 | 1 | +| 图像复原 | ESRGAN | 支持 | 支持 | 1 | +| 图像复原 | LESRCNN | 支持 | 支持 | 1 | +| 目标检测 | Faster R-CNN | 支持 | 支持 | 1 | +| 目标检测 | PP-YOLO | 支持 | 支持 | 1 | +| 目标检测 | PP-YOLO Tiny | 支持 | 支持 | 1 | +| 目标检测 | PP-YOLOv2 | 支持 | 支持 | 1 | +| 目标检测 | YOLOv3 | 支持 | 支持 | 1 | +| 图像分割 | DeepLab V3+ | 支持 | 支持 | 1 | +| 图像分割 | UNet | 支持 | 支持 | 1 | ## 2 测试流程 @@ -67,7 +107,7 @@ bash ./test_tipc/test_train_inference_python.sh test_tipc/configs/clas/hrnet/tra 运行相应指令后,在`test_tipc/output`目录中会自动保存运行日志。如lite_train_lite_infer模式下,该目录中可能存在以下文件: ``` -test_tipc/output/[task name]/[model name]/ +test_tipc/output/{task name}/{model name}/ |- results_python.log # 存储指令执行状态的日志 |- norm_gpus_0_autocast_null/ # GPU 0号卡上的训练日志和模型保存目录 ...... diff --git a/test_tipc/infer.py b/test_tipc/infer.py index 3672940..28a717f 100644 --- a/test_tipc/infer.py +++ b/test_tipc/infer.py @@ -13,6 +13,8 @@ from paddle.inference import PrecisionType from paddlers.tasks import load_model from paddlers.utils import logging +from config_utils import parse_configs + class _bool(object): def __new__(cls, x): @@ -101,7 +103,7 @@ class TIPCPredictor(object): logging.warning( "Semantic segmentation models do not support TensorRT acceleration, " "TensorRT is forcibly disabled.") - elif 'RCNN' in self._model.__class__.__name__: + elif self._model.model_type == 'detector' and 'RCNN' in self._model.__class__.__name__: logging.warning( "RCNN models do not support TensorRT acceleration, " "TensorRT is forcibly disabled.") @@ -123,7 +125,7 @@ class TIPCPredictor(object): ) else: try: - # Cache 10 different shapes for mkldnn to avoid memory leak + # Cache 10 different shapes for mkldnn to avoid memory leak. config.set_mkldnn_cache_capacity(10) config.enable_mkldnn() config.set_cpu_math_library_num_threads(mkl_thread_num) @@ -158,13 +160,23 @@ class TIPCPredictor(object): 'image2': preprocessed_samples[1], 'ori_shape': preprocessed_samples[2] } + elif self._model.model_type == 'restorer': + preprocessed_samples = { + 'image': preprocessed_samples[0], + 'tar_shape': preprocessed_samples[1] + } else: logging.error( "Invalid model type {}".format(self._model.model_type), exit=True) return preprocessed_samples - def postprocess(self, net_outputs, topk=1, ori_shape=None, transforms=None): + def postprocess(self, + net_outputs, + topk=1, + ori_shape=None, + tar_shape=None, + transforms=None): if self._model.model_type == 'classifier': true_topk = min(self._model.num_classes, topk) if self._model.postprocess is None: @@ -196,6 +208,12 @@ class TIPCPredictor(object): for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs) } preds = self._model.postprocess(net_outputs) + elif self._model.model_type == 'restorer': + res_maps = self._model.postprocess( + net_outputs[0], + batch_tar_shape=tar_shape, + transforms=transforms.transforms) + preds = [{'res_map': res_map} for res_map in res_maps] else: logging.error( "Invalid model type {}.".format(self._model.model_type), @@ -232,6 +250,7 @@ class TIPCPredictor(object): net_outputs, topk, ori_shape=preprocessed_input.get('ori_shape', None), + tar_shape=preprocessed_input.get('tar_shape', None), transforms=transforms) if self.benchmark and time_it: @@ -285,7 +304,8 @@ class TIPCPredictor(object): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--file_list', type=str, nargs=2) + parser.add_argument('--config', type=str) + parser.add_argument('--inherit_off', action='store_true') parser.add_argument('--model_dir', type=str, default='./') parser.add_argument( '--device', type=str, choices=['cpu', 'gpu'], default='cpu') @@ -300,6 +320,11 @@ if __name__ == '__main__': args = parser.parse_args() + cfg = parse_configs(args.config, not args.inherit_off) + eval_dataset = cfg['datasets']['eval'] + data_dir = eval_dataset.args['data_dir'] + file_list = eval_dataset.args['file_list'] + predictor = TIPCPredictor( args.model_dir, device=args.device, @@ -310,7 +335,7 @@ if __name__ == '__main__': trt_precision_mode=args.precision, benchmark=args.benchmark) - predictor.predict(args.file_list[0], args.file_list[1]) + predictor.predict(data_dir, file_list) if args.benchmark: predictor.autolog.report() diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh index ead48af..0198213 100644 --- a/test_tipc/prepare.sh +++ b/test_tipc/prepare.sh @@ -35,6 +35,8 @@ if [[ ${MODE} == 'lite_train_lite_infer' \ download_and_unzip_dataset "${DATA_DIR}" ucmerced https://paddlers.bj.bcebos.com/datasets/ucmerced.zip elif [[ ${task_name} == 'det' ]]; then download_and_unzip_dataset "${DATA_DIR}" sarship https://paddlers.bj.bcebos.com/datasets/sarship.zip + elif [[ ${task_name} == 'res' ]]; then + download_and_unzip_dataset "${DATA_DIR}" rssr https://paddlers.bj.bcebos.com/datasets/rssr_mini.zip elif [[ ${task_name} == 'seg' ]]; then download_and_unzip_dataset "${DATA_DIR}" rsseg https://paddlers.bj.bcebos.com/datasets/rsseg_mini.zip fi @@ -42,12 +44,26 @@ if [[ ${MODE} == 'lite_train_lite_infer' \ elif [[ ${MODE} == 'whole_train_whole_infer' ]]; then if [[ ${task_name} == 'cd' ]]; then + rm -rf "${DATA_DIR}/levircd" download_and_unzip_dataset "${DATA_DIR}" raw_levircd https://paddlers.bj.bcebos.com/datasets/raw/LEVIR-CD.zip \ && python tools/prepare_dataset/prepare_levircd.py \ --in_dataset_dir "${DATA_DIR}/raw_levircd" \ --out_dataset_dir "${DATA_DIR}/levircd" \ --crop_size 256 \ --crop_stride 256 + elif [[ ${task_name} == 'clas' ]]; then + download_and_unzip_dataset "${DATA_DIR}" ucmerced https://paddlers.bj.bcebos.com/datasets/ucmerced.zip + elif [[ ${task_name} == 'det' ]]; then + rm -rf "${DATA_DIR}/rsod" + download_and_unzip_dataset "${DATA_DIR}" raw_rsod https://paddlers.bj.bcebos.com/datasets/raw/RSOD.zip + python tools/prepare_dataset/prepare_rsod.py \ + --in_dataset_dir "${DATA_DIR}/raw_rsod" \ + --out_dataset_dir "${DATA_DIR}/rsod" \ + --seed 114514 + elif [[ ${task_name} == 'res' ]]; then + download_and_unzip_dataset "${DATA_DIR}" rssr https://paddlers.bj.bcebos.com/datasets/rssr.zip + elif [[ ${task_name} == 'seg' ]]; then + download_and_unzip_dataset "${DATA_DIR}" rsseg https://paddlers.bj.bcebos.com/datasets/rsseg.zip fi fi diff --git a/tests/data/data_utils.py b/tests/data/data_utils.py index 404e04e..afd9a0e 100644 --- a/tests/data/data_utils.py +++ b/tests/data/data_utils.py @@ -14,7 +14,6 @@ import os.path as osp import re -import imghdr import platform from collections import OrderedDict from functools import partial, wraps @@ -34,20 +33,6 @@ def norm_path(path): return path -def is_pic(im_path): - valid_suffix = [ - 'JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png', 'npy' - ] - suffix = im_path.split('.')[-1] - if suffix in valid_suffix: - return True - im_format = imghdr.what(im_path) - _, ext = osp.splitext(im_path) - if im_format == 'tiff' or ext == '.img': - return True - return False - - def get_full_path(p, prefix=''): p = norm_path(p) return osp.join(prefix, p) @@ -323,15 +308,34 @@ class ConstrDetSample(ConstrSample): return samples -def build_input_from_file(file_list, prefix='', task='auto', label_list=None): +class ConstrResSample(ConstrSample): + def __init__(self, prefix, label_list, sr_factor=None): + super().__init__(prefix, label_list) + self.sr_factor = sr_factor + + def __call__(self, src_path, tar_path): + sample = { + 'image': self.get_full_path(src_path), + 'target': self.get_full_path(tar_path) + } + if self.sr_factor is not None: + sample['sr_factor'] = self.sr_factor + return sample + + +def build_input_from_file(file_list, + prefix='', + task='auto', + label_list=None, + **kwargs): """ Construct a list of dictionaries from file. Each dict in the list can be used as the input to paddlers.transforms.Transform objects. Args: - file_list (str): Path of file_list. + file_list (str): Path of file list. prefix (str, optional): A nonempty `prefix` specifies the directory that stores the images and annotation files. Default: ''. - task (str, optional): Supported values are 'seg', 'det', 'cd', 'clas', and 'auto'. When `task` is set to 'auto', automatically determine the task based on the input. - Default: 'auto'. + task (str, optional): Supported values are 'seg', 'det', 'cd', 'clas', 'res', and 'auto'. When `task` is set to 'auto', + automatically determine the task based on the input. Default: 'auto'. label_list (str|None, optional): Path of label_list. Default: None. Returns: @@ -339,22 +343,21 @@ def build_input_from_file(file_list, prefix='', task='auto', label_list=None): """ def _determine_task(parts): + task = 'unknown' if len(parts) in (3, 5): task = 'cd' elif len(parts) == 2: if parts[1].isdigit(): task = 'clas' - elif is_pic(osp.join(prefix, parts[1])): - task = 'seg' - else: + elif parts[1].endswith('.xml'): task = 'det' - else: + if task == 'unknown': raise RuntimeError( "Cannot automatically determine the task type. Please specify `task` manually." ) return task - if task not in ('seg', 'det', 'cd', 'clas', 'auto'): + if task not in ('seg', 'det', 'cd', 'clas', 'res', 'auto'): raise ValueError("Invalid value of `task`") samples = [] @@ -366,9 +369,8 @@ def build_input_from_file(file_list, prefix='', task='auto', label_list=None): if task == 'auto': task = _determine_task(parts) if ctor is None: - # Select and build sample constructor ctor_class = globals()['Constr' + task.capitalize() + 'Sample'] - ctor = ctor_class(prefix, label_list) + ctor = ctor_class(prefix, label_list, **kwargs) sample = ctor(*parts) if isinstance(sample, list): samples.extend(sample) diff --git a/tests/deploy/test_predictor.py b/tests/deploy/test_predictor.py index 6283951..24db0ff 100644 --- a/tests/deploy/test_predictor.py +++ b/tests/deploy/test_predictor.py @@ -24,7 +24,7 @@ from testing_utils import CommonTest, run_script __all__ = [ 'TestCDPredictor', 'TestClasPredictor', 'TestDetPredictor', - 'TestSegPredictor' + 'TestResPredictor', 'TestSegPredictor' ] @@ -105,7 +105,7 @@ class TestPredictor(CommonTest): dict_[key], expected_dict[key], rtol=1.e-4, atol=1.e-6) -@TestPredictor.add_tests +# @TestPredictor.add_tests class TestCDPredictor(TestPredictor): MODULE = pdrs.tasks.change_detector TRAINER_NAME_TO_EXPORT_OPTS = { @@ -177,7 +177,7 @@ class TestCDPredictor(TestPredictor): self.assertEqual(len(out_multi_array_t), num_inputs) -@TestPredictor.add_tests +# @TestPredictor.add_tests class TestClasPredictor(TestPredictor): MODULE = pdrs.tasks.classifier TRAINER_NAME_TO_EXPORT_OPTS = { @@ -185,7 +185,7 @@ class TestClasPredictor(TestPredictor): } def check_predictor(self, predictor, trainer): - single_input = "data/ssmt/optical_t1.bmp" + single_input = "data/ssst/optical.bmp" num_inputs = 2 transforms = pdrs.transforms.Compose([ pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(), @@ -242,7 +242,7 @@ class TestClasPredictor(TestPredictor): self.check_dict_equal(out_multi_array_p, out_multi_array_t) -@TestPredictor.add_tests +# @TestPredictor.add_tests class TestDetPredictor(TestPredictor): MODULE = pdrs.tasks.object_detector TRAINER_NAME_TO_EXPORT_OPTS = { @@ -253,7 +253,7 @@ class TestDetPredictor(TestPredictor): # For detection tasks, do NOT ensure the consistence of bboxes. # This is because the coordinates of bboxes were observed to be very sensitive to numeric errors, # given that the network is (partially?) randomly initialized. - single_input = "data/ssmt/optical_t1.bmp" + single_input = "data/ssst/optical.bmp" num_inputs = 2 transforms = pdrs.transforms.Compose([ pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(), @@ -303,6 +303,59 @@ class TestDetPredictor(TestPredictor): @TestPredictor.add_tests +class TestResPredictor(TestPredictor): + MODULE = pdrs.tasks.restorer + + def check_predictor(self, predictor, trainer): + # For restoration tasks, do NOT ensure the consistence of numeric values, + # because the output is of uint8 type. + single_input = "data/ssst/optical.bmp" + num_inputs = 2 + transforms = pdrs.transforms.Compose([ + pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(), + pdrs.transforms.ArrangeRestorer('test') + ]) + + # Single input (file path) + input_ = single_input + predictor.predict(input_, transforms=transforms) + trainer.predict(input_, transforms=transforms) + out_single_file_list_p = predictor.predict( + [input_], transforms=transforms) + self.assertEqual(len(out_single_file_list_p), 1) + out_single_file_list_t = trainer.predict( + [input_], transforms=transforms) + self.assertEqual(len(out_single_file_list_t), 1) + + # Single input (ndarray) + input_ = decode_image( + single_input, to_rgb=False) # Reuse the name `input_` + predictor.predict(input_, transforms=transforms) + trainer.predict(input_, transforms=transforms) + out_single_array_list_p = predictor.predict( + [input_], transforms=transforms) + self.assertEqual(len(out_single_array_list_p), 1) + out_single_array_list_t = trainer.predict( + [input_], transforms=transforms) + self.assertEqual(len(out_single_array_list_t), 1) + + # Multiple inputs (file paths) + input_ = [single_input] * num_inputs # Reuse the name `input_` + out_multi_file_p = predictor.predict(input_, transforms=transforms) + self.assertEqual(len(out_multi_file_p), num_inputs) + out_multi_file_t = trainer.predict(input_, transforms=transforms) + self.assertEqual(len(out_multi_file_t), num_inputs) + + # Multiple inputs (ndarrays) + input_ = [decode_image( + single_input, to_rgb=False)] * num_inputs # Reuse the name `input_` + out_multi_array_p = predictor.predict(input_, transforms=transforms) + self.assertEqual(len(out_multi_array_p), num_inputs) + out_multi_array_t = trainer.predict(input_, transforms=transforms) + self.assertEqual(len(out_multi_array_t), num_inputs) + + +# @TestPredictor.add_tests class TestSegPredictor(TestPredictor): MODULE = pdrs.tasks.segmenter TRAINER_NAME_TO_EXPORT_OPTS = { @@ -310,7 +363,7 @@ class TestSegPredictor(TestPredictor): } def check_predictor(self, predictor, trainer): - single_input = "data/ssmt/optical_t1.bmp" + single_input = "data/ssst/optical.bmp" num_inputs = 2 transforms = pdrs.transforms.Compose([ pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(), diff --git a/tests/rs_models/test_cd_models.py b/tests/rs_models/test_cd_models.py index 8478ea4..3b3e683 100644 --- a/tests/rs_models/test_cd_models.py +++ b/tests/rs_models/test_cd_models.py @@ -33,9 +33,7 @@ class TestCDModel(TestModel): self.check_output_equal(len(output), len(target)) for o, t in zip(output, target): o = o.numpy() - self.check_output_equal(o.shape[0], t.shape[0]) - self.check_output_equal(len(o.shape), 4) - self.check_output_equal(o.shape[2:], t.shape[2:]) + self.check_output_equal(o.shape, t.shape) def set_inputs(self): if self.EF_MODE == 'Concat': diff --git a/tests/rs_models/test_det_models.py b/tests/rs_models/test_det_models.py index 5aed6ef..112c6cd 100644 --- a/tests/rs_models/test_det_models.py +++ b/tests/rs_models/test_det_models.py @@ -32,3 +32,6 @@ class TestDetModel(TestModel): def set_inputs(self): self.inputs = cycle([self.get_randn_tensor(3)]) + + def set_targets(self): + self.targets = cycle([None]) diff --git a/tests/rs_models/test_res_models.py b/tests/rs_models/test_res_models.py new file mode 100644 index 0000000..8b6ec56 --- /dev/null +++ b/tests/rs_models/test_res_models.py @@ -0,0 +1,46 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddlers +from rs_models.test_model import TestModel + +__all__ = [] + + +class TestResModel(TestModel): + def check_output(self, output, target): + output = output.numpy() + self.check_output_equal(output.shape, target.shape) + + def set_inputs(self): + def _gen_data(specs): + for spec in specs: + c = spec.get('in_channels', 3) + yield self.get_randn_tensor(c) + + self.inputs = _gen_data(self.specs) + + def set_targets(self): + def _gen_data(specs): + for spec in specs: + # XXX: Hard coding + if 'out_channels' in spec: + c = spec['out_channels'] + elif 'in_channels' in spec: + c = spec['in_channels'] + else: + c = 3 + yield [self.get_zeros_array(c)] + + self.targets = _gen_data(self.specs) diff --git a/tests/rs_models/test_seg_models.py b/tests/rs_models/test_seg_models.py index cc3415d..88fb6e1 100644 --- a/tests/rs_models/test_seg_models.py +++ b/tests/rs_models/test_seg_models.py @@ -26,9 +26,7 @@ class TestSegModel(TestModel): self.check_output_equal(len(output), len(target)) for o, t in zip(output, target): o = o.numpy() - self.check_output_equal(o.shape[0], t.shape[0]) - self.check_output_equal(len(o.shape), 4) - self.check_output_equal(o.shape[2:], t.shape[2:]) + self.check_output_equal(o.shape, t.shape) def set_inputs(self): def _gen_data(specs): @@ -54,3 +52,7 @@ class TestFarSegModel(TestSegModel): self.specs = [ dict(), dict(num_classes=20), dict(encoder_pretrained=False) ] + + def set_targets(self): + self.targets = [[self.get_zeros_array(16)], [self.get_zeros_array(20)], + [self.get_zeros_array(16)]] diff --git a/tests/run_examples.sh b/tests/run_examples.sh deleted file mode 100644 index f1f641a..0000000 --- a/tests/run_examples.sh +++ /dev/null @@ -1 +0,0 @@ -#!/usr/bin/env bash diff --git a/tests/run_tests.sh b/tests/run_tests.sh index 236f055..94253e5 100644 --- a/tests/run_tests.sh +++ b/tests/run_tests.sh @@ -15,6 +15,3 @@ done # Test tutorials bash run_tutorials.sh - -# Test examples -bash run_examples.sh diff --git a/tests/transforms/test_operators.py b/tests/transforms/test_operators.py index 6ddac53..cff8428 100644 --- a/tests/transforms/test_operators.py +++ b/tests/transforms/test_operators.py @@ -164,12 +164,15 @@ class TestTransform(CpuCommonTest): prefix="./data/ssst"), build_input_from_file( "data/ssst/test_optical_seg.txt", + task='seg', prefix="./data/ssst"), build_input_from_file( "data/ssst/test_sar_seg.txt", + task='seg', prefix="./data/ssst"), build_input_from_file( "data/ssst/test_multispectral_seg.txt", + task='seg', prefix="./data/ssst"), build_input_from_file( "data/ssst/test_optical_det.txt", @@ -185,7 +188,23 @@ class TestTransform(CpuCommonTest): label_list="data/ssst/labels_det.txt"), build_input_from_file( "data/ssst/test_det_coco.txt", + task='det', prefix="./data/ssst"), + build_input_from_file( + "data/ssst/test_optical_res.txt", + task='res', + prefix="./data/ssst", + sr_factor=4), + build_input_from_file( + "data/ssst/test_sar_res.txt", + task='res', + prefix="./data/ssst", + sr_factor=4), + build_input_from_file( + "data/ssst/test_multispectral_res.txt", + task='res', + prefix="./data/ssst", + sr_factor=4), build_input_from_file( "data/ssmt/test_mixed_binary.txt", prefix="./data/ssmt"), @@ -227,6 +246,8 @@ class TestTransform(CpuCommonTest): self.aux_mask_values = [ set(aux_mask.ravel()) for aux_mask in sample['aux_masks'] ] + if 'target' in sample: + self.target_shape = sample['target'].shape return sample def _out_hook_not_keep_ratio(sample): @@ -243,6 +264,21 @@ class TestTransform(CpuCommonTest): for aux_mask, amv in zip(sample['aux_masks'], self.aux_mask_values): self.assertLessEqual(set(aux_mask.ravel()), amv) + if 'target' in sample: + if 'sr_factor' in sample: + self.check_output_equal( + sample['target'].shape[:2], + T.functions.calc_hr_shape(TARGET_SIZE, + sample['sr_factor'])) + else: + self.check_output_equal(sample['target'].shape[:2], + TARGET_SIZE) + self.check_output_equal( + sample['target'].shape[0] / self.target_shape[0], + sample['image'].shape[0] / self.image_shape[0]) + self.check_output_equal( + sample['target'].shape[1] / self.target_shape[1], + sample['image'].shape[1] / self.image_shape[1]) # TODO: Test gt_bbox and gt_poly return sample @@ -260,6 +296,13 @@ class TestTransform(CpuCommonTest): for aux_mask, ori_aux_mask_shape in zip(sample['aux_masks'], self.aux_mask_shapes): __check_ratio(aux_mask.shape, ori_aux_mask_shape) + if 'target' in sample: + self.check_output_equal( + sample['target'].shape[0] / self.target_shape[0], + sample['image'].shape[0] / self.image_shape[0]) + self.check_output_equal( + sample['target'].shape[1] / self.target_shape[1], + sample['image'].shape[1] / self.image_shape[1]) # TODO: Test gt_bbox and gt_poly return sample diff --git a/tools/prepare_dataset/prepare_rsod.py b/tools/prepare_dataset/prepare_rsod.py new file mode 100644 index 0000000..d9cd5ee --- /dev/null +++ b/tools/prepare_dataset/prepare_rsod.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python + +import random +import os.path as osp +from functools import reduce, partial + +from common import (get_default_parser, get_path_tuples, create_file_list, + link_dataset, random_split, create_label_list) + +CLASSES = ('aircraft', 'oiltank', 'overpass', 'playground') +SUBSETS = ('train', 'val', 'test') +SUBDIRS = ('JPEGImages', osp.sep.join(['Annotation', 'xml'])) +FILE_LIST_PATTERN = "{subset}.txt" +LABEL_LIST_NAME = "labels.txt" +URL = "" + +if __name__ == '__main__': + parser = get_default_parser() + parser.add_argument('--seed', type=int, default=None, help="Random seed.") + parser.add_argument( + '--ratios', + type=float, + nargs='+', + default=(0.7, 0.2, 0.1), + help="Ratios of each subset (train/val or train/val/test).") + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + + if len(args.ratios) not in (2, 3): + raise ValueError("Wrong number of ratios!") + + out_dir = osp.join(args.out_dataset_dir, + osp.basename(osp.normpath(args.in_dataset_dir))) + + link_dataset(args.in_dataset_dir, args.out_dataset_dir) + + splits_list = [] + for cls in CLASSES: + path_tuples = get_path_tuples( + *(osp.join(out_dir, cls, subdir) for subdir in SUBDIRS), + data_dir=args.out_dataset_dir) + splits = random_split(path_tuples, ratios=args.ratios) + splits_list.append(splits) + splits = map(partial(reduce, list.__add__), zip(*splits_list)) + + for subset, split in zip(SUBSETS, splits): + file_list = osp.join( + args.out_dataset_dir, FILE_LIST_PATTERN.format(subset=subset)) + create_file_list(file_list, split) + print(f"Write file list to {file_list}.") + + label_list = osp.join(args.out_dataset_dir, LABEL_LIST_NAME) + create_label_list(label_list, CLASSES) + print(f"Write label list to {label_list}.") diff --git a/tutorials/train/README.md b/tutorials/train/README.md index c63cf26..44e93a3 100644 --- a/tutorials/train/README.md +++ b/tutorials/train/README.md @@ -9,20 +9,21 @@ |change_detection/changeformer.py | 变化检测 | ChangeFormer | |change_detection/dsamnet.py | 变化检测 | DSAMNet | |change_detection/dsifn.py | 变化检测 | DSIFN | -|change_detection/snunet.py | 变化检测 | SNUNet | -|change_detection/stanet.py | 变化检测 | STANet | |change_detection/fc_ef.py | 变化检测 | FC-EF | |change_detection/fc_siam_conc.py | 变化检测 | FC-Siam-conc | |change_detection/fc_siam_diff.py | 变化检测 | FC-Siam-diff | +|change_detection/fccdn.py | 变化检测 | FCCDN | +|change_detection/snunet.py | 变化检测 | SNUNet | +|change_detection/stanet.py | 变化检测 | STANet | |classification/hrnet.py | 场景分类 | HRNet | |classification/mobilenetv3.py | 场景分类 | MobileNetV3 | |classification/resnet50_vd.py | 场景分类 | ResNet50-vd | -|image_restoration/drn.py | 超分辨率 | DRN | -|image_restoration/esrgan.py | 超分辨率 | ESRGAN | -|image_restoration/lesrcnn.py | 超分辨率 | LESRCNN | +|image_restoration/drn.py | 图像复原 | DRN | +|image_restoration/esrgan.py | 图像复原 | ESRGAN | +|image_restoration/lesrcnn.py | 图像复原 | LESRCNN | |object_detection/faster_rcnn.py | 目标检测 | Faster R-CNN | |object_detection/ppyolo.py | 目标检测 | PP-YOLO | -|object_detection/ppyolotiny.py | 目标检测 | PP-YOLO Tiny | +|object_detection/ppyolo_tiny.py | 目标检测 | PP-YOLO Tiny | |object_detection/ppyolov2.py | 目标检测 | PP-YOLOv2 | |object_detection/yolov3.py | 目标检测 | YOLOv3 | |semantic_segmentation/deeplabv3p.py | 图像分割 | DeepLab V3+ | diff --git a/tutorials/train/change_detection/bit.py b/tutorials/train/change_detection/bit.py index 83c96ce..10410f6 100644 --- a/tutorials/train/change_detection/bit.py +++ b/tutorials/train/change_detection/bit.py @@ -78,7 +78,7 @@ model = pdrs.tasks.cd.BIT() # 执行模型训练 model.train( - num_epochs=5, + num_epochs=10, train_dataset=train_dataset, train_batch_size=4, eval_dataset=eval_dataset, diff --git a/tutorials/train/change_detection/cdnet.py b/tutorials/train/change_detection/cdnet.py index 2aa2ad6..ca53f94 100644 --- a/tutorials/train/change_detection/cdnet.py +++ b/tutorials/train/change_detection/cdnet.py @@ -78,7 +78,7 @@ model = pdrs.tasks.cd.CDNet() # 执行模型训练 model.train( - num_epochs=5, + num_epochs=10, train_dataset=train_dataset, train_batch_size=4, eval_dataset=eval_dataset, diff --git a/tutorials/train/change_detection/changeformer.py b/tutorials/train/change_detection/changeformer.py index 7afbf96..7d4c3cc 100644 --- a/tutorials/train/change_detection/changeformer.py +++ b/tutorials/train/change_detection/changeformer.py @@ -72,13 +72,13 @@ eval_dataset = pdrs.datasets.CDDataset( binarize_labels=True) # 使用默认参数构建ChangeFormer模型 -# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/model_zoo.md +# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/change_detector.py model = pdrs.tasks.cd.ChangeFormer() # 执行模型训练 model.train( - num_epochs=5, + num_epochs=10, train_dataset=train_dataset, train_batch_size=4, eval_dataset=eval_dataset, diff --git a/tutorials/train/change_detection/dsamnet.py b/tutorials/train/change_detection/dsamnet.py index 2a0d3ae..5d75af4 100644 --- a/tutorials/train/change_detection/dsamnet.py +++ b/tutorials/train/change_detection/dsamnet.py @@ -78,7 +78,7 @@ model = pdrs.tasks.cd.DSAMNet() # 执行模型训练 model.train( - num_epochs=5, + num_epochs=10, train_dataset=train_dataset, train_batch_size=4, eval_dataset=eval_dataset, diff --git a/tutorials/train/change_detection/dsifn.py b/tutorials/train/change_detection/dsifn.py index 6a2ed19..8186ba5 100644 --- a/tutorials/train/change_detection/dsifn.py +++ b/tutorials/train/change_detection/dsifn.py @@ -78,7 +78,7 @@ model = pdrs.tasks.cd.DSIFN() # 执行模型训练 model.train( - num_epochs=5, + num_epochs=10, train_dataset=train_dataset, train_batch_size=4, eval_dataset=eval_dataset, diff --git a/tutorials/train/change_detection/fc_ef.py b/tutorials/train/change_detection/fc_ef.py index 4324564..1a2f6ae 100644 --- a/tutorials/train/change_detection/fc_ef.py +++ b/tutorials/train/change_detection/fc_ef.py @@ -78,7 +78,7 @@ model = pdrs.tasks.cd.FCEarlyFusion() # 执行模型训练 model.train( - num_epochs=5, + num_epochs=10, train_dataset=train_dataset, train_batch_size=4, eval_dataset=eval_dataset, diff --git a/tutorials/train/change_detection/fc_siam_conc.py b/tutorials/train/change_detection/fc_siam_conc.py index d63f5dc..19c4912 100644 --- a/tutorials/train/change_detection/fc_siam_conc.py +++ b/tutorials/train/change_detection/fc_siam_conc.py @@ -78,7 +78,7 @@ model = pdrs.tasks.cd.FCSiamConc() # 执行模型训练 model.train( - num_epochs=5, + num_epochs=10, train_dataset=train_dataset, train_batch_size=4, eval_dataset=eval_dataset, diff --git a/tutorials/train/change_detection/fc_siam_diff.py b/tutorials/train/change_detection/fc_siam_diff.py index 55f8681..c289d8d 100644 --- a/tutorials/train/change_detection/fc_siam_diff.py +++ b/tutorials/train/change_detection/fc_siam_diff.py @@ -78,7 +78,7 @@ model = pdrs.tasks.cd.FCSiamDiff() # 执行模型训练 model.train( - num_epochs=5, + num_epochs=10, train_dataset=train_dataset, train_batch_size=4, eval_dataset=eval_dataset, diff --git a/tutorials/train/change_detection/fccdn.py b/tutorials/train/change_detection/fccdn.py new file mode 100644 index 0000000..62abbba --- /dev/null +++ b/tutorials/train/change_detection/fccdn.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python + +# 变化检测模型FCCDN训练示例脚本 +# 执行此脚本前,请确认已正确安装PaddleRS库 + +import paddlers as pdrs +from paddlers import transforms as T + +# 数据集存放目录 +DATA_DIR = './data/airchange/' +# 训练集`file_list`文件路径 +TRAIN_FILE_LIST_PATH = './data/airchange/train.txt' +# 验证集`file_list`文件路径 +EVAL_FILE_LIST_PATH = './data/airchange/eval.txt' +# 实验目录,保存输出的模型权重和结果 +EXP_DIR = './output/fccdn/' + +# 下载和解压AirChange数据集 +pdrs.utils.download_and_decompress( + 'https://paddlers.bj.bcebos.com/datasets/airchange.zip', path='./data/') + +# 定义训练和验证时使用的数据变换(数据增强、预处理等) +# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行 +# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md +train_transforms = T.Compose([ + # 读取影像 + T.DecodeImg(), + # 随机裁剪 + T.RandomCrop( + # 裁剪区域将被缩放到256x256 + crop_size=256, + # 裁剪区域的横纵比在0.5-2之间变动 + aspect_ratio=[0.5, 2.0], + # 裁剪区域相对原始影像长宽比例在一定范围内变动,最小不低于原始长宽的1/5 + scaling=[0.2, 1.0]), + # 以50%的概率实施随机水平翻转 + T.RandomHorizontalFlip(prob=0.5), + # 将数据归一化到[-1,1] + T.Normalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + T.ArrangeChangeDetector('train') +]) + +eval_transforms = T.Compose([ + T.DecodeImg(), + # 验证阶段与训练阶段的数据归一化方式必须相同 + T.Normalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + T.ReloadMask(), + T.ArrangeChangeDetector('eval') +]) + +# 分别构建训练和验证所用的数据集 +train_dataset = pdrs.datasets.CDDataset( + data_dir=DATA_DIR, + file_list=TRAIN_FILE_LIST_PATH, + label_list=None, + transforms=train_transforms, + num_workers=0, + shuffle=True, + with_seg_labels=False, + binarize_labels=True) + +eval_dataset = pdrs.datasets.CDDataset( + data_dir=DATA_DIR, + file_list=EVAL_FILE_LIST_PATH, + label_list=None, + transforms=eval_transforms, + num_workers=0, + shuffle=False, + with_seg_labels=False, + binarize_labels=True) + +# 使用默认参数构建FCCDN模型 +# 目前已支持的模型及模型输入参数请参考: +# https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/change_detector.py +model = pdrs.tasks.cd.FCCDN() + +# 执行模型训练 +model.train( + num_epochs=5, + train_dataset=train_dataset, + train_batch_size=4, + eval_dataset=eval_dataset, + save_interval_epochs=2, + # 每多少次迭代记录一次日志 + log_interval_steps=50, + save_dir=EXP_DIR, + # 是否使用early stopping策略,当精度不再改善时提前终止训练 + early_stop=False, + # 是否启用VisualDL日志功能 + use_vdl=True, + # 指定从某个检查点继续训练 + resume_checkpoint=None) diff --git a/tutorials/train/change_detection/snunet.py b/tutorials/train/change_detection/snunet.py index a4b6d65..37ef1a6 100644 --- a/tutorials/train/change_detection/snunet.py +++ b/tutorials/train/change_detection/snunet.py @@ -78,7 +78,7 @@ model = pdrs.tasks.cd.SNUNet() # 执行模型训练 model.train( - num_epochs=5, + num_epochs=10, train_dataset=train_dataset, train_batch_size=4, eval_dataset=eval_dataset, diff --git a/tutorials/train/change_detection/stanet.py b/tutorials/train/change_detection/stanet.py index 4fe9799..9659c5b 100644 --- a/tutorials/train/change_detection/stanet.py +++ b/tutorials/train/change_detection/stanet.py @@ -78,7 +78,7 @@ model = pdrs.tasks.cd.STANet() # 执行模型训练 model.train( - num_epochs=5, + num_epochs=10, train_dataset=train_dataset, train_batch_size=4, eval_dataset=eval_dataset, diff --git a/tutorials/train/classification/hrnet.py b/tutorials/train/classification/hrnet.py index 658dcef..7a89843 100644 --- a/tutorials/train/classification/hrnet.py +++ b/tutorials/train/classification/hrnet.py @@ -65,7 +65,7 @@ eval_dataset = pdrs.datasets.ClasDataset( num_workers=0, shuffle=False) -# 使用默认参数构建HRNet模型 +# 构建HRNet模型 # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py model = pdrs.tasks.clas.HRNet_W18_C(num_classes=len(train_dataset.labels)) diff --git a/tutorials/train/classification/mobilenetv3.py b/tutorials/train/classification/mobilenetv3.py index 1d85a06..36efe29 100644 --- a/tutorials/train/classification/mobilenetv3.py +++ b/tutorials/train/classification/mobilenetv3.py @@ -65,7 +65,7 @@ eval_dataset = pdrs.datasets.ClasDataset( num_workers=0, shuffle=False) -# 使用默认参数构建MobileNetV3模型 +# 构建MobileNetV3模型 # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py model = pdrs.tasks.clas.MobileNetV3_small_x1_0( diff --git a/tutorials/train/classification/resnet50_vd.py b/tutorials/train/classification/resnet50_vd.py index 40891e6..a0957f2 100644 --- a/tutorials/train/classification/resnet50_vd.py +++ b/tutorials/train/classification/resnet50_vd.py @@ -65,7 +65,7 @@ eval_dataset = pdrs.datasets.ClasDataset( num_workers=0, shuffle=False) -# 使用默认参数构建ResNet50-vd模型 +# 构建ResNet50-vd模型 # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py model = pdrs.tasks.clas.ResNet50_vd(num_classes=len(train_dataset.labels)) diff --git a/tutorials/train/image_restoration/data/.gitignore b/tutorials/train/image_restoration/data/.gitignore new file mode 100644 index 0000000..2d1d39b --- /dev/null +++ b/tutorials/train/image_restoration/data/.gitignore @@ -0,0 +1,3 @@ +*.zip +*.tar.gz +rssr/ \ No newline at end of file diff --git a/tutorials/train/image_restoration/drn.py b/tutorials/train/image_restoration/drn.py new file mode 100644 index 0000000..6af93ac --- /dev/null +++ b/tutorials/train/image_restoration/drn.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python + +# 图像复原模型DRN训练示例脚本 +# 执行此脚本前,请确认已正确安装PaddleRS库 + +import paddlers as pdrs +from paddlers import transforms as T + +# 数据集存放目录 +DATA_DIR = './data/rssr/' +# 训练集`file_list`文件路径 +TRAIN_FILE_LIST_PATH = './data/rssr/train.txt' +# 验证集`file_list`文件路径 +EVAL_FILE_LIST_PATH = './data/rssr/val.txt' +# 实验目录,保存输出的模型权重和结果 +EXP_DIR = './output/drn/' + +# 下载和解压遥感影像超分辨率数据集 +pdrs.utils.download_and_decompress( + 'https://paddlers.bj.bcebos.com/datasets/rssr.zip', path='./data/') + +# 定义训练和验证时使用的数据变换(数据增强、预处理等) +# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行 +# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md +train_transforms = T.Compose([ + # 读取影像 + T.DecodeImg(), + # 从输入影像中裁剪96x96大小的影像块 + T.RandomCrop(crop_size=96), + # 以50%的概率实施随机水平翻转 + T.RandomHorizontalFlip(prob=0.5), + # 以50%的概率实施随机垂直翻转 + T.RandomVerticalFlip(prob=0.5), + # 将数据归一化到[0,1] + T.Normalize( + mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), + T.ArrangeRestorer('train') +]) + +eval_transforms = T.Compose([ + T.DecodeImg(), + # 将输入影像缩放到256x256大小 + T.Resize(target_size=256), + # 验证阶段与训练阶段的数据归一化方式必须相同 + T.Normalize( + mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), + T.ArrangeRestorer('eval') +]) + +# 分别构建训练和验证所用的数据集 +train_dataset = pdrs.datasets.ResDataset( + data_dir=DATA_DIR, + file_list=TRAIN_FILE_LIST_PATH, + transforms=train_transforms, + num_workers=0, + shuffle=True, + sr_factor=4) + +eval_dataset = pdrs.datasets.ResDataset( + data_dir=DATA_DIR, + file_list=EVAL_FILE_LIST_PATH, + transforms=eval_transforms, + num_workers=0, + shuffle=False, + sr_factor=4) + +# 使用默认参数构建DRN模型 +# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md +# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py +model = pdrs.tasks.res.DRN() + +# 执行模型训练 +model.train( + num_epochs=10, + train_dataset=train_dataset, + train_batch_size=8, + eval_dataset=eval_dataset, + save_interval_epochs=5, + # 每多少次迭代记录一次日志 + log_interval_steps=10, + save_dir=EXP_DIR, + # 初始学习率大小 + learning_rate=0.001, + # 是否使用early stopping策略,当精度不再改善时提前终止训练 + early_stop=False, + # 是否启用VisualDL日志功能 + use_vdl=True, + # 指定从某个检查点继续训练 + resume_checkpoint=None) diff --git a/tutorials/train/image_restoration/drn_train.py b/tutorials/train/image_restoration/drn_train.py deleted file mode 100644 index 2d871a5..0000000 --- a/tutorials/train/image_restoration/drn_train.py +++ /dev/null @@ -1,80 +0,0 @@ -import os -import sys -sys.path.append(os.path.abspath('../PaddleRS')) - -import paddle -import paddlers as pdrs - -# 定义训练和验证时的transforms -train_transforms = pdrs.datasets.ComposeTrans( - input_keys=['lq', 'gt'], - output_keys=['lq', 'lqx2', 'gt'], - pipelines=[{ - 'name': 'SRPairedRandomCrop', - 'gt_patch_size': 192, - 'scale': 4, - 'scale_list': True - }, { - 'name': 'PairedRandomHorizontalFlip' - }, { - 'name': 'PairedRandomVerticalFlip' - }, { - 'name': 'PairedRandomTransposeHW' - }, { - 'name': 'Transpose' - }, { - 'name': 'Normalize', - 'mean': [0.0, 0.0, 0.0], - 'std': [1.0, 1.0, 1.0] - }]) - -test_transforms = pdrs.datasets.ComposeTrans( - input_keys=['lq', 'gt'], - output_keys=['lq', 'gt'], - pipelines=[{ - 'name': 'Transpose' - }, { - 'name': 'Normalize', - 'mean': [0.0, 0.0, 0.0], - 'std': [1.0, 1.0, 1.0] - }]) - -# 定义训练集 -train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径 -train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径 -num_workers = 4 -batch_size = 8 -scale = 4 -train_dataset = pdrs.datasets.SRdataset( - mode='train', - gt_floder=train_gt_floder, - lq_floder=train_lq_floder, - transforms=train_transforms(), - scale=scale, - num_workers=num_workers, - batch_size=batch_size) -train_dict = train_dataset() - -# 定义测试集 -test_gt_floder = r"../work/RSdata_for_SR/test_HR" -test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4" -test_dataset = pdrs.datasets.SRdataset( - mode='test', - gt_floder=test_gt_floder, - lq_floder=test_lq_floder, - transforms=test_transforms(), - scale=scale) - -# 初始化模型,可以对网络结构的参数进行调整 -model = pdrs.tasks.res.DRNet( - n_blocks=30, n_feats=16, n_colors=3, rgb_range=255, negval=0.2) - -model.train( - total_iters=100000, - train_dataset=train_dataset(), - test_dataset=test_dataset(), - output_dir='output_dir', - validate=5000, - snapshot=5000, - lr_rate=0.0001, - log=10) diff --git a/tutorials/train/image_restoration/esrgan.py b/tutorials/train/image_restoration/esrgan.py new file mode 100644 index 0000000..33ff3f8 --- /dev/null +++ b/tutorials/train/image_restoration/esrgan.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python + +# 图像复原模型ESRGAN训练示例脚本 +# 执行此脚本前,请确认已正确安装PaddleRS库 + +import paddlers as pdrs +from paddlers import transforms as T + +# 数据集存放目录 +DATA_DIR = './data/rssr/' +# 训练集`file_list`文件路径 +TRAIN_FILE_LIST_PATH = './data/rssr/train.txt' +# 验证集`file_list`文件路径 +EVAL_FILE_LIST_PATH = './data/rssr/val.txt' +# 实验目录,保存输出的模型权重和结果 +EXP_DIR = './output/esrgan/' + +# 下载和解压遥感影像超分辨率数据集 +pdrs.utils.download_and_decompress( + 'https://paddlers.bj.bcebos.com/datasets/rssr.zip', path='./data/') + +# 定义训练和验证时使用的数据变换(数据增强、预处理等) +# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行 +# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md +train_transforms = T.Compose([ + # 读取影像 + T.DecodeImg(), + # 从输入影像中裁剪32x32大小的影像块 + T.RandomCrop(crop_size=32), + # 以50%的概率实施随机水平翻转 + T.RandomHorizontalFlip(prob=0.5), + # 以50%的概率实施随机垂直翻转 + T.RandomVerticalFlip(prob=0.5), + # 将数据归一化到[0,1] + T.Normalize( + mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), + T.ArrangeRestorer('train') +]) + +eval_transforms = T.Compose([ + T.DecodeImg(), + # 将输入影像缩放到256x256大小 + T.Resize(target_size=256), + # 验证阶段与训练阶段的数据归一化方式必须相同 + T.Normalize( + mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), + T.ArrangeRestorer('eval') +]) + +# 分别构建训练和验证所用的数据集 +train_dataset = pdrs.datasets.ResDataset( + data_dir=DATA_DIR, + file_list=TRAIN_FILE_LIST_PATH, + transforms=train_transforms, + num_workers=0, + shuffle=True, + sr_factor=4) + +eval_dataset = pdrs.datasets.ResDataset( + data_dir=DATA_DIR, + file_list=EVAL_FILE_LIST_PATH, + transforms=eval_transforms, + num_workers=0, + shuffle=False, + sr_factor=4) + +# 使用默认参数构建ESRGAN模型 +# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md +# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py +model = pdrs.tasks.res.ESRGAN() + +# 执行模型训练 +model.train( + num_epochs=10, + train_dataset=train_dataset, + train_batch_size=8, + eval_dataset=eval_dataset, + save_interval_epochs=5, + # 每多少次迭代记录一次日志 + log_interval_steps=10, + save_dir=EXP_DIR, + # 初始学习率大小 + learning_rate=0.001, + # 是否使用early stopping策略,当精度不再改善时提前终止训练 + early_stop=False, + # 是否启用VisualDL日志功能 + use_vdl=True, + # 指定从某个检查点继续训练 + resume_checkpoint=None) diff --git a/tutorials/train/image_restoration/esrgan_train.py b/tutorials/train/image_restoration/esrgan_train.py deleted file mode 100644 index a972f03..0000000 --- a/tutorials/train/image_restoration/esrgan_train.py +++ /dev/null @@ -1,80 +0,0 @@ -import os -import sys -sys.path.append(os.path.abspath('../PaddleRS')) - -import paddlers as pdrs - -# 定义训练和验证时的transforms -train_transforms = pdrs.datasets.ComposeTrans( - input_keys=['lq', 'gt'], - output_keys=['lq', 'gt'], - pipelines=[{ - 'name': 'SRPairedRandomCrop', - 'gt_patch_size': 128, - 'scale': 4 - }, { - 'name': 'PairedRandomHorizontalFlip' - }, { - 'name': 'PairedRandomVerticalFlip' - }, { - 'name': 'PairedRandomTransposeHW' - }, { - 'name': 'Transpose' - }, { - 'name': 'Normalize', - 'mean': [0.0, 0.0, 0.0], - 'std': [255.0, 255.0, 255.0] - }]) - -test_transforms = pdrs.datasets.ComposeTrans( - input_keys=['lq', 'gt'], - output_keys=['lq', 'gt'], - pipelines=[{ - 'name': 'Transpose' - }, { - 'name': 'Normalize', - 'mean': [0.0, 0.0, 0.0], - 'std': [255.0, 255.0, 255.0] - }]) - -# 定义训练集 -train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径 -train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径 -num_workers = 6 -batch_size = 32 -scale = 4 -train_dataset = pdrs.datasets.SRdataset( - mode='train', - gt_floder=train_gt_floder, - lq_floder=train_lq_floder, - transforms=train_transforms(), - scale=scale, - num_workers=num_workers, - batch_size=batch_size) - -# 定义测试集 -test_gt_floder = r"../work/RSdata_for_SR/test_HR" -test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4" -test_dataset = pdrs.datasets.SRdataset( - mode='test', - gt_floder=test_gt_floder, - lq_floder=test_lq_floder, - transforms=test_transforms(), - scale=scale) - -# 初始化模型,可以对网络结构的参数进行调整 -# 若loss_type='gan' 使用感知损失、对抗损失和像素损失 -# 若loss_type = 'pixel' 只使用像素损失 -model = pdrs.tasks.res.ESRGANet(loss_type='pixel') - -model.train( - total_iters=1000000, - train_dataset=train_dataset(), - test_dataset=test_dataset(), - output_dir='output_dir', - validate=5000, - snapshot=5000, - log=100, - lr_rate=0.0001, - periods=[250000, 250000, 250000, 250000], - restart_weights=[1, 1, 1, 1]) diff --git a/tutorials/train/image_restoration/lesrcnn.py b/tutorials/train/image_restoration/lesrcnn.py new file mode 100644 index 0000000..0c27823 --- /dev/null +++ b/tutorials/train/image_restoration/lesrcnn.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python + +# 图像复原模型LESRCNN训练示例脚本 +# 执行此脚本前,请确认已正确安装PaddleRS库 + +import paddlers as pdrs +from paddlers import transforms as T + +# 数据集存放目录 +DATA_DIR = './data/rssr/' +# 训练集`file_list`文件路径 +TRAIN_FILE_LIST_PATH = './data/rssr/train.txt' +# 验证集`file_list`文件路径 +EVAL_FILE_LIST_PATH = './data/rssr/val.txt' +# 实验目录,保存输出的模型权重和结果 +EXP_DIR = './output/lesrcnn/' + +# 下载和解压遥感影像超分辨率数据集 +pdrs.utils.download_and_decompress( + 'https://paddlers.bj.bcebos.com/datasets/rssr.zip', path='./data/') + +# 定义训练和验证时使用的数据变换(数据增强、预处理等) +# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行 +# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md +train_transforms = T.Compose([ + # 读取影像 + T.DecodeImg(), + # 从输入影像中裁剪32x32大小的影像块 + T.RandomCrop(crop_size=32), + # 以50%的概率实施随机水平翻转 + T.RandomHorizontalFlip(prob=0.5), + # 以50%的概率实施随机垂直翻转 + T.RandomVerticalFlip(prob=0.5), + # 将数据归一化到[0,1] + T.Normalize( + mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), + T.ArrangeRestorer('train') +]) + +eval_transforms = T.Compose([ + T.DecodeImg(), + # 将输入影像缩放到256x256大小 + T.Resize(target_size=256), + # 验证阶段与训练阶段的数据归一化方式必须相同 + T.Normalize( + mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), + T.ArrangeRestorer('eval') +]) + +# 分别构建训练和验证所用的数据集 +train_dataset = pdrs.datasets.ResDataset( + data_dir=DATA_DIR, + file_list=TRAIN_FILE_LIST_PATH, + transforms=train_transforms, + num_workers=0, + shuffle=True, + sr_factor=4) + +eval_dataset = pdrs.datasets.ResDataset( + data_dir=DATA_DIR, + file_list=EVAL_FILE_LIST_PATH, + transforms=eval_transforms, + num_workers=0, + shuffle=False, + sr_factor=4) + +# 使用默认参数构建LESRCNN模型 +# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md +# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py +model = pdrs.tasks.res.LESRCNN() + +# 执行模型训练 +model.train( + num_epochs=10, + train_dataset=train_dataset, + train_batch_size=8, + eval_dataset=eval_dataset, + save_interval_epochs=5, + # 每多少次迭代记录一次日志 + log_interval_steps=10, + save_dir=EXP_DIR, + # 初始学习率大小 + learning_rate=0.001, + # 是否使用early stopping策略,当精度不再改善时提前终止训练 + early_stop=False, + # 是否启用VisualDL日志功能 + use_vdl=True, + # 指定从某个检查点继续训练 + resume_checkpoint=None) diff --git a/tutorials/train/image_restoration/lesrcnn_train.py b/tutorials/train/image_restoration/lesrcnn_train.py deleted file mode 100644 index 7f34f84..0000000 --- a/tutorials/train/image_restoration/lesrcnn_train.py +++ /dev/null @@ -1,78 +0,0 @@ -import os -import sys -sys.path.append(os.path.abspath('../PaddleRS')) - -import paddlers as pdrs - -# 定义训练和验证时的transforms -train_transforms = pdrs.datasets.ComposeTrans( - input_keys=['lq', 'gt'], - output_keys=['lq', 'gt'], - pipelines=[{ - 'name': 'SRPairedRandomCrop', - 'gt_patch_size': 192, - 'scale': 4 - }, { - 'name': 'PairedRandomHorizontalFlip' - }, { - 'name': 'PairedRandomVerticalFlip' - }, { - 'name': 'PairedRandomTransposeHW' - }, { - 'name': 'Transpose' - }, { - 'name': 'Normalize', - 'mean': [0.0, 0.0, 0.0], - 'std': [255.0, 255.0, 255.0] - }]) - -test_transforms = pdrs.datasets.ComposeTrans( - input_keys=['lq', 'gt'], - output_keys=['lq', 'gt'], - pipelines=[{ - 'name': 'Transpose' - }, { - 'name': 'Normalize', - 'mean': [0.0, 0.0, 0.0], - 'std': [255.0, 255.0, 255.0] - }]) - -# 定义训练集 -train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径 -train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径 -num_workers = 4 -batch_size = 16 -scale = 4 -train_dataset = pdrs.datasets.SRdataset( - mode='train', - gt_floder=train_gt_floder, - lq_floder=train_lq_floder, - transforms=train_transforms(), - scale=scale, - num_workers=num_workers, - batch_size=batch_size) - -# 定义测试集 -test_gt_floder = r"../work/RSdata_for_SR/test_HR" -test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4" -test_dataset = pdrs.datasets.SRdataset( - mode='test', - gt_floder=test_gt_floder, - lq_floder=test_lq_floder, - transforms=test_transforms(), - scale=scale) - -# 初始化模型,可以对网络结构的参数进行调整 -model = pdrs.tasks.res.LESRCNNet(scale=4, multi_scale=False, group=1) - -model.train( - total_iters=1000000, - train_dataset=train_dataset(), - test_dataset=test_dataset(), - output_dir='output_dir', - validate=5000, - snapshot=5000, - log=100, - lr_rate=0.0001, - periods=[250000, 250000, 250000, 250000], - restart_weights=[1, 1, 1, 1]) diff --git a/tutorials/train/object_detection/ppyolotiny.py b/tutorials/train/object_detection/ppyolo_tiny.py similarity index 100% rename from tutorials/train/object_detection/ppyolotiny.py rename to tutorials/train/object_detection/ppyolo_tiny.py