`ultralytics 8.2.79` YOLOv10 CoreML and MPS training "gather" op error fix (#15672)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
fps v8.2.79
Quet Almahdi Morris 3 months ago committed by GitHub
parent bb3850ffa2
commit 7a79680dcc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      ultralytics/__init__.py
  2. 6
      ultralytics/nn/modules/head.py

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.78"
__version__ = "8.2.79"
import os

@ -8,6 +8,7 @@ import torch
import torch.nn as nn
from torch.nn.init import constant_, xavier_uniform_
from ultralytics.utils import MACOS
from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
from .block import DFL, BNContrastiveHead, ContrastiveHead, Proto
@ -151,13 +152,16 @@ class Detect(nn.Module):
boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1]))
scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1]))
# NOTE: simplify but result slightly lower mAP
# NOTE: simplify result but slightly lower mAP
# scores, labels = scores.max(dim=-1)
# return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
scores, index = torch.topk(scores.flatten(1), max_det, axis=-1)
labels = index % nc
index = index // nc
# Set int64 dtype for MPS and CoreML compatibility to avoid 'gather_along_axis' ops error
if MACOS:
index = index.to(torch.int64)
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1)

Loading…
Cancel
Save