|
|
|
@ -1015,12 +1015,17 @@ class Exporter: |
|
|
|
|
|
|
|
|
|
def _add_tflite_metadata(self, file): |
|
|
|
|
"""Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata.""" |
|
|
|
|
from tflite_support import flatbuffers # noqa |
|
|
|
|
from tflite_support import metadata as _metadata # noqa |
|
|
|
|
from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa |
|
|
|
|
import flatbuffers |
|
|
|
|
|
|
|
|
|
if MACOS: # TFLite Support bug https://github.com/tensorflow/tflite-support/issues/954#issuecomment-2108570845 |
|
|
|
|
from tflite_support import metadata # noqa |
|
|
|
|
from tflite_support import metadata_schema_py_generated as schema # noqa |
|
|
|
|
else: |
|
|
|
|
from tensorflow_lite_support.metadata import metadata_schema_py_generated as schema # noqa |
|
|
|
|
from tensorflow_lite_support.metadata.python import metadata # noqa |
|
|
|
|
|
|
|
|
|
# Create model info |
|
|
|
|
model_meta = _metadata_fb.ModelMetadataT() |
|
|
|
|
model_meta = schema.ModelMetadataT() |
|
|
|
|
model_meta.name = self.metadata["description"] |
|
|
|
|
model_meta.version = self.metadata["version"] |
|
|
|
|
model_meta.author = self.metadata["author"] |
|
|
|
@ -1031,41 +1036,41 @@ class Exporter: |
|
|
|
|
with open(tmp_file, "w") as f: |
|
|
|
|
f.write(str(self.metadata)) |
|
|
|
|
|
|
|
|
|
label_file = _metadata_fb.AssociatedFileT() |
|
|
|
|
label_file = schema.AssociatedFileT() |
|
|
|
|
label_file.name = tmp_file.name |
|
|
|
|
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS |
|
|
|
|
label_file.type = schema.AssociatedFileType.TENSOR_AXIS_LABELS |
|
|
|
|
|
|
|
|
|
# Create input info |
|
|
|
|
input_meta = _metadata_fb.TensorMetadataT() |
|
|
|
|
input_meta = schema.TensorMetadataT() |
|
|
|
|
input_meta.name = "image" |
|
|
|
|
input_meta.description = "Input image to be detected." |
|
|
|
|
input_meta.content = _metadata_fb.ContentT() |
|
|
|
|
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT() |
|
|
|
|
input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB |
|
|
|
|
input_meta.content.contentPropertiesType = _metadata_fb.ContentProperties.ImageProperties |
|
|
|
|
input_meta.content = schema.ContentT() |
|
|
|
|
input_meta.content.contentProperties = schema.ImagePropertiesT() |
|
|
|
|
input_meta.content.contentProperties.colorSpace = schema.ColorSpaceType.RGB |
|
|
|
|
input_meta.content.contentPropertiesType = schema.ContentProperties.ImageProperties |
|
|
|
|
|
|
|
|
|
# Create output info |
|
|
|
|
output1 = _metadata_fb.TensorMetadataT() |
|
|
|
|
output1 = schema.TensorMetadataT() |
|
|
|
|
output1.name = "output" |
|
|
|
|
output1.description = "Coordinates of detected objects, class labels, and confidence score" |
|
|
|
|
output1.associatedFiles = [label_file] |
|
|
|
|
if self.model.task == "segment": |
|
|
|
|
output2 = _metadata_fb.TensorMetadataT() |
|
|
|
|
output2 = schema.TensorMetadataT() |
|
|
|
|
output2.name = "output" |
|
|
|
|
output2.description = "Mask protos" |
|
|
|
|
output2.associatedFiles = [label_file] |
|
|
|
|
|
|
|
|
|
# Create subgraph info |
|
|
|
|
subgraph = _metadata_fb.SubGraphMetadataT() |
|
|
|
|
subgraph = schema.SubGraphMetadataT() |
|
|
|
|
subgraph.inputTensorMetadata = [input_meta] |
|
|
|
|
subgraph.outputTensorMetadata = [output1, output2] if self.model.task == "segment" else [output1] |
|
|
|
|
model_meta.subgraphMetadata = [subgraph] |
|
|
|
|
|
|
|
|
|
b = flatbuffers.Builder(0) |
|
|
|
|
b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) |
|
|
|
|
b.Finish(model_meta.Pack(b), metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) |
|
|
|
|
metadata_buf = b.Output() |
|
|
|
|
|
|
|
|
|
populator = _metadata.MetadataPopulator.with_model_file(str(file)) |
|
|
|
|
populator = metadata.MetadataPopulator.with_model_file(str(file)) |
|
|
|
|
populator.load_metadata_buffer(metadata_buf) |
|
|
|
|
populator.load_associated_files([str(tmp_file)]) |
|
|
|
|
populator.populate() |
|
|
|
|