Use new `ultralytics-thop` package (#13282)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
pull/13288/head
Glenn Jocher 8 months ago committed by GitHub
parent 8fb140688a
commit 7453753544
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 4
      pyproject.toml
  2. 8
      ultralytics/nn/tasks.py
  3. 11
      ultralytics/utils/torch_utils.py

@ -75,9 +75,9 @@ dependencies = [
"tqdm>=4.64.0", # progress bars "tqdm>=4.64.0", # progress bars
"psutil", # system utilization "psutil", # system utilization
"py-cpuinfo", # display CPU info "py-cpuinfo", # display CPU info
"thop>=0.1.1", # FLOPs computation
"pandas>=1.1.4", "pandas>=1.1.4",
"seaborn>=0.11.0", # plotting "seaborn>=0.11.0", # plotting
"ultralytics-thop>=0.2.4", # FLOPs computation https://github.com/ultralytics/thop
] ]
# Optional dependencies ------------------------------------------------------------------------------------------------ # Optional dependencies ------------------------------------------------------------------------------------------------
@ -94,7 +94,7 @@ dev = [
"mkdocstrings[python]", "mkdocstrings[python]",
"mkdocs-jupyter", # for notebooks "mkdocs-jupyter", # for notebooks
"mkdocs-redirects", # for 301 redirects "mkdocs-redirects", # for 301 redirects
"mkdocs-ultralytics-plugin>=0.0.44", # for meta descriptions and images, dates and authors "mkdocs-ultralytics-plugin>=0.0.45", # for meta descriptions and images, dates and authors
] ]
export = [ export = [
"onnx>=1.12.0", # ONNX export "onnx>=1.12.0", # ONNX export

@ -4,6 +4,7 @@ import contextlib
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
import thop
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -65,11 +66,6 @@ from ultralytics.utils.torch_utils import (
time_sync, time_sync,
) )
try:
import thop
except ImportError:
thop = None
class BaseModel(nn.Module): class BaseModel(nn.Module):
"""The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.""" """The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family."""
@ -157,7 +153,7 @@ class BaseModel(nn.Module):
None None
""" """
c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # FLOPs flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 # GFLOPs
t = time_sync() t = time_sync()
for _ in range(10): for _ in range(10):
m(x.copy() if c else x) m(x.copy() if c else x)

@ -11,6 +11,7 @@ from pathlib import Path
from typing import Union from typing import Union
import numpy as np import numpy as np
import thop
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
@ -27,11 +28,6 @@ from ultralytics.utils import (
) )
from ultralytics.utils.checks import check_version from ultralytics.utils.checks import check_version
try:
import thop
except ImportError:
thop = None
# Version checks (all default to version>=min_version) # Version checks (all default to version>=min_version)
TORCH_1_9 = check_version(torch.__version__, "1.9.0") TORCH_1_9 = check_version(torch.__version__, "1.9.0")
TORCH_1_13 = check_version(torch.__version__, "1.13.0") TORCH_1_13 = check_version(torch.__version__, "1.13.0")
@ -308,9 +304,6 @@ def model_info_for_loggers(trainer):
def get_flops(model, imgsz=640): def get_flops(model, imgsz=640):
"""Return a YOLO model's FLOPs.""" """Return a YOLO model's FLOPs."""
if not thop:
return 0.0 # if not installed return 0.0 GFLOPs
try: try:
model = de_parallel(model) model = de_parallel(model)
p = next(model.parameters()) p = next(model.parameters())
@ -571,7 +564,7 @@ def profile(input, ops, n=10, device=None):
m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
try: try:
flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 # GFLOPs
except Exception: except Exception:
flops = 0 flops = 0

Loading…
Cancel
Save