`ultralytics 8.3.76` fix `dynamic` batch inference with NMS export (#19249)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/18321/head^2 v8.3.76
Mohammed Yasin 3 weeks ago committed by GitHub
parent 0f81777af5
commit e16593336b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      ultralytics/__init__.py
  2. 20
      ultralytics/engine/exporter.py

@ -1,6 +1,6 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
__version__ = "8.3.75"
__version__ = "8.3.76"
import os

@ -1560,20 +1560,20 @@ class NMSModel(torch.nn.Module):
preds = self.model(x)
pred = preds[0] if isinstance(preds, tuple) else preds
kwargs = dict(device=pred.device, dtype=pred.dtype)
bs = pred.shape[0]
pred = pred.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
extra_shape = pred.shape[-1] - (4 + len(self.model.names)) # extras from Segment, OBB, Pose
if self.args.dynamic and self.args.batch > 1: # batch size needs to always be same due to loop unroll
pad = torch.zeros(torch.max(torch.tensor(self.args.batch - bs), torch.tensor(0)), *pred.shape[1:], **kwargs)
pred = torch.cat((pred, pad))
boxes, scores, extras = pred.split([4, len(self.model.names), extra_shape], dim=2)
scores, classes = scores.max(dim=-1)
self.args.max_det = min(pred.shape[1], self.args.max_det) # in case num_anchors < max_det
# (N, max_det, 4 coords + 1 class score + 1 class label + extra_shape).
out = torch.zeros(
boxes.shape[0],
self.args.max_det,
boxes.shape[-1] + 2 + extra_shape,
device=boxes.device,
dtype=boxes.dtype,
)
for i, (box, cls, score, extra) in enumerate(zip(boxes, classes, scores, extras)):
out = torch.zeros(bs, self.args.max_det, boxes.shape[-1] + 2 + extra_shape, **kwargs)
for i in range(bs):
box, cls, score, extra = boxes[i], classes[i], scores[i], extras[i]
mask = score > self.args.conf
if self.is_tf:
# TFLite GatherND error if mask is empty
@ -1593,7 +1593,7 @@ class NMSModel(torch.nn.Module):
if self.args.format == "tflite": # TFLite is already normalized
nmsbox *= multiplier
else:
nmsbox = multiplier * nmsbox / torch.tensor(x.shape[2:], device=box.device, dtype=box.dtype).max()
nmsbox = multiplier * nmsbox / torch.tensor(x.shape[2:], **kwargs).max()
if not self.args.agnostic_nms: # class-specific NMS
end = 2 if self.obb else 4
# fully explicit expansion otherwise reshape error
@ -1624,4 +1624,4 @@ class NMSModel(torch.nn.Module):
# Zero-pad to max_det size to avoid reshape error
pad = (0, 0, 0, self.args.max_det - dets.shape[0])
out[i] = torch.nn.functional.pad(dets, pad)
return (out, preds[1]) if self.model.task == "segment" else out
return (out[:bs], preds[1]) if self.model.task == "segment" else out[:bs]

Loading…
Cancel
Save