|
|
|
@ -3,7 +3,9 @@ import warnings |
|
|
|
|
from math import inf |
|
|
|
|
|
|
|
|
|
import mmcv |
|
|
|
|
import torch.distributed as dist |
|
|
|
|
from mmcv.runner import Hook |
|
|
|
|
from torch.nn.modules.batchnorm import _BatchNorm |
|
|
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
|
|
|
|
from mmdet.utils import get_root_logger |
|
|
|
@ -199,6 +201,9 @@ class DistEvalHook(EvalHook): |
|
|
|
|
``CheckpointHook`` should device EvalHook. Default: None. |
|
|
|
|
rule (str | None): Comparison rule for best score. If set to None, |
|
|
|
|
it will infer a reasonable rule. Default: 'None'. |
|
|
|
|
broadcast_bn_buffer (bool): Whether to broadcast the |
|
|
|
|
buffer(running_mean and running_var) of rank 0 to other rank |
|
|
|
|
before evaluation. Default: True. |
|
|
|
|
**eval_kwargs: Evaluation arguments fed into the evaluate function of |
|
|
|
|
the dataset. |
|
|
|
|
""" |
|
|
|
@ -211,6 +216,7 @@ class DistEvalHook(EvalHook): |
|
|
|
|
gpu_collect=False, |
|
|
|
|
save_best=None, |
|
|
|
|
rule=None, |
|
|
|
|
broadcast_bn_buffer=True, |
|
|
|
|
**eval_kwargs): |
|
|
|
|
super().__init__( |
|
|
|
|
dataloader, |
|
|
|
@ -219,10 +225,24 @@ class DistEvalHook(EvalHook): |
|
|
|
|
save_best=save_best, |
|
|
|
|
rule=rule, |
|
|
|
|
**eval_kwargs) |
|
|
|
|
self.broadcast_bn_buffer = broadcast_bn_buffer |
|
|
|
|
self.tmpdir = tmpdir |
|
|
|
|
self.gpu_collect = gpu_collect |
|
|
|
|
|
|
|
|
|
def after_train_epoch(self, runner): |
|
|
|
|
# Synchronization of BatchNorm's buffer (running_mean |
|
|
|
|
# and running_var) is not supported in the DDP of pytorch, |
|
|
|
|
# which may cause the inconsistent performance of models in |
|
|
|
|
# different ranks, so we broadcast BatchNorm's buffers |
|
|
|
|
# of rank 0 to other ranks to avoid this. |
|
|
|
|
if self.broadcast_bn_buffer: |
|
|
|
|
model = runner.model |
|
|
|
|
for name, module in model.named_modules(): |
|
|
|
|
if isinstance(module, |
|
|
|
|
_BatchNorm) and module.track_running_stats: |
|
|
|
|
dist.broadcast(module.running_var, 0) |
|
|
|
|
dist.broadcast(module.running_mean, 0) |
|
|
|
|
|
|
|
|
|
if not self.evaluation_flag(runner): |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|