|
|
|
@ -184,9 +184,6 @@ class Exporter: |
|
|
|
|
y = model(im) # dry runs |
|
|
|
|
if self.args.half and not coreml and not xml: |
|
|
|
|
im, model = im.half(), model.half() # to FP16 |
|
|
|
|
shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape |
|
|
|
|
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} and " |
|
|
|
|
f"output shape {shape} ({file_size(file):.1f} MB)") |
|
|
|
|
|
|
|
|
|
# Warnings |
|
|
|
|
warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) # suppress TracerWarning |
|
|
|
@ -207,6 +204,9 @@ class Exporter: |
|
|
|
|
'stride': int(max(model.stride)), |
|
|
|
|
'names': model.names} # model metadata |
|
|
|
|
|
|
|
|
|
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} and " |
|
|
|
|
f"output shape {self.output_shape} ({file_size(file):.1f} MB)") |
|
|
|
|
|
|
|
|
|
# Exports |
|
|
|
|
f = [''] * len(fmts) # exported filenames |
|
|
|
|
if jit: # TorchScript |
|
|
|
@ -220,9 +220,8 @@ class Exporter: |
|
|
|
|
if coreml: # CoreML |
|
|
|
|
f[4], _ = self._export_coreml() |
|
|
|
|
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats |
|
|
|
|
raise NotImplementedError('YOLOv8 TensorFlow export support is still under development. ' |
|
|
|
|
'Please consider contributing to the effort if you have TF expertise. Thank you!') |
|
|
|
|
assert not isinstance(model, ClassificationModel), 'ClassificationModel TF exports not yet supported.' |
|
|
|
|
LOGGER.warning('WARNING ⚠️ YOLOv8 TensorFlow export support is still under development. ' |
|
|
|
|
'Please consider contributing to the effort if you have TF expertise. Thank you!') |
|
|
|
|
nms = False |
|
|
|
|
f[5], s_model = self._export_saved_model(nms=nms or self.args.agnostic_nms or tfjs, |
|
|
|
|
agnostic_nms=self.args.agnostic_nms or tfjs) |
|
|
|
@ -236,7 +235,7 @@ class Exporter: |
|
|
|
|
agnostic_nms=self.args.agnostic_nms) |
|
|
|
|
if edgetpu: |
|
|
|
|
f[8], _ = self._export_edgetpu() |
|
|
|
|
self._add_tflite_metadata(f[8] or f[7], num_outputs=len(s_model.outputs)) |
|
|
|
|
self._add_tflite_metadata(f[8] or f[7], num_outputs=len(self.output_shape)) |
|
|
|
|
if tfjs: |
|
|
|
|
f[9], _ = self._export_tfjs() |
|
|
|
|
if paddle: # PaddlePaddle |
|
|
|
@ -552,13 +551,13 @@ class Exporter: |
|
|
|
|
return f, keras_model |
|
|
|
|
|
|
|
|
|
@try_export |
|
|
|
|
def _export_pb(self, keras_model, file, prefix=colorstr('TensorFlow GraphDef:')): |
|
|
|
|
def _export_pb(self, keras_model, prefix=colorstr('TensorFlow GraphDef:')): |
|
|
|
|
# YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow |
|
|
|
|
import tensorflow as tf # noqa |
|
|
|
|
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa |
|
|
|
|
|
|
|
|
|
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') |
|
|
|
|
f = file.with_suffix('.pb') |
|
|
|
|
f = self.file.with_suffix('.pb') |
|
|
|
|
|
|
|
|
|
m = tf.function(lambda x: keras_model(x)) # full model |
|
|
|
|
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)) |
|
|
|
|