`ultralytics 8.2.64` YOLOv10 SavedModel, TFlite, and GraphDef export (#14572)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/14644/head v8.2.64
Hassan Ghaffari 7 months ago committed by GitHub
parent 0d7bf447eb
commit c6db604fe1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 6
      docs/en/models/yolov10.md
  2. 2
      ultralytics/__init__.py
  3. 2
      ultralytics/engine/exporter.py
  4. 13
      ultralytics/nn/autobackend.py
  5. 6
      ultralytics/utils/benchmarks.py

@ -198,9 +198,9 @@ Due to the new operations introduced with YOLOv10, not all export formats provid
| [OpenVINO](../integrations/openvino.md) | ✅ |
| [TensorRT](../integrations/tensorrt.md) | ✅ |
| [CoreML](../integrations/coreml.md) | ❌ |
| [TF SavedModel](../integrations/tf-savedmodel.md) | |
| [TF GraphDef](../integrations/tf-graphdef.md) | |
| [TF Lite](../integrations/tflite.md) | |
| [TF SavedModel](../integrations/tf-savedmodel.md) | |
| [TF GraphDef](../integrations/tf-graphdef.md) | |
| [TF Lite](../integrations/tflite.md) | |
| [TF Edge TPU](../integrations/edge-tpu.md) | ❌ |
| [TF.js](../integrations/tfjs.md) | ❌ |
| [PaddlePaddle](../integrations/paddlepaddle.md) | ❌ |

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.63"
__version__ = "8.2.64"
import os

@ -885,6 +885,8 @@ class Exporter:
output_integer_quantized_tflite=self.args.int8,
quant_type="per-tensor", # "per-tensor" (faster) or "per-channel" (slower but more accurate)
custom_input_op_name_np_data_path=np_data,
disable_group_convolution=True, # for end-to-end model compatibility
enable_batchmatmul_unfold=True, # for end-to-end model compatibility
)
yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml

@ -587,14 +587,21 @@ class AutoBackend(nn.Module):
if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well
# Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695
# xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models
x[:, [0, 2]] *= w
x[:, [1, 3]] *= h
if x.shape[-1] == 6: # end-to-end model
x[:, :, [0, 2]] *= w
x[:, :, [1, 3]] *= h
else:
x[:, [0, 2]] *= w
x[:, [1, 3]] *= h
y.append(x)
# TF segment fixes: export is reversed vs ONNX export and protos are transposed
if len(y) == 2: # segment with (det, proto) output order reversed
if len(y[1].shape) != 4:
y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32)
y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160)
if y[1].shape[-1] == 6: # end-to-end model
y = [y[1]]
else:
y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160)
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
# for x in y:

@ -100,9 +100,11 @@ def benchmark(
assert not is_end2end, "End-to-end models not supported by CoreML and TF.js yet"
if i in {3, 5}: # CoreML and OpenVINO
assert not IS_PYTHON_3_12, "CoreML and OpenVINO not supported on Python 3.12"
if i in {6, 7, 8, 9, 10}: # All TF formats
if i in {6, 7, 8}: # TF SavedModel, TF GraphDef, and TFLite
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
assert not is_end2end, "End-to-end models not supported by onnx2tf yet"
if i in {9, 10}: # TF EdgeTPU and TF.js
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
assert not is_end2end, "End-to-end models not supported by TF EdgeTPU and TF.js yet"
if i in {11}: # Paddle
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet"

Loading…
Cancel
Save