Refactor run_task.py

own
Bobholamovic 2 years ago
parent 09d3dd1202
commit 5a24513136
  1. 28
      examples/rs_research/run_task.py
  2. 28
      test_tipc/run_task.py

@ -53,6 +53,17 @@ if __name__ == '__main__':
paddlers.utils.download_and_decompress(
cfg['download_url'], path=cfg['download_path'])
if not isinstance(cfg['datasets']['eval'].args, dict):
raise ValueError("args of eval dataset must be a dict!")
if cfg['datasets']['eval'].args.get('transforms', None) is not None:
raise ValueError(
"Found key 'transforms' in args of eval dataset and the value is not None."
)
eval_transforms = T.Compose(build_objects(cfg['transforms']['eval'], mod=T))
# Inplace modification
cfg['datasets']['eval'].args['transforms'] = eval_transforms
eval_dataset = build_objects(cfg['datasets']['eval'], mod=paddlers.datasets)
if cfg['cmd'] == 'train':
if not isinstance(cfg['datasets']['train'].args, dict):
raise ValueError("args of train dataset must be a dict!")
@ -67,21 +78,8 @@ if __name__ == '__main__':
cfg['datasets']['train'].args['transforms'] = train_transforms
train_dataset = build_objects(
cfg['datasets']['train'], mod=paddlers.datasets)
if not isinstance(cfg['datasets']['eval'].args, dict):
raise ValueError("args of eval dataset must be a dict!")
if cfg['datasets']['eval'].args.get('transforms', None) is not None:
raise ValueError(
"Found key 'transforms' in args of eval dataset and the value is not None."
)
eval_transforms = T.Compose(build_objects(cfg['transforms']['eval'], mod=T))
# Inplace modification
cfg['datasets']['eval'].args['transforms'] = eval_transforms
eval_dataset = build_objects(cfg['datasets']['eval'], mod=paddlers.datasets)
model = build_objects(
cfg['model'], mod=getattr(paddlers.tasks, cfg['task']))
if cfg['cmd'] == 'train':
if cfg['optimizer']:
if len(cfg['optimizer'].args) == 0:
cfg['optimizer'].args = {}
@ -112,8 +110,6 @@ if __name__ == '__main__':
resume_checkpoint=cfg['resume_checkpoint'] or None,
**cfg['train'])
elif cfg['cmd'] == 'eval':
state_dict = paddle.load(
os.path.join(cfg['resume_checkpoint'], 'model.pdparams'))
model.net.set_state_dict(state_dict)
model = paddlers.tasks.load_model(cfg['resume_checkpoint'])
res = model.evaluate(eval_dataset)
print(res)

@ -51,6 +51,17 @@ if __name__ == '__main__':
paddlers.utils.download_and_decompress(
cfg['download_url'], path=cfg['download_path'])
if not isinstance(cfg['datasets']['eval'].args, dict):
raise ValueError("args of eval dataset must be a dict!")
if cfg['datasets']['eval'].args.get('transforms', None) is not None:
raise ValueError(
"Found key 'transforms' in args of eval dataset and the value is not None."
)
eval_transforms = T.Compose(build_objects(cfg['transforms']['eval'], mod=T))
# Inplace modification
cfg['datasets']['eval'].args['transforms'] = eval_transforms
eval_dataset = build_objects(cfg['datasets']['eval'], mod=paddlers.datasets)
if cfg['cmd'] == 'train':
if not isinstance(cfg['datasets']['train'].args, dict):
raise ValueError("args of train dataset must be a dict!")
@ -65,21 +76,8 @@ if __name__ == '__main__':
cfg['datasets']['train'].args['transforms'] = train_transforms
train_dataset = build_objects(
cfg['datasets']['train'], mod=paddlers.datasets)
if not isinstance(cfg['datasets']['eval'].args, dict):
raise ValueError("args of eval dataset must be a dict!")
if cfg['datasets']['eval'].args.get('transforms', None) is not None:
raise ValueError(
"Found key 'transforms' in args of eval dataset and the value is not None."
)
eval_transforms = T.Compose(build_objects(cfg['transforms']['eval'], mod=T))
# Inplace modification
cfg['datasets']['eval'].args['transforms'] = eval_transforms
eval_dataset = build_objects(cfg['datasets']['eval'], mod=paddlers.datasets)
model = build_objects(
cfg['model'], mod=getattr(paddlers.tasks, cfg['task']))
if cfg['cmd'] == 'train':
if cfg['optimizer']:
if len(cfg['optimizer'].args) == 0:
cfg['optimizer'].args = {}
@ -110,8 +108,6 @@ if __name__ == '__main__':
resume_checkpoint=cfg['resume_checkpoint'] or None,
**cfg['train'])
elif cfg['cmd'] == 'eval':
state_dict = paddle.load(
os.path.join(cfg['resume_checkpoint'], 'model.pdparams'))
model.net.set_state_dict(state_dict)
model = paddlers.tasks.load_model(cfg['resume_checkpoint'])
res = model.evaluate(eval_dataset)
print(res)

Loading…
Cancel
Save