[Enhance] Continue to speed up training. (#6974)

* [Enhance] Speed up training time.

* set in cfg
pull/7074/head
RangiLyu 3 years ago committed by GitHub
parent 9fcd11e9a3
commit 4b87ddc9ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 5
      configs/_base_/default_runtime.py
  2. 40
      tools/train.py

@ -14,3 +14,8 @@ log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
# disable opencv multithreading to avoid system being overloaded
opencv_num_threads = 0
# set multi-process start method as `fork` to speed up the training
mp_start_method = 'fork'

@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import copy
import multiprocessing as mp
import os
import os.path as osp
import platform
import time
import warnings
@ -19,8 +21,6 @@ from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.utils import collect_env, get_root_logger
cv2.setNumThreads(0)
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
@ -91,12 +91,48 @@ def parse_args():
return args
def setup_multi_processes(cfg):
# set multi-process start method as `fork` to speed up the training
if platform.system() != 'Windows':
mp_start_method = cfg.get('mp_start_method', 'fork')
mp.set_start_method(mp_start_method)
# disable opencv multithreading to avoid system being overloaded
opencv_num_threads = cfg.get('opencv_num_threads', 0)
cv2.setNumThreads(opencv_num_threads)
# setup OMP threads
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
if ('OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1):
omp_num_threads = 1
warnings.warn(
f'Setting OMP_NUM_THREADS environment variable for each process '
f'to be {omp_num_threads} in default, to avoid your system being '
f'overloaded, please further tune the variable for optimal '
f'performance in your application as needed.')
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
# setup MKL threads
if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
mkl_num_threads = 1
warnings.warn(
f'Setting MKL_NUM_THREADS environment variable for each process '
f'to be {mkl_num_threads} in default, to avoid your system being '
f'overloaded, please further tune the variable for optimal '
f'performance in your application as needed.')
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# set multi-process settings
setup_multi_processes(cfg)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True

Loading…
Cancel
Save