`ultralytics 8.3.8` replace `contextlib` with `try` for speed (#16782)

Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
pull/16784/head v8.3.8
Glenn Jocher 1 month ago committed by GitHub
parent 1e6c454460
commit a6a577961f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      ultralytics/__init__.py
  2. 5
      ultralytics/cfg/__init__.py
  3. 46
      ultralytics/data/dataset.py
  4. 5
      ultralytics/data/utils.py
  5. 13
      ultralytics/nn/autobackend.py
  6. 26
      ultralytics/nn/tasks.py
  7. 38
      ultralytics/utils/__init__.py
  8. 32
      ultralytics/utils/callbacks/tensorboard.py
  9. 20
      ultralytics/utils/checks.py
  10. 6
      ultralytics/utils/downloads.py
  11. 5
      ultralytics/utils/plotting.py
  12. 5
      ultralytics/utils/torch_utils.py

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.7"
__version__ = "8.3.8"
import os

@ -669,9 +669,10 @@ def smart_value(v):
elif v_lower == "false":
return False
else:
with contextlib.suppress(Exception):
try:
return eval(v)
return v
except: # noqa E722
return v
def entrypoint(debug=""):

@ -1,6 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import json
from collections import defaultdict
from itertools import repeat
@ -483,7 +482,7 @@ class ClassificationDataset:
desc = f"{self.prefix}Scanning {self.root}..."
path = Path(self.root).with_suffix(".cache") # *.cache file path
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
try:
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
@ -495,24 +494,25 @@ class ClassificationDataset:
LOGGER.info("\n".join(cache["msgs"])) # display warnings
return samples
# Run scan if *.cache retrieval failed
nf, nc, msgs, samples, x = 0, 0, [], [], {}
with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
pbar = TQDM(results, desc=desc, total=len(self.samples))
for sample, nf_f, nc_f, msg in pbar:
if nf_f:
samples.append(sample)
if msg:
msgs.append(msg)
nf += nf_f
nc += nc_f
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
pbar.close()
if msgs:
LOGGER.info("\n".join(msgs))
x["hash"] = get_hash([x[0] for x in self.samples])
x["results"] = nf, nc, len(samples), samples
x["msgs"] = msgs # warnings
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
return samples
except (FileNotFoundError, AssertionError, AttributeError):
# Run scan if *.cache retrieval failed
nf, nc, msgs, samples, x = 0, 0, [], [], {}
with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
pbar = TQDM(results, desc=desc, total=len(self.samples))
for sample, nf_f, nc_f, msg in pbar:
if nf_f:
samples.append(sample)
if msg:
msgs.append(msg)
nf += nf_f
nc += nc_f
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
pbar.close()
if msgs:
LOGGER.info("\n".join(msgs))
x["hash"] = get_hash([x[0] for x in self.samples])
x["results"] = nf, nc, len(samples), samples
x["msgs"] = msgs # warnings
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
return samples

@ -1,6 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import hashlib
import json
import os
@ -60,12 +59,14 @@ def exif_size(img: Image.Image):
"""Returns exif-corrected PIL size."""
s = img.size # (width, height)
if img.format == "JPEG": # only support JPEG images
with contextlib.suppress(Exception):
try:
exif = img.getexif()
if exif:
rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274
if rotation in {6, 8}: # rotation 270 or 90
s = s[1], s[0]
except: # noqa E722
pass
return s

@ -1,7 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import ast
import contextlib
import json
import platform
import zipfile
@ -45,8 +44,10 @@ def check_class_names(names):
def default_class_names(data=None):
"""Applies default class names to an input YAML file or returns numerical class names."""
if data:
with contextlib.suppress(Exception):
try:
return yaml_load(check_yaml(data))["names"]
except: # noqa E722
pass
return {i: f"class{i}" for i in range(999)} # return default if above errors
@ -321,8 +322,10 @@ class AutoBackend(nn.Module):
with open(w, "rb") as f:
gd.ParseFromString(f.read())
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
with contextlib.suppress(StopIteration): # find metadata in SavedModel alongside GraphDef
try: # find metadata in SavedModel alongside GraphDef
metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml"))
except StopIteration:
pass
# TFLite or TFLite Edge TPU
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
@ -345,10 +348,12 @@ class AutoBackend(nn.Module):
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
# Load metadata
with contextlib.suppress(zipfile.BadZipFile):
try:
with zipfile.ZipFile(w, "r") as model:
meta_file = model.namelist()[0]
metadata = ast.literal_eval(model.read(meta_file).decode("utf-8"))
except zipfile.BadZipFile:
pass
# TF.js
elif tfjs:

@ -2,6 +2,7 @@
import contextlib
import pickle
import re
import types
from copy import deepcopy
from pathlib import Path
@ -958,8 +959,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m] # get module
for j, a in enumerate(args):
if isinstance(a, str):
with contextlib.suppress(ValueError):
try:
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
except ValueError:
pass
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
if m in {
@ -1072,8 +1075,6 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
def yaml_model_load(path):
"""Load a YOLOv8 model from a YAML file."""
import re
path = Path(path)
if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
@ -1100,11 +1101,10 @@ def guess_model_scale(model_path):
Returns:
(str): The size character of the model's scale, which can be n, s, m, l, or x.
"""
with contextlib.suppress(AttributeError):
import re
try:
return re.search(r"yolo[v]?\d+([nslmx])", Path(model_path).stem).group(1) # n, s, m, l, or x
return ""
except AttributeError:
return ""
def guess_model_task(model):
@ -1137,17 +1137,23 @@ def guess_model_task(model):
# Guess from model cfg
if isinstance(model, dict):
with contextlib.suppress(Exception):
try:
return cfg2task(model)
except: # noqa E722
pass
# Guess from PyTorch model
if isinstance(model, nn.Module): # PyTorch model
for x in "model.args", "model.model.args", "model.model.model.args":
with contextlib.suppress(Exception):
try:
return eval(x)["task"]
except: # noqa E722
pass
for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
with contextlib.suppress(Exception):
try:
return cfg2task(eval(x))
except: # noqa E722
pass
for m in model.modules():
if isinstance(m, Segment):

@ -523,10 +523,11 @@ def read_device_model() -> str:
Returns:
(str): Model file contents if read successfully or empty string otherwise.
"""
with contextlib.suppress(Exception):
try:
with open("/proc/device-tree/model") as f:
return f.read()
return ""
except: # noqa E722
return ""
def is_ubuntu() -> bool:
@ -536,10 +537,11 @@ def is_ubuntu() -> bool:
Returns:
(bool): True if OS is Ubuntu, False otherwise.
"""
with contextlib.suppress(FileNotFoundError):
try:
with open("/etc/os-release") as f:
return "ID=ubuntu" in f.read()
return False
except FileNotFoundError:
return False
def is_colab():
@ -569,11 +571,7 @@ def is_jupyter():
Returns:
(bool): True if running inside a Jupyter Notebook, False otherwise.
"""
with contextlib.suppress(Exception):
from IPython import get_ipython
return get_ipython() is not None
return False
return "get_ipython" in locals()
def is_docker() -> bool:
@ -583,10 +581,11 @@ def is_docker() -> bool:
Returns:
(bool): True if the script is running inside a Docker container, False otherwise.
"""
with contextlib.suppress(Exception):
try:
with open("/proc/self/cgroup") as f:
return "docker" in f.read()
return False
except: # noqa E722
return False
def is_raspberrypi() -> bool:
@ -617,14 +616,15 @@ def is_online() -> bool:
Returns:
(bool): True if connection is successful, False otherwise.
"""
with contextlib.suppress(Exception):
try:
assert str(os.getenv("YOLO_OFFLINE", "")).lower() != "true" # check if ENV var YOLO_OFFLINE="True"
import socket
for dns in ("1.1.1.1", "8.8.8.8"): # check Cloudflare and Google DNS
socket.create_connection(address=(dns, 80), timeout=2.0).close()
return True
return False
except: # noqa E722
return False
def is_pip_package(filepath: str = __name__) -> bool:
@ -711,9 +711,11 @@ def get_git_origin_url():
(str | None): The origin URL of the git repository or None if not git directory.
"""
if IS_GIT_DIR:
with contextlib.suppress(subprocess.CalledProcessError):
try:
origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"])
return origin.decode().strip()
except subprocess.CalledProcessError:
return None
def get_git_branch():
@ -724,9 +726,11 @@ def get_git_branch():
(str | None): The current git branch name or None if not a git directory.
"""
if IS_GIT_DIR:
with contextlib.suppress(subprocess.CalledProcessError):
try:
origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
return origin.decode().strip()
except subprocess.CalledProcessError:
return None
def get_default_args(func):
@ -751,9 +755,11 @@ def get_ubuntu_version():
(str): Ubuntu version or None if not an Ubuntu OS.
"""
if is_ubuntu():
with contextlib.suppress(FileNotFoundError, AttributeError):
try:
with open("/etc/os-release") as f:
return re.search(r'VERSION_ID="(\d+\.\d+)"', f.read())[1]
except (FileNotFoundError, AttributeError):
return None
def get_user_config_dir(sub_dir="Ultralytics"):

@ -1,6 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
@ -45,26 +44,27 @@ def _log_tensorboard_graph(trainer):
warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning
# Try simple method first (YOLO)
with contextlib.suppress(Exception):
try:
trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
return
# Fallback to TorchScript export steps (RTDETR)
try:
model = deepcopy(de_parallel(trainer.model))
model.eval()
model = model.fuse(verbose=False)
for m in model.modules():
if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
m.export = True
m.format = "torchscript"
model(im) # dry run
WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
except Exception as e:
LOGGER.warning(f"{PREFIX}WARNING ⚠ TensorBoard graph visualization failure {e}")
except: # noqa E722
# Fallback to TorchScript export steps (RTDETR)
try:
model = deepcopy(de_parallel(trainer.model))
model.eval()
model = model.fuse(verbose=False)
for m in model.modules():
if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
m.export = True
m.format = "torchscript"
model(im) # dry run
WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
except Exception as e:
LOGGER.warning(f"{PREFIX}WARNING ⚠ TensorBoard graph visualization failure {e}")
def on_pretrain_routine_start(trainer):

@ -1,6 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import glob
import inspect
import math
@ -271,11 +270,13 @@ def check_latest_pypi_version(package_name="ultralytics"):
Returns:
(str): The latest version of the package.
"""
with contextlib.suppress(Exception):
try:
requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning
response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3)
if response.status_code == 200:
return response.json()["info"]["version"]
except: # noqa E722
return None
def check_pip_update_available():
@ -286,7 +287,7 @@ def check_pip_update_available():
(bool): True if an update is available, False otherwise.
"""
if ONLINE and IS_PIP_PACKAGE:
with contextlib.suppress(Exception):
try:
from ultralytics import __version__
latest = check_latest_pypi_version()
@ -296,6 +297,8 @@ def check_pip_update_available():
f"Update with 'pip install -U ultralytics'"
)
return True
except: # noqa E722
pass
return False
@ -577,10 +580,12 @@ def check_yolo(verbose=True, device=""):
ram = psutil.virtual_memory().total
total, used, free = shutil.disk_usage("/")
s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)"
with contextlib.suppress(Exception): # clear display if ipython is installed
try:
from IPython import display
display.clear_output()
display.clear_output() # clear display if notebook
except ImportError:
pass
else:
s = ""
@ -707,9 +712,10 @@ def check_amp(model):
def git_describe(path=ROOT): # path must be a directory
"""Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
with contextlib.suppress(Exception):
try:
return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1]
return ""
except: # noqa E722
return ""
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):

