diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index 654496f1..e183b3e1 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -95,7 +95,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
- python-version: ['3.10']
+ python-version: ['3.11']
model: [yolov8n]
steps:
- uses: actions/checkout@v4
diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml
new file mode 100644
index 00000000..48a56e3d
--- /dev/null
+++ b/.github/workflows/format.yml
@@ -0,0 +1,25 @@
+# Ultralytics 🚀 - AGPL-3.0 license
+# Ultralytics Actions https://github.com/ultralytics/actions
+# This workflow automatically formats code and documentation in PRs to official Ultralytics standards
+
+name: Ultralytics Actions
+
+on:
+ push:
+ branches: [main]
+ pull_request:
+ branches: [main]
+
+jobs:
+ format:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Run Ultralytics Formatting
+ uses: ultralytics/actions@main
+ with:
+ token: ${{ secrets.GITHUB_TOKEN }} # automatically generated
+ python: true
+ docstrings: true
+ markdown: true
+ spelling: true
+ links: false
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 72cef360..07ae822f 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -22,7 +22,6 @@ repos:
- id: check-case-conflict
# - id: check-yaml
- id: check-docstring-first
- - id: double-quote-string-fixer
- id: detect-private-key
- repo: https://github.com/asottile/pyupgrade
@@ -64,7 +63,7 @@ repos:
- id: codespell
exclude: 'docs/de|docs/fr|docs/pt|docs/es|docs/mkdocs_de.yml'
args:
- - --ignore-words-list=crate,nd,strack,dota,ane,segway,fo,gool,winn
+ - --ignore-words-list=crate,nd,ned,strack,dota,ane,segway,fo,gool,winn,commend,bloc,nam,afterall
- repo: https://github.com/PyCQA/docformatter
rev: v1.7.5
diff --git a/docs/build_docs.py b/docs/build_docs.py
index 914f2fe8..9c9d75ed 100644
--- a/docs/build_docs.py
+++ b/docs/build_docs.py
@@ -30,45 +30,47 @@ import subprocess
from pathlib import Path
DOCS = Path(__file__).parent.resolve()
-SITE = DOCS.parent / 'site'
+SITE = DOCS.parent / "site"
def build_docs():
"""Build docs using mkdocs."""
if SITE.exists():
- print(f'Removing existing {SITE}')
+ print(f"Removing existing {SITE}")
shutil.rmtree(SITE)
# Build the main documentation
- print(f'Building docs from {DOCS}')
- subprocess.run(f'mkdocs build -f {DOCS}/mkdocs.yml', check=True, shell=True)
+ print(f"Building docs from {DOCS}")
+ subprocess.run(f"mkdocs build -f {DOCS}/mkdocs.yml", check=True, shell=True)
# Build other localized documentations
- for file in DOCS.glob('mkdocs_*.yml'):
- print(f'Building MkDocs site with configuration file: {file}')
- subprocess.run(f'mkdocs build -f {file}', check=True, shell=True)
- print(f'Site built at {SITE}')
+ for file in DOCS.glob("mkdocs_*.yml"):
+ print(f"Building MkDocs site with configuration file: {file}")
+ subprocess.run(f"mkdocs build -f {file}", check=True, shell=True)
+ print(f"Site built at {SITE}")
def update_html_links():
"""Update href links in HTML files to remove '.md' and '/index.md', excluding links starting with 'https://'."""
- html_files = Path(SITE).rglob('*.html')
+ html_files = Path(SITE).rglob("*.html")
total_updated_links = 0
for html_file in html_files:
- with open(html_file, 'r+', encoding='utf-8') as file:
+ with open(html_file, "r+", encoding="utf-8") as file:
content = file.read()
# Find all links to be updated, excluding those starting with 'https://'
links_to_update = re.findall(r'href="(?!https://)([^"]+?)(/index)?\.md"', content)
# Update the content and count the number of links updated
- updated_content, number_of_links_updated = re.subn(r'href="(?!https://)([^"]+?)(/index)?\.md"',
- r'href="\1"', content)
+ updated_content, number_of_links_updated = re.subn(
+ r'href="(?!https://)([^"]+?)(/index)?\.md"', r'href="\1"', content
+ )
total_updated_links += number_of_links_updated
# Special handling for '/index' links
- updated_content, number_of_index_links_updated = re.subn(r'href="([^"]+)/index"', r'href="\1/"',
- updated_content)
+ updated_content, number_of_index_links_updated = re.subn(
+ r'href="([^"]+)/index"', r'href="\1/"', updated_content
+ )
total_updated_links += number_of_index_links_updated
# Write the updated content back to the file
@@ -78,23 +80,23 @@ def update_html_links():
# Print updated links for this file
for link in links_to_update:
- print(f'Updated link in {html_file}: {link[0]}')
+ print(f"Updated link in {html_file}: {link[0]}")
- print(f'Total number of links updated: {total_updated_links}')
+ print(f"Total number of links updated: {total_updated_links}")
def update_page_title(file_path: Path, new_title: str):
"""Update the title of an HTML file."""
# Read the content of the file
- with open(file_path, encoding='utf-8') as file:
+ with open(file_path, encoding="utf-8") as file:
content = file.read()
# Replace the existing title with the new title
- updated_content = re.sub(r'
.*?', f'{new_title}', content)
+ updated_content = re.sub(r".*?", f"{new_title}", content)
# Write the updated content back to the file
- with open(file_path, 'w', encoding='utf-8') as file:
+ with open(file_path, "w", encoding="utf-8") as file:
file.write(updated_content)
@@ -109,8 +111,8 @@ def main():
print('Serve site at http://localhost:8000 with "python -m http.server --directory site"')
# Update titles
- update_page_title(SITE / '404.html', new_title='Ultralytics Docs - Not Found')
+ update_page_title(SITE / "404.html", new_title="Ultralytics Docs - Not Found")
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/docs/build_reference.py b/docs/build_reference.py
index cb15d34a..65736ac0 100644
--- a/docs/build_reference.py
+++ b/docs/build_reference.py
@@ -14,14 +14,14 @@ from ultralytics.utils import ROOT
NEW_YAML_DIR = ROOT.parent
CODE_DIR = ROOT
-REFERENCE_DIR = ROOT.parent / 'docs/en/reference'
+REFERENCE_DIR = ROOT.parent / "docs/en/reference"
def extract_classes_and_functions(filepath: Path) -> tuple:
"""Extracts class and function names from a given Python file."""
content = filepath.read_text()
- class_pattern = r'(?:^|\n)class\s(\w+)(?:\(|:)'
- func_pattern = r'(?:^|\n)def\s(\w+)\('
+ class_pattern = r"(?:^|\n)class\s(\w+)(?:\(|:)"
+ func_pattern = r"(?:^|\n)def\s(\w+)\("
classes = re.findall(class_pattern, content)
functions = re.findall(func_pattern, content)
@@ -31,31 +31,31 @@ def extract_classes_and_functions(filepath: Path) -> tuple:
def create_markdown(py_filepath: Path, module_path: str, classes: list, functions: list):
"""Creates a Markdown file containing the API reference for the given Python module."""
- md_filepath = py_filepath.with_suffix('.md')
+ md_filepath = py_filepath.with_suffix(".md")
# Read existing content and keep header content between first two ---
- header_content = ''
+ header_content = ""
if md_filepath.exists():
existing_content = md_filepath.read_text()
- header_parts = existing_content.split('---')
+ header_parts = existing_content.split("---")
for part in header_parts:
- if 'description:' in part or 'comments:' in part:
- header_content += f'---{part}---\n\n'
+ if "description:" in part or "comments:" in part:
+ header_content += f"---{part}---\n\n"
- module_name = module_path.replace('.__init__', '')
- module_path = module_path.replace('.', '/')
- url = f'https://github.com/ultralytics/ultralytics/blob/main/{module_path}.py'
- edit = f'https://github.com/ultralytics/ultralytics/edit/main/{module_path}.py'
+ module_name = module_path.replace(".__init__", "")
+ module_path = module_path.replace(".", "/")
+ url = f"https://github.com/ultralytics/ultralytics/blob/main/{module_path}.py"
+ edit = f"https://github.com/ultralytics/ultralytics/edit/main/{module_path}.py"
title_content = (
- f'# Reference for `{module_path}.py`\n\n'
- f'!!! Note\n\n'
- f' This file is available at [{url}]({url}). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request]({edit}) 🛠️. Thank you 🙏!\n\n'
+ f"# Reference for `{module_path}.py`\n\n"
+ f"!!! Note\n\n"
+ f" This file is available at [{url}]({url}). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request]({edit}) 🛠️. Thank you 🙏!\n\n"
)
- md_content = ['
\n'] + [f'## ::: {module_name}.{class_name}\n\n
\n' for class_name in classes]
- md_content.extend(f'## ::: {module_name}.{func_name}\n\n
\n' for func_name in functions)
- md_content = header_content + title_content + '\n'.join(md_content)
- if not md_content.endswith('\n'):
- md_content += '\n'
+ md_content = ["
\n"] + [f"## ::: {module_name}.{class_name}\n\n
\n" for class_name in classes]
+ md_content.extend(f"## ::: {module_name}.{func_name}\n\n
\n" for func_name in functions)
+ md_content = header_content + title_content + "\n".join(md_content)
+ if not md_content.endswith("\n"):
+ md_content += "\n"
md_filepath.parent.mkdir(parents=True, exist_ok=True)
md_filepath.write_text(md_content)
@@ -80,28 +80,28 @@ def create_nav_menu_yaml(nav_items: list):
for item_str in nav_items:
item = Path(item_str)
parts = item.parts
- current_level = nav_tree['reference']
+ current_level = nav_tree["reference"]
for part in parts[2:-1]: # skip the first two parts (docs and reference) and the last part (filename)
current_level = current_level[part]
- md_file_name = parts[-1].replace('.md', '')
+ md_file_name = parts[-1].replace(".md", "")
current_level[md_file_name] = item
nav_tree_sorted = sort_nested_dict(nav_tree)
def _dict_to_yaml(d, level=0):
"""Converts a nested dictionary to a YAML-formatted string with indentation."""
- yaml_str = ''
- indent = ' ' * level
+ yaml_str = ""
+ indent = " " * level
for k, v in d.items():
if isinstance(v, dict):
- yaml_str += f'{indent}- {k}:\n{_dict_to_yaml(v, level + 1)}'
+ yaml_str += f"{indent}- {k}:\n{_dict_to_yaml(v, level + 1)}"
else:
yaml_str += f"{indent}- {k}: {str(v).replace('docs/en/', '')}\n"
return yaml_str
# Print updated YAML reference section
- print('Scan complete, new mkdocs.yaml reference section is:\n\n', _dict_to_yaml(nav_tree_sorted))
+ print("Scan complete, new mkdocs.yaml reference section is:\n\n", _dict_to_yaml(nav_tree_sorted))
# Save new YAML reference section
# (NEW_YAML_DIR / 'nav_menu_updated.yml').write_text(_dict_to_yaml(nav_tree_sorted))
@@ -111,7 +111,7 @@ def main():
"""Main function to extract class and function names, create Markdown files, and generate a YAML navigation menu."""
nav_items = []
- for py_filepath in CODE_DIR.rglob('*.py'):
+ for py_filepath in CODE_DIR.rglob("*.py"):
classes, functions = extract_classes_and_functions(py_filepath)
if classes or functions:
@@ -124,5 +124,5 @@ def main():
create_nav_menu_yaml(nav_items)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/docs/update_translations.py b/docs/update_translations.py
index 9c27c700..8b64a4a4 100644
--- a/docs/update_translations.py
+++ b/docs/update_translations.py
@@ -22,69 +22,232 @@ class MarkdownLinkFixer:
self.base_dir = Path(base_dir)
self.update_links = update_links
self.update_text = update_text
- self.md_link_regex = re.compile(r'\[([^]]+)]\(([^:)]+)\.md\)')
+ self.md_link_regex = re.compile(r"\[([^]]+)]\(([^:)]+)\.md\)")
@staticmethod
def replace_front_matter(content, lang_dir):
"""Ensure front matter keywords remain in English."""
- english = ['comments', 'description', 'keywords']
+ english = ["comments", "description", "keywords"]
translations = {
- 'zh': ['评论', '描述', '关键词'], # Mandarin Chinese (Simplified) warning, sometimes translates as 关键字
- 'es': ['comentarios', 'descripción', 'palabras clave'], # Spanish
- 'ru': ['комментарии', 'описание', 'ключевые слова'], # Russian
- 'pt': ['comentários', 'descrição', 'palavras-chave'], # Portuguese
- 'fr': ['commentaires', 'description', 'mots-clés'], # French
- 'de': ['kommentare', 'beschreibung', 'schlüsselwörter'], # German
- 'ja': ['コメント', '説明', 'キーワード'], # Japanese
- 'ko': ['댓글', '설명', '키워드'], # Korean
- 'hi': ['टिप्पणियाँ', 'विवरण', 'कीवर्ड'], # Hindi
- 'ar': ['التعليقات', 'الوصف', 'الكلمات الرئيسية'] # Arabic
+ "zh": ["评论", "描述", "关键词"], # Mandarin Chinese (Simplified) warning, sometimes translates as 关键字
+ "es": ["comentarios", "descripción", "palabras clave"], # Spanish
+ "ru": ["комментарии", "описание", "ключевые слова"], # Russian
+ "pt": ["comentários", "descrição", "palavras-chave"], # Portuguese
+ "fr": ["commentaires", "description", "mots-clés"], # French
+ "de": ["kommentare", "beschreibung", "schlüsselwörter"], # German
+ "ja": ["コメント", "説明", "キーワード"], # Japanese
+ "ko": ["댓글", "설명", "키워드"], # Korean
+ "hi": ["टिप्पणियाँ", "विवरण", "कीवर्ड"], # Hindi
+ "ar": ["التعليقات", "الوصف", "الكلمات الرئيسية"], # Arabic
} # front matter translations for comments, description, keyword
for term, eng_key in zip(translations.get(lang_dir.stem, []), english):
- content = re.sub(rf'{term} *[::].*', f'{eng_key}: true', content, flags=re.IGNORECASE) if \
- eng_key == 'comments' else re.sub(rf'{term} *[::] *', f'{eng_key}: ', content, flags=re.IGNORECASE)
+ content = (
+ re.sub(rf"{term} *[::].*", f"{eng_key}: true", content, flags=re.IGNORECASE)
+ if eng_key == "comments"
+ else re.sub(rf"{term} *[::] *", f"{eng_key}: ", content, flags=re.IGNORECASE)
+ )
return content
@staticmethod
def replace_admonitions(content, lang_dir):
"""Ensure front matter keywords remain in English."""
english = [
- 'Note', 'Summary', 'Tip', 'Info', 'Success', 'Question', 'Warning', 'Failure', 'Danger', 'Bug', 'Example',
- 'Quote', 'Abstract', 'Seealso', 'Admonition']
+ "Note",
+ "Summary",
+ "Tip",
+ "Info",
+ "Success",
+ "Question",
+ "Warning",
+ "Failure",
+ "Danger",
+ "Bug",
+ "Example",
+ "Quote",
+ "Abstract",
+ "Seealso",
+ "Admonition",
+ ]
translations = {
- 'en':
- english,
- 'zh': ['笔记', '摘要', '提示', '信息', '成功', '问题', '警告', '失败', '危险', '故障', '示例', '引用', '摘要', '另见', '警告'],
- 'es': [
- 'Nota', 'Resumen', 'Consejo', 'Información', 'Éxito', 'Pregunta', 'Advertencia', 'Fracaso', 'Peligro',
- 'Error', 'Ejemplo', 'Cita', 'Abstracto', 'Véase También', 'Amonestación'],
- 'ru': [
- 'Заметка', 'Сводка', 'Совет', 'Информация', 'Успех', 'Вопрос', 'Предупреждение', 'Неудача', 'Опасность',
- 'Ошибка', 'Пример', 'Цитата', 'Абстракт', 'См. Также', 'Предостережение'],
- 'pt': [
- 'Nota', 'Resumo', 'Dica', 'Informação', 'Sucesso', 'Questão', 'Aviso', 'Falha', 'Perigo', 'Bug',
- 'Exemplo', 'Citação', 'Abstrato', 'Veja Também', 'Advertência'],
- 'fr': [
- 'Note', 'Résumé', 'Conseil', 'Info', 'Succès', 'Question', 'Avertissement', 'Échec', 'Danger', 'Bug',
- 'Exemple', 'Citation', 'Abstrait', 'Voir Aussi', 'Admonestation'],
- 'de': [
- 'Hinweis', 'Zusammenfassung', 'Tipp', 'Info', 'Erfolg', 'Frage', 'Warnung', 'Ausfall', 'Gefahr',
- 'Fehler', 'Beispiel', 'Zitat', 'Abstrakt', 'Siehe Auch', 'Ermahnung'],
- 'ja': ['ノート', '要約', 'ヒント', '情報', '成功', '質問', '警告', '失敗', '危険', 'バグ', '例', '引用', '抄録', '参照', '訓告'],
- 'ko': ['노트', '요약', '팁', '정보', '성공', '질문', '경고', '실패', '위험', '버그', '예제', '인용', '추상', '참조', '경고'],
- 'hi': [
- 'नोट', 'सारांश', 'सुझाव', 'जानकारी', 'सफलता', 'प्रश्न', 'चेतावनी', 'विफलता', 'खतरा', 'बग', 'उदाहरण',
- 'उद्धरण', 'सार', 'देखें भी', 'आगाही'],
- 'ar': [
- 'ملاحظة', 'ملخص', 'نصيحة', 'معلومات', 'نجاح', 'سؤال', 'تحذير', 'فشل', 'خطر', 'عطل', 'مثال', 'اقتباس',
- 'ملخص', 'انظر أيضاً', 'تحذير']}
+ "en": english,
+ "zh": [
+ "笔记",
+ "摘要",
+ "提示",
+ "信息",
+ "成功",
+ "问题",
+ "警告",
+ "失败",
+ "危险",
+ "故障",
+ "示例",
+ "引用",
+ "摘要",
+ "另见",
+ "警告",
+ ],
+ "es": [
+ "Nota",
+ "Resumen",
+ "Consejo",
+ "Información",
+ "Éxito",
+ "Pregunta",
+ "Advertencia",
+ "Fracaso",
+ "Peligro",
+ "Error",
+ "Ejemplo",
+ "Cita",
+ "Abstracto",
+ "Véase También",
+ "Amonestación",
+ ],
+ "ru": [
+ "Заметка",
+ "Сводка",
+ "Совет",
+ "Информация",
+ "Успех",
+ "Вопрос",
+ "Предупреждение",
+ "Неудача",
+ "Опасность",
+ "Ошибка",
+ "Пример",
+ "Цитата",
+ "Абстракт",
+ "См. Также",
+ "Предостережение",
+ ],
+ "pt": [
+ "Nota",
+ "Resumo",
+ "Dica",
+ "Informação",
+ "Sucesso",
+ "Questão",
+ "Aviso",
+ "Falha",
+ "Perigo",
+ "Bug",
+ "Exemplo",
+ "Citação",
+ "Abstrato",
+ "Veja Também",
+ "Advertência",
+ ],
+ "fr": [
+ "Note",
+ "Résumé",
+ "Conseil",
+ "Info",
+ "Succès",
+ "Question",
+ "Avertissement",
+ "Échec",
+ "Danger",
+ "Bug",
+ "Exemple",
+ "Citation",
+ "Abstrait",
+ "Voir Aussi",
+ "Admonestation",
+ ],
+ "de": [
+ "Hinweis",
+ "Zusammenfassung",
+ "Tipp",
+ "Info",
+ "Erfolg",
+ "Frage",
+ "Warnung",
+ "Ausfall",
+ "Gefahr",
+ "Fehler",
+ "Beispiel",
+ "Zitat",
+ "Abstrakt",
+ "Siehe Auch",
+ "Ermahnung",
+ ],
+ "ja": [
+ "ノート",
+ "要約",
+ "ヒント",
+ "情報",
+ "成功",
+ "質問",
+ "警告",
+ "失敗",
+ "危険",
+ "バグ",
+ "例",
+ "引用",
+ "抄録",
+ "参照",
+ "訓告",
+ ],
+ "ko": [
+ "노트",
+ "요약",
+ "팁",
+ "정보",
+ "성공",
+ "질문",
+ "경고",
+ "실패",
+ "위험",
+ "버그",
+ "예제",
+ "인용",
+ "추상",
+ "참조",
+ "경고",
+ ],
+ "hi": [
+ "नोट",
+ "सारांश",
+ "सुझाव",
+ "जानकारी",
+ "सफलता",
+ "प्रश्न",
+ "चेतावनी",
+ "विफलता",
+ "खतरा",
+ "बग",
+ "उदाहरण",
+ "उद्धरण",
+ "सार",
+ "देखें भी",
+ "आगाही",
+ ],
+ "ar": [
+ "ملاحظة",
+ "ملخص",
+ "نصيحة",
+ "معلومات",
+ "نجاح",
+ "سؤال",
+ "تحذير",
+ "فشل",
+ "خطر",
+ "عطل",
+ "مثال",
+ "اقتباس",
+ "ملخص",
+ "انظر أيضاً",
+ "تحذير",
+ ],
+ }
for term, eng_key in zip(translations.get(lang_dir.stem, []), english):
- if lang_dir.stem != 'en':
- content = re.sub(rf'!!! *{eng_key} *\n', f'!!! {eng_key} "{term}"\n', content, flags=re.IGNORECASE)
- content = re.sub(rf'!!! *{term} *\n', f'!!! {eng_key} "{term}"\n', content, flags=re.IGNORECASE)
- content = re.sub(rf'!!! *{term}', f'!!! {eng_key}', content, flags=re.IGNORECASE)
+ if lang_dir.stem != "en":
+ content = re.sub(rf"!!! *{eng_key} *\n", f'!!! {eng_key} "{term}"\n', content, flags=re.IGNORECASE)
+ content = re.sub(rf"!!! *{term} *\n", f'!!! {eng_key} "{term}"\n', content, flags=re.IGNORECASE)
+ content = re.sub(rf"!!! *{term}", f"!!! {eng_key}", content, flags=re.IGNORECASE)
content = re.sub(r'!!! *"', '!!! Example "', content, flags=re.IGNORECASE)
return content
@@ -92,30 +255,30 @@ class MarkdownLinkFixer:
@staticmethod
def update_iframe(content):
"""Update the 'allow' attribute of iframe if it does not contain the specific English permissions."""
- english = 'accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share'
+ english = "accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share"
pattern = re.compile(f'allow="(?!{re.escape(english)}).+?"')
return pattern.sub(f'allow="{english}"', content)
def link_replacer(self, match, parent_dir, lang_dir, use_abs_link=False):
"""Replace broken links with corresponding links in the /en/ directory."""
text, path = match.groups()
- linked_path = (parent_dir / path).resolve().with_suffix('.md')
+ linked_path = (parent_dir / path).resolve().with_suffix(".md")
if not linked_path.exists():
- en_linked_path = Path(str(linked_path).replace(str(lang_dir), str(lang_dir.parent / 'en')))
+ en_linked_path = Path(str(linked_path).replace(str(lang_dir), str(lang_dir.parent / "en")))
if en_linked_path.exists():
if use_abs_link:
# Use absolute links WARNING: BUGS, DO NOT USE
docs_root_relative_path = en_linked_path.relative_to(lang_dir.parent)
- updated_path = str(docs_root_relative_path).replace('en/', '/../')
+ updated_path = str(docs_root_relative_path).replace("en/", "/../")
else:
# Use relative links
steps_up = len(parent_dir.relative_to(self.base_dir).parts)
- updated_path = Path('../' * steps_up) / en_linked_path.relative_to(self.base_dir)
- updated_path = str(updated_path).replace('/en/', '/')
+ updated_path = Path("../" * steps_up) / en_linked_path.relative_to(self.base_dir)
+ updated_path = str(updated_path).replace("/en/", "/")
print(f"Redirecting link '[{text}]({path})' from {parent_dir} to {updated_path}")
- return f'[{text}]({updated_path})'
+ return f"[{text}]({updated_path})"
else:
print(f"Warning: Broken link '[{text}]({path})' found in {parent_dir} does not exist in /docs/en/.")
@@ -124,28 +287,30 @@ class MarkdownLinkFixer:
@staticmethod
def update_html_tags(content):
"""Updates HTML tags in docs."""
- alt_tag = 'MISSING'
+ alt_tag = "MISSING"
# Remove closing slashes from self-closing HTML tags
- pattern = re.compile(r'<([^>]+?)\s*/>')
- content = re.sub(pattern, r'<\1>', content)
+ pattern = re.compile(r"<([^>]+?)\s*/>")
+ content = re.sub(pattern, r"<\1>", content)
# Find all images without alt tags and add placeholder alt text
- pattern = re.compile(r'!\[(.*?)\]\((.*?)\)')
- content, num_replacements = re.subn(pattern, lambda match: f'})',
- content)
+ pattern = re.compile(r"!\[(.*?)\]\((.*?)\)")
+ content, num_replacements = re.subn(
+ pattern, lambda match: f"})", content
+ )
# Add missing alt tags to HTML images
pattern = re.compile(r'
]*src=["\'](.*?)["\'][^>]*>')
- content, num_replacements = re.subn(pattern, lambda match: match.group(0).replace('>', f' alt="{alt_tag}">', 1),
- content)
+ content, num_replacements = re.subn(
+ pattern, lambda match: match.group(0).replace(">", f' alt="{alt_tag}">', 1), content
+ )
return content
def process_markdown_file(self, md_file_path, lang_dir):
"""Process each markdown file in the language directory."""
- print(f'Processing file: {md_file_path}')
- with open(md_file_path, encoding='utf-8') as file:
+ print(f"Processing file: {md_file_path}")
+ with open(md_file_path, encoding="utf-8") as file:
content = file.read()
if self.update_links:
@@ -157,23 +322,23 @@ class MarkdownLinkFixer:
content = self.update_iframe(content)
content = self.update_html_tags(content)
- with open(md_file_path, 'w', encoding='utf-8') as file:
+ with open(md_file_path, "w", encoding="utf-8") as file:
file.write(content)
def process_language_directory(self, lang_dir):
"""Process each language-specific directory."""
- print(f'Processing language directory: {lang_dir}')
- for md_file in lang_dir.rglob('*.md'):
+ print(f"Processing language directory: {lang_dir}")
+ for md_file in lang_dir.rglob("*.md"):
self.process_markdown_file(md_file, lang_dir)
def run(self):
"""Run the link fixing and front matter updating process for each language-specific directory."""
for subdir in self.base_dir.iterdir():
- if subdir.is_dir() and re.match(r'^\w\w$', subdir.name):
+ if subdir.is_dir() and re.match(r"^\w\w$", subdir.name):
self.process_language_directory(subdir)
-if __name__ == '__main__':
+if __name__ == "__main__":
# Set the path to your MkDocs 'docs' directory here
docs_dir = str(Path(__file__).parent.resolve())
fixer = MarkdownLinkFixer(docs_dir, update_links=True, update_text=True)
diff --git a/examples/YOLOv8-ONNXRuntime/main.py b/examples/YOLOv8-ONNXRuntime/main.py
index ec768713..e3755a44 100644
--- a/examples/YOLOv8-ONNXRuntime/main.py
+++ b/examples/YOLOv8-ONNXRuntime/main.py
@@ -28,7 +28,7 @@ class YOLOv8:
self.iou_thres = iou_thres
# Load the class names from the COCO dataset
- self.classes = yaml_load(check_yaml('coco128.yaml'))['names']
+ self.classes = yaml_load(check_yaml("coco128.yaml"))["names"]
# Generate a color palette for the classes
self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
@@ -57,7 +57,7 @@ class YOLOv8:
cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
# Create the label text with class name and score
- label = f'{self.classes[class_id]}: {score:.2f}'
+ label = f"{self.classes[class_id]}: {score:.2f}"
# Calculate the dimensions of the label text
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
@@ -67,8 +67,9 @@ class YOLOv8:
label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
# Draw a filled rectangle as the background for the label text
- cv2.rectangle(img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color,
- cv2.FILLED)
+ cv2.rectangle(
+ img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color, cv2.FILLED
+ )
# Draw the label text on the image
cv2.putText(img, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
@@ -182,7 +183,7 @@ class YOLOv8:
output_img: The output image with drawn detections.
"""
# Create an inference session using the ONNX model and specify execution providers
- session = ort.InferenceSession(self.onnx_model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+ session = ort.InferenceSession(self.onnx_model, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
# Get the model inputs
model_inputs = session.get_inputs()
@@ -202,17 +203,17 @@ class YOLOv8:
return self.postprocess(self.img, outputs) # output image
-if __name__ == '__main__':
+if __name__ == "__main__":
# Create an argument parser to handle command-line arguments
parser = argparse.ArgumentParser()
- parser.add_argument('--model', type=str, default='yolov8n.onnx', help='Input your ONNX model.')
- parser.add_argument('--img', type=str, default=str(ASSETS / 'bus.jpg'), help='Path to input image.')
- parser.add_argument('--conf-thres', type=float, default=0.5, help='Confidence threshold')
- parser.add_argument('--iou-thres', type=float, default=0.5, help='NMS IoU threshold')
+ parser.add_argument("--model", type=str, default="yolov8n.onnx", help="Input your ONNX model.")
+ parser.add_argument("--img", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image.")
+ parser.add_argument("--conf-thres", type=float, default=0.5, help="Confidence threshold")
+ parser.add_argument("--iou-thres", type=float, default=0.5, help="NMS IoU threshold")
args = parser.parse_args()
# Check the requirements and select the appropriate backend (CPU or GPU)
- check_requirements('onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime')
+ check_requirements("onnxruntime-gpu" if torch.cuda.is_available() else "onnxruntime")
# Create an instance of the YOLOv8 class with the specified arguments
detection = YOLOv8(args.model, args.img, args.conf_thres, args.iou_thres)
@@ -221,8 +222,8 @@ if __name__ == '__main__':
output_image = detection.main()
# Display the output image in a window
- cv2.namedWindow('Output', cv2.WINDOW_NORMAL)
- cv2.imshow('Output', output_image)
+ cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
+ cv2.imshow("Output", output_image)
# Wait for a key press to exit
cv2.waitKey(0)
diff --git a/examples/YOLOv8-OpenCV-ONNX-Python/main.py b/examples/YOLOv8-OpenCV-ONNX-Python/main.py
index 78b0b08e..c0564d15 100644
--- a/examples/YOLOv8-OpenCV-ONNX-Python/main.py
+++ b/examples/YOLOv8-OpenCV-ONNX-Python/main.py
@@ -6,7 +6,7 @@ import numpy as np
from ultralytics.utils import ASSETS, yaml_load
from ultralytics.utils.checks import check_yaml
-CLASSES = yaml_load(check_yaml('coco128.yaml'))['names']
+CLASSES = yaml_load(check_yaml("coco128.yaml"))["names"]
colors = np.random.uniform(0, 255, size=(len(CLASSES), 3))
@@ -23,7 +23,7 @@ def draw_bounding_box(img, class_id, confidence, x, y, x_plus_w, y_plus_h):
x_plus_w (int): X-coordinate of the bottom-right corner of the bounding box.
y_plus_h (int): Y-coordinate of the bottom-right corner of the bounding box.
"""
- label = f'{CLASSES[class_id]} ({confidence:.2f})'
+ label = f"{CLASSES[class_id]} ({confidence:.2f})"
color = colors[class_id]
cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2)
cv2.putText(img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
@@ -76,8 +76,11 @@ def main(onnx_model, input_image):
(minScore, maxScore, minClassLoc, (x, maxClassIndex)) = cv2.minMaxLoc(classes_scores)
if maxScore >= 0.25:
box = [
- outputs[0][i][0] - (0.5 * outputs[0][i][2]), outputs[0][i][1] - (0.5 * outputs[0][i][3]),
- outputs[0][i][2], outputs[0][i][3]]
+ outputs[0][i][0] - (0.5 * outputs[0][i][2]),
+ outputs[0][i][1] - (0.5 * outputs[0][i][3]),
+ outputs[0][i][2],
+ outputs[0][i][3],
+ ]
boxes.append(box)
scores.append(maxScore)
class_ids.append(maxClassIndex)
@@ -92,26 +95,34 @@ def main(onnx_model, input_image):
index = result_boxes[i]
box = boxes[index]
detection = {
- 'class_id': class_ids[index],
- 'class_name': CLASSES[class_ids[index]],
- 'confidence': scores[index],
- 'box': box,
- 'scale': scale}
+ "class_id": class_ids[index],
+ "class_name": CLASSES[class_ids[index]],
+ "confidence": scores[index],
+ "box": box,
+ "scale": scale,
+ }
detections.append(detection)
- draw_bounding_box(original_image, class_ids[index], scores[index], round(box[0] * scale), round(box[1] * scale),
- round((box[0] + box[2]) * scale), round((box[1] + box[3]) * scale))
+ draw_bounding_box(
+ original_image,
+ class_ids[index],
+ scores[index],
+ round(box[0] * scale),
+ round(box[1] * scale),
+ round((box[0] + box[2]) * scale),
+ round((box[1] + box[3]) * scale),
+ )
# Display the image with bounding boxes
- cv2.imshow('image', original_image)
+ cv2.imshow("image", original_image)
cv2.waitKey(0)
cv2.destroyAllWindows()
return detections
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--model', default='yolov8n.onnx', help='Input your ONNX model.')
- parser.add_argument('--img', default=str(ASSETS / 'bus.jpg'), help='Path to input image.')
+ parser.add_argument("--model", default="yolov8n.onnx", help="Input your ONNX model.")
+ parser.add_argument("--img", default=str(ASSETS / "bus.jpg"), help="Path to input image.")
args = parser.parse_args()
main(args.model, args.img)
diff --git a/examples/YOLOv8-OpenCV-int8-tflite-Python/main.py b/examples/YOLOv8-OpenCV-int8-tflite-Python/main.py
index 04b39e9e..9c23173c 100644
--- a/examples/YOLOv8-OpenCV-int8-tflite-Python/main.py
+++ b/examples/YOLOv8-OpenCV-int8-tflite-Python/main.py
@@ -13,14 +13,9 @@ img_height = 640
class LetterBox:
-
- def __init__(self,
- new_shape=(img_width, img_height),
- auto=False,
- scaleFill=False,
- scaleup=True,
- center=True,
- stride=32):
+ def __init__(
+ self, new_shape=(img_width, img_height), auto=False, scaleFill=False, scaleup=True, center=True, stride=32
+ ):
self.new_shape = new_shape
self.auto = auto
self.scaleFill = scaleFill
@@ -33,9 +28,9 @@ class LetterBox:
if labels is None:
labels = {}
- img = labels.get('img') if image is None else image
+ img = labels.get("img") if image is None else image
shape = img.shape[:2] # current shape [height, width]
- new_shape = labels.pop('rect_shape', self.new_shape)
+ new_shape = labels.pop("rect_shape", self.new_shape)
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
@@ -63,15 +58,16 @@ class LetterBox:
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
- img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
- value=(114, 114, 114)) # add border
- if labels.get('ratio_pad'):
- labels['ratio_pad'] = (labels['ratio_pad'], (left, top)) # for evaluation
+ img = cv2.copyMakeBorder(
+ img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
+ ) # add border
+ if labels.get("ratio_pad"):
+ labels["ratio_pad"] = (labels["ratio_pad"], (left, top)) # for evaluation
if len(labels):
labels = self._update_labels(labels, ratio, dw, dh)
- labels['img'] = img
- labels['resized_shape'] = new_shape
+ labels["img"] = img
+ labels["resized_shape"] = new_shape
return labels
else:
return img
@@ -79,15 +75,14 @@ class LetterBox:
def _update_labels(self, labels, ratio, padw, padh):
"""Update labels."""
- labels['instances'].convert_bbox(format='xyxy')
- labels['instances'].denormalize(*labels['img'].shape[:2][::-1])
- labels['instances'].scale(*ratio)
- labels['instances'].add_padding(padw, padh)
+ labels["instances"].convert_bbox(format="xyxy")
+ labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
+ labels["instances"].scale(*ratio)
+ labels["instances"].add_padding(padw, padh)
return labels
class Yolov8TFLite:
-
def __init__(self, tflite_model, input_image, confidence_thres, iou_thres):
"""
Initializes an instance of the Yolov8TFLite class.
@@ -105,7 +100,7 @@ class Yolov8TFLite:
self.iou_thres = iou_thres
# Load the class names from the COCO dataset
- self.classes = yaml_load(check_yaml('coco128.yaml'))['names']
+ self.classes = yaml_load(check_yaml("coco128.yaml"))["names"]
# Generate a color palette for the classes
self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
@@ -134,7 +129,7 @@ class Yolov8TFLite:
cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
# Create the label text with class name and score
- label = f'{self.classes[class_id]}: {score:.2f}'
+ label = f"{self.classes[class_id]}: {score:.2f}"
# Calculate the dimensions of the label text
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
@@ -144,8 +139,13 @@ class Yolov8TFLite:
label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
# Draw a filled rectangle as the background for the label text
- cv2.rectangle(img, (int(label_x), int(label_y - label_height)),
- (int(label_x + label_width), int(label_y + label_height)), color, cv2.FILLED)
+ cv2.rectangle(
+ img,
+ (int(label_x), int(label_y - label_height)),
+ (int(label_x + label_width), int(label_y + label_height)),
+ color,
+ cv2.FILLED,
+ )
# Draw the label text on the image
cv2.putText(img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
@@ -161,7 +161,7 @@ class Yolov8TFLite:
# Read the input image using OpenCV
self.img = cv2.imread(self.input_image)
- print('image befor', self.img)
+ print("image before", self.img)
# Get the height and width of the input image
self.img_height, self.img_width = self.img.shape[:2]
@@ -209,8 +209,10 @@ class Yolov8TFLite:
# Get the box, score, and class ID corresponding to the index
box = boxes[i]
gain = min(img_width / self.img_width, img_height / self.img_height)
- pad = round((img_width - self.img_width * gain) / 2 -
- 0.1), round((img_height - self.img_height * gain) / 2 - 0.1)
+ pad = (
+ round((img_width - self.img_width * gain) / 2 - 0.1),
+ round((img_height - self.img_height * gain) / 2 - 0.1),
+ )
box[0] = (box[0] - pad[0]) / gain
box[1] = (box[1] - pad[1]) / gain
box[2] = box[2] / gain
@@ -242,7 +244,7 @@ class Yolov8TFLite:
output_details = interpreter.get_output_details()
# Store the shape of the input for later use
- input_shape = input_details[0]['shape']
+ input_shape = input_details[0]["shape"]
self.input_width = input_shape[1]
self.input_height = input_shape[2]
@@ -251,19 +253,19 @@ class Yolov8TFLite:
img_data = img_data
# img_data = img_data.cpu().numpy()
# Set the input tensor to the interpreter
- print(input_details[0]['index'])
+ print(input_details[0]["index"])
print(img_data.shape)
img_data = img_data.transpose((0, 2, 3, 1))
- scale, zero_point = input_details[0]['quantization']
- interpreter.set_tensor(input_details[0]['index'], img_data)
+ scale, zero_point = input_details[0]["quantization"]
+ interpreter.set_tensor(input_details[0]["index"], img_data)
# Run inference
interpreter.invoke()
# Get the output tensor from the interpreter
- output = interpreter.get_tensor(output_details[0]['index'])
- scale, zero_point = output_details[0]['quantization']
+ output = interpreter.get_tensor(output_details[0]["index"])
+ scale, zero_point = output_details[0]["quantization"]
output = (output.astype(np.float32) - zero_point) * scale
output[:, [0, 2]] *= img_width
@@ -273,16 +275,15 @@ class Yolov8TFLite:
return self.postprocess(self.img, output)
-if __name__ == '__main__':
+if __name__ == "__main__":
# Create an argument parser to handle command-line arguments
parser = argparse.ArgumentParser()
- parser.add_argument('--model',
- type=str,
- default='yolov8n_full_integer_quant.tflite',
- help='Input your TFLite model.')
- parser.add_argument('--img', type=str, default=str(ASSETS / 'bus.jpg'), help='Path to input image.')
- parser.add_argument('--conf-thres', type=float, default=0.5, help='Confidence threshold')
- parser.add_argument('--iou-thres', type=float, default=0.5, help='NMS IoU threshold')
+ parser.add_argument(
+ "--model", type=str, default="yolov8n_full_integer_quant.tflite", help="Input your TFLite model."
+ )
+ parser.add_argument("--img", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image.")
+ parser.add_argument("--conf-thres", type=float, default=0.5, help="Confidence threshold")
+ parser.add_argument("--iou-thres", type=float, default=0.5, help="NMS IoU threshold")
args = parser.parse_args()
# Create an instance of the Yolov8TFLite class with the specified arguments
@@ -292,7 +293,7 @@ if __name__ == '__main__':
output_image = detection.main()
# Display the output image in a window
- cv2.imshow('Output', output_image)
+ cv2.imshow("Output", output_image)
# Wait for a key press to exit
cv2.waitKey(0)
diff --git a/examples/YOLOv8-Region-Counter/yolov8_region_counter.py b/examples/YOLOv8-Region-Counter/yolov8_region_counter.py
index 5379fd3b..70bcfd5a 100644
--- a/examples/YOLOv8-Region-Counter/yolov8_region_counter.py
+++ b/examples/YOLOv8-Region-Counter/yolov8_region_counter.py
@@ -16,21 +16,22 @@ track_history = defaultdict(list)
current_region = None
counting_regions = [
{
- 'name': 'YOLOv8 Polygon Region',
- 'polygon': Polygon([(50, 80), (250, 20), (450, 80), (400, 350), (100, 350)]), # Polygon points
- 'counts': 0,
- 'dragging': False,
- 'region_color': (255, 42, 4), # BGR Value
- 'text_color': (255, 255, 255) # Region Text Color
+ "name": "YOLOv8 Polygon Region",
+ "polygon": Polygon([(50, 80), (250, 20), (450, 80), (400, 350), (100, 350)]), # Polygon points
+ "counts": 0,
+ "dragging": False,
+ "region_color": (255, 42, 4), # BGR Value
+ "text_color": (255, 255, 255), # Region Text Color
},
{
- 'name': 'YOLOv8 Rectangle Region',
- 'polygon': Polygon([(200, 250), (440, 250), (440, 550), (200, 550)]), # Polygon points
- 'counts': 0,
- 'dragging': False,
- 'region_color': (37, 255, 225), # BGR Value
- 'text_color': (0, 0, 0), # Region Text Color
- }, ]
+ "name": "YOLOv8 Rectangle Region",
+ "polygon": Polygon([(200, 250), (440, 250), (440, 550), (200, 550)]), # Polygon points
+ "counts": 0,
+ "dragging": False,
+ "region_color": (37, 255, 225), # BGR Value
+ "text_color": (0, 0, 0), # Region Text Color
+ },
+]
def mouse_callback(event, x, y, flags, param):
@@ -40,32 +41,33 @@ def mouse_callback(event, x, y, flags, param):
# Mouse left button down event
if event == cv2.EVENT_LBUTTONDOWN:
for region in counting_regions:
- if region['polygon'].contains(Point((x, y))):
+ if region["polygon"].contains(Point((x, y))):
current_region = region
- current_region['dragging'] = True
- current_region['offset_x'] = x
- current_region['offset_y'] = y
+ current_region["dragging"] = True
+ current_region["offset_x"] = x
+ current_region["offset_y"] = y
# Mouse move event
elif event == cv2.EVENT_MOUSEMOVE:
- if current_region is not None and current_region['dragging']:
- dx = x - current_region['offset_x']
- dy = y - current_region['offset_y']
- current_region['polygon'] = Polygon([
- (p[0] + dx, p[1] + dy) for p in current_region['polygon'].exterior.coords])
- current_region['offset_x'] = x
- current_region['offset_y'] = y
+ if current_region is not None and current_region["dragging"]:
+ dx = x - current_region["offset_x"]
+ dy = y - current_region["offset_y"]
+ current_region["polygon"] = Polygon(
+ [(p[0] + dx, p[1] + dy) for p in current_region["polygon"].exterior.coords]
+ )
+ current_region["offset_x"] = x
+ current_region["offset_y"] = y
# Mouse left button up event
elif event == cv2.EVENT_LBUTTONUP:
- if current_region is not None and current_region['dragging']:
- current_region['dragging'] = False
+ if current_region is not None and current_region["dragging"]:
+ current_region["dragging"] = False
def run(
- weights='yolov8n.pt',
+ weights="yolov8n.pt",
source=None,
- device='cpu',
+ device="cpu",
view_img=False,
save_img=False,
exist_ok=False,
@@ -100,8 +102,8 @@ def run(
raise FileNotFoundError(f"Source path '{source}' does not exist.")
# Setup Model
- model = YOLO(f'{weights}')
- model.to('cuda') if device == '0' else model.to('cpu')
+ model = YOLO(f"{weights}")
+ model.to("cuda") if device == "0" else model.to("cpu")
# Extract classes names
names = model.model.names
@@ -109,12 +111,12 @@ def run(
# Video setup
videocapture = cv2.VideoCapture(source)
frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4))
- fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*'mp4v')
+ fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*"mp4v")
# Output setup
- save_dir = increment_path(Path('ultralytics_rc_output') / 'exp', exist_ok)
+ save_dir = increment_path(Path("ultralytics_rc_output") / "exp", exist_ok)
save_dir.mkdir(parents=True, exist_ok=True)
- video_writer = cv2.VideoWriter(str(save_dir / f'{Path(source).stem}.mp4'), fourcc, fps, (frame_width, frame_height))
+ video_writer = cv2.VideoWriter(str(save_dir / f"{Path(source).stem}.mp4"), fourcc, fps, (frame_width, frame_height))
# Iterate over video frames
while videocapture.isOpened():
@@ -146,43 +148,48 @@ def run(
# Check if detection inside region
for region in counting_regions:
- if region['polygon'].contains(Point((bbox_center[0], bbox_center[1]))):
- region['counts'] += 1
+ if region["polygon"].contains(Point((bbox_center[0], bbox_center[1]))):
+ region["counts"] += 1
# Draw regions (Polygons/Rectangles)
for region in counting_regions:
- region_label = str(region['counts'])
- region_color = region['region_color']
- region_text_color = region['text_color']
+ region_label = str(region["counts"])
+ region_color = region["region_color"]
+ region_text_color = region["text_color"]
- polygon_coords = np.array(region['polygon'].exterior.coords, dtype=np.int32)
- centroid_x, centroid_y = int(region['polygon'].centroid.x), int(region['polygon'].centroid.y)
+ polygon_coords = np.array(region["polygon"].exterior.coords, dtype=np.int32)
+ centroid_x, centroid_y = int(region["polygon"].centroid.x), int(region["polygon"].centroid.y)
- text_size, _ = cv2.getTextSize(region_label,
- cv2.FONT_HERSHEY_SIMPLEX,
- fontScale=0.7,
- thickness=line_thickness)
+ text_size, _ = cv2.getTextSize(
+ region_label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.7, thickness=line_thickness
+ )
text_x = centroid_x - text_size[0] // 2
text_y = centroid_y + text_size[1] // 2
- cv2.rectangle(frame, (text_x - 5, text_y - text_size[1] - 5), (text_x + text_size[0] + 5, text_y + 5),
- region_color, -1)
- cv2.putText(frame, region_label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, region_text_color,
- line_thickness)
+ cv2.rectangle(
+ frame,
+ (text_x - 5, text_y - text_size[1] - 5),
+ (text_x + text_size[0] + 5, text_y + 5),
+ region_color,
+ -1,
+ )
+ cv2.putText(
+ frame, region_label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, region_text_color, line_thickness
+ )
cv2.polylines(frame, [polygon_coords], isClosed=True, color=region_color, thickness=region_thickness)
if view_img:
if vid_frame_count == 1:
- cv2.namedWindow('Ultralytics YOLOv8 Region Counter Movable')
- cv2.setMouseCallback('Ultralytics YOLOv8 Region Counter Movable', mouse_callback)
- cv2.imshow('Ultralytics YOLOv8 Region Counter Movable', frame)
+ cv2.namedWindow("Ultralytics YOLOv8 Region Counter Movable")
+ cv2.setMouseCallback("Ultralytics YOLOv8 Region Counter Movable", mouse_callback)
+ cv2.imshow("Ultralytics YOLOv8 Region Counter Movable", frame)
if save_img:
video_writer.write(frame)
for region in counting_regions: # Reinitialize count for each region
- region['counts'] = 0
+ region["counts"] = 0
- if cv2.waitKey(1) & 0xFF == ord('q'):
+ if cv2.waitKey(1) & 0xFF == ord("q"):
break
del vid_frame_count
@@ -194,16 +201,16 @@ def run(
def parse_opt():
"""Parse command line arguments."""
parser = argparse.ArgumentParser()
- parser.add_argument('--weights', type=str, default='yolov8n.pt', help='initial weights path')
- parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
- parser.add_argument('--source', type=str, required=True, help='video file path')
- parser.add_argument('--view-img', action='store_true', help='show results')
- parser.add_argument('--save-img', action='store_true', help='save results')
- parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
- parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
- parser.add_argument('--line-thickness', type=int, default=2, help='bounding box thickness')
- parser.add_argument('--track-thickness', type=int, default=2, help='Tracking line thickness')
- parser.add_argument('--region-thickness', type=int, default=4, help='Region thickness')
+ parser.add_argument("--weights", type=str, default="yolov8n.pt", help="initial weights path")
+ parser.add_argument("--device", default="", help="cuda device, i.e. 0 or 0,1,2,3 or cpu")
+ parser.add_argument("--source", type=str, required=True, help="video file path")
+ parser.add_argument("--view-img", action="store_true", help="show results")
+ parser.add_argument("--save-img", action="store_true", help="save results")
+ parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment")
+ parser.add_argument("--classes", nargs="+", type=int, help="filter by class: --classes 0, or --classes 0 2 3")
+ parser.add_argument("--line-thickness", type=int, default=2, help="bounding box thickness")
+ parser.add_argument("--track-thickness", type=int, default=2, help="Tracking line thickness")
+ parser.add_argument("--region-thickness", type=int, default=4, help="Region thickness")
return parser.parse_args()
@@ -213,6 +220,6 @@ def main(opt):
run(**vars(opt))
-if __name__ == '__main__':
+if __name__ == "__main__":
opt = parse_opt()
main(opt)
diff --git a/examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py b/examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py
index 7ab84417..37d78114 100644
--- a/examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py
+++ b/examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py
@@ -9,7 +9,7 @@ from sahi.utils.yolov8 import download_yolov8s_model
from ultralytics.utils.files import increment_path
-def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False, exist_ok=False):
+def run(weights="yolov8n.pt", source="test.mp4", view_img=False, save_img=False, exist_ok=False):
"""
Run object detection on a video using YOLOv8 and SAHI.
@@ -25,41 +25,41 @@ def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False,
if not Path(source).exists():
raise FileNotFoundError(f"Source path '{source}' does not exist.")
- yolov8_model_path = f'models/{weights}'
+ yolov8_model_path = f"models/{weights}"
download_yolov8s_model(yolov8_model_path)
- detection_model = AutoDetectionModel.from_pretrained(model_type='yolov8',
- model_path=yolov8_model_path,
- confidence_threshold=0.3,
- device='cpu')
+ detection_model = AutoDetectionModel.from_pretrained(
+ model_type="yolov8", model_path=yolov8_model_path, confidence_threshold=0.3, device="cpu"
+ )
# Video setup
videocapture = cv2.VideoCapture(source)
frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4))
- fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*'mp4v')
+ fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*"mp4v")
# Output setup
- save_dir = increment_path(Path('ultralytics_results_with_sahi') / 'exp', exist_ok)
+ save_dir = increment_path(Path("ultralytics_results_with_sahi") / "exp", exist_ok)
save_dir.mkdir(parents=True, exist_ok=True)
- video_writer = cv2.VideoWriter(str(save_dir / f'{Path(source).stem}.mp4'), fourcc, fps, (frame_width, frame_height))
+ video_writer = cv2.VideoWriter(str(save_dir / f"{Path(source).stem}.mp4"), fourcc, fps, (frame_width, frame_height))
while videocapture.isOpened():
success, frame = videocapture.read()
if not success:
break
- results = get_sliced_prediction(frame,
- detection_model,
- slice_height=512,
- slice_width=512,
- overlap_height_ratio=0.2,
- overlap_width_ratio=0.2)
+ results = get_sliced_prediction(
+ frame, detection_model, slice_height=512, slice_width=512, overlap_height_ratio=0.2, overlap_width_ratio=0.2
+ )
object_prediction_list = results.object_prediction_list
boxes_list = []
clss_list = []
for ind, _ in enumerate(object_prediction_list):
- boxes = object_prediction_list[ind].bbox.minx, object_prediction_list[ind].bbox.miny, \
- object_prediction_list[ind].bbox.maxx, object_prediction_list[ind].bbox.maxy
+ boxes = (
+ object_prediction_list[ind].bbox.minx,
+ object_prediction_list[ind].bbox.miny,
+ object_prediction_list[ind].bbox.maxx,
+ object_prediction_list[ind].bbox.maxy,
+ )
clss = object_prediction_list[ind].category.name
boxes_list.append(boxes)
clss_list.append(clss)
@@ -69,21 +69,19 @@ def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False,
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (56, 56, 255), 2)
label = str(cls)
t_size = cv2.getTextSize(label, 0, fontScale=0.6, thickness=1)[0]
- cv2.rectangle(frame, (int(x1), int(y1) - t_size[1] - 3), (int(x1) + t_size[0], int(y1) + 3), (56, 56, 255),
- -1)
- cv2.putText(frame,
- label, (int(x1), int(y1) - 2),
- 0,
- 0.6, [255, 255, 255],
- thickness=1,
- lineType=cv2.LINE_AA)
+ cv2.rectangle(
+ frame, (int(x1), int(y1) - t_size[1] - 3), (int(x1) + t_size[0], int(y1) + 3), (56, 56, 255), -1
+ )
+ cv2.putText(
+ frame, label, (int(x1), int(y1) - 2), 0, 0.6, [255, 255, 255], thickness=1, lineType=cv2.LINE_AA
+ )
if view_img:
cv2.imshow(Path(source).stem, frame)
if save_img:
video_writer.write(frame)
- if cv2.waitKey(1) & 0xFF == ord('q'):
+ if cv2.waitKey(1) & 0xFF == ord("q"):
break
video_writer.release()
videocapture.release()
@@ -93,11 +91,11 @@ def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False,
def parse_opt():
"""Parse command line arguments."""
parser = argparse.ArgumentParser()
- parser.add_argument('--weights', type=str, default='yolov8n.pt', help='initial weights path')
- parser.add_argument('--source', type=str, required=True, help='video file path')
- parser.add_argument('--view-img', action='store_true', help='show results')
- parser.add_argument('--save-img', action='store_true', help='save results')
- parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
+ parser.add_argument("--weights", type=str, default="yolov8n.pt", help="initial weights path")
+ parser.add_argument("--source", type=str, required=True, help="video file path")
+ parser.add_argument("--view-img", action="store_true", help="show results")
+ parser.add_argument("--save-img", action="store_true", help="save results")
+ parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment")
return parser.parse_args()
@@ -106,6 +104,6 @@ def main(opt):
run(**vars(opt))
-if __name__ == '__main__':
+if __name__ == "__main__":
opt = parse_opt()
main(opt)
diff --git a/examples/YOLOv8-Segmentation-ONNXRuntime-Python/main.py b/examples/YOLOv8-Segmentation-ONNXRuntime-Python/main.py
index b13eab35..7dd11dc3 100644
--- a/examples/YOLOv8-Segmentation-ONNXRuntime-Python/main.py
+++ b/examples/YOLOv8-Segmentation-ONNXRuntime-Python/main.py
@@ -21,18 +21,21 @@ class YOLOv8Seg:
"""
# Build Ort session
- self.session = ort.InferenceSession(onnx_model,
- providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
- if ort.get_device() == 'GPU' else ['CPUExecutionProvider'])
+ self.session = ort.InferenceSession(
+ onnx_model,
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
+ if ort.get_device() == "GPU"
+ else ["CPUExecutionProvider"],
+ )
# Numpy dtype: support both FP32 and FP16 onnx model
- self.ndtype = np.half if self.session.get_inputs()[0].type == 'tensor(float16)' else np.single
+ self.ndtype = np.half if self.session.get_inputs()[0].type == "tensor(float16)" else np.single
# Get model width and height(YOLOv8-seg only has one input)
self.model_height, self.model_width = [x.shape for x in self.session.get_inputs()][0][-2:]
# Load COCO class names
- self.classes = yaml_load(check_yaml('coco128.yaml'))['names']
+ self.classes = yaml_load(check_yaml("coco128.yaml"))["names"]
# Create color palette
self.color_palette = Colors()
@@ -60,14 +63,16 @@ class YOLOv8Seg:
preds = self.session.run(None, {self.session.get_inputs()[0].name: im})
# Post-process
- boxes, segments, masks = self.postprocess(preds,
- im0=im0,
- ratio=ratio,
- pad_w=pad_w,
- pad_h=pad_h,
- conf_threshold=conf_threshold,
- iou_threshold=iou_threshold,
- nm=nm)
+ boxes, segments, masks = self.postprocess(
+ preds,
+ im0=im0,
+ ratio=ratio,
+ pad_w=pad_w,
+ pad_h=pad_h,
+ conf_threshold=conf_threshold,
+ iou_threshold=iou_threshold,
+ nm=nm,
+ )
return boxes, segments, masks
def preprocess(self, img):
@@ -98,7 +103,7 @@ class YOLOv8Seg:
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
# Transforms: HWC to CHW -> BGR to RGB -> div(255) -> contiguous -> add axis(optional)
- img = np.ascontiguousarray(np.einsum('HWC->CHW', img)[::-1], dtype=self.ndtype) / 255.0
+ img = np.ascontiguousarray(np.einsum("HWC->CHW", img)[::-1], dtype=self.ndtype) / 255.0
img_process = img[None] if len(img.shape) == 3 else img
return img_process, ratio, (pad_w, pad_h)
@@ -124,7 +129,7 @@ class YOLOv8Seg:
x, protos = preds[0], preds[1] # Two outputs: predictions and protos
# Transpose the first output: (Batch_size, xywh_conf_cls_nm, Num_anchors) -> (Batch_size, Num_anchors, xywh_conf_cls_nm)
- x = np.einsum('bcn->bnc', x)
+ x = np.einsum("bcn->bnc", x)
# Predictions filtering by conf-threshold
x = x[np.amax(x[..., 4:-nm], axis=-1) > conf_threshold]
@@ -138,7 +143,6 @@ class YOLOv8Seg:
# Decode and return
if len(x) > 0:
-
# Bounding boxes format change: cxcywh -> xyxy
x[..., [0, 1]] -= x[..., [2, 3]] / 2
x[..., [2, 3]] += x[..., [0, 1]]
@@ -173,13 +177,13 @@ class YOLOv8Seg:
segments (List): list of segment masks.
"""
segments = []
- for x in masks.astype('uint8'):
+ for x in masks.astype("uint8"):
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[0] # CHAIN_APPROX_SIMPLE
if c:
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
else:
c = np.zeros((0, 2)) # no segments found
- segments.append(c.astype('float32'))
+ segments.append(c.astype("float32"))
return segments
@staticmethod
@@ -219,7 +223,7 @@ class YOLOv8Seg:
masks = np.matmul(masks_in, protos.reshape((c, -1))).reshape((-1, mh, mw)).transpose(1, 2, 0) # HWN
masks = np.ascontiguousarray(masks)
masks = self.scale_mask(masks, im0_shape) # re-scale mask from P3 shape to original input image shape
- masks = np.einsum('HWN -> NHW', masks) # HWN -> NHW
+ masks = np.einsum("HWN -> NHW", masks) # HWN -> NHW
masks = self.crop_mask(masks, bboxes)
return np.greater(masks, 0.5)
@@ -250,8 +254,9 @@ class YOLOv8Seg:
if len(masks.shape) < 2:
raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
masks = masks[top:bottom, left:right]
- masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]),
- interpolation=cv2.INTER_LINEAR) # INTER_CUBIC would be better
+ masks = cv2.resize(
+ masks, (im0_shape[1], im0_shape[0]), interpolation=cv2.INTER_LINEAR
+ ) # INTER_CUBIC would be better
if len(masks.shape) == 2:
masks = masks[:, :, None]
return masks
@@ -279,32 +284,46 @@ class YOLOv8Seg:
cv2.fillPoly(im_canvas, np.int32([segment]), self.color_palette(int(cls_), bgr=True))
# draw bbox rectangle
- cv2.rectangle(im, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])),
- self.color_palette(int(cls_), bgr=True), 1, cv2.LINE_AA)
- cv2.putText(im, f'{self.classes[cls_]}: {conf:.3f}', (int(box[0]), int(box[1] - 9)),
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, self.color_palette(int(cls_), bgr=True), 2, cv2.LINE_AA)
+ cv2.rectangle(
+ im,
+ (int(box[0]), int(box[1])),
+ (int(box[2]), int(box[3])),
+ self.color_palette(int(cls_), bgr=True),
+ 1,
+ cv2.LINE_AA,
+ )
+ cv2.putText(
+ im,
+ f"{self.classes[cls_]}: {conf:.3f}",
+ (int(box[0]), int(box[1] - 9)),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ 0.7,
+ self.color_palette(int(cls_), bgr=True),
+ 2,
+ cv2.LINE_AA,
+ )
# Mix image
im = cv2.addWeighted(im_canvas, 0.3, im, 0.7, 0)
# Show image
if vis:
- cv2.imshow('demo', im)
+ cv2.imshow("demo", im)
cv2.waitKey(0)
cv2.destroyAllWindows()
# Save image
if save:
- cv2.imwrite('demo.jpg', im)
+ cv2.imwrite("demo.jpg", im)
-if __name__ == '__main__':
+if __name__ == "__main__":
# Create an argument parser to handle command-line arguments
parser = argparse.ArgumentParser()
- parser.add_argument('--model', type=str, required=True, help='Path to ONNX model')
- parser.add_argument('--source', type=str, default=str(ASSETS / 'bus.jpg'), help='Path to input image')
- parser.add_argument('--conf', type=float, default=0.25, help='Confidence threshold')
- parser.add_argument('--iou', type=float, default=0.45, help='NMS IoU threshold')
+ parser.add_argument("--model", type=str, required=True, help="Path to ONNX model")
+ parser.add_argument("--source", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image")
+ parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
+ parser.add_argument("--iou", type=float, default=0.45, help="NMS IoU threshold")
args = parser.parse_args()
# Build model
diff --git a/pyproject.toml b/pyproject.toml
index 7b734356..a994e2ab 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -107,9 +107,9 @@ export = [
]
explorer = [
- "lancedb", # vector search
- "duckdb", # SQL queries, supports lancedb tables
- "streamlit", # visualizing with GUI
+ "lancedb", # vector search
+ "duckdb", # SQL queries, supports lancedb tables
+ "streamlit", # visualizing with GUI
]
# tensorflow>=2.4.1,<=2.13.1 # TF exports (-cpu, -aarch64, -macos)
@@ -179,5 +179,5 @@ pre-summary-newline = true
close-quotes-on-newline = true
[tool.codespell]
-ignore-words-list = "crate,nd,strack,dota,ane,segway,fo,gool,winn,commend"
-skip = '*.csv,*venv*,docs/??/,docs/mkdocs_??.yml'
+ignore-words-list = "crate,nd,ned,strack,dota,ane,segway,fo,gool,winn,commend,bloc,nam,afterall"
+skip = "*.pt,*.pth,*.torchscript,*.onnx,*.tflite,*.pb,*.bin,*.param,*.mlmodel,*.engine,*.npy,*.data*,*.csv,*pnnx*,*venv*,__pycache__*,*.ico,*.jpg,*.png,*.mp4,*.mov,/runs,/.git,./docs/??/*.md,./docs/mkdocs_??.yml"
diff --git a/tests/test_explorer.py b/tests/test_explorer.py
index 23121c43..6a0995bc 100644
--- a/tests/test_explorer.py
+++ b/tests/test_explorer.py
@@ -3,6 +3,8 @@ import PIL
from ultralytics import Explorer
from ultralytics.utils import ASSETS
+import PIL
+
def test_similarity():
"""Test similarity calculations and SQL queries for correctness and response length."""
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index 4656c5a9..a8a23d32 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-__version__ = '8.0.238'
+__version__ = "8.0.239"
from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO
@@ -10,4 +10,4 @@ from ultralytics.utils import SETTINGS as settings
from ultralytics.utils.checks import check_yolo as checks
from ultralytics.utils.downloads import download
-__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'settings', 'Explorer'
+__all__ = "__version__", "YOLO", "NAS", "SAM", "FastSAM", "RTDETR", "checks", "download", "settings", "Explorer"
diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py
index 2488aa6b..7d9aca7e 100644
--- a/ultralytics/cfg/__init__.py
+++ b/ultralytics/cfg/__init__.py
@@ -8,34 +8,53 @@ from pathlib import Path
from types import SimpleNamespace
from typing import Dict, List, Union
-from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, RANK, ROOT, RUNS_DIR,
- SETTINGS, SETTINGS_YAML, TESTS_RUNNING, IterableSimpleNamespace, __version__, checks,
- colorstr, deprecation_warn, yaml_load, yaml_print)
+from ultralytics.utils import (
+ ASSETS,
+ DEFAULT_CFG,
+ DEFAULT_CFG_DICT,
+ DEFAULT_CFG_PATH,
+ LOGGER,
+ RANK,
+ ROOT,
+ RUNS_DIR,
+ SETTINGS,
+ SETTINGS_YAML,
+ TESTS_RUNNING,
+ IterableSimpleNamespace,
+ __version__,
+ checks,
+ colorstr,
+ deprecation_warn,
+ yaml_load,
+ yaml_print,
+)
# Define valid tasks and modes
-MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
-TASKS = 'detect', 'segment', 'classify', 'pose', 'obb'
+MODES = "train", "val", "predict", "export", "track", "benchmark"
+TASKS = "detect", "segment", "classify", "pose", "obb"
TASK2DATA = {
- 'detect': 'coco8.yaml',
- 'segment': 'coco8-seg.yaml',
- 'classify': 'imagenet10',
- 'pose': 'coco8-pose.yaml',
- 'obb': 'dota8-obb.yaml'} # not implemented yet
+ "detect": "coco8.yaml",
+ "segment": "coco8-seg.yaml",
+ "classify": "imagenet10",
+ "pose": "coco8-pose.yaml",
+ "obb": "dota8-obb.yaml",
+}
TASK2MODEL = {
- 'detect': 'yolov8n.pt',
- 'segment': 'yolov8n-seg.pt',
- 'classify': 'yolov8n-cls.pt',
- 'pose': 'yolov8n-pose.pt',
- 'obb': 'yolov8n-obb.pt'}
+ "detect": "yolov8n.pt",
+ "segment": "yolov8n-seg.pt",
+ "classify": "yolov8n-cls.pt",
+ "pose": "yolov8n-pose.pt",
+ "obb": "yolov8n-obb.pt",
+}
TASK2METRIC = {
- 'detect': 'metrics/mAP50-95(B)',
- 'segment': 'metrics/mAP50-95(M)',
- 'classify': 'metrics/accuracy_top1',
- 'pose': 'metrics/mAP50-95(P)',
- 'obb': 'metrics/mAP50-95(OBB)'}
-
-CLI_HELP_MSG = \
- f"""
+ "detect": "metrics/mAP50-95(B)",
+ "segment": "metrics/mAP50-95(M)",
+ "classify": "metrics/accuracy_top1",
+ "pose": "metrics/mAP50-95(P)",
+ "obb": "metrics/mAP50-95(OBB)",
+}
+
+CLI_HELP_MSG = f"""
Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax:
yolo TASK MODE ARGS
@@ -74,16 +93,83 @@ CLI_HELP_MSG = \
"""
# Define keys for arg type checks
-CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear', 'time'
-CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr',
- 'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud',
- 'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou', 'fraction') # fraction floats 0.0 - 1.0
-CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
- 'line_width', 'workspace', 'nbs', 'save_period')
-CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', 'rect', 'cos_lr', 'overlap_mask', 'val',
- 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf', 'save_crop',
- 'save_frames', 'show_labels', 'show_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks',
- 'show_boxes', 'keras', 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'profile', 'multi_scale')
+CFG_FLOAT_KEYS = "warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time"
+CFG_FRACTION_KEYS = (
+ "dropout",
+ "iou",
+ "lr0",
+ "lrf",
+ "momentum",
+ "weight_decay",
+ "warmup_momentum",
+ "warmup_bias_lr",
+ "label_smoothing",
+ "hsv_h",
+ "hsv_s",
+ "hsv_v",
+ "translate",
+ "scale",
+ "perspective",
+ "flipud",
+ "fliplr",
+ "mosaic",
+ "mixup",
+ "copy_paste",
+ "conf",
+ "iou",
+ "fraction",
+) # fraction floats 0.0 - 1.0
+CFG_INT_KEYS = (
+ "epochs",
+ "patience",
+ "batch",
+ "workers",
+ "seed",
+ "close_mosaic",
+ "mask_ratio",
+ "max_det",
+ "vid_stride",
+ "line_width",
+ "workspace",
+ "nbs",
+ "save_period",
+)
+CFG_BOOL_KEYS = (
+ "save",
+ "exist_ok",
+ "verbose",
+ "deterministic",
+ "single_cls",
+ "rect",
+ "cos_lr",
+ "overlap_mask",
+ "val",
+ "save_json",
+ "save_hybrid",
+ "half",
+ "dnn",
+ "plots",
+ "show",
+ "save_txt",
+ "save_conf",
+ "save_crop",
+ "save_frames",
+ "show_labels",
+ "show_conf",
+ "visualize",
+ "augment",
+ "agnostic_nms",
+ "retina_masks",
+ "show_boxes",
+ "keras",
+ "optimize",
+ "int8",
+ "dynamic",
+ "simplify",
+ "nms",
+ "profile",
+ "multi_scale",
+)
def cfg2dict(cfg):
@@ -119,38 +205,44 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
# Merge overrides
if overrides:
overrides = cfg2dict(overrides)
- if 'save_dir' not in cfg:
- overrides.pop('save_dir', None) # special override keys to ignore
+ if "save_dir" not in cfg:
+ overrides.pop("save_dir", None) # special override keys to ignore
check_dict_alignment(cfg, overrides)
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
# Special handling for numeric project/name
- for k in 'project', 'name':
+ for k in "project", "name":
if k in cfg and isinstance(cfg[k], (int, float)):
cfg[k] = str(cfg[k])
- if cfg.get('name') == 'model': # assign model to 'name' arg
- cfg['name'] = cfg.get('model', '').split('.')[0]
+ if cfg.get("name") == "model": # assign model to 'name' arg
+ cfg["name"] = cfg.get("model", "").split(".")[0]
LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.")
# Type and Value checks
for k, v in cfg.items():
if v is not None: # None values may be from optional args
if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
- raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
- f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
+ raise TypeError(
+ f"'{k}={v}' is of invalid type {type(v).__name__}. "
+ f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')"
+ )
elif k in CFG_FRACTION_KEYS:
if not isinstance(v, (int, float)):
- raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
- f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
+ raise TypeError(
+ f"'{k}={v}' is of invalid type {type(v).__name__}. "
+ f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')"
+ )
if not (0.0 <= v <= 1.0):
- raise ValueError(f"'{k}={v}' is an invalid value. "
- f"Valid '{k}' values are between 0.0 and 1.0.")
+ raise ValueError(f"'{k}={v}' is an invalid value. " f"Valid '{k}' values are between 0.0 and 1.0.")
elif k in CFG_INT_KEYS and not isinstance(v, int):
- raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
- f"'{k}' must be an int (i.e. '{k}=8')")
+ raise TypeError(
+ f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')"
+ )
elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
- raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
- f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')")
+ raise TypeError(
+ f"'{k}={v}' is of invalid type {type(v).__name__}. "
+ f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')"
+ )
# Return instance
return IterableSimpleNamespace(**cfg)
@@ -159,13 +251,13 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
def get_save_dir(args, name=None):
"""Return save_dir as created from train/val/predict arguments."""
- if getattr(args, 'save_dir', None):
+ if getattr(args, "save_dir", None):
save_dir = args.save_dir
else:
from ultralytics.utils.files import increment_path
- project = args.project or (ROOT.parent / 'tests/tmp/runs' if TESTS_RUNNING else RUNS_DIR) / args.task
- name = name or args.name or f'{args.mode}'
+ project = args.project or (ROOT.parent / "tests/tmp/runs" if TESTS_RUNNING else RUNS_DIR) / args.task
+ name = name or args.name or f"{args.mode}"
save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True)
return Path(save_dir)
@@ -175,18 +267,18 @@ def _handle_deprecation(custom):
"""Hardcoded function to handle deprecated config keys."""
for key in custom.copy().keys():
- if key == 'boxes':
- deprecation_warn(key, 'show_boxes')
- custom['show_boxes'] = custom.pop('boxes')
- if key == 'hide_labels':
- deprecation_warn(key, 'show_labels')
- custom['show_labels'] = custom.pop('hide_labels') == 'False'
- if key == 'hide_conf':
- deprecation_warn(key, 'show_conf')
- custom['show_conf'] = custom.pop('hide_conf') == 'False'
- if key == 'line_thickness':
- deprecation_warn(key, 'line_width')
- custom['line_width'] = custom.pop('line_thickness')
+ if key == "boxes":
+ deprecation_warn(key, "show_boxes")
+ custom["show_boxes"] = custom.pop("boxes")
+ if key == "hide_labels":
+ deprecation_warn(key, "show_labels")
+ custom["show_labels"] = custom.pop("hide_labels") == "False"
+ if key == "hide_conf":
+ deprecation_warn(key, "show_conf")
+ custom["show_conf"] = custom.pop("hide_conf") == "False"
+ if key == "line_thickness":
+ deprecation_warn(key, "line_width")
+ custom["line_width"] = custom.pop("line_thickness")
return custom
@@ -207,11 +299,11 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None):
if mismatched:
from difflib import get_close_matches
- string = ''
+ string = ""
for x in mismatched:
matches = get_close_matches(x, base_keys) # key list
- matches = [f'{k}={base[k]}' if base.get(k) is not None else k for k in matches]
- match_str = f'Similar arguments are i.e. {matches}.' if matches else ''
+ matches = [f"{k}={base[k]}" if base.get(k) is not None else k for k in matches]
+ match_str = f"Similar arguments are i.e. {matches}." if matches else ""
string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
raise SyntaxError(string + CLI_HELP_MSG) from e
@@ -229,13 +321,13 @@ def merge_equals_args(args: List[str]) -> List[str]:
"""
new_args = []
for i, arg in enumerate(args):
- if arg == '=' and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']
- new_args[-1] += f'={args[i + 1]}'
+ if arg == "=" and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']
+ new_args[-1] += f"={args[i + 1]}"
del args[i + 1]
- elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]: # merge ['arg=', 'val']
- new_args.append(f'{arg}{args[i + 1]}')
+ elif arg.endswith("=") and i < len(args) - 1 and "=" not in args[i + 1]: # merge ['arg=', 'val']
+ new_args.append(f"{arg}{args[i + 1]}")
del args[i + 1]
- elif arg.startswith('=') and i > 0: # merge ['arg', '=val']
+ elif arg.startswith("=") and i > 0: # merge ['arg', '=val']
new_args[-1] += arg
else:
new_args.append(arg)
@@ -259,11 +351,11 @@ def handle_yolo_hub(args: List[str]) -> None:
"""
from ultralytics import hub
- if args[0] == 'login':
- key = args[1] if len(args) > 1 else ''
+ if args[0] == "login":
+ key = args[1] if len(args) > 1 else ""
# Log in to Ultralytics HUB using the provided API key
hub.login(key)
- elif args[0] == 'logout':
+ elif args[0] == "logout":
# Log out from Ultralytics HUB
hub.logout()
@@ -283,19 +375,19 @@ def handle_yolo_settings(args: List[str]) -> None:
python my_script.py yolo settings reset
```
"""
- url = 'https://docs.ultralytics.com/quickstart/#ultralytics-settings' # help URL
+ url = "https://docs.ultralytics.com/quickstart/#ultralytics-settings" # help URL
try:
if any(args):
- if args[0] == 'reset':
+ if args[0] == "reset":
SETTINGS_YAML.unlink() # delete the settings file
SETTINGS.reset() # create new settings
- LOGGER.info('Settings reset successfully') # inform the user that settings have been reset
+ LOGGER.info("Settings reset successfully") # inform the user that settings have been reset
else: # save a new setting
new = dict(parse_key_value_pair(a) for a in args)
check_dict_alignment(SETTINGS, new)
SETTINGS.update(new)
- LOGGER.info(f'💡 Learn about settings at {url}')
+ LOGGER.info(f"💡 Learn about settings at {url}")
yaml_print(SETTINGS_YAML) # print the current settings
except Exception as e:
LOGGER.warning(f"WARNING ⚠️ settings error: '{e}'. Please see {url} for help.")
@@ -303,13 +395,13 @@ def handle_yolo_settings(args: List[str]) -> None:
def handle_explorer():
"""Open the Ultralytics Explorer GUI."""
- checks.check_requirements('streamlit')
- subprocess.run(['streamlit', 'run', ROOT / 'data/explorer/gui/dash.py', '--server.maxMessageSize', '2048'])
+ checks.check_requirements("streamlit")
+ subprocess.run(["streamlit", "run", ROOT / "data/explorer/gui/dash.py", "--server.maxMessageSize", "2048"])
def parse_key_value_pair(pair):
"""Parse one 'key=value' pair and return key and value."""
- k, v = pair.split('=', 1) # split on first '=' sign
+ k, v = pair.split("=", 1) # split on first '=' sign
k, v = k.strip(), v.strip() # remove spaces
assert v, f"missing '{k}' value"
return k, smart_value(v)
@@ -318,11 +410,11 @@ def parse_key_value_pair(pair):
def smart_value(v):
"""Convert a string to an underlying type such as int, float, bool, etc."""
v_lower = v.lower()
- if v_lower == 'none':
+ if v_lower == "none":
return None
- elif v_lower == 'true':
+ elif v_lower == "true":
return True
- elif v_lower == 'false':
+ elif v_lower == "false":
return False
else:
with contextlib.suppress(Exception):
@@ -330,7 +422,7 @@ def smart_value(v):
return v
-def entrypoint(debug=''):
+def entrypoint(debug=""):
"""
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
to the package.
@@ -345,139 +437,150 @@ def entrypoint(debug=''):
It uses the package's default cfg and initializes it using the passed overrides.
Then it calls the CLI function with the composed cfg
"""
- args = (debug.split(' ') if debug else sys.argv)[1:]
+ args = (debug.split(" ") if debug else sys.argv)[1:]
if not args: # no arguments passed
LOGGER.info(CLI_HELP_MSG)
return
special = {
- 'help': lambda: LOGGER.info(CLI_HELP_MSG),
- 'checks': checks.collect_system_info,
- 'version': lambda: LOGGER.info(__version__),
- 'settings': lambda: handle_yolo_settings(args[1:]),
- 'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
- 'hub': lambda: handle_yolo_hub(args[1:]),
- 'login': lambda: handle_yolo_hub(args),
- 'copy-cfg': copy_default_cfg,
- 'explorer': lambda: handle_explorer()}
+ "help": lambda: LOGGER.info(CLI_HELP_MSG),
+ "checks": checks.collect_system_info,
+ "version": lambda: LOGGER.info(__version__),
+ "settings": lambda: handle_yolo_settings(args[1:]),
+ "cfg": lambda: yaml_print(DEFAULT_CFG_PATH),
+ "hub": lambda: handle_yolo_hub(args[1:]),
+ "login": lambda: handle_yolo_hub(args),
+ "copy-cfg": copy_default_cfg,
+ "explorer": lambda: handle_explorer(),
+ }
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
# Define common misuses of special commands, i.e. -h, -help, --help
special.update({k[0]: v for k, v in special.items()}) # singular
- special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith('s')}) # singular
- special = {**special, **{f'-{k}': v for k, v in special.items()}, **{f'--{k}': v for k, v in special.items()}}
+ special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith("s")}) # singular
+ special = {**special, **{f"-{k}": v for k, v in special.items()}, **{f"--{k}": v for k, v in special.items()}}
overrides = {} # basic overrides, i.e. imgsz=320
for a in merge_equals_args(args): # merge spaces around '=' sign
- if a.startswith('--'):
+ if a.startswith("--"):
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
a = a[2:]
- if a.endswith(','):
+ if a.endswith(","):
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
a = a[:-1]
- if '=' in a:
+ if "=" in a:
try:
k, v = parse_key_value_pair(a)
- if k == 'cfg' and v is not None: # custom.yaml passed
- LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}')
- overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != 'cfg'}
+ if k == "cfg" and v is not None: # custom.yaml passed
+ LOGGER.info(f"Overriding {DEFAULT_CFG_PATH} with {v}")
+ overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != "cfg"}
else:
overrides[k] = v
except (NameError, SyntaxError, ValueError, AssertionError) as e:
- check_dict_alignment(full_args_dict, {a: ''}, e)
+ check_dict_alignment(full_args_dict, {a: ""}, e)
elif a in TASKS:
- overrides['task'] = a
+ overrides["task"] = a
elif a in MODES:
- overrides['mode'] = a
+ overrides["mode"] = a
elif a.lower() in special:
special[a.lower()]()
return
elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool):
overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True
elif a in DEFAULT_CFG_DICT:
- raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
- f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}")
+ raise SyntaxError(
+ f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
+ f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}"
+ )
else:
- check_dict_alignment(full_args_dict, {a: ''})
+ check_dict_alignment(full_args_dict, {a: ""})
# Check keys
check_dict_alignment(full_args_dict, overrides)
# Mode
- mode = overrides.get('mode')
+ mode = overrides.get("mode")
if mode is None:
- mode = DEFAULT_CFG.mode or 'predict'
+ mode = DEFAULT_CFG.mode or "predict"
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
elif mode not in MODES:
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")
# Task
- task = overrides.pop('task', None)
+ task = overrides.pop("task", None)
if task:
if task not in TASKS:
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
- if 'model' not in overrides:
- overrides['model'] = TASK2MODEL[task]
+ if "model" not in overrides:
+ overrides["model"] = TASK2MODEL[task]
# Model
- model = overrides.pop('model', DEFAULT_CFG.model)
+ model = overrides.pop("model", DEFAULT_CFG.model)
if model is None:
- model = 'yolov8n.pt'
+ model = "yolov8n.pt"
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
- overrides['model'] = model
+ overrides["model"] = model
stem = Path(model).stem.lower()
- if 'rtdetr' in stem: # guess architecture
+ if "rtdetr" in stem: # guess architecture
from ultralytics import RTDETR
+
model = RTDETR(model) # no task argument
- elif 'fastsam' in stem:
+ elif "fastsam" in stem:
from ultralytics import FastSAM
+
model = FastSAM(model)
- elif 'sam' in stem:
+ elif "sam" in stem:
from ultralytics import SAM
+
model = SAM(model)
else:
from ultralytics import YOLO
+
model = YOLO(model, task=task)
- if isinstance(overrides.get('pretrained'), str):
- model.load(overrides['pretrained'])
+ if isinstance(overrides.get("pretrained"), str):
+ model.load(overrides["pretrained"])
# Task Update
if task != model.task:
if task:
- LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "
- f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.")
+ LOGGER.warning(
+ f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "
+ f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model."
+ )
task = model.task
# Mode
- if mode in ('predict', 'track') and 'source' not in overrides:
- overrides['source'] = DEFAULT_CFG.source or ASSETS
+ if mode in ("predict", "track") and "source" not in overrides:
+ overrides["source"] = DEFAULT_CFG.source or ASSETS
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
- elif mode in ('train', 'val'):
- if 'data' not in overrides and 'resume' not in overrides:
- overrides['data'] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
+ elif mode in ("train", "val"):
+ if "data" not in overrides and "resume" not in overrides:
+ overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
- elif mode == 'export':
- if 'format' not in overrides:
- overrides['format'] = DEFAULT_CFG.format or 'torchscript'
+ elif mode == "export":
+ if "format" not in overrides:
+ overrides["format"] = DEFAULT_CFG.format or "torchscript"
LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
# Run command in python
getattr(model, mode)(**overrides) # default args from model
# Show help
- LOGGER.info(f'💡 Learn more at https://docs.ultralytics.com/modes/{mode}')
+ LOGGER.info(f"💡 Learn more at https://docs.ultralytics.com/modes/{mode}")
# Special modes --------------------------------------------------------------------------------------------------------
def copy_default_cfg():
"""Copy and create a new default configuration file with '_copy' appended to its name."""
- new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
+ new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace(".yaml", "_copy.yaml")
shutil.copy2(DEFAULT_CFG_PATH, new_file)
- LOGGER.info(f'{DEFAULT_CFG_PATH} copied to {new_file}\n'
- f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8")
+ LOGGER.info(
+ f"{DEFAULT_CFG_PATH} copied to {new_file}\n"
+ f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8"
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
# Example: entrypoint(debug='yolo predict model=yolov8n.pt')
- entrypoint(debug='')
+ entrypoint(debug="")
diff --git a/ultralytics/data/__init__.py b/ultralytics/data/__init__.py
index 6fa7e845..9f91ce97 100644
--- a/ultralytics/data/__init__.py
+++ b/ultralytics/data/__init__.py
@@ -4,5 +4,12 @@ from .base import BaseDataset
from .build import build_dataloader, build_yolo_dataset, load_inference_source
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
-__all__ = ('BaseDataset', 'ClassificationDataset', 'SemanticDataset', 'YOLODataset', 'build_yolo_dataset',
- 'build_dataloader', 'load_inference_source')
+__all__ = (
+ "BaseDataset",
+ "ClassificationDataset",
+ "SemanticDataset",
+ "YOLODataset",
+ "build_yolo_dataset",
+ "build_dataloader",
+ "load_inference_source",
+)
diff --git a/ultralytics/data/annotator.py b/ultralytics/data/annotator.py
index b4e08c76..cd0444f3 100644
--- a/ultralytics/data/annotator.py
+++ b/ultralytics/data/annotator.py
@@ -5,7 +5,7 @@ from pathlib import Path
from ultralytics import SAM, YOLO
-def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None):
+def auto_annotate(data, det_model="yolov8x.pt", sam_model="sam_b.pt", device="", output_dir=None):
"""
Automatically annotates images using a YOLO object detection model and a SAM segmentation model.
@@ -29,7 +29,7 @@ def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='',
data = Path(data)
if not output_dir:
- output_dir = data.parent / f'{data.stem}_auto_annotate_labels'
+ output_dir = data.parent / f"{data.stem}_auto_annotate_labels"
Path(output_dir).mkdir(exist_ok=True, parents=True)
det_results = det_model(data, stream=True, device=device)
@@ -41,10 +41,10 @@ def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='',
sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device)
segments = sam_results[0].masks.xyn # noqa
- with open(f'{str(Path(output_dir) / Path(result.path).stem)}.txt', 'w') as f:
+ with open(f"{str(Path(output_dir) / Path(result.path).stem)}.txt", "w") as f:
for i in range(len(segments)):
s = segments[i]
if len(s) == 0:
continue
segment = map(str, segments[i].reshape(-1).tolist())
- f.write(f'{class_ids[i]} ' + ' '.join(segment) + '\n')
+ f.write(f"{class_ids[i]} " + " ".join(segment) + "\n")
diff --git a/ultralytics/data/augment.py b/ultralytics/data/augment.py
index a1cd1c3c..9dd5ef59 100644
--- a/ultralytics/data/augment.py
+++ b/ultralytics/data/augment.py
@@ -117,11 +117,11 @@ class BaseMixTransform:
if self.pre_transform is not None:
for i, data in enumerate(mix_labels):
mix_labels[i] = self.pre_transform(data)
- labels['mix_labels'] = mix_labels
+ labels["mix_labels"] = mix_labels
# Mosaic or MixUp
labels = self._mix_transform(labels)
- labels.pop('mix_labels', None)
+ labels.pop("mix_labels", None)
return labels
def _mix_transform(self, labels):
@@ -149,8 +149,8 @@ class Mosaic(BaseMixTransform):
def __init__(self, dataset, imgsz=640, p=1.0, n=4):
"""Initializes the object with a dataset, image size, probability, and border."""
- assert 0 <= p <= 1.0, f'The probability should be in range [0, 1], but got {p}.'
- assert n in (4, 9), 'grid must be equal to 4 or 9.'
+ assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}."
+ assert n in (4, 9), "grid must be equal to 4 or 9."
super().__init__(dataset=dataset, p=p)
self.dataset = dataset
self.imgsz = imgsz
@@ -166,20 +166,21 @@ class Mosaic(BaseMixTransform):
def _mix_transform(self, labels):
"""Apply mixup transformation to the input image and labels."""
- assert labels.get('rect_shape', None) is None, 'rect and mosaic are mutually exclusive.'
- assert len(labels.get('mix_labels', [])), 'There are no other images for mosaic augment.'
- return self._mosaic3(labels) if self.n == 3 else self._mosaic4(labels) if self.n == 4 else self._mosaic9(
- labels) # This code is modified for mosaic3 method.
+ assert labels.get("rect_shape", None) is None, "rect and mosaic are mutually exclusive."
+ assert len(labels.get("mix_labels", [])), "There are no other images for mosaic augment."
+ return (
+ self._mosaic3(labels) if self.n == 3 else self._mosaic4(labels) if self.n == 4 else self._mosaic9(labels)
+ ) # This code is modified for mosaic3 method.
def _mosaic3(self, labels):
"""Create a 1x3 image mosaic."""
mosaic_labels = []
s = self.imgsz
for i in range(3):
- labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
+ labels_patch = labels if i == 0 else labels["mix_labels"][i - 1]
# Load image
- img = labels_patch['img']
- h, w = labels_patch.pop('resized_shape')
+ img = labels_patch["img"]
+ h, w = labels_patch.pop("resized_shape")
# Place img in img3
if i == 0: # center
@@ -194,7 +195,7 @@ class Mosaic(BaseMixTransform):
padw, padh = c[:2]
x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
- img3[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img3[ymin:ymax, xmin:xmax]
+ img3[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img3[ymin:ymax, xmin:xmax]
# hp, wp = h, w # height, width previous for next iteration
# Labels assuming imgsz*2 mosaic size
@@ -202,7 +203,7 @@ class Mosaic(BaseMixTransform):
mosaic_labels.append(labels_patch)
final_labels = self._cat_labels(mosaic_labels)
- final_labels['img'] = img3[-self.border[0]:self.border[0], -self.border[1]:self.border[1]]
+ final_labels["img"] = img3[-self.border[0] : self.border[0], -self.border[1] : self.border[1]]
return final_labels
def _mosaic4(self, labels):
@@ -211,10 +212,10 @@ class Mosaic(BaseMixTransform):
s = self.imgsz
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
for i in range(4):
- labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
+ labels_patch = labels if i == 0 else labels["mix_labels"][i - 1]
# Load image
- img = labels_patch['img']
- h, w = labels_patch.pop('resized_shape')
+ img = labels_patch["img"]
+ h, w = labels_patch.pop("resized_shape")
# Place img in img4
if i == 0: # top left
@@ -238,7 +239,7 @@ class Mosaic(BaseMixTransform):
labels_patch = self._update_labels(labels_patch, padw, padh)
mosaic_labels.append(labels_patch)
final_labels = self._cat_labels(mosaic_labels)
- final_labels['img'] = img4
+ final_labels["img"] = img4
return final_labels
def _mosaic9(self, labels):
@@ -247,10 +248,10 @@ class Mosaic(BaseMixTransform):
s = self.imgsz
hp, wp = -1, -1 # height, width previous
for i in range(9):
- labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
+ labels_patch = labels if i == 0 else labels["mix_labels"][i - 1]
# Load image
- img = labels_patch['img']
- h, w = labels_patch.pop('resized_shape')
+ img = labels_patch["img"]
+ h, w = labels_patch.pop("resized_shape")
# Place img in img9
if i == 0: # center
@@ -278,7 +279,7 @@ class Mosaic(BaseMixTransform):
x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
# Image
- img9[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img9[ymin:ymax, xmin:xmax]
+ img9[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img9[ymin:ymax, xmin:xmax]
hp, wp = h, w # height, width previous for next iteration
# Labels assuming imgsz*2 mosaic size
@@ -286,16 +287,16 @@ class Mosaic(BaseMixTransform):
mosaic_labels.append(labels_patch)
final_labels = self._cat_labels(mosaic_labels)
- final_labels['img'] = img9[-self.border[0]:self.border[0], -self.border[1]:self.border[1]]
+ final_labels["img"] = img9[-self.border[0] : self.border[0], -self.border[1] : self.border[1]]
return final_labels
@staticmethod
def _update_labels(labels, padw, padh):
"""Update labels."""
- nh, nw = labels['img'].shape[:2]
- labels['instances'].convert_bbox(format='xyxy')
- labels['instances'].denormalize(nw, nh)
- labels['instances'].add_padding(padw, padh)
+ nh, nw = labels["img"].shape[:2]
+ labels["instances"].convert_bbox(format="xyxy")
+ labels["instances"].denormalize(nw, nh)
+ labels["instances"].add_padding(padw, padh)
return labels
def _cat_labels(self, mosaic_labels):
@@ -306,18 +307,20 @@ class Mosaic(BaseMixTransform):
instances = []
imgsz = self.imgsz * 2 # mosaic imgsz
for labels in mosaic_labels:
- cls.append(labels['cls'])
- instances.append(labels['instances'])
+ cls.append(labels["cls"])
+ instances.append(labels["instances"])
+ # Final labels
final_labels = {
- 'im_file': mosaic_labels[0]['im_file'],
- 'ori_shape': mosaic_labels[0]['ori_shape'],
- 'resized_shape': (imgsz, imgsz),
- 'cls': np.concatenate(cls, 0),
- 'instances': Instances.concatenate(instances, axis=0),
- 'mosaic_border': self.border} # final_labels
- final_labels['instances'].clip(imgsz, imgsz)
- good = final_labels['instances'].remove_zero_area_boxes()
- final_labels['cls'] = final_labels['cls'][good]
+ "im_file": mosaic_labels[0]["im_file"],
+ "ori_shape": mosaic_labels[0]["ori_shape"],
+ "resized_shape": (imgsz, imgsz),
+ "cls": np.concatenate(cls, 0),
+ "instances": Instances.concatenate(instances, axis=0),
+ "mosaic_border": self.border,
+ }
+ final_labels["instances"].clip(imgsz, imgsz)
+ good = final_labels["instances"].remove_zero_area_boxes()
+ final_labels["cls"] = final_labels["cls"][good]
return final_labels
@@ -335,10 +338,10 @@ class MixUp(BaseMixTransform):
def _mix_transform(self, labels):
"""Applies MixUp augmentation as per https://arxiv.org/pdf/1710.09412.pdf."""
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
- labels2 = labels['mix_labels'][0]
- labels['img'] = (labels['img'] * r + labels2['img'] * (1 - r)).astype(np.uint8)
- labels['instances'] = Instances.concatenate([labels['instances'], labels2['instances']], axis=0)
- labels['cls'] = np.concatenate([labels['cls'], labels2['cls']], 0)
+ labels2 = labels["mix_labels"][0]
+ labels["img"] = (labels["img"] * r + labels2["img"] * (1 - r)).astype(np.uint8)
+ labels["instances"] = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0)
+ labels["cls"] = np.concatenate([labels["cls"], labels2["cls"]], 0)
return labels
@@ -366,14 +369,9 @@ class RandomPerspective:
box_candidates(box1, box2): Filters out bounding boxes that don't meet certain criteria post-transformation.
"""
- def __init__(self,
- degrees=0.0,
- translate=0.1,
- scale=0.5,
- shear=0.0,
- perspective=0.0,
- border=(0, 0),
- pre_transform=None):
+ def __init__(
+ self, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, border=(0, 0), pre_transform=None
+ ):
"""Initializes RandomPerspective object with transformation parameters."""
self.degrees = degrees
@@ -519,18 +517,18 @@ class RandomPerspective:
Args:
labels (dict): a dict of `bboxes`, `segments`, `keypoints`.
"""
- if self.pre_transform and 'mosaic_border' not in labels:
+ if self.pre_transform and "mosaic_border" not in labels:
labels = self.pre_transform(labels)
- labels.pop('ratio_pad', None) # do not need ratio pad
+ labels.pop("ratio_pad", None) # do not need ratio pad
- img = labels['img']
- cls = labels['cls']
- instances = labels.pop('instances')
+ img = labels["img"]
+ cls = labels["cls"]
+ instances = labels.pop("instances")
# Make sure the coord formats are right
- instances.convert_bbox(format='xyxy')
+ instances.convert_bbox(format="xyxy")
instances.denormalize(*img.shape[:2][::-1])
- border = labels.pop('mosaic_border', self.border)
+ border = labels.pop("mosaic_border", self.border)
self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h
# M is affine matrix
# Scale for func:`box_candidates`
@@ -546,20 +544,20 @@ class RandomPerspective:
if keypoints is not None:
keypoints = self.apply_keypoints(keypoints, M)
- new_instances = Instances(bboxes, segments, keypoints, bbox_format='xyxy', normalized=False)
+ new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False)
# Clip
new_instances.clip(*self.size)
# Filter instances
instances.scale(scale_w=scale, scale_h=scale, bbox_only=True)
# Make the bboxes have the same scale with new_bboxes
- i = self.box_candidates(box1=instances.bboxes.T,
- box2=new_instances.bboxes.T,
- area_thr=0.01 if len(segments) else 0.10)
- labels['instances'] = new_instances[i]
- labels['cls'] = cls[i]
- labels['img'] = img
- labels['resized_shape'] = img.shape[:2]
+ i = self.box_candidates(
+ box1=instances.bboxes.T, box2=new_instances.bboxes.T, area_thr=0.01 if len(segments) else 0.10
+ )
+ labels["instances"] = new_instances[i]
+ labels["cls"] = cls[i]
+ labels["img"] = img
+ labels["resized_shape"] = img.shape[:2]
return labels
def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16):
@@ -611,7 +609,7 @@ class RandomHSV:
The modified image replaces the original image in the input 'labels' dict.
"""
- img = labels['img']
+ img = labels["img"]
if self.hgain or self.sgain or self.vgain:
r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains
hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
@@ -634,7 +632,7 @@ class RandomFlip:
Also updates any instances (bounding boxes, keypoints, etc.) accordingly.
"""
- def __init__(self, p=0.5, direction='horizontal', flip_idx=None) -> None:
+ def __init__(self, p=0.5, direction="horizontal", flip_idx=None) -> None:
"""
Initializes the RandomFlip class with probability and direction.
@@ -644,7 +642,7 @@ class RandomFlip:
Default is 'horizontal'.
flip_idx (array-like, optional): Index mapping for flipping keypoints, if any.
"""
- assert direction in ['horizontal', 'vertical'], f'Support direction `horizontal` or `vertical`, got {direction}'
+ assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}"
assert 0 <= p <= 1.0
self.p = p
@@ -662,25 +660,25 @@ class RandomFlip:
Returns:
(dict): The same dict with the flipped image and updated instances under the 'img' and 'instances' keys.
"""
- img = labels['img']
- instances = labels.pop('instances')
- instances.convert_bbox(format='xywh')
+ img = labels["img"]
+ instances = labels.pop("instances")
+ instances.convert_bbox(format="xywh")
h, w = img.shape[:2]
h = 1 if instances.normalized else h
w = 1 if instances.normalized else w
# Flip up-down
- if self.direction == 'vertical' and random.random() < self.p:
+ if self.direction == "vertical" and random.random() < self.p:
img = np.flipud(img)
instances.flipud(h)
- if self.direction == 'horizontal' and random.random() < self.p:
+ if self.direction == "horizontal" and random.random() < self.p:
img = np.fliplr(img)
instances.fliplr(w)
# For keypoints
if self.flip_idx is not None and instances.keypoints is not None:
instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :])
- labels['img'] = np.ascontiguousarray(img)
- labels['instances'] = instances
+ labels["img"] = np.ascontiguousarray(img)
+ labels["instances"] = instances
return labels
@@ -700,9 +698,9 @@ class LetterBox:
"""Return updated labels and image with added border."""
if labels is None:
labels = {}
- img = labels.get('img') if image is None else image
+ img = labels.get("img") if image is None else image
shape = img.shape[:2] # current shape [height, width]
- new_shape = labels.pop('rect_shape', self.new_shape)
+ new_shape = labels.pop("rect_shape", self.new_shape)
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
@@ -730,25 +728,26 @@ class LetterBox:
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
- img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
- value=(114, 114, 114)) # add border
- if labels.get('ratio_pad'):
- labels['ratio_pad'] = (labels['ratio_pad'], (left, top)) # for evaluation
+ img = cv2.copyMakeBorder(
+ img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
+ ) # add border
+ if labels.get("ratio_pad"):
+ labels["ratio_pad"] = (labels["ratio_pad"], (left, top)) # for evaluation
if len(labels):
labels = self._update_labels(labels, ratio, dw, dh)
- labels['img'] = img
- labels['resized_shape'] = new_shape
+ labels["img"] = img
+ labels["resized_shape"] = new_shape
return labels
else:
return img
def _update_labels(self, labels, ratio, padw, padh):
"""Update labels."""
- labels['instances'].convert_bbox(format='xyxy')
- labels['instances'].denormalize(*labels['img'].shape[:2][::-1])
- labels['instances'].scale(*ratio)
- labels['instances'].add_padding(padw, padh)
+ labels["instances"].convert_bbox(format="xyxy")
+ labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
+ labels["instances"].scale(*ratio)
+ labels["instances"].add_padding(padw, padh)
return labels
@@ -785,11 +784,11 @@ class CopyPaste:
1. Instances are expected to have 'segments' as one of their attributes for this augmentation to work.
2. This method modifies the input dictionary 'labels' in place.
"""
- im = labels['img']
- cls = labels['cls']
+ im = labels["img"]
+ cls = labels["cls"]
h, w = im.shape[:2]
- instances = labels.pop('instances')
- instances.convert_bbox(format='xyxy')
+ instances = labels.pop("instances")
+ instances.convert_bbox(format="xyxy")
instances.denormalize(w, h)
if self.p and len(instances.segments):
n = len(instances)
@@ -812,9 +811,9 @@ class CopyPaste:
i = cv2.flip(im_new, 1).astype(bool)
im[i] = result[i]
- labels['img'] = im
- labels['cls'] = cls
- labels['instances'] = instances
+ labels["img"] = im
+ labels["cls"] = cls
+ labels["instances"] = instances
return labels
@@ -831,12 +830,13 @@ class Albumentations:
"""Initialize the transform object for YOLO bbox formatted params."""
self.p = p
self.transform = None
- prefix = colorstr('albumentations: ')
+ prefix = colorstr("albumentations: ")
try:
import albumentations as A
- check_version(A.__version__, '1.0.3', hard=True) # version requirement
+ check_version(A.__version__, "1.0.3", hard=True) # version requirement
+ # Transforms
T = [
A.Blur(p=0.01),
A.MedianBlur(p=0.01),
@@ -844,31 +844,32 @@ class Albumentations:
A.CLAHE(p=0.01),
A.RandomBrightnessContrast(p=0.0),
A.RandomGamma(p=0.0),
- A.ImageCompression(quality_lower=75, p=0.0)] # transforms
- self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
+ A.ImageCompression(quality_lower=75, p=0.0),
+ ]
+ self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))
- LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
+ LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
except ImportError: # package not installed, skip
pass
except Exception as e:
- LOGGER.info(f'{prefix}{e}')
+ LOGGER.info(f"{prefix}{e}")
def __call__(self, labels):
"""Generates object detections and returns a dictionary with detection results."""
- im = labels['img']
- cls = labels['cls']
+ im = labels["img"]
+ cls = labels["cls"]
if len(cls):
- labels['instances'].convert_bbox('xywh')
- labels['instances'].normalize(*im.shape[:2][::-1])
- bboxes = labels['instances'].bboxes
+ labels["instances"].convert_bbox("xywh")
+ labels["instances"].normalize(*im.shape[:2][::-1])
+ bboxes = labels["instances"].bboxes
# TODO: add supports of segments and keypoints
if self.transform and random.random() < self.p:
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
- if len(new['class_labels']) > 0: # skip update if no bbox in new im
- labels['img'] = new['image']
- labels['cls'] = np.array(new['class_labels'])
- bboxes = np.array(new['bboxes'], dtype=np.float32)
- labels['instances'].update(bboxes=bboxes)
+ if len(new["class_labels"]) > 0: # skip update if no bbox in new im
+ labels["img"] = new["image"]
+ labels["cls"] = np.array(new["class_labels"])
+ bboxes = np.array(new["bboxes"], dtype=np.float32)
+ labels["instances"].update(bboxes=bboxes)
return labels
@@ -888,15 +889,17 @@ class Format:
batch_idx (bool): Keep batch indexes. Default is True.
"""
- def __init__(self,
- bbox_format='xywh',
- normalize=True,
- return_mask=False,
- return_keypoint=False,
- return_obb=False,
- mask_ratio=4,
- mask_overlap=True,
- batch_idx=True):
+ def __init__(
+ self,
+ bbox_format="xywh",
+ normalize=True,
+ return_mask=False,
+ return_keypoint=False,
+ return_obb=False,
+ mask_ratio=4,
+ mask_overlap=True,
+ batch_idx=True,
+ ):
"""Initializes the Format class with given parameters."""
self.bbox_format = bbox_format
self.normalize = normalize
@@ -909,10 +912,10 @@ class Format:
def __call__(self, labels):
"""Return formatted image, classes, bounding boxes & keypoints to be used by 'collate_fn'."""
- img = labels.pop('img')
+ img = labels.pop("img")
h, w = img.shape[:2]
- cls = labels.pop('cls')
- instances = labels.pop('instances')
+ cls = labels.pop("cls")
+ instances = labels.pop("instances")
instances.convert_bbox(format=self.bbox_format)
instances.denormalize(w, h)
nl = len(instances)
@@ -922,22 +925,24 @@ class Format:
masks, instances, cls = self._format_segments(instances, cls, w, h)
masks = torch.from_numpy(masks)
else:
- masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio,
- img.shape[1] // self.mask_ratio)
- labels['masks'] = masks
+ masks = torch.zeros(
+ 1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio
+ )
+ labels["masks"] = masks
if self.normalize:
instances.normalize(w, h)
- labels['img'] = self._format_img(img)
- labels['cls'] = torch.from_numpy(cls) if nl else torch.zeros(nl)
- labels['bboxes'] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
+ labels["img"] = self._format_img(img)
+ labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
+ labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
if self.return_keypoint:
- labels['keypoints'] = torch.from_numpy(instances.keypoints)
+ labels["keypoints"] = torch.from_numpy(instances.keypoints)
if self.return_obb:
- labels['bboxes'] = xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(
- instances.segments) else torch.zeros((0, 5))
+ labels["bboxes"] = (
+ xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(instances.segments) else torch.zeros((0, 5))
+ )
# Then we can use collate_fn
if self.batch_idx:
- labels['batch_idx'] = torch.zeros(nl)
+ labels["batch_idx"] = torch.zeros(nl)
return labels
def _format_img(self, img):
@@ -964,33 +969,39 @@ class Format:
def v8_transforms(dataset, imgsz, hyp, stretch=False):
"""Convert images to a size suitable for YOLOv8 training."""
- pre_transform = Compose([
- Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
- CopyPaste(p=hyp.copy_paste),
- RandomPerspective(
- degrees=hyp.degrees,
- translate=hyp.translate,
- scale=hyp.scale,
- shear=hyp.shear,
- perspective=hyp.perspective,
- pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),
- )])
- flip_idx = dataset.data.get('flip_idx', []) # for keypoints augmentation
+ pre_transform = Compose(
+ [
+ Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
+ CopyPaste(p=hyp.copy_paste),
+ RandomPerspective(
+ degrees=hyp.degrees,
+ translate=hyp.translate,
+ scale=hyp.scale,
+ shear=hyp.shear,
+ perspective=hyp.perspective,
+ pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),
+ ),
+ ]
+ )
+ flip_idx = dataset.data.get("flip_idx", []) # for keypoints augmentation
if dataset.use_keypoints:
- kpt_shape = dataset.data.get('kpt_shape', None)
+ kpt_shape = dataset.data.get("kpt_shape", None)
if len(flip_idx) == 0 and hyp.fliplr > 0.0:
hyp.fliplr = 0.0
LOGGER.warning("WARNING ⚠️ No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'")
elif flip_idx and (len(flip_idx) != kpt_shape[0]):
- raise ValueError(f'data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}')
+ raise ValueError(f"data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}")
- return Compose([
- pre_transform,
- MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
- Albumentations(p=1.0),
- RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
- RandomFlip(direction='vertical', p=hyp.flipud),
- RandomFlip(direction='horizontal', p=hyp.fliplr, flip_idx=flip_idx)]) # transforms
+ return Compose(
+ [
+ pre_transform,
+ MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
+ Albumentations(p=1.0),
+ RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
+ RandomFlip(direction="vertical", p=hyp.flipud),
+ RandomFlip(direction="horizontal", p=hyp.fliplr, flip_idx=flip_idx),
+ ]
+ ) # transforms
# Classification augmentations -----------------------------------------------------------------------------------------
@@ -1031,10 +1042,13 @@ def classify_transforms(
tfl = [T.Resize(scale_size)]
tfl += [T.CenterCrop(size)]
- tfl += [T.ToTensor(), T.Normalize(
- mean=torch.tensor(mean),
- std=torch.tensor(std),
- )]
+ tfl += [
+ T.ToTensor(),
+ T.Normalize(
+ mean=torch.tensor(mean),
+ std=torch.tensor(std),
+ ),
+ ]
return T.Compose(tfl)
@@ -1053,7 +1067,7 @@ def classify_augmentations(
hsv_s=0.4, # image HSV-Saturation augmentation (fraction)
hsv_v=0.4, # image HSV-Value augmentation (fraction)
force_color_jitter=False,
- erasing=0.,
+ erasing=0.0,
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
):
"""
@@ -1080,13 +1094,13 @@ def classify_augmentations(
"""
# Transforms to apply if albumentations not installed
if not isinstance(size, int):
- raise TypeError(f'classify_transforms() size {size} must be integer, not (list, tuple)')
+ raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
- ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
+ ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range
primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)]
- if hflip > 0.:
+ if hflip > 0.0:
primary_tfl += [T.RandomHorizontalFlip(p=hflip)]
- if vflip > 0.:
+ if vflip > 0.0:
primary_tfl += [T.RandomVerticalFlip(p=vflip)]
secondary_tfl = []
@@ -1097,27 +1111,29 @@ def classify_augmentations(
# this allows override without breaking old hparm cfgs
disable_color_jitter = not force_color_jitter
- if auto_augment == 'randaugment':
+ if auto_augment == "randaugment":
if TORCHVISION_0_11:
secondary_tfl += [T.RandAugment(interpolation=interpolation)]
else:
LOGGER.warning('"auto_augment=randaugment" requires torchvision >= 0.11.0. Disabling it.')
- elif auto_augment == 'augmix':
+ elif auto_augment == "augmix":
if TORCHVISION_0_13:
secondary_tfl += [T.AugMix(interpolation=interpolation)]
else:
LOGGER.warning('"auto_augment=augmix" requires torchvision >= 0.13.0. Disabling it.')
- elif auto_augment == 'autoaugment':
+ elif auto_augment == "autoaugment":
if TORCHVISION_0_10:
secondary_tfl += [T.AutoAugment(interpolation=interpolation)]
else:
LOGGER.warning('"auto_augment=autoaugment" requires torchvision >= 0.10.0. Disabling it.')
else:
- raise ValueError(f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", '
- f'"augmix", "autoaugment" or None')
+ raise ValueError(
+ f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", '
+ f'"augmix", "autoaugment" or None'
+ )
if not disable_color_jitter:
secondary_tfl += [T.ColorJitter(brightness=hsv_v, contrast=hsv_v, saturation=hsv_s, hue=hsv_h)]
@@ -1125,7 +1141,8 @@ def classify_augmentations(
final_tfl = [
T.ToTensor(),
T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
- T.RandomErasing(p=erasing, inplace=True)]
+ T.RandomErasing(p=erasing, inplace=True),
+ ]
return T.Compose(primary_tfl + secondary_tfl + final_tfl)
@@ -1177,7 +1194,7 @@ class ClassifyLetterBox:
# Create padded image
im_out = np.full((hs, ws, 3), 114, dtype=im.dtype)
- im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
+ im_out[top : top + h, left : left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
return im_out
@@ -1205,7 +1222,7 @@ class CenterCrop:
imh, imw = im.shape[:2]
m = min(imh, imw) # min dimension
top, left = (imh - m) // 2, (imw - m) // 2
- return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
+ return cv2.resize(im[top : top + m, left : left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
# NOTE: keep this class for backward compatibility
diff --git a/ultralytics/data/base.py b/ultralytics/data/base.py
index 1df546b3..1a392e0c 100644
--- a/ultralytics/data/base.py
+++ b/ultralytics/data/base.py
@@ -47,20 +47,22 @@ class BaseDataset(Dataset):
transforms (callable): Image transformation function.
"""
- def __init__(self,
- img_path,
- imgsz=640,
- cache=False,
- augment=True,
- hyp=DEFAULT_CFG,
- prefix='',
- rect=False,
- batch_size=16,
- stride=32,
- pad=0.5,
- single_cls=False,
- classes=None,
- fraction=1.0):
+ def __init__(
+ self,
+ img_path,
+ imgsz=640,
+ cache=False,
+ augment=True,
+ hyp=DEFAULT_CFG,
+ prefix="",
+ rect=False,
+ batch_size=16,
+ stride=32,
+ pad=0.5,
+ single_cls=False,
+ classes=None,
+ fraction=1.0,
+ ):
"""Initialize BaseDataset with given configuration and options."""
super().__init__()
self.img_path = img_path
@@ -86,10 +88,10 @@ class BaseDataset(Dataset):
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
# Cache images
- if cache == 'ram' and not self.check_cache_ram():
+ if cache == "ram" and not self.check_cache_ram():
cache = False
self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
- self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
+ self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
if cache:
self.cache_images(cache)
@@ -103,23 +105,23 @@ class BaseDataset(Dataset):
for p in img_path if isinstance(img_path, list) else [img_path]:
p = Path(p) # os-agnostic
if p.is_dir(): # dir
- f += glob.glob(str(p / '**' / '*.*'), recursive=True)
+ f += glob.glob(str(p / "**" / "*.*"), recursive=True)
# F = list(p.rglob('*.*')) # pathlib
elif p.is_file(): # file
with open(p) as t:
t = t.read().strip().splitlines()
parent = str(p.parent) + os.sep
- f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
+ f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path
# F += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
else:
- raise FileNotFoundError(f'{self.prefix}{p} does not exist')
- im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
+ raise FileNotFoundError(f"{self.prefix}{p} does not exist")
+ im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
- assert im_files, f'{self.prefix}No images found in {img_path}'
+ assert im_files, f"{self.prefix}No images found in {img_path}"
except Exception as e:
- raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e
+ raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
if self.fraction < 1:
- im_files = im_files[:round(len(im_files) * self.fraction)]
+ im_files = im_files[: round(len(im_files) * self.fraction)]
return im_files
def update_labels(self, include_class: Optional[list]):
@@ -127,19 +129,19 @@ class BaseDataset(Dataset):
include_class_array = np.array(include_class).reshape(1, -1)
for i in range(len(self.labels)):
if include_class is not None:
- cls = self.labels[i]['cls']
- bboxes = self.labels[i]['bboxes']
- segments = self.labels[i]['segments']
- keypoints = self.labels[i]['keypoints']
+ cls = self.labels[i]["cls"]
+ bboxes = self.labels[i]["bboxes"]
+ segments = self.labels[i]["segments"]
+ keypoints = self.labels[i]["keypoints"]
j = (cls == include_class_array).any(1)
- self.labels[i]['cls'] = cls[j]
- self.labels[i]['bboxes'] = bboxes[j]
+ self.labels[i]["cls"] = cls[j]
+ self.labels[i]["bboxes"] = bboxes[j]
if segments:
- self.labels[i]['segments'] = [segments[si] for si, idx in enumerate(j) if idx]
+ self.labels[i]["segments"] = [segments[si] for si, idx in enumerate(j) if idx]
if keypoints is not None:
- self.labels[i]['keypoints'] = keypoints[j]
+ self.labels[i]["keypoints"] = keypoints[j]
if self.single_cls:
- self.labels[i]['cls'][:, 0] = 0
+ self.labels[i]["cls"][:, 0] = 0
def load_image(self, i, rect_mode=True):
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
@@ -149,13 +151,13 @@ class BaseDataset(Dataset):
try:
im = np.load(fn)
except Exception as e:
- LOGGER.warning(f'{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}')
+ LOGGER.warning(f"{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}")
Path(fn).unlink(missing_ok=True)
im = cv2.imread(f) # BGR
else: # read image
im = cv2.imread(f) # BGR
if im is None:
- raise FileNotFoundError(f'Image Not Found {f}')
+ raise FileNotFoundError(f"Image Not Found {f}")
h0, w0 = im.shape[:2] # orig hw
if rect_mode: # resize long side to imgsz while maintaining aspect ratio
@@ -181,17 +183,17 @@ class BaseDataset(Dataset):
def cache_images(self, cache):
"""Cache images to memory or disk."""
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
- fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
+ fcn = self.cache_images_to_disk if cache == "disk" else self.load_image
with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(fcn, range(self.ni))
pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0)
for i, x in pbar:
- if cache == 'disk':
+ if cache == "disk":
b += self.npy_files[i].stat().st_size
else: # 'ram'
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
b += self.ims[i].nbytes
- pbar.desc = f'{self.prefix}Caching images ({b / gb:.1f}GB {cache})'
+ pbar.desc = f"{self.prefix}Caching images ({b / gb:.1f}GB {cache})"
pbar.close()
def cache_images_to_disk(self, i):
@@ -207,15 +209,17 @@ class BaseDataset(Dataset):
for _ in range(n):
im = cv2.imread(random.choice(self.im_files)) # sample image
ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
- b += im.nbytes * ratio ** 2
+ b += im.nbytes * ratio**2
mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM
mem = psutil.virtual_memory()
cache = mem_required < mem.available # to cache or not to cache, that is the question
if not cache:
- LOGGER.info(f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images '
- f'with {int(safety_margin * 100)}% safety margin but only '
- f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
- f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
+ LOGGER.info(
+ f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images '
+ f'with {int(safety_margin * 100)}% safety margin but only '
+ f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
+ f"{'caching images ✅' if cache else 'not caching images ⚠️'}"
+ )
return cache
def set_rectangle(self):
@@ -223,7 +227,7 @@ class BaseDataset(Dataset):
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
nb = bi[-1] + 1 # number of batches
- s = np.array([x.pop('shape') for x in self.labels]) # hw
+ s = np.array([x.pop("shape") for x in self.labels]) # hw
ar = s[:, 0] / s[:, 1] # aspect ratio
irect = ar.argsort()
self.im_files = [self.im_files[i] for i in irect]
@@ -250,12 +254,14 @@ class BaseDataset(Dataset):
def get_image_and_label(self, index):
"""Get and return label information from the dataset."""
label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
- label.pop('shape', None) # shape is for rect, remove it
- label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index)
- label['ratio_pad'] = (label['resized_shape'][0] / label['ori_shape'][0],
- label['resized_shape'][1] / label['ori_shape'][1]) # for evaluation
+ label.pop("shape", None) # shape is for rect, remove it
+ label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
+ label["ratio_pad"] = (
+ label["resized_shape"][0] / label["ori_shape"][0],
+ label["resized_shape"][1] / label["ori_shape"][1],
+ ) # for evaluation
if self.rect:
- label['rect_shape'] = self.batch_shapes[self.batch[index]]
+ label["rect_shape"] = self.batch_shapes[self.batch[index]]
return self.update_labels_info(label)
def __len__(self):
diff --git a/ultralytics/data/build.py b/ultralytics/data/build.py
index 99fed0c4..0892a2d9 100644
--- a/ultralytics/data/build.py
+++ b/ultralytics/data/build.py
@@ -9,8 +9,16 @@ import torch
from PIL import Image
from torch.utils.data import dataloader, distributed
-from ultralytics.data.loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams, LoadTensor,
- SourceTypes, autocast_list)
+from ultralytics.data.loaders import (
+ LOADERS,
+ LoadImages,
+ LoadPilAndNumpy,
+ LoadScreenshots,
+ LoadStreams,
+ LoadTensor,
+ SourceTypes,
+ autocast_list,
+)
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.utils import RANK, colorstr
from ultralytics.utils.checks import check_file
@@ -29,7 +37,7 @@ class InfiniteDataLoader(dataloader.DataLoader):
def __init__(self, *args, **kwargs):
"""Dataloader that infinitely recycles workers, inherits from DataLoader."""
super().__init__(*args, **kwargs)
- object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
+ object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__()
def __len__(self):
@@ -70,29 +78,30 @@ class _RepeatSampler:
def seed_worker(worker_id): # noqa
"""Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader."""
- worker_seed = torch.initial_seed() % 2 ** 32
+ worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
-def build_yolo_dataset(cfg, img_path, batch, data, mode='train', rect=False, stride=32):
+def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32):
"""Build YOLO Dataset."""
return YOLODataset(
img_path=img_path,
imgsz=cfg.imgsz,
batch_size=batch,
- augment=mode == 'train', # augmentation
+ augment=mode == "train", # augmentation
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
rect=cfg.rect or rect, # rectangular batches
cache=cfg.cache or None,
single_cls=cfg.single_cls or False,
stride=int(stride),
- pad=0.0 if mode == 'train' else 0.5,
- prefix=colorstr(f'{mode}: '),
+ pad=0.0 if mode == "train" else 0.5,
+ prefix=colorstr(f"{mode}: "),
task=cfg.task,
classes=cfg.classes,
data=data,
- fraction=cfg.fraction if mode == 'train' else 1.0)
+ fraction=cfg.fraction if mode == "train" else 1.0,
+ )
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
@@ -103,15 +112,17 @@ def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK)
- return InfiniteDataLoader(dataset=dataset,
- batch_size=batch,
- shuffle=shuffle and sampler is None,
- num_workers=nw,
- sampler=sampler,
- pin_memory=PIN_MEMORY,
- collate_fn=getattr(dataset, 'collate_fn', None),
- worker_init_fn=seed_worker,
- generator=generator)
+ return InfiniteDataLoader(
+ dataset=dataset,
+ batch_size=batch,
+ shuffle=shuffle and sampler is None,
+ num_workers=nw,
+ sampler=sampler,
+ pin_memory=PIN_MEMORY,
+ collate_fn=getattr(dataset, "collate_fn", None),
+ worker_init_fn=seed_worker,
+ generator=generator,
+ )
def check_source(source):
@@ -120,9 +131,9 @@ def check_source(source):
if isinstance(source, (str, int, Path)): # int for local usb camera
source = str(source)
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
- is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://', 'tcp://'))
- webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
- screenshot = source.lower() == 'screen'
+ is_url = source.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://"))
+ webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file)
+ screenshot = source.lower() == "screen"
if is_url and is_file:
source = check_file(source) # download
elif isinstance(source, LOADERS):
@@ -135,7 +146,7 @@ def check_source(source):
elif isinstance(source, torch.Tensor):
tensor = True
else:
- raise TypeError('Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict')
+ raise TypeError("Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict")
return source, webcam, screenshot, from_img, in_memory, tensor
@@ -171,6 +182,6 @@ def load_inference_source(source=None, imgsz=640, vid_stride=1, buffer=False):
dataset = LoadImages(source, imgsz=imgsz, vid_stride=vid_stride)
# Attach source types to the dataset
- setattr(dataset, 'source_type', source_type)
+ setattr(dataset, "source_type", source_type)
return dataset
diff --git a/ultralytics/data/converter.py b/ultralytics/data/converter.py
index 5714320f..1a875f3b 100644
--- a/ultralytics/data/converter.py
+++ b/ultralytics/data/converter.py
@@ -20,10 +20,98 @@ def coco91_to_coco80_class():
corresponding 91-index class ID.
"""
return [
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, None, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, None, 24, 25, None,
- None, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, None, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
- 51, 52, 53, 54, 55, 56, 57, 58, 59, None, 60, None, None, 61, None, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72,
- None, 73, 74, 75, 76, 77, 78, 79, None]
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+ 10,
+ None,
+ 11,
+ 12,
+ 13,
+ 14,
+ 15,
+ 16,
+ 17,
+ 18,
+ 19,
+ 20,
+ 21,
+ 22,
+ 23,
+ None,
+ 24,
+ 25,
+ None,
+ None,
+ 26,
+ 27,
+ 28,
+ 29,
+ 30,
+ 31,
+ 32,
+ 33,
+ 34,
+ 35,
+ 36,
+ 37,
+ 38,
+ 39,
+ None,
+ 40,
+ 41,
+ 42,
+ 43,
+ 44,
+ 45,
+ 46,
+ 47,
+ 48,
+ 49,
+ 50,
+ 51,
+ 52,
+ 53,
+ 54,
+ 55,
+ 56,
+ 57,
+ 58,
+ 59,
+ None,
+ 60,
+ None,
+ None,
+ 61,
+ None,
+ 62,
+ 63,
+ 64,
+ 65,
+ 66,
+ 67,
+ 68,
+ 69,
+ 70,
+ 71,
+ 72,
+ None,
+ 73,
+ 74,
+ 75,
+ 76,
+ 77,
+ 78,
+ 79,
+ None,
+ ]
def coco80_to_coco91_class():
@@ -42,16 +130,96 @@ def coco80_to_coco91_class():
```
"""
return [
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
- 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
- 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
-
-
-def convert_coco(labels_dir='../coco/annotations/',
- save_dir='coco_converted/',
- use_segments=False,
- use_keypoints=False,
- cls91to80=True):
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+ 10,
+ 11,
+ 13,
+ 14,
+ 15,
+ 16,
+ 17,
+ 18,
+ 19,
+ 20,
+ 21,
+ 22,
+ 23,
+ 24,
+ 25,
+ 27,
+ 28,
+ 31,
+ 32,
+ 33,
+ 34,
+ 35,
+ 36,
+ 37,
+ 38,
+ 39,
+ 40,
+ 41,
+ 42,
+ 43,
+ 44,
+ 46,
+ 47,
+ 48,
+ 49,
+ 50,
+ 51,
+ 52,
+ 53,
+ 54,
+ 55,
+ 56,
+ 57,
+ 58,
+ 59,
+ 60,
+ 61,
+ 62,
+ 63,
+ 64,
+ 65,
+ 67,
+ 70,
+ 72,
+ 73,
+ 74,
+ 75,
+ 76,
+ 77,
+ 78,
+ 79,
+ 80,
+ 81,
+ 82,
+ 84,
+ 85,
+ 86,
+ 87,
+ 88,
+ 89,
+ 90,
+ ]
+
+
+def convert_coco(
+ labels_dir="../coco/annotations/",
+ save_dir="coco_converted/",
+ use_segments=False,
+ use_keypoints=False,
+ cls91to80=True,
+):
"""
Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.
@@ -75,76 +243,78 @@ def convert_coco(labels_dir='../coco/annotations/',
# Create dataset directory
save_dir = increment_path(save_dir) # increment if save directory already exists
- for p in save_dir / 'labels', save_dir / 'images':
+ for p in save_dir / "labels", save_dir / "images":
p.mkdir(parents=True, exist_ok=True) # make dir
# Convert classes
coco80 = coco91_to_coco80_class()
# Import json
- for json_file in sorted(Path(labels_dir).resolve().glob('*.json')):
- fn = Path(save_dir) / 'labels' / json_file.stem.replace('instances_', '') # folder name
+ for json_file in sorted(Path(labels_dir).resolve().glob("*.json")):
+ fn = Path(save_dir) / "labels" / json_file.stem.replace("instances_", "") # folder name
fn.mkdir(parents=True, exist_ok=True)
with open(json_file) as f:
data = json.load(f)
# Create image dict
- images = {f'{x["id"]:d}': x for x in data['images']}
+ images = {f'{x["id"]:d}': x for x in data["images"]}
# Create image-annotations dict
imgToAnns = defaultdict(list)
- for ann in data['annotations']:
- imgToAnns[ann['image_id']].append(ann)
+ for ann in data["annotations"]:
+ imgToAnns[ann["image_id"]].append(ann)
# Write labels file
- for img_id, anns in TQDM(imgToAnns.items(), desc=f'Annotations {json_file}'):
- img = images[f'{img_id:d}']
- h, w, f = img['height'], img['width'], img['file_name']
+ for img_id, anns in TQDM(imgToAnns.items(), desc=f"Annotations {json_file}"):
+ img = images[f"{img_id:d}"]
+ h, w, f = img["height"], img["width"], img["file_name"]
bboxes = []
segments = []
keypoints = []
for ann in anns:
- if ann['iscrowd']:
+ if ann["iscrowd"]:
continue
# The COCO box format is [top left x, top left y, width, height]
- box = np.array(ann['bbox'], dtype=np.float64)
+ box = np.array(ann["bbox"], dtype=np.float64)
box[:2] += box[2:] / 2 # xy top-left corner to center
box[[0, 2]] /= w # normalize x
box[[1, 3]] /= h # normalize y
if box[2] <= 0 or box[3] <= 0: # if w <= 0 and h <= 0
continue
- cls = coco80[ann['category_id'] - 1] if cls91to80 else ann['category_id'] - 1 # class
+ cls = coco80[ann["category_id"] - 1] if cls91to80 else ann["category_id"] - 1 # class
box = [cls] + box.tolist()
if box not in bboxes:
bboxes.append(box)
- if use_segments and ann.get('segmentation') is not None:
- if len(ann['segmentation']) == 0:
+ if use_segments and ann.get("segmentation") is not None:
+ if len(ann["segmentation"]) == 0:
segments.append([])
continue
- elif len(ann['segmentation']) > 1:
- s = merge_multi_segment(ann['segmentation'])
+ elif len(ann["segmentation"]) > 1:
+ s = merge_multi_segment(ann["segmentation"])
s = (np.concatenate(s, axis=0) / np.array([w, h])).reshape(-1).tolist()
else:
- s = [j for i in ann['segmentation'] for j in i] # all segments concatenated
+ s = [j for i in ann["segmentation"] for j in i] # all segments concatenated
s = (np.array(s).reshape(-1, 2) / np.array([w, h])).reshape(-1).tolist()
s = [cls] + s
segments.append(s)
- if use_keypoints and ann.get('keypoints') is not None:
- keypoints.append(box + (np.array(ann['keypoints']).reshape(-1, 3) /
- np.array([w, h, 1])).reshape(-1).tolist())
+ if use_keypoints and ann.get("keypoints") is not None:
+ keypoints.append(
+ box + (np.array(ann["keypoints"]).reshape(-1, 3) / np.array([w, h, 1])).reshape(-1).tolist()
+ )
# Write
- with open((fn / f).with_suffix('.txt'), 'a') as file:
+ with open((fn / f).with_suffix(".txt"), "a") as file:
for i in range(len(bboxes)):
if use_keypoints:
- line = *(keypoints[i]), # cls, box, keypoints
+ line = (*(keypoints[i]),) # cls, box, keypoints
else:
- line = *(segments[i]
- if use_segments and len(segments[i]) > 0 else bboxes[i]), # cls, box or segments
- file.write(('%g ' * len(line)).rstrip() % line + '\n')
+ line = (
+ *(segments[i] if use_segments and len(segments[i]) > 0 else bboxes[i]),
+ ) # cls, box or segments
+ file.write(("%g " * len(line)).rstrip() % line + "\n")
- LOGGER.info(f'COCO data converted successfully.\nResults saved to {save_dir.resolve()}')
+ LOGGER.info(f"COCO data converted successfully.\nResults saved to {save_dir.resolve()}")
def convert_dota_to_yolo_obb(dota_root_path: str):
@@ -184,31 +354,32 @@ def convert_dota_to_yolo_obb(dota_root_path: str):
# Class names to indices mapping
class_mapping = {
- 'plane': 0,
- 'ship': 1,
- 'storage-tank': 2,
- 'baseball-diamond': 3,
- 'tennis-court': 4,
- 'basketball-court': 5,
- 'ground-track-field': 6,
- 'harbor': 7,
- 'bridge': 8,
- 'large-vehicle': 9,
- 'small-vehicle': 10,
- 'helicopter': 11,
- 'roundabout': 12,
- 'soccer-ball-field': 13,
- 'swimming-pool': 14,
- 'container-crane': 15,
- 'airport': 16,
- 'helipad': 17}
+ "plane": 0,
+ "ship": 1,
+ "storage-tank": 2,
+ "baseball-diamond": 3,
+ "tennis-court": 4,
+ "basketball-court": 5,
+ "ground-track-field": 6,
+ "harbor": 7,
+ "bridge": 8,
+ "large-vehicle": 9,
+ "small-vehicle": 10,
+ "helicopter": 11,
+ "roundabout": 12,
+ "soccer-ball-field": 13,
+ "swimming-pool": 14,
+ "container-crane": 15,
+ "airport": 16,
+ "helipad": 17,
+ }
def convert_label(image_name, image_width, image_height, orig_label_dir, save_dir):
"""Converts a single image's DOTA annotation to YOLO OBB format and saves it to a specified directory."""
- orig_label_path = orig_label_dir / f'{image_name}.txt'
- save_path = save_dir / f'{image_name}.txt'
+ orig_label_path = orig_label_dir / f"{image_name}.txt"
+ save_path = save_dir / f"{image_name}.txt"
- with orig_label_path.open('r') as f, save_path.open('w') as g:
+ with orig_label_path.open("r") as f, save_path.open("w") as g:
lines = f.readlines()
for line in lines:
parts = line.strip().split()
@@ -218,20 +389,21 @@ def convert_dota_to_yolo_obb(dota_root_path: str):
class_idx = class_mapping[class_name]
coords = [float(p) for p in parts[:8]]
normalized_coords = [
- coords[i] / image_width if i % 2 == 0 else coords[i] / image_height for i in range(8)]
- formatted_coords = ['{:.6g}'.format(coord) for coord in normalized_coords]
+ coords[i] / image_width if i % 2 == 0 else coords[i] / image_height for i in range(8)
+ ]
+ formatted_coords = ["{:.6g}".format(coord) for coord in normalized_coords]
g.write(f"{class_idx} {' '.join(formatted_coords)}\n")
- for phase in ['train', 'val']:
- image_dir = dota_root_path / 'images' / phase
- orig_label_dir = dota_root_path / 'labels' / f'{phase}_original'
- save_dir = dota_root_path / 'labels' / phase
+ for phase in ["train", "val"]:
+ image_dir = dota_root_path / "images" / phase
+ orig_label_dir = dota_root_path / "labels" / f"{phase}_original"
+ save_dir = dota_root_path / "labels" / phase
save_dir.mkdir(parents=True, exist_ok=True)
image_paths = list(image_dir.iterdir())
- for image_path in TQDM(image_paths, desc=f'Processing {phase} images'):
- if image_path.suffix != '.png':
+ for image_path in TQDM(image_paths, desc=f"Processing {phase} images"):
+ if image_path.suffix != ".png":
continue
image_name_without_ext = image_path.stem
img = cv2.imread(str(image_path))
@@ -293,7 +465,7 @@ def merge_multi_segment(segments):
s.append(segments[i])
else:
idx = [0, idx[1] - idx[0]]
- s.append(segments[i][idx[0]:idx[1] + 1])
+ s.append(segments[i][idx[0] : idx[1] + 1])
else:
for i in range(len(idx_list) - 1, -1, -1):
diff --git a/ultralytics/data/dataset.py b/ultralytics/data/dataset.py
index ad0ba56e..0475508e 100644
--- a/ultralytics/data/dataset.py
+++ b/ultralytics/data/dataset.py
@@ -18,7 +18,7 @@ from .base import BaseDataset
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
-DATASET_CACHE_VERSION = '1.0.3'
+DATASET_CACHE_VERSION = "1.0.3"
class YOLODataset(BaseDataset):
@@ -33,16 +33,16 @@ class YOLODataset(BaseDataset):
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
"""
- def __init__(self, *args, data=None, task='detect', **kwargs):
+ def __init__(self, *args, data=None, task="detect", **kwargs):
"""Initializes the YOLODataset with optional configurations for segments and keypoints."""
- self.use_segments = task == 'segment'
- self.use_keypoints = task == 'pose'
- self.use_obb = task == 'obb'
+ self.use_segments = task == "segment"
+ self.use_keypoints = task == "pose"
+ self.use_obb = task == "obb"
self.data = data
- assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.'
+ assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
super().__init__(*args, **kwargs)
- def cache_labels(self, path=Path('./labels.cache')):
+ def cache_labels(self, path=Path("./labels.cache")):
"""
Cache dataset labels, check images and read shapes.
@@ -51,19 +51,29 @@ class YOLODataset(BaseDataset):
Returns:
(dict): labels.
"""
- x = {'labels': []}
+ x = {"labels": []}
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
- desc = f'{self.prefix}Scanning {path.parent / path.stem}...'
+ desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
total = len(self.im_files)
- nkpt, ndim = self.data.get('kpt_shape', (0, 0))
+ nkpt, ndim = self.data.get("kpt_shape", (0, 0))
if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
- raise ValueError("'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
- "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'")
+ raise ValueError(
+ "'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
+ "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
+ )
with ThreadPool(NUM_THREADS) as pool:
- results = pool.imap(func=verify_image_label,
- iterable=zip(self.im_files, self.label_files, repeat(self.prefix),
- repeat(self.use_keypoints), repeat(len(self.data['names'])), repeat(nkpt),
- repeat(ndim)))
+ results = pool.imap(
+ func=verify_image_label,
+ iterable=zip(
+ self.im_files,
+ self.label_files,
+ repeat(self.prefix),
+ repeat(self.use_keypoints),
+ repeat(len(self.data["names"])),
+ repeat(nkpt),
+ repeat(ndim),
+ ),
+ )
pbar = TQDM(results, desc=desc, total=total)
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
nm += nm_f
@@ -71,7 +81,7 @@ class YOLODataset(BaseDataset):
ne += ne_f
nc += nc_f
if im_file:
- x['labels'].append(
+ x["labels"].append(
dict(
im_file=im_file,
shape=shape,
@@ -80,60 +90,63 @@ class YOLODataset(BaseDataset):
segments=segments,
keypoints=keypoint,
normalized=True,
- bbox_format='xywh'))
+ bbox_format="xywh",
+ )
+ )
if msg:
msgs.append(msg)
- pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'
+ pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
pbar.close()
if msgs:
- LOGGER.info('\n'.join(msgs))
+ LOGGER.info("\n".join(msgs))
if nf == 0:
- LOGGER.warning(f'{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
- x['hash'] = get_hash(self.label_files + self.im_files)
- x['results'] = nf, nm, ne, nc, len(self.im_files)
- x['msgs'] = msgs # warnings
+ LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")
+ x["hash"] = get_hash(self.label_files + self.im_files)
+ x["results"] = nf, nm, ne, nc, len(self.im_files)
+ x["msgs"] = msgs # warnings
save_dataset_cache_file(self.prefix, path, x)
return x
def get_labels(self):
"""Returns dictionary of labels for YOLO training."""
self.label_files = img2label_paths(self.im_files)
- cache_path = Path(self.label_files[0]).parent.with_suffix('.cache')
+ cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
try:
cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
- assert cache['version'] == DATASET_CACHE_VERSION # matches current version
- assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
+ assert cache["version"] == DATASET_CACHE_VERSION # matches current version
+ assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
except (FileNotFoundError, AssertionError, AttributeError):
cache, exists = self.cache_labels(cache_path), False # run cache ops
# Display cache
- nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
+ nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
if exists and LOCAL_RANK in (-1, 0):
- d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
+ d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
- if cache['msgs']:
- LOGGER.info('\n'.join(cache['msgs'])) # display warnings
+ if cache["msgs"]:
+ LOGGER.info("\n".join(cache["msgs"])) # display warnings
# Read cache
- [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
- labels = cache['labels']
+ [cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
+ labels = cache["labels"]
if not labels:
- LOGGER.warning(f'WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}')
- self.im_files = [lb['im_file'] for lb in labels] # update im_files
+ LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}")
+ self.im_files = [lb["im_file"] for lb in labels] # update im_files
# Check if the dataset is all boxes or all segments
- lengths = ((len(lb['cls']), len(lb['bboxes']), len(lb['segments'])) for lb in labels)
+ lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)
len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
if len_segments and len_boxes != len_segments:
LOGGER.warning(
- f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, '
- f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. '
- 'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.')
+ f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
+ f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
+ "To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset."
+ )
for lb in labels:
- lb['segments'] = []
+ lb["segments"] = []
if len_cls == 0:
- LOGGER.warning(f'WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}')
+ LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}")
return labels
def build_transforms(self, hyp=None):
@@ -145,14 +158,17 @@ class YOLODataset(BaseDataset):
else:
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
transforms.append(
- Format(bbox_format='xywh',
- normalize=True,
- return_mask=self.use_segments,
- return_keypoint=self.use_keypoints,
- return_obb=self.use_obb,
- batch_idx=True,
- mask_ratio=hyp.mask_ratio,
- mask_overlap=hyp.overlap_mask))
+ Format(
+ bbox_format="xywh",
+ normalize=True,
+ return_mask=self.use_segments,
+ return_keypoint=self.use_keypoints,
+ return_obb=self.use_obb,
+ batch_idx=True,
+ mask_ratio=hyp.mask_ratio,
+ mask_overlap=hyp.overlap_mask,
+ )
+ )
return transforms
def close_mosaic(self, hyp):
@@ -166,11 +182,11 @@ class YOLODataset(BaseDataset):
"""Custom your label format here."""
# NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
# We can make it also support classification and semantic segmentation by add or remove some dict keys there.
- bboxes = label.pop('bboxes')
- segments = label.pop('segments', [])
- keypoints = label.pop('keypoints', None)
- bbox_format = label.pop('bbox_format')
- normalized = label.pop('normalized')
+ bboxes = label.pop("bboxes")
+ segments = label.pop("segments", [])
+ keypoints = label.pop("keypoints", None)
+ bbox_format = label.pop("bbox_format")
+ normalized = label.pop("normalized")
# NOTE: do NOT resample oriented boxes
segment_resamples = 100 if self.use_obb else 1000
@@ -180,7 +196,7 @@ class YOLODataset(BaseDataset):
segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
else:
segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
- label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
+ label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
return label
@staticmethod
@@ -191,15 +207,15 @@ class YOLODataset(BaseDataset):
values = list(zip(*[list(b.values()) for b in batch]))
for i, k in enumerate(keys):
value = values[i]
- if k == 'img':
+ if k == "img":
value = torch.stack(value, 0)
- if k in ['masks', 'keypoints', 'bboxes', 'cls', 'segments', 'obb']:
+ if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]:
value = torch.cat(value, 0)
new_batch[k] = value
- new_batch['batch_idx'] = list(new_batch['batch_idx'])
- for i in range(len(new_batch['batch_idx'])):
- new_batch['batch_idx'][i] += i # add target image index for build_targets()
- new_batch['batch_idx'] = torch.cat(new_batch['batch_idx'], 0)
+ new_batch["batch_idx"] = list(new_batch["batch_idx"])
+ for i in range(len(new_batch["batch_idx"])):
+ new_batch["batch_idx"][i] += i # add target image index for build_targets()
+ new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
return new_batch
@@ -219,7 +235,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True.
"""
- def __init__(self, root, args, augment=False, cache=False, prefix=''):
+ def __init__(self, root, args, augment=False, cache=False, prefix=""):
"""
Initialize YOLO object with root, image size, augmentations, and cache settings.
@@ -231,23 +247,28 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
"""
super().__init__(root=root)
if augment and args.fraction < 1.0: # reduce training fraction
- self.samples = self.samples[:round(len(self.samples) * args.fraction)]
- self.prefix = colorstr(f'{prefix}: ') if prefix else ''
- self.cache_ram = cache is True or cache == 'ram'
- self.cache_disk = cache == 'disk'
+ self.samples = self.samples[: round(len(self.samples) * args.fraction)]
+ self.prefix = colorstr(f"{prefix}: ") if prefix else ""
+ self.cache_ram = cache is True or cache == "ram"
+ self.cache_disk = cache == "disk"
self.samples = self.verify_images() # filter out bad images
- self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
+ self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
- self.torch_transforms = classify_augmentations(size=args.imgsz,
- scale=scale,
- hflip=args.fliplr,
- vflip=args.flipud,
- erasing=args.erasing,
- auto_augment=args.auto_augment,
- hsv_h=args.hsv_h,
- hsv_s=args.hsv_s,
- hsv_v=args.hsv_v) if augment else classify_transforms(
- size=args.imgsz, crop_fraction=args.crop_fraction)
+ self.torch_transforms = (
+ classify_augmentations(
+ size=args.imgsz,
+ scale=scale,
+ hflip=args.fliplr,
+ vflip=args.flipud,
+ erasing=args.erasing,
+ auto_augment=args.auto_augment,
+ hsv_h=args.hsv_h,
+ hsv_s=args.hsv_s,
+ hsv_v=args.hsv_v,
+ )
+ if augment
+ else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
+ )
def __getitem__(self, i):
"""Returns subset of data and targets corresponding to given indices."""
@@ -263,7 +284,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
# Convert NumPy array to PIL image
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
sample = self.torch_transforms(im)
- return {'img': sample, 'cls': j}
+ return {"img": sample, "cls": j}
def __len__(self) -> int:
"""Return the total number of samples in the dataset."""
@@ -271,19 +292,19 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
def verify_images(self):
"""Verify all images in dataset."""
- desc = f'{self.prefix}Scanning {self.root}...'
- path = Path(self.root).with_suffix('.cache') # *.cache file path
+ desc = f"{self.prefix}Scanning {self.root}..."
+ path = Path(self.root).with_suffix(".cache") # *.cache file path
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
- assert cache['version'] == DATASET_CACHE_VERSION # matches current version
- assert cache['hash'] == get_hash([x[0] for x in self.samples]) # identical hash
- nf, nc, n, samples = cache.pop('results') # found, missing, empty, corrupt, total
+ assert cache["version"] == DATASET_CACHE_VERSION # matches current version
+ assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
+ nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
if LOCAL_RANK in (-1, 0):
- d = f'{desc} {nf} images, {nc} corrupt'
+ d = f"{desc} {nf} images, {nc} corrupt"
TQDM(None, desc=d, total=n, initial=n)
- if cache['msgs']:
- LOGGER.info('\n'.join(cache['msgs'])) # display warnings
+ if cache["msgs"]:
+ LOGGER.info("\n".join(cache["msgs"])) # display warnings
return samples
# Run scan if *.cache retrieval failed
@@ -298,13 +319,13 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
msgs.append(msg)
nf += nf_f
nc += nc_f
- pbar.desc = f'{desc} {nf} images, {nc} corrupt'
+ pbar.desc = f"{desc} {nf} images, {nc} corrupt"
pbar.close()
if msgs:
- LOGGER.info('\n'.join(msgs))
- x['hash'] = get_hash([x[0] for x in self.samples])
- x['results'] = nf, nc, len(samples), samples
- x['msgs'] = msgs # warnings
+ LOGGER.info("\n".join(msgs))
+ x["hash"] = get_hash([x[0] for x in self.samples])
+ x["results"] = nf, nc, len(samples), samples
+ x["msgs"] = msgs # warnings
save_dataset_cache_file(self.prefix, path, x)
return samples
@@ -312,6 +333,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
def load_dataset_cache_file(path):
"""Load an Ultralytics *.cache dictionary from path."""
import gc
+
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
cache = np.load(str(path), allow_pickle=True).item() # load dict
gc.enable()
@@ -320,15 +342,15 @@ def load_dataset_cache_file(path):
def save_dataset_cache_file(prefix, path, x):
"""Save an Ultralytics dataset *.cache dictionary x to path."""
- x['version'] = DATASET_CACHE_VERSION # add cache version
+ x["version"] = DATASET_CACHE_VERSION # add cache version
if is_dir_writeable(path.parent):
if path.exists():
path.unlink() # remove *.cache file if exists
np.save(str(path), x) # save cache for next time
- path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
- LOGGER.info(f'{prefix}New cache created: {path}')
+ path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
+ LOGGER.info(f"{prefix}New cache created: {path}")
else:
- LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.')
+ LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
# TODO: support semantic segmentation
diff --git a/ultralytics/data/explorer/__init__.py b/ultralytics/data/explorer/__init__.py
index f0344e2d..ce594dc1 100644
--- a/ultralytics/data/explorer/__init__.py
+++ b/ultralytics/data/explorer/__init__.py
@@ -2,4 +2,4 @@
from .utils import plot_query_result
-__all__ = ['plot_query_result']
+__all__ = ["plot_query_result"]
diff --git a/ultralytics/data/explorer/explorer.py b/ultralytics/data/explorer/explorer.py
index 4a8595d6..5381a0a1 100644
--- a/ultralytics/data/explorer/explorer.py
+++ b/ultralytics/data/explorer/explorer.py
@@ -22,7 +22,6 @@ from .utils import get_sim_index_schema, get_table_schema, plot_query_result, pr
class ExplorerDataset(YOLODataset):
-
def __init__(self, *args, data: dict = None, **kwargs) -> None:
super().__init__(*args, data=data, **kwargs)
@@ -35,7 +34,7 @@ class ExplorerDataset(YOLODataset):
else: # read image
im = cv2.imread(f) # BGR
if im is None:
- raise FileNotFoundError(f'Image Not Found {f}')
+ raise FileNotFoundError(f"Image Not Found {f}")
h0, w0 = im.shape[:2] # orig hw
return im, (h0, w0), im.shape[:2]
@@ -44,7 +43,7 @@ class ExplorerDataset(YOLODataset):
def build_transforms(self, hyp: IterableSimpleNamespace = None):
"""Creates transforms for dataset images without resizing."""
return Format(
- bbox_format='xyxy',
+ bbox_format="xyxy",
normalize=False,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
@@ -55,17 +54,16 @@ class ExplorerDataset(YOLODataset):
class Explorer:
-
- def __init__(self,
- data: Union[str, Path] = 'coco128.yaml',
- model: str = 'yolov8n.pt',
- uri: str = '~/ultralytics/explorer') -> None:
- checks.check_requirements(['lancedb>=0.4.3', 'duckdb'])
+ def __init__(
+ self, data: Union[str, Path] = "coco128.yaml", model: str = "yolov8n.pt", uri: str = "~/ultralytics/explorer"
+ ) -> None:
+ checks.check_requirements(["lancedb>=0.4.3", "duckdb"])
import lancedb
self.connection = lancedb.connect(uri)
- self.table_name = Path(data).name.lower() + '_' + model.lower()
- self.sim_idx_base_name = f'{self.table_name}_sim_idx'.lower(
+ self.table_name = Path(data).name.lower() + "_" + model.lower()
+ self.sim_idx_base_name = (
+ f"{self.table_name}_sim_idx".lower()
) # Use this name and append thres and top_k to reuse the table
self.model = YOLO(model)
self.data = data # None
@@ -74,7 +72,7 @@ class Explorer:
self.table = None
self.progress = 0
- def create_embeddings_table(self, force: bool = False, split: str = 'train') -> None:
+ def create_embeddings_table(self, force: bool = False, split: str = "train") -> None:
"""
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
already exists. Pass force=True to overwrite the existing table.
@@ -90,20 +88,20 @@ class Explorer:
```
"""
if self.table is not None and not force:
- LOGGER.info('Table already exists. Reusing it. Pass force=True to overwrite it.')
+ LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.")
return
if self.table_name in self.connection.table_names() and not force:
- LOGGER.info(f'Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.')
+ LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.")
self.table = self.connection.open_table(self.table_name)
self.progress = 1
return
if self.data is None:
- raise ValueError('Data must be provided to create embeddings table')
+ raise ValueError("Data must be provided to create embeddings table")
data_info = check_det_dataset(self.data)
if split not in data_info:
raise ValueError(
- f'Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}'
+ f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}"
)
choice_set = data_info[split]
@@ -113,13 +111,16 @@ class Explorer:
# Create the table schema
batch = dataset[0]
- vector_size = self.model.embed(batch['im_file'], verbose=False)[0].shape[0]
- table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode='overwrite')
+ vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0]
+ table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite")
table.add(
- self._yield_batches(dataset,
- data_info,
- self.model,
- exclude_keys=['img', 'ratio_pad', 'resized_shape', 'ori_shape', 'batch_idx']))
+ self._yield_batches(
+ dataset,
+ data_info,
+ self.model,
+ exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"],
+ )
+ )
self.table = table
@@ -131,12 +132,12 @@ class Explorer:
for k in exclude_keys:
batch.pop(k, None)
batch = sanitize_batch(batch, data_info)
- batch['vector'] = model.embed(batch['im_file'], verbose=False)[0].detach().tolist()
+ batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist()
yield [batch]
- def query(self,
- imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
- limit: int = 25) -> Any: # pyarrow.Table
+ def query(
+ self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25
+ ) -> Any: # pyarrow.Table
"""
Query the table for similar images. Accepts a single image or a list of images.
@@ -157,18 +158,18 @@ class Explorer:
```
"""
if self.table is None:
- raise ValueError('Table is not created. Please create the table first.')
+ raise ValueError("Table is not created. Please create the table first.")
if isinstance(imgs, str):
imgs = [imgs]
- assert isinstance(imgs, list), f'img must be a string or a list of strings. Got {type(imgs)}'
+ assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}"
embeds = self.model.embed(imgs)
# Get avg if multiple images are passed (len > 1)
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
return self.table.search(embeds).limit(limit).to_arrow()
- def sql_query(self,
- query: str,
- return_type: str = 'pandas') -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
+ def sql_query(
+ self, query: str, return_type: str = "pandas"
+ ) -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
"""
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
@@ -187,27 +188,29 @@ class Explorer:
result = exp.sql_query(query)
```
"""
- assert return_type in ['pandas',
- 'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
+ assert return_type in [
+ "pandas",
+ "arrow",
+ ], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
import duckdb
if self.table is None:
- raise ValueError('Table is not created. Please create the table first.')
+ raise ValueError("Table is not created. Please create the table first.")
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
- if not query.startswith('SELECT') and not query.startswith('WHERE'):
+ if not query.startswith("SELECT") and not query.startswith("WHERE"):
raise ValueError(
- f'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}'
+ f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}"
)
- if query.startswith('WHERE'):
+ if query.startswith("WHERE"):
query = f"SELECT * FROM 'table' {query}"
- LOGGER.info(f'Running query: {query}')
+ LOGGER.info(f"Running query: {query}")
rs = duckdb.sql(query)
- if return_type == 'pandas':
+ if return_type == "pandas":
return rs.df()
- elif return_type == 'arrow':
+ elif return_type == "arrow":
return rs.arrow()
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
@@ -228,18 +231,20 @@ class Explorer:
result = exp.plot_sql_query(query)
```
"""
- result = self.sql_query(query, return_type='arrow')
+ result = self.sql_query(query, return_type="arrow")
if len(result) == 0:
- LOGGER.info('No results found.')
+ LOGGER.info("No results found.")
return None
img = plot_query_result(result, plot_labels=labels)
return Image.fromarray(img)
- def get_similar(self,
- img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
- idx: Union[int, List[int]] = None,
- limit: int = 25,
- return_type: str = 'pandas') -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
+ def get_similar(
+ self,
+ img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
+ idx: Union[int, List[int]] = None,
+ limit: int = 25,
+ return_type: str = "pandas",
+ ) -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
"""
Query the table for similar images. Accepts a single image or a list of images.
@@ -259,21 +264,25 @@ class Explorer:
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
```
"""
- assert return_type in ['pandas',
- 'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
+ assert return_type in [
+ "pandas",
+ "arrow",
+ ], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
img = self._check_imgs_or_idxs(img, idx)
similar = self.query(img, limit=limit)
- if return_type == 'pandas':
+ if return_type == "pandas":
return similar.to_pandas()
- elif return_type == 'arrow':
+ elif return_type == "arrow":
return similar
- def plot_similar(self,
- img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
- idx: Union[int, List[int]] = None,
- limit: int = 25,
- labels: bool = True) -> Image.Image:
+ def plot_similar(
+ self,
+ img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
+ idx: Union[int, List[int]] = None,
+ limit: int = 25,
+ labels: bool = True,
+ ) -> Image.Image:
"""
Plot the similar images. Accepts images or indexes.
@@ -293,9 +302,9 @@ class Explorer:
similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg')
```
"""
- similar = self.get_similar(img, idx, limit, return_type='arrow')
+ similar = self.get_similar(img, idx, limit, return_type="arrow")
if len(similar) == 0:
- LOGGER.info('No results found.')
+ LOGGER.info("No results found.")
return None
img = plot_query_result(similar, plot_labels=labels)
return Image.fromarray(img)
@@ -323,34 +332,37 @@ class Explorer:
```
"""
if self.table is None:
- raise ValueError('Table is not created. Please create the table first.')
- sim_idx_table_name = f'{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}'.lower()
+ raise ValueError("Table is not created. Please create the table first.")
+ sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower()
if sim_idx_table_name in self.connection.table_names() and not force:
- LOGGER.info('Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.')
+ LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.")
return self.connection.open_table(sim_idx_table_name).to_pandas()
if top_k and not (1.0 >= top_k >= 0.0):
- raise ValueError(f'top_k must be between 0.0 and 1.0. Got {top_k}')
+ raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}")
if max_dist < 0.0:
- raise ValueError(f'max_dist must be greater than 0. Got {max_dist}')
+ raise ValueError(f"max_dist must be greater than 0. Got {max_dist}")
top_k = int(top_k * len(self.table)) if top_k else len(self.table)
top_k = max(top_k, 1)
- features = self.table.to_lance().to_table(columns=['vector', 'im_file']).to_pydict()
- im_files = features['im_file']
- embeddings = features['vector']
+ features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict()
+ im_files = features["im_file"]
+ embeddings = features["vector"]
- sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode='overwrite')
+ sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite")
def _yield_sim_idx():
"""Generates a dataframe with similarity indices and distances for images."""
for i in tqdm(range(len(embeddings))):
- sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f'_distance <= {max_dist}')
- yield [{
- 'idx': i,
- 'im_file': im_files[i],
- 'count': len(sim_idx),
- 'sim_im_files': sim_idx['im_file'].tolist()}]
+ sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}")
+ yield [
+ {
+ "idx": i,
+ "im_file": im_files[i],
+ "count": len(sim_idx),
+ "sim_im_files": sim_idx["im_file"].tolist(),
+ }
+ ]
sim_table.add(_yield_sim_idx())
self.sim_index = sim_table
@@ -381,7 +393,7 @@ class Explorer:
```
"""
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
- sim_count = sim_idx['count'].tolist()
+ sim_count = sim_idx["count"].tolist()
sim_count = np.array(sim_count)
indices = np.arange(len(sim_count))
@@ -390,25 +402,26 @@ class Explorer:
plt.bar(indices, sim_count)
# Customize the plot (optional)
- plt.xlabel('data idx')
- plt.ylabel('Count')
- plt.title('Similarity Count')
+ plt.xlabel("data idx")
+ plt.ylabel("Count")
+ plt.title("Similarity Count")
buffer = BytesIO()
- plt.savefig(buffer, format='png')
+ plt.savefig(buffer, format="png")
buffer.seek(0)
# Use Pillow to open the image from the buffer
return Image.fromarray(np.array(Image.open(buffer)))
- def _check_imgs_or_idxs(self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None],
- idx: Union[None, int, List[int]]) -> List[np.ndarray]:
+ def _check_imgs_or_idxs(
+ self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]]
+ ) -> List[np.ndarray]:
if img is None and idx is None:
- raise ValueError('Either img or idx must be provided.')
+ raise ValueError("Either img or idx must be provided.")
if img is not None and idx is not None:
- raise ValueError('Only one of img or idx must be provided.')
+ raise ValueError("Only one of img or idx must be provided.")
if idx is not None:
idx = idx if isinstance(idx, list) else [idx]
- img = self.table.to_lance().take(idx, columns=['im_file']).to_pydict()['im_file']
+ img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"]
return img if isinstance(img, list) else [img]
@@ -433,7 +446,7 @@ class Explorer:
try:
df = self.sql_query(result)
except Exception as e:
- LOGGER.error('AI generated query is not valid. Please try again with a different prompt')
+ LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
LOGGER.error(e)
return None
return df
diff --git a/ultralytics/data/explorer/gui/dash.py b/ultralytics/data/explorer/gui/dash.py
index d2b5ec21..1ef7fe42 100644
--- a/ultralytics/data/explorer/gui/dash.py
+++ b/ultralytics/data/explorer/gui/dash.py
@@ -9,100 +9,114 @@ from ultralytics import Explorer
from ultralytics.utils import ROOT, SETTINGS
from ultralytics.utils.checks import check_requirements
-check_requirements(('streamlit>=1.29.0', 'streamlit-select>=0.2'))
+check_requirements(("streamlit>=1.29.0", "streamlit-select>=0.2"))
+
import streamlit as st
from streamlit_select import image_select
def _get_explorer():
"""Initializes and returns an instance of the Explorer class."""
- exp = Explorer(data=st.session_state.get('dataset'), model=st.session_state.get('model'))
- thread = Thread(target=exp.create_embeddings_table,
- kwargs={'force': st.session_state.get('force_recreate_embeddings')})
+ exp = Explorer(data=st.session_state.get("dataset"), model=st.session_state.get("model"))
+ thread = Thread(
+ target=exp.create_embeddings_table, kwargs={"force": st.session_state.get("force_recreate_embeddings")}
+ )
thread.start()
- progress_bar = st.progress(0, text='Creating embeddings table...')
+ progress_bar = st.progress(0, text="Creating embeddings table...")
while exp.progress < 1:
time.sleep(0.1)
- progress_bar.progress(exp.progress, text=f'Progress: {exp.progress * 100}%')
+ progress_bar.progress(exp.progress, text=f"Progress: {exp.progress * 100}%")
thread.join()
- st.session_state['explorer'] = exp
+ st.session_state["explorer"] = exp
progress_bar.empty()
def init_explorer_form():
"""Initializes an Explorer instance and creates embeddings table with progress tracking."""
- datasets = ROOT / 'cfg' / 'datasets'
- ds = [d.name for d in datasets.glob('*.yaml')]
+ datasets = ROOT / "cfg" / "datasets"
+ ds = [d.name for d in datasets.glob("*.yaml")]
models = [
- 'yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt', 'yolov8l.pt', 'yolov8x.pt', 'yolov8n-seg.pt', 'yolov8s-seg.pt',
- 'yolov8m-seg.pt', 'yolov8l-seg.pt', 'yolov8x-seg.pt', 'yolov8n-pose.pt', 'yolov8s-pose.pt', 'yolov8m-pose.pt',
- 'yolov8l-pose.pt', 'yolov8x-pose.pt']
- with st.form(key='explorer_init_form'):
+ "yolov8n.pt",
+ "yolov8s.pt",
+ "yolov8m.pt",
+ "yolov8l.pt",
+ "yolov8x.pt",
+ "yolov8n-seg.pt",
+ "yolov8s-seg.pt",
+ "yolov8m-seg.pt",
+ "yolov8l-seg.pt",
+ "yolov8x-seg.pt",
+ "yolov8n-pose.pt",
+ "yolov8s-pose.pt",
+ "yolov8m-pose.pt",
+ "yolov8l-pose.pt",
+ "yolov8x-pose.pt",
+ ]
+ with st.form(key="explorer_init_form"):
col1, col2 = st.columns(2)
with col1:
- st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml'))
+ st.selectbox("Select dataset", ds, key="dataset", index=ds.index("coco128.yaml"))
with col2:
- st.selectbox('Select model', models, key='model')
- st.checkbox('Force recreate embeddings', key='force_recreate_embeddings')
+ st.selectbox("Select model", models, key="model")
+ st.checkbox("Force recreate embeddings", key="force_recreate_embeddings")
- st.form_submit_button('Explore', on_click=_get_explorer)
+ st.form_submit_button("Explore", on_click=_get_explorer)
def query_form():
"""Sets up a form in Streamlit to initialize Explorer with dataset and model selection."""
- with st.form('query_form'):
+ with st.form("query_form"):
col1, col2 = st.columns([0.8, 0.2])
with col1:
- st.text_input('Query',
- "WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
- label_visibility='collapsed',
- key='query')
+ st.text_input(
+ "Query",
+ "WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
+ label_visibility="collapsed",
+ key="query",
+ )
with col2:
- st.form_submit_button('Query', on_click=run_sql_query)
+ st.form_submit_button("Query", on_click=run_sql_query)
def ai_query_form():
"""Sets up a Streamlit form for user input to initialize Explorer with dataset and model selection."""
- with st.form('ai_query_form'):
+ with st.form("ai_query_form"):
col1, col2 = st.columns([0.8, 0.2])
with col1:
- st.text_input('Query', 'Show images with 1 person and 1 dog', label_visibility='collapsed', key='ai_query')
+ st.text_input("Query", "Show images with 1 person and 1 dog", label_visibility="collapsed", key="ai_query")
with col2:
- st.form_submit_button('Ask AI', on_click=run_ai_query)
+ st.form_submit_button("Ask AI", on_click=run_ai_query)
def find_similar_imgs(imgs):
"""Initializes a Streamlit form for AI-based image querying with custom input."""
- exp = st.session_state['explorer']
- similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow')
- paths = similar.to_pydict()['im_file']
- st.session_state['imgs'] = paths
+ exp = st.session_state["explorer"]
+ similar = exp.get_similar(img=imgs, limit=st.session_state.get("limit"), return_type="arrow")
+ paths = similar.to_pydict()["im_file"]
+ st.session_state["imgs"] = paths
def similarity_form(selected_imgs):
"""Initializes a form for AI-based image querying with custom input in Streamlit."""
- st.write('Similarity Search')
- with st.form('similarity_form'):
+ st.write("Similarity Search")
+ with st.form("similarity_form"):
subcol1, subcol2 = st.columns([1, 1])
with subcol1:
- st.number_input('limit',
- min_value=None,
- max_value=None,
- value=25,
- label_visibility='collapsed',
- key='limit')
+ st.number_input(
+ "limit", min_value=None, max_value=None, value=25, label_visibility="collapsed", key="limit"
+ )
with subcol2:
disabled = not len(selected_imgs)
- st.write('Selected: ', len(selected_imgs))
+ st.write("Selected: ", len(selected_imgs))
st.form_submit_button(
- 'Search',
+ "Search",
disabled=disabled,
on_click=find_similar_imgs,
- args=(selected_imgs, ),
+ args=(selected_imgs,),
)
if disabled:
- st.error('Select at least one image to search.')
+ st.error("Select at least one image to search.")
# def persist_reset_form():
@@ -117,100 +131,108 @@ def similarity_form(selected_imgs):
def run_sql_query():
"""Executes an SQL query and returns the results."""
- st.session_state['error'] = None
- query = st.session_state.get('query')
+ st.session_state["error"] = None
+ query = st.session_state.get("query")
if query.rstrip().lstrip():
- exp = st.session_state['explorer']
- res = exp.sql_query(query, return_type='arrow')
- st.session_state['imgs'] = res.to_pydict()['im_file']
+ exp = st.session_state["explorer"]
+ res = exp.sql_query(query, return_type="arrow")
+ st.session_state["imgs"] = res.to_pydict()["im_file"]
def run_ai_query():
"""Execute SQL query and update session state with query results."""
- if not SETTINGS['openai_api_key']:
+ if not SETTINGS["openai_api_key"]:
st.session_state[
- 'error'] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
+ "error"
+ ] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
return
- st.session_state['error'] = None
- query = st.session_state.get('ai_query')
+ st.session_state["error"] = None
+ query = st.session_state.get("ai_query")
if query.rstrip().lstrip():
- exp = st.session_state['explorer']
+ exp = st.session_state["explorer"]
res = exp.ask_ai(query)
if not isinstance(res, pd.DataFrame) or res.empty:
- st.session_state['error'] = 'No results found using AI generated query. Try another query or rerun it.'
+ st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it."
return
- st.session_state['imgs'] = res['im_file'].to_list()
+ st.session_state["imgs"] = res["im_file"].to_list()
def reset_explorer():
"""Resets the explorer to its initial state by clearing session variables."""
- st.session_state['explorer'] = None
- st.session_state['imgs'] = None
- st.session_state['error'] = None
+ st.session_state["explorer"] = None
+ st.session_state["imgs"] = None
+ st.session_state["error"] = None
def utralytics_explorer_docs_callback():
"""Resets the explorer to its initial state by clearing session variables."""
with st.container(border=True):
- st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg',
- width=100)
+ st.image(
+ "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg",
+ width=100,
+ )
st.markdown(
"This demo is built using Ultralytics Explorer API. Visit API docs to try examples & learn more
",
unsafe_allow_html=True,
- help=None)
- st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/datasets/explorer/')
+ help=None,
+ )
+ st.link_button("Ultrlaytics Explorer API", "https://docs.ultralytics.com/datasets/explorer/")
def layout():
"""Resets explorer session variables and provides documentation with a link to API docs."""
- st.set_page_config(layout='wide', initial_sidebar_state='collapsed')
+ st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
st.markdown("Ultralytics Explorer Demo
", unsafe_allow_html=True)
- if st.session_state.get('explorer') is None:
+ if st.session_state.get("explorer") is None:
init_explorer_form()
return
- st.button(':arrow_backward: Select Dataset', on_click=reset_explorer)
- exp = st.session_state.get('explorer')
- col1, col2 = st.columns([0.75, 0.25], gap='small')
+ st.button(":arrow_backward: Select Dataset", on_click=reset_explorer)
+ exp = st.session_state.get("explorer")
+ col1, col2 = st.columns([0.75, 0.25], gap="small")
imgs = []
- if st.session_state.get('error'):
- st.error(st.session_state['error'])
+ if st.session_state.get("error"):
+ st.error(st.session_state["error"])
else:
- imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file']
+ imgs = st.session_state.get("imgs") or exp.table.to_lance().to_table(columns=["im_file"]).to_pydict()["im_file"]
total_imgs, selected_imgs = len(imgs), []
with col1:
subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
with subcol1:
- st.write('Max Images Displayed:')
+ st.write("Max Images Displayed:")
with subcol2:
- num = st.number_input('Max Images Displayed',
- min_value=0,
- max_value=total_imgs,
- value=min(500, total_imgs),
- key='num_imgs_displayed',
- label_visibility='collapsed')
+ num = st.number_input(
+ "Max Images Displayed",
+ min_value=0,
+ max_value=total_imgs,
+ value=min(500, total_imgs),
+ key="num_imgs_displayed",
+ label_visibility="collapsed",
+ )
with subcol3:
- st.write('Start Index:')
+ st.write("Start Index:")
with subcol4:
- start_idx = st.number_input('Start Index',
- min_value=0,
- max_value=total_imgs,
- value=0,
- key='start_index',
- label_visibility='collapsed')
+ start_idx = st.number_input(
+ "Start Index",
+ min_value=0,
+ max_value=total_imgs,
+ value=0,
+ key="start_index",
+ label_visibility="collapsed",
+ )
with subcol5:
- reset = st.button('Reset', use_container_width=False, key='reset')
+ reset = st.button("Reset", use_container_width=False, key="reset")
if reset:
- st.session_state['imgs'] = None
+ st.session_state["imgs"] = None
st.experimental_rerun()
query_form()
ai_query_form()
if total_imgs:
- imgs_displayed = imgs[start_idx:start_idx + num]
+ imgs_displayed = imgs[start_idx : start_idx + num]
selected_imgs = image_select(
- f'Total samples: {total_imgs}',
+ f"Total samples: {total_imgs}",
images=imgs_displayed,
use_container_width=False,
# indices=[i for i in range(num)] if select_all else None,
@@ -222,5 +244,5 @@ def layout():
utralytics_explorer_docs_callback()
-if __name__ == '__main__':
+if __name__ == "__main__":
layout()
diff --git a/ultralytics/data/explorer/utils.py b/ultralytics/data/explorer/utils.py
index 0064d362..c70722bc 100644
--- a/ultralytics/data/explorer/utils.py
+++ b/ultralytics/data/explorer/utils.py
@@ -46,14 +46,13 @@ def get_sim_index_schema():
def sanitize_batch(batch, dataset_info):
"""Sanitizes input batch for inference, ensuring correct format and dimensions."""
- batch['cls'] = batch['cls'].flatten().int().tolist()
- box_cls_pair = sorted(zip(batch['bboxes'].tolist(), batch['cls']), key=lambda x: x[1])
- batch['bboxes'] = [box for box, _ in box_cls_pair]
- batch['cls'] = [cls for _, cls in box_cls_pair]
- batch['labels'] = [dataset_info['names'][i] for i in batch['cls']]
- batch['masks'] = batch['masks'].tolist() if 'masks' in batch else [[[]]]
- batch['keypoints'] = batch['keypoints'].tolist() if 'keypoints' in batch else [[[]]]
-
+ batch["cls"] = batch["cls"].flatten().int().tolist()
+ box_cls_pair = sorted(zip(batch["bboxes"].tolist(), batch["cls"]), key=lambda x: x[1])
+ batch["bboxes"] = [box for box, _ in box_cls_pair]
+ batch["cls"] = [cls for _, cls in box_cls_pair]
+ batch["labels"] = [dataset_info["names"][i] for i in batch["cls"]]
+ batch["masks"] = batch["masks"].tolist() if "masks" in batch else [[[]]]
+ batch["keypoints"] = batch["keypoints"].tolist() if "keypoints" in batch else [[[]]]
return batch
@@ -65,15 +64,16 @@ def plot_query_result(similar_set, plot_labels=True):
similar_set (list): Pyarrow or pandas object containing the similar data points
plot_labels (bool): Whether to plot labels or not
"""
- similar_set = similar_set.to_dict(
- orient='list') if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
+ similar_set = (
+ similar_set.to_dict(orient="list") if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
+ )
empty_masks = [[[]]]
empty_boxes = [[]]
- images = similar_set.get('im_file', [])
- bboxes = similar_set.get('bboxes', []) if similar_set.get('bboxes') is not empty_boxes else []
- masks = similar_set.get('masks') if similar_set.get('masks')[0] != empty_masks else []
- kpts = similar_set.get('keypoints') if similar_set.get('keypoints')[0] != empty_masks else []
- cls = similar_set.get('cls', [])
+ images = similar_set.get("im_file", [])
+ bboxes = similar_set.get("bboxes", []) if similar_set.get("bboxes") is not empty_boxes else []
+ masks = similar_set.get("masks") if similar_set.get("masks")[0] != empty_masks else []
+ kpts = similar_set.get("keypoints") if similar_set.get("keypoints")[0] != empty_masks else []
+ cls = similar_set.get("cls", [])
plot_size = 640
imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], []
@@ -104,34 +104,26 @@ def plot_query_result(similar_set, plot_labels=True):
batch_idx = np.concatenate(batch_idx, axis=0)
cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)
- return plot_images(imgs,
- batch_idx,
- cls,
- bboxes=boxes,
- masks=masks,
- kpts=kpts,
- max_subplots=len(images),
- save=False,
- threaded=False)
+ return plot_images(
+ imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, max_subplots=len(images), save=False, threaded=False
+ )
def prompt_sql_query(query):
"""Plots images with optional labels from a similar data set."""
- check_requirements('openai>=1.6.1')
+ check_requirements("openai>=1.6.1")
from openai import OpenAI
- if not SETTINGS['openai_api_key']:
- logger.warning('OpenAI API key not found in settings. Please enter your API key below.')
- openai_api_key = getpass.getpass('OpenAI API key: ')
- SETTINGS.update({'openai_api_key': openai_api_key})
- openai = OpenAI(api_key=SETTINGS['openai_api_key'])
+ if not SETTINGS["openai_api_key"]:
+ logger.warning("OpenAI API key not found in settings. Please enter your API key below.")
+ openai_api_key = getpass.getpass("OpenAI API key: ")
+ SETTINGS.update({"openai_api_key": openai_api_key})
+ openai = OpenAI(api_key=SETTINGS["openai_api_key"])
messages = [
{
- 'role':
- 'system',
- 'content':
- '''
+ "role": "system",
+ "content": """
You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on
the following schema and a user request. You only need to output the format with fixed selection
statement that selects everything from "'table'", like `SELECT * from 'table'`
@@ -165,10 +157,10 @@ def prompt_sql_query(query):
request - Get all data points that contain 2 or more people and at least one dog
correct query-
SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1;
- '''},
- {
- 'role': 'user',
- 'content': f'{query}'}, ]
+ """,
+ },
+ {"role": "user", "content": f"{query}"},
+ ]
- response = openai.chat.completions.create(model='gpt-3.5-turbo', messages=messages)
+ response = openai.chat.completions.create(model="gpt-3.5-turbo", messages=messages)
return response.choices[0].message.content
diff --git a/ultralytics/data/loaders.py b/ultralytics/data/loaders.py
index 2545e9ab..a2cc2be7 100644
--- a/ultralytics/data/loaders.py
+++ b/ultralytics/data/loaders.py
@@ -23,6 +23,7 @@ from ultralytics.utils.checks import check_requirements
@dataclass
class SourceTypes:
"""Class to represent various types of input sources for predictions."""
+
webcam: bool = False
screenshot: bool = False
from_img: bool = False
@@ -59,12 +60,12 @@ class LoadStreams:
__len__: Return the length of the sources object.
"""
- def __init__(self, sources='file.streams', imgsz=640, vid_stride=1, buffer=False):
+ def __init__(self, sources="file.streams", imgsz=640, vid_stride=1, buffer=False):
"""Initialize instance variables and check for consistent input stream shapes."""
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
self.buffer = buffer # buffer input streams
self.running = True # running flag for Thread
- self.mode = 'stream'
+ self.mode = "stream"
self.imgsz = imgsz
self.vid_stride = vid_stride # video frame-rate stride
@@ -79,33 +80,36 @@ class LoadStreams:
self.sources = [ops.clean_str(x) for x in sources] # clean source names for later
for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream
- st = f'{i + 1}/{n}: {s}... '
- if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
+ st = f"{i + 1}/{n}: {s}... "
+ if urlparse(s).hostname in ("www.youtube.com", "youtube.com", "youtu.be"): # if source is YouTube video
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4'
s = get_best_youtube_url(s)
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
if s == 0 and (is_colab() or is_kaggle()):
- raise NotImplementedError("'source=0' webcam not supported in Colab and Kaggle notebooks. "
- "Try running 'source=0' in a local environment.")
+ raise NotImplementedError(
+ "'source=0' webcam not supported in Colab and Kaggle notebooks. "
+ "Try running 'source=0' in a local environment."
+ )
self.caps[i] = cv2.VideoCapture(s) # store video capture object
if not self.caps[i].isOpened():
- raise ConnectionError(f'{st}Failed to open {s}')
+ raise ConnectionError(f"{st}Failed to open {s}")
w = int(self.caps[i].get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(self.caps[i].get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = self.caps[i].get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
self.frames[i] = max(int(self.caps[i].get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float(
- 'inf') # infinite stream fallback
+ "inf"
+ ) # infinite stream fallback
self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
success, im = self.caps[i].read() # guarantee first frame
if not success or im is None:
- raise ConnectionError(f'{st}Failed to read images from {s}')
+ raise ConnectionError(f"{st}Failed to read images from {s}")
self.imgs[i].append(im)
self.shape[i] = im.shape
self.threads[i] = Thread(target=self.update, args=([i, self.caps[i], s]), daemon=True)
- LOGGER.info(f'{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)')
+ LOGGER.info(f"{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)")
self.threads[i].start()
- LOGGER.info('') # newline
+ LOGGER.info("") # newline
# Check for common shapes
self.bs = self.__len__()
@@ -121,7 +125,7 @@ class LoadStreams:
success, im = cap.retrieve()
if not success:
im = np.zeros(self.shape[i], dtype=np.uint8)
- LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.')
+ LOGGER.warning("WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.")
cap.open(stream) # re-open stream if signal was lost
if self.buffer:
self.imgs[i].append(im)
@@ -140,7 +144,7 @@ class LoadStreams:
try:
cap.release() # release video capture
except Exception as e:
- LOGGER.warning(f'WARNING ⚠️ Could not release VideoCapture object: {e}')
+ LOGGER.warning(f"WARNING ⚠️ Could not release VideoCapture object: {e}")
cv2.destroyAllWindows()
def __iter__(self):
@@ -154,16 +158,15 @@ class LoadStreams:
images = []
for i, x in enumerate(self.imgs):
-
# Wait until a frame is available in each buffer
while not x:
- if not self.threads[i].is_alive() or cv2.waitKey(1) == ord('q'): # q to quit
+ if not self.threads[i].is_alive() or cv2.waitKey(1) == ord("q"): # q to quit
self.close()
raise StopIteration
time.sleep(1 / min(self.fps))
x = self.imgs[i]
if not x:
- LOGGER.warning(f'WARNING ⚠️ Waiting for stream {i}')
+ LOGGER.warning(f"WARNING ⚠️ Waiting for stream {i}")
# Get and remove the first frame from imgs buffer
if self.buffer:
@@ -174,7 +177,7 @@ class LoadStreams:
images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8))
x.clear()
- return self.sources, images, None, ''
+ return self.sources, images, None, ""
def __len__(self):
"""Return the length of the sources object."""
@@ -209,7 +212,7 @@ class LoadScreenshots:
def __init__(self, source, imgsz=640):
"""Source = [screen_number left top width height] (pixels)."""
- check_requirements('mss')
+ check_requirements("mss")
import mss # noqa
source, *params = source.split()
@@ -221,18 +224,18 @@ class LoadScreenshots:
elif len(params) == 5:
self.screen, left, top, width, height = (int(x) for x in params)
self.imgsz = imgsz
- self.mode = 'stream'
+ self.mode = "stream"
self.frame = 0
self.sct = mss.mss()
self.bs = 1
# Parse monitor shape
monitor = self.sct.monitors[self.screen]
- self.top = monitor['top'] if top is None else (monitor['top'] + top)
- self.left = monitor['left'] if left is None else (monitor['left'] + left)
- self.width = width or monitor['width']
- self.height = height or monitor['height']
- self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
+ self.top = monitor["top"] if top is None else (monitor["top"] + top)
+ self.left = monitor["left"] if left is None else (monitor["left"] + left)
+ self.width = width or monitor["width"]
+ self.height = height or monitor["height"]
+ self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
def __iter__(self):
"""Returns an iterator of the object."""
@@ -241,7 +244,7 @@ class LoadScreenshots:
def __next__(self):
"""mss screen capture: get raw pixels from the screen as np array."""
im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3] # BGRA to BGR
- s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: '
+ s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
self.frame += 1
return [str(self.screen)], [im0], None, s # screen, img, vid_cap, string
@@ -274,32 +277,32 @@ class LoadImages:
def __init__(self, path, imgsz=640, vid_stride=1):
"""Initialize the Dataloader and raise FileNotFoundError if file not found."""
parent = None
- if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
+ if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
parent = Path(path).parent
path = Path(path).read_text().splitlines() # list of sources
files = []
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
a = str(Path(p).absolute()) # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912
- if '*' in a:
+ if "*" in a:
files.extend(sorted(glob.glob(a, recursive=True))) # glob
elif os.path.isdir(a):
- files.extend(sorted(glob.glob(os.path.join(a, '*.*')))) # dir
+ files.extend(sorted(glob.glob(os.path.join(a, "*.*")))) # dir
elif os.path.isfile(a):
files.append(a) # files (absolute or relative to CWD)
elif parent and (parent / p).is_file():
files.append(str((parent / p).absolute())) # files (relative to *.txt file parent)
else:
- raise FileNotFoundError(f'{p} does not exist')
+ raise FileNotFoundError(f"{p} does not exist")
- images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
- videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
+ images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS]
+ videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS]
ni, nv = len(images), len(videos)
self.imgsz = imgsz
self.files = images + videos
self.nf = ni + nv # number of files
self.video_flag = [False] * ni + [True] * nv
- self.mode = 'image'
+ self.mode = "image"
self.vid_stride = vid_stride # video frame-rate stride
self.bs = 1
if any(videos):
@@ -307,8 +310,10 @@ class LoadImages:
else:
self.cap = None
if self.nf == 0:
- raise FileNotFoundError(f'No images or videos found in {p}. '
- f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}')
+ raise FileNotFoundError(
+ f"No images or videos found in {p}. "
+ f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
+ )
def __iter__(self):
"""Returns an iterator object for VideoStream or ImageFolder."""
@@ -323,7 +328,7 @@ class LoadImages:
if self.video_flag[self.count]:
# Read video
- self.mode = 'video'
+ self.mode = "video"
for _ in range(self.vid_stride):
self.cap.grab()
success, im0 = self.cap.retrieve()
@@ -338,15 +343,15 @@ class LoadImages:
self.frame += 1
# im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
- s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
+ s = f"video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: "
else:
# Read image
self.count += 1
im0 = cv2.imread(path) # BGR
if im0 is None:
- raise FileNotFoundError(f'Image Not Found {path}')
- s = f'image {self.count}/{self.nf} {path}: '
+ raise FileNotFoundError(f"Image Not Found {path}")
+ s = f"image {self.count}/{self.nf} {path}: "
return [path], [im0], self.cap, s
@@ -385,20 +390,20 @@ class LoadPilAndNumpy:
"""Initialize PIL and Numpy Dataloader."""
if not isinstance(im0, list):
im0 = [im0]
- self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
+ self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
self.im0 = [self._single_check(im) for im in im0]
self.imgsz = imgsz
- self.mode = 'image'
+ self.mode = "image"
# Generate fake paths
self.bs = len(self.im0)
@staticmethod
def _single_check(im):
"""Validate and format an image to numpy array."""
- assert isinstance(im, (Image.Image, np.ndarray)), f'Expected PIL/np.ndarray image type, but got {type(im)}'
+ assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}"
if isinstance(im, Image.Image):
- if im.mode != 'RGB':
- im = im.convert('RGB')
+ if im.mode != "RGB":
+ im = im.convert("RGB")
im = np.asarray(im)[:, :, ::-1]
im = np.ascontiguousarray(im) # contiguous
return im
@@ -412,7 +417,7 @@ class LoadPilAndNumpy:
if self.count == 1: # loop only once as it's batch inference
raise StopIteration
self.count += 1
- return self.paths, self.im0, None, ''
+ return self.paths, self.im0, None, ""
def __iter__(self):
"""Enables iteration for class LoadPilAndNumpy."""
@@ -441,14 +446,16 @@ class LoadTensor:
"""Initialize Tensor Dataloader."""
self.im0 = self._single_check(im0)
self.bs = self.im0.shape[0]
- self.mode = 'image'
- self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
+ self.mode = "image"
+ self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
@staticmethod
def _single_check(im, stride=32):
"""Validate and format an image to torch.Tensor."""
- s = f'WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) ' \
- f'divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible.'
+ s = (
+ f"WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) "
+ f"divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible."
+ )
if len(im.shape) != 4:
if len(im.shape) != 3:
raise ValueError(s)
@@ -457,8 +464,10 @@ class LoadTensor:
if im.shape[2] % stride or im.shape[3] % stride:
raise ValueError(s)
if im.max() > 1.0 + torch.finfo(im.dtype).eps: # torch.float32 eps is 1.2e-07
- LOGGER.warning(f'WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. '
- f'Dividing input by 255.')
+ LOGGER.warning(
+ f"WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. "
+ f"Dividing input by 255."
+ )
im = im.float() / 255.0
return im
@@ -473,7 +482,7 @@ class LoadTensor:
if self.count == 1:
raise StopIteration
self.count += 1
- return self.paths, self.im0, None, ''
+ return self.paths, self.im0, None, ""
def __len__(self):
"""Returns the batch size."""
@@ -485,12 +494,14 @@ def autocast_list(source):
files = []
for im in source:
if isinstance(im, (str, Path)): # filename or uri
- files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im))
+ files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith("http") else im))
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
files.append(im)
else:
- raise TypeError(f'type {type(im).__name__} is not a supported Ultralytics prediction source type. \n'
- f'See https://docs.ultralytics.com/modes/predict for supported source types.')
+ raise TypeError(
+ f"type {type(im).__name__} is not a supported Ultralytics prediction source type. \n"
+ f"See https://docs.ultralytics.com/modes/predict for supported source types."
+ )
return files
@@ -513,16 +524,18 @@ def get_best_youtube_url(url, use_pafy=True):
(str): The URL of the best quality MP4 video stream, or None if no suitable stream is found.
"""
if use_pafy:
- check_requirements(('pafy', 'youtube_dl==2020.12.2'))
+ check_requirements(("pafy", "youtube_dl==2020.12.2"))
import pafy # noqa
- return pafy.new(url).getbestvideo(preftype='mp4').url
+
+ return pafy.new(url).getbestvideo(preftype="mp4").url
else:
- check_requirements('yt-dlp')
+ check_requirements("yt-dlp")
import yt_dlp
- with yt_dlp.YoutubeDL({'quiet': True}) as ydl:
+
+ with yt_dlp.YoutubeDL({"quiet": True}) as ydl:
info_dict = ydl.extract_info(url, download=False) # extract info
- for f in reversed(info_dict.get('formats', [])): # reversed because best is usually last
+ for f in reversed(info_dict.get("formats", [])): # reversed because best is usually last
# Find a format with video codec, no audio, *.mp4 extension at least 1920x1080 size
- good_size = (f.get('width') or 0) >= 1920 or (f.get('height') or 0) >= 1080
- if good_size and f['vcodec'] != 'none' and f['acodec'] == 'none' and f['ext'] == 'mp4':
- return f.get('url')
+ good_size = (f.get("width") or 0) >= 1920 or (f.get("height") or 0) >= 1080
+ if good_size and f["vcodec"] != "none" and f["acodec"] == "none" and f["ext"] == "mp4":
+ return f.get("url")
diff --git a/ultralytics/data/split_dota.py b/ultralytics/data/split_dota.py
index aea97493..6e7169b7 100644
--- a/ultralytics/data/split_dota.py
+++ b/ultralytics/data/split_dota.py
@@ -14,7 +14,7 @@ from tqdm import tqdm
from ultralytics.data.utils import exif_size, img2label_paths
from ultralytics.utils.checks import check_requirements
-check_requirements('shapely')
+check_requirements("shapely")
from shapely.geometry import Polygon
@@ -54,7 +54,7 @@ def bbox_iof(polygon1, bbox2, eps=1e-6):
return outputs
-def load_yolo_dota(data_root, split='train'):
+def load_yolo_dota(data_root, split="train"):
"""
Load DOTA dataset.
@@ -72,10 +72,10 @@ def load_yolo_dota(data_root, split='train'):
- train
- val
"""
- assert split in ['train', 'val']
- im_dir = os.path.join(data_root, f'images/{split}')
+ assert split in ["train", "val"]
+ im_dir = os.path.join(data_root, f"images/{split}")
assert Path(im_dir).exists(), f"Can't find {im_dir}, please check your data root."
- im_files = glob(os.path.join(data_root, f'images/{split}/*'))
+ im_files = glob(os.path.join(data_root, f"images/{split}/*"))
lb_files = img2label_paths(im_files)
annos = []
for im_file, lb_file in zip(im_files, lb_files):
@@ -100,7 +100,7 @@ def get_windows(im_size, crop_sizes=[1024], gaps=[200], im_rate_thr=0.6, eps=0.0
h, w = im_size
windows = []
for crop_size, gap in zip(crop_sizes, gaps):
- assert crop_size > gap, f'invaild crop_size gap pair [{crop_size} {gap}]'
+ assert crop_size > gap, f"invalid crop_size gap pair [{crop_size} {gap}]"
step = crop_size - gap
xn = 1 if w <= crop_size else ceil((w - crop_size) / step + 1)
@@ -132,8 +132,8 @@ def get_windows(im_size, crop_sizes=[1024], gaps=[200], im_rate_thr=0.6, eps=0.0
def get_window_obj(anno, windows, iof_thr=0.7):
"""Get objects for each window."""
- h, w = anno['ori_size']
- label = anno['label']
+ h, w = anno["ori_size"]
+ label = anno["label"]
if len(label):
label[:, 1::2] *= w
label[:, 2::2] *= h
@@ -166,15 +166,15 @@ def crop_and_save(anno, windows, window_objs, im_dir, lb_dir):
- train
- val
"""
- im = cv2.imread(anno['filepath'])
- name = Path(anno['filepath']).stem
+ im = cv2.imread(anno["filepath"])
+ name = Path(anno["filepath"]).stem
for i, window in enumerate(windows):
x_start, y_start, x_stop, y_stop = window.tolist()
- new_name = name + '__' + str(x_stop - x_start) + '__' + str(x_start) + '___' + str(y_start)
+ new_name = name + "__" + str(x_stop - x_start) + "__" + str(x_start) + "___" + str(y_start)
patch_im = im[y_start:y_stop, x_start:x_stop]
ph, pw = patch_im.shape[:2]
- cv2.imwrite(os.path.join(im_dir, f'{new_name}.jpg'), patch_im)
+ cv2.imwrite(os.path.join(im_dir, f"{new_name}.jpg"), patch_im)
label = window_objs[i]
if len(label) == 0:
continue
@@ -183,13 +183,13 @@ def crop_and_save(anno, windows, window_objs, im_dir, lb_dir):
label[:, 1::2] /= pw
label[:, 2::2] /= ph
- with open(os.path.join(lb_dir, f'{new_name}.txt'), 'w') as f:
+ with open(os.path.join(lb_dir, f"{new_name}.txt"), "w") as f:
for lb in label:
- formatted_coords = ['{:.6g}'.format(coord) for coord in lb[1:]]
+ formatted_coords = ["{:.6g}".format(coord) for coord in lb[1:]]
f.write(f"{int(lb[0])} {' '.join(formatted_coords)}\n")
-def split_images_and_labels(data_root, save_dir, split='train', crop_sizes=[1024], gaps=[200]):
+def split_images_and_labels(data_root, save_dir, split="train", crop_sizes=[1024], gaps=[200]):
"""
Split both images and labels.
@@ -207,14 +207,14 @@ def split_images_and_labels(data_root, save_dir, split='train', crop_sizes=[1024
- labels
- split
"""
- im_dir = Path(save_dir) / 'images' / split
+ im_dir = Path(save_dir) / "images" / split
im_dir.mkdir(parents=True, exist_ok=True)
- lb_dir = Path(save_dir) / 'labels' / split
+ lb_dir = Path(save_dir) / "labels" / split
lb_dir.mkdir(parents=True, exist_ok=True)
annos = load_yolo_dota(data_root, split=split)
for anno in tqdm(annos, total=len(annos), desc=split):
- windows = get_windows(anno['ori_size'], crop_sizes, gaps)
+ windows = get_windows(anno["ori_size"], crop_sizes, gaps)
window_objs = get_window_obj(anno, windows)
crop_and_save(anno, windows, window_objs, str(im_dir), str(lb_dir))
@@ -245,7 +245,7 @@ def split_trainval(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]):
for r in rates:
crop_sizes.append(int(crop_size / r))
gaps.append(int(gap / r))
- for split in ['train', 'val']:
+ for split in ["train", "val"]:
split_images_and_labels(data_root, save_dir, split, crop_sizes, gaps)
@@ -267,30 +267,30 @@ def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]):
for r in rates:
crop_sizes.append(int(crop_size / r))
gaps.append(int(gap / r))
- save_dir = Path(save_dir) / 'images' / 'test'
+ save_dir = Path(save_dir) / "images" / "test"
save_dir.mkdir(parents=True, exist_ok=True)
- im_dir = Path(os.path.join(data_root, 'images/test'))
+ im_dir = Path(os.path.join(data_root, "images/test"))
assert im_dir.exists(), f"Can't find {str(im_dir)}, please check your data root."
- im_files = glob(str(im_dir / '*'))
- for im_file in tqdm(im_files, total=len(im_files), desc='test'):
+ im_files = glob(str(im_dir / "*"))
+ for im_file in tqdm(im_files, total=len(im_files), desc="test"):
w, h = exif_size(Image.open(im_file))
windows = get_windows((h, w), crop_sizes=crop_sizes, gaps=gaps)
im = cv2.imread(im_file)
name = Path(im_file).stem
for window in windows:
x_start, y_start, x_stop, y_stop = window.tolist()
- new_name = (name + '__' + str(x_stop - x_start) + '__' + str(x_start) + '___' + str(y_start))
+ new_name = name + "__" + str(x_stop - x_start) + "__" + str(x_start) + "___" + str(y_start)
patch_im = im[y_start:y_stop, x_start:x_stop]
- cv2.imwrite(os.path.join(str(save_dir), f'{new_name}.jpg'), patch_im)
+ cv2.imwrite(os.path.join(str(save_dir), f"{new_name}.jpg"), patch_im)
-if __name__ == '__main__':
+if __name__ == "__main__":
split_trainval(
- data_root='DOTAv2',
- save_dir='DOTAv2-split',
+ data_root="DOTAv2",
+ save_dir="DOTAv2-split",
)
split_test(
- data_root='DOTAv2',
- save_dir='DOTAv2-split',
+ data_root="DOTAv2",
+ save_dir="DOTAv2-split",
)
diff --git a/ultralytics/data/utils.py b/ultralytics/data/utils.py
index c3447394..9214f8ce 100644
--- a/ultralytics/data/utils.py
+++ b/ultralytics/data/utils.py
@@ -17,36 +17,47 @@ import numpy as np
from PIL import Image, ImageOps
from ultralytics.nn.autobackend import check_class_names
-from ultralytics.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_YAML, TQDM, clean_url, colorstr,
- emojis, yaml_load, yaml_save)
+from ultralytics.utils import (
+ DATASETS_DIR,
+ LOGGER,
+ NUM_THREADS,
+ ROOT,
+ SETTINGS_YAML,
+ TQDM,
+ clean_url,
+ colorstr,
+ emojis,
+ yaml_load,
+ yaml_save,
+)
from ultralytics.utils.checks import check_file, check_font, is_ascii
from ultralytics.utils.downloads import download, safe_download, unzip_file
from ultralytics.utils.ops import segments2boxes
-HELP_URL = 'See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance.'
-IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # image suffixes
-VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv', 'webm' # video suffixes
-PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
+HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance."
+IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # image suffixes
+VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm" # video suffixes
+PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
def img2label_paths(img_paths):
"""Define label paths as a function of image paths."""
- sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
- return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
+ sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
+ return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
def get_hash(paths):
"""Returns a single hash value of a list of paths (files or dirs)."""
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
h = hashlib.sha256(str(size).encode()) # hash sizes
- h.update(''.join(paths).encode()) # hash paths
+ h.update("".join(paths).encode()) # hash paths
return h.hexdigest() # return hash
def exif_size(img: Image.Image):
"""Returns exif-corrected PIL size."""
s = img.size # (width, height)
- if img.format == 'JPEG': # only support JPEG images
+ if img.format == "JPEG": # only support JPEG images
with contextlib.suppress(Exception):
exif = img.getexif()
if exif:
@@ -60,24 +71,24 @@ def verify_image(args):
"""Verify one image."""
(im_file, cls), prefix = args
# Number (found, corrupt), message
- nf, nc, msg = 0, 0, ''
+ nf, nc, msg = 0, 0, ""
try:
im = Image.open(im_file)
im.verify() # PIL verify
shape = exif_size(im) # image size
shape = (shape[1], shape[0]) # hw
- assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
- assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
- if im.format.lower() in ('jpg', 'jpeg'):
- with open(im_file, 'rb') as f:
+ assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
+ assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
+ if im.format.lower() in ("jpg", "jpeg"):
+ with open(im_file, "rb") as f:
f.seek(-2, 2)
- if f.read() != b'\xff\xd9': # corrupt JPEG
- ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
- msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
+ if f.read() != b"\xff\xd9": # corrupt JPEG
+ ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
+ msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
nf = 1
except Exception as e:
nc = 1
- msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
+ msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
return (im_file, cls), nf, nc, msg
@@ -85,21 +96,21 @@ def verify_image_label(args):
"""Verify one image-label pair."""
im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
# Number (missing, found, empty, corrupt), message, segments, keypoints
- nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, '', [], None
+ nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
try:
# Verify images
im = Image.open(im_file)
im.verify() # PIL verify
shape = exif_size(im) # image size
shape = (shape[1], shape[0]) # hw
- assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
- assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
- if im.format.lower() in ('jpg', 'jpeg'):
- with open(im_file, 'rb') as f:
+ assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
+ assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
+ if im.format.lower() in ("jpg", "jpeg"):
+ with open(im_file, "rb") as f:
f.seek(-2, 2)
- if f.read() != b'\xff\xd9': # corrupt JPEG
- ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
- msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
+ if f.read() != b"\xff\xd9": # corrupt JPEG
+ ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
+ msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
# Verify labels
if os.path.isfile(lb_file):
@@ -114,25 +125,26 @@ def verify_image_label(args):
nl = len(lb)
if nl:
if keypoint:
- assert lb.shape[1] == (5 + nkpt * ndim), f'labels require {(5 + nkpt * ndim)} columns each'
+ assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each"
points = lb[:, 5:].reshape(-1, ndim)[:, :2]
else:
- assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
+ assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
points = lb[:, 1:]
- assert points.max() <= 1, f'non-normalized or out of bounds coordinates {points[points > 1]}'
- assert lb.min() >= 0, f'negative label values {lb[lb < 0]}'
+ assert points.max() <= 1, f"non-normalized or out of bounds coordinates {points[points > 1]}"
+ assert lb.min() >= 0, f"negative label values {lb[lb < 0]}"
# All labels
max_cls = lb[:, 0].max() # max label count
- assert max_cls <= num_cls, \
- f'Label class {int(max_cls)} exceeds dataset class count {num_cls}. ' \
- f'Possible class labels are 0-{num_cls - 1}'
+ assert max_cls <= num_cls, (
+ f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. "
+ f"Possible class labels are 0-{num_cls - 1}"
+ )
_, i = np.unique(lb, axis=0, return_index=True)
if len(i) < nl: # duplicate row check
lb = lb[i] # remove duplicates
if segments:
segments = [segments[x] for x in i]
- msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed'
+ msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed"
else:
ne = 1 # label empty
lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)
@@ -148,7 +160,7 @@ def verify_image_label(args):
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
except Exception as e:
nc = 1
- msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
+ msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
return [None, None, None, None, None, nm, nf, ne, nc, msg]
@@ -194,8 +206,10 @@ def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
"""Return a (640, 640) overlap mask."""
- masks = np.zeros((imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
- dtype=np.int32 if len(segments) > 255 else np.uint8)
+ masks = np.zeros(
+ (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
+ dtype=np.int32 if len(segments) > 255 else np.uint8,
+ )
areas = []
ms = []
for si in range(len(segments)):
@@ -226,7 +240,7 @@ def find_dataset_yaml(path: Path) -> Path:
Returns:
(Path): The path of the found YAML file.
"""
- files = list(path.glob('*.yaml')) or list(path.rglob('*.yaml')) # try root level first and then recursive
+ files = list(path.glob("*.yaml")) or list(path.rglob("*.yaml")) # try root level first and then recursive
assert files, f"No YAML file found in '{path.resolve()}'"
if len(files) > 1:
files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match
@@ -253,7 +267,7 @@ def check_det_dataset(dataset, autodownload=True):
file = check_file(dataset)
# Download (optional)
- extract_dir = ''
+ extract_dir = ""
if zipfile.is_zipfile(file) or is_tarfile(file):
new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
file = find_dataset_yaml(DATASETS_DIR / new_dir)
@@ -263,43 +277,44 @@ def check_det_dataset(dataset, autodownload=True):
data = yaml_load(file, append_filename=True) # dictionary
# Checks
- for k in 'train', 'val':
+ for k in "train", "val":
if k not in data:
- if k != 'val' or 'validation' not in data:
+ if k != "val" or "validation" not in data:
raise SyntaxError(
- emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs."))
+ emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.")
+ )
LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.")
- data['val'] = data.pop('validation') # replace 'validation' key with 'val' key
- if 'names' not in data and 'nc' not in data:
+ data["val"] = data.pop("validation") # replace 'validation' key with 'val' key
+ if "names" not in data and "nc" not in data:
raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
- if 'names' in data and 'nc' in data and len(data['names']) != data['nc']:
+ if "names" in data and "nc" in data and len(data["names"]) != data["nc"]:
raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
- if 'names' not in data:
- data['names'] = [f'class_{i}' for i in range(data['nc'])]
+ if "names" not in data:
+ data["names"] = [f"class_{i}" for i in range(data["nc"])]
else:
- data['nc'] = len(data['names'])
+ data["nc"] = len(data["names"])
- data['names'] = check_class_names(data['names'])
+ data["names"] = check_class_names(data["names"])
# Resolve paths
- path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent) # dataset root
+ path = Path(extract_dir or data.get("path") or Path(data.get("yaml_file", "")).parent) # dataset root
if not path.is_absolute():
path = (DATASETS_DIR / path).resolve()
# Set paths
- data['path'] = path # download scripts
- for k in 'train', 'val', 'test':
+ data["path"] = path # download scripts
+ for k in "train", "val", "test":
if data.get(k): # prepend path
if isinstance(data[k], str):
x = (path / data[k]).resolve()
- if not x.exists() and data[k].startswith('../'):
+ if not x.exists() and data[k].startswith("../"):
x = (path / data[k][3:]).resolve()
data[k] = str(x)
else:
data[k] = [str((path / x).resolve()) for x in data[k]]
# Parse YAML
- val, s = (data.get(x) for x in ('val', 'download'))
+ val, s = (data.get(x) for x in ("val", "download"))
if val:
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
if not all(x.exists() for x in val):
@@ -312,22 +327,22 @@ def check_det_dataset(dataset, autodownload=True):
raise FileNotFoundError(m)
t = time.time()
r = None # success
- if s.startswith('http') and s.endswith('.zip'): # URL
+ if s.startswith("http") and s.endswith(".zip"): # URL
safe_download(url=s, dir=DATASETS_DIR, delete=True)
- elif s.startswith('bash '): # bash script
- LOGGER.info(f'Running {s} ...')
+ elif s.startswith("bash "): # bash script
+ LOGGER.info(f"Running {s} ...")
r = os.system(s)
else: # python script
- exec(s, {'yaml': data})
- dt = f'({round(time.time() - t, 1)}s)'
- s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt} ❌'
- LOGGER.info(f'Dataset download {s}\n')
- check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts
+ exec(s, {"yaml": data})
+ dt = f"({round(time.time() - t, 1)}s)"
+ s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
+ LOGGER.info(f"Dataset download {s}\n")
+ check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts
return data # dictionary
-def check_cls_dataset(dataset, split=''):
+def check_cls_dataset(dataset, split=""):
"""
Checks a classification dataset such as Imagenet.
@@ -348,54 +363,59 @@ def check_cls_dataset(dataset, split=''):
"""
# Download (optional if dataset=https://file.zip is passed directly)
- if str(dataset).startswith(('http:/', 'https:/')):
+ if str(dataset).startswith(("http:/", "https:/")):
dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
dataset = Path(dataset)
data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
if not data_dir.is_dir():
- LOGGER.warning(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
+ LOGGER.warning(f"\nDataset not found ⚠️, missing path {data_dir}, attempting download...")
t = time.time()
- if str(dataset) == 'imagenet':
+ if str(dataset) == "imagenet":
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
else:
- url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
+ url = f"https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip"
download(url, dir=data_dir.parent)
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
LOGGER.info(s)
- train_set = data_dir / 'train'
- val_set = data_dir / 'val' if (data_dir / 'val').exists() else data_dir / 'validation' if \
- (data_dir / 'validation').exists() else None # data/test or data/val
- test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
- if split == 'val' and not val_set:
+ train_set = data_dir / "train"
+ val_set = (
+ data_dir / "val"
+ if (data_dir / "val").exists()
+ else data_dir / "validation"
+ if (data_dir / "validation").exists()
+ else None
+ ) # data/test or data/val
+ test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test
+ if split == "val" and not val_set:
LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
- elif split == 'test' and not test_set:
+ elif split == "test" and not test_set:
LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
- nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
- names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
+ nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # number of classes
+ names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # class names list
names = dict(enumerate(sorted(names)))
# Print to console
- for k, v in {'train': train_set, 'val': val_set, 'test': test_set}.items():
+ for k, v in {"train": train_set, "val": val_set, "test": test_set}.items():
prefix = f'{colorstr(f"{k}:")} {v}...'
if v is None:
LOGGER.info(prefix)
else:
- files = [path for path in v.rglob('*.*') if path.suffix[1:].lower() in IMG_FORMATS]
+ files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS]
nf = len(files) # number of files
nd = len({file.parent for file in files}) # number of directories
if nf == 0:
- if k == 'train':
+ if k == "train":
raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ "))
else:
- LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found')
+ LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found")
elif nd != nc:
- LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}')
+ LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}")
else:
- LOGGER.info(f'{prefix} found {nf} images in {nd} classes ✅ ')
+ LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ")
- return {'train': train_set, 'val': val_set, 'test': test_set, 'nc': nc, 'names': names}
+ return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names}
class HUBDatasetStats:
@@ -423,42 +443,43 @@ class HUBDatasetStats:
```
"""
- def __init__(self, path='coco8.yaml', task='detect', autodownload=False):
+ def __init__(self, path="coco8.yaml", task="detect", autodownload=False):
"""Initialize class."""
path = Path(path).resolve()
- LOGGER.info(f'Starting HUB dataset checks for {path}....')
+ LOGGER.info(f"Starting HUB dataset checks for {path}....")
self.task = task # detect, segment, pose, classify
- if self.task == 'classify':
+ if self.task == "classify":
unzip_dir = unzip_file(path)
data = check_cls_dataset(unzip_dir)
- data['path'] = unzip_dir
+ data["path"] = unzip_dir
else: # detect, segment, pose
_, data_dir, yaml_path = self._unzip(Path(path))
try:
# Load YAML with checks
data = yaml_load(yaml_path)
- data['path'] = '' # strip path since YAML should be in dataset root for all HUB datasets
+ data["path"] = "" # strip path since YAML should be in dataset root for all HUB datasets
yaml_save(yaml_path, data)
data = check_det_dataset(yaml_path, autodownload) # dict
- data['path'] = data_dir # YAML path should be set to '' (relative) or parent (absolute)
+ data["path"] = data_dir # YAML path should be set to '' (relative) or parent (absolute)
except Exception as e:
- raise Exception('error/HUB/dataset_stats/init') from e
+ raise Exception("error/HUB/dataset_stats/init") from e
self.hub_dir = Path(f'{data["path"]}-hub')
- self.im_dir = self.hub_dir / 'images'
+ self.im_dir = self.hub_dir / "images"
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
- self.stats = {'nc': len(data['names']), 'names': list(data['names'].values())} # statistics dictionary
+ self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())} # statistics dictionary
self.data = data
@staticmethod
def _unzip(path):
"""Unzip data.zip."""
- if not str(path).endswith('.zip'): # path is data.yaml
+ if not str(path).endswith(".zip"): # path is data.yaml
return False, None, path
unzip_dir = unzip_file(path, path=path.parent)
- assert unzip_dir.is_dir(), f'Error unzipping {path}, {unzip_dir} not found. ' \
- f'path/to/abc.zip MUST unzip to path/to/abc/'
+ assert unzip_dir.is_dir(), (
+ f"Error unzipping {path}, {unzip_dir} not found. " f"path/to/abc.zip MUST unzip to path/to/abc/"
+ )
return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path
def _hub_ops(self, f):
@@ -470,31 +491,31 @@ class HUBDatasetStats:
def _round(labels):
"""Update labels to integer class and 4 decimal place floats."""
- if self.task == 'detect':
- coordinates = labels['bboxes']
- elif self.task == 'segment':
- coordinates = [x.flatten() for x in labels['segments']]
- elif self.task == 'pose':
- n = labels['keypoints'].shape[0]
- coordinates = np.concatenate((labels['bboxes'], labels['keypoints'].reshape(n, -1)), 1)
+ if self.task == "detect":
+ coordinates = labels["bboxes"]
+ elif self.task == "segment":
+ coordinates = [x.flatten() for x in labels["segments"]]
+ elif self.task == "pose":
+ n = labels["keypoints"].shape[0]
+ coordinates = np.concatenate((labels["bboxes"], labels["keypoints"].reshape(n, -1)), 1)
else:
- raise ValueError('Undefined dataset task.')
- zipped = zip(labels['cls'], coordinates)
+ raise ValueError("Undefined dataset task.")
+ zipped = zip(labels["cls"], coordinates)
return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped]
- for split in 'train', 'val', 'test':
+ for split in "train", "val", "test":
self.stats[split] = None # predefine
path = self.data.get(split)
# Check split
if path is None: # no split
continue
- files = [f for f in Path(path).rglob('*.*') if f.suffix[1:].lower() in IMG_FORMATS] # image files in split
+ files = [f for f in Path(path).rglob("*.*") if f.suffix[1:].lower() in IMG_FORMATS] # image files in split
if not files: # no images
continue
# Get dataset statistics
- if self.task == 'classify':
+ if self.task == "classify":
from torchvision.datasets import ImageFolder
dataset = ImageFolder(self.data[split])
@@ -504,38 +525,35 @@ class HUBDatasetStats:
x[im[1]] += 1
self.stats[split] = {
- 'instance_stats': {
- 'total': len(dataset),
- 'per_class': x.tolist()},
- 'image_stats': {
- 'total': len(dataset),
- 'unlabelled': 0,
- 'per_class': x.tolist()},
- 'labels': [{
- Path(k).name: v} for k, v in dataset.imgs]}
+ "instance_stats": {"total": len(dataset), "per_class": x.tolist()},
+ "image_stats": {"total": len(dataset), "unlabelled": 0, "per_class": x.tolist()},
+ "labels": [{Path(k).name: v} for k, v in dataset.imgs],
+ }
else:
from ultralytics.data import YOLODataset
dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task)
- x = np.array([
- np.bincount(label['cls'].astype(int).flatten(), minlength=self.data['nc'])
- for label in TQDM(dataset.labels, total=len(dataset), desc='Statistics')]) # shape(128x80)
+ x = np.array(
+ [
+ np.bincount(label["cls"].astype(int).flatten(), minlength=self.data["nc"])
+ for label in TQDM(dataset.labels, total=len(dataset), desc="Statistics")
+ ]
+ ) # shape(128x80)
self.stats[split] = {
- 'instance_stats': {
- 'total': int(x.sum()),
- 'per_class': x.sum(0).tolist()},
- 'image_stats': {
- 'total': len(dataset),
- 'unlabelled': int(np.all(x == 0, 1).sum()),
- 'per_class': (x > 0).sum(0).tolist()},
- 'labels': [{
- Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)]}
+ "instance_stats": {"total": int(x.sum()), "per_class": x.sum(0).tolist()},
+ "image_stats": {
+ "total": len(dataset),
+ "unlabelled": int(np.all(x == 0, 1).sum()),
+ "per_class": (x > 0).sum(0).tolist(),
+ },
+ "labels": [{Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)],
+ }
# Save, print and return
if save:
- stats_path = self.hub_dir / 'stats.json'
- LOGGER.info(f'Saving {stats_path.resolve()}...')
- with open(stats_path, 'w') as f:
+ stats_path = self.hub_dir / "stats.json"
+ LOGGER.info(f"Saving {stats_path.resolve()}...")
+ with open(stats_path, "w") as f:
json.dump(self.stats, f) # save stats.json
if verbose:
LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
@@ -545,14 +563,14 @@ class HUBDatasetStats:
"""Compress images for Ultralytics HUB."""
from ultralytics.data import YOLODataset # ClassificationDataset
- for split in 'train', 'val', 'test':
+ for split in "train", "val", "test":
if self.data.get(split) is None:
continue
dataset = YOLODataset(img_path=self.data[split], data=self.data)
with ThreadPool(NUM_THREADS) as pool:
- for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f'{split} images'):
+ for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f"{split} images"):
pass
- LOGGER.info(f'Done. All images saved to {self.im_dir}')
+ LOGGER.info(f"Done. All images saved to {self.im_dir}")
return self.im_dir
@@ -583,9 +601,9 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
r = max_dim / max(im.height, im.width) # ratio
if r < 1.0: # image too large
im = im.resize((int(im.width * r), int(im.height * r)))
- im.save(f_new or f, 'JPEG', quality=quality, optimize=True) # save
+ im.save(f_new or f, "JPEG", quality=quality, optimize=True) # save
except Exception as e: # use OpenCV
- LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
+ LOGGER.info(f"WARNING ⚠️ HUB ops PIL failure {f}: {e}")
im = cv2.imread(f)
im_height, im_width = im.shape[:2]
r = max_dim / max(im_height, im_width) # ratio
@@ -594,7 +612,7 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
cv2.imwrite(str(f_new or f), im)
-def autosplit(path=DATASETS_DIR / 'coco8/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
+def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annotated_only=False):
"""
Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
@@ -612,18 +630,18 @@ def autosplit(path=DATASETS_DIR / 'coco8/images', weights=(0.9, 0.1, 0.0), annot
"""
path = Path(path) # images dir
- files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
+ files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS) # image files only
n = len(files) # number of files
random.seed(0) # for reproducibility
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
- txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
+ txt = ["autosplit_train.txt", "autosplit_val.txt", "autosplit_test.txt"] # 3 txt files
for x in txt:
if (path.parent / x).exists():
(path.parent / x).unlink() # remove existing
- LOGGER.info(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
+ LOGGER.info(f"Autosplitting images from {path}" + ", using *.txt labeled images only" * annotated_only)
for i, img in TQDM(zip(indices, files), total=n):
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
- with open(path.parent / txt[i], 'a') as f:
- f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file
+ with open(path.parent / txt[i], "a") as f:
+ f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file
diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py
index 6819088f..e3d2a86a 100644
--- a/ultralytics/engine/exporter.py
+++ b/ultralytics/engine/exporter.py
@@ -67,8 +67,20 @@ from ultralytics.data.utils import check_det_dataset
from ultralytics.nn.autobackend import check_class_names, default_class_names
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
from ultralytics.nn.tasks import DetectionModel, SegmentationModel
-from ultralytics.utils import (ARM64, DEFAULT_CFG, LINUX, LOGGER, MACOS, ROOT, WINDOWS, __version__, callbacks,
- colorstr, get_default_args, yaml_save)
+from ultralytics.utils import (
+ ARM64,
+ DEFAULT_CFG,
+ LINUX,
+ LOGGER,
+ MACOS,
+ ROOT,
+ WINDOWS,
+ __version__,
+ callbacks,
+ colorstr,
+ get_default_args,
+ yaml_save,
+)
from ultralytics.utils.checks import check_imgsz, check_is_path_safe, check_requirements, check_version
from ultralytics.utils.downloads import attempt_download_asset, get_github_assets
from ultralytics.utils.files import file_size, spaces_in_path
@@ -79,21 +91,23 @@ from ultralytics.utils.torch_utils import get_latest_opset, select_device, smart
def export_formats():
"""YOLOv8 export formats."""
import pandas
+
x = [
- ['PyTorch', '-', '.pt', True, True],
- ['TorchScript', 'torchscript', '.torchscript', True, True],
- ['ONNX', 'onnx', '.onnx', True, True],
- ['OpenVINO', 'openvino', '_openvino_model', True, False],
- ['TensorRT', 'engine', '.engine', False, True],
- ['CoreML', 'coreml', '.mlpackage', True, False],
- ['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],
- ['TensorFlow GraphDef', 'pb', '.pb', True, True],
- ['TensorFlow Lite', 'tflite', '.tflite', True, False],
- ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', True, False],
- ['TensorFlow.js', 'tfjs', '_web_model', True, False],
- ['PaddlePaddle', 'paddle', '_paddle_model', True, True],
- ['ncnn', 'ncnn', '_ncnn_model', True, True], ]
- return pandas.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
+ ["PyTorch", "-", ".pt", True, True],
+ ["TorchScript", "torchscript", ".torchscript", True, True],
+ ["ONNX", "onnx", ".onnx", True, True],
+ ["OpenVINO", "openvino", "_openvino_model", True, False],
+ ["TensorRT", "engine", ".engine", False, True],
+ ["CoreML", "coreml", ".mlpackage", True, False],
+ ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True],
+ ["TensorFlow GraphDef", "pb", ".pb", True, True],
+ ["TensorFlow Lite", "tflite", ".tflite", True, False],
+ ["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False],
+ ["TensorFlow.js", "tfjs", "_web_model", True, False],
+ ["PaddlePaddle", "paddle", "_paddle_model", True, True],
+ ["ncnn", "ncnn", "_ncnn_model", True, True],
+ ]
+ return pandas.DataFrame(x, columns=["Format", "Argument", "Suffix", "CPU", "GPU"])
def gd_outputs(gd):
@@ -102,7 +116,7 @@ def gd_outputs(gd):
for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
name_list.append(node.name)
input_list.extend(node.input)
- return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))
+ return sorted(f"{x}:0" for x in list(set(name_list) - set(input_list)) if not x.startswith("NoOp"))
def try_export(inner_func):
@@ -111,14 +125,14 @@ def try_export(inner_func):
def outer_func(*args, **kwargs):
"""Export a model."""
- prefix = inner_args['prefix']
+ prefix = inner_args["prefix"]
try:
with Profile() as dt:
f, model = inner_func(*args, **kwargs)
LOGGER.info(f"{prefix} export success ✅ {dt.t:.1f}s, saved as '{f}' ({file_size(f):.1f} MB)")
return f, model
except Exception as e:
- LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}')
+ LOGGER.info(f"{prefix} export failure ❌ {dt.t:.1f}s: {e}")
raise e
return outer_func
@@ -143,8 +157,8 @@ class Exporter:
_callbacks (dict, optional): Dictionary of callback functions. Defaults to None.
"""
self.args = get_cfg(cfg, overrides)
- if self.args.format.lower() in ('coreml', 'mlmodel'): # fix attempt for protobuf<3.20.x errors
- os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' # must run before TensorBoard callback
+ if self.args.format.lower() in ("coreml", "mlmodel"): # fix attempt for protobuf<3.20.x errors
+ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # must run before TensorBoard callback
self.callbacks = _callbacks or callbacks.get_default_callbacks()
callbacks.add_integration_callbacks(self)
@@ -152,45 +166,46 @@ class Exporter:
@smart_inference_mode()
def __call__(self, model=None):
"""Returns list of exported files/dirs after running callbacks."""
- self.run_callbacks('on_export_start')
+ self.run_callbacks("on_export_start")
t = time.time()
fmt = self.args.format.lower() # to lowercase
- if fmt in ('tensorrt', 'trt'): # 'engine' aliases
- fmt = 'engine'
- if fmt in ('mlmodel', 'mlpackage', 'mlprogram', 'apple', 'ios', 'coreml'): # 'coreml' aliases
- fmt = 'coreml'
- fmts = tuple(export_formats()['Argument'][1:]) # available export formats
+ if fmt in ("tensorrt", "trt"): # 'engine' aliases
+ fmt = "engine"
+ if fmt in ("mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"): # 'coreml' aliases
+ fmt = "coreml"
+ fmts = tuple(export_formats()["Argument"][1:]) # available export formats
flags = [x == fmt for x in fmts]
if sum(flags) != 1:
raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}")
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn = flags # export booleans
# Device
- if fmt == 'engine' and self.args.device is None:
- LOGGER.warning('WARNING ⚠️ TensorRT requires GPU export, automatically assigning device=0')
- self.args.device = '0'
- self.device = select_device('cpu' if self.args.device is None else self.args.device)
+ if fmt == "engine" and self.args.device is None:
+ LOGGER.warning("WARNING ⚠️ TensorRT requires GPU export, automatically assigning device=0")
+ self.args.device = "0"
+ self.device = select_device("cpu" if self.args.device is None else self.args.device)
# Checks
- if not hasattr(model, 'names'):
+ if not hasattr(model, "names"):
model.names = default_class_names()
model.names = check_class_names(model.names)
- if self.args.half and onnx and self.device.type == 'cpu':
- LOGGER.warning('WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0')
+ if self.args.half and onnx and self.device.type == "cpu":
+ LOGGER.warning("WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0")
self.args.half = False
- assert not self.args.dynamic, 'half=True not compatible with dynamic=True, i.e. use only one.'
+ assert not self.args.dynamic, "half=True not compatible with dynamic=True, i.e. use only one."
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
if self.args.optimize:
assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False"
- assert self.device.type == 'cpu', "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
+ assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
if edgetpu and not LINUX:
- raise SystemError('Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/')
+ raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/")
# Input
im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
file = Path(
- getattr(model, 'pt_path', None) or getattr(model, 'yaml_file', None) or model.yaml.get('yaml_file', ''))
- if file.suffix in {'.yaml', '.yml'}:
+ getattr(model, "pt_path", None) or getattr(model, "yaml_file", None) or model.yaml.get("yaml_file", "")
+ )
+ if file.suffix in {".yaml", ".yml"}:
file = Path(file.name)
# Update model
@@ -212,42 +227,48 @@ class Exporter:
y = None
for _ in range(2):
y = model(im) # dry runs
- if self.args.half and (engine or onnx) and self.device.type != 'cpu':
+ if self.args.half and (engine or onnx) and self.device.type != "cpu":
im, model = im.half(), model.half() # to FP16
# Filter warnings
- warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
- warnings.filterwarnings('ignore', category=UserWarning) # suppress shape prim::Constant missing ONNX warning
- warnings.filterwarnings('ignore', category=DeprecationWarning) # suppress CoreML np.bool deprecation warning
+ warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) # suppress TracerWarning
+ warnings.filterwarnings("ignore", category=UserWarning) # suppress shape prim::Constant missing ONNX warning
+ warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress CoreML np.bool deprecation warning
# Assign
self.im = im
self.model = model
self.file = file
- self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(
- tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y)
- self.pretty_name = Path(self.model.yaml.get('yaml_file', self.file)).stem.replace('yolo', 'YOLO')
- data = model.args['data'] if hasattr(model, 'args') and isinstance(model.args, dict) else ''
+ self.output_shape = (
+ tuple(y.shape)
+ if isinstance(y, torch.Tensor)
+ else tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y)
+ )
+ self.pretty_name = Path(self.model.yaml.get("yaml_file", self.file)).stem.replace("yolo", "YOLO")
+ data = model.args["data"] if hasattr(model, "args") and isinstance(model.args, dict) else ""
description = f'Ultralytics {self.pretty_name} model {f"trained on {data}" if data else ""}'
self.metadata = {
- 'description': description,
- 'author': 'Ultralytics',
- 'license': 'AGPL-3.0 https://ultralytics.com/license',
- 'date': datetime.now().isoformat(),
- 'version': __version__,
- 'stride': int(max(model.stride)),
- 'task': model.task,
- 'batch': self.args.batch,
- 'imgsz': self.imgsz,
- 'names': model.names} # model metadata
- if model.task == 'pose':
- self.metadata['kpt_shape'] = model.model[-1].kpt_shape
-
- LOGGER.info(f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and "
- f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)')
+ "description": description,
+ "author": "Ultralytics",
+ "license": "AGPL-3.0 https://ultralytics.com/license",
+ "date": datetime.now().isoformat(),
+ "version": __version__,
+ "stride": int(max(model.stride)),
+ "task": model.task,
+ "batch": self.args.batch,
+ "imgsz": self.imgsz,
+ "names": model.names,
+ } # model metadata
+ if model.task == "pose":
+ self.metadata["kpt_shape"] = model.model[-1].kpt_shape
+
+ LOGGER.info(
+ f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and "
+ f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)'
+ )
# Exports
- f = [''] * len(fmts) # exported filenames
+ f = [""] * len(fmts) # exported filenames
if jit or ncnn: # TorchScript
f[0], _ = self.export_torchscript()
if engine: # TensorRT required before ONNX
@@ -266,7 +287,7 @@ class Exporter:
if tflite:
f[7], _ = self.export_tflite(keras_model=keras_model, nms=False, agnostic_nms=self.args.agnostic_nms)
if edgetpu:
- f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f'{self.file.stem}_full_integer_quant.tflite')
+ f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f"{self.file.stem}_full_integer_quant.tflite")
if tfjs:
f[9], _ = self.export_tfjs()
if paddle: # PaddlePaddle
@@ -279,58 +300,65 @@ class Exporter:
if any(f):
f = str(Path(f[-1]))
square = self.imgsz[0] == self.imgsz[1]
- s = '' if square else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " \
- f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
- imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '')
- predict_data = f'data={data}' if model.task == 'segment' and fmt == 'pb' else ''
- q = 'int8' if self.args.int8 else 'half' if self.args.half else '' # quantization
- LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
- f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
- f'\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}'
- f'\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}'
- f'\nVisualize: https://netron.app')
-
- self.run_callbacks('on_export_end')
+ s = (
+ ""
+ if square
+ else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not "
+ f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
+ )
+ imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(" ", "")
+ predict_data = f"data={data}" if model.task == "segment" and fmt == "pb" else ""
+ q = "int8" if self.args.int8 else "half" if self.args.half else "" # quantization
+ LOGGER.info(
+ f'\nExport complete ({time.time() - t:.1f}s)'
+ f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
+ f'\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}'
+ f'\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}'
+ f'\nVisualize: https://netron.app'
+ )
+
+ self.run_callbacks("on_export_end")
return f # return list of exported files/dirs
@try_export
- def export_torchscript(self, prefix=colorstr('TorchScript:')):
+ def export_torchscript(self, prefix=colorstr("TorchScript:")):
"""YOLOv8 TorchScript model export."""
- LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
- f = self.file.with_suffix('.torchscript')
+ LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...")
+ f = self.file.with_suffix(".torchscript")
ts = torch.jit.trace(self.model, self.im, strict=False)
- extra_files = {'config.txt': json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
+ extra_files = {"config.txt": json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
- LOGGER.info(f'{prefix} optimizing for mobile...')
+ LOGGER.info(f"{prefix} optimizing for mobile...")
from torch.utils.mobile_optimizer import optimize_for_mobile
+
optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
else:
ts.save(str(f), _extra_files=extra_files)
return f, None
@try_export
- def export_onnx(self, prefix=colorstr('ONNX:')):
+ def export_onnx(self, prefix=colorstr("ONNX:")):
"""YOLOv8 ONNX export."""
- requirements = ['onnx>=1.12.0']
+ requirements = ["onnx>=1.12.0"]
if self.args.simplify:
- requirements += ['onnxsim>=0.4.33', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime']
+ requirements += ["onnxsim>=0.4.33", "onnxruntime-gpu" if torch.cuda.is_available() else "onnxruntime"]
check_requirements(requirements)
import onnx # noqa
opset_version = self.args.opset or get_latest_opset()
- LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...')
- f = str(self.file.with_suffix('.onnx'))
+ LOGGER.info(f"\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...")
+ f = str(self.file.with_suffix(".onnx"))
- output_names = ['output0', 'output1'] if isinstance(self.model, SegmentationModel) else ['output0']
+ output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
dynamic = self.args.dynamic
if dynamic:
- dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640)
+ dynamic = {"images": {0: "batch", 2: "height", 3: "width"}} # shape(1,3,640,640)
if isinstance(self.model, SegmentationModel):
- dynamic['output0'] = {0: 'batch', 2: 'anchors'} # shape(1, 116, 8400)
- dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160)
+ dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 116, 8400)
+ dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160)
elif isinstance(self.model, DetectionModel):
- dynamic['output0'] = {0: 'batch', 2: 'anchors'} # shape(1, 84, 8400)
+ dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, 8400)
torch.onnx.export(
self.model.cpu() if dynamic else self.model, # dynamic=True only compatible with cpu
@@ -339,9 +367,10 @@ class Exporter:
verbose=False,
opset_version=opset_version,
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
- input_names=['images'],
+ input_names=["images"],
output_names=output_names,
- dynamic_axes=dynamic or None)
+ dynamic_axes=dynamic or None,
+ )
# Checks
model_onnx = onnx.load(f) # load onnx model
@@ -352,12 +381,12 @@ class Exporter:
try:
import onnxsim
- LOGGER.info(f'{prefix} simplifying with onnxsim {onnxsim.__version__}...')
+ LOGGER.info(f"{prefix} simplifying with onnxsim {onnxsim.__version__}...")
# subprocess.run(f'onnxsim "{f}" "{f}"', shell=True)
model_onnx, check = onnxsim.simplify(model_onnx)
- assert check, 'Simplified ONNX model could not be validated'
+ assert check, "Simplified ONNX model could not be validated"
except Exception as e:
- LOGGER.info(f'{prefix} simplifier failure: {e}')
+ LOGGER.info(f"{prefix} simplifier failure: {e}")
# Metadata
for k, v in self.metadata.items():
@@ -368,58 +397,56 @@ class Exporter:
return f, model_onnx
@try_export
- def export_openvino(self, prefix=colorstr('OpenVINO:')):
+ def export_openvino(self, prefix=colorstr("OpenVINO:")):
"""YOLOv8 OpenVINO export."""
- check_requirements('openvino-dev>=2023.0') # requires openvino-dev: https://pypi.org/project/openvino-dev/
+ check_requirements("openvino-dev>=2023.0") # requires openvino-dev: https://pypi.org/project/openvino-dev/
import openvino.runtime as ov # noqa
from openvino.tools import mo # noqa
- LOGGER.info(f'\n{prefix} starting export with openvino {ov.__version__}...')
- f = str(self.file).replace(self.file.suffix, f'_openvino_model{os.sep}')
- fq = str(self.file).replace(self.file.suffix, f'_int8_openvino_model{os.sep}')
- f_onnx = self.file.with_suffix('.onnx')
- f_ov = str(Path(f) / self.file.with_suffix('.xml').name)
- fq_ov = str(Path(fq) / self.file.with_suffix('.xml').name)
+ LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
+ f = str(self.file).replace(self.file.suffix, f"_openvino_model{os.sep}")
+ fq = str(self.file).replace(self.file.suffix, f"_int8_openvino_model{os.sep}")
+ f_onnx = self.file.with_suffix(".onnx")
+ f_ov = str(Path(f) / self.file.with_suffix(".xml").name)
+ fq_ov = str(Path(fq) / self.file.with_suffix(".xml").name)
def serialize(ov_model, file):
"""Set RT info, serialize and save metadata YAML."""
- ov_model.set_rt_info('YOLOv8', ['model_info', 'model_type'])
- ov_model.set_rt_info(True, ['model_info', 'reverse_input_channels'])
- ov_model.set_rt_info(114, ['model_info', 'pad_value'])
- ov_model.set_rt_info([255.0], ['model_info', 'scale_values'])
- ov_model.set_rt_info(self.args.iou, ['model_info', 'iou_threshold'])
- ov_model.set_rt_info([v.replace(' ', '_') for v in self.model.names.values()], ['model_info', 'labels'])
- if self.model.task != 'classify':
- ov_model.set_rt_info('fit_to_window_letterbox', ['model_info', 'resize_type'])
+ ov_model.set_rt_info("YOLOv8", ["model_info", "model_type"])
+ ov_model.set_rt_info(True, ["model_info", "reverse_input_channels"])
+ ov_model.set_rt_info(114, ["model_info", "pad_value"])
+ ov_model.set_rt_info([255.0], ["model_info", "scale_values"])
+ ov_model.set_rt_info(self.args.iou, ["model_info", "iou_threshold"])
+ ov_model.set_rt_info([v.replace(" ", "_") for v in self.model.names.values()], ["model_info", "labels"])
+ if self.model.task != "classify":
+ ov_model.set_rt_info("fit_to_window_letterbox", ["model_info", "resize_type"])
ov.serialize(ov_model, file) # save
- yaml_save(Path(file).parent / 'metadata.yaml', self.metadata) # add metadata.yaml
+ yaml_save(Path(file).parent / "metadata.yaml", self.metadata) # add metadata.yaml
- ov_model = mo.convert_model(f_onnx,
- model_name=self.pretty_name,
- framework='onnx',
- compress_to_fp16=self.args.half) # export
+ ov_model = mo.convert_model(
+ f_onnx, model_name=self.pretty_name, framework="onnx", compress_to_fp16=self.args.half
+ ) # export
if self.args.int8:
assert self.args.data, "INT8 export requires a data argument for calibration, i.e. 'data=coco8.yaml'"
- check_requirements('nncf>=2.5.0')
+ check_requirements("nncf>=2.5.0")
import nncf
def transform_fn(data_item):
"""Quantization transform function."""
- im = data_item['img'].numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0
+ im = data_item["img"].numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0
return np.expand_dims(im, 0) if im.ndim == 3 else im
# Generate calibration data for integer quantization
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
data = check_det_dataset(self.args.data)
- dataset = YOLODataset(data['val'], data=data, imgsz=self.imgsz[0], augment=False)
+ dataset = YOLODataset(data["val"], data=data, imgsz=self.imgsz[0], augment=False)
quantization_dataset = nncf.Dataset(dataset, transform_fn)
- ignored_scope = nncf.IgnoredScope(types=['Multiply', 'Subtract', 'Sigmoid']) # ignore operation
- quantized_ov_model = nncf.quantize(ov_model,
- quantization_dataset,
- preset=nncf.QuantizationPreset.MIXED,
- ignored_scope=ignored_scope)
+ ignored_scope = nncf.IgnoredScope(types=["Multiply", "Subtract", "Sigmoid"]) # ignore operation
+ quantized_ov_model = nncf.quantize(
+ ov_model, quantization_dataset, preset=nncf.QuantizationPreset.MIXED, ignored_scope=ignored_scope
+ )
serialize(quantized_ov_model, fq_ov)
return fq, None
@@ -427,48 +454,49 @@ class Exporter:
return f, None
@try_export
- def export_paddle(self, prefix=colorstr('PaddlePaddle:')):
+ def export_paddle(self, prefix=colorstr("PaddlePaddle:")):
"""YOLOv8 Paddle export."""
- check_requirements(('paddlepaddle', 'x2paddle'))
+ check_requirements(("paddlepaddle", "x2paddle"))
import x2paddle # noqa
from x2paddle.convert import pytorch2paddle # noqa
- LOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...')
- f = str(self.file).replace(self.file.suffix, f'_paddle_model{os.sep}')
+ LOGGER.info(f"\n{prefix} starting export with X2Paddle {x2paddle.__version__}...")
+ f = str(self.file).replace(self.file.suffix, f"_paddle_model{os.sep}")
- pytorch2paddle(module=self.model, save_dir=f, jit_type='trace', input_examples=[self.im]) # export
- yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
+ pytorch2paddle(module=self.model, save_dir=f, jit_type="trace", input_examples=[self.im]) # export
+ yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
return f, None
@try_export
- def export_ncnn(self, prefix=colorstr('ncnn:')):
+ def export_ncnn(self, prefix=colorstr("ncnn:")):
"""
YOLOv8 ncnn export using PNNX https://github.com/pnnx/pnnx.
"""
- check_requirements('git+https://github.com/Tencent/ncnn.git' if ARM64 else 'ncnn') # requires ncnn
+ check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires ncnn
import ncnn # noqa
- LOGGER.info(f'\n{prefix} starting export with ncnn {ncnn.__version__}...')
- f = Path(str(self.file).replace(self.file.suffix, f'_ncnn_model{os.sep}'))
- f_ts = self.file.with_suffix('.torchscript')
+ LOGGER.info(f"\n{prefix} starting export with ncnn {ncnn.__version__}...")
+ f = Path(str(self.file).replace(self.file.suffix, f"_ncnn_model{os.sep}"))
+ f_ts = self.file.with_suffix(".torchscript")
- name = Path('pnnx.exe' if WINDOWS else 'pnnx') # PNNX filename
+ name = Path("pnnx.exe" if WINDOWS else "pnnx") # PNNX filename
pnnx = name if name.is_file() else ROOT / name
if not pnnx.is_file():
LOGGER.warning(
- f'{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from '
- 'https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory '
- f'or in {ROOT}. See PNNX repo for full installation instructions.')
- system = ['macos'] if MACOS else ['windows'] if WINDOWS else ['ubuntu', 'linux'] # operating system
+ f"{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from "
+ "https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory "
+ f"or in {ROOT}. See PNNX repo for full installation instructions."
+ )
+ system = ["macos"] if MACOS else ["windows"] if WINDOWS else ["ubuntu", "linux"] # operating system
try:
- _, assets = get_github_assets(repo='pnnx/pnnx', retry=True)
+ _, assets = get_github_assets(repo="pnnx/pnnx", retry=True)
url = [x for x in assets if any(s in x for s in system)][0]
except Exception as e:
- url = f'https://github.com/pnnx/pnnx/releases/download/20231127/pnnx-20231127-{system[0]}.zip'
- LOGGER.warning(f'{prefix} WARNING ⚠️ PNNX GitHub assets not found: {e}, using default {url}')
- asset = attempt_download_asset(url, repo='pnnx/pnnx', release='latest')
+ url = f"https://github.com/pnnx/pnnx/releases/download/20231127/pnnx-20231127-{system[0]}.zip"
+ LOGGER.warning(f"{prefix} WARNING ⚠️ PNNX GitHub assets not found: {e}, using default {url}")
+ asset = attempt_download_asset(url, repo="pnnx/pnnx", release="latest")
if check_is_path_safe(Path.cwd(), asset): # avoid path traversal security vulnerability
- unzip_dir = Path(asset).with_suffix('')
+ unzip_dir = Path(asset).with_suffix("")
(unzip_dir / name).rename(pnnx) # move binary to ROOT
shutil.rmtree(unzip_dir) # delete unzip dir
Path(asset).unlink() # delete zip
@@ -477,53 +505,56 @@ class Exporter:
ncnn_args = [
f'ncnnparam={f / "model.ncnn.param"}',
f'ncnnbin={f / "model.ncnn.bin"}',
- f'ncnnpy={f / "model_ncnn.py"}', ]
+ f'ncnnpy={f / "model_ncnn.py"}',
+ ]
pnnx_args = [
f'pnnxparam={f / "model.pnnx.param"}',
f'pnnxbin={f / "model.pnnx.bin"}',
f'pnnxpy={f / "model_pnnx.py"}',
- f'pnnxonnx={f / "model.pnnx.onnx"}', ]
+ f'pnnxonnx={f / "model.pnnx.onnx"}',
+ ]
cmd = [
str(pnnx),
str(f_ts),
*ncnn_args,
*pnnx_args,
- f'fp16={int(self.args.half)}',
- f'device={self.device.type}',
- f'inputshape="{[self.args.batch, 3, *self.imgsz]}"', ]
+ f"fp16={int(self.args.half)}",
+ f"device={self.device.type}",
+ f'inputshape="{[self.args.batch, 3, *self.imgsz]}"',
+ ]
f.mkdir(exist_ok=True) # make ncnn_model directory
LOGGER.info(f"{prefix} running '{' '.join(cmd)}'")
subprocess.run(cmd, check=True)
# Remove debug files
- pnnx_files = [x.split('=')[-1] for x in pnnx_args]
- for f_debug in ('debug.bin', 'debug.param', 'debug2.bin', 'debug2.param', *pnnx_files):
+ pnnx_files = [x.split("=")[-1] for x in pnnx_args]
+ for f_debug in ("debug.bin", "debug.param", "debug2.bin", "debug2.param", *pnnx_files):
Path(f_debug).unlink(missing_ok=True)
- yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
+ yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml
return str(f), None
@try_export
- def export_coreml(self, prefix=colorstr('CoreML:')):
+ def export_coreml(self, prefix=colorstr("CoreML:")):
"""YOLOv8 CoreML export."""
- mlmodel = self.args.format.lower() == 'mlmodel' # legacy *.mlmodel export format requested
- check_requirements('coremltools>=6.0,<=6.2' if mlmodel else 'coremltools>=7.0')
+ mlmodel = self.args.format.lower() == "mlmodel" # legacy *.mlmodel export format requested
+ check_requirements("coremltools>=6.0,<=6.2" if mlmodel else "coremltools>=7.0")
import coremltools as ct # noqa
- LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
- f = self.file.with_suffix('.mlmodel' if mlmodel else '.mlpackage')
+ LOGGER.info(f"\n{prefix} starting export with coremltools {ct.__version__}...")
+ f = self.file.with_suffix(".mlmodel" if mlmodel else ".mlpackage")
if f.is_dir():
shutil.rmtree(f)
bias = [0.0, 0.0, 0.0]
scale = 1 / 255
classifier_config = None
- if self.model.task == 'classify':
+ if self.model.task == "classify":
classifier_config = ct.ClassifierConfig(list(self.model.names.values())) if self.args.nms else None
model = self.model
- elif self.model.task == 'detect':
+ elif self.model.task == "detect":
model = IOSDetectModel(self.model, self.im) if self.args.nms else self.model
else:
if self.args.nms:
@@ -532,69 +563,73 @@ class Exporter:
model = self.model
ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model
- ct_model = ct.convert(ts,
- inputs=[ct.ImageType('image', shape=self.im.shape, scale=scale, bias=bias)],
- classifier_config=classifier_config,
- convert_to='neuralnetwork' if mlmodel else 'mlprogram')
- bits, mode = (8, 'kmeans') if self.args.int8 else (16, 'linear') if self.args.half else (32, None)
+ ct_model = ct.convert(
+ ts,
+ inputs=[ct.ImageType("image", shape=self.im.shape, scale=scale, bias=bias)],
+ classifier_config=classifier_config,
+ convert_to="neuralnetwork" if mlmodel else "mlprogram",
+ )
+ bits, mode = (8, "kmeans") if self.args.int8 else (16, "linear") if self.args.half else (32, None)
if bits < 32:
- if 'kmeans' in mode:
- check_requirements('scikit-learn') # scikit-learn package required for k-means quantization
+ if "kmeans" in mode:
+ check_requirements("scikit-learn") # scikit-learn package required for k-means quantization
if mlmodel:
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
elif bits == 8: # mlprogram already quantized to FP16
import coremltools.optimize.coreml as cto
- op_config = cto.OpPalettizerConfig(mode='kmeans', nbits=bits, weight_threshold=512)
+
+ op_config = cto.OpPalettizerConfig(mode="kmeans", nbits=bits, weight_threshold=512)
config = cto.OptimizationConfig(global_config=op_config)
ct_model = cto.palettize_weights(ct_model, config=config)
- if self.args.nms and self.model.task == 'detect':
+ if self.args.nms and self.model.task == "detect":
if mlmodel:
import platform
# coremltools<=6.2 NMS export requires Python<3.11
- check_version(platform.python_version(), '<3.11', name='Python ', hard=True)
+ check_version(platform.python_version(), "<3.11", name="Python ", hard=True)
weights_dir = None
else:
ct_model.save(str(f)) # save otherwise weights_dir does not exist
- weights_dir = str(f / 'Data/com.apple.CoreML/weights')
+ weights_dir = str(f / "Data/com.apple.CoreML/weights")
ct_model = self._pipeline_coreml(ct_model, weights_dir=weights_dir)
m = self.metadata # metadata dict
- ct_model.short_description = m.pop('description')
- ct_model.author = m.pop('author')
- ct_model.license = m.pop('license')
- ct_model.version = m.pop('version')
+ ct_model.short_description = m.pop("description")
+ ct_model.author = m.pop("author")
+ ct_model.license = m.pop("license")
+ ct_model.version = m.pop("version")
ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()})
try:
ct_model.save(str(f)) # save *.mlpackage
except Exception as e:
LOGGER.warning(
- f'{prefix} WARNING ⚠️ CoreML export to *.mlpackage failed ({e}), reverting to *.mlmodel export. '
- f'Known coremltools Python 3.11 and Windows bugs https://github.com/apple/coremltools/issues/1928.')
- f = f.with_suffix('.mlmodel')
+ f"{prefix} WARNING ⚠️ CoreML export to *.mlpackage failed ({e}), reverting to *.mlmodel export. "
+ f"Known coremltools Python 3.11 and Windows bugs https://github.com/apple/coremltools/issues/1928."
+ )
+ f = f.with_suffix(".mlmodel")
ct_model.save(str(f))
return f, ct_model
@try_export
- def export_engine(self, prefix=colorstr('TensorRT:')):
+ def export_engine(self, prefix=colorstr("TensorRT:")):
"""YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt."""
- assert self.im.device.type != 'cpu', "export running on CPU but must be on GPU, i.e. use 'device=0'"
+ assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'"
f_onnx, _ = self.export_onnx() # run before trt import https://github.com/ultralytics/ultralytics/issues/7016
try:
import tensorrt as trt # noqa
except ImportError:
if LINUX:
- check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
+ check_requirements("nvidia-tensorrt", cmds="-U --index-url https://pypi.ngc.nvidia.com")
import tensorrt as trt # noqa
- check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
+ check_version(trt.__version__, "7.0.0", hard=True) # require tensorrt>=7.0.0
self.args.simplify = True
- LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
- assert Path(f_onnx).exists(), f'failed to export ONNX file: {f_onnx}'
- f = self.file.with_suffix('.engine') # TensorRT engine file
+ LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...")
+ assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}"
+ f = self.file.with_suffix(".engine") # TensorRT engine file
logger = trt.Logger(trt.Logger.INFO)
if self.args.verbose:
logger.min_severity = trt.Logger.Severity.VERBOSE
@@ -604,11 +639,11 @@ class Exporter:
config.max_workspace_size = self.args.workspace * 1 << 30
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
- flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
+ flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(f_onnx):
- raise RuntimeError(f'failed to load ONNX file: {f_onnx}')
+ raise RuntimeError(f"failed to load ONNX file: {f_onnx}")
inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
@@ -627,7 +662,8 @@ class Exporter:
config.add_optimization_profile(profile)
LOGGER.info(
- f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}')
+ f"{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}"
+ )
if builder.platform_has_fast_fp16 and self.args.half:
config.set_flag(trt.BuilderFlag.FP16)
@@ -635,10 +671,10 @@ class Exporter:
torch.cuda.empty_cache()
# Write file
- with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
+ with builder.build_engine(network, config) as engine, open(f, "wb") as t:
# Metadata
meta = json.dumps(self.metadata)
- t.write(len(meta).to_bytes(4, byteorder='little', signed=True))
+ t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
t.write(meta.encode())
# Model
t.write(engine.serialize())
@@ -646,7 +682,7 @@ class Exporter:
return f, None
@try_export
- def export_saved_model(self, prefix=colorstr('TensorFlow SavedModel:')):
+ def export_saved_model(self, prefix=colorstr("TensorFlow SavedModel:")):
"""YOLOv8 TensorFlow SavedModel export."""
cuda = torch.cuda.is_available()
try:
@@ -655,44 +691,55 @@ class Exporter:
check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}")
import tensorflow as tf # noqa
check_requirements(
- ('onnx', 'onnx2tf>=1.15.4,<=1.17.5', 'sng4onnx>=1.0.1', 'onnxsim>=0.4.33', 'onnx_graphsurgeon>=0.3.26',
- 'tflite_support', 'onnxruntime-gpu' if cuda else 'onnxruntime'),
- cmds='--extra-index-url https://pypi.ngc.nvidia.com') # onnx_graphsurgeon only on NVIDIA
-
- LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
- check_version(tf.__version__,
- '<=2.13.1',
- name='tensorflow',
- verbose=True,
- msg='https://github.com/ultralytics/ultralytics/issues/5161')
- f = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
+ (
+ "onnx",
+ "onnx2tf>=1.15.4,<=1.17.5",
+ "sng4onnx>=1.0.1",
+ "onnxsim>=0.4.33",
+ "onnx_graphsurgeon>=0.3.26",
+ "tflite_support",
+ "onnxruntime-gpu" if cuda else "onnxruntime",
+ ),
+ cmds="--extra-index-url https://pypi.ngc.nvidia.com",
+ ) # onnx_graphsurgeon only on NVIDIA
+
+ LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
+ check_version(
+ tf.__version__,
+ "<=2.13.1",
+ name="tensorflow",
+ verbose=True,
+ msg="https://github.com/ultralytics/ultralytics/issues/5161",
+ )
+ f = Path(str(self.file).replace(self.file.suffix, "_saved_model"))
if f.is_dir():
import shutil
+
shutil.rmtree(f) # delete output folder
# Pre-download calibration file to fix https://github.com/PINTO0309/onnx2tf/issues/545
- onnx2tf_file = Path('calibration_image_sample_data_20x128x128x3_float32.npy')
+ onnx2tf_file = Path("calibration_image_sample_data_20x128x128x3_float32.npy")
if not onnx2tf_file.exists():
- attempt_download_asset(f'{onnx2tf_file}.zip', unzip=True, delete=True)
+ attempt_download_asset(f"{onnx2tf_file}.zip", unzip=True, delete=True)
# Export to ONNX
self.args.simplify = True
f_onnx, _ = self.export_onnx()
# Export to TF
- tmp_file = f / 'tmp_tflite_int8_calibration_images.npy' # int8 calibration images file
+ tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
if self.args.int8:
- verbosity = '--verbosity info'
+ verbosity = "--verbosity info"
if self.args.data:
# Generate calibration data for integer quantization
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
data = check_det_dataset(self.args.data)
- dataset = YOLODataset(data['val'], data=data, imgsz=self.imgsz[0], augment=False)
+ dataset = YOLODataset(data["val"], data=data, imgsz=self.imgsz[0], augment=False)
images = []
for i, batch in enumerate(dataset):
if i >= 100: # maximum number of calibration images
break
- im = batch['img'].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC
+ im = batch["img"].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC
images.append(im)
f.mkdir()
images = torch.cat(images, 0).float()
@@ -701,38 +748,38 @@ class Exporter:
np.save(str(tmp_file), images.numpy()) # BHWC
int8 = f'-oiqt -qt per-tensor -cind images "{tmp_file}" "[[[[0, 0, 0]]]]" "[[[[255, 255, 255]]]]"'
else:
- int8 = '-oiqt -qt per-tensor'
+ int8 = "-oiqt -qt per-tensor"
else:
- verbosity = '--non_verbose'
- int8 = ''
+ verbosity = "--non_verbose"
+ int8 = ""
cmd = f'onnx2tf -i "{f_onnx}" -o "{f}" -nuo {verbosity} {int8}'.strip()
LOGGER.info(f"{prefix} running '{cmd}'")
subprocess.run(cmd, shell=True)
- yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
+ yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml
# Remove/rename TFLite models
if self.args.int8:
tmp_file.unlink(missing_ok=True)
- for file in f.rglob('*_dynamic_range_quant.tflite'):
- file.rename(file.with_name(file.stem.replace('_dynamic_range_quant', '_int8') + file.suffix))
- for file in f.rglob('*_integer_quant_with_int16_act.tflite'):
+ for file in f.rglob("*_dynamic_range_quant.tflite"):
+ file.rename(file.with_name(file.stem.replace("_dynamic_range_quant", "_int8") + file.suffix))
+ for file in f.rglob("*_integer_quant_with_int16_act.tflite"):
file.unlink() # delete extra fp16 activation TFLite files
# Add TFLite metadata
- for file in f.rglob('*.tflite'):
- f.unlink() if 'quant_with_int16_act.tflite' in str(f) else self._add_tflite_metadata(file)
+ for file in f.rglob("*.tflite"):
+ f.unlink() if "quant_with_int16_act.tflite" in str(f) else self._add_tflite_metadata(file)
return str(f), tf.saved_model.load(f, tags=None, options=None) # load saved_model as Keras model
@try_export
- def export_pb(self, keras_model, prefix=colorstr('TensorFlow GraphDef:')):
+ def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")):
"""YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow."""
import tensorflow as tf # noqa
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
- LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
- f = self.file.with_suffix('.pb')
+ LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
+ f = self.file.with_suffix(".pb")
m = tf.function(lambda x: keras_model(x)) # full model
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
@@ -742,40 +789,43 @@ class Exporter:
return f, None
@try_export
- def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
+ def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr("TensorFlow Lite:")):
"""YOLOv8 TensorFlow Lite export."""
import tensorflow as tf # noqa
- LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
- saved_model = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
+ LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
+ saved_model = Path(str(self.file).replace(self.file.suffix, "_saved_model"))
if self.args.int8:
- f = saved_model / f'{self.file.stem}_int8.tflite' # fp32 in/out
+ f = saved_model / f"{self.file.stem}_int8.tflite" # fp32 in/out
elif self.args.half:
- f = saved_model / f'{self.file.stem}_float16.tflite' # fp32 in/out
+ f = saved_model / f"{self.file.stem}_float16.tflite" # fp32 in/out
else:
- f = saved_model / f'{self.file.stem}_float32.tflite'
+ f = saved_model / f"{self.file.stem}_float32.tflite"
return str(f), None
@try_export
- def export_edgetpu(self, tflite_model='', prefix=colorstr('Edge TPU:')):
+ def export_edgetpu(self, tflite_model="", prefix=colorstr("Edge TPU:")):
"""YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/."""
- LOGGER.warning(f'{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185')
+ LOGGER.warning(f"{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185")
- cmd = 'edgetpu_compiler --version'
- help_url = 'https://coral.ai/docs/edgetpu/compiler/'
- assert LINUX, f'export only supported on Linux. See {help_url}'
+ cmd = "edgetpu_compiler --version"
+ help_url = "https://coral.ai/docs/edgetpu/compiler/"
+ assert LINUX, f"export only supported on Linux. See {help_url}"
if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0:
- LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
- sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
- for c in ('curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
- 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | '
- 'sudo tee /etc/apt/sources.list.d/coral-edgetpu.list', 'sudo apt-get update',
- 'sudo apt-get install edgetpu-compiler'):
- subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
+ LOGGER.info(f"\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}")
+ sudo = subprocess.run("sudo --version >/dev/null", shell=True).returncode == 0 # sudo installed on system
+ for c in (
+ "curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -",
+ 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | '
+ "sudo tee /etc/apt/sources.list.d/coral-edgetpu.list",
+ "sudo apt-get update",
+ "sudo apt-get install edgetpu-compiler",
+ ):
+ subprocess.run(c if sudo else c.replace("sudo ", ""), shell=True, check=True)
ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
- LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
- f = str(tflite_model).replace('.tflite', '_edgetpu.tflite') # Edge TPU model
+ LOGGER.info(f"\n{prefix} starting export with Edge TPU compiler {ver}...")
+ f = str(tflite_model).replace(".tflite", "_edgetpu.tflite") # Edge TPU model
cmd = f'edgetpu_compiler -s -d -k 10 --out_dir "{Path(f).parent}" "{tflite_model}"'
LOGGER.info(f"{prefix} running '{cmd}'")
@@ -784,30 +834,30 @@ class Exporter:
return f, None
@try_export
- def export_tfjs(self, prefix=colorstr('TensorFlow.js:')):
+ def export_tfjs(self, prefix=colorstr("TensorFlow.js:")):
"""YOLOv8 TensorFlow.js export."""
# JAX bug requiring install constraints in https://github.com/google/jax/issues/18978
- check_requirements(['jax<=0.4.21', 'jaxlib<=0.4.21', 'tensorflowjs'])
+ check_requirements(["jax<=0.4.21", "jaxlib<=0.4.21", "tensorflowjs"])
import tensorflow as tf
import tensorflowjs as tfjs # noqa
- LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
- f = str(self.file).replace(self.file.suffix, '_web_model') # js dir
- f_pb = str(self.file.with_suffix('.pb')) # *.pb path
+ LOGGER.info(f"\n{prefix} starting export with tensorflowjs {tfjs.__version__}...")
+ f = str(self.file).replace(self.file.suffix, "_web_model") # js dir
+ f_pb = str(self.file.with_suffix(".pb")) # *.pb path
gd = tf.Graph().as_graph_def() # TF GraphDef
- with open(f_pb, 'rb') as file:
+ with open(f_pb, "rb") as file:
gd.ParseFromString(file.read())
- outputs = ','.join(gd_outputs(gd))
- LOGGER.info(f'\n{prefix} output node names: {outputs}')
+ outputs = ",".join(gd_outputs(gd))
+ LOGGER.info(f"\n{prefix} output node names: {outputs}")
- quantization = '--quantize_float16' if self.args.half else '--quantize_uint8' if self.args.int8 else ''
+ quantization = "--quantize_float16" if self.args.half else "--quantize_uint8" if self.args.int8 else ""
with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_: # exporter can not handle spaces in path
cmd = f'tensorflowjs_converter --input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"'
LOGGER.info(f"{prefix} running '{cmd}'")
subprocess.run(cmd, shell=True)
- if ' ' in f:
+ if " " in f:
LOGGER.warning(f"{prefix} WARNING ⚠️ your model may not work correctly with spaces in path '{f}'.")
# f_json = Path(f) / 'model.json' # *.json path
@@ -824,7 +874,7 @@ class Exporter:
# f_json.read_text(),
# )
# j.write(subst)
- yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
+ yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
return f, None
def _add_tflite_metadata(self, file):
@@ -835,14 +885,14 @@ class Exporter:
# Create model info
model_meta = _metadata_fb.ModelMetadataT()
- model_meta.name = self.metadata['description']
- model_meta.version = self.metadata['version']
- model_meta.author = self.metadata['author']
- model_meta.license = self.metadata['license']
+ model_meta.name = self.metadata["description"]
+ model_meta.version = self.metadata["version"]
+ model_meta.author = self.metadata["author"]
+ model_meta.license = self.metadata["license"]
# Label file
- tmp_file = Path(file).parent / 'temp_meta.txt'
- with open(tmp_file, 'w') as f:
+ tmp_file = Path(file).parent / "temp_meta.txt"
+ with open(tmp_file, "w") as f:
f.write(str(self.metadata))
label_file = _metadata_fb.AssociatedFileT()
@@ -851,8 +901,8 @@ class Exporter:
# Create input info
input_meta = _metadata_fb.TensorMetadataT()
- input_meta.name = 'image'
- input_meta.description = 'Input image to be detected.'
+ input_meta.name = "image"
+ input_meta.description = "Input image to be detected."
input_meta.content = _metadata_fb.ContentT()
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
@@ -860,19 +910,19 @@ class Exporter:
# Create output info
output1 = _metadata_fb.TensorMetadataT()
- output1.name = 'output'
- output1.description = 'Coordinates of detected objects, class labels, and confidence score'
+ output1.name = "output"
+ output1.description = "Coordinates of detected objects, class labels, and confidence score"
output1.associatedFiles = [label_file]
- if self.model.task == 'segment':
+ if self.model.task == "segment":
output2 = _metadata_fb.TensorMetadataT()
- output2.name = 'output'
- output2.description = 'Mask protos'
+ output2.name = "output"
+ output2.description = "Mask protos"
output2.associatedFiles = [label_file]
# Create subgraph info
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [input_meta]
- subgraph.outputTensorMetadata = [output1, output2] if self.model.task == 'segment' else [output1]
+ subgraph.outputTensorMetadata = [output1, output2] if self.model.task == "segment" else [output1]
model_meta.subgraphMetadata = [subgraph]
b = flatbuffers.Builder(0)
@@ -885,11 +935,11 @@ class Exporter:
populator.populate()
tmp_file.unlink()
- def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr('CoreML Pipeline:')):
+ def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr("CoreML Pipeline:")):
"""YOLOv8 CoreML pipeline."""
import coremltools as ct # noqa
- LOGGER.info(f'{prefix} starting pipeline with coremltools {ct.__version__}...')
+ LOGGER.info(f"{prefix} starting pipeline with coremltools {ct.__version__}...")
_, _, h, w = list(self.im.shape) # BCHW
# Output shapes
@@ -897,8 +947,9 @@ class Exporter:
out0, out1 = iter(spec.description.output)
if MACOS:
from PIL import Image
- img = Image.new('RGB', (w, h)) # w=192, h=320
- out = model.predict({'image': img})
+
+ img = Image.new("RGB", (w, h)) # w=192, h=320
+ out = model.predict({"image": img})
out0_shape = out[out0.name].shape # (3780, 80)
out1_shape = out[out1.name].shape # (3780, 4)
else: # linux and windows can not run model.predict(), get sizes from PyTorch model output y
@@ -906,11 +957,11 @@ class Exporter:
out1_shape = self.output_shape[2], 4 # (3780, 4)
# Checks
- names = self.metadata['names']
+ names = self.metadata["names"]
nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height
_, nc = out0_shape # number of anchors, number of classes
# _, nc = out0.type.multiArrayType.shape
- assert len(names) == nc, f'{len(names)} names found for nc={nc}' # check
+ assert len(names) == nc, f"{len(names)} names found for nc={nc}" # check
# Define output shapes (missing)
out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80)
@@ -944,8 +995,8 @@ class Exporter:
nms_spec.description.output.add()
nms_spec.description.output[i].ParseFromString(decoder_output)
- nms_spec.description.output[0].name = 'confidence'
- nms_spec.description.output[1].name = 'coordinates'
+ nms_spec.description.output[0].name = "confidence"
+ nms_spec.description.output[1].name = "coordinates"
output_sizes = [nc, 4]
for i in range(2):
@@ -961,10 +1012,10 @@ class Exporter:
nms = nms_spec.nonMaximumSuppression
nms.confidenceInputFeatureName = out0.name # 1x507x80
nms.coordinatesInputFeatureName = out1.name # 1x507x4
- nms.confidenceOutputFeatureName = 'confidence'
- nms.coordinatesOutputFeatureName = 'coordinates'
- nms.iouThresholdInputFeatureName = 'iouThreshold'
- nms.confidenceThresholdInputFeatureName = 'confidenceThreshold'
+ nms.confidenceOutputFeatureName = "confidence"
+ nms.coordinatesOutputFeatureName = "coordinates"
+ nms.iouThresholdInputFeatureName = "iouThreshold"
+ nms.confidenceThresholdInputFeatureName = "confidenceThreshold"
nms.iouThreshold = 0.45
nms.confidenceThreshold = 0.25
nms.pickTop.perClass = True
@@ -972,10 +1023,14 @@ class Exporter:
nms_model = ct.models.MLModel(nms_spec)
# 4. Pipeline models together
- pipeline = ct.models.pipeline.Pipeline(input_features=[('image', ct.models.datatypes.Array(3, ny, nx)),
- ('iouThreshold', ct.models.datatypes.Double()),
- ('confidenceThreshold', ct.models.datatypes.Double())],
- output_features=['confidence', 'coordinates'])
+ pipeline = ct.models.pipeline.Pipeline(
+ input_features=[
+ ("image", ct.models.datatypes.Array(3, ny, nx)),
+ ("iouThreshold", ct.models.datatypes.Double()),
+ ("confidenceThreshold", ct.models.datatypes.Double()),
+ ],
+ output_features=["confidence", "coordinates"],
+ )
pipeline.add_model(model)
pipeline.add_model(nms_model)
@@ -986,19 +1041,20 @@ class Exporter:
# Update metadata
pipeline.spec.specificationVersion = 5
- pipeline.spec.description.metadata.userDefined.update({
- 'IoU threshold': str(nms.iouThreshold),
- 'Confidence threshold': str(nms.confidenceThreshold)})
+ pipeline.spec.description.metadata.userDefined.update(
+ {"IoU threshold": str(nms.iouThreshold), "Confidence threshold": str(nms.confidenceThreshold)}
+ )
# Save the model
model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir)
- model.input_description['image'] = 'Input image'
- model.input_description['iouThreshold'] = f'(optional) IOU threshold override (default: {nms.iouThreshold})'
- model.input_description['confidenceThreshold'] = \
- f'(optional) Confidence threshold override (default: {nms.confidenceThreshold})'
- model.output_description['confidence'] = 'Boxes × Class confidence (see user-defined metadata "classes")'
- model.output_description['coordinates'] = 'Boxes × [x, y, width, height] (relative to image size)'
- LOGGER.info(f'{prefix} pipeline success')
+ model.input_description["image"] = "Input image"
+ model.input_description["iouThreshold"] = f"(optional) IOU threshold override (default: {nms.iouThreshold})"
+ model.input_description[
+ "confidenceThreshold"
+ ] = f"(optional) Confidence threshold override (default: {nms.confidenceThreshold})"
+ model.output_description["confidence"] = 'Boxes × Class confidence (see user-defined metadata "classes")'
+ model.output_description["coordinates"] = "Boxes × [x, y, width, height] (relative to image size)"
+ LOGGER.info(f"{prefix} pipeline success")
return model
def add_callback(self, event: str, callback):
diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py
index 66052f30..01c81682 100644
--- a/ultralytics/engine/model.py
+++ b/ultralytics/engine/model.py
@@ -53,7 +53,7 @@ class Model(nn.Module):
list(ultralytics.engine.results.Results): The prediction results.
"""
- def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None:
+ def __init__(self, model: Union[str, Path] = "yolov8n.pt", task=None) -> None:
"""
Initializes the YOLO model.
@@ -89,7 +89,7 @@ class Model(nn.Module):
# Load or create new YOLO model
model = checks.check_model_file_from_stem(model) # add suffix, i.e. yolov8n -> yolov8n.pt
- if Path(model).suffix in ('.yaml', '.yml'):
+ if Path(model).suffix in (".yaml", ".yml"):
self._new(model, task)
else:
self._load(model, task)
@@ -112,16 +112,20 @@ class Model(nn.Module):
def is_triton_model(model):
"""Is model a Triton Server URL string, i.e. :////"""
from urllib.parse import urlsplit
+
url = urlsplit(model)
- return url.netloc and url.path and url.scheme in {'http', 'grpc'}
+ return url.netloc and url.path and url.scheme in {"http", "grpc"}
@staticmethod
def is_hub_model(model):
"""Check if the provided model is a HUB model."""
- return any((
- model.startswith(f'{HUB_WEB_ROOT}/models/'), # i.e. https://hub.ultralytics.com/models/MODEL_ID
- [len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
- len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID
+ return any(
+ (
+ model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
+ [len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODELID
+ len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"),
+ )
+ ) # MODELID
def _new(self, cfg: str, task=None, model=None, verbose=True):
"""
@@ -136,9 +140,9 @@ class Model(nn.Module):
cfg_dict = yaml_model_load(cfg)
self.cfg = cfg
self.task = task or guess_model_task(cfg_dict)
- self.model = (model or self._smart_load('model'))(cfg_dict, verbose=verbose and RANK == -1) # build model
- self.overrides['model'] = self.cfg
- self.overrides['task'] = self.task
+ self.model = (model or self._smart_load("model"))(cfg_dict, verbose=verbose and RANK == -1) # build model
+ self.overrides["model"] = self.cfg
+ self.overrides["task"] = self.task
# Below added to allow export from YAMLs
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
@@ -153,9 +157,9 @@ class Model(nn.Module):
task (str | None): model task
"""
suffix = Path(weights).suffix
- if suffix == '.pt':
+ if suffix == ".pt":
self.model, self.ckpt = attempt_load_one_weight(weights)
- self.task = self.model.args['task']
+ self.task = self.model.args["task"]
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
self.ckpt_path = self.model.pt_path
else:
@@ -163,12 +167,12 @@ class Model(nn.Module):
self.model, self.ckpt = weights, None
self.task = task or guess_model_task(weights)
self.ckpt_path = weights
- self.overrides['model'] = weights
- self.overrides['task'] = self.task
+ self.overrides["model"] = weights
+ self.overrides["task"] = self.task
def _check_is_pytorch_model(self):
"""Raises TypeError is model is not a PyTorch model."""
- pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
+ pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
pt_module = isinstance(self.model, nn.Module)
if not (pt_module or pt_str):
raise TypeError(
@@ -176,19 +180,20 @@ class Model(nn.Module):
f"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported "
f"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, "
f"i.e. 'yolo predict model=yolov8n.onnx'.\nTo run CUDA or MPS inference please pass the device "
- f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'")
+ f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'"
+ )
def reset_weights(self):
"""Resets the model modules parameters to randomly initialized values, losing all training information."""
self._check_is_pytorch_model()
for m in self.model.modules():
- if hasattr(m, 'reset_parameters'):
+ if hasattr(m, "reset_parameters"):
m.reset_parameters()
for p in self.model.parameters():
p.requires_grad = True
return self
- def load(self, weights='yolov8n.pt'):
+ def load(self, weights="yolov8n.pt"):
"""Transfers parameters with matching names and shapes from 'weights' to model."""
self._check_is_pytorch_model()
if isinstance(weights, (str, Path)):
@@ -226,8 +231,8 @@ class Model(nn.Module):
Returns:
(List[torch.Tensor]): A list of image embeddings.
"""
- if not kwargs.get('embed'):
- kwargs['embed'] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
+ if not kwargs.get("embed"):
+ kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
return self.predict(source, stream, **kwargs)
def predict(self, source=None, stream=False, predictor=None, **kwargs):
@@ -249,21 +254,22 @@ class Model(nn.Module):
source = ASSETS
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
- is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
- x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
+ is_cli = (sys.argv[0].endswith("yolo") or sys.argv[0].endswith("ultralytics")) and any(
+ x in sys.argv for x in ("predict", "track", "mode=predict", "mode=track")
+ )
- custom = {'conf': 0.25, 'save': is_cli} # method defaults
- args = {**self.overrides, **custom, **kwargs, 'mode': 'predict'} # highest priority args on the right
- prompts = args.pop('prompts', None) # for SAM-type models
+ custom = {"conf": 0.25, "save": is_cli} # method defaults
+ args = {**self.overrides, **custom, **kwargs, "mode": "predict"} # highest priority args on the right
+ prompts = args.pop("prompts", None) # for SAM-type models
if not self.predictor:
- self.predictor = predictor or self._smart_load('predictor')(overrides=args, _callbacks=self.callbacks)
+ self.predictor = predictor or self._smart_load("predictor")(overrides=args, _callbacks=self.callbacks)
self.predictor.setup_model(model=self.model, verbose=is_cli)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, args)
- if 'project' in args or 'name' in args:
+ if "project" in args or "name" in args:
self.predictor.save_dir = get_save_dir(self.predictor.args)
- if prompts and hasattr(self.predictor, 'set_prompts'): # for SAM-type models
+ if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models
self.predictor.set_prompts(prompts)
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
@@ -280,11 +286,12 @@ class Model(nn.Module):
Returns:
(List[ultralytics.engine.results.Results]): The tracking results.
"""
- if not hasattr(self.predictor, 'trackers'):
+ if not hasattr(self.predictor, "trackers"):
from ultralytics.trackers import register_tracker
+
register_tracker(self, persist)
- kwargs['conf'] = kwargs.get('conf') or 0.1 # ByteTrack-based method needs low confidence predictions as input
- kwargs['mode'] = 'track'
+ kwargs["conf"] = kwargs.get("conf") or 0.1 # ByteTrack-based method needs low confidence predictions as input
+ kwargs["mode"] = "track"
return self.predict(source=source, stream=stream, **kwargs)
def val(self, validator=None, **kwargs):
@@ -295,10 +302,10 @@ class Model(nn.Module):
validator (BaseValidator): Customized validator.
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
"""
- custom = {'rect': True} # method defaults
- args = {**self.overrides, **custom, **kwargs, 'mode': 'val'} # highest priority args on the right
+ custom = {"rect": True} # method defaults
+ args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right
- validator = (validator or self._smart_load('validator'))(args=args, _callbacks=self.callbacks)
+ validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks)
validator(model=self.model)
self.metrics = validator.metrics
return validator.metrics
@@ -313,16 +320,17 @@ class Model(nn.Module):
self._check_is_pytorch_model()
from ultralytics.utils.benchmarks import benchmark
- custom = {'verbose': False} # method defaults
- args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, 'mode': 'benchmark'}
+ custom = {"verbose": False} # method defaults
+ args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"}
return benchmark(
model=self,
- data=kwargs.get('data'), # if no 'data' argument passed set data=None for default datasets
- imgsz=args['imgsz'],
- half=args['half'],
- int8=args['int8'],
- device=args['device'],
- verbose=kwargs.get('verbose'))
+ data=kwargs.get("data"), # if no 'data' argument passed set data=None for default datasets
+ imgsz=args["imgsz"],
+ half=args["half"],
+ int8=args["int8"],
+ device=args["device"],
+ verbose=kwargs.get("verbose"),
+ )
def export(self, **kwargs):
"""
@@ -334,8 +342,8 @@ class Model(nn.Module):
self._check_is_pytorch_model()
from .exporter import Exporter
- custom = {'imgsz': self.model.args['imgsz'], 'batch': 1, 'data': None, 'verbose': False} # method defaults
- args = {**self.overrides, **custom, **kwargs, 'mode': 'export'} # highest priority args on the right
+ custom = {"imgsz": self.model.args["imgsz"], "batch": 1, "data": None, "verbose": False} # method defaults
+ args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
def train(self, trainer=None, **kwargs):
@@ -347,32 +355,32 @@ class Model(nn.Module):
**kwargs (Any): Any number of arguments representing the training configuration.
"""
self._check_is_pytorch_model()
- if hasattr(self.session, 'model') and self.session.model.id: # Ultralytics HUB session with loaded model
+ if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model
if any(kwargs):
- LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
+ LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.")
kwargs = self.session.train_args # overwrite kwargs
checks.check_pip_update_available()
- overrides = yaml_load(checks.check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides
- custom = {'data': DEFAULT_CFG_DICT['data'] or TASK2DATA[self.task]} # method defaults
- args = {**overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right
- if args.get('resume'):
- args['resume'] = self.ckpt_path
+ overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides
+ custom = {"data": DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task]} # method defaults
+ args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
+ if args.get("resume"):
+ args["resume"] = self.ckpt_path
- self.trainer = (trainer or self._smart_load('trainer'))(overrides=args, _callbacks=self.callbacks)
- if not args.get('resume'): # manually set model only if not resuming
+ self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)
+ if not args.get("resume"): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
- if SETTINGS['hub'] is True and not self.session:
+ if SETTINGS["hub"] is True and not self.session:
# Create a model in HUB
try:
self.session = self._get_hub_session(self.model_name)
if self.session:
self.session.create_model(args)
# Check model was created
- if not getattr(self.session.model, 'id', None):
+ if not getattr(self.session.model, "id", None):
self.session = None
except PermissionError:
# Ignore permission error
@@ -385,7 +393,7 @@ class Model(nn.Module):
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
self.model, _ = attempt_load_one_weight(ckpt)
self.overrides = self.model.args
- self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
+ self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
return self.metrics
def tune(self, use_ray=False, iterations=10, *args, **kwargs):
@@ -398,12 +406,13 @@ class Model(nn.Module):
self._check_is_pytorch_model()
if use_ray:
from ultralytics.utils.tuner import run_ray_tune
+
return run_ray_tune(self, max_samples=iterations, *args, **kwargs)
else:
from .tuner import Tuner
custom = {} # method defaults
- args = {**self.overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right
+ args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
def _apply(self, fn):
@@ -411,13 +420,13 @@ class Model(nn.Module):
self._check_is_pytorch_model()
self = super()._apply(fn) # noqa
self.predictor = None # reset predictor as device may have changed
- self.overrides['device'] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0'
+ self.overrides["device"] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0'
return self
@property
def names(self):
"""Returns class names of the loaded model."""
- return self.model.names if hasattr(self.model, 'names') else None
+ return self.model.names if hasattr(self.model, "names") else None
@property
def device(self):
@@ -427,7 +436,7 @@ class Model(nn.Module):
@property
def transforms(self):
"""Returns transform of the loaded model."""
- return self.model.transforms if hasattr(self.model, 'transforms') else None
+ return self.model.transforms if hasattr(self.model, "transforms") else None
def add_callback(self, event: str, func):
"""Add a callback."""
@@ -445,7 +454,7 @@ class Model(nn.Module):
@staticmethod
def _reset_ckpt_args(args):
"""Reset arguments when loading a PyTorch model."""
- include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
+ include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model
return {k: v for k, v in args.items() if k in include}
# def __getattr__(self, attr):
@@ -461,7 +470,8 @@ class Model(nn.Module):
name = self.__class__.__name__
mode = inspect.stack()[1][3] # get the function name.
raise NotImplementedError(
- emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.")) from e
+ emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.")
+ ) from e
@property
def task_map(self):
@@ -471,4 +481,4 @@ class Model(nn.Module):
Returns:
task_map (dict): The map of model task to mode classes.
"""
- raise NotImplementedError('Please provide task map for your model!')
+ raise NotImplementedError("Please provide task map for your model!")
diff --git a/ultralytics/engine/predictor.py b/ultralytics/engine/predictor.py
index f78de2fa..f6c029c4 100644
--- a/ultralytics/engine/predictor.py
+++ b/ultralytics/engine/predictor.py
@@ -132,8 +132,11 @@ class BasePredictor:
def inference(self, im, *args, **kwargs):
"""Runs inference on a given image using the specified model and arguments."""
- visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
- mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
+ visualize = (
+ increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
+ if self.args.visualize and (not self.source_type.tensor)
+ else False
+ )
return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
def pre_transform(self, im):
@@ -153,35 +156,38 @@ class BasePredictor:
def write_results(self, idx, results, batch):
"""Write inference results to a file or directory."""
p, im, _ = batch
- log_string = ''
+ log_string = ""
if len(im.shape) == 3:
im = im[None] # expand for batch dim
if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
- log_string += f'{idx}: '
+ log_string += f"{idx}: "
frame = self.dataset.count
else:
- frame = getattr(self.dataset, 'frame', 0)
+ frame = getattr(self.dataset, "frame", 0)
self.data_path = p
- self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
- log_string += '%gx%g ' % im.shape[2:] # print string
+ self.txt_path = str(self.save_dir / "labels" / p.stem) + ("" if self.dataset.mode == "image" else f"_{frame}")
+ log_string += "%gx%g " % im.shape[2:] # print string
result = results[idx]
log_string += result.verbose()
if self.args.save or self.args.show: # Add bbox to image
plot_args = {
- 'line_width': self.args.line_width,
- 'boxes': self.args.show_boxes,
- 'conf': self.args.show_conf,
- 'labels': self.args.show_labels}
+ "line_width": self.args.line_width,
+ "boxes": self.args.show_boxes,
+ "conf": self.args.show_conf,
+ "labels": self.args.show_labels,
+ }
if not self.args.retina_masks:
- plot_args['im_gpu'] = im[idx]
+ plot_args["im_gpu"] = im[idx]
self.plotted_img = result.plot(**plot_args)
# Write
if self.args.save_txt:
- result.save_txt(f'{self.txt_path}.txt', save_conf=self.args.save_conf)
+ result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf)
if self.args.save_crop:
- result.save_crop(save_dir=self.save_dir / 'crops',
- file_name=self.data_path.stem + ('' if self.dataset.mode == 'image' else f'_{frame}'))
+ result.save_crop(
+ save_dir=self.save_dir / "crops",
+ file_name=self.data_path.stem + ("" if self.dataset.mode == "image" else f"_{frame}"),
+ )
return log_string
@@ -210,17 +216,24 @@ class BasePredictor:
def setup_source(self, source):
"""Sets up source and inference mode."""
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
- self.transforms = getattr(
- self.model.model, 'transforms', classify_transforms(
- self.imgsz[0], crop_fraction=self.args.crop_fraction)) if self.args.task == 'classify' else None
- self.dataset = load_inference_source(source=source,
- imgsz=self.imgsz,
- vid_stride=self.args.vid_stride,
- buffer=self.args.stream_buffer)
+ self.transforms = (
+ getattr(
+ self.model.model,
+ "transforms",
+ classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction),
+ )
+ if self.args.task == "classify"
+ else None
+ )
+ self.dataset = load_inference_source(
+ source=source, imgsz=self.imgsz, vid_stride=self.args.vid_stride, buffer=self.args.stream_buffer
+ )
self.source_type = self.dataset.source_type
- if not getattr(self, 'stream', True) and (self.dataset.mode == 'stream' or # streams
- len(self.dataset) > 1000 or # images
- any(getattr(self.dataset, 'video_flag', [False]))): # videos
+ if not getattr(self, "stream", True) and (
+ self.dataset.mode == "stream" # streams
+ or len(self.dataset) > 1000 # images
+ or any(getattr(self.dataset, "video_flag", [False]))
+ ): # videos
LOGGER.warning(STREAM_WARNING)
self.vid_path = [None] * self.dataset.bs
self.vid_writer = [None] * self.dataset.bs
@@ -230,7 +243,7 @@ class BasePredictor:
def stream_inference(self, source=None, model=None, *args, **kwargs):
"""Streams real-time inference on camera feed and saves results to file."""
if self.args.verbose:
- LOGGER.info('')
+ LOGGER.info("")
# Setup model
if not self.model:
@@ -242,7 +255,7 @@ class BasePredictor:
# Check if save_dir/ label file exists
if self.args.save or self.args.save_txt:
- (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
+ (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
# Warmup model
if not self.done_warmup:
@@ -250,10 +263,10 @@ class BasePredictor:
self.done_warmup = True
self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile())
- self.run_callbacks('on_predict_start')
+ self.run_callbacks("on_predict_start")
for batch in self.dataset:
- self.run_callbacks('on_predict_batch_start')
+ self.run_callbacks("on_predict_batch_start")
self.batch = batch
path, im0s, vid_cap, s = batch
@@ -272,15 +285,16 @@ class BasePredictor:
with profilers[2]:
self.results = self.postprocess(preds, im, im0s)
- self.run_callbacks('on_predict_postprocess_end')
+ self.run_callbacks("on_predict_postprocess_end")
# Visualize, save, write results
n = len(im0s)
for i in range(n):
self.seen += 1
self.results[i].speed = {
- 'preprocess': profilers[0].dt * 1E3 / n,
- 'inference': profilers[1].dt * 1E3 / n,
- 'postprocess': profilers[2].dt * 1E3 / n}
+ "preprocess": profilers[0].dt * 1e3 / n,
+ "inference": profilers[1].dt * 1e3 / n,
+ "postprocess": profilers[2].dt * 1e3 / n,
+ }
p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy()
p = Path(p)
@@ -293,12 +307,12 @@ class BasePredictor:
if self.args.save and self.plotted_img is not None:
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
- self.run_callbacks('on_predict_batch_end')
+ self.run_callbacks("on_predict_batch_end")
yield from self.results
# Print time (inference-only)
if self.args.verbose:
- LOGGER.info(f'{s}{profilers[1].dt * 1E3:.1f}ms')
+ LOGGER.info(f"{s}{profilers[1].dt * 1E3:.1f}ms")
# Release assets
if isinstance(self.vid_writer[-1], cv2.VideoWriter):
@@ -306,25 +320,29 @@ class BasePredictor:
# Print results
if self.args.verbose and self.seen:
- t = tuple(x.t / self.seen * 1E3 for x in profilers) # speeds per image
- LOGGER.info(f'Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape '
- f'{(1, 3, *im.shape[2:])}' % t)
+ t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image
+ LOGGER.info(
+ f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape "
+ f"{(1, 3, *im.shape[2:])}" % t
+ )
if self.args.save or self.args.save_txt or self.args.save_crop:
- nl = len(list(self.save_dir.glob('labels/*.txt'))) # number of labels
- s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
+ nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels
+ s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ""
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
- self.run_callbacks('on_predict_end')
+ self.run_callbacks("on_predict_end")
def setup_model(self, model, verbose=True):
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
- self.model = AutoBackend(model or self.args.model,
- device=select_device(self.args.device, verbose=verbose),
- dnn=self.args.dnn,
- data=self.args.data,
- fp16=self.args.half,
- fuse=True,
- verbose=verbose)
+ self.model = AutoBackend(
+ model or self.args.model,
+ device=select_device(self.args.device, verbose=verbose),
+ dnn=self.args.dnn,
+ data=self.args.data,
+ fp16=self.args.half,
+ fuse=True,
+ verbose=verbose,
+ )
self.device = self.model.device # update device
self.args.half = self.model.fp16 # update half
@@ -333,18 +351,18 @@ class BasePredictor:
def show(self, p):
"""Display an image in a window using OpenCV imshow()."""
im0 = self.plotted_img
- if platform.system() == 'Linux' and p not in self.windows:
+ if platform.system() == "Linux" and p not in self.windows:
self.windows.append(p)
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
cv2.imshow(str(p), im0)
- cv2.waitKey(500 if self.batch[3].startswith('image') else 1) # 1 millisecond
+ cv2.waitKey(500 if self.batch[3].startswith("image") else 1) # 1 millisecond
def save_preds(self, vid_cap, idx, save_path):
"""Save video predictions as mp4 at specified path."""
im0 = self.plotted_img
# Save imgs
- if self.dataset.mode == 'image':
+ if self.dataset.mode == "image":
cv2.imwrite(save_path, im0)
else: # 'video' or 'stream'
frames_path = f'{save_path.split(".", 1)[0]}_frames/'
@@ -361,15 +379,16 @@ class BasePredictor:
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
else: # stream
fps, w, h = 30, im0.shape[1], im0.shape[0]
- suffix, fourcc = ('.mp4', 'avc1') if MACOS else ('.avi', 'WMV2') if WINDOWS else ('.avi', 'MJPG')
- self.vid_writer[idx] = cv2.VideoWriter(str(Path(save_path).with_suffix(suffix)),
- cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
+ suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG")
+ self.vid_writer[idx] = cv2.VideoWriter(
+ str(Path(save_path).with_suffix(suffix)), cv2.VideoWriter_fourcc(*fourcc), fps, (w, h)
+ )
# Write video
self.vid_writer[idx].write(im0)
# Write frame
if self.args.save_frames:
- cv2.imwrite(f'{frames_path}{self.vid_frame[idx]}.jpg', im0)
+ cv2.imwrite(f"{frames_path}{self.vid_frame[idx]}.jpg", im0)
self.vid_frame[idx] += 1
def run_callbacks(self, event: str):
diff --git a/ultralytics/engine/results.py b/ultralytics/engine/results.py
index 21e11078..82d654ed 100644
--- a/ultralytics/engine/results.py
+++ b/ultralytics/engine/results.py
@@ -98,15 +98,15 @@ class Results(SimpleClass):
self.probs = Probs(probs) if probs is not None else None
self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None
self.obb = OBB(obb, self.orig_shape) if obb is not None else None
- self.speed = {'preprocess': None, 'inference': None, 'postprocess': None} # milliseconds per image
+ self.speed = {"preprocess": None, "inference": None, "postprocess": None} # milliseconds per image
self.names = names
self.path = path
self.save_dir = None
- self._keys = 'boxes', 'masks', 'probs', 'keypoints', 'obb'
+ self._keys = "boxes", "masks", "probs", "keypoints", "obb"
def __getitem__(self, idx):
"""Return a Results object for the specified index."""
- return self._apply('__getitem__', idx)
+ return self._apply("__getitem__", idx)
def __len__(self):
"""Return the number of detections in the Results object."""
@@ -146,19 +146,19 @@ class Results(SimpleClass):
def cpu(self):
"""Return a copy of the Results object with all tensors on CPU memory."""
- return self._apply('cpu')
+ return self._apply("cpu")
def numpy(self):
"""Return a copy of the Results object with all tensors as numpy arrays."""
- return self._apply('numpy')
+ return self._apply("numpy")
def cuda(self):
"""Return a copy of the Results object with all tensors on GPU memory."""
- return self._apply('cuda')
+ return self._apply("cuda")
def to(self, *args, **kwargs):
"""Return a copy of the Results object with tensors on the specified device and dtype."""
- return self._apply('to', *args, **kwargs)
+ return self._apply("to", *args, **kwargs)
def new(self):
"""Return a new Results object with the same image, path, and names."""
@@ -169,7 +169,7 @@ class Results(SimpleClass):
conf=True,
line_width=None,
font_size=None,
- font='Arial.ttf',
+ font="Arial.ttf",
pil=False,
img=None,
im_gpu=None,
@@ -229,14 +229,20 @@ class Results(SimpleClass):
font_size,
font,
pil or (pred_probs is not None and show_probs), # Classify tasks default to pil=True
- example=names)
+ example=names,
+ )
# Plot Segment results
if pred_masks and show_masks:
if im_gpu is None:
img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
- im_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute(
- 2, 0, 1).flip(0).contiguous() / 255
+ im_gpu = (
+ torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device)
+ .permute(2, 0, 1)
+ .flip(0)
+ .contiguous()
+ / 255
+ )
idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu)
@@ -244,14 +250,14 @@ class Results(SimpleClass):
if pred_boxes is not None and show_boxes:
for d in reversed(pred_boxes):
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
- name = ('' if id is None else f'id:{id} ') + names[c]
- label = (f'{name} {conf:.2f}' if conf else name) if labels else None
+ name = ("" if id is None else f"id:{id} ") + names[c]
+ label = (f"{name} {conf:.2f}" if conf else name) if labels else None
box = d.xyxyxyxy.reshape(-1, 4, 2).squeeze() if is_obb else d.xyxy.squeeze()
annotator.box_label(box, label, color=colors(c, True), rotated=is_obb)
# Plot Classify results
if pred_probs is not None and show_probs:
- text = ',\n'.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)
+ text = ",\n".join(f"{names[j] if names else j} {pred_probs.data[j]:.2f}" for j in pred_probs.top5)
x = round(self.orig_shape[0] * 0.03)
annotator.text([x, x], text, txt_color=(255, 255, 255)) # TODO: allow setting colors
@@ -264,11 +270,11 @@ class Results(SimpleClass):
def verbose(self):
"""Return log string for each task."""
- log_string = ''
+ log_string = ""
probs = self.probs
boxes = self.boxes
if len(self) == 0:
- return log_string if probs is not None else f'{log_string}(no detections), '
+ return log_string if probs is not None else f"{log_string}(no detections), "
if probs is not None:
log_string += f"{', '.join(f'{self.names[j]} {probs.data[j]:.2f}' for j in probs.top5)}, "
if boxes:
@@ -293,7 +299,7 @@ class Results(SimpleClass):
texts = []
if probs is not None:
# Classify
- [texts.append(f'{probs.data[j]:.2f} {self.names[j]}') for j in probs.top5]
+ [texts.append(f"{probs.data[j]:.2f} {self.names[j]}") for j in probs.top5]
elif boxes:
# Detect/segment/pose
for j, d in enumerate(boxes):
@@ -304,16 +310,16 @@ class Results(SimpleClass):
line = (c, *seg)
if kpts is not None:
kpt = torch.cat((kpts[j].xyn, kpts[j].conf[..., None]), 2) if kpts[j].has_visible else kpts[j].xyn
- line += (*kpt.reshape(-1).tolist(), )
- line += (conf, ) * save_conf + (() if id is None else (id, ))
- texts.append(('%g ' * len(line)).rstrip() % line)
+ line += (*kpt.reshape(-1).tolist(),)
+ line += (conf,) * save_conf + (() if id is None else (id,))
+ texts.append(("%g " * len(line)).rstrip() % line)
if texts:
Path(txt_file).parent.mkdir(parents=True, exist_ok=True) # make directory
- with open(txt_file, 'a') as f:
- f.writelines(text + '\n' for text in texts)
+ with open(txt_file, "a") as f:
+ f.writelines(text + "\n" for text in texts)
- def save_crop(self, save_dir, file_name=Path('im.jpg')):
+ def save_crop(self, save_dir, file_name=Path("im.jpg")):
"""
Save cropped predictions to `save_dir/cls/file_name.jpg`.
@@ -322,21 +328,23 @@ class Results(SimpleClass):
file_name (str | pathlib.Path): File name.
"""
if self.probs is not None:
- LOGGER.warning('WARNING ⚠️ Classify task do not support `save_crop`.')
+ LOGGER.warning("WARNING ⚠️ Classify task do not support `save_crop`.")
return
if self.obb is not None:
- LOGGER.warning('WARNING ⚠️ OBB task do not support `save_crop`.')
+ LOGGER.warning("WARNING ⚠️ OBB task do not support `save_crop`.")
return
for d in self.boxes:
- save_one_box(d.xyxy,
- self.orig_img.copy(),
- file=Path(save_dir) / self.names[int(d.cls)] / f'{Path(file_name)}.jpg',
- BGR=True)
+ save_one_box(
+ d.xyxy,
+ self.orig_img.copy(),
+ file=Path(save_dir) / self.names[int(d.cls)] / f"{Path(file_name)}.jpg",
+ BGR=True,
+ )
def tojson(self, normalize=False):
"""Convert the object to JSON format."""
if self.probs is not None:
- LOGGER.warning('Warning: Classify task do not support `tojson` yet.')
+ LOGGER.warning("Warning: Classify task do not support `tojson` yet.")
return
import json
@@ -346,19 +354,19 @@ class Results(SimpleClass):
data = self.boxes.data.cpu().tolist()
h, w = self.orig_shape if normalize else (1, 1)
for i, row in enumerate(data): # xyxy, track_id if tracking, conf, class_id
- box = {'x1': row[0] / w, 'y1': row[1] / h, 'x2': row[2] / w, 'y2': row[3] / h}
+ box = {"x1": row[0] / w, "y1": row[1] / h, "x2": row[2] / w, "y2": row[3] / h}
conf = row[-2]
class_id = int(row[-1])
name = self.names[class_id]
- result = {'name': name, 'class': class_id, 'confidence': conf, 'box': box}
+ result = {"name": name, "class": class_id, "confidence": conf, "box": box}
if self.boxes.is_track:
- result['track_id'] = int(row[-3]) # track ID
+ result["track_id"] = int(row[-3]) # track ID
if self.masks:
x, y = self.masks.xy[i][:, 0], self.masks.xy[i][:, 1] # numpy array
- result['segments'] = {'x': (x / w).tolist(), 'y': (y / h).tolist()}
+ result["segments"] = {"x": (x / w).tolist(), "y": (y / h).tolist()}
if self.keypoints is not None:
x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1) # torch Tensor
- result['keypoints'] = {'x': (x / w).tolist(), 'y': (y / h).tolist(), 'visible': visible.tolist()}
+ result["keypoints"] = {"x": (x / w).tolist(), "y": (y / h).tolist(), "visible": visible.tolist()}
results.append(result)
# Convert detections to JSON
@@ -397,7 +405,7 @@ class Boxes(BaseTensor):
if boxes.ndim == 1:
boxes = boxes[None, :]
n = boxes.shape[-1]
- assert n in (6, 7), f'expected 6 or 7 values but got {n}' # xyxy, track_id, conf, cls
+ assert n in (6, 7), f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls
super().__init__(boxes, orig_shape)
self.is_track = n == 7
self.orig_shape = orig_shape
@@ -474,7 +482,8 @@ class Masks(BaseTensor):
"""Return normalized segments."""
return [
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True)
- for x in ops.masks2segments(self.data)]
+ for x in ops.masks2segments(self.data)
+ ]
@property
@lru_cache(maxsize=1)
@@ -482,7 +491,8 @@ class Masks(BaseTensor):
"""Return segments in pixel coordinates."""
return [
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False)
- for x in ops.masks2segments(self.data)]
+ for x in ops.masks2segments(self.data)
+ ]
class Keypoints(BaseTensor):
@@ -610,7 +620,7 @@ class OBB(BaseTensor):
if boxes.ndim == 1:
boxes = boxes[None, :]
n = boxes.shape[-1]
- assert n in (7, 8), f'expected 7 or 8 values but got {n}' # xywh, rotation, track_id, conf, cls
+ assert n in (7, 8), f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls
super().__init__(boxes, orig_shape)
self.is_track = n == 8
self.orig_shape = orig_shape
diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py
index 15eefb23..0fa98396 100644
--- a/ultralytics/engine/trainer.py
+++ b/ultralytics/engine/trainer.py
@@ -23,14 +23,31 @@ from torch import nn, optim
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
-from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, TQDM, __version__, callbacks, clean_url, colorstr, emojis,
- yaml_save)
+from ultralytics.utils import (
+ DEFAULT_CFG,
+ LOGGER,
+ RANK,
+ TQDM,
+ __version__,
+ callbacks,
+ clean_url,
+ colorstr,
+ emojis,
+ yaml_save,
+)
from ultralytics.utils.autobatch import check_train_batch_size
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.utils.files import get_latest_run
-from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device,
- strip_optimizer)
+from ultralytics.utils.torch_utils import (
+ EarlyStopping,
+ ModelEMA,
+ de_parallel,
+ init_seeds,
+ one_cycle,
+ select_device,
+ strip_optimizer,
+)
class BaseTrainer:
@@ -89,12 +106,12 @@ class BaseTrainer:
# Dirs
self.save_dir = get_save_dir(self.args)
self.args.name = self.save_dir.name # update name for loggers
- self.wdir = self.save_dir / 'weights' # weights dir
+ self.wdir = self.save_dir / "weights" # weights dir
if RANK in (-1, 0):
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
self.args.save_dir = str(self.save_dir)
- yaml_save(self.save_dir / 'args.yaml', vars(self.args)) # save run args
- self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
+ yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
+ self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
self.save_period = self.args.save_period
self.batch_size = self.args.batch
@@ -104,18 +121,18 @@ class BaseTrainer:
print_args(vars(self.args))
# Device
- if self.device.type in ('cpu', 'mps'):
+ if self.device.type in ("cpu", "mps"):
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
# Model and Dataset
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
try:
- if self.args.task == 'classify':
+ if self.args.task == "classify":
self.data = check_cls_dataset(self.args.data)
- elif self.args.data.split('.')[-1] in ('yaml', 'yml') or self.args.task in ('detect', 'segment', 'pose'):
+ elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in ("detect", "segment", "pose"):
self.data = check_det_dataset(self.args.data)
- if 'yaml_file' in self.data:
- self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
+ if "yaml_file" in self.data:
+ self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
except Exception as e:
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
@@ -131,8 +148,8 @@ class BaseTrainer:
self.fitness = None
self.loss = None
self.tloss = None
- self.loss_names = ['Loss']
- self.csv = self.save_dir / 'results.csv'
+ self.loss_names = ["Loss"]
+ self.csv = self.save_dir / "results.csv"
self.plot_idx = [0, 1, 2]
# Callbacks
@@ -156,7 +173,7 @@ class BaseTrainer:
def train(self):
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
- world_size = len(self.args.device.split(','))
+ world_size = len(self.args.device.split(","))
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
world_size = len(self.args.device)
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
@@ -165,14 +182,16 @@ class BaseTrainer:
world_size = 0
# Run subprocess if DDP training, else train normally
- if world_size > 1 and 'LOCAL_RANK' not in os.environ:
+ if world_size > 1 and "LOCAL_RANK" not in os.environ:
# Argument checks
if self.args.rect:
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
self.args.rect = False
if self.args.batch == -1:
- LOGGER.warning("WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
- "default 'batch=16'")
+ LOGGER.warning(
+ "WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
+ "default 'batch=16'"
+ )
self.args.batch = 16
# Command
@@ -199,37 +218,45 @@ class BaseTrainer:
def _setup_ddp(self, world_size):
"""Initializes and sets the DistributedDataParallel parameters for training."""
torch.cuda.set_device(RANK)
- self.device = torch.device('cuda', RANK)
+ self.device = torch.device("cuda", RANK)
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
- os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout
+ os.environ["NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
dist.init_process_group(
- 'nccl' if dist.is_nccl_available() else 'gloo',
+ "nccl" if dist.is_nccl_available() else "gloo",
timeout=timedelta(seconds=10800), # 3 hours
rank=RANK,
- world_size=world_size)
+ world_size=world_size,
+ )
def _setup_train(self, world_size):
"""Builds dataloaders and optimizer on correct rank process."""
# Model
- self.run_callbacks('on_pretrain_routine_start')
+ self.run_callbacks("on_pretrain_routine_start")
ckpt = self.setup_model()
self.model = self.model.to(self.device)
self.set_model_attributes()
# Freeze layers
- freeze_list = self.args.freeze if isinstance(
- self.args.freeze, list) else range(self.args.freeze) if isinstance(self.args.freeze, int) else []
- always_freeze_names = ['.dfl'] # always freeze these layers
- freeze_layer_names = [f'model.{x}.' for x in freeze_list] + always_freeze_names
+ freeze_list = (
+ self.args.freeze
+ if isinstance(self.args.freeze, list)
+ else range(self.args.freeze)
+ if isinstance(self.args.freeze, int)
+ else []
+ )
+ always_freeze_names = [".dfl"] # always freeze these layers
+ freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
for k, v in self.model.named_parameters():
# v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
if any(x in k for x in freeze_layer_names):
LOGGER.info(f"Freezing layer '{k}'")
v.requires_grad = False
elif not v.requires_grad:
- LOGGER.info(f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
- 'See ultralytics.engine.trainer for customization of frozen layers.')
+ LOGGER.info(
+ f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
+ "See ultralytics.engine.trainer for customization of frozen layers."
+ )
v.requires_grad = True
# Check AMP
@@ -246,7 +273,7 @@ class BaseTrainer:
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])
# Check imgsz
- gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
+ gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride)
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
self.stride = gs # for multi-scale training
@@ -256,15 +283,14 @@ class BaseTrainer:
# Dataloaders
batch_size = self.batch_size // max(world_size, 1)
- self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
+ self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
if RANK in (-1, 0):
# NOTE: When training DOTA dataset, double batch size could get OOM cause some images got more than 2000 objects.
- self.test_loader = self.get_dataloader(self.testset,
- batch_size=batch_size if self.args.task == 'obb' else batch_size * 2,
- rank=-1,
- mode='val')
+ self.test_loader = self.get_dataloader(
+ self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
+ )
self.validator = self.get_validator()
- metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
+ metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
self.ema = ModelEMA(self.model)
if self.args.plots:
@@ -274,18 +300,20 @@ class BaseTrainer:
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
- self.optimizer = self.build_optimizer(model=self.model,
- name=self.args.optimizer,
- lr=self.args.lr0,
- momentum=self.args.momentum,
- decay=weight_decay,
- iterations=iterations)
+ self.optimizer = self.build_optimizer(
+ model=self.model,
+ name=self.args.optimizer,
+ lr=self.args.lr0,
+ momentum=self.args.momentum,
+ decay=weight_decay,
+ iterations=iterations,
+ )
# Scheduler
self._setup_scheduler()
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
self.resume_training(ckpt)
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
- self.run_callbacks('on_pretrain_routine_end')
+ self.run_callbacks("on_pretrain_routine_end")
def _do_train(self, world_size=1):
"""Train completed, evaluate and plot if specified by arguments."""
@@ -299,19 +327,23 @@ class BaseTrainer:
self.epoch_time = None
self.epoch_time_start = time.time()
self.train_time_start = time.time()
- self.run_callbacks('on_train_start')
- LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
- f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
- f"Logging results to {colorstr('bold', self.save_dir)}\n"
- f'Starting training for '
- f'{self.args.time} hours...' if self.args.time else f'{self.epochs} epochs...')
+ self.run_callbacks("on_train_start")
+ LOGGER.info(
+ f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
+ f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
+ f"Logging results to {colorstr('bold', self.save_dir)}\n"
+ f'Starting training for '
+ f'{self.args.time} hours...'
+ if self.args.time
+ else f"{self.epochs} epochs..."
+ )
if self.args.close_mosaic:
base_idx = (self.epochs - self.args.close_mosaic) * nb
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
epoch = self.epochs # predefine for resume fully trained model edge cases
for epoch in range(self.start_epoch, self.epochs):
self.epoch = epoch
- self.run_callbacks('on_train_epoch_start')
+ self.run_callbacks("on_train_epoch_start")
self.model.train()
if RANK != -1:
self.train_loader.sampler.set_epoch(epoch)
@@ -327,7 +359,7 @@ class BaseTrainer:
self.tloss = None
self.optimizer.zero_grad()
for i, batch in pbar:
- self.run_callbacks('on_train_batch_start')
+ self.run_callbacks("on_train_batch_start")
# Warmup
ni = i + nb * epoch
if ni <= nw:
@@ -335,10 +367,11 @@ class BaseTrainer:
self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
for j, x in enumerate(self.optimizer.param_groups):
# Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
- x['lr'] = np.interp(
- ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)])
- if 'momentum' in x:
- x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
+ x["lr"] = np.interp(
+ ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
+ )
+ if "momentum" in x:
+ x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
# Forward
with torch.cuda.amp.autocast(self.amp):
@@ -346,8 +379,9 @@ class BaseTrainer:
self.loss, self.loss_items = self.model(batch)
if RANK != -1:
self.loss *= world_size
- self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
- else self.loss_items
+ self.tloss = (
+ (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
+ )
# Backward
self.scaler.scale(self.loss).backward()
@@ -368,24 +402,25 @@ class BaseTrainer:
break
# Log
- mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
+ mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
if RANK in (-1, 0):
pbar.set_description(
- ('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
- (f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1]))
- self.run_callbacks('on_batch_end')
+ ("%11s" * 2 + "%11.4g" * (2 + loss_len))
+ % (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
+ )
+ self.run_callbacks("on_batch_end")
if self.args.plots and ni in self.plot_idx:
self.plot_training_samples(batch, ni)
- self.run_callbacks('on_train_batch_end')
+ self.run_callbacks("on_train_batch_end")
- self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
- self.run_callbacks('on_train_epoch_end')
+ self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
+ self.run_callbacks("on_train_epoch_end")
if RANK in (-1, 0):
final_epoch = epoch + 1 == self.epochs
- self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
+ self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
# Validation
if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
@@ -398,14 +433,14 @@ class BaseTrainer:
# Save model
if self.args.save or final_epoch:
self.save_model()
- self.run_callbacks('on_model_save')
+ self.run_callbacks("on_model_save")
# Scheduler
t = time.time()
self.epoch_time = t - self.epoch_time_start
self.epoch_time_start = t
with warnings.catch_warnings():
- warnings.simplefilter('ignore') # suppress 'Detected lr_scheduler.step() before optimizer.step()'
+ warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
if self.args.time:
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
@@ -413,7 +448,7 @@ class BaseTrainer:
self.scheduler.last_epoch = self.epoch # do not move
self.stop |= epoch >= self.epochs # stop if exceeded epochs
self.scheduler.step()
- self.run_callbacks('on_fit_epoch_end')
+ self.run_callbacks("on_fit_epoch_end")
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
# Early Stopping
@@ -426,39 +461,43 @@ class BaseTrainer:
if RANK in (-1, 0):
# Do final val with best.pt
- LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in '
- f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
+ LOGGER.info(
+ f"\n{epoch - self.start_epoch + 1} epochs completed in "
+ f"{(time.time() - self.train_time_start) / 3600:.3f} hours."
+ )
self.final_eval()
if self.args.plots:
self.plot_metrics()
- self.run_callbacks('on_train_end')
+ self.run_callbacks("on_train_end")
torch.cuda.empty_cache()
- self.run_callbacks('teardown')
+ self.run_callbacks("teardown")
def save_model(self):
"""Save model training checkpoints with additional metadata."""
import pandas as pd # scope for faster startup
- metrics = {**self.metrics, **{'fitness': self.fitness}}
- results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient='list').items()}
+
+ metrics = {**self.metrics, **{"fitness": self.fitness}}
+ results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
ckpt = {
- 'epoch': self.epoch,
- 'best_fitness': self.best_fitness,
- 'model': deepcopy(de_parallel(self.model)).half(),
- 'ema': deepcopy(self.ema.ema).half(),
- 'updates': self.ema.updates,
- 'optimizer': self.optimizer.state_dict(),
- 'train_args': vars(self.args), # save as dict
- 'train_metrics': metrics,
- 'train_results': results,
- 'date': datetime.now().isoformat(),
- 'version': __version__}
+ "epoch": self.epoch,
+ "best_fitness": self.best_fitness,
+ "model": deepcopy(de_parallel(self.model)).half(),
+ "ema": deepcopy(self.ema.ema).half(),
+ "updates": self.ema.updates,
+ "optimizer": self.optimizer.state_dict(),
+ "train_args": vars(self.args), # save as dict
+ "train_metrics": metrics,
+ "train_results": results,
+ "date": datetime.now().isoformat(),
+ "version": __version__,
+ }
# Save last and best
torch.save(ckpt, self.last)
if self.best_fitness == self.fitness:
torch.save(ckpt, self.best)
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
- torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt')
+ torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt")
@staticmethod
def get_dataset(data):
@@ -467,7 +506,7 @@ class BaseTrainer:
Returns None if data format is not recognized.
"""
- return data['train'], data.get('val') or data.get('test')
+ return data["train"], data.get("val") or data.get("test")
def setup_model(self):
"""Load/create/download model for any task."""
@@ -476,9 +515,9 @@ class BaseTrainer:
model, weights = self.model, None
ckpt = None
- if str(model).endswith('.pt'):
+ if str(model).endswith(".pt"):
weights, ckpt = attempt_load_one_weight(model)
- cfg = ckpt['model'].yaml
+ cfg = ckpt["model"].yaml
else:
cfg = model
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
@@ -505,7 +544,7 @@ class BaseTrainer:
The returned dict is expected to contain "fitness" key.
"""
metrics = self.validator(self)
- fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
+ fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
if not self.best_fitness or self.best_fitness < fitness:
self.best_fitness = fitness
return metrics, fitness
@@ -516,24 +555,24 @@ class BaseTrainer:
def get_validator(self):
"""Returns a NotImplementedError when the get_validator function is called."""
- raise NotImplementedError('get_validator function not implemented in trainer')
+ raise NotImplementedError("get_validator function not implemented in trainer")
- def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
+ def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
"""Returns dataloader derived from torch.data.Dataloader."""
- raise NotImplementedError('get_dataloader function not implemented in trainer')
+ raise NotImplementedError("get_dataloader function not implemented in trainer")
- def build_dataset(self, img_path, mode='train', batch=None):
+ def build_dataset(self, img_path, mode="train", batch=None):
"""Build dataset."""
- raise NotImplementedError('build_dataset function not implemented in trainer')
+ raise NotImplementedError("build_dataset function not implemented in trainer")
- def label_loss_items(self, loss_items=None, prefix='train'):
+ def label_loss_items(self, loss_items=None, prefix="train"):
"""Returns a loss dict with labelled training loss items tensor."""
# Not needed for classification but necessary for segmentation & detection
- return {'loss': loss_items} if loss_items is not None else ['loss']
+ return {"loss": loss_items} if loss_items is not None else ["loss"]
def set_model_attributes(self):
"""To set or update model parameters before training."""
- self.model.names = self.data['names']
+ self.model.names = self.data["names"]
def build_targets(self, preds, targets):
"""Builds target tensors for training YOLO model."""
@@ -541,7 +580,7 @@ class BaseTrainer:
def progress_string(self):
"""Returns a string describing training progress."""
- return ''
+ return ""
# TODO: may need to put these following functions into callback
def plot_training_samples(self, batch, ni):
@@ -556,9 +595,9 @@ class BaseTrainer:
"""Saves training metrics to a CSV file."""
keys, vals = list(metrics.keys()), list(metrics.values())
n = len(metrics) + 1 # number of cols
- s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
- with open(self.csv, 'a') as f:
- f.write(s + ('%23.5g,' * n % tuple([self.epoch + 1] + vals)).rstrip(',') + '\n')
+ s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n") # header
+ with open(self.csv, "a") as f:
+ f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n")
def plot_metrics(self):
"""Plot and display metrics visually."""
@@ -567,7 +606,7 @@ class BaseTrainer:
def on_plot(self, name, data=None):
"""Registers plots (e.g. to be consumed in callbacks)"""
path = Path(name)
- self.plots[path] = {'data': data, 'timestamp': time.time()}
+ self.plots[path] = {"data": data, "timestamp": time.time()}
def final_eval(self):
"""Performs final evaluation and validation for object detection YOLO model."""
@@ -575,11 +614,11 @@ class BaseTrainer:
if f.exists():
strip_optimizer(f) # strip optimizers
if f is self.best:
- LOGGER.info(f'\nValidating {f}...')
+ LOGGER.info(f"\nValidating {f}...")
self.validator.args.plots = self.args.plots
self.metrics = self.validator(model=f)
- self.metrics.pop('fitness', None)
- self.run_callbacks('on_fit_epoch_end')
+ self.metrics.pop("fitness", None)
+ self.run_callbacks("on_fit_epoch_end")
def check_resume(self, overrides):
"""Check if resume checkpoint exists and update arguments accordingly."""
@@ -591,19 +630,21 @@ class BaseTrainer:
# Check that resume data YAML exists, otherwise strip to force re-download of dataset
ckpt_args = attempt_load_weights(last).args
- if not Path(ckpt_args['data']).exists():
- ckpt_args['data'] = self.args.data
+ if not Path(ckpt_args["data"]).exists():
+ ckpt_args["data"] = self.args.data
resume = True
self.args = get_cfg(ckpt_args)
self.args.model = str(last) # reinstate model
- for k in 'imgsz', 'batch': # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
+ for k in "imgsz", "batch": # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
if k in overrides:
setattr(self.args, k, overrides[k])
except Exception as e:
- raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
- "i.e. 'yolo train resume model=path/to/last.pt'") from e
+ raise FileNotFoundError(
+ "Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
+ "i.e. 'yolo train resume model=path/to/last.pt'"
+ ) from e
self.resume = resume
def resume_training(self, ckpt):
@@ -611,23 +652,26 @@ class BaseTrainer:
if ckpt is None:
return
best_fitness = 0.0
- start_epoch = ckpt['epoch'] + 1
- if ckpt['optimizer'] is not None:
- self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer
- best_fitness = ckpt['best_fitness']
- if self.ema and ckpt.get('ema'):
- self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
- self.ema.updates = ckpt['updates']
+ start_epoch = ckpt["epoch"] + 1
+ if ckpt["optimizer"] is not None:
+ self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
+ best_fitness = ckpt["best_fitness"]
+ if self.ema and ckpt.get("ema"):
+ self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
+ self.ema.updates = ckpt["updates"]
if self.resume:
- assert start_epoch > 0, \
- f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
+ assert start_epoch > 0, (
+ f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
+ )
LOGGER.info(
- f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs')
+ f"Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs"
+ )
if self.epochs < start_epoch:
LOGGER.info(
- f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.")
- self.epochs += ckpt['epoch'] # finetune additional epochs
+ f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
+ )
+ self.epochs += ckpt["epoch"] # finetune additional epochs
self.best_fitness = best_fitness
self.start_epoch = start_epoch
if start_epoch > (self.epochs - self.args.close_mosaic):
@@ -635,13 +679,13 @@ class BaseTrainer:
def _close_dataloader_mosaic(self):
"""Update dataloaders to stop using mosaic augmentation."""
- if hasattr(self.train_loader.dataset, 'mosaic'):
+ if hasattr(self.train_loader.dataset, "mosaic"):
self.train_loader.dataset.mosaic = False
- if hasattr(self.train_loader.dataset, 'close_mosaic'):
- LOGGER.info('Closing dataloader mosaic')
+ if hasattr(self.train_loader.dataset, "close_mosaic"):
+ LOGGER.info("Closing dataloader mosaic")
self.train_loader.dataset.close_mosaic(hyp=self.args)
- def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
+ def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
"""
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
weight decay, and number of iterations.
@@ -661,41 +705,45 @@ class BaseTrainer:
"""
g = [], [], [] # optimizer parameter groups
- bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
- if name == 'auto':
- LOGGER.info(f"{colorstr('optimizer:')} 'optimizer=auto' found, "
- f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
- f"determining best 'optimizer', 'lr0' and 'momentum' automatically... ")
- nc = getattr(model, 'nc', 10) # number of classes
+ bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
+ if name == "auto":
+ LOGGER.info(
+ f"{colorstr('optimizer:')} 'optimizer=auto' found, "
+ f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
+ f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
+ )
+ nc = getattr(model, "nc", 10) # number of classes
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
- name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9)
+ name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
for module_name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
- fullname = f'{module_name}.{param_name}' if module_name else param_name
- if 'bias' in fullname: # bias (no decay)
+ fullname = f"{module_name}.{param_name}" if module_name else param_name
+ if "bias" in fullname: # bias (no decay)
g[2].append(param)
elif isinstance(module, bn): # weight (no decay)
g[1].append(param)
else: # weight (with decay)
g[0].append(param)
- if name in ('Adam', 'Adamax', 'AdamW', 'NAdam', 'RAdam'):
+ if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
- elif name == 'RMSProp':
+ elif name == "RMSProp":
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
- elif name == 'SGD':
+ elif name == "SGD":
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
else:
raise NotImplementedError(
f"Optimizer '{name}' not found in list of available optimizers "
- f'[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].'
- 'To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.')
+ f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]."
+ "To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics."
+ )
- optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
- optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
+ optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
+ optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
LOGGER.info(
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
- f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)')
+ f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
+ )
return optimizer
diff --git a/ultralytics/engine/tuner.py b/ultralytics/engine/tuner.py
index 5765bce1..80a554d4 100644
--- a/ultralytics/engine/tuner.py
+++ b/ultralytics/engine/tuner.py
@@ -73,40 +73,43 @@ class Tuner:
Args:
args (dict, optional): Configuration for hyperparameter evolution.
"""
- self.space = args.pop('space', None) or { # key: (min, max, gain(optional))
+ self.space = args.pop("space", None) or { # key: (min, max, gain(optional))
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
- 'lr0': (1e-5, 1e-1), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
- 'lrf': (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf)
- 'momentum': (0.7, 0.98, 0.3), # SGD momentum/Adam beta1
- 'weight_decay': (0.0, 0.001), # optimizer weight decay 5e-4
- 'warmup_epochs': (0.0, 5.0), # warmup epochs (fractions ok)
- 'warmup_momentum': (0.0, 0.95), # warmup initial momentum
- 'box': (1.0, 20.0), # box loss gain
- 'cls': (0.2, 4.0), # cls loss gain (scale with pixels)
- 'dfl': (0.4, 6.0), # dfl loss gain
- 'hsv_h': (0.0, 0.1), # image HSV-Hue augmentation (fraction)
- 'hsv_s': (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
- 'hsv_v': (0.0, 0.9), # image HSV-Value augmentation (fraction)
- 'degrees': (0.0, 45.0), # image rotation (+/- deg)
- 'translate': (0.0, 0.9), # image translation (+/- fraction)
- 'scale': (0.0, 0.95), # image scale (+/- gain)
- 'shear': (0.0, 10.0), # image shear (+/- deg)
- 'perspective': (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
- 'flipud': (0.0, 1.0), # image flip up-down (probability)
- 'fliplr': (0.0, 1.0), # image flip left-right (probability)
- 'mosaic': (0.0, 1.0), # image mixup (probability)
- 'mixup': (0.0, 1.0), # image mixup (probability)
- 'copy_paste': (0.0, 1.0)} # segment copy-paste (probability)
+ "lr0": (1e-5, 1e-1), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
+ "lrf": (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf)
+ "momentum": (0.7, 0.98, 0.3), # SGD momentum/Adam beta1
+ "weight_decay": (0.0, 0.001), # optimizer weight decay 5e-4
+ "warmup_epochs": (0.0, 5.0), # warmup epochs (fractions ok)
+ "warmup_momentum": (0.0, 0.95), # warmup initial momentum
+ "box": (1.0, 20.0), # box loss gain
+ "cls": (0.2, 4.0), # cls loss gain (scale with pixels)
+ "dfl": (0.4, 6.0), # dfl loss gain
+ "hsv_h": (0.0, 0.1), # image HSV-Hue augmentation (fraction)
+ "hsv_s": (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
+ "hsv_v": (0.0, 0.9), # image HSV-Value augmentation (fraction)
+ "degrees": (0.0, 45.0), # image rotation (+/- deg)
+ "translate": (0.0, 0.9), # image translation (+/- fraction)
+ "scale": (0.0, 0.95), # image scale (+/- gain)
+ "shear": (0.0, 10.0), # image shear (+/- deg)
+ "perspective": (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
+ "flipud": (0.0, 1.0), # image flip up-down (probability)
+ "fliplr": (0.0, 1.0), # image flip left-right (probability)
+ "mosaic": (0.0, 1.0), # image mixup (probability)
+ "mixup": (0.0, 1.0), # image mixup (probability)
+ "copy_paste": (0.0, 1.0), # segment copy-paste (probability)
+ }
self.args = get_cfg(overrides=args)
- self.tune_dir = get_save_dir(self.args, name='tune')
- self.tune_csv = self.tune_dir / 'tune_results.csv'
+ self.tune_dir = get_save_dir(self.args, name="tune")
+ self.tune_csv = self.tune_dir / "tune_results.csv"
self.callbacks = _callbacks or callbacks.get_default_callbacks()
- self.prefix = colorstr('Tuner: ')
+ self.prefix = colorstr("Tuner: ")
callbacks.add_integration_callbacks(self)
- LOGGER.info(f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
- f'{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning')
+ LOGGER.info(
+ f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
+ f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning"
+ )
- def _mutate(self, parent='single', n=5, mutation=0.8, sigma=0.2):
+ def _mutate(self, parent="single", n=5, mutation=0.8, sigma=0.2):
"""
Mutates the hyperparameters based on bounds and scaling factors specified in `self.space`.
@@ -121,15 +124,15 @@ class Tuner:
"""
if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate
# Select parent(s)
- x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=',', skiprows=1)
+ x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
fitness = x[:, 0] # first column
n = min(n, len(x)) # number of previous results to consider
x = x[np.argsort(-fitness)][:n] # top n mutations
- w = x[:, 0] - x[:, 0].min() + 1E-6 # weights (sum > 0)
- if parent == 'single' or len(x) == 1:
+ w = x[:, 0] - x[:, 0].min() + 1e-6 # weights (sum > 0)
+ if parent == "single" or len(x) == 1:
# x = x[random.randint(0, n - 1)] # random selection
x = x[random.choices(range(n), weights=w)[0]] # weighted selection
- elif parent == 'weighted':
+ elif parent == "weighted":
x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
# Mutate
@@ -174,44 +177,44 @@ class Tuner:
t0 = time.time()
best_save_dir, best_metrics = None, None
- (self.tune_dir / 'weights').mkdir(parents=True, exist_ok=True)
+ (self.tune_dir / "weights").mkdir(parents=True, exist_ok=True)
for i in range(iterations):
# Mutate hyperparameters
mutated_hyp = self._mutate()
- LOGGER.info(f'{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}')
+ LOGGER.info(f"{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}")
metrics = {}
train_args = {**vars(self.args), **mutated_hyp}
save_dir = get_save_dir(get_cfg(train_args))
- weights_dir = save_dir / 'weights'
- ckpt_file = weights_dir / ('best.pt' if (weights_dir / 'best.pt').exists() else 'last.pt')
+ weights_dir = save_dir / "weights"
+ ckpt_file = weights_dir / ("best.pt" if (weights_dir / "best.pt").exists() else "last.pt")
try:
# Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang)
- cmd = ['yolo', 'train', *(f'{k}={v}' for k, v in train_args.items())]
+ cmd = ["yolo", "train", *(f"{k}={v}" for k, v in train_args.items())]
return_code = subprocess.run(cmd, check=True).returncode
- metrics = torch.load(ckpt_file)['train_metrics']
- assert return_code == 0, 'training failed'
+ metrics = torch.load(ckpt_file)["train_metrics"]
+ assert return_code == 0, "training failed"
except Exception as e:
- LOGGER.warning(f'WARNING ❌️ training failure for hyperparameter tuning iteration {i + 1}\n{e}')
+ LOGGER.warning(f"WARNING ❌️ training failure for hyperparameter tuning iteration {i + 1}\n{e}")
# Save results and mutated_hyp to CSV
- fitness = metrics.get('fitness', 0.0)
+ fitness = metrics.get("fitness", 0.0)
log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
- headers = '' if self.tune_csv.exists() else (','.join(['fitness'] + list(self.space.keys())) + '\n')
- with open(self.tune_csv, 'a') as f:
- f.write(headers + ','.join(map(str, log_row)) + '\n')
+ headers = "" if self.tune_csv.exists() else (",".join(["fitness"] + list(self.space.keys())) + "\n")
+ with open(self.tune_csv, "a") as f:
+ f.write(headers + ",".join(map(str, log_row)) + "\n")
# Get best results
- x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=',', skiprows=1)
+ x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
fitness = x[:, 0] # first column
best_idx = fitness.argmax()
best_is_current = best_idx == i
if best_is_current:
best_save_dir = save_dir
best_metrics = {k: round(v, 5) for k, v in metrics.items()}
- for ckpt in weights_dir.glob('*.pt'):
- shutil.copy2(ckpt, self.tune_dir / 'weights')
+ for ckpt in weights_dir.glob("*.pt"):
+ shutil.copy2(ckpt, self.tune_dir / "weights")
elif cleanup:
shutil.rmtree(ckpt_file.parent) # remove iteration weights/ dir to reduce storage space
@@ -219,15 +222,19 @@ class Tuner:
plot_tune_results(self.tune_csv)
# Save and print tune results
- header = (f'{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n'
- f'{self.prefix}Results saved to {colorstr("bold", self.tune_dir)}\n'
- f'{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n'
- f'{self.prefix}Best fitness metrics are {best_metrics}\n'
- f'{self.prefix}Best fitness model is {best_save_dir}\n'
- f'{self.prefix}Best fitness hyperparameters are printed below.\n')
- LOGGER.info('\n' + header)
+ header = (
+ f'{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n'
+ f'{self.prefix}Results saved to {colorstr("bold", self.tune_dir)}\n'
+ f'{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n'
+ f'{self.prefix}Best fitness metrics are {best_metrics}\n'
+ f'{self.prefix}Best fitness model is {best_save_dir}\n'
+ f'{self.prefix}Best fitness hyperparameters are printed below.\n'
+ )
+ LOGGER.info("\n" + header)
data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())}
- yaml_save(self.tune_dir / 'best_hyperparameters.yaml',
- data=data,
- header=remove_colorstr(header.replace(self.prefix, '# ')) + '\n')
- yaml_print(self.tune_dir / 'best_hyperparameters.yaml')
+ yaml_save(
+ self.tune_dir / "best_hyperparameters.yaml",
+ data=data,
+ header=remove_colorstr(header.replace(self.prefix, "# ")) + "\n",
+ )
+ yaml_print(self.tune_dir / "best_hyperparameters.yaml")
diff --git a/ultralytics/engine/validator.py b/ultralytics/engine/validator.py
index cf906d9e..5761113f 100644
--- a/ultralytics/engine/validator.py
+++ b/ultralytics/engine/validator.py
@@ -89,10 +89,10 @@ class BaseValidator:
self.nc = None
self.iouv = None
self.jdict = None
- self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
+ self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
self.save_dir = save_dir or get_save_dir(self.args)
- (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
+ (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
if self.args.conf is None:
self.args.conf = 0.001 # default conf=0.001
self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
@@ -110,7 +110,7 @@ class BaseValidator:
if self.training:
self.device = trainer.device
self.data = trainer.data
- self.args.half = self.device.type != 'cpu' # force FP16 val during training
+ self.args.half = self.device.type != "cpu" # force FP16 val during training
model = trainer.ema.ema or trainer.model
model = model.half() if self.args.half else model.float()
# self.model = model
@@ -119,11 +119,13 @@ class BaseValidator:
model.eval()
else:
callbacks.add_integration_callbacks(self)
- model = AutoBackend(model or self.args.model,
- device=select_device(self.args.device, self.args.batch),
- dnn=self.args.dnn,
- data=self.args.data,
- fp16=self.args.half)
+ model = AutoBackend(
+ model or self.args.model,
+ device=select_device(self.args.device, self.args.batch),
+ dnn=self.args.dnn,
+ data=self.args.data,
+ fp16=self.args.half,
+ )
# self.model = model
self.device = model.device # update device
self.args.half = model.fp16 # update half
@@ -133,16 +135,16 @@ class BaseValidator:
self.args.batch = model.batch_size
elif not pt and not jit:
self.args.batch = 1 # export.py models default to batch-size 1
- LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
+ LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
- if str(self.args.data).split('.')[-1] in ('yaml', 'yml'):
+ if str(self.args.data).split(".")[-1] in ("yaml", "yml"):
self.data = check_det_dataset(self.args.data)
- elif self.args.task == 'classify':
+ elif self.args.task == "classify":
self.data = check_cls_dataset(self.args.data, split=self.args.split)
else:
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
- if self.device.type in ('cpu', 'mps'):
+ if self.device.type in ("cpu", "mps"):
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
if not pt:
self.args.rect = False
@@ -152,13 +154,13 @@ class BaseValidator:
model.eval()
model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup
- self.run_callbacks('on_val_start')
+ self.run_callbacks("on_val_start")
dt = Profile(), Profile(), Profile(), Profile()
bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
self.init_metrics(de_parallel(model))
self.jdict = [] # empty before each val
for batch_i, batch in enumerate(bar):
- self.run_callbacks('on_val_batch_start')
+ self.run_callbacks("on_val_batch_start")
self.batch_i = batch_i
# Preprocess
with dt[0]:
@@ -166,7 +168,7 @@ class BaseValidator:
# Inference
with dt[1]:
- preds = model(batch['img'], augment=augment)
+ preds = model(batch["img"], augment=augment)
# Loss
with dt[2]:
@@ -182,23 +184,25 @@ class BaseValidator:
self.plot_val_samples(batch, batch_i)
self.plot_predictions(batch, preds, batch_i)
- self.run_callbacks('on_val_batch_end')
+ self.run_callbacks("on_val_batch_end")
stats = self.get_stats()
self.check_stats(stats)
- self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt)))
+ self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
self.finalize_metrics()
self.print_results()
- self.run_callbacks('on_val_end')
+ self.run_callbacks("on_val_end")
if self.training:
model.float()
- results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
+ results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
else:
- LOGGER.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
- tuple(self.speed.values()))
+ LOGGER.info(
+ "Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image"
+ % tuple(self.speed.values())
+ )
if self.args.save_json and self.jdict:
- with open(str(self.save_dir / 'predictions.json'), 'w') as f:
- LOGGER.info(f'Saving {f.name}...')
+ with open(str(self.save_dir / "predictions.json"), "w") as f:
+ LOGGER.info(f"Saving {f.name}...")
json.dump(self.jdict, f) # flatten and save
stats = self.eval_json(stats) # update stats
if self.args.plots or self.args.save_json:
@@ -228,6 +232,7 @@ class BaseValidator:
if use_scipy:
# WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
import scipy # scope import to avoid importing for all commands
+
cost_matrix = iou * (iou >= threshold)
if cost_matrix.any():
labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True)
@@ -257,11 +262,11 @@ class BaseValidator:
def get_dataloader(self, dataset_path, batch_size):
"""Get data loader from dataset path and batch size."""
- raise NotImplementedError('get_dataloader function not implemented for this validator')
+ raise NotImplementedError("get_dataloader function not implemented for this validator")
def build_dataset(self, img_path):
"""Build dataset."""
- raise NotImplementedError('build_dataset function not implemented in validator')
+ raise NotImplementedError("build_dataset function not implemented in validator")
def preprocess(self, batch):
"""Preprocesses an input batch."""
@@ -306,7 +311,7 @@ class BaseValidator:
def on_plot(self, name, data=None):
"""Registers plots (e.g. to be consumed in callbacks)"""
- self.plots[Path(name)] = {'data': data, 'timestamp': time.time()}
+ self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
# TODO: may need to put these following functions into callback
def plot_val_samples(self, batch, ni):
diff --git a/ultralytics/hub/__init__.py b/ultralytics/hub/__init__.py
index 13bc9246..df17da5e 100644
--- a/ultralytics/hub/__init__.py
+++ b/ultralytics/hub/__init__.py
@@ -21,10 +21,10 @@ def login(api_key: str = None, save=True) -> bool:
Returns:
bool: True if authentication is successful, False otherwise.
"""
- api_key_url = f'{HUB_WEB_ROOT}/settings?tab=api+keys' # Set the redirect URL
- saved_key = SETTINGS.get('api_key')
+ api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys" # set the redirect URL
+ saved_key = SETTINGS.get("api_key")
active_key = api_key or saved_key
- credentials = {'api_key': active_key} if active_key and active_key != '' else None # Set credentials
+ credentials = {"api_key": active_key} if active_key and active_key != "" else None # set credentials
client = HUBClient(credentials) # initialize HUBClient
@@ -32,17 +32,18 @@ def login(api_key: str = None, save=True) -> bool:
# Successfully authenticated with HUB
if save and client.api_key != saved_key:
- SETTINGS.update({'api_key': client.api_key}) # update settings with valid API key
+ SETTINGS.update({"api_key": client.api_key}) # update settings with valid API key
# Set message based on whether key was provided or retrieved from settings
- log_message = ('New authentication successful ✅'
- if client.api_key == api_key or not credentials else 'Authenticated ✅')
- LOGGER.info(f'{PREFIX}{log_message}')
+ log_message = (
+ "New authentication successful ✅" if client.api_key == api_key or not credentials else "Authenticated ✅"
+ )
+ LOGGER.info(f"{PREFIX}{log_message}")
return True
else:
# Failed to authenticate with HUB
- LOGGER.info(f'{PREFIX}Retrieve API key from {api_key_url}')
+ LOGGER.info(f"{PREFIX}Retrieve API key from {api_key_url}")
return False
@@ -57,50 +58,50 @@ def logout():
hub.logout()
```
"""
- SETTINGS['api_key'] = ''
+ SETTINGS["api_key"] = ""
SETTINGS.save()
LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")
-def reset_model(model_id=''):
+def reset_model(model_id=""):
"""Reset a trained model to an untrained state."""
- r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'modelId': model_id}, headers={'x-api-key': Auth().api_key})
+ r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key})
if r.status_code == 200:
- LOGGER.info(f'{PREFIX}Model reset successfully')
+ LOGGER.info(f"{PREFIX}Model reset successfully")
return
- LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}')
+ LOGGER.warning(f"{PREFIX}Model reset failure {r.status_code} {r.reason}")
def export_fmts_hub():
"""Returns a list of HUB-supported export formats."""
from ultralytics.engine.exporter import export_formats
- return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml']
+ return list(export_formats()["Argument"][1:]) + ["ultralytics_tflite", "ultralytics_coreml"]
-def export_model(model_id='', format='torchscript'):
+
+def export_model(model_id="", format="torchscript"):
"""Export a model to all formats."""
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
- r = requests.post(f'{HUB_API_ROOT}/v1/models/{model_id}/export',
- json={'format': format},
- headers={'x-api-key': Auth().api_key})
- assert r.status_code == 200, f'{PREFIX}{format} export failure {r.status_code} {r.reason}'
- LOGGER.info(f'{PREFIX}{format} export started ✅')
+ r = requests.post(
+ f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key}
+ )
+ assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}"
+ LOGGER.info(f"{PREFIX}{format} export started ✅")
-def get_export(model_id='', format='torchscript'):
+def get_export(model_id="", format="torchscript"):
"""Get an exported model dictionary with download URL."""
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
- r = requests.post(f'{HUB_API_ROOT}/get-export',
- json={
- 'apiKey': Auth().api_key,
- 'modelId': model_id,
- 'format': format},
- headers={'x-api-key': Auth().api_key})
- assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
+ r = requests.post(
+ f"{HUB_API_ROOT}/get-export",
+ json={"apiKey": Auth().api_key, "modelId": model_id, "format": format},
+ headers={"x-api-key": Auth().api_key},
+ )
+ assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}"
return r.json()
-def check_dataset(path='', task='detect'):
+def check_dataset(path="", task="detect"):
"""
Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is uploaded
to the HUB. Usage examples are given below.
@@ -119,4 +120,4 @@ def check_dataset(path='', task='detect'):
```
"""
HUBDatasetStats(path=path, task=task).get_json()
- LOGGER.info(f'Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.')
+ LOGGER.info(f"Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.")
diff --git a/ultralytics/hub/auth.py b/ultralytics/hub/auth.py
index 202ca8c4..72aad109 100644
--- a/ultralytics/hub/auth.py
+++ b/ultralytics/hub/auth.py
@@ -6,7 +6,7 @@ from hub_sdk import HUB_API_ROOT, HUB_WEB_ROOT
from ultralytics.hub.utils import PREFIX, request_with_credentials
from ultralytics.utils import LOGGER, SETTINGS, emojis, is_colab
-API_KEY_URL = f'{HUB_WEB_ROOT}/settings?tab=api+keys'
+API_KEY_URL = f"{HUB_WEB_ROOT}/settings?tab=api+keys"
class Auth:
@@ -23,9 +23,10 @@ class Auth:
api_key (str or bool): API key for authentication, initialized as False.
model_key (bool): Placeholder for model key, initialized as False.
"""
+
id_token = api_key = model_key = False
- def __init__(self, api_key='', verbose=False):
+ def __init__(self, api_key="", verbose=False):
"""
Initialize the Auth class with an optional API key.
@@ -33,18 +34,18 @@ class Auth:
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
"""
# Split the input API key in case it contains a combined key_model and keep only the API key part
- api_key = api_key.split('_')[0]
+ api_key = api_key.split("_")[0]
# Set API key attribute as value passed or SETTINGS API key if none passed
- self.api_key = api_key or SETTINGS.get('api_key', '')
+ self.api_key = api_key or SETTINGS.get("api_key", "")
# If an API key is provided
if self.api_key:
# If the provided API key matches the API key in the SETTINGS
- if self.api_key == SETTINGS.get('api_key'):
+ if self.api_key == SETTINGS.get("api_key"):
# Log that the user is already logged in
if verbose:
- LOGGER.info(f'{PREFIX}Authenticated ✅')
+ LOGGER.info(f"{PREFIX}Authenticated ✅")
return
else:
# Attempt to authenticate with the provided API key
@@ -59,12 +60,12 @@ class Auth:
# Update SETTINGS with the new API key after successful authentication
if success:
- SETTINGS.update({'api_key': self.api_key})
+ SETTINGS.update({"api_key": self.api_key})
# Log that the new login was successful
if verbose:
- LOGGER.info(f'{PREFIX}New authentication successful ✅')
+ LOGGER.info(f"{PREFIX}New authentication successful ✅")
elif verbose:
- LOGGER.info(f'{PREFIX}Retrieve API key from {API_KEY_URL}')
+ LOGGER.info(f"{PREFIX}Retrieve API key from {API_KEY_URL}")
def request_api_key(self, max_attempts=3):
"""
@@ -73,13 +74,14 @@ class Auth:
Returns the model ID.
"""
import getpass
+
for attempts in range(max_attempts):
- LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}')
- input_key = getpass.getpass(f'Enter API key from {API_KEY_URL} ')
- self.api_key = input_key.split('_')[0] # remove model id if present
+ LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}")
+ input_key = getpass.getpass(f"Enter API key from {API_KEY_URL} ")
+ self.api_key = input_key.split("_")[0] # remove model id if present
if self.authenticate():
return True
- raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌'))
+ raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
def authenticate(self) -> bool:
"""
@@ -90,14 +92,14 @@ class Auth:
"""
try:
if header := self.get_auth_header():
- r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header)
- if not r.json().get('success', False):
- raise ConnectionError('Unable to authenticate.')
+ r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header)
+ if not r.json().get("success", False):
+ raise ConnectionError("Unable to authenticate.")
return True
- raise ConnectionError('User has not authenticated locally.')
+ raise ConnectionError("User has not authenticated locally.")
except ConnectionError:
self.id_token = self.api_key = False # reset invalid
- LOGGER.warning(f'{PREFIX}Invalid API key ⚠️')
+ LOGGER.warning(f"{PREFIX}Invalid API key ⚠️")
return False
def auth_with_cookies(self) -> bool:
@@ -111,12 +113,12 @@ class Auth:
if not is_colab():
return False # Currently only works with Colab
try:
- authn = request_with_credentials(f'{HUB_API_ROOT}/v1/auth/auto')
- if authn.get('success', False):
- self.id_token = authn.get('data', {}).get('idToken', None)
+ authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto")
+ if authn.get("success", False):
+ self.id_token = authn.get("data", {}).get("idToken", None)
self.authenticate()
return True
- raise ConnectionError('Unable to fetch browser authentication details.')
+ raise ConnectionError("Unable to fetch browser authentication details.")
except ConnectionError:
self.id_token = False # reset invalid
return False
@@ -129,7 +131,7 @@ class Auth:
(dict): The authentication header if id_token or API key is set, None otherwise.
"""
if self.id_token:
- return {'authorization': f'Bearer {self.id_token}'}
+ return {"authorization": f"Bearer {self.id_token}"}
elif self.api_key:
- return {'x-api-key': self.api_key}
+ return {"x-api-key": self.api_key}
# else returns None
diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py
index dd3d01c0..150f4eb4 100644
--- a/ultralytics/hub/session.py
+++ b/ultralytics/hub/session.py
@@ -12,16 +12,13 @@ from ultralytics.hub.utils import HELP_MSG, PREFIX, TQDM
from ultralytics.utils import LOGGER, SETTINGS, __version__, checks, emojis, is_colab
from ultralytics.utils.errors import HUBModelError
-AGENT_NAME = (f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local')
+AGENT_NAME = f"python-{__version__}-colab" if is_colab() else f"python-{__version__}-local"
class HUBTrainingSession:
"""
HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
- Args:
- url (str): Model identifier used to initialize the HUB training session.
-
Attributes:
agent_id (str): Identifier for the instance communicating with the server.
model_id (str): Identifier for the YOLO model being trained.
@@ -40,17 +37,18 @@ class HUBTrainingSession:
Initialize the HUBTrainingSession with the provided model identifier.
Args:
- url (str): Model identifier used to initialize the HUB training session.
- It can be a URL string or a model key with specific format.
+ identifier (str): Model identifier used to initialize the HUB training session.
+ It can be a URL string or a model key with specific format.
Raises:
ValueError: If the provided model identifier is invalid.
ConnectionError: If connecting with global API key is not supported.
"""
self.rate_limits = {
- 'metrics': 3.0,
- 'ckpt': 900.0,
- 'heartbeat': 300.0, } # rate limits (seconds)
+ "metrics": 3.0,
+ "ckpt": 900.0,
+ "heartbeat": 300.0,
+ } # rate limits (seconds)
self.metrics_queue = {} # holds metrics for each epoch until upload
self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
@@ -58,8 +56,8 @@ class HUBTrainingSession:
api_key, model_id, self.filename = self._parse_identifier(identifier)
# Get credentials
- active_key = api_key or SETTINGS.get('api_key')
- credentials = {'api_key': active_key} if active_key else None # set credentials
+ active_key = api_key or SETTINGS.get("api_key")
+ credentials = {"api_key": active_key} if active_key else None # set credentials
# Initialize client
self.client = HUBClient(credentials)
@@ -72,35 +70,37 @@ class HUBTrainingSession:
def load_model(self, model_id):
# Initialize model
self.model = self.client.model(model_id)
- self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}'
+ self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
self._set_train_args()
# Start heartbeats for HUB to monitor agent
- self.model.start_heartbeat(self.rate_limits['heartbeat'])
- LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
+ self.model.start_heartbeat(self.rate_limits["heartbeat"])
+ LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
def create_model(self, model_args):
# Initialize model
payload = {
- 'config': {
- 'batchSize': model_args.get('batch', -1),
- 'epochs': model_args.get('epochs', 300),
- 'imageSize': model_args.get('imgsz', 640),
- 'patience': model_args.get('patience', 100),
- 'device': model_args.get('device', ''),
- 'cache': model_args.get('cache', 'ram'), },
- 'dataset': {
- 'name': model_args.get('data')},
- 'lineage': {
- 'architecture': {
- 'name': self.filename.replace('.pt', '').replace('.yaml', ''), },
- 'parent': {}, },
- 'meta': {
- 'name': self.filename}, }
-
- if self.filename.endswith('.pt'):
- payload['lineage']['parent']['name'] = self.filename
+ "config": {
+ "batchSize": model_args.get("batch", -1),
+ "epochs": model_args.get("epochs", 300),
+ "imageSize": model_args.get("imgsz", 640),
+ "patience": model_args.get("patience", 100),
+ "device": model_args.get("device", ""),
+ "cache": model_args.get("cache", "ram"),
+ },
+ "dataset": {"name": model_args.get("data")},
+ "lineage": {
+ "architecture": {
+ "name": self.filename.replace(".pt", "").replace(".yaml", ""),
+ },
+ "parent": {},
+ },
+ "meta": {"name": self.filename},
+ }
+
+ if self.filename.endswith(".pt"):
+ payload["lineage"]["parent"]["name"] = self.filename
self.model.create_model(payload)
@@ -109,12 +109,12 @@ class HUBTrainingSession:
if not self.model.id:
return
- self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}'
+ self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
# Start heartbeats for HUB to monitor agent
- self.model.start_heartbeat(self.rate_limits['heartbeat'])
+ self.model.start_heartbeat(self.rate_limits["heartbeat"])
- LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
+ LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
def _parse_identifier(self, identifier):
"""
@@ -125,13 +125,13 @@ class HUBTrainingSession:
- An identifier containing an API key and a model ID separated by an underscore
- An identifier that is solely a model ID of a fixed length
- A local filename that ends with '.pt' or '.yaml'
-
+
Args:
identifier (str): The identifier string to be parsed.
-
+
Returns:
(tuple): A tuple containing the API key, model ID, and filename as applicable.
-
+
Raises:
HUBModelError: If the identifier format is not recognized.
"""
@@ -140,12 +140,12 @@ class HUBTrainingSession:
api_key, model_id, filename = None, None, None
# Check if identifier is a HUB URL
- if identifier.startswith(f'{HUB_WEB_ROOT}/models/'):
+ if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
# Extract the model_id after the HUB_WEB_ROOT URL
- model_id = identifier.split(f'{HUB_WEB_ROOT}/models/')[-1]
+ model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
else:
# Split the identifier based on underscores only if it's not a HUB URL
- parts = identifier.split('_')
+ parts = identifier.split("_")
# Check if identifier is in the format of API key and model ID
if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
@@ -154,43 +154,46 @@ class HUBTrainingSession:
elif len(parts) == 1 and len(parts[0]) == 20:
model_id = parts[0]
# Check if identifier is a local filename
- elif identifier.endswith('.pt') or identifier.endswith('.yaml'):
+ elif identifier.endswith(".pt") or identifier.endswith(".yaml"):
filename = identifier
else:
raise HUBModelError(
f"model='{identifier}' could not be parsed. Check format is correct. "
- f'Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file.')
+ f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file."
+ )
return api_key, model_id, filename
def _set_train_args(self, **kwargs):
if self.model.is_trained():
# Model is already trained
- raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀'))
+ raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀"))
if self.model.is_resumable():
# Model has saved weights
- self.train_args = {'data': self.model.get_dataset_url(), 'resume': True}
- self.model_file = self.model.get_weights_url('last')
+ self.train_args = {"data": self.model.get_dataset_url(), "resume": True}
+ self.model_file = self.model.get_weights_url("last")
else:
# Model has no saved weights
def get_train_args(config):
return {
- 'batch': config['batchSize'],
- 'epochs': config['epochs'],
- 'imgsz': config['imageSize'],
- 'patience': config['patience'],
- 'device': config['device'],
- 'cache': config['cache'],
- 'data': self.model.get_dataset_url(), }
-
- self.train_args = get_train_args(self.model.data.get('config'))
+ "batch": config["batchSize"],
+ "epochs": config["epochs"],
+ "imgsz": config["imageSize"],
+ "patience": config["patience"],
+ "device": config["device"],
+ "cache": config["cache"],
+ "data": self.model.get_dataset_url(),
+ }
+
+ self.train_args = get_train_args(self.model.data.get("config"))
# Set the model file as either a *.pt or *.yaml file
- self.model_file = (self.model.get_weights_url('parent')
- if self.model.is_pretrained() else self.model.get_architecture())
+ self.model_file = (
+ self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture()
+ )
- if not self.train_args.get('data'):
- raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
+ if not self.train_args.get("data"):
+ raise ValueError("Dataset may still be processing. Please wait a minute and try again.") # RF fix
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
self.model_id = self.model.id
@@ -206,12 +209,11 @@ class HUBTrainingSession:
*args,
**kwargs,
):
-
def retry_request():
t0 = time.time() # Record the start time for the timeout
for i in range(retry + 1):
if (time.time() - t0) > timeout:
- LOGGER.warning(f'{PREFIX}Timeout for request reached. {HELP_MSG}')
+ LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}")
break # Timeout reached, exit loop
response = request_func(*args, **kwargs)
@@ -219,8 +221,8 @@ class HUBTrainingSession:
self._show_upload_progress(progress_total, response)
if response is None:
- LOGGER.warning(f'{PREFIX}Received no response from the request. {HELP_MSG}')
- time.sleep(2 ** i) # Exponential backoff before retrying
+ LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}")
+ time.sleep(2**i) # Exponential backoff before retrying
continue # Skip further processing and retry
if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
@@ -231,13 +233,13 @@ class HUBTrainingSession:
message = self._get_failure_message(response, retry, timeout)
if verbose:
- LOGGER.warning(f'{PREFIX}{message} {HELP_MSG} ({response.status_code})')
+ LOGGER.warning(f"{PREFIX}{message} {HELP_MSG} ({response.status_code})")
if not self._should_retry(response.status_code):
- LOGGER.warning(f'{PREFIX}Request failed. {HELP_MSG} ({response.status_code}')
+ LOGGER.warning(f"{PREFIX}Request failed. {HELP_MSG} ({response.status_code}")
break # Not an error that should be retried, exit loop
- time.sleep(2 ** i) # Exponential backoff for retries
+ time.sleep(2**i) # Exponential backoff for retries
return response
@@ -253,7 +255,8 @@ class HUBTrainingSession:
retry_codes = {
HTTPStatus.REQUEST_TIMEOUT,
HTTPStatus.BAD_GATEWAY,
- HTTPStatus.GATEWAY_TIMEOUT, }
+ HTTPStatus.GATEWAY_TIMEOUT,
+ }
return True if status_code in retry_codes else False
def _get_failure_message(self, response: requests.Response, retry: int, timeout: int):
@@ -269,16 +272,18 @@ class HUBTrainingSession:
str: The retry message.
"""
if self._should_retry(response.status_code):
- return f'Retrying {retry}x for {timeout}s.' if retry else ''
+ return f"Retrying {retry}x for {timeout}s." if retry else ""
elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit
headers = response.headers
- return (f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
- f"Please retry after {headers['Retry-After']}s.")
+ return (
+ f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
+ f"Please retry after {headers['Retry-After']}s."
+ )
else:
try:
- return response.json().get('message', 'No JSON message.')
+ return response.json().get("message", "No JSON message.")
except AttributeError:
- return 'Unable to read JSON.'
+ return "Unable to read JSON."
def upload_metrics(self):
"""Upload model metrics to Ultralytics HUB."""
@@ -303,7 +308,7 @@ class HUBTrainingSession:
final (bool): Indicates if the model is the final model after training.
"""
if Path(weights).is_file():
- progress_total = (Path(weights).stat().st_size if final else None) # Only show progress if final
+ progress_total = Path(weights).stat().st_size if final else None # Only show progress if final
self.request_queue(
self.model.upload_model,
epoch=epoch,
@@ -317,7 +322,7 @@ class HUBTrainingSession:
progress_total=progress_total,
)
else:
- LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.')
+ LOGGER.warning(f"{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.")
def _show_upload_progress(self, content_length: int, response: requests.Response) -> None:
"""
@@ -330,6 +335,6 @@ class HUBTrainingSession:
Returns:
(None)
"""
- with TQDM(total=content_length, unit='B', unit_scale=True, unit_divisor=1024) as pbar:
+ with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:
for data in response.iter_content(chunk_size=1024):
pbar.update(len(data))
diff --git a/ultralytics/hub/utils.py b/ultralytics/hub/utils.py
index 1277c63a..2cdd0350 100644
--- a/ultralytics/hub/utils.py
+++ b/ultralytics/hub/utils.py
@@ -9,12 +9,26 @@ from pathlib import Path
import requests
-from ultralytics.utils import (ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING, TQDM, TryExcept, __version__,
- colorstr, get_git_origin_url, is_colab, is_git_dir, is_pip_package)
+from ultralytics.utils import (
+ ENVIRONMENT,
+ LOGGER,
+ ONLINE,
+ RANK,
+ SETTINGS,
+ TESTS_RUNNING,
+ TQDM,
+ TryExcept,
+ __version__,
+ colorstr,
+ get_git_origin_url,
+ is_colab,
+ is_git_dir,
+ is_pip_package,
+)
from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES
-PREFIX = colorstr('Ultralytics HUB: ')
-HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
+PREFIX = colorstr("Ultralytics HUB: ")
+HELP_MSG = "If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance."
def request_with_credentials(url: str) -> any:
@@ -31,11 +45,13 @@ def request_with_credentials(url: str) -> any:
OSError: If the function is not run in a Google Colab environment.
"""
if not is_colab():
- raise OSError('request_with_credentials() must run in a Colab environment')
+ raise OSError("request_with_credentials() must run in a Colab environment")
from google.colab import output # noqa
from IPython import display # noqa
+
display.display(
- display.Javascript("""
+ display.Javascript(
+ """
window._hub_tmp = new Promise((resolve, reject) => {
const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000)
fetch("%s", {
@@ -50,8 +66,11 @@ def request_with_credentials(url: str) -> any:
reject(err);
});
});
- """ % url))
- return output.eval_js('_hub_tmp')
+ """
+ % url
+ )
+ )
+ return output.eval_js("_hub_tmp")
def requests_with_progress(method, url, **kwargs):
@@ -71,13 +90,13 @@ def requests_with_progress(method, url, **kwargs):
content length.
- If 'progress' is a number then progress bar will display assuming content length = progress.
"""
- progress = kwargs.pop('progress', False)
+ progress = kwargs.pop("progress", False)
if not progress:
return requests.request(method, url, **kwargs)
response = requests.request(method, url, stream=True, **kwargs)
- total = int(response.headers.get('content-length', 0) if isinstance(progress, bool) else progress) # total size
+ total = int(response.headers.get("content-length", 0) if isinstance(progress, bool) else progress) # total size
try:
- pbar = TQDM(total=total, unit='B', unit_scale=True, unit_divisor=1024)
+ pbar = TQDM(total=total, unit="B", unit_scale=True, unit_divisor=1024)
for data in response.iter_content(chunk_size=1024):
pbar.update(len(data))
pbar.close()
@@ -118,25 +137,27 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos
if r.status_code < 300: # return codes in the 2xx range are generally considered "good" or "successful"
break
try:
- m = r.json().get('message', 'No JSON message.')
+ m = r.json().get("message", "No JSON message.")
except AttributeError:
- m = 'Unable to read JSON.'
+ m = "Unable to read JSON."
if i == 0:
if r.status_code in retry_codes:
- m += f' Retrying {retry}x for {timeout}s.' if retry else ''
+ m += f" Retrying {retry}x for {timeout}s." if retry else ""
elif r.status_code == 429: # rate limit
h = r.headers # response headers
- m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \
+ m = (
+ f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). "
f"Please retry after {h['Retry-After']}s."
+ )
if verbose:
- LOGGER.warning(f'{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})')
+ LOGGER.warning(f"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})")
if r.status_code not in retry_codes:
return r
- time.sleep(2 ** i) # exponential standoff
+ time.sleep(2**i) # exponential standoff
return r
args = method, url
- kwargs['progress'] = progress
+ kwargs["progress"] = progress
if thread:
threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
else:
@@ -155,7 +176,7 @@ class Events:
enabled (bool): A flag to enable or disable Events based on certain conditions.
"""
- url = 'https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw'
+ url = "https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw"
def __init__(self):
"""Initializes the Events object with default values for events, rate_limit, and metadata."""
@@ -163,19 +184,21 @@ class Events:
self.rate_limit = 60.0 # rate limit (seconds)
self.t = 0.0 # rate limit timer (seconds)
self.metadata = {
- 'cli': Path(sys.argv[0]).name == 'yolo',
- 'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
- 'python': '.'.join(platform.python_version_tuple()[:2]), # i.e. 3.10
- 'version': __version__,
- 'env': ENVIRONMENT,
- 'session_id': round(random.random() * 1E15),
- 'engagement_time_msec': 1000}
- self.enabled = \
- SETTINGS['sync'] and \
- RANK in (-1, 0) and \
- not TESTS_RUNNING and \
- ONLINE and \
- (is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git')
+ "cli": Path(sys.argv[0]).name == "yolo",
+ "install": "git" if is_git_dir() else "pip" if is_pip_package() else "other",
+ "python": ".".join(platform.python_version_tuple()[:2]), # i.e. 3.10
+ "version": __version__,
+ "env": ENVIRONMENT,
+ "session_id": round(random.random() * 1e15),
+ "engagement_time_msec": 1000,
+ }
+ self.enabled = (
+ SETTINGS["sync"]
+ and RANK in (-1, 0)
+ and not TESTS_RUNNING
+ and ONLINE
+ and (is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git")
+ )
def __call__(self, cfg):
"""
@@ -191,11 +214,13 @@ class Events:
# Attempt to add to events
if len(self.events) < 25: # Events list limited to 25 events (drop any events past this)
params = {
- **self.metadata, 'task': cfg.task,
- 'model': cfg.model if cfg.model in GITHUB_ASSETS_NAMES else 'custom'}
- if cfg.mode == 'export':
- params['format'] = cfg.format
- self.events.append({'name': cfg.mode, 'params': params})
+ **self.metadata,
+ "task": cfg.task,
+ "model": cfg.model if cfg.model in GITHUB_ASSETS_NAMES else "custom",
+ }
+ if cfg.mode == "export":
+ params["format"] = cfg.format
+ self.events.append({"name": cfg.mode, "params": params})
# Check rate limit
t = time.time()
@@ -204,10 +229,10 @@ class Events:
return
# Time is over rate limiter, send now
- data = {'client_id': SETTINGS['uuid'], 'events': self.events} # SHA-256 anonymized UUID hash and events list
+ data = {"client_id": SETTINGS["uuid"], "events": self.events} # SHA-256 anonymized UUID hash and events list
# POST equivalent to requests.post(self.url, json=data)
- smart_request('post', self.url, json=data, retry=0, verbose=False)
+ smart_request("post", self.url, json=data, retry=0, verbose=False)
# Reset events and rate limit timer
self.events = []
diff --git a/ultralytics/models/__init__.py b/ultralytics/models/__init__.py
index e96f893e..ef49faff 100644
--- a/ultralytics/models/__init__.py
+++ b/ultralytics/models/__init__.py
@@ -4,4 +4,4 @@ from .rtdetr import RTDETR
from .sam import SAM
from .yolo import YOLO
-__all__ = 'YOLO', 'RTDETR', 'SAM' # allow simpler import
+__all__ = "YOLO", "RTDETR", "SAM" # allow simpler import
diff --git a/ultralytics/models/fastsam/__init__.py b/ultralytics/models/fastsam/__init__.py
index 8f47772f..eabf5b9f 100644
--- a/ultralytics/models/fastsam/__init__.py
+++ b/ultralytics/models/fastsam/__init__.py
@@ -5,4 +5,4 @@ from .predict import FastSAMPredictor
from .prompt import FastSAMPrompt
from .val import FastSAMValidator
-__all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMValidator'
+__all__ = "FastSAMPredictor", "FastSAM", "FastSAMPrompt", "FastSAMValidator"
diff --git a/ultralytics/models/fastsam/model.py b/ultralytics/models/fastsam/model.py
index e6475faa..8904f8e9 100644
--- a/ultralytics/models/fastsam/model.py
+++ b/ultralytics/models/fastsam/model.py
@@ -21,14 +21,14 @@ class FastSAM(Model):
```
"""
- def __init__(self, model='FastSAM-x.pt'):
+ def __init__(self, model="FastSAM-x.pt"):
"""Call the __init__ method of the parent class (YOLO) with the updated default model."""
- if str(model) == 'FastSAM.pt':
- model = 'FastSAM-x.pt'
- assert Path(model).suffix not in ('.yaml', '.yml'), 'FastSAM models only support pre-trained models.'
- super().__init__(model=model, task='segment')
+ if str(model) == "FastSAM.pt":
+ model = "FastSAM-x.pt"
+ assert Path(model).suffix not in (".yaml", ".yml"), "FastSAM models only support pre-trained models."
+ super().__init__(model=model, task="segment")
@property
def task_map(self):
"""Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
- return {'segment': {'predictor': FastSAMPredictor, 'validator': FastSAMValidator}}
+ return {"segment": {"predictor": FastSAMPredictor, "validator": FastSAMValidator}}
diff --git a/ultralytics/models/fastsam/predict.py b/ultralytics/models/fastsam/predict.py
index 4a3c2e9e..0ef18032 100644
--- a/ultralytics/models/fastsam/predict.py
+++ b/ultralytics/models/fastsam/predict.py
@@ -33,7 +33,7 @@ class FastSAMPredictor(DetectionPredictor):
_callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
"""
super().__init__(cfg, overrides, _callbacks)
- self.args.task = 'segment'
+ self.args.task = "segment"
def postprocess(self, preds, img, orig_imgs):
"""
@@ -55,7 +55,8 @@ class FastSAMPredictor(DetectionPredictor):
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nc=1, # set to 1 class since SAM has no class predictions
- classes=self.args.classes)
+ classes=self.args.classes,
+ )
full_box = torch.zeros(p[0].shape[1], device=p[0].device)
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
full_box = full_box.view(1, -1)
diff --git a/ultralytics/models/fastsam/prompt.py b/ultralytics/models/fastsam/prompt.py
index 0f43441a..921d8322 100644
--- a/ultralytics/models/fastsam/prompt.py
+++ b/ultralytics/models/fastsam/prompt.py
@@ -23,7 +23,7 @@ class FastSAMPrompt:
clip: CLIP model for linear assignment.
"""
- def __init__(self, source, results, device='cuda') -> None:
+ def __init__(self, source, results, device="cuda") -> None:
"""Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment."""
self.device = device
self.results = results
@@ -34,7 +34,8 @@ class FastSAMPrompt:
import clip # for linear_assignment
except ImportError:
from ultralytics.utils.checks import check_requirements
- check_requirements('git+https://github.com/openai/CLIP.git')
+
+ check_requirements("git+https://github.com/openai/CLIP.git")
import clip
self.clip = clip
@@ -46,11 +47,11 @@ class FastSAMPrompt:
x1, y1, x2, y2 = bbox
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
segmented_image = Image.fromarray(segmented_image_array)
- black_image = Image.new('RGB', image.size, (255, 255, 255))
+ black_image = Image.new("RGB", image.size, (255, 255, 255))
# transparency_mask = np.zeros_like((), dtype=np.uint8)
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
transparency_mask[y1:y2, x1:x2] = 255
- transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
+ transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
black_image.paste(segmented_image, mask=transparency_mask_image)
return black_image
@@ -65,11 +66,12 @@ class FastSAMPrompt:
mask = result.masks.data[i] == 1.0
if torch.sum(mask) >= filter:
annotation = {
- 'id': i,
- 'segmentation': mask.cpu().numpy(),
- 'bbox': result.boxes.data[i],
- 'score': result.boxes.conf[i]}
- annotation['area'] = annotation['segmentation'].sum()
+ "id": i,
+ "segmentation": mask.cpu().numpy(),
+ "bbox": result.boxes.data[i],
+ "score": result.boxes.conf[i],
+ }
+ annotation["area"] = annotation["segmentation"].sum()
annotations.append(annotation)
return annotations
@@ -91,16 +93,18 @@ class FastSAMPrompt:
y2 = max(y2, y_t + h_t)
return [x1, y1, x2, y2]
- def plot(self,
- annotations,
- output,
- bbox=None,
- points=None,
- point_label=None,
- mask_random_color=True,
- better_quality=True,
- retina=False,
- with_contours=True):
+ def plot(
+ self,
+ annotations,
+ output,
+ bbox=None,
+ points=None,
+ point_label=None,
+ mask_random_color=True,
+ better_quality=True,
+ retina=False,
+ with_contours=True,
+ ):
"""
Plots annotations, bounding boxes, and points on images and saves the output.
@@ -139,15 +143,17 @@ class FastSAMPrompt:
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
- self.fast_show_mask(masks,
- plt.gca(),
- random_color=mask_random_color,
- bbox=bbox,
- points=points,
- pointlabel=point_label,
- retinamask=retina,
- target_height=original_h,
- target_width=original_w)
+ self.fast_show_mask(
+ masks,
+ plt.gca(),
+ random_color=mask_random_color,
+ bbox=bbox,
+ points=points,
+ pointlabel=point_label,
+ retinamask=retina,
+ target_height=original_h,
+ target_width=original_w,
+ )
if with_contours:
contour_all = []
@@ -166,10 +172,10 @@ class FastSAMPrompt:
# Save the figure
save_path = Path(output) / result_name
save_path.parent.mkdir(exist_ok=True, parents=True)
- plt.axis('off')
- plt.savefig(save_path, bbox_inches='tight', pad_inches=0, transparent=True)
+ plt.axis("off")
+ plt.savefig(save_path, bbox_inches="tight", pad_inches=0, transparent=True)
plt.close()
- pbar.set_description(f'Saving {result_name} to {save_path}')
+ pbar.set_description(f"Saving {result_name} to {save_path}")
@staticmethod
def fast_show_mask(
@@ -212,26 +218,26 @@ class FastSAMPrompt:
mask_image = np.expand_dims(annotation, -1) * visual
show = np.zeros((h, w, 4))
- h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
+ h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing="ij")
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
show[h_indices, w_indices, :] = mask_image[indices]
if bbox is not None:
x1, y1, x2, y2 = bbox
- ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1))
# Draw point
if points is not None:
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
s=20,
- c='y',
+ c="y",
)
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
s=20,
- c='m',
+ c="m",
)
if not retinamask:
@@ -258,7 +264,7 @@ class FastSAMPrompt:
image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB))
ori_w, ori_h = image.size
annotations = format_results
- mask_h, mask_w = annotations[0]['segmentation'].shape
+ mask_h, mask_w = annotations[0]["segmentation"].shape
if ori_w != mask_w or ori_h != mask_h:
image = image.resize((mask_w, mask_h))
cropped_boxes = []
@@ -266,19 +272,19 @@ class FastSAMPrompt:
not_crop = []
filter_id = []
for _, mask in enumerate(annotations):
- if np.sum(mask['segmentation']) <= 100:
+ if np.sum(mask["segmentation"]) <= 100:
filter_id.append(_)
continue
- bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox
- cropped_boxes.append(self._segment_image(image, bbox)) # 保存裁剪的图片
- cropped_images.append(bbox) # 保存裁剪的图片的bbox
+ bbox = self._get_bbox_from_mask(mask["segmentation"]) # bbox from mask
+ cropped_boxes.append(self._segment_image(image, bbox)) # save cropped image
+ cropped_images.append(bbox) # save cropped image bbox
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
def box_prompt(self, bbox):
"""Modifies the bounding box properties and calculates IoU between masks and bounding box."""
if self.results[0].masks is not None:
- assert (bbox[2] != 0 and bbox[3] != 0)
+ assert bbox[2] != 0 and bbox[3] != 0
if os.path.isdir(self.source):
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
masks = self.results[0].masks.data
@@ -290,7 +296,8 @@ class FastSAMPrompt:
int(bbox[0] * w / target_width),
int(bbox[1] * h / target_height),
int(bbox[2] * w / target_width),
- int(bbox[3] * h / target_height), ]
+ int(bbox[3] * h / target_height),
+ ]
bbox[0] = max(round(bbox[0]), 0)
bbox[1] = max(round(bbox[1]), 0)
bbox[2] = min(round(bbox[2]), w)
@@ -299,7 +306,7 @@ class FastSAMPrompt:
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
- masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2))
+ masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
orig_masks_area = torch.sum(masks, dim=(1, 2))
union = bbox_area + orig_masks_area - masks_area
@@ -316,13 +323,13 @@ class FastSAMPrompt:
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
masks = self._format_results(self.results[0], 0)
target_height, target_width = self.results[0].orig_shape
- h = masks[0]['segmentation'].shape[0]
- w = masks[0]['segmentation'].shape[1]
+ h = masks[0]["segmentation"].shape[0]
+ w = masks[0]["segmentation"].shape[1]
if h != target_height or w != target_width:
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
onemask = np.zeros((h, w))
for annotation in masks:
- mask = annotation['segmentation'] if isinstance(annotation, dict) else annotation
+ mask = annotation["segmentation"] if isinstance(annotation, dict) else annotation
for i, point in enumerate(points):
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
onemask += mask
@@ -337,12 +344,12 @@ class FastSAMPrompt:
if self.results[0].masks is not None:
format_results = self._format_results(self.results[0], 0)
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
- clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device)
+ clip_model, preprocess = self.clip.load("ViT-B/32", device=self.device)
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
max_idx = scores.argsort()
max_idx = max_idx[-1]
max_idx += sum(np.array(filter_id) <= int(max_idx))
- self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]['segmentation']]))
+ self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]["segmentation"]]))
return self.results
def everything_prompt(self):
diff --git a/ultralytics/models/fastsam/val.py b/ultralytics/models/fastsam/val.py
index 4e1e0b01..9014b27a 100644
--- a/ultralytics/models/fastsam/val.py
+++ b/ultralytics/models/fastsam/val.py
@@ -35,6 +35,6 @@ class FastSAMValidator(SegmentationValidator):
Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.
"""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
- self.args.task = 'segment'
+ self.args.task = "segment"
self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
diff --git a/ultralytics/models/nas/__init__.py b/ultralytics/models/nas/__init__.py
index eec3837d..b095a050 100644
--- a/ultralytics/models/nas/__init__.py
+++ b/ultralytics/models/nas/__init__.py
@@ -4,4 +4,4 @@ from .model import NAS
from .predict import NASPredictor
from .val import NASValidator
-__all__ = 'NASPredictor', 'NASValidator', 'NAS'
+__all__ = "NASPredictor", "NASValidator", "NAS"
diff --git a/ultralytics/models/nas/model.py b/ultralytics/models/nas/model.py
index 00d0b6ed..f990cb4e 100644
--- a/ultralytics/models/nas/model.py
+++ b/ultralytics/models/nas/model.py
@@ -44,20 +44,21 @@ class NAS(Model):
YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
"""
- def __init__(self, model='yolo_nas_s.pt') -> None:
+ def __init__(self, model="yolo_nas_s.pt") -> None:
"""Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
- assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.'
- super().__init__(model, task='detect')
+ assert Path(model).suffix not in (".yaml", ".yml"), "YOLO-NAS models only support pre-trained models."
+ super().__init__(model, task="detect")
@smart_inference_mode()
def _load(self, weights: str, task: str):
"""Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
import super_gradients
+
suffix = Path(weights).suffix
- if suffix == '.pt':
+ if suffix == ".pt":
self.model = torch.load(weights)
- elif suffix == '':
- self.model = super_gradients.training.models.get(weights, pretrained_weights='coco')
+ elif suffix == "":
+ self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
# Standardize model
self.model.fuse = lambda verbose=True: self.model
self.model.stride = torch.tensor([32])
@@ -65,7 +66,7 @@ class NAS(Model):
self.model.is_fused = lambda: False # for info()
self.model.yaml = {} # for info()
self.model.pt_path = weights # for export()
- self.model.task = 'detect' # for export()
+ self.model.task = "detect" # for export()
def info(self, detailed=False, verbose=True):
"""
@@ -80,4 +81,4 @@ class NAS(Model):
@property
def task_map(self):
"""Returns a dictionary mapping tasks to respective predictor and validator classes."""
- return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}}
+ return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}
diff --git a/ultralytics/models/nas/predict.py b/ultralytics/models/nas/predict.py
index 0118527a..2e485462 100644
--- a/ultralytics/models/nas/predict.py
+++ b/ultralytics/models/nas/predict.py
@@ -39,12 +39,14 @@ class NASPredictor(BasePredictor):
boxes = ops.xyxy2xywh(preds_in[0][0])
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
- preds = ops.non_max_suppression(preds,
- self.args.conf,
- self.args.iou,
- agnostic=self.args.agnostic_nms,
- max_det=self.args.max_det,
- classes=self.args.classes)
+ preds = ops.non_max_suppression(
+ preds,
+ self.args.conf,
+ self.args.iou,
+ agnostic=self.args.agnostic_nms,
+ max_det=self.args.max_det,
+ classes=self.args.classes,
+ )
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
diff --git a/ultralytics/models/nas/val.py b/ultralytics/models/nas/val.py
index 41f60c19..a4a4f990 100644
--- a/ultralytics/models/nas/val.py
+++ b/ultralytics/models/nas/val.py
@@ -5,7 +5,7 @@ import torch
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import ops
-__all__ = ['NASValidator']
+__all__ = ["NASValidator"]
class NASValidator(DetectionValidator):
@@ -38,11 +38,13 @@ class NASValidator(DetectionValidator):
"""Apply Non-maximum suppression to prediction outputs."""
boxes = ops.xyxy2xywh(preds_in[0][0])
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
- return ops.non_max_suppression(preds,
- self.args.conf,
- self.args.iou,
- labels=self.lb,
- multi_label=False,
- agnostic=self.args.single_cls,
- max_det=self.args.max_det,
- max_time_img=0.5)
+ return ops.non_max_suppression(
+ preds,
+ self.args.conf,
+ self.args.iou,
+ labels=self.lb,
+ multi_label=False,
+ agnostic=self.args.single_cls,
+ max_det=self.args.max_det,
+ max_time_img=0.5,
+ )
diff --git a/ultralytics/models/rtdetr/__init__.py b/ultralytics/models/rtdetr/__init__.py
index 4d121156..172c74b4 100644
--- a/ultralytics/models/rtdetr/__init__.py
+++ b/ultralytics/models/rtdetr/__init__.py
@@ -4,4 +4,4 @@ from .model import RTDETR
from .predict import RTDETRPredictor
from .val import RTDETRValidator
-__all__ = 'RTDETRPredictor', 'RTDETRValidator', 'RTDETR'
+__all__ = "RTDETRPredictor", "RTDETRValidator", "RTDETR"
diff --git a/ultralytics/models/rtdetr/model.py b/ultralytics/models/rtdetr/model.py
index 6e582a8b..b8def485 100644
--- a/ultralytics/models/rtdetr/model.py
+++ b/ultralytics/models/rtdetr/model.py
@@ -24,7 +24,7 @@ class RTDETR(Model):
model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'.
"""
- def __init__(self, model='rtdetr-l.pt') -> None:
+ def __init__(self, model="rtdetr-l.pt") -> None:
"""
Initializes the RT-DETR model with the given pre-trained model file. Supports .pt and .yaml formats.
@@ -34,9 +34,9 @@ class RTDETR(Model):
Raises:
NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
"""
- if model and model.split('.')[-1] not in ('pt', 'yaml', 'yml'):
- raise NotImplementedError('RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.')
- super().__init__(model=model, task='detect')
+ if model and model.split(".")[-1] not in ("pt", "yaml", "yml"):
+ raise NotImplementedError("RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.")
+ super().__init__(model=model, task="detect")
@property
def task_map(self) -> dict:
@@ -47,8 +47,10 @@ class RTDETR(Model):
dict: A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
"""
return {
- 'detect': {
- 'predictor': RTDETRPredictor,
- 'validator': RTDETRValidator,
- 'trainer': RTDETRTrainer,
- 'model': RTDETRDetectionModel}}
+ "detect": {
+ "predictor": RTDETRPredictor,
+ "validator": RTDETRValidator,
+ "trainer": RTDETRTrainer,
+ "model": RTDETRDetectionModel,
+ }
+ }
diff --git a/ultralytics/models/rtdetr/train.py b/ultralytics/models/rtdetr/train.py
index 26b7ea68..973af649 100644
--- a/ultralytics/models/rtdetr/train.py
+++ b/ultralytics/models/rtdetr/train.py
@@ -43,12 +43,12 @@ class RTDETRTrainer(DetectionTrainer):
Returns:
(RTDETRDetectionModel): Initialized model.
"""
- model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
+ model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
- def build_dataset(self, img_path, mode='val', batch=None):
+ def build_dataset(self, img_path, mode="val", batch=None):
"""
Build and return an RT-DETR dataset for training or validation.
@@ -60,15 +60,17 @@ class RTDETRTrainer(DetectionTrainer):
Returns:
(RTDETRDataset): Dataset object for the specific mode.
"""
- return RTDETRDataset(img_path=img_path,
- imgsz=self.args.imgsz,
- batch_size=batch,
- augment=mode == 'train',
- hyp=self.args,
- rect=False,
- cache=self.args.cache or None,
- prefix=colorstr(f'{mode}: '),
- data=self.data)
+ return RTDETRDataset(
+ img_path=img_path,
+ imgsz=self.args.imgsz,
+ batch_size=batch,
+ augment=mode == "train",
+ hyp=self.args,
+ rect=False,
+ cache=self.args.cache or None,
+ prefix=colorstr(f"{mode}: "),
+ data=self.data,
+ )
def get_validator(self):
"""
@@ -77,7 +79,7 @@ class RTDETRTrainer(DetectionTrainer):
Returns:
(RTDETRValidator): Validator object for model validation.
"""
- self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'
+ self.loss_names = "giou_loss", "cls_loss", "l1_loss"
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def preprocess_batch(self, batch):
@@ -91,10 +93,10 @@ class RTDETRTrainer(DetectionTrainer):
(dict): Preprocessed batch.
"""
batch = super().preprocess_batch(batch)
- bs = len(batch['img'])
- batch_idx = batch['batch_idx']
+ bs = len(batch["img"])
+ batch_idx = batch["batch_idx"]
gt_bbox, gt_class = [], []
for i in range(bs):
- gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
- gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
+ gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
+ gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
return batch
diff --git a/ultralytics/models/rtdetr/val.py b/ultralytics/models/rtdetr/val.py
index c52b5940..359edb5d 100644
--- a/ultralytics/models/rtdetr/val.py
+++ b/ultralytics/models/rtdetr/val.py
@@ -7,7 +7,7 @@ from ultralytics.data.augment import Compose, Format, v8_transforms
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import colorstr, ops
-__all__ = 'RTDETRValidator', # tuple or list
+__all__ = ("RTDETRValidator",) # tuple or list
class RTDETRDataset(YOLODataset):
@@ -37,13 +37,16 @@ class RTDETRDataset(YOLODataset):
# transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)])
transforms = Compose([])
transforms.append(
- Format(bbox_format='xywh',
- normalize=True,
- return_mask=self.use_segments,
- return_keypoint=self.use_keypoints,
- batch_idx=True,
- mask_ratio=hyp.mask_ratio,
- mask_overlap=hyp.overlap_mask))
+ Format(
+ bbox_format="xywh",
+ normalize=True,
+ return_mask=self.use_segments,
+ return_keypoint=self.use_keypoints,
+ batch_idx=True,
+ mask_ratio=hyp.mask_ratio,
+ mask_overlap=hyp.overlap_mask,
+ )
+ )
return transforms
@@ -68,7 +71,7 @@ class RTDETRValidator(DetectionValidator):
For further details on the attributes and methods, refer to the parent DetectionValidator class.
"""
- def build_dataset(self, img_path, mode='val', batch=None):
+ def build_dataset(self, img_path, mode="val", batch=None):
"""
Build an RTDETR Dataset.
@@ -85,8 +88,9 @@ class RTDETRValidator(DetectionValidator):
hyp=self.args,
rect=False, # no rect
cache=self.args.cache or None,
- prefix=colorstr(f'{mode}: '),
- data=self.data)
+ prefix=colorstr(f"{mode}: "),
+ data=self.data,
+ )
def postprocess(self, preds):
"""Apply Non-maximum suppression to prediction outputs."""
@@ -108,12 +112,12 @@ class RTDETRValidator(DetectionValidator):
def _prepare_batch(self, si, batch):
"""Prepares a batch for training or inference by applying transformations."""
- idx = batch['batch_idx'] == si
- cls = batch['cls'][idx].squeeze(-1)
- bbox = batch['bboxes'][idx]
- ori_shape = batch['ori_shape'][si]
- imgsz = batch['img'].shape[2:]
- ratio_pad = batch['ratio_pad'][si]
+ idx = batch["batch_idx"] == si
+ cls = batch["cls"][idx].squeeze(-1)
+ bbox = batch["bboxes"][idx]
+ ori_shape = batch["ori_shape"][si]
+ imgsz = batch["img"].shape[2:]
+ ratio_pad = batch["ratio_pad"][si]
if len(cls):
bbox = ops.xywh2xyxy(bbox) # target boxes
bbox[..., [0, 2]] *= ori_shape[1] # native-space pred
@@ -124,6 +128,6 @@ class RTDETRValidator(DetectionValidator):
def _prepare_pred(self, pred, pbatch):
"""Prepares and returns a batch with transformed bounding boxes and class labels."""
predn = pred.clone()
- predn[..., [0, 2]] *= pbatch['ori_shape'][1] / self.args.imgsz # native-space pred
- predn[..., [1, 3]] *= pbatch['ori_shape'][0] / self.args.imgsz # native-space pred
+ predn[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred
+ predn[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred
return predn.float()
diff --git a/ultralytics/models/sam/__init__.py b/ultralytics/models/sam/__init__.py
index abf2eef5..8701fcce 100644
--- a/ultralytics/models/sam/__init__.py
+++ b/ultralytics/models/sam/__init__.py
@@ -3,4 +3,4 @@
from .model import SAM
from .predict import Predictor
-__all__ = 'SAM', 'Predictor' # tuple or list
+__all__ = "SAM", "Predictor" # tuple or list
diff --git a/ultralytics/models/sam/amg.py b/ultralytics/models/sam/amg.py
index d7751d6f..c4bb6d1b 100644
--- a/ultralytics/models/sam/amg.py
+++ b/ultralytics/models/sam/amg.py
@@ -8,10 +8,9 @@ import numpy as np
import torch
-def is_box_near_crop_edge(boxes: torch.Tensor,
- crop_box: List[int],
- orig_box: List[int],
- atol: float = 20.0) -> torch.Tensor:
+def is_box_near_crop_edge(
+ boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
+) -> torch.Tensor:
"""Return a boolean tensor indicating if boxes are near the crop edge."""
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
@@ -24,10 +23,10 @@ def is_box_near_crop_edge(boxes: torch.Tensor,
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
"""Yield batches of data from the input arguments."""
- assert args and all(len(a) == len(args[0]) for a in args), 'Batched iteration must have same-size inputs.'
+ assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
for b in range(n_batches):
- yield [arg[b * batch_size:(b + 1) * batch_size] for arg in args]
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
@@ -39,9 +38,8 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
"""
# One mask is always contained inside the other.
# Save memory by preventing unnecessary cast to torch.int64
- intersections = ((masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1,
- dtype=torch.int32))
- unions = ((masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32))
+ intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
+ unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
return intersections / unions
@@ -56,11 +54,12 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
"""Generate point grids for all crop layers."""
- return [build_point_grid(int(n_per_side / (scale_per_layer ** i))) for i in range(n_layers + 1)]
+ return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
-def generate_crop_boxes(im_size: Tuple[int, ...], n_layers: int,
- overlap_ratio: float) -> Tuple[List[List[int]], List[int]]:
+def generate_crop_boxes(
+ im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
+) -> Tuple[List[List[int]], List[int]]:
"""
Generates a list of crop boxes of different sizes.
@@ -132,8 +131,8 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
"""Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator."""
import cv2 # type: ignore
- assert mode in {'holes', 'islands'}
- correct_holes = mode == 'holes'
+ assert mode in {"holes", "islands"}
+ correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
diff --git a/ultralytics/models/sam/build.py b/ultralytics/models/sam/build.py
index c27f2d09..cb3a7c68 100644
--- a/ultralytics/models/sam/build.py
+++ b/ultralytics/models/sam/build.py
@@ -64,46 +64,47 @@ def build_mobile_sam(checkpoint=None):
)
-def _build_sam(encoder_embed_dim,
- encoder_depth,
- encoder_num_heads,
- encoder_global_attn_indexes,
- checkpoint=None,
- mobile_sam=False):
+def _build_sam(
+ encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, mobile_sam=False
+):
"""Builds the selected SAM model architecture."""
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
- image_encoder = (TinyViT(
- img_size=1024,
- in_chans=3,
- num_classes=1000,
- embed_dims=encoder_embed_dim,
- depths=encoder_depth,
- num_heads=encoder_num_heads,
- window_sizes=[7, 7, 14, 7],
- mlp_ratio=4.0,
- drop_rate=0.0,
- drop_path_rate=0.0,
- use_checkpoint=False,
- mbconv_expand_ratio=4.0,
- local_conv_size=3,
- layer_lr_decay=0.8,
- ) if mobile_sam else ImageEncoderViT(
- depth=encoder_depth,
- embed_dim=encoder_embed_dim,
- img_size=image_size,
- mlp_ratio=4,
- norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
- num_heads=encoder_num_heads,
- patch_size=vit_patch_size,
- qkv_bias=True,
- use_rel_pos=True,
- global_attn_indexes=encoder_global_attn_indexes,
- window_size=14,
- out_chans=prompt_embed_dim,
- ))
+ image_encoder = (
+ TinyViT(
+ img_size=1024,
+ in_chans=3,
+ num_classes=1000,
+ embed_dims=encoder_embed_dim,
+ depths=encoder_depth,
+ num_heads=encoder_num_heads,
+ window_sizes=[7, 7, 14, 7],
+ mlp_ratio=4.0,
+ drop_rate=0.0,
+ drop_path_rate=0.0,
+ use_checkpoint=False,
+ mbconv_expand_ratio=4.0,
+ local_conv_size=3,
+ layer_lr_decay=0.8,
+ )
+ if mobile_sam
+ else ImageEncoderViT(
+ depth=encoder_depth,
+ embed_dim=encoder_embed_dim,
+ img_size=image_size,
+ mlp_ratio=4,
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
+ num_heads=encoder_num_heads,
+ patch_size=vit_patch_size,
+ qkv_bias=True,
+ use_rel_pos=True,
+ global_attn_indexes=encoder_global_attn_indexes,
+ window_size=14,
+ out_chans=prompt_embed_dim,
+ )
+ )
sam = Sam(
image_encoder=image_encoder,
prompt_encoder=PromptEncoder(
@@ -129,7 +130,7 @@ def _build_sam(encoder_embed_dim,
)
if checkpoint is not None:
checkpoint = attempt_download_asset(checkpoint)
- with open(checkpoint, 'rb') as f:
+ with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
sam.load_state_dict(state_dict)
sam.eval()
@@ -139,13 +140,14 @@ def _build_sam(encoder_embed_dim,
sam_model_map = {
- 'sam_h.pt': build_sam_vit_h,
- 'sam_l.pt': build_sam_vit_l,
- 'sam_b.pt': build_sam_vit_b,
- 'mobile_sam.pt': build_mobile_sam, }
+ "sam_h.pt": build_sam_vit_h,
+ "sam_l.pt": build_sam_vit_l,
+ "sam_b.pt": build_sam_vit_b,
+ "mobile_sam.pt": build_mobile_sam,
+}
-def build_sam(ckpt='sam_b.pt'):
+def build_sam(ckpt="sam_b.pt"):
"""Build a SAM model specified by ckpt."""
model_builder = None
ckpt = str(ckpt) # to allow Path ckpt types
@@ -154,6 +156,6 @@ def build_sam(ckpt='sam_b.pt'):
model_builder = sam_model_map.get(k)
if not model_builder:
- raise FileNotFoundError(f'{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}')
+ raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}")
return model_builder(ckpt)
diff --git a/ultralytics/models/sam/model.py b/ultralytics/models/sam/model.py
index 68acd22f..8ae9ebaa 100644
--- a/ultralytics/models/sam/model.py
+++ b/ultralytics/models/sam/model.py
@@ -32,7 +32,7 @@ class SAM(Model):
dataset.
"""
- def __init__(self, model='sam_b.pt') -> None:
+ def __init__(self, model="sam_b.pt") -> None:
"""
Initializes the SAM model with a pre-trained model file.
@@ -42,9 +42,9 @@ class SAM(Model):
Raises:
NotImplementedError: If the model file extension is not .pt or .pth.
"""
- if model and Path(model).suffix not in ('.pt', '.pth'):
- raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.')
- super().__init__(model=model, task='segment')
+ if model and Path(model).suffix not in (".pt", ".pth"):
+ raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
+ super().__init__(model=model, task="segment")
def _load(self, weights: str, task=None):
"""
@@ -70,7 +70,7 @@ class SAM(Model):
Returns:
(list): The model predictions.
"""
- overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
+ overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
kwargs.update(overrides)
prompts = dict(bboxes=bboxes, points=points, labels=labels)
return super().predict(source, stream, prompts=prompts, **kwargs)
@@ -112,4 +112,4 @@ class SAM(Model):
Returns:
(dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
"""
- return {'segment': {'predictor': Predictor}}
+ return {"segment": {"predictor": Predictor}}
diff --git a/ultralytics/models/sam/modules/decoders.py b/ultralytics/models/sam/modules/decoders.py
index 4ad1d9f1..41e3af5d 100644
--- a/ultralytics/models/sam/modules/decoders.py
+++ b/ultralytics/models/sam/modules/decoders.py
@@ -64,8 +64,9 @@ class MaskDecoder(nn.Module):
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
activation(),
)
- self.output_hypernetworks_mlps = nn.ModuleList([
- MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)])
+ self.output_hypernetworks_mlps = nn.ModuleList(
+ [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
+ )
self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)
@@ -132,13 +133,14 @@ class MaskDecoder(nn.Module):
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
- mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :]
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
hyper_in_list: List[torch.Tensor] = [
- self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)]
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
+ ]
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
diff --git a/ultralytics/models/sam/modules/encoders.py b/ultralytics/models/sam/modules/encoders.py
index f7771380..a43e115c 100644
--- a/ultralytics/models/sam/modules/encoders.py
+++ b/ultralytics/models/sam/modules/encoders.py
@@ -28,23 +28,23 @@ class ImageEncoderViT(nn.Module):
"""
def __init__(
- self,
- img_size: int = 1024,
- patch_size: int = 16,
- in_chans: int = 3,
- embed_dim: int = 768,
- depth: int = 12,
- num_heads: int = 12,
- mlp_ratio: float = 4.0,
- out_chans: int = 256,
- qkv_bias: bool = True,
- norm_layer: Type[nn.Module] = nn.LayerNorm,
- act_layer: Type[nn.Module] = nn.GELU,
- use_abs_pos: bool = True,
- use_rel_pos: bool = False,
- rel_pos_zero_init: bool = True,
- window_size: int = 0,
- global_attn_indexes: Tuple[int, ...] = (),
+ self,
+ img_size: int = 1024,
+ patch_size: int = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ depth: int = 12,
+ num_heads: int = 12,
+ mlp_ratio: float = 4.0,
+ out_chans: int = 256,
+ qkv_bias: bool = True,
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
+ act_layer: Type[nn.Module] = nn.GELU,
+ use_abs_pos: bool = True,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ window_size: int = 0,
+ global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
Args:
@@ -283,9 +283,9 @@ class PromptEncoder(nn.Module):
if masks is not None:
dense_embeddings = self._embed_masks(masks)
else:
- dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1,
- 1).expand(bs, -1, self.image_embedding_size[0],
- self.image_embedding_size[1])
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
+ )
return sparse_embeddings, dense_embeddings
@@ -298,7 +298,7 @@ class PositionEmbeddingRandom(nn.Module):
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
- self.register_buffer('positional_encoding_gaussian_matrix', scale * torch.randn((2, num_pos_feats)))
+ self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)))
# Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'
torch.use_deterministic_algorithms(False)
@@ -425,14 +425,14 @@ class Attention(nn.Module):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
- self.scale = head_dim ** -0.5
+ self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
- assert (input_size is not None), 'Input size must be provided if using relative positional encoding.'
+ assert input_size is not None, "Input size must be provided if using relative positional encoding."
# Initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
@@ -479,8 +479,9 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T
return windows, (Hp, Wp)
-def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int],
- hw: Tuple[int, int]) -> torch.Tensor:
+def window_unpartition(
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
+) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
@@ -523,7 +524,7 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
- mode='linear',
+ mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
@@ -567,11 +568,12 @@ def add_decomposed_rel_pos(
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
- rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh)
- rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw)
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
- B, q_h * q_w, k_h * k_w)
+ B, q_h * q_w, k_h * k_w
+ )
return attn
@@ -580,12 +582,12 @@ class PatchEmbed(nn.Module):
"""Image to Patch Embedding."""
def __init__(
- self,
- kernel_size: Tuple[int, int] = (16, 16),
- stride: Tuple[int, int] = (16, 16),
- padding: Tuple[int, int] = (0, 0),
- in_chans: int = 3,
- embed_dim: int = 768,
+ self,
+ kernel_size: Tuple[int, int] = (16, 16),
+ stride: Tuple[int, int] = (16, 16),
+ padding: Tuple[int, int] = (0, 0),
+ in_chans: int = 3,
+ embed_dim: int = 768,
) -> None:
"""
Initialize PatchEmbed module.
diff --git a/ultralytics/models/sam/modules/sam.py b/ultralytics/models/sam/modules/sam.py
index 4097a228..95d9bbe6 100644
--- a/ultralytics/models/sam/modules/sam.py
+++ b/ultralytics/models/sam/modules/sam.py
@@ -30,8 +30,9 @@ class Sam(nn.Module):
pixel_mean (List[float]): Mean pixel values for image normalization.
pixel_std (List[float]): Standard deviation values for image normalization.
"""
+
mask_threshold: float = 0.0
- image_format: str = 'RGB'
+ image_format: str = "RGB"
def __init__(
self,
@@ -39,7 +40,7 @@ class Sam(nn.Module):
prompt_encoder: PromptEncoder,
mask_decoder: MaskDecoder,
pixel_mean: List[float] = (123.675, 116.28, 103.53),
- pixel_std: List[float] = (58.395, 57.12, 57.375)
+ pixel_std: List[float] = (58.395, 57.12, 57.375),
) -> None:
"""
Initialize the Sam class to predict object masks from an image and input prompts.
@@ -60,5 +61,5 @@ class Sam(nn.Module):
self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder
self.mask_decoder = mask_decoder
- self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False)
- self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False)
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
diff --git a/ultralytics/models/sam/modules/tiny_encoder.py b/ultralytics/models/sam/modules/tiny_encoder.py
index 9955a261..ac378680 100644
--- a/ultralytics/models/sam/modules/tiny_encoder.py
+++ b/ultralytics/models/sam/modules/tiny_encoder.py
@@ -28,11 +28,11 @@ class Conv2d_BN(torch.nn.Sequential):
drop path.
"""
super().__init__()
- self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
+ self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
bn = torch.nn.BatchNorm2d(b)
torch.nn.init.constant_(bn.weight, bn_weight_init)
torch.nn.init.constant_(bn.bias, 0)
- self.add_module('bn', bn)
+ self.add_module("bn", bn)
class PatchEmbed(nn.Module):
@@ -146,11 +146,11 @@ class ConvLayer(nn.Module):
input_resolution,
depth,
activation,
- drop_path=0.,
+ drop_path=0.0,
downsample=None,
use_checkpoint=False,
out_dim=None,
- conv_expand_ratio=4.,
+ conv_expand_ratio=4.0,
):
"""
Initializes the ConvLayer with the given dimensions and settings.
@@ -173,18 +173,25 @@ class ConvLayer(nn.Module):
self.use_checkpoint = use_checkpoint
# Build blocks
- self.blocks = nn.ModuleList([
- MBConv(
- dim,
- dim,
- conv_expand_ratio,
- activation,
- drop_path[i] if isinstance(drop_path, list) else drop_path,
- ) for i in range(depth)])
+ self.blocks = nn.ModuleList(
+ [
+ MBConv(
+ dim,
+ dim,
+ conv_expand_ratio,
+ activation,
+ drop_path[i] if isinstance(drop_path, list) else drop_path,
+ )
+ for i in range(depth)
+ ]
+ )
# Patch merging layer
- self.downsample = None if downsample is None else downsample(
- input_resolution, dim=dim, out_dim=out_dim, activation=activation)
+ self.downsample = (
+ None
+ if downsample is None
+ else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
+ )
def forward(self, x):
"""Processes the input through a series of convolutional layers and returns the activated output."""
@@ -200,7 +207,7 @@ class Mlp(nn.Module):
This layer takes an input with in_features, applies layer normalization and two fully-connected layers.
"""
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
"""Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc."""
super().__init__()
out_features = out_features or in_features
@@ -232,12 +239,12 @@ class Attention(torch.nn.Module):
"""
def __init__(
- self,
- dim,
- key_dim,
- num_heads=8,
- attn_ratio=4,
- resolution=(14, 14),
+ self,
+ dim,
+ key_dim,
+ num_heads=8,
+ attn_ratio=4,
+ resolution=(14, 14),
):
"""
Initializes the Attention module.
@@ -256,7 +263,7 @@ class Attention(torch.nn.Module):
assert isinstance(resolution, tuple) and len(resolution) == 2
self.num_heads = num_heads
- self.scale = key_dim ** -0.5
+ self.scale = key_dim**-0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
@@ -279,13 +286,13 @@ class Attention(torch.nn.Module):
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
- self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False)
+ self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False)
@torch.no_grad()
def train(self, mode=True):
"""Sets the module in training mode and handles attribute 'ab' based on the mode."""
super().train(mode)
- if mode and hasattr(self, 'ab'):
+ if mode and hasattr(self, "ab"):
del self.ab
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
@@ -306,8 +313,9 @@ class Attention(torch.nn.Module):
v = v.permute(0, 2, 1, 3)
self.ab = self.ab.to(self.attention_biases.device)
- attn = ((q @ k.transpose(-2, -1)) * self.scale +
- (self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab))
+ attn = (q @ k.transpose(-2, -1)) * self.scale + (
+ self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
+ )
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
return self.proj(x)
@@ -322,9 +330,9 @@ class TinyViTBlock(nn.Module):
input_resolution,
num_heads,
window_size=7,
- mlp_ratio=4.,
- drop=0.,
- drop_path=0.,
+ mlp_ratio=4.0,
+ drop=0.0,
+ drop_path=0.0,
local_conv_size=3,
activation=nn.GELU,
):
@@ -350,7 +358,7 @@ class TinyViTBlock(nn.Module):
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
- assert window_size > 0, 'window_size must be greater than 0'
+ assert window_size > 0, "window_size must be greater than 0"
self.window_size = window_size
self.mlp_ratio = mlp_ratio
@@ -358,7 +366,7 @@ class TinyViTBlock(nn.Module):
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path = nn.Identity()
- assert dim % num_heads == 0, 'dim must be divisible by num_heads'
+ assert dim % num_heads == 0, "dim must be divisible by num_heads"
head_dim = dim // num_heads
window_resolution = (window_size, window_size)
@@ -377,7 +385,7 @@ class TinyViTBlock(nn.Module):
"""
H, W = self.input_resolution
B, L, C = x.shape
- assert L == H * W, 'input feature has wrong size'
+ assert L == H * W, "input feature has wrong size"
res_x = x
if H == self.window_size and W == self.window_size:
x = self.attn(x)
@@ -394,8 +402,11 @@ class TinyViTBlock(nn.Module):
nH = pH // self.window_size
nW = pW // self.window_size
# Window partition
- x = x.view(B, nH, self.window_size, nW, self.window_size,
- C).transpose(2, 3).reshape(B * nH * nW, self.window_size * self.window_size, C)
+ x = (
+ x.view(B, nH, self.window_size, nW, self.window_size, C)
+ .transpose(2, 3)
+ .reshape(B * nH * nW, self.window_size * self.window_size, C)
+ )
x = self.attn(x)
# Window reverse
x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C)
@@ -417,8 +428,10 @@ class TinyViTBlock(nn.Module):
"""Returns a formatted string representing the TinyViTBlock's parameters: dimension, input resolution, number of
attentions heads, window size, and MLP ratio.
"""
- return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
- f'window_size={self.window_size}, mlp_ratio={self.mlp_ratio}'
+ return (
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
+ f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
+ )
class BasicLayer(nn.Module):
@@ -431,9 +444,9 @@ class BasicLayer(nn.Module):
depth,
num_heads,
window_size,
- mlp_ratio=4.,
- drop=0.,
- drop_path=0.,
+ mlp_ratio=4.0,
+ drop=0.0,
+ drop_path=0.0,
downsample=None,
use_checkpoint=False,
local_conv_size=3,
@@ -468,22 +481,29 @@ class BasicLayer(nn.Module):
self.use_checkpoint = use_checkpoint
# Build blocks
- self.blocks = nn.ModuleList([
- TinyViTBlock(
- dim=dim,
- input_resolution=input_resolution,
- num_heads=num_heads,
- window_size=window_size,
- mlp_ratio=mlp_ratio,
- drop=drop,
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
- local_conv_size=local_conv_size,
- activation=activation,
- ) for i in range(depth)])
+ self.blocks = nn.ModuleList(
+ [
+ TinyViTBlock(
+ dim=dim,
+ input_resolution=input_resolution,
+ num_heads=num_heads,
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ drop=drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ local_conv_size=local_conv_size,
+ activation=activation,
+ )
+ for i in range(depth)
+ ]
+ )
# Patch merging layer
- self.downsample = None if downsample is None else downsample(
- input_resolution, dim=dim, out_dim=out_dim, activation=activation)
+ self.downsample = (
+ None
+ if downsample is None
+ else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
+ )
def forward(self, x):
"""Performs forward propagation on the input tensor and returns a normalized tensor."""
@@ -493,7 +513,7 @@ class BasicLayer(nn.Module):
def extra_repr(self) -> str:
"""Returns a string representation of the extra_repr function with the layer's parameters."""
- return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
class LayerNorm2d(nn.Module):
@@ -549,8 +569,8 @@ class TinyViT(nn.Module):
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_sizes=[7, 7, 14, 7],
- mlp_ratio=4.,
- drop_rate=0.,
+ mlp_ratio=4.0,
+ drop_rate=0.0,
drop_path_rate=0.1,
use_checkpoint=False,
mbconv_expand_ratio=4.0,
@@ -585,10 +605,9 @@ class TinyViT(nn.Module):
activation = nn.GELU
- self.patch_embed = PatchEmbed(in_chans=in_chans,
- embed_dim=embed_dims[0],
- resolution=img_size,
- activation=activation)
+ self.patch_embed = PatchEmbed(
+ in_chans=in_chans, embed_dim=embed_dims[0], resolution=img_size, activation=activation
+ )
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
@@ -601,27 +620,30 @@ class TinyViT(nn.Module):
for i_layer in range(self.num_layers):
kwargs = dict(
dim=embed_dims[i_layer],
- input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
- patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))),
+ input_resolution=(
+ patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
+ patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
+ ),
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
# patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
- out_dim=embed_dims[min(i_layer + 1,
- len(embed_dims) - 1)],
+ out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)],
activation=activation,
)
if i_layer == 0:
layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)
else:
- layer = BasicLayer(num_heads=num_heads[i_layer],
- window_size=window_sizes[i_layer],
- mlp_ratio=self.mlp_ratio,
- drop=drop_rate,
- local_conv_size=local_conv_size,
- **kwargs)
+ layer = BasicLayer(
+ num_heads=num_heads[i_layer],
+ window_size=window_sizes[i_layer],
+ mlp_ratio=self.mlp_ratio,
+ drop=drop_rate,
+ local_conv_size=local_conv_size,
+ **kwargs,
+ )
self.layers.append(layer)
# Classifier head
@@ -680,7 +702,7 @@ class TinyViT(nn.Module):
def _check_lr_scale(m):
"""Checks if the learning rate scale attribute is present in module's parameters."""
for p in m.parameters():
- assert hasattr(p, 'lr_scale'), p.param_name
+ assert hasattr(p, "lr_scale"), p.param_name
self.apply(_check_lr_scale)
@@ -698,7 +720,7 @@ class TinyViT(nn.Module):
@torch.jit.ignore
def no_weight_decay_keywords(self):
"""Returns a dictionary of parameter names where weight decay should not be applied."""
- return {'attention_biases'}
+ return {"attention_biases"}
def forward_features(self, x):
"""Runs the input through the model layers and returns the transformed output."""
diff --git a/ultralytics/models/sam/modules/transformer.py b/ultralytics/models/sam/modules/transformer.py
index 5c06acd9..1ad07418 100644
--- a/ultralytics/models/sam/modules/transformer.py
+++ b/ultralytics/models/sam/modules/transformer.py
@@ -62,7 +62,8 @@ class TwoWayTransformer(nn.Module):
activation=activation,
attention_downsample_rate=attention_downsample_rate,
skip_first_layer_pe=(i == 0),
- ))
+ )
+ )
self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
self.norm_final_attn = nn.LayerNorm(embedding_dim)
@@ -227,7 +228,7 @@ class Attention(nn.Module):
self.embedding_dim = embedding_dim
self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads
- assert self.internal_dim % num_heads == 0, 'num_heads must divide embedding_dim.'
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py
index 94362ecc..540ed6c6 100644
--- a/ultralytics/models/sam/predict.py
+++ b/ultralytics/models/sam/predict.py
@@ -19,8 +19,17 @@ from ultralytics.engine.results import Results
from ultralytics.utils import DEFAULT_CFG, ops
from ultralytics.utils.torch_utils import select_device
-from .amg import (batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score,
- generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks)
+from .amg import (
+ batch_iterator,
+ batched_mask_to_box,
+ build_all_layer_point_grids,
+ calculate_stability_score,
+ generate_crop_boxes,
+ is_box_near_crop_edge,
+ remove_small_regions,
+ uncrop_boxes_xyxy,
+ uncrop_masks,
+)
from .build import build_sam
@@ -58,7 +67,7 @@ class Predictor(BasePredictor):
"""
if overrides is None:
overrides = {}
- overrides.update(dict(task='segment', mode='predict', imgsz=1024))
+ overrides.update(dict(task="segment", mode="predict", imgsz=1024))
super().__init__(cfg, overrides, _callbacks)
self.args.retina_masks = True
self.im = None
@@ -107,7 +116,7 @@ class Predictor(BasePredictor):
Returns:
(List[np.ndarray]): List of transformed images.
"""
- assert len(im) == 1, 'SAM model does not currently support batched inference'
+ assert len(im) == 1, "SAM model does not currently support batched inference"
letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
return [letterbox(image=x) for x in im]
@@ -132,9 +141,9 @@ class Predictor(BasePredictor):
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
"""
# Override prompts if any stored in self.prompts
- bboxes = self.prompts.pop('bboxes', bboxes)
- points = self.prompts.pop('points', points)
- masks = self.prompts.pop('masks', masks)
+ bboxes = self.prompts.pop("bboxes", bboxes)
+ points = self.prompts.pop("points", points)
+ masks = self.prompts.pop("masks", masks)
if all(i is None for i in [bboxes, points, masks]):
return self.generate(im, *args, **kwargs)
@@ -199,18 +208,20 @@ class Predictor(BasePredictor):
# `d` could be 1 or 3 depends on `multimask_output`.
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
- def generate(self,
- im,
- crop_n_layers=0,
- crop_overlap_ratio=512 / 1500,
- crop_downscale_factor=1,
- point_grids=None,
- points_stride=32,
- points_batch_size=64,
- conf_thres=0.88,
- stability_score_thresh=0.95,
- stability_score_offset=0.95,
- crop_nms_thresh=0.7):
+ def generate(
+ self,
+ im,
+ crop_n_layers=0,
+ crop_overlap_ratio=512 / 1500,
+ crop_downscale_factor=1,
+ point_grids=None,
+ points_stride=32,
+ points_batch_size=64,
+ conf_thres=0.88,
+ stability_score_thresh=0.95,
+ stability_score_offset=0.95,
+ crop_nms_thresh=0.7,
+ ):
"""
Perform image segmentation using the Segment Anything Model (SAM).
@@ -248,19 +259,20 @@ class Predictor(BasePredictor):
area = torch.tensor(w * h, device=im.device)
points_scale = np.array([[w, h]]) # w, h
# Crop image and interpolate to input size
- crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode='bilinear', align_corners=False)
+ crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False)
# (num_points, 2)
points_for_image = point_grids[layer_idx] * points_scale
crop_masks, crop_scores, crop_bboxes = [], [], []
- for (points, ) in batch_iterator(points_batch_size, points_for_image):
+ for (points,) in batch_iterator(points_batch_size, points_for_image):
pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True)
# Interpolate predicted masks to input size
- pred_mask = F.interpolate(pred_mask[None], (h, w), mode='bilinear', align_corners=False)[0]
+ pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0]
idx = pred_score > conf_thres
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
- stability_score = calculate_stability_score(pred_mask, self.model.mask_threshold,
- stability_score_offset)
+ stability_score = calculate_stability_score(
+ pred_mask, self.model.mask_threshold, stability_score_offset
+ )
idx = stability_score > stability_score_thresh
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
# Bool type is much more memory-efficient.
@@ -404,7 +416,7 @@ class Predictor(BasePredictor):
model = build_sam(self.args.model)
self.setup_model(model)
self.setup_source(image)
- assert len(self.dataset) == 1, '`set_image` only supports setting one image!'
+ assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
for batch in self.dataset:
im = self.preprocess(batch[1])
self.features = self.model.image_encoder(im)
@@ -446,9 +458,9 @@ class Predictor(BasePredictor):
scores = []
for mask in masks:
mask = mask.cpu().numpy().astype(np.uint8)
- mask, changed = remove_small_regions(mask, min_area, mode='holes')
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
unchanged = not changed
- mask, changed = remove_small_regions(mask, min_area, mode='islands')
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
unchanged = unchanged and not changed
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
diff --git a/ultralytics/models/utils/loss.py b/ultralytics/models/utils/loss.py
index abb54958..1251cc96 100644
--- a/ultralytics/models/utils/loss.py
+++ b/ultralytics/models/utils/loss.py
@@ -30,14 +30,9 @@ class DETRLoss(nn.Module):
device (torch.device): Device on which tensors are stored.
"""
- def __init__(self,
- nc=80,
- loss_gain=None,
- aux_loss=True,
- use_fl=True,
- use_vfl=False,
- use_uni_match=False,
- uni_match_ind=0):
+ def __init__(
+ self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0
+ ):
"""
DETR loss function.
@@ -52,9 +47,9 @@ class DETRLoss(nn.Module):
super().__init__()
if loss_gain is None:
- loss_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'no_object': 0.1, 'mask': 1, 'dice': 1}
+ loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1}
self.nc = nc
- self.matcher = HungarianMatcher(cost_gain={'class': 2, 'bbox': 5, 'giou': 2})
+ self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
self.loss_gain = loss_gain
self.aux_loss = aux_loss
self.fl = FocalLoss() if use_fl else None
@@ -64,10 +59,10 @@ class DETRLoss(nn.Module):
self.uni_match_ind = uni_match_ind
self.device = None
- def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=''):
+ def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""):
"""Computes the classification loss based on predictions, target values, and ground truth scores."""
# Logits: [b, query, num_classes], gt_class: list[[n, 1]]
- name_class = f'loss_class{postfix}'
+ name_class = f"loss_class{postfix}"
bs, nq = pred_scores.shape[:2]
# one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
@@ -82,28 +77,28 @@ class DETRLoss(nn.Module):
loss_cls = self.fl(pred_scores, one_hot.float())
loss_cls /= max(num_gts, 1) / nq
else:
- loss_cls = nn.BCEWithLogitsLoss(reduction='none')(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
+ loss_cls = nn.BCEWithLogitsLoss(reduction="none")(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
- return {name_class: loss_cls.squeeze() * self.loss_gain['class']}
+ return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
- def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=''):
+ def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
"""Calculates and returns the bounding box loss and GIoU loss for the predicted and ground truth bounding
boxes.
"""
# Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
- name_bbox = f'loss_bbox{postfix}'
- name_giou = f'loss_giou{postfix}'
+ name_bbox = f"loss_bbox{postfix}"
+ name_giou = f"loss_giou{postfix}"
loss = {}
if len(gt_bboxes) == 0:
- loss[name_bbox] = torch.tensor(0., device=self.device)
- loss[name_giou] = torch.tensor(0., device=self.device)
+ loss[name_bbox] = torch.tensor(0.0, device=self.device)
+ loss[name_giou] = torch.tensor(0.0, device=self.device)
return loss
- loss[name_bbox] = self.loss_gain['bbox'] * F.l1_loss(pred_bboxes, gt_bboxes, reduction='sum') / len(gt_bboxes)
+ loss[name_bbox] = self.loss_gain["bbox"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction="sum") / len(gt_bboxes)
loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
- loss[name_giou] = self.loss_gain['giou'] * loss[name_giou]
+ loss[name_giou] = self.loss_gain["giou"] * loss[name_giou]
return {k: v.squeeze() for k, v in loss.items()}
# This function is for future RT-DETR Segment models
@@ -137,50 +132,57 @@ class DETRLoss(nn.Module):
# loss = 1 - (numerator + 1) / (denominator + 1)
# return loss.sum() / num_gts
- def _get_loss_aux(self,
- pred_bboxes,
- pred_scores,
- gt_bboxes,
- gt_cls,
- gt_groups,
- match_indices=None,
- postfix='',
- masks=None,
- gt_mask=None):
+ def _get_loss_aux(
+ self,
+ pred_bboxes,
+ pred_scores,
+ gt_bboxes,
+ gt_cls,
+ gt_groups,
+ match_indices=None,
+ postfix="",
+ masks=None,
+ gt_mask=None,
+ ):
"""Get auxiliary losses."""
# NOTE: loss class, bbox, giou, mask, dice
loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
if match_indices is None and self.use_uni_match:
- match_indices = self.matcher(pred_bboxes[self.uni_match_ind],
- pred_scores[self.uni_match_ind],
- gt_bboxes,
- gt_cls,
- gt_groups,
- masks=masks[self.uni_match_ind] if masks is not None else None,
- gt_mask=gt_mask)
+ match_indices = self.matcher(
+ pred_bboxes[self.uni_match_ind],
+ pred_scores[self.uni_match_ind],
+ gt_bboxes,
+ gt_cls,
+ gt_groups,
+ masks=masks[self.uni_match_ind] if masks is not None else None,
+ gt_mask=gt_mask,
+ )
for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
aux_masks = masks[i] if masks is not None else None
- loss_ = self._get_loss(aux_bboxes,
- aux_scores,
- gt_bboxes,
- gt_cls,
- gt_groups,
- masks=aux_masks,
- gt_mask=gt_mask,
- postfix=postfix,
- match_indices=match_indices)
- loss[0] += loss_[f'loss_class{postfix}']
- loss[1] += loss_[f'loss_bbox{postfix}']
- loss[2] += loss_[f'loss_giou{postfix}']
+ loss_ = self._get_loss(
+ aux_bboxes,
+ aux_scores,
+ gt_bboxes,
+ gt_cls,
+ gt_groups,
+ masks=aux_masks,
+ gt_mask=gt_mask,
+ postfix=postfix,
+ match_indices=match_indices,
+ )
+ loss[0] += loss_[f"loss_class{postfix}"]
+ loss[1] += loss_[f"loss_bbox{postfix}"]
+ loss[2] += loss_[f"loss_giou{postfix}"]
# if masks is not None and gt_mask is not None:
# loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
# loss[3] += loss_[f'loss_mask{postfix}']
# loss[4] += loss_[f'loss_dice{postfix}']
loss = {
- f'loss_class_aux{postfix}': loss[0],
- f'loss_bbox_aux{postfix}': loss[1],
- f'loss_giou_aux{postfix}': loss[2]}
+ f"loss_class_aux{postfix}": loss[0],
+ f"loss_bbox_aux{postfix}": loss[1],
+ f"loss_giou_aux{postfix}": loss[2],
+ }
# if masks is not None and gt_mask is not None:
# loss[f'loss_mask_aux{postfix}'] = loss[3]
# loss[f'loss_dice_aux{postfix}'] = loss[4]
@@ -196,33 +198,37 @@ class DETRLoss(nn.Module):
def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
"""Assigns predicted bounding boxes to ground truth bounding boxes based on the match indices."""
- pred_assigned = torch.cat([
- t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
- for t, (I, _) in zip(pred_bboxes, match_indices)])
- gt_assigned = torch.cat([
- t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
- for t, (_, J) in zip(gt_bboxes, match_indices)])
+ pred_assigned = torch.cat(
+ [
+ t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
+ for t, (I, _) in zip(pred_bboxes, match_indices)
+ ]
+ )
+ gt_assigned = torch.cat(
+ [
+ t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
+ for t, (_, J) in zip(gt_bboxes, match_indices)
+ ]
+ )
return pred_assigned, gt_assigned
- def _get_loss(self,
- pred_bboxes,
- pred_scores,
- gt_bboxes,
- gt_cls,
- gt_groups,
- masks=None,
- gt_mask=None,
- postfix='',
- match_indices=None):
+ def _get_loss(
+ self,
+ pred_bboxes,
+ pred_scores,
+ gt_bboxes,
+ gt_cls,
+ gt_groups,
+ masks=None,
+ gt_mask=None,
+ postfix="",
+ match_indices=None,
+ ):
"""Get losses."""
if match_indices is None:
- match_indices = self.matcher(pred_bboxes,
- pred_scores,
- gt_bboxes,
- gt_cls,
- gt_groups,
- masks=masks,
- gt_mask=gt_mask)
+ match_indices = self.matcher(
+ pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask
+ )
idx, gt_idx = self._get_index(match_indices)
pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
@@ -242,7 +248,7 @@ class DETRLoss(nn.Module):
# loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix))
return loss
- def forward(self, pred_bboxes, pred_scores, batch, postfix='', **kwargs):
+ def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs):
"""
Args:
pred_bboxes (torch.Tensor): [l, b, query, 4]
@@ -254,21 +260,19 @@ class DETRLoss(nn.Module):
postfix (str): postfix of loss name.
"""
self.device = pred_bboxes.device
- match_indices = kwargs.get('match_indices', None)
- gt_cls, gt_bboxes, gt_groups = batch['cls'], batch['bboxes'], batch['gt_groups']
+ match_indices = kwargs.get("match_indices", None)
+ gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"]
- total_loss = self._get_loss(pred_bboxes[-1],
- pred_scores[-1],
- gt_bboxes,
- gt_cls,
- gt_groups,
- postfix=postfix,
- match_indices=match_indices)
+ total_loss = self._get_loss(
+ pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices
+ )
if self.aux_loss:
total_loss.update(
- self._get_loss_aux(pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices,
- postfix))
+ self._get_loss_aux(
+ pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix
+ )
+ )
return total_loss
@@ -300,18 +304,18 @@ class RTDETRDetectionLoss(DETRLoss):
# Check for denoising metadata to compute denoising training loss
if dn_meta is not None:
- dn_pos_idx, dn_num_group = dn_meta['dn_pos_idx'], dn_meta['dn_num_group']
- assert len(batch['gt_groups']) == len(dn_pos_idx)
+ dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"]
+ assert len(batch["gt_groups"]) == len(dn_pos_idx)
# Get the match indices for denoising
- match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch['gt_groups'])
+ match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"])
# Compute the denoising training loss
- dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix='_dn', match_indices=match_indices)
+ dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices)
total_loss.update(dn_loss)
else:
# If no denoising metadata is provided, set denoising loss to zero
- total_loss.update({f'{k}_dn': torch.tensor(0., device=self.device) for k in total_loss.keys()})
+ total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()})
return total_loss
@@ -334,8 +338,8 @@ class RTDETRDetectionLoss(DETRLoss):
if num_gt > 0:
gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
gt_idx = gt_idx.repeat(dn_num_group)
- assert len(dn_pos_idx[i]) == len(gt_idx), 'Expected the same length, '
- f'but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.'
+ assert len(dn_pos_idx[i]) == len(gt_idx), "Expected the same length, "
+ f"but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively."
dn_match_indices.append((dn_pos_idx[i], gt_idx))
else:
dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
diff --git a/ultralytics/models/utils/ops.py b/ultralytics/models/utils/ops.py
index 902756db..4f66feef 100644
--- a/ultralytics/models/utils/ops.py
+++ b/ultralytics/models/utils/ops.py
@@ -37,7 +37,7 @@ class HungarianMatcher(nn.Module):
"""
super().__init__()
if cost_gain is None:
- cost_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'mask': 1, 'dice': 1}
+ cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
self.cost_gain = cost_gain
self.use_fl = use_fl
self.with_mask = with_mask
@@ -86,7 +86,7 @@ class HungarianMatcher(nn.Module):
# Compute the classification cost
pred_scores = pred_scores[:, gt_cls]
if self.use_fl:
- neg_cost_class = (1 - self.alpha) * (pred_scores ** self.gamma) * (-(1 - pred_scores + 1e-8).log())
+ neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log())
pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
cost_class = pos_cost_class - neg_cost_class
else:
@@ -99,9 +99,11 @@ class HungarianMatcher(nn.Module):
cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
# Final cost matrix
- C = self.cost_gain['class'] * cost_class + \
- self.cost_gain['bbox'] * cost_bbox + \
- self.cost_gain['giou'] * cost_giou
+ C = (
+ self.cost_gain["class"] * cost_class
+ + self.cost_gain["bbox"] * cost_bbox
+ + self.cost_gain["giou"] * cost_giou
+ )
# Compute the mask cost and dice cost
if self.with_mask:
C += self._cost_mask(bs, gt_groups, masks, gt_mask)
@@ -111,10 +113,11 @@ class HungarianMatcher(nn.Module):
C = C.view(bs, nq, -1).cpu()
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
- gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
- # (idx for queries, idx for gt)
- return [(torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
- for k, (i, j) in enumerate(indices)]
+ gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) # (idx for queries, idx for gt)
+ return [
+ (torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
+ for k, (i, j) in enumerate(indices)
+ ]
# This function is for future RT-DETR Segment models
# def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
@@ -147,14 +150,9 @@ class HungarianMatcher(nn.Module):
# return C
-def get_cdn_group(batch,
- num_classes,
- num_queries,
- class_embed,
- num_dn=100,
- cls_noise_ratio=0.5,
- box_noise_scale=1.0,
- training=False):
+def get_cdn_group(
+ batch, num_classes, num_queries, class_embed, num_dn=100, cls_noise_ratio=0.5, box_noise_scale=1.0, training=False
+):
"""
Get contrastive denoising training group. This function creates a contrastive denoising training group with positive
and negative samples from the ground truths (gt). It applies noise to the class labels and bounding box coordinates,
@@ -180,7 +178,7 @@ def get_cdn_group(batch,
if (not training) or num_dn <= 0:
return None, None, None, None
- gt_groups = batch['gt_groups']
+ gt_groups = batch["gt_groups"]
total_num = sum(gt_groups)
max_nums = max(gt_groups)
if max_nums == 0:
@@ -190,9 +188,9 @@ def get_cdn_group(batch,
num_group = 1 if num_group == 0 else num_group
# Pad gt to max_num of a batch
bs = len(gt_groups)
- gt_cls = batch['cls'] # (bs*num, )
- gt_bbox = batch['bboxes'] # bs*num, 4
- b_idx = batch['batch_idx']
+ gt_cls = batch["cls"] # (bs*num, )
+ gt_bbox = batch["bboxes"] # bs*num, 4
+ b_idx = batch["batch_idx"]
# Each group has positive and negative queries.
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
@@ -245,16 +243,21 @@ def get_cdn_group(batch,
# Reconstruct cannot see each other
for i in range(num_group):
if i == 0:
- attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
+ attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
if i == num_group - 1:
- attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * i * 2] = True
+ attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True
else:
- attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
- attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * 2 * i] = True
+ attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
+ attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True
dn_meta = {
- 'dn_pos_idx': [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
- 'dn_num_group': num_group,
- 'dn_num_split': [num_dn, num_queries]}
-
- return padding_cls.to(class_embed.device), padding_bbox.to(class_embed.device), attn_mask.to(
- class_embed.device), dn_meta
+ "dn_pos_idx": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
+ "dn_num_group": num_group,
+ "dn_num_split": [num_dn, num_queries],
+ }
+
+ return (
+ padding_cls.to(class_embed.device),
+ padding_bbox.to(class_embed.device),
+ attn_mask.to(class_embed.device),
+ dn_meta,
+ )
diff --git a/ultralytics/models/yolo/__init__.py b/ultralytics/models/yolo/__init__.py
index 602307b8..230f53a8 100644
--- a/ultralytics/models/yolo/__init__.py
+++ b/ultralytics/models/yolo/__init__.py
@@ -4,4 +4,4 @@ from ultralytics.models.yolo import classify, detect, obb, pose, segment
from .model import YOLO
-__all__ = 'classify', 'segment', 'detect', 'pose', 'obb', 'YOLO'
+__all__ = "classify", "segment", "detect", "pose", "obb", "YOLO"
diff --git a/ultralytics/models/yolo/classify/__init__.py b/ultralytics/models/yolo/classify/__init__.py
index 33d72e68..ca92f892 100644
--- a/ultralytics/models/yolo/classify/__init__.py
+++ b/ultralytics/models/yolo/classify/__init__.py
@@ -4,4 +4,4 @@ from ultralytics.models.yolo.classify.predict import ClassificationPredictor
from ultralytics.models.yolo.classify.train import ClassificationTrainer
from ultralytics.models.yolo.classify.val import ClassificationValidator
-__all__ = 'ClassificationPredictor', 'ClassificationTrainer', 'ClassificationValidator'
+__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator"
diff --git a/ultralytics/models/yolo/classify/predict.py b/ultralytics/models/yolo/classify/predict.py
index 9047d8df..853ef048 100644
--- a/ultralytics/models/yolo/classify/predict.py
+++ b/ultralytics/models/yolo/classify/predict.py
@@ -30,19 +30,21 @@ class ClassificationPredictor(BasePredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initializes ClassificationPredictor setting the task to 'classify'."""
super().__init__(cfg, overrides, _callbacks)
- self.args.task = 'classify'
- self._legacy_transform_name = 'ultralytics.yolo.data.augment.ToTensor'
+ self.args.task = "classify"
+ self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"
def preprocess(self, img):
"""Converts input image to model-compatible data type."""
if not isinstance(img, torch.Tensor):
- is_legacy_transform = any(self._legacy_transform_name in str(transform)
- for transform in self.transforms.transforms)
+ is_legacy_transform = any(
+ self._legacy_transform_name in str(transform) for transform in self.transforms.transforms
+ )
if is_legacy_transform: # to handle legacy transforms
img = torch.stack([self.transforms(im) for im in img], dim=0)
else:
- img = torch.stack([self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img],
- dim=0)
+ img = torch.stack(
+ [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
+ )
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
diff --git a/ultralytics/models/yolo/classify/train.py b/ultralytics/models/yolo/classify/train.py
index 3c23b2a0..a7bbe772 100644
--- a/ultralytics/models/yolo/classify/train.py
+++ b/ultralytics/models/yolo/classify/train.py
@@ -33,23 +33,23 @@ class ClassificationTrainer(BaseTrainer):
"""Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
if overrides is None:
overrides = {}
- overrides['task'] = 'classify'
- if overrides.get('imgsz') is None:
- overrides['imgsz'] = 224
+ overrides["task"] = "classify"
+ if overrides.get("imgsz") is None:
+ overrides["imgsz"] = 224
super().__init__(cfg, overrides, _callbacks)
def set_model_attributes(self):
"""Set the YOLO model's class names from the loaded dataset."""
- self.model.names = self.data['names']
+ self.model.names = self.data["names"]
def get_model(self, cfg=None, weights=None, verbose=True):
"""Returns a modified PyTorch model configured for training YOLO."""
- model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
+ model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
for m in model.modules():
- if not self.args.pretrained and hasattr(m, 'reset_parameters'):
+ if not self.args.pretrained and hasattr(m, "reset_parameters"):
m.reset_parameters()
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
m.p = self.args.dropout # set dropout
@@ -64,32 +64,32 @@ class ClassificationTrainer(BaseTrainer):
model, ckpt = str(self.model), None
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
- if model.endswith('.pt'):
- self.model, ckpt = attempt_load_one_weight(model, device='cpu')
+ if model.endswith(".pt"):
+ self.model, ckpt = attempt_load_one_weight(model, device="cpu")
for p in self.model.parameters():
p.requires_grad = True # for training
- elif model.split('.')[-1] in ('yaml', 'yml'):
+ elif model.split(".")[-1] in ("yaml", "yml"):
self.model = self.get_model(cfg=model)
elif model in torchvision.models.__dict__:
- self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if self.args.pretrained else None)
+ self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None)
else:
- FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
- ClassificationModel.reshape_outputs(self.model, self.data['nc'])
+ FileNotFoundError(f"ERROR: model={model} not found locally or online. Please check model name.")
+ ClassificationModel.reshape_outputs(self.model, self.data["nc"])
return ckpt
- def build_dataset(self, img_path, mode='train', batch=None):
+ def build_dataset(self, img_path, mode="train", batch=None):
"""Creates a ClassificationDataset instance given an image path, and mode (train/test etc.)."""
- return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train', prefix=mode)
+ return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
- def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
+ def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = self.build_dataset(dataset_path, mode)
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
# Attach inference transforms
- if mode != 'train':
+ if mode != "train":
if is_parallel(self.model):
self.model.module.transforms = loader.dataset.torch_transforms
else:
@@ -98,27 +98,32 @@ class ClassificationTrainer(BaseTrainer):
def preprocess_batch(self, batch):
"""Preprocesses a batch of images and classes."""
- batch['img'] = batch['img'].to(self.device)
- batch['cls'] = batch['cls'].to(self.device)
+ batch["img"] = batch["img"].to(self.device)
+ batch["cls"] = batch["cls"].to(self.device)
return batch
def progress_string(self):
"""Returns a formatted string showing training progress."""
- return ('\n' + '%11s' * (4 + len(self.loss_names))) % \
- ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
+ return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
+ "Epoch",
+ "GPU_mem",
+ *self.loss_names,
+ "Instances",
+ "Size",
+ )
def get_validator(self):
"""Returns an instance of ClassificationValidator for validation."""
- self.loss_names = ['loss']
+ self.loss_names = ["loss"]
return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks)
- def label_loss_items(self, loss_items=None, prefix='train'):
+ def label_loss_items(self, loss_items=None, prefix="train"):
"""
Returns a loss dict with labelled training loss items tensor.
Not needed for classification but necessary for segmentation & detection
"""
- keys = [f'{prefix}/{x}' for x in self.loss_names]
+ keys = [f"{prefix}/{x}" for x in self.loss_names]
if loss_items is None:
return keys
loss_items = [round(float(loss_items), 5)]
@@ -134,19 +139,20 @@ class ClassificationTrainer(BaseTrainer):
if f.exists():
strip_optimizer(f) # strip optimizers
if f is self.best:
- LOGGER.info(f'\nValidating {f}...')
+ LOGGER.info(f"\nValidating {f}...")
self.validator.args.data = self.args.data
self.validator.args.plots = self.args.plots
self.metrics = self.validator(model=f)
- self.metrics.pop('fitness', None)
- self.run_callbacks('on_fit_epoch_end')
+ self.metrics.pop("fitness", None)
+ self.run_callbacks("on_fit_epoch_end")
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
def plot_training_samples(self, batch, ni):
"""Plots training samples with their annotations."""
plot_images(
- images=batch['img'],
- batch_idx=torch.arange(len(batch['img'])),
- cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models
- fname=self.save_dir / f'train_batch{ni}.jpg',
- on_plot=self.on_plot)
+ images=batch["img"],
+ batch_idx=torch.arange(len(batch["img"])),
+ cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models
+ fname=self.save_dir / f"train_batch{ni}.jpg",
+ on_plot=self.on_plot,
+ )
diff --git a/ultralytics/models/yolo/classify/val.py b/ultralytics/models/yolo/classify/val.py
index 3ebf3808..de3cff2b 100644
--- a/ultralytics/models/yolo/classify/val.py
+++ b/ultralytics/models/yolo/classify/val.py
@@ -31,43 +31,42 @@ class ClassificationValidator(BaseValidator):
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.targets = None
self.pred = None
- self.args.task = 'classify'
+ self.args.task = "classify"
self.metrics = ClassifyMetrics()
def get_desc(self):
"""Returns a formatted string summarizing classification metrics."""
- return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc')
+ return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
def init_metrics(self, model):
"""Initialize confusion matrix, class names, and top-1 and top-5 accuracy."""
self.names = model.names
self.nc = len(model.names)
- self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task='classify')
+ self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task="classify")
self.pred = []
self.targets = []
def preprocess(self, batch):
"""Preprocesses input batch and returns it."""
- batch['img'] = batch['img'].to(self.device, non_blocking=True)
- batch['img'] = batch['img'].half() if self.args.half else batch['img'].float()
- batch['cls'] = batch['cls'].to(self.device)
+ batch["img"] = batch["img"].to(self.device, non_blocking=True)
+ batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
+ batch["cls"] = batch["cls"].to(self.device)
return batch
def update_metrics(self, preds, batch):
"""Updates running metrics with model predictions and batch targets."""
n5 = min(len(self.names), 5)
self.pred.append(preds.argsort(1, descending=True)[:, :n5])
- self.targets.append(batch['cls'])
+ self.targets.append(batch["cls"])
def finalize_metrics(self, *args, **kwargs):
"""Finalizes metrics of the model such as confusion_matrix and speed."""
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
if self.args.plots:
for normalize in True, False:
- self.confusion_matrix.plot(save_dir=self.save_dir,
- names=self.names.values(),
- normalize=normalize,
- on_plot=self.on_plot)
+ self.confusion_matrix.plot(
+ save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
+ )
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
self.metrics.save_dir = self.save_dir
@@ -88,24 +87,27 @@ class ClassificationValidator(BaseValidator):
def print_results(self):
"""Prints evaluation metrics for YOLO object detection model."""
- pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
- LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
+ pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format
+ LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
def plot_val_samples(self, batch, ni):
"""Plot validation image samples."""
plot_images(
- images=batch['img'],
- batch_idx=torch.arange(len(batch['img'])),
- cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models
- fname=self.save_dir / f'val_batch{ni}_labels.jpg',
+ images=batch["img"],
+ batch_idx=torch.arange(len(batch["img"])),
+ cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models
+ fname=self.save_dir / f"val_batch{ni}_labels.jpg",
names=self.names,
- on_plot=self.on_plot)
+ on_plot=self.on_plot,
+ )
def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""
- plot_images(batch['img'],
- batch_idx=torch.arange(len(batch['img'])),
- cls=torch.argmax(preds, dim=1),
- fname=self.save_dir / f'val_batch{ni}_pred.jpg',
- names=self.names,
- on_plot=self.on_plot) # pred
+ plot_images(
+ batch["img"],
+ batch_idx=torch.arange(len(batch["img"])),
+ cls=torch.argmax(preds, dim=1),
+ fname=self.save_dir / f"val_batch{ni}_pred.jpg",
+ names=self.names,
+ on_plot=self.on_plot,
+ ) # pred
diff --git a/ultralytics/models/yolo/detect/__init__.py b/ultralytics/models/yolo/detect/__init__.py
index 20fc0c48..5f3e62c1 100644
--- a/ultralytics/models/yolo/detect/__init__.py
+++ b/ultralytics/models/yolo/detect/__init__.py
@@ -4,4 +4,4 @@ from .predict import DetectionPredictor
from .train import DetectionTrainer
from .val import DetectionValidator
-__all__ = 'DetectionPredictor', 'DetectionTrainer', 'DetectionValidator'
+__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator"
diff --git a/ultralytics/models/yolo/detect/predict.py b/ultralytics/models/yolo/detect/predict.py
index 28cbd7ce..3a0c6287 100644
--- a/ultralytics/models/yolo/detect/predict.py
+++ b/ultralytics/models/yolo/detect/predict.py
@@ -22,12 +22,14 @@ class DetectionPredictor(BasePredictor):
def postprocess(self, preds, img, orig_imgs):
"""Post-processes predictions and returns a list of Results objects."""
- preds = ops.non_max_suppression(preds,
- self.args.conf,
- self.args.iou,
- agnostic=self.args.agnostic_nms,
- max_det=self.args.max_det,
- classes=self.args.classes)
+ preds = ops.non_max_suppression(
+ preds,
+ self.args.conf,
+ self.args.iou,
+ agnostic=self.args.agnostic_nms,
+ max_det=self.args.max_det,
+ classes=self.args.classes,
+ )
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
diff --git a/ultralytics/models/yolo/detect/train.py b/ultralytics/models/yolo/detect/train.py
index fc656984..3326512b 100644
--- a/ultralytics/models/yolo/detect/train.py
+++ b/ultralytics/models/yolo/detect/train.py
@@ -30,7 +30,7 @@ class DetectionTrainer(BaseTrainer):
```
"""
- def build_dataset(self, img_path, mode='train', batch=None):
+ def build_dataset(self, img_path, mode="train", batch=None):
"""
Build YOLO Dataset.
@@ -40,33 +40,37 @@ class DetectionTrainer(BaseTrainer):
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
- return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs)
+ return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
- def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
+ def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
"""Construct and return dataloader."""
- assert mode in ['train', 'val']
+ assert mode in ["train", "val"]
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = self.build_dataset(dataset_path, mode, batch_size)
- shuffle = mode == 'train'
- if getattr(dataset, 'rect', False) and shuffle:
+ shuffle = mode == "train"
+ if getattr(dataset, "rect", False) and shuffle:
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
shuffle = False
- workers = self.args.workers if mode == 'train' else self.args.workers * 2
+ workers = self.args.workers if mode == "train" else self.args.workers * 2
return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
def preprocess_batch(self, batch):
"""Preprocesses a batch of images by scaling and converting to float."""
- batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255
+ batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
if self.args.multi_scale:
- imgs = batch['img']
- sz = (random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride) // self.stride *
- self.stride) # size
+ imgs = batch["img"]
+ sz = (
+ random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride)
+ // self.stride
+ * self.stride
+ ) # size
sf = sz / max(imgs.shape[2:]) # scale factor
if sf != 1:
- ns = [math.ceil(x * sf / self.stride) * self.stride
- for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
- imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
- batch['img'] = imgs
+ ns = [
+ math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]
+ ] # new shape (stretched to gs-multiple)
+ imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
+ batch["img"] = imgs
return batch
def set_model_attributes(self):
@@ -74,33 +78,32 @@ class DetectionTrainer(BaseTrainer):
# self.args.box *= 3 / nl # scale to layers
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
- self.model.nc = self.data['nc'] # attach number of classes to model
- self.model.names = self.data['names'] # attach class names to model
+ self.model.nc = self.data["nc"] # attach number of classes to model
+ self.model.names = self.data["names"] # attach class names to model
self.model.args = self.args # attach hyperparameters to model
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return a YOLO detection model."""
- model = DetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
+ model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def get_validator(self):
"""Returns a DetectionValidator for YOLO model validation."""
- self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
- return yolo.detect.DetectionValidator(self.test_loader,
- save_dir=self.save_dir,
- args=copy(self.args),
- _callbacks=self.callbacks)
+ self.loss_names = "box_loss", "cls_loss", "dfl_loss"
+ return yolo.detect.DetectionValidator(
+ self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
+ )
- def label_loss_items(self, loss_items=None, prefix='train'):
+ def label_loss_items(self, loss_items=None, prefix="train"):
"""
Returns a loss dict with labelled training loss items tensor.
Not needed for classification but necessary for segmentation & detection
"""
- keys = [f'{prefix}/{x}' for x in self.loss_names]
+ keys = [f"{prefix}/{x}" for x in self.loss_names]
if loss_items is not None:
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
return dict(zip(keys, loss_items))
@@ -109,18 +112,25 @@ class DetectionTrainer(BaseTrainer):
def progress_string(self):
"""Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
- return ('\n' + '%11s' *
- (4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
+ return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
+ "Epoch",
+ "GPU_mem",
+ *self.loss_names,
+ "Instances",
+ "Size",
+ )
def plot_training_samples(self, batch, ni):
"""Plots training samples with their annotations."""
- plot_images(images=batch['img'],
- batch_idx=batch['batch_idx'],
- cls=batch['cls'].squeeze(-1),
- bboxes=batch['bboxes'],
- paths=batch['im_file'],
- fname=self.save_dir / f'train_batch{ni}.jpg',
- on_plot=self.on_plot)
+ plot_images(
+ images=batch["img"],
+ batch_idx=batch["batch_idx"],
+ cls=batch["cls"].squeeze(-1),
+ bboxes=batch["bboxes"],
+ paths=batch["im_file"],
+ fname=self.save_dir / f"train_batch{ni}.jpg",
+ on_plot=self.on_plot,
+ )
def plot_metrics(self):
"""Plots metrics from a CSV file."""
@@ -128,6 +138,6 @@ class DetectionTrainer(BaseTrainer):
def plot_training_labels(self):
"""Create a labeled training plot of the YOLO model."""
- boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0)
- cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0)
- plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot)
+ boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
+ cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
+ plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
diff --git a/ultralytics/models/yolo/detect/val.py b/ultralytics/models/yolo/detect/val.py
index 295c482c..7ac68e50 100644
--- a/ultralytics/models/yolo/detect/val.py
+++ b/ultralytics/models/yolo/detect/val.py
@@ -34,7 +34,7 @@ class DetectionValidator(BaseValidator):
self.nt_per_class = None
self.is_coco = False
self.class_map = None
- self.args.task = 'detect'
+ self.args.task = "detect"
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
self.niou = self.iouv.numel()
@@ -42,25 +42,30 @@ class DetectionValidator(BaseValidator):
def preprocess(self, batch):
"""Preprocesses batch of images for YOLO training."""
- batch['img'] = batch['img'].to(self.device, non_blocking=True)
- batch['img'] = (batch['img'].half() if self.args.half else batch['img'].float()) / 255
- for k in ['batch_idx', 'cls', 'bboxes']:
+ batch["img"] = batch["img"].to(self.device, non_blocking=True)
+ batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
+ for k in ["batch_idx", "cls", "bboxes"]:
batch[k] = batch[k].to(self.device)
if self.args.save_hybrid:
- height, width = batch['img'].shape[2:]
- nb = len(batch['img'])
- bboxes = batch['bboxes'] * torch.tensor((width, height, width, height), device=self.device)
- self.lb = [
- torch.cat([batch['cls'][batch['batch_idx'] == i], bboxes[batch['batch_idx'] == i]], dim=-1)
- for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling
+ height, width = batch["img"].shape[2:]
+ nb = len(batch["img"])
+ bboxes = batch["bboxes"] * torch.tensor((width, height, width, height), device=self.device)
+ self.lb = (
+ [
+ torch.cat([batch["cls"][batch["batch_idx"] == i], bboxes[batch["batch_idx"] == i]], dim=-1)
+ for i in range(nb)
+ ]
+ if self.args.save_hybrid
+ else []
+ ) # for autolabelling
return batch
def init_metrics(self, model):
"""Initialize evaluation metrics for YOLO."""
- val = self.data.get(self.args.split, '') # validation path
- self.is_coco = isinstance(val, str) and 'coco' in val and val.endswith(f'{os.sep}val2017.txt') # is COCO
+ val = self.data.get(self.args.split, "") # validation path
+ self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000))
self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO
self.names = model.names
@@ -74,26 +79,28 @@ class DetectionValidator(BaseValidator):
def get_desc(self):
"""Return a formatted string summarizing class metrics of YOLO model."""
- return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)')
+ return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
def postprocess(self, preds):
"""Apply Non-maximum suppression to prediction outputs."""
- return ops.non_max_suppression(preds,
- self.args.conf,
- self.args.iou,
- labels=self.lb,
- multi_label=True,
- agnostic=self.args.single_cls,
- max_det=self.args.max_det)
+ return ops.non_max_suppression(
+ preds,
+ self.args.conf,
+ self.args.iou,
+ labels=self.lb,
+ multi_label=True,
+ agnostic=self.args.single_cls,
+ max_det=self.args.max_det,
+ )
def _prepare_batch(self, si, batch):
"""Prepares a batch of images and annotations for validation."""
- idx = batch['batch_idx'] == si
- cls = batch['cls'][idx].squeeze(-1)
- bbox = batch['bboxes'][idx]
- ori_shape = batch['ori_shape'][si]
- imgsz = batch['img'].shape[2:]
- ratio_pad = batch['ratio_pad'][si]
+ idx = batch["batch_idx"] == si
+ cls = batch["cls"][idx].squeeze(-1)
+ bbox = batch["bboxes"][idx]
+ ori_shape = batch["ori_shape"][si]
+ imgsz = batch["img"].shape[2:]
+ ratio_pad = batch["ratio_pad"][si]
if len(cls):
bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels
@@ -103,8 +110,9 @@ class DetectionValidator(BaseValidator):
def _prepare_pred(self, pred, pbatch):
"""Prepares a batch of images and annotations for validation."""
predn = pred.clone()
- ops.scale_boxes(pbatch['imgsz'], predn[:, :4], pbatch['ori_shape'],
- ratio_pad=pbatch['ratio_pad']) # native-space pred
+ ops.scale_boxes(
+ pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
+ ) # native-space pred
return predn
def update_metrics(self, preds, batch):
@@ -112,19 +120,21 @@ class DetectionValidator(BaseValidator):
for si, pred in enumerate(preds):
self.seen += 1
npr = len(pred)
- stat = dict(conf=torch.zeros(0, device=self.device),
- pred_cls=torch.zeros(0, device=self.device),
- tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device))
+ stat = dict(
+ conf=torch.zeros(0, device=self.device),
+ pred_cls=torch.zeros(0, device=self.device),
+ tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
+ )
pbatch = self._prepare_batch(si, batch)
- cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox')
+ cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
nl = len(cls)
- stat['target_cls'] = cls
+ stat["target_cls"] = cls
if npr == 0:
if nl:
for k in self.stats.keys():
self.stats[k].append(stat[k])
# TODO: obb has not supported confusion_matrix yet.
- if self.args.plots and self.args.task != 'obb':
+ if self.args.plots and self.args.task != "obb":
self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
continue
@@ -132,24 +142,24 @@ class DetectionValidator(BaseValidator):
if self.args.single_cls:
pred[:, 5] = 0
predn = self._prepare_pred(pred, pbatch)
- stat['conf'] = predn[:, 4]
- stat['pred_cls'] = predn[:, 5]
+ stat["conf"] = predn[:, 4]
+ stat["pred_cls"] = predn[:, 5]
# Evaluate
if nl:
- stat['tp'] = self._process_batch(predn, bbox, cls)
+ stat["tp"] = self._process_batch(predn, bbox, cls)
# TODO: obb has not supported confusion_matrix yet.
- if self.args.plots and self.args.task != 'obb':
+ if self.args.plots and self.args.task != "obb":
self.confusion_matrix.process_batch(predn, bbox, cls)
for k in self.stats.keys():
self.stats[k].append(stat[k])
# Save
if self.args.save_json:
- self.pred_to_json(predn, batch['im_file'][si])
+ self.pred_to_json(predn, batch["im_file"][si])
if self.args.save_txt:
- file = self.save_dir / 'labels' / f'{Path(batch["im_file"][si]).stem}.txt'
- self.save_one_txt(predn, self.args.save_conf, pbatch['ori_shape'], file)
+ file = self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt'
+ self.save_one_txt(predn, self.args.save_conf, pbatch["ori_shape"], file)
def finalize_metrics(self, *args, **kwargs):
"""Set final values for metrics speed and confusion matrix."""
@@ -159,19 +169,19 @@ class DetectionValidator(BaseValidator):
def get_stats(self):
"""Returns metrics statistics and results dictionary."""
stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
- if len(stats) and stats['tp'].any():
+ if len(stats) and stats["tp"].any():
self.metrics.process(**stats)
- self.nt_per_class = np.bincount(stats['target_cls'].astype(int),
- minlength=self.nc) # number of targets per class
+ self.nt_per_class = np.bincount(
+ stats["target_cls"].astype(int), minlength=self.nc
+ ) # number of targets per class
return self.metrics.results_dict
def print_results(self):
"""Prints training/validation set metrics per class."""
- pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format
- LOGGER.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
+ pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
+ LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
if self.nt_per_class.sum() == 0:
- LOGGER.warning(
- f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels')
+ LOGGER.warning(f"WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels")
# Print results per class
if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
@@ -180,10 +190,9 @@ class DetectionValidator(BaseValidator):
if self.args.plots:
for normalize in True, False:
- self.confusion_matrix.plot(save_dir=self.save_dir,
- names=self.names.values(),
- normalize=normalize,
- on_plot=self.on_plot)
+ self.confusion_matrix.plot(
+ save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
+ )
def _process_batch(self, detections, gt_bboxes, gt_cls):
"""
@@ -201,7 +210,7 @@ class DetectionValidator(BaseValidator):
iou = box_iou(gt_bboxes, detections[:, :4])
return self.match_predictions(detections[:, 5], gt_cls, iou)
- def build_dataset(self, img_path, mode='val', batch=None):
+ def build_dataset(self, img_path, mode="val", batch=None):
"""
Build YOLO Dataset.
@@ -214,28 +223,32 @@ class DetectionValidator(BaseValidator):
def get_dataloader(self, dataset_path, batch_size):
"""Construct and return dataloader."""
- dataset = self.build_dataset(dataset_path, batch=batch_size, mode='val')
+ dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader
def plot_val_samples(self, batch, ni):
"""Plot validation image samples."""
- plot_images(batch['img'],
- batch['batch_idx'],
- batch['cls'].squeeze(-1),
- batch['bboxes'],
- paths=batch['im_file'],
- fname=self.save_dir / f'val_batch{ni}_labels.jpg',
- names=self.names,
- on_plot=self.on_plot)
+ plot_images(
+ batch["img"],
+ batch["batch_idx"],
+ batch["cls"].squeeze(-1),
+ batch["bboxes"],
+ paths=batch["im_file"],
+ fname=self.save_dir / f"val_batch{ni}_labels.jpg",
+ names=self.names,
+ on_plot=self.on_plot,
+ )
def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""
- plot_images(batch['img'],
- *output_to_target(preds, max_det=self.args.max_det),
- paths=batch['im_file'],
- fname=self.save_dir / f'val_batch{ni}_pred.jpg',
- names=self.names,
- on_plot=self.on_plot) # pred
+ plot_images(
+ batch["img"],
+ *output_to_target(preds, max_det=self.args.max_det),
+ paths=batch["im_file"],
+ fname=self.save_dir / f"val_batch{ni}_pred.jpg",
+ names=self.names,
+ on_plot=self.on_plot,
+ ) # pred
def save_one_txt(self, predn, save_conf, shape, file):
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
@@ -243,8 +256,8 @@ class DetectionValidator(BaseValidator):
for *xyxy, conf, cls in predn.tolist():
xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
- with open(file, 'a') as f:
- f.write(('%g ' * len(line)).rstrip() % line + '\n')
+ with open(file, "a") as f:
+ f.write(("%g " * len(line)).rstrip() % line + "\n")
def pred_to_json(self, predn, filename):
"""Serialize YOLO predictions to COCO json format."""
@@ -253,28 +266,31 @@ class DetectionValidator(BaseValidator):
box = ops.xyxy2xywh(predn[:, :4]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
for p, b in zip(predn.tolist(), box.tolist()):
- self.jdict.append({
- 'image_id': image_id,
- 'category_id': self.class_map[int(p[5])],
- 'bbox': [round(x, 3) for x in b],
- 'score': round(p[4], 5)})
+ self.jdict.append(
+ {
+ "image_id": image_id,
+ "category_id": self.class_map[int(p[5])],
+ "bbox": [round(x, 3) for x in b],
+ "score": round(p[4], 5),
+ }
+ )
def eval_json(self, stats):
"""Evaluates YOLO output in JSON format and returns performance statistics."""
if self.args.save_json and self.is_coco and len(self.jdict):
- anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
- pred_json = self.save_dir / 'predictions.json' # predictions
- LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
+ anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
+ pred_json = self.save_dir / "predictions.json" # predictions
+ LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
- check_requirements('pycocotools>=2.0.6')
+ check_requirements("pycocotools>=2.0.6")
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
- assert x.is_file(), f'{x} file not found'
+ assert x.is_file(), f"{x} file not found"
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
- eval = COCOeval(anno, pred, 'bbox')
+ eval = COCOeval(anno, pred, "bbox")
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
eval.evaluate()
@@ -282,5 +298,5 @@ class DetectionValidator(BaseValidator):
eval.summarize()
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
except Exception as e:
- LOGGER.warning(f'pycocotools unable to run: {e}')
+ LOGGER.warning(f"pycocotools unable to run: {e}")
return stats
diff --git a/ultralytics/models/yolo/model.py b/ultralytics/models/yolo/model.py
index eb2225d8..858e42c5 100644
--- a/ultralytics/models/yolo/model.py
+++ b/ultralytics/models/yolo/model.py
@@ -12,28 +12,34 @@ class YOLO(Model):
def task_map(self):
"""Map head to model, trainer, validator, and predictor classes."""
return {
- 'classify': {
- 'model': ClassificationModel,
- 'trainer': yolo.classify.ClassificationTrainer,
- 'validator': yolo.classify.ClassificationValidator,
- 'predictor': yolo.classify.ClassificationPredictor, },
- 'detect': {
- 'model': DetectionModel,
- 'trainer': yolo.detect.DetectionTrainer,
- 'validator': yolo.detect.DetectionValidator,
- 'predictor': yolo.detect.DetectionPredictor, },
- 'segment': {
- 'model': SegmentationModel,
- 'trainer': yolo.segment.SegmentationTrainer,
- 'validator': yolo.segment.SegmentationValidator,
- 'predictor': yolo.segment.SegmentationPredictor, },
- 'pose': {
- 'model': PoseModel,
- 'trainer': yolo.pose.PoseTrainer,
- 'validator': yolo.pose.PoseValidator,
- 'predictor': yolo.pose.PosePredictor, },
- 'obb': {
- 'model': OBBModel,
- 'trainer': yolo.obb.OBBTrainer,
- 'validator': yolo.obb.OBBValidator,
- 'predictor': yolo.obb.OBBPredictor, }, }
+ "classify": {
+ "model": ClassificationModel,
+ "trainer": yolo.classify.ClassificationTrainer,
+ "validator": yolo.classify.ClassificationValidator,
+ "predictor": yolo.classify.ClassificationPredictor,
+ },
+ "detect": {
+ "model": DetectionModel,
+ "trainer": yolo.detect.DetectionTrainer,
+ "validator": yolo.detect.DetectionValidator,
+ "predictor": yolo.detect.DetectionPredictor,
+ },
+ "segment": {
+ "model": SegmentationModel,
+ "trainer": yolo.segment.SegmentationTrainer,
+ "validator": yolo.segment.SegmentationValidator,
+ "predictor": yolo.segment.SegmentationPredictor,
+ },
+ "pose": {
+ "model": PoseModel,
+ "trainer": yolo.pose.PoseTrainer,
+ "validator": yolo.pose.PoseValidator,
+ "predictor": yolo.pose.PosePredictor,
+ },
+ "obb": {
+ "model": OBBModel,
+ "trainer": yolo.obb.OBBTrainer,
+ "validator": yolo.obb.OBBValidator,
+ "predictor": yolo.obb.OBBPredictor,
+ },
+ }
diff --git a/ultralytics/models/yolo/obb/__init__.py b/ultralytics/models/yolo/obb/__init__.py
index 09f10481..f60349a7 100644
--- a/ultralytics/models/yolo/obb/__init__.py
+++ b/ultralytics/models/yolo/obb/__init__.py
@@ -4,4 +4,4 @@ from .predict import OBBPredictor
from .train import OBBTrainer
from .val import OBBValidator
-__all__ = 'OBBPredictor', 'OBBTrainer', 'OBBValidator'
+__all__ = "OBBPredictor", "OBBTrainer", "OBBValidator"
diff --git a/ultralytics/models/yolo/obb/predict.py b/ultralytics/models/yolo/obb/predict.py
index 662266f7..572082b8 100644
--- a/ultralytics/models/yolo/obb/predict.py
+++ b/ultralytics/models/yolo/obb/predict.py
@@ -25,26 +25,27 @@ class OBBPredictor(DetectionPredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initializes OBBPredictor with optional model and data configuration overrides."""
super().__init__(cfg, overrides, _callbacks)
- self.args.task = 'obb'
+ self.args.task = "obb"
def postprocess(self, preds, img, orig_imgs):
"""Post-processes predictions and returns a list of Results objects."""
- preds = ops.non_max_suppression(preds,
- self.args.conf,
- self.args.iou,
- agnostic=self.args.agnostic_nms,
- max_det=self.args.max_det,
- nc=len(self.model.names),
- classes=self.args.classes,
- rotated=True)
+ preds = ops.non_max_suppression(
+ preds,
+ self.args.conf,
+ self.args.iou,
+ agnostic=self.args.agnostic_nms,
+ max_det=self.args.max_det,
+ nc=len(self.model.names),
+ classes=self.args.classes,
+ rotated=True,
+ )
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
- for i, (pred, orig_img) in enumerate(zip(preds, orig_imgs)):
+ for i, (pred, orig_img, img_path) in enumerate(zip(preds, orig_imgs, self.batch[0])):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape, xywh=True)
- img_path = self.batch[0][i]
# xywh, r, conf, cls
obb = torch.cat([pred[:, :4], pred[:, -1:], pred[:, 4:6]], dim=-1)
results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb))
diff --git a/ultralytics/models/yolo/obb/train.py b/ultralytics/models/yolo/obb/train.py
index 0d1284a7..43ebaecd 100644
--- a/ultralytics/models/yolo/obb/train.py
+++ b/ultralytics/models/yolo/obb/train.py
@@ -25,12 +25,12 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
"""Initialize a OBBTrainer object with given arguments."""
if overrides is None:
overrides = {}
- overrides['task'] = 'obb'
+ overrides["task"] = "obb"
super().__init__(cfg, overrides, _callbacks)
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return OBBModel initialized with specified config and weights."""
- model = OBBModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
+ model = OBBModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
@@ -38,5 +38,5 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
def get_validator(self):
"""Return an instance of OBBValidator for validation of YOLO model."""
- self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
+ self.loss_names = "box_loss", "cls_loss", "dfl_loss"
return yolo.obb.OBBValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
diff --git a/ultralytics/models/yolo/obb/val.py b/ultralytics/models/yolo/obb/val.py
index a5c030a8..40dbb604 100644
--- a/ultralytics/models/yolo/obb/val.py
+++ b/ultralytics/models/yolo/obb/val.py
@@ -27,26 +27,28 @@ class OBBValidator(DetectionValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
- self.args.task = 'obb'
+ self.args.task = "obb"
self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True, on_plot=self.on_plot)
def init_metrics(self, model):
"""Initialize evaluation metrics for YOLO."""
super().init_metrics(model)
- val = self.data.get(self.args.split, '') # validation path
- self.is_dota = isinstance(val, str) and 'DOTA' in val # is COCO
+ val = self.data.get(self.args.split, "") # validation path
+ self.is_dota = isinstance(val, str) and "DOTA" in val # is COCO
def postprocess(self, preds):
"""Apply Non-maximum suppression to prediction outputs."""
- return ops.non_max_suppression(preds,
- self.args.conf,
- self.args.iou,
- labels=self.lb,
- nc=self.nc,
- multi_label=True,
- agnostic=self.args.single_cls,
- max_det=self.args.max_det,
- rotated=True)
+ return ops.non_max_suppression(
+ preds,
+ self.args.conf,
+ self.args.iou,
+ labels=self.lb,
+ nc=self.nc,
+ multi_label=True,
+ agnostic=self.args.single_cls,
+ max_det=self.args.max_det,
+ rotated=True,
+ )
def _process_batch(self, detections, gt_bboxes, gt_cls):
"""
@@ -66,12 +68,12 @@ class OBBValidator(DetectionValidator):
def _prepare_batch(self, si, batch):
"""Prepares and returns a batch for OBB validation."""
- idx = batch['batch_idx'] == si
- cls = batch['cls'][idx].squeeze(-1)
- bbox = batch['bboxes'][idx]
- ori_shape = batch['ori_shape'][si]
- imgsz = batch['img'].shape[2:]
- ratio_pad = batch['ratio_pad'][si]
+ idx = batch["batch_idx"] == si
+ cls = batch["cls"][idx].squeeze(-1)
+ bbox = batch["bboxes"][idx]
+ ori_shape = batch["ori_shape"][si]
+ imgsz = batch["img"].shape[2:]
+ ratio_pad = batch["ratio_pad"][si]
if len(cls):
bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
@@ -81,18 +83,21 @@ class OBBValidator(DetectionValidator):
def _prepare_pred(self, pred, pbatch):
"""Prepares and returns a batch for OBB validation with scaled and padded bounding boxes."""
predn = pred.clone()
- ops.scale_boxes(pbatch['imgsz'], predn[:, :4], pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'],
- xywh=True) # native-space pred
+ ops.scale_boxes(
+ pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
+ ) # native-space pred
return predn
def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""
- plot_images(batch['img'],
- *output_to_rotated_target(preds, max_det=self.args.max_det),
- paths=batch['im_file'],
- fname=self.save_dir / f'val_batch{ni}_pred.jpg',
- names=self.names,
- on_plot=self.on_plot) # pred
+ plot_images(
+ batch["img"],
+ *output_to_rotated_target(preds, max_det=self.args.max_det),
+ paths=batch["im_file"],
+ fname=self.save_dir / f"val_batch{ni}_pred.jpg",
+ names=self.names,
+ on_plot=self.on_plot,
+ ) # pred
def pred_to_json(self, predn, filename):
"""Serialize YOLO predictions to COCO json format."""
@@ -101,12 +106,15 @@ class OBBValidator(DetectionValidator):
rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())):
- self.jdict.append({
- 'image_id': image_id,
- 'category_id': self.class_map[int(predn[i, 5].item())],
- 'score': round(predn[i, 4].item(), 5),
- 'rbox': [round(x, 3) for x in r],
- 'poly': [round(x, 3) for x in b]})
+ self.jdict.append(
+ {
+ "image_id": image_id,
+ "category_id": self.class_map[int(predn[i, 5].item())],
+ "score": round(predn[i, 4].item(), 5),
+ "rbox": [round(x, 3) for x in r],
+ "poly": [round(x, 3) for x in b],
+ }
+ )
def save_one_txt(self, predn, save_conf, shape, file):
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
@@ -116,8 +124,8 @@ class OBBValidator(DetectionValidator):
xywha[:, :4] /= gn
xyxyxyxy = ops.xywhr2xyxyxyxy(xywha).view(-1).tolist() # normalized xywh
line = (cls, *xyxyxyxy, conf) if save_conf else (cls, *xyxyxyxy) # label format
- with open(file, 'a') as f:
- f.write(('%g ' * len(line)).rstrip() % line + '\n')
+ with open(file, "a") as f:
+ f.write(("%g " * len(line)).rstrip() % line + "\n")
def eval_json(self, stats):
"""Evaluates YOLO output in JSON format and returns performance statistics."""
@@ -125,42 +133,43 @@ class OBBValidator(DetectionValidator):
import json
import re
from collections import defaultdict
- pred_json = self.save_dir / 'predictions.json' # predictions
- pred_txt = self.save_dir / 'predictions_txt' # predictions
+
+ pred_json = self.save_dir / "predictions.json" # predictions
+ pred_txt = self.save_dir / "predictions_txt" # predictions
pred_txt.mkdir(parents=True, exist_ok=True)
data = json.load(open(pred_json))
# Save split results
- LOGGER.info(f'Saving predictions with DOTA format to {str(pred_txt)}...')
+ LOGGER.info(f"Saving predictions with DOTA format to {str(pred_txt)}...")
for d in data:
- image_id = d['image_id']
- score = d['score']
- classname = self.names[d['category_id']].replace(' ', '-')
+ image_id = d["image_id"]
+ score = d["score"]
+ classname = self.names[d["category_id"]].replace(" ", "-")
- lines = '{} {} {} {} {} {} {} {} {} {}\n'.format(
+ lines = "{} {} {} {} {} {} {} {} {} {}\n".format(
image_id,
score,
- d['poly'][0],
- d['poly'][1],
- d['poly'][2],
- d['poly'][3],
- d['poly'][4],
- d['poly'][5],
- d['poly'][6],
- d['poly'][7],
+ d["poly"][0],
+ d["poly"][1],
+ d["poly"][2],
+ d["poly"][3],
+ d["poly"][4],
+ d["poly"][5],
+ d["poly"][6],
+ d["poly"][7],
)
- with open(str(pred_txt / f'Task1_{classname}') + '.txt', 'a') as f:
+ with open(str(pred_txt / f"Task1_{classname}") + ".txt", "a") as f:
f.writelines(lines)
# Save merged results, this could result slightly lower map than using official merging script,
# because of the probiou calculation.
- pred_merged_txt = self.save_dir / 'predictions_merged_txt' # predictions
+ pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions
pred_merged_txt.mkdir(parents=True, exist_ok=True)
merged_results = defaultdict(list)
- LOGGER.info(f'Saving merged predictions with DOTA format to {str(pred_merged_txt)}...')
+ LOGGER.info(f"Saving merged predictions with DOTA format to {str(pred_merged_txt)}...")
for d in data:
- image_id = d['image_id'].split('__')[0]
- pattern = re.compile(r'\d+___\d+')
- x, y = (int(c) for c in re.findall(pattern, d['image_id'])[0].split('___'))
- bbox, score, cls = d['rbox'], d['score'], d['category_id']
+ image_id = d["image_id"].split("__")[0]
+ pattern = re.compile(r"\d+___\d+")
+ x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
+ bbox, score, cls = d["rbox"], d["score"], d["category_id"]
bbox[0] += x
bbox[1] += y
bbox.extend([score, cls])
@@ -178,11 +187,11 @@ class OBBValidator(DetectionValidator):
b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
- classname = self.names[int(x[-1])].replace(' ', '-')
+ classname = self.names[int(x[-1])].replace(" ", "-")
poly = [round(i, 3) for i in x[:-2]]
score = round(x[-2], 3)
- lines = '{} {} {} {} {} {} {} {} {} {}\n'.format(
+ lines = "{} {} {} {} {} {} {} {} {} {}\n".format(
image_id,
score,
poly[0],
@@ -194,7 +203,7 @@ class OBBValidator(DetectionValidator):
poly[6],
poly[7],
)
- with open(str(pred_merged_txt / f'Task1_{classname}') + '.txt', 'a') as f:
+ with open(str(pred_merged_txt / f"Task1_{classname}") + ".txt", "a") as f:
f.writelines(lines)
return stats
diff --git a/ultralytics/models/yolo/pose/__init__.py b/ultralytics/models/yolo/pose/__init__.py
index 2a79f0f3..d5669430 100644
--- a/ultralytics/models/yolo/pose/__init__.py
+++ b/ultralytics/models/yolo/pose/__init__.py
@@ -4,4 +4,4 @@ from .predict import PosePredictor
from .train import PoseTrainer
from .val import PoseValidator
-__all__ = 'PoseTrainer', 'PoseValidator', 'PosePredictor'
+__all__ = "PoseTrainer", "PoseValidator", "PosePredictor"
diff --git a/ultralytics/models/yolo/pose/predict.py b/ultralytics/models/yolo/pose/predict.py
index d00cea02..7c55709f 100644
--- a/ultralytics/models/yolo/pose/predict.py
+++ b/ultralytics/models/yolo/pose/predict.py
@@ -23,20 +23,24 @@ class PosePredictor(DetectionPredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initializes PosePredictor, sets task to 'pose' and logs a warning for using 'mps' as device."""
super().__init__(cfg, overrides, _callbacks)
- self.args.task = 'pose'
- if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
- LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
- 'See https://github.com/ultralytics/ultralytics/issues/4031.')
+ self.args.task = "pose"
+ if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
+ LOGGER.warning(
+ "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
+ "See https://github.com/ultralytics/ultralytics/issues/4031."
+ )
def postprocess(self, preds, img, orig_imgs):
"""Return detection results for a given input image or list of images."""
- preds = ops.non_max_suppression(preds,
- self.args.conf,
- self.args.iou,
- agnostic=self.args.agnostic_nms,
- max_det=self.args.max_det,
- classes=self.args.classes,
- nc=len(self.model.names))
+ preds = ops.non_max_suppression(
+ preds,
+ self.args.conf,
+ self.args.iou,
+ agnostic=self.args.agnostic_nms,
+ max_det=self.args.max_det,
+ classes=self.args.classes,
+ nc=len(self.model.names),
+ )
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
@@ -49,5 +53,6 @@ class PosePredictor(DetectionPredictor):
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
img_path = self.batch[0][i]
results.append(
- Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts))
+ Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)
+ )
return results
diff --git a/ultralytics/models/yolo/pose/train.py b/ultralytics/models/yolo/pose/train.py
index c9ccf52e..f5229e50 100644
--- a/ultralytics/models/yolo/pose/train.py
+++ b/ultralytics/models/yolo/pose/train.py
@@ -26,16 +26,18 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
"""Initialize a PoseTrainer object with specified configurations and overrides."""
if overrides is None:
overrides = {}
- overrides['task'] = 'pose'
+ overrides["task"] = "pose"
super().__init__(cfg, overrides, _callbacks)
- if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
- LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
- 'See https://github.com/ultralytics/ultralytics/issues/4031.')
+ if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
+ LOGGER.warning(
+ "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
+ "See https://github.com/ultralytics/ultralytics/issues/4031."
+ )
def get_model(self, cfg=None, weights=None, verbose=True):
"""Get pose estimation model with specified configuration and weights."""
- model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose)
+ model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose)
if weights:
model.load(weights)
@@ -44,32 +46,33 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
def set_model_attributes(self):
"""Sets keypoints shape attribute of PoseModel."""
super().set_model_attributes()
- self.model.kpt_shape = self.data['kpt_shape']
+ self.model.kpt_shape = self.data["kpt_shape"]
def get_validator(self):
"""Returns an instance of the PoseValidator class for validation."""
- self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss'
- return yolo.pose.PoseValidator(self.test_loader,
- save_dir=self.save_dir,
- args=copy(self.args),
- _callbacks=self.callbacks)
+ self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
+ return yolo.pose.PoseValidator(
+ self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
+ )
def plot_training_samples(self, batch, ni):
"""Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
- images = batch['img']
- kpts = batch['keypoints']
- cls = batch['cls'].squeeze(-1)
- bboxes = batch['bboxes']
- paths = batch['im_file']
- batch_idx = batch['batch_idx']
- plot_images(images,
- batch_idx,
- cls,
- bboxes,
- kpts=kpts,
- paths=paths,
- fname=self.save_dir / f'train_batch{ni}.jpg',
- on_plot=self.on_plot)
+ images = batch["img"]
+ kpts = batch["keypoints"]
+ cls = batch["cls"].squeeze(-1)
+ bboxes = batch["bboxes"]
+ paths = batch["im_file"]
+ batch_idx = batch["batch_idx"]
+ plot_images(
+ images,
+ batch_idx,
+ cls,
+ bboxes,
+ kpts=kpts,
+ paths=paths,
+ fname=self.save_dir / f"train_batch{ni}.jpg",
+ on_plot=self.on_plot,
+ )
def plot_metrics(self):
"""Plots training/val metrics."""
diff --git a/ultralytics/models/yolo/pose/val.py b/ultralytics/models/yolo/pose/val.py
index f7855bf2..84056863 100644
--- a/ultralytics/models/yolo/pose/val.py
+++ b/ultralytics/models/yolo/pose/val.py
@@ -31,38 +31,53 @@ class PoseValidator(DetectionValidator):
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.sigma = None
self.kpt_shape = None
- self.args.task = 'pose'
+ self.args.task = "pose"
self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
- if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
- LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
- 'See https://github.com/ultralytics/ultralytics/issues/4031.')
+ if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
+ LOGGER.warning(
+ "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
+ "See https://github.com/ultralytics/ultralytics/issues/4031."
+ )
def preprocess(self, batch):
"""Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
batch = super().preprocess(batch)
- batch['keypoints'] = batch['keypoints'].to(self.device).float()
+ batch["keypoints"] = batch["keypoints"].to(self.device).float()
return batch
def get_desc(self):
"""Returns description of evaluation metrics in string format."""
- return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Pose(P',
- 'R', 'mAP50', 'mAP50-95)')
+ return ("%22s" + "%11s" * 10) % (
+ "Class",
+ "Images",
+ "Instances",
+ "Box(P",
+ "R",
+ "mAP50",
+ "mAP50-95)",
+ "Pose(P",
+ "R",
+ "mAP50",
+ "mAP50-95)",
+ )
def postprocess(self, preds):
"""Apply non-maximum suppression and return detections with high confidence scores."""
- return ops.non_max_suppression(preds,
- self.args.conf,
- self.args.iou,
- labels=self.lb,
- multi_label=True,
- agnostic=self.args.single_cls,
- max_det=self.args.max_det,
- nc=self.nc)
+ return ops.non_max_suppression(
+ preds,
+ self.args.conf,
+ self.args.iou,
+ labels=self.lb,
+ multi_label=True,
+ agnostic=self.args.single_cls,
+ max_det=self.args.max_det,
+ nc=self.nc,
+ )
def init_metrics(self, model):
"""Initiate pose estimation metrics for YOLO model."""
super().init_metrics(model)
- self.kpt_shape = self.data['kpt_shape']
+ self.kpt_shape = self.data["kpt_shape"]
is_pose = self.kpt_shape == [17, 3]
nkpt = self.kpt_shape[0]
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
@@ -71,21 +86,21 @@ class PoseValidator(DetectionValidator):
def _prepare_batch(self, si, batch):
"""Prepares a batch for processing by converting keypoints to float and moving to device."""
pbatch = super()._prepare_batch(si, batch)
- kpts = batch['keypoints'][batch['batch_idx'] == si]
- h, w = pbatch['imgsz']
+ kpts = batch["keypoints"][batch["batch_idx"] == si]
+ h, w = pbatch["imgsz"]
kpts = kpts.clone()
kpts[..., 0] *= w
kpts[..., 1] *= h
- kpts = ops.scale_coords(pbatch['imgsz'], kpts, pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'])
- pbatch['kpts'] = kpts
+ kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
+ pbatch["kpts"] = kpts
return pbatch
def _prepare_pred(self, pred, pbatch):
"""Prepares and scales keypoints in a batch for pose processing."""
predn = super()._prepare_pred(pred, pbatch)
- nk = pbatch['kpts'].shape[1]
+ nk = pbatch["kpts"].shape[1]
pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
- ops.scale_coords(pbatch['imgsz'], pred_kpts, pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'])
+ ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
return predn, pred_kpts
def update_metrics(self, preds, batch):
@@ -93,14 +108,16 @@ class PoseValidator(DetectionValidator):
for si, pred in enumerate(preds):
self.seen += 1
npr = len(pred)
- stat = dict(conf=torch.zeros(0, device=self.device),
- pred_cls=torch.zeros(0, device=self.device),
- tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
- tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device))
+ stat = dict(
+ conf=torch.zeros(0, device=self.device),
+ pred_cls=torch.zeros(0, device=self.device),
+ tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
+ tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
+ )
pbatch = self._prepare_batch(si, batch)
- cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox')
+ cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
nl = len(cls)
- stat['target_cls'] = cls
+ stat["target_cls"] = cls
if npr == 0:
if nl:
for k in self.stats.keys():
@@ -113,13 +130,13 @@ class PoseValidator(DetectionValidator):
if self.args.single_cls:
pred[:, 5] = 0
predn, pred_kpts = self._prepare_pred(pred, pbatch)
- stat['conf'] = predn[:, 4]
- stat['pred_cls'] = predn[:, 5]
+ stat["conf"] = predn[:, 4]
+ stat["pred_cls"] = predn[:, 5]
# Evaluate
if nl:
- stat['tp'] = self._process_batch(predn, bbox, cls)
- stat['tp_p'] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch['kpts'])
+ stat["tp"] = self._process_batch(predn, bbox, cls)
+ stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"])
if self.args.plots:
self.confusion_matrix.process_batch(predn, bbox, cls)
@@ -128,7 +145,7 @@ class PoseValidator(DetectionValidator):
# Save
if self.args.save_json:
- self.pred_to_json(predn, batch['im_file'][si])
+ self.pred_to_json(predn, batch["im_file"][si])
# if self.args.save_txt:
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
@@ -159,26 +176,30 @@ class PoseValidator(DetectionValidator):
def plot_val_samples(self, batch, ni):
"""Plots and saves validation set samples with predicted bounding boxes and keypoints."""
- plot_images(batch['img'],
- batch['batch_idx'],
- batch['cls'].squeeze(-1),
- batch['bboxes'],
- kpts=batch['keypoints'],
- paths=batch['im_file'],
- fname=self.save_dir / f'val_batch{ni}_labels.jpg',
- names=self.names,
- on_plot=self.on_plot)
+ plot_images(
+ batch["img"],
+ batch["batch_idx"],
+ batch["cls"].squeeze(-1),
+ batch["bboxes"],
+ kpts=batch["keypoints"],
+ paths=batch["im_file"],
+ fname=self.save_dir / f"val_batch{ni}_labels.jpg",
+ names=self.names,
+ on_plot=self.on_plot,
+ )
def plot_predictions(self, batch, preds, ni):
"""Plots predictions for YOLO model."""
pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
- plot_images(batch['img'],
- *output_to_target(preds, max_det=self.args.max_det),
- kpts=pred_kpts,
- paths=batch['im_file'],
- fname=self.save_dir / f'val_batch{ni}_pred.jpg',
- names=self.names,
- on_plot=self.on_plot) # pred
+ plot_images(
+ batch["img"],
+ *output_to_target(preds, max_det=self.args.max_det),
+ kpts=pred_kpts,
+ paths=batch["im_file"],
+ fname=self.save_dir / f"val_batch{ni}_pred.jpg",
+ names=self.names,
+ on_plot=self.on_plot,
+ ) # pred
def pred_to_json(self, predn, filename):
"""Converts YOLO predictions to COCO JSON format."""
@@ -187,37 +208,41 @@ class PoseValidator(DetectionValidator):
box = ops.xyxy2xywh(predn[:, :4]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
for p, b in zip(predn.tolist(), box.tolist()):
- self.jdict.append({
- 'image_id': image_id,
- 'category_id': self.class_map[int(p[5])],
- 'bbox': [round(x, 3) for x in b],
- 'keypoints': p[6:],
- 'score': round(p[4], 5)})
+ self.jdict.append(
+ {
+ "image_id": image_id,
+ "category_id": self.class_map[int(p[5])],
+ "bbox": [round(x, 3) for x in b],
+ "keypoints": p[6:],
+ "score": round(p[4], 5),
+ }
+ )
def eval_json(self, stats):
"""Evaluates object detection model using COCO JSON format."""
if self.args.save_json and self.is_coco and len(self.jdict):
- anno_json = self.data['path'] / 'annotations/person_keypoints_val2017.json' # annotations
- pred_json = self.save_dir / 'predictions.json' # predictions
- LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
+ anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
+ pred_json = self.save_dir / "predictions.json" # predictions
+ LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
- check_requirements('pycocotools>=2.0.6')
+ check_requirements("pycocotools>=2.0.6")
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
- assert x.is_file(), f'{x} file not found'
+ assert x.is_file(), f"{x} file not found"
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
- for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'keypoints')]):
+ for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "keypoints")]):
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
eval.evaluate()
eval.accumulate()
eval.summarize()
idx = i * 4 + 2
- stats[self.metrics.keys[idx + 1]], stats[
- self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
+ stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[
+ :2
+ ] # update mAP50-95 and mAP50
except Exception as e:
- LOGGER.warning(f'pycocotools unable to run: {e}')
+ LOGGER.warning(f"pycocotools unable to run: {e}")
return stats
diff --git a/ultralytics/models/yolo/segment/__init__.py b/ultralytics/models/yolo/segment/__init__.py
index c84a570e..ec1ac799 100644
--- a/ultralytics/models/yolo/segment/__init__.py
+++ b/ultralytics/models/yolo/segment/__init__.py
@@ -4,4 +4,4 @@ from .predict import SegmentationPredictor
from .train import SegmentationTrainer
from .val import SegmentationValidator
-__all__ = 'SegmentationPredictor', 'SegmentationTrainer', 'SegmentationValidator'
+__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator"
diff --git a/ultralytics/models/yolo/segment/predict.py b/ultralytics/models/yolo/segment/predict.py
index ba44a482..8d8cd59a 100644
--- a/ultralytics/models/yolo/segment/predict.py
+++ b/ultralytics/models/yolo/segment/predict.py
@@ -23,17 +23,19 @@ class SegmentationPredictor(DetectionPredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initializes the SegmentationPredictor with the provided configuration, overrides, and callbacks."""
super().__init__(cfg, overrides, _callbacks)
- self.args.task = 'segment'
+ self.args.task = "segment"
def postprocess(self, preds, img, orig_imgs):
"""Applies non-max suppression and processes detections for each image in an input batch."""
- p = ops.non_max_suppression(preds[0],
- self.args.conf,
- self.args.iou,
- agnostic=self.args.agnostic_nms,
- max_det=self.args.max_det,
- nc=len(self.model.names),
- classes=self.args.classes)
+ p = ops.non_max_suppression(
+ preds[0],
+ self.args.conf,
+ self.args.iou,
+ agnostic=self.args.agnostic_nms,
+ max_det=self.args.max_det,
+ nc=len(self.model.names),
+ classes=self.args.classes,
+ )
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
diff --git a/ultralytics/models/yolo/segment/train.py b/ultralytics/models/yolo/segment/train.py
index 1d1227da..126baf20 100644
--- a/ultralytics/models/yolo/segment/train.py
+++ b/ultralytics/models/yolo/segment/train.py
@@ -26,12 +26,12 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
"""Initialize a SegmentationTrainer object with given arguments."""
if overrides is None:
overrides = {}
- overrides['task'] = 'segment'
+ overrides["task"] = "segment"
super().__init__(cfg, overrides, _callbacks)
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return SegmentationModel initialized with specified config and weights."""
- model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
+ model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
@@ -39,22 +39,23 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
def get_validator(self):
"""Return an instance of SegmentationValidator for validation of YOLO model."""
- self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
- return yolo.segment.SegmentationValidator(self.test_loader,
- save_dir=self.save_dir,
- args=copy(self.args),
- _callbacks=self.callbacks)
+ self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
+ return yolo.segment.SegmentationValidator(
+ self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
+ )
def plot_training_samples(self, batch, ni):
"""Creates a plot of training sample images with labels and box coordinates."""
- plot_images(batch['img'],
- batch['batch_idx'],
- batch['cls'].squeeze(-1),
- batch['bboxes'],
- masks=batch['masks'],
- paths=batch['im_file'],
- fname=self.save_dir / f'train_batch{ni}.jpg',
- on_plot=self.on_plot)
+ plot_images(
+ batch["img"],
+ batch["batch_idx"],
+ batch["cls"].squeeze(-1),
+ batch["bboxes"],
+ masks=batch["masks"],
+ paths=batch["im_file"],
+ fname=self.save_dir / f"train_batch{ni}.jpg",
+ on_plot=self.on_plot,
+ )
def plot_metrics(self):
"""Plots training/val metrics."""
diff --git a/ultralytics/models/yolo/segment/val.py b/ultralytics/models/yolo/segment/val.py
index 10ac374f..72c04b2c 100644
--- a/ultralytics/models/yolo/segment/val.py
+++ b/ultralytics/models/yolo/segment/val.py
@@ -33,13 +33,13 @@ class SegmentationValidator(DetectionValidator):
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.plot_masks = None
self.process = None
- self.args.task = 'segment'
+ self.args.task = "segment"
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
def preprocess(self, batch):
"""Preprocesses batch by converting masks to float and sending to device."""
batch = super().preprocess(batch)
- batch['masks'] = batch['masks'].to(self.device).float()
+ batch["masks"] = batch["masks"].to(self.device).float()
return batch
def init_metrics(self, model):
@@ -47,7 +47,7 @@ class SegmentationValidator(DetectionValidator):
super().init_metrics(model)
self.plot_masks = []
if self.args.save_json:
- check_requirements('pycocotools>=2.0.6')
+ check_requirements("pycocotools>=2.0.6")
self.process = ops.process_mask_upsample # more accurate
else:
self.process = ops.process_mask # faster
@@ -55,33 +55,46 @@ class SegmentationValidator(DetectionValidator):
def get_desc(self):
"""Return a formatted description of evaluation metrics."""
- return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P',
- 'R', 'mAP50', 'mAP50-95)')
+ return ("%22s" + "%11s" * 10) % (
+ "Class",
+ "Images",
+ "Instances",
+ "Box(P",
+ "R",
+ "mAP50",
+ "mAP50-95)",
+ "Mask(P",
+ "R",
+ "mAP50",
+ "mAP50-95)",
+ )
def postprocess(self, preds):
"""Post-processes YOLO predictions and returns output detections with proto."""
- p = ops.non_max_suppression(preds[0],
- self.args.conf,
- self.args.iou,
- labels=self.lb,
- multi_label=True,
- agnostic=self.args.single_cls,
- max_det=self.args.max_det,
- nc=self.nc)
+ p = ops.non_max_suppression(
+ preds[0],
+ self.args.conf,
+ self.args.iou,
+ labels=self.lb,
+ multi_label=True,
+ agnostic=self.args.single_cls,
+ max_det=self.args.max_det,
+ nc=self.nc,
+ )
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
return p, proto
def _prepare_batch(self, si, batch):
"""Prepares a batch for training or inference by processing images and targets."""
prepared_batch = super()._prepare_batch(si, batch)
- midx = [si] if self.args.overlap_mask else batch['batch_idx'] == si
- prepared_batch['masks'] = batch['masks'][midx]
+ midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si
+ prepared_batch["masks"] = batch["masks"][midx]
return prepared_batch
def _prepare_pred(self, pred, pbatch, proto):
"""Prepares a batch for training or inference by processing images and targets."""
predn = super()._prepare_pred(pred, pbatch)
- pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch['imgsz'])
+ pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch["imgsz"])
return predn, pred_masks
def update_metrics(self, preds, batch):
@@ -89,14 +102,16 @@ class SegmentationValidator(DetectionValidator):
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
self.seen += 1
npr = len(pred)
- stat = dict(conf=torch.zeros(0, device=self.device),
- pred_cls=torch.zeros(0, device=self.device),
- tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
- tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device))
+ stat = dict(
+ conf=torch.zeros(0, device=self.device),
+ pred_cls=torch.zeros(0, device=self.device),
+ tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
+ tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
+ )
pbatch = self._prepare_batch(si, batch)
- cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox')
+ cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
nl = len(cls)
- stat['target_cls'] = cls
+ stat["target_cls"] = cls
if npr == 0:
if nl:
for k in self.stats.keys():
@@ -106,24 +121,20 @@ class SegmentationValidator(DetectionValidator):
continue
# Masks
- gt_masks = pbatch.pop('masks')
+ gt_masks = pbatch.pop("masks")
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
predn, pred_masks = self._prepare_pred(pred, pbatch, proto)
- stat['conf'] = predn[:, 4]
- stat['pred_cls'] = predn[:, 5]
+ stat["conf"] = predn[:, 4]
+ stat["pred_cls"] = predn[:, 5]
# Evaluate
if nl:
- stat['tp'] = self._process_batch(predn, bbox, cls)
- stat['tp_m'] = self._process_batch(predn,
- bbox,
- cls,
- pred_masks,
- gt_masks,
- self.args.overlap_mask,
- masks=True)
+ stat["tp"] = self._process_batch(predn, bbox, cls)
+ stat["tp_m"] = self._process_batch(
+ predn, bbox, cls, pred_masks, gt_masks, self.args.overlap_mask, masks=True
+ )
if self.args.plots:
self.confusion_matrix.process_batch(predn, bbox, cls)
@@ -136,10 +147,12 @@ class SegmentationValidator(DetectionValidator):
# Save
if self.args.save_json:
- pred_masks = ops.scale_image(pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
- pbatch['ori_shape'],
- ratio_pad=batch['ratio_pad'][si])
- self.pred_to_json(predn, batch['im_file'][si], pred_masks)
+ pred_masks = ops.scale_image(
+ pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
+ pbatch["ori_shape"],
+ ratio_pad=batch["ratio_pad"][si],
+ )
+ self.pred_to_json(predn, batch["im_file"][si], pred_masks)
# if self.args.save_txt:
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
@@ -166,7 +179,7 @@ class SegmentationValidator(DetectionValidator):
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
if gt_masks.shape[1:] != pred_masks.shape[1:]:
- gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0]
+ gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
gt_masks = gt_masks.gt_(0.5)
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
else: # boxes
@@ -176,26 +189,29 @@ class SegmentationValidator(DetectionValidator):
def plot_val_samples(self, batch, ni):
"""Plots validation samples with bounding box labels."""
- plot_images(batch['img'],
- batch['batch_idx'],
- batch['cls'].squeeze(-1),
- batch['bboxes'],
- masks=batch['masks'],
- paths=batch['im_file'],
- fname=self.save_dir / f'val_batch{ni}_labels.jpg',
- names=self.names,
- on_plot=self.on_plot)
+ plot_images(
+ batch["img"],
+ batch["batch_idx"],
+ batch["cls"].squeeze(-1),
+ batch["bboxes"],
+ masks=batch["masks"],
+ paths=batch["im_file"],
+ fname=self.save_dir / f"val_batch{ni}_labels.jpg",
+ names=self.names,
+ on_plot=self.on_plot,
+ )
def plot_predictions(self, batch, preds, ni):
"""Plots batch predictions with masks and bounding boxes."""
plot_images(
- batch['img'],
+ batch["img"],
*output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
- paths=batch['im_file'],
- fname=self.save_dir / f'val_batch{ni}_pred.jpg',
+ paths=batch["im_file"],
+ fname=self.save_dir / f"val_batch{ni}_pred.jpg",
names=self.names,
- on_plot=self.on_plot) # pred
+ on_plot=self.on_plot,
+ ) # pred
self.plot_masks.clear()
def pred_to_json(self, predn, filename, pred_masks):
@@ -205,8 +221,8 @@ class SegmentationValidator(DetectionValidator):
def single_encode(x):
"""Encode predicted masks as RLE and append results to jdict."""
- rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0]
- rle['counts'] = rle['counts'].decode('utf-8')
+ rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
+ rle["counts"] = rle["counts"].decode("utf-8")
return rle
stem = Path(filename).stem
@@ -217,37 +233,41 @@ class SegmentationValidator(DetectionValidator):
with ThreadPool(NUM_THREADS) as pool:
rles = pool.map(single_encode, pred_masks)
for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
- self.jdict.append({
- 'image_id': image_id,
- 'category_id': self.class_map[int(p[5])],
- 'bbox': [round(x, 3) for x in b],
- 'score': round(p[4], 5),
- 'segmentation': rles[i]})
+ self.jdict.append(
+ {
+ "image_id": image_id,
+ "category_id": self.class_map[int(p[5])],
+ "bbox": [round(x, 3) for x in b],
+ "score": round(p[4], 5),
+ "segmentation": rles[i],
+ }
+ )
def eval_json(self, stats):
"""Return COCO-style object detection evaluation metrics."""
if self.args.save_json and self.is_coco and len(self.jdict):
- anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
- pred_json = self.save_dir / 'predictions.json' # predictions
- LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
+ anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
+ pred_json = self.save_dir / "predictions.json" # predictions
+ LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
- check_requirements('pycocotools>=2.0.6')
+ check_requirements("pycocotools>=2.0.6")
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
- assert x.is_file(), f'{x} file not found'
+ assert x.is_file(), f"{x} file not found"
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
- for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]):
+ for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "segm")]):
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
eval.evaluate()
eval.accumulate()
eval.summarize()
idx = i * 4 + 2
- stats[self.metrics.keys[idx + 1]], stats[
- self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
+ stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[
+ :2
+ ] # update mAP50-95 and mAP50
except Exception as e:
- LOGGER.warning(f'pycocotools unable to run: {e}')
+ LOGGER.warning(f"pycocotools unable to run: {e}")
return stats
diff --git a/ultralytics/nn/__init__.py b/ultralytics/nn/__init__.py
index 9889b7ef..6905d349 100644
--- a/ultralytics/nn/__init__.py
+++ b/ultralytics/nn/__init__.py
@@ -1,9 +1,29 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-from .tasks import (BaseModel, ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
- attempt_load_weights, guess_model_scale, guess_model_task, parse_model, torch_safe_load,
- yaml_model_load)
+from .tasks import (
+ BaseModel,
+ ClassificationModel,
+ DetectionModel,
+ SegmentationModel,
+ attempt_load_one_weight,
+ attempt_load_weights,
+ guess_model_scale,
+ guess_model_task,
+ parse_model,
+ torch_safe_load,
+ yaml_model_load,
+)
-__all__ = ('attempt_load_one_weight', 'attempt_load_weights', 'parse_model', 'yaml_model_load', 'guess_model_task',
- 'guess_model_scale', 'torch_safe_load', 'DetectionModel', 'SegmentationModel', 'ClassificationModel',
- 'BaseModel')
+__all__ = (
+ "attempt_load_one_weight",
+ "attempt_load_weights",
+ "parse_model",
+ "yaml_model_load",
+ "guess_model_task",
+ "guess_model_scale",
+ "torch_safe_load",
+ "DetectionModel",
+ "SegmentationModel",
+ "ClassificationModel",
+ "BaseModel",
+)
diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py
index a5c964ce..8f55c3ba 100644
--- a/ultralytics/nn/autobackend.py
+++ b/ultralytics/nn/autobackend.py
@@ -32,10 +32,12 @@ def check_class_names(names):
names = {int(k): str(v) for k, v in names.items()}
n = len(names)
if max(names.keys()) >= n:
- raise KeyError(f'{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices '
- f'{min(names.keys())}-{max(names.keys())} defined in your dataset YAML.')
- if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
- names_map = yaml_load(ROOT / 'cfg/datasets/ImageNet.yaml')['map'] # human-readable names
+ raise KeyError(
+ f"{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices "
+ f"{min(names.keys())}-{max(names.keys())} defined in your dataset YAML."
+ )
+ if isinstance(names[0], str) and names[0].startswith("n0"): # imagenet class codes, i.e. 'n01440764'
+ names_map = yaml_load(ROOT / "cfg/datasets/ImageNet.yaml")["map"] # human-readable names
names = {k: names_map[v] for k, v in names.items()}
return names
@@ -44,8 +46,8 @@ def default_class_names(data=None):
"""Applies default class names to an input YAML file or returns numerical class names."""
if data:
with contextlib.suppress(Exception):
- return yaml_load(check_yaml(data))['names']
- return {i: f'class{i}' for i in range(999)} # return default if above errors
+ return yaml_load(check_yaml(data))["names"]
+ return {i: f"class{i}" for i in range(999)} # return default if above errors
class AutoBackend(nn.Module):
@@ -77,14 +79,16 @@ class AutoBackend(nn.Module):
"""
@torch.no_grad()
- def __init__(self,
- weights='yolov8n.pt',
- device=torch.device('cpu'),
- dnn=False,
- data=None,
- fp16=False,
- fuse=True,
- verbose=True):
+ def __init__(
+ self,
+ weights="yolov8n.pt",
+ device=torch.device("cpu"),
+ dnn=False,
+ data=None,
+ fp16=False,
+ fuse=True,
+ verbose=True,
+ ):
"""
Initialize the AutoBackend for inference.
@@ -100,17 +104,31 @@ class AutoBackend(nn.Module):
super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights)
nn_module = isinstance(weights, torch.nn.Module)
- pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, triton = \
- self._model_type(w)
+ (
+ pt,
+ jit,
+ onnx,
+ xml,
+ engine,
+ coreml,
+ saved_model,
+ pb,
+ tflite,
+ edgetpu,
+ tfjs,
+ paddle,
+ ncnn,
+ triton,
+ ) = self._model_type(w)
fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
stride = 32 # default stride
model, metadata = None, None
# Set device
- cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
+ cuda = torch.cuda.is_available() and device.type != "cpu" # use CUDA
if cuda and not any([nn_module, pt, jit, engine, onnx]): # GPU dataloader formats
- device = torch.device('cpu')
+ device = torch.device("cpu")
cuda = False
# Download if not local
@@ -121,77 +139,79 @@ class AutoBackend(nn.Module):
if nn_module: # in-memory PyTorch model
model = weights.to(device)
model = model.fuse(verbose=verbose) if fuse else model
- if hasattr(model, 'kpt_shape'):
+ if hasattr(model, "kpt_shape"):
kpt_shape = model.kpt_shape # pose-only
stride = max(int(model.stride.max()), 32) # model stride
- names = model.module.names if hasattr(model, 'module') else model.names # get class names
+ names = model.module.names if hasattr(model, "module") else model.names # get class names
model.half() if fp16 else model.float()
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
pt = True
elif pt: # PyTorch
from ultralytics.nn.tasks import attempt_load_weights
- model = attempt_load_weights(weights if isinstance(weights, list) else w,
- device=device,
- inplace=True,
- fuse=fuse)
- if hasattr(model, 'kpt_shape'):
+
+ model = attempt_load_weights(
+ weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse
+ )
+ if hasattr(model, "kpt_shape"):
kpt_shape = model.kpt_shape # pose-only
stride = max(int(model.stride.max()), 32) # model stride
- names = model.module.names if hasattr(model, 'module') else model.names # get class names
+ names = model.module.names if hasattr(model, "module") else model.names # get class names
model.half() if fp16 else model.float()
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
elif jit: # TorchScript
- LOGGER.info(f'Loading {w} for TorchScript inference...')
- extra_files = {'config.txt': ''} # model metadata
+ LOGGER.info(f"Loading {w} for TorchScript inference...")
+ extra_files = {"config.txt": ""} # model metadata
model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
model.half() if fp16 else model.float()
- if extra_files['config.txt']: # load metadata dict
- metadata = json.loads(extra_files['config.txt'], object_hook=lambda x: dict(x.items()))
+ if extra_files["config.txt"]: # load metadata dict
+ metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items()))
elif dnn: # ONNX OpenCV DNN
- LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
- check_requirements('opencv-python>=4.5.4')
+ LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...")
+ check_requirements("opencv-python>=4.5.4")
net = cv2.dnn.readNetFromONNX(w)
elif onnx: # ONNX Runtime
- LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
- check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
+ LOGGER.info(f"Loading {w} for ONNX Runtime inference...")
+ check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime"))
import onnxruntime
- providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
+
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if cuda else ["CPUExecutionProvider"]
session = onnxruntime.InferenceSession(w, providers=providers)
output_names = [x.name for x in session.get_outputs()]
metadata = session.get_modelmeta().custom_metadata_map # metadata
elif xml: # OpenVINO
- LOGGER.info(f'Loading {w} for OpenVINO inference...')
- check_requirements('openvino>=2023.0') # requires openvino-dev: https://pypi.org/project/openvino-dev/
+ LOGGER.info(f"Loading {w} for OpenVINO inference...")
+ check_requirements("openvino>=2023.0") # requires openvino-dev: https://pypi.org/project/openvino-dev/
from openvino.runtime import Core, Layout, get_batch # noqa
+
core = Core()
w = Path(w)
if not w.is_file(): # if not *.xml
- w = next(w.glob('*.xml')) # get *.xml file from *_openvino_model dir
- ov_model = core.read_model(model=str(w), weights=w.with_suffix('.bin'))
+ w = next(w.glob("*.xml")) # get *.xml file from *_openvino_model dir
+ ov_model = core.read_model(model=str(w), weights=w.with_suffix(".bin"))
if ov_model.get_parameters()[0].get_layout().empty:
- ov_model.get_parameters()[0].set_layout(Layout('NCHW'))
+ ov_model.get_parameters()[0].set_layout(Layout("NCHW"))
batch_dim = get_batch(ov_model)
if batch_dim.is_static:
batch_size = batch_dim.get_length()
- ov_compiled_model = core.compile_model(ov_model, device_name='AUTO') # AUTO selects best available device
- metadata = w.parent / 'metadata.yaml'
+ ov_compiled_model = core.compile_model(ov_model, device_name="AUTO") # AUTO selects best available device
+ metadata = w.parent / "metadata.yaml"
elif engine: # TensorRT
- LOGGER.info(f'Loading {w} for TensorRT inference...')
+ LOGGER.info(f"Loading {w} for TensorRT inference...")
try:
import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download
except ImportError:
if LINUX:
- check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
+ check_requirements("nvidia-tensorrt", cmds="-U --index-url https://pypi.ngc.nvidia.com")
import tensorrt as trt # noqa
- check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
- if device.type == 'cpu':
- device = torch.device('cuda:0')
- Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
+ check_version(trt.__version__, "7.0.0", hard=True) # require tensorrt>=7.0.0
+ if device.type == "cpu":
+ device = torch.device("cuda:0")
+ Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr"))
logger = trt.Logger(trt.Logger.INFO)
# Read file
- with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
- meta_len = int.from_bytes(f.read(4), byteorder='little') # read metadata length
- metadata = json.loads(f.read(meta_len).decode('utf-8')) # read metadata
+ with open(w, "rb") as f, trt.Runtime(logger) as runtime:
+ meta_len = int.from_bytes(f.read(4), byteorder="little") # read metadata length
+ metadata = json.loads(f.read(meta_len).decode("utf-8")) # read metadata
model = runtime.deserialize_cuda_engine(f.read()) # read engine
context = model.create_execution_context()
bindings = OrderedDict()
@@ -213,116 +233,124 @@ class AutoBackend(nn.Module):
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
- batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
+ batch_size = bindings["images"].shape[0] # if dynamic, this is instead max batch size
elif coreml: # CoreML
- LOGGER.info(f'Loading {w} for CoreML inference...')
+ LOGGER.info(f"Loading {w} for CoreML inference...")
import coremltools as ct
+
model = ct.models.MLModel(w)
metadata = dict(model.user_defined_metadata)
elif saved_model: # TF SavedModel
- LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
+ LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...")
import tensorflow as tf
+
keras = False # assume TF1 saved_model
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
- metadata = Path(w) / 'metadata.yaml'
+ metadata = Path(w) / "metadata.yaml"
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
- LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
+ LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...")
import tensorflow as tf
from ultralytics.engine.exporter import gd_outputs
def wrap_frozen_graph(gd, inputs, outputs):
"""Wrap frozen graphs for deployment."""
- x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # wrapped
+ x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
ge = x.graph.as_graph_element
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
gd = tf.Graph().as_graph_def() # TF GraphDef
- with open(w, 'rb') as f:
+ with open(w, "rb") as f:
gd.ParseFromString(f.read())
- frozen_func = wrap_frozen_graph(gd, inputs='x:0', outputs=gd_outputs(gd))
+ frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
from tflite_runtime.interpreter import Interpreter, load_delegate
except ImportError:
import tensorflow as tf
+
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate
if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
- LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
- delegate = {
- 'Linux': 'libedgetpu.so.1',
- 'Darwin': 'libedgetpu.1.dylib',
- 'Windows': 'edgetpu.dll'}[platform.system()]
+ LOGGER.info(f"Loading {w} for TensorFlow Lite Edge TPU inference...")
+ delegate = {"Linux": "libedgetpu.so.1", "Darwin": "libedgetpu.1.dylib", "Windows": "edgetpu.dll"}[
+ platform.system()
+ ]
interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
else: # TFLite
- LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
+ LOGGER.info(f"Loading {w} for TensorFlow Lite inference...")
interpreter = Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
# Load metadata
with contextlib.suppress(zipfile.BadZipFile):
- with zipfile.ZipFile(w, 'r') as model:
+ with zipfile.ZipFile(w, "r") as model:
meta_file = model.namelist()[0]
- metadata = ast.literal_eval(model.read(meta_file).decode('utf-8'))
+ metadata = ast.literal_eval(model.read(meta_file).decode("utf-8"))
elif tfjs: # TF.js
- raise NotImplementedError('YOLOv8 TF.js inference is not currently supported.')
+ raise NotImplementedError("YOLOv8 TF.js inference is not currently supported.")
elif paddle: # PaddlePaddle
- LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
- check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
+ LOGGER.info(f"Loading {w} for PaddlePaddle inference...")
+ check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle")
import paddle.inference as pdi # noqa
+
w = Path(w)
if not w.is_file(): # if not *.pdmodel
- w = next(w.rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir
- config = pdi.Config(str(w), str(w.with_suffix('.pdiparams')))
+ w = next(w.rglob("*.pdmodel")) # get *.pdmodel file from *_paddle_model dir
+ config = pdi.Config(str(w), str(w.with_suffix(".pdiparams")))
if cuda:
config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
predictor = pdi.create_predictor(config)
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
output_names = predictor.get_output_names()
- metadata = w.parents[1] / 'metadata.yaml'
+ metadata = w.parents[1] / "metadata.yaml"
elif ncnn: # ncnn
- LOGGER.info(f'Loading {w} for ncnn inference...')
- check_requirements('git+https://github.com/Tencent/ncnn.git' if ARM64 else 'ncnn') # requires ncnn
+ LOGGER.info(f"Loading {w} for ncnn inference...")
+ check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires ncnn
import ncnn as pyncnn
+
net = pyncnn.Net()
net.opt.use_vulkan_compute = cuda
w = Path(w)
if not w.is_file(): # if not *.param
- w = next(w.glob('*.param')) # get *.param file from *_ncnn_model dir
+ w = next(w.glob("*.param")) # get *.param file from *_ncnn_model dir
net.load_param(str(w))
- net.load_model(str(w.with_suffix('.bin')))
- metadata = w.parent / 'metadata.yaml'
+ net.load_model(str(w.with_suffix(".bin")))
+ metadata = w.parent / "metadata.yaml"
elif triton: # NVIDIA Triton Inference Server
- check_requirements('tritonclient[all]')
+ check_requirements("tritonclient[all]")
from ultralytics.utils.triton import TritonRemoteModel
+
model = TritonRemoteModel(w)
else:
from ultralytics.engine.exporter import export_formats
- raise TypeError(f"model='{w}' is not a supported model format. "
- 'See https://docs.ultralytics.com/modes/predict for help.'
- f'\n\n{export_formats()}')
+
+ raise TypeError(
+ f"model='{w}' is not a supported model format. "
+ "See https://docs.ultralytics.com/modes/predict for help."
+ f"\n\n{export_formats()}"
+ )
# Load external metadata YAML
if isinstance(metadata, (str, Path)) and Path(metadata).exists():
metadata = yaml_load(metadata)
if metadata:
for k, v in metadata.items():
- if k in ('stride', 'batch'):
+ if k in ("stride", "batch"):
metadata[k] = int(v)
- elif k in ('imgsz', 'names', 'kpt_shape') and isinstance(v, str):
+ elif k in ("imgsz", "names", "kpt_shape") and isinstance(v, str):
metadata[k] = eval(v)
- stride = metadata['stride']
- task = metadata['task']
- batch = metadata['batch']
- imgsz = metadata['imgsz']
- names = metadata['names']
- kpt_shape = metadata.get('kpt_shape')
+ stride = metadata["stride"]
+ task = metadata["task"]
+ batch = metadata["batch"]
+ imgsz = metadata["imgsz"]
+ names = metadata["names"]
+ kpt_shape = metadata.get("kpt_shape")
elif not (pt or triton or nn_module):
LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'")
# Check names
- if 'names' not in locals(): # names missing
+ if "names" not in locals(): # names missing
names = default_class_names(data)
names = check_class_names(names)
@@ -367,26 +395,28 @@ class AutoBackend(nn.Module):
im = im.cpu().numpy() # FP32
y = list(self.ov_compiled_model(im).values())
elif self.engine: # TensorRT
- if self.dynamic and im.shape != self.bindings['images'].shape:
- i = self.model.get_binding_index('images')
+ if self.dynamic and im.shape != self.bindings["images"].shape:
+ i = self.model.get_binding_index("images")
self.context.set_binding_shape(i, im.shape) # reshape if dynamic
- self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
+ self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape)
for name in self.output_names:
i = self.model.get_binding_index(name)
self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
- s = self.bindings['images'].shape
+ s = self.bindings["images"].shape
assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
- self.binding_addrs['images'] = int(im.data_ptr())
+ self.binding_addrs["images"] = int(im.data_ptr())
self.context.execute_v2(list(self.binding_addrs.values()))
y = [self.bindings[x].data for x in sorted(self.output_names)]
elif self.coreml: # CoreML
im = im[0].cpu().numpy()
- im_pil = Image.fromarray((im * 255).astype('uint8'))
+ im_pil = Image.fromarray((im * 255).astype("uint8"))
# im = im.resize((192, 320), Image.BILINEAR)
- y = self.model.predict({'image': im_pil}) # coordinates are xywh normalized
- if 'confidence' in y:
- raise TypeError('Ultralytics only supports inference of non-pipelined CoreML models exported with '
- f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export.")
+ y = self.model.predict({"image": im_pil}) # coordinates are xywh normalized
+ if "confidence" in y:
+ raise TypeError(
+ "Ultralytics only supports inference of non-pipelined CoreML models exported with "
+ f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export."
+ )
# TODO: CoreML NMS inference handling
# from ultralytics.utils.ops import xywh2xyxy
# box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
@@ -425,20 +455,20 @@ class AutoBackend(nn.Module):
if len(y) == 2 and len(self.names) == 999: # segments and names not defined
ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes
nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400)
- self.names = {i: f'class{i}' for i in range(nc)}
+ self.names = {i: f"class{i}" for i in range(nc)}
else: # Lite or Edge TPU
details = self.input_details[0]
- integer = details['dtype'] in (np.int8, np.int16) # is TFLite quantized int8 or int16 model
+ integer = details["dtype"] in (np.int8, np.int16) # is TFLite quantized int8 or int16 model
if integer:
- scale, zero_point = details['quantization']
- im = (im / scale + zero_point).astype(details['dtype']) # de-scale
- self.interpreter.set_tensor(details['index'], im)
+ scale, zero_point = details["quantization"]
+ im = (im / scale + zero_point).astype(details["dtype"]) # de-scale
+ self.interpreter.set_tensor(details["index"], im)
self.interpreter.invoke()
y = []
for output in self.output_details:
- x = self.interpreter.get_tensor(output['index'])
+ x = self.interpreter.get_tensor(output["index"])
if integer:
- scale, zero_point = output['quantization']
+ scale, zero_point = output["quantization"]
x = (x.astype(np.float32) - zero_point) * scale # re-scale
if x.ndim > 2: # if task is not classification
# Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695
@@ -483,13 +513,13 @@ class AutoBackend(nn.Module):
(None): This method runs the forward pass and don't return any value
"""
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
- if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
+ if any(warmup_types) and (self.device.type != "cpu" or self.triton):
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
for _ in range(2 if self.jit else 1):
self.forward(im) # warmup
@staticmethod
- def _model_type(p='path/to/model.pt'):
+ def _model_type(p="path/to/model.pt"):
"""
This function takes a path to a model file and returns the model type.
@@ -499,18 +529,20 @@ class AutoBackend(nn.Module):
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
from ultralytics.engine.exporter import export_formats
+
sf = list(export_formats().Suffix) # export suffixes
if not is_url(p, check=False) and not isinstance(p, str):
check_suffix(p, sf) # checks
name = Path(p).name
types = [s in name for s in sf]
- types[5] |= name.endswith('.mlmodel') # retain support for older Apple CoreML *.mlmodel formats
+ types[5] |= name.endswith(".mlmodel") # retain support for older Apple CoreML *.mlmodel formats
types[8] &= not types[9] # tflite &= not edgetpu
if any(types):
triton = False
else:
from urllib.parse import urlsplit
+
url = urlsplit(p)
- triton = url.netloc and url.path and url.scheme in {'http', 'grpc'}
+ triton = url.netloc and url.path and url.scheme in {"http", "grpc"}
return types + [triton]
diff --git a/ultralytics/nn/modules/__init__.py b/ultralytics/nn/modules/__init__.py
index bbc7a5b1..3c6739e5 100644
--- a/ultralytics/nn/modules/__init__.py
+++ b/ultralytics/nn/modules/__init__.py
@@ -17,18 +17,101 @@ Example:
```
"""
-from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, GhostBottleneck,
- HGBlock, HGStem, Proto, RepC3, ResNetLayer)
-from .conv import (CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus,
- GhostConv, LightConv, RepConv, SpatialAttention)
+from .block import (
+ C1,
+ C2,
+ C3,
+ C3TR,
+ DFL,
+ SPP,
+ SPPF,
+ Bottleneck,
+ BottleneckCSP,
+ C2f,
+ C3Ghost,
+ C3x,
+ GhostBottleneck,
+ HGBlock,
+ HGStem,
+ Proto,
+ RepC3,
+ ResNetLayer,
+)
+from .conv import (
+ CBAM,
+ ChannelAttention,
+ Concat,
+ Conv,
+ Conv2,
+ ConvTranspose,
+ DWConv,
+ DWConvTranspose2d,
+ Focus,
+ GhostConv,
+ LightConv,
+ RepConv,
+ SpatialAttention,
+)
from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment
-from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d,
- MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer)
+from .transformer import (
+ AIFI,
+ MLP,
+ DeformableTransformerDecoder,
+ DeformableTransformerDecoderLayer,
+ LayerNorm2d,
+ MLPBlock,
+ MSDeformAttn,
+ TransformerBlock,
+ TransformerEncoderLayer,
+ TransformerLayer,
+)
-__all__ = ('Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus',
- 'GhostConv', 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer',
- 'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3',
- 'C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect',
- 'Segment', 'Pose', 'Classify', 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI',
- 'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP', 'ResNetLayer',
- 'OBB')
+__all__ = (
+ "Conv",
+ "Conv2",
+ "LightConv",
+ "RepConv",
+ "DWConv",
+ "DWConvTranspose2d",
+ "ConvTranspose",
+ "Focus",
+ "GhostConv",
+ "ChannelAttention",
+ "SpatialAttention",
+ "CBAM",
+ "Concat",
+ "TransformerLayer",
+ "TransformerBlock",
+ "MLPBlock",
+ "LayerNorm2d",
+ "DFL",
+ "HGBlock",
+ "HGStem",
+ "SPP",
+ "SPPF",
+ "C1",
+ "C2",
+ "C3",
+ "C2f",
+ "C3x",
+ "C3TR",
+ "C3Ghost",
+ "GhostBottleneck",
+ "Bottleneck",
+ "BottleneckCSP",
+ "Proto",
+ "Detect",
+ "Segment",
+ "Pose",
+ "Classify",
+ "TransformerEncoderLayer",
+ "RepC3",
+ "RTDETRDecoder",
+ "AIFI",
+ "DeformableTransformerDecoder",
+ "DeformableTransformerDecoderLayer",
+ "MSDeformAttn",
+ "MLP",
+ "ResNetLayer",
+ "OBB",
+)
diff --git a/ultralytics/nn/modules/block.py b/ultralytics/nn/modules/block.py
index c754cdfc..5a6d12f7 100644
--- a/ultralytics/nn/modules/block.py
+++ b/ultralytics/nn/modules/block.py
@@ -8,8 +8,26 @@ import torch.nn.functional as F
from .conv import Conv, DWConv, GhostConv, LightConv, RepConv
from .transformer import TransformerBlock
-__all__ = ('DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 'C2f', 'C3x', 'C3TR', 'C3Ghost',
- 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'RepC3', 'ResNetLayer')
+__all__ = (
+ "DFL",
+ "HGBlock",
+ "HGStem",
+ "SPP",
+ "SPPF",
+ "C1",
+ "C2",
+ "C3",
+ "C2f",
+ "C3x",
+ "C3TR",
+ "C3Ghost",
+ "GhostBottleneck",
+ "Bottleneck",
+ "BottleneckCSP",
+ "Proto",
+ "RepC3",
+ "ResNetLayer",
+)
class DFL(nn.Module):
@@ -284,9 +302,11 @@ class GhostBottleneck(nn.Module):
self.conv = nn.Sequential(
GhostConv(c1, c_, 1, 1), # pw
DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
- GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
- self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1,
- act=False)) if s == 2 else nn.Identity()
+ GhostConv(c_, c2, 1, 1, act=False), # pw-linear
+ )
+ self.shortcut = (
+ nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
+ )
def forward(self, x):
"""Applies skip connection and concatenation to input tensor."""
@@ -359,8 +379,9 @@ class ResNetLayer(nn.Module):
self.is_first = is_first
if self.is_first:
- self.layer = nn.Sequential(Conv(c1, c2, k=7, s=2, p=3, act=True),
- nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
+ self.layer = nn.Sequential(
+ Conv(c1, c2, k=7, s=2, p=3, act=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ )
else:
blocks = [ResNetBlock(c1, c2, s, e=e)]
blocks.extend([ResNetBlock(e * c2, c2, 1, e=e) for _ in range(n - 1)])
diff --git a/ultralytics/nn/modules/conv.py b/ultralytics/nn/modules/conv.py
index 7fe615d9..399c4225 100644
--- a/ultralytics/nn/modules/conv.py
+++ b/ultralytics/nn/modules/conv.py
@@ -7,8 +7,21 @@ import numpy as np
import torch
import torch.nn as nn
-__all__ = ('Conv', 'Conv2', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv',
- 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'RepConv')
+__all__ = (
+ "Conv",
+ "Conv2",
+ "LightConv",
+ "DWConv",
+ "DWConvTranspose2d",
+ "ConvTranspose",
+ "Focus",
+ "GhostConv",
+ "ChannelAttention",
+ "SpatialAttention",
+ "CBAM",
+ "Concat",
+ "RepConv",
+)
def autopad(k, p=None, d=1): # kernel, padding, dilation
@@ -22,6 +35,7 @@ def autopad(k, p=None, d=1): # kernel, padding, dilation
class Conv(nn.Module):
"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
+
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
@@ -60,9 +74,9 @@ class Conv2(Conv):
"""Fuse parallel convolutions."""
w = torch.zeros_like(self.conv.weight.data)
i = [x // 2 for x in w.shape[2:]]
- w[:, :, i[0]:i[0] + 1, i[1]:i[1] + 1] = self.cv2.weight.data.clone()
+ w[:, :, i[0] : i[0] + 1, i[1] : i[1] + 1] = self.cv2.weight.data.clone()
self.conv.weight.data += w
- self.__delattr__('cv2')
+ self.__delattr__("cv2")
self.forward = self.forward_fuse
@@ -102,6 +116,7 @@ class DWConvTranspose2d(nn.ConvTranspose2d):
class ConvTranspose(nn.Module):
"""Convolution transpose 2d layer."""
+
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
@@ -164,6 +179,7 @@ class RepConv(nn.Module):
This module is used in RT-DETR.
Based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
"""
+
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
@@ -214,7 +230,7 @@ class RepConv(nn.Module):
beta = branch.bn.bias
eps = branch.bn.eps
elif isinstance(branch, nn.BatchNorm2d):
- if not hasattr(self, 'id_tensor'):
+ if not hasattr(self, "id_tensor"):
input_dim = self.c1 // self.g
kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)
for i in range(self.c1):
@@ -232,29 +248,31 @@ class RepConv(nn.Module):
def fuse_convs(self):
"""Combines two convolution layers into a single layer and removes unused attributes from the class."""
- if hasattr(self, 'conv'):
+ if hasattr(self, "conv"):
return
kernel, bias = self.get_equivalent_kernel_bias()
- self.conv = nn.Conv2d(in_channels=self.conv1.conv.in_channels,
- out_channels=self.conv1.conv.out_channels,
- kernel_size=self.conv1.conv.kernel_size,
- stride=self.conv1.conv.stride,
- padding=self.conv1.conv.padding,
- dilation=self.conv1.conv.dilation,
- groups=self.conv1.conv.groups,
- bias=True).requires_grad_(False)
+ self.conv = nn.Conv2d(
+ in_channels=self.conv1.conv.in_channels,
+ out_channels=self.conv1.conv.out_channels,
+ kernel_size=self.conv1.conv.kernel_size,
+ stride=self.conv1.conv.stride,
+ padding=self.conv1.conv.padding,
+ dilation=self.conv1.conv.dilation,
+ groups=self.conv1.conv.groups,
+ bias=True,
+ ).requires_grad_(False)
self.conv.weight.data = kernel
self.conv.bias.data = bias
for para in self.parameters():
para.detach_()
- self.__delattr__('conv1')
- self.__delattr__('conv2')
- if hasattr(self, 'nm'):
- self.__delattr__('nm')
- if hasattr(self, 'bn'):
- self.__delattr__('bn')
- if hasattr(self, 'id_tensor'):
- self.__delattr__('id_tensor')
+ self.__delattr__("conv1")
+ self.__delattr__("conv2")
+ if hasattr(self, "nm"):
+ self.__delattr__("nm")
+ if hasattr(self, "bn"):
+ self.__delattr__("bn")
+ if hasattr(self, "id_tensor"):
+ self.__delattr__("id_tensor")
class ChannelAttention(nn.Module):
@@ -278,7 +296,7 @@ class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
"""Initialize Spatial-attention module with kernel size argument."""
super().__init__()
- assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
+ assert kernel_size in (3, 7), "kernel size must be 3 or 7"
padding = 3 if kernel_size == 7 else 1
self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.act = nn.Sigmoid()
diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py
index cf4cff2f..d5951baa 100644
--- a/ultralytics/nn/modules/head.py
+++ b/ultralytics/nn/modules/head.py
@@ -14,11 +14,12 @@ from .conv import Conv
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
from .utils import bias_init_with_prob, linear_init_
-__all__ = 'Detect', 'Segment', 'Pose', 'Classify', 'OBB', 'RTDETRDecoder'
+__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"
class Detect(nn.Module):
"""YOLOv8 Detect head for detection models."""
+
dynamic = False # force grid reconstruction
export = False # export mode
shape = None
@@ -35,7 +36,8 @@ class Detect(nn.Module):
self.stride = torch.zeros(self.nl) # strides computed during build
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
self.cv2 = nn.ModuleList(
- nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
+ )
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
@@ -53,14 +55,14 @@ class Detect(nn.Module):
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
- if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV ops
- box = x_cat[:, :self.reg_max * 4]
- cls = x_cat[:, self.reg_max * 4:]
+ if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV ops
+ box = x_cat[:, : self.reg_max * 4]
+ cls = x_cat[:, self.reg_max * 4 :]
else:
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
dbox = self.decode_bboxes(box)
- if self.export and self.format in ('tflite', 'edgetpu'):
+ if self.export and self.format in ("tflite", "edgetpu"):
# Precompute normalization factor to increase numerical stability
# See https://github.com/ultralytics/ultralytics/issues/7371
img_h = shape[2]
@@ -79,7 +81,7 @@ class Detect(nn.Module):
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
a[-1].bias.data[:] = 1.0 # box
- b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
+ b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
def decode_bboxes(self, bboxes):
"""Decode bounding boxes."""
@@ -214,26 +216,28 @@ class RTDETRDecoder(nn.Module):
and class labels for objects in an image. It integrates features from multiple layers and runs through a series of
Transformer decoder layers to output the final predictions.
"""
+
export = False # export mode
def __init__(
- self,
- nc=80,
- ch=(512, 1024, 2048),
- hd=256, # hidden dim
- nq=300, # num queries
- ndp=4, # num decoder points
- nh=8, # num head
- ndl=6, # num decoder layers
- d_ffn=1024, # dim of feedforward
- dropout=0.,
- act=nn.ReLU(),
- eval_idx=-1,
- # Training args
- nd=100, # num denoising
- label_noise_ratio=0.5,
- box_noise_scale=1.0,
- learnt_init_query=False):
+ self,
+ nc=80,
+ ch=(512, 1024, 2048),
+ hd=256, # hidden dim
+ nq=300, # num queries
+ ndp=4, # num decoder points
+ nh=8, # num head
+ ndl=6, # num decoder layers
+ d_ffn=1024, # dim of feedforward
+ dropout=0.0,
+ act=nn.ReLU(),
+ eval_idx=-1,
+ # Training args
+ nd=100, # num denoising
+ label_noise_ratio=0.5,
+ box_noise_scale=1.0,
+ learnt_init_query=False,
+ ):
"""
Initializes the RTDETRDecoder module with the given parameters.
@@ -302,28 +306,30 @@ class RTDETRDecoder(nn.Module):
feats, shapes = self._get_encoder_input(x)
# Prepare denoising training
- dn_embed, dn_bbox, attn_mask, dn_meta = \
- get_cdn_group(batch,
- self.nc,
- self.num_queries,
- self.denoising_class_embed.weight,
- self.num_denoising,
- self.label_noise_ratio,
- self.box_noise_scale,
- self.training)
-
- embed, refer_bbox, enc_bboxes, enc_scores = \
- self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)
+ dn_embed, dn_bbox, attn_mask, dn_meta = get_cdn_group(
+ batch,
+ self.nc,
+ self.num_queries,
+ self.denoising_class_embed.weight,
+ self.num_denoising,
+ self.label_noise_ratio,
+ self.box_noise_scale,
+ self.training,
+ )
+
+ embed, refer_bbox, enc_bboxes, enc_scores = self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)
# Decoder
- dec_bboxes, dec_scores = self.decoder(embed,
- refer_bbox,
- feats,
- shapes,
- self.dec_bbox_head,
- self.dec_score_head,
- self.query_pos_head,
- attn_mask=attn_mask)
+ dec_bboxes, dec_scores = self.decoder(
+ embed,
+ refer_bbox,
+ feats,
+ shapes,
+ self.dec_bbox_head,
+ self.dec_score_head,
+ self.query_pos_head,
+ attn_mask=attn_mask,
+ )
x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
if self.training:
return x
@@ -331,24 +337,24 @@ class RTDETRDecoder(nn.Module):
y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
return y if self.export else (y, x)
- def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):
+ def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device="cpu", eps=1e-2):
"""Generates anchor bounding boxes for given shapes with specific grid size and validates them."""
anchors = []
for i, (h, w) in enumerate(shapes):
sy = torch.arange(end=h, dtype=dtype, device=device)
sx = torch.arange(end=w, dtype=dtype, device=device)
- grid_y, grid_x = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx)
+ grid_y, grid_x = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)
valid_WH = torch.tensor([w, h], dtype=dtype, device=device)
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2)
- wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0 ** i)
+ wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**i)
anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4)
anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4)
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1
anchors = torch.log(anchors / (1 - anchors))
- anchors = anchors.masked_fill(~valid_mask, float('inf'))
+ anchors = anchors.masked_fill(~valid_mask, float("inf"))
return anchors, valid_mask
def _get_encoder_input(self, x):
@@ -415,13 +421,13 @@ class RTDETRDecoder(nn.Module):
# NOTE: the weight initialization in `linear_init_` would cause NaN when training with custom datasets.
# linear_init_(self.enc_score_head)
constant_(self.enc_score_head.bias, bias_cls)
- constant_(self.enc_bbox_head.layers[-1].weight, 0.)
- constant_(self.enc_bbox_head.layers[-1].bias, 0.)
+ constant_(self.enc_bbox_head.layers[-1].weight, 0.0)
+ constant_(self.enc_bbox_head.layers[-1].bias, 0.0)
for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
# linear_init_(cls_)
constant_(cls_.bias, bias_cls)
- constant_(reg_.layers[-1].weight, 0.)
- constant_(reg_.layers[-1].bias, 0.)
+ constant_(reg_.layers[-1].weight, 0.0)
+ constant_(reg_.layers[-1].bias, 0.0)
linear_init_(self.enc_output[0])
xavier_uniform_(self.enc_output[0].weight)
diff --git a/ultralytics/nn/modules/transformer.py b/ultralytics/nn/modules/transformer.py
index 9fe9597f..465b7170 100644
--- a/ultralytics/nn/modules/transformer.py
+++ b/ultralytics/nn/modules/transformer.py
@@ -11,8 +11,18 @@ from torch.nn.init import constant_, xavier_uniform_
from .conv import Conv
from .utils import _get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch
-__all__ = ('TransformerEncoderLayer', 'TransformerLayer', 'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'AIFI',
- 'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP')
+__all__ = (
+ "TransformerEncoderLayer",
+ "TransformerLayer",
+ "TransformerBlock",
+ "MLPBlock",
+ "LayerNorm2d",
+ "AIFI",
+ "DeformableTransformerDecoder",
+ "DeformableTransformerDecoderLayer",
+ "MSDeformAttn",
+ "MLP",
+)
class TransformerEncoderLayer(nn.Module):
@@ -22,9 +32,11 @@ class TransformerEncoderLayer(nn.Module):
"""Initialize the TransformerEncoderLayer with specified parameters."""
super().__init__()
from ...utils.torch_utils import TORCH_1_9
+
if not TORCH_1_9:
raise ModuleNotFoundError(
- 'TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True).')
+ "TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True)."
+ )
self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)
# Implementation of Feedforward model
self.fc1 = nn.Linear(c1, cm)
@@ -91,12 +103,11 @@ class AIFI(TransformerEncoderLayer):
"""Builds 2D sine-cosine position embedding."""
grid_w = torch.arange(int(w), dtype=torch.float32)
grid_h = torch.arange(int(h), dtype=torch.float32)
- grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij')
- assert embed_dim % 4 == 0, \
- 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
+ assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
- omega = 1. / (temperature ** omega)
+ omega = 1.0 / (temperature**omega)
out_w = grid_w.flatten()[..., None] @ omega[None]
out_h = grid_h.flatten()[..., None] @ omega[None]
@@ -213,10 +224,10 @@ class MSDeformAttn(nn.Module):
"""Initialize MSDeformAttn with the given parameters."""
super().__init__()
if d_model % n_heads != 0:
- raise ValueError(f'd_model must be divisible by n_heads, but got {d_model} and {n_heads}')
+ raise ValueError(f"d_model must be divisible by n_heads, but got {d_model} and {n_heads}")
_d_per_head = d_model // n_heads
# Better to set _d_per_head to a power of 2 which is more efficient in a CUDA implementation
- assert _d_per_head * n_heads == d_model, '`d_model` must be divisible by `n_heads`'
+ assert _d_per_head * n_heads == d_model, "`d_model` must be divisible by `n_heads`"
self.im2col_step = 64
@@ -234,21 +245,24 @@ class MSDeformAttn(nn.Module):
def _reset_parameters(self):
"""Reset module parameters."""
- constant_(self.sampling_offsets.weight.data, 0.)
+ constant_(self.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
- grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(
- 1, self.n_levels, self.n_points, 1)
+ grid_init = (
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
+ .view(self.n_heads, 1, 1, 2)
+ .repeat(1, self.n_levels, self.n_points, 1)
+ )
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
- constant_(self.attention_weights.weight.data, 0.)
- constant_(self.attention_weights.bias.data, 0.)
+ constant_(self.attention_weights.weight.data, 0.0)
+ constant_(self.attention_weights.bias.data, 0.0)
xavier_uniform_(self.value_proj.weight.data)
- constant_(self.value_proj.bias.data, 0.)
+ constant_(self.value_proj.bias.data, 0.0)
xavier_uniform_(self.output_proj.weight.data)
- constant_(self.output_proj.bias.data, 0.)
+ constant_(self.output_proj.bias.data, 0.0)
def forward(self, query, refer_bbox, value, value_shapes, value_mask=None):
"""
@@ -288,7 +302,7 @@ class MSDeformAttn(nn.Module):
add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5
sampling_locations = refer_bbox[:, :, None, :, None, :2] + add
else:
- raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {num_points}.')
+ raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {num_points}.")
output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights)
return self.output_proj(output)
@@ -301,7 +315,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py
"""
- def __init__(self, d_model=256, n_heads=8, d_ffn=1024, dropout=0., act=nn.ReLU(), n_levels=4, n_points=4):
+ def __init__(self, d_model=256, n_heads=8, d_ffn=1024, dropout=0.0, act=nn.ReLU(), n_levels=4, n_points=4):
"""Initialize the DeformableTransformerDecoderLayer with the given parameters."""
super().__init__()
@@ -339,14 +353,16 @@ class DeformableTransformerDecoderLayer(nn.Module):
# Self attention
q = k = self.with_pos_embed(embed, query_pos)
- tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1),
- attn_mask=attn_mask)[0].transpose(0, 1)
+ tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), attn_mask=attn_mask)[
+ 0
+ ].transpose(0, 1)
embed = embed + self.dropout1(tgt)
embed = self.norm1(embed)
# Cross attention
- tgt = self.cross_attn(self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes,
- padding_mask)
+ tgt = self.cross_attn(
+ self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes, padding_mask
+ )
embed = embed + self.dropout2(tgt)
embed = self.norm2(embed)
@@ -370,16 +386,17 @@ class DeformableTransformerDecoder(nn.Module):
self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
def forward(
- self,
- embed, # decoder embeddings
- refer_bbox, # anchor
- feats, # image features
- shapes, # feature shapes
- bbox_head,
- score_head,
- pos_mlp,
- attn_mask=None,
- padding_mask=None):
+ self,
+ embed, # decoder embeddings
+ refer_bbox, # anchor
+ feats, # image features
+ shapes, # feature shapes
+ bbox_head,
+ score_head,
+ pos_mlp,
+ attn_mask=None,
+ padding_mask=None,
+ ):
"""Perform the forward pass through the entire decoder."""
output = embed
dec_bboxes = []
diff --git a/ultralytics/nn/modules/utils.py b/ultralytics/nn/modules/utils.py
index c7bec7af..2cb615a6 100644
--- a/ultralytics/nn/modules/utils.py
+++ b/ultralytics/nn/modules/utils.py
@@ -10,7 +10,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import uniform_
-__all__ = 'multi_scale_deformable_attn_pytorch', 'inverse_sigmoid'
+__all__ = "multi_scale_deformable_attn_pytorch", "inverse_sigmoid"
def _get_clones(module, n):
@@ -27,7 +27,7 @@ def linear_init_(module):
"""Initialize the weights and biases of a linear module."""
bound = 1 / math.sqrt(module.weight.shape[0])
uniform_(module.weight, -bound, bound)
- if hasattr(module, 'bias') and module.bias is not None:
+ if hasattr(module, "bias") and module.bias is not None:
uniform_(module.bias, -bound, bound)
@@ -39,9 +39,12 @@ def inverse_sigmoid(x, eps=1e-5):
return torch.log(x1 / x2)
-def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shapes: torch.Tensor,
- sampling_locations: torch.Tensor,
- attention_weights: torch.Tensor) -> torch.Tensor:
+def multi_scale_deformable_attn_pytorch(
+ value: torch.Tensor,
+ value_spatial_shapes: torch.Tensor,
+ sampling_locations: torch.Tensor,
+ attention_weights: torch.Tensor,
+) -> torch.Tensor:
"""
Multi-scale deformable attention.
@@ -58,23 +61,25 @@ def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shape
# bs, H_*W_, num_heads*embed_dims ->
# bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_
- value_l_ = (value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_))
+ value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
# bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
# bs*num_heads, embed_dims, num_queries, num_points
- sampling_value_l_ = F.grid_sample(value_l_,
- sampling_grid_l_,
- mode='bilinear',
- padding_mode='zeros',
- align_corners=False)
+ sampling_value_l_ = F.grid_sample(
+ value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
+ )
sampling_value_list.append(sampling_value_l_)
# (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, 1, num_queries, num_levels*num_points)
- attention_weights = attention_weights.transpose(1, 2).reshape(bs * num_heads, 1, num_queries,
- num_levels * num_points)
- output = ((torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(
- bs, num_heads * embed_dims, num_queries))
+ attention_weights = attention_weights.transpose(1, 2).reshape(
+ bs * num_heads, 1, num_queries, num_levels * num_points
+ )
+ output = (
+ (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
+ .sum(-1)
+ .view(bs, num_heads * embed_dims, num_queries)
+ )
return output.transpose(1, 2).contiguous()
diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py
index 8261ed04..8efd8e33 100644
--- a/ultralytics/nn/tasks.py
+++ b/ultralytics/nn/tasks.py
@@ -7,16 +7,54 @@ from pathlib import Path
import torch
import torch.nn as nn
-from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, OBB, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost,
- C3x, Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv,
- DWConvTranspose2d, Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3,
- RepConv, ResNetLayer, RTDETRDecoder, Segment)
+from ultralytics.nn.modules import (
+ AIFI,
+ C1,
+ C2,
+ C3,
+ C3TR,
+ OBB,
+ SPP,
+ SPPF,
+ Bottleneck,
+ BottleneckCSP,
+ C2f,
+ C3Ghost,
+ C3x,
+ Classify,
+ Concat,
+ Conv,
+ Conv2,
+ ConvTranspose,
+ Detect,
+ DWConv,
+ DWConvTranspose2d,
+ Focus,
+ GhostBottleneck,
+ GhostConv,
+ HGBlock,
+ HGStem,
+ Pose,
+ RepC3,
+ RepConv,
+ ResNetLayer,
+ RTDETRDecoder,
+ Segment,
+)
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss
from ultralytics.utils.plotting import feature_visualization
-from ultralytics.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, intersect_dicts,
- make_divisible, model_info, scale_img, time_sync)
+from ultralytics.utils.torch_utils import (
+ fuse_conv_and_bn,
+ fuse_deconv_and_bn,
+ initialize_weights,
+ intersect_dicts,
+ make_divisible,
+ model_info,
+ scale_img,
+ time_sync,
+)
try:
import thop
@@ -90,8 +128,10 @@ class BaseModel(nn.Module):
def _predict_augment(self, x):
"""Perform augmentations on input image x and return augmented inference."""
- LOGGER.warning(f'WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. '
- f'Reverting to single-scale inference instead.')
+ LOGGER.warning(
+ f"WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. "
+ f"Reverting to single-scale inference instead."
+ )
return self._predict_once(x)
def _profile_one_layer(self, m, x, dt):
@@ -108,14 +148,14 @@ class BaseModel(nn.Module):
None
"""
c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
- flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
+ flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # FLOPs
t = time_sync()
for _ in range(10):
m(x.copy() if c else x)
dt.append((time_sync() - t) * 100)
if m == self.model[0]:
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
- LOGGER.info(f'{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}')
+ LOGGER.info(f"{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}")
if c:
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
@@ -129,15 +169,15 @@ class BaseModel(nn.Module):
"""
if not self.is_fused():
for m in self.model.modules():
- if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, 'bn'):
+ if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, "bn"):
if isinstance(m, Conv2):
m.fuse_convs()
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
- delattr(m, 'bn') # remove batchnorm
+ delattr(m, "bn") # remove batchnorm
m.forward = m.forward_fuse # update forward
- if isinstance(m, ConvTranspose) and hasattr(m, 'bn'):
+ if isinstance(m, ConvTranspose) and hasattr(m, "bn"):
m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
- delattr(m, 'bn') # remove batchnorm
+ delattr(m, "bn") # remove batchnorm
m.forward = m.forward_fuse # update forward
if isinstance(m, RepConv):
m.fuse_convs()
@@ -156,7 +196,7 @@ class BaseModel(nn.Module):
Returns:
(bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
"""
- bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
+ bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
def info(self, detailed=False, verbose=True, imgsz=640):
@@ -196,12 +236,12 @@ class BaseModel(nn.Module):
weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
verbose (bool, optional): Whether to log the transfer progress. Defaults to True.
"""
- model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts
+ model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
csd = model.float().state_dict() # checkpoint state_dict as FP32
csd = intersect_dicts(csd, self.state_dict()) # intersect
self.load_state_dict(csd, strict=False) # load
if verbose:
- LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
+ LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights")
def loss(self, batch, preds=None):
"""
@@ -211,33 +251,33 @@ class BaseModel(nn.Module):
batch (dict): Batch to compute loss on
preds (torch.Tensor | List[torch.Tensor]): Predictions.
"""
- if not hasattr(self, 'criterion'):
+ if not hasattr(self, "criterion"):
self.criterion = self.init_criterion()
- preds = self.forward(batch['img']) if preds is None else preds
+ preds = self.forward(batch["img"]) if preds is None else preds
return self.criterion(preds, batch)
def init_criterion(self):
"""Initialize the loss criterion for the BaseModel."""
- raise NotImplementedError('compute_loss() needs to be implemented by task heads')
+ raise NotImplementedError("compute_loss() needs to be implemented by task heads")
class DetectionModel(BaseModel):
"""YOLOv8 detection model."""
- def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes
+ def __init__(self, cfg="yolov8n.yaml", ch=3, nc=None, verbose=True): # model, input channels, number of classes
"""Initialize the YOLOv8 detection model with the given config and parameters."""
super().__init__()
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
# Define model
- ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
- if nc and nc != self.yaml['nc']:
+ ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
+ if nc and nc != self.yaml["nc"]:
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
- self.yaml['nc'] = nc # override YAML value
+ self.yaml["nc"] = nc # override YAML value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
- self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
- self.inplace = self.yaml.get('inplace', True)
+ self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
+ self.inplace = self.yaml.get("inplace", True)
# Build strides
m = self.model[-1] # Detect()
@@ -255,7 +295,7 @@ class DetectionModel(BaseModel):
initialize_weights(self)
if verbose:
self.info()
- LOGGER.info('')
+ LOGGER.info("")
def _predict_augment(self, x):
"""Perform augmentations on input image x and return augmented inference and train outputs."""
@@ -285,9 +325,9 @@ class DetectionModel(BaseModel):
def _clip_augmented(self, y):
"""Clip YOLO augmented inference tails."""
nl = self.model[-1].nl # number of detection layers (P3-P5)
- g = sum(4 ** x for x in range(nl)) # grid points
+ g = sum(4**x for x in range(nl)) # grid points
e = 1 # exclude layer count
- i = (y[0].shape[-1] // g) * sum(4 ** x for x in range(e)) # indices
+ i = (y[0].shape[-1] // g) * sum(4**x for x in range(e)) # indices
y[0] = y[0][..., :-i] # large
i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
y[-1] = y[-1][..., i:] # small
@@ -301,7 +341,7 @@ class DetectionModel(BaseModel):
class OBBModel(DetectionModel):
""""YOLOv8 Oriented Bounding Box (OBB) model."""
- def __init__(self, cfg='yolov8n-obb.yaml', ch=3, nc=None, verbose=True):
+ def __init__(self, cfg="yolov8n-obb.yaml", ch=3, nc=None, verbose=True):
"""Initialize YOLOv8 OBB model with given config and parameters."""
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
@@ -313,7 +353,7 @@ class OBBModel(DetectionModel):
class SegmentationModel(DetectionModel):
"""YOLOv8 segmentation model."""
- def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True):
+ def __init__(self, cfg="yolov8n-seg.yaml", ch=3, nc=None, verbose=True):
"""Initialize YOLOv8 segmentation model with given config and parameters."""
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
@@ -325,13 +365,13 @@ class SegmentationModel(DetectionModel):
class PoseModel(DetectionModel):
"""YOLOv8 pose model."""
- def __init__(self, cfg='yolov8n-pose.yaml', ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
+ def __init__(self, cfg="yolov8n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
"""Initialize YOLOv8 Pose model."""
if not isinstance(cfg, dict):
cfg = yaml_model_load(cfg) # load model YAML
- if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg['kpt_shape']):
+ if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]):
LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}")
- cfg['kpt_shape'] = data_kpt_shape
+ cfg["kpt_shape"] = data_kpt_shape
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
def init_criterion(self):
@@ -342,7 +382,7 @@ class PoseModel(DetectionModel):
class ClassificationModel(BaseModel):
"""YOLOv8 classification model."""
- def __init__(self, cfg='yolov8n-cls.yaml', ch=3, nc=None, verbose=True):
+ def __init__(self, cfg="yolov8n-cls.yaml", ch=3, nc=None, verbose=True):
"""Init ClassificationModel with YAML, channels, number of classes, verbose flag."""
super().__init__()
self._from_yaml(cfg, ch, nc, verbose)
@@ -352,21 +392,21 @@ class ClassificationModel(BaseModel):
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
# Define model
- ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
- if nc and nc != self.yaml['nc']:
+ ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
+ if nc and nc != self.yaml["nc"]:
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
- self.yaml['nc'] = nc # override YAML value
- elif not nc and not self.yaml.get('nc', None):
- raise ValueError('nc not specified. Must specify nc in model.yaml or function arguments.')
+ self.yaml["nc"] = nc # override YAML value
+ elif not nc and not self.yaml.get("nc", None):
+ raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.")
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
self.stride = torch.Tensor([1]) # no stride constraints
- self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
+ self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
self.info()
@staticmethod
def reshape_outputs(model, nc):
"""Update a TorchVision classification model to class count 'n' if required."""
- name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
+ name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module
if isinstance(m, Classify): # YOLO Classify() head
if m.linear.out_features != nc:
m.linear = nn.Linear(m.linear.in_features, nc)
@@ -409,7 +449,7 @@ class RTDETRDetectionModel(DetectionModel):
predict: Performs a forward pass through the network and returns the output.
"""
- def __init__(self, cfg='rtdetr-l.yaml', ch=3, nc=None, verbose=True):
+ def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):
"""
Initialize the RTDETRDetectionModel.
@@ -438,39 +478,39 @@ class RTDETRDetectionModel(DetectionModel):
Returns:
(tuple): A tuple containing the total loss and main three losses in a tensor.
"""
- if not hasattr(self, 'criterion'):
+ if not hasattr(self, "criterion"):
self.criterion = self.init_criterion()
- img = batch['img']
+ img = batch["img"]
# NOTE: preprocess gt_bbox and gt_labels to list.
bs = len(img)
- batch_idx = batch['batch_idx']
+ batch_idx = batch["batch_idx"]
gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
targets = {
- 'cls': batch['cls'].to(img.device, dtype=torch.long).view(-1),
- 'bboxes': batch['bboxes'].to(device=img.device),
- 'batch_idx': batch_idx.to(img.device, dtype=torch.long).view(-1),
- 'gt_groups': gt_groups}
+ "cls": batch["cls"].to(img.device, dtype=torch.long).view(-1),
+ "bboxes": batch["bboxes"].to(device=img.device),
+ "batch_idx": batch_idx.to(img.device, dtype=torch.long).view(-1),
+ "gt_groups": gt_groups,
+ }
preds = self.predict(img, batch=targets) if preds is None else preds
dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]
if dn_meta is None:
dn_bboxes, dn_scores = None, None
else:
- dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta['dn_num_split'], dim=2)
- dn_scores, dec_scores = torch.split(dec_scores, dn_meta['dn_num_split'], dim=2)
+ dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta["dn_num_split"], dim=2)
+ dn_scores, dec_scores = torch.split(dec_scores, dn_meta["dn_num_split"], dim=2)
dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4)
dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])
- loss = self.criterion((dec_bboxes, dec_scores),
- targets,
- dn_bboxes=dn_bboxes,
- dn_scores=dn_scores,
- dn_meta=dn_meta)
+ loss = self.criterion(
+ (dec_bboxes, dec_scores), targets, dn_bboxes=dn_bboxes, dn_scores=dn_scores, dn_meta=dn_meta
+ )
# NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
- return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']],
- device=img.device)
+ return sum(loss.values()), torch.as_tensor(
+ [loss[k].detach() for k in ["loss_giou", "loss_class", "loss_bbox"]], device=img.device
+ )
def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
"""
@@ -553,6 +593,7 @@ def temporary_modules(modules=None):
import importlib
import sys
+
try:
# Set modules in sys.modules under their old name
for old, new in modules.items():
@@ -580,30 +621,38 @@ def torch_safe_load(weight):
"""
from ultralytics.utils.downloads import attempt_download_asset
- check_suffix(file=weight, suffix='.pt')
+ check_suffix(file=weight, suffix=".pt")
file = attempt_download_asset(weight) # search online if missing locally
try:
- with temporary_modules({
- 'ultralytics.yolo.utils': 'ultralytics.utils',
- 'ultralytics.yolo.v8': 'ultralytics.models.yolo',
- 'ultralytics.yolo.data': 'ultralytics.data'}): # for legacy 8.0 Classify and Pose models
- return torch.load(file, map_location='cpu'), file # load
+ with temporary_modules(
+ {
+ "ultralytics.yolo.utils": "ultralytics.utils",
+ "ultralytics.yolo.v8": "ultralytics.models.yolo",
+ "ultralytics.yolo.data": "ultralytics.data",
+ }
+ ): # for legacy 8.0 Classify and Pose models
+ return torch.load(file, map_location="cpu"), file # load
except ModuleNotFoundError as e: # e.name is missing module name
- if e.name == 'models':
+ if e.name == "models":
raise TypeError(
- emojis(f'ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained '
- f'with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with '
- f'YOLOv8 at https://github.com/ultralytics/ultralytics.'
- f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
- f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")) from e
- LOGGER.warning(f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
- f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
- f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
- f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")
+ emojis(
+ f"ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained "
+ f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "
+ f"YOLOv8 at https://github.com/ultralytics/ultralytics."
+ f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
+ f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
+ )
+ ) from e
+ LOGGER.warning(
+ f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
+ f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
+ f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
+ f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
+ )
check_requirements(e.name) # install missing module
- return torch.load(file, map_location='cpu'), file # load
+ return torch.load(file, map_location="cpu"), file # load
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
@@ -612,25 +661,25 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
ensemble = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
ckpt, w = torch_safe_load(w) # load ckpt
- args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} if 'train_args' in ckpt else None # combined args
- model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
+ args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None # combined args
+ model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
# Model compatibility updates
model.args = args # attach args to model
model.pt_path = w # attach *.pt file path to model
model.task = guess_model_task(model)
- if not hasattr(model, 'stride'):
- model.stride = torch.tensor([32.])
+ if not hasattr(model, "stride"):
+ model.stride = torch.tensor([32.0])
# Append
- ensemble.append(model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval()) # model in eval mode
+ ensemble.append(model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()) # model in eval mode
# Module updates
for m in ensemble.modules():
t = type(m)
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
m.inplace = inplace
- elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
+ elif t is nn.Upsample and not hasattr(m, "recompute_scale_factor"):
m.recompute_scale_factor = None # torch 1.11.0 compatibility
# Return model
@@ -638,35 +687,35 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
return ensemble[-1]
# Return ensemble
- LOGGER.info(f'Ensemble created with {weights}\n')
- for k in 'names', 'nc', 'yaml':
+ LOGGER.info(f"Ensemble created with {weights}\n")
+ for k in "names", "nc", "yaml":
setattr(ensemble, k, getattr(ensemble[0], k))
ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride
- assert all(ensemble[0].nc == m.nc for m in ensemble), f'Models differ in class counts {[m.nc for m in ensemble]}'
+ assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}"
return ensemble
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
"""Loads a single model weights."""
ckpt, weight = torch_safe_load(weight) # load ckpt
- args = {**DEFAULT_CFG_DICT, **(ckpt.get('train_args', {}))} # combine model and default args, preferring model args
- model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
+ args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))} # combine model and default args, preferring model args
+ model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
# Model compatibility updates
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
model.pt_path = weight # attach *.pt file path to model
model.task = guess_model_task(model)
- if not hasattr(model, 'stride'):
- model.stride = torch.tensor([32.])
+ if not hasattr(model, "stride"):
+ model.stride = torch.tensor([32.0])
- model = model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval() # model in eval mode
+ model = model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval() # model in eval mode
# Module updates
for m in model.modules():
t = type(m)
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
m.inplace = inplace
- elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
+ elif t is nn.Upsample and not hasattr(m, "recompute_scale_factor"):
m.recompute_scale_factor = None # torch 1.11.0 compatibility
# Return model and ckpt
@@ -678,11 +727,11 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
import ast
# Args
- max_channels = float('inf')
- nc, act, scales = (d.get(x) for x in ('nc', 'activation', 'scales'))
- depth, width, kpt_shape = (d.get(x, 1.0) for x in ('depth_multiple', 'width_multiple', 'kpt_shape'))
+ max_channels = float("inf")
+ nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
+ depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
if scales:
- scale = d.get('scale')
+ scale = d.get("scale")
if not scale:
scale = tuple(scales.keys())[0]
LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")
@@ -697,16 +746,37 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
ch = [ch]
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
- for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
- m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m] # get module
+ for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
+ m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m] # get module
for j, a in enumerate(args):
if isinstance(a, str):
with contextlib.suppress(ValueError):
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
- if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
- BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3):
+ if m in (
+ Classify,
+ Conv,
+ ConvTranspose,
+ GhostConv,
+ Bottleneck,
+ GhostBottleneck,
+ SPP,
+ SPPF,
+ DWConv,
+ Focus,
+ BottleneckCSP,
+ C1,
+ C2,
+ C2f,
+ C3,
+ C3TR,
+ C3Ghost,
+ nn.ConvTranspose2d,
+ DWConvTranspose2d,
+ C3x,
+ RepC3,
+ ):
c1, c2 = ch[f], args[0]
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
c2 = make_divisible(min(c2, max_channels) * width, 8)
@@ -739,11 +809,11 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
c2 = ch[f]
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
- t = str(m)[8:-2].replace('__main__.', '') # module type
+ t = str(m)[8:-2].replace("__main__.", "") # module type
m.np = sum(x.numel() for x in m_.parameters()) # number params
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
if verbose:
- LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}') # print
+ LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}") # print
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
layers.append(m_)
if i == 0:
@@ -757,16 +827,16 @@ def yaml_model_load(path):
import re
path = Path(path)
- if path.stem in (f'yolov{d}{x}6' for x in 'nsmlx' for d in (5, 8)):
- new_stem = re.sub(r'(\d+)([nslmx])6(.+)?$', r'\1\2-p6\3', path.stem)
- LOGGER.warning(f'WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.')
+ if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
+ new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
+ LOGGER.warning(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.")
path = path.with_name(new_stem + path.suffix)
- unified_path = re.sub(r'(\d+)([nslmx])(.+)?$', r'\1\3', str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
+ unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
d = yaml_load(yaml_file) # model dict
- d['scale'] = guess_model_scale(path)
- d['yaml_file'] = str(path)
+ d["scale"] = guess_model_scale(path)
+ d["yaml_file"] = str(path)
return d
@@ -784,8 +854,9 @@ def guess_model_scale(model_path):
"""
with contextlib.suppress(AttributeError):
import re
- return re.search(r'yolov\d+([nslmx])', Path(model_path).stem).group(1) # n, s, m, l, or x
- return ''
+
+ return re.search(r"yolov\d+([nslmx])", Path(model_path).stem).group(1) # n, s, m, l, or x
+ return ""
def guess_model_task(model):
@@ -804,17 +875,17 @@ def guess_model_task(model):
def cfg2task(cfg):
"""Guess from YAML dictionary."""
- m = cfg['head'][-1][-2].lower() # output module name
- if m in ('classify', 'classifier', 'cls', 'fc'):
- return 'classify'
- if m == 'detect':
- return 'detect'
- if m == 'segment':
- return 'segment'
- if m == 'pose':
- return 'pose'
- if m == 'obb':
- return 'obb'
+ m = cfg["head"][-1][-2].lower() # output module name
+ if m in ("classify", "classifier", "cls", "fc"):
+ return "classify"
+ if m == "detect":
+ return "detect"
+ if m == "segment":
+ return "segment"
+ if m == "pose":
+ return "pose"
+ if m == "obb":
+ return "obb"
# Guess from model cfg
if isinstance(model, dict):
@@ -823,40 +894,42 @@ def guess_model_task(model):
# Guess from PyTorch model
if isinstance(model, nn.Module): # PyTorch model
- for x in 'model.args', 'model.model.args', 'model.model.model.args':
+ for x in "model.args", "model.model.args", "model.model.model.args":
with contextlib.suppress(Exception):
- return eval(x)['task']
- for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
+ return eval(x)["task"]
+ for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
with contextlib.suppress(Exception):
return cfg2task(eval(x))
for m in model.modules():
if isinstance(m, Detect):
- return 'detect'
+ return "detect"
elif isinstance(m, Segment):
- return 'segment'
+ return "segment"
elif isinstance(m, Classify):
- return 'classify'
+ return "classify"
elif isinstance(m, Pose):
- return 'pose'
+ return "pose"
elif isinstance(m, OBB):
- return 'obb'
+ return "obb"
# Guess from model filename
if isinstance(model, (str, Path)):
model = Path(model)
- if '-seg' in model.stem or 'segment' in model.parts:
- return 'segment'
- elif '-cls' in model.stem or 'classify' in model.parts:
- return 'classify'
- elif '-pose' in model.stem or 'pose' in model.parts:
- return 'pose'
- elif '-obb' in model.stem or 'obb' in model.parts:
- return 'obb'
- elif 'detect' in model.parts:
- return 'detect'
+ if "-seg" in model.stem or "segment" in model.parts:
+ return "segment"
+ elif "-cls" in model.stem or "classify" in model.parts:
+ return "classify"
+ elif "-pose" in model.stem or "pose" in model.parts:
+ return "pose"
+ elif "-obb" in model.stem or "obb" in model.parts:
+ return "obb"
+ elif "detect" in model.parts:
+ return "detect"
# Unable to determine task from model
- LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
- "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'.")
- return 'detect' # assume detect
+ LOGGER.warning(
+ "WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
+ "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'."
+ )
+ return "detect" # assume detect
diff --git a/ultralytics/solutions/ai_gym.py b/ultralytics/solutions/ai_gym.py
index df9b422a..0cd59620 100644
--- a/ultralytics/solutions/ai_gym.py
+++ b/ultralytics/solutions/ai_gym.py
@@ -26,7 +26,7 @@ class AIGym:
self.angle = None
self.count = None
self.stage = None
- self.pose_type = 'pushup'
+ self.pose_type = "pushup"
self.kpts_to_check = None
# Visual Information
@@ -36,13 +36,15 @@ class AIGym:
# Check if environment support imshow
self.env_check = check_imshow(warn=True)
- def set_args(self,
- kpts_to_check,
- line_thickness=2,
- view_img=False,
- pose_up_angle=145.0,
- pose_down_angle=90.0,
- pose_type='pullup'):
+ def set_args(
+ self,
+ kpts_to_check,
+ line_thickness=2,
+ view_img=False,
+ pose_up_angle=145.0,
+ pose_down_angle=90.0,
+ pose_type="pullup",
+ ):
"""
Configures the AIGym line_thickness, save image and view image parameters
Args:
@@ -72,65 +74,75 @@ class AIGym:
if frame_count == 1:
self.count = [0] * len(results[0])
self.angle = [0] * len(results[0])
- self.stage = ['-' for _ in results[0]]
+ self.stage = ["-" for _ in results[0]]
self.keypoints = results[0].keypoints.data
self.annotator = Annotator(im0, line_width=2)
for ind, k in enumerate(reversed(self.keypoints)):
- if self.pose_type == 'pushup' or self.pose_type == 'pullup':
- self.angle[ind] = self.annotator.estimate_pose_angle(k[int(self.kpts_to_check[0])].cpu(),
- k[int(self.kpts_to_check[1])].cpu(),
- k[int(self.kpts_to_check[2])].cpu())
+ if self.pose_type == "pushup" or self.pose_type == "pullup":
+ self.angle[ind] = self.annotator.estimate_pose_angle(
+ k[int(self.kpts_to_check[0])].cpu(),
+ k[int(self.kpts_to_check[1])].cpu(),
+ k[int(self.kpts_to_check[2])].cpu(),
+ )
self.im0 = self.annotator.draw_specific_points(k, self.kpts_to_check, shape=(640, 640), radius=10)
- if self.pose_type == 'abworkout':
- self.angle[ind] = self.annotator.estimate_pose_angle(k[int(self.kpts_to_check[0])].cpu(),
- k[int(self.kpts_to_check[1])].cpu(),
- k[int(self.kpts_to_check[2])].cpu())
+ if self.pose_type == "abworkout":
+ self.angle[ind] = self.annotator.estimate_pose_angle(
+ k[int(self.kpts_to_check[0])].cpu(),
+ k[int(self.kpts_to_check[1])].cpu(),
+ k[int(self.kpts_to_check[2])].cpu(),
+ )
self.im0 = self.annotator.draw_specific_points(k, self.kpts_to_check, shape=(640, 640), radius=10)
if self.angle[ind] > self.poseup_angle:
- self.stage[ind] = 'down'
- if self.angle[ind] < self.posedown_angle and self.stage[ind] == 'down':
- self.stage[ind] = 'up'
+ self.stage[ind] = "down"
+ if self.angle[ind] < self.posedown_angle and self.stage[ind] == "down":
+ self.stage[ind] = "up"
self.count[ind] += 1
- self.annotator.plot_angle_and_count_and_stage(angle_text=self.angle[ind],
- count_text=self.count[ind],
- stage_text=self.stage[ind],
- center_kpt=k[int(self.kpts_to_check[1])],
- line_thickness=self.tf)
-
- if self.pose_type == 'pushup':
+ self.annotator.plot_angle_and_count_and_stage(
+ angle_text=self.angle[ind],
+ count_text=self.count[ind],
+ stage_text=self.stage[ind],
+ center_kpt=k[int(self.kpts_to_check[1])],
+ line_thickness=self.tf,
+ )
+
+ if self.pose_type == "pushup":
if self.angle[ind] > self.poseup_angle:
- self.stage[ind] = 'up'
- if self.angle[ind] < self.posedown_angle and self.stage[ind] == 'up':
- self.stage[ind] = 'down'
+ self.stage[ind] = "up"
+ if self.angle[ind] < self.posedown_angle and self.stage[ind] == "up":
+ self.stage[ind] = "down"
self.count[ind] += 1
- self.annotator.plot_angle_and_count_and_stage(angle_text=self.angle[ind],
- count_text=self.count[ind],
- stage_text=self.stage[ind],
- center_kpt=k[int(self.kpts_to_check[1])],
- line_thickness=self.tf)
- if self.pose_type == 'pullup':
+ self.annotator.plot_angle_and_count_and_stage(
+ angle_text=self.angle[ind],
+ count_text=self.count[ind],
+ stage_text=self.stage[ind],
+ center_kpt=k[int(self.kpts_to_check[1])],
+ line_thickness=self.tf,
+ )
+ if self.pose_type == "pullup":
if self.angle[ind] > self.poseup_angle:
- self.stage[ind] = 'down'
- if self.angle[ind] < self.posedown_angle and self.stage[ind] == 'down':
- self.stage[ind] = 'up'
+ self.stage[ind] = "down"
+ if self.angle[ind] < self.posedown_angle and self.stage[ind] == "down":
+ self.stage[ind] = "up"
self.count[ind] += 1
- self.annotator.plot_angle_and_count_and_stage(angle_text=self.angle[ind],
- count_text=self.count[ind],
- stage_text=self.stage[ind],
- center_kpt=k[int(self.kpts_to_check[1])],
- line_thickness=self.tf)
+ self.annotator.plot_angle_and_count_and_stage(
+ angle_text=self.angle[ind],
+ count_text=self.count[ind],
+ stage_text=self.stage[ind],
+ center_kpt=k[int(self.kpts_to_check[1])],
+ line_thickness=self.tf,
+ )
self.annotator.kpts(k, shape=(640, 640), radius=1, kpt_line=True)
if self.env_check and self.view_img:
- cv2.imshow('Ultralytics YOLOv8 AI GYM', self.im0)
- if cv2.waitKey(1) & 0xFF == ord('q'):
+ cv2.imshow("Ultralytics YOLOv8 AI GYM", self.im0)
+ if cv2.waitKey(1) & 0xFF == ord("q"):
return
return self.im0
-if __name__ == '__main__':
+if __name__ == "__main__":
AIGym()
diff --git a/ultralytics/solutions/distance_calculation.py b/ultralytics/solutions/distance_calculation.py
index 306a5bd2..1770a7c0 100644
--- a/ultralytics/solutions/distance_calculation.py
+++ b/ultralytics/solutions/distance_calculation.py
@@ -41,13 +41,15 @@ class DistanceCalculation:
# Check if environment support imshow
self.env_check = check_imshow(warn=True)
- def set_args(self,
- names,
- pixels_per_meter=10,
- view_img=False,
- line_thickness=2,
- line_color=(255, 255, 0),
- centroid_color=(255, 0, 255)):
+ def set_args(
+ self,
+ names,
+ pixels_per_meter=10,
+ view_img=False,
+ line_thickness=2,
+ line_color=(255, 255, 0),
+ centroid_color=(255, 0, 255),
+ ):
"""
Configures the distance calculation and display parameters.
@@ -129,8 +131,9 @@ class DistanceCalculation:
distance (float): Distance between two centroids
"""
cv2.rectangle(self.im0, (15, 25), (280, 70), (255, 255, 255), -1)
- cv2.putText(self.im0, f'Distance : {distance:.2f}m', (20, 55), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2,
- cv2.LINE_AA)
+ cv2.putText(
+ self.im0, f"Distance : {distance:.2f}m", (20, 55), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2, cv2.LINE_AA
+ )
cv2.line(self.im0, self.centroids[0], self.centroids[1], self.line_color, 3)
cv2.circle(self.im0, self.centroids[0], 6, self.centroid_color, -1)
cv2.circle(self.im0, self.centroids[1], 6, self.centroid_color, -1)
@@ -179,13 +182,13 @@ class DistanceCalculation:
def display_frames(self):
"""Display frame."""
- cv2.namedWindow('Ultralytics Distance Estimation')
- cv2.setMouseCallback('Ultralytics Distance Estimation', self.mouse_event_for_distance)
- cv2.imshow('Ultralytics Distance Estimation', self.im0)
+ cv2.namedWindow("Ultralytics Distance Estimation")
+ cv2.setMouseCallback("Ultralytics Distance Estimation", self.mouse_event_for_distance)
+ cv2.imshow("Ultralytics Distance Estimation", self.im0)
- if cv2.waitKey(1) & 0xFF == ord('q'):
+ if cv2.waitKey(1) & 0xFF == ord("q"):
return
-if __name__ == '__main__':
+if __name__ == "__main__":
DistanceCalculation()
diff --git a/ultralytics/solutions/heatmap.py b/ultralytics/solutions/heatmap.py
index 5d5b6d1f..c60eaa76 100644
--- a/ultralytics/solutions/heatmap.py
+++ b/ultralytics/solutions/heatmap.py
@@ -8,7 +8,7 @@ import numpy as np
from ultralytics.utils.checks import check_imshow, check_requirements
from ultralytics.utils.plotting import Annotator
-check_requirements('shapely>=2.0.0')
+check_requirements("shapely>=2.0.0")
from shapely.geometry import LineString, Point, Polygon
@@ -22,7 +22,7 @@ class Heatmap:
# Visual information
self.annotator = None
self.view_img = False
- self.shape = 'circle'
+ self.shape = "circle"
# Image information
self.imw = None
@@ -63,23 +63,25 @@ class Heatmap:
# Check if environment support imshow
self.env_check = check_imshow(warn=True)
- def set_args(self,
- imw,
- imh,
- colormap=cv2.COLORMAP_JET,
- heatmap_alpha=0.5,
- view_img=False,
- view_in_counts=True,
- view_out_counts=True,
- count_reg_pts=None,
- count_txt_thickness=2,
- count_txt_color=(0, 0, 0),
- count_color=(255, 255, 255),
- count_reg_color=(255, 0, 255),
- region_thickness=5,
- line_dist_thresh=15,
- decay_factor=0.99,
- shape='circle'):
+ def set_args(
+ self,
+ imw,
+ imh,
+ colormap=cv2.COLORMAP_JET,
+ heatmap_alpha=0.5,
+ view_img=False,
+ view_in_counts=True,
+ view_out_counts=True,
+ count_reg_pts=None,
+ count_txt_thickness=2,
+ count_txt_color=(0, 0, 0),
+ count_color=(255, 255, 255),
+ count_reg_color=(255, 0, 255),
+ region_thickness=5,
+ line_dist_thresh=15,
+ decay_factor=0.99,
+ shape="circle",
+ ):
"""
Configures the heatmap colormap, width, height and display parameters.
@@ -111,20 +113,19 @@ class Heatmap:
# Region and line selection
if count_reg_pts is not None:
-
if len(count_reg_pts) == 2:
- print('Line Counter Initiated.')
+ print("Line Counter Initiated.")
self.count_reg_pts = count_reg_pts
self.counting_region = LineString(count_reg_pts)
elif len(count_reg_pts) == 4:
- print('Region Counter Initiated.')
+ print("Region Counter Initiated.")
self.count_reg_pts = count_reg_pts
self.counting_region = Polygon(self.count_reg_pts)
else:
- print('Region or line points Invalid, 2 or 4 points supported')
- print('Using Line Counter Now')
+ print("Region or line points Invalid, 2 or 4 points supported")
+ print("Using Line Counter Now")
self.counting_region = Polygon([(20, 400), (1260, 400)]) # dummy points
# Heatmap new frame
@@ -140,10 +141,10 @@ class Heatmap:
self.shape = shape
# shape of heatmap, if not selected
- if self.shape not in ['circle', 'rect']:
+ if self.shape not in ["circle", "rect"]:
print("Unknown shape value provided, 'circle' & 'rect' supported")
- print('Using Circular shape now')
- self.shape = 'circle'
+ print("Using Circular shape now")
+ self.shape = "circle"
def extract_results(self, tracks):
"""
@@ -177,27 +178,26 @@ class Heatmap:
self.annotator = Annotator(self.im0, self.count_txt_thickness, None)
if self.count_reg_pts is not None:
-
# Draw counting region
if self.view_in_counts or self.view_out_counts:
- self.annotator.draw_region(reg_pts=self.count_reg_pts,
- color=self.region_color,
- thickness=self.region_thickness)
+ self.annotator.draw_region(
+ reg_pts=self.count_reg_pts, color=self.region_color, thickness=self.region_thickness
+ )
for box, cls, track_id in zip(self.boxes, self.clss, self.track_ids):
-
- if self.shape == 'circle':
+ if self.shape == "circle":
center = (int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2))
radius = min(int(box[2]) - int(box[0]), int(box[3]) - int(box[1])) // 2
- y, x = np.ogrid[0:self.heatmap.shape[0], 0:self.heatmap.shape[1]]
- mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius ** 2
+ y, x = np.ogrid[0 : self.heatmap.shape[0], 0 : self.heatmap.shape[1]]
+ mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius**2
- self.heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] += \
- (2 * mask[int(box[1]):int(box[3]), int(box[0]):int(box[2])])
+ self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += (
+ 2 * mask[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])]
+ )
else:
- self.heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] += 2
+ self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += 2
# Store tracking hist
track_line = self.track_history[track_id]
@@ -226,26 +226,26 @@ class Heatmap:
self.in_counts += 1
else:
for box, cls in zip(self.boxes, self.clss):
-
- if self.shape == 'circle':
+ if self.shape == "circle":
center = (int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2))
radius = min(int(box[2]) - int(box[0]), int(box[3]) - int(box[1])) // 2
- y, x = np.ogrid[0:self.heatmap.shape[0], 0:self.heatmap.shape[1]]
- mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius ** 2
+ y, x = np.ogrid[0 : self.heatmap.shape[0], 0 : self.heatmap.shape[1]]
+ mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius**2
- self.heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] += \
- (2 * mask[int(box[1]):int(box[3]), int(box[0]):int(box[2])])
+ self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += (
+ 2 * mask[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])]
+ )
else:
- self.heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] += 2
+ self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += 2
# Normalize, apply colormap to heatmap and combine with original image
heatmap_normalized = cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX)
heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), self.colormap)
- incount_label = 'In Count : ' + f'{self.in_counts}'
- outcount_label = 'OutCount : ' + f'{self.out_counts}'
+ incount_label = "In Count : " + f"{self.in_counts}"
+ outcount_label = "OutCount : " + f"{self.out_counts}"
# Display counts based on user choice
counts_label = None
@@ -256,13 +256,15 @@ class Heatmap:
elif not self.view_out_counts:
counts_label = incount_label
else:
- counts_label = incount_label + ' ' + outcount_label
+ counts_label = incount_label + " " + outcount_label
if self.count_reg_pts is not None and counts_label is not None:
- self.annotator.count_labels(counts=counts_label,
- count_txt_size=self.count_txt_thickness,
- txt_color=self.count_txt_color,
- color=self.count_color)
+ self.annotator.count_labels(
+ counts=counts_label,
+ count_txt_size=self.count_txt_thickness,
+ txt_color=self.count_txt_color,
+ color=self.count_color,
+ )
self.im0 = cv2.addWeighted(self.im0, 1 - self.heatmap_alpha, heatmap_colored, self.heatmap_alpha, 0)
@@ -273,11 +275,11 @@ class Heatmap:
def display_frames(self):
"""Display frame."""
- cv2.imshow('Ultralytics Heatmap', self.im0)
+ cv2.imshow("Ultralytics Heatmap", self.im0)
- if cv2.waitKey(1) & 0xFF == ord('q'):
+ if cv2.waitKey(1) & 0xFF == ord("q"):
return
-if __name__ == '__main__':
+if __name__ == "__main__":
Heatmap()
diff --git a/ultralytics/solutions/object_counter.py b/ultralytics/solutions/object_counter.py
index f817c798..1257553f 100644
--- a/ultralytics/solutions/object_counter.py
+++ b/ultralytics/solutions/object_counter.py
@@ -7,7 +7,7 @@ import cv2
from ultralytics.utils.checks import check_imshow, check_requirements
from ultralytics.utils.plotting import Annotator, colors
-check_requirements('shapely>=2.0.0')
+check_requirements("shapely>=2.0.0")
from shapely.geometry import LineString, Point, Polygon
@@ -56,22 +56,24 @@ class ObjectCounter:
# Check if environment support imshow
self.env_check = check_imshow(warn=True)
- def set_args(self,
- classes_names,
- reg_pts,
- count_reg_color=(255, 0, 255),
- line_thickness=2,
- track_thickness=2,
- view_img=False,
- view_in_counts=True,
- view_out_counts=True,
- draw_tracks=False,
- count_txt_thickness=2,
- count_txt_color=(0, 0, 0),
- count_color=(255, 255, 255),
- track_color=(0, 255, 0),
- region_thickness=5,
- line_dist_thresh=15):
+ def set_args(
+ self,
+ classes_names,
+ reg_pts,
+ count_reg_color=(255, 0, 255),
+ line_thickness=2,
+ track_thickness=2,
+ view_img=False,
+ view_in_counts=True,
+ view_out_counts=True,
+ draw_tracks=False,
+ count_txt_thickness=2,
+ count_txt_color=(0, 0, 0),
+ count_color=(255, 255, 255),
+ track_color=(0, 255, 0),
+ region_thickness=5,
+ line_dist_thresh=15,
+ ):
"""
Configures the Counter's image, bounding box line thickness, and counting region points.
@@ -101,16 +103,16 @@ class ObjectCounter:
# Region and line selection
if len(reg_pts) == 2:
- print('Line Counter Initiated.')
+ print("Line Counter Initiated.")
self.reg_pts = reg_pts
self.counting_region = LineString(self.reg_pts)
elif len(reg_pts) == 4:
- print('Region Counter Initiated.')
+ print("Region Counter Initiated.")
self.reg_pts = reg_pts
self.counting_region = Polygon(self.reg_pts)
else:
- print('Invalid Region points provided, region_points can be 2 or 4')
- print('Using Line Counter Now')
+ print("Invalid Region points provided, region_points can be 2 or 4")
+ print("Using Line Counter Now")
self.counting_region = LineString(self.reg_pts)
self.names = classes_names
@@ -164,8 +166,9 @@ class ObjectCounter:
# Extract tracks
for box, track_id, cls in zip(boxes, track_ids, clss):
- self.annotator.box_label(box, label=str(track_id) + ':' + self.names[cls],
- color=colors(int(cls), True)) # Draw bounding box
+ self.annotator.box_label(
+ box, label=str(track_id) + ":" + self.names[cls], color=colors(int(cls), True)
+ ) # Draw bounding box
# Draw Tracks
track_line = self.track_history[track_id]
@@ -175,9 +178,9 @@ class ObjectCounter:
# Draw track trails
if self.draw_tracks:
- self.annotator.draw_centroid_and_tracks(track_line,
- color=self.track_color,
- track_thickness=self.track_thickness)
+ self.annotator.draw_centroid_and_tracks(
+ track_line, color=self.track_color, track_thickness=self.track_thickness
+ )
# Count objects
if len(self.reg_pts) == 4:
@@ -199,8 +202,8 @@ class ObjectCounter:
else:
self.in_counts += 1
- incount_label = 'In Count : ' + f'{self.in_counts}'
- outcount_label = 'OutCount : ' + f'{self.out_counts}'
+ incount_label = "In Count : " + f"{self.in_counts}"
+ outcount_label = "OutCount : " + f"{self.out_counts}"
# Display counts based on user choice
counts_label = None
@@ -211,24 +214,27 @@ class ObjectCounter:
elif not self.view_out_counts:
counts_label = incount_label
else:
- counts_label = incount_label + ' ' + outcount_label
+ counts_label = incount_label + " " + outcount_label
if counts_label is not None:
- self.annotator.count_labels(counts=counts_label,
- count_txt_size=self.count_txt_thickness,
- txt_color=self.count_txt_color,
- color=self.count_color)
+ self.annotator.count_labels(
+ counts=counts_label,
+ count_txt_size=self.count_txt_thickness,
+ txt_color=self.count_txt_color,
+ color=self.count_color,
+ )
def display_frames(self):
"""Display frame."""
if self.env_check:
- cv2.namedWindow('Ultralytics YOLOv8 Object Counter')
+ cv2.namedWindow("Ultralytics YOLOv8 Object Counter")
if len(self.reg_pts) == 4: # only add mouse event If user drawn region
- cv2.setMouseCallback('Ultralytics YOLOv8 Object Counter', self.mouse_event_for_region,
- {'region_points': self.reg_pts})
- cv2.imshow('Ultralytics YOLOv8 Object Counter', self.im0)
+ cv2.setMouseCallback(
+ "Ultralytics YOLOv8 Object Counter", self.mouse_event_for_region, {"region_points": self.reg_pts}
+ )
+ cv2.imshow("Ultralytics YOLOv8 Object Counter", self.im0)
# Break Window
- if cv2.waitKey(1) & 0xFF == ord('q'):
+ if cv2.waitKey(1) & 0xFF == ord("q"):
return
def start_counting(self, im0, tracks):
@@ -254,5 +260,5 @@ class ObjectCounter:
return self.im0
-if __name__ == '__main__':
+if __name__ == "__main__":
ObjectCounter()
diff --git a/ultralytics/solutions/speed_estimation.py b/ultralytics/solutions/speed_estimation.py
index 7260141f..55e209f9 100644
--- a/ultralytics/solutions/speed_estimation.py
+++ b/ultralytics/solutions/speed_estimation.py
@@ -66,7 +66,7 @@ class SpeedEstimator:
spdl_dist_thresh (int): Euclidean distance threshold for speed line
"""
if reg_pts is None:
- print('Region points not provided, using default values')
+ print("Region points not provided, using default values")
else:
self.reg_pts = reg_pts
self.names = names
@@ -114,8 +114,9 @@ class SpeedEstimator:
cls (str): object class name
track (list): tracking history for tracks path drawing
"""
- speed_label = str(int(
- self.dist_data[track_id])) + 'km/ph' if track_id in self.dist_data else self.names[int(cls)]
+ speed_label = (
+ str(int(self.dist_data[track_id])) + "km/ph" if track_id in self.dist_data else self.names[int(cls)]
+ )
bbox_color = colors(int(track_id)) if track_id in self.dist_data else (255, 0, 255)
self.annotator.box_label(box, speed_label, bbox_color)
@@ -132,19 +133,16 @@ class SpeedEstimator:
"""
if self.reg_pts[0][0] < track[-1][0] < self.reg_pts[1][0]:
+ if self.reg_pts[1][1] - self.spdl_dist_thresh < track[-1][1] < self.reg_pts[1][1] + self.spdl_dist_thresh:
+ direction = "known"
- if (self.reg_pts[1][1] - self.spdl_dist_thresh < track[-1][1] < self.reg_pts[1][1] + self.spdl_dist_thresh):
- direction = 'known'
-
- elif (self.reg_pts[0][1] - self.spdl_dist_thresh < track[-1][1] <
- self.reg_pts[0][1] + self.spdl_dist_thresh):
- direction = 'known'
+ elif self.reg_pts[0][1] - self.spdl_dist_thresh < track[-1][1] < self.reg_pts[0][1] + self.spdl_dist_thresh:
+ direction = "known"
else:
- direction = 'unknown'
-
- if self.trk_previous_times[trk_id] != 0 and direction != 'unknown':
+ direction = "unknown"
+ if self.trk_previous_times[trk_id] != 0 and direction != "unknown":
if trk_id not in self.trk_idslist:
self.trk_idslist.append(trk_id)
@@ -178,7 +176,6 @@ class SpeedEstimator:
self.annotator.draw_region(reg_pts=self.reg_pts, color=(255, 0, 0), thickness=self.region_thickness)
for box, trk_id, cls in zip(self.boxes, self.trk_ids, self.clss):
-
track = self.store_track_info(trk_id, box)
if trk_id not in self.trk_previous_times:
@@ -194,10 +191,10 @@ class SpeedEstimator:
def display_frames(self):
"""Display frame."""
- cv2.imshow('Ultralytics Speed Estimation', self.im0)
- if cv2.waitKey(1) & 0xFF == ord('q'):
+ cv2.imshow("Ultralytics Speed Estimation", self.im0)
+ if cv2.waitKey(1) & 0xFF == ord("q"):
return
-if __name__ == '__main__':
+if __name__ == "__main__":
SpeedEstimator()
diff --git a/ultralytics/trackers/__init__.py b/ultralytics/trackers/__init__.py
index 46e178e4..bf51b8df 100644
--- a/ultralytics/trackers/__init__.py
+++ b/ultralytics/trackers/__init__.py
@@ -4,4 +4,4 @@ from .bot_sort import BOTSORT
from .byte_tracker import BYTETracker
from .track import register_tracker
-__all__ = 'register_tracker', 'BOTSORT', 'BYTETracker' # allow simpler import
+__all__ = "register_tracker", "BOTSORT", "BYTETracker" # allow simpler import
diff --git a/ultralytics/trackers/bot_sort.py b/ultralytics/trackers/bot_sort.py
index 778786b8..31d5e1be 100644
--- a/ultralytics/trackers/bot_sort.py
+++ b/ultralytics/trackers/bot_sort.py
@@ -39,6 +39,7 @@ class BOTrack(STrack):
bo_track.predict()
bo_track.update(new_track, frame_id)
"""
+
shared_kalman = KalmanFilterXYWH()
def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
@@ -176,7 +177,7 @@ class BOTSORT(BYTETracker):
def get_dists(self, tracks, detections):
"""Get distances between tracks and detections using IoU and (optionally) ReID embeddings."""
dists = matching.iou_distance(tracks, detections)
- dists_mask = (dists > self.proximity_thresh)
+ dists_mask = dists > self.proximity_thresh
# TODO: mot20
# if not self.args.mot20:
diff --git a/ultralytics/trackers/byte_tracker.py b/ultralytics/trackers/byte_tracker.py
index 5331511e..619ea12c 100644
--- a/ultralytics/trackers/byte_tracker.py
+++ b/ultralytics/trackers/byte_tracker.py
@@ -112,8 +112,9 @@ class STrack(BaseTrack):
def re_activate(self, new_track, frame_id, new_id=False):
"""Reactivates a previously lost track with a new detection."""
- self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance,
- self.convert_coords(new_track.tlwh))
+ self.mean, self.covariance = self.kalman_filter.update(
+ self.mean, self.covariance, self.convert_coords(new_track.tlwh)
+ )
self.tracklet_len = 0
self.state = TrackState.Tracked
self.is_activated = True
@@ -136,8 +137,9 @@ class STrack(BaseTrack):
self.tracklet_len += 1
new_tlwh = new_track.tlwh
- self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance,
- self.convert_coords(new_tlwh))
+ self.mean, self.covariance = self.kalman_filter.update(
+ self.mean, self.covariance, self.convert_coords(new_tlwh)
+ )
self.state = TrackState.Tracked
self.is_activated = True
@@ -192,7 +194,7 @@ class STrack(BaseTrack):
def __repr__(self):
"""Return a string representation of the BYTETracker object with start and end frames and track ID."""
- return f'OT_{self.track_id}_({self.start_frame}-{self.end_frame})'
+ return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
class BYTETracker:
@@ -275,7 +277,7 @@ class BYTETracker:
strack_pool = self.joint_stracks(tracked_stracks, self.lost_stracks)
# Predict the current location with KF
self.multi_predict(strack_pool)
- if hasattr(self, 'gmc') and img is not None:
+ if hasattr(self, "gmc") and img is not None:
warp = self.gmc.apply(img, dets)
STrack.multi_gmc(strack_pool, warp)
STrack.multi_gmc(unconfirmed, warp)
@@ -349,7 +351,8 @@ class BYTETracker:
self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum
return np.asarray(
[x.tlbr.tolist() + [x.track_id, x.score, x.cls, x.idx] for x in self.tracked_stracks if x.is_activated],
- dtype=np.float32)
+ dtype=np.float32,
+ )
def get_kalmanfilter(self):
"""Returns a Kalman filter object for tracking bounding boxes."""
diff --git a/ultralytics/trackers/track.py b/ultralytics/trackers/track.py
index 2ad4f5d7..c0dd4e45 100644
--- a/ultralytics/trackers/track.py
+++ b/ultralytics/trackers/track.py
@@ -10,7 +10,7 @@ from .bot_sort import BOTSORT
from .byte_tracker import BYTETracker
# A mapping of tracker types to corresponding tracker classes
-TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT}
+TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT}
def on_predict_start(predictor: object, persist: bool = False) -> None:
@@ -24,15 +24,15 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
Raises:
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
"""
- if predictor.args.task == 'obb':
- raise NotImplementedError('ERROR ❌ OBB task does not support track mode!')
- if hasattr(predictor, 'trackers') and persist:
+ if predictor.args.task == "obb":
+ raise NotImplementedError("ERROR ❌ OBB task does not support track mode!")
+ if hasattr(predictor, "trackers") and persist:
return
tracker = check_yaml(predictor.args.tracker)
cfg = IterableSimpleNamespace(**yaml_load(tracker))
- if cfg.tracker_type not in ['bytetrack', 'botsort']:
+ if cfg.tracker_type not in ["bytetrack", "botsort"]:
raise AssertionError(f"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'")
trackers = []
@@ -76,5 +76,5 @@ def register_tracker(model: object, persist: bool) -> None:
model (object): The model object to register tracking callbacks for.
persist (bool): Whether to persist the trackers if they already exist.
"""
- model.add_callback('on_predict_start', partial(on_predict_start, persist=persist))
- model.add_callback('on_predict_postprocess_end', partial(on_predict_postprocess_end, persist=persist))
+ model.add_callback("on_predict_start", partial(on_predict_start, persist=persist))
+ model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist))
diff --git a/ultralytics/trackers/utils/gmc.py b/ultralytics/trackers/utils/gmc.py
index 1e801b4a..07726d73 100644
--- a/ultralytics/trackers/utils/gmc.py
+++ b/ultralytics/trackers/utils/gmc.py
@@ -33,7 +33,7 @@ class GMC:
applySparseOptFlow(self, raw_frame, detections=None): Applies the Sparse Optical Flow method to a raw frame.
"""
- def __init__(self, method: str = 'sparseOptFlow', downscale: int = 2) -> None:
+ def __init__(self, method: str = "sparseOptFlow", downscale: int = 2) -> None:
"""
Initialize a video tracker with specified parameters.
@@ -46,34 +46,31 @@ class GMC:
self.method = method
self.downscale = max(1, int(downscale))
- if self.method == 'orb':
+ if self.method == "orb":
self.detector = cv2.FastFeatureDetector_create(20)
self.extractor = cv2.ORB_create()
self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING)
- elif self.method == 'sift':
+ elif self.method == "sift":
self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
self.matcher = cv2.BFMatcher(cv2.NORM_L2)
- elif self.method == 'ecc':
+ elif self.method == "ecc":
number_of_iterations = 5000
termination_eps = 1e-6
self.warp_mode = cv2.MOTION_EUCLIDEAN
self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps)
- elif self.method == 'sparseOptFlow':
- self.feature_params = dict(maxCorners=1000,
- qualityLevel=0.01,
- minDistance=1,
- blockSize=3,
- useHarrisDetector=False,
- k=0.04)
+ elif self.method == "sparseOptFlow":
+ self.feature_params = dict(
+ maxCorners=1000, qualityLevel=0.01, minDistance=1, blockSize=3, useHarrisDetector=False, k=0.04
+ )
- elif self.method in ['none', 'None', None]:
+ elif self.method in ["none", "None", None]:
self.method = None
else:
- raise ValueError(f'Error: Unknown GMC method:{method}')
+ raise ValueError(f"Error: Unknown GMC method:{method}")
self.prevFrame = None
self.prevKeyPoints = None
@@ -97,11 +94,11 @@ class GMC:
array([[1, 2, 3],
[4, 5, 6]])
"""
- if self.method in ['orb', 'sift']:
+ if self.method in ["orb", "sift"]:
return self.applyFeatures(raw_frame, detections)
- elif self.method == 'ecc':
+ elif self.method == "ecc":
return self.applyEcc(raw_frame, detections)
- elif self.method == 'sparseOptFlow':
+ elif self.method == "sparseOptFlow":
return self.applySparseOptFlow(raw_frame, detections)
else:
return np.eye(2, 3)
@@ -149,7 +146,7 @@ class GMC:
try:
(cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
except Exception as e:
- LOGGER.warning(f'WARNING: find transform failed. Set warp as identity {e}')
+ LOGGER.warning(f"WARNING: find transform failed. Set warp as identity {e}")
return H
@@ -182,11 +179,11 @@ class GMC:
# Find the keypoints
mask = np.zeros_like(frame)
- mask[int(0.02 * height):int(0.98 * height), int(0.02 * width):int(0.98 * width)] = 255
+ mask[int(0.02 * height) : int(0.98 * height), int(0.02 * width) : int(0.98 * width)] = 255
if detections is not None:
for det in detections:
tlbr = (det[:4] / self.downscale).astype(np.int_)
- mask[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2]] = 0
+ mask[tlbr[1] : tlbr[3], tlbr[0] : tlbr[2]] = 0
keypoints = self.detector.detect(frame, mask)
@@ -228,11 +225,14 @@ class GMC:
prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
currKeyPointLocation = keypoints[m.trainIdx].pt
- spatialDistance = (prevKeyPointLocation[0] - currKeyPointLocation[0],
- prevKeyPointLocation[1] - currKeyPointLocation[1])
+ spatialDistance = (
+ prevKeyPointLocation[0] - currKeyPointLocation[0],
+ prevKeyPointLocation[1] - currKeyPointLocation[1],
+ )
- if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and \
- (np.abs(spatialDistance[1]) < maxSpatialDistance[1]):
+ if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and (
+ np.abs(spatialDistance[1]) < maxSpatialDistance[1]
+ ):
spatialDistances.append(spatialDistance)
matches.append(m)
@@ -283,7 +283,7 @@ class GMC:
H[0, 2] *= self.downscale
H[1, 2] *= self.downscale
else:
- LOGGER.warning('WARNING: not enough matching points')
+ LOGGER.warning("WARNING: not enough matching points")
# Store to next iteration
self.prevFrame = frame.copy()
@@ -350,7 +350,7 @@ class GMC:
H[0, 2] *= self.downscale
H[1, 2] *= self.downscale
else:
- LOGGER.warning('WARNING: not enough matching points')
+ LOGGER.warning("WARNING: not enough matching points")
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
diff --git a/ultralytics/trackers/utils/kalman_filter.py b/ultralytics/trackers/utils/kalman_filter.py
index b4fb91fc..4ae68be2 100644
--- a/ultralytics/trackers/utils/kalman_filter.py
+++ b/ultralytics/trackers/utils/kalman_filter.py
@@ -17,7 +17,7 @@ class KalmanFilterXYAH:
def __init__(self):
"""Initialize Kalman filter model matrices with motion and observation uncertainty weights."""
- ndim, dt = 4, 1.
+ ndim, dt = 4, 1.0
# Create Kalman filter model matrices
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
@@ -27,8 +27,8 @@ class KalmanFilterXYAH:
# Motion and observation uncertainty are chosen relative to the current state estimate. These weights control
# the amount of uncertainty in the model.
- self._std_weight_position = 1. / 20
- self._std_weight_velocity = 1. / 160
+ self._std_weight_position = 1.0 / 20
+ self._std_weight_velocity = 1.0 / 160
def initiate(self, measurement: np.ndarray) -> tuple:
"""
@@ -47,9 +47,15 @@ class KalmanFilterXYAH:
mean = np.r_[mean_pos, mean_vel]
std = [
- 2 * self._std_weight_position * measurement[3], 2 * self._std_weight_position * measurement[3], 1e-2,
- 2 * self._std_weight_position * measurement[3], 10 * self._std_weight_velocity * measurement[3],
- 10 * self._std_weight_velocity * measurement[3], 1e-5, 10 * self._std_weight_velocity * measurement[3]]
+ 2 * self._std_weight_position * measurement[3],
+ 2 * self._std_weight_position * measurement[3],
+ 1e-2,
+ 2 * self._std_weight_position * measurement[3],
+ 10 * self._std_weight_velocity * measurement[3],
+ 10 * self._std_weight_velocity * measurement[3],
+ 1e-5,
+ 10 * self._std_weight_velocity * measurement[3],
+ ]
covariance = np.diag(np.square(std))
return mean, covariance
@@ -66,11 +72,17 @@ class KalmanFilterXYAH:
velocities are initialized to 0 mean.
"""
std_pos = [
- self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-2,
- self._std_weight_position * mean[3]]
+ self._std_weight_position * mean[3],
+ self._std_weight_position * mean[3],
+ 1e-2,
+ self._std_weight_position * mean[3],
+ ]
std_vel = [
- self._std_weight_velocity * mean[3], self._std_weight_velocity * mean[3], 1e-5,
- self._std_weight_velocity * mean[3]]
+ self._std_weight_velocity * mean[3],
+ self._std_weight_velocity * mean[3],
+ 1e-5,
+ self._std_weight_velocity * mean[3],
+ ]
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
mean = np.dot(mean, self._motion_mat.T)
@@ -90,8 +102,11 @@ class KalmanFilterXYAH:
(tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate.
"""
std = [
- self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-1,
- self._std_weight_position * mean[3]]
+ self._std_weight_position * mean[3],
+ self._std_weight_position * mean[3],
+ 1e-1,
+ self._std_weight_position * mean[3],
+ ]
innovation_cov = np.diag(np.square(std))
mean = np.dot(self._update_mat, mean)
@@ -111,11 +126,17 @@ class KalmanFilterXYAH:
velocities are initialized to 0 mean.
"""
std_pos = [
- self._std_weight_position * mean[:, 3], self._std_weight_position * mean[:, 3],
- 1e-2 * np.ones_like(mean[:, 3]), self._std_weight_position * mean[:, 3]]
+ self._std_weight_position * mean[:, 3],
+ self._std_weight_position * mean[:, 3],
+ 1e-2 * np.ones_like(mean[:, 3]),
+ self._std_weight_position * mean[:, 3],
+ ]
std_vel = [
- self._std_weight_velocity * mean[:, 3], self._std_weight_velocity * mean[:, 3],
- 1e-5 * np.ones_like(mean[:, 3]), self._std_weight_velocity * mean[:, 3]]
+ self._std_weight_velocity * mean[:, 3],
+ self._std_weight_velocity * mean[:, 3],
+ 1e-5 * np.ones_like(mean[:, 3]),
+ self._std_weight_velocity * mean[:, 3],
+ ]
sqr = np.square(np.r_[std_pos, std_vel]).T
motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
@@ -143,21 +164,23 @@ class KalmanFilterXYAH:
projected_mean, projected_cov = self.project(mean, covariance)
chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)
- kalman_gain = scipy.linalg.cho_solve((chol_factor, lower),
- np.dot(covariance, self._update_mat.T).T,
- check_finite=False).T
+ kalman_gain = scipy.linalg.cho_solve(
+ (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, check_finite=False
+ ).T
innovation = measurement - projected_mean
new_mean = mean + np.dot(innovation, kalman_gain.T)
new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T))
return new_mean, new_covariance
- def gating_distance(self,
- mean: np.ndarray,
- covariance: np.ndarray,
- measurements: np.ndarray,
- only_position: bool = False,
- metric: str = 'maha') -> np.ndarray:
+ def gating_distance(
+ self,
+ mean: np.ndarray,
+ covariance: np.ndarray,
+ measurements: np.ndarray,
+ only_position: bool = False,
+ metric: str = "maha",
+ ) -> np.ndarray:
"""
Compute gating distance between state distribution and measurements. A suitable distance threshold can be
obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of freedom,
@@ -183,14 +206,14 @@ class KalmanFilterXYAH:
measurements = measurements[:, :2]
d = measurements - mean
- if metric == 'gaussian':
+ if metric == "gaussian":
return np.sum(d * d, axis=1)
- elif metric == 'maha':
+ elif metric == "maha":
cholesky_factor = np.linalg.cholesky(covariance)
z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True)
return np.sum(z * z, axis=0) # square maha
else:
- raise ValueError('Invalid distance metric')
+ raise ValueError("Invalid distance metric")
class KalmanFilterXYWH(KalmanFilterXYAH):
@@ -220,10 +243,15 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
mean = np.r_[mean_pos, mean_vel]
std = [
- 2 * self._std_weight_position * measurement[2], 2 * self._std_weight_position * measurement[3],
- 2 * self._std_weight_position * measurement[2], 2 * self._std_weight_position * measurement[3],
- 10 * self._std_weight_velocity * measurement[2], 10 * self._std_weight_velocity * measurement[3],
- 10 * self._std_weight_velocity * measurement[2], 10 * self._std_weight_velocity * measurement[3]]
+ 2 * self._std_weight_position * measurement[2],
+ 2 * self._std_weight_position * measurement[3],
+ 2 * self._std_weight_position * measurement[2],
+ 2 * self._std_weight_position * measurement[3],
+ 10 * self._std_weight_velocity * measurement[2],
+ 10 * self._std_weight_velocity * measurement[3],
+ 10 * self._std_weight_velocity * measurement[2],
+ 10 * self._std_weight_velocity * measurement[3],
+ ]
covariance = np.diag(np.square(std))
return mean, covariance
@@ -240,11 +268,17 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
velocities are initialized to 0 mean.
"""
std_pos = [
- self._std_weight_position * mean[2], self._std_weight_position * mean[3],
- self._std_weight_position * mean[2], self._std_weight_position * mean[3]]
+ self._std_weight_position * mean[2],
+ self._std_weight_position * mean[3],
+ self._std_weight_position * mean[2],
+ self._std_weight_position * mean[3],
+ ]
std_vel = [
- self._std_weight_velocity * mean[2], self._std_weight_velocity * mean[3],
- self._std_weight_velocity * mean[2], self._std_weight_velocity * mean[3]]
+ self._std_weight_velocity * mean[2],
+ self._std_weight_velocity * mean[3],
+ self._std_weight_velocity * mean[2],
+ self._std_weight_velocity * mean[3],
+ ]
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
mean = np.dot(mean, self._motion_mat.T)
@@ -264,8 +298,11 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
(tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate.
"""
std = [
- self._std_weight_position * mean[2], self._std_weight_position * mean[3],
- self._std_weight_position * mean[2], self._std_weight_position * mean[3]]
+ self._std_weight_position * mean[2],
+ self._std_weight_position * mean[3],
+ self._std_weight_position * mean[2],
+ self._std_weight_position * mean[3],
+ ]
innovation_cov = np.diag(np.square(std))
mean = np.dot(self._update_mat, mean)
@@ -285,11 +322,17 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
velocities are initialized to 0 mean.
"""
std_pos = [
- self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3],
- self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3]]
+ self._std_weight_position * mean[:, 2],
+ self._std_weight_position * mean[:, 3],
+ self._std_weight_position * mean[:, 2],
+ self._std_weight_position * mean[:, 3],
+ ]
std_vel = [
- self._std_weight_velocity * mean[:, 2], self._std_weight_velocity * mean[:, 3],
- self._std_weight_velocity * mean[:, 2], self._std_weight_velocity * mean[:, 3]]
+ self._std_weight_velocity * mean[:, 2],
+ self._std_weight_velocity * mean[:, 3],
+ self._std_weight_velocity * mean[:, 2],
+ self._std_weight_velocity * mean[:, 3],
+ ]
sqr = np.square(np.r_[std_pos, std_vel]).T
motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
diff --git a/ultralytics/trackers/utils/matching.py b/ultralytics/trackers/utils/matching.py
index b2df4bd1..cf59755c 100644
--- a/ultralytics/trackers/utils/matching.py
+++ b/ultralytics/trackers/utils/matching.py
@@ -13,7 +13,7 @@ try:
except (ImportError, AssertionError, AttributeError):
from ultralytics.utils.checks import check_requirements
- check_requirements('lapx>=0.5.2') # update to lap package from https://github.com/rathaROG/lapx
+ check_requirements("lapx>=0.5.2") # update to lap package from https://github.com/rathaROG/lapx
import lap
@@ -70,8 +70,9 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray:
(np.ndarray): Cost matrix computed based on IoU.
"""
- if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) \
- or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
+ if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) or (
+ len(btracks) > 0 and isinstance(btracks[0], np.ndarray)
+ ):
atlbrs = atracks
btlbrs = btracks
else:
@@ -80,13 +81,13 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray:
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
if len(atlbrs) and len(btlbrs):
- ious = bbox_ioa(np.ascontiguousarray(atlbrs, dtype=np.float32),
- np.ascontiguousarray(btlbrs, dtype=np.float32),
- iou=True)
+ ious = bbox_ioa(
+ np.ascontiguousarray(atlbrs, dtype=np.float32), np.ascontiguousarray(btlbrs, dtype=np.float32), iou=True
+ )
return 1 - ious # cost matrix
-def embedding_distance(tracks: list, detections: list, metric: str = 'cosine') -> np.ndarray:
+def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -> np.ndarray:
"""
Compute distance between tracks and detections based on embeddings.
diff --git a/ultralytics/utils/__init__.py b/ultralytics/utils/__init__.py
index 67bc7a2d..07641c65 100644
--- a/ultralytics/utils/__init__.py
+++ b/ultralytics/utils/__init__.py
@@ -25,23 +25,22 @@ from tqdm import tqdm as tqdm_original
from ultralytics import __version__
# PyTorch Multi-GPU DDP Constants
-RANK = int(os.getenv('RANK', -1))
-LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
+RANK = int(os.getenv("RANK", -1))
+LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
# Other Constants
FILE = Path(__file__).resolve()
ROOT = FILE.parents[1] # YOLO
-ASSETS = ROOT / 'assets' # default images
-DEFAULT_CFG_PATH = ROOT / 'cfg/default.yaml'
+ASSETS = ROOT / "assets" # default images
+DEFAULT_CFG_PATH = ROOT / "cfg/default.yaml"
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
-AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
-VERBOSE = str(os.getenv('YOLO_VERBOSE', True)).lower() == 'true' # global verbose mode
-TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' if VERBOSE else None # tqdm bar format
-LOGGING_NAME = 'ultralytics'
-MACOS, LINUX, WINDOWS = (platform.system() == x for x in ['Darwin', 'Linux', 'Windows']) # environment booleans
-ARM64 = platform.machine() in ('arm64', 'aarch64') # ARM64 booleans
-HELP_MSG = \
- """
+AUTOINSTALL = str(os.getenv("YOLO_AUTOINSTALL", True)).lower() == "true" # global auto-install mode
+VERBOSE = str(os.getenv("YOLO_VERBOSE", True)).lower() == "true" # global verbose mode
+TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar format
+LOGGING_NAME = "ultralytics"
+MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans
+ARM64 = platform.machine() in ("arm64", "aarch64") # ARM64 booleans
+HELP_MSG = """
Usage examples for running YOLOv8:
1. Install the ultralytics package:
@@ -99,12 +98,12 @@ HELP_MSG = \
"""
# Settings
-torch.set_printoptions(linewidth=320, precision=4, profile='default')
-np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
+torch.set_printoptions(linewidth=320, precision=4, profile="default")
+np.set_printoptions(linewidth=320, formatter={"float_kind": "{:11.5g}".format}) # format short g, %precision=5
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
-os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
-os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training
-os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # suppress verbose TF compiler warnings in Colab
+os.environ["NUMEXPR_MAX_THREADS"] = str(NUM_THREADS) # NumExpr max threads
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # for deterministic training
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # suppress verbose TF compiler warnings in Colab
class TQDM(tqdm_original):
@@ -119,8 +118,8 @@ class TQDM(tqdm_original):
def __init__(self, *args, **kwargs):
"""Initialize custom Ultralytics tqdm class with different default arguments."""
# Set new default values (these can still be overridden when calling TQDM)
- kwargs['disable'] = not VERBOSE or kwargs.get('disable', False) # logical 'and' with default value if passed
- kwargs.setdefault('bar_format', TQDM_BAR_FORMAT) # override default value if passed
+ kwargs["disable"] = not VERBOSE or kwargs.get("disable", False) # logical 'and' with default value if passed
+ kwargs.setdefault("bar_format", TQDM_BAR_FORMAT) # override default value if passed
super().__init__(*args, **kwargs)
@@ -134,14 +133,14 @@ class SimpleClass:
attr = []
for a in dir(self):
v = getattr(self, a)
- if not callable(v) and not a.startswith('_'):
+ if not callable(v) and not a.startswith("_"):
if isinstance(v, SimpleClass):
# Display only the module and class name for subclasses
- s = f'{a}: {v.__module__}.{v.__class__.__name__} object'
+ s = f"{a}: {v.__module__}.{v.__class__.__name__} object"
else:
- s = f'{a}: {repr(v)}'
+ s = f"{a}: {repr(v)}"
attr.append(s)
- return f'{self.__module__}.{self.__class__.__name__} object with attributes:\n\n' + '\n'.join(attr)
+ return f"{self.__module__}.{self.__class__.__name__} object with attributes:\n\n" + "\n".join(attr)
def __repr__(self):
"""Return a machine-readable string representation of the object."""
@@ -164,24 +163,26 @@ class IterableSimpleNamespace(SimpleNamespace):
def __str__(self):
"""Return a human-readable string representation of the object."""
- return '\n'.join(f'{k}={v}' for k, v in vars(self).items())
+ return "\n".join(f"{k}={v}" for k, v in vars(self).items())
def __getattr__(self, attr):
"""Custom attribute access error message with helpful information."""
name = self.__class__.__name__
- raise AttributeError(f"""
+ raise AttributeError(
+ f"""
'{name}' object has no attribute '{attr}'. This may be caused by a modified or out of date ultralytics
'default.yaml' file.\nPlease update your code with 'pip install -U ultralytics' and if necessary replace
{DEFAULT_CFG_PATH} with the latest version from
https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml
- """)
+ """
+ )
def get(self, key, default=None):
"""Return the value of the specified key if it exists; otherwise, return the default value."""
return getattr(self, key, default)
-def plt_settings(rcparams=None, backend='Agg'):
+def plt_settings(rcparams=None, backend="Agg"):
"""
Decorator to temporarily set rc parameters and the backend for a plotting function.
@@ -199,7 +200,7 @@ def plt_settings(rcparams=None, backend='Agg'):
"""
if rcparams is None:
- rcparams = {'font.size': 11}
+ rcparams = {"font.size": 11}
def decorator(func):
"""Decorator to apply temporary rc parameters and backend to a function."""
@@ -208,14 +209,14 @@ def plt_settings(rcparams=None, backend='Agg'):
"""Sets rc parameters and backend, calls the original function, and restores the settings."""
original_backend = plt.get_backend()
if backend != original_backend:
- plt.close('all') # auto-close()ing of figures upon backend switching is deprecated since 3.8
+ plt.close("all") # auto-close()ing of figures upon backend switching is deprecated since 3.8
plt.switch_backend(backend)
with plt.rc_context(rcparams):
result = func(*args, **kwargs)
if backend != original_backend:
- plt.close('all')
+ plt.close("all")
plt.switch_backend(original_backend)
return result
@@ -229,26 +230,26 @@ def set_logging(name=LOGGING_NAME, verbose=True):
level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR # rank in world for Multi-GPU trainings
# Configure the console (stdout) encoding to UTF-8
- formatter = logging.Formatter('%(message)s') # Default formatter
- if WINDOWS and sys.stdout.encoding != 'utf-8':
+ formatter = logging.Formatter("%(message)s") # Default formatter
+ if WINDOWS and sys.stdout.encoding != "utf-8":
try:
- if hasattr(sys.stdout, 'reconfigure'):
- sys.stdout.reconfigure(encoding='utf-8')
- elif hasattr(sys.stdout, 'buffer'):
+ if hasattr(sys.stdout, "reconfigure"):
+ sys.stdout.reconfigure(encoding="utf-8")
+ elif hasattr(sys.stdout, "buffer"):
import io
- sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
+
+ sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
else:
- sys.stdout.encoding = 'utf-8'
+ sys.stdout.encoding = "utf-8"
except Exception as e:
- print(f'Creating custom formatter for non UTF-8 environments due to {e}')
+ print(f"Creating custom formatter for non UTF-8 environments due to {e}")
class CustomFormatter(logging.Formatter):
-
def format(self, record):
"""Sets up logging with UTF-8 encoding and configurable verbosity."""
return emojis(super().format(record))
- formatter = CustomFormatter('%(message)s') # Use CustomFormatter to eliminate UTF-8 output as last recourse
+ formatter = CustomFormatter("%(message)s") # Use CustomFormatter to eliminate UTF-8 output as last recourse
# Create and configure the StreamHandler
stream_handler = logging.StreamHandler(sys.stdout)
@@ -264,13 +265,13 @@ def set_logging(name=LOGGING_NAME, verbose=True):
# Set logger
LOGGER = set_logging(LOGGING_NAME, verbose=VERBOSE) # define globally (used in train.py, val.py, predict.py, etc.)
-for logger in 'sentry_sdk', 'urllib3.connectionpool':
+for logger in "sentry_sdk", "urllib3.connectionpool":
logging.getLogger(logger).setLevel(logging.CRITICAL + 1)
-def emojis(string=''):
+def emojis(string=""):
"""Return platform-dependent emoji-safe version of string."""
- return string.encode().decode('ascii', 'ignore') if WINDOWS else string
+ return string.encode().decode("ascii", "ignore") if WINDOWS else string
class ThreadingLocked:
@@ -310,7 +311,7 @@ class ThreadingLocked:
return decorated
-def yaml_save(file='data.yaml', data=None, header=''):
+def yaml_save(file="data.yaml", data=None, header=""):
"""
Save YAML data to a file.
@@ -336,13 +337,13 @@ def yaml_save(file='data.yaml', data=None, header=''):
data[k] = str(v)
# Dump data to file in YAML format
- with open(file, 'w', errors='ignore', encoding='utf-8') as f:
+ with open(file, "w", errors="ignore", encoding="utf-8") as f:
if header:
f.write(header)
yaml.safe_dump(data, f, sort_keys=False, allow_unicode=True)
-def yaml_load(file='data.yaml', append_filename=False):
+def yaml_load(file="data.yaml", append_filename=False):
"""
Load YAML data from a file.
@@ -353,18 +354,18 @@ def yaml_load(file='data.yaml', append_filename=False):
Returns:
(dict): YAML data and file name.
"""
- assert Path(file).suffix in ('.yaml', '.yml'), f'Attempting to load non-YAML file {file} with yaml_load()'
- with open(file, errors='ignore', encoding='utf-8') as f:
+ assert Path(file).suffix in (".yaml", ".yml"), f"Attempting to load non-YAML file {file} with yaml_load()"
+ with open(file, errors="ignore", encoding="utf-8") as f:
s = f.read() # string
# Remove special characters
if not s.isprintable():
- s = re.sub(r'[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+', '', s)
+ s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+", "", s)
# Add YAML filename to dict and return
data = yaml.safe_load(s) or {} # always return a dict (yaml.safe_load() may return None for empty files)
if append_filename:
- data['yaml_file'] = str(file)
+ data["yaml_file"] = str(file)
return data
@@ -386,7 +387,7 @@ def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
# Default configuration
DEFAULT_CFG_DICT = yaml_load(DEFAULT_CFG_PATH)
for k, v in DEFAULT_CFG_DICT.items():
- if isinstance(v, str) and v.lower() == 'none':
+ if isinstance(v, str) and v.lower() == "none":
DEFAULT_CFG_DICT[k] = None
DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
@@ -400,8 +401,8 @@ def is_ubuntu() -> bool:
(bool): True if OS is Ubuntu, False otherwise.
"""
with contextlib.suppress(FileNotFoundError):
- with open('/etc/os-release') as f:
- return 'ID=ubuntu' in f.read()
+ with open("/etc/os-release") as f:
+ return "ID=ubuntu" in f.read()
return False
@@ -412,7 +413,7 @@ def is_colab():
Returns:
(bool): True if running inside a Colab notebook, False otherwise.
"""
- return 'COLAB_RELEASE_TAG' in os.environ or 'COLAB_BACKEND_VERSION' in os.environ
+ return "COLAB_RELEASE_TAG" in os.environ or "COLAB_BACKEND_VERSION" in os.environ
def is_kaggle():
@@ -422,7 +423,7 @@ def is_kaggle():
Returns:
(bool): True if running inside a Kaggle kernel, False otherwise.
"""
- return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
+ return os.environ.get("PWD") == "/kaggle/working" and os.environ.get("KAGGLE_URL_BASE") == "https://www.kaggle.com"
def is_jupyter():
@@ -434,6 +435,7 @@ def is_jupyter():
"""
with contextlib.suppress(Exception):
from IPython import get_ipython
+
return get_ipython() is not None
return False
@@ -445,10 +447,10 @@ def is_docker() -> bool:
Returns:
(bool): True if the script is running inside a Docker container, False otherwise.
"""
- file = Path('/proc/self/cgroup')
+ file = Path("/proc/self/cgroup")
if file.exists():
with open(file) as f:
- return 'docker' in f.read()
+ return "docker" in f.read()
else:
return False
@@ -462,7 +464,7 @@ def is_online() -> bool:
"""
import socket
- for host in '1.1.1.1', '8.8.8.8', '223.5.5.5': # Cloudflare, Google, AliDNS:
+ for host in "1.1.1.1", "8.8.8.8", "223.5.5.5": # Cloudflare, Google, AliDNS:
try:
test_connection = socket.create_connection(address=(host, 53), timeout=2)
except (socket.timeout, socket.gaierror, OSError):
@@ -516,7 +518,7 @@ def is_pytest_running():
Returns:
(bool): True if pytest is running, False otherwise.
"""
- return ('PYTEST_CURRENT_TEST' in os.environ) or ('pytest' in sys.modules) or ('pytest' in Path(sys.argv[0]).stem)
+ return ("PYTEST_CURRENT_TEST" in os.environ) or ("pytest" in sys.modules) or ("pytest" in Path(sys.argv[0]).stem)
def is_github_action_running() -> bool:
@@ -526,7 +528,7 @@ def is_github_action_running() -> bool:
Returns:
(bool): True if the current environment is a GitHub Actions runner, False otherwise.
"""
- return 'GITHUB_ACTIONS' in os.environ and 'GITHUB_WORKFLOW' in os.environ and 'RUNNER_OS' in os.environ
+ return "GITHUB_ACTIONS" in os.environ and "GITHUB_WORKFLOW" in os.environ and "RUNNER_OS" in os.environ
def is_git_dir():
@@ -549,7 +551,7 @@ def get_git_dir():
(Path | None): Git root directory if found or None if not found.
"""
for d in Path(__file__).parents:
- if (d / '.git').is_dir():
+ if (d / ".git").is_dir():
return d
@@ -562,7 +564,7 @@ def get_git_origin_url():
"""
if is_git_dir():
with contextlib.suppress(subprocess.CalledProcessError):
- origin = subprocess.check_output(['git', 'config', '--get', 'remote.origin.url'])
+ origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"])
return origin.decode().strip()
@@ -575,7 +577,7 @@ def get_git_branch():
"""
if is_git_dir():
with contextlib.suppress(subprocess.CalledProcessError):
- origin = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
+ origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
return origin.decode().strip()
@@ -602,11 +604,11 @@ def get_ubuntu_version():
"""
if is_ubuntu():
with contextlib.suppress(FileNotFoundError, AttributeError):
- with open('/etc/os-release') as f:
+ with open("/etc/os-release") as f:
return re.search(r'VERSION_ID="(\d+\.\d+)"', f.read())[1]
-def get_user_config_dir(sub_dir='Ultralytics'):
+def get_user_config_dir(sub_dir="Ultralytics"):
"""
Get the user config directory.
@@ -618,19 +620,21 @@ def get_user_config_dir(sub_dir='Ultralytics'):
"""
# Return the appropriate config directory for each operating system
if WINDOWS:
- path = Path.home() / 'AppData' / 'Roaming' / sub_dir
+ path = Path.home() / "AppData" / "Roaming" / sub_dir
elif MACOS: # macOS
- path = Path.home() / 'Library' / 'Application Support' / sub_dir
+ path = Path.home() / "Library" / "Application Support" / sub_dir
elif LINUX:
- path = Path.home() / '.config' / sub_dir
+ path = Path.home() / ".config" / sub_dir
else:
- raise ValueError(f'Unsupported operating system: {platform.system()}')
+ raise ValueError(f"Unsupported operating system: {platform.system()}")
# GCP and AWS lambda fix, only /tmp is writeable
if not is_dir_writeable(path.parent):
- LOGGER.warning(f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD."
- 'Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path.')
- path = Path('/tmp') / sub_dir if is_dir_writeable('/tmp') else Path().cwd() / sub_dir
+ LOGGER.warning(
+ f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD."
+ "Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path."
+ )
+ path = Path("/tmp") / sub_dir if is_dir_writeable("/tmp") else Path().cwd() / sub_dir
# Create the subdirectory if it does not exist
path.mkdir(parents=True, exist_ok=True)
@@ -638,8 +642,8 @@ def get_user_config_dir(sub_dir='Ultralytics'):
return path
-USER_CONFIG_DIR = Path(os.getenv('YOLO_CONFIG_DIR') or get_user_config_dir()) # Ultralytics settings dir
-SETTINGS_YAML = USER_CONFIG_DIR / 'settings.yaml'
+USER_CONFIG_DIR = Path(os.getenv("YOLO_CONFIG_DIR") or get_user_config_dir()) # Ultralytics settings dir
+SETTINGS_YAML = USER_CONFIG_DIR / "settings.yaml"
def colorstr(*input):
@@ -670,28 +674,29 @@ def colorstr(*input):
>>> colorstr('blue', 'bold', 'hello world')
>>> '\033[34m\033[1mhello world\033[0m'
"""
- *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
+ *args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string
colors = {
- 'black': '\033[30m', # basic colors
- 'red': '\033[31m',
- 'green': '\033[32m',
- 'yellow': '\033[33m',
- 'blue': '\033[34m',
- 'magenta': '\033[35m',
- 'cyan': '\033[36m',
- 'white': '\033[37m',
- 'bright_black': '\033[90m', # bright colors
- 'bright_red': '\033[91m',
- 'bright_green': '\033[92m',
- 'bright_yellow': '\033[93m',
- 'bright_blue': '\033[94m',
- 'bright_magenta': '\033[95m',
- 'bright_cyan': '\033[96m',
- 'bright_white': '\033[97m',
- 'end': '\033[0m', # misc
- 'bold': '\033[1m',
- 'underline': '\033[4m'}
- return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
+ "black": "\033[30m", # basic colors
+ "red": "\033[31m",
+ "green": "\033[32m",
+ "yellow": "\033[33m",
+ "blue": "\033[34m",
+ "magenta": "\033[35m",
+ "cyan": "\033[36m",
+ "white": "\033[37m",
+ "bright_black": "\033[90m", # bright colors
+ "bright_red": "\033[91m",
+ "bright_green": "\033[92m",
+ "bright_yellow": "\033[93m",
+ "bright_blue": "\033[94m",
+ "bright_magenta": "\033[95m",
+ "bright_cyan": "\033[96m",
+ "bright_white": "\033[97m",
+ "end": "\033[0m", # misc
+ "bold": "\033[1m",
+ "underline": "\033[4m",
+ }
+ return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
def remove_colorstr(input_string):
@@ -708,8 +713,8 @@ def remove_colorstr(input_string):
>>> remove_colorstr(colorstr('blue', 'bold', 'hello world'))
>>> 'hello world'
"""
- ansi_escape = re.compile(r'\x1B\[[0-9;]*[A-Za-z]')
- return ansi_escape.sub('', input_string)
+ ansi_escape = re.compile(r"\x1B\[[0-9;]*[A-Za-z]")
+ return ansi_escape.sub("", input_string)
class TryExcept(contextlib.ContextDecorator):
@@ -719,7 +724,7 @@ class TryExcept(contextlib.ContextDecorator):
Use as @TryExcept() decorator or 'with TryExcept():' context manager.
"""
- def __init__(self, msg='', verbose=True):
+ def __init__(self, msg="", verbose=True):
"""Initialize TryExcept class with optional message and verbosity settings."""
self.msg = msg
self.verbose = verbose
@@ -744,7 +749,7 @@ def threaded(func):
def wrapper(*args, **kwargs):
"""Multi-threads a given function based on 'threaded' kwarg and returns the thread or function result."""
- if kwargs.pop('threaded', True): # run in thread
+ if kwargs.pop("threaded", True): # run in thread
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread
@@ -786,27 +791,28 @@ def set_sentry():
Returns:
dict: The modified event or None if the event should not be sent to Sentry.
"""
- if 'exc_info' in hint:
- exc_type, exc_value, tb = hint['exc_info']
- if exc_type in (KeyboardInterrupt, FileNotFoundError) \
- or 'out of memory' in str(exc_value):
+ if "exc_info" in hint:
+ exc_type, exc_value, tb = hint["exc_info"]
+ if exc_type in (KeyboardInterrupt, FileNotFoundError) or "out of memory" in str(exc_value):
return None # do not send event
- event['tags'] = {
- 'sys_argv': sys.argv[0],
- 'sys_argv_name': Path(sys.argv[0]).name,
- 'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
- 'os': ENVIRONMENT}
+ event["tags"] = {
+ "sys_argv": sys.argv[0],
+ "sys_argv_name": Path(sys.argv[0]).name,
+ "install": "git" if is_git_dir() else "pip" if is_pip_package() else "other",
+ "os": ENVIRONMENT,
+ }
return event
- if SETTINGS['sync'] and \
- RANK in (-1, 0) and \
- Path(sys.argv[0]).name == 'yolo' and \
- not TESTS_RUNNING and \
- ONLINE and \
- is_pip_package() and \
- not is_git_dir():
-
+ if (
+ SETTINGS["sync"]
+ and RANK in (-1, 0)
+ and Path(sys.argv[0]).name == "yolo"
+ and not TESTS_RUNNING
+ and ONLINE
+ and is_pip_package()
+ and not is_git_dir()
+ ):
# If sentry_sdk package is not installed then return and do not use Sentry
try:
import sentry_sdk # noqa
@@ -814,14 +820,15 @@ def set_sentry():
return
sentry_sdk.init(
- dsn='https://5ff1556b71594bfea135ff0203a0d290@o4504521589325824.ingest.sentry.io/4504521592406016',
+ dsn="https://5ff1556b71594bfea135ff0203a0d290@o4504521589325824.ingest.sentry.io/4504521592406016",
debug=False,
traces_sample_rate=1.0,
release=__version__,
- environment='production', # 'dev' or 'production'
+ environment="production", # 'dev' or 'production'
before_send=before_send,
- ignore_errors=[KeyboardInterrupt, FileNotFoundError])
- sentry_sdk.set_user({'id': SETTINGS['uuid']}) # SHA-256 anonymized UUID hash
+ ignore_errors=[KeyboardInterrupt, FileNotFoundError],
+ )
+ sentry_sdk.set_user({"id": SETTINGS["uuid"]}) # SHA-256 anonymized UUID hash
class SettingsManager(dict):
@@ -833,7 +840,7 @@ class SettingsManager(dict):
version (str): Settings version. In case of local version mismatch, new default settings will be saved.
"""
- def __init__(self, file=SETTINGS_YAML, version='0.0.4'):
+ def __init__(self, file=SETTINGS_YAML, version="0.0.4"):
"""Initialize the SettingsManager with default settings, load and validate current settings from the YAML
file.
"""
@@ -850,23 +857,24 @@ class SettingsManager(dict):
self.file = Path(file)
self.version = version
self.defaults = {
- 'settings_version': version,
- 'datasets_dir': str(datasets_root / 'datasets'),
- 'weights_dir': str(root / 'weights'),
- 'runs_dir': str(root / 'runs'),
- 'uuid': hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(),
- 'sync': True,
- 'api_key': '',
- 'openai_api_key': '',
- 'clearml': True, # integrations
- 'comet': True,
- 'dvc': True,
- 'hub': True,
- 'mlflow': True,
- 'neptune': True,
- 'raytune': True,
- 'tensorboard': True,
- 'wandb': True}
+ "settings_version": version,
+ "datasets_dir": str(datasets_root / "datasets"),
+ "weights_dir": str(root / "weights"),
+ "runs_dir": str(root / "runs"),
+ "uuid": hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(),
+ "sync": True,
+ "api_key": "",
+ "openai_api_key": "",
+ "clearml": True, # integrations
+ "comet": True,
+ "dvc": True,
+ "hub": True,
+ "mlflow": True,
+ "neptune": True,
+ "raytune": True,
+ "tensorboard": True,
+ "wandb": True,
+ }
super().__init__(copy.deepcopy(self.defaults))
@@ -877,13 +885,14 @@ class SettingsManager(dict):
self.load()
correct_keys = self.keys() == self.defaults.keys()
correct_types = all(type(a) is type(b) for a, b in zip(self.values(), self.defaults.values()))
- correct_version = check_version(self['settings_version'], self.version)
+ correct_version = check_version(self["settings_version"], self.version)
if not (correct_keys and correct_types and correct_version):
LOGGER.warning(
- 'WARNING ⚠️ Ultralytics settings reset to default values. This may be due to a possible problem '
- 'with your settings or a recent ultralytics package update. '
+ "WARNING ⚠️ Ultralytics settings reset to default values. This may be due to a possible problem "
+ "with your settings or a recent ultralytics package update. "
f"\nView settings with 'yolo settings' or at '{self.file}'"
- "\nUpdate settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'.")
+ "\nUpdate settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'."
+ )
self.reset()
def load(self):
@@ -910,14 +919,16 @@ def deprecation_warn(arg, new_arg, version=None):
"""Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument."""
if not version:
version = float(__version__[:3]) + 0.2 # deprecate after 2nd major release
- LOGGER.warning(f"WARNING ⚠️ '{arg}' is deprecated and will be removed in 'ultralytics {version}' in the future. "
- f"Please use '{new_arg}' instead.")
+ LOGGER.warning(
+ f"WARNING ⚠️ '{arg}' is deprecated and will be removed in 'ultralytics {version}' in the future. "
+ f"Please use '{new_arg}' instead."
+ )
def clean_url(url):
"""Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt."""
- url = Path(url).as_posix().replace(':/', '://') # Pathlib turns :// -> :/, as_posix() for Windows
- return urllib.parse.unquote(url).split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
+ url = Path(url).as_posix().replace(":/", "://") # Pathlib turns :// -> :/, as_posix() for Windows
+ return urllib.parse.unquote(url).split("?")[0] # '%2F' to '/', split https://url.com/file.txt?auth
def url2file(url):
@@ -928,13 +939,22 @@ def url2file(url):
# Run below code on utils init ------------------------------------------------------------------------------------
# Check first-install steps
-PREFIX = colorstr('Ultralytics: ')
+PREFIX = colorstr("Ultralytics: ")
SETTINGS = SettingsManager() # initialize settings
-DATASETS_DIR = Path(SETTINGS['datasets_dir']) # global datasets directory
-WEIGHTS_DIR = Path(SETTINGS['weights_dir']) # global weights directory
-RUNS_DIR = Path(SETTINGS['runs_dir']) # global runs directory
-ENVIRONMENT = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \
- 'Docker' if is_docker() else platform.system()
+DATASETS_DIR = Path(SETTINGS["datasets_dir"]) # global datasets directory
+WEIGHTS_DIR = Path(SETTINGS["weights_dir"]) # global weights directory
+RUNS_DIR = Path(SETTINGS["runs_dir"]) # global runs directory
+ENVIRONMENT = (
+ "Colab"
+ if is_colab()
+ else "Kaggle"
+ if is_kaggle()
+ else "Jupyter"
+ if is_jupyter()
+ else "Docker"
+ if is_docker()
+ else platform.system()
+)
TESTS_RUNNING = is_pytest_running() or is_github_action_running()
set_sentry()
diff --git a/ultralytics/utils/autobatch.py b/ultralytics/utils/autobatch.py
index 172a4c12..daea14ed 100644
--- a/ultralytics/utils/autobatch.py
+++ b/ultralytics/utils/autobatch.py
@@ -42,14 +42,14 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
"""
# Check device
- prefix = colorstr('AutoBatch: ')
- LOGGER.info(f'{prefix}Computing optimal batch size for imgsz={imgsz}')
+ prefix = colorstr("AutoBatch: ")
+ LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz}")
device = next(model.parameters()).device # get model device
- if device.type == 'cpu':
- LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
+ if device.type == "cpu":
+ LOGGER.info(f"{prefix}CUDA not detected, using default CPU batch-size {batch_size}")
return batch_size
if torch.backends.cudnn.benchmark:
- LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}')
+ LOGGER.info(f"{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}")
return batch_size
# Inspect CUDA memory
@@ -60,7 +60,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
r = torch.cuda.memory_reserved(device) / gb # GiB reserved
a = torch.cuda.memory_allocated(device) / gb # GiB allocated
f = t - (r + a) # GiB free
- LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')
+ LOGGER.info(f"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free")
# Profile batch sizes
batch_sizes = [1, 2, 4, 8, 16]
@@ -70,7 +70,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
# Fit a solution
y = [x[2] for x in results if x] # memory [2]
- p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit
+ p = np.polyfit(batch_sizes[: len(y)], y, deg=1) # first degree polynomial fit
b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
if None in results: # some sizes failed
i = results.index(None) # first fail index
@@ -78,11 +78,11 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
b = batch_sizes[max(i - 1, 0)] # select prior safe point
if b < 1 or b > 1024: # b outside of safe range
b = batch_size
- LOGGER.info(f'{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.')
+ LOGGER.info(f"{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.")
fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted
- LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅')
+ LOGGER.info(f"{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅")
return b
except Exception as e:
- LOGGER.warning(f'{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.')
+ LOGGER.warning(f"{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.")
return batch_size
diff --git a/ultralytics/utils/benchmarks.py b/ultralytics/utils/benchmarks.py
index 4842ff5b..66621f53 100644
--- a/ultralytics/utils/benchmarks.py
+++ b/ultralytics/utils/benchmarks.py
@@ -42,13 +42,9 @@ from ultralytics.utils.files import file_size
from ultralytics.utils.torch_utils import select_device
-def benchmark(model=WEIGHTS_DIR / 'yolov8n.pt',
- data=None,
- imgsz=160,
- half=False,
- int8=False,
- device='cpu',
- verbose=False):
+def benchmark(
+ model=WEIGHTS_DIR / "yolov8n.pt", data=None, imgsz=160, half=False, int8=False, device="cpu", verbose=False
+):
"""
Benchmark a YOLO model across different formats for speed and accuracy.
@@ -76,6 +72,7 @@ def benchmark(model=WEIGHTS_DIR / 'yolov8n.pt',
"""
import pandas as pd
+
pd.options.display.max_columns = 10
pd.options.display.width = 120
device = select_device(device, verbose=False)
@@ -85,67 +82,62 @@ def benchmark(model=WEIGHTS_DIR / 'yolov8n.pt',
y = []
t0 = time.time()
for i, (name, format, suffix, cpu, gpu) in export_formats().iterrows(): # index, (name, format, suffix, CPU, GPU)
- emoji, filename = '❌', None # export defaults
+ emoji, filename = "❌", None # export defaults
try:
- assert i != 9 or LINUX, 'Edge TPU export only supported on Linux'
+ assert i != 9 or LINUX, "Edge TPU export only supported on Linux"
if i == 10:
- assert MACOS or LINUX, 'TF.js export only supported on macOS and Linux'
+ assert MACOS or LINUX, "TF.js export only supported on macOS and Linux"
elif i == 11:
- assert sys.version_info < (3, 11), 'PaddlePaddle export only supported on Python<=3.10'
- if 'cpu' in device.type:
- assert cpu, 'inference not supported on CPU'
- if 'cuda' in device.type:
- assert gpu, 'inference not supported on GPU'
+ assert sys.version_info < (3, 11), "PaddlePaddle export only supported on Python<=3.10"
+ if "cpu" in device.type:
+ assert cpu, "inference not supported on CPU"
+ if "cuda" in device.type:
+ assert gpu, "inference not supported on GPU"
# Export
- if format == '-':
+ if format == "-":
filename = model.ckpt_path or model.cfg
exported_model = model # PyTorch format
else:
filename = model.export(imgsz=imgsz, format=format, half=half, int8=int8, device=device, verbose=False)
exported_model = YOLO(filename, task=model.task)
- assert suffix in str(filename), 'export failed'
- emoji = '❎' # indicates export succeeded
+ assert suffix in str(filename), "export failed"
+ emoji = "❎" # indicates export succeeded
# Predict
- assert model.task != 'pose' or i != 7, 'GraphDef Pose inference is not supported'
- assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
- assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
- exported_model.predict(ASSETS / 'bus.jpg', imgsz=imgsz, device=device, half=half)
+ assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported"
+ assert i not in (9, 10), "inference not supported" # Edge TPU and TF.js are unsupported
+ assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML
+ exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half)
# Validate
data = data or TASK2DATA[model.task] # task to dataset, i.e. coco8.yaml for task=detect
key = TASK2METRIC[model.task] # task to metric, i.e. metrics/mAP50-95(B) for task=detect
- results = exported_model.val(data=data,
- batch=1,
- imgsz=imgsz,
- plots=False,
- device=device,
- half=half,
- int8=int8,
- verbose=False)
- metric, speed = results.results_dict[key], results.speed['inference']
- y.append([name, '✅', round(file_size(filename), 1), round(metric, 4), round(speed, 2)])
+ results = exported_model.val(
+ data=data, batch=1, imgsz=imgsz, plots=False, device=device, half=half, int8=int8, verbose=False
+ )
+ metric, speed = results.results_dict[key], results.speed["inference"]
+ y.append([name, "✅", round(file_size(filename), 1), round(metric, 4), round(speed, 2)])
except Exception as e:
if verbose:
- assert type(e) is AssertionError, f'Benchmark failure for {name}: {e}'
- LOGGER.warning(f'ERROR ❌️ Benchmark failure for {name}: {e}')
+ assert type(e) is AssertionError, f"Benchmark failure for {name}: {e}"
+ LOGGER.warning(f"ERROR ❌️ Benchmark failure for {name}: {e}")
y.append([name, emoji, round(file_size(filename), 1), None, None]) # mAP, t_inference
# Print results
check_yolo(device=device) # print system info
- df = pd.DataFrame(y, columns=['Format', 'Status❔', 'Size (MB)', key, 'Inference time (ms/im)'])
+ df = pd.DataFrame(y, columns=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)"])
name = Path(model.ckpt_path).name
- s = f'\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n'
+ s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n"
LOGGER.info(s)
- with open('benchmarks.log', 'a', errors='ignore', encoding='utf-8') as f:
+ with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f:
f.write(s)
if verbose and isinstance(verbose, float):
metrics = df[key].array # values to compare to floor
floor = verbose # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
- assert all(x > floor for x in metrics if pd.notna(x)), f'Benchmark failure: metric(s) < floor {floor}'
+ assert all(x > floor for x in metrics if pd.notna(x)), f"Benchmark failure: metric(s) < floor {floor}"
return df
@@ -175,15 +167,17 @@ class ProfileModels:
```
"""
- def __init__(self,
- paths: list,
- num_timed_runs=100,
- num_warmup_runs=10,
- min_time=60,
- imgsz=640,
- half=True,
- trt=True,
- device=None):
+ def __init__(
+ self,
+ paths: list,
+ num_timed_runs=100,
+ num_warmup_runs=10,
+ min_time=60,
+ imgsz=640,
+ half=True,
+ trt=True,
+ device=None,
+ ):
"""
Initialize the ProfileModels class for profiling models.
@@ -204,37 +198,32 @@ class ProfileModels:
self.imgsz = imgsz
self.half = half
self.trt = trt # run TensorRT profiling
- self.device = device or torch.device(0 if torch.cuda.is_available() else 'cpu')
+ self.device = device or torch.device(0 if torch.cuda.is_available() else "cpu")
def profile(self):
"""Logs the benchmarking results of a model, checks metrics against floor and returns the results."""
files = self.get_files()
if not files:
- print('No matching *.pt or *.onnx files found.')
+ print("No matching *.pt or *.onnx files found.")
return
table_rows = []
output = []
for file in files:
- engine_file = file.with_suffix('.engine')
- if file.suffix in ('.pt', '.yaml', '.yml'):
+ engine_file = file.with_suffix(".engine")
+ if file.suffix in (".pt", ".yaml", ".yml"):
model = YOLO(str(file))
model.fuse() # to report correct params and GFLOPs in model.info()
model_info = model.info()
- if self.trt and self.device.type != 'cpu' and not engine_file.is_file():
- engine_file = model.export(format='engine',
- half=self.half,
- imgsz=self.imgsz,
- device=self.device,
- verbose=False)
- onnx_file = model.export(format='onnx',
- half=self.half,
- imgsz=self.imgsz,
- simplify=True,
- device=self.device,
- verbose=False)
- elif file.suffix == '.onnx':
+ if self.trt and self.device.type != "cpu" and not engine_file.is_file():
+ engine_file = model.export(
+ format="engine", half=self.half, imgsz=self.imgsz, device=self.device, verbose=False
+ )
+ onnx_file = model.export(
+ format="onnx", half=self.half, imgsz=self.imgsz, simplify=True, device=self.device, verbose=False
+ )
+ elif file.suffix == ".onnx":
model_info = self.get_onnx_model_info(file)
onnx_file = file
else:
@@ -254,14 +243,14 @@ class ProfileModels:
for path in self.paths:
path = Path(path)
if path.is_dir():
- extensions = ['*.pt', '*.onnx', '*.yaml']
+ extensions = ["*.pt", "*.onnx", "*.yaml"]
files.extend([file for ext in extensions for file in glob.glob(str(path / ext))])
- elif path.suffix in {'.pt', '.yaml', '.yml'}: # add non-existing
+ elif path.suffix in {".pt", ".yaml", ".yml"}: # add non-existing
files.append(str(path))
else:
files.extend(glob.glob(str(path)))
- print(f'Profiling: {sorted(files)}')
+ print(f"Profiling: {sorted(files)}")
return [Path(file) for file in sorted(files)]
def get_onnx_model_info(self, onnx_file: str):
@@ -306,7 +295,7 @@ class ProfileModels:
run_times = []
for _ in TQDM(range(num_runs), desc=engine_file):
results = model(input_data, imgsz=self.imgsz, verbose=False)
- run_times.append(results[0].speed['inference']) # Convert to milliseconds
+ run_times.append(results[0].speed["inference"]) # Convert to milliseconds
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping
return np.mean(run_times), np.std(run_times)
@@ -315,31 +304,31 @@ class ProfileModels:
"""Profiles an ONNX model by executing it multiple times and returns the mean and standard deviation of run
times.
"""
- check_requirements('onnxruntime')
+ check_requirements("onnxruntime")
import onnxruntime as ort
# Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.intra_op_num_threads = 8 # Limit the number of threads
- sess = ort.InferenceSession(onnx_file, sess_options, providers=['CPUExecutionProvider'])
+ sess = ort.InferenceSession(onnx_file, sess_options, providers=["CPUExecutionProvider"])
input_tensor = sess.get_inputs()[0]
input_type = input_tensor.type
# Mapping ONNX datatype to numpy datatype
- if 'float16' in input_type:
+ if "float16" in input_type:
input_dtype = np.float16
- elif 'float' in input_type:
+ elif "float" in input_type:
input_dtype = np.float32
- elif 'double' in input_type:
+ elif "double" in input_type:
input_dtype = np.float64
- elif 'int64' in input_type:
+ elif "int64" in input_type:
input_dtype = np.int64
- elif 'int32' in input_type:
+ elif "int32" in input_type:
input_dtype = np.int32
else:
- raise ValueError(f'Unsupported ONNX datatype {input_type}')
+ raise ValueError(f"Unsupported ONNX datatype {input_type}")
input_data = np.random.rand(*input_tensor.shape).astype(input_dtype)
input_name = input_tensor.name
@@ -369,25 +358,26 @@ class ProfileModels:
def generate_table_row(self, model_name, t_onnx, t_engine, model_info):
"""Generates a formatted string for a table row that includes model performance and metric details."""
layers, params, gradients, flops = model_info
- return f'| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |'
+ return f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |"
def generate_results_dict(self, model_name, t_onnx, t_engine, model_info):
"""Generates a dictionary of model details including name, parameters, GFLOPS and speed metrics."""
layers, params, gradients, flops = model_info
return {
- 'model/name': model_name,
- 'model/parameters': params,
- 'model/GFLOPs': round(flops, 3),
- 'model/speed_ONNX(ms)': round(t_onnx[0], 3),
- 'model/speed_TensorRT(ms)': round(t_engine[0], 3)}
+ "model/name": model_name,
+ "model/parameters": params,
+ "model/GFLOPs": round(flops, 3),
+ "model/speed_ONNX(ms)": round(t_onnx[0], 3),
+ "model/speed_TensorRT(ms)": round(t_engine[0], 3),
+ }
def print_table(self, table_rows):
"""Formats and prints a comparison table for different models with given statistics and performance data."""
- gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'GPU'
- header = f'| Model | size
(pixels) | mAPval
50-95 | Speed
CPU ONNX
(ms) | Speed
{gpu} TensorRT
(ms) | params
(M) | FLOPs
(B) |'
- separator = '|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|'
+ gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU"
+ header = f"| Model | size
(pixels) | mAPval
50-95 | Speed
CPU ONNX
(ms) | Speed
{gpu} TensorRT
(ms) | params
(M) | FLOPs
(B) |"
+ separator = "|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|"
- print(f'\n\n{header}')
+ print(f"\n\n{header}")
print(separator)
for row in table_rows:
print(row)
diff --git a/ultralytics/utils/callbacks/__init__.py b/ultralytics/utils/callbacks/__init__.py
index 8ad4ad6e..116babe9 100644
--- a/ultralytics/utils/callbacks/__init__.py
+++ b/ultralytics/utils/callbacks/__init__.py
@@ -2,4 +2,4 @@
from .base import add_integration_callbacks, default_callbacks, get_default_callbacks
-__all__ = 'add_integration_callbacks', 'default_callbacks', 'get_default_callbacks'
+__all__ = "add_integration_callbacks", "default_callbacks", "get_default_callbacks"
diff --git a/ultralytics/utils/callbacks/base.py b/ultralytics/utils/callbacks/base.py
index 211ae5bf..e1e9b42b 100644
--- a/ultralytics/utils/callbacks/base.py
+++ b/ultralytics/utils/callbacks/base.py
@@ -143,37 +143,35 @@ def on_export_end(exporter):
default_callbacks = {
# Run in trainer
- 'on_pretrain_routine_start': [on_pretrain_routine_start],
- 'on_pretrain_routine_end': [on_pretrain_routine_end],
- 'on_train_start': [on_train_start],
- 'on_train_epoch_start': [on_train_epoch_start],
- 'on_train_batch_start': [on_train_batch_start],
- 'optimizer_step': [optimizer_step],
- 'on_before_zero_grad': [on_before_zero_grad],
- 'on_train_batch_end': [on_train_batch_end],
- 'on_train_epoch_end': [on_train_epoch_end],
- 'on_fit_epoch_end': [on_fit_epoch_end], # fit = train + val
- 'on_model_save': [on_model_save],
- 'on_train_end': [on_train_end],
- 'on_params_update': [on_params_update],
- 'teardown': [teardown],
-
+ "on_pretrain_routine_start": [on_pretrain_routine_start],
+ "on_pretrain_routine_end": [on_pretrain_routine_end],
+ "on_train_start": [on_train_start],
+ "on_train_epoch_start": [on_train_epoch_start],
+ "on_train_batch_start": [on_train_batch_start],
+ "optimizer_step": [optimizer_step],
+ "on_before_zero_grad": [on_before_zero_grad],
+ "on_train_batch_end": [on_train_batch_end],
+ "on_train_epoch_end": [on_train_epoch_end],
+ "on_fit_epoch_end": [on_fit_epoch_end], # fit = train + val
+ "on_model_save": [on_model_save],
+ "on_train_end": [on_train_end],
+ "on_params_update": [on_params_update],
+ "teardown": [teardown],
# Run in validator
- 'on_val_start': [on_val_start],
- 'on_val_batch_start': [on_val_batch_start],
- 'on_val_batch_end': [on_val_batch_end],
- 'on_val_end': [on_val_end],
-
+ "on_val_start": [on_val_start],
+ "on_val_batch_start": [on_val_batch_start],
+ "on_val_batch_end": [on_val_batch_end],
+ "on_val_end": [on_val_end],
# Run in predictor
- 'on_predict_start': [on_predict_start],
- 'on_predict_batch_start': [on_predict_batch_start],
- 'on_predict_postprocess_end': [on_predict_postprocess_end],
- 'on_predict_batch_end': [on_predict_batch_end],
- 'on_predict_end': [on_predict_end],
-
+ "on_predict_start": [on_predict_start],
+ "on_predict_batch_start": [on_predict_batch_start],
+ "on_predict_postprocess_end": [on_predict_postprocess_end],
+ "on_predict_batch_end": [on_predict_batch_end],
+ "on_predict_end": [on_predict_end],
# Run in exporter
- 'on_export_start': [on_export_start],
- 'on_export_end': [on_export_end]}
+ "on_export_start": [on_export_start],
+ "on_export_end": [on_export_end],
+}
def get_default_callbacks():
@@ -197,10 +195,11 @@ def add_integration_callbacks(instance):
# Load HUB callbacks
from .hub import callbacks as hub_cb
+
callbacks_list = [hub_cb]
# Load training callbacks
- if 'Trainer' in instance.__class__.__name__:
+ if "Trainer" in instance.__class__.__name__:
from .clearml import callbacks as clear_cb
from .comet import callbacks as comet_cb
from .dvc import callbacks as dvc_cb
@@ -209,6 +208,7 @@ def add_integration_callbacks(instance):
from .raytune import callbacks as tune_cb
from .tensorboard import callbacks as tb_cb
from .wb import callbacks as wb_cb
+
callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb])
# Add the callbacks to the callbacks dictionary
diff --git a/ultralytics/utils/callbacks/clearml.py b/ultralytics/utils/callbacks/clearml.py
index dc0b2716..a030fc5e 100644
--- a/ultralytics/utils/callbacks/clearml.py
+++ b/ultralytics/utils/callbacks/clearml.py
@@ -4,19 +4,19 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
try:
assert not TESTS_RUNNING # do not log pytest
- assert SETTINGS['clearml'] is True # verify integration is enabled
+ assert SETTINGS["clearml"] is True # verify integration is enabled
import clearml
from clearml import Task
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
from clearml.binding.matplotlib_bind import PatchedMatplotlib
- assert hasattr(clearml, '__version__') # verify package is not directory
+ assert hasattr(clearml, "__version__") # verify package is not directory
except (ImportError, AssertionError):
clearml = None
-def _log_debug_samples(files, title='Debug Samples') -> None:
+def _log_debug_samples(files, title="Debug Samples") -> None:
"""
Log files (images) as debug samples in the ClearML task.
@@ -29,12 +29,11 @@ def _log_debug_samples(files, title='Debug Samples') -> None:
if task := Task.current_task():
for f in files:
if f.exists():
- it = re.search(r'_batch(\d+)', f.name)
+ it = re.search(r"_batch(\d+)", f.name)
iteration = int(it.groups()[0]) if it else 0
- task.get_logger().report_image(title=title,
- series=f.name.replace(it.group(), ''),
- local_path=str(f),
- iteration=iteration)
+ task.get_logger().report_image(
+ title=title, series=f.name.replace(it.group(), ""), local_path=str(f), iteration=iteration
+ )
def _log_plot(title, plot_path) -> None:
@@ -50,13 +49,12 @@ def _log_plot(title, plot_path) -> None:
img = mpimg.imread(plot_path)
fig = plt.figure()
- ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
+ ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
ax.imshow(img)
- Task.current_task().get_logger().report_matplotlib_figure(title=title,
- series='',
- figure=fig,
- report_interactive=False)
+ Task.current_task().get_logger().report_matplotlib_figure(
+ title=title, series="", figure=fig, report_interactive=False
+ )
def on_pretrain_routine_start(trainer):
@@ -68,19 +66,21 @@ def on_pretrain_routine_start(trainer):
PatchPyTorchModelIO.update_current_task(None)
PatchedMatplotlib.update_current_task(None)
else:
- task = Task.init(project_name=trainer.args.project or 'YOLOv8',
- task_name=trainer.args.name,
- tags=['YOLOv8'],
- output_uri=True,
- reuse_last_task_id=False,
- auto_connect_frameworks={
- 'pytorch': False,
- 'matplotlib': False})
- LOGGER.warning('ClearML Initialized a new task. If you want to run remotely, '
- 'please add clearml-init and connect your arguments before initializing YOLO.')
- task.connect(vars(trainer.args), name='General')
+ task = Task.init(
+ project_name=trainer.args.project or "YOLOv8",
+ task_name=trainer.args.name,
+ tags=["YOLOv8"],
+ output_uri=True,
+ reuse_last_task_id=False,
+ auto_connect_frameworks={"pytorch": False, "matplotlib": False},
+ )
+ LOGGER.warning(
+ "ClearML Initialized a new task. If you want to run remotely, "
+ "please add clearml-init and connect your arguments before initializing YOLO."
+ )
+ task.connect(vars(trainer.args), name="General")
except Exception as e:
- LOGGER.warning(f'WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}')
+ LOGGER.warning(f"WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}")
def on_train_epoch_end(trainer):
@@ -88,26 +88,26 @@ def on_train_epoch_end(trainer):
if task := Task.current_task():
# Log debug samples
if trainer.epoch == 1:
- _log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic')
+ _log_debug_samples(sorted(trainer.save_dir.glob("train_batch*.jpg")), "Mosaic")
# Report the current training progress
- for k, v in trainer.label_loss_items(trainer.tloss, prefix='train').items():
- task.get_logger().report_scalar('train', k, v, iteration=trainer.epoch)
+ for k, v in trainer.label_loss_items(trainer.tloss, prefix="train").items():
+ task.get_logger().report_scalar("train", k, v, iteration=trainer.epoch)
for k, v in trainer.lr.items():
- task.get_logger().report_scalar('lr', k, v, iteration=trainer.epoch)
+ task.get_logger().report_scalar("lr", k, v, iteration=trainer.epoch)
def on_fit_epoch_end(trainer):
"""Reports model information to logger at the end of an epoch."""
if task := Task.current_task():
# You should have access to the validation bboxes under jdict
- task.get_logger().report_scalar(title='Epoch Time',
- series='Epoch Time',
- value=trainer.epoch_time,
- iteration=trainer.epoch)
+ task.get_logger().report_scalar(
+ title="Epoch Time", series="Epoch Time", value=trainer.epoch_time, iteration=trainer.epoch
+ )
for k, v in trainer.metrics.items():
- task.get_logger().report_scalar('val', k, v, iteration=trainer.epoch)
+ task.get_logger().report_scalar("val", k, v, iteration=trainer.epoch)
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
+
for k, v in model_info_for_loggers(trainer).items():
task.get_logger().report_single_value(k, v)
@@ -116,7 +116,7 @@ def on_val_end(validator):
"""Logs validation results including labels and predictions."""
if Task.current_task():
# Log val_labels and val_pred
- _log_debug_samples(sorted(validator.save_dir.glob('val*.jpg')), 'Validation')
+ _log_debug_samples(sorted(validator.save_dir.glob("val*.jpg")), "Validation")
def on_train_end(trainer):
@@ -124,8 +124,11 @@ def on_train_end(trainer):
if task := Task.current_task():
# Log final results, CM matrix + PR plots
files = [
- 'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
- *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
+ "results.png",
+ "confusion_matrix.png",
+ "confusion_matrix_normalized.png",
+ *(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
+ ]
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
for f in files:
_log_plot(title=f.stem, plot_path=f)
@@ -136,9 +139,14 @@ def on_train_end(trainer):
task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False)
-callbacks = {
- 'on_pretrain_routine_start': on_pretrain_routine_start,
- 'on_train_epoch_end': on_train_epoch_end,
- 'on_fit_epoch_end': on_fit_epoch_end,
- 'on_val_end': on_val_end,
- 'on_train_end': on_train_end} if clearml else {}
+callbacks = (
+ {
+ "on_pretrain_routine_start": on_pretrain_routine_start,
+ "on_train_epoch_end": on_train_epoch_end,
+ "on_fit_epoch_end": on_fit_epoch_end,
+ "on_val_end": on_val_end,
+ "on_train_end": on_train_end,
+ }
+ if clearml
+ else {}
+)
diff --git a/ultralytics/utils/callbacks/comet.py b/ultralytics/utils/callbacks/comet.py
index e8016f4e..2ed60577 100644
--- a/ultralytics/utils/callbacks/comet.py
+++ b/ultralytics/utils/callbacks/comet.py
@@ -4,20 +4,20 @@ from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops
try:
assert not TESTS_RUNNING # do not log pytest
- assert SETTINGS['comet'] is True # verify integration is enabled
+ assert SETTINGS["comet"] is True # verify integration is enabled
import comet_ml
- assert hasattr(comet_ml, '__version__') # verify package is not directory
+ assert hasattr(comet_ml, "__version__") # verify package is not directory
import os
from pathlib import Path
# Ensures certain logging functions only run for supported tasks
- COMET_SUPPORTED_TASKS = ['detect']
+ COMET_SUPPORTED_TASKS = ["detect"]
# Names of plots created by YOLOv8 that are logged to Comet
- EVALUATION_PLOT_NAMES = 'F1_curve', 'P_curve', 'R_curve', 'PR_curve', 'confusion_matrix'
- LABEL_PLOT_NAMES = 'labels', 'labels_correlogram'
+ EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve", "confusion_matrix"
+ LABEL_PLOT_NAMES = "labels", "labels_correlogram"
_comet_image_prediction_count = 0
@@ -27,43 +27,43 @@ except (ImportError, AssertionError):
def _get_comet_mode():
"""Returns the mode of comet set in the environment variables, defaults to 'online' if not set."""
- return os.getenv('COMET_MODE', 'online')
+ return os.getenv("COMET_MODE", "online")
def _get_comet_model_name():
"""Returns the model name for Comet from the environment variable 'COMET_MODEL_NAME' or defaults to 'YOLOv8'."""
- return os.getenv('COMET_MODEL_NAME', 'YOLOv8')
+ return os.getenv("COMET_MODEL_NAME", "YOLOv8")
def _get_eval_batch_logging_interval():
"""Get the evaluation batch logging interval from environment variable or use default value 1."""
- return int(os.getenv('COMET_EVAL_BATCH_LOGGING_INTERVAL', 1))
+ return int(os.getenv("COMET_EVAL_BATCH_LOGGING_INTERVAL", 1))
def _get_max_image_predictions_to_log():
"""Get the maximum number of image predictions to log from the environment variables."""
- return int(os.getenv('COMET_MAX_IMAGE_PREDICTIONS', 100))
+ return int(os.getenv("COMET_MAX_IMAGE_PREDICTIONS", 100))
def _scale_confidence_score(score):
"""Scales the given confidence score by a factor specified in an environment variable."""
- scale = float(os.getenv('COMET_MAX_CONFIDENCE_SCORE', 100.0))
+ scale = float(os.getenv("COMET_MAX_CONFIDENCE_SCORE", 100.0))
return score * scale
def _should_log_confusion_matrix():
"""Determines if the confusion matrix should be logged based on the environment variable settings."""
- return os.getenv('COMET_EVAL_LOG_CONFUSION_MATRIX', 'false').lower() == 'true'
+ return os.getenv("COMET_EVAL_LOG_CONFUSION_MATRIX", "false").lower() == "true"
def _should_log_image_predictions():
"""Determines whether to log image predictions based on a specified environment variable."""
- return os.getenv('COMET_EVAL_LOG_IMAGE_PREDICTIONS', 'true').lower() == 'true'
+ return os.getenv("COMET_EVAL_LOG_IMAGE_PREDICTIONS", "true").lower() == "true"
def _get_experiment_type(mode, project_name):
"""Return an experiment based on mode and project name."""
- if mode == 'offline':
+ if mode == "offline":
return comet_ml.OfflineExperiment(project_name=project_name)
return comet_ml.Experiment(project_name=project_name)
@@ -75,18 +75,21 @@ def _create_experiment(args):
return
try:
comet_mode = _get_comet_mode()
- _project_name = os.getenv('COMET_PROJECT_NAME', args.project)
+ _project_name = os.getenv("COMET_PROJECT_NAME", args.project)
experiment = _get_experiment_type(comet_mode, _project_name)
experiment.log_parameters(vars(args))
- experiment.log_others({
- 'eval_batch_logging_interval': _get_eval_batch_logging_interval(),
- 'log_confusion_matrix_on_eval': _should_log_confusion_matrix(),
- 'log_image_predictions': _should_log_image_predictions(),
- 'max_image_predictions': _get_max_image_predictions_to_log(), })
- experiment.log_other('Created from', 'yolov8')
+ experiment.log_others(
+ {
+ "eval_batch_logging_interval": _get_eval_batch_logging_interval(),
+ "log_confusion_matrix_on_eval": _should_log_confusion_matrix(),
+ "log_image_predictions": _should_log_image_predictions(),
+ "max_image_predictions": _get_max_image_predictions_to_log(),
+ }
+ )
+ experiment.log_other("Created from", "yolov8")
except Exception as e:
- LOGGER.warning(f'WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}')
+ LOGGER.warning(f"WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}")
def _fetch_trainer_metadata(trainer):
@@ -134,29 +137,32 @@ def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, origin
def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None):
"""Format ground truth annotations for detection."""
- indices = batch['batch_idx'] == img_idx
- bboxes = batch['bboxes'][indices]
+ indices = batch["batch_idx"] == img_idx
+ bboxes = batch["bboxes"][indices]
if len(bboxes) == 0:
- LOGGER.debug(f'COMET WARNING: Image: {image_path} has no bounding boxes labels')
+ LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes labels")
return None
- cls_labels = batch['cls'][indices].squeeze(1).tolist()
+ cls_labels = batch["cls"][indices].squeeze(1).tolist()
if class_name_map:
cls_labels = [str(class_name_map[label]) for label in cls_labels]
- original_image_shape = batch['ori_shape'][img_idx]
- resized_image_shape = batch['resized_shape'][img_idx]
- ratio_pad = batch['ratio_pad'][img_idx]
+ original_image_shape = batch["ori_shape"][img_idx]
+ resized_image_shape = batch["resized_shape"][img_idx]
+ ratio_pad = batch["ratio_pad"][img_idx]
data = []
for box, label in zip(bboxes, cls_labels):
box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad)
- data.append({
- 'boxes': [box],
- 'label': f'gt_{label}',
- 'score': _scale_confidence_score(1.0), })
+ data.append(
+ {
+ "boxes": [box],
+ "label": f"gt_{label}",
+ "score": _scale_confidence_score(1.0),
+ }
+ )
- return {'name': 'ground_truth', 'data': data}
+ return {"name": "ground_truth", "data": data}
def _format_prediction_annotations_for_detection(image_path, metadata, class_label_map=None):
@@ -166,31 +172,34 @@ def _format_prediction_annotations_for_detection(image_path, metadata, class_lab
predictions = metadata.get(image_id)
if not predictions:
- LOGGER.debug(f'COMET WARNING: Image: {image_path} has no bounding boxes predictions')
+ LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes predictions")
return None
data = []
for prediction in predictions:
- boxes = prediction['bbox']
- score = _scale_confidence_score(prediction['score'])
- cls_label = prediction['category_id']
+ boxes = prediction["bbox"]
+ score = _scale_confidence_score(prediction["score"])
+ cls_label = prediction["category_id"]
if class_label_map:
cls_label = str(class_label_map[cls_label])
- data.append({'boxes': [boxes], 'label': cls_label, 'score': score})
+ data.append({"boxes": [boxes], "label": cls_label, "score": score})
- return {'name': 'prediction', 'data': data}
+ return {"name": "prediction", "data": data}
def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map):
"""Join the ground truth and prediction annotations if they exist."""
- ground_truth_annotations = _format_ground_truth_annotations_for_detection(img_idx, image_path, batch,
- class_label_map)
- prediction_annotations = _format_prediction_annotations_for_detection(image_path, prediction_metadata_map,
- class_label_map)
+ ground_truth_annotations = _format_ground_truth_annotations_for_detection(
+ img_idx, image_path, batch, class_label_map
+ )
+ prediction_annotations = _format_prediction_annotations_for_detection(
+ image_path, prediction_metadata_map, class_label_map
+ )
annotations = [
- annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None]
+ annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None
+ ]
return [annotations] if annotations else None
@@ -198,8 +207,8 @@ def _create_prediction_metadata_map(model_predictions):
"""Create metadata map for model predictions by groupings them based on image ID."""
pred_metadata_map = {}
for prediction in model_predictions:
- pred_metadata_map.setdefault(prediction['image_id'], [])
- pred_metadata_map[prediction['image_id']].append(prediction)
+ pred_metadata_map.setdefault(prediction["image_id"], [])
+ pred_metadata_map[prediction["image_id"]].append(prediction)
return pred_metadata_map
@@ -207,7 +216,7 @@ def _create_prediction_metadata_map(model_predictions):
def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch):
"""Log the confusion matrix to Comet experiment."""
conf_mat = trainer.validator.confusion_matrix.matrix
- names = list(trainer.data['names'].values()) + ['background']
+ names = list(trainer.data["names"].values()) + ["background"]
experiment.log_confusion_matrix(
matrix=conf_mat,
labels=names,
@@ -251,7 +260,7 @@ def _log_image_predictions(experiment, validator, curr_step):
if (batch_idx + 1) % batch_logging_interval != 0:
continue
- image_paths = batch['im_file']
+ image_paths = batch["im_file"]
for img_idx, image_path in enumerate(image_paths):
if _comet_image_prediction_count >= max_image_predictions:
return
@@ -275,10 +284,10 @@ def _log_image_predictions(experiment, validator, curr_step):
def _log_plots(experiment, trainer):
"""Logs evaluation plots and label plots for the experiment."""
- plot_filenames = [trainer.save_dir / f'{plots}.png' for plots in EVALUATION_PLOT_NAMES]
+ plot_filenames = [trainer.save_dir / f"{plots}.png" for plots in EVALUATION_PLOT_NAMES]
_log_images(experiment, plot_filenames, None)
- label_plot_filenames = [trainer.save_dir / f'{labels}.jpg' for labels in LABEL_PLOT_NAMES]
+ label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES]
_log_images(experiment, label_plot_filenames, None)
@@ -288,7 +297,7 @@ def _log_model(experiment, trainer):
experiment.log_model(
model_name,
file_or_folder=str(trainer.best),
- file_name='best.pt',
+ file_name="best.pt",
overwrite=True,
)
@@ -296,7 +305,7 @@ def _log_model(experiment, trainer):
def on_pretrain_routine_start(trainer):
"""Creates or resumes a CometML experiment at the start of a YOLO pre-training routine."""
experiment = comet_ml.get_global_experiment()
- is_alive = getattr(experiment, 'alive', False)
+ is_alive = getattr(experiment, "alive", False)
if not experiment or not is_alive:
_create_experiment(trainer.args)
@@ -308,17 +317,17 @@ def on_train_epoch_end(trainer):
return
metadata = _fetch_trainer_metadata(trainer)
- curr_epoch = metadata['curr_epoch']
- curr_step = metadata['curr_step']
+ curr_epoch = metadata["curr_epoch"]
+ curr_step = metadata["curr_step"]
experiment.log_metrics(
- trainer.label_loss_items(trainer.tloss, prefix='train'),
+ trainer.label_loss_items(trainer.tloss, prefix="train"),
step=curr_step,
epoch=curr_epoch,
)
if curr_epoch == 1:
- _log_images(experiment, trainer.save_dir.glob('train_batch*.jpg'), curr_step)
+ _log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step)
def on_fit_epoch_end(trainer):
@@ -328,14 +337,15 @@ def on_fit_epoch_end(trainer):
return
metadata = _fetch_trainer_metadata(trainer)
- curr_epoch = metadata['curr_epoch']
- curr_step = metadata['curr_step']
- save_assets = metadata['save_assets']
+ curr_epoch = metadata["curr_epoch"]
+ curr_step = metadata["curr_step"]
+ save_assets = metadata["save_assets"]
experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch)
experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch)
if curr_epoch == 1:
from ultralytics.utils.torch_utils import model_info_for_loggers
+
experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch)
if not save_assets:
@@ -355,8 +365,8 @@ def on_train_end(trainer):
return
metadata = _fetch_trainer_metadata(trainer)
- curr_epoch = metadata['curr_epoch']
- curr_step = metadata['curr_step']
+ curr_epoch = metadata["curr_epoch"]
+ curr_step = metadata["curr_step"]
plots = trainer.args.plots
_log_model(experiment, trainer)
@@ -371,8 +381,13 @@ def on_train_end(trainer):
_comet_image_prediction_count = 0
-callbacks = {
- 'on_pretrain_routine_start': on_pretrain_routine_start,
- 'on_train_epoch_end': on_train_epoch_end,
- 'on_fit_epoch_end': on_fit_epoch_end,
- 'on_train_end': on_train_end} if comet_ml else {}
+callbacks = (
+ {
+ "on_pretrain_routine_start": on_pretrain_routine_start,
+ "on_train_epoch_end": on_train_epoch_end,
+ "on_fit_epoch_end": on_fit_epoch_end,
+ "on_train_end": on_train_end,
+ }
+ if comet_ml
+ else {}
+)
diff --git a/ultralytics/utils/callbacks/dvc.py b/ultralytics/utils/callbacks/dvc.py
index 7fa05c6b..ab51dc52 100644
--- a/ultralytics/utils/callbacks/dvc.py
+++ b/ultralytics/utils/callbacks/dvc.py
@@ -4,9 +4,10 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks
try:
assert not TESTS_RUNNING # do not log pytest
- assert SETTINGS['dvc'] is True # verify integration is enabled
+ assert SETTINGS["dvc"] is True # verify integration is enabled
import dvclive
- assert checks.check_version('dvclive', '2.11.0', verbose=True)
+
+ assert checks.check_version("dvclive", "2.11.0", verbose=True)
import os
import re
@@ -24,24 +25,24 @@ except (ImportError, AssertionError, TypeError):
dvclive = None
-def _log_images(path, prefix=''):
+def _log_images(path, prefix=""):
"""Logs images at specified path with an optional prefix using DVCLive."""
if live:
name = path.name
# Group images by batch to enable sliders in UI
- if m := re.search(r'_batch(\d+)', name):
+ if m := re.search(r"_batch(\d+)", name):
ni = m[1]
- new_stem = re.sub(r'_batch(\d+)', '_batch', path.stem)
+ new_stem = re.sub(r"_batch(\d+)", "_batch", path.stem)
name = (Path(new_stem) / ni).with_suffix(path.suffix)
live.log_image(os.path.join(prefix, name), path)
-def _log_plots(plots, prefix=''):
+def _log_plots(plots, prefix=""):
"""Logs plot images for training progress if they have not been previously processed."""
for name, params in plots.items():
- timestamp = params['timestamp']
+ timestamp = params["timestamp"]
if _processed_plots.get(name) != timestamp:
_log_images(name, prefix)
_processed_plots[name] = timestamp
@@ -53,15 +54,15 @@ def _log_confusion_matrix(validator):
preds = []
matrix = validator.confusion_matrix.matrix
names = list(validator.names.values())
- if validator.confusion_matrix.task == 'detect':
- names += ['background']
+ if validator.confusion_matrix.task == "detect":
+ names += ["background"]
for ti, pred in enumerate(matrix.T.astype(int)):
for pi, num in enumerate(pred):
targets.extend([names[ti]] * num)
preds.extend([names[pi]] * num)
- live.log_sklearn_plot('confusion_matrix', targets, preds, name='cf.json', normalized=True)
+ live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True)
def on_pretrain_routine_start(trainer):
@@ -71,12 +72,12 @@ def on_pretrain_routine_start(trainer):
live = dvclive.Live(save_dvc_exp=True, cache_images=True)
LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).")
except Exception as e:
- LOGGER.warning(f'WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}')
+ LOGGER.warning(f"WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}")
def on_pretrain_routine_end(trainer):
"""Logs plots related to the training process at the end of the pretraining routine."""
- _log_plots(trainer.plots, 'train')
+ _log_plots(trainer.plots, "train")
def on_train_start(trainer):
@@ -95,17 +96,18 @@ def on_fit_epoch_end(trainer):
"""Logs training metrics and model info, and advances to next step on the end of each fit epoch."""
global _training_epoch
if live and _training_epoch:
- all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
+ all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
for metric, value in all_metrics.items():
live.log_metric(metric, value)
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
+
for metric, value in model_info_for_loggers(trainer).items():
live.log_metric(metric, value, plot=False)
- _log_plots(trainer.plots, 'train')
- _log_plots(trainer.validator.plots, 'val')
+ _log_plots(trainer.plots, "train")
+ _log_plots(trainer.validator.plots, "val")
live.next_step()
_training_epoch = False
@@ -115,24 +117,29 @@ def on_train_end(trainer):
"""Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active."""
if live:
# At the end log the best metrics. It runs validator on the best model internally.
- all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
+ all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
for metric, value in all_metrics.items():
live.log_metric(metric, value, plot=False)
- _log_plots(trainer.plots, 'val')
- _log_plots(trainer.validator.plots, 'val')
+ _log_plots(trainer.plots, "val")
+ _log_plots(trainer.validator.plots, "val")
_log_confusion_matrix(trainer.validator)
if trainer.best.exists():
- live.log_artifact(trainer.best, copy=True, type='model')
+ live.log_artifact(trainer.best, copy=True, type="model")
live.end()
-callbacks = {
- 'on_pretrain_routine_start': on_pretrain_routine_start,
- 'on_pretrain_routine_end': on_pretrain_routine_end,
- 'on_train_start': on_train_start,
- 'on_train_epoch_start': on_train_epoch_start,
- 'on_fit_epoch_end': on_fit_epoch_end,
- 'on_train_end': on_train_end} if dvclive else {}
+callbacks = (
+ {
+ "on_pretrain_routine_start": on_pretrain_routine_start,
+ "on_pretrain_routine_end": on_pretrain_routine_end,
+ "on_train_start": on_train_start,
+ "on_train_epoch_start": on_train_epoch_start,
+ "on_fit_epoch_end": on_fit_epoch_end,
+ "on_train_end": on_train_end,
+ }
+ if dvclive
+ else {}
+)
diff --git a/ultralytics/utils/callbacks/hub.py b/ultralytics/utils/callbacks/hub.py
index f3a3f353..60c74161 100644
--- a/ultralytics/utils/callbacks/hub.py
+++ b/ultralytics/utils/callbacks/hub.py
@@ -11,60 +11,62 @@ from ultralytics.utils import LOGGER, SETTINGS
def on_pretrain_routine_end(trainer):
"""Logs info before starting timer for upload rate limit."""
- session = getattr(trainer, 'hub_session', None)
+ session = getattr(trainer, "hub_session", None)
if session:
# Start timer for upload rate limit
session.timers = {
- 'metrics': time(),
- 'ckpt': time(), } # start timer on session.rate_limit
+ "metrics": time(),
+ "ckpt": time(),
+ } # start timer on session.rate_limit
def on_fit_epoch_end(trainer):
"""Uploads training progress metrics at the end of each epoch."""
- session = getattr(trainer, 'hub_session', None)
+ session = getattr(trainer, "hub_session", None)
if session:
# Upload metrics after val end
all_plots = {
- **trainer.label_loss_items(trainer.tloss, prefix='train'),
- **trainer.metrics, }
+ **trainer.label_loss_items(trainer.tloss, prefix="train"),
+ **trainer.metrics,
+ }
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
+
all_plots = {**all_plots, **model_info_for_loggers(trainer)}
session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
- if time() - session.timers['metrics'] > session.rate_limits['metrics']:
+ if time() - session.timers["metrics"] > session.rate_limits["metrics"]:
session.upload_metrics()
- session.timers['metrics'] = time() # reset timer
+ session.timers["metrics"] = time() # reset timer
session.metrics_queue = {} # reset queue
def on_model_save(trainer):
"""Saves checkpoints to Ultralytics HUB with rate limiting."""
- session = getattr(trainer, 'hub_session', None)
+ session = getattr(trainer, "hub_session", None)
if session:
# Upload checkpoints with rate limiting
is_best = trainer.best_fitness == trainer.fitness
- if time() - session.timers['ckpt'] > session.rate_limits['ckpt']:
- LOGGER.info(f'{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_file}')
+ if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]:
+ LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_file}")
session.upload_model(trainer.epoch, trainer.last, is_best)
- session.timers['ckpt'] = time() # reset timer
+ session.timers["ckpt"] = time() # reset timer
def on_train_end(trainer):
"""Upload final model and metrics to Ultralytics HUB at the end of training."""
- session = getattr(trainer, 'hub_session', None)
+ session = getattr(trainer, "hub_session", None)
if session:
# Upload final model and metrics with exponential standoff
- LOGGER.info(f'{PREFIX}Syncing final model...')
+ LOGGER.info(f"{PREFIX}Syncing final model...")
session.upload_model(
trainer.epoch,
trainer.best,
- map=trainer.metrics.get('metrics/mAP50-95(B)', 0),
+ map=trainer.metrics.get("metrics/mAP50-95(B)", 0),
final=True,
)
session.alive = False # stop heartbeats
- LOGGER.info(f'{PREFIX}Done ✅\n'
- f'{PREFIX}View model at {session.model_url} 🚀')
+ LOGGER.info(f"{PREFIX}Done ✅\n" f"{PREFIX}View model at {session.model_url} 🚀")
def on_train_start(trainer):
@@ -87,12 +89,17 @@ def on_export_start(exporter):
events(exporter.args)
-callbacks = ({
- 'on_pretrain_routine_end': on_pretrain_routine_end,
- 'on_fit_epoch_end': on_fit_epoch_end,
- 'on_model_save': on_model_save,
- 'on_train_end': on_train_end,
- 'on_train_start': on_train_start,
- 'on_val_start': on_val_start,
- 'on_predict_start': on_predict_start,
- 'on_export_start': on_export_start, } if SETTINGS['hub'] is True else {}) # verify enabled
+callbacks = (
+ {
+ "on_pretrain_routine_end": on_pretrain_routine_end,
+ "on_fit_epoch_end": on_fit_epoch_end,
+ "on_model_save": on_model_save,
+ "on_train_end": on_train_end,
+ "on_train_start": on_train_start,
+ "on_val_start": on_val_start,
+ "on_predict_start": on_predict_start,
+ "on_export_start": on_export_start,
+ }
+ if SETTINGS["hub"] is True
+ else {}
+) # verify enabled
diff --git a/ultralytics/utils/callbacks/mlflow.py b/ultralytics/utils/callbacks/mlflow.py
index 91dcff5e..9676fe85 100644
--- a/ultralytics/utils/callbacks/mlflow.py
+++ b/ultralytics/utils/callbacks/mlflow.py
@@ -26,15 +26,15 @@ from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorst
try:
import os
- assert not TESTS_RUNNING or 'test_mlflow' in os.environ.get('PYTEST_CURRENT_TEST', '') # do not log pytest
- assert SETTINGS['mlflow'] is True # verify integration is enabled
+ assert not TESTS_RUNNING or "test_mlflow" in os.environ.get("PYTEST_CURRENT_TEST", "") # do not log pytest
+ assert SETTINGS["mlflow"] is True # verify integration is enabled
import mlflow
- assert hasattr(mlflow, '__version__') # verify package is not directory
+ assert hasattr(mlflow, "__version__") # verify package is not directory
from pathlib import Path
- PREFIX = colorstr('MLflow: ')
- SANITIZE = lambda x: {k.replace('(', '').replace(')', ''): float(v) for k, v in x.items()}
+ PREFIX = colorstr("MLflow: ")
+ SANITIZE = lambda x: {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()}
except (ImportError, AssertionError):
mlflow = None
@@ -61,33 +61,33 @@ def on_pretrain_routine_end(trainer):
"""
global mlflow
- uri = os.environ.get('MLFLOW_TRACKING_URI') or str(RUNS_DIR / 'mlflow')
- LOGGER.debug(f'{PREFIX} tracking uri: {uri}')
+ uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow")
+ LOGGER.debug(f"{PREFIX} tracking uri: {uri}")
mlflow.set_tracking_uri(uri)
# Set experiment and run names
- experiment_name = os.environ.get('MLFLOW_EXPERIMENT_NAME') or trainer.args.project or '/Shared/YOLOv8'
- run_name = os.environ.get('MLFLOW_RUN') or trainer.args.name
+ experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/YOLOv8"
+ run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name
mlflow.set_experiment(experiment_name)
mlflow.autolog()
try:
active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name)
- LOGGER.info(f'{PREFIX}logging run_id({active_run.info.run_id}) to {uri}')
+ LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}")
if Path(uri).is_dir():
LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'")
LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'")
mlflow.log_params(dict(trainer.args))
except Exception as e:
- LOGGER.warning(f'{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n'
- f'{PREFIX}WARNING ⚠️ Not tracking this run')
+ LOGGER.warning(f"{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n" f"{PREFIX}WARNING ⚠️ Not tracking this run")
def on_train_epoch_end(trainer):
"""Log training metrics at the end of each train epoch to MLflow."""
if mlflow:
- mlflow.log_metrics(metrics=SANITIZE(trainer.label_loss_items(trainer.tloss, prefix='train')),
- step=trainer.epoch)
+ mlflow.log_metrics(
+ metrics=SANITIZE(trainer.label_loss_items(trainer.tloss, prefix="train")), step=trainer.epoch
+ )
mlflow.log_metrics(metrics=SANITIZE(trainer.lr), step=trainer.epoch)
@@ -101,16 +101,23 @@ def on_train_end(trainer):
"""Log model artifacts at the end of the training."""
if mlflow:
mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt
- for f in trainer.save_dir.glob('*'): # log all other files in save_dir
- if f.suffix in {'.png', '.jpg', '.csv', '.pt', '.yaml'}:
+ for f in trainer.save_dir.glob("*"): # log all other files in save_dir
+ if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
mlflow.log_artifact(str(f))
mlflow.end_run()
- LOGGER.info(f'{PREFIX}results logged to {mlflow.get_tracking_uri()}\n'
- f"{PREFIX}disable with 'yolo settings mlflow=False'")
-
-
-callbacks = {
- 'on_pretrain_routine_end': on_pretrain_routine_end,
- 'on_fit_epoch_end': on_fit_epoch_end,
- 'on_train_end': on_train_end} if mlflow else {}
+ LOGGER.info(
+ f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n"
+ f"{PREFIX}disable with 'yolo settings mlflow=False'"
+ )
+
+
+callbacks = (
+ {
+ "on_pretrain_routine_end": on_pretrain_routine_end,
+ "on_fit_epoch_end": on_fit_epoch_end,
+ "on_train_end": on_train_end,
+ }
+ if mlflow
+ else {}
+)
diff --git a/ultralytics/utils/callbacks/neptune.py b/ultralytics/utils/callbacks/neptune.py
index 088e3f8e..60c85371 100644
--- a/ultralytics/utils/callbacks/neptune.py
+++ b/ultralytics/utils/callbacks/neptune.py
@@ -4,11 +4,11 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
try:
assert not TESTS_RUNNING # do not log pytest
- assert SETTINGS['neptune'] is True # verify integration is enabled
+ assert SETTINGS["neptune"] is True # verify integration is enabled
import neptune
from neptune.types import File
- assert hasattr(neptune, '__version__')
+ assert hasattr(neptune, "__version__")
run = None # NeptuneAI experiment logger instance
@@ -23,11 +23,11 @@ def _log_scalars(scalars, step=0):
run[k].append(value=v, step=step)
-def _log_images(imgs_dict, group=''):
+def _log_images(imgs_dict, group=""):
"""Log scalars to the NeptuneAI experiment logger."""
if run:
for k, v in imgs_dict.items():
- run[f'{group}/{k}'].upload(File(v))
+ run[f"{group}/{k}"].upload(File(v))
def _log_plot(title, plot_path):
@@ -43,34 +43,35 @@ def _log_plot(title, plot_path):
img = mpimg.imread(plot_path)
fig = plt.figure()
- ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
+ ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
ax.imshow(img)
- run[f'Plots/{title}'].upload(fig)
+ run[f"Plots/{title}"].upload(fig)
def on_pretrain_routine_start(trainer):
"""Callback function called before the training routine starts."""
try:
global run
- run = neptune.init_run(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, tags=['YOLOv8'])
- run['Configuration/Hyperparameters'] = {k: '' if v is None else v for k, v in vars(trainer.args).items()}
+ run = neptune.init_run(project=trainer.args.project or "YOLOv8", name=trainer.args.name, tags=["YOLOv8"])
+ run["Configuration/Hyperparameters"] = {k: "" if v is None else v for k, v in vars(trainer.args).items()}
except Exception as e:
- LOGGER.warning(f'WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}')
+ LOGGER.warning(f"WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}")
def on_train_epoch_end(trainer):
"""Callback function called at end of each training epoch."""
- _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
+ _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
_log_scalars(trainer.lr, trainer.epoch + 1)
if trainer.epoch == 1:
- _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic')
+ _log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic")
def on_fit_epoch_end(trainer):
"""Callback function called at end of each fit (train+val) epoch."""
if run and trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
- run['Configuration/Model'] = model_info_for_loggers(trainer)
+
+ run["Configuration/Model"] = model_info_for_loggers(trainer)
_log_scalars(trainer.metrics, trainer.epoch + 1)
@@ -78,7 +79,7 @@ def on_val_end(validator):
"""Callback function called at end of each validation."""
if run:
# Log val_labels and val_pred
- _log_images({f.stem: str(f) for f in validator.save_dir.glob('val*.jpg')}, 'Validation')
+ _log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation")
def on_train_end(trainer):
@@ -86,19 +87,28 @@ def on_train_end(trainer):
if run:
# Log final results, CM matrix + PR plots
files = [
- 'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
- *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
+ "results.png",
+ "confusion_matrix.png",
+ "confusion_matrix_normalized.png",
+ *(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
+ ]
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
for f in files:
_log_plot(title=f.stem, plot_path=f)
# Log the final model
- run[f'weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}'].upload(File(str(
- trainer.best)))
-
-
-callbacks = {
- 'on_pretrain_routine_start': on_pretrain_routine_start,
- 'on_train_epoch_end': on_train_epoch_end,
- 'on_fit_epoch_end': on_fit_epoch_end,
- 'on_val_end': on_val_end,
- 'on_train_end': on_train_end} if neptune else {}
+ run[f"weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}"].upload(
+ File(str(trainer.best))
+ )
+
+
+callbacks = (
+ {
+ "on_pretrain_routine_start": on_pretrain_routine_start,
+ "on_train_epoch_end": on_train_epoch_end,
+ "on_fit_epoch_end": on_fit_epoch_end,
+ "on_val_end": on_val_end,
+ "on_train_end": on_train_end,
+ }
+ if neptune
+ else {}
+)
diff --git a/ultralytics/utils/callbacks/raytune.py b/ultralytics/utils/callbacks/raytune.py
index 417b3314..f2694554 100644
--- a/ultralytics/utils/callbacks/raytune.py
+++ b/ultralytics/utils/callbacks/raytune.py
@@ -3,7 +3,7 @@
from ultralytics.utils import SETTINGS
try:
- assert SETTINGS['raytune'] is True # verify integration is enabled
+ assert SETTINGS["raytune"] is True # verify integration is enabled
import ray
from ray import tune
from ray.air import session
@@ -16,9 +16,14 @@ def on_fit_epoch_end(trainer):
"""Sends training metrics to Ray Tune at end of each epoch."""
if ray.tune.is_session_enabled():
metrics = trainer.metrics
- metrics['epoch'] = trainer.epoch
+ metrics["epoch"] = trainer.epoch
session.report(metrics)
-callbacks = {
- 'on_fit_epoch_end': on_fit_epoch_end, } if tune else {}
+callbacks = (
+ {
+ "on_fit_epoch_end": on_fit_epoch_end,
+ }
+ if tune
+ else {}
+)
diff --git a/ultralytics/utils/callbacks/tensorboard.py b/ultralytics/utils/callbacks/tensorboard.py
index 86a50667..0a39d094 100644
--- a/ultralytics/utils/callbacks/tensorboard.py
+++ b/ultralytics/utils/callbacks/tensorboard.py
@@ -7,7 +7,7 @@ try:
from torch.utils.tensorboard import SummaryWriter
assert not TESTS_RUNNING # do not log pytest
- assert SETTINGS['tensorboard'] is True # verify integration is enabled
+ assert SETTINGS["tensorboard"] is True # verify integration is enabled
WRITER = None # TensorBoard SummaryWriter instance
except (ImportError, AssertionError, TypeError):
@@ -34,10 +34,10 @@ def _log_tensorboard_graph(trainer):
p = next(trainer.model.parameters()) # for device, type
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
with warnings.catch_warnings():
- warnings.simplefilter('ignore', category=UserWarning) # suppress jit trace warning
+ warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
except Exception as e:
- LOGGER.warning(f'WARNING ⚠️ TensorBoard graph visualization failure {e}')
+ LOGGER.warning(f"WARNING ⚠️ TensorBoard graph visualization failure {e}")
def on_pretrain_routine_start(trainer):
@@ -46,10 +46,10 @@ def on_pretrain_routine_start(trainer):
try:
global WRITER
WRITER = SummaryWriter(str(trainer.save_dir))
- prefix = colorstr('TensorBoard: ')
+ prefix = colorstr("TensorBoard: ")
LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
except Exception as e:
- LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}')
+ LOGGER.warning(f"WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
def on_train_start(trainer):
@@ -60,7 +60,7 @@ def on_train_start(trainer):
def on_train_epoch_end(trainer):
"""Logs scalar statistics at the end of a training epoch."""
- _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
+ _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
_log_scalars(trainer.lr, trainer.epoch + 1)
@@ -69,8 +69,13 @@ def on_fit_epoch_end(trainer):
_log_scalars(trainer.metrics, trainer.epoch + 1)
-callbacks = {
- 'on_pretrain_routine_start': on_pretrain_routine_start,
- 'on_train_start': on_train_start,
- 'on_fit_epoch_end': on_fit_epoch_end,
- 'on_train_epoch_end': on_train_epoch_end} if SummaryWriter else {}
+callbacks = (
+ {
+ "on_pretrain_routine_start": on_pretrain_routine_start,
+ "on_train_start": on_train_start,
+ "on_fit_epoch_end": on_fit_epoch_end,
+ "on_train_epoch_end": on_train_epoch_end,
+ }
+ if SummaryWriter
+ else {}
+)
diff --git a/ultralytics/utils/callbacks/wb.py b/ultralytics/utils/callbacks/wb.py
index 88f9bd7c..7f0f57ac 100644
--- a/ultralytics/utils/callbacks/wb.py
+++ b/ultralytics/utils/callbacks/wb.py
@@ -5,10 +5,10 @@ from ultralytics.utils.torch_utils import model_info_for_loggers
try:
assert not TESTS_RUNNING # do not log pytest
- assert SETTINGS['wandb'] is True # verify integration is enabled
+ assert SETTINGS["wandb"] is True # verify integration is enabled
import wandb as wb
- assert hasattr(wb, '__version__') # verify package is not directory
+ assert hasattr(wb, "__version__") # verify package is not directory
import numpy as np
import pandas as pd
@@ -19,7 +19,7 @@ except (ImportError, AssertionError):
wb = None
-def _custom_table(x, y, classes, title='Precision Recall Curve', x_title='Recall', y_title='Precision'):
+def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
"""
Create and log a custom metric visualization to wandb.plot.pr_curve.
@@ -37,24 +37,25 @@ def _custom_table(x, y, classes, title='Precision Recall Curve', x_title='Recall
Returns:
(wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
"""
- df = pd.DataFrame({'class': classes, 'y': y, 'x': x}).round(3)
- fields = {'x': 'x', 'y': 'y', 'class': 'class'}
- string_fields = {'title': title, 'x-axis-title': x_title, 'y-axis-title': y_title}
- return wb.plot_table('wandb/area-under-curve/v0',
- wb.Table(dataframe=df),
- fields=fields,
- string_fields=string_fields)
-
-
-def _plot_curve(x,
- y,
- names=None,
- id='precision-recall',
- title='Precision Recall Curve',
- x_title='Recall',
- y_title='Precision',
- num_x=100,
- only_mean=False):
+ df = pd.DataFrame({"class": classes, "y": y, "x": x}).round(3)
+ fields = {"x": "x", "y": "y", "class": "class"}
+ string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
+ return wb.plot_table(
+ "wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields
+ )
+
+
+def _plot_curve(
+ x,
+ y,
+ names=None,
+ id="precision-recall",
+ title="Precision Recall Curve",
+ x_title="Recall",
+ y_title="Precision",
+ num_x=100,
+ only_mean=False,
+):
"""
Log a metric curve visualization.
@@ -88,7 +89,7 @@ def _plot_curve(x,
table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])
wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})
else:
- classes = ['mean'] * len(x_log)
+ classes = ["mean"] * len(x_log)
for i, yi in enumerate(y):
x_log.extend(x_new) # add new x
y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x
@@ -99,7 +100,7 @@ def _plot_curve(x,
def _log_plots(plots, step):
"""Logs plots from the input dictionary if they haven't been logged already at the specified step."""
for name, params in plots.items():
- timestamp = params['timestamp']
+ timestamp = params["timestamp"]
if _processed_plots.get(name) != timestamp:
wb.run.log({name.stem: wb.Image(str(name))}, step=step)
_processed_plots[name] = timestamp
@@ -107,7 +108,7 @@ def _log_plots(plots, step):
def on_pretrain_routine_start(trainer):
"""Initiate and start project if module is present."""
- wb.run or wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars(trainer.args))
+ wb.run or wb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=vars(trainer.args))
def on_fit_epoch_end(trainer):
@@ -121,7 +122,7 @@ def on_fit_epoch_end(trainer):
def on_train_epoch_end(trainer):
"""Log metrics and save images at the end of each training epoch."""
- wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1)
+ wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
wb.run.log(trainer.lr, step=trainer.epoch + 1)
if trainer.epoch == 1:
_log_plots(trainer.plots, step=trainer.epoch + 1)
@@ -131,17 +132,17 @@ def on_train_end(trainer):
"""Save the best model as an artifact at end of training."""
_log_plots(trainer.validator.plots, step=trainer.epoch + 1)
_log_plots(trainer.plots, step=trainer.epoch + 1)
- art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model')
+ art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
if trainer.best.exists():
art.add_file(trainer.best)
- wb.run.log_artifact(art, aliases=['best'])
+ wb.run.log_artifact(art, aliases=["best"])
for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):
x, y, x_title, y_title = curve_values
_plot_curve(
x,
y,
names=list(trainer.validator.metrics.names.values()),
- id=f'curves/{curve_name}',
+ id=f"curves/{curve_name}",
title=curve_name,
x_title=x_title,
y_title=y_title,
@@ -149,8 +150,13 @@ def on_train_end(trainer):
wb.run.finish() # required or run continues on dashboard
-callbacks = {
- 'on_pretrain_routine_start': on_pretrain_routine_start,
- 'on_train_epoch_end': on_train_epoch_end,
- 'on_fit_epoch_end': on_fit_epoch_end,
- 'on_train_end': on_train_end} if wb else {}
+callbacks = (
+ {
+ "on_pretrain_routine_start": on_pretrain_routine_start,
+ "on_train_epoch_end": on_train_epoch_end,
+ "on_fit_epoch_end": on_fit_epoch_end,
+ "on_train_end": on_train_end,
+ }
+ if wb
+ else {}
+)
diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py
index ed804fff..ba25fd46 100644
--- a/ultralytics/utils/checks.py
+++ b/ultralytics/utils/checks.py
@@ -21,12 +21,33 @@ import requests
import torch
from matplotlib import font_manager
-from ultralytics.utils import (ASSETS, AUTOINSTALL, LINUX, LOGGER, ONLINE, ROOT, USER_CONFIG_DIR, SimpleNamespace,
- ThreadingLocked, TryExcept, clean_url, colorstr, downloads, emojis, is_colab, is_docker,
- is_github_action_running, is_jupyter, is_kaggle, is_online, is_pip_package, url2file)
-
-
-def parse_requirements(file_path=ROOT.parent / 'requirements.txt', package=''):
+from ultralytics.utils import (
+ ASSETS,
+ AUTOINSTALL,
+ LINUX,
+ LOGGER,
+ ONLINE,
+ ROOT,
+ USER_CONFIG_DIR,
+ SimpleNamespace,
+ ThreadingLocked,
+ TryExcept,
+ clean_url,
+ colorstr,
+ downloads,
+ emojis,
+ is_colab,
+ is_docker,
+ is_github_action_running,
+ is_jupyter,
+ is_kaggle,
+ is_online,
+ is_pip_package,
+ url2file,
+)
+
+
+def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
"""
Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'.
@@ -46,23 +67,23 @@ def parse_requirements(file_path=ROOT.parent / 'requirements.txt', package=''):
"""
if package:
- requires = [x for x in metadata.distribution(package).requires if 'extra == ' not in x]
+ requires = [x for x in metadata.distribution(package).requires if "extra == " not in x]
else:
requires = Path(file_path).read_text().splitlines()
requirements = []
for line in requires:
line = line.strip()
- if line and not line.startswith('#'):
- line = line.split('#')[0].strip() # ignore inline comments
- match = re.match(r'([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?', line)
+ if line and not line.startswith("#"):
+ line = line.split("#")[0].strip() # ignore inline comments
+ match = re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line)
if match:
- requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else ''))
+ requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else ""))
return requirements
-def parse_version(version='0.0.0') -> tuple:
+def parse_version(version="0.0.0") -> tuple:
"""
Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This
function replaces deprecated 'pkg_resources.parse_version(v)'.
@@ -74,9 +95,9 @@ def parse_version(version='0.0.0') -> tuple:
(tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1)
"""
try:
- return tuple(map(int, re.findall(r'\d+', version)[:3])) # '2.0.1+cpu' -> (2, 0, 1)
+ return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1)
except Exception as e:
- LOGGER.warning(f'WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}')
+ LOGGER.warning(f"WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}")
return 0, 0, 0
@@ -121,15 +142,19 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
elif isinstance(imgsz, (list, tuple)):
imgsz = list(imgsz)
else:
- raise TypeError(f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
- f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'")
+ raise TypeError(
+ f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
+ f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'"
+ )
# Apply max_dim
if len(imgsz) > max_dim:
- msg = "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " \
- "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
+ msg = (
+ "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list "
+ "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
+ )
if max_dim != 1:
- raise ValueError(f'imgsz={imgsz} is not a valid image size. {msg}')
+ raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}")
LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}")
imgsz = [max(imgsz)]
# Make image size a multiple of the stride
@@ -137,7 +162,7 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
# Print warning message if image size was updated
if sz != imgsz:
- LOGGER.warning(f'WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}')
+ LOGGER.warning(f"WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}")
# Add missing dimensions if necessary
sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz
@@ -145,12 +170,14 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
return sz
-def check_version(current: str = '0.0.0',
- required: str = '0.0.0',
- name: str = 'version',
- hard: bool = False,
- verbose: bool = False,
- msg: str = '') -> bool:
+def check_version(
+ current: str = "0.0.0",
+ required: str = "0.0.0",
+ name: str = "version",
+ hard: bool = False,
+ verbose: bool = False,
+ msg: str = "",
+) -> bool:
"""
Check current version against the required version or range.
@@ -181,7 +208,7 @@ def check_version(current: str = '0.0.0',
```
"""
if not current: # if current is '' or None
- LOGGER.warning(f'WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.')
+ LOGGER.warning(f"WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.")
return True
elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics'
try:
@@ -189,34 +216,34 @@ def check_version(current: str = '0.0.0',
current = metadata.version(current) # get version string from package name
except metadata.PackageNotFoundError:
if hard:
- raise ModuleNotFoundError(emojis(f'WARNING ⚠️ {current} package is required but not installed'))
+ raise ModuleNotFoundError(emojis(f"WARNING ⚠️ {current} package is required but not installed"))
else:
return False
if not required: # if required is '' or None
return True
- op = ''
- version = ''
+ op = ""
+ version = ""
result = True
c = parse_version(current) # '1.2.3' -> (1, 2, 3)
- for r in required.strip(',').split(','):
- op, version = re.match(r'([^0-9]*)([\d.]+)', r).groups() # split '>=22.04' -> ('>=', '22.04')
+ for r in required.strip(",").split(","):
+ op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04')
v = parse_version(version) # '1.2.3' -> (1, 2, 3)
- if op == '==' and c != v:
+ if op == "==" and c != v:
result = False
- elif op == '!=' and c == v:
+ elif op == "!=" and c == v:
result = False
- elif op in ('>=', '') and not (c >= v): # if no constraint passed assume '>=required'
+ elif op in (">=", "") and not (c >= v): # if no constraint passed assume '>=required'
result = False
- elif op == '<=' and not (c <= v):
+ elif op == "<=" and not (c <= v):
result = False
- elif op == '>' and not (c > v):
+ elif op == ">" and not (c > v):
result = False
- elif op == '<' and not (c < v):
+ elif op == "<" and not (c < v):
result = False
if not result:
- warning = f'WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}'
+ warning = f"WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}"
if hard:
raise ModuleNotFoundError(emojis(warning)) # assert version requirements met
if verbose:
@@ -224,7 +251,7 @@ def check_version(current: str = '0.0.0',
return result
-def check_latest_pypi_version(package_name='ultralytics'):
+def check_latest_pypi_version(package_name="ultralytics"):
"""
Returns the latest version of a PyPI package without downloading or installing it.
@@ -236,9 +263,9 @@ def check_latest_pypi_version(package_name='ultralytics'):
"""
with contextlib.suppress(Exception):
requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning
- response = requests.get(f'https://pypi.org/pypi/{package_name}/json', timeout=3)
+ response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3)
if response.status_code == 200:
- return response.json()['info']['version']
+ return response.json()["info"]["version"]
def check_pip_update_available():
@@ -251,16 +278,19 @@ def check_pip_update_available():
if ONLINE and is_pip_package():
with contextlib.suppress(Exception):
from ultralytics import __version__
+
latest = check_latest_pypi_version()
- if check_version(__version__, f'<{latest}'): # check if current version is < latest version
- LOGGER.info(f'New https://pypi.org/project/ultralytics/{latest} available 😃 '
- f"Update with 'pip install -U ultralytics'")
+ if check_version(__version__, f"<{latest}"): # check if current version is < latest version
+ LOGGER.info(
+ f"New https://pypi.org/project/ultralytics/{latest} available 😃 "
+ f"Update with 'pip install -U ultralytics'"
+ )
return True
return False
@ThreadingLocked()
-def check_font(font='Arial.ttf'):
+def check_font(font="Arial.ttf"):
"""
Find font locally or download to user's configuration directory if it does not already exist.
@@ -283,13 +313,13 @@ def check_font(font='Arial.ttf'):
return matches[0]
# Download to USER_CONFIG_DIR if missing
- url = f'https://ultralytics.com/assets/{name}'
+ url = f"https://ultralytics.com/assets/{name}"
if downloads.is_url(url):
downloads.safe_download(url=url, file=file)
return file
-def check_python(minimum: str = '3.8.0') -> bool:
+def check_python(minimum: str = "3.8.0") -> bool:
"""
Check current python version against the required minimum version.
@@ -299,11 +329,11 @@ def check_python(minimum: str = '3.8.0') -> bool:
Returns:
None
"""
- return check_version(platform.python_version(), minimum, name='Python ', hard=True)
+ return check_version(platform.python_version(), minimum, name="Python ", hard=True)
@TryExcept()
-def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=(), install=True, cmds=''):
+def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""):
"""
Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed.
@@ -329,41 +359,42 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
```
"""
- prefix = colorstr('red', 'bold', 'requirements:')
+ prefix = colorstr("red", "bold", "requirements:")
check_python() # check python version
check_torchvision() # check torch-torchvision compatibility
if isinstance(requirements, Path): # requirements.txt file
file = requirements.resolve()
- assert file.exists(), f'{prefix} {file} not found, check failed.'
- requirements = [f'{x.name}{x.specifier}' for x in parse_requirements(file) if x.name not in exclude]
+ assert file.exists(), f"{prefix} {file} not found, check failed."
+ requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude]
elif isinstance(requirements, str):
requirements = [requirements]
pkgs = []
for r in requirements:
- r_stripped = r.split('/')[-1].replace('.git', '') # replace git+https://org/repo.git -> 'repo'
- match = re.match(r'([a-zA-Z0-9-_]+)([<>!=~]+.*)?', r_stripped)
- name, required = match[1], match[2].strip() if match[2] else ''
+ r_stripped = r.split("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo'
+ match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped)
+ name, required = match[1], match[2].strip() if match[2] else ""
try:
assert check_version(metadata.version(name), required) # exception if requirements not met
except (AssertionError, metadata.PackageNotFoundError):
pkgs.append(r)
- s = ' '.join(f'"{x}"' for x in pkgs) # console string
+ s = " ".join(f'"{x}"' for x in pkgs) # console string
if s:
if install and AUTOINSTALL: # check environment variable
n = len(pkgs) # number of packages updates
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
try:
t = time.time()
- assert is_online(), 'AutoUpdate skipped (offline)'
- LOGGER.info(subprocess.check_output(f'pip install --no-cache {s} {cmds}', shell=True).decode())
+ assert is_online(), "AutoUpdate skipped (offline)"
+ LOGGER.info(subprocess.check_output(f"pip install --no-cache {s} {cmds}", shell=True).decode())
dt = time.time() - t
LOGGER.info(
f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n"
- f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n")
+ f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
+ )
except Exception as e:
- LOGGER.warning(f'{prefix} ❌ {e}')
+ LOGGER.warning(f"{prefix} ❌ {e}")
return False
else:
return False
@@ -386,76 +417,82 @@ def check_torchvision():
import torchvision
# Compatibility table
- compatibility_table = {'2.0': ['0.15'], '1.13': ['0.14'], '1.12': ['0.13']}
+ compatibility_table = {"2.0": ["0.15"], "1.13": ["0.14"], "1.12": ["0.13"]}
# Extract only the major and minor versions
- v_torch = '.'.join(torch.__version__.split('+')[0].split('.')[:2])
- v_torchvision = '.'.join(torchvision.__version__.split('+')[0].split('.')[:2])
+ v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
+ v_torchvision = ".".join(torchvision.__version__.split("+")[0].split(".")[:2])
if v_torch in compatibility_table:
compatible_versions = compatibility_table[v_torch]
if all(v_torchvision != v for v in compatible_versions):
- print(f'WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n'
- f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
- "'pip install -U torch torchvision' to update both.\n"
- 'For a full compatibility table see https://github.com/pytorch/vision#installation')
+ print(
+ f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n"
+ f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
+ "'pip install -U torch torchvision' to update both.\n"
+ "For a full compatibility table see https://github.com/pytorch/vision#installation"
+ )
-def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
+def check_suffix(file="yolov8n.pt", suffix=".pt", msg=""):
"""Check file(s) for acceptable suffix."""
if file and suffix:
if isinstance(suffix, str):
- suffix = (suffix, )
+ suffix = (suffix,)
for f in file if isinstance(file, (list, tuple)) else [file]:
s = Path(f).suffix.lower().strip() # file suffix
if len(s):
- assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}, not {s}'
+ assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}"
def check_yolov5u_filename(file: str, verbose: bool = True):
"""Replace legacy YOLOv5 filenames with updated YOLOv5u filenames."""
- if 'yolov3' in file or 'yolov5' in file:
- if 'u.yaml' in file:
- file = file.replace('u.yaml', '.yaml') # i.e. yolov5nu.yaml -> yolov5n.yaml
- elif '.pt' in file and 'u' not in file:
+ if "yolov3" in file or "yolov5" in file:
+ if "u.yaml" in file:
+ file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml
+ elif ".pt" in file and "u" not in file:
original_file = file
- file = re.sub(r'(.*yolov5([nsmlx]))\.pt', '\\1u.pt', file) # i.e. yolov5n.pt -> yolov5nu.pt
- file = re.sub(r'(.*yolov5([nsmlx])6)\.pt', '\\1u.pt', file) # i.e. yolov5n6.pt -> yolov5n6u.pt
- file = re.sub(r'(.*yolov3(|-tiny|-spp))\.pt', '\\1u.pt', file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
+ file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt
+ file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt
+ file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
if file != original_file and verbose:
LOGGER.info(
f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
- f'trained with https://github.com/ultralytics/ultralytics and feature improved performance vs '
- f'standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n')
+ f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs "
+ f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n"
+ )
return file
-def check_model_file_from_stem(model='yolov8n'):
+def check_model_file_from_stem(model="yolov8n"):
"""Return a model filename from a valid model stem."""
if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS:
- return Path(model).with_suffix('.pt') # add suffix, i.e. yolov8n -> yolov8n.pt
+ return Path(model).with_suffix(".pt") # add suffix, i.e. yolov8n -> yolov8n.pt
else:
return model
-def check_file(file, suffix='', download=True, hard=True):
+def check_file(file, suffix="", download=True, hard=True):
"""Search/download file (if necessary) and return path."""
check_suffix(file, suffix) # optional
file = str(file).strip() # convert to string and strip spaces
file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
- if (not file or ('://' not in file and Path(file).exists()) or # '://' check required in Windows Python<3.10
- file.lower().startswith('grpc://')): # file exists or gRPC Triton images
+ if (
+ not file
+ or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10
+ or file.lower().startswith("grpc://")
+ ): # file exists or gRPC Triton images
return file
- elif download and file.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://', 'tcp://')): # download
+ elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download
url = file # warning: Pathlib turns :// -> :/
file = url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth
if Path(file).exists():
- LOGGER.info(f'Found {clean_url(url)} locally at {file}') # file already exists
+ LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
else:
downloads.safe_download(url=url, file=file, unzip=False)
return file
else: # search
- files = glob.glob(str(ROOT / 'cfg' / '**' / file), recursive=True) # find file
+ files = glob.glob(str(ROOT / "cfg" / "**" / file), recursive=True) # find file
if not files and hard:
raise FileNotFoundError(f"'{file}' does not exist")
elif len(files) > 1 and hard:
@@ -463,7 +500,7 @@ def check_file(file, suffix='', download=True, hard=True):
return files[0] if len(files) else [] # return file
-def check_yaml(file, suffix=('.yaml', '.yml'), hard=True):
+def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
"""Search/download YAML file (if necessary) and return path, checking suffix."""
return check_file(file, suffix, hard=hard)
@@ -482,51 +519,52 @@ def check_is_path_safe(basedir, path):
base_dir_resolved = Path(basedir).resolve()
path_resolved = Path(path).resolve()
- return path_resolved.is_file() and path_resolved.parts[:len(base_dir_resolved.parts)] == base_dir_resolved.parts
+ return path_resolved.is_file() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts
def check_imshow(warn=False):
"""Check if environment supports image displays."""
try:
if LINUX:
- assert 'DISPLAY' in os.environ and not is_docker() and not is_colab() and not is_kaggle()
- cv2.imshow('test', np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image
+ assert "DISPLAY" in os.environ and not is_docker() and not is_colab() and not is_kaggle()
+ cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image
cv2.waitKey(1)
cv2.destroyAllWindows()
cv2.waitKey(1)
return True
except Exception as e:
if warn:
- LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
+ LOGGER.warning(f"WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}")
return False
-def check_yolo(verbose=True, device=''):
+def check_yolo(verbose=True, device=""):
"""Return a human-readable YOLO software and hardware summary."""
import psutil
from ultralytics.utils.torch_utils import select_device
if is_jupyter():
- if check_requirements('wandb', install=False):
- os.system('pip uninstall -y wandb') # uninstall wandb: unwanted account creation prompt with infinite hang
+ if check_requirements("wandb", install=False):
+ os.system("pip uninstall -y wandb") # uninstall wandb: unwanted account creation prompt with infinite hang
if is_colab():
- shutil.rmtree('sample_data', ignore_errors=True) # remove colab /sample_data directory
+ shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory
if verbose:
# System info
gib = 1 << 30 # bytes per GiB
ram = psutil.virtual_memory().total
- total, used, free = shutil.disk_usage('/')
- s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)'
+ total, used, free = shutil.disk_usage("/")
+ s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)"
with contextlib.suppress(Exception): # clear display if ipython is installed
from IPython import display
+
display.clear_output()
else:
- s = ''
+ s = ""
select_device(device=device, newline=False)
- LOGGER.info(f'Setup complete ✅ {s}')
+ LOGGER.info(f"Setup complete ✅ {s}")
def collect_system_info():
@@ -537,32 +575,36 @@ def collect_system_info():
from ultralytics.utils import ENVIRONMENT, is_git_dir
from ultralytics.utils.torch_utils import get_cpu_info
- ram_info = psutil.virtual_memory().total / (1024 ** 3) # Convert bytes to GB
+ ram_info = psutil.virtual_memory().total / (1024**3) # Convert bytes to GB
check_yolo()
- LOGGER.info(f"\n{'OS':<20}{platform.platform()}\n"
- f"{'Environment':<20}{ENVIRONMENT}\n"
- f"{'Python':<20}{sys.version.split()[0]}\n"
- f"{'Install':<20}{'git' if is_git_dir() else 'pip' if is_pip_package() else 'other'}\n"
- f"{'RAM':<20}{ram_info:.2f} GB\n"
- f"{'CPU':<20}{get_cpu_info()}\n"
- f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n")
-
- for r in parse_requirements(package='ultralytics'):
+ LOGGER.info(
+ f"\n{'OS':<20}{platform.platform()}\n"
+ f"{'Environment':<20}{ENVIRONMENT}\n"
+ f"{'Python':<20}{sys.version.split()[0]}\n"
+ f"{'Install':<20}{'git' if is_git_dir() else 'pip' if is_pip_package() else 'other'}\n"
+ f"{'RAM':<20}{ram_info:.2f} GB\n"
+ f"{'CPU':<20}{get_cpu_info()}\n"
+ f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n"
+ )
+
+ for r in parse_requirements(package="ultralytics"):
try:
current = metadata.version(r.name)
- is_met = '✅ ' if check_version(current, str(r.specifier), hard=True) else '❌ '
+ is_met = "✅ " if check_version(current, str(r.specifier), hard=True) else "❌ "
except metadata.PackageNotFoundError:
- current = '(not installed)'
- is_met = '❌ '
- LOGGER.info(f'{r.name:<20}{is_met}{current}{r.specifier}')
+ current = "(not installed)"
+ is_met = "❌ "
+ LOGGER.info(f"{r.name:<20}{is_met}{current}{r.specifier}")
if is_github_action_running():
- LOGGER.info(f"\nRUNNER_OS: {os.getenv('RUNNER_OS')}\n"
- f"GITHUB_EVENT_NAME: {os.getenv('GITHUB_EVENT_NAME')}\n"
- f"GITHUB_WORKFLOW: {os.getenv('GITHUB_WORKFLOW')}\n"
- f"GITHUB_ACTOR: {os.getenv('GITHUB_ACTOR')}\n"
- f"GITHUB_REPOSITORY: {os.getenv('GITHUB_REPOSITORY')}\n"
- f"GITHUB_REPOSITORY_OWNER: {os.getenv('GITHUB_REPOSITORY_OWNER')}\n")
+ LOGGER.info(
+ f"\nRUNNER_OS: {os.getenv('RUNNER_OS')}\n"
+ f"GITHUB_EVENT_NAME: {os.getenv('GITHUB_EVENT_NAME')}\n"
+ f"GITHUB_WORKFLOW: {os.getenv('GITHUB_WORKFLOW')}\n"
+ f"GITHUB_ACTOR: {os.getenv('GITHUB_ACTOR')}\n"
+ f"GITHUB_REPOSITORY: {os.getenv('GITHUB_REPOSITORY')}\n"
+ f"GITHUB_REPOSITORY_OWNER: {os.getenv('GITHUB_REPOSITORY_OWNER')}\n"
+ )
def check_amp(model):
@@ -587,7 +629,7 @@ def check_amp(model):
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
"""
device = next(model.parameters()).device # get model device
- if device.type in ('cpu', 'mps'):
+ if device.type in ("cpu", "mps"):
return False # AMP only used on CUDA devices
def amp_allclose(m, im):
@@ -598,22 +640,27 @@ def check_amp(model):
del m
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
- im = ASSETS / 'bus.jpg' # image to check
- prefix = colorstr('AMP: ')
- LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...')
+ im = ASSETS / "bus.jpg" # image to check
+ prefix = colorstr("AMP: ")
+ LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...")
warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."
try:
from ultralytics import YOLO
- assert amp_allclose(YOLO('yolov8n.pt'), im)
- LOGGER.info(f'{prefix}checks passed ✅')
+
+ assert amp_allclose(YOLO("yolov8n.pt"), im)
+ LOGGER.info(f"{prefix}checks passed ✅")
except ConnectionError:
- LOGGER.warning(f'{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}')
+ LOGGER.warning(f"{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}")
except (AttributeError, ModuleNotFoundError):
- LOGGER.warning(f'{prefix}checks skipped ⚠️. '
- f'Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}')
+ LOGGER.warning(
+ f"{prefix}checks skipped ⚠️. "
+ f"Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}"
+ )
except AssertionError:
- LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to '
- f'NaN losses or zero-mAP results, so AMP will be disabled during training.')
+ LOGGER.warning(
+ f"{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to "
+ f"NaN losses or zero-mAP results, so AMP will be disabled during training."
+ )
return False
return True
@@ -621,8 +668,8 @@ def check_amp(model):
def git_describe(path=ROOT): # path must be a directory
"""Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
with contextlib.suppress(Exception):
- return subprocess.check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
- return ''
+ return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1]
+ return ""
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
@@ -630,7 +677,7 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
def strip_auth(v):
"""Clean longer Ultralytics HUB URLs by stripping potential authentication information."""
- return clean_url(v) if (isinstance(v, str) and v.startswith('http') and len(v) > 100) else v
+ return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v
x = inspect.currentframe().f_back # previous frame
file, _, func, _, _ = inspect.getframeinfo(x)
@@ -638,11 +685,11 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
args, _, _, frm = inspect.getargvalues(x)
args = {k: v for k, v in frm.items() if k in args}
try:
- file = Path(file).resolve().relative_to(ROOT).with_suffix('')
+ file = Path(file).resolve().relative_to(ROOT).with_suffix("")
except ValueError:
file = Path(file).stem
- s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
- LOGGER.info(colorstr(s) + ', '.join(f'{k}={strip_auth(v)}' for k, v in args.items()))
+ s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
+ LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in args.items()))
def cuda_device_count() -> int:
@@ -654,11 +701,12 @@ def cuda_device_count() -> int:
"""
try:
# Run the nvidia-smi command and capture its output
- output = subprocess.check_output(['nvidia-smi', '--query-gpu=count', '--format=csv,noheader,nounits'],
- encoding='utf-8')
+ output = subprocess.check_output(
+ ["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8"
+ )
# Take the first line and strip any leading/trailing white space
- first_line = output.strip().split('\n')[0]
+ first_line = output.strip().split("\n")[0]
return int(first_line)
except (subprocess.CalledProcessError, FileNotFoundError, ValueError):
diff --git a/ultralytics/utils/dist.py b/ultralytics/utils/dist.py
index b07204e9..b669e52f 100644
--- a/ultralytics/utils/dist.py
+++ b/ultralytics/utils/dist.py
@@ -18,13 +18,13 @@ def find_free_network_port() -> int:
`MASTER_PORT` environment variable.
"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(('127.0.0.1', 0))
+ s.bind(("127.0.0.1", 0))
return s.getsockname()[1] # port
def generate_ddp_file(trainer):
"""Generates a DDP file and returns its file name."""
- module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1)
+ module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1)
content = f"""
# Ultralytics Multi-GPU training temp file (should be automatically deleted after use)
@@ -39,13 +39,15 @@ if __name__ == "__main__":
trainer = {name}(cfg=cfg, overrides=overrides)
results = trainer.train()
"""
- (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
- with tempfile.NamedTemporaryFile(prefix='_temp_',
- suffix=f'{id(trainer)}.py',
- mode='w+',
- encoding='utf-8',
- dir=USER_CONFIG_DIR / 'DDP',
- delete=False) as file:
+ (USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True)
+ with tempfile.NamedTemporaryFile(
+ prefix="_temp_",
+ suffix=f"{id(trainer)}.py",
+ mode="w+",
+ encoding="utf-8",
+ dir=USER_CONFIG_DIR / "DDP",
+ delete=False,
+ ) as file:
file.write(content)
return file.name
@@ -53,16 +55,17 @@ if __name__ == "__main__":
def generate_ddp_command(world_size, trainer):
"""Generates and returns command for distributed training."""
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
+
if not trainer.resume:
shutil.rmtree(trainer.save_dir) # remove the save_dir
file = generate_ddp_file(trainer)
- dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
+ dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
port = find_free_network_port()
- cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file]
+ cmd = [sys.executable, "-m", dist_cmd, "--nproc_per_node", f"{world_size}", "--master_port", f"{port}", file]
return cmd, file
def ddp_cleanup(trainer, file):
"""Delete temp file if created."""
- if f'{id(trainer)}.py' in file: # if temp_file suffix in file
+ if f"{id(trainer)}.py" in file: # if temp_file suffix in file
os.remove(file)
diff --git a/ultralytics/utils/downloads.py b/ultralytics/utils/downloads.py
index 363e9a0d..fe8b8af0 100644
--- a/ultralytics/utils/downloads.py
+++ b/ultralytics/utils/downloads.py
@@ -15,15 +15,17 @@ import torch
from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file
# Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets
-GITHUB_ASSETS_REPO = 'ultralytics/assets'
-GITHUB_ASSETS_NAMES = [f'yolov8{k}{suffix}.pt' for k in 'nsmlx' for suffix in ('', '-cls', '-seg', '-pose', '-obb')] + \
- [f'yolov5{k}{resolution}u.pt' for k in 'nsmlx' for resolution in ('', '6')] + \
- [f'yolov3{k}u.pt' for k in ('', '-spp', '-tiny')] + \
- [f'yolo_nas_{k}.pt' for k in 'sml'] + \
- [f'sam_{k}.pt' for k in 'bl'] + \
- [f'FastSAM-{k}.pt' for k in 'sx'] + \
- [f'rtdetr-{k}.pt' for k in 'lx'] + \
- ['mobile_sam.pt']
+GITHUB_ASSETS_REPO = "ultralytics/assets"
+GITHUB_ASSETS_NAMES = (
+ [f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")]
+ + [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")]
+ + [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")]
+ + [f"yolo_nas_{k}.pt" for k in "sml"]
+ + [f"sam_{k}.pt" for k in "bl"]
+ + [f"FastSAM-{k}.pt" for k in "sx"]
+ + [f"rtdetr-{k}.pt" for k in "lx"]
+ + ["mobile_sam.pt"]
+)
GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES]
@@ -56,7 +58,7 @@ def is_url(url, check=True):
return False
-def delete_dsstore(path, files_to_delete=('.DS_Store', '__MACOSX')):
+def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):
"""
Deletes all ".DS_store" files under a specified directory.
@@ -77,12 +79,12 @@ def delete_dsstore(path, files_to_delete=('.DS_Store', '__MACOSX')):
"""
for file in files_to_delete:
matches = list(Path(path).rglob(file))
- LOGGER.info(f'Deleting {file} files: {matches}')
+ LOGGER.info(f"Deleting {file} files: {matches}")
for f in matches:
f.unlink()
-def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), progress=True):
+def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True):
"""
Zips the contents of a directory, excluding files containing strings in the exclude list. The resulting zip file is
named after the directory and placed alongside it.
@@ -111,17 +113,17 @@ def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), p
raise FileNotFoundError(f"Directory '{directory}' does not exist.")
# Unzip with progress bar
- files_to_zip = [f for f in directory.rglob('*') if f.is_file() and all(x not in f.name for x in exclude)]
- zip_file = directory.with_suffix('.zip')
+ files_to_zip = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)]
+ zip_file = directory.with_suffix(".zip")
compression = ZIP_DEFLATED if compress else ZIP_STORED
- with ZipFile(zip_file, 'w', compression) as f:
- for file in TQDM(files_to_zip, desc=f'Zipping {directory} to {zip_file}...', unit='file', disable=not progress):
+ with ZipFile(zip_file, "w", compression) as f:
+ for file in TQDM(files_to_zip, desc=f"Zipping {directory} to {zip_file}...", unit="file", disable=not progress):
f.write(file, file.relative_to(directory))
return zip_file # return path to zip file
-def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=False, progress=True):
+def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True):
"""
Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list.
@@ -161,7 +163,7 @@ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=Fals
files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)]
top_level_dirs = {Path(f).parts[0] for f in files}
- if len(top_level_dirs) > 1 or (len(files) > 1 and not files[0].endswith('/')):
+ if len(top_level_dirs) > 1 or (len(files) > 1 and not files[0].endswith("/")):
# Zip has multiple files at top level
path = extract_path = Path(path) / Path(file).stem # i.e. ../datasets/coco8
else:
@@ -172,20 +174,20 @@ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=Fals
# Check if destination directory already exists and contains files
if path.exists() and any(path.iterdir()) and not exist_ok:
# If it exists and is not empty, return the path without unzipping
- LOGGER.warning(f'WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.')
+ LOGGER.warning(f"WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.")
return path
- for f in TQDM(files, desc=f'Unzipping {file} to {Path(path).resolve()}...', unit='file', disable=not progress):
+ for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="file", disable=not progress):
# Ensure the file is within the extract_path to avoid path traversal security vulnerability
- if '..' in Path(f).parts:
- LOGGER.warning(f'Potentially insecure file path: {f}, skipping extraction.')
+ if ".." in Path(f).parts:
+ LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.")
continue
zipObj.extract(f, extract_path)
return path # return unzip dir
-def check_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=1.5, hard=True):
+def check_disk_space(url="https://ultralytics.com/assets/coco128.zip", sf=1.5, hard=True):
"""
Check if there is sufficient disk space to download and store a file.
@@ -199,20 +201,23 @@ def check_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=1.5, h
"""
try:
r = requests.head(url) # response
- assert r.status_code < 400, f'URL error for {url}: {r.status_code} {r.reason}' # check response
+ assert r.status_code < 400, f"URL error for {url}: {r.status_code} {r.reason}" # check response
except Exception:
return True # requests issue, default to True
# Check file size
gib = 1 << 30 # bytes per GiB
- data = int(r.headers.get('Content-Length', 0)) / gib # file size (GB)
+ data = int(r.headers.get("Content-Length", 0)) / gib # file size (GB)
total, used, free = (x / gib for x in shutil.disk_usage(Path.cwd())) # bytes
+
if data * sf < free:
return True # sufficient space
# Insufficient space
- text = (f'WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, '
- f'Please free {data * sf - free:.1f} GB additional disk space and try again.')
+ text = (
+ f"WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, "
+ f"Please free {data * sf - free:.1f} GB additional disk space and try again."
+ )
if hard:
raise MemoryError(text)
LOGGER.warning(text)
@@ -238,36 +243,41 @@ def get_google_drive_file_info(link):
url, filename = get_google_drive_file_info(link)
```
"""
- file_id = link.split('/d/')[1].split('/view')[0]
- drive_url = f'https://drive.google.com/uc?export=download&id={file_id}'
+ file_id = link.split("/d/")[1].split("/view")[0]
+ drive_url = f"https://drive.google.com/uc?export=download&id={file_id}"
filename = None
# Start session
with requests.Session() as session:
response = session.get(drive_url, stream=True)
- if 'quota exceeded' in str(response.content.lower()):
+ if "quota exceeded" in str(response.content.lower()):
raise ConnectionError(
- emojis(f'❌ Google Drive file download quota exceeded. '
- f'Please try again later or download this file manually at {link}.'))
+ emojis(
+ f"❌ Google Drive file download quota exceeded. "
+ f"Please try again later or download this file manually at {link}."
+ )
+ )
for k, v in response.cookies.items():
- if k.startswith('download_warning'):
- drive_url += f'&confirm={v}' # v is token
- cd = response.headers.get('content-disposition')
+ if k.startswith("download_warning"):
+ drive_url += f"&confirm={v}" # v is token
+ cd = response.headers.get("content-disposition")
if cd:
filename = re.findall('filename="(.+)"', cd)[0]
return drive_url, filename
-def safe_download(url,
- file=None,
- dir=None,
- unzip=True,
- delete=False,
- curl=False,
- retry=3,
- min_bytes=1E0,
- exist_ok=False,
- progress=True):
+def safe_download(
+ url,
+ file=None,
+ dir=None,
+ unzip=True,
+ delete=False,
+ curl=False,
+ retry=3,
+ min_bytes=1e0,
+ exist_ok=False,
+ progress=True,
+):
"""
Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
@@ -294,36 +304,38 @@ def safe_download(url,
path = safe_download(link)
```
"""
- gdrive = url.startswith('https://drive.google.com/') # check if the URL is a Google Drive link
+ gdrive = url.startswith("https://drive.google.com/") # check if the URL is a Google Drive link
if gdrive:
url, file = get_google_drive_file_info(url)
- f = Path(dir or '.') / (file or url2file(url)) # URL converted to filename
- if '://' not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)
+ f = Path(dir or ".") / (file or url2file(url)) # URL converted to filename
+ if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)
f = Path(url) # filename
elif not f.is_file(): # URL and file do not exist
desc = f"Downloading {url if gdrive else clean_url(url)} to '{f}'"
- LOGGER.info(f'{desc}...')
+ LOGGER.info(f"{desc}...")
f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing
check_disk_space(url)
for i in range(retry + 1):
try:
if curl or i > 0: # curl download with retry, continue
- s = 'sS' * (not progress) # silent
- r = subprocess.run(['curl', '-#', f'-{s}L', url, '-o', f, '--retry', '3', '-C', '-']).returncode
- assert r == 0, f'Curl return value {r}'
+ s = "sS" * (not progress) # silent
+ r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode
+ assert r == 0, f"Curl return value {r}"
else: # urllib download
- method = 'torch'
- if method == 'torch':
+ method = "torch"
+ if method == "torch":
torch.hub.download_url_to_file(url, f, progress=progress)
else:
- with request.urlopen(url) as response, TQDM(total=int(response.getheader('Content-Length', 0)),
- desc=desc,
- disable=not progress,
- unit='B',
- unit_scale=True,
- unit_divisor=1024) as pbar:
- with open(f, 'wb') as f_opened:
+ with request.urlopen(url) as response, TQDM(
+ total=int(response.getheader("Content-Length", 0)),
+ desc=desc,
+ disable=not progress,
+ unit="B",
+ unit_scale=True,
+ unit_divisor=1024,
+ ) as pbar:
+ with open(f, "wb") as f_opened:
for data in response:
f_opened.write(data)
pbar.update(len(data))
@@ -334,26 +346,26 @@ def safe_download(url,
f.unlink() # remove partial downloads
except Exception as e:
if i == 0 and not is_online():
- raise ConnectionError(emojis(f'❌ Download failure for {url}. Environment is not online.')) from e
+ raise ConnectionError(emojis(f"❌ Download failure for {url}. Environment is not online.")) from e
elif i >= retry:
- raise ConnectionError(emojis(f'❌ Download failure for {url}. Retry limit reached.')) from e
- LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
+ raise ConnectionError(emojis(f"❌ Download failure for {url}. Retry limit reached.")) from e
+ LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {url}...")
- if unzip and f.exists() and f.suffix in ('', '.zip', '.tar', '.gz'):
+ if unzip and f.exists() and f.suffix in ("", ".zip", ".tar", ".gz"):
from zipfile import is_zipfile
unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place
if is_zipfile(f):
unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip
- elif f.suffix in ('.tar', '.gz'):
- LOGGER.info(f'Unzipping {f} to {unzip_dir}...')
- subprocess.run(['tar', 'xf' if f.suffix == '.tar' else 'xfz', f, '--directory', unzip_dir], check=True)
+ elif f.suffix in (".tar", ".gz"):
+ LOGGER.info(f"Unzipping {f} to {unzip_dir}...")
+ subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True)
if delete:
f.unlink() # remove zip
return unzip_dir
-def get_github_assets(repo='ultralytics/assets', version='latest', retry=False):
+def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):
"""
Retrieve the specified version's tag and assets from a GitHub repository. If the version is not specified, the
function fetches the latest release assets.
@@ -372,20 +384,20 @@ def get_github_assets(repo='ultralytics/assets', version='latest', retry=False):
```
"""
- if version != 'latest':
- version = f'tags/{version}' # i.e. tags/v6.2
- url = f'https://api.github.com/repos/{repo}/releases/{version}'
+ if version != "latest":
+ version = f"tags/{version}" # i.e. tags/v6.2
+ url = f"https://api.github.com/repos/{repo}/releases/{version}"
r = requests.get(url) # github api
- if r.status_code != 200 and r.reason != 'rate limit exceeded' and retry: # failed and not 403 rate limit exceeded
+ if r.status_code != 200 and r.reason != "rate limit exceeded" and retry: # failed and not 403 rate limit exceeded
r = requests.get(url) # try again
if r.status_code != 200:
- LOGGER.warning(f'⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}')
- return '', []
+ LOGGER.warning(f"⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}")
+ return "", []
data = r.json()
- return data['tag_name'], [x['name'] for x in data['assets']] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...]
+ return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...]
-def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0', **kwargs):
+def attempt_download_asset(file, repo="ultralytics/assets", release="v0.0.0", **kwargs):
"""
Attempt to download a file from GitHub release assets if it is not found locally. The function checks for the file
locally first, then tries to download it from the specified GitHub repository release.
@@ -409,32 +421,32 @@ def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0', **
# YOLOv3/5u updates
file = str(file)
file = checks.check_yolov5u_filename(file)
- file = Path(file.strip().replace("'", ''))
+ file = Path(file.strip().replace("'", ""))
if file.exists():
return str(file)
- elif (SETTINGS['weights_dir'] / file).exists():
- return str(SETTINGS['weights_dir'] / file)
+ elif (SETTINGS["weights_dir"] / file).exists():
+ return str(SETTINGS["weights_dir"] / file)
else:
# URL specified
name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc.
- download_url = f'https://github.com/{repo}/releases/download'
- if str(file).startswith(('http:/', 'https:/')): # download
- url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
+ download_url = f"https://github.com/{repo}/releases/download"
+ if str(file).startswith(("http:/", "https:/")): # download
+ url = str(file).replace(":/", "://") # Pathlib turns :// -> :/
file = url2file(name) # parse authentication https://url.com/file.txt?auth...
if Path(file).is_file():
- LOGGER.info(f'Found {clean_url(url)} locally at {file}') # file already exists
+ LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
else:
- safe_download(url=url, file=file, min_bytes=1E5, **kwargs)
+ safe_download(url=url, file=file, min_bytes=1e5, **kwargs)
elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES:
- safe_download(url=f'{download_url}/{release}/{name}', file=file, min_bytes=1E5, **kwargs)
+ safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs)
else:
tag, assets = get_github_assets(repo, release)
if not assets:
tag, assets = get_github_assets(repo) # latest release
if name in assets:
- safe_download(url=f'{download_url}/{tag}/{name}', file=file, min_bytes=1E5, **kwargs)
+ safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs)
return str(file)
@@ -464,14 +476,18 @@ def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=
if threads > 1:
with ThreadPool(threads) as pool:
pool.map(
- lambda x: safe_download(url=x[0],
- dir=x[1],
- unzip=unzip,
- delete=delete,
- curl=curl,
- retry=retry,
- exist_ok=exist_ok,
- progress=threads <= 1), zip(url, repeat(dir)))
+ lambda x: safe_download(
+ url=x[0],
+ dir=x[1],
+ unzip=unzip,
+ delete=delete,
+ curl=curl,
+ retry=retry,
+ exist_ok=exist_ok,
+ progress=threads <= 1,
+ ),
+ zip(url, repeat(dir)),
+ )
pool.close()
pool.join()
else:
diff --git a/ultralytics/utils/errors.py b/ultralytics/utils/errors.py
index 745ca0a4..86aee1d9 100644
--- a/ultralytics/utils/errors.py
+++ b/ultralytics/utils/errors.py
@@ -17,6 +17,6 @@ class HUBModelError(Exception):
The message is automatically processed through the 'emojis' function from the 'ultralytics.utils' package.
"""
- def __init__(self, message='Model not found. Please check model URL and try again.'):
+ def __init__(self, message="Model not found. Please check model URL and try again."):
"""Create an exception for when a model is not found."""
super().__init__(emojis(message))
diff --git a/ultralytics/utils/files.py b/ultralytics/utils/files.py
index 9fa1488f..ae8e90c2 100644
--- a/ultralytics/utils/files.py
+++ b/ultralytics/utils/files.py
@@ -50,13 +50,13 @@ def spaces_in_path(path):
"""
# If path has spaces, replace them with underscores
- if ' ' in str(path):
+ if " " in str(path):
string = isinstance(path, str) # input type
path = Path(path)
# Create a temporary directory and construct the new path
with tempfile.TemporaryDirectory() as tmp_dir:
- tmp_path = Path(tmp_dir) / path.name.replace(' ', '_')
+ tmp_path = Path(tmp_dir) / path.name.replace(" ", "_")
# Copy file/directory
if path.is_dir():
@@ -82,7 +82,7 @@ def spaces_in_path(path):
yield path
-def increment_path(path, exist_ok=False, sep='', mkdir=False):
+def increment_path(path, exist_ok=False, sep="", mkdir=False):
"""
Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
@@ -102,11 +102,11 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
"""
path = Path(path) # os-agnostic
if path.exists() and not exist_ok:
- path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
+ path, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
# Method 1
for n in range(2, 9999):
- p = f'{path}{sep}{n}{suffix}' # increment path
+ p = f"{path}{sep}{n}{suffix}" # increment path
if not os.path.exists(p):
break
path = Path(p)
@@ -119,14 +119,14 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
def file_age(path=__file__):
"""Return days since last file update."""
- dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
+ dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta
return dt.days # + dt.seconds / 86400 # fractional days
def file_date(path=__file__):
"""Return human-readable file modification date, i.e. '2021-3-26'."""
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
- return f'{t.year}-{t.month}-{t.day}'
+ return f"{t.year}-{t.month}-{t.day}"
def file_size(path):
@@ -137,11 +137,11 @@ def file_size(path):
if path.is_file():
return path.stat().st_size / mb
elif path.is_dir():
- return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
+ return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / mb
return 0.0
-def get_latest_run(search_dir='.'):
+def get_latest_run(search_dir="."):
"""Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
- last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
- return max(last_list, key=os.path.getctime) if last_list else ''
+ last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
+ return max(last_list, key=os.path.getctime) if last_list else ""
diff --git a/ultralytics/utils/instance.py b/ultralytics/utils/instance.py
index a1d85aaf..4e9ef2c6 100644
--- a/ultralytics/utils/instance.py
+++ b/ultralytics/utils/instance.py
@@ -26,9 +26,9 @@ to_4tuple = _ntuple(4)
# `xyxy` means left top and right bottom
# `xywh` means center x, center y and width, height(YOLO format)
# `ltwh` means left top and width, height(COCO format)
-_formats = ['xyxy', 'xywh', 'ltwh']
+_formats = ["xyxy", "xywh", "ltwh"]
-__all__ = 'Bboxes', # tuple or list
+__all__ = ("Bboxes",) # tuple or list
class Bboxes:
@@ -46,9 +46,9 @@ class Bboxes:
This class does not handle normalization or denormalization of bounding boxes.
"""
- def __init__(self, bboxes, format='xyxy') -> None:
+ def __init__(self, bboxes, format="xyxy") -> None:
"""Initializes the Bboxes class with bounding box data in a specified format."""
- assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}'
+ assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
assert bboxes.ndim == 2
assert bboxes.shape[1] == 4
@@ -58,21 +58,21 @@ class Bboxes:
def convert(self, format):
"""Converts bounding box format from one type to another."""
- assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}'
+ assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
if self.format == format:
return
- elif self.format == 'xyxy':
- func = xyxy2xywh if format == 'xywh' else xyxy2ltwh
- elif self.format == 'xywh':
- func = xywh2xyxy if format == 'xyxy' else xywh2ltwh
+ elif self.format == "xyxy":
+ func = xyxy2xywh if format == "xywh" else xyxy2ltwh
+ elif self.format == "xywh":
+ func = xywh2xyxy if format == "xyxy" else xywh2ltwh
else:
- func = ltwh2xyxy if format == 'xyxy' else ltwh2xywh
+ func = ltwh2xyxy if format == "xyxy" else ltwh2xywh
self.bboxes = func(self.bboxes)
self.format = format
def areas(self):
"""Return box areas."""
- self.convert('xyxy')
+ self.convert("xyxy")
return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
# def denormalize(self, w, h):
@@ -124,7 +124,7 @@ class Bboxes:
return len(self.bboxes)
@classmethod
- def concatenate(cls, boxes_list: List['Bboxes'], axis=0) -> 'Bboxes':
+ def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes":
"""
Concatenate a list of Bboxes objects into a single Bboxes object.
@@ -148,7 +148,7 @@ class Bboxes:
return boxes_list[0]
return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
- def __getitem__(self, index) -> 'Bboxes':
+ def __getitem__(self, index) -> "Bboxes":
"""
Retrieve a specific bounding box or a set of bounding boxes using indexing.
@@ -169,7 +169,7 @@ class Bboxes:
if isinstance(index, int):
return Bboxes(self.bboxes[index].view(1, -1))
b = self.bboxes[index]
- assert b.ndim == 2, f'Indexing on Bboxes with {index} failed to return a matrix!'
+ assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"
return Bboxes(b)
@@ -205,7 +205,7 @@ class Instances:
This class does not perform input validation, and it assumes the inputs are well-formed.
"""
- def __init__(self, bboxes, segments=None, keypoints=None, bbox_format='xywh', normalized=True) -> None:
+ def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
"""
Args:
bboxes (ndarray): bboxes with shape [N, 4].
@@ -263,7 +263,7 @@ class Instances:
def add_padding(self, padw, padh):
"""Handle rect and mosaic situation."""
- assert not self.normalized, 'you should add padding with absolute coordinates.'
+ assert not self.normalized, "you should add padding with absolute coordinates."
self._bboxes.add(offset=(padw, padh, padw, padh))
self.segments[..., 0] += padw
self.segments[..., 1] += padh
@@ -271,7 +271,7 @@ class Instances:
self.keypoints[..., 0] += padw
self.keypoints[..., 1] += padh
- def __getitem__(self, index) -> 'Instances':
+ def __getitem__(self, index) -> "Instances":
"""
Retrieve a specific instance or a set of instances using indexing.
@@ -301,7 +301,7 @@ class Instances:
def flipud(self, h):
"""Flips the coordinates of bounding boxes, segments, and keypoints vertically."""
- if self._bboxes.format == 'xyxy':
+ if self._bboxes.format == "xyxy":
y1 = self.bboxes[:, 1].copy()
y2 = self.bboxes[:, 3].copy()
self.bboxes[:, 1] = h - y2
@@ -314,7 +314,7 @@ class Instances:
def fliplr(self, w):
"""Reverses the order of the bounding boxes and segments horizontally."""
- if self._bboxes.format == 'xyxy':
+ if self._bboxes.format == "xyxy":
x1 = self.bboxes[:, 0].copy()
x2 = self.bboxes[:, 2].copy()
self.bboxes[:, 0] = w - x2
@@ -328,10 +328,10 @@ class Instances:
def clip(self, w, h):
"""Clips bounding boxes, segments, and keypoints values to stay within image boundaries."""
ori_format = self._bboxes.format
- self.convert_bbox(format='xyxy')
+ self.convert_bbox(format="xyxy")
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
- if ori_format != 'xyxy':
+ if ori_format != "xyxy":
self.convert_bbox(format=ori_format)
self.segments[..., 0] = self.segments[..., 0].clip(0, w)
self.segments[..., 1] = self.segments[..., 1].clip(0, h)
@@ -367,7 +367,7 @@ class Instances:
return len(self.bboxes)
@classmethod
- def concatenate(cls, instances_list: List['Instances'], axis=0) -> 'Instances':
+ def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances":
"""
Concatenates a list of Instances objects into a single Instances object.
diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py
index 38d633bf..900338dc 100644
--- a/ultralytics/utils/loss.py
+++ b/ultralytics/utils/loss.py
@@ -28,22 +28,27 @@ class VarifocalLoss(nn.Module):
"""Computes varfocal loss."""
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
with torch.cuda.amp.autocast(enabled=False):
- loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
- weight).mean(1).sum()
+ loss = (
+ (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
+ .mean(1)
+ .sum()
+ )
return loss
class FocalLoss(nn.Module):
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
- def __init__(self, ):
+ def __init__(
+ self,
+ ):
"""Initializer for FocalLoss class with no parameters."""
super().__init__()
@staticmethod
def forward(pred, label, gamma=1.5, alpha=0.25):
"""Calculates and updates confusion matrix for object detection/classification tasks."""
- loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none')
+ loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
# p_t = torch.exp(-loss)
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
@@ -91,8 +96,10 @@ class BboxLoss(nn.Module):
tr = tl + 1 # target right
wl = tr - target # weight left
wr = 1 - wl # weight right
- return (F.cross_entropy(pred_dist, tl.view(-1), reduction='none').view(tl.shape) * wl +
- F.cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape) * wr).mean(-1, keepdim=True)
+ return (
+ F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
+ + F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
+ ).mean(-1, keepdim=True)
class RotatedBboxLoss(BboxLoss):
@@ -145,7 +152,7 @@ class v8DetectionLoss:
h = model.args # hyperparameters
m = model.model[-1] # Detect() module
- self.bce = nn.BCEWithLogitsLoss(reduction='none')
+ self.bce = nn.BCEWithLogitsLoss(reduction="none")
self.hyp = h
self.stride = m.stride # model strides
self.nc = m.nc # number of classes
@@ -190,7 +197,8 @@ class v8DetectionLoss:
loss = torch.zeros(3, device=self.device) # box, cls, dfl
feats = preds[1] if isinstance(preds, tuple) else preds
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
- (self.reg_max * 4, self.nc), 1)
+ (self.reg_max * 4, self.nc), 1
+ )
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
@@ -201,7 +209,7 @@ class v8DetectionLoss:
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# Targets
- targets = torch.cat((batch['batch_idx'].view(-1, 1), batch['cls'].view(-1, 1), batch['bboxes']), 1)
+ targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
@@ -210,8 +218,13 @@ class v8DetectionLoss:
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
- pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
- anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
+ pred_scores.detach().sigmoid(),
+ (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
+ anchor_points * stride_tensor,
+ gt_labels,
+ gt_bboxes,
+ mask_gt,
+ )
target_scores_sum = max(target_scores.sum(), 1)
@@ -222,8 +235,9 @@ class v8DetectionLoss:
# Bbox loss
if fg_mask.sum():
target_bboxes /= stride_tensor
- loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
- target_scores_sum, fg_mask)
+ loss[0], loss[2] = self.bbox_loss(
+ pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
+ )
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.cls # cls gain
@@ -246,7 +260,8 @@ class v8SegmentationLoss(v8DetectionLoss):
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
- (self.reg_max * 4, self.nc), 1)
+ (self.reg_max * 4, self.nc), 1
+ )
# B, grids, ..
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
@@ -259,24 +274,31 @@ class v8SegmentationLoss(v8DetectionLoss):
# Targets
try:
- batch_idx = batch['batch_idx'].view(-1, 1)
- targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
+ batch_idx = batch["batch_idx"].view(-1, 1)
+ targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
except RuntimeError as e:
- raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n'
- "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
- "i.e. 'yolo train model=yolov8n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
- "correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
- 'as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help.') from e
+ raise TypeError(
+ "ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
+ "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
+ "i.e. 'yolo train model=yolov8n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
+ "correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
+ "as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
+ ) from e
# Pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
- pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
- anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
+ pred_scores.detach().sigmoid(),
+ (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
+ anchor_points * stride_tensor,
+ gt_labels,
+ gt_bboxes,
+ mask_gt,
+ )
target_scores_sum = max(target_scores.sum(), 1)
@@ -286,15 +308,23 @@ class v8SegmentationLoss(v8DetectionLoss):
if fg_mask.sum():
# Bbox loss
- loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor,
- target_scores, target_scores_sum, fg_mask)
+ loss[0], loss[3] = self.bbox_loss(
+ pred_distri,
+ pred_bboxes,
+ anchor_points,
+ target_bboxes / stride_tensor,
+ target_scores,
+ target_scores_sum,
+ fg_mask,
+ )
# Masks loss
- masks = batch['masks'].to(self.device).float()
+ masks = batch["masks"].to(self.device).float()
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
- masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
+ masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
- loss[1] = self.calculate_segmentation_loss(fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto,
- pred_masks, imgsz, self.overlap)
+ loss[1] = self.calculate_segmentation_loss(
+ fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
+ )
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
else:
@@ -308,8 +338,9 @@ class v8SegmentationLoss(v8DetectionLoss):
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
@staticmethod
- def single_mask_loss(gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor,
- area: torch.Tensor) -> torch.Tensor:
+ def single_mask_loss(
+ gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
+ ) -> torch.Tensor:
"""
Compute the instance segmentation loss for a single image.
@@ -327,8 +358,8 @@ class v8SegmentationLoss(v8DetectionLoss):
The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the
predicted masks from the prototype masks and predicted mask coefficients.
"""
- pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
- loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
+ pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
+ loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()
def calculate_segmentation_loss(
@@ -387,8 +418,9 @@ class v8SegmentationLoss(v8DetectionLoss):
else:
gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
- loss += self.single_mask_loss(gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i],
- marea_i[fg_mask_i])
+ loss += self.single_mask_loss(
+ gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i]
+ )
# WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
else:
@@ -415,7 +447,8 @@ class v8PoseLoss(v8DetectionLoss):
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
- (self.reg_max * 4, self.nc), 1)
+ (self.reg_max * 4, self.nc), 1
+ )
# B, grids, ..
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
@@ -428,8 +461,8 @@ class v8PoseLoss(v8DetectionLoss):
# Targets
batch_size = pred_scores.shape[0]
- batch_idx = batch['batch_idx'].view(-1, 1)
- targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
+ batch_idx = batch["batch_idx"].view(-1, 1)
+ targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
@@ -439,8 +472,13 @@ class v8PoseLoss(v8DetectionLoss):
pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
- pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
- anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
+ pred_scores.detach().sigmoid(),
+ (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
+ anchor_points * stride_tensor,
+ gt_labels,
+ gt_bboxes,
+ mask_gt,
+ )
target_scores_sum = max(target_scores.sum(), 1)
@@ -451,14 +489,16 @@ class v8PoseLoss(v8DetectionLoss):
# Bbox loss
if fg_mask.sum():
target_bboxes /= stride_tensor
- loss[0], loss[4] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
- target_scores_sum, fg_mask)
- keypoints = batch['keypoints'].to(self.device).float().clone()
+ loss[0], loss[4] = self.bbox_loss(
+ pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
+ )
+ keypoints = batch["keypoints"].to(self.device).float().clone()
keypoints[..., 0] *= imgsz[1]
keypoints[..., 1] *= imgsz[0]
- loss[1], loss[2] = self.calculate_keypoints_loss(fg_mask, target_gt_idx, keypoints, batch_idx,
- stride_tensor, target_bboxes, pred_kpts)
+ loss[1], loss[2] = self.calculate_keypoints_loss(
+ fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
+ )
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.pose # pose gain
@@ -477,8 +517,9 @@ class v8PoseLoss(v8DetectionLoss):
y[..., 1] += anchor_points[:, [1]] - 0.5
return y
- def calculate_keypoints_loss(self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes,
- pred_kpts):
+ def calculate_keypoints_loss(
+ self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
+ ):
"""
Calculate the keypoints loss for the model.
@@ -507,21 +548,23 @@ class v8PoseLoss(v8DetectionLoss):
max_kpts = torch.unique(batch_idx, return_counts=True)[1].max()
# Create a tensor to hold batched keypoints
- batched_keypoints = torch.zeros((batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]),
- device=keypoints.device)
+ batched_keypoints = torch.zeros(
+ (batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device
+ )
# TODO: any idea how to vectorize this?
# Fill batched_keypoints with keypoints based on batch_idx
for i in range(batch_size):
keypoints_i = keypoints[batch_idx == i]
- batched_keypoints[i, :keypoints_i.shape[0]] = keypoints_i
+ batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i
# Expand dimensions of target_gt_idx to match the shape of batched_keypoints
target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1)
# Use target_gt_idx_expanded to select keypoints from batched_keypoints
selected_keypoints = batched_keypoints.gather(
- 1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2]))
+ 1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
+ )
# Divide coordinates by stride
selected_keypoints /= stride_tensor.view(1, -1, 1, 1)
@@ -547,13 +590,12 @@ class v8ClassificationLoss:
def __call__(self, preds, batch):
"""Compute the classification loss between predictions and true labels."""
- loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='mean')
+ loss = torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean")
loss_items = loss.detach()
return loss, loss_items
class v8OBBLoss(v8DetectionLoss):
-
def __init__(self, model): # model must be de-paralleled
super().__init__(model)
self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
@@ -583,7 +625,8 @@ class v8OBBLoss(v8DetectionLoss):
feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
- (self.reg_max * 4, self.nc), 1)
+ (self.reg_max * 4, self.nc), 1
+ )
# b, grids, ..
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
@@ -596,19 +639,21 @@ class v8OBBLoss(v8DetectionLoss):
# targets
try:
- batch_idx = batch['batch_idx'].view(-1, 1)
- targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes'].view(-1, 5)), 1)
+ batch_idx = batch["batch_idx"].view(-1, 1)
+ targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
except RuntimeError as e:
- raise TypeError('ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n'
- "This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
- "i.e. 'yolo train model=yolov8n-obb.pt data=coco8.yaml'.\nVerify your dataset is a "
- "correctly formatted 'OBB' dataset using 'data=coco8-obb.yaml' "
- 'as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help.') from e
+ raise TypeError(
+ "ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
+ "This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
+ "i.e. 'yolo train model=yolov8n-obb.pt data=coco8.yaml'.\nVerify your dataset is a "
+ "correctly formatted 'OBB' dataset using 'data=coco8-obb.yaml' "
+ "as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
+ ) from e
# Pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4)
@@ -616,10 +661,14 @@ class v8OBBLoss(v8DetectionLoss):
bboxes_for_assigner = pred_bboxes.clone().detach()
# Only the first four elements need to be scaled
bboxes_for_assigner[..., :4] *= stride_tensor
- _, target_bboxes, target_scores, fg_mask, _ = self.assigner(pred_scores.detach().sigmoid(),
- bboxes_for_assigner.type(gt_bboxes.dtype),
- anchor_points * stride_tensor, gt_labels, gt_bboxes,
- mask_gt)
+ _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
+ pred_scores.detach().sigmoid(),
+ bboxes_for_assigner.type(gt_bboxes.dtype),
+ anchor_points * stride_tensor,
+ gt_labels,
+ gt_bboxes,
+ mask_gt,
+ )
target_scores_sum = max(target_scores.sum(), 1)
@@ -630,8 +679,9 @@ class v8OBBLoss(v8DetectionLoss):
# Bbox loss
if fg_mask.sum():
target_bboxes[..., :4] /= stride_tensor
- loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
- target_scores_sum, fg_mask)
+ loss[0], loss[2] = self.bbox_loss(
+ pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
+ )
else:
loss[0] += (pred_angle * 0).sum()
diff --git a/ultralytics/utils/metrics.py b/ultralytics/utils/metrics.py
index 8d60b372..676ef737 100644
--- a/ultralytics/utils/metrics.py
+++ b/ultralytics/utils/metrics.py
@@ -11,7 +11,10 @@ import torch
from ultralytics.utils import LOGGER, SimpleClass, TryExcept, plt_settings
-OKS_SIGMA = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0
+OKS_SIGMA = (
+ np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89])
+ / 10.0
+)
def bbox_ioa(box1, box2, iou=False, eps=1e-7):
@@ -33,8 +36,9 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
# Intersection area
- inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * \
- (np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)).clip(0)
+ inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * (
+ np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)
+ ).clip(0)
# Box2 area
area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
@@ -99,8 +103,9 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
# Intersection area
- inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * \
- (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp_(0)
+ inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * (
+ b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)
+ ).clamp_(0)
# Union Area
union = w1 * h1 + w2 * h2 - inter + eps
@@ -111,10 +116,10 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
- c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
+ c2 = cw**2 + ch**2 + eps # convex diagonal squared
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
- v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
+ v = (4 / math.pi**2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
with torch.no_grad():
alpha = v / (v - iou + (1 + eps))
return iou - (rho2 / c2 + v * alpha) # CIoU
@@ -202,12 +207,19 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7):
a1, b1, c1 = _get_covariance_matrix(obb1)
a2, b2, c2 = _get_covariance_matrix(obb2)
- t1 = (((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2))) /
- ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.25
+ t1 = (
+ ((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2)))
+ / ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)
+ ) * 0.25
t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.5
- t3 = torch.log(((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2))) /
- (4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) *
- (a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps) + eps) * 0.5
+ t3 = (
+ torch.log(
+ ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)))
+ / (4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) * (a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps)
+ + eps
+ )
+ * 0.5
+ )
bd = t1 + t2 + t3
bd = torch.clamp(bd, eps, 100.0)
hd = torch.sqrt(1.0 - torch.exp(-bd) + eps)
@@ -215,7 +227,7 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7):
if CIoU: # only include the wh aspect ratio part
w1, h1 = obb1[..., 2:4].split(1, dim=-1)
w2, h2 = obb2[..., 2:4].split(1, dim=-1)
- v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
+ v = (4 / math.pi**2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
with torch.no_grad():
alpha = v / (v - iou + (1 + eps))
return iou - v * alpha # CIoU
@@ -239,12 +251,19 @@ def batch_probiou(obb1, obb2, eps=1e-7):
a1, b1, c1 = _get_covariance_matrix(obb1)
a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2))
- t1 = (((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2))) /
- ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.25
+ t1 = (
+ ((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2)))
+ / ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)
+ ) * 0.25
t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.5
- t3 = torch.log(((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2))) /
- (4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) *
- (a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps) + eps) * 0.5
+ t3 = (
+ torch.log(
+ ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)))
+ / (4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) * (a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps)
+ + eps
+ )
+ * 0.5
+ )
bd = t1 + t2 + t3
bd = torch.clamp(bd, eps, 100.0)
hd = torch.sqrt(1.0 - torch.exp(-bd) + eps)
@@ -279,10 +298,10 @@ class ConfusionMatrix:
iou_thres (float): The Intersection over Union threshold.
"""
- def __init__(self, nc, conf=0.25, iou_thres=0.45, task='detect'):
+ def __init__(self, nc, conf=0.25, iou_thres=0.45, task="detect"):
"""Initialize attributes for the YOLO model."""
self.task = task
- self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == 'detect' else np.zeros((nc, nc))
+ self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
self.nc = nc # number of classes
self.conf = 0.25 if conf in (None, 0.001) else conf # apply 0.25 if default val conf is passed
self.iou_thres = iou_thres
@@ -361,11 +380,11 @@ class ConfusionMatrix:
tp = self.matrix.diagonal() # true positives
fp = self.matrix.sum(1) - tp # false positives
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
- return (tp[:-1], fp[:-1]) if self.task == 'detect' else (tp, fp) # remove background class if task=detect
+ return (tp[:-1], fp[:-1]) if self.task == "detect" else (tp, fp) # remove background class if task=detect
- @TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
+ @TryExcept("WARNING ⚠️ ConfusionMatrix plot failure")
@plt_settings()
- def plot(self, normalize=True, save_dir='', names=(), on_plot=None):
+ def plot(self, normalize=True, save_dir="", names=(), on_plot=None):
"""
Plot the confusion matrix using seaborn and save it to a file.
@@ -377,30 +396,31 @@ class ConfusionMatrix:
"""
import seaborn as sn
- array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
+ array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
nc, nn = self.nc, len(names) # number of classes, names
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
- ticklabels = (list(names) + ['background']) if labels else 'auto'
+ ticklabels = (list(names) + ["background"]) if labels else "auto"
with warnings.catch_warnings():
- warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
- sn.heatmap(array,
- ax=ax,
- annot=nc < 30,
- annot_kws={
- 'size': 8},
- cmap='Blues',
- fmt='.2f' if normalize else '.0f',
- square=True,
- vmin=0.0,
- xticklabels=ticklabels,
- yticklabels=ticklabels).set_facecolor((1, 1, 1))
- title = 'Confusion Matrix' + ' Normalized' * normalize
- ax.set_xlabel('True')
- ax.set_ylabel('Predicted')
+ warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered
+ sn.heatmap(
+ array,
+ ax=ax,
+ annot=nc < 30,
+ annot_kws={"size": 8},
+ cmap="Blues",
+ fmt=".2f" if normalize else ".0f",
+ square=True,
+ vmin=0.0,
+ xticklabels=ticklabels,
+ yticklabels=ticklabels,
+ ).set_facecolor((1, 1, 1))
+ title = "Confusion Matrix" + " Normalized" * normalize
+ ax.set_xlabel("True")
+ ax.set_ylabel("Predicted")
ax.set_title(title)
plot_fname = Path(save_dir) / f'{title.lower().replace(" ", "_")}.png'
fig.savefig(plot_fname, dpi=250)
@@ -411,7 +431,7 @@ class ConfusionMatrix:
def print(self):
"""Print the confusion matrix to the console."""
for i in range(self.nc + 1):
- LOGGER.info(' '.join(map(str, self.matrix[i])))
+ LOGGER.info(" ".join(map(str, self.matrix[i])))
def smooth(y, f=0.05):
@@ -419,28 +439,28 @@ def smooth(y, f=0.05):
nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
p = np.ones(nf // 2) # ones padding
yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
- return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed
+ return np.convolve(yp, np.ones(nf) / nf, mode="valid") # y-smoothed
@plt_settings()
-def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=(), on_plot=None):
+def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=(), on_plot=None):
"""Plots a precision-recall curve."""
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
py = np.stack(py, axis=1)
if 0 < len(names) < 21: # display per-class legend if < 21 classes
for i, y in enumerate(py.T):
- ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
+ ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision)
else:
- ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
+ ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
- ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
- ax.set_xlabel('Recall')
- ax.set_ylabel('Precision')
+ ax.plot(px, py.mean(1), linewidth=3, color="blue", label="all classes %.3f mAP@0.5" % ap[:, 0].mean())
+ ax.set_xlabel("Recall")
+ ax.set_ylabel("Precision")
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
- ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left')
- ax.set_title('Precision-Recall Curve')
+ ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
+ ax.set_title("Precision-Recall Curve")
fig.savefig(save_dir, dpi=250)
plt.close(fig)
if on_plot:
@@ -448,24 +468,24 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=(), on_plot=N
@plt_settings()
-def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric', on_plot=None):
+def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names=(), xlabel="Confidence", ylabel="Metric", on_plot=None):
"""Plots a metric-confidence curve."""
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
if 0 < len(names) < 21: # display per-class legend if < 21 classes
for i, y in enumerate(py):
- ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
+ ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric)
else:
- ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
+ ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric)
y = smooth(py.mean(0), 0.05)
- ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
+ ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.2f} at {px[y.argmax()]:.3f}")
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
- ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left')
- ax.set_title(f'{ylabel}-Confidence Curve')
+ ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
+ ax.set_title(f"{ylabel}-Confidence Curve")
fig.savefig(save_dir, dpi=250)
plt.close(fig)
if on_plot:
@@ -494,8 +514,8 @@ def compute_ap(recall, precision):
mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
# Integrate area under curve
- method = 'interp' # methods: 'continuous', 'interp'
- if method == 'interp':
+ method = "interp" # methods: 'continuous', 'interp'
+ if method == "interp":
x = np.linspace(0, 1, 101) # 101-point interp (COCO)
ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
else: # 'continuous'
@@ -505,16 +525,9 @@ def compute_ap(recall, precision):
return ap, mpre, mrec
-def ap_per_class(tp,
- conf,
- pred_cls,
- target_cls,
- plot=False,
- on_plot=None,
- save_dir=Path(),
- names=(),
- eps=1e-16,
- prefix=''):
+def ap_per_class(
+ tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names=(), eps=1e-16, prefix=""
+):
"""
Computes the average precision per class for object detection evaluation.
@@ -591,10 +604,10 @@ def ap_per_class(tp,
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
names = dict(enumerate(names)) # to dict
if plot:
- plot_pr_curve(x, prec_values, ap, save_dir / f'{prefix}PR_curve.png', names, on_plot=on_plot)
- plot_mc_curve(x, f1_curve, save_dir / f'{prefix}F1_curve.png', names, ylabel='F1', on_plot=on_plot)
- plot_mc_curve(x, p_curve, save_dir / f'{prefix}P_curve.png', names, ylabel='Precision', on_plot=on_plot)
- plot_mc_curve(x, r_curve, save_dir / f'{prefix}R_curve.png', names, ylabel='Recall', on_plot=on_plot)
+ plot_pr_curve(x, prec_values, ap, save_dir / f"{prefix}PR_curve.png", names, on_plot=on_plot)
+ plot_mc_curve(x, f1_curve, save_dir / f"{prefix}F1_curve.png", names, ylabel="F1", on_plot=on_plot)
+ plot_mc_curve(x, p_curve, save_dir / f"{prefix}P_curve.png", names, ylabel="Precision", on_plot=on_plot)
+ plot_mc_curve(x, r_curve, save_dir / f"{prefix}R_curve.png", names, ylabel="Recall", on_plot=on_plot)
i = smooth(f1_curve.mean(0), 0.1).argmax() # max F1 index
p, r, f1 = p_curve[:, i], r_curve[:, i], f1_curve[:, i] # max-F1 precision, recall, F1 values
@@ -746,8 +759,18 @@ class Metric(SimpleClass):
Updates the class attributes `self.p`, `self.r`, `self.f1`, `self.all_ap`, and `self.ap_class_index` based
on the values provided in the `results` tuple.
"""
- (self.p, self.r, self.f1, self.all_ap, self.ap_class_index, self.p_curve, self.r_curve, self.f1_curve, self.px,
- self.prec_values) = results
+ (
+ self.p,
+ self.r,
+ self.f1,
+ self.all_ap,
+ self.ap_class_index,
+ self.p_curve,
+ self.r_curve,
+ self.f1_curve,
+ self.px,
+ self.prec_values,
+ ) = results
@property
def curves(self):
@@ -757,8 +780,12 @@ class Metric(SimpleClass):
@property
def curves_results(self):
"""Returns a list of curves for accessing specific metrics curves."""
- return [[self.px, self.prec_values, 'Recall', 'Precision'], [self.px, self.f1_curve, 'Confidence', 'F1'],
- [self.px, self.p_curve, 'Confidence', 'Precision'], [self.px, self.r_curve, 'Confidence', 'Recall']]
+ return [
+ [self.px, self.prec_values, "Recall", "Precision"],
+ [self.px, self.f1_curve, "Confidence", "F1"],
+ [self.px, self.p_curve, "Confidence", "Precision"],
+ [self.px, self.r_curve, "Confidence", "Recall"],
+ ]
class DetMetrics(SimpleClass):
@@ -793,33 +820,35 @@ class DetMetrics(SimpleClass):
curves_results: TODO
"""
- def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
+ def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
"""Initialize a DetMetrics instance with a save directory, plot flag, callback function, and class names."""
self.save_dir = save_dir
self.plot = plot
self.on_plot = on_plot
self.names = names
self.box = Metric()
- self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
- self.task = 'detect'
+ self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
+ self.task = "detect"
def process(self, tp, conf, pred_cls, target_cls):
"""Process predicted results for object detection and update metrics."""
- results = ap_per_class(tp,
- conf,
- pred_cls,
- target_cls,
- plot=self.plot,
- save_dir=self.save_dir,
- names=self.names,
- on_plot=self.on_plot)[2:]
+ results = ap_per_class(
+ tp,
+ conf,
+ pred_cls,
+ target_cls,
+ plot=self.plot,
+ save_dir=self.save_dir,
+ names=self.names,
+ on_plot=self.on_plot,
+ )[2:]
self.box.nc = len(self.names)
self.box.update(results)
@property
def keys(self):
"""Returns a list of keys for accessing specific metrics."""
- return ['metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)']
+ return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
def mean_results(self):
"""Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
@@ -847,12 +876,12 @@ class DetMetrics(SimpleClass):
@property
def results_dict(self):
"""Returns dictionary of computed performance metrics and statistics."""
- return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
+ return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
@property
def curves(self):
"""Returns a list of curves for accessing specific metrics curves."""
- return ['Precision-Recall(B)', 'F1-Confidence(B)', 'Precision-Confidence(B)', 'Recall-Confidence(B)']
+ return ["Precision-Recall(B)", "F1-Confidence(B)", "Precision-Confidence(B)", "Recall-Confidence(B)"]
@property
def curves_results(self):
@@ -889,7 +918,7 @@ class SegmentMetrics(SimpleClass):
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
"""
- def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
+ def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
"""Initialize a SegmentMetrics instance with a save directory, plot flag, callback function, and class names."""
self.save_dir = save_dir
self.plot = plot
@@ -897,8 +926,8 @@ class SegmentMetrics(SimpleClass):
self.names = names
self.box = Metric()
self.seg = Metric()
- self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
- self.task = 'segment'
+ self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
+ self.task = "segment"
def process(self, tp, tp_m, conf, pred_cls, target_cls):
"""
@@ -912,26 +941,30 @@ class SegmentMetrics(SimpleClass):
target_cls (list): List of target classes.
"""
- results_mask = ap_per_class(tp_m,
- conf,
- pred_cls,
- target_cls,
- plot=self.plot,
- on_plot=self.on_plot,
- save_dir=self.save_dir,
- names=self.names,
- prefix='Mask')[2:]
+ results_mask = ap_per_class(
+ tp_m,
+ conf,
+ pred_cls,
+ target_cls,
+ plot=self.plot,
+ on_plot=self.on_plot,
+ save_dir=self.save_dir,
+ names=self.names,
+ prefix="Mask",
+ )[2:]
self.seg.nc = len(self.names)
self.seg.update(results_mask)
- results_box = ap_per_class(tp,
- conf,
- pred_cls,
- target_cls,
- plot=self.plot,
- on_plot=self.on_plot,
- save_dir=self.save_dir,
- names=self.names,
- prefix='Box')[2:]
+ results_box = ap_per_class(
+ tp,
+ conf,
+ pred_cls,
+ target_cls,
+ plot=self.plot,
+ on_plot=self.on_plot,
+ save_dir=self.save_dir,
+ names=self.names,
+ prefix="Box",
+ )[2:]
self.box.nc = len(self.names)
self.box.update(results_box)
@@ -939,8 +972,15 @@ class SegmentMetrics(SimpleClass):
def keys(self):
"""Returns a list of keys for accessing metrics."""
return [
- 'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)',
- 'metrics/precision(M)', 'metrics/recall(M)', 'metrics/mAP50(M)', 'metrics/mAP50-95(M)']
+ "metrics/precision(B)",
+ "metrics/recall(B)",
+ "metrics/mAP50(B)",
+ "metrics/mAP50-95(B)",
+ "metrics/precision(M)",
+ "metrics/recall(M)",
+ "metrics/mAP50(M)",
+ "metrics/mAP50-95(M)",
+ ]
def mean_results(self):
"""Return the mean metrics for bounding box and segmentation results."""
@@ -968,14 +1008,21 @@ class SegmentMetrics(SimpleClass):
@property
def results_dict(self):
"""Returns results of object detection model for evaluation."""
- return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
+ return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
@property
def curves(self):
"""Returns a list of curves for accessing specific metrics curves."""
return [
- 'Precision-Recall(B)', 'F1-Confidence(B)', 'Precision-Confidence(B)', 'Recall-Confidence(B)',
- 'Precision-Recall(M)', 'F1-Confidence(M)', 'Precision-Confidence(M)', 'Recall-Confidence(M)']
+ "Precision-Recall(B)",
+ "F1-Confidence(B)",
+ "Precision-Confidence(B)",
+ "Recall-Confidence(B)",
+ "Precision-Recall(M)",
+ "F1-Confidence(M)",
+ "Precision-Confidence(M)",
+ "Recall-Confidence(M)",
+ ]
@property
def curves_results(self):
@@ -1012,7 +1059,7 @@ class PoseMetrics(SegmentMetrics):
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
"""
- def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
+ def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
"""Initialize the PoseMetrics class with directory path, class names, and plotting options."""
super().__init__(save_dir, plot, names)
self.save_dir = save_dir
@@ -1021,8 +1068,8 @@ class PoseMetrics(SegmentMetrics):
self.names = names
self.box = Metric()
self.pose = Metric()
- self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
- self.task = 'pose'
+ self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
+ self.task = "pose"
def process(self, tp, tp_p, conf, pred_cls, target_cls):
"""
@@ -1036,26 +1083,30 @@ class PoseMetrics(SegmentMetrics):
target_cls (list): List of target classes.
"""
- results_pose = ap_per_class(tp_p,
- conf,
- pred_cls,
- target_cls,
- plot=self.plot,
- on_plot=self.on_plot,
- save_dir=self.save_dir,
- names=self.names,
- prefix='Pose')[2:]
+ results_pose = ap_per_class(
+ tp_p,
+ conf,
+ pred_cls,
+ target_cls,
+ plot=self.plot,
+ on_plot=self.on_plot,
+ save_dir=self.save_dir,
+ names=self.names,
+ prefix="Pose",
+ )[2:]
self.pose.nc = len(self.names)
self.pose.update(results_pose)
- results_box = ap_per_class(tp,
- conf,
- pred_cls,
- target_cls,
- plot=self.plot,
- on_plot=self.on_plot,
- save_dir=self.save_dir,
- names=self.names,
- prefix='Box')[2:]
+ results_box = ap_per_class(
+ tp,
+ conf,
+ pred_cls,
+ target_cls,
+ plot=self.plot,
+ on_plot=self.on_plot,
+ save_dir=self.save_dir,
+ names=self.names,
+ prefix="Box",
+ )[2:]
self.box.nc = len(self.names)
self.box.update(results_box)
@@ -1063,8 +1114,15 @@ class PoseMetrics(SegmentMetrics):
def keys(self):
"""Returns list of evaluation metric keys."""
return [
- 'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)',
- 'metrics/precision(P)', 'metrics/recall(P)', 'metrics/mAP50(P)', 'metrics/mAP50-95(P)']
+ "metrics/precision(B)",
+ "metrics/recall(B)",
+ "metrics/mAP50(B)",
+ "metrics/mAP50-95(B)",
+ "metrics/precision(P)",
+ "metrics/recall(P)",
+ "metrics/mAP50(P)",
+ "metrics/mAP50-95(P)",
+ ]
def mean_results(self):
"""Return the mean results of box and pose."""
@@ -1088,8 +1146,15 @@ class PoseMetrics(SegmentMetrics):
def curves(self):
"""Returns a list of curves for accessing specific metrics curves."""
return [
- 'Precision-Recall(B)', 'F1-Confidence(B)', 'Precision-Confidence(B)', 'Recall-Confidence(B)',
- 'Precision-Recall(P)', 'F1-Confidence(P)', 'Precision-Confidence(P)', 'Recall-Confidence(P)']
+ "Precision-Recall(B)",
+ "F1-Confidence(B)",
+ "Precision-Confidence(B)",
+ "Recall-Confidence(B)",
+ "Precision-Recall(P)",
+ "F1-Confidence(P)",
+ "Precision-Confidence(P)",
+ "Recall-Confidence(P)",
+ ]
@property
def curves_results(self):
@@ -1119,8 +1184,8 @@ class ClassifyMetrics(SimpleClass):
"""Initialize a ClassifyMetrics instance."""
self.top1 = 0
self.top5 = 0
- self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
- self.task = 'classify'
+ self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
+ self.task = "classify"
def process(self, targets, pred):
"""Target classes and predicted classes."""
@@ -1137,12 +1202,12 @@ class ClassifyMetrics(SimpleClass):
@property
def results_dict(self):
"""Returns a dictionary with model's performance metrics and fitness score."""
- return dict(zip(self.keys + ['fitness'], [self.top1, self.top5, self.fitness]))
+ return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness]))
@property
def keys(self):
"""Returns a list of keys for the results_dict property."""
- return ['metrics/accuracy_top1', 'metrics/accuracy_top5']
+ return ["metrics/accuracy_top1", "metrics/accuracy_top5"]
@property
def curves(self):
@@ -1156,32 +1221,33 @@ class ClassifyMetrics(SimpleClass):
class OBBMetrics(SimpleClass):
-
- def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
+ def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
self.save_dir = save_dir
self.plot = plot
self.on_plot = on_plot
self.names = names
self.box = Metric()
- self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
+ self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
def process(self, tp, conf, pred_cls, target_cls):
"""Process predicted results for object detection and update metrics."""
- results = ap_per_class(tp,
- conf,
- pred_cls,
- target_cls,
- plot=self.plot,
- save_dir=self.save_dir,
- names=self.names,
- on_plot=self.on_plot)[2:]
+ results = ap_per_class(
+ tp,
+ conf,
+ pred_cls,
+ target_cls,
+ plot=self.plot,
+ save_dir=self.save_dir,
+ names=self.names,
+ on_plot=self.on_plot,
+ )[2:]
self.box.nc = len(self.names)
self.box.update(results)
@property
def keys(self):
"""Returns a list of keys for accessing specific metrics."""
- return ['metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)']
+ return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
def mean_results(self):
"""Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
@@ -1209,7 +1275,7 @@ class OBBMetrics(SimpleClass):
@property
def results_dict(self):
"""Returns dictionary of computed performance metrics and statistics."""
- return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
+ return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
@property
def curves(self):
diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py
index d589436d..8a86239c 100644
--- a/ultralytics/utils/ops.py
+++ b/ultralytics/utils/ops.py
@@ -52,7 +52,7 @@ class Profile(contextlib.ContextDecorator):
def __str__(self):
"""Returns a human-readable string representing the accumulated elapsed time in the profiler."""
- return f'Elapsed time is {self.t} s'
+ return f"Elapsed time is {self.t} s"
def time(self):
"""Get current time."""
@@ -76,9 +76,13 @@ def segment2box(segment, width=640, height=640):
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
x, y = segment.T # segment xy
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
- x, y, = x[inside], y[inside]
- return np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) if any(x) else np.zeros(
- 4, dtype=segment.dtype) # xyxy
+ x = x[inside]
+ y = y[inside]
+ return (
+ np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype)
+ if any(x)
+ else np.zeros(4, dtype=segment.dtype)
+ ) # xyxy
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False):
@@ -101,8 +105,10 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xyw
"""
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
- pad = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), round(
- (img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1) # wh padding
+ pad = (
+ round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1),
+ round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1),
+ ) # wh padding
else:
gain = ratio_pad[0][0]
pad = ratio_pad[1]
@@ -145,7 +151,7 @@ def nms_rotated(boxes, scores, threshold=0.45):
Returns:
"""
if len(boxes) == 0:
- return np.empty((0, ), dtype=np.int8)
+ return np.empty((0,), dtype=np.int8)
sorted_idx = torch.argsort(scores, descending=True)
boxes = boxes[sorted_idx]
ious = batch_probiou(boxes, boxes).triu_(diagonal=1)
@@ -199,8 +205,8 @@ def non_max_suppression(
"""
# Checks
- assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
- assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
+ assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
+ assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output
@@ -284,7 +290,7 @@ def non_max_suppression(
output[xi] = x[i]
if (time.time() - t) > time_limit:
- LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
+ LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded")
break # time limit exceeded
return output
@@ -378,7 +384,7 @@ def xyxy2xywh(x):
Returns:
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
"""
- assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
+ assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
@@ -398,7 +404,7 @@ def xywh2xyxy(x):
Returns:
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
"""
- assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
+ assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
dw = x[..., 2] / 2 # half-width
dh = x[..., 3] / 2 # half-height
@@ -423,7 +429,7 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
"""
- assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
+ assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
@@ -449,7 +455,7 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
"""
if clip:
x = clip_boxes(x, (h - eps, w - eps))
- assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
+ assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
@@ -526,8 +532,11 @@ def xyxyxyxy2xywhr(corners):
# especially some objects are cut off by augmentations in dataloader.
(x, y), (w, h), angle = cv2.minAreaRect(pts)
rboxes.append([x, y, w, h, angle / 180 * np.pi])
- rboxes = torch.tensor(rboxes, device=corners.device, dtype=corners.dtype) if is_torch else np.asarray(
- rboxes, dtype=points.dtype)
+ rboxes = (
+ torch.tensor(rboxes, device=corners.device, dtype=corners.dtype)
+ if is_torch
+ else np.asarray(rboxes, dtype=points.dtype)
+ )
return rboxes
@@ -546,7 +555,7 @@ def xywhr2xyxyxyxy(center):
cos, sin = (np.cos, np.sin) if is_numpy else (torch.cos, torch.sin)
ctr = center[..., :2]
- w, h, angle = (center[..., i:i + 1] for i in range(2, 5))
+ w, h, angle = (center[..., i : i + 1] for i in range(2, 5))
cos_value, sin_value = cos(angle), sin(angle)
vec1 = [w / 2 * cos_value, w / 2 * sin_value]
vec2 = [-h / 2 * sin_value, h / 2 * cos_value]
@@ -607,8 +616,9 @@ def resample_segments(segments, n=1000):
s = np.concatenate((s, s[0:1, :]), axis=0)
x = np.linspace(0, len(s) - 1, n)
xp = np.arange(len(s))
- segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)],
- dtype=np.float32).reshape(2, -1).T # segment xy
+ segments[i] = (
+ np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T
+ ) # segment xy
return segments
@@ -647,7 +657,7 @@ def process_mask_upsample(protos, masks_in, bboxes, shape):
"""
c, mh, mw = protos.shape # CHW
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
- masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
+ masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
masks = crop_mask(masks, bboxes) # CHW
return masks.gt_(0.5)
@@ -680,7 +690,7 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
masks = crop_mask(masks, downsampled_bboxes) # CHW
if upsample:
- masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
+ masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
return masks.gt_(0.5)
@@ -724,7 +734,7 @@ def scale_masks(masks, shape, padding=True):
bottom, right = (int(round(mh - pad[1] + 0.1)), int(round(mw - pad[0] + 0.1)))
masks = masks[..., top:bottom, left:right]
- masks = F.interpolate(masks, shape, mode='bilinear', align_corners=False) # NCHW
+ masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False) # NCHW
return masks
@@ -763,7 +773,7 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False
return coords
-def masks2segments(masks, strategy='largest'):
+def masks2segments(masks, strategy="largest"):
"""
It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
@@ -775,16 +785,16 @@ def masks2segments(masks, strategy='largest'):
segments (List): list of segment masks
"""
segments = []
- for x in masks.int().cpu().numpy().astype('uint8'):
+ for x in masks.int().cpu().numpy().astype("uint8"):
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
if c:
- if strategy == 'concat': # concatenate all segments
+ if strategy == "concat": # concatenate all segments
c = np.concatenate([x.reshape(-1, 2) for x in c])
- elif strategy == 'largest': # select largest segment
+ elif strategy == "largest": # select largest segment
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
else:
c = np.zeros((0, 2)) # no segments found
- segments.append(c.astype('float32'))
+ segments.append(c.astype("float32"))
return segments
@@ -811,4 +821,4 @@ def clean_str(s):
Returns:
(str): a string with special characters replaced by an underscore _
"""
- return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s)
+ return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
diff --git a/ultralytics/utils/patches.py b/ultralytics/utils/patches.py
index 541cf45a..f9c5bb1b 100644
--- a/ultralytics/utils/patches.py
+++ b/ultralytics/utils/patches.py
@@ -52,7 +52,7 @@ def imshow(winname: str, mat: np.ndarray):
winname (str): Name of the window.
mat (np.ndarray): Image to be shown.
"""
- _imshow(winname.encode('unicode_escape').decode(), mat)
+ _imshow(winname.encode("unicode_escape").decode(), mat)
# PyTorch functions ----------------------------------------------------------------------------------------------------
@@ -72,6 +72,6 @@ def torch_save(*args, **kwargs):
except ImportError:
import pickle
- if 'pickle_module' not in kwargs:
- kwargs['pickle_module'] = pickle # noqa
+ if "pickle_module" not in kwargs:
+ kwargs["pickle_module"] = pickle # noqa
return _torch_save(*args, **kwargs)
diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py
index 46fca19f..280a3e28 100644
--- a/ultralytics/utils/plotting.py
+++ b/ultralytics/utils/plotting.py
@@ -33,15 +33,55 @@ class Colors:
def __init__(self):
"""Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
- hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
- '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
- self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
+ hexs = (
+ "FF3838",
+ "FF9D97",
+ "FF701F",
+ "FFB21D",
+ "CFD231",
+ "48F90A",
+ "92CC17",
+ "3DDB86",
+ "1A9334",
+ "00D4BB",
+ "2C99A8",
+ "00C2FF",
+ "344593",
+ "6473FF",
+ "0018EC",
+ "8438FF",
+ "520085",
+ "CB38FF",
+ "FF95C8",
+ "FF37C7",
+ )
+ self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
self.n = len(self.palette)
- self.pose_palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], [230, 230, 0], [255, 153, 255],
- [153, 204, 255], [255, 102, 255], [255, 51, 255], [102, 178, 255], [51, 153, 255],
- [255, 153, 153], [255, 102, 102], [255, 51, 51], [153, 255, 153], [102, 255, 102],
- [51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0], [255, 255, 255]],
- dtype=np.uint8)
+ self.pose_palette = np.array(
+ [
+ [255, 128, 0],
+ [255, 153, 51],
+ [255, 178, 102],
+ [230, 230, 0],
+ [255, 153, 255],
+ [153, 204, 255],
+ [255, 102, 255],
+ [255, 51, 255],
+ [102, 178, 255],
+ [51, 153, 255],
+ [255, 153, 153],
+ [255, 102, 102],
+ [255, 51, 51],
+ [153, 255, 153],
+ [102, 255, 102],
+ [51, 255, 51],
+ [0, 255, 0],
+ [0, 0, 255],
+ [255, 0, 0],
+ [255, 255, 255],
+ ],
+ dtype=np.uint8,
+ )
def __call__(self, i, bgr=False):
"""Converts hex color codes to RGB values."""
@@ -51,7 +91,7 @@ class Colors:
@staticmethod
def hex2rgb(h):
"""Converts hex color codes to RGB values (i.e. default PIL order)."""
- return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
+ return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
colors = Colors() # create instance for 'from utils.plots import colors'
@@ -71,9 +111,9 @@ class Annotator:
kpt_color (List[int]): Color palette for keypoints.
"""
- def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
+ def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"):
"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
- assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
+ assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images."
non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
self.pil = pil or non_ascii
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
@@ -81,26 +121,45 @@ class Annotator:
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
self.draw = ImageDraw.Draw(self.im)
try:
- font = check_font('Arial.Unicode.ttf' if non_ascii else font)
+ font = check_font("Arial.Unicode.ttf" if non_ascii else font)
size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)
self.font = ImageFont.truetype(str(font), size)
except Exception:
self.font = ImageFont.load_default()
# Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string)
- if check_version(pil_version, '9.2.0'):
+ if check_version(pil_version, "9.2.0"):
self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height
else: # use cv2
self.im = im if im.flags.writeable else im.copy()
self.tf = max(self.lw - 1, 1) # font thickness
self.sf = self.lw / 3 # font scale
# Pose
- self.skeleton = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], [6, 7], [6, 8], [7, 9],
- [8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]
+ self.skeleton = [
+ [16, 14],
+ [14, 12],
+ [17, 15],
+ [15, 13],
+ [12, 13],
+ [6, 12],
+ [7, 13],
+ [6, 7],
+ [6, 8],
+ [7, 9],
+ [8, 10],
+ [9, 11],
+ [2, 3],
+ [1, 2],
+ [1, 3],
+ [2, 4],
+ [3, 5],
+ [4, 6],
+ [5, 7],
+ ]
self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
- def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
+ def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
"""Add one xyxy box to image with label."""
if isinstance(box, torch.Tensor):
box = box.tolist()
@@ -134,13 +193,16 @@ class Annotator:
outside = p1[1] - h >= 3
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
- cv2.putText(self.im,
- label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
- 0,
- self.sf,
- txt_color,
- thickness=self.tf,
- lineType=cv2.LINE_AA)
+ cv2.putText(
+ self.im,
+ label,
+ (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
+ 0,
+ self.sf,
+ txt_color,
+ thickness=self.tf,
+ lineType=cv2.LINE_AA,
+ )
def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
"""
@@ -171,7 +233,7 @@ class Annotator:
im_gpu = im_gpu.flip(dims=[0]) # flip channel
im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
- im_mask = (im_gpu * 255)
+ im_mask = im_gpu * 255
im_mask_np = im_mask.byte().cpu().numpy()
self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)
if self.pil:
@@ -230,9 +292,9 @@ class Annotator:
"""Add rectangle to image (PIL-only)."""
self.draw.rectangle(xy, fill, outline, width)
- def text(self, xy, text, txt_color=(255, 255, 255), anchor='top', box_style=False):
+ def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_style=False):
"""Adds text to an image using PIL or cv2."""
- if anchor == 'bottom': # start y from font bottom
+ if anchor == "bottom": # start y from font bottom
w, h = self.font.getsize(text) # text width, height
xy[1] += 1 - h
if self.pil:
@@ -241,8 +303,8 @@ class Annotator:
self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)
# Using `txt_color` for background and draw fg with white color
txt_color = (255, 255, 255)
- if '\n' in text:
- lines = text.split('\n')
+ if "\n" in text:
+ lines = text.split("\n")
_, h = self.font.getsize(text)
for line in lines:
self.draw.text(xy, line, fill=txt_color, font=self.font)
@@ -314,15 +376,12 @@ class Annotator:
text_y = t_size_in[1]
# Create a rounded rectangle for in_count
- cv2.rectangle(self.im, (text_x - 5, text_y - 5), (text_x + text_width + 7, text_y + t_size_in[1] + 7), color,
- -1)
- cv2.putText(self.im,
- str(counts), (text_x, text_y + t_size_in[1]),
- 0,
- tl / 2,
- txt_color,
- self.tf,
- lineType=cv2.LINE_AA)
+ cv2.rectangle(
+ self.im, (text_x - 5, text_y - 5), (text_x + text_width + 7, text_y + t_size_in[1] + 7), color, -1
+ )
+ cv2.putText(
+ self.im, str(counts), (text_x, text_y + t_size_in[1]), 0, tl / 2, txt_color, self.tf, lineType=cv2.LINE_AA
+ )
@staticmethod
def estimate_pose_angle(a, b, c):
@@ -375,7 +434,7 @@ class Annotator:
center_kpt (int): centroid pose index for workout monitoring
line_thickness (int): thickness for text display
"""
- angle_text, count_text, stage_text = (f' {angle_text:.2f}', 'Steps : ' + f'{count_text}', f' {stage_text}')
+ angle_text, count_text, stage_text = (f" {angle_text:.2f}", "Steps : " + f"{count_text}", f" {stage_text}")
font_scale = 0.6 + (line_thickness / 10.0)
# Draw angle
@@ -383,21 +442,37 @@ class Annotator:
angle_text_position = (int(center_kpt[0]), int(center_kpt[1]))
angle_background_position = (angle_text_position[0], angle_text_position[1] - angle_text_height - 5)
angle_background_size = (angle_text_width + 2 * 5, angle_text_height + 2 * 5 + (line_thickness * 2))
- cv2.rectangle(self.im, angle_background_position, (angle_background_position[0] + angle_background_size[0],
- angle_background_position[1] + angle_background_size[1]),
- (255, 255, 255), -1)
+ cv2.rectangle(
+ self.im,
+ angle_background_position,
+ (
+ angle_background_position[0] + angle_background_size[0],
+ angle_background_position[1] + angle_background_size[1],
+ ),
+ (255, 255, 255),
+ -1,
+ )
cv2.putText(self.im, angle_text, angle_text_position, 0, font_scale, (0, 0, 0), line_thickness)
# Draw Counts
(count_text_width, count_text_height), _ = cv2.getTextSize(count_text, 0, font_scale, line_thickness)
count_text_position = (angle_text_position[0], angle_text_position[1] + angle_text_height + 20)
- count_background_position = (angle_background_position[0],
- angle_background_position[1] + angle_background_size[1] + 5)
+ count_background_position = (
+ angle_background_position[0],
+ angle_background_position[1] + angle_background_size[1] + 5,
+ )
count_background_size = (count_text_width + 10, count_text_height + 10 + (line_thickness * 2))
- cv2.rectangle(self.im, count_background_position, (count_background_position[0] + count_background_size[0],
- count_background_position[1] + count_background_size[1]),
- (255, 255, 255), -1)
+ cv2.rectangle(
+ self.im,
+ count_background_position,
+ (
+ count_background_position[0] + count_background_size[0],
+ count_background_position[1] + count_background_size[1],
+ ),
+ (255, 255, 255),
+ -1,
+ )
cv2.putText(self.im, count_text, count_text_position, 0, font_scale, (0, 0, 0), line_thickness)
# Draw Stage
@@ -406,9 +481,16 @@ class Annotator:
stage_background_position = (stage_text_position[0], stage_text_position[1] - stage_text_height - 5)
stage_background_size = (stage_text_width + 10, stage_text_height + 10)
- cv2.rectangle(self.im, stage_background_position, (stage_background_position[0] + stage_background_size[0],
- stage_background_position[1] + stage_background_size[1]),
- (255, 255, 255), -1)
+ cv2.rectangle(
+ self.im,
+ stage_background_position,
+ (
+ stage_background_position[0] + stage_background_size[0],
+ stage_background_position[1] + stage_background_size[1],
+ ),
+ (255, 255, 255),
+ -1,
+ )
cv2.putText(self.im, stage_text, stage_text_position, 0, font_scale, (0, 0, 0), line_thickness)
def seg_bbox(self, mask, mask_color=(255, 0, 255), det_label=None, track_label=None):
@@ -423,14 +505,20 @@ class Annotator:
"""
cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)
- label = f'Track ID: {track_label}' if track_label else det_label
+ label = f"Track ID: {track_label}" if track_label else det_label
text_size, _ = cv2.getTextSize(label, 0, 0.7, 1)
- cv2.rectangle(self.im, (int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),
- (int(mask[0][0]) + text_size[0] // 2 + 5, int(mask[0][1] + 5)), mask_color, -1)
+ cv2.rectangle(
+ self.im,
+ (int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),
+ (int(mask[0][0]) + text_size[0] // 2 + 5, int(mask[0][1] + 5)),
+ mask_color,
+ -1,
+ )
- cv2.putText(self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1]) - 5), 0, 0.7, (255, 255, 255),
- 2)
+ cv2.putText(
+ self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1]) - 5), 0, 0.7, (255, 255, 255), 2
+ )
def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, 255), thickness=2, pins_radius=10):
"""
@@ -452,24 +540,24 @@ class Annotator:
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
@plt_settings()
-def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
+def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
"""Plot training labels including class histograms and box statistics."""
import pandas as pd
import seaborn as sn
# Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings
- warnings.filterwarnings('ignore', category=UserWarning, message='The figure layout has changed to tight')
- warnings.filterwarnings('ignore', category=FutureWarning)
+ warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
+ warnings.filterwarnings("ignore", category=FutureWarning)
# Plot dataset labels
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
nc = int(cls.max() + 1) # number of classes
boxes = boxes[:1000000] # limit to 1M boxes
- x = pd.DataFrame(boxes, columns=['x', 'y', 'width', 'height'])
+ x = pd.DataFrame(boxes, columns=["x", "y", "width", "height"])
# Seaborn correlogram
- sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
- plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
+ sn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
+ plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
plt.close()
# Matplotlib labels
@@ -477,14 +565,14 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
for i in range(nc):
y[2].patches[i].set_color([x / 255 for x in colors(i)])
- ax[0].set_ylabel('instances')
+ ax[0].set_ylabel("instances")
if 0 < len(names) < 30:
ax[0].set_xticks(range(len(names)))
ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
else:
- ax[0].set_xlabel('classes')
- sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
- sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
+ ax[0].set_xlabel("classes")
+ sn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
+ sn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
# Rectangles
boxes[:, 0:2] = 0.5 # center
@@ -493,20 +581,20 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
for cls, box in zip(cls[:500], boxes[:500]):
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
ax[1].imshow(img)
- ax[1].axis('off')
+ ax[1].axis("off")
for a in [0, 1, 2, 3]:
- for s in ['top', 'right', 'left', 'bottom']:
+ for s in ["top", "right", "left", "bottom"]:
ax[a].spines[s].set_visible(False)
- fname = save_dir / 'labels.jpg'
+ fname = save_dir / "labels.jpg"
plt.savefig(fname, dpi=200)
plt.close()
if on_plot:
on_plot(fname)
-def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
+def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True):
"""
Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
@@ -545,29 +633,31 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False,
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
xyxy = ops.xywh2xyxy(b).long()
xyxy = ops.clip_boxes(xyxy, im.shape)
- crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
+ crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR else -1)]
if save:
file.parent.mkdir(parents=True, exist_ok=True) # make directory
- f = str(increment_path(file).with_suffix('.jpg'))
+ f = str(increment_path(file).with_suffix(".jpg"))
# cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
return crop
@threaded
-def plot_images(images,
- batch_idx,
- cls,
- bboxes=np.zeros(0, dtype=np.float32),
- confs=None,
- masks=np.zeros(0, dtype=np.uint8),
- kpts=np.zeros((0, 51), dtype=np.float32),
- paths=None,
- fname='images.jpg',
- names=None,
- on_plot=None,
- max_subplots=16,
- save=True):
+def plot_images(
+ images,
+ batch_idx,
+ cls,
+ bboxes=np.zeros(0, dtype=np.float32),
+ confs=None,
+ masks=np.zeros(0, dtype=np.uint8),
+ kpts=np.zeros((0, 51), dtype=np.float32),
+ paths=None,
+ fname="images.jpg",
+ names=None,
+ on_plot=None,
+ max_subplots=16,
+ save=True,
+):
"""Plot image grid with labels."""
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
@@ -585,7 +675,7 @@ def plot_images(images,
max_size = 1920 # max image size
bs, _, h, w = images.shape # batch size, _, height, width
bs = min(bs, max_subplots) # limit plot images
- ns = np.ceil(bs ** 0.5) # number of subplots (square)
+ ns = np.ceil(bs**0.5) # number of subplots (square)
if np.max(images[0]) <= 1:
images *= 255 # de-normalise (optional)
@@ -593,7 +683,7 @@ def plot_images(images,
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
for i in range(bs):
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
- mosaic[y:y + h, x:x + w, :] = images[i].transpose(1, 2, 0)
+ mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)
# Resize (optional)
scale = max_size / ns / max(h, w)
@@ -612,7 +702,7 @@ def plot_images(images,
annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
if len(cls) > 0:
idx = batch_idx == i
- classes = cls[idx].astype('int')
+ classes = cls[idx].astype("int")
labels = confs is None
if len(bboxes):
@@ -633,14 +723,14 @@ def plot_images(images,
color = colors(c)
c = names.get(c, c) if names else c
if labels or conf[j] > 0.25: # 0.25 conf thresh
- label = f'{c}' if labels else f'{c} {conf[j]:.1f}'
+ label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
annotator.box_label(box, label, color=color, rotated=is_obb)
elif len(classes):
for c in classes:
color = colors(c)
c = names.get(c, c) if names else c
- annotator.text((x, y), f'{c}', txt_color=color, box_style=True)
+ annotator.text((x, y), f"{c}", txt_color=color, box_style=True)
# Plot keypoints
if len(kpts):
@@ -680,7 +770,9 @@ def plot_images(images,
else:
mask = image_masks[j].astype(bool)
with contextlib.suppress(Exception):
- im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
+ im[y : y + h, x : x + w, :][mask] = (
+ im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
+ )
annotator.fromarray(im)
if save:
annotator.im.save(fname) # save
@@ -691,7 +783,7 @@ def plot_images(images,
@plt_settings()
-def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False, on_plot=None):
+def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
"""
Plot training results from a results CSV file. The function supports various types of data including segmentation,
pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
@@ -714,6 +806,7 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False,
"""
import pandas as pd
from scipy.ndimage import gaussian_filter1d
+
save_dir = Path(file).parent if file else Path(dir)
if classify:
fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
@@ -728,32 +821,32 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False,
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
index = [1, 2, 3, 4, 5, 8, 9, 10, 6, 7]
ax = ax.ravel()
- files = list(save_dir.glob('results*.csv'))
- assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
+ files = list(save_dir.glob("results*.csv"))
+ assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
for f in files:
try:
data = pd.read_csv(f)
s = [x.strip() for x in data.columns]
x = data.values[:, 0]
for i, j in enumerate(index):
- y = data.values[:, j].astype('float')
+ y = data.values[:, j].astype("float")
# y[y == 0] = np.nan # don't show zero values
- ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8) # actual results
- ax[i].plot(x, gaussian_filter1d(y, sigma=3), ':', label='smooth', linewidth=2) # smoothing line
+ ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
+ ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
ax[i].set_title(s[j], fontsize=12)
# if j in [8, 9, 10]: # share train and val loss y axes
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
except Exception as e:
- LOGGER.warning(f'WARNING: Plotting error for {f}: {e}')
+ LOGGER.warning(f"WARNING: Plotting error for {f}: {e}")
ax[1].legend()
- fname = save_dir / 'results.png'
+ fname = save_dir / "results.png"
fig.savefig(fname, dpi=200)
plt.close()
if on_plot:
on_plot(fname)
-def plt_color_scatter(v, f, bins=20, cmap='viridis', alpha=0.8, edgecolors='none'):
+def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
"""
Plots a scatter plot with points colored based on a 2D histogram.
@@ -774,14 +867,18 @@ def plt_color_scatter(v, f, bins=20, cmap='viridis', alpha=0.8, edgecolors='none
# Calculate 2D histogram and corresponding colors
hist, xedges, yedges = np.histogram2d(v, f, bins=bins)
colors = [
- hist[min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),
- min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1)] for i in range(len(v))]
+ hist[
+ min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),
+ min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1),
+ ]
+ for i in range(len(v))
+ ]
# Scatter plot
plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
-def plot_tune_results(csv_file='tune_results.csv'):
+def plot_tune_results(csv_file="tune_results.csv"):
"""
Plot the evolution results stored in an 'tune_results.csv' file. The function generates a scatter plot for each key
in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
@@ -810,33 +907,33 @@ def plot_tune_results(csv_file='tune_results.csv'):
v = x[:, i + num_metrics_columns]
mu = v[j] # best single result
plt.subplot(n, n, i + 1)
- plt_color_scatter(v, fitness, cmap='viridis', alpha=.8, edgecolors='none')
- plt.plot(mu, fitness.max(), 'k+', markersize=15)
- plt.title(f'{k} = {mu:.3g}', fontdict={'size': 9}) # limit to 40 characters
- plt.tick_params(axis='both', labelsize=8) # Set axis label size to 8
+ plt_color_scatter(v, fitness, cmap="viridis", alpha=0.8, edgecolors="none")
+ plt.plot(mu, fitness.max(), "k+", markersize=15)
+ plt.title(f"{k} = {mu:.3g}", fontdict={"size": 9}) # limit to 40 characters
+ plt.tick_params(axis="both", labelsize=8) # Set axis label size to 8
if i % n != 0:
plt.yticks([])
- file = csv_file.with_name('tune_scatter_plots.png') # filename
+ file = csv_file.with_name("tune_scatter_plots.png") # filename
plt.savefig(file, dpi=200)
plt.close()
- LOGGER.info(f'Saved {file}')
+ LOGGER.info(f"Saved {file}")
# Fitness vs iteration
x = range(1, len(fitness) + 1)
plt.figure(figsize=(10, 6), tight_layout=True)
- plt.plot(x, fitness, marker='o', linestyle='none', label='fitness')
- plt.plot(x, gaussian_filter1d(fitness, sigma=3), ':', label='smoothed', linewidth=2) # smoothing line
- plt.title('Fitness vs Iteration')
- plt.xlabel('Iteration')
- plt.ylabel('Fitness')
+ plt.plot(x, fitness, marker="o", linestyle="none", label="fitness")
+ plt.plot(x, gaussian_filter1d(fitness, sigma=3), ":", label="smoothed", linewidth=2) # smoothing line
+ plt.title("Fitness vs Iteration")
+ plt.xlabel("Iteration")
+ plt.ylabel("Fitness")
plt.grid(True)
plt.legend()
- file = csv_file.with_name('tune_fitness.png') # filename
+ file = csv_file.with_name("tune_fitness.png") # filename
plt.savefig(file, dpi=200)
plt.close()
- LOGGER.info(f'Saved {file}')
+ LOGGER.info(f"Saved {file}")
def output_to_target(output, max_det=300):
@@ -861,7 +958,7 @@ def output_to_rotated_target(output, max_det=300):
return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
-def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
+def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
"""
Visualize feature maps of a given model module during inference.
@@ -872,7 +969,7 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec
n (int, optional): Maximum number of feature maps to plot. Defaults to 32.
save_dir (Path, optional): Directory to save results. Defaults to Path('runs/detect/exp').
"""
- for m in ['Detect', 'Pose', 'Segment']:
+ for m in ["Detect", "Pose", "Segment"]:
if m in module_type:
return
batch, channels, height, width = x.shape # batch, channels, height, width
@@ -886,9 +983,9 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec
plt.subplots_adjust(wspace=0.05, hspace=0.05)
for i in range(n):
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
- ax[i].axis('off')
+ ax[i].axis("off")
- LOGGER.info(f'Saving {f}... ({n}/{channels})')
- plt.savefig(f, dpi=300, bbox_inches='tight')
+ LOGGER.info(f"Saving {f}... ({n}/{channels})")
+ plt.savefig(f, dpi=300, bbox_inches="tight")
plt.close()
- np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save
+ np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy()) # npy save
diff --git a/ultralytics/utils/tal.py b/ultralytics/utils/tal.py
index 97406cd7..b72255d4 100644
--- a/ultralytics/utils/tal.py
+++ b/ultralytics/utils/tal.py
@@ -7,7 +7,7 @@ from .checks import check_version
from .metrics import bbox_iou, probiou
from .ops import xywhr2xyxyxyxy
-TORCH_1_10 = check_version(torch.__version__, '1.10.0')
+TORCH_1_10 = check_version(torch.__version__, "1.10.0")
class TaskAlignedAssigner(nn.Module):
@@ -61,12 +61,17 @@ class TaskAlignedAssigner(nn.Module):
if self.n_max_boxes == 0:
device = gt_bboxes.device
- return (torch.full_like(pd_scores[..., 0], self.bg_idx).to(device), torch.zeros_like(pd_bboxes).to(device),
- torch.zeros_like(pd_scores).to(device), torch.zeros_like(pd_scores[..., 0]).to(device),
- torch.zeros_like(pd_scores[..., 0]).to(device))
-
- mask_pos, align_metric, overlaps = self.get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points,
- mask_gt)
+ return (
+ torch.full_like(pd_scores[..., 0], self.bg_idx).to(device),
+ torch.zeros_like(pd_bboxes).to(device),
+ torch.zeros_like(pd_scores).to(device),
+ torch.zeros_like(pd_scores[..., 0]).to(device),
+ torch.zeros_like(pd_scores[..., 0]).to(device),
+ )
+
+ mask_pos, align_metric, overlaps = self.get_pos_mask(
+ pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
+ )
target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
@@ -148,7 +153,7 @@ class TaskAlignedAssigner(nn.Module):
ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
for k in range(self.topk):
# Expand topk_idxs for each value of k and add 1 at the specified positions
- count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones)
+ count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
# count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
# Filter invalid bboxes
count_tensor.masked_fill_(count_tensor > 1, 0)
@@ -192,9 +197,11 @@ class TaskAlignedAssigner(nn.Module):
target_labels.clamp_(0)
# 10x faster than F.one_hot()
- target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes),
- dtype=torch.int64,
- device=target_labels.device) # (b, h*w, 80)
+ target_scores = torch.zeros(
+ (target_labels.shape[0], target_labels.shape[1], self.num_classes),
+ dtype=torch.int64,
+ device=target_labels.device,
+ ) # (b, h*w, 80)
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)
@@ -252,7 +259,6 @@ class TaskAlignedAssigner(nn.Module):
class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
-
def iou_calculation(self, gt_bboxes, pd_bboxes):
"""Iou calculation for rotated bounding boxes."""
return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
@@ -295,7 +301,7 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
_, _, h, w = feats[i].shape
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
- sy, sx = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx)
+ sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
return torch.cat(anchor_points), torch.cat(stride_tensor)
diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py
index 57139d8f..bdd7e335 100644
--- a/ultralytics/utils/torch_utils.py
+++ b/ultralytics/utils/torch_utils.py
@@ -25,11 +25,11 @@ try:
except ImportError:
thop = None
-TORCH_1_9 = check_version(torch.__version__, '1.9.0')
-TORCH_2_0 = check_version(torch.__version__, '2.0.0')
-TORCHVISION_0_10 = check_version(torchvision.__version__, '0.10.0')
-TORCHVISION_0_11 = check_version(torchvision.__version__, '0.11.0')
-TORCHVISION_0_13 = check_version(torchvision.__version__, '0.13.0')
+TORCH_1_9 = check_version(torch.__version__, "1.9.0")
+TORCH_2_0 = check_version(torch.__version__, "2.0.0")
+TORCHVISION_0_10 = check_version(torchvision.__version__, "0.10.0")
+TORCHVISION_0_11 = check_version(torchvision.__version__, "0.11.0")
+TORCHVISION_0_13 = check_version(torchvision.__version__, "0.13.0")
@contextmanager
@@ -60,13 +60,13 @@ def get_cpu_info():
"""Return a string with system CPU information, i.e. 'Apple M2'."""
import cpuinfo # pip install py-cpuinfo
- k = 'brand_raw', 'hardware_raw', 'arch_string_raw' # info keys sorted by preference (not all keys always available)
+ k = "brand_raw", "hardware_raw", "arch_string_raw" # info keys sorted by preference (not all keys always available)
info = cpuinfo.get_cpu_info() # info dict
- string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], 'unknown')
- return string.replace('(R)', '').replace('CPU ', '').replace('@ ', '')
+ string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
+ return string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
-def select_device(device='', batch=0, newline=False, verbose=True):
+def select_device(device="", batch=0, newline=False, verbose=True):
"""
Selects the appropriate PyTorch device based on the provided arguments.
@@ -103,49 +103,57 @@ def select_device(device='', batch=0, newline=False, verbose=True):
if isinstance(device, torch.device):
return device
- s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
+ s = f"Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} "
device = str(device).lower()
- for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
- device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
- cpu = device == 'cpu'
- mps = device in ('mps', 'mps:0') # Apple Metal Performance Shaders (MPS)
+ for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
+ device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
+ cpu = device == "cpu"
+ mps = device in ("mps", "mps:0") # Apple Metal Performance Shaders (MPS)
if cpu or mps:
- os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
elif device: # non-cpu device requested
- if device == 'cuda':
- device = '0'
- visible = os.environ.get('CUDA_VISIBLE_DEVICES', None)
- os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
- if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))):
+ if device == "cuda":
+ device = "0"
+ visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+ os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available()
+ if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(",", ""))):
LOGGER.info(s)
- install = 'See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no ' \
- 'CUDA devices are seen by torch.\n' if torch.cuda.device_count() == 0 else ''
- raise ValueError(f"Invalid CUDA 'device={device}' requested."
- f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
- f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
- f'\ntorch.cuda.is_available(): {torch.cuda.is_available()}'
- f'\ntorch.cuda.device_count(): {torch.cuda.device_count()}'
- f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
- f'{install}')
+ install = (
+ "See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no "
+ "CUDA devices are seen by torch.\n"
+ if torch.cuda.device_count() == 0
+ else ""
+ )
+ raise ValueError(
+ f"Invalid CUDA 'device={device}' requested."
+ f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
+ f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
+ f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}"
+ f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}"
+ f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
+ f"{install}"
+ )
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
- devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
+ devices = device.split(",") if device else "0" # range(torch.cuda.device_count()) # i.e. 0,1,6,7
n = len(devices) # device count
if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
- raise ValueError(f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
- f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}.")
- space = ' ' * (len(s) + 1)
+ raise ValueError(
+ f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
+ f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
+ )
+ space = " " * (len(s) + 1)
for i, d in enumerate(devices):
p = torch.cuda.get_device_properties(i)
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
- arg = 'cuda:0'
+ arg = "cuda:0"
elif mps and TORCH_2_0 and torch.backends.mps.is_available():
# Prefer MPS if available
- s += f'MPS ({get_cpu_info()})\n'
- arg = 'mps'
+ s += f"MPS ({get_cpu_info()})\n"
+ arg = "mps"
else: # revert to CPU
- s += f'CPU ({get_cpu_info()})\n'
- arg = 'cpu'
+ s += f"CPU ({get_cpu_info()})\n"
+ arg = "cpu"
if verbose:
LOGGER.info(s if newline else s.rstrip())
@@ -161,14 +169,20 @@ def time_sync():
def fuse_conv_and_bn(conv, bn):
"""Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/."""
- fusedconv = nn.Conv2d(conv.in_channels,
- conv.out_channels,
- kernel_size=conv.kernel_size,
- stride=conv.stride,
- padding=conv.padding,
- dilation=conv.dilation,
- groups=conv.groups,
- bias=True).requires_grad_(False).to(conv.weight.device)
+ fusedconv = (
+ nn.Conv2d(
+ conv.in_channels,
+ conv.out_channels,
+ kernel_size=conv.kernel_size,
+ stride=conv.stride,
+ padding=conv.padding,
+ dilation=conv.dilation,
+ groups=conv.groups,
+ bias=True,
+ )
+ .requires_grad_(False)
+ .to(conv.weight.device)
+ )
# Prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
@@ -185,15 +199,21 @@ def fuse_conv_and_bn(conv, bn):
def fuse_deconv_and_bn(deconv, bn):
"""Fuse ConvTranspose2d() and BatchNorm2d() layers."""
- fuseddconv = nn.ConvTranspose2d(deconv.in_channels,
- deconv.out_channels,
- kernel_size=deconv.kernel_size,
- stride=deconv.stride,
- padding=deconv.padding,
- output_padding=deconv.output_padding,
- dilation=deconv.dilation,
- groups=deconv.groups,
- bias=True).requires_grad_(False).to(deconv.weight.device)
+ fuseddconv = (
+ nn.ConvTranspose2d(
+ deconv.in_channels,
+ deconv.out_channels,
+ kernel_size=deconv.kernel_size,
+ stride=deconv.stride,
+ padding=deconv.padding,
+ output_padding=deconv.output_padding,
+ dilation=deconv.dilation,
+ groups=deconv.groups,
+ bias=True,
+ )
+ .requires_grad_(False)
+ .to(deconv.weight.device)
+ )
# Prepare filters
w_deconv = deconv.weight.clone().view(deconv.out_channels, -1)
@@ -221,18 +241,21 @@ def model_info(model, detailed=False, verbose=True, imgsz=640):
n_l = len(list(model.modules())) # number of layers
if detailed:
LOGGER.info(
- f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
+ f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}"
+ )
for i, (name, p) in enumerate(model.named_parameters()):
- name = name.replace('module_list.', '')
- LOGGER.info('%5g %40s %9s %12g %20s %10.3g %10.3g %10s' %
- (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype))
+ name = name.replace("module_list.", "")
+ LOGGER.info(
+ "%5g %40s %9s %12g %20s %10.3g %10.3g %10s"
+ % (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype)
+ )
flops = get_flops(model, imgsz)
- fused = ' (fused)' if getattr(model, 'is_fused', lambda: False)() else ''
- fs = f', {flops:.1f} GFLOPs' if flops else ''
- yaml_file = getattr(model, 'yaml_file', '') or getattr(model, 'yaml', {}).get('yaml_file', '')
- model_name = Path(yaml_file).stem.replace('yolo', 'YOLO') or 'Model'
- LOGGER.info(f'{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}')
+ fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
+ fs = f", {flops:.1f} GFLOPs" if flops else ""
+ yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "")
+ model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model"
+ LOGGER.info(f"{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}")
return n_l, n_p, n_g, flops
@@ -262,13 +285,15 @@ def model_info_for_loggers(trainer):
"""
if trainer.args.profile: # profile ONNX and TensorRT times
from ultralytics.utils.benchmarks import ProfileModels
+
results = ProfileModels([trainer.last], device=trainer.device).profile()[0]
- results.pop('model/name')
+ results.pop("model/name")
else: # only return PyTorch times from most recent validation
results = {
- 'model/parameters': get_num_params(trainer.model),
- 'model/GFLOPs': round(get_flops(trainer.model), 3)}
- results['model/speed_PyTorch(ms)'] = round(trainer.validator.speed['inference'], 3)
+ "model/parameters": get_num_params(trainer.model),
+ "model/GFLOPs": round(get_flops(trainer.model), 3),
+ }
+ results["model/speed_PyTorch(ms)"] = round(trainer.validator.speed["inference"], 3)
return results
@@ -284,14 +309,14 @@ def get_flops(model, imgsz=640):
imgsz = [imgsz, imgsz] # expand if int/float
try:
# Use stride size for input tensor
- stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
+ stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
- flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 # stride GFLOPs
+ flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs
return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs
except Exception:
# Use actual image size for input tensor (i.e. required for RTDETR models)
im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
- return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 # imgsz GFLOPs
+ return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs
except Exception:
return 0.0
@@ -301,11 +326,11 @@ def get_flops_with_torch_profiler(model, imgsz=640):
if TORCH_2_0:
model = de_parallel(model)
p = next(model.parameters())
- stride = (max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32) * 2 # max stride
+ stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride
im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
with torch.profiler.profile(with_flops=True) as prof:
model(im)
- flops = sum(x.flops for x in prof.key_averages()) / 1E9
+ flops = sum(x.flops for x in prof.key_averages()) / 1e9
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
return flops
@@ -333,7 +358,7 @@ def scale_img(img, ratio=1.0, same_shape=False, gs=32):
return img
h, w = img.shape[2:]
s = (int(h * ratio), int(w * ratio)) # new size
- img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
+ img = F.interpolate(img, size=s, mode="bilinear", align_corners=False) # resize
if not same_shape: # pad/crop img
h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
@@ -349,7 +374,7 @@ def make_divisible(x, divisor):
def copy_attr(a, b, include=(), exclude=()):
"""Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""
for k, v in b.__dict__.items():
- if (len(include) and k not in include) or k.startswith('_') or k in exclude:
+ if (len(include) and k not in include) or k.startswith("_") or k in exclude:
continue
else:
setattr(a, k, v)
@@ -357,7 +382,7 @@ def copy_attr(a, b, include=(), exclude=()):
def get_latest_opset():
"""Return second-most (for maturity) recently supported ONNX opset by this version of torch."""
- return max(int(k[14:]) for k in vars(torch.onnx) if 'symbolic_opset' in k) - 1 # opset
+ return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1 # opset
def intersect_dicts(da, db, exclude=()):
@@ -392,10 +417,10 @@ def init_seeds(seed=0, deterministic=False):
if TORCH_2_0:
torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible
torch.backends.cudnn.deterministic = True
- os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
- os.environ['PYTHONHASHSEED'] = str(seed)
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+ os.environ["PYTHONHASHSEED"] = str(seed)
else:
- LOGGER.warning('WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.')
+ LOGGER.warning("WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.")
else:
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = False
@@ -430,13 +455,13 @@ class ModelEMA:
v += (1 - d) * msd[k].detach()
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
- def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
+ def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
"""Updates attributes and saves stripped model with optimizer removed."""
if self.enabled:
copy_attr(self.ema, model, include, exclude)
-def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
+def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
"""
Strip optimizer from 'f' to finalize training, optionally save as 's'.
@@ -456,26 +481,26 @@ def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
strip_optimizer(f)
```
"""
- x = torch.load(f, map_location=torch.device('cpu'))
- if 'model' not in x:
- LOGGER.info(f'Skipping {f}, not a valid Ultralytics model.')
+ x = torch.load(f, map_location=torch.device("cpu"))
+ if "model" not in x:
+ LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.")
return
- if hasattr(x['model'], 'args'):
- x['model'].args = dict(x['model'].args) # convert from IterableSimpleNamespace to dict
- args = {**DEFAULT_CFG_DICT, **x['train_args']} if 'train_args' in x else None # combine args
- if x.get('ema'):
- x['model'] = x['ema'] # replace model with ema
- for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
+ if hasattr(x["model"], "args"):
+ x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict
+ args = {**DEFAULT_CFG_DICT, **x["train_args"]} if "train_args" in x else None # combine args
+ if x.get("ema"):
+ x["model"] = x["ema"] # replace model with ema
+ for k in "optimizer", "best_fitness", "ema", "updates": # keys
x[k] = None
- x['epoch'] = -1
- x['model'].half() # to FP16
- for p in x['model'].parameters():
+ x["epoch"] = -1
+ x["model"].half() # to FP16
+ for p in x["model"].parameters():
p.requires_grad = False
- x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
+ x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
# x['model'].args = x['train_args']
torch.save(x, s or f)
- mb = os.path.getsize(s or f) / 1E6 # file size
+ mb = os.path.getsize(s or f) / 1e6 # file size
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
@@ -496,18 +521,20 @@ def profile(input, ops, n=10, device=None):
results = []
if not isinstance(device, torch.device):
device = select_device(device)
- LOGGER.info(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
- f"{'input':>24s}{'output':>24s}")
+ LOGGER.info(
+ f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
+ f"{'input':>24s}{'output':>24s}"
+ )
for x in input if isinstance(input, list) else [input]:
x = x.to(device)
x.requires_grad = True
for m in ops if isinstance(ops, list) else [ops]:
- m = m.to(device) if hasattr(m, 'to') else m # device
- m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
+ m = m.to(device) if hasattr(m, "to") else m # device
+ m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
try:
- flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1E9 * 2 if thop else 0 # GFLOPs
+ flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
except Exception:
flops = 0
@@ -521,13 +548,13 @@ def profile(input, ops, n=10, device=None):
t[2] = time_sync()
except Exception: # no backward method
# print(e) # for debug
- t[2] = float('nan')
+ t[2] = float("nan")
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
- mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
- s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes
+ mem = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 # (GB)
+ s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes
p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
- LOGGER.info(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
+ LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")
results.append([p, flops, mem, tf, tb, s_in, s_out])
except Exception as e:
LOGGER.info(e)
@@ -548,7 +575,7 @@ class EarlyStopping:
"""
self.best_fitness = 0.0 # i.e. mAP
self.best_epoch = 0
- self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
+ self.patience = patience or float("inf") # epochs to wait after fitness stops improving to stop
self.possible_stop = False # possible stop may occur next epoch
def __call__(self, epoch, fitness):
@@ -572,8 +599,10 @@ class EarlyStopping:
self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
stop = delta >= self.patience # stop training if patience exceeded
if stop:
- LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
- f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
- f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
- f'i.e. `patience=300` or use `patience=0` to disable EarlyStopping.')
+ LOGGER.info(
+ f"Stopping training early as no improvement observed in last {self.patience} epochs. "
+ f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n"
+ f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "
+ f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
+ )
return stop
diff --git a/ultralytics/utils/triton.py b/ultralytics/utils/triton.py
index b79b04b4..98c4be20 100644
--- a/ultralytics/utils/triton.py
+++ b/ultralytics/utils/triton.py
@@ -22,7 +22,7 @@ class TritonRemoteModel:
output_names (List[str]): The names of the model outputs.
"""
- def __init__(self, url: str, endpoint: str = '', scheme: str = ''):
+ def __init__(self, url: str, endpoint: str = "", scheme: str = ""):
"""
Initialize the TritonRemoteModel.
@@ -36,7 +36,7 @@ class TritonRemoteModel:
"""
if not endpoint and not scheme: # Parse all args from URL string
splits = urlsplit(url)
- endpoint = splits.path.strip('/').split('/')[0]
+ endpoint = splits.path.strip("/").split("/")[0]
scheme = splits.scheme
url = splits.netloc
@@ -44,26 +44,28 @@ class TritonRemoteModel:
self.url = url
# Choose the Triton client based on the communication scheme
- if scheme == 'http':
+ if scheme == "http":
import tritonclient.http as client # noqa
+
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
config = self.triton_client.get_model_config(endpoint)
else:
import tritonclient.grpc as client # noqa
+
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
- config = self.triton_client.get_model_config(endpoint, as_json=True)['config']
+ config = self.triton_client.get_model_config(endpoint, as_json=True)["config"]
# Sort output names alphabetically, i.e. 'output0', 'output1', etc.
- config['output'] = sorted(config['output'], key=lambda x: x.get('name'))
+ config["output"] = sorted(config["output"], key=lambda x: x.get("name"))
# Define model attributes
- type_map = {'TYPE_FP32': np.float32, 'TYPE_FP16': np.float16, 'TYPE_UINT8': np.uint8}
+ type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8}
self.InferRequestedOutput = client.InferRequestedOutput
self.InferInput = client.InferInput
- self.input_formats = [x['data_type'] for x in config['input']]
+ self.input_formats = [x["data_type"] for x in config["input"]]
self.np_input_formats = [type_map[x] for x in self.input_formats]
- self.input_names = [x['name'] for x in config['input']]
- self.output_names = [x['name'] for x in config['output']]
+ self.input_names = [x["name"] for x in config["input"]]
+ self.output_names = [x["name"] for x in config["output"]]
def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
"""
@@ -80,7 +82,7 @@ class TritonRemoteModel:
for i, x in enumerate(inputs):
if x.dtype != self.np_input_formats[i]:
x = x.astype(self.np_input_formats[i])
- infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace('TYPE_', ''))
+ infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", ""))
infer_input.set_data_from_numpy(x)
infer_inputs.append(infer_input)
diff --git a/ultralytics/utils/tuner.py b/ultralytics/utils/tuner.py
index a06f813d..db7d4907 100644
--- a/ultralytics/utils/tuner.py
+++ b/ultralytics/utils/tuner.py
@@ -6,12 +6,9 @@ from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_save_dir
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS
-def run_ray_tune(model,
- space: dict = None,
- grace_period: int = 10,
- gpu_per_trial: int = None,
- max_samples: int = 10,
- **train_args):
+def run_ray_tune(
+ model, space: dict = None, grace_period: int = 10, gpu_per_trial: int = None, max_samples: int = 10, **train_args
+):
"""
Runs hyperparameter tuning using Ray Tune.
@@ -38,12 +35,12 @@ def run_ray_tune(model,
```
"""
- LOGGER.info('💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune')
+ LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune")
if train_args is None:
train_args = {}
try:
- subprocess.run('pip install ray[tune]'.split(), check=True)
+ subprocess.run("pip install ray[tune]".split(), check=True)
import ray
from ray import tune
@@ -56,33 +53,34 @@ def run_ray_tune(model,
try:
import wandb
- assert hasattr(wandb, '__version__')
+ assert hasattr(wandb, "__version__")
except (ImportError, AssertionError):
wandb = False
default_space = {
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
- 'lr0': tune.uniform(1e-5, 1e-1),
- 'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
- 'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
- 'weight_decay': tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4
- 'warmup_epochs': tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
- 'warmup_momentum': tune.uniform(0.0, 0.95), # warmup initial momentum
- 'box': tune.uniform(0.02, 0.2), # box loss gain
- 'cls': tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
- 'hsv_h': tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
- 'hsv_s': tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
- 'hsv_v': tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
- 'degrees': tune.uniform(0.0, 45.0), # image rotation (+/- deg)
- 'translate': tune.uniform(0.0, 0.9), # image translation (+/- fraction)
- 'scale': tune.uniform(0.0, 0.9), # image scale (+/- gain)
- 'shear': tune.uniform(0.0, 10.0), # image shear (+/- deg)
- 'perspective': tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
- 'flipud': tune.uniform(0.0, 1.0), # image flip up-down (probability)
- 'fliplr': tune.uniform(0.0, 1.0), # image flip left-right (probability)
- 'mosaic': tune.uniform(0.0, 1.0), # image mixup (probability)
- 'mixup': tune.uniform(0.0, 1.0), # image mixup (probability)
- 'copy_paste': tune.uniform(0.0, 1.0)} # segment copy-paste (probability)
+ "lr0": tune.uniform(1e-5, 1e-1),
+ "lrf": tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
+ "momentum": tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
+ "weight_decay": tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4
+ "warmup_epochs": tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
+ "warmup_momentum": tune.uniform(0.0, 0.95), # warmup initial momentum
+ "box": tune.uniform(0.02, 0.2), # box loss gain
+ "cls": tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
+ "hsv_h": tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
+ "hsv_s": tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
+ "hsv_v": tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
+ "degrees": tune.uniform(0.0, 45.0), # image rotation (+/- deg)
+ "translate": tune.uniform(0.0, 0.9), # image translation (+/- fraction)
+ "scale": tune.uniform(0.0, 0.9), # image scale (+/- gain)
+ "shear": tune.uniform(0.0, 10.0), # image shear (+/- deg)
+ "perspective": tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
+ "flipud": tune.uniform(0.0, 1.0), # image flip up-down (probability)
+ "fliplr": tune.uniform(0.0, 1.0), # image flip left-right (probability)
+ "mosaic": tune.uniform(0.0, 1.0), # image mixup (probability)
+ "mixup": tune.uniform(0.0, 1.0), # image mixup (probability)
+ "copy_paste": tune.uniform(0.0, 1.0), # segment copy-paste (probability)
+ }
# Put the model in ray store
task = model.task
@@ -107,35 +105,39 @@ def run_ray_tune(model,
# Get search space
if not space:
space = default_space
- LOGGER.warning('WARNING ⚠️ search space not provided, using default search space.')
+ LOGGER.warning("WARNING ⚠️ search space not provided, using default search space.")
# Get dataset
- data = train_args.get('data', TASK2DATA[task])
- space['data'] = data
- if 'data' not in train_args:
+ data = train_args.get("data", TASK2DATA[task])
+ space["data"] = data
+ if "data" not in train_args:
LOGGER.warning(f'WARNING ⚠️ data not provided, using default "data={data}".')
# Define the trainable function with allocated resources
- trainable_with_resources = tune.with_resources(_tune, {'cpu': NUM_THREADS, 'gpu': gpu_per_trial or 0})
+ trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0})
# Define the ASHA scheduler for hyperparameter search
- asha_scheduler = ASHAScheduler(time_attr='epoch',
- metric=TASK2METRIC[task],
- mode='max',
- max_t=train_args.get('epochs') or DEFAULT_CFG_DICT['epochs'] or 100,
- grace_period=grace_period,
- reduction_factor=3)
+ asha_scheduler = ASHAScheduler(
+ time_attr="epoch",
+ metric=TASK2METRIC[task],
+ mode="max",
+ max_t=train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100,
+ grace_period=grace_period,
+ reduction_factor=3,
+ )
# Define the callbacks for the hyperparameter search
- tuner_callbacks = [WandbLoggerCallback(project='YOLOv8-tune')] if wandb else []
+ tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else []
# Create the Ray Tune hyperparameter search tuner
- tune_dir = get_save_dir(DEFAULT_CFG, name='tune').resolve() # must be absolute dir
+ tune_dir = get_save_dir(DEFAULT_CFG, name="tune").resolve() # must be absolute dir
tune_dir.mkdir(parents=True, exist_ok=True)
- tuner = tune.Tuner(trainable_with_resources,
- param_space=space,
- tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
- run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir))
+ tuner = tune.Tuner(
+ trainable_with_resources,
+ param_space=space,
+ tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
+ run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir),
+ )
# Run the hyperparameter search
tuner.fit()