Update rs example

own
Bobholamovic 2 years ago
parent 6479c51d4f
commit f14c45d89e
  1. 18
      examples/rs_research/bootstrap.py
  2. 2
      examples/rs_research/config_utils.py
  3. 6
      examples/rs_research/predict_cd.py
  4. 3
      examples/rs_research/run_task.py
  5. 3
      examples/rs_research/tools/analyze_model.py
  6. 3
      examples/rs_research/tools/visualize_feats.py
  7. 1
      examples/rs_research/train_cd.py

@ -0,0 +1,18 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from custom_model import CustomModel
from custom_trainer import make_trainer
make_trainer(CustomModel)

@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

@ -23,8 +23,7 @@ import paddle
import paddlers
from tqdm import tqdm
from custom_model import CustomModel
from custom_trainer import make_trainer
import bootstrap
def read_file_list(file_list, sep=' '):
@ -57,9 +56,6 @@ def parse_args():
if __name__ == '__main__':
args = parse_args()
# 注册训练器
make_trainer(CustomModel)
model = paddlers.tasks.load_model(args.model_dir)
if not osp.exists(args.save_dir):

@ -21,8 +21,7 @@ import paddle
import paddlers
from paddlers import transforms as T
import custom_model
import custom_trainer
import bootstrap
from config_utils import parse_args, build_objects, CfgNode

@ -30,8 +30,7 @@ from paddle.hapi.static_flops import Table
_dir = osp.dirname(osp.abspath(__file__))
sys.path.append(osp.abspath(osp.join(_dir, '../')))
import custom_model
import custom_trainer
import bootstrap
def parse_args():

@ -28,8 +28,7 @@ from sklearn.decomposition import PCA
_dir = osp.dirname(osp.abspath(__file__))
sys.path.append(osp.abspath(osp.join(_dir, '../')))
import custom_model
import custom_trainer
import bootstrap
FILENAME_PATTERN = "{key}_{idx}_vis.png"

@ -93,6 +93,7 @@ model.train(
train_dataset=train_dataset,
train_batch_size=8,
eval_dataset=val_dataset,
optimizer=optimizer,
# 每多少个epoch验证并保存一次模型
save_interval_epochs=5,
# 每多少次迭代记录一次日志

Loading…
Cancel
Save