@ -1,6 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import re
import shutil
import subprocess
@ -53,7 +52,7 @@ def is_url(url, check=False):
valid = is_url("https://www.example.com")
```
"""
with contextlib.suppress(Exception):
try:
url = str(url)
result = parse.urlparse(url)
assert all([result.scheme, result.netloc]) # check if is url
@ -61,7 +60,8 @@ def is_url(url, check=False):
with request.urlopen(url) as response:
return response.getcode() == 200 # check if exists online
return True
return False
except: # noqa E722
return False
def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):

@ -1,6 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import math
import warnings
from pathlib import Path
@ -1115,10 +1114,12 @@ def plot_images(
mask = mask.astype(bool)
else:
mask = image_masks[j].astype(bool)
with contextlib.suppress(Exception):
try:
im[y : y + h, x : x + w, :][mask] = (
im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
)
except: # noqa E722
pass
annotator.fromarray(im)
if not save:
return np.asarray(annotator.im)

@ -1,6 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import gc
import math
import os
@ -113,13 +112,15 @@ def get_cpu_info():
from ultralytics.utils import PERSISTENT_CACHE # avoid circular import error
if "cpu_info" not in PERSISTENT_CACHE:
with contextlib.suppress(Exception):
try:
import cpuinfo # pip install py-cpuinfo
k = "brand_raw", "hardware_raw", "arch_string_raw" # keys sorted by preference
info = cpuinfo.get_cpu_info() # info dict
string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
PERSISTENT_CACHE["cpu_info"] = string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
except: # noqa E722
pass
return PERSISTENT_CACHE.get("cpu_info", "unknown")

Loading…
Cancel
Save