own
juncaipeng 3 years ago
commit 806335ad29
  1. 1
      paddlers/datasets/__init__.py
  2. 139
      paddlers/datasets/raster.py
  3. 2
      paddlers/tools/yolo_cluster.py
  4. 4
      paddlers/transforms/batch_operators.py
  5. 5
      paddlers/transforms/img_decoder.py
  6. 16
      paddlers/transforms/operators.py
  7. 1
      paddlers/utils/__init__.py
  8. 95
      paddlers/utils/convert.py
  9. 4
      requirements.txt
  10. 53
      tutorials/train/README.md
  11. 54
      tutorials/train/object_detection/ppyolo.py

@ -1,2 +1,3 @@
from .voc import VOCDetection
from .seg_dataset import SegDataset
from .raster import Raster

@ -0,0 +1,139 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
import numpy as np
from typing import List, Tuple, Union
from paddlers.utils import raster2uint8
try:
from osgeo import gdal
except:
import gdal
class Raster:
def __init__(self,
path: str,
band_list: Union[List[int], Tuple[int], None]=None,
to_uint8: bool=False) -> None:
""" Class of read raster.
Args:
path (str): The path of raster.
band_list (Union[List[int], Tuple[int], None], optional):
band list (start with 1) or None (all of bands). Defaults to None.
to_uint8 (bool, optional):
Convert uint8 or return raw data. Defaults to False.
"""
super(Raster, self).__init__()
if osp.exists(path):
self.path = path
self.__src_data = np.load(path) if path.split(".")[-1] == "npy" \
else gdal.Open(path)
self.__getInfo()
self.to_uint8 = to_uint8
self.setBands(band_list)
else:
raise ValueError("The path {0} not exists.".format(path))
def setBands(self,
band_list: Union[List[int], Tuple[int], None]) -> None:
""" Set band of data.
Args:
band_list (Union[List[int], Tuple[int], None]):
band list (start with 1) or None (all of bands).
"""
if band_list is not None:
if len(band_list) > self.bands:
raise ValueError("The lenght of band_list must be less than {0}.".format(str(self.bands)))
if max(band_list) > self.bands or min(band_list) < 1:
raise ValueError("The range of band_list must within [1, {0}].".format(str(self.bands)))
self.band_list = band_list
def getArray(self,
start_loc: Union[List[int], Tuple[int], None]=None,
block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray:
""" Get ndarray data
Args:
start_loc (Union[List[int], Tuple[int], None], optional):
Coordinates of the upper left corner of the block, if None means return full image.
block_size (Union[List[int], Tuple[int]], optional):
Block size. Defaults to [512, 512].
Returns:
np.ndarray: data's ndarray.
"""
if start_loc is None:
return self.__getAarray()
else:
return self.__getBlock(start_loc, block_size)
def __getInfo(self) -> None:
self.bands = self.__src_data.RasterCount
self.width = self.__src_data.RasterXSize
self.height = self.__src_data.RasterYSize
def __getAarray(self, window: Union[None, List[int], Tuple[int]]=None) -> np.ndarray:
if window is not None:
xoff, yoff, xsize, ysize = window
if self.band_list is None:
if window is None:
ima = self.__src_data.ReadAsArray()
else:
ima = self.__src_data.ReadAsArray(xoff, yoff, xsize, ysize)
else:
band_array = []
for b in self.band_list:
if window is None:
band_i = self.__src_data.GetRasterBand(b).ReadAsArray()
else:
band_i = self.__src_data.GetRasterBand(b).ReadAsArray(xoff, yoff, xsize, ysize)
band_array.append(band_i)
ima = np.stack(band_array, axis=0)
if self.bands == 1:
# the type is complex means this is a SAR data
if isinstance(type(ima[0, 0]), complex):
ima = abs(ima)
else:
ima = ima.transpose((1, 2, 0))
if self.to_uint8 is True:
ima = raster2uint8(ima)
return ima
def __getBlock(self,
start_loc: Union[List[int], Tuple[int]],
block_size: Union[List[int], Tuple[int]]=[512, 512]) -> np.ndarray:
if len(start_loc) != 2 or len(block_size) != 2:
raise ValueError("The length start_loc/block_size must be 2.")
xoff, yoff = start_loc
xsize, ysize = block_size
if (xoff < 0 or xoff > self.width) or (yoff < 0 or yoff > self.height):
raise ValueError(
"start_loc must be within [0-{0}, 0-{1}].".format(str(self.width), str(self.height)))
if xoff + xsize > self.width:
xsize = self.width - xoff
if yoff + ysize > self.height:
ysize = self.height - yoff
ima = self.__getAarray([int(xoff), int(yoff), int(xsize), int(ysize)])
h, w = ima.shape[:2] if len(ima.shape) == 3 else ima.shape
if self.bands != 1:
tmp = np.zeros((block_size[0], block_size[1], self.bands), dtype=ima.dtype)
tmp[:h, :w, :] = ima
else:
tmp = np.zeros((block_size[0], block_size[1]), dtype=ima.dtype)
tmp[:h, :w] = ima
return tmp

@ -99,7 +99,7 @@ class YOLOAnchorCluster(BaseAnchorCluster):
num_anchors (int): number of clusters
dataset (DataSet): DataSet instance, VOC or COCO
image_size (list or int): [h, w], being an int means image height and image width are the same.
cache (bool): whether using cache Defaults to True.
cache (bool): whether using cache. Defaults to True.
cache_path (str or None, optional): cache directory path. If None, use `data_dir` of dataset. Defaults to None.
iters (int, optional): iters of kmeans algorithm. Defaults to 300.
gen_iters (int, optional): iters of genetic algorithm. Defaults to 1000.

@ -69,7 +69,7 @@ class BatchRandomResize(Transform):
"""
Resize a batch of input to random sizes.
AttentionIf interp is 'RANDOM', the interpolation method will be chose randomly.
Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
Args:
target_sizes (List[int], List[list or tuple] or Tuple[list or tuple]):
@ -108,7 +108,7 @@ class BatchRandomResize(Transform):
class BatchRandomResizeByShort(Transform):
"""Resize a batch of input to random sizes with keeping the aspect ratio.
AttentionIf interp is 'RANDOM', the interpolation method will be chose randomly.
Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
Args:
short_sizes (List[int], Tuple[int]): Target sizes of the shorter side of the image(s).

@ -1,5 +1,3 @@
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -21,6 +19,7 @@ import copy
import random
import imghdr
from PIL import Image
try:
from collections.abc import Sequence
except Exception:
@ -103,7 +102,7 @@ class ImgDecode(Transform):
return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR)
else:
return cv2.imread(im_file, cv2.IMREAD_ANYDEPTH |
return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
cv2.IMREAD_ANYCOLOR)
elif ext == '.npy':
return np.load(img_path)

