Use `tensorflow_lite_support` (#13042)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/12955/head
Glenn Jocher 6 months ago committed by GitHub
parent e71efd4830
commit 1a4ac2c6ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 37
      ultralytics/engine/exporter.py

@ -1015,12 +1015,17 @@ class Exporter:
def _add_tflite_metadata(self, file):
"""Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata."""
from tflite_support import flatbuffers # noqa
from tflite_support import metadata as _metadata # noqa
from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
import flatbuffers
if MACOS: # TFLite Support bug https://github.com/tensorflow/tflite-support/issues/954#issuecomment-2108570845
from tflite_support import metadata # noqa
from tflite_support import metadata_schema_py_generated as schema # noqa
else:
from tensorflow_lite_support.metadata import metadata_schema_py_generated as schema # noqa
from tensorflow_lite_support.metadata.python import metadata # noqa
# Create model info
model_meta = _metadata_fb.ModelMetadataT()
model_meta = schema.ModelMetadataT()
model_meta.name = self.metadata["description"]
model_meta.version = self.metadata["version"]
model_meta.author = self.metadata["author"]
@ -1031,41 +1036,41 @@ class Exporter:
with open(tmp_file, "w") as f:
f.write(str(self.metadata))
label_file = _metadata_fb.AssociatedFileT()
label_file = schema.AssociatedFileT()
label_file.name = tmp_file.name
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
label_file.type = schema.AssociatedFileType.TENSOR_AXIS_LABELS
# Create input info
input_meta = _metadata_fb.TensorMetadataT()
input_meta = schema.TensorMetadataT()
input_meta.name = "image"
input_meta.description = "Input image to be detected."
input_meta.content = _metadata_fb.ContentT()
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
input_meta.content.contentPropertiesType = _metadata_fb.ContentProperties.ImageProperties
input_meta.content = schema.ContentT()
input_meta.content.contentProperties = schema.ImagePropertiesT()
input_meta.content.contentProperties.colorSpace = schema.ColorSpaceType.RGB
input_meta.content.contentPropertiesType = schema.ContentProperties.ImageProperties
# Create output info
output1 = _metadata_fb.TensorMetadataT()
output1 = schema.TensorMetadataT()
output1.name = "output"
output1.description = "Coordinates of detected objects, class labels, and confidence score"
output1.associatedFiles = [label_file]
if self.model.task == "segment":
output2 = _metadata_fb.TensorMetadataT()
output2 = schema.TensorMetadataT()
output2.name = "output"
output2.description = "Mask protos"
output2.associatedFiles = [label_file]
# Create subgraph info
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph = schema.SubGraphMetadataT()
subgraph.inputTensorMetadata = [input_meta]
subgraph.outputTensorMetadata = [output1, output2] if self.model.task == "segment" else [output1]
model_meta.subgraphMetadata = [subgraph]
b = flatbuffers.Builder(0)
b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
b.Finish(model_meta.Pack(b), metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_buf = b.Output()
populator = _metadata.MetadataPopulator.with_model_file(str(file))
populator = metadata.MetadataPopulator.with_model_file(str(file))
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files([str(tmp_file)])
populator.populate()

Loading…
Cancel
Save