From a334ff7f6e29e4d4ddb045ac639864f8f4930279 Mon Sep 17 00:00:00 2001 From: WangQvQ <1579093407@qq.com> Date: Fri, 5 Jan 2024 23:29:30 +0800 Subject: [PATCH] Fixed RTDETR GFLOPs bug (#7309) Co-authored-by: Glenn Jocher --- ultralytics/utils/torch_utils.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index e8051571e..57139d8f9 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -274,16 +274,26 @@ def model_info_for_loggers(trainer): def get_flops(model, imgsz=640): """Return a YOLO model's FLOPs.""" + if not thop: + return 0.0 # if not installed return 0.0 GFLOPs + try: model = de_parallel(model) p = next(model.parameters()) - stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride - im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format - flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 if thop else 0 # stride GFLOPs - imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float - return flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] # expand if int/float + try: + # Use stride size for input tensor + stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride + im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format + flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 # stride GFLOPs + return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs + except Exception: + # Use actual image size for input tensor (i.e. required for RTDETR models) + im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format + return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 # imgsz GFLOPs except Exception: - return 0 + return 0.0 def get_flops_with_torch_profiler(model, imgsz=640):