|
|
|
@ -343,6 +343,8 @@ class Exporter: |
|
|
|
|
requirements = ["onnx>=1.12.0"] |
|
|
|
|
if self.args.simplify: |
|
|
|
|
requirements += ["onnxsim>=0.4.33", "onnxruntime-gpu" if torch.cuda.is_available() else "onnxruntime"] |
|
|
|
|
if ARM64: |
|
|
|
|
check_requirements("cmake") # 'cmake' is needed to build onnxsim on aarch64 |
|
|
|
|
check_requirements(requirements) |
|
|
|
|
import onnx # noqa |
|
|
|
|
|
|
|
|
@ -712,8 +714,12 @@ class Exporter: |
|
|
|
|
try: |
|
|
|
|
import tensorflow as tf # noqa |
|
|
|
|
except ImportError: |
|
|
|
|
check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}") |
|
|
|
|
suffix = "-macos" if MACOS else "-aarch64" if ARM64 else "" if cuda else "-cpu" |
|
|
|
|
version = "" if ARM64 else "<=2.13.1" |
|
|
|
|
check_requirements(f"tensorflow{suffix}{version}") |
|
|
|
|
import tensorflow as tf # noqa |
|
|
|
|
if ARM64: |
|
|
|
|
check_requirements("cmake") # 'cmake' is needed to build onnxsim on aarch64 |
|
|
|
|
check_requirements( |
|
|
|
|
( |
|
|
|
|
"onnx>=1.12.0", |
|
|
|
@ -722,7 +728,7 @@ class Exporter: |
|
|
|
|
"onnxsim>=0.4.33", |
|
|
|
|
"onnx_graphsurgeon>=0.3.26", |
|
|
|
|
"tflite_support", |
|
|
|
|
"flatbuffers>=23.5.26", # update old 'flatbuffers' included inside tensorflow package |
|
|
|
|
"flatbuffers>=23.5.26,<100", # update old 'flatbuffers' included inside tensorflow package |
|
|
|
|
"onnxruntime-gpu" if cuda else "onnxruntime", |
|
|
|
|
), |
|
|
|
|
cmds="--extra-index-url https://pypi.ngc.nvidia.com", |
|
|
|
|