|
|
|
@ -206,7 +206,7 @@ class Exporter: |
|
|
|
|
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y) |
|
|
|
|
self.pretty_name = self.file.stem.replace('yolo', 'YOLO') |
|
|
|
|
self.metadata = { |
|
|
|
|
'description': f"Ultralytics {self.pretty_name} model trained on {self.model.args['data']}", |
|
|
|
|
'description': f"Ultralytics {self.pretty_name} model trained on {self.args.data}", |
|
|
|
|
'author': 'Ultralytics', |
|
|
|
|
'license': 'GPL-3.0 https://ultralytics.com/license', |
|
|
|
|
'version': __version__, |
|
|
|
@ -257,11 +257,16 @@ class Exporter: |
|
|
|
|
f = [str(x) for x in f if x] # filter out '' and None |
|
|
|
|
if any(f): |
|
|
|
|
f = str(Path(f[-1])) |
|
|
|
|
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)' |
|
|
|
|
f"\nResults saved to {colorstr('bold', file.parent.resolve())}" |
|
|
|
|
f"\nPredict: yolo task={model.task} mode=predict model={f}" |
|
|
|
|
f"\nValidate: yolo task={model.task} mode=val model={f}" |
|
|
|
|
f"\nVisualize: https://netron.app") |
|
|
|
|
square = self.imgsz[0] == self.imgsz[1] |
|
|
|
|
s = f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not work. Use " \ |
|
|
|
|
f"export 'imgsz={max(self.imgsz)}' if val is required." if not square else '' |
|
|
|
|
imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '') |
|
|
|
|
LOGGER.info( |
|
|
|
|
f'\nExport complete ({time.time() - t:.1f}s)' |
|
|
|
|
f"\nResults saved to {colorstr('bold', file.parent.resolve())}" |
|
|
|
|
f"\nPredict: yolo task={model.task} mode=predict model={f} imgsz={imgsz}" |
|
|
|
|
f"\nValidate: yolo task={model.task} mode=val model={f} imgsz={imgsz} data={self.args.data} {s}" |
|
|
|
|
f"\nVisualize: https://netron.app") |
|
|
|
|
|
|
|
|
|
self.run_callbacks("on_export_end") |
|
|
|
|
return f # return list of exported files/dirs |
|
|
|
@ -497,7 +502,7 @@ class Exporter: |
|
|
|
|
except ImportError: |
|
|
|
|
check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}") |
|
|
|
|
import tensorflow as tf # noqa |
|
|
|
|
check_requirements(("onnx", "onnx2tf", "sng4onnx", "onnxsim", "onnx_graphsurgeon"), |
|
|
|
|
check_requirements(("onnx", "onnx2tf", "sng4onnx", "onnxsim", "onnx_graphsurgeon", "tflite_support"), |
|
|
|
|
cmds="--extra-index-url https://pypi.ngc.nvidia.com ") |
|
|
|
|
|
|
|
|
|
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') |
|
|
|
@ -680,24 +685,45 @@ class Exporter: |
|
|
|
|
|
|
|
|
|
def _add_tflite_metadata(self, file): |
|
|
|
|
# Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata |
|
|
|
|
check_requirements('tflite_support') |
|
|
|
|
|
|
|
|
|
from tflite_support import flatbuffers # noqa |
|
|
|
|
from tflite_support import metadata as _metadata # noqa |
|
|
|
|
from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa |
|
|
|
|
|
|
|
|
|
# Creates model info. |
|
|
|
|
model_meta = _metadata_fb.ModelMetadataT() |
|
|
|
|
model_meta.name = self.metadata['description'] |
|
|
|
|
model_meta.version = self.metadata['version'] |
|
|
|
|
model_meta.author = self.metadata['author'] |
|
|
|
|
model_meta.license = self.metadata['license'] |
|
|
|
|
|
|
|
|
|
# Creates input info. |
|
|
|
|
input_meta = _metadata_fb.TensorMetadataT() |
|
|
|
|
input_meta.name = "image" |
|
|
|
|
input_meta.description = "Input image to be detected." |
|
|
|
|
input_meta.content = _metadata_fb.ContentT() |
|
|
|
|
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT() |
|
|
|
|
input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB |
|
|
|
|
input_meta.content.contentPropertiesType = _metadata_fb.ContentProperties.ImageProperties |
|
|
|
|
|
|
|
|
|
# Creates output info. |
|
|
|
|
output_meta = _metadata_fb.TensorMetadataT() |
|
|
|
|
output_meta.name = "output" |
|
|
|
|
output_meta.description = "Coordinates of detected objects, class labels, and confidence score." |
|
|
|
|
|
|
|
|
|
# Label file |
|
|
|
|
tmp_file = Path('/tmp/meta.txt') |
|
|
|
|
with open(tmp_file, 'w') as meta_f: |
|
|
|
|
meta_f.write(str(self.metadata)) |
|
|
|
|
|
|
|
|
|
model_meta = _metadata_fb.ModelMetadataT() |
|
|
|
|
label_file = _metadata_fb.AssociatedFileT() |
|
|
|
|
label_file.name = tmp_file.name |
|
|
|
|
model_meta.associatedFiles = [label_file] |
|
|
|
|
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS |
|
|
|
|
output_meta.associatedFiles = [label_file] |
|
|
|
|
|
|
|
|
|
# Creates subgraph info. |
|
|
|
|
subgraph = _metadata_fb.SubGraphMetadataT() |
|
|
|
|
subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()] |
|
|
|
|
subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * len(self.output_shape) |
|
|
|
|
subgraph.inputTensorMetadata = [input_meta] |
|
|
|
|
subgraph.outputTensorMetadata = [output_meta] |
|
|
|
|
model_meta.subgraphMetadata = [subgraph] |
|
|
|
|
|
|
|
|
|
b = flatbuffers.Builder(0) |
|
|
|
@ -710,6 +736,14 @@ class Exporter: |
|
|
|
|
populator.populate() |
|
|
|
|
tmp_file.unlink() |
|
|
|
|
|
|
|
|
|
# TODO Rename this here and in `_add_tflite_metadata` |
|
|
|
|
def _extracted_from__add_tflite_metadata_15(self, _metadata_fb, arg1, arg2): |
|
|
|
|
# Creates input info. |
|
|
|
|
result = _metadata_fb.TensorMetadataT() |
|
|
|
|
result.name = arg1 |
|
|
|
|
result.description = arg2 |
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
def _pipeline_coreml(self, model, prefix=colorstr('CoreML Pipeline:')): |
|
|
|
|
# YOLOv8 CoreML pipeline |
|
|
|
|
import coremltools as ct # noqa |
|
|
|
|