diff --git a/examples/rs_research/config_utils.py b/examples/rs_research/config_utils.py index 10e7129..62996fe 100644 --- a/examples/rs_research/config_utils.py +++ b/examples/rs_research/config_utils.py @@ -133,6 +133,7 @@ def parse_args(*args, **kwargs): # Global settings parser.add_argument('cmd', choices=['train', 'eval']) parser.add_argument('task', choices=['cd', 'clas', 'det', 'res', 'seg']) + parser.add_argument('--seed', type=int, default=None) # Data parser.add_argument('--datasets', type=dict, default={}) diff --git a/examples/rs_research/run_task.py b/examples/rs_research/run_task.py index a487c45..cab073a 100644 --- a/examples/rs_research/run_task.py +++ b/examples/rs_research/run_task.py @@ -15,7 +15,9 @@ # limitations under the License. import os +import random +import numpy as np # Import cv2 and sklearn before paddlers to solve the # "ImportError: dlopen: cannot load any more object with static TLS" issue. import cv2 @@ -62,6 +64,11 @@ if __name__ == '__main__': cfg = parse_args() print(format_cfg(cfg)) + if cfg['seed'] is not None: + random.seed(cfg['seed']) + np.random.seed(cfg['seed']) + paddle.seed(cfg['seed']) + # Automatically download data if cfg['download_on']: paddlers.utils.download_and_decompress( diff --git a/test_tipc/config_utils.py b/test_tipc/config_utils.py index 6e677b4..3fb17bd 100644 --- a/test_tipc/config_utils.py +++ b/test_tipc/config_utils.py @@ -119,6 +119,7 @@ def parse_args(*args, **kwargs): # Global settings parser.add_argument('cmd', choices=['train', 'eval']) parser.add_argument('task', choices=['cd', 'clas', 'det', 'res', 'seg']) + parser.add_argument('--seed', type=int, default=None) # Data parser.add_argument('--datasets', type=dict, default={}) diff --git a/test_tipc/run_task.py b/test_tipc/run_task.py index 923415b..7c1d40d 100644 --- a/test_tipc/run_task.py +++ b/test_tipc/run_task.py @@ -1,7 +1,9 @@ #!/usr/bin/env python import os +import random +import numpy as np # Import cv2 and sklearn before paddlers to solve the # "ImportError: dlopen: cannot load any more object with static TLS" issue. import cv2 @@ -46,6 +48,11 @@ if __name__ == '__main__': cfg = parse_args() print(format_cfg(cfg)) + if cfg['seed'] is not None: + random.seed(cfg['seed']) + np.random.seed(cfg['seed']) + paddle.seed(cfg['seed']) + # Automatically download data if cfg['download_on']: paddlers.utils.download_and_decompress(