Allow missing `thop` package (#19314)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
pull/19312/head^2
Glenn Jocher 1 month ago committed by GitHub
parent b94dd87f9d
commit b50a327a04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 6
      ultralytics/nn/tasks.py
  2. 11
      ultralytics/utils/torch_utils.py

@ -7,7 +7,6 @@ import types
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
import thop
import torch import torch
from ultralytics.nn.modules import ( from ultralytics.nn.modules import (
@ -86,6 +85,11 @@ from ultralytics.utils.torch_utils import (
time_sync, time_sync,
) )
try:
import thop
except ImportError:
thop = None # conda support without 'ultralytics-thop' installed
class BaseModel(torch.nn.Module): class BaseModel(torch.nn.Module):
"""The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.""" """The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family."""

@ -12,7 +12,6 @@ from pathlib import Path
from typing import Union from typing import Union
import numpy as np import numpy as np
import thop
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
@ -31,6 +30,11 @@ from ultralytics.utils import (
) )
from ultralytics.utils.checks import check_version from ultralytics.utils.checks import check_version
try:
import thop
except ImportError:
thop = None # conda support without 'ultralytics-thop' installed
# Version checks (all default to version>=min_version) # Version checks (all default to version>=min_version)
TORCH_1_9 = check_version(torch.__version__, "1.9.0") TORCH_1_9 = check_version(torch.__version__, "1.9.0")
TORCH_1_13 = check_version(torch.__version__, "1.13.0") TORCH_1_13 = check_version(torch.__version__, "1.13.0")
@ -370,6 +374,9 @@ def model_info_for_loggers(trainer):
def get_flops(model, imgsz=640): def get_flops(model, imgsz=640):
"""Return a YOLO model's FLOPs.""" """Return a YOLO model's FLOPs."""
if not thop:
return 0.0 # if not installed return 0.0 GFLOPs
try: try:
model = de_parallel(model) model = de_parallel(model)
p = next(model.parameters()) p = next(model.parameters())
@ -681,7 +688,7 @@ def profile(input, ops, n=10, device=None, max_num_obj=0):
m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
try: try:
flops = thop.profile(deepcopy(m), inputs=[x], verbose=False)[0] / 1e9 * 2 # GFLOPs flops = thop.profile(deepcopy(m), inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
except Exception: except Exception:
flops = 0 flops = 0

Loading…
Cancel
Save