|
|
|
@ -8,6 +8,7 @@ import torch |
|
|
|
|
import torch.nn as nn |
|
|
|
|
from torch.nn.init import constant_, xavier_uniform_ |
|
|
|
|
|
|
|
|
|
from ultralytics.utils import MACOS |
|
|
|
|
from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors |
|
|
|
|
|
|
|
|
|
from .block import DFL, BNContrastiveHead, ContrastiveHead, Proto |
|
|
|
@ -151,13 +152,16 @@ class Detect(nn.Module): |
|
|
|
|
boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1])) |
|
|
|
|
scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1])) |
|
|
|
|
|
|
|
|
|
# NOTE: simplify but result slightly lower mAP |
|
|
|
|
# NOTE: simplify result but slightly lower mAP |
|
|
|
|
# scores, labels = scores.max(dim=-1) |
|
|
|
|
# return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) |
|
|
|
|
|
|
|
|
|
scores, index = torch.topk(scores.flatten(1), max_det, axis=-1) |
|
|
|
|
labels = index % nc |
|
|
|
|
index = index // nc |
|
|
|
|
# Set int64 dtype for MPS and CoreML compatibility to avoid 'gather_along_axis' ops error |
|
|
|
|
if MACOS: |
|
|
|
|
index = index.to(torch.int64) |
|
|
|
|
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1])) |
|
|
|
|
|
|
|
|
|
return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1) |
|
|
|
|