diff --git a/.travis.yml b/.travis.yml index b83708cdd..976076826 100644 --- a/.travis.yml +++ b/.travis.yml @@ -35,7 +35,7 @@ before_install: install: - pip install Pillow==6.2.2 # remove this line when torchvision>=0.5 - pip install torch==${TORCH} torchvision==${TORCHVISION} - - pip install mmcv-nightly + - pip install mmcv - pip install "git+https://github.com/open-mmlab/cocoapi.git#subdirectory=pycocotools" - pip install -r requirements.txt diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index fd8d32116..7e76539cf 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -1,11 +1,9 @@ import random -from collections import OrderedDict import numpy as np import torch -import torch.distributed as dist from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -from mmcv.runner import (DistSamplerSeedHook, OptimizerHook, Runner, +from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook, build_optimizer) from mmdet.core import DistEvalHook, EvalHook, Fp16OptimizerHook @@ -32,54 +30,6 @@ def set_random_seed(seed, deterministic=False): torch.backends.cudnn.benchmark = False -def parse_losses(losses): - log_vars = OrderedDict() - for loss_name, loss_value in losses.items(): - if isinstance(loss_value, torch.Tensor): - log_vars[loss_name] = loss_value.mean() - elif isinstance(loss_value, list): - log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) - else: - raise TypeError(f'{loss_name} is not a tensor or list of tensors') - - loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key) - - log_vars['loss'] = loss - for loss_name, loss_value in log_vars.items(): - # reduce loss when distributed training - if dist.is_available() and dist.is_initialized(): - loss_value = loss_value.data.clone() - dist.all_reduce(loss_value.div_(dist.get_world_size())) - log_vars[loss_name] = loss_value.item() - - return loss, log_vars - - -def batch_processor(model, data, train_mode): - """Process a data batch. - - This method is required as an argument of Runner, which defines how to - process a data batch and obtain proper outputs. The first 3 arguments of - batch_processor are fixed. - - Args: - model (nn.Module): A PyTorch model. - data (dict): The data batch in a dict. - train_mode (bool): Training mode or not. It may be useless for some - models. - - Returns: - dict: A dict containing losses and log vars. - """ - losses = model(**data) - loss, log_vars = parse_losses(losses) - - outputs = dict( - loss=loss, log_vars=log_vars, num_samples=len(data['img'].data)) - - return outputs - - def train_detector(model, dataset, cfg, @@ -132,11 +82,10 @@ def train_detector(model, # build runner optimizer = build_optimizer(model, cfg.optimizer) - runner = Runner( + runner = EpochBasedRunner( model, - batch_processor, - optimizer, - cfg.work_dir, + optimizer=optimizer, + work_dir=cfg.work_dir, logger=logger, meta=meta) # an ugly workaround to make .log and .log.json filenames the same diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py index d22e8e01b..7e00c2c3a 100644 --- a/mmdet/models/detectors/base.py +++ b/mmdet/models/detectors/base.py @@ -1,8 +1,11 @@ import warnings from abc import ABCMeta, abstractmethod +from collections import OrderedDict import mmcv import numpy as np +import torch +import torch.distributed as dist import torch.nn as nn from mmcv.utils import print_log @@ -149,6 +152,90 @@ class BaseDetector(nn.Module, metaclass=ABCMeta): else: return self.forward_test(img, img_metas, **kwargs) + def _parse_losses(self, losses): + """Parse the raw outputs (losses) of the network. + + Args: + losses (dict): Raw output of the network, which usually contain + losses and other necessary infomation. + + Returns: + tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor + which may be a weighted sum of all losses, log_vars contains + all the variables to be sent to the logger. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError( + f'{loss_name} is not a tensor or list of tensors') + + loss = sum(_value for _key, _value in log_vars.items() + if 'loss' in _key) + + log_vars['loss'] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars + + def train_step(self, data, optimizer): + """The iteration step during training. + + This method defines an iteration step during training, except for the + back propagation and optimizer updating, which are done in an optimizer + hook. Note that in some complicated cases or models, the whole process + including back propagation and optimizer updating is also defined in + this method, such as GAN. + + Args: + data (dict): The output of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. This argument is unused + and reserved. + + Returns: + dict: It should contain at least 3 keys: ``loss``, ``log_vars``, + ``num_samples``. + ``loss`` is a tensor for back propagation, which can be a + weighted sum of multiple losses. + ``log_vars`` contains all the variables to be sent to the + logger. + ``num_samples`` indicates the batch size (when the model is + DDP, it means the batch size on each GPU), which is used for + averaging the logs. + """ + losses = self(**data) + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, log_vars=log_vars, num_samples=len(data['img'].data)) + + return outputs + + def val_step(self, data, optimizer): + """The iteration step during validation. + + This method shares the same signature as :func:`train_step`, but used + during val epochs. Note that the evaluation after training epochs is + not implemented with this method, but an evaluation hook. + """ + losses = self(**data) + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, log_vars=log_vars, num_samples=len(data['img'].data)) + + return outputs + def show_result(self, img, result, diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 5f2374a94..192f18705 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,5 +1,5 @@ matplotlib -mmcv>=0.5.9 +mmcv>=0.6.0 numpy # need older pillow until torchvision is fixed Pillow<=6.2.2 diff --git a/tests/test_forward.py b/tests/test_forward.py index 2fc235349..208254006 100644 --- a/tests/test_forward.py +++ b/tests/test_forward.py @@ -153,9 +153,8 @@ def test_faster_rcnn_ohem_forward(): gt_labels=gt_labels, return_loss=True) assert isinstance(losses, dict) - from mmdet.apis.train import parse_losses - total_loss = float(parse_losses(losses)[0].item()) - assert total_loss > 0 + loss, _ = detector._parse_losses(losses) + assert float(loss.item()) > 0 # Test forward train with an empty truth batch mm_inputs = _demo_mm_inputs(input_shape, num_items=[0]) @@ -170,9 +169,8 @@ def test_faster_rcnn_ohem_forward(): gt_labels=gt_labels, return_loss=True) assert isinstance(losses, dict) - from mmdet.apis.train import parse_losses - total_loss = float(parse_losses(losses)[0].item()) - assert total_loss > 0 + loss, _ = detector._parse_losses(losses) + assert float(loss.item()) > 0 # HTC is not ready yet @@ -206,10 +204,10 @@ def test_two_stage_forward(cfg_file): gt_masks=gt_masks, return_loss=True) assert isinstance(losses, dict) - from mmdet.apis.train import parse_losses - total_loss = parse_losses(losses)[0].requires_grad_(True) - assert float(total_loss.item()) > 0 - total_loss.backward() + loss, _ = detector._parse_losses(losses) + loss.requires_grad_(True) + assert float(loss.item()) > 0 + loss.backward() # Test forward train with an empty truth batch mm_inputs = _demo_mm_inputs(input_shape, num_items=[0]) @@ -226,10 +224,10 @@ def test_two_stage_forward(cfg_file): gt_masks=gt_masks, return_loss=True) assert isinstance(losses, dict) - from mmdet.apis.train import parse_losses - total_loss = parse_losses(losses)[0].requires_grad_(True) - assert float(total_loss.item()) > 0 - total_loss.backward() + loss, _ = detector._parse_losses(losses) + loss.requires_grad_(True) + assert float(loss.item()) > 0 + loss.backward() # Test forward test with torch.no_grad():