Fix TensorBoard graph `UserWarning` catch (#4513)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/4517/head
Glenn Jocher 1 year ago committed by GitHub
parent c7ceb84fb6
commit 3c40e7a9fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 31
      examples/tutorial.ipynb
  2. 5
      ultralytics/engine/trainer.py
  3. 6
      ultralytics/utils/callbacks/tensorboard.py

@ -66,7 +66,7 @@
"import ultralytics\n",
"ultralytics.checks()"
],
"execution_count": 1,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
@ -102,7 +102,7 @@
"# Run inference on an image with YOLOv8n\n",
"!yolo predict model=yolov8n.pt source='https://ultralytics.com/images/zidane.jpg'"
],
"execution_count": 2,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
@ -169,7 +169,7 @@
"# Validate YOLOv8n on COCO8 val\n",
"!yolo val model=yolov8n.pt data=coco8.yaml"
],
"execution_count": 3,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
@ -215,6 +215,25 @@
"Train YOLOv8 on [Detect](https://docs.ultralytics.com/tasks/detect/), [Segment](https://docs.ultralytics.com/tasks/segment/), [Classify](https://docs.ultralytics.com/tasks/classify/) and [Pose](https://docs.ultralytics.com/tasks/pose/) datasets. See [YOLOv8 Train Docs](https://docs.ultralytics.com/modes/train/) for more information."
]
},
{
"cell_type": "code",
"source": [
"#@title Select YOLOv8 🚀 logger {run: 'auto'}\n",
"logger = 'Comet' #@param ['Comet', 'TensorBoard']\n",
"\n",
"if logger == 'Comet':\n",
" %pip install -q comet_ml\n",
" import comet_ml; comet_ml.init()\n",
"elif logger == 'TensorBoard':\n",
" %load_ext tensorboard\n",
" %tensorboard --logdir ."
],
"metadata": {
"id": "ktegpM42AooT"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
@ -228,7 +247,7 @@
"# Train YOLOv8n on COCO8 for 3 epochs\n",
"!yolo train model=yolov8n.pt data=coco8.yaml epochs=3 imgsz=640"
],
"execution_count": 4,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
@ -280,8 +299,6 @@
"\n",
" Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size\n",
" 1/3 0.761G 0.9273 3.155 1.291 32 640: 100% 1/1 [00:01<00:00, 1.23s/it]\n",
"/usr/local/lib/python3.10/dist-packages/torch/optim/lr_scheduler.py:139: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\n",
" warnings.warn(\"Detected call of `lr_scheduler.step()` before `optimizer.step()`. \"\n",
" Class Images Instances Box(P R mAP50 mAP50-95): 100% 1/1 [00:00<00:00, 2.21it/s]\n",
" all 4 17 0.613 0.899 0.888 0.621\n",
"\n",
@ -359,7 +376,7 @@
"id": "CYIjW4igCjqD",
"outputId": "2b65e381-717b-4a6f-d6f5-5254c867f3a4"
},
"execution_count": 5,
"execution_count": null,
"outputs": [
{
"output_type": "stream",

@ -10,6 +10,7 @@ import math
import os
import subprocess
import time
import warnings
from copy import deepcopy
from datetime import datetime, timedelta
from pathlib import Path
@ -378,7 +379,9 @@ class BaseTrainer:
self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
self.scheduler.step()
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress 'Detected lr_scheduler.step() before optimizer.step()'
self.scheduler.step()
self.run_callbacks('on_train_epoch_end')
if RANK in (-1, 0):

@ -32,9 +32,9 @@ def _log_tensorboard_graph(trainer):
imgsz = trainer.args.imgsz
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
p = next(trainer.model.parameters()) # for device, type
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input (WARNING: must be zeros, not empty)
with warnings.catch_warnings(category=UserWarning):
warnings.simplefilter('ignore') # suppress jit trace warning
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=UserWarning) # suppress jit trace warning
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
except Exception as e:
LOGGER.warning(f'WARNING ⚠ TensorBoard graph visualization failure {e}')

Loading…
Cancel
Save