@ -236,9 +236,9 @@ class Resize(Transform):
"""
Resize input.
- If target_size is an intresize the image(s) to (target_size, target_size).
- If target_size is an int, resize the image(s) to (target_size, target_size).
- If target_size is a list or tuple, resize the image(s) to target_size.
AttentionIf interp is 'RANDOM', the interpolation method will be chose randomly.
Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
Args:
target_size (int, List[int] or Tuple[int]): Target size. If int, the height and width share the same target_size.
@ -347,7 +347,7 @@ class RandomResize(Transform):
"""
Resize input to random sizes.
AttentionIf interp is 'RANDOM', the interpolation method will be chose randomly.
Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
Args:
target_sizes (List[int], List[list or tuple] or Tuple[list or tuple]):
@ -388,7 +388,7 @@ class ResizeByShort(Transform):
"""
Resize input with keeping the aspect ratio.
AttentionIf interp is 'RANDOM', the interpolation method will be chose randomly.
Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
Args:
short_size (int): Target size of the shorter side of the image(s).
@ -427,7 +427,7 @@ class RandomResizeByShort(Transform):
"""
Resize input to random sizes with keeping the aspect ratio.
AttentionIf interp is 'RANDOM', the interpolation method will be chose randomly.
Attention: If interp is 'RANDOM', the interpolation method will be chose randomly.
Args:
short_sizes (List[int]): Target size of the shorter side of the image(s).
@ -865,8 +865,8 @@ class RandomCrop(Transform):
class RandomScaleAspect(Transform):
"""
Crop input image(s) and resize back to original sizes.
Args
min_scale (float)Minimum ratio between the cropped region and the original image.
Args:
min_scale (float): Minimum ratio between the cropped region and the original image.
If 0, image(s) will not be cropped. Defaults to .5.
aspect_ratio (float): Aspect ratio of cropped region. Defaults to .33.
"""
@ -1262,7 +1262,7 @@ class RandomBlur(Transform):
"""
Randomly blur input image(s).
Args
Args:
prob (float): Probability of blurring.
"""

