OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
208 lines
7.2 KiB
208 lines
7.2 KiB
# Copyright (c) OpenMMLab. All rights reserved. |
|
import random |
|
import warnings |
|
|
|
import numpy as np |
|
import torch |
|
import torch.distributed as dist |
|
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel |
|
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, |
|
Fp16OptimizerHook, OptimizerHook, build_optimizer, |
|
build_runner, get_dist_info) |
|
|
|
from mmdet.core import DistEvalHook, EvalHook |
|
from mmdet.datasets import (build_dataloader, build_dataset, |
|
replace_ImageToTensor) |
|
from mmdet.utils import find_latest_checkpoint, get_root_logger |
|
|
|
|
|
def init_random_seed(seed=None, device='cuda'): |
|
"""Initialize random seed. |
|
|
|
If the seed is not set, the seed will be automatically randomized, |
|
and then broadcast to all processes to prevent some potential bugs. |
|
|
|
Args: |
|
seed (int, Optional): The seed. Default to None. |
|
device (str): The device where the seed will be put on. |
|
Default to 'cuda'. |
|
|
|
Returns: |
|
int: Seed to be used. |
|
""" |
|
if seed is not None: |
|
return seed |
|
|
|
# Make sure all ranks share the same random seed to prevent |
|
# some potential bugs. Please refer to |
|
# https://github.com/open-mmlab/mmdetection/issues/6339 |
|
rank, world_size = get_dist_info() |
|
seed = np.random.randint(2**31) |
|
if world_size == 1: |
|
return seed |
|
|
|
if rank == 0: |
|
random_num = torch.tensor(seed, dtype=torch.int32, device=device) |
|
else: |
|
random_num = torch.tensor(0, dtype=torch.int32, device=device) |
|
dist.broadcast(random_num, src=0) |
|
return random_num.item() |
|
|
|
|
|
def set_random_seed(seed, deterministic=False): |
|
"""Set random seed. |
|
|
|
Args: |
|
seed (int): Seed to be used. |
|
deterministic (bool): Whether to set the deterministic option for |
|
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` |
|
to True and `torch.backends.cudnn.benchmark` to False. |
|
Default: False. |
|
""" |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
if deterministic: |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
def train_detector(model, |
|
dataset, |
|
cfg, |
|
distributed=False, |
|
validate=False, |
|
timestamp=None, |
|
meta=None): |
|
logger = get_root_logger(log_level=cfg.log_level) |
|
|
|
# prepare data loaders |
|
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] |
|
if 'imgs_per_gpu' in cfg.data: |
|
logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. ' |
|
'Please use "samples_per_gpu" instead') |
|
if 'samples_per_gpu' in cfg.data: |
|
logger.warning( |
|
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and ' |
|
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"' |
|
f'={cfg.data.imgs_per_gpu} is used in this experiments') |
|
else: |
|
logger.warning( |
|
'Automatically set "samples_per_gpu"="imgs_per_gpu"=' |
|
f'{cfg.data.imgs_per_gpu} in this experiments') |
|
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu |
|
|
|
runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner[ |
|
'type'] |
|
data_loaders = [ |
|
build_dataloader( |
|
ds, |
|
cfg.data.samples_per_gpu, |
|
cfg.data.workers_per_gpu, |
|
# `num_gpus` will be ignored if distributed |
|
num_gpus=len(cfg.gpu_ids), |
|
dist=distributed, |
|
seed=cfg.seed, |
|
runner_type=runner_type, |
|
persistent_workers=cfg.data.get('persistent_workers', False)) |
|
for ds in dataset |
|
] |
|
|
|
# put model on gpus |
|
if distributed: |
|
find_unused_parameters = cfg.get('find_unused_parameters', False) |
|
# Sets the `find_unused_parameters` parameter in |
|
# torch.nn.parallel.DistributedDataParallel |
|
model = MMDistributedDataParallel( |
|
model.cuda(), |
|
device_ids=[torch.cuda.current_device()], |
|
broadcast_buffers=False, |
|
find_unused_parameters=find_unused_parameters) |
|
else: |
|
model = MMDataParallel(model, device_ids=cfg.gpu_ids) |
|
|
|
# build runner |
|
optimizer = build_optimizer(model, cfg.optimizer) |
|
|
|
if 'runner' not in cfg: |
|
cfg.runner = { |
|
'type': 'EpochBasedRunner', |
|
'max_epochs': cfg.total_epochs |
|
} |
|
warnings.warn( |
|
'config is now expected to have a `runner` section, ' |
|
'please set `runner` in your config.', UserWarning) |
|
else: |
|
if 'total_epochs' in cfg: |
|
assert cfg.total_epochs == cfg.runner.max_epochs |
|
|
|
runner = build_runner( |
|
cfg.runner, |
|
default_args=dict( |
|
model=model, |
|
optimizer=optimizer, |
|
work_dir=cfg.work_dir, |
|
logger=logger, |
|
meta=meta)) |
|
|
|
# an ugly workaround to make .log and .log.json filenames the same |
|
runner.timestamp = timestamp |
|
|
|
# fp16 setting |
|
fp16_cfg = cfg.get('fp16', None) |
|
if fp16_cfg is not None: |
|
optimizer_config = Fp16OptimizerHook( |
|
**cfg.optimizer_config, **fp16_cfg, distributed=distributed) |
|
elif distributed and 'type' not in cfg.optimizer_config: |
|
optimizer_config = OptimizerHook(**cfg.optimizer_config) |
|
else: |
|
optimizer_config = cfg.optimizer_config |
|
|
|
# register hooks |
|
runner.register_training_hooks( |
|
cfg.lr_config, |
|
optimizer_config, |
|
cfg.checkpoint_config, |
|
cfg.log_config, |
|
cfg.get('momentum_config', None), |
|
custom_hooks_config=cfg.get('custom_hooks', None)) |
|
|
|
if distributed: |
|
if isinstance(runner, EpochBasedRunner): |
|
runner.register_hook(DistSamplerSeedHook()) |
|
|
|
# register eval hooks |
|
if validate: |
|
# Support batch_size > 1 in validation |
|
val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1) |
|
if val_samples_per_gpu > 1: |
|
# Replace 'ImageToTensor' to 'DefaultFormatBundle' |
|
cfg.data.val.pipeline = replace_ImageToTensor( |
|
cfg.data.val.pipeline) |
|
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) |
|
val_dataloader = build_dataloader( |
|
val_dataset, |
|
samples_per_gpu=val_samples_per_gpu, |
|
workers_per_gpu=cfg.data.workers_per_gpu, |
|
dist=distributed, |
|
shuffle=False) |
|
eval_cfg = cfg.get('evaluation', {}) |
|
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' |
|
eval_hook = DistEvalHook if distributed else EvalHook |
|
# In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the |
|
# priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'. |
|
runner.register_hook( |
|
eval_hook(val_dataloader, **eval_cfg), priority='LOW') |
|
|
|
resume_from = None |
|
if cfg.resume_from is None and cfg.get('auto_resume'): |
|
resume_from = find_latest_checkpoint(cfg.work_dir) |
|
if resume_from is not None: |
|
cfg.resume_from = resume_from |
|
|
|
if cfg.resume_from: |
|
runner.resume(cfg.resume_from) |
|
elif cfg.load_from: |
|
runner.load_checkpoint(cfg.load_from) |
|
runner.run(data_loaders, cfg.workflow)
|
|
|