diff --git a/mmdet/utils/__init__.py b/mmdet/utils/__init__.py index 64f3173d5..4bd1019e9 100644 --- a/mmdet/utils/__init__.py +++ b/mmdet/utils/__init__.py @@ -2,9 +2,9 @@ from .collect_env import collect_env from .logger import get_root_logger from .misc import find_latest_checkpoint +from .setup_env import setup_multi_processes __all__ = [ - 'get_root_logger', - 'collect_env', - 'find_latest_checkpoint', + 'get_root_logger', 'collect_env', 'find_latest_checkpoint', + 'setup_multi_processes' ] diff --git a/mmdet/utils/setup_env.py b/mmdet/utils/setup_env.py new file mode 100644 index 000000000..21def2f08 --- /dev/null +++ b/mmdet/utils/setup_env.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import platform +import warnings + +import cv2 +import torch.multiprocessing as mp + + +def setup_multi_processes(cfg): + """Setup multi-processing environment variables.""" + # set multi-process start method as `fork` to speed up the training + if platform.system() != 'Windows': + mp_start_method = cfg.get('mp_start_method', 'fork') + current_method = mp.get_start_method(allow_none=True) + if current_method is not None and current_method != mp_start_method: + warnings.warn( + f'Multi-processing start method `{mp_start_method}` is ' + f'different from the previous setting `{current_method}`.' + f'It will be force set to `{mp_start_method}`. You can change ' + f'this behavior by changing `mp_start_method` in your config.') + mp.set_start_method(mp_start_method, force=True) + + # disable opencv multithreading to avoid system being overloaded + opencv_num_threads = cfg.get('opencv_num_threads', 0) + cv2.setNumThreads(opencv_num_threads) + + # setup OMP threads + # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa + if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: + omp_num_threads = 1 + warnings.warn( + f'Setting OMP_NUM_THREADS environment variable for each process ' + f'to be {omp_num_threads} in default, to avoid your system being ' + f'overloaded, please further tune the variable for optimal ' + f'performance in your application as needed.') + os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) + + # setup MKL threads + if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: + mkl_num_threads = 1 + warnings.warn( + f'Setting MKL_NUM_THREADS environment variable for each process ' + f'to be {mkl_num_threads} in default, to avoid your system being ' + f'overloaded, please further tune the variable for optimal ' + f'performance in your application as needed.') + os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) diff --git a/tests/test_utils/test_setup_env.py b/tests/test_utils/test_setup_env.py new file mode 100644 index 000000000..70f01b8ac --- /dev/null +++ b/tests/test_utils/test_setup_env.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import multiprocessing as mp +import os +import platform + +import cv2 +from mmcv import Config + +from mmdet.utils import setup_multi_processes + + +def test_setup_multi_processes(): + # temp save system setting + sys_start_mehod = mp.get_start_method(allow_none=True) + sys_cv_threads = cv2.getNumThreads() + # pop and temp save system env vars + sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None) + sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None) + + # test config without setting env + config = dict(data=dict(workers_per_gpu=2)) + cfg = Config(config) + setup_multi_processes(cfg) + assert os.getenv('OMP_NUM_THREADS') == '1' + assert os.getenv('MKL_NUM_THREADS') == '1' + # when set to 0, the num threads will be 1 + assert cv2.getNumThreads() == 1 + if platform.system() != 'Windows': + assert mp.get_start_method() == 'fork' + + # test num workers <= 1 + os.environ.pop('OMP_NUM_THREADS') + os.environ.pop('MKL_NUM_THREADS') + config = dict(data=dict(workers_per_gpu=0)) + cfg = Config(config) + setup_multi_processes(cfg) + assert 'OMP_NUM_THREADS' not in os.environ + assert 'MKL_NUM_THREADS' not in os.environ + + # test manually set env var + os.environ['OMP_NUM_THREADS'] = '4' + config = dict(data=dict(workers_per_gpu=2)) + cfg = Config(config) + setup_multi_processes(cfg) + assert os.getenv('OMP_NUM_THREADS') == '4' + + # test manually set opencv threads and mp start method + config = dict( + data=dict(workers_per_gpu=2), + opencv_num_threads=4, + mp_start_method='spawn') + cfg = Config(config) + setup_multi_processes(cfg) + assert cv2.getNumThreads() == 4 + assert mp.get_start_method() == 'spawn' + + # revert setting to avoid affecting other programs + if sys_start_mehod: + mp.set_start_method(sys_start_mehod, force=True) + cv2.setNumThreads(sys_cv_threads) + if sys_omp_threads: + os.environ['OMP_NUM_THREADS'] = sys_omp_threads + else: + os.environ.pop('OMP_NUM_THREADS') + if sys_mkl_threads: + os.environ['MKL_NUM_THREADS'] = sys_mkl_threads + else: + os.environ.pop('MKL_NUM_THREADS') diff --git a/tools/test.py b/tools/test.py index da9821de1..0bac77782 100644 --- a/tools/test.py +++ b/tools/test.py @@ -17,6 +17,7 @@ from mmdet.apis import multi_gpu_test, single_gpu_test from mmdet.datasets import (build_dataloader, build_dataset, replace_ImageToTensor) from mmdet.models import build_detector +from mmdet.utils import setup_multi_processes def parse_args(): @@ -128,6 +129,10 @@ def main(): cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) + + # set multi-process settings + setup_multi_processes(cfg) + # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True diff --git a/tools/train.py b/tools/train.py index d5742c70b..2ff4c93ff 100644 --- a/tools/train.py +++ b/tools/train.py @@ -1,14 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse import copy -import multiprocessing as mp import os import os.path as osp -import platform import time import warnings -import cv2 import mmcv import torch from mmcv import Config, DictAction @@ -19,7 +16,7 @@ from mmdet import __version__ from mmdet.apis import init_random_seed, set_random_seed, train_detector from mmdet.datasets import build_dataset from mmdet.models import build_detector -from mmdet.utils import collect_env, get_root_logger +from mmdet.utils import collect_env, get_root_logger, setup_multi_processes def parse_args(): @@ -91,38 +88,6 @@ def parse_args(): return args -def setup_multi_processes(cfg): - # set multi-process start method as `fork` to speed up the training - if platform.system() != 'Windows': - mp_start_method = cfg.get('mp_start_method', 'fork') - mp.set_start_method(mp_start_method) - - # disable opencv multithreading to avoid system being overloaded - opencv_num_threads = cfg.get('opencv_num_threads', 0) - cv2.setNumThreads(opencv_num_threads) - - # setup OMP threads - # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa - if ('OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1): - omp_num_threads = 1 - warnings.warn( - f'Setting OMP_NUM_THREADS environment variable for each process ' - f'to be {omp_num_threads} in default, to avoid your system being ' - f'overloaded, please further tune the variable for optimal ' - f'performance in your application as needed.') - os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) - - # setup MKL threads - if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: - mkl_num_threads = 1 - warnings.warn( - f'Setting MKL_NUM_THREADS environment variable for each process ' - f'to be {mkl_num_threads} in default, to avoid your system being ' - f'overloaded, please further tune the variable for optimal ' - f'performance in your application as needed.') - os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) - - def main(): args = parse_args()