Add fix seed

own
Bobholamovic 2 years ago
parent 3929c7e26e
commit 1ef0067ad9
  1. 1
      examples/rs_research/config_utils.py
  2. 7
      examples/rs_research/run_task.py
  3. 1
      test_tipc/config_utils.py
  4. 7
      test_tipc/run_task.py

@ -133,6 +133,7 @@ def parse_args(*args, **kwargs):
# Global settings # Global settings
parser.add_argument('cmd', choices=['train', 'eval']) parser.add_argument('cmd', choices=['train', 'eval'])
parser.add_argument('task', choices=['cd', 'clas', 'det', 'res', 'seg']) parser.add_argument('task', choices=['cd', 'clas', 'det', 'res', 'seg'])
parser.add_argument('--seed', type=int, default=None)
# Data # Data
parser.add_argument('--datasets', type=dict, default={}) parser.add_argument('--datasets', type=dict, default={})

@ -15,7 +15,9 @@
# limitations under the License. # limitations under the License.
import os import os
import random
import numpy as np
# Import cv2 and sklearn before paddlers to solve the # Import cv2 and sklearn before paddlers to solve the
# "ImportError: dlopen: cannot load any more object with static TLS" issue. # "ImportError: dlopen: cannot load any more object with static TLS" issue.
import cv2 import cv2
@ -62,6 +64,11 @@ if __name__ == '__main__':
cfg = parse_args() cfg = parse_args()
print(format_cfg(cfg)) print(format_cfg(cfg))
if cfg['seed'] is not None:
random.seed(cfg['seed'])
np.random.seed(cfg['seed'])
paddle.seed(cfg['seed'])
# Automatically download data # Automatically download data
if cfg['download_on']: if cfg['download_on']:
paddlers.utils.download_and_decompress( paddlers.utils.download_and_decompress(

@ -119,6 +119,7 @@ def parse_args(*args, **kwargs):
# Global settings # Global settings
parser.add_argument('cmd', choices=['train', 'eval']) parser.add_argument('cmd', choices=['train', 'eval'])
parser.add_argument('task', choices=['cd', 'clas', 'det', 'res', 'seg']) parser.add_argument('task', choices=['cd', 'clas', 'det', 'res', 'seg'])
parser.add_argument('--seed', type=int, default=None)
# Data # Data
parser.add_argument('--datasets', type=dict, default={}) parser.add_argument('--datasets', type=dict, default={})

@ -1,7 +1,9 @@
#!/usr/bin/env python #!/usr/bin/env python
import os import os
import random
import numpy as np
# Import cv2 and sklearn before paddlers to solve the # Import cv2 and sklearn before paddlers to solve the
# "ImportError: dlopen: cannot load any more object with static TLS" issue. # "ImportError: dlopen: cannot load any more object with static TLS" issue.
import cv2 import cv2
@ -46,6 +48,11 @@ if __name__ == '__main__':
cfg = parse_args() cfg = parse_args()
print(format_cfg(cfg)) print(format_cfg(cfg))
if cfg['seed'] is not None:
random.seed(cfg['seed'])
np.random.seed(cfg['seed'])
paddle.seed(cfg['seed'])
# Automatically download data # Automatically download data
if cfg['download_on']: if cfg['download_on']:
paddlers.utils.download_and_decompress( paddlers.utils.download_and_decompress(

Loading…
Cancel
Save