|
|
|
@ -320,6 +320,10 @@ class AutoBackend(nn.Module): |
|
|
|
|
with open(w, "rb") as f: |
|
|
|
|
gd.ParseFromString(f.read()) |
|
|
|
|
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd)) |
|
|
|
|
try: # attempt to retrieve metadata from SavedModel file potentially alongside GraphDef file |
|
|
|
|
metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml")) |
|
|
|
|
except StopIteration: |
|
|
|
|
pass # no metadata file found |
|
|
|
|
|
|
|
|
|
# TFLite or TFLite Edge TPU |
|
|
|
|
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python |
|
|
|
@ -402,7 +406,7 @@ class AutoBackend(nn.Module): |
|
|
|
|
# Load external metadata YAML |
|
|
|
|
if isinstance(metadata, (str, Path)) and Path(metadata).exists(): |
|
|
|
|
metadata = yaml_load(metadata) |
|
|
|
|
if metadata: |
|
|
|
|
if metadata and isinstance(metadata, dict): |
|
|
|
|
for k, v in metadata.items(): |
|
|
|
|
if k in {"stride", "batch"}: |
|
|
|
|
metadata[k] = int(v) |
|
|
|
@ -563,7 +567,7 @@ class AutoBackend(nn.Module): |
|
|
|
|
y = [y] |
|
|
|
|
elif self.pb: # GraphDef |
|
|
|
|
y = self.frozen_func(x=self.tf.constant(im)) |
|
|
|
|
if len(y) == 2 and len(self.names) == 999: # segments and names not defined |
|
|
|
|
if (self.task == "segment" or len(y) == 2) and len(self.names) == 999: # segments and names not defined |
|
|
|
|
ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes |
|
|
|
|
nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400) |
|
|
|
|
self.names = {i: f"class{i}" for i in range(nc)} |
|
|
|
|