|
|
|
@ -1,5 +1,6 @@ |
|
|
|
|
import glob |
|
|
|
|
import inspect |
|
|
|
|
import math |
|
|
|
|
import platform |
|
|
|
|
import urllib |
|
|
|
|
from pathlib import Path |
|
|
|
@ -13,71 +14,141 @@ import torch |
|
|
|
|
|
|
|
|
|
from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, emojis, |
|
|
|
|
is_docker, is_jupyter_notebook) |
|
|
|
|
from ultralytics.yolo.utils.ops import make_divisible |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_ascii(s=''): |
|
|
|
|
# Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7) |
|
|
|
|
s = str(s) # convert list, tuple, None, etc. to str |
|
|
|
|
return len(s.encode().decode('ascii', 'ignore')) == len(s) |
|
|
|
|
def is_ascii(s) -> bool: |
|
|
|
|
""" |
|
|
|
|
Check if a string is composed of only ASCII characters. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
s (str): String to be checked. |
|
|
|
|
|
|
|
|
|
def check_imgsz(imgsz, stride=32, min_dim=1, floor=0): |
|
|
|
|
# Verify image size is a multiple of stride s in each dimension |
|
|
|
|
Returns: |
|
|
|
|
bool: True if the string is composed only of ASCII characters, False otherwise. |
|
|
|
|
""" |
|
|
|
|
# Convert list, tuple, None, etc. to string |
|
|
|
|
s = str(s) |
|
|
|
|
|
|
|
|
|
# Check if the string is composed of only ASCII characters |
|
|
|
|
return all(ord(c) < 128 for c in s) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_imgsz(imgsz, stride=32, min_dim=1, floor=0): |
|
|
|
|
""" |
|
|
|
|
Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the |
|
|
|
|
stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
imgsz (int or List[int]): Image size. |
|
|
|
|
stride (int): Stride value. |
|
|
|
|
min_dim (int): Minimum number of dimensions. |
|
|
|
|
floor (int): Minimum allowed value for image size. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
List[int]: Updated image size. |
|
|
|
|
""" |
|
|
|
|
# Convert stride to integer if it is a tensor |
|
|
|
|
stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride) |
|
|
|
|
if isinstance(imgsz, int): # integer i.e. imgsz=640 |
|
|
|
|
sz = max(make_divisible(imgsz, stride), floor) |
|
|
|
|
else: # list i.e. imgsz=[640, 480] |
|
|
|
|
imgsz = list(imgsz) # convert to list if tuple |
|
|
|
|
sz = [max(make_divisible(x, stride), floor) for x in imgsz] |
|
|
|
|
|
|
|
|
|
# Convert image size to list if it is an integer |
|
|
|
|
if isinstance(imgsz, int): |
|
|
|
|
imgsz = [imgsz] |
|
|
|
|
|
|
|
|
|
# Make image size a multiple of the stride |
|
|
|
|
sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz] |
|
|
|
|
|
|
|
|
|
# Print warning message if image size was updated |
|
|
|
|
if sz != imgsz: |
|
|
|
|
LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {stride}, updating to {sz}') |
|
|
|
|
|
|
|
|
|
# Check dims |
|
|
|
|
if min_dim == 2: |
|
|
|
|
if isinstance(imgsz, int): |
|
|
|
|
sz = [sz, sz] |
|
|
|
|
elif len(sz) == 1: |
|
|
|
|
sz = [sz[0], sz[0]] |
|
|
|
|
# Add missing dimensions if necessary |
|
|
|
|
if min_dim == 2 and len(sz) == 1: |
|
|
|
|
sz = [sz[0], sz[0]] |
|
|
|
|
|
|
|
|
|
return sz |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_version(current="0.0.0", minimum="0.0.0", name="version ", pinned=False, hard=False, verbose=False): |
|
|
|
|
# Check version vs. required version |
|
|
|
|
current, minimum = (pkg.parse_version(x) for x in (current, minimum)) |
|
|
|
|
def check_version(current: str = "0.0.0", |
|
|
|
|
minimum: str = "0.0.0", |
|
|
|
|
name: str = "version ", |
|
|
|
|
pinned: bool = False, |
|
|
|
|
hard: bool = False, |
|
|
|
|
verbose: bool = False) -> bool: |
|
|
|
|
""" |
|
|
|
|
Check current version against the required minimum version. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
current (str): Current version. |
|
|
|
|
minimum (str): Required minimum version. |
|
|
|
|
name (str): Name to be used in warning message. |
|
|
|
|
pinned (bool): If True, versions must match exactly. If False, minimum version must be satisfied. |
|
|
|
|
hard (bool): If True, raise an AssertionError if the minimum version is not met. |
|
|
|
|
verbose (bool): If True, print warning message if minimum version is not met. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
bool: True if minimum version is met, False otherwise. |
|
|
|
|
""" |
|
|
|
|
from pkg_resources import parse_version |
|
|
|
|
current, minimum = (parse_version(x) for x in (current, minimum)) |
|
|
|
|
result = (current == minimum) if pinned else (current >= minimum) # bool |
|
|
|
|
s = f"WARNING ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed" # string |
|
|
|
|
warning_message = f"WARNING ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed" |
|
|
|
|
if hard: |
|
|
|
|
assert result, emojis(s) # assert min requirements met |
|
|
|
|
assert result, emojis(warning_message) # assert min requirements met |
|
|
|
|
if verbose and not result: |
|
|
|
|
LOGGER.warning(s) |
|
|
|
|
LOGGER.warning(warning_message) |
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_font(font=FONT, progress=False): |
|
|
|
|
# Download font to CONFIG_DIR if necessary |
|
|
|
|
def check_font(font: str = FONT, progress: bool = False) -> None: |
|
|
|
|
""" |
|
|
|
|
Download font file to the user's configuration directory if it does not already exist. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
font (str): Path to font file. |
|
|
|
|
progress (bool): If True, display a progress bar during the download. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
None |
|
|
|
|
""" |
|
|
|
|
font = Path(font) |
|
|
|
|
|
|
|
|
|
# Destination path for the font file |
|
|
|
|
file = USER_CONFIG_DIR / font.name |
|
|
|
|
|
|
|
|
|
# Check if font file exists at the source or destination path |
|
|
|
|
if not font.exists() and not file.exists(): |
|
|
|
|
# Download font file |
|
|
|
|
url = f'https://ultralytics.com/assets/{font.name}' |
|
|
|
|
LOGGER.info(f'Downloading {url} to {file}...') |
|
|
|
|
torch.hub.download_url_to_file(url, str(file), progress=progress) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_online(): |
|
|
|
|
# Check internet connectivity |
|
|
|
|
def check_online() -> bool: |
|
|
|
|
""" |
|
|
|
|
Check internet connectivity by attempting to connect to a known online host. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
bool: True if connection is successful, False otherwise. |
|
|
|
|
""" |
|
|
|
|
import socket |
|
|
|
|
try: |
|
|
|
|
socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility |
|
|
|
|
# Check host accessibility by attempting to establish a connection |
|
|
|
|
socket.create_connection(("1.1.1.1", 443), timeout=5) |
|
|
|
|
return True |
|
|
|
|
except OSError: |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_python(minimum='3.7.0'): |
|
|
|
|
# Check current python version vs. required python version |
|
|
|
|
def check_python(minimum: str = '3.7.0') -> bool: |
|
|
|
|
""" |
|
|
|
|
Check current python version against the required minimum version. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
minimum (str): Required minimum version of python. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
None |
|
|
|
|
""" |
|
|
|
|
check_version(platform.python_version(), minimum, name='Python ', hard=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|