@ -22,3 +22,4 @@ from .env import get_environ_info, get_num_workers, init_parallel_env
from .download import download_and_decompress, decompress
from .stats import SmoothedValue, TrainingStats
from .shm import _get_shared_memory_size_in_M
from .convert import raster2uint8

@ -0,0 +1,95 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import operator
from functools import reduce
def raster2uint8(image: np.ndarray) -> np.ndarray:
""" Convert raster to uint8.
Args:
image (np.ndarray): image.
Returns:
np.ndarray: image on uint8.
"""
dtype = image.dtype.name
dtypes = ["uint8", "uint16", "float32"]
if dtype not in dtypes:
raise ValueError(f"'dtype' must be uint8/uint16/float32, not {dtype}.")
if dtype == "uint8":
return image
else:
if dtype == "float32":
image = _sample_norm(image)
return _two_percentLinear(image)
# 2% linear stretch
def _two_percentLinear(image: np.ndarray, max_out: int=255, min_out: int=0) -> np.ndarray:
def _gray_process(gray, maxout=max_out, minout=min_out):
# get the corresponding gray level at 98% histogram
high_value = np.percentile(gray, 98)
low_value = np.percentile(gray, 2)
truncated_gray = np.clip(gray, a_min=low_value, a_max=high_value)
processed_gray = ((truncated_gray - low_value) / (high_value - low_value)) * (maxout - minout)
return processed_gray
if len(image.shape) == 3:
processes = []
for b in range(image.shape[-1]):
processes.append(_gray_process(image[:, :, b]))
result = np.stack(processes, axis=2)
else: # if len(image.shape) == 2
result = _gray_process(image)
return np.uint8(result)
# simple image standardization
def _sample_norm(image: np.ndarray, NUMS: int=65536) -> np.ndarray:
stretches = []
if len(image.shape) == 3:
for b in range(image.shape[-1]):
stretched = _stretch(image[:, :, b], NUMS)
stretched /= float(NUMS)
stretches.append(stretched)
stretched_img = np.stack(stretches, axis=2)
else: # if len(image.shape) == 2
stretched_img = _stretch(image, NUMS)
return np.uint8(stretched_img * 255)
# histogram equalization
def _stretch(ima: np.ndarray, NUMS: int) -> np.ndarray:
hist = _histogram(ima, NUMS)
lut = []
for bt in range(0, len(hist), NUMS):
# step size
step = reduce(operator.add, hist[bt : bt + NUMS]) / (NUMS - 1)
# create balanced lookup table
n = 0
for i in range(NUMS):
lut.append(n / step)
n += hist[i + bt]
np.take(lut, ima, out=ima)
return ima
# calculate histogram
def _histogram(ima: np.ndarray, NUMS: int) -> np.ndarray:
bins = list(range(0, NUMS))
flat = ima.flat
n = np.searchsorted(np.sort(flat), bins)
n = np.concatenate([n, [len(flat)]])
hist = n[1:] - n[:-1]
return hist

@ -8,10 +8,10 @@ paddleslim == 2.2.1
shapely
paddlepaddle-gpu >= 2.2.0
opencv-python
scikit-learn==0.20.3
scikit-learn == 0.20.3
lap
motmetrics
matplotlib
chardet
openpyxl
gdal
GDAL >= 3.2.2

