Search for model metadata with TensorFlow GraphDef (#13389)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/13420/head
Burhan 5 months ago committed by GitHub
parent 22dec59b57
commit fd854a7c68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 8
      ultralytics/nn/autobackend.py

@ -320,6 +320,10 @@ class AutoBackend(nn.Module):
with open(w, "rb") as f:
gd.ParseFromString(f.read())
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
try: # attempt to retrieve metadata from SavedModel file potentially alongside GraphDef file
metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml"))
except StopIteration:
pass # no metadata file found
# TFLite or TFLite Edge TPU
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
@ -402,7 +406,7 @@ class AutoBackend(nn.Module):
# Load external metadata YAML
if isinstance(metadata, (str, Path)) and Path(metadata).exists():
metadata = yaml_load(metadata)
if metadata:
if metadata and isinstance(metadata, dict):
for k, v in metadata.items():
if k in {"stride", "batch"}:
metadata[k] = int(v)
@ -563,7 +567,7 @@ class AutoBackend(nn.Module):
y = [y]
elif self.pb: # GraphDef
y = self.frozen_func(x=self.tf.constant(im))
if len(y) == 2 and len(self.names) == 999: # segments and names not defined
if (self.task == "segment" or len(y) == 2) and len(self.names) == 999: # segments and names not defined
ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes
nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400)
self.names = {i: f"class{i}" for i in range(nc)}

Loading…
Cancel
Save