From 7d25f5855e80ae55215df40b3afca8a019e5e1a9 Mon Sep 17 00:00:00 2001 From: juncaipeng <13006307475@163.com> Date: Thu, 3 Mar 2022 20:53:21 +0800 Subject: [PATCH] Add installation and readme --- requirements.txt | 1 + setup.py | 43 +++++++++++++ .../train/detection/faster_rcnn_sar_ship.py | 64 +++++++++++++++++++ tutorials/train/detection/readme.md | 25 ++++++++ 4 files changed, 133 insertions(+) create mode 100644 setup.py create mode 100644 tutorials/train/detection/faster_rcnn_sar_ship.py create mode 100644 tutorials/train/detection/readme.md diff --git a/requirements.txt b/requirements.txt index 804ab1b..10fd638 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ motmetrics matplotlib chardet openpyxl +gdal diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..956d0c0 --- /dev/null +++ b/setup.py @@ -0,0 +1,43 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import setuptools + +long_description = "Awesome Remote Sensing Toolkit based on PaddlePaddle" + +setuptools.setup( + name="paddlers", + version='0.0.1', + author="paddlers", + author_email="paddlers@baidu.com", + description=long_description, + long_description=long_description, + long_description_content_type="text/plain", + url="https://github.com/PaddleCV-SIG/PaddleRS", + packages=setuptools.find_packages(), + setup_requires=['cython', 'numpy'], + install_requires=[ + "pycocotools", 'pyyaml', 'colorama', 'tqdm', 'paddleslim==2.2.1', + 'visualdl>=2.2.2', 'shapely>=1.7.0', 'opencv-python', 'scipy', 'lap', + 'motmetrics', 'scikit-learn==0.23.2', 'chardet', 'flask_cors', + 'openpyxl', 'gdal' + ], + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + ], + license='Apache 2.0', + ) + diff --git a/tutorials/train/detection/faster_rcnn_sar_ship.py b/tutorials/train/detection/faster_rcnn_sar_ship.py new file mode 100644 index 0000000..92565e0 --- /dev/null +++ b/tutorials/train/detection/faster_rcnn_sar_ship.py @@ -0,0 +1,64 @@ +import os +import paddlers as pdrs +from paddlers import transforms as T + +# download dataset +data_dir = 'sar_ship_1' +if not os.path.exists(data_dir): + dataset_url = 'https://paddleseg.bj.bcebos.com/dataset/sar_ship_1.tar.gz' + pdrs.utils.download_and_decompress(dataset_url, path='./') + +# define transforms +train_transforms = T.Compose([ + T.RandomDistort(), + T.RandomExpand(im_padding_value=[123.675, 116.28, 103.53]), + T.RandomCrop(), + T.RandomHorizontalFlip(), + T.BatchRandomResize( + target_sizes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608], + interp='RANDOM'), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +]) + +eval_transforms = T.Compose([ + T.Resize( + target_size=608, interp='CUBIC'), T.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +]) + +# define dataset +train_file_list = os.path.join(data_dir, 'train.txt') +val_file_list = os.path.join(data_dir, 'valid.txt') +label_file_list = os.path.join(data_dir, 'labels.txt') +train_dataset = pdrs.datasets.VOCDetection( + data_dir=data_dir, + file_list=train_file_list, + label_list=label_file_list, + transforms=train_transforms, + shuffle=True) + +eval_dataset = pdrs.datasets.VOCDetection( + data_dir=data_dir, + file_list=train_file_list, + label_list=label_file_list, + transforms=eval_transforms, + shuffle=False) + +# define models +num_classes = len(train_dataset.labels) +model = pdrs.tasks.det.FasterRCNN(num_classes=num_classes) + +# train +model.train( + num_epochs=60, + train_dataset=train_dataset, + train_batch_size=2, + eval_dataset=eval_dataset, + pretrain_weights='COCO', + learning_rate=0.005 / 12, + warmup_steps=10, + warmup_start_lr=0.0, + save_interval_epochs=5, + lr_decay_epochs=[20, 40], + save_dir='output/faster_rcnn_sar_ship', + use_vdl=True) diff --git a/tutorials/train/detection/readme.md b/tutorials/train/detection/readme.md new file mode 100644 index 0000000..01e6b66 --- /dev/null +++ b/tutorials/train/detection/readme.md @@ -0,0 +1,25 @@ +Run the detection training demo: + +1, Install PaddleRS + +``` +git clone https://github.com/PaddleCV-SIG/PaddleRS.git +cd PaddleRS +pip install -r requirements.txt +python setup.py install +``` + + +2. Run the demo + +``` +cd tutorials/train/detection/ + +# run training on single GPU +export CUDA_VISIBLE_DEVICES=0 +python faster_rcnn_sar_ship.py + +# run traing on multi gpu +export CUDA_VISIBLE_DEVICES=0,1 +python -m paddle.distributed.launch faster_rcnn_sar_ship.py +```