@ -0,0 +1,53 @@
# 使用教程——训练模型
本目录下整理了使用PaddleRS训练模型的示例代码,代码中均提供了示例数据的自动下载,并均使用单张GPU卡进行训练。
|代码 | 模型任务 | 数据 |
|------|--------|---------|
|object_detection/ppyolo.py | 目标检测PPYOLO | 昆虫检测 |
|semantic_segmentation/deeplabv3p_resnet50_vd.py | 语义分割DeepLabV3 | 视盘分割 |
<!-- 可参考API接口说明了解示例代码中的API:
* [数据集读取API](../../docs/apis/datasets.md)
* [数据预处理和数据增强API](../../docs/apis/transforms/transforms.md)
* [模型API/模型加载API](../../docs/apis/models/README.md)
* [预测结果可视化API](../../docs/apis/visualize.md) -->
# 环境准备
- [PaddlePaddle安装](https://www.paddlepaddle.org.cn/install/quick)
* 版本要求:PaddlePaddle>=2.1.0
<!-- - [PaddleRS安装](../../docs/install.md) -->
## 开始训练
* 修改tutorials/train/semantic_segmentation/deeplabv3p_resnet50_vd.py中sys.path路径
```
sys.path.append("your/PaddleRS/path")
```
* 在安装PaddleRS后,使用如下命令开始训练,代码会自动下载训练数据, 并均使用单张GPU卡进行训练。
```commandline
export CUDA_VISIBLE_DEVICES=0
python tutorials/train/semantic_segmentation/deeplabv3p_resnet50_vd.py
```
* 若需使用多张GPU卡进行训练,例如使用2张卡时执行:
```commandline
python -m paddle.distributed.launch --gpus 0,1 tutorials/train/semantic_segmentation/deeplabv3p_resnet50_vd.py
```
使用多卡时,参考[训练参数调整](../../docs/parameters.md)调整学习率和批量大小。
## VisualDL可视化训练指标
在模型训练过程,在`train`函数中,将`use_vdl`设为True,则训练过程会自动将训练日志以VisualDL的格式打点在`save_dir`(用户自己指定的路径)下的`vdl_log`目录,用户可以使用如下命令启动VisualDL服务,查看可视化指标
```commandline
visualdl --logdir output/deeplabv3p_resnet50_vd/vdl_log --port 8001
```
服务启动后,使用浏览器打开 https://0.0.0.0:8001 或 https://localhost:8001

@ -0,0 +1,54 @@
import sys
sys.path.append("/ssd2/pengjuncai/PaddleRS")
import paddlers as pdrs
from paddlers import transforms as T
train_transforms = T.Compose([
T.MixupImage(mixup_epoch=-1), T.RandomDistort(),
T.RandomExpand(im_padding_value=[123.675, 116.28, 103.53]), T.RandomCrop(),
T.RandomHorizontalFlip(), T.BatchRandomResize(
target_sizes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608],
interp='RANDOM'), T.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
eval_transforms = T.Compose([
T.Resize(
target_size=608, interp='CUBIC'), T.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = pdrs.datasets.VOCDetection(
data_dir='insect_det',
file_list='insect_det/train_list.txt',
label_list='insect_det/labels.txt',
transforms=train_transforms,
shuffle=True)
eval_dataset = pdrs.datasets.VOCDetection(
data_dir='insect_det',
file_list='insect_det/val_list.txt',
label_list='insect_det/labels.txt',
transforms=eval_transforms,
shuffle=False)
num_classes = len(train_dataset.labels)
model = pdrs.tasks.det.PPYOLO(num_classes=num_classes, backbone='ResNet50_vd_dcn')
model.train(
num_epochs=200,
train_dataset=train_dataset,
train_batch_size=8,
eval_dataset=eval_dataset,
pretrain_weights='COCO',
learning_rate=0.005 / 12,
warmup_steps=500,
warmup_start_lr=0.0,
save_interval_epochs=5,
lr_decay_epochs=[85, 135],
save_dir='output/ppyolo_r50vd_dcn',
use_vdl=True)
Loading…
Cancel
Save