Model saved_model export (#151)

pull/147/head
Glenn Jocher 2 years ago committed by GitHub
parent d17d1e064d
commit f8a13c49a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 43
      ultralytics/yolo/engine/exporter.py
  2. 5
      ultralytics/yolo/engine/model.py

@ -313,14 +313,11 @@ class Exporter:
# Simplify
if self.args.simplify:
try:
cuda = torch.cuda.is_available()
check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))
import onnxsim # noqa
check_requirements('onnxsim')
import onnxsim
LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(model_onnx)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
subprocess.run(f'onnxsim {f} {f}', shell=True)
except Exception as e:
LOGGER.info(f'{prefix} simplifier failure: {e}')
return f, model_onnx
@ -460,6 +457,40 @@ class Exporter:
iou_thres=0.45,
conf_thres=0.25,
prefix=colorstr('TensorFlow SavedModel:')):
# YOLOv5 TensorFlow SavedModel export
try:
import tensorflow as tf # noqa
except ImportError:
check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}")
import tensorflow as tf # noqa
check_requirements(("onnx", "onnx2tf", "sng4onnx", "onnxsim", "onnx_graphsurgeon"),
cmds="--extra-index-url https://pypi.ngc.nvidia.com ")
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
f = str(self.file).replace(self.file.suffix, '_saved_model')
# Export to ONNX
self._export_onnx()
onnx = self.file.with_suffix('.onnx')
# Export to TF SavedModel
subprocess.run(f'onnx2tf -i {onnx} --output_signaturedefs -o {f}', shell=True)
# Load saved_model
keras_model = tf.saved_model.load(f, tags=None, options=None)
return f, keras_model
@try_export
def _export_saved_model_OLD(self,
nms=False,
agnostic_nms=False,
topk_per_class=100,
topk_all=100,
iou_thres=0.45,
conf_thres=0.25,
prefix=colorstr('TensorFlow SavedModel:')):
# YOLOv5 TensorFlow SavedModel export
try:
import tensorflow as tf # noqa

@ -52,8 +52,8 @@ class YOLO:
# Load or create new YOLO model
{'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model)
def __call__(self, source):
return self.predict(source)
def __call__(self, source, **kwargs):
return self.predict(source, **kwargs)
def _new(self, cfg: str, verbose=True):
"""
@ -218,3 +218,4 @@ class YOLO:
args.pop("name", None)
args.pop("batch", None)
args.pop("epochs", None)
args.pop("cache", None)

Loading…
Cancel
Save