[Feature] Add ClearMLLoggerHook (#1906)

* added clearml logger support

* review fixes

* review fixes
pull/1911/head
Artem 3 years ago committed by GitHub
parent 4fb59b9830
commit 94c071b310
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 9
      mmcv/runner/__init__.py
  2. 9
      mmcv/runner/hooks/__init__.py
  3. 4
      mmcv/runner/hooks/logger/__init__.py
  4. 62
      mmcv/runner/hooks/logger/clearml.py
  5. 29
      tests/test_runner/test_hooks.py

@ -10,9 +10,10 @@ from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info,
init_dist, master_only)
from .epoch_based_runner import EpochBasedRunner, Runner
from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model
from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook,
DistSamplerSeedHook, DvcliveLoggerHook, EMAHook, EvalHook,
Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
from .hooks import (HOOKS, CheckpointHook, ClearMLLoggerHook, ClosureHook,
DistEvalHook, DistSamplerSeedHook, DvcliveLoggerHook,
EMAHook, EvalHook, Fp16OptimizerHook,
GradientCumulativeFp16OptimizerHook,
GradientCumulativeOptimizerHook, Hook, IterTimerHook,
LoggerHook, MlflowLoggerHook, NeptuneLoggerHook,
OptimizerHook, PaviLoggerHook, SegmindLoggerHook,
@ -68,5 +69,5 @@ __all__ = [
'ModuleDict', 'ModuleList', 'GradientCumulativeOptimizerHook',
'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor',
'SegmindLoggerHook', 'LinearAnnealingMomentumUpdaterHook',
'LinearAnnealingLrUpdaterHook'
'LinearAnnealingLrUpdaterHook', 'ClearMLLoggerHook'
]

@ -5,9 +5,10 @@ from .ema import EMAHook
from .evaluation import DistEvalHook, EvalHook
from .hook import HOOKS, Hook
from .iter_timer import IterTimerHook
from .logger import (DvcliveLoggerHook, LoggerHook, MlflowLoggerHook,
NeptuneLoggerHook, PaviLoggerHook, SegmindLoggerHook,
TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook)
from .logger import (ClearMLLoggerHook, DvcliveLoggerHook, LoggerHook,
MlflowLoggerHook, NeptuneLoggerHook, PaviLoggerHook,
SegmindLoggerHook, TensorboardLoggerHook, TextLoggerHook,
WandbLoggerHook)
from .lr_updater import (CosineAnnealingLrUpdaterHook,
CosineRestartLrUpdaterHook, CyclicLrUpdaterHook,
ExpLrUpdaterHook, FixedLrUpdaterHook,
@ -43,5 +44,5 @@ __all__ = [
'SyncBuffersHook', 'EMAHook', 'EvalHook', 'DistEvalHook', 'ProfilerHook',
'GradientCumulativeOptimizerHook', 'GradientCumulativeFp16OptimizerHook',
'SegmindLoggerHook', 'LinearAnnealingLrUpdaterHook',
'LinearAnnealingMomentumUpdaterHook'
'LinearAnnealingMomentumUpdaterHook', 'ClearMLLoggerHook'
]

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base import LoggerHook
from .clearml import ClearMLLoggerHook
from .dvclive import DvcliveLoggerHook
from .mlflow import MlflowLoggerHook
from .neptune import NeptuneLoggerHook
@ -12,5 +13,6 @@ from .wandb import WandbLoggerHook
__all__ = [
'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook',
'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook',
'NeptuneLoggerHook', 'DvcliveLoggerHook', 'SegmindLoggerHook'
'NeptuneLoggerHook', 'DvcliveLoggerHook', 'SegmindLoggerHook',
'ClearMLLoggerHook'
]

@ -0,0 +1,62 @@
# Copyright (c) OpenMMLab. All rights reserved.
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
@HOOKS.register_module()
class ClearMLLoggerHook(LoggerHook):
"""Class to log metrics with clearml.
It requires `clearml`_ to be installed.
Args:
init_kwargs (dict): A dict contains the `clearml.Task.init`
initialization keys. See `taskinit`_ for more details.
interval (int): Logging interval (every k iterations). Default 10.
ignore_last (bool): Ignore the log of last iterations in each epoch
if less than `interval`. Default: True.
reset_flag (bool): Whether to clear the output buffer after logging.
Default: False.
by_epoch (bool): Whether EpochBasedRunner is used. Default: True.
.. _clearml:
https://clear.ml/docs/latest/docs/
.. _taskinit:
https://clear.ml/docs/latest/docs/references/sdk/task/#taskinit
"""
def __init__(self,
init_kwargs=None,
interval=10,
ignore_last=True,
reset_flag=False,
by_epoch=True):
super(ClearMLLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
self.import_clearml()
self.init_kwargs = init_kwargs
def import_clearml(self):
try:
import clearml
except ImportError:
raise ImportError(
'Please run "pip install clearml" to install clearml')
self.clearml = clearml
@master_only
def before_run(self, runner):
super(ClearMLLoggerHook, self).before_run(runner)
task_kwargs = self.init_kwargs if self.init_kwargs else {}
self.task = self.clearml.Task.init(**task_kwargs)
self.task_logger = self.task.get_logger()
@master_only
def log(self, runner):
tags = self.get_loggable_tags(runner)
for tag, val in tags.items():
self.task_logger.report_scalar(tag, tag, val,
self.get_iter(runner))

@ -23,8 +23,8 @@ from torch.utils.data import DataLoader
from mmcv.fileio.file_client import PetrelBackend
# yapf: disable
from mmcv.runner import (CheckpointHook, DvcliveLoggerHook, EMAHook,
Fp16OptimizerHook,
from mmcv.runner import (CheckpointHook, ClearMLLoggerHook, DvcliveLoggerHook,
EMAHook, Fp16OptimizerHook,
GradientCumulativeFp16OptimizerHook,
GradientCumulativeOptimizerHook, IterTimerHook,
MlflowLoggerHook, NeptuneLoggerHook, OptimizerHook,
@ -1572,6 +1572,31 @@ def test_dvclive_hook_model_file(tmp_path):
shutil.rmtree(runner.work_dir)
def test_clearml_hook():
sys.modules['clearml'] = MagicMock()
runner = _build_demo_runner()
hook = ClearMLLoggerHook(init_kwargs={
'project_name': 'proj',
'task_name': 'task',
})
loader = DataLoader(torch.ones((5, 2)))
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir)
hook.clearml.Task.init.assert_called_with(
project_name='proj', task_name='task')
hook.task.get_logger.assert_called_with()
report_scalar_calls = [
call('momentum', 'momentum', 0.95, 6),
call('learning_rate', 'learning_rate', 0.02, 6),
]
hook.task_logger.report_scalar.assert_has_calls(
report_scalar_calls, any_order=True)
def _build_demo_runner_without_hook(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,

Loading…
Cancel
Save