[Doc] Format docs (#9)

own
Lin Manhui 3 years ago committed by GitHub
parent bddddc5164
commit 64c9697a4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 8
      .pre-commit-config.yaml
  2. 4
      paddlers/datasets/__init__.py
  3. 43
      paddlers/datasets/cd_dataset.py
  4. 24
      paddlers/datasets/clas_dataset.py
  5. 65
      paddlers/datasets/coco.py
  6. 24
      paddlers/datasets/seg_dataset.py
  7. 69
      paddlers/datasets/voc.py
  8. 64
      paddlers/deploy/predictor.py
  9. 29
      paddlers/rs_models/cd/backbones/resnet.py
  10. 18
      paddlers/rs_models/cd/bit.py
  11. 40
      paddlers/rs_models/cd/cdnet.py
  12. 18
      paddlers/rs_models/cd/changestar.py
  13. 17
      paddlers/rs_models/cd/dsamnet.py
  14. 11
      paddlers/rs_models/cd/dsifn.py
  15. 12
      paddlers/rs_models/cd/fc_ef.py
  16. 12
      paddlers/rs_models/cd/fc_siam_conc.py
  17. 12
      paddlers/rs_models/cd/fc_siam_diff.py
  18. 15
      paddlers/rs_models/cd/layers/attention.py
  19. 4
      paddlers/rs_models/cd/layers/blocks.py
  20. 13
      paddlers/rs_models/cd/snunet.py
  21. 24
      paddlers/rs_models/cd/stanet.py
  22. 3
      paddlers/rs_models/res/rcan_model.py
  23. 9
      paddlers/rs_models/seg/farseg.py
  24. 9
      paddlers/rs_models/seg/layers/layers_lib.py
  25. 45
      paddlers/tasks/base.py
  26. 189
      paddlers/tasks/change_detector.py
  27. 143
      paddlers/tasks/classifier.py
  28. 8
      paddlers/tasks/load_model.py
  29. 436
      paddlers/tasks/object_detector.py
  30. 161
      paddlers/tasks/segmenter.py
  31. 29
      paddlers/tasks/utils/det_metrics/coco_utils.py
  32. 43
      paddlers/tasks/utils/visualize.py
  33. 23
      paddlers/tools/yolo_cluster.py
  34. 18
      paddlers/transforms/__init__.py
  35. 39
      paddlers/transforms/batch_operators.py
  36. 154
      paddlers/transforms/functions.py
  37. 267
      paddlers/transforms/operators.py
  38. 11
      paddlers/utils/download.py
  39. 4
      paddlers/utils/env.py
  40. 4
      paddlers/utils/stats.py
  41. 6
      paddlers/utils/utils.py
  42. 2
      tests/data/data_utils.py
  43. 26
      tests/testing_utils.py
  44. 5
      tools/coco_tools/json_AnnoSta.py
  45. 5
      tools/coco_tools/json_Img2Json.py
  46. 4
      tools/coco_tools/json_ImgSta.py
  47. 4
      tools/coco_tools/json_InfoShow.py
  48. 4
      tools/coco_tools/json_Merge.py
  49. 4
      tools/coco_tools/json_Split.py
  50. 2
      tools/match.py
  51. 37
      tools/utils/raster.py
  52. 4
      tutorials/train/object_detection/faster_rcnn.py
  53. 4
      tutorials/train/object_detection/ppyolo.py
  54. 4
      tutorials/train/object_detection/ppyolotiny.py
  55. 4
      tutorials/train/object_detection/ppyolov2.py
  56. 4
      tutorials/train/object_detection/yolov3.py

@ -1,11 +1,11 @@
repos:
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
rev: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks:
- id: yapf
files: \.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
sha: a11d9314b22d8f8c7556443875b731ef05965464
rev: a11d9314b22d8f8c7556443875b731ef05965464
hooks:
- id: check-merge-conflict
- id: check-symlinks
@ -16,7 +16,7 @@ repos:
- id: trailing-whitespace
files: \.md$
- repo: https://github.com/Lucas-C/pre-commit-hooks
sha: v1.0.1
rev: v1.0.1
hooks:
- id: forbid-crlf
files: \.md$
@ -25,4 +25,4 @@ repos:
- id: forbid-tabs
files: \.md$
- id: remove-tabs
files: \.md$
files: \.md$

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .voc import VOCDetection
from .coco import COCODetection
from .voc import VOCDetDataset
from .coco import COCODetDataset
from .seg_dataset import SegDataset
from .cd_dataset import CDDataset
from .clas_dataset import ClasDataset

@ -22,28 +22,33 @@ from paddlers.utils import logging, get_encoding, norm_path, is_pic
class CDDataset(BaseDataset):
"""
读取变化检测任务数据集并对样本进行相应的处理来自SegDataset图像标签需要两个
Dataset for change detection tasks.
Args:
data_dir (str): 数据集所在的目录路径
file_list (str): 描述数据集图片文件和对应标注文件的文件路径文本内每行路径为相对data_dir的相对路径`with_seg_labels`
False默认设置文件中每一行应依次包含第一时相影像第二时相影像以及变化检测标签的路径`with_seg_labels`为True时
文件中每一行应依次包含第一时相影像第二时相影像变化检测标签第一时相建筑物标签以及第二时相建筑物标签的路径
label_list (str): 描述数据集包含的类别信息文件路径默认值为None
transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子
num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数默认为'auto'当设为'auto'根据
系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8`num_workers`为8否则为CPU核数的
一半
shuffle (bool): 是否需要对数据集中样本打乱顺序默认为False
with_seg_labels (bool, optional): 数据集中是否包含两个时相的语义分割标签默认为False
binarize_labels (bool, optional): 是否对数据集中的标签进行二值化操作默认为False
data_dir (str): Root directory of the dataset.
file_list (str): Path of the file that contains relative paths of images and annotation files. When
`with_seg_labels` False, each line in the file contains the paths of the bi-temporal images and
the change mask. When `with_seg_labels` is True, each line in the file contains the paths of the
bi-temporal images, the path of the change mask, and the paths of the segmentation masks in both
temporal phases.
transforms (paddlers.transforms.Compose): Data preprocessing and data augmentation operators to apply.
label_list (str, optional): Path of the file that contains the category names. Defaults to None.
num_workers (int|str, optional): Number of processes used for data loading. If `num_workers` is 'auto',
the number of workers will be automatically determined according to the number of CPU cores: If
there are more than 16 cores8 workers will be used. Otherwise, the number of workers will be half
the number of CPU cores. Defaults: 'auto'.
shuffle (bool, optional): Whether to shuffle the samples. Defaults to False.
with_seg_labels (bool, optional): Set `with_seg_labels` to True if the datasets provides segmentation
masks (e.g., building masks in each temporal phase). Defaults to False.
binarize_labels (bool, optional): Whether to binarize change masks and segmentation masks.
Defaults to False.
"""
def __init__(self,
data_dir,
file_list,
transforms,
label_list=None,
transforms=None,
num_workers='auto',
shuffle=False,
with_seg_labels=False,
@ -64,8 +69,7 @@ class CDDataset(BaseDataset):
num_items = 3 # RGB1, RGB2, CD
self.binarize_labels = binarize_labels
# TODO:非None时,让用户跳转数据集分析生成label_list
# 不要在此处分析label file
# TODO: If `label_list` is not None, let the user parse `label_list`.
if label_list is not None:
with open(label_list, encoding=get_encoding(label_list)) as f:
for line in f:
@ -77,7 +81,7 @@ class CDDataset(BaseDataset):
items = line.strip().split(DELIMETER)
if len(items) != num_items:
raise Exception(
raise ValueError(
"Line[{}] in file_list[{}] has an incorrect number of file paths.".
format(line.strip(), file_list))
@ -148,7 +152,10 @@ class CDDataset(BaseDataset):
class MaskType(IntEnum):
"""Enumeration of the mask types used in the change detection task."""
"""
Enumeration of the mask types used in the change detection task.
"""
CD = 0
SEG_T1 = 1
SEG_T2 = 2

@ -19,24 +19,26 @@ from paddlers.utils import logging, get_encoding, norm_path, is_pic
class ClasDataset(BaseDataset):
"""读取图像分类任务数据集,并对样本进行相应的处理。
"""
Dataset for scene classification tasks.
Args:
data_dir (str): 数据集所在的目录路径
file_list (str): 描述数据集图片文件和对应标注序号文本内每行路径为相对data_dir的相对路
label_list (str): 描述数据集包含的类别信息文件路径文件格式为类别 说明默认值为None
transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子
num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数默认为'auto'当设为'auto'根据
系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8`num_workers`为8否则为CPU核数的
一半
shuffle (bool): 是否需要对数据集中样本打乱顺序默认为False
data_dir (str): Root directory of the dataset.
file_list (str): Path of the file that contains relative paths of images and labels.
transforms (paddlers.transforms.Compose): Data preprocessing and data augmentation operators to apply.
label_list (str, optional): Path of the file that contains the category names. Defaults to None.
num_workers (int|str, optional): Number of processes used for data loading. If `num_workers` is 'auto',
the number of workers will be automatically determined according to the number of CPU cores: If
there are more than 16 cores8 workers will be used. Otherwise, the number of workers will be half
the number of CPU cores. Defaults: 'auto'.
shuffle (bool, optional): Whether to shuffle the samples. Defaults to False.
"""
def __init__(self,
data_dir,
file_list,
transforms,
label_list=None,
transforms=None,
num_workers='auto',
shuffle=False):
super(ClasDataset, self).__init__(data_dir, label_list, transforms,
@ -57,7 +59,7 @@ class ClasDataset(BaseDataset):
for line in f:
items = line.strip().split()
if len(items) > 2:
raise Exception(
raise ValueError(
"A space is defined as the delimiter to separate the image and label path, " \
"so the space cannot be in the image or label path, but the line[{}] of " \
" file_list[{}] has a space in the image or label path.".format(line, file_list))

@ -27,29 +27,32 @@ from paddlers.transforms import DecodeImg, MixupImage
from paddlers.tools import YOLOAnchorCluster
class COCODetection(BaseDataset):
"""读取COCO格式的检测数据集,并对样本进行相应的处理。
class COCODetDataset(BaseDataset):
"""
Dataset with COCO annotations for detection tasks.
Args:
data_dir (str): 数据集所在的目录路径
image_dir (str): 描述数据集图片文件路径
anno_path (str): COCO标注文件路径
label_list (str): 描述数据集包含的类别信息文件路径
transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子
num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数默认为'auto'当设为'auto'根据
系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8`num_workers`为8否则为CPU核数的
一半
shuffle (bool): 是否需要对数据集中样本打乱顺序默认为False
allow_empty (bool): 是否加载负样本默认为False
empty_ratio (float): 用于指定负样本占总样本数的比例如果小于0或大于等于1则保留全部的负样本默认为1
data_dir (str): Root directory of the dataset.
image_dir (str): Directory that contains the images.
ann_path (str): Path to COCO annotations.
transforms (paddlers.transforms.Compose): Data preprocessing and data augmentation operators to apply.
label_list (str, optional): Path of the file that contains the category names. Defaults to None.
num_workers (int|str, optional): Number of processes used for data loading. If `num_workers` is 'auto',
the number of workers will be automatically determined according to the number of CPU cores: If
there are more than 16 cores8 workers will be used. Otherwise, the number of workers will be half
the number of CPU cores. Defaults: 'auto'.
shuffle (bool, optional): Whether to shuffle the samples. Defaults to False.
allow_empty (bool, optional): Whether to add negative samples. Defaults to False.
empty_ratio (float, optional): Ratio of negative samples. If `empty_ratio` is smaller than 0 or not less
than 1, keep all generated negative samples. Defaults to 1.0.
"""
def __init__(self,
data_dir,
image_dir,
anno_path,
transforms,
label_list,
transforms=None,
num_workers='auto',
shuffle=False,
allow_empty=False,
@ -60,8 +63,8 @@ class COCODetection(BaseDataset):
import matplotlib
matplotlib.use('Agg')
from pycocotools.coco import COCO
super(COCODetection, self).__init__(data_dir, label_list, transforms,
num_workers, shuffle)
super(COCODetDataset, self).__init__(data_dir, label_list, transforms,
num_workers, shuffle)
self.data_fields = None
self.num_max_boxes = 50
@ -281,15 +284,16 @@ class COCODetection(BaseDataset):
https://github.com/ultralytics/yolov5/blob/master/utils/autoanchor.py
Args:
num_anchors (int): number of clusters
image_size (list or int): [h, w], being an int means image height and image width are the same.
cache (bool): whether using cache
cache_path (str or None, optional): cache directory path. If None, use `data_dir` of dataset.
iters (int, optional): iters of kmeans algorithm
gen_iters (int, optional): iters of genetic algorithm
threshold (float, optional): anchor scale threshold
verbose (bool, optional): whether print results
num_anchors (int): Number of clusters.
image_size (list[int]|int): [h, w] or an int value that corresponds to the shape [image_size, image_size].
cache (bool, optional): Whether to use cache. Defaults to True.
cache_path (str|None, optional): Path of cache directory. If None, use `dataset.data_dir`.
Defaults to None.
iters (int, optional): Iterations of k-means algorithm. Defaults to 300.
gen_iters (int, optional): Iterations of genetic algorithm. Defaults to 1000.
thresh (float, optional): Anchor scale threshold. Defaults to 0.25.
"""
if cache_path is None:
cache_path = self.data_dir
cluster = YOLOAnchorCluster(
@ -305,17 +309,18 @@ class COCODetection(BaseDataset):
return anchors
def add_negative_samples(self, image_dir, empty_ratio=1):
"""将背景图片加入训练
"""
Generate and add negative samples.
Args:
image_dir (str)背景图片所在的文件夹目录
empty_ratio (float or None): 用于指定负样本占总样本数的比例如果为None保留数据集初始化是设置的`empty_ratio`
否则更新原有`empty_ratio`如果小于0或大于等于1则保留全部的负样本默认为1
image_dir (str): Directory that contains images.
empty_ratio (float|None, optional): Ratio of negative samples. If `empty_ratio` is smaller than
0 or not less than 1, keep all generated negative samples. Defaults to 1.0.
"""
import cv2
if not osp.isdir(image_dir):
raise Exception("{} is not a valid image directory.".format(
raise ValueError("{} is not a valid image directory.".format(
image_dir))
if empty_ratio is not None:
self.empty_ratio = empty_ratio

@ -20,24 +20,26 @@ from paddlers.utils import logging, get_encoding, norm_path, is_pic
class SegDataset(BaseDataset):
"""读取语义分割任务数据集,并对样本进行相应的处理。
"""
Dataset for semantic segmentation tasks.
Args:
data_dir (str): 数据集所在的目录路径
file_list (str): 描述数据集图片文件和对应标注文件的文件路径文本内每行路径为相对data_dir的相对路
label_list (str): 描述数据集包含的类别信息文件路径默认值为None
transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子
num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数默认为'auto'当设为'auto'根据
系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8`num_workers`为8否则为CPU核数的
一半
shuffle (bool): 是否需要对数据集中样本打乱顺序默认为False
data_dir (str): Root directory of the dataset.
file_list (str): Path of the file that contains relative paths of images and annotation files.
transforms (paddlers.transforms.Compose): Data preprocessing and data augmentation operators to apply.
label_list (str, optional): Path of the file that contains the category names. Defaults to None.
num_workers (int|str, optional): Number of processes used for data loading. If `num_workers` is 'auto',
the number of workers will be automatically determined according to the number of CPU cores: If
there are more than 16 cores8 workers will be used. Otherwise, the number of workers will be half
the number of CPU cores. Defaults: 'auto'.
shuffle (bool, optional): Whether to shuffle the samples. Defaults to False.
"""
def __init__(self,
data_dir,
file_list,
transforms,
label_list=None,
transforms=None,
num_workers='auto',
shuffle=False):
super(SegDataset, self).__init__(data_dir, label_list, transforms,
@ -58,7 +60,7 @@ class SegDataset(BaseDataset):
for line in f:
items = line.strip().split()
if len(items) > 2:
raise Exception(
raise ValueError(
"A space is defined as the delimiter to separate the image and label path, " \
"so the space cannot be in the image or label path, but the line[{}] of " \
" file_list[{}] has a space in the image or label path.".format(line, file_list))

@ -29,27 +29,30 @@ from paddlers.transforms import DecodeImg, MixupImage
from paddlers.tools import YOLOAnchorCluster
class VOCDetection(BaseDataset):
"""读取PascalVOC格式的检测数据集,并对样本进行相应的处理。
class VOCDetDataset(BaseDataset):
"""
Dataset with PASCAL VOC annotations for detection tasks.
Args:
data_dir (str): 数据集所在的目录路径
file_list (str): 描述数据集图片文件和对应标注文件的文件路径文本内每行路径为相对data_dir的相对路
label_list (str): 描述数据集包含的类别信息文件路径
transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子
num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数默认为'auto'当设为'auto'根据
系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8`num_workers`为8否则为CPU核数的
一半
shuffle (bool): 是否需要对数据集中样本打乱顺序默认为False
allow_empty (bool): 是否加载负样本默认为False
empty_ratio (float): 用于指定负样本占总样本数的比例如果小于0或大于等于1则保留全部的负样本默认为1
data_dir (str): Root directory of the dataset.
file_list (str): Path of the file that contains relative paths of images and annotation files.
transforms (paddlers.transforms.Compose): Data preprocessing and data augmentation operators to apply.
label_list (str, optional): Path of the file that contains the category names. Defaults to None.
num_workers (int|str, optional): Number of processes used for data loading. If `num_workers` is 'auto',
the number of workers will be automatically determined according to the number of CPU cores: If
there are more than 16 cores8 workers will be used. Otherwise, the number of workers will be half
the number of CPU cores. Defaults: 'auto'.
shuffle (bool, optional): Whether to shuffle the samples. Defaults to False.
allow_empty (bool, optional): Whether to add negative samples. Defaults to False.
empty_ratio (float, optional): Ratio of negative samples. If `empty_ratio` is smaller than 0 or not less
than 1, keep all generated negative samples. Defaults to 1.0.
"""
def __init__(self,
data_dir,
file_list,
transforms,
label_list,
transforms=None,
num_workers='auto',
shuffle=False,
allow_empty=False,
@ -60,8 +63,8 @@ class VOCDetection(BaseDataset):
import matplotlib
matplotlib.use('Agg')
from pycocotools.coco import COCO
super(VOCDetection, self).__init__(data_dir, label_list, transforms,
num_workers, shuffle)
super(VOCDetDataset, self).__init__(data_dir, label_list, transforms,
num_workers, shuffle)
self.data_fields = None
self.num_max_boxes = 50
@ -109,9 +112,9 @@ class VOCDetection(BaseDataset):
if not line:
break
if len(line.strip().split()) > 2:
raise Exception("A space is defined as the separator, "
"but it exists in image or label name {}."
.format(line))
raise ValueError("A space is defined as the separator, "
"but it exists in image or label name {}."
.format(line))
img_file, xml_file = [
osp.join(data_dir, x) for x in line.strip().split()[:2]
]
@ -345,15 +348,16 @@ class VOCDetection(BaseDataset):
https://github.com/ultralytics/yolov5/blob/master/utils/autoanchor.py
Args:
num_anchors (int): number of clusters
image_size (list or int): [h, w], being an int means image height and image width are the same.
cache (bool): whether using cache
cache_path (str or None, optional): cache directory path. If None, use `data_dir` of dataset.
iters (int, optional): iters of kmeans algorithm
gen_iters (int, optional): iters of genetic algorithm
threshold (float, optional): anchor scale threshold
verbose (bool, optional): whether print results
num_anchors (int): Number of clusters.
image_size (list[int]|int): [h, w] or an int value that corresponds to the shape [image_size, image_size].
cache (bool, optional): Whether to use cache. Defaults to True.
cache_path (str|None, optional): Path of cache directory. If None, use `dataset.data_dir`.
Defaults to None.
iters (int, optional): Iterations of k-means algorithm. Defaults to 300.
gen_iters (int, optional): Iterations of genetic algorithm. Defaults to 1000.
thresh (float, optional): Anchor scale threshold. Defaults to 0.25.
"""
if cache_path is None:
cache_path = self.data_dir
cluster = YOLOAnchorCluster(
@ -369,17 +373,18 @@ class VOCDetection(BaseDataset):
return anchors
def add_negative_samples(self, image_dir, empty_ratio=1):
"""将背景图片加入训练
"""
Generate and add negative samples.
Args:
image_dir (str)背景图片所在的文件夹目录
empty_ratio (float or None): 用于指定负样本占总样本数的比例如果为None保留数据集初始化是设置的`empty_ratio`
否则更新原有`empty_ratio`如果小于0或大于等于1则保留全部的负样本默认为1
image_dir (str): Directory that contains images.
empty_ratio (float|None, optional): Ratio of negative samples. If `empty_ratio` is smaller than
0 or not less than 1, keep all generated negative samples. Defaults to 1.0.
"""
import cv2
if not osp.isdir(image_dir):
raise Exception("{} is not a valid image directory.".format(
raise ValueError("{} is not a valid image directory.".format(
image_dir))
if empty_ratio is not None:
self.empty_ratio = empty_ratio

@ -39,20 +39,20 @@ class Predictor(object):
max_trt_batch_size=1,
trt_precision_mode='float32'):
"""
创建Paddle Predictor
Args:
model_dir: 模型路径必须是导出的部署或量化模型
use_gpu: 是否使用GPU默认为False
gpu_id: 使用GPU的ID默认为0
cpu_thread_num使用cpu进行预测时的线程数默认为1
use_mkl: 是否使用mkldnn计算库CPU情况下使用默认为False
mkl_thread_num: mkldnn计算线程数默认为4
use_trt: 是否使用TensorRT默认为False
use_glog: 是否启用glog日志, 默认为False
memory_optimize: 是否启动内存优化默认为True
max_trt_batch_size: 在使用TensorRT时配置的最大batch size默认为1
trt_precision_mode在使用TensorRT时采用的精度可选值['float32', 'float16']默认为'float32'
model_dir (str): Path of the exported model.
use_gpu (bool, optional): Whether to use a GPU. Defaults to False.
gpu_id (int, optional): GPU ID. Defaults to 0.
cpu_thread_num (int, optional): Number of threads to use when making predictions using CPUs.
Defaults to 1.
use_mkl (bool, optional): Whether to use MKL-DNN. Defaults to False.
mkl_thread_num (int, optional): Number of MKL threads. Defaults to 4.
use_trt (bool, optional): Whether to use TensorRT. Defaults to False.
use_glog (bool, optional): Whether to enable glog logs. Defaults to False.
memory_optimize (bool, optional): Whether to enable memory optimization. Defaults to True.
max_trt_batch_size (int, optional): Maximum batch size when configured with TensorRT. Defaults to 1.
trt_precision_mode (str, optional)Precision to use when configured with TensorRT. Possible values
are {'float32', 'float16'}. Defaults to 'float32'.
"""
self.model_dir = model_dir
@ -209,10 +209,13 @@ class Predictor(object):
return preds
def raw_predict(self, inputs):
""" 接受预处理过后的数据进行预测
Args:
inputs(dict): 预处理过后的数据
"""
Predict according to preprocessed inputs.
Args:
inputs (dict): Preprocessed inputs.
"""
input_names = self.predictor.get_input_names()
for name in input_names:
input_tensor = self.predictor.get_input_handle(name)
@ -253,21 +256,22 @@ class Predictor(object):
warmup_iters=0,
repeats=1):
"""
Do prediction.
Args:
img_file(list[str | tuple | np.ndarray] | str | tuple | np.ndarray): For scene classification, image restoration,
object detection and semantic segmentation tasks, `img_file` should be either the path of the image to predict
, a decoded image (a np.ndarray, which should be consistent with what you get from passing image path to
paddlers.transforms.decode_image()), or a list of image paths or decoded images. For change detection tasks,
img_file should be a tuple of image paths, a tuple of decoded images, or a list of tuples.
topk(int, optional): Top-k values to reserve in a classification result. Defaults to 1.
transforms (paddlers.transforms.Compose | None, optional): Pipeline of data preprocessing. If None, load transforms
from `model.yml`. Defaults to None.
warmup_iters (int, optional): Warm-up iterations before measuring the execution time. Defaults to 0.
repeats (int, optional): Number of repetitions to evaluate model inference and data processing speed. If greater than
1, the reported time consumption is the average of all repeats. Defaults to 1.
Do prediction.
Args:
img_file(list[str|tuple|np.ndarray] | str | tuple | np.ndarray): For scene classification, image restoration,
object detection and semantic segmentation tasks, `img_file` should be either the path of the image to predict
, a decoded image (a np.ndarray, which should be consistent with what you get from passing image path to
paddlers.transforms.decode_image()), or a list of image paths or decoded images. For change detection tasks,
img_file should be a tuple of image paths, a tuple of decoded images, or a list of tuples.
topk(int, optional): Top-k values to reserve in a classification result. Defaults to 1.
transforms (paddlers.transforms.Compose|None, optional): Pipeline of data preprocessing. If None, load transforms
from `model.yml`. Defaults to None.
warmup_iters (int, optional): Warm-up iterations before measuring the execution time. Defaults to 0.
repeats (int, optional): Number of repetitions to evaluate model inference and data processing speed. If greater than
1, the reported time consumption is the average of all repeats. Defaults to 1.
"""
if repeats < 1:
logging.error("`repeats` must be greater than 1.", exit=True)
if transforms is None and not hasattr(self._model, 'test_transforms'):

@ -162,14 +162,17 @@ class BottleneckBlock(nn.Layer):
class ResNet(nn.Layer):
"""ResNet model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
"""
ResNet model from "Deep Residual Learning for Image Recognition"
(https://arxiv.org/pdf/1512.03385.pdf)
Args:
Block (BasicBlock|BottleneckBlock): block module of model.
depth (int): layers of resnet, default: 50.
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc
layer will not be defined. Default: 1000.
with_pool (bool): use pool before the last fc layer or not. Default: True.
Examples:
.. code-block:: python
from paddle.vision.models import ResNet
@ -283,7 +286,8 @@ def _resnet(arch, Block, depth, pretrained, **kwargs):
def resnet18(pretrained=False, **kwargs):
"""ResNet 18-layer model
"""
ResNet 18-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
@ -299,7 +303,8 @@ def resnet18(pretrained=False, **kwargs):
def resnet34(pretrained=False, **kwargs):
"""ResNet 34-layer model
"""
ResNet 34-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
@ -316,10 +321,12 @@ def resnet34(pretrained=False, **kwargs):
def resnet50(pretrained=False, **kwargs):
"""ResNet 50-layer model
"""
ResNet 50-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet50
@ -332,10 +339,12 @@ def resnet50(pretrained=False, **kwargs):
def resnet101(pretrained=False, **kwargs):
"""ResNet 101-layer model
"""
ResNet 101-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet101
@ -348,10 +357,12 @@ def resnet101(pretrained=False, **kwargs):
def resnet152(pretrained=False, **kwargs):
"""ResNet 152-layer model
"""
ResNet 152-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet152

@ -42,24 +42,24 @@ class BIT(nn.Layer):
This implementation adopts pretrained encoders, as opposed to the original work where weights are randomly initialized.
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
in_channels (int): Number of bands of the input images.
num_classes (int): Number of target classes.
backbone (str, optional): The ResNet architecture that is used as the backbone. Currently, only 'resnet18' and
'resnet34' are supported. Default: 'resnet18'.
n_stages (int, optional): The number of ResNet stages used in the backbone, which should be a value in {3,4,5}.
n_stages (int, optional): Number of ResNet stages used in the backbone, which should be a value in {3,4,5}.
Default: 4.
use_tokenizer (bool, optional): Use a tokenizer or not. Default: True.
token_len (int, optional): The length of input tokens. Default: 4.
token_len (int, optional): Length of input tokens. Default: 4.
pool_mode (str, optional): The pooling strategy to obtain input tokens when `use_tokenizer` is set to False. 'max'
for global max pooling and 'avg' for global average pooling. Default: 'max'.
pool_size (int, optional): The height and width of the pooled feature maps when `use_tokenizer` is set to False.
pool_size (int, optional): Height and width of the pooled feature maps when `use_tokenizer` is set to False.
Default: 2.
enc_with_pos (bool, optional): Whether to add leanred positional embedding to the input feature sequence of the
encoder. Default: True.
enc_depth (int, optional): The number of attention blocks used in the encoder. Default: 1
enc_head_dim (int, optional): The embedding dimension of each encoder head. Default: 64.
dec_depth (int, optional): The number of attention blocks used in the decoder. Default: 8.
dec_head_dim (int, optional): The embedding dimension of each decoder head. Default: 8.
enc_depth (int, optional): Number of attention blocks used in the encoder. Default: 1
enc_head_dim (int, optional): Embedding dimension of each encoder head. Default: 64.
dec_depth (int, optional): Number of attention blocks used in the decoder. Default: 8.
dec_head_dim (int, optional): Embedding dimension of each decoder head. Default: 8.
Raises:
ValueError: When an unsupported backbone type is specified, or the number of backbone stages is not 3, 4, or 5.

@ -15,8 +15,23 @@
import paddle
import paddle.nn as nn
from .layers import Conv7x7
class CDNet(nn.Layer):
"""
The CDNet implementation based on PaddlePaddle.
The original article refers to
Pablo F. Alcantarilla, et al., "Street-View Change Detection with Deconvolut
ional Networks"
(https://link.springer.com/article/10.1007/s10514-018-9734-5).
Args:
in_channels (int): Number of bands of the input images.
num_classes (int): Number of target classes.
"""
def __init__(self, in_channels=6, num_classes=2):
super(CDNet, self).__init__()
self.conv1 = Conv7x7(in_channels, 64, norm=True, act=True)
@ -48,28 +63,3 @@ class CDNet(nn.Layer):
x = self.conv7(self.upool2(x, ind2))
x = self.conv8(self.upool1(x, ind1))
return [self.conv_out(x)]
class Conv7x7(nn.Layer):
def __init__(self, in_ch, out_ch, norm=False, act=False):
super(Conv7x7, self).__init__()
layers = [
nn.Pad2D(3), nn.Conv2D(
in_ch, out_ch, 7, bias_attr=(False if norm else None))
]
if norm:
layers.append(nn.BatchNorm2D(out_ch))
if act:
layers.append(nn.ReLU())
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
if __name__ == "__main__":
t1 = paddle.randn((1, 3, 512, 512), dtype="float32")
t2 = paddle.randn((1, 3, 512, 512), dtype="float32")
model = CDNet(6, 2)
pred = model(t1, t2)[0]
print(pred.shape)

@ -86,7 +86,8 @@ class ChangeStar_FarSeg(_ChangeStarBase):
The ChangeStar implementation with a FarSeg encoder based on PaddlePaddle.
The original article refers to
Z. Zheng, et al., "Change is Everywhere: Single-Temporal Supervised Object Change Detection in Remote Sensing Imagery"
Z. Zheng, et al., "Change is Everywhere: Single-Temporal Supervised Object
Change Detection in Remote Sensing Imagery"
(https://arxiv.org/abs/2108.07002).
Note that this implementation differs from the original code in two aspects:
@ -94,12 +95,15 @@ class ChangeStar_FarSeg(_ChangeStarBase):
2. We use conv-bn-relu instead of conv-relu-bn.
Args:
num_classes (int): The number of target classes.
mid_channels (int, optional): The number of channels required by the ChangeMixin module. Default: 256.
inner_channels (int, optional): The number of filters used in the convolutional layers in the ChangeMixin module.
Default: 16.
num_convs (int, optional): The number of convolutional layers used in the ChangeMixin module. Default: 4.
scale_factor (float, optional): The scaling factor of the output upsampling layer. Default: 4.0.
num_classes (int): Number of target classes.
mid_channels (int, optional): Number of channels required by the
ChangeMixin module. Default: 256.
inner_channels (int, optional): Number of filters used in the
convolutional layers in the ChangeMixin module. Default: 16.
num_convs (int, optional): Number of convolutional layers used in the
ChangeMixin module. Default: 4.
scale_factor (float, optional): Scaling factor of the output upsampling
layer. Default: 4.0.
"""
def __init__(

@ -25,19 +25,22 @@ class DSAMNet(nn.Layer):
The DSAMNet implementation based on PaddlePaddle.
The original article refers to
Q. Shi, et al., "A Deeply Supervised Attention Metric-Based Network and an Open Aerial Image Dataset for Remote Sensing
Change Detection"
Q. Shi, et al., "A Deeply Supervised Attention Metric-Based Network and an
Open Aerial Image Dataset for Remote Sensing Change Detection"
(https://ieeexplore.ieee.org/document/9467555).
Note that this implementation differs from the original work in two aspects:
1. We do not use multiple dilation rates in layer 4 of the ResNet backbone.
2. A classification head is used in place of the original metric learning-based head to stablize the training process.
2. A classification head is used in place of the original metric learning-based
head to stablize the training process.
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
ca_ratio (int, optional): The channel reduction ratio for the channel attention module. Default: 8.
sa_kernel (int, optional): The size of the convolutional kernel used in the spatial attention module. Default: 7.
in_channels (int): Number of bands of the input images.
num_classes (int): Number of target classes.
ca_ratio (int, optional): Channel reduction ratio for the channel
attention module. Default: 8.
sa_kernel (int, optional): Size of the convolutional kernel used in the
spatial attention module. Default: 7.
"""
def __init__(self, in_channels, num_classes, ca_ratio=8, sa_kernel=7):

@ -28,16 +28,17 @@ class DSIFN(nn.Layer):
The DSIFN implementation based on PaddlePaddle.
The original article refers to
C. Zhang, et al., "A deeply supervised image fusion network for change detection in high resolution bi-temporal remote
sensing images"
C. Zhang, et al., "A deeply supervised image fusion network for change
detection in high resolution bi-temporal remote sensing images"
(https://www.sciencedirect.com/science/article/pii/S0924271620301532).
Note that in this implementation, there is a flexible number of target classes.
Args:
num_classes (int): The number of target classes.
use_dropout (bool, optional): A bool value that indicates whether to use dropout layers. When the model is trained
on a relatively small dataset, the dropout layers help prevent overfitting. Default: False.
num_classes (int): Number of target classes.
use_dropout (bool, optional): A bool value that indicates whether to use
dropout layers. When the model is trained on a relatively small dataset,
the dropout layers help prevent overfitting. Default: False.
"""
def __init__(self, num_classes, use_dropout=False):

@ -26,14 +26,16 @@ class FCEarlyFusion(nn.Layer):
The FC-EF implementation based on PaddlePaddle.
The original article refers to
Caye Daudt, R., et al. "Fully convolutional siamese networks for change detection"
Rodrigo Caye Daudt, et al. "Fully convolutional siamese networks for change
detection"
(https://arxiv.org/abs/1810.08462).
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
use_dropout (bool, optional): A bool value that indicates whether to use dropout layers. When the model is trained
on a relatively small dataset, the dropout layers help prevent overfitting. Default: False.
in_channels (int): Number of bands of the input images.
num_classes (int): Number of target classes.
use_dropout (bool, optional): A bool value that indicates whether to use
dropout layers. When the model is trained on a relatively small dataset,
the dropout layers help prevent overfitting. Default: False.
"""
def __init__(self, in_channels, num_classes, use_dropout=False):

@ -26,14 +26,16 @@ class FCSiamConc(nn.Layer):
The FC-Siam-conc implementation based on PaddlePaddle.
The original article refers to
Caye Daudt, R., et al. "Fully convolutional siamese networks for change detection"
Rodrigo Caye Daudt, et al. "Fully convolutional siamese networks for change
detection"
(https://arxiv.org/abs/1810.08462).
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
use_dropout (bool, optional): A bool value that indicates whether to use dropout layers. When the model is trained
on a relatively small dataset, the dropout layers help prevent overfitting. Default: False.
in_channels (int): Number of bands of the input images.
num_classes (int): Number of target classes.
use_dropout (bool, optional): A bool value that indicates whether to use
dropout layers. When the model is trained on a relatively small dataset,
the dropout layers help prevent overfitting. Default: False.
"""
def __init__(self, in_channels, num_classes, use_dropout=False):

@ -26,14 +26,16 @@ class FCSiamDiff(nn.Layer):
The FC-Siam-diff implementation based on PaddlePaddle.
The original article refers to
Caye Daudt, R., et al. "Fully convolutional siamese networks for change detection"
Rodrigo Caye Daudt, et al. "Fully convolutional siamese networks for change
detection"
(https://arxiv.org/abs/1810.08462).
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
use_dropout (bool, optional): A bool value that indicates whether to use dropout layers. When the model is trained
on a relatively small dataset, the dropout layers help prevent overfitting. Default: False.
in_channels (int): Number of bands of the input images.
num_classes (int): Number of target classes.
use_dropout (bool, optional): A bool value that indicates whether to use
dropout layers. When the model is trained on a relatively small dataset,
the dropout layers help prevent overfitting. Default: False.
"""
def __init__(self, in_channels, num_classes, use_dropout=False):

@ -28,8 +28,8 @@ class ChannelAttention(nn.Layer):
(https://arxiv.org/abs/1807.06521).
Args:
in_ch (int): The number of channels of the input features.
ratio (int, optional): The channel reduction ratio. Default: 8.
in_ch (int): Number of channels of the input features.
ratio (int, optional): Channel reduction ratio. Default: 8.
"""
def __init__(self, in_ch, ratio=8):
@ -55,7 +55,8 @@ class SpatialAttention(nn.Layer):
(https://arxiv.org/abs/1807.06521).
Args:
kernel_size (int, optional): The size of the convolutional kernel. Default: 7.
kernel_size (int, optional): Size of the convolutional kernel.
Default: 7.
"""
def __init__(self, kernel_size=7):
@ -79,9 +80,11 @@ class CBAM(nn.Layer):
(https://arxiv.org/abs/1807.06521).
Args:
in_ch (int): The number of channels of the input features.
ratio (int, optional): The channel reduction ratio for the channel attention module. Default: 8.
kernel_size (int, optional): The size of the convolutional kernel used in the spatial attention module. Default: 7.
in_ch (int): Number of channels of the input features.
ratio (int, optional): Channel reduction ratio for the channel
attention module. Default: 8.
kernel_size (int, optional): Size of the convolutional kernel used in
the spatial attention module. Default: 7.
"""
def __init__(self, in_ch, ratio=8, kernel_size=7):

@ -184,7 +184,9 @@ class ConvTransposed3x3(nn.Layer):
class Identity(nn.Layer):
"""A placeholder identity operator that accepts exactly one argument."""
"""
A placeholder identity operator that accepts exactly one argument.
"""
def __init__(self, *args, **kwargs):
super(Identity, self).__init__()

@ -27,15 +27,18 @@ class SNUNet(nn.Layer, KaimingInitMixin):
The SNUNet implementation based on PaddlePaddle.
The original article refers to
S. Fang, et al., "SNUNet-CD: A Densely Connected Siamese Network for Change Detection of VHR Images"
S. Fang, et al., "SNUNet-CD: A Densely Connected Siamese Network for Change
Detection of VHR Images"
(https://ieeexplore.ieee.org/document/9355573).
Note that bilinear interpolation is adopted as the upsampling method, which is different from the paper.
Note that bilinear interpolation is adopted as the upsampling method, which is
different from the paper.
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
width (int, optional): The output channels of the first convolutional layer. Default: 32.
in_channels (int): Number of bands of the input images.
num_classes (int): Number of target classes.
width (int, optional): Output channels of the first convolutional layer.
Default: 32.
"""
def __init__(self, in_channels, num_classes, width=32):

@ -26,23 +26,29 @@ class STANet(nn.Layer):
The STANet implementation based on PaddlePaddle.
The original article refers to
H. Chen and Z. Shi, "A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection"
H. Chen and Z. Shi, "A Spatial-Temporal Attention-Based Method and a New
Dataset for Remote Sensing Image Change Detection"
(https://www.mdpi.com/2072-4292/12/10/1662).
Note that this implementation differs from the original work in two aspects:
1. We do not use multiple dilation rates in layer 4 of the ResNet backbone.
2. A classification head is used in place of the original metric learning-based head to stablize the training process.
2. A classification head is used in place of the original metric learning-based
head to stablize the training process.
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
att_type (str, optional): The attention module used in the model. Options are 'PAM' and 'BAM'. Default: 'BAM'.
ds_factor (int, optional): The downsampling factor of the attention modules. When `ds_factor` is set to values
greater than 1, the input features will first be processed by an average pooling layer with the kernel size of
`ds_factor`, before being used to calculate the attention scores. Default: 1.
in_channels (int): Number of bands of the input images.
num_classes (int): Number of target classes.
att_type (str, optional): The attention module used in the model. Options
are 'PAM' and 'BAM'. Default: 'BAM'.
ds_factor (int, optional): Downsampling factor of the attention modules.
When `ds_factor` is set to values greater than 1, the input features
will first be processed by an average pooling layer with the kernel size
of `ds_factor`, before being used to calculate the attention scores.
Default: 1.
Raises:
ValueError: When `att_type` has an illeagal value (unsupported attention type).
ValueError: When `att_type` has an illeagal value (unsupported attention
type).
"""
def __init__(self, in_channels, num_classes, att_type='BAM', ds_factor=1):

@ -25,7 +25,8 @@ from ...models.ppgan.modules.init import reset_parameters
@MODELS.register()
class RCANModel(BaseModel):
"""Base SR model for single image super-resolution.
"""
Base SR model for single image super-resolution.
"""
def __init__(self, generator, pixel_criterion=None, use_init_weight=False):

@ -32,7 +32,7 @@ class FPN(nn.Layer):
"""
Module that adds FPN on top of a list of feature maps.
The feature maps are currently supposed to be in increasing depth
order, and must be consecutive
order, and must be consecutive.
"""
def __init__(self,
@ -233,13 +233,14 @@ class ResNet50Encoder(nn.Layer):
class FarSeg(nn.Layer):
'''
"""
The FarSeg implementation based on PaddlePaddle.
The original article refers to
Zheng, Zhuo, et al. "Foreground-Aware Relation Network for Geospatial Object Segmentation in High Spatial Resolution Remote Sensing Imagery"
Zheng, Zhuo, et al. "Foreground-Aware Relation Network for Geospatial Object
Segmentation in High Spatial Resolution Remote Sensing Imagery"
(https://openaccess.thecvf.com/content_CVPR_2020/papers/Zheng_Foreground-Aware_Relation_Network_for_Geospatial_Object_Segmentation_in_High_Spatial_CVPR_2020_paper.pdf)
'''
"""
def __init__(self,
num_classes=16,

@ -96,16 +96,17 @@ class Activation(nn.Layer):
"""
The wrapper of activations.
Args:
act (str, optional): The activation name in lowercase. It must be one of ['elu', 'gelu',
'hardshrink', 'tanh', 'hardtanh', 'prelu', 'relu', 'relu6', 'selu', 'leakyrelu', 'sigmoid',
'softmax', 'softplus', 'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax',
act (str, optional): Activation name in lowercase, which must be one of
['elu', 'gelu', 'hardshrink', 'tanh', 'hardtanh', 'prelu', 'relu',
'relu6', 'selu', 'leakyrelu', 'sigmoid', 'softmax', 'softplus',
'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax',
'hsigmoid']. Default: None, means identical transformation.
Returns:
A callable object of Activation.
Raises:
KeyError: When parameter `act` is not in the optional range.
Examples:
from paddleseg.models.common.activation import Activation
from paddlers.rs_models.seg.layers import Activation
relu = Activation("relu")
print(relu)
# <class 'paddle.nn.layer.activation.ReLU'>

@ -126,18 +126,18 @@ class BaseModel(metaclass=ModelMeta):
if not osp.exists(osp.join(resume_checkpoint, 'model.pdparams')):
logging.error(
"Model parameter state dictionary file 'model.pdparams' "
"not found under given checkpoint path {}".format(
"was not found in given checkpoint path {}!".format(
resume_checkpoint),
exit=True)
if not osp.exists(osp.join(resume_checkpoint, 'model.pdopt')):
logging.error(
"Optimizer state dictionary file 'model.pdparams' "
"not found under given checkpoint path {}".format(
"was not found in given checkpoint path {}!".format(
resume_checkpoint),
exit=True)
if not osp.exists(osp.join(resume_checkpoint, 'model.yml')):
logging.error(
"'model.yml' not found under given checkpoint path {}".
"'model.yml' was not found in given checkpoint path {}!".
format(resume_checkpoint),
exit=True)
with open(osp.join(resume_checkpoint, "model.yml")) as f:
@ -264,7 +264,7 @@ class BaseModel(metaclass=ModelMeta):
def build_data_loader(self, dataset, batch_size, mode='train'):
if dataset.num_samples < batch_size:
raise Exception(
raise ValueError(
'The volume of dataset({}) must be larger than batch size({}).'
.format(dataset.num_samples, batch_size))
batch_size_each_card = get_single_card_bs(batch_size=batch_size)
@ -478,17 +478,21 @@ class BaseModel(metaclass=ModelMeta):
save_dir='output'):
"""
Args:
dataset(paddlers.dataset): Dataset used for evaluation during sensitivity analysis.
batch_size(int, optional): Batch size used in evaluation. Defaults to 8.
criterion({'l1_norm', 'fpgm'}, optional): Pruning criterion. Defaults to 'l1_norm'.
save_dir(str, optional): The directory to save sensitivity file of the model. Defaults to 'output'.
dataset (paddlers.datasets.BaseDataset): Dataset used for evaluation during
sensitivity analysis.
batch_size (int, optional): Batch size used in evaluation. Defaults to 8.
criterion (str, optional): Pruning criterion. Choices are {'l1_norm', 'fpgm'}.
Defaults to 'l1_norm'.
save_dir (str, optional): Directory to save sensitivity file of the model.
Defaults to 'output'.
"""
if self.__class__.__name__ in {'FasterRCNN', 'MaskRCNN', 'PicoDet'}:
raise Exception("{} does not support pruning currently!".format(
raise ValueError("{} does not support pruning currently!".format(
self.__class__.__name__))
assert criterion in {'l1_norm', 'fpgm'}, \
"Pruning criterion {} is not supported. Please choose from ['l1_norm', 'fpgm']"
"Pruning criterion {} is not supported. Please choose from {'l1_norm', 'fpgm'}."
self._check_transforms(dataset.transforms, 'eval')
if self.model_type == 'detector':
self.net.eval()
@ -515,13 +519,14 @@ class BaseModel(metaclass=ModelMeta):
def prune(self, pruned_flops, save_dir=None):
"""
Args:
pruned_flops(float): Ratio of FLOPs to be pruned.
save_dir(None or str, optional): If None, the pruned model will not be saved.
Otherwise, the pruned model will be saved at save_dir. Defaults to None.
pruned_flops (float): Ratio of FLOPs to be pruned.
save_dir (str|None, optional): If None, the pruned model will not be
saved. Otherwise, the pruned model will be saved at `save_dir`.
Defaults to None.
"""
if self.status == "Pruned":
raise Exception(
"A pruned model cannot be done model pruning again!")
raise ValueError(
"A pruned model cannot be pruned for a second time!")
pre_pruning_flops = flops(self.net, self.pruner.inputs)
logging.info("Pre-pruning FLOPs: {}. Pruning starts...".format(
pre_pruning_flops))
@ -529,8 +534,8 @@ class BaseModel(metaclass=ModelMeta):
post_pruning_flops = flops(self.net, self.pruner.inputs)
logging.info("Pruning is complete. Post-pruning FLOPs: {}".format(
post_pruning_flops))
logging.warning("Pruning the model may hurt its performance, "
"retraining is highly recommended")
logging.warning("Pruning the model may hurt its performance. "
"Re-training is highly recommended.")
self.status = 'Pruned'
if save_dir is not None:
@ -540,7 +545,7 @@ class BaseModel(metaclass=ModelMeta):
def _prepare_qat(self, quant_config):
if self.status == 'Infer':
logging.error(
"Exported inference model does not support quantization aware training.",
"Exported inference model does not support quantization-aware training.",
exit=True)
if quant_config is None:
# default quantization configuration
@ -578,7 +583,7 @@ class BaseModel(metaclass=ModelMeta):
elif quant_config != self.quant_config:
logging.error(
"The model has been quantized with the following quant_config: {}."
"Doing quantization-aware training with a quantized model "
"Performing quantization-aware training with a quantized model "
"using a different configuration is not supported."
.format(self.quant_config),
exit=True)
@ -666,7 +671,7 @@ class BaseModel(metaclass=ModelMeta):
# 模型保存成功的标志
open(osp.join(save_dir, '.success'), 'w').close()
logging.info("The model for the inference deployment is saved in {}.".
logging.info("The inference model for deployment is saved in {}.".
format(save_dir))
def _check_transforms(self, transforms, mode):

@ -238,29 +238,37 @@ class BaseChangeDetector(BaseModel):
resume_checkpoint=None):
"""
Train the model.
Args:
num_epochs(int): The number of epochs.
train_dataset(paddlers.dataset): Training dataset.
train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
eval_dataset(paddlers.dataset, optional):
Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
optimizer(paddle.optimizer.Optimizer or None, optional):
Optimizer used in training. If None, a default optimizer is used. Defaults to None.
save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
save_dir(str, optional): Directory to save the model. Defaults to 'output'.
pretrain_weights(str or None, optional):
None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to None.
learning_rate(float, optional): Learning rate for training. Defaults to .025.
lr_decay_power(float, optional): Learning decay power. Defaults to .9.
early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
early_stop_patience(int, optional): Early stop patience. Defaults to 5.
use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
`pretrain_weights` can be set simultaneously. Defaults to None.
Args:
num_epochs (int): Number of epochs.
train_dataset (paddlers.datasets.CDDataset): Training dataset.
train_batch_size (int, optional): Total batch size among all cards used in
training. Defaults to 2.
eval_dataset (paddlers.datasets.CDDataset|None, optional): Evaluation dataset.
If None, the model will not be evaluated during training process.
Defaults to None.
optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in
training. If None, a default optimizer will be used. Defaults to None.
save_interval_epochs (int, optional): Epoch interval for saving the model.
Defaults to 1.
log_interval_steps (int, optional): Step interval for printing training
information. Defaults to 2.
save_dir (str, optional): Directory to save the model. Defaults to 'output'.
pretrain_weights (str|None, optional): None or name/path of pretrained
weights. If None, no pretrained weights will be loaded. Defaults to None.
learning_rate (float, optional): Learning rate for training. Defaults to .01.
lr_decay_power (float, optional): Learning decay power. Defaults to .9.
early_stop (bool, optional): Whether to adopt early stop strategy. Defaults
to False.
early_stop_patience (int, optional): Early stop patience. Defaults to 5.
use_vdl (bool, optional): Whether to use VisualDL to monitor the training
process. Defaults to True.
resume_checkpoint (str|None, optional): Path of the checkpoint to resume
training from. If None, no training checkpoint will be resumed. At most
Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
Defaults to None.
"""
if self.status == 'Infer':
logging.error(
"Exported inference model does not support training.",
@ -336,28 +344,37 @@ class BaseChangeDetector(BaseModel):
quant_config=None):
"""
Quantization-aware training.
Args:
num_epochs(int): The number of epochs.
train_dataset(paddlers.dataset): Training dataset.
train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
eval_dataset(paddlers.dataset, optional):
Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
optimizer(paddle.optimizer.Optimizer or None, optional):
Optimizer used in training. If None, a default optimizer is used. Defaults to None.
save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
save_dir(str, optional): Directory to save the model. Defaults to 'output'.
learning_rate(float, optional): Learning rate for training. Defaults to .025.
lr_decay_power(float, optional): Learning decay power. Defaults to .9.
early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
early_stop_patience(int, optional): Early stop patience. Defaults to 5.
use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
configuration will be used. Defaults to None.
resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
from. If None, no training checkpoint will be resumed. Defaults to None.
Args:
num_epochs (int): Number of epochs.
train_dataset (paddlers.datasets.CDDataset): Training dataset.
train_batch_size (int, optional): Total batch size among all cards used in
training. Defaults to 2.
eval_dataset (paddlers.datasets.CDDataset, optional): Evaluation dataset.
If None, the model will not be evaluated during training process.
Defaults to None.
optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in
training. If None, a default optimizer will be used. Defaults to None.
save_interval_epochs (int, optional): Epoch interval for saving the model.
Defaults to 1.
log_interval_steps (int, optional): Step interval for printing training
information. Defaults to 2.
save_dir (str, optional): Directory to save the model. Defaults to 'output'.
learning_rate (float, optional): Learning rate for training.
Defaults to .0001.
lr_decay_power (float, optional): Learning decay power. Defaults to .9.
early_stop (bool, optional): Whether to adopt early stop strategy.
Defaults to False.
early_stop_patience (int, optional): Early stop patience. Defaults to 5.
use_vdl (bool, optional): Whether to use VisualDL to monitor the training
process. Defaults to True.
quant_config (dict|None, optional): Quantization configuration. If None,
a default rule of thumb configuration will be used. Defaults to None.
resume_checkpoint (str|None, optional): Path of the checkpoint to resume
quantization-aware training from. If None, no training checkpoint will
be resumed. Defaults to None.
"""
self._prepare_qat(quant_config)
self.train(
num_epochs=num_epochs,
@ -379,27 +396,32 @@ class BaseChangeDetector(BaseModel):
def evaluate(self, eval_dataset, batch_size=1, return_details=False):
"""
Evaluate the model.
Args:
eval_dataset(paddlers.dataset): Evaluation dataset.
batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
return_details(bool, optional): Whether to return evaluation details. Defaults to False.
eval_dataset (paddlers.datasets.CDDataset): Evaluation dataset.
batch_size (int, optional): Total batch size among all cards used for
evaluation. Defaults to 1.
return_details (bool, optional): Whether to return evaluation details.
Defaults to False.
Returns:
collections.OrderedDict with key-value pairs:
For binary change detection (number of classes == 2), the key-value pairs are like:
{"iou": `intersection over union for the change class`,
"f1": `F1 score for the change class`,
"oacc": `overall accuracy`,
"kappa": ` kappa coefficient`}.
For multi-class change detection (number of classes > 2), the key-value pairs are like:
{"miou": `mean intersection over union`,
"category_iou": `category-wise mean intersection over union`,
"oacc": `overall accuracy`,
"category_acc": `category-wise accuracy`,
"kappa": ` kappa coefficient`,
"category_F1-score": `F1 score`}.
For binary change detection (number of classes == 2), the key-value
pairs are like:
{"iou": `intersection over union for the change class`,
"f1": `F1 score for the change class`,
"oacc": `overall accuracy`,
"kappa": ` kappa coefficient`}.
For multi-class change detection (number of classes > 2), the key-value
pairs are like:
{"miou": `mean intersection over union`,
"category_iou": `category-wise mean intersection over union`,
"oacc": `overall accuracy`,
"category_acc": `category-wise accuracy`,
"kappa": ` kappa coefficient`,
"category_F1-score": `F1 score`}.
"""
self._check_transforms(eval_dataset.transforms, 'eval')
self.net.eval()
@ -500,24 +522,27 @@ class BaseChangeDetector(BaseModel):
def predict(self, img_file, transforms=None):
"""
Do inference.
Args:
Args:
img_file (list[tuple] | tuple[str | np.ndarray]):
Tuple of image paths or decoded image data for bi-temporal images, which also could constitute a list,
meaning all image pairs to be predicted as a mini-batch.
transforms(paddlers.transforms.Compose or None, optional):
Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
img_file (list[tuple] | tuple[str|np.ndarray]): Tuple of image paths or
decoded image data for bi-temporal images, which also could constitute
a list, meaning all image pairs to be predicted as a mini-batch.
transforms (paddlers.transforms.Compose|None, optional): Transforms for
inputs. If None, the transforms for evaluation process will be used.
Defaults to None.
Returns:
If img_file is a tuple of string or np.array, the result is a dict with key-value pairs:
{"label map": `label map`, "score_map": `score map`}.
If img_file is a list, the result is a list composed of dicts with the corresponding fields:
label_map(np.ndarray): the predicted label map (HW)
score_map(np.ndarray): the prediction score map (HWC)
If `img_file` is a tuple of string or np.array, the result is a dict with
key-value pairs:
{"label map": `label map`, "score_map": `score map`}.
If `img_file` is a list, the result is a list composed of dicts with the
corresponding fields:
label_map (np.ndarray): the predicted label map (HW)
score_map (np.ndarray): the prediction score map (HWC)
"""
if transforms is None and not hasattr(self, 'test_transforms'):
raise Exception("transforms need to be defined, now is None.")
raise ValueError("transforms need to be defined, now is None.")
if transforms is None:
transforms = self.test_transforms
if isinstance(img_file, tuple):
@ -555,26 +580,24 @@ class BaseChangeDetector(BaseModel):
transforms=None):
"""
Do inference.
Args:
Args:
img_file(list[str]):
List of image paths.
save_dir(str):
Directory that contains saved geotiff file.
block_size(list[int] | tuple[int] | int, optional):
Size of block.
overlap(list[int] | tuple[int] | int, optional):
Overlap between two blocks. Defaults to 36.
transforms(paddlers.transforms.Compose or None, optional):
Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
img_file (tuple[str]): Tuple of image paths.
save_dir (str): Directory that contains saved geotiff file.
block_size (list[int] | tuple[int] | int, optional): Size of block.
overlap (list[int] | tuple[int] | int, optional): Overlap between two blocks.
Defaults to 36.
transforms (paddlers.transforms.Compose|None, optional): Transforms for inputs.
If None, the transforms for evaluation process will be used. Defaults to None.
"""
try:
from osgeo import gdal
except:
import gdal
if len(img_file) != 2:
raise ValueError("`img_file` must be a list of length 2.")
if not isinstance(img_file, tuple) or len(img_file) != 2:
raise ValueError("`img_file` must be a tuple of length 2.")
if isinstance(block_size, int):
block_size = (block_size, block_size)
elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:

@ -52,7 +52,7 @@ class BaseClassifier(BaseModel):
super(BaseClassifier, self).__init__('classifier')
if not hasattr(paddleclas.arch.backbone, model_name) and \
not hasattr(cmcls, model_name):
raise Exception("ERROR: There's no model named {}.".format(
raise ValueError("ERROR: There is no model named {}.".format(
model_name))
self.model_name = model_name
self.in_channels = in_channels
@ -202,29 +202,39 @@ class BaseClassifier(BaseModel):
resume_checkpoint=None):
"""
Train the model.
Args:
num_epochs(int): The number of epochs.
train_dataset(paddlers.dataset): Training dataset.
train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
eval_dataset(paddlers.dataset, optional):
Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
optimizer(paddle.optimizer.Optimizer or None, optional):
Optimizer used in training. If None, a default optimizer is used. Defaults to None.
save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
save_dir(str, optional): Directory to save the model. Defaults to 'output'.
pretrain_weights(str or None, optional):
None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'CITYSCAPES'.
learning_rate(float, optional): Learning rate for training. Defaults to .025.
lr_decay_power(float, optional): Learning decay power. Defaults to .9.
early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
early_stop_patience(int, optional): Early stop patience. Defaults to 5.
use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
`pretrain_weights` can be set simultaneously. Defaults to None.
Args:
num_epochs (int): Number of epochs.
train_dataset (paddlers.datasets.ClasDataset): Training dataset.
train_batch_size (int, optional): Total batch size among all cards used in
training. Defaults to 2.
eval_dataset (paddlers.datasets.ClasDataset, optional): Evaluation dataset.
If None, the model will not be evaluated during training process.
Defaults to None.
optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in
training. If None, a default optimizer will be used. Defaults to None.
save_interval_epochs (int, optional): Epoch interval for saving the model.
Defaults to 1.
log_interval_steps (int, optional): Step interval for printing training
information. Defaults to 2.
save_dir (str, optional): Directory to save the model. Defaults to 'output'.
pretrain_weights (str|None, optional): None or name/path of pretrained
weights. If None, no pretrained weights will be loaded.
Defaults to 'IMAGENET'.
learning_rate (float, optional): Learning rate for training.
Defaults to .1.
lr_decay_power (float, optional): Learning decay power. Defaults to .9.
early_stop (bool, optional): Whether to adopt early stop strategy.
Defaults to False.
early_stop_patience (int, optional): Early stop patience. Defaults to 5.
use_vdl (bool, optional): Whether to use VisualDL to monitor the training
process. Defaults to True.
resume_checkpoint (str|None, optional): Path of the checkpoint to resume
training from. If None, no training checkpoint will be resumed. At most
Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
Defaults to None.
"""
if self.status == 'Infer':
logging.error(
"Exported inference model does not support training.",
@ -303,28 +313,37 @@ class BaseClassifier(BaseModel):
quant_config=None):
"""
Quantization-aware training.
Args:
num_epochs(int): The number of epochs.
train_dataset(paddlers.dataset): Training dataset.
train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
eval_dataset(paddlers.dataset, optional):
Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
optimizer(paddle.optimizer.Optimizer or None, optional):
Optimizer used in training. If None, a default optimizer is used. Defaults to None.
save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
save_dir(str, optional): Directory to save the model. Defaults to 'output'.
learning_rate(float, optional): Learning rate for training. Defaults to .025.
lr_decay_power(float, optional): Learning decay power. Defaults to .9.
early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
early_stop_patience(int, optional): Early stop patience. Defaults to 5.
use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
configuration will be used. Defaults to None.
resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
from. If None, no training checkpoint will be resumed. Defaults to None.
Args:
num_epochs (int): Number of epochs.
train_dataset (paddlers.datasets.ClasDataset): Training dataset.
train_batch_size (int, optional): Total batch size among all cards used in
training. Defaults to 2.
eval_dataset (paddlers.datasets.ClasDataset, optional): Evaluation dataset.
If None, the model will not be evaluated during training process.
Defaults to None.
optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in
training. If None, a default optimizer will be used. Defaults to None.
save_interval_epochs (int, optional): Epoch interval for saving the model.
Defaults to 1.
log_interval_steps (int, optional): Step interval for printing training
information. Defaults to 2.
save_dir (str, optional): Directory to save the model. Defaults to 'output'.
learning_rate (float, optional): Learning rate for training.
Defaults to .0001.
lr_decay_power (float, optional): Learning decay power. Defaults to .9.
early_stop (bool, optional): Whether to adopt early stop strategy.
Defaults to False.
early_stop_patience (int, optional): Early stop patience. Defaults to 5.
use_vdl (bool, optional): Whether to use VisualDL to monitor the training
process. Defaults to True.
quant_config (dict|None, optional): Quantization configuration. If None,
a default rule of thumb configuration will be used. Defaults to None.
resume_checkpoint (str|None, optional): Path of the checkpoint to resume
quantization-aware training from. If None, no training checkpoint will
be resumed. Defaults to None.
"""
self._prepare_qat(quant_config)
self.train(
num_epochs=num_epochs,
@ -346,17 +365,20 @@ class BaseClassifier(BaseModel):
def evaluate(self, eval_dataset, batch_size=1, return_details=False):
"""
Evaluate the model.
Args:
eval_dataset(paddlers.dataset): Evaluation dataset.
batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
return_details(bool, optional): Whether to return evaluation details. Defaults to False.
eval_dataset (paddlers.datasets.ClasDataset): Evaluation dataset.
batch_size (int, optional): Total batch size among all cards used for
evaluation. Defaults to 1.
return_details (bool, optional): Whether to return evaluation details.
Defaults to False.
Returns:
collections.OrderedDict with key-value pairs:
{"top1": `acc of top1`,
"top5": `acc of top5`}.
"""
self._check_transforms(eval_dataset.transforms, 'eval')
self.net.eval()
@ -404,25 +426,28 @@ class BaseClassifier(BaseModel):
def predict(self, img_file, transforms=None):
"""
Do inference.
Args:
Args:
img_file(list[np.ndarray | str] | str | np.ndarray):
Image path or decoded image data, which also could constitute a list, meaning all images to be
img_file (list[np.ndarray|str] | str | np.ndarray): Image path or decoded
image data, which also could constitute a list, meaning all images to be
predicted as a mini-batch.
transforms(paddlers.transforms.Compose or None, optional):
Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
transforms (paddlers.transforms.Compose|None, optional): Transforms for
inputs. If None, the transforms for evaluation process will be used.
Defaults to None.
Returns:
If img_file is a string or np.array, the result is a dict with key-value pairs:
{"label map": `class_ids_map`, "scores_map": `label_names_map`}.
If img_file is a list, the result is a list composed of dicts with the corresponding fields:
class_ids_map(np.ndarray): class_ids
scores_map(np.ndarray): scores
label_names_map(np.ndarray): label_names
If `img_file` is a string or np.array, the result is a dict with key-value
pairs:
{"label map": `class_ids_map`, "scores_map": `label_names_map`}.
If `img_file` is a list, the result is a list composed of dicts with the
corresponding fields:
class_ids_map (np.ndarray): class_ids
scores_map (np.ndarray): scores
label_names_map (np.ndarray): label_names
"""
if transforms is None and not hasattr(self, 'test_transforms'):
raise Exception("transforms need to be defined, now is None.")
raise ValueError("transforms need to be defined, now is None.")
if transforms is None:
transforms = self.test_transforms
if isinstance(img_file, (str, np.ndarray)):

@ -52,7 +52,7 @@ def load_model(model_dir, **params):
Load saved model from a given directory.
Args:
model_dir(str): The directory where the model is saved.
model_dir(str): Directory where the model is saved.
Returns:
The model loaded from the directory.
@ -61,8 +61,8 @@ def load_model(model_dir, **params):
if not osp.exists(model_dir):
logging.error("Directory '{}' does not exist!".format(model_dir))
if not osp.exists(osp.join(model_dir, "model.yml")):
raise Exception("There is no file named model.yml in {}.".format(
model_dir))
raise FileNotFoundError(
"There is no file named model.yml in {}.".format(model_dir))
with open(osp.join(model_dir, "model.yml")) as f:
model_info = yaml.load(f.read(), Loader=yaml.Loader)
@ -76,7 +76,7 @@ def load_model(model_dir, **params):
model_type = model_info['_Attributes']['model_type']
mod = getattr(paddlers.tasks, model_type)
if not hasattr(mod, model_info['Model']):
raise Exception("There is no {} attribute in {}.".format(model_info[
raise ValueError("There is no {} attribute in {}.".format(model_info[
'Model'], mod))
if 'model_name' in model_info['_init_params']:
del model_info['_init_params']['model_name']

@ -81,7 +81,7 @@ class BaseDetector(BaseModel):
if len(image_shape) == 2:
image_shape = [1, 3] + image_shape
if image_shape[-2] % 32 > 0 or image_shape[-1] % 32 > 0:
raise Exception(
raise ValueError(
"Height and width in fixed_input_shape must be a multiple of 32, but received {}.".
format(image_shape[-2:]))
return image_shape
@ -206,34 +206,51 @@ class BaseDetector(BaseModel):
resume_checkpoint=None):
"""
Train the model.
Args:
num_epochs(int): The number of epochs.
train_dataset(paddlers.dataset): Training dataset.
train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
eval_dataset(paddlers.dataset, optional):
Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
optimizer(paddle.optimizer.Optimizer or None, optional):
Optimizer used for training. If None, a default optimizer is used. Defaults to None.
save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
save_dir(str, optional): Directory to save the model. Defaults to 'output'.
pretrain_weights(str or None, optional):
None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
learning_rate(float, optional): Learning rate for training. Defaults to .001.
warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
metric({'VOC', 'COCO', None}, optional):
Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
early_stop_patience(int, optional): Early stop patience. Defaults to 5.
use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
`pretrain_weights` can be set simultaneously. Defaults to None.
num_epochs (int): Number of epochs.
train_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset):
Training dataset.
train_batch_size (int, optional): Total batch size among all cards used in
training. Defaults to 64.
eval_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset, optional):
Evaluation dataset. If None, the model will not be evaluated during training
process. Defaults to None.
optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used for
training. If None, a default optimizer will be used. Defaults to None.
save_interval_epochs (int, optional): Epoch interval for saving the model.
Defaults to 1.
log_interval_steps (int, optional): Step interval for printing training
information. Defaults to 10.
save_dir (str, optional): Directory to save the model. Defaults to 'output'.
pretrain_weights (str|None, optional): None or name/path of pretrained
weights. If None, no pretrained weights will be loaded.
Defaults to 'IMAGENET'.
learning_rate (float, optional): Learning rate for training. Defaults to .001.
warmup_steps (int, optional): Number of steps of warm-up training.
Defaults to 0.
warmup_start_lr (float, optional): Start learning rate of warm-up training.
Defaults to 0..
lr_decay_epochs (list|tuple, optional): Epoch milestones for learning
rate decay. Defaults to (216, 243).
lr_decay_gamma (float, optional): Gamma coefficient of learning rate decay.
Defaults to .1.
metric (str|None, optional): Evaluation metric. Choices are {'VOC', 'COCO', None}.
If None, determine the metric according to the dataset format.
Defaults to None.
use_ema (bool, optional): Whether to use exponential moving average
strategy. Defaults to False.
early_stop (bool, optional): Whether to adopt early stop strategy.
Defaults to False.
early_stop_patience (int, optional): Early stop patience. Defaults to 5.
use_vdl(bool, optional): Whether to use VisualDL to monitor the training
process. Defaults to True.
resume_checkpoint (str|None, optional): Path of the checkpoint to resume
training from. If None, no training checkpoint will be resumed. At most
Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
Defaults to None.
"""
if self.status == 'Infer':
logging.error(
"Exported inference model does not support training.",
@ -242,7 +259,7 @@ class BaseDetector(BaseModel):
logging.error(
"pretrain_weights and resume_checkpoint cannot be set simultaneously.",
exit=True)
if train_dataset.__class__.__name__ == 'VOCDetection':
if train_dataset.__class__.__name__ == 'VOCDetDataset':
train_dataset.data_fields = {
'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
'difficult'
@ -260,13 +277,13 @@ class BaseDetector(BaseModel):
}
if metric is None:
if eval_dataset.__class__.__name__ == 'VOCDetection':
if eval_dataset.__class__.__name__ == 'VOCDetDataset':
self.metric = 'voc'
elif eval_dataset.__class__.__name__ == 'CocoDetection':
elif eval_dataset.__class__.__name__ == 'COCODetDataset':
self.metric = 'coco'
else:
assert metric.lower() in ['coco', 'voc'], \
"Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
"Evaluation metric {} is not supported. Please choose from 'COCO' and 'VOC'."
self.metric = metric.lower()
self.labels = train_dataset.labels
@ -355,33 +372,50 @@ class BaseDetector(BaseModel):
quant_config=None):
"""
Quantization-aware training.
Args:
num_epochs(int): The number of epochs.
train_dataset(paddlers.dataset): Training dataset.
train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
eval_dataset(paddlers.dataset, optional):
Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
optimizer(paddle.optimizer.Optimizer or None, optional):
Optimizer used for training. If None, a default optimizer is used. Defaults to None.
save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
save_dir(str, optional): Directory to save the model. Defaults to 'output'.
learning_rate(float, optional): Learning rate for training. Defaults to .001.
warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
metric({'VOC', 'COCO', None}, optional):
Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
early_stop_patience(int, optional): Early stop patience. Defaults to 5.
use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
configuration will be used. Defaults to None.
resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
from. If None, no training checkpoint will be resumed. Defaults to None.
num_epochs (int): Number of epochs.
train_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset):
Training dataset.
train_batch_size (int, optional): Total batch size among all cards used in
training. Defaults to 64.
eval_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset, optional):
Evaluation dataset. If None, the model will not be evaluated during training
process. Defaults to None.
optimizer (paddle.optimizer.Optimizer or None, optional): Optimizer used for
training. If None, a default optimizer will be used. Defaults to None.
save_interval_epochs (int, optional): Epoch interval for saving the model.
Defaults to 1.
log_interval_steps (int, optional): Step interval for printing training
information. Defaults to 10.
save_dir (str, optional): Directory to save the model. Defaults to 'output'.
learning_rate (float, optional): Learning rate for training.
Defaults to .00001.
warmup_steps (int, optional): Number of steps of warm-up training.
Defaults to 0.
warmup_start_lr (float, optional): Start learning rate of warm-up training.
Defaults to 0..
lr_decay_epochs (list or tuple, optional): Epoch milestones for learning rate
decay. Defaults to (216, 243).
lr_decay_gamma (float, optional): Gamma coefficient of learning rate decay.
Defaults to .1.
metric (str|None, optional): Evaluation metric. Choices are {'VOC', 'COCO', None}.
If None, determine the metric according to the dataset format.
Defaults to None.
use_ema (bool, optional): Whether to use exponential moving average strategy.
Defaults to False.
early_stop (bool, optional): Whether to adopt early stop strategy.
Defaults to False.
early_stop_patience (int, optional): Early stop patience. Defaults to 5.
use_vdl (bool, optional): Whether to use VisualDL to monitor the training
process. Defaults to True.
quant_config (dict or None, optional): Quantization configuration. If None,
a default rule of thumb configuration will be used. Defaults to None.
resume_checkpoint (str|None, optional): Path of the checkpoint to resume
quantization-aware training from. If None, no training checkpoint will
be resumed. Defaults to None.
"""
self._prepare_qat(quant_config)
self.train(
num_epochs=num_epochs,
@ -412,25 +446,32 @@ class BaseDetector(BaseModel):
return_details=False):
"""
Evaluate the model.
Args:
eval_dataset(paddlers.dataset): Evaluation dataset.
batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
metric({'VOC', 'COCO', None}, optional):
Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
return_details(bool, optional): Whether to return evaluation details. Defaults to False.
eval_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset):
Evaluation dataset.
batch_size (int, optional): Total batch size among all cards used for
evaluation. Defaults to 1.
metric (str|None, optional): Evaluation metric. Choices are {'VOC', 'COCO', None}.
If None, determine the metric according to the dataset format.
Defaults to None.
return_details (bool, optional): Whether to return evaluation details.
Defaults to False.
Returns:
collections.OrderedDict with key-value pairs: {"mAP(0.50, 11point)":`mean average precision`}.
collections.OrderedDict with key-value pairs:
{"mAP(0.50, 11point)":`mean average precision`}.
"""
if metric is None:
if not hasattr(self, 'metric'):
if eval_dataset.__class__.__name__ == 'VOCDetection':
if eval_dataset.__class__.__name__ == 'VOCDetDataset':
self.metric = 'voc'
elif eval_dataset.__class__.__name__ == 'CocoDetection':
elif eval_dataset.__class__.__name__ == 'COCODetDataset':
self.metric = 'coco'
else:
assert metric.lower() in ['coco', 'voc'], \
"Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
"Evaluation metric {} is not supported. Please choose from 'COCO' and 'VOC'."
self.metric = metric.lower()
if self.metric == 'voc':
@ -506,24 +547,32 @@ class BaseDetector(BaseModel):
def predict(self, img_file, transforms=None):
"""
Do inference.
Args:
img_file(list[np.ndarray | str] | str | np.ndarray):
Image path or decoded image data, which also could constitute a list,meaning all images to be
img_file (list[np.ndarray|str] | str | np.ndarray): Image path or decoded
image data, which also could constitute a list, meaning all images to be
predicted as a mini-batch.
transforms(paddlers.transforms.Compose or None, optional):
Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
transforms (paddlers.transforms.Compose|None, optional): Transforms for
inputs. If None, the transforms for evaluation process will be used.
Defaults to None.
Returns:
If img_file is a string or np.array, the result is a list of dict with key-value pairs:
{"category_id": `category_id`, "category": `category`, "bbox": `[x, y, w, h]`, "score": `score`}.
If img_file is a list, the result is a list composed of dicts with the corresponding fields:
category_id(int): the predicted category ID. 0 represents the first category in the dataset, and so on.
category(str): category name
bbox(list): bounding box in [x, y, w, h] format
score(str): confidence
mask(dict): Only for instance segmentation task. Mask of the object in RLE format
If `img_file` is a string or np.array, the result is a list of dict with
key-value pairs:
{"category_id": `category_id`, "category": `category`, "bbox": `[x, y, w, h]`, "score": `score`}.
If `img_file` is a list, the result is a list composed of dicts with the
corresponding fields:
category_id(int): the predicted category ID. 0 represents the first
category in the dataset, and so on.
category(str): category name
bbox(list): bounding box in [x, y, w, h] format
score(str): confidence
mask(dict): Only for instance segmentation task. Mask of the object in
RLE format
"""
if transforms is None and not hasattr(self, 'test_transforms'):
raise Exception("transforms need to be defined, now is None.")
raise ValueError("transforms need to be defined, now is None.")
if transforms is None:
transforms = self.test_transforms
if isinstance(img_file, (str, np.ndarray)):
@ -649,7 +698,7 @@ class PicoDet(BaseDetector):
}:
raise ValueError(
"backbone: {} is not supported. Please choose one of "
"('ESNet_s', 'ESNet_m', 'ESNet_l', 'LCNet', 'MobileNetV3', 'ResNet18_vd')".
"{'ESNet_s', 'ESNet_m', 'ESNet_l', 'LCNet', 'MobileNetV3', 'ResNet18_vd'}.".
format(backbone))
self.backbone_name = backbone
if params.get('with_net', True):
@ -772,7 +821,7 @@ class PicoDet(BaseDetector):
for i, op in enumerate(transforms.transforms):
if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
if mode != 'train':
raise Exception(
raise ValueError(
"{} cannot be present in the {} transforms. ".format(
op.__class__.__name__, mode) +
"Please check the {} transforms.".format(mode))
@ -851,34 +900,51 @@ class PicoDet(BaseDetector):
resume_checkpoint=None):
"""
Train the model.
Args:
num_epochs(int): The number of epochs.
train_dataset(paddlers.dataset): Training dataset.
train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
eval_dataset(paddlers.dataset, optional):
Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
optimizer(paddle.optimizer.Optimizer or None, optional):
Optimizer used for training. If None, a default optimizer is used. Defaults to None.
save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
save_dir(str, optional): Directory to save the model. Defaults to 'output'.
pretrain_weights(str or None, optional):
None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
learning_rate(float, optional): Learning rate for training. Defaults to .001.
warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
metric({'VOC', 'COCO', None}, optional):
Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
early_stop_patience(int, optional): Early stop patience. Defaults to 5.
use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
`pretrain_weights` can be set simultaneously. Defaults to None.
num_epochs (int): Number of epochs.
train_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset):
Training dataset.
train_batch_size (int, optional): Total batch size among all cards used in
training. Defaults to 64.
eval_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset, optional):
Evaluation dataset. If None, the model will not be evaluated during training
process. Defaults to None.
optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used for
training. If None, a default optimizer will be used. Defaults to None.
save_interval_epochs (int, optional): Epoch interval for saving the model.
Defaults to 1.
log_interval_steps (int, optional): Step interval for printing training
information. Defaults to 10.
save_dir (str, optional): Directory to save the model. Defaults to 'output'.
pretrain_weights (str|None, optional): None or name/path of pretrained
weights. If None, no pretrained weights will be loaded.
Defaults to 'IMAGENET'.
learning_rate (float, optional): Learning rate for training. Defaults to .001.
warmup_steps (int, optional): Number of steps of warm-up training.
Defaults to 0.
warmup_start_lr (float, optional): Start learning rate of warm-up training.
Defaults to 0..
lr_decay_epochs (list|tuple, optional): Epoch milestones for learning
rate decay. Defaults to (216, 243).
lr_decay_gamma (float, optional): Gamma coefficient of learning rate decay.
Defaults to .1.
metric (str|None, optional): Evaluation metric. Choices are {'VOC', 'COCO', None}.
If None, determine the metric according to the dataset format.
Defaults to None.
use_ema (bool, optional): Whether to use exponential moving average
strategy. Defaults to False.
early_stop (bool, optional): Whether to adopt early stop strategy.
Defaults to False.
early_stop_patience (int, optional): Early stop patience. Defaults to 5.
use_vdl(bool, optional): Whether to use VisualDL to monitor the training
process. Defaults to True.
resume_checkpoint (str|None, optional): Path of the checkpoint to resume
training from. If None, no training checkpoint will be resumed. At most
Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
Defaults to None.
"""
if optimizer is None:
num_steps_each_epoch = len(train_dataset) // train_batch_size
optimizer = self.default_optimizer(
@ -936,8 +1002,8 @@ class YOLOv3(BaseDetector):
}:
raise ValueError(
"backbone: {} is not supported. Please choose one of "
"('MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3', 'MobileNetV3_ssld', 'DarkNet53', "
"'ResNet50_vd_dcn', 'ResNet34')".format(backbone))
"{'MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3', 'MobileNetV3_ssld', 'DarkNet53', "
"'ResNet50_vd_dcn', 'ResNet34'}.".format(backbone))
self.backbone_name = backbone
if params.get('with_net', True):
@ -1030,7 +1096,7 @@ class YOLOv3(BaseDetector):
for i, op in enumerate(transforms.transforms):
if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
if mode != 'train':
raise Exception(
raise ValueError(
"{} cannot be present in the {} transforms. ".format(
op.__class__.__name__, mode) +
"Please check the {} transforms.".format(mode))
@ -1089,8 +1155,8 @@ class FasterRCNN(BaseDetector):
}:
raise ValueError(
"backbone: {} is not supported. Please choose one of "
"('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
"'ResNet101', 'ResNet101_vd', 'HRNet_W18')".format(backbone))
"{'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
"'ResNet101', 'ResNet101_vd', 'HRNet_W18'}.".format(backbone))
self.backbone_name = backbone
if params.get('with_net', True):
@ -1327,34 +1393,51 @@ class FasterRCNN(BaseDetector):
resume_checkpoint=None):
"""
Train the model.
Args:
num_epochs(int): The number of epochs.
train_dataset(paddlers.dataset): Training dataset.
train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
eval_dataset(paddlers.dataset, optional):
Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
optimizer(paddle.optimizer.Optimizer or None, optional):
Optimizer used for training. If None, a default optimizer is used. Defaults to None.
save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
save_dir(str, optional): Directory to save the model. Defaults to 'output'.
pretrain_weights(str or None, optional):
None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
learning_rate(float, optional): Learning rate for training. Defaults to .001.
warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
metric({'VOC', 'COCO', None}, optional):
Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
early_stop_patience(int, optional): Early stop patience. Defaults to 5.
use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
`pretrain_weights` can be set simultaneously. Defaults to None.
num_epochs (int): Number of epochs.
train_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset):
Training dataset.
train_batch_size (int, optional): Total batch size among all cards used in
training. Defaults to 64.
eval_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset, optional):
Evaluation dataset. If None, the model will not be evaluated during training
process. Defaults to None.
optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used for
training. If None, a default optimizer will be used. Defaults to None.
save_interval_epochs (int, optional): Epoch interval for saving the model.
Defaults to 1.
log_interval_steps (int, optional): Step interval for printing training
information. Defaults to 10.
save_dir (str, optional): Directory to save the model. Defaults to 'output'.
pretrain_weights (str|None, optional): None or name/path of pretrained
weights. If None, no pretrained weights will be loaded.
Defaults to 'IMAGENET'.
learning_rate (float, optional): Learning rate for training. Defaults to .001.
warmup_steps (int, optional): Number of steps of warm-up training.
Defaults to 0.
warmup_start_lr (float, optional): Start learning rate of warm-up training.
Defaults to 0..
lr_decay_epochs (list|tuple, optional): Epoch milestones for learning
rate decay. Defaults to (216, 243).
lr_decay_gamma (float, optional): Gamma coefficient of learning rate decay.
Defaults to .1.
metric (str|None, optional): Evaluation metric. Choices are {'VOC', 'COCO', None}.
If None, determine the metric according to the dataset format.
Defaults to None.
use_ema (bool, optional): Whether to use exponential moving average
strategy. Defaults to False.
early_stop (bool, optional): Whether to adopt early stop strategy.
Defaults to False.
early_stop_patience (int, optional): Early stop patience. Defaults to 5.
use_vdl(bool, optional): Whether to use VisualDL to monitor the training
process. Defaults to True.
resume_checkpoint (str|None, optional): Path of the checkpoint to resume
training from. If None, no training checkpoint will be resumed. At most
Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
Defaults to None.
"""
if train_dataset.pos_num < len(train_dataset.file_list):
train_dataset.num_workers = 0
super(FasterRCNN, self).train(
@ -1377,7 +1460,7 @@ class FasterRCNN(BaseDetector):
for i, op in enumerate(transforms.transforms):
if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
if mode != 'train':
raise Exception(
raise ValueError(
"{} cannot be present in the {} transforms. ".format(
op.__class__.__name__, mode) +
"Please check the {} transforms.".format(mode))
@ -1456,7 +1539,7 @@ class PPYOLO(YOLOv3):
}:
raise ValueError(
"backbone: {} is not supported. Please choose one of "
"('ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large', 'MobileNetV3_small')".
"{'ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large', 'MobileNetV3_small'}.".
format(backbone))
self.backbone_name = backbone
self.downsample_ratios = [
@ -1769,7 +1852,7 @@ class PPYOLOv2(YOLOv3):
if backbone not in {'ResNet50_vd_dcn', 'ResNet101_vd_dcn'}:
raise ValueError(
"backbone: {} is not supported. Please choose one of "
"('ResNet50_vd_dcn', 'ResNet101_vd_dcn')".format(backbone))
"{'ResNet50_vd_dcn', 'ResNet101_vd_dcn'}.".format(backbone))
self.backbone_name = backbone
self.downsample_ratios = [32, 16, 8]
@ -1916,7 +1999,7 @@ class MaskRCNN(BaseDetector):
}:
raise ValueError(
"backbone: {} is not supported. Please choose one of "
"('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101', 'ResNet101_vd')".
"{'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101', 'ResNet101_vd'}.".
format(backbone))
self.backbone_name = backbone + '_fpn' if with_fpn else backbone
@ -2152,34 +2235,51 @@ class MaskRCNN(BaseDetector):
resume_checkpoint=None):
"""
Train the model.
Args:
num_epochs(int): The number of epochs.
train_dataset(paddlers.dataset): Training dataset.
train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
eval_dataset(paddlers.dataset, optional):
Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
optimizer(paddle.optimizer.Optimizer or None, optional):
Optimizer used for training. If None, a default optimizer is used. Defaults to None.
save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
save_dir(str, optional): Directory to save the model. Defaults to 'output'.
pretrain_weights(str or None, optional):
None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
learning_rate(float, optional): Learning rate for training. Defaults to .001.
warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
metric({'VOC', 'COCO', None}, optional):
Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
early_stop_patience(int, optional): Early stop patience. Defaults to 5.
use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
`pretrain_weights` can be set simultaneously. Defaults to None.
num_epochs (int): Number of epochs.
train_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset):
Training dataset.
train_batch_size (int, optional): Total batch size among all cards used in
training. Defaults to 64.
eval_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset, optional):
Evaluation dataset. If None, the model will not be evaluated during training
process. Defaults to None.
optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used for
training. If None, a default optimizer will be used. Defaults to None.
save_interval_epochs (int, optional): Epoch interval for saving the model.
Defaults to 1.
log_interval_steps (int, optional): Step interval for printing training
information. Defaults to 10.
save_dir (str, optional): Directory to save the model. Defaults to 'output'.
pretrain_weights (str|None, optional): None or name/path of pretrained
weights. If None, no pretrained weights will be loaded.
Defaults to 'IMAGENET'.
learning_rate (float, optional): Learning rate for training. Defaults to .001.
warmup_steps (int, optional): Number of steps of warm-up training.
Defaults to 0.
warmup_start_lr (float, optional): Start learning rate of warm-up training.
Defaults to 0..
lr_decay_epochs (list|tuple, optional): Epoch milestones for learning
rate decay. Defaults to (216, 243).
lr_decay_gamma (float, optional): Gamma coefficient of learning rate decay.
Defaults to .1.
metric (str|None, optional): Evaluation metric. Choices are {'VOC', 'COCO', None}.
If None, determine the metric according to the dataset format.
Defaults to None.
use_ema (bool, optional): Whether to use exponential moving average
strategy. Defaults to False.
early_stop (bool, optional): Whether to adopt early stop strategy.
Defaults to False.
early_stop_patience (int, optional): Early stop patience. Defaults to 5.
use_vdl(bool, optional): Whether to use VisualDL to monitor the training
process. Defaults to True.
resume_checkpoint (str|None, optional): Path of the checkpoint to resume
training from. If None, no training checkpoint will be resumed. At most
Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
Defaults to None.
"""
if train_dataset.pos_num < len(train_dataset.file_list):
train_dataset.num_workers = 0
super(MaskRCNN, self).train(
@ -2202,7 +2302,7 @@ class MaskRCNN(BaseDetector):
for i, op in enumerate(transforms.transforms):
if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
if mode != 'train':
raise Exception(
raise ValueError(
"{} cannot be present in the {} transforms. ".format(
op.__class__.__name__, mode) +
"Please check the {} transforms.".format(mode))

@ -228,29 +228,38 @@ class BaseSegmenter(BaseModel):
resume_checkpoint=None):
"""
Train the model.
Args:
num_epochs(int): The number of epochs.
train_dataset(paddlers.dataset): Training dataset.
train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
eval_dataset(paddlers.dataset, optional):
Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
optimizer(paddle.optimizer.Optimizer or None, optional):
Optimizer used in training. If None, a default optimizer is used. Defaults to None.
save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
save_dir(str, optional): Directory to save the model. Defaults to 'output'.
pretrain_weights(str or None, optional):
None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'CITYSCAPES'.
learning_rate(float, optional): Learning rate for training. Defaults to .025.
lr_decay_power(float, optional): Learning decay power. Defaults to .9.
early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
early_stop_patience(int, optional): Early stop patience. Defaults to 5.
use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
`pretrain_weights` can be set simultaneously. Defaults to None.
Args:
num_epochs (int): Number of epochs.
train_dataset (paddlers.datasets.SegDataset): Training dataset.
train_batch_size (int, optional): Total batch size among all cards used in
training. Defaults to 2.
eval_dataset (paddlers.datasets.SegDataset|None, optional): Evaluation dataset.
If None, the model will not be evaluated during training process.
Defaults to None.
optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in
training. If None, a default optimizer will be used. Defaults to None.
save_interval_epochs (int, optional): Epoch interval for saving the model.
Defaults to 1.
log_interval_steps (int, optional): Step interval for printing training
information. Defaults to 2.
save_dir (str, optional): Directory to save the model. Defaults to 'output'.
pretrain_weights (str|None, optional): None or name/path of pretrained
weights. If None, no pretrained weights will be loaded.
Defaults to 'CITYSCAPES'.
learning_rate (float, optional): Learning rate for training. Defaults to .025.
lr_decay_power (float, optional): Learning decay power. Defaults to .9.
early_stop (bool, optional): Whether to adopt early stop strategy. Defaults
to False.
early_stop_patience (int, optional): Early stop patience. Defaults to 5.
use_vdl (bool, optional): Whether to use VisualDL to monitor the training
process. Defaults to True.
resume_checkpoint (str|None, optional): Path of the checkpoint to resume
training from. If None, no training checkpoint will be resumed. At most
Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
Defaults to None.
"""
if self.status == 'Infer':
logging.error(
"Exported inference model does not support training.",
@ -326,28 +335,37 @@ class BaseSegmenter(BaseModel):
quant_config=None):
"""
Quantization-aware training.
Args:
num_epochs(int): The number of epochs.
train_dataset(paddlers.dataset): Training dataset.
train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
eval_dataset(paddlers.dataset, optional):
Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
optimizer(paddle.optimizer.Optimizer or None, optional):
Optimizer used in training. If None, a default optimizer is used. Defaults to None.
save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
save_dir(str, optional): Directory to save the model. Defaults to 'output'.
learning_rate(float, optional): Learning rate for training. Defaults to .025.
lr_decay_power(float, optional): Learning decay power. Defaults to .9.
early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
early_stop_patience(int, optional): Early stop patience. Defaults to 5.
use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
configuration will be used. Defaults to None.
resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
from. If None, no training checkpoint will be resumed. Defaults to None.
Args:
num_epochs (int): Number of epochs.
train_dataset (paddlers.datasets.SegDataset): Training dataset.
train_batch_size (int, optional): Total batch size among all cards used in
training. Defaults to 2.
eval_dataset (paddlers.datasets.SegDataset|None, optional): Evaluation dataset.
If None, the model will not be evaluated during training process.
Defaults to None.
optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in
training. If None, a default optimizer will be used. Defaults to None.
save_interval_epochs (int, optional): Epoch interval for saving the model.
Defaults to 1.
log_interval_steps (int, optional): Step interval for printing training
information. Defaults to 2.
save_dir (str, optional): Directory to save the model. Defaults to 'output'.
learning_rate (float, optional): Learning rate for training.
Defaults to .0001.
lr_decay_power (float, optional): Learning decay power. Defaults to .9.
early_stop (bool, optional): Whether to adopt early stop strategy.
Defaults to False.
early_stop_patience (int, optional): Early stop patience. Defaults to 5.
use_vdl (bool, optional): Whether to use VisualDL to monitor the training
process. Defaults to True.
quant_config (dict|None, optional): Quantization configuration. If None,
a default rule of thumb configuration will be used. Defaults to None.
resume_checkpoint (str|None, optional): Path of the checkpoint to resume
quantization-aware training from. If None, no training checkpoint will
be resumed. Defaults to None.
"""
self._prepare_qat(quant_config)
self.train(
num_epochs=num_epochs,
@ -369,10 +387,13 @@ class BaseSegmenter(BaseModel):
def evaluate(self, eval_dataset, batch_size=1, return_details=False):
"""
Evaluate the model.
Args:
eval_dataset(paddlers.dataset): Evaluation dataset.
batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
return_details(bool, optional): Whether to return evaluation details. Defaults to False.
eval_dataset (paddlers.datasets.SegDataset): Evaluation dataset.
batch_size (int, optional): Total batch size among all cards used for
evaluation. Defaults to 1.
return_details (bool, optional): Whether to return evaluation details.
Defaults to False.
Returns:
collections.OrderedDict with key-value pairs:
@ -384,6 +405,7 @@ class BaseSegmenter(BaseModel):
"category_F1-score": `F1 score`}.
"""
self._check_transforms(eval_dataset.transforms, 'eval')
self.net.eval()
@ -477,24 +499,27 @@ class BaseSegmenter(BaseModel):
def predict(self, img_file, transforms=None):
"""
Do inference.
Args:
Args:
img_file(list[np.ndarray | str] | str | np.ndarray):
Image path or decoded image data, which also could constitute a list,meaning all images to be
img_file (list[np.ndarray|str] | str | np.ndarray): Image path or decoded
image data, which also could constitute a list, meaning all images to be
predicted as a mini-batch.
transforms(paddlers.transforms.Compose or None, optional):
Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
transforms (paddlers.transforms.Compose|None, optional): Transforms for
inputs. If None, the transforms for evaluation process will be used.
Defaults to None.
Returns:
If img_file is a string or np.array, the result is a dict with key-value pairs:
{"label map": `label map`, "score_map": `score map`}.
If img_file is a list, the result is a list composed of dicts with the corresponding fields:
label_map(np.ndarray): the predicted label map (HW)
score_map(np.ndarray): the prediction score map (HWC)
If `img_file` is a string or np.array, the result is a dict with key-value
pairs:
{"label map": `label map`, "score_map": `score map`}.
If `img_file` is a list, the result is a list composed of dicts with the
corresponding fields:
label_map (np.ndarray): the predicted label map (HW)
score_map (np.ndarray): the prediction score map (HWC)
"""
if transforms is None and not hasattr(self, 'test_transforms'):
raise Exception("transforms need to be defined, now is None.")
raise ValueError("transforms need to be defined, now is None.")
if transforms is None:
transforms = self.test_transforms
if isinstance(img_file, (str, np.ndarray)):
@ -528,19 +553,19 @@ class BaseSegmenter(BaseModel):
transforms=None):
"""
Do inference.
Args:
Args:
img_file(str):
Image path.
save_dir(str):
Directory that contains saved geotiff file.
block_size(list[int] | tuple[int] | int):
img_file (str): Image path.
save_dir (str): Directory that contains saved geotiff file.
block_size (list[int] | tuple[int] | int):
Size of block.
overlap(list[int] | tuple[int] | int, optional):
overlap (list[int] | tuple[int] | int, optional):
Overlap between two blocks. Defaults to 36.
transforms(paddlers.transforms.Compose or None, optional):
Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
transforms (paddlers.transforms.Compose|None, optional): Transforms for
inputs. If None, the transforms for evaluation process will be used.
Defaults to None.
"""
try:
from osgeo import gdal
except:
@ -790,7 +815,7 @@ class DeepLabV3P(BaseSegmenter):
if backbone not in ['ResNet50_vd', 'ResNet101_vd']:
raise ValueError(
"backbone: {} is not supported. Please choose one of "
"('ResNet50_vd', 'ResNet101_vd')".format(backbone))
"{'ResNet50_vd', 'ResNet101_vd'}.".format(backbone))
if params.get('with_net', True):
with DisablePrint():
backbone = getattr(paddleseg.models, backbone)(
@ -834,8 +859,8 @@ class HRNet(BaseSegmenter):
**params):
if width not in (18, 48):
raise ValueError(
"width={} is not supported, please choose from [18, 48]".format(
width))
"width={} is not supported, please choose from {18, 48}.".
format(width))
self.backbone_name = 'HRNet_W{}'.format(width)
if params.get('with_net', True):
with DisablePrint():

@ -71,13 +71,13 @@ def cocoapi_eval(anns,
classwise=False):
"""
Args:
anns: Evaluation result.
style (str): COCOeval style, can be `bbox` , `segm` and `proposal`.
coco_gt (str): Whether to load COCOAPI through anno_file,
anns (list): Evaluation result.
style (str): COCOeval style. Choices are 'bbox', 'segm' and 'proposal'.
coco_gt (str, optional): Whether to load COCOAPI through anno_file,
eg: coco_gt = COCO(anno_file)
anno_file (str): COCO annotations file.
max_dets (tuple): COCO evaluation maxDets.
classwise (bool): Whether per-category AP and draw P-R Curve or not.
anno_file (str, optional): COCO annotations file. Defaults to None.
max_dets (tuple, optional): COCO evaluation maxDets. Defaults to (100, 300, 1000).
classwise (bool, optional): Whether to calculate per-category statistics or not. Defaults to None.
"""
assert coco_gt is not None or anno_file is not None
@ -148,12 +148,6 @@ def cocoapi_eval(anns,
def loadRes(coco_obj, anns):
"""
Load result file and return a result api object.
:param resFile (str) : file name of result file
:return: res (obj) : result api object
"""
# This function has the same functionality as pycocotools.COCO.loadRes,
# except that the input anns is list of results rather than a json file.
# Refer to
@ -294,7 +288,6 @@ def analyze_individual_category(k, cocoDt, cocoGt, catId, iou_type, areas=None):
int:
dict: 有关键字'ps_supercategory''ps_allcategory'关键字'ps_supercategory'的键值是忽略亚类间
混淆时的准确率关键字'ps_allcategory'的键值是忽略类别间混淆时的准确率
"""
# matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
@ -402,13 +395,13 @@ def coco_error_analysis(eval_details_file=None,
pred_mask = eval_details['mask']
gt = eval_details['gt']
if gt is None or pred_bbox is None:
raise Exception(
"gt/pred_bbox/pred_mask is None now, please set right eval_details_file or gt/pred_bbox/pred_mask."
raise ValueError(
"gt/pred_bbox/pred_mask is None now. Please set right eval_details_file or gt/pred_bbox/pred_mask."
)
if pred_bbox is not None and len(pred_bbox) == 0:
raise Exception("There is no predicted bbox.")
raise ValueError("There is no predicted bbox.")
if pred_mask is not None and len(pred_mask) == 0:
raise Exception("There is no predicted mask.")
raise ValueError("There is no predicted mask.")
def _analyze_results(cocoGt, cocoDt, res_type, out_dir):
"""
@ -474,4 +467,4 @@ def coco_error_analysis(eval_details_file=None,
if pred_mask is not None:
coco_dt = loadRes(coco_gt, pred_mask)
_analyze_results(coco_gt, coco_dt, res_type='segm', out_dir=save_dir)
logging.info("The analysis figures are saved in {}".format(save_dir))
logging.info("The analysis figures are saved in {}.".format(save_dir))

@ -50,11 +50,12 @@ def visualize_segmentation(image, result, weight=0.6, save_dir='./',
Convert segment result to color image, and save added image.
Args:
image: the path of origin image
result: the predict result of image
weight: the image weight of visual image, and the result weight is (1 - weight)
save_dir: the directory for saving visual image
color: the list of a BGR-mode color for each label.
image (str): Path of original image.
result (dict): Predicted results.
weight (float, optional): Weight used to mix the original image with the predicted image.
Defaults to 0.6.
save_dir (str, optional): Directory for saving visualized image. Defaults to './'.
color (list|None): None or list of BGR indices for each label. Defaults to None.
"""
label_map = result['label_map'].astype("uint8")
@ -106,14 +107,15 @@ def visualize_segmentation(image, result, weight=0.6, save_dir='./',
def get_color_map_list(num_classes):
"""
Returns the color map for visualizing the segmentation mask, which can support arbitrary number of classes.
"""
Get the color map for visualizing a segmentation mask.
This function supports arbitrary number of classes.
Args:
num_classes: Number of classes
num_classes (int): Number of classes.
Returns:
The color map
list: Color map.
"""
color_map = num_classes * [0, 0, 0]
@ -130,10 +132,10 @@ def get_color_map_list(num_classes):
return color_map
# expand an array of boxes by a given scale.
def expand_boxes(boxes, scale):
"""
"""
Expand an array of boxes by a given scale.
"""
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
@ -175,7 +177,7 @@ def draw_bbox_mask(image, results, threshold=0.5, color_map=None):
else:
color_map = np.asarray(color_map)
if color_map.shape[0] != len(labels) or color_map.shape[1] != 3:
raise Exception(
raise ValueError(
"The shape for color_map is required to be {}x3, but recieved shape is {}x{}.".
format(len(labels), color_map.shape))
if np.max(color_map) > 255 or np.min(color_map) < 0:
@ -203,11 +205,11 @@ def draw_bbox_mask(image, results, threshold=0.5, color_map=None):
ymax = ymin + h
color = tuple(map(int, color_map[labels.index(cname)]))
# draw bbox
# Draw bbox
image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color,
linewidth)
# draw mask
# Draw mask
if 'mask' in dt:
mask = dt['mask'] * 255
image = image.astype('float32')
@ -230,7 +232,7 @@ def draw_bbox_mask(image, results, threshold=0.5, color_map=None):
thickness=1,
lineType=cv2.LINE_AA)
# draw label
# Draw label
text_pos = (xmin, ymin)
instance_area = w * h
if (instance_area < _SMALL_OBJECT_AREA_THRESH or h < 40):
@ -279,13 +281,13 @@ def draw_pr_curve(eval_details_file=None,
pred_mask = eval_details['mask']
gt = eval_details['gt']
if gt is None or pred_bbox is None:
raise Exception(
raise ValueError(
"gt/pred_bbox/pred_mask is None now, please set right eval_details_file or gt/pred_bbox/pred_mask."
)
if pred_bbox is not None and len(pred_bbox) == 0:
raise Exception("There is no predicted bbox.")
raise ValueError("There is no predicted bbox.")
if pred_mask is not None and len(pred_mask) == 0:
raise Exception("There is no predicted mask.")
raise ValueError("There is no predicted mask.")
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
@ -297,7 +299,8 @@ def draw_pr_curve(eval_details_file=None,
def _summarize(coco_gt, ap=1, iouThr=None, areaRng='all', maxDets=100):
"""
This function has the same functionality as _summarize() in pycocotools.COCOeval.summarize().
This function has the same functionality as _summarize() in
pycocotools.COCOeval.summarize().
Refer to
https://github.com/cocodataset/cocoapi/blob/8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9/PythonAPI/pycocotools/cocoeval.py#L427,
@ -336,7 +339,7 @@ def draw_pr_curve(eval_details_file=None,
stats = _summarize(coco_eval, iouThr=iou_thresh)
catIds = coco_gt.getCatIds()
if len(catIds) != coco_eval.eval['precision'].shape[2]:
raise Exception(
raise ValueError(
"The category number must be same as the third dimension of precisions."
)
x = np.arange(0.0, 1.01, 0.01)

@ -30,9 +30,9 @@ class BaseAnchorCluster(object):
Base Anchor Cluster
Args:
num_anchors (int): number of clusters
cache (bool): whether using cache
cache_path (str): cache directory path
num_anchors (int): Number of clusters.
cache (bool): Whether to use cache.
cache_path (str): Cache directory path.
"""
super(BaseAnchorCluster, self).__init__()
self.num_anchors = num_anchors
@ -99,14 +99,15 @@ class YOLOAnchorCluster(BaseAnchorCluster):
https://github.com/ultralytics/yolov5/blob/master/utils/autoanchor.py
Args:
num_anchors (int): number of clusters
dataset (DataSet): DataSet instance, VOC or COCO
image_size (list or int): [h, w], being an int means image height and image width are the same.
cache (bool): whether using cache. Defaults to True.
cache_path (str or None, optional): cache directory path. If None, use `data_dir` of dataset. Defaults to None.
iters (int, optional): iters of kmeans algorithm. Defaults to 300.
gen_iters (int, optional): iters of genetic algorithm. Defaults to 1000.
thresh (float, optional): anchor scale threshold. Defaults to 0.25.
num_anchors (int): Number of clusters.
dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset): Dataset instance.
image_size (list[int] | int): [h, w] or an int value that corresponds to the shape [image_size, image_size].
cache (bool, optional): Whether to use cache. Defaults to True.
cache_path (str|None, optional): Path of cache directory. If None, use `dataset.data_dir`.
Defaults to None.
iters (int, optional): Iterations of k-means algorithm. Defaults to 300.
gen_iters (int, optional): Iterations of genetic algorithm. Defaults to 1000.
thresh (float, optional): Anchor scale threshold. Defaults to 0.25.
"""
self.dataset = dataset
if cache_path is None:

@ -29,14 +29,16 @@ def decode_image(im_path,
Decode an image.
Args:
im_path (str): Path of the image to decode.
to_rgb (bool, optional): If True, convert input image(s) from BGR format to RGB format. Defaults to True.
to_uint8 (bool, optional): If True, quantize and convert decoded image(s) to uint8 type. Defaults to True.
decode_bgr (bool, optional): If True, automatically interpret a non-geo image (e.g. jpeg images) as a BGR image.
Defaults to True.
decode_sar (bool, optional): If True, automatically interpret a two-channel geo image (e.g. geotiff images) as a
SAR image, set this argument to True. Defaults to True.
to_rgb (bool, optional): If True, convert input image(s) from BGR format to
RGB format. Defaults to True.
to_uint8 (bool, optional): If True, quantize and convert decoded image(s) to
uint8 type. Defaults to True.
decode_bgr (bool, optional): If True, automatically interpret a non-geo
image (e.g. jpeg images) as a BGR image. Defaults to True.
decode_sar (bool, optional): If True, automatically interpret a two-channel
geo image (e.g. geotiff images) as a SAR image, set this argument to
True. Defaults to True.
Returns:
np.ndarray: Decoded image.
"""

@ -72,17 +72,18 @@ class BatchCompose(Transform):
class BatchRandomResize(Transform):
"""
Resize a batch of input to random sizes.
Resize a batch of inputs to random sizes.
Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
Attention: If `interp` is 'RANDOM', the interpolation method will be chosen randomly.
Args:
target_sizes (list[int] | list[list | tuple] | tuple[list | tuple]):
Multiple target sizes, each target size is an int or list/tuple of length 2.
interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional):
Interpolation method of resize. Defaults to 'LINEAR'.
target_sizes (list[int] | list[list|tuple] | tuple[list|tuple]):
Multiple target sizes, each of which should be an int or list/tuple of length 2.
interp (str, optional): Interpolation method for resizing image(s). One of
{'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}.
Defaults to 'LINEAR'.
Raises:
TypeError: Invalid type of target_size.
TypeError: Invalid type of `target_size`.
ValueError: Invalid interpolation method.
See Also:
@ -111,23 +112,27 @@ class BatchRandomResize(Transform):
class BatchRandomResizeByShort(Transform):
"""Resize a batch of input to random sizes with keeping the aspect ratio.
"""
Resize a batch of inputs to random sizes while keeping the aspect ratio.
Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
Attention: If `interp` is 'RANDOM', the interpolation method will be chosen randomly.
Args:
short_sizes (list[int] | tuple[int]): Target sizes of the shorter side of the image(s).
max_size (int, optional): The upper bound of longer side of the image(s).
If max_size is -1, no upper bound is applied. Defaults to -1.
interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional):
Interpolation method of resize. Defaults to 'LINEAR'.
short_sizes (list[int] | tuple[int]): Target sizes of the shorter side of
the image(s).
max_size (int, optional): Upper bound of longer side of the image(s).
If `max_size` is -1, no upper bound will be applied. Defaults to -1.
interp (str, optional): Interpolation method for resizing image(s). One of
{'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}.
Defaults to 'LINEAR'.
Raises:
TypeError: Invalid type of target_size.
TypeError: Invalid type of `target_size`.
ValueError: Invalid interpolation method.
See Also:
RandomResizeByShort: Resize input to random sizes with keeping the aspect ratio.
RandomResizeByShort: Resize input to random sizes while keeping the aspect
ratio.
"""
def __init__(self, short_sizes, max_size=-1, interp='NEAREST'):
@ -180,7 +185,7 @@ class _BatchPad(Transform):
class _Gt2YoloTarget(Transform):
"""
Generate YOLOv3 targets by groud truth data, this operator is only used in
fine grained YOLOv3 loss mode
fine grained YOLOv3 loss mode.
"""
def __init__(self,

@ -58,75 +58,73 @@ def center_crop(im, crop_size=224):
# region flip
def img_flip(im, method=0):
"""
flip image in different ways, this function provides 5 method to filp
this function can be applied to 2D or 3D images
Flip an image.
This function provides 5 flipping methods and can be applied to 2D or 3D numpy arrays.
Args:
im(array): image array
method(int or string): choose the flip method, it must be one of [
0, 1, 2, 3, 4, 'h', 'v', 'hv', 'rt2lb', 'lt2rb', 'dia', 'adia']
0 or 'h': flipped in horizontal direction, which is the most frequently used method
1 or 'v': flipped in vertical direction
2 or 'hv': flipped in both horizontal diction and vertical direction
3 or 'rt2lb' or 'dia': flipped around the diagonal,
which also can be thought as changing the RightTop part with LeftBottom part,
so it is called 'rt2lb' as well.
4 or 'lt2rb' or 'adia': flipped around the anti-diagonal
which also can be thought as changing the LeftTop part with RightBottom part,
so it is called 'lt2rb' as well.
im (np.ndarray): Input image.
method (int|string): Flipping method. Must be one of [
0, 1, 2, 3, 4, 'h', 'v', 'hv', 'rt2lb', 'lt2rb',
'dia', 'adia'].
0 or 'h': flip the image in horizontal direction, which is the most frequently
used method;
1 or 'v': flip the image in vertical direction;
2 or 'hv': flip the image in both horizontal diction and vertical direction;
3 or 'rt2lb' or 'dia': flip the image across the diagonal;
4 or 'lt2rb' or 'adia': flip the image across the anti-diagonal.
Returns:
flipped image(array)
np.ndarray: Flipped image.
Raises:
ValueError: Shape of image should 2d, 3d or more.
ValueError: Invalid shape of images.
Examples:
--assume an image is like this:
Assume an image is like this:
img:
/ + +
- / *
- * /
--we can flip it in following code:
We can flip it with following code:
img_h = im_flip(img, 'h')
img_v = im_flip(img, 'v')
img_vh = im_flip(img, 2)
img_rt2lb = im_flip(img, 3)
img_lt2rb = im_flip(img, 4)
img_h = img_flip(img, 'h')
img_v = img_flip(img, 'v')
img_vh = img_flip(img, 2)
img_rt2lb = img_flip(img, 3)
img_lt2rb = img_flip(img, 4)
--we can get flipped image:
Then we get the flipped images:
img_h, flipped in horizontal direction
img_h, flipped in horizontal direction:
+ + \
* \ -
\ * -
img_v, flipped in vertical direction
img_v, flipped in vertical direction:
- * \
- \ *
\ + +
img_vh, flipped in both horizontal diction and vertical direction
img_vh, flipped in both horizontal diction and vertical direction:
/ * -
* / -
+ + /
img_rt2lb, flipped around the diagonal
img_rt2lb, mirrored on the diagonal:
/ | |
+ / *
+ * /
img_lt2rb, flipped around the anti-diagonal
img_lt2rb, mirrored on the anti-diagonal:
/ * +
* / +
| | /
"""
if not len(im.shape) >= 2:
raise ValueError("Shape of image should 2d, 3d or more")
raise ValueError("The number of image dimensions is less than 2.")
if method == 0 or method == 'h':
return horizontal_flip(im)
elif method == 1 or method == 'v':
@ -176,61 +174,58 @@ def lt2rb_flip(im):
# region rotation
def img_simple_rotate(im, method=0):
"""
rotate image in simple ways, this function provides 3 method to rotate
this function can be applied to 2D or 3D images
Rotate an image.
This function provides 3 rotating methods and can be applied to 2D or 3D numpy arrays.
Args:
im(array): image array
method(int or string): choose the flip method, it must be one of [
im (np.ndarray): Input image.
method (int|string): Rotating method, which must be one of [
0, 1, 2, 90, 180, 270
]
0 or 90 : rotated in 90 degree, clockwise
1 or 180: rotated in 180 degree, clockwise
2 or 270: rotated in 270 degree, clockwise
].
0 or 90 : rotate the image by 90 degrees, clockwise;
1 or 180: rotate the image by 180 degrees, clockwise;
2 or 270: rotate the image by 270 degrees, clockwise.
Returns:
flipped image(array)
np.ndarray: Rotated image.
Raises:
ValueError: Shape of image should 2d, 3d or more.
ValueError: Invalid shape of images.
Examples:
--assume an image is like this:
Assume an image is like this:
img:
/ + +
- / *
- * /
--we can rotate it in following code:
We can rotate it with following code:
img_r90 = img_simple_rotate(img, 90)
img_r180 = img_simple_rotate(img, 1)
img_r270 = img_simple_rotate(img, 2)
--we can get rotated image:
Then we get the following rotated images:
img_r90, rotated in 90 degree
img_r90, rotated by 90°:
| | \
* \ +
\ * +
img_r180, rotated in 180 degree
img_r180, rotated by 180°:
/ * -
* / -
+ + /
img_r270, rotated in 270 degree
img_r270, rotated by 270°:
+ * \
+ \ *
\ | |
"""
if not len(im.shape) >= 2:
raise ValueError("Shape of image should 2d, 3d or more")
raise ValueError("The number of image dimensions is less than 2.")
if method == 0 or method == 90:
return rot_90(im)
elif method == 1 or method == 180:
@ -396,14 +391,15 @@ def resize_rle(rle, im_h, im_w, im_scale_x, im_scale_y, interp):
def to_uint8(im, is_linear=False):
""" Convert raster to uint8.
"""
Convert raster data to uint8 type.
Args:
im (np.ndarray): The image.
im (np.ndarray): Input raster image.
is_linear (bool, optional): Use 2% linear stretch or not. Default is False.
Returns:
np.ndarray: Image on uint8.
np.ndarray: Image data with unit8 type.
"""
# 2% linear stretch
@ -448,16 +444,18 @@ def to_uint8(im, is_linear=False):
def to_intensity(im):
""" calculate SAR data's intensity diagram.
"""
Calculate the intensity of SAR data.
Args:
im (np.ndarray): The SAR image.
im (np.ndarray): SAR image.
Returns:
np.ndarray: Intensity diagram.
np.ndarray: Intensity image.
"""
if len(im.shape) != 2:
raise ValueError("im's shape must be 2.")
raise ValueError("`len(im.shape) must be 2.")
# the type is complex means this is a SAR data
if isinstance(type(im[0, 0]), complex):
im = abs(im)
@ -465,15 +463,18 @@ def to_intensity(im):
def select_bands(im, band_list=[1, 2, 3]):
""" Select bands.
"""
Select bands of a multi-band image.
Args:
im (np.ndarray): The image.
band_list (list, optional): Bands of selected (Start with 1). Defaults to [1, 2, 3].
im (np.ndarray): Input image.
band_list (list, optional): Bands to select (band index start from 1).
Defaults to [1, 2, 3].
Returns:
np.ndarray: The image after band selected.
np.ndarray: Image with selected bands.
"""
if len(im.shape) == 2: # just have one channel
return im
if not isinstance(band_list, list) or len(band_list) == 0:
@ -492,14 +493,14 @@ def select_bands(im, band_list=[1, 2, 3]):
def dehaze(im, gamma=False):
"""
Single image haze removal using dark channel prior.
Perform single image haze removal using dark channel prior.
Args:
im (np.ndarray): Input image.
gamma (bool, optional): Use gamma correction or not. Defaults to False.
Returns:
np.ndarray: The image after dehazed.
np.ndarray: Output dehazed image.
"""
def _guided_filter(I, p, r, eps):
@ -549,7 +550,8 @@ def match_histograms(im, ref):
Args:
im (np.ndarray): Input image.
ref (np.ndarray): Reference image to match histogram of. `ref` must have the same number of channels as `im`.
ref (np.ndarray): Reference image to match histogram of. `ref` must have
the same number of channels as `im`.
Returns:
np.ndarray: Transformed input image.
@ -557,6 +559,7 @@ def match_histograms(im, ref):
Raises:
ValueError: When the number of channels of `ref` differs from that of im`.
"""
# TODO: Check the data types of the inputs to see if they are supported by skimage
return exposure.match_histograms(
im, ref, channel_axis=-1 if im.ndim > 2 else None)
@ -568,10 +571,12 @@ def match_by_regression(im, ref, pif_loc=None):
Args:
im (np.ndarray): Input image.
ref (np.ndarray): Reference image to match. `ref` must have the same shape as `im`.
pif_loc (tuple|None, optional): Spatial locations where pseudo-invariant features (PIFs) are obtained. If
`pif_loc` is set to None, all pixels in the image will be used as training samples for the regression model.
In other cases, `pif_loc` should be a tuple of np.ndarrays. Default: None.
ref (np.ndarray): Reference image to match. `ref` must have the same shape
as `im`.
pif_loc (tuple|None, optional): Spatial locations where pseudo-invariant
features (PIFs) are obtained. If `pif_loc` is set to None, all pixels in
the image will be used as training samples for the regression model. In
other cases, `pif_loc` should be a tuple of np.ndarrays. Default: None.
Returns:
np.ndarray: Transformed input image.
@ -609,15 +614,16 @@ def match_by_regression(im, ref, pif_loc=None):
def inv_pca(im, joblib_path):
"""
Restore PCA result.
Perform inverse PCA transformation.
Args:
im (np.ndarray): The input image after PCA.
joblib_path (str): Path of *.joblib about PCA.
im (np.ndarray): Input image after performing PCA.
joblib_path (str): Path of *.joblib file that stores PCA information.
Returns:
np.ndarray: The raw input image.
np.ndarray: Reconstructed input image.
"""
pca = load(joblib_path)
H, W, C = im.shape
n_im = np.reshape(im, (-1, C))

@ -123,7 +123,7 @@ class Compose(object):
class Transform(object):
"""
Parent class of all data augmentation operations
Parent class of all data augmentation operators.
"""
def __init__(self):
@ -171,12 +171,15 @@ class DecodeImg(Transform):
Decode image(s) in input.
Args:
to_rgb (bool, optional): If True, convert input image(s) from BGR format to RGB format. Defaults to True.
to_uint8 (bool, optional): If True, quantize and convert decoded image(s) to uint8 type. Defaults to True.
decode_bgr (bool, optional): If True, automatically interpret a non-geo image (e.g., jpeg images) as a BGR image.
Defaults to True.
decode_sar (bool, optional): If True, automatically interpret a two-channel geo image (e.g. geotiff images) as a
SAR image, set this argument to True. Defaults to True.
to_rgb (bool, optional): If True, convert input image(s) from BGR format to
RGB format. Defaults to True.
to_uint8 (bool, optional): If True, quantize and convert decoded image(s) to
uint8 type. Defaults to True.
decode_bgr (bool, optional): If True, automatically interpret a non-geo image
(e.g., jpeg images) as a BGR image. Defaults to True.
decode_sar (bool, optional): If True, automatically interpret a two-channel
geo image (e.g. geotiff images) as a SAR image, set this argument to
True. Defaults to True.
"""
def __init__(self,
@ -262,7 +265,7 @@ class DecodeImg(Transform):
sample (dict): Input sample.
Returns:
dict: Decoded sample.
dict: Sample with decoded images.
"""
if 'image' in sample:
@ -299,17 +302,20 @@ class Resize(Transform):
"""
Resize input.
- If target_size is an int, resize the image(s) to (target_size, target_size).
- If target_size is a list or tuple, resize the image(s) to target_size.
Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
- If `target_size` is an int, resize the image(s) to (`target_size`, `target_size`).
- If `target_size` is a list or tuple, resize the image(s) to `target_size`.
Attention: If `interp` is 'RANDOM', the interpolation method will be chosen randomly.
Args:
target_size (int, list[int] | tuple[int]): Target size. If int, the height and width share the same target_size.
Otherwise, target_size represents [target height, target width].
interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional):
Interpolation method of resize. Defaults to 'LINEAR'.
keep_ratio (bool): the resize scale of width/height is same and width/height after resized is not greater
than target width/height. Defaults to False.
target_size (int | list[int] | tuple[int]): Target size. If it is an integer, the
target height and width will be both set to `target_size`. Otherwise,
`target_size` represents [target height, target width].
interp (str, optional): Interpolation method for resizing image(s). One of
{'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}.
Defaults to 'LINEAR'.
keep_ratio (bool, optional): If True, the scaling factor of width and height will
be set to same value, and height/width of the resized image will be not
greater than the target width/height. Defaults to False.
Raises:
TypeError: Invalid type of target_size.
@ -420,20 +426,18 @@ class RandomResize(Transform):
"""
Resize input to random sizes.
Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
Attention: If `interp` is 'RANDOM', the interpolation method will be chosen randomly.
Args:
target_sizes (list[int] | list[list | tuple] | tuple[list | tuple]):
Multiple target sizes, each target size is an int or list/tuple.
interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional):
Interpolation method of resize. Defaults to 'LINEAR'.
target_sizes (list[int] | list[list|tuple] | tuple[list|tuple]):
Multiple target sizes, each of which should be int, list, or tuple.
interp (str, optional): Interpolation method for resizing image(s). One of
{'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}.
Defaults to 'LINEAR'.
Raises:
TypeError: Invalid type of target_size.
TypeError: Invalid type of `target_size`.
ValueError: Invalid interpolation method.
See Also:
Resize input to a specific size.
"""
def __init__(self, target_sizes, interp='LINEAR'):
@ -459,14 +463,17 @@ class RandomResize(Transform):
class ResizeByShort(Transform):
"""
Resize input with keeping the aspect ratio.
Resize input while keeping the aspect ratio.
Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
Attention: If `interp` is 'RANDOM', the interpolation method will be chosen randomly.
Args:
short_size (int): Target size of the shorter side of the image(s).
max_size (int, optional): The upper bound of longer side of the image(s). If max_size is -1, no upper bound is applied. Defaults to -1.
interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional): Interpolation method of resize. Defaults to 'LINEAR'.
max_size (int, optional): Upper bound of longer side of the image(s). If
`max_size` is -1, no upper bound will be applied. Defaults to -1.
interp (str, optional): Interpolation method for resizing image(s). One of
{'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}.
Defaults to 'LINEAR'.
Raises:
ValueError: Invalid interpolation method.
@ -498,21 +505,24 @@ class ResizeByShort(Transform):
class RandomResizeByShort(Transform):
"""
Resize input to random sizes with keeping the aspect ratio.
Resize input to random sizes while keeping the aspect ratio.
Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
Attention: If `interp` is 'RANDOM', the interpolation method will be chosen randomly.
Args:
short_sizes (list[int]): Target size of the shorter side of the image(s).
max_size (int, optional): The upper bound of longer side of the image(s). If max_size is -1, no upper bound is applied. Defaults to -1.
interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional): Interpolation method of resize. Defaults to 'LINEAR'.
max_size (int, optional): Upper bound of longer side of the image(s).
If `max_size` is -1, no upper bound will be applied. Defaults to -1.
interp (str, optional): Interpolation method for resizing image(s). One of
{'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}.
Defaults to 'LINEAR'.
Raises:
TypeError: Invalid type of target_size.
TypeError: Invalid type of `target_size`.
ValueError: Invalid interpolation method.
See Also:
ResizeByShort: Resize image(s) in input with keeping the aspect ratio.
ResizeByShort: Resize image(s) in input while keeping the aspect ratio.
"""
def __init__(self, short_sizes, max_size=-1, interp='LINEAR'):
@ -555,29 +565,30 @@ class ResizeByLong(Transform):
class RandomFlipOrRotate(Transform):
"""
Flip or Rotate an image in different ways with a certain probability.
Flip or Rotate an image in different directions with a certain probability.
Args:
probs (list of float): Probabilities of flipping and rotation. Default: [0.35,0.25].
probsf (list of float): Probabilities of 5 flipping mode
(horizontal, vertical, both horizontal diction and vertical, diagonal, anti-diagonal).
Default: [0.3, 0.3, 0.2, 0.1, 0.1].
probsr (list of float): Probabilities of 3 rotation mode(90°, 180°, 270° clockwise). Default: [0.25,0.5,0.25].
probs (list[float]): Probabilities of performing flipping and rotation.
Default: [0.35,0.25].
probsf (list[float]): Probabilities of 5 flipping modes (horizontal,
vertical, both horizontal diction and vertical, diagonal,
anti-diagonal). Default: [0.3, 0.3, 0.2, 0.1, 0.1].
probsr (list[float]): Probabilities of 3 rotation modes (90°, 180°, 270°
clockwise). Default: [0.25,0.5,0.25].
Examples:
from paddlers import transforms as T
# 定义数据增强
# Define operators for data augmentation
train_transforms = T.Compose([
T.DecodeImg(),
T.RandomFlipOrRotate(
probs = [0.3, 0.2] # 进行flip增强的概率是0.3,进行rotate增强的概率是0.2,不变的概率是0.5
probsf = [0.3, 0.25, 0, 0, 0] # flip增强时,使用水平flip、垂直flip的概率分别是0.3、0.25,水平且垂直flip、对角线flip、反对角线flip概率均为0,不变的概率是0.45
probsr = [0, 0.65, 0]), # rotate增强时,顺时针旋转90度的概率是0,顺时针旋转180度的概率是0.65,顺时针旋转90度的概率是0,不变的概率是0.35
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
probs = [0.3, 0.2] # p=0.3 to flip the image,p=0.2 to rotate the image,p=0.5 to keep the image unchanged.
probsf = [0.3, 0.25, 0, 0, 0] # p=0.3 and p=0.25 to perform horizontal and vertical flipping; probility of no-flipping is 0.45.
probsr = [0, 0.65, 0]), # p=0.65 to rotate the image by 180°; probility of no-rotation is 0.35.
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
"""
def __init__(self,
@ -615,15 +626,16 @@ class RandomFlipOrRotate(Transform):
)
def get_probs_range(self, probs):
'''
Change various probabilities into cumulative probabilities
"""
Change list of probabilities into cumulative probability intervals.
Args:
probs(list of float): probabilities of different mode, shape:[n]
probs (list[float]): Probabilities of different modes, shape: [n].
Returns:
probability intervals(list of binary list): shape:[n, 2]
'''
list[list]: Probability intervals, shape: [n, 2].
"""
ps = []
last_prob = 0
for prob in probs:
@ -635,17 +647,17 @@ class RandomFlipOrRotate(Transform):
return ps
def judge_probs_range(self, p, probs):
'''
Judge whether a probability value falls within the given probability interval
"""
Judge whether the value of `p` falls within the given probability interval.
Args:
p(float): probability
probs(list of binary list): probability intervals, shape:[n, 2]
p (float): Value between 0 and 1.
probs (list[list]): Probability intervals, shape: [n, 2].
Returns:
mode id(int):the probability interval number where the input probability falls,
if return -1, the image will remain as it is and will not be processed
'''
int: Interval where the input probability falls into.
"""
for id, id_range in enumerate(probs):
if p > id_range[0] and p < id_range[1]:
return id
@ -702,7 +714,7 @@ class RandomHorizontalFlip(Transform):
Randomly flip the input horizontally.
Args:
prob(float, optional): Probability of flipping the input. Defaults to .5.
prob (float, optional): Probability of flipping the input. Defaults to .5.
"""
def __init__(self, prob=0.5):
@ -760,7 +772,7 @@ class RandomVerticalFlip(Transform):
Randomly flip the input vertically.
Args:
prob(float, optional): Probability of flipping the input. Defaults to .5.
prob (float, optional): Probability of flipping the input. Defaults to .5.
"""
def __init__(self, prob=0.5):
@ -821,10 +833,14 @@ class Normalize(Transform):
3. im = im / std
Args:
mean(list[float] | tuple[float], optional): Mean of input image(s). Defaults to [0.485, 0.456, 0.406].
std(list[float] | tuple[float], optional): Standard deviation of input image(s). Defaults to [0.229, 0.224, 0.225].
min_val(list[float] | tuple[float], optional): Minimum value of input image(s). Defaults to [0, 0, 0, ].
max_val(list[float] | tuple[float], optional): Max value of input image(s). Defaults to [255., 255., 255.].
mean (list[float] | tuple[float], optional): Mean of input image(s).
Defaults to [0.485, 0.456, 0.406].
std (list[float] | tuple[float], optional): Standard deviation of input
image(s). Defaults to [0.229, 0.224, 0.225].
min_val (list[float] | tuple[float], optional): Minimum value of input
image(s). Defaults to [0, 0, 0, ].
max_val (list[float] | tuple[float], optional): Max value of input image(s).
Defaults to [255., 255., 255.].
"""
def __init__(self,
@ -872,12 +888,13 @@ class Normalize(Transform):
class CenterCrop(Transform):
"""
Crop the input at the center.
Crop the input image(s) at the center.
1. Locate the center of the image.
2. Crop the sample.
2. Crop the image.
Args:
crop_size(int, optional): target size of the cropped image(s). Defaults to 224.
crop_size (int, optional): Target size of the cropped image(s).
Defaults to 224.
"""
def __init__(self, crop_size=224):
@ -908,22 +925,27 @@ class CenterCrop(Transform):
class RandomCrop(Transform):
"""
Randomly crop the input.
1. Compute the height and width of cropped area according to aspect_ratio and scaling.
1. Compute the height and width of cropped area according to `aspect_ratio` and
`scaling`.
2. Locate the upper left corner of cropped area randomly.
3. Crop the image(s).
4. Resize the cropped area to crop_size by crop_size.
4. Resize the cropped area to `crop_size` x `crop_size`.
Args:
crop_size(int, list[int] | tuple[int]): Target size of the cropped area. If None, the cropped area will not be
resized. Defaults to None.
aspect_ratio (list[float], optional): Aspect ratio of cropped region in [min, max] format. Defaults to [.5, 2.].
thresholds (list[float], optional): Iou thresholds to decide a valid bbox crop.
Defaults to [.0, .1, .3, .5, .7, .9].
scaling (list[float], optional): Ratio between the cropped region and the original image in [min, max] format.
Defaults to [.3, 1.].
num_attempts (int, optional): The number of tries before giving up. Defaults to 50.
allow_no_crop (bool, optional): Whether returning without doing crop is allowed. Defaults to True.
cover_all_box (bool, optional): Whether to ensure all bboxes are covered in the final crop. Defaults to False.
crop_size (int | list[int] | tuple[int]): Target size of the cropped area. If
None, the cropped area will not be resized. Defaults to None.
aspect_ratio (list[float], optional): Aspect ratio of cropped region in
[min, max] format. Defaults to [.5, 2.].
thresholds (list[float], optional): Iou thresholds to decide a valid bbox
crop. Defaults to [.0, .1, .3, .5, .7, .9].
scaling (list[float], optional): Ratio between the cropped region and the
original image in [min, max] format. Defaults to [.3, 1.].
num_attempts (int, optional): Max number of tries before giving up.
Defaults to 50.
allow_no_crop (bool, optional): Whether returning without doing crop is
allowed. Defaults to True.
cover_all_box (bool, optional): Whether to ensure all bboxes be covered in
the final crop. Defaults to False.
"""
def __init__(self,
@ -1107,9 +1129,10 @@ class RandomCrop(Transform):
class RandomScaleAspect(Transform):
"""
Crop input image(s) and resize back to original sizes.
Args:
min_scale (float): Minimum ratio between the cropped region and the original image.
If 0, image(s) will not be cropped. Defaults to .5.
min_scale (float): Minimum ratio between the cropped region and the original
image. If 0, image(s) will not be cropped. Defaults to .5.
aspect_ratio (float): Aspect ratio of cropped region. Defaults to .33.
"""
@ -1135,10 +1158,13 @@ class RandomExpand(Transform):
Randomly expand the input by padding according to random offsets.
Args:
upper_ratio(float, optional): The maximum ratio to which the original image is expanded. Defaults to 4..
prob(float, optional): The probability of apply expanding. Defaults to .5.
im_padding_value(list[float] | tuple[float], optional): RGB filling value for the image. Defaults to (127.5, 127.5, 127.5).
label_padding_value(int, optional): Filling value for the mask. Defaults to 255.
upper_ratio (float, optional): Maximum ratio to which the original image
is expanded. Defaults to 4..
prob (float, optional): Probability of apply expanding. Defaults to .5.
im_padding_value (list[float] | tuple[float], optional): RGB filling value
for the image. Defaults to (127.5, 127.5, 127.5).
label_padding_value (int, optional): Filling value for the mask.
Defaults to 255.
See Also:
paddlers.transforms.Pad
@ -1187,15 +1213,20 @@ class Pad(Transform):
label_padding_value=255,
size_divisor=32):
"""
Pad image to a specified size or multiple of size_divisor.
Pad image to a specified size or multiple of `size_divisor`.
Args:
target_size(int, Sequence, optional): Image target size, if None, pad to multiple of size_divisor. Defaults to None.
pad_mode({-1, 0, 1, 2}, optional): Pad mode, currently only supports four modes [-1, 0, 1, 2]. if -1, use specified offsets
if 0, only pad to right and bottom. If 1, pad according to center. If 2, only pad left and top. Defaults to 0.
im_padding_value(Sequence[float]): RGB value of pad area. Defaults to (127.5, 127.5, 127.5).
label_padding_value(int, optional): Filling value for the mask. Defaults to 255.
size_divisor(int): Image width and height after padding is a multiple of coarsest_stride.
target_size (list[int] | tuple[int], optional): Image target size, if None, pad to
multiple of size_divisor. Defaults to None.
pad_mode (int, optional): Pad mode. Currently only four modes are supported:
[-1, 0, 1, 2]. if -1, use specified offsets. If 0, only pad to right and bottom
If 1, pad according to center. If 2, only pad left and top. Defaults to 0.
im_padding_value (list[float] | tuple[float]): RGB value of padded area.
Defaults to (127.5, 127.5, 127.5).
label_padding_value (int, optional): Filling value for the mask.
Defaults to 255.
size_divisor (int): Image width and height after padding will be a multiple of
`size_divisor`.
"""
super(Pad, self).__init__()
if isinstance(target_size, (list, tuple)):
@ -1306,8 +1337,10 @@ class MixupImage(Transform):
Mixup two images and their gt_bbbox/gt_score.
Args:
alpha (float, optional): Alpha parameter of beta distribution. Defaults to 1.5.
beta (float, optional): Beta parameter of beta distribution. Defaults to 1.5.
alpha (float, optional): Alpha parameter of beta distribution.
Defaults to 1.5.
beta (float, optional): Beta parameter of beta distribution.
Defaults to 1.5.
"""
super(MixupImage, self).__init__()
if alpha <= 0.0:
@ -1385,18 +1418,25 @@ class RandomDistort(Transform):
Random color distortion.
Args:
brightness_range(float, optional): Range of brightness distortion. Defaults to .5.
brightness_prob(float, optional): Probability of brightness distortion. Defaults to .5.
contrast_range(float, optional): Range of contrast distortion. Defaults to .5.
contrast_prob(float, optional): Probability of contrast distortion. Defaults to .5.
saturation_range(float, optional): Range of saturation distortion. Defaults to .5.
saturation_prob(float, optional): Probability of saturation distortion. Defaults to .5.
hue_range(float, optional): Range of hue distortion. Defaults to .5.
hue_prob(float, optional): Probability of hue distortion. Defaults to .5.
random_apply (bool, optional): whether to apply in random (yolo) or fixed (SSD)
order. Defaults to True.
count (int, optional): the number of doing distortion. Defaults to 4.
shuffle_channel (bool, optional): whether to swap channels randomly. Defaults to False.
brightness_range (float, optional): Range of brightness distortion.
Defaults to .5.
brightness_prob (float, optional): Probability of brightness distortion.
Defaults to .5.
contrast_range (float, optional): Range of contrast distortion.
Defaults to .5.
contrast_prob (float, optional): Probability of contrast distortion.
Defaults to .5.
saturation_range (float, optional): Range of saturation distortion.
Defaults to .5.
saturation_prob (float, optional): Probability of saturation distortion.
Defaults to .5.
hue_range (float, optional): Range of hue distortion. Defaults to .5.
hue_prob (float, optional): Probability of hue distortion. Defaults to .5.
random_apply (bool, optional): Apply the transformation in random (yolo) or
fixed (SSD) order. Defaults to True.
count (int, optional): Number of distortions to apply. Defaults to 4.
shuffle_channel (bool, optional): Whether to swap channels randomly.
Defaults to False.
"""
def __init__(self,
@ -1632,7 +1672,8 @@ class SelectBand(Transform):
Select a set of bands of input image(s).
Args:
band_list (list, optional): Bands to select (the band index starts with 1). Defaults to [1, 2, 3].
band_list (list, optional): Bands to select (band index starts from 1).
Defaults to [1, 2, 3].
"""
def __init__(self, band_list=[1, 2, 3]):
@ -1653,11 +1694,12 @@ class SelectBand(Transform):
class _PadBox(Transform):
def __init__(self, num_max_boxes=50):
"""
Pad zeros to bboxes if number of bboxes is less than num_max_boxes.
Pad zeros to bboxes if number of bboxes is less than `num_max_boxes`.
Args:
num_max_boxes (int, optional): the max number of bboxes. Defaults to 50.
num_max_boxes (int, optional): Max number of bboxes. Defaults to 50.
"""
self.num_max_boxes = num_max_boxes
super(_PadBox, self).__init__()
@ -1741,7 +1783,8 @@ class RandomSwap(Transform):
Randomly swap multi-temporal images.
Args:
prob (float, optional): Probability of swapping the input images. Default: 0.2.
prob (float, optional): Probability of swapping the input images.
Default: 0.2.
"""
def __init__(self, prob=0.2):

@ -50,8 +50,7 @@ def md5check(fullname, md5sum=None):
def move_and_merge_tree(src, dst):
"""
Move src directory to dst, if dst is already exists,
merge src to dst
Move `src` to `dst`. If `dst` already exists, merge `src` with `dst`.
"""
if not osp.exists(dst):
shutil.move(src, dst)
@ -71,10 +70,10 @@ def move_and_merge_tree(src, dst):
def download(url, path, md5sum=None):
"""
Download from url, save to path.
Download from `url` and save the result to `path`.
url (str): download url
path (str): download to given path
url (str): URL.
path (str): Path to save the downloaded result.
"""
if not osp.exists(path):
os.makedirs(path)
@ -136,7 +135,7 @@ def download(url, path, md5sum=None):
def decompress(fname):
"""
Decompress for zip and tar file
Decompress zip or tar files.
"""
logging.info("Decompressing {}...".format(fname))

@ -22,7 +22,9 @@ import paddle
def get_environ_info():
"""collect environment information"""
"""
Collect environment information.
"""
env_info = dict()
# TODO is_compiled_with_cuda() has not been moved

@ -19,8 +19,8 @@ import numpy as np
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window.
"""
Track a series of values and provide access to smoothed values over window.
"""
def __init__(self, window_size=20):

@ -53,9 +53,9 @@ def get_single_card_bs(batch_size):
# Evaluation of detection task only supports single card with batch size 1
return batch_size
else:
raise Exception("Please support correct batch_size, \
raise ValueError("Please support correct batch_size, \
which can be divided by available cards({}) in {}"
.format(card_num, place))
.format(card_num, place))
def dict2str(dict_input):
@ -113,7 +113,7 @@ class EarlyStop:
self.max = 0
self.thresh = thresh
if patience < 1:
raise Exception("Argument patience should be a positive integer.")
raise ValueError("Argument patience should be a positive integer.")
def __call__(self, current_score):
if self.score is None:

@ -332,7 +332,7 @@ def build_input_from_file(file_list, prefix='', task='auto', label_list=None):
prefix (str, optional): A nonempty `prefix` specifies the directory that stores the images and annotation files. Default: ''.
task (str, optional): Supported values are 'seg', 'det', 'cd', 'clas', and 'auto'. When `task` is set to 'auto', automatically determine the task based on the input.
Default: 'auto'.
label_list (str | None, optional): Path of label_list. Default: None.
label_list (str|None, optional): Path of label_list. Default: None.
Returns:
list: List of samples.

@ -118,7 +118,7 @@ class _CommonTestNamespace:
assertForFloat = self.assertTrue
result_t = type(result)
error_msg = 'Output has diff at place:{}. \nExpect: {} \nBut Got: {} in class {}'
error_msg = "Output has diff at place:{}. \nExpect: {} \nBut Got: {} in class {}"
if result_t in [list, tuple]:
result_t = get_container_type(result)
if result_t in [
@ -144,8 +144,8 @@ class _CommonTestNamespace:
result.shape, self.__class__.__name__))
else:
raise ValueError(
'result type must be str, int, bool, set, np.bool, np.int32, '
'np.int64, np.str, float, np.ndarray, np.float32, np.float64'
"result type must be str, int, bool, set, np.bool, np.int32, "
"np.int64, np.str, float, np.ndarray, np.float32, np.float64"
)
def check_output_equal(self,
@ -157,13 +157,13 @@ class _CommonTestNamespace:
Check whether result and expected result are equal, including shape.
Args:
result: str, int, bool, set, np.ndarray.
result (str|int|bool|set|np.ndarray):
The result needs to be checked.
expected_result: str, int, bool, set, np.ndarray. The type has to be same as result's.
Use the expected result to check result.
rtol: float
expected_result (str|int|bool|set|np.ndarray): The type has to be same as
result's. Use the expected result to check result.
rtol (float, optional):
relative tolerance, default 1.e-5.
atol: float
atol (float, optional):
absolute tolerance, default 1.e-8
"""
@ -178,13 +178,13 @@ class _CommonTestNamespace:
Check whether result and expected result are not equal, including shape.
Args:
result: str, int, bool, set, np.ndarray.
result (str|int|bool|set|np.ndarray):
The result needs to be checked.
expected_result: str, int, bool, set, np.ndarray. The type has to be same as result's.
Use the expected result to check result.
rtol: float
expected_result (str|int|bool|set|np.ndarray): The type has to be same
as result's. Use the expected result to check result.
rtol (float, optional):
relative tolerance, default 1.e-5.
atol: float
atol (float, optional):
absolute tolerance, default 1.e-8
"""

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
"""
@File Description:
# json文件annotations信息,生成统计结果csv,对象框shape、对象看shape比例、对象框起始位置、对象结束位置、对象结束位置、对象类别、单个图像对象数量的分布
python ./coco_tools/json_AnnoSta.py \
@ -24,7 +24,8 @@ python ./coco_tools/json_AnnoSta.py \
--png_cat_path=./anno_sta/annos_cat.png \
--png_objNum_path=./anno_sta/annos_objNum.png \
--get_relative=True
'''
"""
import os
import json
import argparse

@ -11,14 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
"""
@File Description:
# 根据test影像文件夹生成test.json
python ./coco_tools/json_Img2Json.py \
--test_image_path=./test2017 \
--json_train_path=./annotations/instances_val2017.json \
--json_test_path=./test.json
'''
"""
import os, cv2
import json
import argparse

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
"""
@File Description:
# 统计json文件images信息,生成统计结果csv,同时生成图像shape、图像shape比例的二维分布图
python ./coco_tools/json_ImgSta.py \
@ -19,7 +19,7 @@ python ./coco_tools/json_ImgSta.py \
--csv_path=./img_sta/images.csv \
--png_shape_path=./img_sta/images_shape.png \
--png_shapeRate_path=./img_sta/images_shapeRate.png
'''
"""
import json
import argparse

@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
"""
@File Description:
# 输出json文件基本信息
python ./coco_tools/json_InfoShow.py \
--json_path=./annotations/instances_val2017.json \
--show_num 5
'''
"""
import json
import argparse

@ -11,14 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
"""
@File Description:
# 合并json文件,可以通过merge_keys控制合并的字段, 默认合并'images', 'annotations'字段
python ./coco_tools/json_Merge.py \
--json1_path=./annotations/instances_train2017.json \
--json2_path=./annotations/instances_val2017.json \
--save_path=./instances_trainval2017.json
'''
"""
import json
import argparse

@ -11,14 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
"""
@File Description:
# json数据集划分,可以通过val_split_rate、val_split_num控制划分比例或个数, keep_val_inTrain可以设定是否在train中保留val相关信息
python ./coco_tools/json_Split.py \
--json_all_path=./annotations/instances_val2017.json \
--json_train_path=./instances_val2017_train.json \
--json_val_path=./instances_val2017_val.json
'''
"""
import json
import argparse

@ -22,7 +22,7 @@ from utils import Raster, raster2uint8, save_geotiff, time_it
class MatchError(Exception):
def __str__(self):
return "Cannot match two images."
return "Cannot match the two images."
def _calcu_tf(im1, im2):

@ -48,21 +48,23 @@ def _get_type(type_name: str) -> int:
class Raster:
def __init__(self,
path: Optional[str],
path: str,
gdal_obj: Optional[gdal.Dataset]=None,
band_list: Union[List[int], Tuple[int], None]=None,
to_uint8: bool=False) -> None:
"""
Class of raster reader.
Reader of raster files.
Args:
path (Optional[str]): Path of raster file.
gdal_obj (Optional[Any], optional): GDAL dataset. Defaults to None.
band_list (Union[List[int], Tuple[int], None], optional):
Select a set of bands (the band index starts from 1) or None (read all bands). Defaults to None.
to_uint8 (bool, optional):
Whether to convert data type to uint8. Defaults to False.
path (str): Path of raster file.
gdal_obj (gdal.Dataset|None, optional): GDAL dataset. Defaults to None.
band_list (list[int] | tuple[int] | None, optional): Select a set of
bands (the band index starts from 1). If None, read all bands.
Defaults to None.
to_uint8 (bool, optional): Whether to convert data type to uint8.
Defaults to False.
"""
super(Raster, self).__init__()
if path is not None:
if osp.exists(path):
@ -92,13 +94,15 @@ class Raster:
self._getType()
def setBands(self, band_list: Union[List[int], Tuple[int], None]) -> None:
"""
"""
Set bands of data.
Args:
band_list (Union[List[int], Tuple[int], None]):
Select a set of bands (the band index starts from 1) or None (read all bands). Defaults to None.
band_list (list[int] | tuple[int] | None, optional): Select a set of
bands (the band index starts from 1). If None, read all bands.
Defaults to None.
"""
if band_list is not None:
if len(band_list) > self.bands:
raise ValueError(
@ -113,18 +117,19 @@ class Raster:
start_loc: Union[List[int], Tuple[int, int], None]=None,
block_size: Union[List[int], Tuple[int, int]]=[512, 512]
) -> np.ndarray:
"""
"""
Fetch data in a ndarray.
Args:
start_loc (Union[List[int], Tuple[int], None], optional):
Coordinates of the upper left corner of the block. None value means returning full image.
block_size (Union[List[int], Tuple[int]], optional):
Block size. Defaults to [512, 512].
start_loc (list[int] | tuple[int] | None, optional): Coordinates of the
upper left corner of the block. None value means returning full image.
block_size (list[int] | tuple[int], optional): Block size.
Defaults to [512, 512].
Returns:
np.ndarray: data's ndarray.
"""
if self._src_data is not None:
if start_loc is None:
return self._getArray()

@ -54,14 +54,14 @@ eval_transforms = T.Compose([
])
# 分别构建训练和验证所用的数据集
train_dataset = pdrs.datasets.VOCDetection(
train_dataset = pdrs.datasets.VOCDetDataset(
data_dir=DATA_DIR,
file_list=TRAIN_FILE_LIST_PATH,
label_list=LABEL_LIST_PATH,
transforms=train_transforms,
shuffle=True)
eval_dataset = pdrs.datasets.VOCDetection(
eval_dataset = pdrs.datasets.VOCDetDataset(
data_dir=DATA_DIR,
file_list=EVAL_FILE_LIST_PATH,
label_list=LABEL_LIST_PATH,

@ -54,14 +54,14 @@ eval_transforms = T.Compose([
])
# 分别构建训练和验证所用的数据集
train_dataset = pdrs.datasets.VOCDetection(
train_dataset = pdrs.datasets.VOCDetDataset(
data_dir=DATA_DIR,
file_list=TRAIN_FILE_LIST_PATH,
label_list=LABEL_LIST_PATH,
transforms=train_transforms,
shuffle=True)
eval_dataset = pdrs.datasets.VOCDetection(
eval_dataset = pdrs.datasets.VOCDetDataset(
data_dir=DATA_DIR,
file_list=EVAL_FILE_LIST_PATH,
label_list=LABEL_LIST_PATH,

@ -54,14 +54,14 @@ eval_transforms = T.Compose([
])
# 分别构建训练和验证所用的数据集
train_dataset = pdrs.datasets.VOCDetection(
train_dataset = pdrs.datasets.VOCDetDataset(
data_dir=DATA_DIR,
file_list=TRAIN_FILE_LIST_PATH,
label_list=LABEL_LIST_PATH,
transforms=train_transforms,
shuffle=True)
eval_dataset = pdrs.datasets.VOCDetection(
eval_dataset = pdrs.datasets.VOCDetDataset(
data_dir=DATA_DIR,
file_list=EVAL_FILE_LIST_PATH,
label_list=LABEL_LIST_PATH,

@ -54,14 +54,14 @@ eval_transforms = T.Compose([
])
# 分别构建训练和验证所用的数据集
train_dataset = pdrs.datasets.VOCDetection(
train_dataset = pdrs.datasets.VOCDetDataset(
data_dir=DATA_DIR,
file_list=TRAIN_FILE_LIST_PATH,
label_list=LABEL_LIST_PATH,
transforms=train_transforms,
shuffle=True)
eval_dataset = pdrs.datasets.VOCDetection(
eval_dataset = pdrs.datasets.VOCDetDataset(
data_dir=DATA_DIR,
file_list=EVAL_FILE_LIST_PATH,
label_list=LABEL_LIST_PATH,

@ -54,14 +54,14 @@ eval_transforms = T.Compose([
])
# 分别构建训练和验证所用的数据集
train_dataset = pdrs.datasets.VOCDetection(
train_dataset = pdrs.datasets.VOCDetDataset(
data_dir=DATA_DIR,
file_list=TRAIN_FILE_LIST_PATH,
label_list=LABEL_LIST_PATH,
transforms=train_transforms,
shuffle=True)
eval_dataset = pdrs.datasets.VOCDetection(
eval_dataset = pdrs.datasets.VOCDetDataset(
data_dir=DATA_DIR,
file_list=EVAL_FILE_LIST_PATH,
label_list=LABEL_LIST_PATH,

Loading…
Cancel
Save