`ultralytics 8.0.239` Ultralytics Actions and `hub-sdk` adoption (#7431)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
main
Glenn Jocher 1 year ago committed by GitHub
parent e795277391
commit fe27db2f6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      .github/workflows/ci.yaml
  2. 25
      .github/workflows/format.yml
  3. 3
      .pre-commit-config.yaml
  4. 44
      docs/build_docs.py
  5. 56
      docs/build_reference.py
  6. 301
      docs/update_translations.py
  7. 27
      examples/YOLOv8-ONNXRuntime/main.py
  8. 41
      examples/YOLOv8-OpenCV-ONNX-Python/main.py
  9. 87
      examples/YOLOv8-OpenCV-int8-tflite-Python/main.py
  10. 135
      examples/YOLOv8-Region-Counter/yolov8_region_counter.py
  11. 62
      examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py
  12. 83
      examples/YOLOv8-Segmentation-ONNXRuntime-Python/main.py
  13. 10
      pyproject.toml
  14. 2
      tests/test_explorer.py
  15. 4
      ultralytics/__init__.py
  16. 379
      ultralytics/cfg/__init__.py
  17. 11
      ultralytics/data/__init__.py
  18. 8
      ultralytics/data/annotator.py
  19. 355
      ultralytics/data/augment.py
  20. 102
      ultralytics/data/base.py
  21. 57
      ultralytics/data/build.py
  22. 312
      ultralytics/data/converter.py
  23. 212
      ultralytics/data/dataset.py
  24. 2
      ultralytics/data/explorer/__init__.py
  25. 181
      ultralytics/data/explorer/explorer.py
  26. 198
      ultralytics/data/explorer/gui/dash.py
  27. 70
      ultralytics/data/explorer/utils.py
  28. 137
      ultralytics/data/loaders.py
  29. 60
      ultralytics/data/split_dota.py
  30. 302
      ultralytics/data/utils.py
  31. 640
      ultralytics/engine/exporter.py
  32. 136
      ultralytics/engine/model.py
  33. 129
      ultralytics/engine/predictor.py
  34. 88
      ultralytics/engine/results.py
  35. 336
      ultralytics/engine/trainer.py
  36. 125
      ultralytics/engine/tuner.py
  37. 57
      ultralytics/engine/validator.py
  38. 61
      ultralytics/hub/__init__.py
  39. 50
      ultralytics/hub/auth.py
  40. 155
      ultralytics/hub/session.py
  41. 103
      ultralytics/hub/utils.py
  42. 2
      ultralytics/models/__init__.py
  43. 2
      ultralytics/models/fastsam/__init__.py
  44. 12
      ultralytics/models/fastsam/model.py
  45. 5
      ultralytics/models/fastsam/predict.py
  46. 103
      ultralytics/models/fastsam/prompt.py
  47. 2
      ultralytics/models/fastsam/val.py
  48. 2
      ultralytics/models/nas/__init__.py
  49. 17
      ultralytics/models/nas/model.py
  50. 14
      ultralytics/models/nas/predict.py
  51. 20
      ultralytics/models/nas/val.py
  52. 2
      ultralytics/models/rtdetr/__init__.py
  53. 20
      ultralytics/models/rtdetr/model.py
  54. 34
      ultralytics/models/rtdetr/train.py
  55. 42
      ultralytics/models/rtdetr/val.py
  56. 2
      ultralytics/models/sam/__init__.py
  57. 27
      ultralytics/models/sam/amg.py
  58. 86
      ultralytics/models/sam/build.py
  59. 12
      ultralytics/models/sam/model.py
  60. 10
      ultralytics/models/sam/modules/decoders.py
  61. 72
      ultralytics/models/sam/modules/encoders.py
  62. 9
      ultralytics/models/sam/modules/sam.py
  63. 168
      ultralytics/models/sam/modules/tiny_encoder.py
  64. 5
      ultralytics/models/sam/modules/transformer.py
  65. 66
      ultralytics/models/sam/predict.py
  66. 194
      ultralytics/models/utils/loss.py
  67. 65
      ultralytics/models/utils/ops.py
  68. 2
      ultralytics/models/yolo/__init__.py
  69. 2
      ultralytics/models/yolo/classify/__init__.py
  70. 14
      ultralytics/models/yolo/classify/predict.py
  71. 68
      ultralytics/models/yolo/classify/train.py
  72. 50
      ultralytics/models/yolo/classify/val.py
  73. 2
      ultralytics/models/yolo/detect/__init__.py
  74. 14
      ultralytics/models/yolo/detect/predict.py
  75. 84
      ultralytics/models/yolo/detect/train.py
  76. 180
      ultralytics/models/yolo/detect/val.py
  77. 56
      ultralytics/models/yolo/model.py
  78. 2
      ultralytics/models/yolo/obb/__init__.py
  79. 23
      ultralytics/models/yolo/obb/predict.py
  80. 6
      ultralytics/models/yolo/obb/train.py
  81. 127
      ultralytics/models/yolo/obb/val.py
  82. 2
      ultralytics/models/yolo/pose/__init__.py
  83. 29
      ultralytics/models/yolo/pose/predict.py
  84. 53
      ultralytics/models/yolo/pose/train.py
  85. 153
      ultralytics/models/yolo/pose/val.py
  86. 2
      ultralytics/models/yolo/segment/__init__.py
  87. 18
      ultralytics/models/yolo/segment/predict.py
  88. 31
      ultralytics/models/yolo/segment/train.py
  89. 156
      ultralytics/models/yolo/segment/val.py
  90. 32
      ultralytics/nn/__init__.py
  91. 256
      ultralytics/nn/autobackend.py
  92. 109
      ultralytics/nn/modules/__init__.py
  93. 35
      ultralytics/nn/modules/block.py
  94. 64
      ultralytics/nn/modules/conv.py
  95. 110
      ultralytics/nn/modules/head.py
  96. 81
      ultralytics/nn/modules/transformer.py
  97. 35
      ultralytics/nn/modules/utils.py
  98. 355
      ultralytics/nn/tasks.py
  99. 106
      ultralytics/solutions/ai_gym.py
  100. 31
      ultralytics/solutions/distance_calculation.py
  101. Some files were not shown because too many files have changed in this diff Show More

@ -95,7 +95,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: ['3.10']
python-version: ['3.11']
model: [yolov8n]
steps:
- uses: actions/checkout@v4

@ -0,0 +1,25 @@
# Ultralytics 🚀 - AGPL-3.0 license
# Ultralytics Actions https://github.com/ultralytics/actions
# This workflow automatically formats code and documentation in PRs to official Ultralytics standards
name: Ultralytics Actions
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
format:
runs-on: ubuntu-latest
steps:
- name: Run Ultralytics Formatting
uses: ultralytics/actions@main
with:
token: ${{ secrets.GITHUB_TOKEN }} # automatically generated
python: true
docstrings: true
markdown: true
spelling: true
links: false

@ -22,7 +22,6 @@ repos:
- id: check-case-conflict
# - id: check-yaml
- id: check-docstring-first
- id: double-quote-string-fixer
- id: detect-private-key
- repo: https://github.com/asottile/pyupgrade
@ -64,7 +63,7 @@ repos:
- id: codespell
exclude: 'docs/de|docs/fr|docs/pt|docs/es|docs/mkdocs_de.yml'
args:
- --ignore-words-list=crate,nd,strack,dota,ane,segway,fo,gool,winn
- --ignore-words-list=crate,nd,ned,strack,dota,ane,segway,fo,gool,winn,commend,bloc,nam,afterall
- repo: https://github.com/PyCQA/docformatter
rev: v1.7.5

@ -30,45 +30,47 @@ import subprocess
from pathlib import Path
DOCS = Path(__file__).parent.resolve()
SITE = DOCS.parent / 'site'
SITE = DOCS.parent / "site"
def build_docs():
"""Build docs using mkdocs."""
if SITE.exists():
print(f'Removing existing {SITE}')
print(f"Removing existing {SITE}")
shutil.rmtree(SITE)
# Build the main documentation
print(f'Building docs from {DOCS}')
subprocess.run(f'mkdocs build -f {DOCS}/mkdocs.yml', check=True, shell=True)
print(f"Building docs from {DOCS}")
subprocess.run(f"mkdocs build -f {DOCS}/mkdocs.yml", check=True, shell=True)
# Build other localized documentations
for file in DOCS.glob('mkdocs_*.yml'):
print(f'Building MkDocs site with configuration file: {file}')
subprocess.run(f'mkdocs build -f {file}', check=True, shell=True)
print(f'Site built at {SITE}')
for file in DOCS.glob("mkdocs_*.yml"):
print(f"Building MkDocs site with configuration file: {file}")
subprocess.run(f"mkdocs build -f {file}", check=True, shell=True)
print(f"Site built at {SITE}")
def update_html_links():
"""Update href links in HTML files to remove '.md' and '/index.md', excluding links starting with 'https://'."""
html_files = Path(SITE).rglob('*.html')
html_files = Path(SITE).rglob("*.html")
total_updated_links = 0
for html_file in html_files:
with open(html_file, 'r+', encoding='utf-8') as file:
with open(html_file, "r+", encoding="utf-8") as file:
content = file.read()
# Find all links to be updated, excluding those starting with 'https://'
links_to_update = re.findall(r'href="(?!https://)([^"]+?)(/index)?\.md"', content)
# Update the content and count the number of links updated
updated_content, number_of_links_updated = re.subn(r'href="(?!https://)([^"]+?)(/index)?\.md"',
r'href="\1"', content)
updated_content, number_of_links_updated = re.subn(
r'href="(?!https://)([^"]+?)(/index)?\.md"', r'href="\1"', content
)
total_updated_links += number_of_links_updated
# Special handling for '/index' links
updated_content, number_of_index_links_updated = re.subn(r'href="([^"]+)/index"', r'href="\1/"',
updated_content)
updated_content, number_of_index_links_updated = re.subn(
r'href="([^"]+)/index"', r'href="\1/"', updated_content
)
total_updated_links += number_of_index_links_updated
# Write the updated content back to the file
@ -78,23 +80,23 @@ def update_html_links():
# Print updated links for this file
for link in links_to_update:
print(f'Updated link in {html_file}: {link[0]}')
print(f"Updated link in {html_file}: {link[0]}")
print(f'Total number of links updated: {total_updated_links}')
print(f"Total number of links updated: {total_updated_links}")
def update_page_title(file_path: Path, new_title: str):
"""Update the title of an HTML file."""
# Read the content of the file
with open(file_path, encoding='utf-8') as file:
with open(file_path, encoding="utf-8") as file:
content = file.read()
# Replace the existing title with the new title
updated_content = re.sub(r'<title>.*?</title>', f'<title>{new_title}</title>', content)
updated_content = re.sub(r"<title>.*?</title>", f"<title>{new_title}</title>", content)
# Write the updated content back to the file
with open(file_path, 'w', encoding='utf-8') as file:
with open(file_path, "w", encoding="utf-8") as file:
file.write(updated_content)
@ -109,8 +111,8 @@ def main():
print('Serve site at http://localhost:8000 with "python -m http.server --directory site"')
# Update titles
update_page_title(SITE / '404.html', new_title='Ultralytics Docs - Not Found')
update_page_title(SITE / "404.html", new_title="Ultralytics Docs - Not Found")
if __name__ == '__main__':
if __name__ == "__main__":
main()

@ -14,14 +14,14 @@ from ultralytics.utils import ROOT
NEW_YAML_DIR = ROOT.parent
CODE_DIR = ROOT
REFERENCE_DIR = ROOT.parent / 'docs/en/reference'
REFERENCE_DIR = ROOT.parent / "docs/en/reference"
def extract_classes_and_functions(filepath: Path) -> tuple:
"""Extracts class and function names from a given Python file."""
content = filepath.read_text()
class_pattern = r'(?:^|\n)class\s(\w+)(?:\(|:)'
func_pattern = r'(?:^|\n)def\s(\w+)\('
class_pattern = r"(?:^|\n)class\s(\w+)(?:\(|:)"
func_pattern = r"(?:^|\n)def\s(\w+)\("
classes = re.findall(class_pattern, content)
functions = re.findall(func_pattern, content)
@ -31,31 +31,31 @@ def extract_classes_and_functions(filepath: Path) -> tuple:
def create_markdown(py_filepath: Path, module_path: str, classes: list, functions: list):
"""Creates a Markdown file containing the API reference for the given Python module."""
md_filepath = py_filepath.with_suffix('.md')
md_filepath = py_filepath.with_suffix(".md")
# Read existing content and keep header content between first two ---
header_content = ''
header_content = ""
if md_filepath.exists():
existing_content = md_filepath.read_text()
header_parts = existing_content.split('---')
header_parts = existing_content.split("---")
for part in header_parts:
if 'description:' in part or 'comments:' in part:
header_content += f'---{part}---\n\n'
if "description:" in part or "comments:" in part:
header_content += f"---{part}---\n\n"
module_name = module_path.replace('.__init__', '')
module_path = module_path.replace('.', '/')
url = f'https://github.com/ultralytics/ultralytics/blob/main/{module_path}.py'
edit = f'https://github.com/ultralytics/ultralytics/edit/main/{module_path}.py'
module_name = module_path.replace(".__init__", "")
module_path = module_path.replace(".", "/")
url = f"https://github.com/ultralytics/ultralytics/blob/main/{module_path}.py"
edit = f"https://github.com/ultralytics/ultralytics/edit/main/{module_path}.py"
title_content = (
f'# Reference for `{module_path}.py`\n\n'
f'!!! Note\n\n'
f' This file is available at [{url}]({url}). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request]({edit}) 🛠. Thank you 🙏!\n\n'
f"# Reference for `{module_path}.py`\n\n"
f"!!! Note\n\n"
f" This file is available at [{url}]({url}). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request]({edit}) 🛠. Thank you 🙏!\n\n"
)
md_content = ['<br><br>\n'] + [f'## ::: {module_name}.{class_name}\n\n<br><br>\n' for class_name in classes]
md_content.extend(f'## ::: {module_name}.{func_name}\n\n<br><br>\n' for func_name in functions)
md_content = header_content + title_content + '\n'.join(md_content)
if not md_content.endswith('\n'):
md_content += '\n'
md_content = ["<br><br>\n"] + [f"## ::: {module_name}.{class_name}\n\n<br><br>\n" for class_name in classes]
md_content.extend(f"## ::: {module_name}.{func_name}\n\n<br><br>\n" for func_name in functions)
md_content = header_content + title_content + "\n".join(md_content)
if not md_content.endswith("\n"):
md_content += "\n"
md_filepath.parent.mkdir(parents=True, exist_ok=True)
md_filepath.write_text(md_content)
@ -80,28 +80,28 @@ def create_nav_menu_yaml(nav_items: list):
for item_str in nav_items:
item = Path(item_str)
parts = item.parts
current_level = nav_tree['reference']
current_level = nav_tree["reference"]
for part in parts[2:-1]: # skip the first two parts (docs and reference) and the last part (filename)
current_level = current_level[part]
md_file_name = parts[-1].replace('.md', '')
md_file_name = parts[-1].replace(".md", "")
current_level[md_file_name] = item
nav_tree_sorted = sort_nested_dict(nav_tree)
def _dict_to_yaml(d, level=0):
"""Converts a nested dictionary to a YAML-formatted string with indentation."""
yaml_str = ''
indent = ' ' * level
yaml_str = ""
indent = " " * level
for k, v in d.items():
if isinstance(v, dict):
yaml_str += f'{indent}- {k}:\n{_dict_to_yaml(v, level + 1)}'
yaml_str += f"{indent}- {k}:\n{_dict_to_yaml(v, level + 1)}"
else:
yaml_str += f"{indent}- {k}: {str(v).replace('docs/en/', '')}\n"
return yaml_str
# Print updated YAML reference section
print('Scan complete, new mkdocs.yaml reference section is:\n\n', _dict_to_yaml(nav_tree_sorted))
print("Scan complete, new mkdocs.yaml reference section is:\n\n", _dict_to_yaml(nav_tree_sorted))
# Save new YAML reference section
# (NEW_YAML_DIR / 'nav_menu_updated.yml').write_text(_dict_to_yaml(nav_tree_sorted))
@ -111,7 +111,7 @@ def main():
"""Main function to extract class and function names, create Markdown files, and generate a YAML navigation menu."""
nav_items = []
for py_filepath in CODE_DIR.rglob('*.py'):
for py_filepath in CODE_DIR.rglob("*.py"):
classes, functions = extract_classes_and_functions(py_filepath)
if classes or functions:
@ -124,5 +124,5 @@ def main():
create_nav_menu_yaml(nav_items)
if __name__ == '__main__':
if __name__ == "__main__":
main()

@ -22,69 +22,232 @@ class MarkdownLinkFixer:
self.base_dir = Path(base_dir)
self.update_links = update_links
self.update_text = update_text
self.md_link_regex = re.compile(r'\[([^]]+)]\(([^:)]+)\.md\)')
self.md_link_regex = re.compile(r"\[([^]]+)]\(([^:)]+)\.md\)")
@staticmethod
def replace_front_matter(content, lang_dir):
"""Ensure front matter keywords remain in English."""
english = ['comments', 'description', 'keywords']
english = ["comments", "description", "keywords"]
translations = {
'zh': ['评论', '描述', '关键词'], # Mandarin Chinese (Simplified) warning, sometimes translates as 关键字
'es': ['comentarios', 'descripción', 'palabras clave'], # Spanish
'ru': ['комментарии', 'описание', 'ключевые слова'], # Russian
'pt': ['comentários', 'descrição', 'palavras-chave'], # Portuguese
'fr': ['commentaires', 'description', 'mots-clés'], # French
'de': ['kommentare', 'beschreibung', 'schlüsselwörter'], # German
'ja': ['コメント', '説明', 'キーワード'], # Japanese
'ko': ['댓글', '설명', '키워드'], # Korean
'hi': ['िपणि', 'िवरण', 'वर'], # Hindi
'ar': ['التعليقات', 'الوصف', 'الكلمات الرئيسية'] # Arabic
"zh": ["评论", "描述", "关键词"], # Mandarin Chinese (Simplified) warning, sometimes translates as 关键字
"es": ["comentarios", "descripción", "palabras clave"], # Spanish
"ru": ["комментарии", "описание", "ключевые слова"], # Russian
"pt": ["comentários", "descrição", "palavras-chave"], # Portuguese
"fr": ["commentaires", "description", "mots-clés"], # French
"de": ["kommentare", "beschreibung", "schlüsselwörter"], # German
"ja": ["コメント", "説明", "キーワード"], # Japanese
"ko": ["댓글", "설명", "키워드"], # Korean
"hi": ["िपणि", "िवरण", "वर"], # Hindi
"ar": ["التعليقات", "الوصف", "الكلمات الرئيسية"], # Arabic
} # front matter translations for comments, description, keyword
for term, eng_key in zip(translations.get(lang_dir.stem, []), english):
content = re.sub(rf'{term} *[::].*', f'{eng_key}: true', content, flags=re.IGNORECASE) if \
eng_key == 'comments' else re.sub(rf'{term} *[::] *', f'{eng_key}: ', content, flags=re.IGNORECASE)
content = (
re.sub(rf"{term} *[::].*", f"{eng_key}: true", content, flags=re.IGNORECASE)
if eng_key == "comments"
else re.sub(rf"{term} *[::] *", f"{eng_key}: ", content, flags=re.IGNORECASE)
)
return content
@staticmethod
def replace_admonitions(content, lang_dir):
"""Ensure front matter keywords remain in English."""
english = [
'Note', 'Summary', 'Tip', 'Info', 'Success', 'Question', 'Warning', 'Failure', 'Danger', 'Bug', 'Example',
'Quote', 'Abstract', 'Seealso', 'Admonition']
"Note",
"Summary",
"Tip",
"Info",
"Success",
"Question",
"Warning",
"Failure",
"Danger",
"Bug",
"Example",
"Quote",
"Abstract",
"Seealso",
"Admonition",
]
translations = {
'en':
english,
'zh': ['笔记', '摘要', '提示', '信息', '成功', '问题', '警告', '失败', '危险', '故障', '示例', '引用', '摘要', '另见', '警告'],
'es': [
'Nota', 'Resumen', 'Consejo', 'Información', 'Éxito', 'Pregunta', 'Advertencia', 'Fracaso', 'Peligro',
'Error', 'Ejemplo', 'Cita', 'Abstracto', 'Véase También', 'Amonestación'],
'ru': [
'Заметка', 'Сводка', 'Совет', 'Информация', 'Успех', 'Вопрос', 'Предупреждение', 'Неудача', 'Опасность',
'Ошибка', 'Пример', 'Цитата', 'Абстракт', 'См. Также', 'Предостережение'],
'pt': [
'Nota', 'Resumo', 'Dica', 'Informação', 'Sucesso', 'Questão', 'Aviso', 'Falha', 'Perigo', 'Bug',
'Exemplo', 'Citação', 'Abstrato', 'Veja Também', 'Advertência'],
'fr': [
'Note', 'Résumé', 'Conseil', 'Info', 'Succès', 'Question', 'Avertissement', 'Échec', 'Danger', 'Bug',
'Exemple', 'Citation', 'Abstrait', 'Voir Aussi', 'Admonestation'],
'de': [
'Hinweis', 'Zusammenfassung', 'Tipp', 'Info', 'Erfolg', 'Frage', 'Warnung', 'Ausfall', 'Gefahr',
'Fehler', 'Beispiel', 'Zitat', 'Abstrakt', 'Siehe Auch', 'Ermahnung'],
'ja': ['ノート', '要約', 'ヒント', '情報', '成功', '質問', '警告', '失敗', '危険', 'バグ', '', '引用', '抄録', '参照', '訓告'],
'ko': ['노트', '요약', '', '정보', '성공', '질문', '경고', '실패', '위험', '버그', '예제', '인용', '추상', '참조', '경고'],
'hi': [
'', '', '', 'नक', 'सफलत', 'रश', 'वन', 'िफलत', 'खतर', 'बग', 'उदहरण',
'उदधरण', '', '', 'आग'],
'ar': [
'ملاحظة', 'ملخص', 'نصيحة', 'معلومات', 'نجاح', 'سؤال', 'تحذير', 'فشل', 'خطر', 'عطل', 'مثال', 'اقتباس',
'ملخص', 'انظر أيضاً', 'تحذير']}
"en": english,
"zh": [
"笔记",
"摘要",
"提示",
"信息",
"成功",
"问题",
"警告",
"失败",
"危险",
"故障",
"示例",
"引用",
"摘要",
"另见",
"警告",
],
"es": [
"Nota",
"Resumen",
"Consejo",
"Información",
"Éxito",
"Pregunta",
"Advertencia",
"Fracaso",
"Peligro",
"Error",
"Ejemplo",
"Cita",
"Abstracto",
"Véase También",
"Amonestación",
],
"ru": [
"Заметка",
"Сводка",
"Совет",
"Информация",
"Успех",
"Вопрос",
"Предупреждение",
"Неудача",
"Опасность",
"Ошибка",
"Пример",
"Цитата",
"Абстракт",
"См. Также",
"Предостережение",
],
"pt": [
"Nota",
"Resumo",
"Dica",
"Informação",
"Sucesso",
"Questão",
"Aviso",
"Falha",
"Perigo",
"Bug",
"Exemplo",
"Citação",
"Abstrato",
"Veja Também",
"Advertência",
],
"fr": [
"Note",
"Résumé",
"Conseil",
"Info",
"Succès",
"Question",
"Avertissement",
"Échec",
"Danger",
"Bug",
"Exemple",
"Citation",
"Abstrait",
"Voir Aussi",
"Admonestation",
],
"de": [
"Hinweis",
"Zusammenfassung",
"Tipp",
"Info",
"Erfolg",
"Frage",
"Warnung",
"Ausfall",
"Gefahr",
"Fehler",
"Beispiel",
"Zitat",
"Abstrakt",
"Siehe Auch",
"Ermahnung",
],
"ja": [
"ノート",
"要約",
"ヒント",
"情報",
"成功",
"質問",
"警告",
"失敗",
"危険",
"バグ",
"",
"引用",
"抄録",
"参照",
"訓告",
],
"ko": [
"노트",
"요약",
"",
"정보",
"성공",
"질문",
"경고",
"실패",
"위험",
"버그",
"예제",
"인용",
"추상",
"참조",
"경고",
],
"hi": [
"",
"",
"",
"नक",
"सफलत",
"रश",
"वन",
"िफलत",
"खतर",
"बग",
"उदहरण",
"उदधरण",
"",
"",
"आग",
],
"ar": [
"ملاحظة",
"ملخص",
"نصيحة",
"معلومات",
"نجاح",
"سؤال",
"تحذير",
"فشل",
"خطر",
"عطل",
"مثال",
"اقتباس",
"ملخص",
"انظر أيضاً",
"تحذير",
],
}
for term, eng_key in zip(translations.get(lang_dir.stem, []), english):
if lang_dir.stem != 'en':
content = re.sub(rf'!!! *{eng_key} *\n', f'!!! {eng_key} "{term}"\n', content, flags=re.IGNORECASE)
content = re.sub(rf'!!! *{term} *\n', f'!!! {eng_key} "{term}"\n', content, flags=re.IGNORECASE)
content = re.sub(rf'!!! *{term}', f'!!! {eng_key}', content, flags=re.IGNORECASE)
if lang_dir.stem != "en":
content = re.sub(rf"!!! *{eng_key} *\n", f'!!! {eng_key} "{term}"\n', content, flags=re.IGNORECASE)
content = re.sub(rf"!!! *{term} *\n", f'!!! {eng_key} "{term}"\n', content, flags=re.IGNORECASE)
content = re.sub(rf"!!! *{term}", f"!!! {eng_key}", content, flags=re.IGNORECASE)
content = re.sub(r'!!! *"', '!!! Example "', content, flags=re.IGNORECASE)
return content
@ -92,30 +255,30 @@ class MarkdownLinkFixer:
@staticmethod
def update_iframe(content):
"""Update the 'allow' attribute of iframe if it does not contain the specific English permissions."""
english = 'accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share'
english = "accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share"
pattern = re.compile(f'allow="(?!{re.escape(english)}).+?"')
return pattern.sub(f'allow="{english}"', content)
def link_replacer(self, match, parent_dir, lang_dir, use_abs_link=False):
"""Replace broken links with corresponding links in the /en/ directory."""
text, path = match.groups()
linked_path = (parent_dir / path).resolve().with_suffix('.md')
linked_path = (parent_dir / path).resolve().with_suffix(".md")
if not linked_path.exists():
en_linked_path = Path(str(linked_path).replace(str(lang_dir), str(lang_dir.parent / 'en')))
en_linked_path = Path(str(linked_path).replace(str(lang_dir), str(lang_dir.parent / "en")))
if en_linked_path.exists():
if use_abs_link:
# Use absolute links WARNING: BUGS, DO NOT USE
docs_root_relative_path = en_linked_path.relative_to(lang_dir.parent)
updated_path = str(docs_root_relative_path).replace('en/', '/../')
updated_path = str(docs_root_relative_path).replace("en/", "/../")
else:
# Use relative links
steps_up = len(parent_dir.relative_to(self.base_dir).parts)
updated_path = Path('../' * steps_up) / en_linked_path.relative_to(self.base_dir)
updated_path = str(updated_path).replace('/en/', '/')
updated_path = Path("../" * steps_up) / en_linked_path.relative_to(self.base_dir)
updated_path = str(updated_path).replace("/en/", "/")
print(f"Redirecting link '[{text}]({path})' from {parent_dir} to {updated_path}")
return f'[{text}]({updated_path})'
return f"[{text}]({updated_path})"
else:
print(f"Warning: Broken link '[{text}]({path})' found in {parent_dir} does not exist in /docs/en/.")
@ -124,28 +287,30 @@ class MarkdownLinkFixer:
@staticmethod
def update_html_tags(content):
"""Updates HTML tags in docs."""
alt_tag = 'MISSING'
alt_tag = "MISSING"
# Remove closing slashes from self-closing HTML tags
pattern = re.compile(r'<([^>]+?)\s*/>')
content = re.sub(pattern, r'<\1>', content)
pattern = re.compile(r"<([^>]+?)\s*/>")
content = re.sub(pattern, r"<\1>", content)
# Find all images without alt tags and add placeholder alt text
pattern = re.compile(r'!\[(.*?)\]\((.*?)\)')
content, num_replacements = re.subn(pattern, lambda match: f'![{match.group(1) or alt_tag}]({match.group(2)})',
content)
pattern = re.compile(r"!\[(.*?)\]\((.*?)\)")
content, num_replacements = re.subn(
pattern, lambda match: f"![{match.group(1) or alt_tag}]({match.group(2)})", content
)
# Add missing alt tags to HTML images
pattern = re.compile(r'<img\s+(?!.*?\balt\b)[^>]*src=["\'](.*?)["\'][^>]*>')
content, num_replacements = re.subn(pattern, lambda match: match.group(0).replace('>', f' alt="{alt_tag}">', 1),
content)
content, num_replacements = re.subn(
pattern, lambda match: match.group(0).replace(">", f' alt="{alt_tag}">', 1), content
)
return content
def process_markdown_file(self, md_file_path, lang_dir):
"""Process each markdown file in the language directory."""
print(f'Processing file: {md_file_path}')
with open(md_file_path, encoding='utf-8') as file:
print(f"Processing file: {md_file_path}")
with open(md_file_path, encoding="utf-8") as file:
content = file.read()
if self.update_links:
@ -157,23 +322,23 @@ class MarkdownLinkFixer:
content = self.update_iframe(content)
content = self.update_html_tags(content)
with open(md_file_path, 'w', encoding='utf-8') as file:
with open(md_file_path, "w", encoding="utf-8") as file:
file.write(content)
def process_language_directory(self, lang_dir):
"""Process each language-specific directory."""
print(f'Processing language directory: {lang_dir}')
for md_file in lang_dir.rglob('*.md'):
print(f"Processing language directory: {lang_dir}")
for md_file in lang_dir.rglob("*.md"):
self.process_markdown_file(md_file, lang_dir)
def run(self):
"""Run the link fixing and front matter updating process for each language-specific directory."""
for subdir in self.base_dir.iterdir():
if subdir.is_dir() and re.match(r'^\w\w$', subdir.name):
if subdir.is_dir() and re.match(r"^\w\w$", subdir.name):
self.process_language_directory(subdir)
if __name__ == '__main__':
if __name__ == "__main__":
# Set the path to your MkDocs 'docs' directory here
docs_dir = str(Path(__file__).parent.resolve())
fixer = MarkdownLinkFixer(docs_dir, update_links=True, update_text=True)

@ -28,7 +28,7 @@ class YOLOv8:
self.iou_thres = iou_thres
# Load the class names from the COCO dataset
self.classes = yaml_load(check_yaml('coco128.yaml'))['names']
self.classes = yaml_load(check_yaml("coco128.yaml"))["names"]
# Generate a color palette for the classes
self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
@ -57,7 +57,7 @@ class YOLOv8:
cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
# Create the label text with class name and score
label = f'{self.classes[class_id]}: {score:.2f}'
label = f"{self.classes[class_id]}: {score:.2f}"
# Calculate the dimensions of the label text
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
@ -67,8 +67,9 @@ class YOLOv8:
label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
# Draw a filled rectangle as the background for the label text
cv2.rectangle(img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color,
cv2.FILLED)
cv2.rectangle(
img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color, cv2.FILLED
)
# Draw the label text on the image
cv2.putText(img, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
@ -182,7 +183,7 @@ class YOLOv8:
output_img: The output image with drawn detections.
"""
# Create an inference session using the ONNX model and specify execution providers
session = ort.InferenceSession(self.onnx_model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
session = ort.InferenceSession(self.onnx_model, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
# Get the model inputs
model_inputs = session.get_inputs()
@ -202,17 +203,17 @@ class YOLOv8:
return self.postprocess(self.img, outputs) # output image
if __name__ == '__main__':
if __name__ == "__main__":
# Create an argument parser to handle command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='yolov8n.onnx', help='Input your ONNX model.')
parser.add_argument('--img', type=str, default=str(ASSETS / 'bus.jpg'), help='Path to input image.')
parser.add_argument('--conf-thres', type=float, default=0.5, help='Confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.5, help='NMS IoU threshold')
parser.add_argument("--model", type=str, default="yolov8n.onnx", help="Input your ONNX model.")
parser.add_argument("--img", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image.")
parser.add_argument("--conf-thres", type=float, default=0.5, help="Confidence threshold")
parser.add_argument("--iou-thres", type=float, default=0.5, help="NMS IoU threshold")
args = parser.parse_args()
# Check the requirements and select the appropriate backend (CPU or GPU)
check_requirements('onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime')
check_requirements("onnxruntime-gpu" if torch.cuda.is_available() else "onnxruntime")
# Create an instance of the YOLOv8 class with the specified arguments
detection = YOLOv8(args.model, args.img, args.conf_thres, args.iou_thres)
@ -221,8 +222,8 @@ if __name__ == '__main__':
output_image = detection.main()
# Display the output image in a window
cv2.namedWindow('Output', cv2.WINDOW_NORMAL)
cv2.imshow('Output', output_image)
cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
cv2.imshow("Output", output_image)
# Wait for a key press to exit
cv2.waitKey(0)

@ -6,7 +6,7 @@ import numpy as np
from ultralytics.utils import ASSETS, yaml_load
from ultralytics.utils.checks import check_yaml
CLASSES = yaml_load(check_yaml('coco128.yaml'))['names']
CLASSES = yaml_load(check_yaml("coco128.yaml"))["names"]
colors = np.random.uniform(0, 255, size=(len(CLASSES), 3))
@ -23,7 +23,7 @@ def draw_bounding_box(img, class_id, confidence, x, y, x_plus_w, y_plus_h):
x_plus_w (int): X-coordinate of the bottom-right corner of the bounding box.
y_plus_h (int): Y-coordinate of the bottom-right corner of the bounding box.
"""
label = f'{CLASSES[class_id]} ({confidence:.2f})'
label = f"{CLASSES[class_id]} ({confidence:.2f})"
color = colors[class_id]
cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2)
cv2.putText(img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
@ -76,8 +76,11 @@ def main(onnx_model, input_image):
(minScore, maxScore, minClassLoc, (x, maxClassIndex)) = cv2.minMaxLoc(classes_scores)
if maxScore >= 0.25:
box = [
outputs[0][i][0] - (0.5 * outputs[0][i][2]), outputs[0][i][1] - (0.5 * outputs[0][i][3]),
outputs[0][i][2], outputs[0][i][3]]
outputs[0][i][0] - (0.5 * outputs[0][i][2]),
outputs[0][i][1] - (0.5 * outputs[0][i][3]),
outputs[0][i][2],
outputs[0][i][3],
]
boxes.append(box)
scores.append(maxScore)
class_ids.append(maxClassIndex)
@ -92,26 +95,34 @@ def main(onnx_model, input_image):
index = result_boxes[i]
box = boxes[index]
detection = {
'class_id': class_ids[index],
'class_name': CLASSES[class_ids[index]],
'confidence': scores[index],
'box': box,
'scale': scale}
"class_id": class_ids[index],
"class_name": CLASSES[class_ids[index]],
"confidence": scores[index],
"box": box,
"scale": scale,
}
detections.append(detection)
draw_bounding_box(original_image, class_ids[index], scores[index], round(box[0] * scale), round(box[1] * scale),
round((box[0] + box[2]) * scale), round((box[1] + box[3]) * scale))
draw_bounding_box(
original_image,
class_ids[index],
scores[index],
round(box[0] * scale),
round(box[1] * scale),
round((box[0] + box[2]) * scale),
round((box[1] + box[3]) * scale),
)
# Display the image with bounding boxes
cv2.imshow('image', original_image)
cv2.imshow("image", original_image)
cv2.waitKey(0)
cv2.destroyAllWindows()
return detections
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='yolov8n.onnx', help='Input your ONNX model.')
parser.add_argument('--img', default=str(ASSETS / 'bus.jpg'), help='Path to input image.')
parser.add_argument("--model", default="yolov8n.onnx", help="Input your ONNX model.")
parser.add_argument("--img", default=str(ASSETS / "bus.jpg"), help="Path to input image.")
args = parser.parse_args()
main(args.model, args.img)

@ -13,14 +13,9 @@ img_height = 640
class LetterBox:
def __init__(self,
new_shape=(img_width, img_height),
auto=False,
scaleFill=False,
scaleup=True,
center=True,
stride=32):
def __init__(
self, new_shape=(img_width, img_height), auto=False, scaleFill=False, scaleup=True, center=True, stride=32
):
self.new_shape = new_shape
self.auto = auto
self.scaleFill = scaleFill
@ -33,9 +28,9 @@ class LetterBox:
if labels is None:
labels = {}
img = labels.get('img') if image is None else image
img = labels.get("img") if image is None else image
shape = img.shape[:2] # current shape [height, width]
new_shape = labels.pop('rect_shape', self.new_shape)
new_shape = labels.pop("rect_shape", self.new_shape)
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
@ -63,15 +58,16 @@ class LetterBox:
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
value=(114, 114, 114)) # add border
if labels.get('ratio_pad'):
labels['ratio_pad'] = (labels['ratio_pad'], (left, top)) # for evaluation
img = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
) # add border
if labels.get("ratio_pad"):
labels["ratio_pad"] = (labels["ratio_pad"], (left, top)) # for evaluation
if len(labels):
labels = self._update_labels(labels, ratio, dw, dh)
labels['img'] = img
labels['resized_shape'] = new_shape
labels["img"] = img
labels["resized_shape"] = new_shape
return labels
else:
return img
@ -79,15 +75,14 @@ class LetterBox:
def _update_labels(self, labels, ratio, padw, padh):
"""Update labels."""
labels['instances'].convert_bbox(format='xyxy')
labels['instances'].denormalize(*labels['img'].shape[:2][::-1])
labels['instances'].scale(*ratio)
labels['instances'].add_padding(padw, padh)
labels["instances"].convert_bbox(format="xyxy")
labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
labels["instances"].scale(*ratio)
labels["instances"].add_padding(padw, padh)
return labels
class Yolov8TFLite:
def __init__(self, tflite_model, input_image, confidence_thres, iou_thres):
"""
Initializes an instance of the Yolov8TFLite class.
@ -105,7 +100,7 @@ class Yolov8TFLite:
self.iou_thres = iou_thres
# Load the class names from the COCO dataset
self.classes = yaml_load(check_yaml('coco128.yaml'))['names']
self.classes = yaml_load(check_yaml("coco128.yaml"))["names"]
# Generate a color palette for the classes
self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
@ -134,7 +129,7 @@ class Yolov8TFLite:
cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
# Create the label text with class name and score
label = f'{self.classes[class_id]}: {score:.2f}'
label = f"{self.classes[class_id]}: {score:.2f}"
# Calculate the dimensions of the label text
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
@ -144,8 +139,13 @@ class Yolov8TFLite:
label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
# Draw a filled rectangle as the background for the label text
cv2.rectangle(img, (int(label_x), int(label_y - label_height)),
(int(label_x + label_width), int(label_y + label_height)), color, cv2.FILLED)
cv2.rectangle(
img,
(int(label_x), int(label_y - label_height)),
(int(label_x + label_width), int(label_y + label_height)),
color,
cv2.FILLED,
)
# Draw the label text on the image
cv2.putText(img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
@ -161,7 +161,7 @@ class Yolov8TFLite:
# Read the input image using OpenCV
self.img = cv2.imread(self.input_image)
print('image befor', self.img)
print("image before", self.img)
# Get the height and width of the input image
self.img_height, self.img_width = self.img.shape[:2]
@ -209,8 +209,10 @@ class Yolov8TFLite:
# Get the box, score, and class ID corresponding to the index
box = boxes[i]
gain = min(img_width / self.img_width, img_height / self.img_height)
pad = round((img_width - self.img_width * gain) / 2 -
0.1), round((img_height - self.img_height * gain) / 2 - 0.1)
pad = (
round((img_width - self.img_width * gain) / 2 - 0.1),
round((img_height - self.img_height * gain) / 2 - 0.1),
)
box[0] = (box[0] - pad[0]) / gain
box[1] = (box[1] - pad[1]) / gain
box[2] = box[2] / gain
@ -242,7 +244,7 @@ class Yolov8TFLite:
output_details = interpreter.get_output_details()
# Store the shape of the input for later use
input_shape = input_details[0]['shape']
input_shape = input_details[0]["shape"]
self.input_width = input_shape[1]
self.input_height = input_shape[2]
@ -251,19 +253,19 @@ class Yolov8TFLite:
img_data = img_data
# img_data = img_data.cpu().numpy()
# Set the input tensor to the interpreter
print(input_details[0]['index'])
print(input_details[0]["index"])
print(img_data.shape)
img_data = img_data.transpose((0, 2, 3, 1))
scale, zero_point = input_details[0]['quantization']
interpreter.set_tensor(input_details[0]['index'], img_data)
scale, zero_point = input_details[0]["quantization"]
interpreter.set_tensor(input_details[0]["index"], img_data)
# Run inference
interpreter.invoke()
# Get the output tensor from the interpreter
output = interpreter.get_tensor(output_details[0]['index'])
scale, zero_point = output_details[0]['quantization']
output = interpreter.get_tensor(output_details[0]["index"])
scale, zero_point = output_details[0]["quantization"]
output = (output.astype(np.float32) - zero_point) * scale
output[:, [0, 2]] *= img_width
@ -273,16 +275,15 @@ class Yolov8TFLite:
return self.postprocess(self.img, output)
if __name__ == '__main__':
if __name__ == "__main__":
# Create an argument parser to handle command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--model',
type=str,
default='yolov8n_full_integer_quant.tflite',
help='Input your TFLite model.')
parser.add_argument('--img', type=str, default=str(ASSETS / 'bus.jpg'), help='Path to input image.')
parser.add_argument('--conf-thres', type=float, default=0.5, help='Confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.5, help='NMS IoU threshold')
parser.add_argument(
"--model", type=str, default="yolov8n_full_integer_quant.tflite", help="Input your TFLite model."
)
parser.add_argument("--img", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image.")
parser.add_argument("--conf-thres", type=float, default=0.5, help="Confidence threshold")
parser.add_argument("--iou-thres", type=float, default=0.5, help="NMS IoU threshold")
args = parser.parse_args()
# Create an instance of the Yolov8TFLite class with the specified arguments
@ -292,7 +293,7 @@ if __name__ == '__main__':
output_image = detection.main()
# Display the output image in a window
cv2.imshow('Output', output_image)
cv2.imshow("Output", output_image)
# Wait for a key press to exit
cv2.waitKey(0)

@ -16,21 +16,22 @@ track_history = defaultdict(list)
current_region = None
counting_regions = [
{
'name': 'YOLOv8 Polygon Region',
'polygon': Polygon([(50, 80), (250, 20), (450, 80), (400, 350), (100, 350)]), # Polygon points
'counts': 0,
'dragging': False,
'region_color': (255, 42, 4), # BGR Value
'text_color': (255, 255, 255) # Region Text Color
"name": "YOLOv8 Polygon Region",
"polygon": Polygon([(50, 80), (250, 20), (450, 80), (400, 350), (100, 350)]), # Polygon points
"counts": 0,
"dragging": False,
"region_color": (255, 42, 4), # BGR Value
"text_color": (255, 255, 255), # Region Text Color
},
{
'name': 'YOLOv8 Rectangle Region',
'polygon': Polygon([(200, 250), (440, 250), (440, 550), (200, 550)]), # Polygon points
'counts': 0,
'dragging': False,
'region_color': (37, 255, 225), # BGR Value
'text_color': (0, 0, 0), # Region Text Color
}, ]
"name": "YOLOv8 Rectangle Region",
"polygon": Polygon([(200, 250), (440, 250), (440, 550), (200, 550)]), # Polygon points
"counts": 0,
"dragging": False,
"region_color": (37, 255, 225), # BGR Value
"text_color": (0, 0, 0), # Region Text Color
},
]
def mouse_callback(event, x, y, flags, param):
@ -40,32 +41,33 @@ def mouse_callback(event, x, y, flags, param):
# Mouse left button down event
if event == cv2.EVENT_LBUTTONDOWN:
for region in counting_regions:
if region['polygon'].contains(Point((x, y))):
if region["polygon"].contains(Point((x, y))):
current_region = region
current_region['dragging'] = True
current_region['offset_x'] = x
current_region['offset_y'] = y
current_region["dragging"] = True
current_region["offset_x"] = x
current_region["offset_y"] = y
# Mouse move event
elif event == cv2.EVENT_MOUSEMOVE:
if current_region is not None and current_region['dragging']:
dx = x - current_region['offset_x']
dy = y - current_region['offset_y']
current_region['polygon'] = Polygon([
(p[0] + dx, p[1] + dy) for p in current_region['polygon'].exterior.coords])
current_region['offset_x'] = x
current_region['offset_y'] = y
if current_region is not None and current_region["dragging"]:
dx = x - current_region["offset_x"]
dy = y - current_region["offset_y"]
current_region["polygon"] = Polygon(
[(p[0] + dx, p[1] + dy) for p in current_region["polygon"].exterior.coords]
)
current_region["offset_x"] = x
current_region["offset_y"] = y
# Mouse left button up event
elif event == cv2.EVENT_LBUTTONUP:
if current_region is not None and current_region['dragging']:
current_region['dragging'] = False
if current_region is not None and current_region["dragging"]:
current_region["dragging"] = False
def run(
weights='yolov8n.pt',
weights="yolov8n.pt",
source=None,
device='cpu',
device="cpu",
view_img=False,
save_img=False,
exist_ok=False,
@ -100,8 +102,8 @@ def run(
raise FileNotFoundError(f"Source path '{source}' does not exist.")
# Setup Model
model = YOLO(f'{weights}')
model.to('cuda') if device == '0' else model.to('cpu')
model = YOLO(f"{weights}")
model.to("cuda") if device == "0" else model.to("cpu")
# Extract classes names
names = model.model.names
@ -109,12 +111,12 @@ def run(
# Video setup
videocapture = cv2.VideoCapture(source)
frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4))
fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*'mp4v')
fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*"mp4v")
# Output setup
save_dir = increment_path(Path('ultralytics_rc_output') / 'exp', exist_ok)
save_dir = increment_path(Path("ultralytics_rc_output") / "exp", exist_ok)
save_dir.mkdir(parents=True, exist_ok=True)
video_writer = cv2.VideoWriter(str(save_dir / f'{Path(source).stem}.mp4'), fourcc, fps, (frame_width, frame_height))
video_writer = cv2.VideoWriter(str(save_dir / f"{Path(source).stem}.mp4"), fourcc, fps, (frame_width, frame_height))
# Iterate over video frames
while videocapture.isOpened():
@ -146,43 +148,48 @@ def run(
# Check if detection inside region
for region in counting_regions:
if region['polygon'].contains(Point((bbox_center[0], bbox_center[1]))):
region['counts'] += 1
if region["polygon"].contains(Point((bbox_center[0], bbox_center[1]))):
region["counts"] += 1
# Draw regions (Polygons/Rectangles)
for region in counting_regions:
region_label = str(region['counts'])
region_color = region['region_color']
region_text_color = region['text_color']
region_label = str(region["counts"])
region_color = region["region_color"]
region_text_color = region["text_color"]
polygon_coords = np.array(region['polygon'].exterior.coords, dtype=np.int32)
centroid_x, centroid_y = int(region['polygon'].centroid.x), int(region['polygon'].centroid.y)
polygon_coords = np.array(region["polygon"].exterior.coords, dtype=np.int32)
centroid_x, centroid_y = int(region["polygon"].centroid.x), int(region["polygon"].centroid.y)
text_size, _ = cv2.getTextSize(region_label,
cv2.FONT_HERSHEY_SIMPLEX,
fontScale=0.7,
thickness=line_thickness)
text_size, _ = cv2.getTextSize(
region_label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.7, thickness=line_thickness
)
text_x = centroid_x - text_size[0] // 2
text_y = centroid_y + text_size[1] // 2
cv2.rectangle(frame, (text_x - 5, text_y - text_size[1] - 5), (text_x + text_size[0] + 5, text_y + 5),
region_color, -1)
cv2.putText(frame, region_label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, region_text_color,
line_thickness)
cv2.rectangle(
frame,
(text_x - 5, text_y - text_size[1] - 5),
(text_x + text_size[0] + 5, text_y + 5),
region_color,
-1,
)
cv2.putText(
frame, region_label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, region_text_color, line_thickness
)
cv2.polylines(frame, [polygon_coords], isClosed=True, color=region_color, thickness=region_thickness)
if view_img:
if vid_frame_count == 1:
cv2.namedWindow('Ultralytics YOLOv8 Region Counter Movable')
cv2.setMouseCallback('Ultralytics YOLOv8 Region Counter Movable', mouse_callback)
cv2.imshow('Ultralytics YOLOv8 Region Counter Movable', frame)
cv2.namedWindow("Ultralytics YOLOv8 Region Counter Movable")
cv2.setMouseCallback("Ultralytics YOLOv8 Region Counter Movable", mouse_callback)
cv2.imshow("Ultralytics YOLOv8 Region Counter Movable", frame)
if save_img:
video_writer.write(frame)
for region in counting_regions: # Reinitialize count for each region
region['counts'] = 0
region["counts"] = 0
if cv2.waitKey(1) & 0xFF == ord('q'):
if cv2.waitKey(1) & 0xFF == ord("q"):
break
del vid_frame_count
@ -194,16 +201,16 @@ def run(
def parse_opt():
"""Parse command line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='yolov8n.pt', help='initial weights path')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--source', type=str, required=True, help='video file path')
parser.add_argument('--view-img', action='store_true', help='show results')
parser.add_argument('--save-img', action='store_true', help='save results')
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
parser.add_argument('--line-thickness', type=int, default=2, help='bounding box thickness')
parser.add_argument('--track-thickness', type=int, default=2, help='Tracking line thickness')
parser.add_argument('--region-thickness', type=int, default=4, help='Region thickness')
parser.add_argument("--weights", type=str, default="yolov8n.pt", help="initial weights path")
parser.add_argument("--device", default="", help="cuda device, i.e. 0 or 0,1,2,3 or cpu")
parser.add_argument("--source", type=str, required=True, help="video file path")
parser.add_argument("--view-img", action="store_true", help="show results")
parser.add_argument("--save-img", action="store_true", help="save results")
parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment")
parser.add_argument("--classes", nargs="+", type=int, help="filter by class: --classes 0, or --classes 0 2 3")
parser.add_argument("--line-thickness", type=int, default=2, help="bounding box thickness")
parser.add_argument("--track-thickness", type=int, default=2, help="Tracking line thickness")
parser.add_argument("--region-thickness", type=int, default=4, help="Region thickness")
return parser.parse_args()
@ -213,6 +220,6 @@ def main(opt):
run(**vars(opt))
if __name__ == '__main__':
if __name__ == "__main__":
opt = parse_opt()
main(opt)

@ -9,7 +9,7 @@ from sahi.utils.yolov8 import download_yolov8s_model
from ultralytics.utils.files import increment_path
def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False, exist_ok=False):
def run(weights="yolov8n.pt", source="test.mp4", view_img=False, save_img=False, exist_ok=False):
"""
Run object detection on a video using YOLOv8 and SAHI.
@ -25,41 +25,41 @@ def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False,
if not Path(source).exists():
raise FileNotFoundError(f"Source path '{source}' does not exist.")
yolov8_model_path = f'models/{weights}'
yolov8_model_path = f"models/{weights}"
download_yolov8s_model(yolov8_model_path)
detection_model = AutoDetectionModel.from_pretrained(model_type='yolov8',
model_path=yolov8_model_path,
confidence_threshold=0.3,
device='cpu')
detection_model = AutoDetectionModel.from_pretrained(
model_type="yolov8", model_path=yolov8_model_path, confidence_threshold=0.3, device="cpu"
)
# Video setup
videocapture = cv2.VideoCapture(source)
frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4))
fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*'mp4v')
fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*"mp4v")
# Output setup
save_dir = increment_path(Path('ultralytics_results_with_sahi') / 'exp', exist_ok)
save_dir = increment_path(Path("ultralytics_results_with_sahi") / "exp", exist_ok)
save_dir.mkdir(parents=True, exist_ok=True)
video_writer = cv2.VideoWriter(str(save_dir / f'{Path(source).stem}.mp4'), fourcc, fps, (frame_width, frame_height))
video_writer = cv2.VideoWriter(str(save_dir / f"{Path(source).stem}.mp4"), fourcc, fps, (frame_width, frame_height))
while videocapture.isOpened():
success, frame = videocapture.read()
if not success:
break
results = get_sliced_prediction(frame,
detection_model,
slice_height=512,
slice_width=512,
overlap_height_ratio=0.2,
overlap_width_ratio=0.2)
results = get_sliced_prediction(
frame, detection_model, slice_height=512, slice_width=512, overlap_height_ratio=0.2, overlap_width_ratio=0.2
)
object_prediction_list = results.object_prediction_list
boxes_list = []
clss_list = []
for ind, _ in enumerate(object_prediction_list):
boxes = object_prediction_list[ind].bbox.minx, object_prediction_list[ind].bbox.miny, \
object_prediction_list[ind].bbox.maxx, object_prediction_list[ind].bbox.maxy
boxes = (
object_prediction_list[ind].bbox.minx,
object_prediction_list[ind].bbox.miny,
object_prediction_list[ind].bbox.maxx,
object_prediction_list[ind].bbox.maxy,
)
clss = object_prediction_list[ind].category.name
boxes_list.append(boxes)
clss_list.append(clss)
@ -69,21 +69,19 @@ def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False,
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (56, 56, 255), 2)
label = str(cls)
t_size = cv2.getTextSize(label, 0, fontScale=0.6, thickness=1)[0]
cv2.rectangle(frame, (int(x1), int(y1) - t_size[1] - 3), (int(x1) + t_size[0], int(y1) + 3), (56, 56, 255),
-1)
cv2.putText(frame,
label, (int(x1), int(y1) - 2),
0,
0.6, [255, 255, 255],
thickness=1,
lineType=cv2.LINE_AA)
cv2.rectangle(
frame, (int(x1), int(y1) - t_size[1] - 3), (int(x1) + t_size[0], int(y1) + 3), (56, 56, 255), -1
)
cv2.putText(
frame, label, (int(x1), int(y1) - 2), 0, 0.6, [255, 255, 255], thickness=1, lineType=cv2.LINE_AA
)
if view_img:
cv2.imshow(Path(source).stem, frame)
if save_img:
video_writer.write(frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
if cv2.waitKey(1) & 0xFF == ord("q"):
break
video_writer.release()
videocapture.release()
@ -93,11 +91,11 @@ def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False,
def parse_opt():
"""Parse command line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='yolov8n.pt', help='initial weights path')
parser.add_argument('--source', type=str, required=True, help='video file path')
parser.add_argument('--view-img', action='store_true', help='show results')
parser.add_argument('--save-img', action='store_true', help='save results')
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
parser.add_argument("--weights", type=str, default="yolov8n.pt", help="initial weights path")
parser.add_argument("--source", type=str, required=True, help="video file path")
parser.add_argument("--view-img", action="store_true", help="show results")
parser.add_argument("--save-img", action="store_true", help="save results")
parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment")
return parser.parse_args()
@ -106,6 +104,6 @@ def main(opt):
run(**vars(opt))
if __name__ == '__main__':
if __name__ == "__main__":
opt = parse_opt()
main(opt)

@ -21,18 +21,21 @@ class YOLOv8Seg:
"""
# Build Ort session
self.session = ort.InferenceSession(onnx_model,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
if ort.get_device() == 'GPU' else ['CPUExecutionProvider'])
self.session = ort.InferenceSession(
onnx_model,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
if ort.get_device() == "GPU"
else ["CPUExecutionProvider"],
)
# Numpy dtype: support both FP32 and FP16 onnx model
self.ndtype = np.half if self.session.get_inputs()[0].type == 'tensor(float16)' else np.single
self.ndtype = np.half if self.session.get_inputs()[0].type == "tensor(float16)" else np.single
# Get model width and height(YOLOv8-seg only has one input)
self.model_height, self.model_width = [x.shape for x in self.session.get_inputs()][0][-2:]
# Load COCO class names
self.classes = yaml_load(check_yaml('coco128.yaml'))['names']
self.classes = yaml_load(check_yaml("coco128.yaml"))["names"]
# Create color palette
self.color_palette = Colors()
@ -60,14 +63,16 @@ class YOLOv8Seg:
preds = self.session.run(None, {self.session.get_inputs()[0].name: im})
# Post-process
boxes, segments, masks = self.postprocess(preds,
im0=im0,
ratio=ratio,
pad_w=pad_w,
pad_h=pad_h,
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
nm=nm)
boxes, segments, masks = self.postprocess(
preds,
im0=im0,
ratio=ratio,
pad_w=pad_w,
pad_h=pad_h,
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
nm=nm,
)
return boxes, segments, masks
def preprocess(self, img):
@ -98,7 +103,7 @@ class YOLOv8Seg:
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
# Transforms: HWC to CHW -> BGR to RGB -> div(255) -> contiguous -> add axis(optional)
img = np.ascontiguousarray(np.einsum('HWC->CHW', img)[::-1], dtype=self.ndtype) / 255.0
img = np.ascontiguousarray(np.einsum("HWC->CHW", img)[::-1], dtype=self.ndtype) / 255.0
img_process = img[None] if len(img.shape) == 3 else img
return img_process, ratio, (pad_w, pad_h)
@ -124,7 +129,7 @@ class YOLOv8Seg:
x, protos = preds[0], preds[1] # Two outputs: predictions and protos
# Transpose the first output: (Batch_size, xywh_conf_cls_nm, Num_anchors) -> (Batch_size, Num_anchors, xywh_conf_cls_nm)
x = np.einsum('bcn->bnc', x)
x = np.einsum("bcn->bnc", x)
# Predictions filtering by conf-threshold
x = x[np.amax(x[..., 4:-nm], axis=-1) > conf_threshold]
@ -138,7 +143,6 @@ class YOLOv8Seg:
# Decode and return
if len(x) > 0:
# Bounding boxes format change: cxcywh -> xyxy
x[..., [0, 1]] -= x[..., [2, 3]] / 2
x[..., [2, 3]] += x[..., [0, 1]]
@ -173,13 +177,13 @@ class YOLOv8Seg:
segments (List): list of segment masks.
"""
segments = []
for x in masks.astype('uint8'):
for x in masks.astype("uint8"):
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[0] # CHAIN_APPROX_SIMPLE
if c:
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
else:
c = np.zeros((0, 2)) # no segments found
segments.append(c.astype('float32'))
segments.append(c.astype("float32"))
return segments
@staticmethod
@ -219,7 +223,7 @@ class YOLOv8Seg:
masks = np.matmul(masks_in, protos.reshape((c, -1))).reshape((-1, mh, mw)).transpose(1, 2, 0) # HWN
masks = np.ascontiguousarray(masks)
masks = self.scale_mask(masks, im0_shape) # re-scale mask from P3 shape to original input image shape
masks = np.einsum('HWN -> NHW', masks) # HWN -> NHW
masks = np.einsum("HWN -> NHW", masks) # HWN -> NHW
masks = self.crop_mask(masks, bboxes)
return np.greater(masks, 0.5)
@ -250,8 +254,9 @@ class YOLOv8Seg:
if len(masks.shape) < 2:
raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
masks = masks[top:bottom, left:right]
masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]),
interpolation=cv2.INTER_LINEAR) # INTER_CUBIC would be better
masks = cv2.resize(
masks, (im0_shape[1], im0_shape[0]), interpolation=cv2.INTER_LINEAR
) # INTER_CUBIC would be better
if len(masks.shape) == 2:
masks = masks[:, :, None]
return masks
@ -279,32 +284,46 @@ class YOLOv8Seg:
cv2.fillPoly(im_canvas, np.int32([segment]), self.color_palette(int(cls_), bgr=True))
# draw bbox rectangle
cv2.rectangle(im, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])),
self.color_palette(int(cls_), bgr=True), 1, cv2.LINE_AA)
cv2.putText(im, f'{self.classes[cls_]}: {conf:.3f}', (int(box[0]), int(box[1] - 9)),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, self.color_palette(int(cls_), bgr=True), 2, cv2.LINE_AA)
cv2.rectangle(
im,
(int(box[0]), int(box[1])),
(int(box[2]), int(box[3])),
self.color_palette(int(cls_), bgr=True),
1,
cv2.LINE_AA,
)
cv2.putText(
im,
f"{self.classes[cls_]}: {conf:.3f}",
(int(box[0]), int(box[1] - 9)),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
self.color_palette(int(cls_), bgr=True),
2,
cv2.LINE_AA,
)
# Mix image
im = cv2.addWeighted(im_canvas, 0.3, im, 0.7, 0)
# Show image
if vis:
cv2.imshow('demo', im)
cv2.imshow("demo", im)
cv2.waitKey(0)
cv2.destroyAllWindows()
# Save image
if save:
cv2.imwrite('demo.jpg', im)
cv2.imwrite("demo.jpg", im)
if __name__ == '__main__':
if __name__ == "__main__":
# Create an argument parser to handle command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, required=True, help='Path to ONNX model')
parser.add_argument('--source', type=str, default=str(ASSETS / 'bus.jpg'), help='Path to input image')
parser.add_argument('--conf', type=float, default=0.25, help='Confidence threshold')
parser.add_argument('--iou', type=float, default=0.45, help='NMS IoU threshold')
parser.add_argument("--model", type=str, required=True, help="Path to ONNX model")
parser.add_argument("--source", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image")
parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
parser.add_argument("--iou", type=float, default=0.45, help="NMS IoU threshold")
args = parser.parse_args()
# Build model

@ -107,9 +107,9 @@ export = [
]
explorer = [
"lancedb", # vector search
"duckdb", # SQL queries, supports lancedb tables
"streamlit", # visualizing with GUI
"lancedb", # vector search
"duckdb", # SQL queries, supports lancedb tables
"streamlit", # visualizing with GUI
]
# tensorflow>=2.4.1,<=2.13.1 # TF exports (-cpu, -aarch64, -macos)
@ -179,5 +179,5 @@ pre-summary-newline = true
close-quotes-on-newline = true
[tool.codespell]
ignore-words-list = "crate,nd,strack,dota,ane,segway,fo,gool,winn,commend"
skip = '*.csv,*venv*,docs/??/,docs/mkdocs_??.yml'
ignore-words-list = "crate,nd,ned,strack,dota,ane,segway,fo,gool,winn,commend,bloc,nam,afterall"
skip = "*.pt,*.pth,*.torchscript,*.onnx,*.tflite,*.pb,*.bin,*.param,*.mlmodel,*.engine,*.npy,*.data*,*.csv,*pnnx*,*venv*,__pycache__*,*.ico,*.jpg,*.png,*.mp4,*.mov,/runs,/.git,./docs/??/*.md,./docs/mkdocs_??.yml"

@ -3,6 +3,8 @@ import PIL
from ultralytics import Explorer
from ultralytics.utils import ASSETS
import PIL
def test_similarity():
"""Test similarity calculations and SQL queries for correctness and response length."""

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.238'
__version__ = "8.0.239"
from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO
@ -10,4 +10,4 @@ from ultralytics.utils import SETTINGS as settings
from ultralytics.utils.checks import check_yolo as checks
from ultralytics.utils.downloads import download
__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'settings', 'Explorer'
__all__ = "__version__", "YOLO", "NAS", "SAM", "FastSAM", "RTDETR", "checks", "download", "settings", "Explorer"

@ -8,34 +8,53 @@ from pathlib import Path
from types import SimpleNamespace
from typing import Dict, List, Union
from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, RANK, ROOT, RUNS_DIR,
SETTINGS, SETTINGS_YAML, TESTS_RUNNING, IterableSimpleNamespace, __version__, checks,
colorstr, deprecation_warn, yaml_load, yaml_print)
from ultralytics.utils import (
ASSETS,
DEFAULT_CFG,
DEFAULT_CFG_DICT,
DEFAULT_CFG_PATH,
LOGGER,
RANK,
ROOT,
RUNS_DIR,
SETTINGS,
SETTINGS_YAML,
TESTS_RUNNING,
IterableSimpleNamespace,
__version__,
checks,
colorstr,
deprecation_warn,
yaml_load,
yaml_print,
)
# Define valid tasks and modes
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
TASKS = 'detect', 'segment', 'classify', 'pose', 'obb'
MODES = "train", "val", "predict", "export", "track", "benchmark"
TASKS = "detect", "segment", "classify", "pose", "obb"
TASK2DATA = {
'detect': 'coco8.yaml',
'segment': 'coco8-seg.yaml',
'classify': 'imagenet10',
'pose': 'coco8-pose.yaml',
'obb': 'dota8-obb.yaml'} # not implemented yet
"detect": "coco8.yaml",
"segment": "coco8-seg.yaml",
"classify": "imagenet10",
"pose": "coco8-pose.yaml",
"obb": "dota8-obb.yaml",
}
TASK2MODEL = {
'detect': 'yolov8n.pt',
'segment': 'yolov8n-seg.pt',
'classify': 'yolov8n-cls.pt',
'pose': 'yolov8n-pose.pt',
'obb': 'yolov8n-obb.pt'}
"detect": "yolov8n.pt",
"segment": "yolov8n-seg.pt",
"classify": "yolov8n-cls.pt",
"pose": "yolov8n-pose.pt",
"obb": "yolov8n-obb.pt",
}
TASK2METRIC = {
'detect': 'metrics/mAP50-95(B)',
'segment': 'metrics/mAP50-95(M)',
'classify': 'metrics/accuracy_top1',
'pose': 'metrics/mAP50-95(P)',
'obb': 'metrics/mAP50-95(OBB)'}
CLI_HELP_MSG = \
f"""
"detect": "metrics/mAP50-95(B)",
"segment": "metrics/mAP50-95(M)",
"classify": "metrics/accuracy_top1",
"pose": "metrics/mAP50-95(P)",
"obb": "metrics/mAP50-95(OBB)",
}
CLI_HELP_MSG = f"""
Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax:
yolo TASK MODE ARGS
@ -74,16 +93,83 @@ CLI_HELP_MSG = \
"""
# Define keys for arg type checks
CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear', 'time'
CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr',
'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud',
'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou', 'fraction') # fraction floats 0.0 - 1.0
CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
'line_width', 'workspace', 'nbs', 'save_period')
CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', 'rect', 'cos_lr', 'overlap_mask', 'val',
'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf', 'save_crop',
'save_frames', 'show_labels', 'show_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks',
'show_boxes', 'keras', 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'profile', 'multi_scale')
CFG_FLOAT_KEYS = "warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time"
CFG_FRACTION_KEYS = (
"dropout",
"iou",
"lr0",
"lrf",
"momentum",
"weight_decay",
"warmup_momentum",
"warmup_bias_lr",
"label_smoothing",
"hsv_h",
"hsv_s",
"hsv_v",
"translate",
"scale",
"perspective",
"flipud",
"fliplr",
"mosaic",
"mixup",
"copy_paste",
"conf",
"iou",
"fraction",
) # fraction floats 0.0 - 1.0
CFG_INT_KEYS = (
"epochs",
"patience",
"batch",
"workers",
"seed",
"close_mosaic",
"mask_ratio",
"max_det",
"vid_stride",
"line_width",
"workspace",
"nbs",
"save_period",
)
CFG_BOOL_KEYS = (
"save",
"exist_ok",
"verbose",
"deterministic",
"single_cls",
"rect",
"cos_lr",
"overlap_mask",
"val",
"save_json",
"save_hybrid",
"half",
"dnn",
"plots",
"show",
"save_txt",
"save_conf",
"save_crop",
"save_frames",
"show_labels",
"show_conf",
"visualize",
"augment",
"agnostic_nms",
"retina_masks",
"show_boxes",
"keras",
"optimize",
"int8",
"dynamic",
"simplify",
"nms",
"profile",
"multi_scale",
)
def cfg2dict(cfg):
@ -119,38 +205,44 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
# Merge overrides
if overrides:
overrides = cfg2dict(overrides)
if 'save_dir' not in cfg:
overrides.pop('save_dir', None) # special override keys to ignore
if "save_dir" not in cfg:
overrides.pop("save_dir", None) # special override keys to ignore
check_dict_alignment(cfg, overrides)
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
# Special handling for numeric project/name
for k in 'project', 'name':
for k in "project", "name":
if k in cfg and isinstance(cfg[k], (int, float)):
cfg[k] = str(cfg[k])
if cfg.get('name') == 'model': # assign model to 'name' arg
cfg['name'] = cfg.get('model', '').split('.')[0]
if cfg.get("name") == "model": # assign model to 'name' arg
cfg["name"] = cfg.get("model", "").split(".")[0]
LOGGER.warning(f"WARNING ⚠ 'name=model' automatically updated to 'name={cfg['name']}'.")
# Type and Value checks
for k, v in cfg.items():
if v is not None: # None values may be from optional args
if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
raise TypeError(
f"'{k}={v}' is of invalid type {type(v).__name__}. "
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')"
)
elif k in CFG_FRACTION_KEYS:
if not isinstance(v, (int, float)):
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
raise TypeError(
f"'{k}={v}' is of invalid type {type(v).__name__}. "
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')"
)
if not (0.0 <= v <= 1.0):
raise ValueError(f"'{k}={v}' is an invalid value. "
f"Valid '{k}' values are between 0.0 and 1.0.")
raise ValueError(f"'{k}={v}' is an invalid value. " f"Valid '{k}' values are between 0.0 and 1.0.")
elif k in CFG_INT_KEYS and not isinstance(v, int):
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
f"'{k}' must be an int (i.e. '{k}=8')")
raise TypeError(
f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')"
)
elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')")
raise TypeError(
f"'{k}={v}' is of invalid type {type(v).__name__}. "
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')"
)
# Return instance
return IterableSimpleNamespace(**cfg)
@ -159,13 +251,13 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
def get_save_dir(args, name=None):
"""Return save_dir as created from train/val/predict arguments."""
if getattr(args, 'save_dir', None):
if getattr(args, "save_dir", None):
save_dir = args.save_dir
else:
from ultralytics.utils.files import increment_path
project = args.project or (ROOT.parent / 'tests/tmp/runs' if TESTS_RUNNING else RUNS_DIR) / args.task
name = name or args.name or f'{args.mode}'
project = args.project or (ROOT.parent / "tests/tmp/runs" if TESTS_RUNNING else RUNS_DIR) / args.task
name = name or args.name or f"{args.mode}"
save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True)
return Path(save_dir)
@ -175,18 +267,18 @@ def _handle_deprecation(custom):
"""Hardcoded function to handle deprecated config keys."""
for key in custom.copy().keys():
if key == 'boxes':
deprecation_warn(key, 'show_boxes')
custom['show_boxes'] = custom.pop('boxes')
if key == 'hide_labels':
deprecation_warn(key, 'show_labels')
custom['show_labels'] = custom.pop('hide_labels') == 'False'
if key == 'hide_conf':
deprecation_warn(key, 'show_conf')
custom['show_conf'] = custom.pop('hide_conf') == 'False'
if key == 'line_thickness':
deprecation_warn(key, 'line_width')
custom['line_width'] = custom.pop('line_thickness')
if key == "boxes":
deprecation_warn(key, "show_boxes")
custom["show_boxes"] = custom.pop("boxes")
if key == "hide_labels":
deprecation_warn(key, "show_labels")
custom["show_labels"] = custom.pop("hide_labels") == "False"
if key == "hide_conf":
deprecation_warn(key, "show_conf")
custom["show_conf"] = custom.pop("hide_conf") == "False"
if key == "line_thickness":
deprecation_warn(key, "line_width")
custom["line_width"] = custom.pop("line_thickness")
return custom
@ -207,11 +299,11 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None):
if mismatched:
from difflib import get_close_matches
string = ''
string = ""
for x in mismatched:
matches = get_close_matches(x, base_keys) # key list
matches = [f'{k}={base[k]}' if base.get(k) is not None else k for k in matches]
match_str = f'Similar arguments are i.e. {matches}.' if matches else ''
matches = [f"{k}={base[k]}" if base.get(k) is not None else k for k in matches]
match_str = f"Similar arguments are i.e. {matches}." if matches else ""
string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
raise SyntaxError(string + CLI_HELP_MSG) from e
@ -229,13 +321,13 @@ def merge_equals_args(args: List[str]) -> List[str]:
"""
new_args = []
for i, arg in enumerate(args):
if arg == '=' and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']
new_args[-1] += f'={args[i + 1]}'
if arg == "=" and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']
new_args[-1] += f"={args[i + 1]}"
del args[i + 1]
elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]: # merge ['arg=', 'val']
new_args.append(f'{arg}{args[i + 1]}')
elif arg.endswith("=") and i < len(args) - 1 and "=" not in args[i + 1]: # merge ['arg=', 'val']
new_args.append(f"{arg}{args[i + 1]}")
del args[i + 1]
elif arg.startswith('=') and i > 0: # merge ['arg', '=val']
elif arg.startswith("=") and i > 0: # merge ['arg', '=val']
new_args[-1] += arg
else:
new_args.append(arg)
@ -259,11 +351,11 @@ def handle_yolo_hub(args: List[str]) -> None:
"""
from ultralytics import hub
if args[0] == 'login':
key = args[1] if len(args) > 1 else ''
if args[0] == "login":
key = args[1] if len(args) > 1 else ""
# Log in to Ultralytics HUB using the provided API key
hub.login(key)
elif args[0] == 'logout':
elif args[0] == "logout":
# Log out from Ultralytics HUB
hub.logout()
@ -283,19 +375,19 @@ def handle_yolo_settings(args: List[str]) -> None:
python my_script.py yolo settings reset
```
"""
url = 'https://docs.ultralytics.com/quickstart/#ultralytics-settings' # help URL
url = "https://docs.ultralytics.com/quickstart/#ultralytics-settings" # help URL
try:
if any(args):
if args[0] == 'reset':
if args[0] == "reset":
SETTINGS_YAML.unlink() # delete the settings file
SETTINGS.reset() # create new settings
LOGGER.info('Settings reset successfully') # inform the user that settings have been reset
LOGGER.info("Settings reset successfully") # inform the user that settings have been reset
else: # save a new setting
new = dict(parse_key_value_pair(a) for a in args)
check_dict_alignment(SETTINGS, new)
SETTINGS.update(new)
LOGGER.info(f'💡 Learn about settings at {url}')
LOGGER.info(f"💡 Learn about settings at {url}")
yaml_print(SETTINGS_YAML) # print the current settings
except Exception as e:
LOGGER.warning(f"WARNING ⚠ settings error: '{e}'. Please see {url} for help.")
@ -303,13 +395,13 @@ def handle_yolo_settings(args: List[str]) -> None:
def handle_explorer():
"""Open the Ultralytics Explorer GUI."""
checks.check_requirements('streamlit')
subprocess.run(['streamlit', 'run', ROOT / 'data/explorer/gui/dash.py', '--server.maxMessageSize', '2048'])
checks.check_requirements("streamlit")
subprocess.run(["streamlit", "run", ROOT / "data/explorer/gui/dash.py", "--server.maxMessageSize", "2048"])
def parse_key_value_pair(pair):
"""Parse one 'key=value' pair and return key and value."""
k, v = pair.split('=', 1) # split on first '=' sign
k, v = pair.split("=", 1) # split on first '=' sign
k, v = k.strip(), v.strip() # remove spaces
assert v, f"missing '{k}' value"
return k, smart_value(v)
@ -318,11 +410,11 @@ def parse_key_value_pair(pair):
def smart_value(v):
"""Convert a string to an underlying type such as int, float, bool, etc."""
v_lower = v.lower()
if v_lower == 'none':
if v_lower == "none":
return None
elif v_lower == 'true':
elif v_lower == "true":
return True
elif v_lower == 'false':
elif v_lower == "false":
return False
else:
with contextlib.suppress(Exception):
@ -330,7 +422,7 @@ def smart_value(v):
return v
def entrypoint(debug=''):
def entrypoint(debug=""):
"""
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
to the package.
@ -345,139 +437,150 @@ def entrypoint(debug=''):
It uses the package's default cfg and initializes it using the passed overrides.
Then it calls the CLI function with the composed cfg
"""
args = (debug.split(' ') if debug else sys.argv)[1:]
args = (debug.split(" ") if debug else sys.argv)[1:]
if not args: # no arguments passed
LOGGER.info(CLI_HELP_MSG)
return
special = {
'help': lambda: LOGGER.info(CLI_HELP_MSG),
'checks': checks.collect_system_info,
'version': lambda: LOGGER.info(__version__),
'settings': lambda: handle_yolo_settings(args[1:]),
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
'hub': lambda: handle_yolo_hub(args[1:]),
'login': lambda: handle_yolo_hub(args),
'copy-cfg': copy_default_cfg,
'explorer': lambda: handle_explorer()}
"help": lambda: LOGGER.info(CLI_HELP_MSG),
"checks": checks.collect_system_info,
"version": lambda: LOGGER.info(__version__),
"settings": lambda: handle_yolo_settings(args[1:]),
"cfg": lambda: yaml_print(DEFAULT_CFG_PATH),
"hub": lambda: handle_yolo_hub(args[1:]),
"login": lambda: handle_yolo_hub(args),
"copy-cfg": copy_default_cfg,
"explorer": lambda: handle_explorer(),
}
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
# Define common misuses of special commands, i.e. -h, -help, --help
special.update({k[0]: v for k, v in special.items()}) # singular
special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith('s')}) # singular
special = {**special, **{f'-{k}': v for k, v in special.items()}, **{f'--{k}': v for k, v in special.items()}}
special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith("s")}) # singular
special = {**special, **{f"-{k}": v for k, v in special.items()}, **{f"--{k}": v for k, v in special.items()}}
overrides = {} # basic overrides, i.e. imgsz=320
for a in merge_equals_args(args): # merge spaces around '=' sign
if a.startswith('--'):
if a.startswith("--"):
LOGGER.warning(f"WARNING ⚠ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
a = a[2:]
if a.endswith(','):
if a.endswith(","):
LOGGER.warning(f"WARNING ⚠ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
a = a[:-1]
if '=' in a:
if "=" in a:
try:
k, v = parse_key_value_pair(a)
if k == 'cfg' and v is not None: # custom.yaml passed
LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}')
overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != 'cfg'}
if k == "cfg" and v is not None: # custom.yaml passed
LOGGER.info(f"Overriding {DEFAULT_CFG_PATH} with {v}")
overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != "cfg"}
else:
overrides[k] = v
except (NameError, SyntaxError, ValueError, AssertionError) as e:
check_dict_alignment(full_args_dict, {a: ''}, e)
check_dict_alignment(full_args_dict, {a: ""}, e)
elif a in TASKS:
overrides['task'] = a
overrides["task"] = a
elif a in MODES:
overrides['mode'] = a
overrides["mode"] = a
elif a.lower() in special:
special[a.lower()]()
return
elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool):
overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True
elif a in DEFAULT_CFG_DICT:
raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}")
raise SyntaxError(
f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}"
)
else:
check_dict_alignment(full_args_dict, {a: ''})
check_dict_alignment(full_args_dict, {a: ""})
# Check keys
check_dict_alignment(full_args_dict, overrides)
# Mode
mode = overrides.get('mode')
mode = overrides.get("mode")
if mode is None:
mode = DEFAULT_CFG.mode or 'predict'
mode = DEFAULT_CFG.mode or "predict"
LOGGER.warning(f"WARNING ⚠ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
elif mode not in MODES:
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")
# Task
task = overrides.pop('task', None)
task = overrides.pop("task", None)
if task:
if task not in TASKS:
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
if 'model' not in overrides:
overrides['model'] = TASK2MODEL[task]
if "model" not in overrides:
overrides["model"] = TASK2MODEL[task]
# Model
model = overrides.pop('model', DEFAULT_CFG.model)
model = overrides.pop("model", DEFAULT_CFG.model)
if model is None:
model = 'yolov8n.pt'
model = "yolov8n.pt"
LOGGER.warning(f"WARNING ⚠ 'model' is missing. Using default 'model={model}'.")
overrides['model'] = model
overrides["model"] = model
stem = Path(model).stem.lower()
if 'rtdetr' in stem: # guess architecture
if "rtdetr" in stem: # guess architecture
from ultralytics import RTDETR
model = RTDETR(model) # no task argument
elif 'fastsam' in stem:
elif "fastsam" in stem:
from ultralytics import FastSAM
model = FastSAM(model)
elif 'sam' in stem:
elif "sam" in stem:
from ultralytics import SAM
model = SAM(model)
else:
from ultralytics import YOLO
model = YOLO(model, task=task)
if isinstance(overrides.get('pretrained'), str):
model.load(overrides['pretrained'])
if isinstance(overrides.get("pretrained"), str):
model.load(overrides["pretrained"])
# Task Update
if task != model.task:
if task:
LOGGER.warning(f"WARNING ⚠ conflicting 'task={task}' passed with 'task={model.task}' model. "
f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.")
LOGGER.warning(
f"WARNING ⚠ conflicting 'task={task}' passed with 'task={model.task}' model. "
f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model."
)
task = model.task
# Mode
if mode in ('predict', 'track') and 'source' not in overrides:
overrides['source'] = DEFAULT_CFG.source or ASSETS
if mode in ("predict", "track") and "source" not in overrides:
overrides["source"] = DEFAULT_CFG.source or ASSETS
LOGGER.warning(f"WARNING ⚠ 'source' is missing. Using default 'source={overrides['source']}'.")
elif mode in ('train', 'val'):
if 'data' not in overrides and 'resume' not in overrides:
overrides['data'] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
elif mode in ("train", "val"):
if "data" not in overrides and "resume" not in overrides:
overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
LOGGER.warning(f"WARNING ⚠ 'data' is missing. Using default 'data={overrides['data']}'.")
elif mode == 'export':
if 'format' not in overrides:
overrides['format'] = DEFAULT_CFG.format or 'torchscript'
elif mode == "export":
if "format" not in overrides:
overrides["format"] = DEFAULT_CFG.format or "torchscript"
LOGGER.warning(f"WARNING ⚠ 'format' is missing. Using default 'format={overrides['format']}'.")
# Run command in python
getattr(model, mode)(**overrides) # default args from model
# Show help
LOGGER.info(f'💡 Learn more at https://docs.ultralytics.com/modes/{mode}')
LOGGER.info(f"💡 Learn more at https://docs.ultralytics.com/modes/{mode}")
# Special modes --------------------------------------------------------------------------------------------------------
def copy_default_cfg():
"""Copy and create a new default configuration file with '_copy' appended to its name."""
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace(".yaml", "_copy.yaml")
shutil.copy2(DEFAULT_CFG_PATH, new_file)
LOGGER.info(f'{DEFAULT_CFG_PATH} copied to {new_file}\n'
f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8")
LOGGER.info(
f"{DEFAULT_CFG_PATH} copied to {new_file}\n"
f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8"
)
if __name__ == '__main__':
if __name__ == "__main__":
# Example: entrypoint(debug='yolo predict model=yolov8n.pt')
entrypoint(debug='')
entrypoint(debug="")

@ -4,5 +4,12 @@ from .base import BaseDataset
from .build import build_dataloader, build_yolo_dataset, load_inference_source
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
__all__ = ('BaseDataset', 'ClassificationDataset', 'SemanticDataset', 'YOLODataset', 'build_yolo_dataset',
'build_dataloader', 'load_inference_source')
__all__ = (
"BaseDataset",
"ClassificationDataset",
"SemanticDataset",
"YOLODataset",
"build_yolo_dataset",
"build_dataloader",
"load_inference_source",
)

@ -5,7 +5,7 @@ from pathlib import Path
from ultralytics import SAM, YOLO
def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None):
def auto_annotate(data, det_model="yolov8x.pt", sam_model="sam_b.pt", device="", output_dir=None):
"""
Automatically annotates images using a YOLO object detection model and a SAM segmentation model.
@ -29,7 +29,7 @@ def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='',
data = Path(data)
if not output_dir:
output_dir = data.parent / f'{data.stem}_auto_annotate_labels'
output_dir = data.parent / f"{data.stem}_auto_annotate_labels"
Path(output_dir).mkdir(exist_ok=True, parents=True)
det_results = det_model(data, stream=True, device=device)
@ -41,10 +41,10 @@ def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='',
sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device)
segments = sam_results[0].masks.xyn # noqa
with open(f'{str(Path(output_dir) / Path(result.path).stem)}.txt', 'w') as f:
with open(f"{str(Path(output_dir) / Path(result.path).stem)}.txt", "w") as f:
for i in range(len(segments)):
s = segments[i]
if len(s) == 0:
continue
segment = map(str, segments[i].reshape(-1).tolist())
f.write(f'{class_ids[i]} ' + ' '.join(segment) + '\n')
f.write(f"{class_ids[i]} " + " ".join(segment) + "\n")

@ -117,11 +117,11 @@ class BaseMixTransform:
if self.pre_transform is not None:
for i, data in enumerate(mix_labels):
mix_labels[i] = self.pre_transform(data)
labels['mix_labels'] = mix_labels
labels["mix_labels"] = mix_labels
# Mosaic or MixUp
labels = self._mix_transform(labels)
labels.pop('mix_labels', None)
labels.pop("mix_labels", None)
return labels
def _mix_transform(self, labels):
@ -149,8 +149,8 @@ class Mosaic(BaseMixTransform):
def __init__(self, dataset, imgsz=640, p=1.0, n=4):
"""Initializes the object with a dataset, image size, probability, and border."""
assert 0 <= p <= 1.0, f'The probability should be in range [0, 1], but got {p}.'
assert n in (4, 9), 'grid must be equal to 4 or 9.'
assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}."
assert n in (4, 9), "grid must be equal to 4 or 9."
super().__init__(dataset=dataset, p=p)
self.dataset = dataset
self.imgsz = imgsz
@ -166,20 +166,21 @@ class Mosaic(BaseMixTransform):
def _mix_transform(self, labels):
"""Apply mixup transformation to the input image and labels."""
assert labels.get('rect_shape', None) is None, 'rect and mosaic are mutually exclusive.'
assert len(labels.get('mix_labels', [])), 'There are no other images for mosaic augment.'
return self._mosaic3(labels) if self.n == 3 else self._mosaic4(labels) if self.n == 4 else self._mosaic9(
labels) # This code is modified for mosaic3 method.
assert labels.get("rect_shape", None) is None, "rect and mosaic are mutually exclusive."
assert len(labels.get("mix_labels", [])), "There are no other images for mosaic augment."
return (
self._mosaic3(labels) if self.n == 3 else self._mosaic4(labels) if self.n == 4 else self._mosaic9(labels)
) # This code is modified for mosaic3 method.
def _mosaic3(self, labels):
"""Create a 1x3 image mosaic."""
mosaic_labels = []
s = self.imgsz
for i in range(3):
labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
labels_patch = labels if i == 0 else labels["mix_labels"][i - 1]
# Load image
img = labels_patch['img']
h, w = labels_patch.pop('resized_shape')
img = labels_patch["img"]
h, w = labels_patch.pop("resized_shape")
# Place img in img3
if i == 0: # center
@ -194,7 +195,7 @@ class Mosaic(BaseMixTransform):
padw, padh = c[:2]
x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
img3[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img3[ymin:ymax, xmin:xmax]
img3[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img3[ymin:ymax, xmin:xmax]
# hp, wp = h, w # height, width previous for next iteration
# Labels assuming imgsz*2 mosaic size
@ -202,7 +203,7 @@ class Mosaic(BaseMixTransform):
mosaic_labels.append(labels_patch)
final_labels = self._cat_labels(mosaic_labels)
final_labels['img'] = img3[-self.border[0]:self.border[0], -self.border[1]:self.border[1]]
final_labels["img"] = img3[-self.border[0] : self.border[0], -self.border[1] : self.border[1]]
return final_labels
def _mosaic4(self, labels):
@ -211,10 +212,10 @@ class Mosaic(BaseMixTransform):
s = self.imgsz
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
for i in range(4):
labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
labels_patch = labels if i == 0 else labels["mix_labels"][i - 1]
# Load image
img = labels_patch['img']
h, w = labels_patch.pop('resized_shape')
img = labels_patch["img"]
h, w = labels_patch.pop("resized_shape")
# Place img in img4
if i == 0: # top left
@ -238,7 +239,7 @@ class Mosaic(BaseMixTransform):
labels_patch = self._update_labels(labels_patch, padw, padh)
mosaic_labels.append(labels_patch)
final_labels = self._cat_labels(mosaic_labels)
final_labels['img'] = img4
final_labels["img"] = img4
return final_labels
def _mosaic9(self, labels):
@ -247,10 +248,10 @@ class Mosaic(BaseMixTransform):
s = self.imgsz
hp, wp = -1, -1 # height, width previous
for i in range(9):
labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
labels_patch = labels if i == 0 else labels["mix_labels"][i - 1]
# Load image
img = labels_patch['img']
h, w = labels_patch.pop('resized_shape')
img = labels_patch["img"]
h, w = labels_patch.pop("resized_shape")
# Place img in img9
if i == 0: # center
@ -278,7 +279,7 @@ class Mosaic(BaseMixTransform):
x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
# Image
img9[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img9[ymin:ymax, xmin:xmax]
img9[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img9[ymin:ymax, xmin:xmax]
hp, wp = h, w # height, width previous for next iteration
# Labels assuming imgsz*2 mosaic size
@ -286,16 +287,16 @@ class Mosaic(BaseMixTransform):
mosaic_labels.append(labels_patch)
final_labels = self._cat_labels(mosaic_labels)
final_labels['img'] = img9[-self.border[0]:self.border[0], -self.border[1]:self.border[1]]
final_labels["img"] = img9[-self.border[0] : self.border[0], -self.border[1] : self.border[1]]
return final_labels
@staticmethod
def _update_labels(labels, padw, padh):
"""Update labels."""
nh, nw = labels['img'].shape[:2]
labels['instances'].convert_bbox(format='xyxy')
labels['instances'].denormalize(nw, nh)
labels['instances'].add_padding(padw, padh)
nh, nw = labels["img"].shape[:2]
labels["instances"].convert_bbox(format="xyxy")
labels["instances"].denormalize(nw, nh)
labels["instances"].add_padding(padw, padh)
return labels
def _cat_labels(self, mosaic_labels):
@ -306,18 +307,20 @@ class Mosaic(BaseMixTransform):
instances = []
imgsz = self.imgsz * 2 # mosaic imgsz
for labels in mosaic_labels:
cls.append(labels['cls'])
instances.append(labels['instances'])
cls.append(labels["cls"])
instances.append(labels["instances"])
# Final labels
final_labels = {
'im_file': mosaic_labels[0]['im_file'],
'ori_shape': mosaic_labels[0]['ori_shape'],
'resized_shape': (imgsz, imgsz),
'cls': np.concatenate(cls, 0),
'instances': Instances.concatenate(instances, axis=0),
'mosaic_border': self.border} # final_labels
final_labels['instances'].clip(imgsz, imgsz)
good = final_labels['instances'].remove_zero_area_boxes()
final_labels['cls'] = final_labels['cls'][good]
"im_file": mosaic_labels[0]["im_file"],
"ori_shape": mosaic_labels[0]["ori_shape"],
"resized_shape": (imgsz, imgsz),
"cls": np.concatenate(cls, 0),
"instances": Instances.concatenate(instances, axis=0),
"mosaic_border": self.border,
}
final_labels["instances"].clip(imgsz, imgsz)
good = final_labels["instances"].remove_zero_area_boxes()
final_labels["cls"] = final_labels["cls"][good]
return final_labels
@ -335,10 +338,10 @@ class MixUp(BaseMixTransform):
def _mix_transform(self, labels):
"""Applies MixUp augmentation as per https://arxiv.org/pdf/1710.09412.pdf."""
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
labels2 = labels['mix_labels'][0]
labels['img'] = (labels['img'] * r + labels2['img'] * (1 - r)).astype(np.uint8)
labels['instances'] = Instances.concatenate([labels['instances'], labels2['instances']], axis=0)
labels['cls'] = np.concatenate([labels['cls'], labels2['cls']], 0)
labels2 = labels["mix_labels"][0]
labels["img"] = (labels["img"] * r + labels2["img"] * (1 - r)).astype(np.uint8)
labels["instances"] = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0)
labels["cls"] = np.concatenate([labels["cls"], labels2["cls"]], 0)
return labels
@ -366,14 +369,9 @@ class RandomPerspective:
box_candidates(box1, box2): Filters out bounding boxes that don't meet certain criteria post-transformation.
"""
def __init__(self,
degrees=0.0,
translate=0.1,
scale=0.5,
shear=0.0,
perspective=0.0,
border=(0, 0),
pre_transform=None):
def __init__(
self, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, border=(0, 0), pre_transform=None
):
"""Initializes RandomPerspective object with transformation parameters."""
self.degrees = degrees
@ -519,18 +517,18 @@ class RandomPerspective:
Args:
labels (dict): a dict of `bboxes`, `segments`, `keypoints`.
"""
if self.pre_transform and 'mosaic_border' not in labels:
if self.pre_transform and "mosaic_border" not in labels:
labels = self.pre_transform(labels)
labels.pop('ratio_pad', None) # do not need ratio pad
labels.pop("ratio_pad", None) # do not need ratio pad
img = labels['img']
cls = labels['cls']
instances = labels.pop('instances')
img = labels["img"]
cls = labels["cls"]
instances = labels.pop("instances")
# Make sure the coord formats are right
instances.convert_bbox(format='xyxy')
instances.convert_bbox(format="xyxy")
instances.denormalize(*img.shape[:2][::-1])
border = labels.pop('mosaic_border', self.border)
border = labels.pop("mosaic_border", self.border)
self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h
# M is affine matrix
# Scale for func:`box_candidates`
@ -546,20 +544,20 @@ class RandomPerspective:
if keypoints is not None:
keypoints = self.apply_keypoints(keypoints, M)
new_instances = Instances(bboxes, segments, keypoints, bbox_format='xyxy', normalized=False)
new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False)
# Clip
new_instances.clip(*self.size)
# Filter instances
instances.scale(scale_w=scale, scale_h=scale, bbox_only=True)
# Make the bboxes have the same scale with new_bboxes
i = self.box_candidates(box1=instances.bboxes.T,
box2=new_instances.bboxes.T,
area_thr=0.01 if len(segments) else 0.10)
labels['instances'] = new_instances[i]
labels['cls'] = cls[i]
labels['img'] = img
labels['resized_shape'] = img.shape[:2]
i = self.box_candidates(
box1=instances.bboxes.T, box2=new_instances.bboxes.T, area_thr=0.01 if len(segments) else 0.10
)
labels["instances"] = new_instances[i]
labels["cls"] = cls[i]
labels["img"] = img
labels["resized_shape"] = img.shape[:2]
return labels
def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16):
@ -611,7 +609,7 @@ class RandomHSV:
The modified image replaces the original image in the input 'labels' dict.
"""
img = labels['img']
img = labels["img"]
if self.hgain or self.sgain or self.vgain:
r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains
hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
@ -634,7 +632,7 @@ class RandomFlip:
Also updates any instances (bounding boxes, keypoints, etc.) accordingly.
"""
def __init__(self, p=0.5, direction='horizontal', flip_idx=None) -> None:
def __init__(self, p=0.5, direction="horizontal", flip_idx=None) -> None:
"""
Initializes the RandomFlip class with probability and direction.
@ -644,7 +642,7 @@ class RandomFlip:
Default is 'horizontal'.
flip_idx (array-like, optional): Index mapping for flipping keypoints, if any.
"""
assert direction in ['horizontal', 'vertical'], f'Support direction `horizontal` or `vertical`, got {direction}'
assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}"
assert 0 <= p <= 1.0
self.p = p
@ -662,25 +660,25 @@ class RandomFlip:
Returns:
(dict): The same dict with the flipped image and updated instances under the 'img' and 'instances' keys.
"""
img = labels['img']
instances = labels.pop('instances')
instances.convert_bbox(format='xywh')
img = labels["img"]
instances = labels.pop("instances")
instances.convert_bbox(format="xywh")
h, w = img.shape[:2]
h = 1 if instances.normalized else h
w = 1 if instances.normalized else w
# Flip up-down
if self.direction == 'vertical' and random.random() < self.p:
if self.direction == "vertical" and random.random() < self.p:
img = np.flipud(img)
instances.flipud(h)
if self.direction == 'horizontal' and random.random() < self.p:
if self.direction == "horizontal" and random.random() < self.p:
img = np.fliplr(img)
instances.fliplr(w)
# For keypoints
if self.flip_idx is not None and instances.keypoints is not None:
instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :])
labels['img'] = np.ascontiguousarray(img)
labels['instances'] = instances
labels["img"] = np.ascontiguousarray(img)
labels["instances"] = instances
return labels
@ -700,9 +698,9 @@ class LetterBox:
"""Return updated labels and image with added border."""
if labels is None:
labels = {}
img = labels.get('img') if image is None else image
img = labels.get("img") if image is None else image
shape = img.shape[:2] # current shape [height, width]
new_shape = labels.pop('rect_shape', self.new_shape)
new_shape = labels.pop("rect_shape", self.new_shape)
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
@ -730,25 +728,26 @@ class LetterBox:
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
value=(114, 114, 114)) # add border
if labels.get('ratio_pad'):
labels['ratio_pad'] = (labels['ratio_pad'], (left, top)) # for evaluation
img = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
) # add border
if labels.get("ratio_pad"):
labels["ratio_pad"] = (labels["ratio_pad"], (left, top)) # for evaluation
if len(labels):
labels = self._update_labels(labels, ratio, dw, dh)
labels['img'] = img
labels['resized_shape'] = new_shape
labels["img"] = img
labels["resized_shape"] = new_shape
return labels
else:
return img
def _update_labels(self, labels, ratio, padw, padh):
"""Update labels."""
labels['instances'].convert_bbox(format='xyxy')
labels['instances'].denormalize(*labels['img'].shape[:2][::-1])
labels['instances'].scale(*ratio)
labels['instances'].add_padding(padw, padh)
labels["instances"].convert_bbox(format="xyxy")
labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
labels["instances"].scale(*ratio)
labels["instances"].add_padding(padw, padh)
return labels
@ -785,11 +784,11 @@ class CopyPaste:
1. Instances are expected to have 'segments' as one of their attributes for this augmentation to work.
2. This method modifies the input dictionary 'labels' in place.
"""
im = labels['img']
cls = labels['cls']
im = labels["img"]
cls = labels["cls"]
h, w = im.shape[:2]
instances = labels.pop('instances')
instances.convert_bbox(format='xyxy')
instances = labels.pop("instances")
instances.convert_bbox(format="xyxy")
instances.denormalize(w, h)
if self.p and len(instances.segments):
n = len(instances)
@ -812,9 +811,9 @@ class CopyPaste:
i = cv2.flip(im_new, 1).astype(bool)
im[i] = result[i]
labels['img'] = im
labels['cls'] = cls
labels['instances'] = instances
labels["img"] = im
labels["cls"] = cls
labels["instances"] = instances
return labels
@ -831,12 +830,13 @@ class Albumentations:
"""Initialize the transform object for YOLO bbox formatted params."""
self.p = p
self.transform = None
prefix = colorstr('albumentations: ')
prefix = colorstr("albumentations: ")
try:
import albumentations as A
check_version(A.__version__, '1.0.3', hard=True) # version requirement
check_version(A.__version__, "1.0.3", hard=True) # version requirement
# Transforms
T = [
A.Blur(p=0.01),
A.MedianBlur(p=0.01),
@ -844,31 +844,32 @@ class Albumentations:
A.CLAHE(p=0.01),
A.RandomBrightnessContrast(p=0.0),
A.RandomGamma(p=0.0),
A.ImageCompression(quality_lower=75, p=0.0)] # transforms
self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
A.ImageCompression(quality_lower=75, p=0.0),
]
self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
except ImportError: # package not installed, skip
pass
except Exception as e:
LOGGER.info(f'{prefix}{e}')
LOGGER.info(f"{prefix}{e}")
def __call__(self, labels):
"""Generates object detections and returns a dictionary with detection results."""
im = labels['img']
cls = labels['cls']
im = labels["img"]
cls = labels["cls"]
if len(cls):
labels['instances'].convert_bbox('xywh')
labels['instances'].normalize(*im.shape[:2][::-1])
bboxes = labels['instances'].bboxes
labels["instances"].convert_bbox("xywh")
labels["instances"].normalize(*im.shape[:2][::-1])
bboxes = labels["instances"].bboxes
# TODO: add supports of segments and keypoints
if self.transform and random.random() < self.p:
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
if len(new['class_labels']) > 0: # skip update if no bbox in new im
labels['img'] = new['image']
labels['cls'] = np.array(new['class_labels'])
bboxes = np.array(new['bboxes'], dtype=np.float32)
labels['instances'].update(bboxes=bboxes)
if len(new["class_labels"]) > 0: # skip update if no bbox in new im
labels["img"] = new["image"]
labels["cls"] = np.array(new["class_labels"])
bboxes = np.array(new["bboxes"], dtype=np.float32)
labels["instances"].update(bboxes=bboxes)
return labels
@ -888,15 +889,17 @@ class Format:
batch_idx (bool): Keep batch indexes. Default is True.
"""
def __init__(self,
bbox_format='xywh',
normalize=True,
return_mask=False,
return_keypoint=False,
return_obb=False,
mask_ratio=4,
mask_overlap=True,
batch_idx=True):
def __init__(
self,
bbox_format="xywh",
normalize=True,
return_mask=False,
return_keypoint=False,
return_obb=False,
mask_ratio=4,
mask_overlap=True,
batch_idx=True,
):
"""Initializes the Format class with given parameters."""
self.bbox_format = bbox_format
self.normalize = normalize
@ -909,10 +912,10 @@ class Format:
def __call__(self, labels):
"""Return formatted image, classes, bounding boxes & keypoints to be used by 'collate_fn'."""
img = labels.pop('img')
img = labels.pop("img")
h, w = img.shape[:2]
cls = labels.pop('cls')
instances = labels.pop('instances')
cls = labels.pop("cls")
instances = labels.pop("instances")
instances.convert_bbox(format=self.bbox_format)
instances.denormalize(w, h)
nl = len(instances)
@ -922,22 +925,24 @@ class Format:
masks, instances, cls = self._format_segments(instances, cls, w, h)
masks = torch.from_numpy(masks)
else:
masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio,
img.shape[1] // self.mask_ratio)
labels['masks'] = masks
masks = torch.zeros(
1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio
)
labels["masks"] = masks
if self.normalize:
instances.normalize(w, h)
labels['img'] = self._format_img(img)
labels['cls'] = torch.from_numpy(cls) if nl else torch.zeros(nl)
labels['bboxes'] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
labels["img"] = self._format_img(img)
labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
if self.return_keypoint:
labels['keypoints'] = torch.from_numpy(instances.keypoints)
labels["keypoints"] = torch.from_numpy(instances.keypoints)
if self.return_obb:
labels['bboxes'] = xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(
instances.segments) else torch.zeros((0, 5))
labels["bboxes"] = (
xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(instances.segments) else torch.zeros((0, 5))
)
# Then we can use collate_fn
if self.batch_idx:
labels['batch_idx'] = torch.zeros(nl)
labels["batch_idx"] = torch.zeros(nl)
return labels
def _format_img(self, img):
@ -964,33 +969,39 @@ class Format:
def v8_transforms(dataset, imgsz, hyp, stretch=False):
"""Convert images to a size suitable for YOLOv8 training."""
pre_transform = Compose([
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
CopyPaste(p=hyp.copy_paste),
RandomPerspective(
degrees=hyp.degrees,
translate=hyp.translate,
scale=hyp.scale,
shear=hyp.shear,
perspective=hyp.perspective,
pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),
)])
flip_idx = dataset.data.get('flip_idx', []) # for keypoints augmentation
pre_transform = Compose(
[
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
CopyPaste(p=hyp.copy_paste),
RandomPerspective(
degrees=hyp.degrees,
translate=hyp.translate,
scale=hyp.scale,
shear=hyp.shear,
perspective=hyp.perspective,
pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),
),
]
)
flip_idx = dataset.data.get("flip_idx", []) # for keypoints augmentation
if dataset.use_keypoints:
kpt_shape = dataset.data.get('kpt_shape', None)
kpt_shape = dataset.data.get("kpt_shape", None)
if len(flip_idx) == 0 and hyp.fliplr > 0.0:
hyp.fliplr = 0.0
LOGGER.warning("WARNING ⚠ No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'")
elif flip_idx and (len(flip_idx) != kpt_shape[0]):
raise ValueError(f'data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}')
raise ValueError(f"data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}")
return Compose([
pre_transform,
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
Albumentations(p=1.0),
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
RandomFlip(direction='vertical', p=hyp.flipud),
RandomFlip(direction='horizontal', p=hyp.fliplr, flip_idx=flip_idx)]) # transforms
return Compose(
[
pre_transform,
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
Albumentations(p=1.0),
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
RandomFlip(direction="vertical", p=hyp.flipud),
RandomFlip(direction="horizontal", p=hyp.fliplr, flip_idx=flip_idx),
]
) # transforms
# Classification augmentations -----------------------------------------------------------------------------------------
@ -1031,10 +1042,13 @@ def classify_transforms(
tfl = [T.Resize(scale_size)]
tfl += [T.CenterCrop(size)]
tfl += [T.ToTensor(), T.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std),
)]
tfl += [
T.ToTensor(),
T.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std),
),
]
return T.Compose(tfl)
@ -1053,7 +1067,7 @@ def classify_augmentations(
hsv_s=0.4, # image HSV-Saturation augmentation (fraction)
hsv_v=0.4, # image HSV-Value augmentation (fraction)
force_color_jitter=False,
erasing=0.,
erasing=0.0,
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
):
"""
@ -1080,13 +1094,13 @@ def classify_augmentations(
"""
# Transforms to apply if albumentations not installed
if not isinstance(size, int):
raise TypeError(f'classify_transforms() size {size} must be integer, not (list, tuple)')
raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range
primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)]
if hflip > 0.:
if hflip > 0.0:
primary_tfl += [T.RandomHorizontalFlip(p=hflip)]
if vflip > 0.:
if vflip > 0.0:
primary_tfl += [T.RandomVerticalFlip(p=vflip)]
secondary_tfl = []
@ -1097,27 +1111,29 @@ def classify_augmentations(
# this allows override without breaking old hparm cfgs
disable_color_jitter = not force_color_jitter
if auto_augment == 'randaugment':
if auto_augment == "randaugment":
if TORCHVISION_0_11:
secondary_tfl += [T.RandAugment(interpolation=interpolation)]
else:
LOGGER.warning('"auto_augment=randaugment" requires torchvision >= 0.11.0. Disabling it.')
elif auto_augment == 'augmix':
elif auto_augment == "augmix":
if TORCHVISION_0_13:
secondary_tfl += [T.AugMix(interpolation=interpolation)]
else:
LOGGER.warning('"auto_augment=augmix" requires torchvision >= 0.13.0. Disabling it.')
elif auto_augment == 'autoaugment':
elif auto_augment == "autoaugment":
if TORCHVISION_0_10:
secondary_tfl += [T.AutoAugment(interpolation=interpolation)]
else:
LOGGER.warning('"auto_augment=autoaugment" requires torchvision >= 0.10.0. Disabling it.')
else:
raise ValueError(f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", '
f'"augmix", "autoaugment" or None')
raise ValueError(
f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", '
f'"augmix", "autoaugment" or None'
)
if not disable_color_jitter:
secondary_tfl += [T.ColorJitter(brightness=hsv_v, contrast=hsv_v, saturation=hsv_s, hue=hsv_h)]
@ -1125,7 +1141,8 @@ def classify_augmentations(
final_tfl = [
T.ToTensor(),
T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
T.RandomErasing(p=erasing, inplace=True)]
T.RandomErasing(p=erasing, inplace=True),
]
return T.Compose(primary_tfl + secondary_tfl + final_tfl)
@ -1177,7 +1194,7 @@ class ClassifyLetterBox:
# Create padded image
im_out = np.full((hs, ws, 3), 114, dtype=im.dtype)
im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
im_out[top : top + h, left : left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
return im_out
@ -1205,7 +1222,7 @@ class CenterCrop:
imh, imw = im.shape[:2]
m = min(imh, imw) # min dimension
top, left = (imh - m) // 2, (imw - m) // 2
return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
return cv2.resize(im[top : top + m, left : left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
# NOTE: keep this class for backward compatibility

@ -47,20 +47,22 @@ class BaseDataset(Dataset):
transforms (callable): Image transformation function.
"""
def __init__(self,
img_path,
imgsz=640,
cache=False,
augment=True,
hyp=DEFAULT_CFG,
prefix='',
rect=False,
batch_size=16,
stride=32,
pad=0.5,
single_cls=False,
classes=None,
fraction=1.0):
def __init__(
self,
img_path,
imgsz=640,
cache=False,
augment=True,
hyp=DEFAULT_CFG,
prefix="",
rect=False,
batch_size=16,
stride=32,
pad=0.5,
single_cls=False,
classes=None,
fraction=1.0,
):
"""Initialize BaseDataset with given configuration and options."""
super().__init__()
self.img_path = img_path
@ -86,10 +88,10 @@ class BaseDataset(Dataset):
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
# Cache images
if cache == 'ram' and not self.check_cache_ram():
if cache == "ram" and not self.check_cache_ram():
cache = False
self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
if cache:
self.cache_images(cache)
@ -103,23 +105,23 @@ class BaseDataset(Dataset):
for p in img_path if isinstance(img_path, list) else [img_path]:
p = Path(p) # os-agnostic
if p.is_dir(): # dir
f += glob.glob(str(p / '**' / '*.*'), recursive=True)
f += glob.glob(str(p / "**" / "*.*"), recursive=True)
# F = list(p.rglob('*.*')) # pathlib
elif p.is_file(): # file
with open(p) as t:
t = t.read().strip().splitlines()
parent = str(p.parent) + os.sep
f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path
# F += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
else:
raise FileNotFoundError(f'{self.prefix}{p} does not exist')
im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
raise FileNotFoundError(f"{self.prefix}{p} does not exist")
im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
assert im_files, f'{self.prefix}No images found in {img_path}'
assert im_files, f"{self.prefix}No images found in {img_path}"
except Exception as e:
raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
if self.fraction < 1:
im_files = im_files[:round(len(im_files) * self.fraction)]
im_files = im_files[: round(len(im_files) * self.fraction)]
return im_files
def update_labels(self, include_class: Optional[list]):
@ -127,19 +129,19 @@ class BaseDataset(Dataset):
include_class_array = np.array(include_class).reshape(1, -1)
for i in range(len(self.labels)):
if include_class is not None:
cls = self.labels[i]['cls']
bboxes = self.labels[i]['bboxes']
segments = self.labels[i]['segments']
keypoints = self.labels[i]['keypoints']
cls = self.labels[i]["cls"]
bboxes = self.labels[i]["bboxes"]
segments = self.labels[i]["segments"]
keypoints = self.labels[i]["keypoints"]
j = (cls == include_class_array).any(1)
self.labels[i]['cls'] = cls[j]
self.labels[i]['bboxes'] = bboxes[j]
self.labels[i]["cls"] = cls[j]
self.labels[i]["bboxes"] = bboxes[j]
if segments:
self.labels[i]['segments'] = [segments[si] for si, idx in enumerate(j) if idx]
self.labels[i]["segments"] = [segments[si] for si, idx in enumerate(j) if idx]
if keypoints is not None:
self.labels[i]['keypoints'] = keypoints[j]
self.labels[i]["keypoints"] = keypoints[j]
if self.single_cls:
self.labels[i]['cls'][:, 0] = 0
self.labels[i]["cls"][:, 0] = 0
def load_image(self, i, rect_mode=True):
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
@ -149,13 +151,13 @@ class BaseDataset(Dataset):
try:
im = np.load(fn)
except Exception as e:
LOGGER.warning(f'{self.prefix}WARNING ⚠ Removing corrupt *.npy image file {fn} due to: {e}')
LOGGER.warning(f"{self.prefix}WARNING ⚠ Removing corrupt *.npy image file {fn} due to: {e}")
Path(fn).unlink(missing_ok=True)
im = cv2.imread(f) # BGR
else: # read image
im = cv2.imread(f) # BGR
if im is None:
raise FileNotFoundError(f'Image Not Found {f}')
raise FileNotFoundError(f"Image Not Found {f}")
h0, w0 = im.shape[:2] # orig hw
if rect_mode: # resize long side to imgsz while maintaining aspect ratio
@ -181,17 +183,17 @@ class BaseDataset(Dataset):
def cache_images(self, cache):
"""Cache images to memory or disk."""
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
fcn = self.cache_images_to_disk if cache == "disk" else self.load_image
with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(fcn, range(self.ni))
pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0)
for i, x in pbar:
if cache == 'disk':
if cache == "disk":
b += self.npy_files[i].stat().st_size
else: # 'ram'
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
b += self.ims[i].nbytes
pbar.desc = f'{self.prefix}Caching images ({b / gb:.1f}GB {cache})'
pbar.desc = f"{self.prefix}Caching images ({b / gb:.1f}GB {cache})"
pbar.close()
def cache_images_to_disk(self, i):
@ -207,15 +209,17 @@ class BaseDataset(Dataset):
for _ in range(n):
im = cv2.imread(random.choice(self.im_files)) # sample image
ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
b += im.nbytes * ratio ** 2
b += im.nbytes * ratio**2
mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM
mem = psutil.virtual_memory()
cache = mem_required < mem.available # to cache or not to cache, that is the question
if not cache:
LOGGER.info(f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images '
f'with {int(safety_margin * 100)}% safety margin but only '
f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
f"{'caching images ✅' if cache else 'not caching images ⚠'}")
LOGGER.info(
f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images '
f'with {int(safety_margin * 100)}% safety margin but only '
f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
f"{'caching images ✅' if cache else 'not caching images ⚠'}"
)
return cache
def set_rectangle(self):
@ -223,7 +227,7 @@ class BaseDataset(Dataset):
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
nb = bi[-1] + 1 # number of batches
s = np.array([x.pop('shape') for x in self.labels]) # hw
s = np.array([x.pop("shape") for x in self.labels]) # hw
ar = s[:, 0] / s[:, 1] # aspect ratio
irect = ar.argsort()
self.im_files = [self.im_files[i] for i in irect]
@ -250,12 +254,14 @@ class BaseDataset(Dataset):
def get_image_and_label(self, index):
"""Get and return label information from the dataset."""
label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
label.pop('shape', None) # shape is for rect, remove it
label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index)
label['ratio_pad'] = (label['resized_shape'][0] / label['ori_shape'][0],
label['resized_shape'][1] / label['ori_shape'][1]) # for evaluation
label.pop("shape", None) # shape is for rect, remove it
label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
label["ratio_pad"] = (
label["resized_shape"][0] / label["ori_shape"][0],
label["resized_shape"][1] / label["ori_shape"][1],
) # for evaluation
if self.rect:
label['rect_shape'] = self.batch_shapes[self.batch[index]]
label["rect_shape"] = self.batch_shapes[self.batch[index]]
return self.update_labels_info(label)
def __len__(self):

@ -9,8 +9,16 @@ import torch
from PIL import Image
from torch.utils.data import dataloader, distributed
from ultralytics.data.loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams, LoadTensor,
SourceTypes, autocast_list)
from ultralytics.data.loaders import (
LOADERS,
LoadImages,
LoadPilAndNumpy,
LoadScreenshots,
LoadStreams,
LoadTensor,
SourceTypes,
autocast_list,
)
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.utils import RANK, colorstr
from ultralytics.utils.checks import check_file
@ -29,7 +37,7 @@ class InfiniteDataLoader(dataloader.DataLoader):
def __init__(self, *args, **kwargs):
"""Dataloader that infinitely recycles workers, inherits from DataLoader."""
super().__init__(*args, **kwargs)
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__()
def __len__(self):
@ -70,29 +78,30 @@ class _RepeatSampler:
def seed_worker(worker_id): # noqa
"""Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader."""
worker_seed = torch.initial_seed() % 2 ** 32
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def build_yolo_dataset(cfg, img_path, batch, data, mode='train', rect=False, stride=32):
def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32):
"""Build YOLO Dataset."""
return YOLODataset(
img_path=img_path,
imgsz=cfg.imgsz,
batch_size=batch,
augment=mode == 'train', # augmentation
augment=mode == "train", # augmentation
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
rect=cfg.rect or rect, # rectangular batches
cache=cfg.cache or None,
single_cls=cfg.single_cls or False,
stride=int(stride),
pad=0.0 if mode == 'train' else 0.5,
prefix=colorstr(f'{mode}: '),
pad=0.0 if mode == "train" else 0.5,
prefix=colorstr(f"{mode}: "),
task=cfg.task,
classes=cfg.classes,
data=data,
fraction=cfg.fraction if mode == 'train' else 1.0)
fraction=cfg.fraction if mode == "train" else 1.0,
)
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
@ -103,15 +112,17 @@ def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK)
return InfiniteDataLoader(dataset=dataset,
batch_size=batch,
shuffle=shuffle and sampler is None,
num_workers=nw,
sampler=sampler,
pin_memory=PIN_MEMORY,
collate_fn=getattr(dataset, 'collate_fn', None),
worker_init_fn=seed_worker,
generator=generator)
return InfiniteDataLoader(
dataset=dataset,
batch_size=batch,
shuffle=shuffle and sampler is None,
num_workers=nw,
sampler=sampler,
pin_memory=PIN_MEMORY,
collate_fn=getattr(dataset, "collate_fn", None),
worker_init_fn=seed_worker,
generator=generator,
)
def check_source(source):
@ -120,9 +131,9 @@ def check_source(source):
if isinstance(source, (str, int, Path)): # int for local usb camera
source = str(source)
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://', 'tcp://'))
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
screenshot = source.lower() == 'screen'
is_url = source.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://"))
webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file)
screenshot = source.lower() == "screen"
if is_url and is_file:
source = check_file(source) # download
elif isinstance(source, LOADERS):
@ -135,7 +146,7 @@ def check_source(source):
elif isinstance(source, torch.Tensor):
tensor = True
else:
raise TypeError('Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict')
raise TypeError("Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict")
return source, webcam, screenshot, from_img, in_memory, tensor
@ -171,6 +182,6 @@ def load_inference_source(source=None, imgsz=640, vid_stride=1, buffer=False):
dataset = LoadImages(source, imgsz=imgsz, vid_stride=vid_stride)
# Attach source types to the dataset
setattr(dataset, 'source_type', source_type)
setattr(dataset, "source_type", source_type)
return dataset

@ -20,10 +20,98 @@ def coco91_to_coco80_class():
corresponding 91-index class ID.
"""
return [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, None, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, None, 24, 25, None,
None, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, None, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
51, 52, 53, 54, 55, 56, 57, 58, 59, None, 60, None, None, 61, None, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72,
None, 73, 74, 75, 76, 77, 78, 79, None]
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
None,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
None,
24,
25,
None,
None,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
None,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
None,
60,
None,
None,
61,
None,
62,
63,
64,
65,
66,
67,
68,
69,
70,
71,
72,
None,
73,
74,
75,
76,
77,
78,
79,
None,
]
def coco80_to_coco91_class():
@ -42,16 +130,96 @@ def coco80_to_coco91_class():
```
"""
return [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
def convert_coco(labels_dir='../coco/annotations/',
save_dir='coco_converted/',
use_segments=False,
use_keypoints=False,
cls91to80=True):
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
27,
28,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
60,
61,
62,
63,
64,
65,
67,
70,
72,
73,
74,
75,
76,
77,
78,
79,
80,
81,
82,
84,
85,
86,
87,
88,
89,
90,
]
def convert_coco(
labels_dir="../coco/annotations/",
save_dir="coco_converted/",
use_segments=False,
use_keypoints=False,
cls91to80=True,
):
"""
Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.
@ -75,76 +243,78 @@ def convert_coco(labels_dir='../coco/annotations/',
# Create dataset directory
save_dir = increment_path(save_dir) # increment if save directory already exists
for p in save_dir / 'labels', save_dir / 'images':
for p in save_dir / "labels", save_dir / "images":
p.mkdir(parents=True, exist_ok=True) # make dir
# Convert classes
coco80 = coco91_to_coco80_class()
# Import json
for json_file in sorted(Path(labels_dir).resolve().glob('*.json')):
fn = Path(save_dir) / 'labels' / json_file.stem.replace('instances_', '') # folder name
for json_file in sorted(Path(labels_dir).resolve().glob("*.json")):
fn = Path(save_dir) / "labels" / json_file.stem.replace("instances_", "") # folder name
fn.mkdir(parents=True, exist_ok=True)
with open(json_file) as f:
data = json.load(f)
# Create image dict
images = {f'{x["id"]:d}': x for x in data['images']}
images = {f'{x["id"]:d}': x for x in data["images"]}
# Create image-annotations dict
imgToAnns = defaultdict(list)
for ann in data['annotations']:
imgToAnns[ann['image_id']].append(ann)
for ann in data["annotations"]:
imgToAnns[ann["image_id"]].append(ann)
# Write labels file
for img_id, anns in TQDM(imgToAnns.items(), desc=f'Annotations {json_file}'):
img = images[f'{img_id:d}']
h, w, f = img['height'], img['width'], img['file_name']
for img_id, anns in TQDM(imgToAnns.items(), desc=f"Annotations {json_file}"):
img = images[f"{img_id:d}"]
h, w, f = img["height"], img["width"], img["file_name"]
bboxes = []
segments = []
keypoints = []
for ann in anns:
if ann['iscrowd']:
if ann["iscrowd"]:
continue
# The COCO box format is [top left x, top left y, width, height]
box = np.array(ann['bbox'], dtype=np.float64)
box = np.array(ann["bbox"], dtype=np.float64)
box[:2] += box[2:] / 2 # xy top-left corner to center
box[[0, 2]] /= w # normalize x
box[[1, 3]] /= h # normalize y
if box[2] <= 0 or box[3] <= 0: # if w <= 0 and h <= 0
continue
cls = coco80[ann['category_id'] - 1] if cls91to80 else ann['category_id'] - 1 # class
cls = coco80[ann["category_id"] - 1] if cls91to80 else ann["category_id"] - 1 # class
box = [cls] + box.tolist()
if box not in bboxes:
bboxes.append(box)
if use_segments and ann.get('segmentation') is not None:
if len(ann['segmentation']) == 0:
if use_segments and ann.get("segmentation") is not None:
if len(ann["segmentation"]) == 0:
segments.append([])
continue
elif len(ann['segmentation']) > 1:
s = merge_multi_segment(ann['segmentation'])
elif len(ann["segmentation"]) > 1:
s = merge_multi_segment(ann["segmentation"])
s = (np.concatenate(s, axis=0) / np.array([w, h])).reshape(-1).tolist()
else:
s = [j for i in ann['segmentation'] for j in i] # all segments concatenated
s = [j for i in ann["segmentation"] for j in i] # all segments concatenated
s = (np.array(s).reshape(-1, 2) / np.array([w, h])).reshape(-1).tolist()
s = [cls] + s
segments.append(s)
if use_keypoints and ann.get('keypoints') is not None:
keypoints.append(box + (np.array(ann['keypoints']).reshape(-1, 3) /
np.array([w, h, 1])).reshape(-1).tolist())
if use_keypoints and ann.get("keypoints") is not None:
keypoints.append(
box + (np.array(ann["keypoints"]).reshape(-1, 3) / np.array([w, h, 1])).reshape(-1).tolist()
)
# Write
with open((fn / f).with_suffix('.txt'), 'a') as file:
with open((fn / f).with_suffix(".txt"), "a") as file:
for i in range(len(bboxes)):
if use_keypoints:
line = *(keypoints[i]), # cls, box, keypoints
line = (*(keypoints[i]),) # cls, box, keypoints
else:
line = *(segments[i]
if use_segments and len(segments[i]) > 0 else bboxes[i]), # cls, box or segments
file.write(('%g ' * len(line)).rstrip() % line + '\n')
line = (
*(segments[i] if use_segments and len(segments[i]) > 0 else bboxes[i]),
) # cls, box or segments
file.write(("%g " * len(line)).rstrip() % line + "\n")
LOGGER.info(f'COCO data converted successfully.\nResults saved to {save_dir.resolve()}')
LOGGER.info(f"COCO data converted successfully.\nResults saved to {save_dir.resolve()}")
def convert_dota_to_yolo_obb(dota_root_path: str):
@ -184,31 +354,32 @@ def convert_dota_to_yolo_obb(dota_root_path: str):
# Class names to indices mapping
class_mapping = {
'plane': 0,
'ship': 1,
'storage-tank': 2,
'baseball-diamond': 3,
'tennis-court': 4,
'basketball-court': 5,
'ground-track-field': 6,
'harbor': 7,
'bridge': 8,
'large-vehicle': 9,
'small-vehicle': 10,
'helicopter': 11,
'roundabout': 12,
'soccer-ball-field': 13,
'swimming-pool': 14,
'container-crane': 15,
'airport': 16,
'helipad': 17}
"plane": 0,
"ship": 1,
"storage-tank": 2,
"baseball-diamond": 3,
"tennis-court": 4,
"basketball-court": 5,
"ground-track-field": 6,
"harbor": 7,
"bridge": 8,
"large-vehicle": 9,
"small-vehicle": 10,
"helicopter": 11,
"roundabout": 12,
"soccer-ball-field": 13,
"swimming-pool": 14,
"container-crane": 15,
"airport": 16,
"helipad": 17,
}
def convert_label(image_name, image_width, image_height, orig_label_dir, save_dir):
"""Converts a single image's DOTA annotation to YOLO OBB format and saves it to a specified directory."""
orig_label_path = orig_label_dir / f'{image_name}.txt'
save_path = save_dir / f'{image_name}.txt'
orig_label_path = orig_label_dir / f"{image_name}.txt"
save_path = save_dir / f"{image_name}.txt"
with orig_label_path.open('r') as f, save_path.open('w') as g:
with orig_label_path.open("r") as f, save_path.open("w") as g:
lines = f.readlines()
for line in lines:
parts = line.strip().split()
@ -218,20 +389,21 @@ def convert_dota_to_yolo_obb(dota_root_path: str):
class_idx = class_mapping[class_name]
coords = [float(p) for p in parts[:8]]
normalized_coords = [
coords[i] / image_width if i % 2 == 0 else coords[i] / image_height for i in range(8)]
formatted_coords = ['{:.6g}'.format(coord) for coord in normalized_coords]
coords[i] / image_width if i % 2 == 0 else coords[i] / image_height for i in range(8)
]
formatted_coords = ["{:.6g}".format(coord) for coord in normalized_coords]
g.write(f"{class_idx} {' '.join(formatted_coords)}\n")
for phase in ['train', 'val']:
image_dir = dota_root_path / 'images' / phase
orig_label_dir = dota_root_path / 'labels' / f'{phase}_original'
save_dir = dota_root_path / 'labels' / phase
for phase in ["train", "val"]:
image_dir = dota_root_path / "images" / phase
orig_label_dir = dota_root_path / "labels" / f"{phase}_original"
save_dir = dota_root_path / "labels" / phase
save_dir.mkdir(parents=True, exist_ok=True)
image_paths = list(image_dir.iterdir())
for image_path in TQDM(image_paths, desc=f'Processing {phase} images'):
if image_path.suffix != '.png':
for image_path in TQDM(image_paths, desc=f"Processing {phase} images"):
if image_path.suffix != ".png":
continue
image_name_without_ext = image_path.stem
img = cv2.imread(str(image_path))
@ -293,7 +465,7 @@ def merge_multi_segment(segments):
s.append(segments[i])
else:
idx = [0, idx[1] - idx[0]]
s.append(segments[i][idx[0]:idx[1] + 1])
s.append(segments[i][idx[0] : idx[1] + 1])
else:
for i in range(len(idx_list) - 1, -1, -1):

@ -18,7 +18,7 @@ from .base import BaseDataset
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
DATASET_CACHE_VERSION = '1.0.3'
DATASET_CACHE_VERSION = "1.0.3"
class YOLODataset(BaseDataset):
@ -33,16 +33,16 @@ class YOLODataset(BaseDataset):
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
"""
def __init__(self, *args, data=None, task='detect', **kwargs):
def __init__(self, *args, data=None, task="detect", **kwargs):
"""Initializes the YOLODataset with optional configurations for segments and keypoints."""
self.use_segments = task == 'segment'
self.use_keypoints = task == 'pose'
self.use_obb = task == 'obb'
self.use_segments = task == "segment"
self.use_keypoints = task == "pose"
self.use_obb = task == "obb"
self.data = data
assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.'
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
super().__init__(*args, **kwargs)
def cache_labels(self, path=Path('./labels.cache')):
def cache_labels(self, path=Path("./labels.cache")):
"""
Cache dataset labels, check images and read shapes.
@ -51,19 +51,29 @@ class YOLODataset(BaseDataset):
Returns:
(dict): labels.
"""
x = {'labels': []}
x = {"labels": []}
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f'{self.prefix}Scanning {path.parent / path.stem}...'
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
total = len(self.im_files)
nkpt, ndim = self.data.get('kpt_shape', (0, 0))
nkpt, ndim = self.data.get("kpt_shape", (0, 0))
if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
raise ValueError("'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'")
raise ValueError(
"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
)
with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(func=verify_image_label,
iterable=zip(self.im_files, self.label_files, repeat(self.prefix),
repeat(self.use_keypoints), repeat(len(self.data['names'])), repeat(nkpt),
repeat(ndim)))
results = pool.imap(
func=verify_image_label,
iterable=zip(
self.im_files,
self.label_files,
repeat(self.prefix),
repeat(self.use_keypoints),
repeat(len(self.data["names"])),
repeat(nkpt),
repeat(ndim),
),
)
pbar = TQDM(results, desc=desc, total=total)
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
nm += nm_f
@ -71,7 +81,7 @@ class YOLODataset(BaseDataset):
ne += ne_f
nc += nc_f
if im_file:
x['labels'].append(
x["labels"].append(
dict(
im_file=im_file,
shape=shape,
@ -80,60 +90,63 @@ class YOLODataset(BaseDataset):
segments=segments,
keypoints=keypoint,
normalized=True,
bbox_format='xywh'))
bbox_format="xywh",
)
)
if msg:
msgs.append(msg)
pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
pbar.close()
if msgs:
LOGGER.info('\n'.join(msgs))
LOGGER.info("\n".join(msgs))
if nf == 0:
LOGGER.warning(f'{self.prefix}WARNING ⚠ No labels found in {path}. {HELP_URL}')
x['hash'] = get_hash(self.label_files + self.im_files)
x['results'] = nf, nm, ne, nc, len(self.im_files)
x['msgs'] = msgs # warnings
LOGGER.warning(f"{self.prefix}WARNING ⚠ No labels found in {path}. {HELP_URL}")
x["hash"] = get_hash(self.label_files + self.im_files)
x["results"] = nf, nm, ne, nc, len(self.im_files)
x["msgs"] = msgs # warnings
save_dataset_cache_file(self.prefix, path, x)
return x
def get_labels(self):
"""Returns dictionary of labels for YOLO training."""
self.label_files = img2label_paths(self.im_files)
cache_path = Path(self.label_files[0]).parent.with_suffix('.cache')
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
try:
cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
assert cache['version'] == DATASET_CACHE_VERSION # matches current version
assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
except (FileNotFoundError, AssertionError, AttributeError):
cache, exists = self.cache_labels(cache_path), False # run cache ops
# Display cache
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
if exists and LOCAL_RANK in (-1, 0):
d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
if cache['msgs']:
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
if cache["msgs"]:
LOGGER.info("\n".join(cache["msgs"])) # display warnings
# Read cache
[cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
labels = cache['labels']
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
labels = cache["labels"]
if not labels:
LOGGER.warning(f'WARNING ⚠ No images found in {cache_path}, training may not work correctly. {HELP_URL}')
self.im_files = [lb['im_file'] for lb in labels] # update im_files
LOGGER.warning(f"WARNING ⚠ No images found in {cache_path}, training may not work correctly. {HELP_URL}")
self.im_files = [lb["im_file"] for lb in labels] # update im_files
# Check if the dataset is all boxes or all segments
lengths = ((len(lb['cls']), len(lb['bboxes']), len(lb['segments'])) for lb in labels)
lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)
len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
if len_segments and len_boxes != len_segments:
LOGGER.warning(
f'WARNING ⚠ Box and segment counts should be equal, but got len(segments) = {len_segments}, '
f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. '
'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.')
f"WARNING ⚠ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset."
)
for lb in labels:
lb['segments'] = []
lb["segments"] = []
if len_cls == 0:
LOGGER.warning(f'WARNING ⚠ No labels found in {cache_path}, training may not work correctly. {HELP_URL}')
LOGGER.warning(f"WARNING ⚠ No labels found in {cache_path}, training may not work correctly. {HELP_URL}")
return labels
def build_transforms(self, hyp=None):
@ -145,14 +158,17 @@ class YOLODataset(BaseDataset):
else:
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
transforms.append(
Format(bbox_format='xywh',
normalize=True,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
return_obb=self.use_obb,
batch_idx=True,
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask))
Format(
bbox_format="xywh",
normalize=True,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
return_obb=self.use_obb,
batch_idx=True,
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask,
)
)
return transforms
def close_mosaic(self, hyp):
@ -166,11 +182,11 @@ class YOLODataset(BaseDataset):
"""Custom your label format here."""
# NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
# We can make it also support classification and semantic segmentation by add or remove some dict keys there.
bboxes = label.pop('bboxes')
segments = label.pop('segments', [])
keypoints = label.pop('keypoints', None)
bbox_format = label.pop('bbox_format')
normalized = label.pop('normalized')
bboxes = label.pop("bboxes")
segments = label.pop("segments", [])
keypoints = label.pop("keypoints", None)
bbox_format = label.pop("bbox_format")
normalized = label.pop("normalized")
# NOTE: do NOT resample oriented boxes
segment_resamples = 100 if self.use_obb else 1000
@ -180,7 +196,7 @@ class YOLODataset(BaseDataset):
segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
else:
segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
return label
@staticmethod
@ -191,15 +207,15 @@ class YOLODataset(BaseDataset):
values = list(zip(*[list(b.values()) for b in batch]))
for i, k in enumerate(keys):
value = values[i]
if k == 'img':
if k == "img":
value = torch.stack(value, 0)
if k in ['masks', 'keypoints', 'bboxes', 'cls', 'segments', 'obb']:
if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]:
value = torch.cat(value, 0)
new_batch[k] = value
new_batch['batch_idx'] = list(new_batch['batch_idx'])
for i in range(len(new_batch['batch_idx'])):
new_batch['batch_idx'][i] += i # add target image index for build_targets()
new_batch['batch_idx'] = torch.cat(new_batch['batch_idx'], 0)
new_batch["batch_idx"] = list(new_batch["batch_idx"])
for i in range(len(new_batch["batch_idx"])):
new_batch["batch_idx"][i] += i # add target image index for build_targets()
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
return new_batch
@ -219,7 +235,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True.
"""
def __init__(self, root, args, augment=False, cache=False, prefix=''):
def __init__(self, root, args, augment=False, cache=False, prefix=""):
"""
Initialize YOLO object with root, image size, augmentations, and cache settings.
@ -231,23 +247,28 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
"""
super().__init__(root=root)
if augment and args.fraction < 1.0: # reduce training fraction
self.samples = self.samples[:round(len(self.samples) * args.fraction)]
self.prefix = colorstr(f'{prefix}: ') if prefix else ''
self.cache_ram = cache is True or cache == 'ram'
self.cache_disk = cache == 'disk'
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
self.cache_ram = cache is True or cache == "ram"
self.cache_disk = cache == "disk"
self.samples = self.verify_images() # filter out bad images
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
self.torch_transforms = classify_augmentations(size=args.imgsz,
scale=scale,
hflip=args.fliplr,
vflip=args.flipud,
erasing=args.erasing,
auto_augment=args.auto_augment,
hsv_h=args.hsv_h,
hsv_s=args.hsv_s,
hsv_v=args.hsv_v) if augment else classify_transforms(
size=args.imgsz, crop_fraction=args.crop_fraction)
self.torch_transforms = (
classify_augmentations(
size=args.imgsz,
scale=scale,
hflip=args.fliplr,
vflip=args.flipud,
erasing=args.erasing,
auto_augment=args.auto_augment,
hsv_h=args.hsv_h,
hsv_s=args.hsv_s,
hsv_v=args.hsv_v,
)
if augment
else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
)
def __getitem__(self, i):
"""Returns subset of data and targets corresponding to given indices."""
@ -263,7 +284,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
# Convert NumPy array to PIL image
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
sample = self.torch_transforms(im)
return {'img': sample, 'cls': j}
return {"img": sample, "cls": j}
def __len__(self) -> int:
"""Return the total number of samples in the dataset."""
@ -271,19 +292,19 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
def verify_images(self):
"""Verify all images in dataset."""
desc = f'{self.prefix}Scanning {self.root}...'
path = Path(self.root).with_suffix('.cache') # *.cache file path
desc = f"{self.prefix}Scanning {self.root}..."
path = Path(self.root).with_suffix(".cache") # *.cache file path
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
assert cache['version'] == DATASET_CACHE_VERSION # matches current version
assert cache['hash'] == get_hash([x[0] for x in self.samples]) # identical hash
nf, nc, n, samples = cache.pop('results') # found, missing, empty, corrupt, total
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
if LOCAL_RANK in (-1, 0):
d = f'{desc} {nf} images, {nc} corrupt'
d = f"{desc} {nf} images, {nc} corrupt"
TQDM(None, desc=d, total=n, initial=n)
if cache['msgs']:
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
if cache["msgs"]:
LOGGER.info("\n".join(cache["msgs"])) # display warnings
return samples
# Run scan if *.cache retrieval failed
@ -298,13 +319,13 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
msgs.append(msg)
nf += nf_f
nc += nc_f
pbar.desc = f'{desc} {nf} images, {nc} corrupt'
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
pbar.close()
if msgs:
LOGGER.info('\n'.join(msgs))
x['hash'] = get_hash([x[0] for x in self.samples])
x['results'] = nf, nc, len(samples), samples
x['msgs'] = msgs # warnings
LOGGER.info("\n".join(msgs))
x["hash"] = get_hash([x[0] for x in self.samples])
x["results"] = nf, nc, len(samples), samples
x["msgs"] = msgs # warnings
save_dataset_cache_file(self.prefix, path, x)
return samples
@ -312,6 +333,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
def load_dataset_cache_file(path):
"""Load an Ultralytics *.cache dictionary from path."""
import gc
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
cache = np.load(str(path), allow_pickle=True).item() # load dict
gc.enable()
@ -320,15 +342,15 @@ def load_dataset_cache_file(path):
def save_dataset_cache_file(prefix, path, x):
"""Save an Ultralytics dataset *.cache dictionary x to path."""
x['version'] = DATASET_CACHE_VERSION # add cache version
x["version"] = DATASET_CACHE_VERSION # add cache version
if is_dir_writeable(path.parent):
if path.exists():
path.unlink() # remove *.cache file if exists
np.save(str(path), x) # save cache for next time
path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
LOGGER.info(f'{prefix}New cache created: {path}')
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
LOGGER.info(f"{prefix}New cache created: {path}")
else:
LOGGER.warning(f'{prefix}WARNING ⚠ Cache directory {path.parent} is not writeable, cache not saved.')
LOGGER.warning(f"{prefix}WARNING ⚠ Cache directory {path.parent} is not writeable, cache not saved.")
# TODO: support semantic segmentation

@ -2,4 +2,4 @@
from .utils import plot_query_result
__all__ = ['plot_query_result']
__all__ = ["plot_query_result"]

@ -22,7 +22,6 @@ from .utils import get_sim_index_schema, get_table_schema, plot_query_result, pr
class ExplorerDataset(YOLODataset):
def __init__(self, *args, data: dict = None, **kwargs) -> None:
super().__init__(*args, data=data, **kwargs)
@ -35,7 +34,7 @@ class ExplorerDataset(YOLODataset):
else: # read image
im = cv2.imread(f) # BGR
if im is None:
raise FileNotFoundError(f'Image Not Found {f}')
raise FileNotFoundError(f"Image Not Found {f}")
h0, w0 = im.shape[:2] # orig hw
return im, (h0, w0), im.shape[:2]
@ -44,7 +43,7 @@ class ExplorerDataset(YOLODataset):
def build_transforms(self, hyp: IterableSimpleNamespace = None):
"""Creates transforms for dataset images without resizing."""
return Format(
bbox_format='xyxy',
bbox_format="xyxy",
normalize=False,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
@ -55,17 +54,16 @@ class ExplorerDataset(YOLODataset):
class Explorer:
def __init__(self,
data: Union[str, Path] = 'coco128.yaml',
model: str = 'yolov8n.pt',
uri: str = '~/ultralytics/explorer') -> None:
checks.check_requirements(['lancedb>=0.4.3', 'duckdb'])
def __init__(
self, data: Union[str, Path] = "coco128.yaml", model: str = "yolov8n.pt", uri: str = "~/ultralytics/explorer"
) -> None:
checks.check_requirements(["lancedb>=0.4.3", "duckdb"])
import lancedb
self.connection = lancedb.connect(uri)
self.table_name = Path(data).name.lower() + '_' + model.lower()
self.sim_idx_base_name = f'{self.table_name}_sim_idx'.lower(
self.table_name = Path(data).name.lower() + "_" + model.lower()
self.sim_idx_base_name = (
f"{self.table_name}_sim_idx".lower()
) # Use this name and append thres and top_k to reuse the table
self.model = YOLO(model)
self.data = data # None
@ -74,7 +72,7 @@ class Explorer:
self.table = None
self.progress = 0
def create_embeddings_table(self, force: bool = False, split: str = 'train') -> None:
def create_embeddings_table(self, force: bool = False, split: str = "train") -> None:
"""
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
already exists. Pass force=True to overwrite the existing table.
@ -90,20 +88,20 @@ class Explorer:
```
"""
if self.table is not None and not force:
LOGGER.info('Table already exists. Reusing it. Pass force=True to overwrite it.')
LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.")
return
if self.table_name in self.connection.table_names() and not force:
LOGGER.info(f'Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.')
LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.")
self.table = self.connection.open_table(self.table_name)
self.progress = 1
return
if self.data is None:
raise ValueError('Data must be provided to create embeddings table')
raise ValueError("Data must be provided to create embeddings table")
data_info = check_det_dataset(self.data)
if split not in data_info:
raise ValueError(
f'Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}'
f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}"
)
choice_set = data_info[split]
@ -113,13 +111,16 @@ class Explorer:
# Create the table schema
batch = dataset[0]
vector_size = self.model.embed(batch['im_file'], verbose=False)[0].shape[0]
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode='overwrite')
vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0]
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite")
table.add(
self._yield_batches(dataset,
data_info,
self.model,
exclude_keys=['img', 'ratio_pad', 'resized_shape', 'ori_shape', 'batch_idx']))
self._yield_batches(
dataset,
data_info,
self.model,
exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"],
)
)
self.table = table
@ -131,12 +132,12 @@ class Explorer:
for k in exclude_keys:
batch.pop(k, None)
batch = sanitize_batch(batch, data_info)
batch['vector'] = model.embed(batch['im_file'], verbose=False)[0].detach().tolist()
batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist()
yield [batch]
def query(self,
imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
limit: int = 25) -> Any: # pyarrow.Table
def query(
self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25
) -> Any: # pyarrow.Table
"""
Query the table for similar images. Accepts a single image or a list of images.
@ -157,18 +158,18 @@ class Explorer:
```
"""
if self.table is None:
raise ValueError('Table is not created. Please create the table first.')
raise ValueError("Table is not created. Please create the table first.")
if isinstance(imgs, str):
imgs = [imgs]
assert isinstance(imgs, list), f'img must be a string or a list of strings. Got {type(imgs)}'
assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}"
embeds = self.model.embed(imgs)
# Get avg if multiple images are passed (len > 1)
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
return self.table.search(embeds).limit(limit).to_arrow()
def sql_query(self,
query: str,
return_type: str = 'pandas') -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
def sql_query(
self, query: str, return_type: str = "pandas"
) -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
"""
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
@ -187,27 +188,29 @@ class Explorer:
result = exp.sql_query(query)
```
"""
assert return_type in ['pandas',
'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
assert return_type in [
"pandas",
"arrow",
], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
import duckdb
if self.table is None:
raise ValueError('Table is not created. Please create the table first.')
raise ValueError("Table is not created. Please create the table first.")
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
if not query.startswith('SELECT') and not query.startswith('WHERE'):
if not query.startswith("SELECT") and not query.startswith("WHERE"):
raise ValueError(
f'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}'
f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}"
)
if query.startswith('WHERE'):
if query.startswith("WHERE"):
query = f"SELECT * FROM 'table' {query}"
LOGGER.info(f'Running query: {query}')
LOGGER.info(f"Running query: {query}")
rs = duckdb.sql(query)
if return_type == 'pandas':
if return_type == "pandas":
return rs.df()
elif return_type == 'arrow':
elif return_type == "arrow":
return rs.arrow()
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
@ -228,18 +231,20 @@ class Explorer:
result = exp.plot_sql_query(query)
```
"""
result = self.sql_query(query, return_type='arrow')
result = self.sql_query(query, return_type="arrow")
if len(result) == 0:
LOGGER.info('No results found.')
LOGGER.info("No results found.")
return None
img = plot_query_result(result, plot_labels=labels)
return Image.fromarray(img)
def get_similar(self,
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
idx: Union[int, List[int]] = None,
limit: int = 25,
return_type: str = 'pandas') -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
def get_similar(
self,
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
idx: Union[int, List[int]] = None,
limit: int = 25,
return_type: str = "pandas",
) -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
"""
Query the table for similar images. Accepts a single image or a list of images.
@ -259,21 +264,25 @@ class Explorer:
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
```
"""
assert return_type in ['pandas',
'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
assert return_type in [
"pandas",
"arrow",
], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
img = self._check_imgs_or_idxs(img, idx)
similar = self.query(img, limit=limit)
if return_type == 'pandas':
if return_type == "pandas":
return similar.to_pandas()
elif return_type == 'arrow':
elif return_type == "arrow":
return similar
def plot_similar(self,
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
idx: Union[int, List[int]] = None,
limit: int = 25,
labels: bool = True) -> Image.Image:
def plot_similar(
self,
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
idx: Union[int, List[int]] = None,
limit: int = 25,
labels: bool = True,
) -> Image.Image:
"""
Plot the similar images. Accepts images or indexes.
@ -293,9 +302,9 @@ class Explorer:
similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg')
```
"""
similar = self.get_similar(img, idx, limit, return_type='arrow')
similar = self.get_similar(img, idx, limit, return_type="arrow")
if len(similar) == 0:
LOGGER.info('No results found.')
LOGGER.info("No results found.")
return None
img = plot_query_result(similar, plot_labels=labels)
return Image.fromarray(img)
@ -323,34 +332,37 @@ class Explorer:
```
"""
if self.table is None:
raise ValueError('Table is not created. Please create the table first.')
sim_idx_table_name = f'{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}'.lower()
raise ValueError("Table is not created. Please create the table first.")
sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower()
if sim_idx_table_name in self.connection.table_names() and not force:
LOGGER.info('Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.')
LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.")
return self.connection.open_table(sim_idx_table_name).to_pandas()
if top_k and not (1.0 >= top_k >= 0.0):
raise ValueError(f'top_k must be between 0.0 and 1.0. Got {top_k}')
raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}")
if max_dist < 0.0:
raise ValueError(f'max_dist must be greater than 0. Got {max_dist}')
raise ValueError(f"max_dist must be greater than 0. Got {max_dist}")
top_k = int(top_k * len(self.table)) if top_k else len(self.table)
top_k = max(top_k, 1)
features = self.table.to_lance().to_table(columns=['vector', 'im_file']).to_pydict()
im_files = features['im_file']
embeddings = features['vector']
features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict()
im_files = features["im_file"]
embeddings = features["vector"]
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode='overwrite')
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite")
def _yield_sim_idx():
"""Generates a dataframe with similarity indices and distances for images."""
for i in tqdm(range(len(embeddings))):
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f'_distance <= {max_dist}')
yield [{
'idx': i,
'im_file': im_files[i],
'count': len(sim_idx),
'sim_im_files': sim_idx['im_file'].tolist()}]
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}")
yield [
{
"idx": i,
"im_file": im_files[i],
"count": len(sim_idx),
"sim_im_files": sim_idx["im_file"].tolist(),
}
]
sim_table.add(_yield_sim_idx())
self.sim_index = sim_table
@ -381,7 +393,7 @@ class Explorer:
```
"""
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
sim_count = sim_idx['count'].tolist()
sim_count = sim_idx["count"].tolist()
sim_count = np.array(sim_count)
indices = np.arange(len(sim_count))
@ -390,25 +402,26 @@ class Explorer:
plt.bar(indices, sim_count)
# Customize the plot (optional)
plt.xlabel('data idx')
plt.ylabel('Count')
plt.title('Similarity Count')
plt.xlabel("data idx")
plt.ylabel("Count")
plt.title("Similarity Count")
buffer = BytesIO()
plt.savefig(buffer, format='png')
plt.savefig(buffer, format="png")
buffer.seek(0)
# Use Pillow to open the image from the buffer
return Image.fromarray(np.array(Image.open(buffer)))
def _check_imgs_or_idxs(self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None],
idx: Union[None, int, List[int]]) -> List[np.ndarray]:
def _check_imgs_or_idxs(
self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]]
) -> List[np.ndarray]:
if img is None and idx is None:
raise ValueError('Either img or idx must be provided.')
raise ValueError("Either img or idx must be provided.")
if img is not None and idx is not None:
raise ValueError('Only one of img or idx must be provided.')
raise ValueError("Only one of img or idx must be provided.")
if idx is not None:
idx = idx if isinstance(idx, list) else [idx]
img = self.table.to_lance().take(idx, columns=['im_file']).to_pydict()['im_file']
img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"]
return img if isinstance(img, list) else [img]
@ -433,7 +446,7 @@ class Explorer:
try:
df = self.sql_query(result)
except Exception as e:
LOGGER.error('AI generated query is not valid. Please try again with a different prompt')
LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
LOGGER.error(e)
return None
return df

@ -9,100 +9,114 @@ from ultralytics import Explorer
from ultralytics.utils import ROOT, SETTINGS
from ultralytics.utils.checks import check_requirements
check_requirements(('streamlit>=1.29.0', 'streamlit-select>=0.2'))
check_requirements(("streamlit>=1.29.0", "streamlit-select>=0.2"))
import streamlit as st
from streamlit_select import image_select
def _get_explorer():
"""Initializes and returns an instance of the Explorer class."""
exp = Explorer(data=st.session_state.get('dataset'), model=st.session_state.get('model'))
thread = Thread(target=exp.create_embeddings_table,
kwargs={'force': st.session_state.get('force_recreate_embeddings')})
exp = Explorer(data=st.session_state.get("dataset"), model=st.session_state.get("model"))
thread = Thread(
target=exp.create_embeddings_table, kwargs={"force": st.session_state.get("force_recreate_embeddings")}
)
thread.start()
progress_bar = st.progress(0, text='Creating embeddings table...')
progress_bar = st.progress(0, text="Creating embeddings table...")
while exp.progress < 1:
time.sleep(0.1)
progress_bar.progress(exp.progress, text=f'Progress: {exp.progress * 100}%')
progress_bar.progress(exp.progress, text=f"Progress: {exp.progress * 100}%")
thread.join()
st.session_state['explorer'] = exp
st.session_state["explorer"] = exp
progress_bar.empty()
def init_explorer_form():
"""Initializes an Explorer instance and creates embeddings table with progress tracking."""
datasets = ROOT / 'cfg' / 'datasets'
ds = [d.name for d in datasets.glob('*.yaml')]
datasets = ROOT / "cfg" / "datasets"
ds = [d.name for d in datasets.glob("*.yaml")]
models = [
'yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt', 'yolov8l.pt', 'yolov8x.pt', 'yolov8n-seg.pt', 'yolov8s-seg.pt',
'yolov8m-seg.pt', 'yolov8l-seg.pt', 'yolov8x-seg.pt', 'yolov8n-pose.pt', 'yolov8s-pose.pt', 'yolov8m-pose.pt',
'yolov8l-pose.pt', 'yolov8x-pose.pt']
with st.form(key='explorer_init_form'):
"yolov8n.pt",
"yolov8s.pt",
"yolov8m.pt",
"yolov8l.pt",
"yolov8x.pt",
"yolov8n-seg.pt",
"yolov8s-seg.pt",
"yolov8m-seg.pt",
"yolov8l-seg.pt",
"yolov8x-seg.pt",
"yolov8n-pose.pt",
"yolov8s-pose.pt",
"yolov8m-pose.pt",
"yolov8l-pose.pt",
"yolov8x-pose.pt",
]
with st.form(key="explorer_init_form"):
col1, col2 = st.columns(2)
with col1:
st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml'))
st.selectbox("Select dataset", ds, key="dataset", index=ds.index("coco128.yaml"))
with col2:
st.selectbox('Select model', models, key='model')
st.checkbox('Force recreate embeddings', key='force_recreate_embeddings')
st.selectbox("Select model", models, key="model")
st.checkbox("Force recreate embeddings", key="force_recreate_embeddings")
st.form_submit_button('Explore', on_click=_get_explorer)
st.form_submit_button("Explore", on_click=_get_explorer)
def query_form():
"""Sets up a form in Streamlit to initialize Explorer with dataset and model selection."""
with st.form('query_form'):
with st.form("query_form"):
col1, col2 = st.columns([0.8, 0.2])
with col1:
st.text_input('Query',
"WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
label_visibility='collapsed',
key='query')
st.text_input(
"Query",
"WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
label_visibility="collapsed",
key="query",
)
with col2:
st.form_submit_button('Query', on_click=run_sql_query)
st.form_submit_button("Query", on_click=run_sql_query)
def ai_query_form():
"""Sets up a Streamlit form for user input to initialize Explorer with dataset and model selection."""
with st.form('ai_query_form'):
with st.form("ai_query_form"):
col1, col2 = st.columns([0.8, 0.2])
with col1:
st.text_input('Query', 'Show images with 1 person and 1 dog', label_visibility='collapsed', key='ai_query')
st.text_input("Query", "Show images with 1 person and 1 dog", label_visibility="collapsed", key="ai_query")
with col2:
st.form_submit_button('Ask AI', on_click=run_ai_query)
st.form_submit_button("Ask AI", on_click=run_ai_query)
def find_similar_imgs(imgs):
"""Initializes a Streamlit form for AI-based image querying with custom input."""
exp = st.session_state['explorer']
similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow')
paths = similar.to_pydict()['im_file']
st.session_state['imgs'] = paths
exp = st.session_state["explorer"]
similar = exp.get_similar(img=imgs, limit=st.session_state.get("limit"), return_type="arrow")
paths = similar.to_pydict()["im_file"]
st.session_state["imgs"] = paths
def similarity_form(selected_imgs):
"""Initializes a form for AI-based image querying with custom input in Streamlit."""
st.write('Similarity Search')
with st.form('similarity_form'):
st.write("Similarity Search")
with st.form("similarity_form"):
subcol1, subcol2 = st.columns([1, 1])
with subcol1:
st.number_input('limit',
min_value=None,
max_value=None,
value=25,
label_visibility='collapsed',
key='limit')
st.number_input(
"limit", min_value=None, max_value=None, value=25, label_visibility="collapsed", key="limit"
)
with subcol2:
disabled = not len(selected_imgs)
st.write('Selected: ', len(selected_imgs))
st.write("Selected: ", len(selected_imgs))
st.form_submit_button(
'Search',
"Search",
disabled=disabled,
on_click=find_similar_imgs,
args=(selected_imgs, ),
args=(selected_imgs,),
)
if disabled:
st.error('Select at least one image to search.')
st.error("Select at least one image to search.")
# def persist_reset_form():
@ -117,100 +131,108 @@ def similarity_form(selected_imgs):
def run_sql_query():
"""Executes an SQL query and returns the results."""
st.session_state['error'] = None
query = st.session_state.get('query')
st.session_state["error"] = None
query = st.session_state.get("query")
if query.rstrip().lstrip():
exp = st.session_state['explorer']
res = exp.sql_query(query, return_type='arrow')
st.session_state['imgs'] = res.to_pydict()['im_file']
exp = st.session_state["explorer"]
res = exp.sql_query(query, return_type="arrow")
st.session_state["imgs"] = res.to_pydict()["im_file"]
def run_ai_query():
"""Execute SQL query and update session state with query results."""
if not SETTINGS['openai_api_key']:
if not SETTINGS["openai_api_key"]:
st.session_state[
'error'] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
"error"
] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
return
st.session_state['error'] = None
query = st.session_state.get('ai_query')
st.session_state["error"] = None
query = st.session_state.get("ai_query")
if query.rstrip().lstrip():
exp = st.session_state['explorer']
exp = st.session_state["explorer"]
res = exp.ask_ai(query)
if not isinstance(res, pd.DataFrame) or res.empty:
st.session_state['error'] = 'No results found using AI generated query. Try another query or rerun it.'
st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it."
return
st.session_state['imgs'] = res['im_file'].to_list()
st.session_state["imgs"] = res["im_file"].to_list()
def reset_explorer():
"""Resets the explorer to its initial state by clearing session variables."""
st.session_state['explorer'] = None
st.session_state['imgs'] = None
st.session_state['error'] = None
st.session_state["explorer"] = None
st.session_state["imgs"] = None
st.session_state["error"] = None
def utralytics_explorer_docs_callback():
"""Resets the explorer to its initial state by clearing session variables."""
with st.container(border=True):
st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg',
width=100)
st.image(
"https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg",
width=100,
)
st.markdown(
"<p>This demo is built using Ultralytics Explorer API. Visit <a href='https://docs.ultralytics.com/datasets/explorer/'>API docs</a> to try examples & learn more</p>",
unsafe_allow_html=True,
help=None)
st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/datasets/explorer/')
help=None,
)
st.link_button("Ultrlaytics Explorer API", "https://docs.ultralytics.com/datasets/explorer/")
def layout():
"""Resets explorer session variables and provides documentation with a link to API docs."""
st.set_page_config(layout='wide', initial_sidebar_state='collapsed')
st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True)
if st.session_state.get('explorer') is None:
if st.session_state.get("explorer") is None:
init_explorer_form()
return
st.button(':arrow_backward: Select Dataset', on_click=reset_explorer)
exp = st.session_state.get('explorer')
col1, col2 = st.columns([0.75, 0.25], gap='small')
st.button(":arrow_backward: Select Dataset", on_click=reset_explorer)
exp = st.session_state.get("explorer")
col1, col2 = st.columns([0.75, 0.25], gap="small")
imgs = []
if st.session_state.get('error'):
st.error(st.session_state['error'])
if st.session_state.get("error"):
st.error(st.session_state["error"])
else:
imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file']
imgs = st.session_state.get("imgs") or exp.table.to_lance().to_table(columns=["im_file"]).to_pydict()["im_file"]
total_imgs, selected_imgs = len(imgs), []
with col1:
subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
with subcol1:
st.write('Max Images Displayed:')
st.write("Max Images Displayed:")
with subcol2:
num = st.number_input('Max Images Displayed',
min_value=0,
max_value=total_imgs,
value=min(500, total_imgs),
key='num_imgs_displayed',
label_visibility='collapsed')
num = st.number_input(
"Max Images Displayed",
min_value=0,
max_value=total_imgs,
value=min(500, total_imgs),
key="num_imgs_displayed",
label_visibility="collapsed",
)
with subcol3:
st.write('Start Index:')
st.write("Start Index:")
with subcol4:
start_idx = st.number_input('Start Index',
min_value=0,
max_value=total_imgs,
value=0,
key='start_index',
label_visibility='collapsed')
start_idx = st.number_input(
"Start Index",
min_value=0,
max_value=total_imgs,
value=0,
key="start_index",
label_visibility="collapsed",
)
with subcol5:
reset = st.button('Reset', use_container_width=False, key='reset')
reset = st.button("Reset", use_container_width=False, key="reset")
if reset:
st.session_state['imgs'] = None
st.session_state["imgs"] = None
st.experimental_rerun()
query_form()
ai_query_form()
if total_imgs:
imgs_displayed = imgs[start_idx:start_idx + num]
imgs_displayed = imgs[start_idx : start_idx + num]
selected_imgs = image_select(
f'Total samples: {total_imgs}',
f"Total samples: {total_imgs}",
images=imgs_displayed,
use_container_width=False,
# indices=[i for i in range(num)] if select_all else None,
@ -222,5 +244,5 @@ def layout():
utralytics_explorer_docs_callback()
if __name__ == '__main__':
if __name__ == "__main__":
layout()

@ -46,14 +46,13 @@ def get_sim_index_schema():
def sanitize_batch(batch, dataset_info):
"""Sanitizes input batch for inference, ensuring correct format and dimensions."""
batch['cls'] = batch['cls'].flatten().int().tolist()
box_cls_pair = sorted(zip(batch['bboxes'].tolist(), batch['cls']), key=lambda x: x[1])
batch['bboxes'] = [box for box, _ in box_cls_pair]
batch['cls'] = [cls for _, cls in box_cls_pair]
batch['labels'] = [dataset_info['names'][i] for i in batch['cls']]
batch['masks'] = batch['masks'].tolist() if 'masks' in batch else [[[]]]
batch['keypoints'] = batch['keypoints'].tolist() if 'keypoints' in batch else [[[]]]
batch["cls"] = batch["cls"].flatten().int().tolist()
box_cls_pair = sorted(zip(batch["bboxes"].tolist(), batch["cls"]), key=lambda x: x[1])
batch["bboxes"] = [box for box, _ in box_cls_pair]
batch["cls"] = [cls for _, cls in box_cls_pair]
batch["labels"] = [dataset_info["names"][i] for i in batch["cls"]]
batch["masks"] = batch["masks"].tolist() if "masks" in batch else [[[]]]
batch["keypoints"] = batch["keypoints"].tolist() if "keypoints" in batch else [[[]]]
return batch
@ -65,15 +64,16 @@ def plot_query_result(similar_set, plot_labels=True):
similar_set (list): Pyarrow or pandas object containing the similar data points
plot_labels (bool): Whether to plot labels or not
"""
similar_set = similar_set.to_dict(
orient='list') if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
similar_set = (
similar_set.to_dict(orient="list") if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
)
empty_masks = [[[]]]
empty_boxes = [[]]
images = similar_set.get('im_file', [])
bboxes = similar_set.get('bboxes', []) if similar_set.get('bboxes') is not empty_boxes else []
masks = similar_set.get('masks') if similar_set.get('masks')[0] != empty_masks else []
kpts = similar_set.get('keypoints') if similar_set.get('keypoints')[0] != empty_masks else []
cls = similar_set.get('cls', [])
images = similar_set.get("im_file", [])
bboxes = similar_set.get("bboxes", []) if similar_set.get("bboxes") is not empty_boxes else []
masks = similar_set.get("masks") if similar_set.get("masks")[0] != empty_masks else []
kpts = similar_set.get("keypoints") if similar_set.get("keypoints")[0] != empty_masks else []
cls = similar_set.get("cls", [])
plot_size = 640
imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], []
@ -104,34 +104,26 @@ def plot_query_result(similar_set, plot_labels=True):
batch_idx = np.concatenate(batch_idx, axis=0)
cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)
return plot_images(imgs,
batch_idx,
cls,
bboxes=boxes,
masks=masks,
kpts=kpts,
max_subplots=len(images),
save=False,
threaded=False)
return plot_images(
imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, max_subplots=len(images), save=False, threaded=False
)
def prompt_sql_query(query):
"""Plots images with optional labels from a similar data set."""
check_requirements('openai>=1.6.1')
check_requirements("openai>=1.6.1")
from openai import OpenAI
if not SETTINGS['openai_api_key']:
logger.warning('OpenAI API key not found in settings. Please enter your API key below.')
openai_api_key = getpass.getpass('OpenAI API key: ')
SETTINGS.update({'openai_api_key': openai_api_key})
openai = OpenAI(api_key=SETTINGS['openai_api_key'])
if not SETTINGS["openai_api_key"]:
logger.warning("OpenAI API key not found in settings. Please enter your API key below.")
openai_api_key = getpass.getpass("OpenAI API key: ")
SETTINGS.update({"openai_api_key": openai_api_key})
openai = OpenAI(api_key=SETTINGS["openai_api_key"])
messages = [
{
'role':
'system',
'content':
'''
"role": "system",
"content": """
You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on
the following schema and a user request. You only need to output the format with fixed selection
statement that selects everything from "'table'", like `SELECT * from 'table'`
@ -165,10 +157,10 @@ def prompt_sql_query(query):
request - Get all data points that contain 2 or more people and at least one dog
correct query-
SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1;
'''},
{
'role': 'user',
'content': f'{query}'}, ]
""",
},
{"role": "user", "content": f"{query}"},
]
response = openai.chat.completions.create(model='gpt-3.5-turbo', messages=messages)
response = openai.chat.completions.create(model="gpt-3.5-turbo", messages=messages)
return response.choices[0].message.content

@ -23,6 +23,7 @@ from ultralytics.utils.checks import check_requirements
@dataclass
class SourceTypes:
"""Class to represent various types of input sources for predictions."""
webcam: bool = False
screenshot: bool = False
from_img: bool = False
@ -59,12 +60,12 @@ class LoadStreams:
__len__: Return the length of the sources object.
"""
def __init__(self, sources='file.streams', imgsz=640, vid_stride=1, buffer=False):
def __init__(self, sources="file.streams", imgsz=640, vid_stride=1, buffer=False):
"""Initialize instance variables and check for consistent input stream shapes."""
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
self.buffer = buffer # buffer input streams
self.running = True # running flag for Thread
self.mode = 'stream'
self.mode = "stream"
self.imgsz = imgsz
self.vid_stride = vid_stride # video frame-rate stride
@ -79,33 +80,36 @@ class LoadStreams:
self.sources = [ops.clean_str(x) for x in sources] # clean source names for later
for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream
st = f'{i + 1}/{n}: {s}... '
if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
st = f"{i + 1}/{n}: {s}... "
if urlparse(s).hostname in ("www.youtube.com", "youtube.com", "youtu.be"): # if source is YouTube video
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4'
s = get_best_youtube_url(s)
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
if s == 0 and (is_colab() or is_kaggle()):
raise NotImplementedError("'source=0' webcam not supported in Colab and Kaggle notebooks. "
"Try running 'source=0' in a local environment.")
raise NotImplementedError(
"'source=0' webcam not supported in Colab and Kaggle notebooks. "
"Try running 'source=0' in a local environment."
)
self.caps[i] = cv2.VideoCapture(s) # store video capture object
if not self.caps[i].isOpened():
raise ConnectionError(f'{st}Failed to open {s}')
raise ConnectionError(f"{st}Failed to open {s}")
w = int(self.caps[i].get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(self.caps[i].get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = self.caps[i].get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
self.frames[i] = max(int(self.caps[i].get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float(
'inf') # infinite stream fallback
"inf"
) # infinite stream fallback
self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
success, im = self.caps[i].read() # guarantee first frame
if not success or im is None:
raise ConnectionError(f'{st}Failed to read images from {s}')
raise ConnectionError(f"{st}Failed to read images from {s}")
self.imgs[i].append(im)
self.shape[i] = im.shape
self.threads[i] = Thread(target=self.update, args=([i, self.caps[i], s]), daemon=True)
LOGGER.info(f'{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)')
LOGGER.info(f"{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)")
self.threads[i].start()
LOGGER.info('') # newline
LOGGER.info("") # newline
# Check for common shapes
self.bs = self.__len__()
@ -121,7 +125,7 @@ class LoadStreams:
success, im = cap.retrieve()
if not success:
im = np.zeros(self.shape[i], dtype=np.uint8)
LOGGER.warning('WARNING ⚠ Video stream unresponsive, please check your IP camera connection.')
LOGGER.warning("WARNING ⚠ Video stream unresponsive, please check your IP camera connection.")
cap.open(stream) # re-open stream if signal was lost
if self.buffer:
self.imgs[i].append(im)
@ -140,7 +144,7 @@ class LoadStreams:
try:
cap.release() # release video capture
except Exception as e:
LOGGER.warning(f'WARNING ⚠ Could not release VideoCapture object: {e}')
LOGGER.warning(f"WARNING ⚠ Could not release VideoCapture object: {e}")
cv2.destroyAllWindows()
def __iter__(self):
@ -154,16 +158,15 @@ class LoadStreams:
images = []
for i, x in enumerate(self.imgs):
# Wait until a frame is available in each buffer
while not x:
if not self.threads[i].is_alive() or cv2.waitKey(1) == ord('q'): # q to quit
if not self.threads[i].is_alive() or cv2.waitKey(1) == ord("q"): # q to quit
self.close()
raise StopIteration
time.sleep(1 / min(self.fps))
x = self.imgs[i]
if not x:
LOGGER.warning(f'WARNING ⚠ Waiting for stream {i}')
LOGGER.warning(f"WARNING ⚠ Waiting for stream {i}")
# Get and remove the first frame from imgs buffer
if self.buffer:
@ -174,7 +177,7 @@ class LoadStreams:
images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8))
x.clear()
return self.sources, images, None, ''
return self.sources, images, None, ""
def __len__(self):
"""Return the length of the sources object."""
@ -209,7 +212,7 @@ class LoadScreenshots:
def __init__(self, source, imgsz=640):
"""Source = [screen_number left top width height] (pixels)."""
check_requirements('mss')
check_requirements("mss")
import mss # noqa
source, *params = source.split()
@ -221,18 +224,18 @@ class LoadScreenshots:
elif len(params) == 5:
self.screen, left, top, width, height = (int(x) for x in params)
self.imgsz = imgsz
self.mode = 'stream'
self.mode = "stream"
self.frame = 0
self.sct = mss.mss()
self.bs = 1
# Parse monitor shape
monitor = self.sct.monitors[self.screen]
self.top = monitor['top'] if top is None else (monitor['top'] + top)
self.left = monitor['left'] if left is None else (monitor['left'] + left)
self.width = width or monitor['width']
self.height = height or monitor['height']
self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
self.top = monitor["top"] if top is None else (monitor["top"] + top)
self.left = monitor["left"] if left is None else (monitor["left"] + left)
self.width = width or monitor["width"]
self.height = height or monitor["height"]
self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
def __iter__(self):
"""Returns an iterator of the object."""
@ -241,7 +244,7 @@ class LoadScreenshots:
def __next__(self):
"""mss screen capture: get raw pixels from the screen as np array."""
im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3] # BGRA to BGR
s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: '
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
self.frame += 1
return [str(self.screen)], [im0], None, s # screen, img, vid_cap, string
@ -274,32 +277,32 @@ class LoadImages:
def __init__(self, path, imgsz=640, vid_stride=1):
"""Initialize the Dataloader and raise FileNotFoundError if file not found."""
parent = None
if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
parent = Path(path).parent
path = Path(path).read_text().splitlines() # list of sources
files = []
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
a = str(Path(p).absolute()) # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912
if '*' in a:
if "*" in a:
files.extend(sorted(glob.glob(a, recursive=True))) # glob
elif os.path.isdir(a):
files.extend(sorted(glob.glob(os.path.join(a, '*.*')))) # dir
files.extend(sorted(glob.glob(os.path.join(a, "*.*")))) # dir
elif os.path.isfile(a):
files.append(a) # files (absolute or relative to CWD)
elif parent and (parent / p).is_file():
files.append(str((parent / p).absolute())) # files (relative to *.txt file parent)
else:
raise FileNotFoundError(f'{p} does not exist')
raise FileNotFoundError(f"{p} does not exist")
images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS]
videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS]
ni, nv = len(images), len(videos)
self.imgsz = imgsz
self.files = images + videos
self.nf = ni + nv # number of files
self.video_flag = [False] * ni + [True] * nv
self.mode = 'image'
self.mode = "image"
self.vid_stride = vid_stride # video frame-rate stride
self.bs = 1
if any(videos):
@ -307,8 +310,10 @@ class LoadImages:
else:
self.cap = None
if self.nf == 0:
raise FileNotFoundError(f'No images or videos found in {p}. '
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}')
raise FileNotFoundError(
f"No images or videos found in {p}. "
f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
)
def __iter__(self):
"""Returns an iterator object for VideoStream or ImageFolder."""
@ -323,7 +328,7 @@ class LoadImages:
if self.video_flag[self.count]:
# Read video
self.mode = 'video'
self.mode = "video"
for _ in range(self.vid_stride):
self.cap.grab()
success, im0 = self.cap.retrieve()
@ -338,15 +343,15 @@ class LoadImages:
self.frame += 1
# im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
s = f"video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: "
else:
# Read image
self.count += 1
im0 = cv2.imread(path) # BGR
if im0 is None:
raise FileNotFoundError(f'Image Not Found {path}')
s = f'image {self.count}/{self.nf} {path}: '
raise FileNotFoundError(f"Image Not Found {path}")
s = f"image {self.count}/{self.nf} {path}: "
return [path], [im0], self.cap, s
@ -385,20 +390,20 @@ class LoadPilAndNumpy:
"""Initialize PIL and Numpy Dataloader."""
if not isinstance(im0, list):
im0 = [im0]
self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
self.im0 = [self._single_check(im) for im in im0]
self.imgsz = imgsz
self.mode = 'image'
self.mode = "image"
# Generate fake paths
self.bs = len(self.im0)
@staticmethod
def _single_check(im):
"""Validate and format an image to numpy array."""
assert isinstance(im, (Image.Image, np.ndarray)), f'Expected PIL/np.ndarray image type, but got {type(im)}'
assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}"
if isinstance(im, Image.Image):
if im.mode != 'RGB':
im = im.convert('RGB')
if im.mode != "RGB":
im = im.convert("RGB")
im = np.asarray(im)[:, :, ::-1]
im = np.ascontiguousarray(im) # contiguous
return im
@ -412,7 +417,7 @@ class LoadPilAndNumpy:
if self.count == 1: # loop only once as it's batch inference
raise StopIteration
self.count += 1
return self.paths, self.im0, None, ''
return self.paths, self.im0, None, ""
def __iter__(self):
"""Enables iteration for class LoadPilAndNumpy."""
@ -441,14 +446,16 @@ class LoadTensor:
"""Initialize Tensor Dataloader."""
self.im0 = self._single_check(im0)
self.bs = self.im0.shape[0]
self.mode = 'image'
self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
self.mode = "image"
self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
@staticmethod
def _single_check(im, stride=32):
"""Validate and format an image to torch.Tensor."""
s = f'WARNING ⚠ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) ' \
f'divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible.'
s = (
f"WARNING ⚠ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) "
f"divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible."
)
if len(im.shape) != 4:
if len(im.shape) != 3:
raise ValueError(s)
@ -457,8 +464,10 @@ class LoadTensor:
if im.shape[2] % stride or im.shape[3] % stride:
raise ValueError(s)
if im.max() > 1.0 + torch.finfo(im.dtype).eps: # torch.float32 eps is 1.2e-07
LOGGER.warning(f'WARNING ⚠ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. '
f'Dividing input by 255.')
LOGGER.warning(
f"WARNING ⚠ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. "
f"Dividing input by 255."
)
im = im.float() / 255.0
return im
@ -473,7 +482,7 @@ class LoadTensor:
if self.count == 1:
raise StopIteration
self.count += 1
return self.paths, self.im0, None, ''
return self.paths, self.im0, None, ""
def __len__(self):
"""Returns the batch size."""
@ -485,12 +494,14 @@ def autocast_list(source):
files = []
for im in source:
if isinstance(im, (str, Path)): # filename or uri
files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im))
files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith("http") else im))
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
files.append(im)
else:
raise TypeError(f'type {type(im).__name__} is not a supported Ultralytics prediction source type. \n'
f'See https://docs.ultralytics.com/modes/predict for supported source types.')
raise TypeError(
f"type {type(im).__name__} is not a supported Ultralytics prediction source type. \n"
f"See https://docs.ultralytics.com/modes/predict for supported source types."
)
return files
@ -513,16 +524,18 @@ def get_best_youtube_url(url, use_pafy=True):
(str): The URL of the best quality MP4 video stream, or None if no suitable stream is found.
"""
if use_pafy:
check_requirements(('pafy', 'youtube_dl==2020.12.2'))
check_requirements(("pafy", "youtube_dl==2020.12.2"))
import pafy # noqa
return pafy.new(url).getbestvideo(preftype='mp4').url
return pafy.new(url).getbestvideo(preftype="mp4").url
else:
check_requirements('yt-dlp')
check_requirements("yt-dlp")
import yt_dlp
with yt_dlp.YoutubeDL({'quiet': True}) as ydl:
with yt_dlp.YoutubeDL({"quiet": True}) as ydl:
info_dict = ydl.extract_info(url, download=False) # extract info
for f in reversed(info_dict.get('formats', [])): # reversed because best is usually last
for f in reversed(info_dict.get("formats", [])): # reversed because best is usually last
# Find a format with video codec, no audio, *.mp4 extension at least 1920x1080 size
good_size = (f.get('width') or 0) >= 1920 or (f.get('height') or 0) >= 1080
if good_size and f['vcodec'] != 'none' and f['acodec'] == 'none' and f['ext'] == 'mp4':
return f.get('url')
good_size = (f.get("width") or 0) >= 1920 or (f.get("height") or 0) >= 1080
if good_size and f["vcodec"] != "none" and f["acodec"] == "none" and f["ext"] == "mp4":
return f.get("url")

@ -14,7 +14,7 @@ from tqdm import tqdm
from ultralytics.data.utils import exif_size, img2label_paths
from ultralytics.utils.checks import check_requirements
check_requirements('shapely')
check_requirements("shapely")
from shapely.geometry import Polygon
@ -54,7 +54,7 @@ def bbox_iof(polygon1, bbox2, eps=1e-6):
return outputs
def load_yolo_dota(data_root, split='train'):
def load_yolo_dota(data_root, split="train"):
"""
Load DOTA dataset.
@ -72,10 +72,10 @@ def load_yolo_dota(data_root, split='train'):
- train
- val
"""
assert split in ['train', 'val']
im_dir = os.path.join(data_root, f'images/{split}')
assert split in ["train", "val"]
im_dir = os.path.join(data_root, f"images/{split}")
assert Path(im_dir).exists(), f"Can't find {im_dir}, please check your data root."
im_files = glob(os.path.join(data_root, f'images/{split}/*'))
im_files = glob(os.path.join(data_root, f"images/{split}/*"))
lb_files = img2label_paths(im_files)
annos = []
for im_file, lb_file in zip(im_files, lb_files):
@ -100,7 +100,7 @@ def get_windows(im_size, crop_sizes=[1024], gaps=[200], im_rate_thr=0.6, eps=0.0
h, w = im_size
windows = []
for crop_size, gap in zip(crop_sizes, gaps):
assert crop_size > gap, f'invaild crop_size gap pair [{crop_size} {gap}]'
assert crop_size > gap, f"invalid crop_size gap pair [{crop_size} {gap}]"
step = crop_size - gap
xn = 1 if w <= crop_size else ceil((w - crop_size) / step + 1)
@ -132,8 +132,8 @@ def get_windows(im_size, crop_sizes=[1024], gaps=[200], im_rate_thr=0.6, eps=0.0
def get_window_obj(anno, windows, iof_thr=0.7):
"""Get objects for each window."""
h, w = anno['ori_size']
label = anno['label']
h, w = anno["ori_size"]
label = anno["label"]
if len(label):
label[:, 1::2] *= w
label[:, 2::2] *= h
@ -166,15 +166,15 @@ def crop_and_save(anno, windows, window_objs, im_dir, lb_dir):
- train
- val
"""
im = cv2.imread(anno['filepath'])
name = Path(anno['filepath']).stem
im = cv2.imread(anno["filepath"])
name = Path(anno["filepath"]).stem
for i, window in enumerate(windows):
x_start, y_start, x_stop, y_stop = window.tolist()
new_name = name + '__' + str(x_stop - x_start) + '__' + str(x_start) + '___' + str(y_start)
new_name = name + "__" + str(x_stop - x_start) + "__" + str(x_start) + "___" + str(y_start)
patch_im = im[y_start:y_stop, x_start:x_stop]
ph, pw = patch_im.shape[:2]
cv2.imwrite(os.path.join(im_dir, f'{new_name}.jpg'), patch_im)
cv2.imwrite(os.path.join(im_dir, f"{new_name}.jpg"), patch_im)
label = window_objs[i]
if len(label) == 0:
continue
@ -183,13 +183,13 @@ def crop_and_save(anno, windows, window_objs, im_dir, lb_dir):
label[:, 1::2] /= pw
label[:, 2::2] /= ph
with open(os.path.join(lb_dir, f'{new_name}.txt'), 'w') as f:
with open(os.path.join(lb_dir, f"{new_name}.txt"), "w") as f:
for lb in label:
formatted_coords = ['{:.6g}'.format(coord) for coord in lb[1:]]
formatted_coords = ["{:.6g}".format(coord) for coord in lb[1:]]
f.write(f"{int(lb[0])} {' '.join(formatted_coords)}\n")
def split_images_and_labels(data_root, save_dir, split='train', crop_sizes=[1024], gaps=[200]):
def split_images_and_labels(data_root, save_dir, split="train", crop_sizes=[1024], gaps=[200]):
"""
Split both images and labels.
@ -207,14 +207,14 @@ def split_images_and_labels(data_root, save_dir, split='train', crop_sizes=[1024
- labels
- split
"""
im_dir = Path(save_dir) / 'images' / split
im_dir = Path(save_dir) / "images" / split
im_dir.mkdir(parents=True, exist_ok=True)
lb_dir = Path(save_dir) / 'labels' / split
lb_dir = Path(save_dir) / "labels" / split
lb_dir.mkdir(parents=True, exist_ok=True)
annos = load_yolo_dota(data_root, split=split)
for anno in tqdm(annos, total=len(annos), desc=split):
windows = get_windows(anno['ori_size'], crop_sizes, gaps)
windows = get_windows(anno["ori_size"], crop_sizes, gaps)
window_objs = get_window_obj(anno, windows)
crop_and_save(anno, windows, window_objs, str(im_dir), str(lb_dir))
@ -245,7 +245,7 @@ def split_trainval(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]):
for r in rates:
crop_sizes.append(int(crop_size / r))
gaps.append(int(gap / r))
for split in ['train', 'val']:
for split in ["train", "val"]:
split_images_and_labels(data_root, save_dir, split, crop_sizes, gaps)
@ -267,30 +267,30 @@ def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]):
for r in rates:
crop_sizes.append(int(crop_size / r))
gaps.append(int(gap / r))
save_dir = Path(save_dir) / 'images' / 'test'
save_dir = Path(save_dir) / "images" / "test"
save_dir.mkdir(parents=True, exist_ok=True)
im_dir = Path(os.path.join(data_root, 'images/test'))
im_dir = Path(os.path.join(data_root, "images/test"))
assert im_dir.exists(), f"Can't find {str(im_dir)}, please check your data root."
im_files = glob(str(im_dir / '*'))
for im_file in tqdm(im_files, total=len(im_files), desc='test'):
im_files = glob(str(im_dir / "*"))
for im_file in tqdm(im_files, total=len(im_files), desc="test"):
w, h = exif_size(Image.open(im_file))
windows = get_windows((h, w), crop_sizes=crop_sizes, gaps=gaps)
im = cv2.imread(im_file)
name = Path(im_file).stem
for window in windows:
x_start, y_start, x_stop, y_stop = window.tolist()
new_name = (name + '__' + str(x_stop - x_start) + '__' + str(x_start) + '___' + str(y_start))
new_name = name + "__" + str(x_stop - x_start) + "__" + str(x_start) + "___" + str(y_start)
patch_im = im[y_start:y_stop, x_start:x_stop]
cv2.imwrite(os.path.join(str(save_dir), f'{new_name}.jpg'), patch_im)
cv2.imwrite(os.path.join(str(save_dir), f"{new_name}.jpg"), patch_im)
if __name__ == '__main__':
if __name__ == "__main__":
split_trainval(
data_root='DOTAv2',
save_dir='DOTAv2-split',
data_root="DOTAv2",
save_dir="DOTAv2-split",
)
split_test(
data_root='DOTAv2',
save_dir='DOTAv2-split',
data_root="DOTAv2",
save_dir="DOTAv2-split",
)

@ -17,36 +17,47 @@ import numpy as np
from PIL import Image, ImageOps
from ultralytics.nn.autobackend import check_class_names
from ultralytics.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_YAML, TQDM, clean_url, colorstr,
emojis, yaml_load, yaml_save)
from ultralytics.utils import (
DATASETS_DIR,
LOGGER,
NUM_THREADS,
ROOT,
SETTINGS_YAML,
TQDM,
clean_url,
colorstr,
emojis,
yaml_load,
yaml_save,
)
from ultralytics.utils.checks import check_file, check_font, is_ascii
from ultralytics.utils.downloads import download, safe_download, unzip_file
from ultralytics.utils.ops import segments2boxes
HELP_URL = 'See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance.'
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # image suffixes
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv', 'webm' # video suffixes
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance."
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # image suffixes
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm" # video suffixes
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
def img2label_paths(img_paths):
"""Define label paths as a function of image paths."""
sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
def get_hash(paths):
"""Returns a single hash value of a list of paths (files or dirs)."""
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
h = hashlib.sha256(str(size).encode()) # hash sizes
h.update(''.join(paths).encode()) # hash paths
h.update("".join(paths).encode()) # hash paths
return h.hexdigest() # return hash
def exif_size(img: Image.Image):
"""Returns exif-corrected PIL size."""
s = img.size # (width, height)
if img.format == 'JPEG': # only support JPEG images
if img.format == "JPEG": # only support JPEG images
with contextlib.suppress(Exception):
exif = img.getexif()
if exif:
@ -60,24 +71,24 @@ def verify_image(args):
"""Verify one image."""
(im_file, cls), prefix = args
# Number (found, corrupt), message
nf, nc, msg = 0, 0, ''
nf, nc, msg = 0, 0, ""
try:
im = Image.open(im_file)
im.verify() # PIL verify
shape = exif_size(im) # image size
shape = (shape[1], shape[0]) # hw
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
if im.format.lower() in ('jpg', 'jpeg'):
with open(im_file, 'rb') as f:
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
if im.format.lower() in ("jpg", "jpeg"):
with open(im_file, "rb") as f:
f.seek(-2, 2)
if f.read() != b'\xff\xd9': # corrupt JPEG
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
msg = f'{prefix}WARNING ⚠ {im_file}: corrupt JPEG restored and saved'
if f.read() != b"\xff\xd9": # corrupt JPEG
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
msg = f"{prefix}WARNING ⚠ {im_file}: corrupt JPEG restored and saved"
nf = 1
except Exception as e:
nc = 1
msg = f'{prefix}WARNING ⚠ {im_file}: ignoring corrupt image/label: {e}'
msg = f"{prefix}WARNING ⚠ {im_file}: ignoring corrupt image/label: {e}"
return (im_file, cls), nf, nc, msg
@ -85,21 +96,21 @@ def verify_image_label(args):
"""Verify one image-label pair."""
im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
# Number (missing, found, empty, corrupt), message, segments, keypoints
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, '', [], None
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
try:
# Verify images
im = Image.open(im_file)
im.verify() # PIL verify
shape = exif_size(im) # image size
shape = (shape[1], shape[0]) # hw
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
if im.format.lower() in ('jpg', 'jpeg'):
with open(im_file, 'rb') as f:
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
if im.format.lower() in ("jpg", "jpeg"):
with open(im_file, "rb") as f:
f.seek(-2, 2)
if f.read() != b'\xff\xd9': # corrupt JPEG
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
msg = f'{prefix}WARNING ⚠ {im_file}: corrupt JPEG restored and saved'
if f.read() != b"\xff\xd9": # corrupt JPEG
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
msg = f"{prefix}WARNING ⚠ {im_file}: corrupt JPEG restored and saved"
# Verify labels
if os.path.isfile(lb_file):
@ -114,25 +125,26 @@ def verify_image_label(args):
nl = len(lb)
if nl:
if keypoint:
assert lb.shape[1] == (5 + nkpt * ndim), f'labels require {(5 + nkpt * ndim)} columns each'
assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each"
points = lb[:, 5:].reshape(-1, ndim)[:, :2]
else:
assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
points = lb[:, 1:]
assert points.max() <= 1, f'non-normalized or out of bounds coordinates {points[points > 1]}'
assert lb.min() >= 0, f'negative label values {lb[lb < 0]}'
assert points.max() <= 1, f"non-normalized or out of bounds coordinates {points[points > 1]}"
assert lb.min() >= 0, f"negative label values {lb[lb < 0]}"
# All labels
max_cls = lb[:, 0].max() # max label count
assert max_cls <= num_cls, \
f'Label class {int(max_cls)} exceeds dataset class count {num_cls}. ' \
f'Possible class labels are 0-{num_cls - 1}'
assert max_cls <= num_cls, (
f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. "
f"Possible class labels are 0-{num_cls - 1}"
)
_, i = np.unique(lb, axis=0, return_index=True)
if len(i) < nl: # duplicate row check
lb = lb[i] # remove duplicates
if segments:
segments = [segments[x] for x in i]
msg = f'{prefix}WARNING ⚠ {im_file}: {nl - len(i)} duplicate labels removed'
msg = f"{prefix}WARNING ⚠ {im_file}: {nl - len(i)} duplicate labels removed"
else:
ne = 1 # label empty
lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)
@ -148,7 +160,7 @@ def verify_image_label(args):
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
except Exception as e:
nc = 1
msg = f'{prefix}WARNING ⚠ {im_file}: ignoring corrupt image/label: {e}'
msg = f"{prefix}WARNING ⚠ {im_file}: ignoring corrupt image/label: {e}"
return [None, None, None, None, None, nm, nf, ne, nc, msg]
@ -194,8 +206,10 @@ def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
"""Return a (640, 640) overlap mask."""
masks = np.zeros((imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
dtype=np.int32 if len(segments) > 255 else np.uint8)
masks = np.zeros(
(imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
dtype=np.int32 if len(segments) > 255 else np.uint8,
)
areas = []
ms = []
for si in range(len(segments)):
@ -226,7 +240,7 @@ def find_dataset_yaml(path: Path) -> Path:
Returns:
(Path): The path of the found YAML file.
"""
files = list(path.glob('*.yaml')) or list(path.rglob('*.yaml')) # try root level first and then recursive
files = list(path.glob("*.yaml")) or list(path.rglob("*.yaml")) # try root level first and then recursive
assert files, f"No YAML file found in '{path.resolve()}'"
if len(files) > 1:
files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match
@ -253,7 +267,7 @@ def check_det_dataset(dataset, autodownload=True):
file = check_file(dataset)
# Download (optional)
extract_dir = ''
extract_dir = ""
if zipfile.is_zipfile(file) or is_tarfile(file):
new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
file = find_dataset_yaml(DATASETS_DIR / new_dir)
@ -263,43 +277,44 @@ def check_det_dataset(dataset, autodownload=True):
data = yaml_load(file, append_filename=True) # dictionary
# Checks
for k in 'train', 'val':
for k in "train", "val":
if k not in data:
if k != 'val' or 'validation' not in data:
if k != "val" or "validation" not in data:
raise SyntaxError(
emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs."))
emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.")
)
LOGGER.info("WARNING ⚠ renaming data YAML 'validation' key to 'val' to match YOLO format.")
data['val'] = data.pop('validation') # replace 'validation' key with 'val' key
if 'names' not in data and 'nc' not in data:
data["val"] = data.pop("validation") # replace 'validation' key with 'val' key
if "names" not in data and "nc" not in data:
raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
if 'names' in data and 'nc' in data and len(data['names']) != data['nc']:
if "names" in data and "nc" in data and len(data["names"]) != data["nc"]:
raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
if 'names' not in data:
data['names'] = [f'class_{i}' for i in range(data['nc'])]
if "names" not in data:
data["names"] = [f"class_{i}" for i in range(data["nc"])]
else:
data['nc'] = len(data['names'])
data["nc"] = len(data["names"])
data['names'] = check_class_names(data['names'])
data["names"] = check_class_names(data["names"])
# Resolve paths
path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent) # dataset root
path = Path(extract_dir or data.get("path") or Path(data.get("yaml_file", "")).parent) # dataset root
if not path.is_absolute():
path = (DATASETS_DIR / path).resolve()
# Set paths
data['path'] = path # download scripts
for k in 'train', 'val', 'test':
data["path"] = path # download scripts
for k in "train", "val", "test":
if data.get(k): # prepend path
if isinstance(data[k], str):
x = (path / data[k]).resolve()
if not x.exists() and data[k].startswith('../'):
if not x.exists() and data[k].startswith("../"):
x = (path / data[k][3:]).resolve()
data[k] = str(x)
else:
data[k] = [str((path / x).resolve()) for x in data[k]]
# Parse YAML
val, s = (data.get(x) for x in ('val', 'download'))
val, s = (data.get(x) for x in ("val", "download"))
if val:
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
if not all(x.exists() for x in val):
@ -312,22 +327,22 @@ def check_det_dataset(dataset, autodownload=True):
raise FileNotFoundError(m)
t = time.time()
r = None # success
if s.startswith('http') and s.endswith('.zip'): # URL
if s.startswith("http") and s.endswith(".zip"): # URL
safe_download(url=s, dir=DATASETS_DIR, delete=True)
elif s.startswith('bash '): # bash script
LOGGER.info(f'Running {s} ...')
elif s.startswith("bash "): # bash script
LOGGER.info(f"Running {s} ...")
r = os.system(s)
else: # python script
exec(s, {'yaml': data})
dt = f'({round(time.time() - t, 1)}s)'
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt}'
LOGGER.info(f'Dataset download {s}\n')
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts
exec(s, {"yaml": data})
dt = f"({round(time.time() - t, 1)}s)"
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt}"
LOGGER.info(f"Dataset download {s}\n")
check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts
return data # dictionary
def check_cls_dataset(dataset, split=''):
def check_cls_dataset(dataset, split=""):
"""
Checks a classification dataset such as Imagenet.
@ -348,54 +363,59 @@ def check_cls_dataset(dataset, split=''):
"""
# Download (optional if dataset=https://file.zip is passed directly)
if str(dataset).startswith(('http:/', 'https:/')):
if str(dataset).startswith(("http:/", "https:/")):
dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
dataset = Path(dataset)
data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
if not data_dir.is_dir():
LOGGER.warning(f'\nDataset not found ⚠, missing path {data_dir}, attempting download...')
LOGGER.warning(f"\nDataset not found ⚠, missing path {data_dir}, attempting download...")
t = time.time()
if str(dataset) == 'imagenet':
if str(dataset) == "imagenet":
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
else:
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
url = f"https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip"
download(url, dir=data_dir.parent)
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
LOGGER.info(s)
train_set = data_dir / 'train'
val_set = data_dir / 'val' if (data_dir / 'val').exists() else data_dir / 'validation' if \
(data_dir / 'validation').exists() else None # data/test or data/val
test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
if split == 'val' and not val_set:
train_set = data_dir / "train"
val_set = (
data_dir / "val"
if (data_dir / "val").exists()
else data_dir / "validation"
if (data_dir / "validation").exists()
else None
) # data/test or data/val
test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test
if split == "val" and not val_set:
LOGGER.warning("WARNING ⚠ Dataset 'split=val' not found, using 'split=test' instead.")
elif split == 'test' and not test_set:
elif split == "test" and not test_set:
LOGGER.warning("WARNING ⚠ Dataset 'split=test' not found, using 'split=val' instead.")
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # number of classes
names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # class names list
names = dict(enumerate(sorted(names)))
# Print to console
for k, v in {'train': train_set, 'val': val_set, 'test': test_set}.items():
for k, v in {"train": train_set, "val": val_set, "test": test_set}.items():
prefix = f'{colorstr(f"{k}:")} {v}...'
if v is None:
LOGGER.info(prefix)
else:
files = [path for path in v.rglob('*.*') if path.suffix[1:].lower() in IMG_FORMATS]
files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS]
nf = len(files) # number of files
nd = len({file.parent for file in files}) # number of directories
if nf == 0:
if k == 'train':
if k == "train":
raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ "))
else:
LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: WARNING ⚠ no images found')
LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: WARNING ⚠ no images found")
elif nd != nc:
LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: ERROR ❌ requires {nc} classes, not {nd}')
LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: ERROR ❌ requires {nc} classes, not {nd}")
else:
LOGGER.info(f'{prefix} found {nf} images in {nd} classes ✅ ')
LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ")
return {'train': train_set, 'val': val_set, 'test': test_set, 'nc': nc, 'names': names}
return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names}
class HUBDatasetStats:
@ -423,42 +443,43 @@ class HUBDatasetStats:
```
"""
def __init__(self, path='coco8.yaml', task='detect', autodownload=False):
def __init__(self, path="coco8.yaml", task="detect", autodownload=False):
"""Initialize class."""
path = Path(path).resolve()
LOGGER.info(f'Starting HUB dataset checks for {path}....')
LOGGER.info(f"Starting HUB dataset checks for {path}....")
self.task = task # detect, segment, pose, classify
if self.task == 'classify':
if self.task == "classify":
unzip_dir = unzip_file(path)
data = check_cls_dataset(unzip_dir)
data['path'] = unzip_dir
data["path"] = unzip_dir
else: # detect, segment, pose
_, data_dir, yaml_path = self._unzip(Path(path))
try:
# Load YAML with checks
data = yaml_load(yaml_path)
data['path'] = '' # strip path since YAML should be in dataset root for all HUB datasets
data["path"] = "" # strip path since YAML should be in dataset root for all HUB datasets
yaml_save(yaml_path, data)
data = check_det_dataset(yaml_path, autodownload) # dict
data['path'] = data_dir # YAML path should be set to '' (relative) or parent (absolute)
data["path"] = data_dir # YAML path should be set to '' (relative) or parent (absolute)
except Exception as e:
raise Exception('error/HUB/dataset_stats/init') from e
raise Exception("error/HUB/dataset_stats/init") from e
self.hub_dir = Path(f'{data["path"]}-hub')
self.im_dir = self.hub_dir / 'images'
self.im_dir = self.hub_dir / "images"
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
self.stats = {'nc': len(data['names']), 'names': list(data['names'].values())} # statistics dictionary
self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())} # statistics dictionary
self.data = data
@staticmethod
def _unzip(path):
"""Unzip data.zip."""
if not str(path).endswith('.zip'): # path is data.yaml
if not str(path).endswith(".zip"): # path is data.yaml
return False, None, path
unzip_dir = unzip_file(path, path=path.parent)
assert unzip_dir.is_dir(), f'Error unzipping {path}, {unzip_dir} not found. ' \
f'path/to/abc.zip MUST unzip to path/to/abc/'
assert unzip_dir.is_dir(), (
f"Error unzipping {path}, {unzip_dir} not found. " f"path/to/abc.zip MUST unzip to path/to/abc/"
)
return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path
def _hub_ops(self, f):
@ -470,31 +491,31 @@ class HUBDatasetStats:
def _round(labels):
"""Update labels to integer class and 4 decimal place floats."""
if self.task == 'detect':
coordinates = labels['bboxes']
elif self.task == 'segment':
coordinates = [x.flatten() for x in labels['segments']]
elif self.task == 'pose':
n = labels['keypoints'].shape[0]
coordinates = np.concatenate((labels['bboxes'], labels['keypoints'].reshape(n, -1)), 1)
if self.task == "detect":
coordinates = labels["bboxes"]
elif self.task == "segment":
coordinates = [x.flatten() for x in labels["segments"]]
elif self.task == "pose":
n = labels["keypoints"].shape[0]
coordinates = np.concatenate((labels["bboxes"], labels["keypoints"].reshape(n, -1)), 1)
else:
raise ValueError('Undefined dataset task.')
zipped = zip(labels['cls'], coordinates)
raise ValueError("Undefined dataset task.")
zipped = zip(labels["cls"], coordinates)
return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped]
for split in 'train', 'val', 'test':
for split in "train", "val", "test":
self.stats[split] = None # predefine
path = self.data.get(split)
# Check split
if path is None: # no split
continue
files = [f for f in Path(path).rglob('*.*') if f.suffix[1:].lower() in IMG_FORMATS] # image files in split
files = [f for f in Path(path).rglob("*.*") if f.suffix[1:].lower() in IMG_FORMATS] # image files in split
if not files: # no images
continue
# Get dataset statistics
if self.task == 'classify':
if self.task == "classify":
from torchvision.datasets import ImageFolder
dataset = ImageFolder(self.data[split])
@ -504,38 +525,35 @@ class HUBDatasetStats:
x[im[1]] += 1
self.stats[split] = {
'instance_stats': {
'total': len(dataset),
'per_class': x.tolist()},
'image_stats': {
'total': len(dataset),
'unlabelled': 0,
'per_class': x.tolist()},
'labels': [{
Path(k).name: v} for k, v in dataset.imgs]}
"instance_stats": {"total": len(dataset), "per_class": x.tolist()},
"image_stats": {"total": len(dataset), "unlabelled": 0, "per_class": x.tolist()},
"labels": [{Path(k).name: v} for k, v in dataset.imgs],
}
else:
from ultralytics.data import YOLODataset
dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task)
x = np.array([
np.bincount(label['cls'].astype(int).flatten(), minlength=self.data['nc'])
for label in TQDM(dataset.labels, total=len(dataset), desc='Statistics')]) # shape(128x80)
x = np.array(
[
np.bincount(label["cls"].astype(int).flatten(), minlength=self.data["nc"])
for label in TQDM(dataset.labels, total=len(dataset), desc="Statistics")
]
) # shape(128x80)
self.stats[split] = {
'instance_stats': {
'total': int(x.sum()),
'per_class': x.sum(0).tolist()},
'image_stats': {
'total': len(dataset),
'unlabelled': int(np.all(x == 0, 1).sum()),
'per_class': (x > 0).sum(0).tolist()},
'labels': [{
Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)]}
"instance_stats": {"total": int(x.sum()), "per_class": x.sum(0).tolist()},
"image_stats": {
"total": len(dataset),
"unlabelled": int(np.all(x == 0, 1).sum()),
"per_class": (x > 0).sum(0).tolist(),
},
"labels": [{Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)],
}
# Save, print and return
if save:
stats_path = self.hub_dir / 'stats.json'
LOGGER.info(f'Saving {stats_path.resolve()}...')
with open(stats_path, 'w') as f:
stats_path = self.hub_dir / "stats.json"
LOGGER.info(f"Saving {stats_path.resolve()}...")
with open(stats_path, "w") as f:
json.dump(self.stats, f) # save stats.json
if verbose:
LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
@ -545,14 +563,14 @@ class HUBDatasetStats:
"""Compress images for Ultralytics HUB."""
from ultralytics.data import YOLODataset # ClassificationDataset
for split in 'train', 'val', 'test':
for split in "train", "val", "test":
if self.data.get(split) is None:
continue
dataset = YOLODataset(img_path=self.data[split], data=self.data)
with ThreadPool(NUM_THREADS) as pool:
for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f'{split} images'):
for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f"{split} images"):
pass
LOGGER.info(f'Done. All images saved to {self.im_dir}')
LOGGER.info(f"Done. All images saved to {self.im_dir}")
return self.im_dir
@ -583,9 +601,9 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
r = max_dim / max(im.height, im.width) # ratio
if r < 1.0: # image too large
im = im.resize((int(im.width * r), int(im.height * r)))
im.save(f_new or f, 'JPEG', quality=quality, optimize=True) # save
im.save(f_new or f, "JPEG", quality=quality, optimize=True) # save
except Exception as e: # use OpenCV
LOGGER.info(f'WARNING ⚠ HUB ops PIL failure {f}: {e}')
LOGGER.info(f"WARNING ⚠ HUB ops PIL failure {f}: {e}")
im = cv2.imread(f)
im_height, im_width = im.shape[:2]
r = max_dim / max(im_height, im_width) # ratio
@ -594,7 +612,7 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
cv2.imwrite(str(f_new or f), im)
def autosplit(path=DATASETS_DIR / 'coco8/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annotated_only=False):
"""
Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
@ -612,18 +630,18 @@ def autosplit(path=DATASETS_DIR / 'coco8/images', weights=(0.9, 0.1, 0.0), annot
"""
path = Path(path) # images dir
files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS) # image files only
n = len(files) # number of files
random.seed(0) # for reproducibility
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
txt = ["autosplit_train.txt", "autosplit_val.txt", "autosplit_test.txt"] # 3 txt files
for x in txt:
if (path.parent / x).exists():
(path.parent / x).unlink() # remove existing
LOGGER.info(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
LOGGER.info(f"Autosplitting images from {path}" + ", using *.txt labeled images only" * annotated_only)
for i, img in TQDM(zip(indices, files), total=n):
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
with open(path.parent / txt[i], 'a') as f:
f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file
with open(path.parent / txt[i], "a") as f:
f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file

File diff suppressed because it is too large Load Diff

@ -53,7 +53,7 @@ class Model(nn.Module):
list(ultralytics.engine.results.Results): The prediction results.
"""
def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None:
def __init__(self, model: Union[str, Path] = "yolov8n.pt", task=None) -> None:
"""
Initializes the YOLO model.
@ -89,7 +89,7 @@ class Model(nn.Module):
# Load or create new YOLO model
model = checks.check_model_file_from_stem(model) # add suffix, i.e. yolov8n -> yolov8n.pt
if Path(model).suffix in ('.yaml', '.yml'):
if Path(model).suffix in (".yaml", ".yml"):
self._new(model, task)
else:
self._load(model, task)
@ -112,16 +112,20 @@ class Model(nn.Module):
def is_triton_model(model):
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
from urllib.parse import urlsplit
url = urlsplit(model)
return url.netloc and url.path and url.scheme in {'http', 'grpc'}
return url.netloc and url.path and url.scheme in {"http", "grpc"}
@staticmethod
def is_hub_model(model):
"""Check if the provided model is a HUB model."""
return any((
model.startswith(f'{HUB_WEB_ROOT}/models/'), # i.e. https://hub.ultralytics.com/models/MODEL_ID
[len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID
return any(
(
model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODELID
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"),
)
) # MODELID
def _new(self, cfg: str, task=None, model=None, verbose=True):
"""
@ -136,9 +140,9 @@ class Model(nn.Module):
cfg_dict = yaml_model_load(cfg)
self.cfg = cfg
self.task = task or guess_model_task(cfg_dict)
self.model = (model or self._smart_load('model'))(cfg_dict, verbose=verbose and RANK == -1) # build model
self.overrides['model'] = self.cfg
self.overrides['task'] = self.task
self.model = (model or self._smart_load("model"))(cfg_dict, verbose=verbose and RANK == -1) # build model
self.overrides["model"] = self.cfg
self.overrides["task"] = self.task
# Below added to allow export from YAMLs
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
@ -153,9 +157,9 @@ class Model(nn.Module):
task (str | None): model task
"""
suffix = Path(weights).suffix
if suffix == '.pt':
if suffix == ".pt":
self.model, self.ckpt = attempt_load_one_weight(weights)
self.task = self.model.args['task']
self.task = self.model.args["task"]
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
self.ckpt_path = self.model.pt_path
else:
@ -163,12 +167,12 @@ class Model(nn.Module):
self.model, self.ckpt = weights, None
self.task = task or guess_model_task(weights)
self.ckpt_path = weights
self.overrides['model'] = weights
self.overrides['task'] = self.task
self.overrides["model"] = weights
self.overrides["task"] = self.task
def _check_is_pytorch_model(self):
"""Raises TypeError is model is not a PyTorch model."""
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
pt_module = isinstance(self.model, nn.Module)
if not (pt_module or pt_str):
raise TypeError(
@ -176,19 +180,20 @@ class Model(nn.Module):
f"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported "
f"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, "
f"i.e. 'yolo predict model=yolov8n.onnx'.\nTo run CUDA or MPS inference please pass the device "
f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'")
f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'"
)
def reset_weights(self):
"""Resets the model modules parameters to randomly initialized values, losing all training information."""
self._check_is_pytorch_model()
for m in self.model.modules():
if hasattr(m, 'reset_parameters'):
if hasattr(m, "reset_parameters"):
m.reset_parameters()
for p in self.model.parameters():
p.requires_grad = True
return self
def load(self, weights='yolov8n.pt'):
def load(self, weights="yolov8n.pt"):
"""Transfers parameters with matching names and shapes from 'weights' to model."""
self._check_is_pytorch_model()
if isinstance(weights, (str, Path)):
@ -226,8 +231,8 @@ class Model(nn.Module):
Returns:
(List[torch.Tensor]): A list of image embeddings.
"""
if not kwargs.get('embed'):
kwargs['embed'] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
if not kwargs.get("embed"):
kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
return self.predict(source, stream, **kwargs)
def predict(self, source=None, stream=False, predictor=None, **kwargs):
@ -249,21 +254,22 @@ class Model(nn.Module):
source = ASSETS
LOGGER.warning(f"WARNING ⚠ 'source' is missing. Using 'source={source}'.")
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
is_cli = (sys.argv[0].endswith("yolo") or sys.argv[0].endswith("ultralytics")) and any(
x in sys.argv for x in ("predict", "track", "mode=predict", "mode=track")
)
custom = {'conf': 0.25, 'save': is_cli} # method defaults
args = {**self.overrides, **custom, **kwargs, 'mode': 'predict'} # highest priority args on the right
prompts = args.pop('prompts', None) # for SAM-type models
custom = {"conf": 0.25, "save": is_cli} # method defaults
args = {**self.overrides, **custom, **kwargs, "mode": "predict"} # highest priority args on the right
prompts = args.pop("prompts", None) # for SAM-type models
if not self.predictor:
self.predictor = predictor or self._smart_load('predictor')(overrides=args, _callbacks=self.callbacks)
self.predictor = predictor or self._smart_load("predictor")(overrides=args, _callbacks=self.callbacks)
self.predictor.setup_model(model=self.model, verbose=is_cli)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, args)
if 'project' in args or 'name' in args:
if "project" in args or "name" in args:
self.predictor.save_dir = get_save_dir(self.predictor.args)
if prompts and hasattr(self.predictor, 'set_prompts'): # for SAM-type models
if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models
self.predictor.set_prompts(prompts)
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
@ -280,11 +286,12 @@ class Model(nn.Module):
Returns:
(List[ultralytics.engine.results.Results]): The tracking results.
"""
if not hasattr(self.predictor, 'trackers'):
if not hasattr(self.predictor, "trackers"):
from ultralytics.trackers import register_tracker
register_tracker(self, persist)
kwargs['conf'] = kwargs.get('conf') or 0.1 # ByteTrack-based method needs low confidence predictions as input
kwargs['mode'] = 'track'
kwargs["conf"] = kwargs.get("conf") or 0.1 # ByteTrack-based method needs low confidence predictions as input
kwargs["mode"] = "track"
return self.predict(source=source, stream=stream, **kwargs)
def val(self, validator=None, **kwargs):
@ -295,10 +302,10 @@ class Model(nn.Module):
validator (BaseValidator): Customized validator.
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
"""
custom = {'rect': True} # method defaults
args = {**self.overrides, **custom, **kwargs, 'mode': 'val'} # highest priority args on the right
custom = {"rect": True} # method defaults
args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right
validator = (validator or self._smart_load('validator'))(args=args, _callbacks=self.callbacks)
validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks)
validator(model=self.model)
self.metrics = validator.metrics
return validator.metrics
@ -313,16 +320,17 @@ class Model(nn.Module):
self._check_is_pytorch_model()
from ultralytics.utils.benchmarks import benchmark
custom = {'verbose': False} # method defaults
args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, 'mode': 'benchmark'}
custom = {"verbose": False} # method defaults
args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"}
return benchmark(
model=self,
data=kwargs.get('data'), # if no 'data' argument passed set data=None for default datasets
imgsz=args['imgsz'],
half=args['half'],
int8=args['int8'],
device=args['device'],
verbose=kwargs.get('verbose'))
data=kwargs.get("data"), # if no 'data' argument passed set data=None for default datasets
imgsz=args["imgsz"],
half=args["half"],
int8=args["int8"],
device=args["device"],
verbose=kwargs.get("verbose"),
)
def export(self, **kwargs):
"""
@ -334,8 +342,8 @@ class Model(nn.Module):
self._check_is_pytorch_model()
from .exporter import Exporter
custom = {'imgsz': self.model.args['imgsz'], 'batch': 1, 'data': None, 'verbose': False} # method defaults
args = {**self.overrides, **custom, **kwargs, 'mode': 'export'} # highest priority args on the right
custom = {"imgsz": self.model.args["imgsz"], "batch": 1, "data": None, "verbose": False} # method defaults
args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
def train(self, trainer=None, **kwargs):
@ -347,32 +355,32 @@ class Model(nn.Module):
**kwargs (Any): Any number of arguments representing the training configuration.
"""
self._check_is_pytorch_model()
if hasattr(self.session, 'model') and self.session.model.id: # Ultralytics HUB session with loaded model
if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model
if any(kwargs):
LOGGER.warning('WARNING ⚠ using HUB training arguments, ignoring local training arguments.')
LOGGER.warning("WARNING ⚠ using HUB training arguments, ignoring local training arguments.")
kwargs = self.session.train_args # overwrite kwargs
checks.check_pip_update_available()
overrides = yaml_load(checks.check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides
custom = {'data': DEFAULT_CFG_DICT['data'] or TASK2DATA[self.task]} # method defaults
args = {**overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right
if args.get('resume'):
args['resume'] = self.ckpt_path
overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides
custom = {"data": DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task]} # method defaults
args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
if args.get("resume"):
args["resume"] = self.ckpt_path
self.trainer = (trainer or self._smart_load('trainer'))(overrides=args, _callbacks=self.callbacks)
if not args.get('resume'): # manually set model only if not resuming
self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)
if not args.get("resume"): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
if SETTINGS['hub'] is True and not self.session:
if SETTINGS["hub"] is True and not self.session:
# Create a model in HUB
try:
self.session = self._get_hub_session(self.model_name)
if self.session:
self.session.create_model(args)
# Check model was created
if not getattr(self.session.model, 'id', None):
if not getattr(self.session.model, "id", None):
self.session = None
except PermissionError:
# Ignore permission error
@ -385,7 +393,7 @@ class Model(nn.Module):
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
self.model, _ = attempt_load_one_weight(ckpt)
self.overrides = self.model.args
self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
return self.metrics
def tune(self, use_ray=False, iterations=10, *args, **kwargs):
@ -398,12 +406,13 @@ class Model(nn.Module):
self._check_is_pytorch_model()
if use_ray:
from ultralytics.utils.tuner import run_ray_tune
return run_ray_tune(self, max_samples=iterations, *args, **kwargs)
else:
from .tuner import Tuner
custom = {} # method defaults
args = {**self.overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right
args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
def _apply(self, fn):
@ -411,13 +420,13 @@ class Model(nn.Module):
self._check_is_pytorch_model()
self = super()._apply(fn) # noqa
self.predictor = None # reset predictor as device may have changed
self.overrides['device'] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0'
self.overrides["device"] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0'
return self
@property
def names(self):
"""Returns class names of the loaded model."""
return self.model.names if hasattr(self.model, 'names') else None
return self.model.names if hasattr(self.model, "names") else None
@property
def device(self):
@ -427,7 +436,7 @@ class Model(nn.Module):
@property
def transforms(self):
"""Returns transform of the loaded model."""
return self.model.transforms if hasattr(self.model, 'transforms') else None
return self.model.transforms if hasattr(self.model, "transforms") else None
def add_callback(self, event: str, func):
"""Add a callback."""
@ -445,7 +454,7 @@ class Model(nn.Module):
@staticmethod
def _reset_ckpt_args(args):
"""Reset arguments when loading a PyTorch model."""
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model
return {k: v for k, v in args.items() if k in include}
# def __getattr__(self, attr):
@ -461,7 +470,8 @@ class Model(nn.Module):
name = self.__class__.__name__
mode = inspect.stack()[1][3] # get the function name.
raise NotImplementedError(
emojis(f"WARNING ⚠ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.")) from e
emojis(f"WARNING ⚠ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.")
) from e
@property
def task_map(self):
@ -471,4 +481,4 @@ class Model(nn.Module):
Returns:
task_map (dict): The map of model task to mode classes.
"""
raise NotImplementedError('Please provide task map for your model!')
raise NotImplementedError("Please provide task map for your model!")

@ -132,8 +132,11 @@ class BasePredictor:
def inference(self, im, *args, **kwargs):
"""Runs inference on a given image using the specified model and arguments."""
visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
visualize = (
increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
if self.args.visualize and (not self.source_type.tensor)
else False
)
return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
def pre_transform(self, im):
@ -153,35 +156,38 @@ class BasePredictor:
def write_results(self, idx, results, batch):
"""Write inference results to a file or directory."""
p, im, _ = batch
log_string = ''
log_string = ""
if len(im.shape) == 3:
im = im[None] # expand for batch dim
if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
log_string += f'{idx}: '
log_string += f"{idx}: "
frame = self.dataset.count
else:
frame = getattr(self.dataset, 'frame', 0)
frame = getattr(self.dataset, "frame", 0)
self.data_path = p
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
log_string += '%gx%g ' % im.shape[2:] # print string
self.txt_path = str(self.save_dir / "labels" / p.stem) + ("" if self.dataset.mode == "image" else f"_{frame}")
log_string += "%gx%g " % im.shape[2:] # print string
result = results[idx]
log_string += result.verbose()
if self.args.save or self.args.show: # Add bbox to image
plot_args = {
'line_width': self.args.line_width,
'boxes': self.args.show_boxes,
'conf': self.args.show_conf,
'labels': self.args.show_labels}
"line_width": self.args.line_width,
"boxes": self.args.show_boxes,
"conf": self.args.show_conf,
"labels": self.args.show_labels,
}
if not self.args.retina_masks:
plot_args['im_gpu'] = im[idx]
plot_args["im_gpu"] = im[idx]
self.plotted_img = result.plot(**plot_args)
# Write
if self.args.save_txt:
result.save_txt(f'{self.txt_path}.txt', save_conf=self.args.save_conf)
result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf)
if self.args.save_crop:
result.save_crop(save_dir=self.save_dir / 'crops',
file_name=self.data_path.stem + ('' if self.dataset.mode == 'image' else f'_{frame}'))
result.save_crop(
save_dir=self.save_dir / "crops",
file_name=self.data_path.stem + ("" if self.dataset.mode == "image" else f"_{frame}"),
)
return log_string
@ -210,17 +216,24 @@ class BasePredictor:
def setup_source(self, source):
"""Sets up source and inference mode."""
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
self.transforms = getattr(
self.model.model, 'transforms', classify_transforms(
self.imgsz[0], crop_fraction=self.args.crop_fraction)) if self.args.task == 'classify' else None
self.dataset = load_inference_source(source=source,
imgsz=self.imgsz,
vid_stride=self.args.vid_stride,
buffer=self.args.stream_buffer)
self.transforms = (
getattr(
self.model.model,
"transforms",
classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction),
)
if self.args.task == "classify"
else None
)
self.dataset = load_inference_source(
source=source, imgsz=self.imgsz, vid_stride=self.args.vid_stride, buffer=self.args.stream_buffer
)
self.source_type = self.dataset.source_type
if not getattr(self, 'stream', True) and (self.dataset.mode == 'stream' or # streams
len(self.dataset) > 1000 or # images
any(getattr(self.dataset, 'video_flag', [False]))): # videos
if not getattr(self, "stream", True) and (
self.dataset.mode == "stream" # streams
or len(self.dataset) > 1000 # images
or any(getattr(self.dataset, "video_flag", [False]))
): # videos
LOGGER.warning(STREAM_WARNING)
self.vid_path = [None] * self.dataset.bs
self.vid_writer = [None] * self.dataset.bs
@ -230,7 +243,7 @@ class BasePredictor:
def stream_inference(self, source=None, model=None, *args, **kwargs):
"""Streams real-time inference on camera feed and saves results to file."""
if self.args.verbose:
LOGGER.info('')
LOGGER.info("")
# Setup model
if not self.model:
@ -242,7 +255,7 @@ class BasePredictor:
# Check if save_dir/ label file exists
if self.args.save or self.args.save_txt:
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
# Warmup model
if not self.done_warmup:
@ -250,10 +263,10 @@ class BasePredictor:
self.done_warmup = True
self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile())
self.run_callbacks('on_predict_start')
self.run_callbacks("on_predict_start")
for batch in self.dataset:
self.run_callbacks('on_predict_batch_start')
self.run_callbacks("on_predict_batch_start")
self.batch = batch
path, im0s, vid_cap, s = batch
@ -272,15 +285,16 @@ class BasePredictor:
with profilers[2]:
self.results = self.postprocess(preds, im, im0s)
self.run_callbacks('on_predict_postprocess_end')
self.run_callbacks("on_predict_postprocess_end")
# Visualize, save, write results
n = len(im0s)
for i in range(n):
self.seen += 1
self.results[i].speed = {
'preprocess': profilers[0].dt * 1E3 / n,
'inference': profilers[1].dt * 1E3 / n,
'postprocess': profilers[2].dt * 1E3 / n}
"preprocess": profilers[0].dt * 1e3 / n,
"inference": profilers[1].dt * 1e3 / n,
"postprocess": profilers[2].dt * 1e3 / n,
}
p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy()
p = Path(p)
@ -293,12 +307,12 @@ class BasePredictor:
if self.args.save and self.plotted_img is not None:
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
self.run_callbacks('on_predict_batch_end')
self.run_callbacks("on_predict_batch_end")
yield from self.results
# Print time (inference-only)
if self.args.verbose:
LOGGER.info(f'{s}{profilers[1].dt * 1E3:.1f}ms')
LOGGER.info(f"{s}{profilers[1].dt * 1E3:.1f}ms")
# Release assets
if isinstance(self.vid_writer[-1], cv2.VideoWriter):
@ -306,25 +320,29 @@ class BasePredictor:
# Print results
if self.args.verbose and self.seen:
t = tuple(x.t / self.seen * 1E3 for x in profilers) # speeds per image
LOGGER.info(f'Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape '
f'{(1, 3, *im.shape[2:])}' % t)
t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image
LOGGER.info(
f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape "
f"{(1, 3, *im.shape[2:])}" % t
)
if self.args.save or self.args.save_txt or self.args.save_crop:
nl = len(list(self.save_dir.glob('labels/*.txt'))) # number of labels
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ""
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
self.run_callbacks('on_predict_end')
self.run_callbacks("on_predict_end")
def setup_model(self, model, verbose=True):
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
self.model = AutoBackend(model or self.args.model,
device=select_device(self.args.device, verbose=verbose),
dnn=self.args.dnn,
data=self.args.data,
fp16=self.args.half,
fuse=True,
verbose=verbose)
self.model = AutoBackend(
model or self.args.model,
device=select_device(self.args.device, verbose=verbose),
dnn=self.args.dnn,
data=self.args.data,
fp16=self.args.half,
fuse=True,
verbose=verbose,
)
self.device = self.model.device # update device
self.args.half = self.model.fp16 # update half
@ -333,18 +351,18 @@ class BasePredictor:
def show(self, p):
"""Display an image in a window using OpenCV imshow()."""
im0 = self.plotted_img
if platform.system() == 'Linux' and p not in self.windows:
if platform.system() == "Linux" and p not in self.windows:
self.windows.append(p)
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
cv2.imshow(str(p), im0)
cv2.waitKey(500 if self.batch[3].startswith('image') else 1) # 1 millisecond
cv2.waitKey(500 if self.batch[3].startswith("image") else 1) # 1 millisecond
def save_preds(self, vid_cap, idx, save_path):
"""Save video predictions as mp4 at specified path."""
im0 = self.plotted_img
# Save imgs
if self.dataset.mode == 'image':
if self.dataset.mode == "image":
cv2.imwrite(save_path, im0)
else: # 'video' or 'stream'
frames_path = f'{save_path.split(".", 1)[0]}_frames/'
@ -361,15 +379,16 @@ class BasePredictor:
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
else: # stream
fps, w, h = 30, im0.shape[1], im0.shape[0]
suffix, fourcc = ('.mp4', 'avc1') if MACOS else ('.avi', 'WMV2') if WINDOWS else ('.avi', 'MJPG')
self.vid_writer[idx] = cv2.VideoWriter(str(Path(save_path).with_suffix(suffix)),
cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG")
self.vid_writer[idx] = cv2.VideoWriter(
str(Path(save_path).with_suffix(suffix)), cv2.VideoWriter_fourcc(*fourcc), fps, (w, h)
)
# Write video
self.vid_writer[idx].write(im0)
# Write frame
if self.args.save_frames:
cv2.imwrite(f'{frames_path}{self.vid_frame[idx]}.jpg', im0)
cv2.imwrite(f"{frames_path}{self.vid_frame[idx]}.jpg", im0)
self.vid_frame[idx] += 1
def run_callbacks(self, event: str):

@ -98,15 +98,15 @@ class Results(SimpleClass):
self.probs = Probs(probs) if probs is not None else None
self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None
self.obb = OBB(obb, self.orig_shape) if obb is not None else None
self.speed = {'preprocess': None, 'inference': None, 'postprocess': None} # milliseconds per image
self.speed = {"preprocess": None, "inference": None, "postprocess": None} # milliseconds per image
self.names = names
self.path = path
self.save_dir = None
self._keys = 'boxes', 'masks', 'probs', 'keypoints', 'obb'
self._keys = "boxes", "masks", "probs", "keypoints", "obb"
def __getitem__(self, idx):
"""Return a Results object for the specified index."""
return self._apply('__getitem__', idx)
return self._apply("__getitem__", idx)
def __len__(self):
"""Return the number of detections in the Results object."""
@ -146,19 +146,19 @@ class Results(SimpleClass):
def cpu(self):
"""Return a copy of the Results object with all tensors on CPU memory."""
return self._apply('cpu')
return self._apply("cpu")
def numpy(self):
"""Return a copy of the Results object with all tensors as numpy arrays."""
return self._apply('numpy')
return self._apply("numpy")
def cuda(self):
"""Return a copy of the Results object with all tensors on GPU memory."""
return self._apply('cuda')
return self._apply("cuda")
def to(self, *args, **kwargs):
"""Return a copy of the Results object with tensors on the specified device and dtype."""
return self._apply('to', *args, **kwargs)
return self._apply("to", *args, **kwargs)
def new(self):
"""Return a new Results object with the same image, path, and names."""
@ -169,7 +169,7 @@ class Results(SimpleClass):
conf=True,
line_width=None,
font_size=None,
font='Arial.ttf',
font="Arial.ttf",
pil=False,
img=None,
im_gpu=None,
@ -229,14 +229,20 @@ class Results(SimpleClass):
font_size,
font,
pil or (pred_probs is not None and show_probs), # Classify tasks default to pil=True
example=names)
example=names,
)
# Plot Segment results
if pred_masks and show_masks:
if im_gpu is None:
img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
im_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute(
2, 0, 1).flip(0).contiguous() / 255
im_gpu = (
torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device)
.permute(2, 0, 1)
.flip(0)
.contiguous()
/ 255
)
idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu)
@ -244,14 +250,14 @@ class Results(SimpleClass):
if pred_boxes is not None and show_boxes:
for d in reversed(pred_boxes):
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
name = ('' if id is None else f'id:{id} ') + names[c]
label = (f'{name} {conf:.2f}' if conf else name) if labels else None
name = ("" if id is None else f"id:{id} ") + names[c]
label = (f"{name} {conf:.2f}" if conf else name) if labels else None
box = d.xyxyxyxy.reshape(-1, 4, 2).squeeze() if is_obb else d.xyxy.squeeze()
annotator.box_label(box, label, color=colors(c, True), rotated=is_obb)
# Plot Classify results
if pred_probs is not None and show_probs:
text = ',\n'.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)
text = ",\n".join(f"{names[j] if names else j} {pred_probs.data[j]:.2f}" for j in pred_probs.top5)
x = round(self.orig_shape[0] * 0.03)
annotator.text([x, x], text, txt_color=(255, 255, 255)) # TODO: allow setting colors
@ -264,11 +270,11 @@ class Results(SimpleClass):
def verbose(self):
"""Return log string for each task."""
log_string = ''
log_string = ""
probs = self.probs
boxes = self.boxes
if len(self) == 0:
return log_string if probs is not None else f'{log_string}(no detections), '
return log_string if probs is not None else f"{log_string}(no detections), "
if probs is not None:
log_string += f"{', '.join(f'{self.names[j]} {probs.data[j]:.2f}' for j in probs.top5)}, "
if boxes:
@ -293,7 +299,7 @@ class Results(SimpleClass):
texts = []
if probs is not None:
# Classify
[texts.append(f'{probs.data[j]:.2f} {self.names[j]}') for j in probs.top5]
[texts.append(f"{probs.data[j]:.2f} {self.names[j]}") for j in probs.top5]
elif boxes:
# Detect/segment/pose
for j, d in enumerate(boxes):
@ -304,16 +310,16 @@ class Results(SimpleClass):
line = (c, *seg)
if kpts is not None:
kpt = torch.cat((kpts[j].xyn, kpts[j].conf[..., None]), 2) if kpts[j].has_visible else kpts[j].xyn
line += (*kpt.reshape(-1).tolist(), )
line += (conf, ) * save_conf + (() if id is None else (id, ))
texts.append(('%g ' * len(line)).rstrip() % line)
line += (*kpt.reshape(-1).tolist(),)
line += (conf,) * save_conf + (() if id is None else (id,))
texts.append(("%g " * len(line)).rstrip() % line)
if texts:
Path(txt_file).parent.mkdir(parents=True, exist_ok=True) # make directory
with open(txt_file, 'a') as f:
f.writelines(text + '\n' for text in texts)
with open(txt_file, "a") as f:
f.writelines(text + "\n" for text in texts)
def save_crop(self, save_dir, file_name=Path('im.jpg')):
def save_crop(self, save_dir, file_name=Path("im.jpg")):
"""
Save cropped predictions to `save_dir/cls/file_name.jpg`.
@ -322,21 +328,23 @@ class Results(SimpleClass):
file_name (str | pathlib.Path): File name.
"""
if self.probs is not None:
LOGGER.warning('WARNING ⚠ Classify task do not support `save_crop`.')
LOGGER.warning("WARNING ⚠ Classify task do not support `save_crop`.")
return
if self.obb is not None:
LOGGER.warning('WARNING ⚠ OBB task do not support `save_crop`.')
LOGGER.warning("WARNING ⚠ OBB task do not support `save_crop`.")
return
for d in self.boxes:
save_one_box(d.xyxy,
self.orig_img.copy(),
file=Path(save_dir) / self.names[int(d.cls)] / f'{Path(file_name)}.jpg',
BGR=True)
save_one_box(
d.xyxy,
self.orig_img.copy(),
file=Path(save_dir) / self.names[int(d.cls)] / f"{Path(file_name)}.jpg",
BGR=True,
)
def tojson(self, normalize=False):
"""Convert the object to JSON format."""
if self.probs is not None:
LOGGER.warning('Warning: Classify task do not support `tojson` yet.')
LOGGER.warning("Warning: Classify task do not support `tojson` yet.")
return
import json
@ -346,19 +354,19 @@ class Results(SimpleClass):
data = self.boxes.data.cpu().tolist()
h, w = self.orig_shape if normalize else (1, 1)
for i, row in enumerate(data): # xyxy, track_id if tracking, conf, class_id
box = {'x1': row[0] / w, 'y1': row[1] / h, 'x2': row[2] / w, 'y2': row[3] / h}
box = {"x1": row[0] / w, "y1": row[1] / h, "x2": row[2] / w, "y2": row[3] / h}
conf = row[-2]
class_id = int(row[-1])
name = self.names[class_id]
result = {'name': name, 'class': class_id, 'confidence': conf, 'box': box}
result = {"name": name, "class": class_id, "confidence": conf, "box": box}
if self.boxes.is_track:
result['track_id'] = int(row[-3]) # track ID
result["track_id"] = int(row[-3]) # track ID
if self.masks:
x, y = self.masks.xy[i][:, 0], self.masks.xy[i][:, 1] # numpy array
result['segments'] = {'x': (x / w).tolist(), 'y': (y / h).tolist()}
result["segments"] = {"x": (x / w).tolist(), "y": (y / h).tolist()}
if self.keypoints is not None:
x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1) # torch Tensor
result['keypoints'] = {'x': (x / w).tolist(), 'y': (y / h).tolist(), 'visible': visible.tolist()}
result["keypoints"] = {"x": (x / w).tolist(), "y": (y / h).tolist(), "visible": visible.tolist()}
results.append(result)
# Convert detections to JSON
@ -397,7 +405,7 @@ class Boxes(BaseTensor):
if boxes.ndim == 1:
boxes = boxes[None, :]
n = boxes.shape[-1]
assert n in (6, 7), f'expected 6 or 7 values but got {n}' # xyxy, track_id, conf, cls
assert n in (6, 7), f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls
super().__init__(boxes, orig_shape)
self.is_track = n == 7
self.orig_shape = orig_shape
@ -474,7 +482,8 @@ class Masks(BaseTensor):
"""Return normalized segments."""
return [
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True)
for x in ops.masks2segments(self.data)]
for x in ops.masks2segments(self.data)
]
@property
@lru_cache(maxsize=1)
@ -482,7 +491,8 @@ class Masks(BaseTensor):
"""Return segments in pixel coordinates."""
return [
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False)
for x in ops.masks2segments(self.data)]
for x in ops.masks2segments(self.data)
]
class Keypoints(BaseTensor):
@ -610,7 +620,7 @@ class OBB(BaseTensor):
if boxes.ndim == 1:
boxes = boxes[None, :]
n = boxes.shape[-1]
assert n in (7, 8), f'expected 7 or 8 values but got {n}' # xywh, rotation, track_id, conf, cls
assert n in (7, 8), f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls
super().__init__(boxes, orig_shape)
self.is_track = n == 8
self.orig_shape = orig_shape

@ -23,14 +23,31 @@ from torch import nn, optim
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, TQDM, __version__, callbacks, clean_url, colorstr, emojis,
yaml_save)
from ultralytics.utils import (
DEFAULT_CFG,
LOGGER,
RANK,
TQDM,
__version__,
callbacks,
clean_url,
colorstr,
emojis,
yaml_save,
)
from ultralytics.utils.autobatch import check_train_batch_size
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.utils.files import get_latest_run
from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device,
strip_optimizer)
from ultralytics.utils.torch_utils import (
EarlyStopping,
ModelEMA,
de_parallel,
init_seeds,
one_cycle,
select_device,
strip_optimizer,
)
class BaseTrainer:
@ -89,12 +106,12 @@ class BaseTrainer:
# Dirs
self.save_dir = get_save_dir(self.args)
self.args.name = self.save_dir.name # update name for loggers
self.wdir = self.save_dir / 'weights' # weights dir
self.wdir = self.save_dir / "weights" # weights dir
if RANK in (-1, 0):
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
self.args.save_dir = str(self.save_dir)
yaml_save(self.save_dir / 'args.yaml', vars(self.args)) # save run args
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
self.save_period = self.args.save_period
self.batch_size = self.args.batch
@ -104,18 +121,18 @@ class BaseTrainer:
print_args(vars(self.args))
# Device
if self.device.type in ('cpu', 'mps'):
if self.device.type in ("cpu", "mps"):
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
# Model and Dataset
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
try:
if self.args.task == 'classify':
if self.args.task == "classify":
self.data = check_cls_dataset(self.args.data)
elif self.args.data.split('.')[-1] in ('yaml', 'yml') or self.args.task in ('detect', 'segment', 'pose'):
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in ("detect", "segment", "pose"):
self.data = check_det_dataset(self.args.data)
if 'yaml_file' in self.data:
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
if "yaml_file" in self.data:
self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
except Exception as e:
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
@ -131,8 +148,8 @@ class BaseTrainer:
self.fitness = None
self.loss = None
self.tloss = None
self.loss_names = ['Loss']
self.csv = self.save_dir / 'results.csv'
self.loss_names = ["Loss"]
self.csv = self.save_dir / "results.csv"
self.plot_idx = [0, 1, 2]
# Callbacks
@ -156,7 +173,7 @@ class BaseTrainer:
def train(self):
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
world_size = len(self.args.device.split(','))
world_size = len(self.args.device.split(","))
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
world_size = len(self.args.device)
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
@ -165,14 +182,16 @@ class BaseTrainer:
world_size = 0
# Run subprocess if DDP training, else train normally
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
if world_size > 1 and "LOCAL_RANK" not in os.environ:
# Argument checks
if self.args.rect:
LOGGER.warning("WARNING ⚠ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
self.args.rect = False
if self.args.batch == -1:
LOGGER.warning("WARNING ⚠ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
"default 'batch=16'")
LOGGER.warning(
"WARNING ⚠ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
"default 'batch=16'"
)
self.args.batch = 16
# Command
@ -199,37 +218,45 @@ class BaseTrainer:
def _setup_ddp(self, world_size):
"""Initializes and sets the DistributedDataParallel parameters for training."""
torch.cuda.set_device(RANK)
self.device = torch.device('cuda', RANK)
self.device = torch.device("cuda", RANK)
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout
os.environ["NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
dist.init_process_group(
'nccl' if dist.is_nccl_available() else 'gloo',
"nccl" if dist.is_nccl_available() else "gloo",
timeout=timedelta(seconds=10800), # 3 hours
rank=RANK,
world_size=world_size)
world_size=world_size,
)
def _setup_train(self, world_size):
"""Builds dataloaders and optimizer on correct rank process."""
# Model
self.run_callbacks('on_pretrain_routine_start')
self.run_callbacks("on_pretrain_routine_start")
ckpt = self.setup_model()
self.model = self.model.to(self.device)
self.set_model_attributes()
# Freeze layers
freeze_list = self.args.freeze if isinstance(
self.args.freeze, list) else range(self.args.freeze) if isinstance(self.args.freeze, int) else []
always_freeze_names = ['.dfl'] # always freeze these layers
freeze_layer_names = [f'model.{x}.' for x in freeze_list] + always_freeze_names
freeze_list = (
self.args.freeze
if isinstance(self.args.freeze, list)
else range(self.args.freeze)
if isinstance(self.args.freeze, int)
else []
)
always_freeze_names = [".dfl"] # always freeze these layers
freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
for k, v in self.model.named_parameters():
# v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
if any(x in k for x in freeze_layer_names):
LOGGER.info(f"Freezing layer '{k}'")
v.requires_grad = False
elif not v.requires_grad:
LOGGER.info(f"WARNING ⚠ setting 'requires_grad=True' for frozen layer '{k}'. "
'See ultralytics.engine.trainer for customization of frozen layers.')
LOGGER.info(
f"WARNING ⚠ setting 'requires_grad=True' for frozen layer '{k}'. "
"See ultralytics.engine.trainer for customization of frozen layers."
)
v.requires_grad = True
# Check AMP
@ -246,7 +273,7 @@ class BaseTrainer:
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])
# Check imgsz
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride)
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
self.stride = gs # for multi-scale training
@ -256,15 +283,14 @@ class BaseTrainer:
# Dataloaders
batch_size = self.batch_size // max(world_size, 1)
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
if RANK in (-1, 0):
# NOTE: When training DOTA dataset, double batch size could get OOM cause some images got more than 2000 objects.
self.test_loader = self.get_dataloader(self.testset,
batch_size=batch_size if self.args.task == 'obb' else batch_size * 2,
rank=-1,
mode='val')
self.test_loader = self.get_dataloader(
self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
)
self.validator = self.get_validator()
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
self.ema = ModelEMA(self.model)
if self.args.plots:
@ -274,18 +300,20 @@ class BaseTrainer:
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
self.optimizer = self.build_optimizer(model=self.model,
name=self.args.optimizer,
lr=self.args.lr0,
momentum=self.args.momentum,
decay=weight_decay,
iterations=iterations)
self.optimizer = self.build_optimizer(
model=self.model,
name=self.args.optimizer,
lr=self.args.lr0,
momentum=self.args.momentum,
decay=weight_decay,
iterations=iterations,
)
# Scheduler
self._setup_scheduler()
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
self.resume_training(ckpt)
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
self.run_callbacks('on_pretrain_routine_end')
self.run_callbacks("on_pretrain_routine_end")
def _do_train(self, world_size=1):
"""Train completed, evaluate and plot if specified by arguments."""
@ -299,19 +327,23 @@ class BaseTrainer:
self.epoch_time = None
self.epoch_time_start = time.time()
self.train_time_start = time.time()
self.run_callbacks('on_train_start')
LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
f"Logging results to {colorstr('bold', self.save_dir)}\n"
f'Starting training for '
f'{self.args.time} hours...' if self.args.time else f'{self.epochs} epochs...')
self.run_callbacks("on_train_start")
LOGGER.info(
f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
f"Logging results to {colorstr('bold', self.save_dir)}\n"
f'Starting training for '
f'{self.args.time} hours...'
if self.args.time
else f"{self.epochs} epochs..."
)
if self.args.close_mosaic:
base_idx = (self.epochs - self.args.close_mosaic) * nb
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
epoch = self.epochs # predefine for resume fully trained model edge cases
for epoch in range(self.start_epoch, self.epochs):
self.epoch = epoch
self.run_callbacks('on_train_epoch_start')
self.run_callbacks("on_train_epoch_start")
self.model.train()
if RANK != -1:
self.train_loader.sampler.set_epoch(epoch)
@ -327,7 +359,7 @@ class BaseTrainer:
self.tloss = None
self.optimizer.zero_grad()
for i, batch in pbar:
self.run_callbacks('on_train_batch_start')
self.run_callbacks("on_train_batch_start")
# Warmup
ni = i + nb * epoch
if ni <= nw:
@ -335,10 +367,11 @@ class BaseTrainer:
self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
for j, x in enumerate(self.optimizer.param_groups):
# Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
x['lr'] = np.interp(
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)])
if 'momentum' in x:
x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
x["lr"] = np.interp(
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
)
if "momentum" in x:
x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
# Forward
with torch.cuda.amp.autocast(self.amp):
@ -346,8 +379,9 @@ class BaseTrainer:
self.loss, self.loss_items = self.model(batch)
if RANK != -1:
self.loss *= world_size
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
else self.loss_items
self.tloss = (
(self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
)
# Backward
self.scaler.scale(self.loss).backward()
@ -368,24 +402,25 @@ class BaseTrainer:
break
# Log
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
if RANK in (-1, 0):
pbar.set_description(
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1]))
self.run_callbacks('on_batch_end')
("%11s" * 2 + "%11.4g" * (2 + loss_len))
% (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
)
self.run_callbacks("on_batch_end")
if self.args.plots and ni in self.plot_idx:
self.plot_training_samples(batch, ni)
self.run_callbacks('on_train_batch_end')
self.run_callbacks("on_train_batch_end")
self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
self.run_callbacks('on_train_epoch_end')
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
self.run_callbacks("on_train_epoch_end")
if RANK in (-1, 0):
final_epoch = epoch + 1 == self.epochs
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
# Validation
if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
@ -398,14 +433,14 @@ class BaseTrainer:
# Save model
if self.args.save or final_epoch:
self.save_model()
self.run_callbacks('on_model_save')
self.run_callbacks("on_model_save")
# Scheduler
t = time.time()
self.epoch_time = t - self.epoch_time_start
self.epoch_time_start = t
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress 'Detected lr_scheduler.step() before optimizer.step()'
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
if self.args.time:
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
@ -413,7 +448,7 @@ class BaseTrainer:
self.scheduler.last_epoch = self.epoch # do not move
self.stop |= epoch >= self.epochs # stop if exceeded epochs
self.scheduler.step()
self.run_callbacks('on_fit_epoch_end')
self.run_callbacks("on_fit_epoch_end")
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
# Early Stopping
@ -426,39 +461,43 @@ class BaseTrainer:
if RANK in (-1, 0):
# Do final val with best.pt
LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in '
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
LOGGER.info(
f"\n{epoch - self.start_epoch + 1} epochs completed in "
f"{(time.time() - self.train_time_start) / 3600:.3f} hours."
)
self.final_eval()
if self.args.plots:
self.plot_metrics()
self.run_callbacks('on_train_end')
self.run_callbacks("on_train_end")
torch.cuda.empty_cache()
self.run_callbacks('teardown')
self.run_callbacks("teardown")
def save_model(self):
"""Save model training checkpoints with additional metadata."""
import pandas as pd # scope for faster startup
metrics = {**self.metrics, **{'fitness': self.fitness}}
results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient='list').items()}
metrics = {**self.metrics, **{"fitness": self.fitness}}
results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
ckpt = {
'epoch': self.epoch,
'best_fitness': self.best_fitness,
'model': deepcopy(de_parallel(self.model)).half(),
'ema': deepcopy(self.ema.ema).half(),
'updates': self.ema.updates,
'optimizer': self.optimizer.state_dict(),
'train_args': vars(self.args), # save as dict
'train_metrics': metrics,
'train_results': results,
'date': datetime.now().isoformat(),
'version': __version__}
"epoch": self.epoch,
"best_fitness": self.best_fitness,
"model": deepcopy(de_parallel(self.model)).half(),
"ema": deepcopy(self.ema.ema).half(),
"updates": self.ema.updates,
"optimizer": self.optimizer.state_dict(),
"train_args": vars(self.args), # save as dict
"train_metrics": metrics,
"train_results": results,
"date": datetime.now().isoformat(),
"version": __version__,
}
# Save last and best
torch.save(ckpt, self.last)
if self.best_fitness == self.fitness:
torch.save(ckpt, self.best)
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt')
torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt")
@staticmethod
def get_dataset(data):
@ -467,7 +506,7 @@ class BaseTrainer:
Returns None if data format is not recognized.
"""
return data['train'], data.get('val') or data.get('test')
return data["train"], data.get("val") or data.get("test")
def setup_model(self):
"""Load/create/download model for any task."""
@ -476,9 +515,9 @@ class BaseTrainer:
model, weights = self.model, None
ckpt = None
if str(model).endswith('.pt'):
if str(model).endswith(".pt"):
weights, ckpt = attempt_load_one_weight(model)
cfg = ckpt['model'].yaml
cfg = ckpt["model"].yaml
else:
cfg = model
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
@ -505,7 +544,7 @@ class BaseTrainer:
The returned dict is expected to contain "fitness" key.
"""
metrics = self.validator(self)
fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
if not self.best_fitness or self.best_fitness < fitness:
self.best_fitness = fitness
return metrics, fitness
@ -516,24 +555,24 @@ class BaseTrainer:
def get_validator(self):
"""Returns a NotImplementedError when the get_validator function is called."""
raise NotImplementedError('get_validator function not implemented in trainer')
raise NotImplementedError("get_validator function not implemented in trainer")
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
"""Returns dataloader derived from torch.data.Dataloader."""
raise NotImplementedError('get_dataloader function not implemented in trainer')
raise NotImplementedError("get_dataloader function not implemented in trainer")
def build_dataset(self, img_path, mode='train', batch=None):
def build_dataset(self, img_path, mode="train", batch=None):
"""Build dataset."""
raise NotImplementedError('build_dataset function not implemented in trainer')
raise NotImplementedError("build_dataset function not implemented in trainer")
def label_loss_items(self, loss_items=None, prefix='train'):
def label_loss_items(self, loss_items=None, prefix="train"):
"""Returns a loss dict with labelled training loss items tensor."""
# Not needed for classification but necessary for segmentation & detection
return {'loss': loss_items} if loss_items is not None else ['loss']
return {"loss": loss_items} if loss_items is not None else ["loss"]
def set_model_attributes(self):
"""To set or update model parameters before training."""
self.model.names = self.data['names']
self.model.names = self.data["names"]
def build_targets(self, preds, targets):
"""Builds target tensors for training YOLO model."""
@ -541,7 +580,7 @@ class BaseTrainer:
def progress_string(self):
"""Returns a string describing training progress."""
return ''
return ""
# TODO: may need to put these following functions into callback
def plot_training_samples(self, batch, ni):
@ -556,9 +595,9 @@ class BaseTrainer:
"""Saves training metrics to a CSV file."""
keys, vals = list(metrics.keys()), list(metrics.values())
n = len(metrics) + 1 # number of cols
s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
with open(self.csv, 'a') as f:
f.write(s + ('%23.5g,' * n % tuple([self.epoch + 1] + vals)).rstrip(',') + '\n')
s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n") # header
with open(self.csv, "a") as f:
f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n")
def plot_metrics(self):
"""Plot and display metrics visually."""
@ -567,7 +606,7 @@ class BaseTrainer:
def on_plot(self, name, data=None):
"""Registers plots (e.g. to be consumed in callbacks)"""
path = Path(name)
self.plots[path] = {'data': data, 'timestamp': time.time()}
self.plots[path] = {"data": data, "timestamp": time.time()}
def final_eval(self):
"""Performs final evaluation and validation for object detection YOLO model."""
@ -575,11 +614,11 @@ class BaseTrainer:
if f.exists():
strip_optimizer(f) # strip optimizers
if f is self.best:
LOGGER.info(f'\nValidating {f}...')
LOGGER.info(f"\nValidating {f}...")
self.validator.args.plots = self.args.plots
self.metrics = self.validator(model=f)
self.metrics.pop('fitness', None)
self.run_callbacks('on_fit_epoch_end')
self.metrics.pop("fitness", None)
self.run_callbacks("on_fit_epoch_end")
def check_resume(self, overrides):
"""Check if resume checkpoint exists and update arguments accordingly."""
@ -591,19 +630,21 @@ class BaseTrainer:
# Check that resume data YAML exists, otherwise strip to force re-download of dataset
ckpt_args = attempt_load_weights(last).args
if not Path(ckpt_args['data']).exists():
ckpt_args['data'] = self.args.data
if not Path(ckpt_args["data"]).exists():
ckpt_args["data"] = self.args.data
resume = True
self.args = get_cfg(ckpt_args)
self.args.model = str(last) # reinstate model
for k in 'imgsz', 'batch': # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
for k in "imgsz", "batch": # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
if k in overrides:
setattr(self.args, k, overrides[k])
except Exception as e:
raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
"i.e. 'yolo train resume model=path/to/last.pt'") from e
raise FileNotFoundError(
"Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
"i.e. 'yolo train resume model=path/to/last.pt'"
) from e
self.resume = resume
def resume_training(self, ckpt):
@ -611,23 +652,26 @@ class BaseTrainer:
if ckpt is None:
return
best_fitness = 0.0
start_epoch = ckpt['epoch'] + 1
if ckpt['optimizer'] is not None:
self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer
best_fitness = ckpt['best_fitness']
if self.ema and ckpt.get('ema'):
self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
self.ema.updates = ckpt['updates']
start_epoch = ckpt["epoch"] + 1
if ckpt["optimizer"] is not None:
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
best_fitness = ckpt["best_fitness"]
if self.ema and ckpt.get("ema"):
self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
self.ema.updates = ckpt["updates"]
if self.resume:
assert start_epoch > 0, \
f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
assert start_epoch > 0, (
f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
)
LOGGER.info(
f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs')
f"Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs"
)
if self.epochs < start_epoch:
LOGGER.info(
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.")
self.epochs += ckpt['epoch'] # finetune additional epochs
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
)
self.epochs += ckpt["epoch"] # finetune additional epochs
self.best_fitness = best_fitness
self.start_epoch = start_epoch
if start_epoch > (self.epochs - self.args.close_mosaic):
@ -635,13 +679,13 @@ class BaseTrainer:
def _close_dataloader_mosaic(self):
"""Update dataloaders to stop using mosaic augmentation."""
if hasattr(self.train_loader.dataset, 'mosaic'):
if hasattr(self.train_loader.dataset, "mosaic"):
self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, 'close_mosaic'):
LOGGER.info('Closing dataloader mosaic')
if hasattr(self.train_loader.dataset, "close_mosaic"):
LOGGER.info("Closing dataloader mosaic")
self.train_loader.dataset.close_mosaic(hyp=self.args)
def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
"""
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
weight decay, and number of iterations.
@ -661,41 +705,45 @@ class BaseTrainer:
"""
g = [], [], [] # optimizer parameter groups
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
if name == 'auto':
LOGGER.info(f"{colorstr('optimizer:')} 'optimizer=auto' found, "
f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
f"determining best 'optimizer', 'lr0' and 'momentum' automatically... ")
nc = getattr(model, 'nc', 10) # number of classes
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
if name == "auto":
LOGGER.info(
f"{colorstr('optimizer:')} 'optimizer=auto' found, "
f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
)
nc = getattr(model, "nc", 10) # number of classes
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9)
name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
for module_name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
fullname = f'{module_name}.{param_name}' if module_name else param_name
if 'bias' in fullname: # bias (no decay)
fullname = f"{module_name}.{param_name}" if module_name else param_name
if "bias" in fullname: # bias (no decay)
g[2].append(param)
elif isinstance(module, bn): # weight (no decay)
g[1].append(param)
else: # weight (with decay)
g[0].append(param)
if name in ('Adam', 'Adamax', 'AdamW', 'NAdam', 'RAdam'):
if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
elif name == 'RMSProp':
elif name == "RMSProp":
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
elif name == 'SGD':
elif name == "SGD":
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
else:
raise NotImplementedError(
f"Optimizer '{name}' not found in list of available optimizers "
f'[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].'
'To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.')
f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]."
"To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics."
)
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
LOGGER.info(
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)')
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
)
return optimizer

@ -73,40 +73,43 @@ class Tuner:
Args:
args (dict, optional): Configuration for hyperparameter evolution.
"""
self.space = args.pop('space', None) or { # key: (min, max, gain(optional))
self.space = args.pop("space", None) or { # key: (min, max, gain(optional))
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
'lr0': (1e-5, 1e-1), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
'lrf': (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf)
'momentum': (0.7, 0.98, 0.3), # SGD momentum/Adam beta1
'weight_decay': (0.0, 0.001), # optimizer weight decay 5e-4
'warmup_epochs': (0.0, 5.0), # warmup epochs (fractions ok)
'warmup_momentum': (0.0, 0.95), # warmup initial momentum
'box': (1.0, 20.0), # box loss gain
'cls': (0.2, 4.0), # cls loss gain (scale with pixels)
'dfl': (0.4, 6.0), # dfl loss gain
'hsv_h': (0.0, 0.1), # image HSV-Hue augmentation (fraction)
'hsv_s': (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
'hsv_v': (0.0, 0.9), # image HSV-Value augmentation (fraction)
'degrees': (0.0, 45.0), # image rotation (+/- deg)
'translate': (0.0, 0.9), # image translation (+/- fraction)
'scale': (0.0, 0.95), # image scale (+/- gain)
'shear': (0.0, 10.0), # image shear (+/- deg)
'perspective': (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
'flipud': (0.0, 1.0), # image flip up-down (probability)
'fliplr': (0.0, 1.0), # image flip left-right (probability)
'mosaic': (0.0, 1.0), # image mixup (probability)
'mixup': (0.0, 1.0), # image mixup (probability)
'copy_paste': (0.0, 1.0)} # segment copy-paste (probability)
"lr0": (1e-5, 1e-1), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
"lrf": (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf)
"momentum": (0.7, 0.98, 0.3), # SGD momentum/Adam beta1
"weight_decay": (0.0, 0.001), # optimizer weight decay 5e-4
"warmup_epochs": (0.0, 5.0), # warmup epochs (fractions ok)
"warmup_momentum": (0.0, 0.95), # warmup initial momentum
"box": (1.0, 20.0), # box loss gain
"cls": (0.2, 4.0), # cls loss gain (scale with pixels)
"dfl": (0.4, 6.0), # dfl loss gain
"hsv_h": (0.0, 0.1), # image HSV-Hue augmentation (fraction)
"hsv_s": (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
"hsv_v": (0.0, 0.9), # image HSV-Value augmentation (fraction)
"degrees": (0.0, 45.0), # image rotation (+/- deg)
"translate": (0.0, 0.9), # image translation (+/- fraction)
"scale": (0.0, 0.95), # image scale (+/- gain)
"shear": (0.0, 10.0), # image shear (+/- deg)
"perspective": (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
"flipud": (0.0, 1.0), # image flip up-down (probability)
"fliplr": (0.0, 1.0), # image flip left-right (probability)
"mosaic": (0.0, 1.0), # image mixup (probability)
"mixup": (0.0, 1.0), # image mixup (probability)
"copy_paste": (0.0, 1.0), # segment copy-paste (probability)
}
self.args = get_cfg(overrides=args)
self.tune_dir = get_save_dir(self.args, name='tune')
self.tune_csv = self.tune_dir / 'tune_results.csv'
self.tune_dir = get_save_dir(self.args, name="tune")
self.tune_csv = self.tune_dir / "tune_results.csv"
self.callbacks = _callbacks or callbacks.get_default_callbacks()
self.prefix = colorstr('Tuner: ')
self.prefix = colorstr("Tuner: ")
callbacks.add_integration_callbacks(self)
LOGGER.info(f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
f'{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning')
LOGGER.info(
f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning"
)
def _mutate(self, parent='single', n=5, mutation=0.8, sigma=0.2):
def _mutate(self, parent="single", n=5, mutation=0.8, sigma=0.2):
"""
Mutates the hyperparameters based on bounds and scaling factors specified in `self.space`.
@ -121,15 +124,15 @@ class Tuner:
"""
if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate
# Select parent(s)
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=',', skiprows=1)
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
fitness = x[:, 0] # first column
n = min(n, len(x)) # number of previous results to consider
x = x[np.argsort(-fitness)][:n] # top n mutations
w = x[:, 0] - x[:, 0].min() + 1E-6 # weights (sum > 0)
if parent == 'single' or len(x) == 1:
w = x[:, 0] - x[:, 0].min() + 1e-6 # weights (sum > 0)
if parent == "single" or len(x) == 1:
# x = x[random.randint(0, n - 1)] # random selection
x = x[random.choices(range(n), weights=w)[0]] # weighted selection
elif parent == 'weighted':
elif parent == "weighted":
x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
# Mutate
@ -174,44 +177,44 @@ class Tuner:
t0 = time.time()
best_save_dir, best_metrics = None, None
(self.tune_dir / 'weights').mkdir(parents=True, exist_ok=True)
(self.tune_dir / "weights").mkdir(parents=True, exist_ok=True)
for i in range(iterations):
# Mutate hyperparameters
mutated_hyp = self._mutate()
LOGGER.info(f'{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}')
LOGGER.info(f"{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}")
metrics = {}
train_args = {**vars(self.args), **mutated_hyp}
save_dir = get_save_dir(get_cfg(train_args))
weights_dir = save_dir / 'weights'
ckpt_file = weights_dir / ('best.pt' if (weights_dir / 'best.pt').exists() else 'last.pt')
weights_dir = save_dir / "weights"
ckpt_file = weights_dir / ("best.pt" if (weights_dir / "best.pt").exists() else "last.pt")
try:
# Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang)
cmd = ['yolo', 'train', *(f'{k}={v}' for k, v in train_args.items())]
cmd = ["yolo", "train", *(f"{k}={v}" for k, v in train_args.items())]
return_code = subprocess.run(cmd, check=True).returncode
metrics = torch.load(ckpt_file)['train_metrics']
assert return_code == 0, 'training failed'
metrics = torch.load(ckpt_file)["train_metrics"]
assert return_code == 0, "training failed"
except Exception as e:
LOGGER.warning(f'WARNING ❌ training failure for hyperparameter tuning iteration {i + 1}\n{e}')
LOGGER.warning(f"WARNING ❌ training failure for hyperparameter tuning iteration {i + 1}\n{e}")
# Save results and mutated_hyp to CSV
fitness = metrics.get('fitness', 0.0)
fitness = metrics.get("fitness", 0.0)
log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
headers = '' if self.tune_csv.exists() else (','.join(['fitness'] + list(self.space.keys())) + '\n')
with open(self.tune_csv, 'a') as f:
f.write(headers + ','.join(map(str, log_row)) + '\n')
headers = "" if self.tune_csv.exists() else (",".join(["fitness"] + list(self.space.keys())) + "\n")
with open(self.tune_csv, "a") as f:
f.write(headers + ",".join(map(str, log_row)) + "\n")
# Get best results
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=',', skiprows=1)
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
fitness = x[:, 0] # first column
best_idx = fitness.argmax()
best_is_current = best_idx == i
if best_is_current:
best_save_dir = save_dir
best_metrics = {k: round(v, 5) for k, v in metrics.items()}
for ckpt in weights_dir.glob('*.pt'):
shutil.copy2(ckpt, self.tune_dir / 'weights')
for ckpt in weights_dir.glob("*.pt"):
shutil.copy2(ckpt, self.tune_dir / "weights")
elif cleanup:
shutil.rmtree(ckpt_file.parent) # remove iteration weights/ dir to reduce storage space
@ -219,15 +222,19 @@ class Tuner:
plot_tune_results(self.tune_csv)
# Save and print tune results
header = (f'{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n'
f'{self.prefix}Results saved to {colorstr("bold", self.tune_dir)}\n'
f'{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n'
f'{self.prefix}Best fitness metrics are {best_metrics}\n'
f'{self.prefix}Best fitness model is {best_save_dir}\n'
f'{self.prefix}Best fitness hyperparameters are printed below.\n')
LOGGER.info('\n' + header)
header = (
f'{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n'
f'{self.prefix}Results saved to {colorstr("bold", self.tune_dir)}\n'
f'{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n'
f'{self.prefix}Best fitness metrics are {best_metrics}\n'
f'{self.prefix}Best fitness model is {best_save_dir}\n'
f'{self.prefix}Best fitness hyperparameters are printed below.\n'
)
LOGGER.info("\n" + header)
data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())}
yaml_save(self.tune_dir / 'best_hyperparameters.yaml',
data=data,
header=remove_colorstr(header.replace(self.prefix, '# ')) + '\n')
yaml_print(self.tune_dir / 'best_hyperparameters.yaml')
yaml_save(
self.tune_dir / "best_hyperparameters.yaml",
data=data,
header=remove_colorstr(header.replace(self.prefix, "# ")) + "\n",
)
yaml_print(self.tune_dir / "best_hyperparameters.yaml")

@ -89,10 +89,10 @@ class BaseValidator:
self.nc = None
self.iouv = None
self.jdict = None
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
self.save_dir = save_dir or get_save_dir(self.args)
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
if self.args.conf is None:
self.args.conf = 0.001 # default conf=0.001
self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
@ -110,7 +110,7 @@ class BaseValidator:
if self.training:
self.device = trainer.device
self.data = trainer.data
self.args.half = self.device.type != 'cpu' # force FP16 val during training
self.args.half = self.device.type != "cpu" # force FP16 val during training
model = trainer.ema.ema or trainer.model
model = model.half() if self.args.half else model.float()
# self.model = model
@ -119,11 +119,13 @@ class BaseValidator:
model.eval()
else:
callbacks.add_integration_callbacks(self)
model = AutoBackend(model or self.args.model,
device=select_device(self.args.device, self.args.batch),
dnn=self.args.dnn,
data=self.args.data,
fp16=self.args.half)
model = AutoBackend(
model or self.args.model,
device=select_device(self.args.device, self.args.batch),
dnn=self.args.dnn,
data=self.args.data,
fp16=self.args.half,
)
# self.model = model
self.device = model.device # update device
self.args.half = model.fp16 # update half
@ -133,16 +135,16 @@ class BaseValidator:
self.args.batch = model.batch_size
elif not pt and not jit:
self.args.batch = 1 # export.py models default to batch-size 1
LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
if str(self.args.data).split('.')[-1] in ('yaml', 'yml'):
if str(self.args.data).split(".")[-1] in ("yaml", "yml"):
self.data = check_det_dataset(self.args.data)
elif self.args.task == 'classify':
elif self.args.task == "classify":
self.data = check_cls_dataset(self.args.data, split=self.args.split)
else:
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
if self.device.type in ('cpu', 'mps'):
if self.device.type in ("cpu", "mps"):
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
if not pt:
self.args.rect = False
@ -152,13 +154,13 @@ class BaseValidator:
model.eval()
model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup
self.run_callbacks('on_val_start')
self.run_callbacks("on_val_start")
dt = Profile(), Profile(), Profile(), Profile()
bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
self.init_metrics(de_parallel(model))
self.jdict = [] # empty before each val
for batch_i, batch in enumerate(bar):
self.run_callbacks('on_val_batch_start')
self.run_callbacks("on_val_batch_start")
self.batch_i = batch_i
# Preprocess
with dt[0]:
@ -166,7 +168,7 @@ class BaseValidator:
# Inference
with dt[1]:
preds = model(batch['img'], augment=augment)
preds = model(batch["img"], augment=augment)
# Loss
with dt[2]:
@ -182,23 +184,25 @@ class BaseValidator:
self.plot_val_samples(batch, batch_i)
self.plot_predictions(batch, preds, batch_i)
self.run_callbacks('on_val_batch_end')
self.run_callbacks("on_val_batch_end")
stats = self.get_stats()
self.check_stats(stats)
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt)))
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
self.finalize_metrics()
self.print_results()
self.run_callbacks('on_val_end')
self.run_callbacks("on_val_end")
if self.training:
model.float()
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
else:
LOGGER.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
tuple(self.speed.values()))
LOGGER.info(
"Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image"
% tuple(self.speed.values())
)
if self.args.save_json and self.jdict:
with open(str(self.save_dir / 'predictions.json'), 'w') as f:
LOGGER.info(f'Saving {f.name}...')
with open(str(self.save_dir / "predictions.json"), "w") as f:
LOGGER.info(f"Saving {f.name}...")
json.dump(self.jdict, f) # flatten and save
stats = self.eval_json(stats) # update stats
if self.args.plots or self.args.save_json:
@ -228,6 +232,7 @@ class BaseValidator:
if use_scipy:
# WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
import scipy # scope import to avoid importing for all commands
cost_matrix = iou * (iou >= threshold)
if cost_matrix.any():
labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True)
@ -257,11 +262,11 @@ class BaseValidator:
def get_dataloader(self, dataset_path, batch_size):
"""Get data loader from dataset path and batch size."""
raise NotImplementedError('get_dataloader function not implemented for this validator')
raise NotImplementedError("get_dataloader function not implemented for this validator")
def build_dataset(self, img_path):
"""Build dataset."""
raise NotImplementedError('build_dataset function not implemented in validator')
raise NotImplementedError("build_dataset function not implemented in validator")
def preprocess(self, batch):
"""Preprocesses an input batch."""
@ -306,7 +311,7 @@ class BaseValidator:
def on_plot(self, name, data=None):
"""Registers plots (e.g. to be consumed in callbacks)"""
self.plots[Path(name)] = {'data': data, 'timestamp': time.time()}
self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
# TODO: may need to put these following functions into callback
def plot_val_samples(self, batch, ni):

@ -21,10 +21,10 @@ def login(api_key: str = None, save=True) -> bool:
Returns:
bool: True if authentication is successful, False otherwise.
"""
api_key_url = f'{HUB_WEB_ROOT}/settings?tab=api+keys' # Set the redirect URL
saved_key = SETTINGS.get('api_key')
api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys" # set the redirect URL
saved_key = SETTINGS.get("api_key")
active_key = api_key or saved_key
credentials = {'api_key': active_key} if active_key and active_key != '' else None # Set credentials
credentials = {"api_key": active_key} if active_key and active_key != "" else None # set credentials
client = HUBClient(credentials) # initialize HUBClient
@ -32,17 +32,18 @@ def login(api_key: str = None, save=True) -> bool:
# Successfully authenticated with HUB
if save and client.api_key != saved_key:
SETTINGS.update({'api_key': client.api_key}) # update settings with valid API key
SETTINGS.update({"api_key": client.api_key}) # update settings with valid API key
# Set message based on whether key was provided or retrieved from settings
log_message = ('New authentication successful ✅'
if client.api_key == api_key or not credentials else 'Authenticated ✅')
LOGGER.info(f'{PREFIX}{log_message}')
log_message = (
"New authentication successful ✅" if client.api_key == api_key or not credentials else "Authenticated ✅"
)
LOGGER.info(f"{PREFIX}{log_message}")
return True
else:
# Failed to authenticate with HUB
LOGGER.info(f'{PREFIX}Retrieve API key from {api_key_url}')
LOGGER.info(f"{PREFIX}Retrieve API key from {api_key_url}")
return False
@ -57,50 +58,50 @@ def logout():
hub.logout()
```
"""
SETTINGS['api_key'] = ''
SETTINGS["api_key"] = ""
SETTINGS.save()
LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")
def reset_model(model_id=''):
def reset_model(model_id=""):
"""Reset a trained model to an untrained state."""
r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'modelId': model_id}, headers={'x-api-key': Auth().api_key})
r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key})
if r.status_code == 200:
LOGGER.info(f'{PREFIX}Model reset successfully')
LOGGER.info(f"{PREFIX}Model reset successfully")
return
LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}')
LOGGER.warning(f"{PREFIX}Model reset failure {r.status_code} {r.reason}")
def export_fmts_hub():
"""Returns a list of HUB-supported export formats."""
from ultralytics.engine.exporter import export_formats
return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml']
return list(export_formats()["Argument"][1:]) + ["ultralytics_tflite", "ultralytics_coreml"]
def export_model(model_id='', format='torchscript'):
def export_model(model_id="", format="torchscript"):
"""Export a model to all formats."""
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
r = requests.post(f'{HUB_API_ROOT}/v1/models/{model_id}/export',
json={'format': format},
headers={'x-api-key': Auth().api_key})
assert r.status_code == 200, f'{PREFIX}{format} export failure {r.status_code} {r.reason}'
LOGGER.info(f'{PREFIX}{format} export started ✅')
r = requests.post(
f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key}
)
assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}"
LOGGER.info(f"{PREFIX}{format} export started ✅")
def get_export(model_id='', format='torchscript'):
def get_export(model_id="", format="torchscript"):
"""Get an exported model dictionary with download URL."""
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
r = requests.post(f'{HUB_API_ROOT}/get-export',
json={
'apiKey': Auth().api_key,
'modelId': model_id,
'format': format},
headers={'x-api-key': Auth().api_key})
assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
r = requests.post(
f"{HUB_API_ROOT}/get-export",
json={"apiKey": Auth().api_key, "modelId": model_id, "format": format},
headers={"x-api-key": Auth().api_key},
)
assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}"
return r.json()
def check_dataset(path='', task='detect'):
def check_dataset(path="", task="detect"):
"""
Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is uploaded
to the HUB. Usage examples are given below.
@ -119,4 +120,4 @@ def check_dataset(path='', task='detect'):
```
"""
HUBDatasetStats(path=path, task=task).get_json()
LOGGER.info(f'Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.')
LOGGER.info(f"Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.")

@ -6,7 +6,7 @@ from hub_sdk import HUB_API_ROOT, HUB_WEB_ROOT
from ultralytics.hub.utils import PREFIX, request_with_credentials
from ultralytics.utils import LOGGER, SETTINGS, emojis, is_colab
API_KEY_URL = f'{HUB_WEB_ROOT}/settings?tab=api+keys'
API_KEY_URL = f"{HUB_WEB_ROOT}/settings?tab=api+keys"
class Auth:
@ -23,9 +23,10 @@ class Auth:
api_key (str or bool): API key for authentication, initialized as False.
model_key (bool): Placeholder for model key, initialized as False.
"""
id_token = api_key = model_key = False
def __init__(self, api_key='', verbose=False):
def __init__(self, api_key="", verbose=False):
"""
Initialize the Auth class with an optional API key.
@ -33,18 +34,18 @@ class Auth:
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
"""
# Split the input API key in case it contains a combined key_model and keep only the API key part
api_key = api_key.split('_')[0]
api_key = api_key.split("_")[0]
# Set API key attribute as value passed or SETTINGS API key if none passed
self.api_key = api_key or SETTINGS.get('api_key', '')
self.api_key = api_key or SETTINGS.get("api_key", "")
# If an API key is provided
if self.api_key:
# If the provided API key matches the API key in the SETTINGS
if self.api_key == SETTINGS.get('api_key'):
if self.api_key == SETTINGS.get("api_key"):
# Log that the user is already logged in
if verbose:
LOGGER.info(f'{PREFIX}Authenticated ✅')
LOGGER.info(f"{PREFIX}Authenticated ✅")
return
else:
# Attempt to authenticate with the provided API key
@ -59,12 +60,12 @@ class Auth:
# Update SETTINGS with the new API key after successful authentication
if success:
SETTINGS.update({'api_key': self.api_key})
SETTINGS.update({"api_key": self.api_key})
# Log that the new login was successful
if verbose:
LOGGER.info(f'{PREFIX}New authentication successful ✅')
LOGGER.info(f"{PREFIX}New authentication successful ✅")
elif verbose:
LOGGER.info(f'{PREFIX}Retrieve API key from {API_KEY_URL}')
LOGGER.info(f"{PREFIX}Retrieve API key from {API_KEY_URL}")
def request_api_key(self, max_attempts=3):
"""
@ -73,13 +74,14 @@ class Auth:
Returns the model ID.
"""
import getpass
for attempts in range(max_attempts):
LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}')
input_key = getpass.getpass(f'Enter API key from {API_KEY_URL} ')
self.api_key = input_key.split('_')[0] # remove model id if present
LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}")
input_key = getpass.getpass(f"Enter API key from {API_KEY_URL} ")
self.api_key = input_key.split("_")[0] # remove model id if present
if self.authenticate():
return True
raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌'))
raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
def authenticate(self) -> bool:
"""
@ -90,14 +92,14 @@ class Auth:
"""
try:
if header := self.get_auth_header():
r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header)
if not r.json().get('success', False):
raise ConnectionError('Unable to authenticate.')
r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header)
if not r.json().get("success", False):
raise ConnectionError("Unable to authenticate.")
return True
raise ConnectionError('User has not authenticated locally.')
raise ConnectionError("User has not authenticated locally.")
except ConnectionError:
self.id_token = self.api_key = False # reset invalid
LOGGER.warning(f'{PREFIX}Invalid API key ⚠')
LOGGER.warning(f"{PREFIX}Invalid API key ⚠")
return False
def auth_with_cookies(self) -> bool:
@ -111,12 +113,12 @@ class Auth:
if not is_colab():
return False # Currently only works with Colab
try:
authn = request_with_credentials(f'{HUB_API_ROOT}/v1/auth/auto')
if authn.get('success', False):
self.id_token = authn.get('data', {}).get('idToken', None)
authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto")
if authn.get("success", False):
self.id_token = authn.get("data", {}).get("idToken", None)
self.authenticate()
return True
raise ConnectionError('Unable to fetch browser authentication details.')
raise ConnectionError("Unable to fetch browser authentication details.")
except ConnectionError:
self.id_token = False # reset invalid
return False
@ -129,7 +131,7 @@ class Auth:
(dict): The authentication header if id_token or API key is set, None otherwise.
"""
if self.id_token:
return {'authorization': f'Bearer {self.id_token}'}
return {"authorization": f"Bearer {self.id_token}"}
elif self.api_key:
return {'x-api-key': self.api_key}
return {"x-api-key": self.api_key}
# else returns None

@ -12,16 +12,13 @@ from ultralytics.hub.utils import HELP_MSG, PREFIX, TQDM
from ultralytics.utils import LOGGER, SETTINGS, __version__, checks, emojis, is_colab
from ultralytics.utils.errors import HUBModelError
AGENT_NAME = (f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local')
AGENT_NAME = f"python-{__version__}-colab" if is_colab() else f"python-{__version__}-local"
class HUBTrainingSession:
"""
HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
Args:
url (str): Model identifier used to initialize the HUB training session.
Attributes:
agent_id (str): Identifier for the instance communicating with the server.
model_id (str): Identifier for the YOLO model being trained.
@ -40,17 +37,18 @@ class HUBTrainingSession:
Initialize the HUBTrainingSession with the provided model identifier.
Args:
url (str): Model identifier used to initialize the HUB training session.
It can be a URL string or a model key with specific format.
identifier (str): Model identifier used to initialize the HUB training session.
It can be a URL string or a model key with specific format.
Raises:
ValueError: If the provided model identifier is invalid.
ConnectionError: If connecting with global API key is not supported.
"""
self.rate_limits = {
'metrics': 3.0,
'ckpt': 900.0,
'heartbeat': 300.0, } # rate limits (seconds)
"metrics": 3.0,
"ckpt": 900.0,
"heartbeat": 300.0,
} # rate limits (seconds)
self.metrics_queue = {} # holds metrics for each epoch until upload
self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
@ -58,8 +56,8 @@ class HUBTrainingSession:
api_key, model_id, self.filename = self._parse_identifier(identifier)
# Get credentials
active_key = api_key or SETTINGS.get('api_key')
credentials = {'api_key': active_key} if active_key else None # set credentials
active_key = api_key or SETTINGS.get("api_key")
credentials = {"api_key": active_key} if active_key else None # set credentials
# Initialize client
self.client = HUBClient(credentials)
@ -72,35 +70,37 @@ class HUBTrainingSession:
def load_model(self, model_id):
# Initialize model
self.model = self.client.model(model_id)
self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}'
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
self._set_train_args()
# Start heartbeats for HUB to monitor agent
self.model.start_heartbeat(self.rate_limits['heartbeat'])
LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
self.model.start_heartbeat(self.rate_limits["heartbeat"])
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
def create_model(self, model_args):
# Initialize model
payload = {
'config': {
'batchSize': model_args.get('batch', -1),
'epochs': model_args.get('epochs', 300),
'imageSize': model_args.get('imgsz', 640),
'patience': model_args.get('patience', 100),
'device': model_args.get('device', ''),
'cache': model_args.get('cache', 'ram'), },
'dataset': {
'name': model_args.get('data')},
'lineage': {
'architecture': {
'name': self.filename.replace('.pt', '').replace('.yaml', ''), },
'parent': {}, },
'meta': {
'name': self.filename}, }
if self.filename.endswith('.pt'):
payload['lineage']['parent']['name'] = self.filename
"config": {
"batchSize": model_args.get("batch", -1),
"epochs": model_args.get("epochs", 300),
"imageSize": model_args.get("imgsz", 640),
"patience": model_args.get("patience", 100),
"device": model_args.get("device", ""),
"cache": model_args.get("cache", "ram"),
},
"dataset": {"name": model_args.get("data")},
"lineage": {
"architecture": {
"name": self.filename.replace(".pt", "").replace(".yaml", ""),
},
"parent": {},
},
"meta": {"name": self.filename},
}
if self.filename.endswith(".pt"):
payload["lineage"]["parent"]["name"] = self.filename
self.model.create_model(payload)
@ -109,12 +109,12 @@ class HUBTrainingSession:
if not self.model.id:
return
self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}'
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
# Start heartbeats for HUB to monitor agent
self.model.start_heartbeat(self.rate_limits['heartbeat'])
self.model.start_heartbeat(self.rate_limits["heartbeat"])
LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
def _parse_identifier(self, identifier):
"""
@ -125,13 +125,13 @@ class HUBTrainingSession:
- An identifier containing an API key and a model ID separated by an underscore
- An identifier that is solely a model ID of a fixed length
- A local filename that ends with '.pt' or '.yaml'
Args:
identifier (str): The identifier string to be parsed.
Returns:
(tuple): A tuple containing the API key, model ID, and filename as applicable.
Raises:
HUBModelError: If the identifier format is not recognized.
"""
@ -140,12 +140,12 @@ class HUBTrainingSession:
api_key, model_id, filename = None, None, None
# Check if identifier is a HUB URL
if identifier.startswith(f'{HUB_WEB_ROOT}/models/'):
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
# Extract the model_id after the HUB_WEB_ROOT URL
model_id = identifier.split(f'{HUB_WEB_ROOT}/models/')[-1]
model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
else:
# Split the identifier based on underscores only if it's not a HUB URL
parts = identifier.split('_')
parts = identifier.split("_")
# Check if identifier is in the format of API key and model ID
if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
@ -154,43 +154,46 @@ class HUBTrainingSession:
elif len(parts) == 1 and len(parts[0]) == 20:
model_id = parts[0]
# Check if identifier is a local filename
elif identifier.endswith('.pt') or identifier.endswith('.yaml'):
elif identifier.endswith(".pt") or identifier.endswith(".yaml"):
filename = identifier
else:
raise HUBModelError(
f"model='{identifier}' could not be parsed. Check format is correct. "
f'Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file.')
f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file."
)
return api_key, model_id, filename
def _set_train_args(self, **kwargs):
if self.model.is_trained():
# Model is already trained
raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀'))
raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀"))
if self.model.is_resumable():
# Model has saved weights
self.train_args = {'data': self.model.get_dataset_url(), 'resume': True}
self.model_file = self.model.get_weights_url('last')
self.train_args = {"data": self.model.get_dataset_url(), "resume": True}
self.model_file = self.model.get_weights_url("last")
else:
# Model has no saved weights
def get_train_args(config):
return {
'batch': config['batchSize'],
'epochs': config['epochs'],
'imgsz': config['imageSize'],
'patience': config['patience'],
'device': config['device'],
'cache': config['cache'],
'data': self.model.get_dataset_url(), }
self.train_args = get_train_args(self.model.data.get('config'))
"batch": config["batchSize"],
"epochs": config["epochs"],
"imgsz": config["imageSize"],
"patience": config["patience"],
"device": config["device"],
"cache": config["cache"],
"data": self.model.get_dataset_url(),
}
self.train_args = get_train_args(self.model.data.get("config"))
# Set the model file as either a *.pt or *.yaml file
self.model_file = (self.model.get_weights_url('parent')
if self.model.is_pretrained() else self.model.get_architecture())
self.model_file = (
self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture()
)
if not self.train_args.get('data'):
raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
if not self.train_args.get("data"):
raise ValueError("Dataset may still be processing. Please wait a minute and try again.") # RF fix
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
self.model_id = self.model.id
@ -206,12 +209,11 @@ class HUBTrainingSession:
*args,
**kwargs,
):
def retry_request():
t0 = time.time() # Record the start time for the timeout
for i in range(retry + 1):
if (time.time() - t0) > timeout:
LOGGER.warning(f'{PREFIX}Timeout for request reached. {HELP_MSG}')
LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}")
break # Timeout reached, exit loop
response = request_func(*args, **kwargs)
@ -219,8 +221,8 @@ class HUBTrainingSession:
self._show_upload_progress(progress_total, response)
if response is None:
LOGGER.warning(f'{PREFIX}Received no response from the request. {HELP_MSG}')
time.sleep(2 ** i) # Exponential backoff before retrying
LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}")
time.sleep(2**i) # Exponential backoff before retrying
continue # Skip further processing and retry
if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
@ -231,13 +233,13 @@ class HUBTrainingSession:
message = self._get_failure_message(response, retry, timeout)
if verbose:
LOGGER.warning(f'{PREFIX}{message} {HELP_MSG} ({response.status_code})')
LOGGER.warning(f"{PREFIX}{message} {HELP_MSG} ({response.status_code})")
if not self._should_retry(response.status_code):
LOGGER.warning(f'{PREFIX}Request failed. {HELP_MSG} ({response.status_code}')
LOGGER.warning(f"{PREFIX}Request failed. {HELP_MSG} ({response.status_code}")
break # Not an error that should be retried, exit loop
time.sleep(2 ** i) # Exponential backoff for retries
time.sleep(2**i) # Exponential backoff for retries
return response
@ -253,7 +255,8 @@ class HUBTrainingSession:
retry_codes = {
HTTPStatus.REQUEST_TIMEOUT,
HTTPStatus.BAD_GATEWAY,
HTTPStatus.GATEWAY_TIMEOUT, }
HTTPStatus.GATEWAY_TIMEOUT,
}
return True if status_code in retry_codes else False
def _get_failure_message(self, response: requests.Response, retry: int, timeout: int):
@ -269,16 +272,18 @@ class HUBTrainingSession:
str: The retry message.
"""
if self._should_retry(response.status_code):
return f'Retrying {retry}x for {timeout}s.' if retry else ''
return f"Retrying {retry}x for {timeout}s." if retry else ""
elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit
headers = response.headers
return (f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
f"Please retry after {headers['Retry-After']}s.")
return (
f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
f"Please retry after {headers['Retry-After']}s."
)
else:
try:
return response.json().get('message', 'No JSON message.')
return response.json().get("message", "No JSON message.")
except AttributeError:
return 'Unable to read JSON.'
return "Unable to read JSON."
def upload_metrics(self):
"""Upload model metrics to Ultralytics HUB."""
@ -303,7 +308,7 @@ class HUBTrainingSession:
final (bool): Indicates if the model is the final model after training.
"""
if Path(weights).is_file():
progress_total = (Path(weights).stat().st_size if final else None) # Only show progress if final
progress_total = Path(weights).stat().st_size if final else None # Only show progress if final
self.request_queue(
self.model.upload_model,
epoch=epoch,
@ -317,7 +322,7 @@ class HUBTrainingSession:
progress_total=progress_total,
)
else:
LOGGER.warning(f'{PREFIX}WARNING ⚠ Model upload issue. Missing model {weights}.')
LOGGER.warning(f"{PREFIX}WARNING ⚠ Model upload issue. Missing model {weights}.")
def _show_upload_progress(self, content_length: int, response: requests.Response) -> None:
"""
@ -330,6 +335,6 @@ class HUBTrainingSession:
Returns:
(None)
"""
with TQDM(total=content_length, unit='B', unit_scale=True, unit_divisor=1024) as pbar:
with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:
for data in response.iter_content(chunk_size=1024):
pbar.update(len(data))

@ -9,12 +9,26 @@ from pathlib import Path
import requests
from ultralytics.utils import (ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING, TQDM, TryExcept, __version__,
colorstr, get_git_origin_url, is_colab, is_git_dir, is_pip_package)
from ultralytics.utils import (
ENVIRONMENT,
LOGGER,
ONLINE,
RANK,
SETTINGS,
TESTS_RUNNING,
TQDM,
TryExcept,
__version__,
colorstr,
get_git_origin_url,
is_colab,
is_git_dir,
is_pip_package,
)
from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES
PREFIX = colorstr('Ultralytics HUB: ')
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
PREFIX = colorstr("Ultralytics HUB: ")
HELP_MSG = "If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance."
def request_with_credentials(url: str) -> any:
@ -31,11 +45,13 @@ def request_with_credentials(url: str) -> any:
OSError: If the function is not run in a Google Colab environment.
"""
if not is_colab():
raise OSError('request_with_credentials() must run in a Colab environment')
raise OSError("request_with_credentials() must run in a Colab environment")
from google.colab import output # noqa
from IPython import display # noqa
display.display(
display.Javascript("""
display.Javascript(
"""
window._hub_tmp = new Promise((resolve, reject) => {
const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000)
fetch("%s", {
@ -50,8 +66,11 @@ def request_with_credentials(url: str) -> any:
reject(err);
});
});
""" % url))
return output.eval_js('_hub_tmp')
"""
% url
)
)
return output.eval_js("_hub_tmp")
def requests_with_progress(method, url, **kwargs):
@ -71,13 +90,13 @@ def requests_with_progress(method, url, **kwargs):
content length.
- If 'progress' is a number then progress bar will display assuming content length = progress.
"""
progress = kwargs.pop('progress', False)
progress = kwargs.pop("progress", False)
if not progress:
return requests.request(method, url, **kwargs)
response = requests.request(method, url, stream=True, **kwargs)
total = int(response.headers.get('content-length', 0) if isinstance(progress, bool) else progress) # total size
total = int(response.headers.get("content-length", 0) if isinstance(progress, bool) else progress) # total size
try:
pbar = TQDM(total=total, unit='B', unit_scale=True, unit_divisor=1024)
pbar = TQDM(total=total, unit="B", unit_scale=True, unit_divisor=1024)
for data in response.iter_content(chunk_size=1024):
pbar.update(len(data))
pbar.close()
@ -118,25 +137,27 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos
if r.status_code < 300: # return codes in the 2xx range are generally considered "good" or "successful"
break
try:
m = r.json().get('message', 'No JSON message.')
m = r.json().get("message", "No JSON message.")
except AttributeError:
m = 'Unable to read JSON.'
m = "Unable to read JSON."
if i == 0:
if r.status_code in retry_codes:
m += f' Retrying {retry}x for {timeout}s.' if retry else ''
m += f" Retrying {retry}x for {timeout}s." if retry else ""
elif r.status_code == 429: # rate limit
h = r.headers # response headers
m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \
m = (
f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). "
f"Please retry after {h['Retry-After']}s."
)
if verbose:
LOGGER.warning(f'{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})')
LOGGER.warning(f"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})")
if r.status_code not in retry_codes:
return r
time.sleep(2 ** i) # exponential standoff
time.sleep(2**i) # exponential standoff
return r
args = method, url
kwargs['progress'] = progress
kwargs["progress"] = progress
if thread:
threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
else:
@ -155,7 +176,7 @@ class Events:
enabled (bool): A flag to enable or disable Events based on certain conditions.
"""
url = 'https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw'
url = "https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw"
def __init__(self):
"""Initializes the Events object with default values for events, rate_limit, and metadata."""
@ -163,19 +184,21 @@ class Events:
self.rate_limit = 60.0 # rate limit (seconds)
self.t = 0.0 # rate limit timer (seconds)
self.metadata = {
'cli': Path(sys.argv[0]).name == 'yolo',
'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
'python': '.'.join(platform.python_version_tuple()[:2]), # i.e. 3.10
'version': __version__,
'env': ENVIRONMENT,
'session_id': round(random.random() * 1E15),
'engagement_time_msec': 1000}
self.enabled = \
SETTINGS['sync'] and \
RANK in (-1, 0) and \
not TESTS_RUNNING and \
ONLINE and \
(is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git')
"cli": Path(sys.argv[0]).name == "yolo",
"install": "git" if is_git_dir() else "pip" if is_pip_package() else "other",
"python": ".".join(platform.python_version_tuple()[:2]), # i.e. 3.10
"version": __version__,
"env": ENVIRONMENT,
"session_id": round(random.random() * 1e15),
"engagement_time_msec": 1000,
}
self.enabled = (
SETTINGS["sync"]
and RANK in (-1, 0)
and not TESTS_RUNNING
and ONLINE
and (is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git")
)
def __call__(self, cfg):
"""
@ -191,11 +214,13 @@ class Events:
# Attempt to add to events
if len(self.events) < 25: # Events list limited to 25 events (drop any events past this)
params = {
**self.metadata, 'task': cfg.task,
'model': cfg.model if cfg.model in GITHUB_ASSETS_NAMES else 'custom'}
if cfg.mode == 'export':
params['format'] = cfg.format
self.events.append({'name': cfg.mode, 'params': params})
**self.metadata,
"task": cfg.task,
"model": cfg.model if cfg.model in GITHUB_ASSETS_NAMES else "custom",
}
if cfg.mode == "export":
params["format"] = cfg.format
self.events.append({"name": cfg.mode, "params": params})
# Check rate limit
t = time.time()
@ -204,10 +229,10 @@ class Events:
return
# Time is over rate limiter, send now
data = {'client_id': SETTINGS['uuid'], 'events': self.events} # SHA-256 anonymized UUID hash and events list
data = {"client_id": SETTINGS["uuid"], "events": self.events} # SHA-256 anonymized UUID hash and events list
# POST equivalent to requests.post(self.url, json=data)
smart_request('post', self.url, json=data, retry=0, verbose=False)
smart_request("post", self.url, json=data, retry=0, verbose=False)
# Reset events and rate limit timer
self.events = []

@ -4,4 +4,4 @@ from .rtdetr import RTDETR
from .sam import SAM
from .yolo import YOLO
__all__ = 'YOLO', 'RTDETR', 'SAM' # allow simpler import
__all__ = "YOLO", "RTDETR", "SAM" # allow simpler import

@ -5,4 +5,4 @@ from .predict import FastSAMPredictor
from .prompt import FastSAMPrompt
from .val import FastSAMValidator
__all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMValidator'
__all__ = "FastSAMPredictor", "FastSAM", "FastSAMPrompt", "FastSAMValidator"

@ -21,14 +21,14 @@ class FastSAM(Model):
```
"""
def __init__(self, model='FastSAM-x.pt'):
def __init__(self, model="FastSAM-x.pt"):
"""Call the __init__ method of the parent class (YOLO) with the updated default model."""
if str(model) == 'FastSAM.pt':
model = 'FastSAM-x.pt'
assert Path(model).suffix not in ('.yaml', '.yml'), 'FastSAM models only support pre-trained models.'
super().__init__(model=model, task='segment')
if str(model) == "FastSAM.pt":
model = "FastSAM-x.pt"
assert Path(model).suffix not in (".yaml", ".yml"), "FastSAM models only support pre-trained models."
super().__init__(model=model, task="segment")
@property
def task_map(self):
"""Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
return {'segment': {'predictor': FastSAMPredictor, 'validator': FastSAMValidator}}
return {"segment": {"predictor": FastSAMPredictor, "validator": FastSAMValidator}}

@ -33,7 +33,7 @@ class FastSAMPredictor(DetectionPredictor):
_callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
"""
super().__init__(cfg, overrides, _callbacks)
self.args.task = 'segment'
self.args.task = "segment"
def postprocess(self, preds, img, orig_imgs):
"""
@ -55,7 +55,8 @@ class FastSAMPredictor(DetectionPredictor):
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nc=1, # set to 1 class since SAM has no class predictions
classes=self.args.classes)
classes=self.args.classes,
)
full_box = torch.zeros(p[0].shape[1], device=p[0].device)
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
full_box = full_box.view(1, -1)

@ -23,7 +23,7 @@ class FastSAMPrompt:
clip: CLIP model for linear assignment.
"""
def __init__(self, source, results, device='cuda') -> None:
def __init__(self, source, results, device="cuda") -> None:
"""Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment."""
self.device = device
self.results = results
@ -34,7 +34,8 @@ class FastSAMPrompt:
import clip # for linear_assignment
except ImportError:
from ultralytics.utils.checks import check_requirements
check_requirements('git+https://github.com/openai/CLIP.git')
check_requirements("git+https://github.com/openai/CLIP.git")
import clip
self.clip = clip
@ -46,11 +47,11 @@ class FastSAMPrompt:
x1, y1, x2, y2 = bbox
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
segmented_image = Image.fromarray(segmented_image_array)
black_image = Image.new('RGB', image.size, (255, 255, 255))
black_image = Image.new("RGB", image.size, (255, 255, 255))
# transparency_mask = np.zeros_like((), dtype=np.uint8)
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
transparency_mask[y1:y2, x1:x2] = 255
transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
black_image.paste(segmented_image, mask=transparency_mask_image)
return black_image
@ -65,11 +66,12 @@ class FastSAMPrompt:
mask = result.masks.data[i] == 1.0
if torch.sum(mask) >= filter:
annotation = {
'id': i,
'segmentation': mask.cpu().numpy(),
'bbox': result.boxes.data[i],
'score': result.boxes.conf[i]}
annotation['area'] = annotation['segmentation'].sum()
"id": i,
"segmentation": mask.cpu().numpy(),
"bbox": result.boxes.data[i],
"score": result.boxes.conf[i],
}
annotation["area"] = annotation["segmentation"].sum()
annotations.append(annotation)
return annotations
@ -91,16 +93,18 @@ class FastSAMPrompt:
y2 = max(y2, y_t + h_t)
return [x1, y1, x2, y2]
def plot(self,
annotations,
output,
bbox=None,
points=None,
point_label=None,
mask_random_color=True,
better_quality=True,
retina=False,
with_contours=True):
def plot(
self,
annotations,
output,
bbox=None,
points=None,
point_label=None,
mask_random_color=True,
better_quality=True,
retina=False,
with_contours=True,
):
"""
Plots annotations, bounding boxes, and points on images and saves the output.
@ -139,15 +143,17 @@ class FastSAMPrompt:
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
self.fast_show_mask(masks,
plt.gca(),
random_color=mask_random_color,
bbox=bbox,
points=points,
pointlabel=point_label,
retinamask=retina,
target_height=original_h,
target_width=original_w)
self.fast_show_mask(
masks,
plt.gca(),
random_color=mask_random_color,
bbox=bbox,
points=points,
pointlabel=point_label,
retinamask=retina,
target_height=original_h,
target_width=original_w,
)
if with_contours:
contour_all = []
@ -166,10 +172,10 @@ class FastSAMPrompt:
# Save the figure
save_path = Path(output) / result_name
save_path.parent.mkdir(exist_ok=True, parents=True)
plt.axis('off')
plt.savefig(save_path, bbox_inches='tight', pad_inches=0, transparent=True)
plt.axis("off")
plt.savefig(save_path, bbox_inches="tight", pad_inches=0, transparent=True)
plt.close()
pbar.set_description(f'Saving {result_name} to {save_path}')
pbar.set_description(f"Saving {result_name} to {save_path}")
@staticmethod
def fast_show_mask(
@ -212,26 +218,26 @@ class FastSAMPrompt:
mask_image = np.expand_dims(annotation, -1) * visual
show = np.zeros((h, w, 4))
h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing="ij")
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
show[h_indices, w_indices, :] = mask_image[indices]
if bbox is not None:
x1, y1, x2, y2 = bbox
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1))
# Draw point
if points is not None:
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
s=20,
c='y',
c="y",
)
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
s=20,
c='m',
c="m",
)
if not retinamask:
@ -258,7 +264,7 @@ class FastSAMPrompt:
image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB))
ori_w, ori_h = image.size
annotations = format_results
mask_h, mask_w = annotations[0]['segmentation'].shape
mask_h, mask_w = annotations[0]["segmentation"].shape
if ori_w != mask_w or ori_h != mask_h:
image = image.resize((mask_w, mask_h))
cropped_boxes = []
@ -266,19 +272,19 @@ class FastSAMPrompt:
not_crop = []
filter_id = []
for _, mask in enumerate(annotations):
if np.sum(mask['segmentation']) <= 100:
if np.sum(mask["segmentation"]) <= 100:
filter_id.append(_)
continue
bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox
cropped_boxes.append(self._segment_image(image, bbox)) # 保存裁剪的图片
cropped_images.append(bbox) # 保存裁剪的图片的bbox
bbox = self._get_bbox_from_mask(mask["segmentation"]) # bbox from mask
cropped_boxes.append(self._segment_image(image, bbox)) # save cropped image
cropped_images.append(bbox) # save cropped image bbox
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
def box_prompt(self, bbox):
"""Modifies the bounding box properties and calculates IoU between masks and bounding box."""
if self.results[0].masks is not None:
assert (bbox[2] != 0 and bbox[3] != 0)
assert bbox[2] != 0 and bbox[3] != 0
if os.path.isdir(self.source):
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
masks = self.results[0].masks.data
@ -290,7 +296,8 @@ class FastSAMPrompt:
int(bbox[0] * w / target_width),
int(bbox[1] * h / target_height),
int(bbox[2] * w / target_width),
int(bbox[3] * h / target_height), ]
int(bbox[3] * h / target_height),
]
bbox[0] = max(round(bbox[0]), 0)
bbox[1] = max(round(bbox[1]), 0)
bbox[2] = min(round(bbox[2]), w)
@ -299,7 +306,7 @@ class FastSAMPrompt:
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2))
masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
orig_masks_area = torch.sum(masks, dim=(1, 2))
union = bbox_area + orig_masks_area - masks_area
@ -316,13 +323,13 @@ class FastSAMPrompt:
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
masks = self._format_results(self.results[0], 0)
target_height, target_width = self.results[0].orig_shape
h = masks[0]['segmentation'].shape[0]
w = masks[0]['segmentation'].shape[1]
h = masks[0]["segmentation"].shape[0]
w = masks[0]["segmentation"].shape[1]
if h != target_height or w != target_width:
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
onemask = np.zeros((h, w))
for annotation in masks:
mask = annotation['segmentation'] if isinstance(annotation, dict) else annotation
mask = annotation["segmentation"] if isinstance(annotation, dict) else annotation
for i, point in enumerate(points):
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
onemask += mask
@ -337,12 +344,12 @@ class FastSAMPrompt:
if self.results[0].masks is not None:
format_results = self._format_results(self.results[0], 0)
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device)
clip_model, preprocess = self.clip.load("ViT-B/32", device=self.device)
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
max_idx = scores.argsort()
max_idx = max_idx[-1]
max_idx += sum(np.array(filter_id) <= int(max_idx))
self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]['segmentation']]))
self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]["segmentation"]]))
return self.results
def everything_prompt(self):

@ -35,6 +35,6 @@ class FastSAMValidator(SegmentationValidator):
Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.
"""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'segment'
self.args.task = "segment"
self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)

@ -4,4 +4,4 @@ from .model import NAS
from .predict import NASPredictor
from .val import NASValidator
__all__ = 'NASPredictor', 'NASValidator', 'NAS'
__all__ = "NASPredictor", "NASValidator", "NAS"

@ -44,20 +44,21 @@ class NAS(Model):
YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
"""
def __init__(self, model='yolo_nas_s.pt') -> None:
def __init__(self, model="yolo_nas_s.pt") -> None:
"""Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.'
super().__init__(model, task='detect')
assert Path(model).suffix not in (".yaml", ".yml"), "YOLO-NAS models only support pre-trained models."
super().__init__(model, task="detect")
@smart_inference_mode()
def _load(self, weights: str, task: str):
"""Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
import super_gradients
suffix = Path(weights).suffix
if suffix == '.pt':
if suffix == ".pt":
self.model = torch.load(weights)
elif suffix == '':
self.model = super_gradients.training.models.get(weights, pretrained_weights='coco')
elif suffix == "":
self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
# Standardize model
self.model.fuse = lambda verbose=True: self.model
self.model.stride = torch.tensor([32])
@ -65,7 +66,7 @@ class NAS(Model):
self.model.is_fused = lambda: False # for info()
self.model.yaml = {} # for info()
self.model.pt_path = weights # for export()
self.model.task = 'detect' # for export()
self.model.task = "detect" # for export()
def info(self, detailed=False, verbose=True):
"""
@ -80,4 +81,4 @@ class NAS(Model):
@property
def task_map(self):
"""Returns a dictionary mapping tasks to respective predictor and validator classes."""
return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}}
return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}

@ -39,12 +39,14 @@ class NASPredictor(BasePredictor):
boxes = ops.xyxy2xywh(preds_in[0][0])
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
preds = ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
classes=self.args.classes)
preds = ops.non_max_suppression(
preds,
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
classes=self.args.classes,
)
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

@ -5,7 +5,7 @@ import torch
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import ops
__all__ = ['NASValidator']
__all__ = ["NASValidator"]
class NASValidator(DetectionValidator):
@ -38,11 +38,13 @@ class NASValidator(DetectionValidator):
"""Apply Non-maximum suppression to prediction outputs."""
boxes = ops.xyxy2xywh(preds_in[0][0])
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
return ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=False,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
max_time_img=0.5)
return ops.non_max_suppression(
preds,
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=False,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
max_time_img=0.5,
)

@ -4,4 +4,4 @@ from .model import RTDETR
from .predict import RTDETRPredictor
from .val import RTDETRValidator
__all__ = 'RTDETRPredictor', 'RTDETRValidator', 'RTDETR'
__all__ = "RTDETRPredictor", "RTDETRValidator", "RTDETR"

@ -24,7 +24,7 @@ class RTDETR(Model):
model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'.
"""
def __init__(self, model='rtdetr-l.pt') -> None:
def __init__(self, model="rtdetr-l.pt") -> None:
"""
Initializes the RT-DETR model with the given pre-trained model file. Supports .pt and .yaml formats.
@ -34,9 +34,9 @@ class RTDETR(Model):
Raises:
NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
"""
if model and model.split('.')[-1] not in ('pt', 'yaml', 'yml'):
raise NotImplementedError('RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.')
super().__init__(model=model, task='detect')
if model and model.split(".")[-1] not in ("pt", "yaml", "yml"):
raise NotImplementedError("RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.")
super().__init__(model=model, task="detect")
@property
def task_map(self) -> dict:
@ -47,8 +47,10 @@ class RTDETR(Model):
dict: A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
"""
return {
'detect': {
'predictor': RTDETRPredictor,
'validator': RTDETRValidator,
'trainer': RTDETRTrainer,
'model': RTDETRDetectionModel}}
"detect": {
"predictor": RTDETRPredictor,
"validator": RTDETRValidator,
"trainer": RTDETRTrainer,
"model": RTDETRDetectionModel,
}
}

@ -43,12 +43,12 @@ class RTDETRTrainer(DetectionTrainer):
Returns:
(RTDETRDetectionModel): Initialized model.
"""
model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def build_dataset(self, img_path, mode='val', batch=None):
def build_dataset(self, img_path, mode="val", batch=None):
"""
Build and return an RT-DETR dataset for training or validation.
@ -60,15 +60,17 @@ class RTDETRTrainer(DetectionTrainer):
Returns:
(RTDETRDataset): Dataset object for the specific mode.
"""
return RTDETRDataset(img_path=img_path,
imgsz=self.args.imgsz,
batch_size=batch,
augment=mode == 'train',
hyp=self.args,
rect=False,
cache=self.args.cache or None,
prefix=colorstr(f'{mode}: '),
data=self.data)
return RTDETRDataset(
img_path=img_path,
imgsz=self.args.imgsz,
batch_size=batch,
augment=mode == "train",
hyp=self.args,
rect=False,
cache=self.args.cache or None,
prefix=colorstr(f"{mode}: "),
data=self.data,
)
def get_validator(self):
"""
@ -77,7 +79,7 @@ class RTDETRTrainer(DetectionTrainer):
Returns:
(RTDETRValidator): Validator object for model validation.
"""
self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'
self.loss_names = "giou_loss", "cls_loss", "l1_loss"
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def preprocess_batch(self, batch):
@ -91,10 +93,10 @@ class RTDETRTrainer(DetectionTrainer):
(dict): Preprocessed batch.
"""
batch = super().preprocess_batch(batch)
bs = len(batch['img'])
batch_idx = batch['batch_idx']
bs = len(batch["img"])
batch_idx = batch["batch_idx"]
gt_bbox, gt_class = [], []
for i in range(bs):
gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
return batch

@ -7,7 +7,7 @@ from ultralytics.data.augment import Compose, Format, v8_transforms
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import colorstr, ops
__all__ = 'RTDETRValidator', # tuple or list
__all__ = ("RTDETRValidator",) # tuple or list
class RTDETRDataset(YOLODataset):
@ -37,13 +37,16 @@ class RTDETRDataset(YOLODataset):
# transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)])
transforms = Compose([])
transforms.append(
Format(bbox_format='xywh',
normalize=True,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
batch_idx=True,
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask))
Format(
bbox_format="xywh",
normalize=True,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
batch_idx=True,
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask,
)
)
return transforms
@ -68,7 +71,7 @@ class RTDETRValidator(DetectionValidator):
For further details on the attributes and methods, refer to the parent DetectionValidator class.
"""
def build_dataset(self, img_path, mode='val', batch=None):
def build_dataset(self, img_path, mode="val", batch=None):
"""
Build an RTDETR Dataset.
@ -85,8 +88,9 @@ class RTDETRValidator(DetectionValidator):
hyp=self.args,
rect=False, # no rect
cache=self.args.cache or None,
prefix=colorstr(f'{mode}: '),
data=self.data)
prefix=colorstr(f"{mode}: "),
data=self.data,
)
def postprocess(self, preds):
"""Apply Non-maximum suppression to prediction outputs."""
@ -108,12 +112,12 @@ class RTDETRValidator(DetectionValidator):
def _prepare_batch(self, si, batch):
"""Prepares a batch for training or inference by applying transformations."""
idx = batch['batch_idx'] == si
cls = batch['cls'][idx].squeeze(-1)
bbox = batch['bboxes'][idx]
ori_shape = batch['ori_shape'][si]
imgsz = batch['img'].shape[2:]
ratio_pad = batch['ratio_pad'][si]
idx = batch["batch_idx"] == si
cls = batch["cls"][idx].squeeze(-1)
bbox = batch["bboxes"][idx]
ori_shape = batch["ori_shape"][si]
imgsz = batch["img"].shape[2:]
ratio_pad = batch["ratio_pad"][si]
if len(cls):
bbox = ops.xywh2xyxy(bbox) # target boxes
bbox[..., [0, 2]] *= ori_shape[1] # native-space pred
@ -124,6 +128,6 @@ class RTDETRValidator(DetectionValidator):
def _prepare_pred(self, pred, pbatch):
"""Prepares and returns a batch with transformed bounding boxes and class labels."""
predn = pred.clone()
predn[..., [0, 2]] *= pbatch['ori_shape'][1] / self.args.imgsz # native-space pred
predn[..., [1, 3]] *= pbatch['ori_shape'][0] / self.args.imgsz # native-space pred
predn[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred
predn[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred
return predn.float()

@ -3,4 +3,4 @@
from .model import SAM
from .predict import Predictor
__all__ = 'SAM', 'Predictor' # tuple or list
__all__ = "SAM", "Predictor" # tuple or list

@ -8,10 +8,9 @@ import numpy as np
import torch
def is_box_near_crop_edge(boxes: torch.Tensor,
crop_box: List[int],
orig_box: List[int],
atol: float = 20.0) -> torch.Tensor:
def is_box_near_crop_edge(
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
) -> torch.Tensor:
"""Return a boolean tensor indicating if boxes are near the crop edge."""
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
@ -24,10 +23,10 @@ def is_box_near_crop_edge(boxes: torch.Tensor,
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
"""Yield batches of data from the input arguments."""
assert args and all(len(a) == len(args[0]) for a in args), 'Batched iteration must have same-size inputs.'
assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
for b in range(n_batches):
yield [arg[b * batch_size:(b + 1) * batch_size] for arg in args]
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
@ -39,9 +38,8 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
"""
# One mask is always contained inside the other.
# Save memory by preventing unnecessary cast to torch.int64
intersections = ((masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1,
dtype=torch.int32))
unions = ((masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32))
intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
return intersections / unions
@ -56,11 +54,12 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
"""Generate point grids for all crop layers."""
return [build_point_grid(int(n_per_side / (scale_per_layer ** i))) for i in range(n_layers + 1)]
return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
def generate_crop_boxes(im_size: Tuple[int, ...], n_layers: int,
overlap_ratio: float) -> Tuple[List[List[int]], List[int]]:
def generate_crop_boxes(
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
) -> Tuple[List[List[int]], List[int]]:
"""
Generates a list of crop boxes of different sizes.
@ -132,8 +131,8 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
"""Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator."""
import cv2 # type: ignore
assert mode in {'holes', 'islands'}
correct_holes = mode == 'holes'
assert mode in {"holes", "islands"}
correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label

@ -64,46 +64,47 @@ def build_mobile_sam(checkpoint=None):
)
def _build_sam(encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
mobile_sam=False):
def _build_sam(
encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, mobile_sam=False
):
"""Builds the selected SAM model architecture."""
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
image_encoder = (TinyViT(
img_size=1024,
in_chans=3,
num_classes=1000,
embed_dims=encoder_embed_dim,
depths=encoder_depth,
num_heads=encoder_num_heads,
window_sizes=[7, 7, 14, 7],
mlp_ratio=4.0,
drop_rate=0.0,
drop_path_rate=0.0,
use_checkpoint=False,
mbconv_expand_ratio=4.0,
local_conv_size=3,
layer_lr_decay=0.8,
) if mobile_sam else ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
))
image_encoder = (
TinyViT(
img_size=1024,
in_chans=3,
num_classes=1000,
embed_dims=encoder_embed_dim,
depths=encoder_depth,
num_heads=encoder_num_heads,
window_sizes=[7, 7, 14, 7],
mlp_ratio=4.0,
drop_rate=0.0,
drop_path_rate=0.0,
use_checkpoint=False,
mbconv_expand_ratio=4.0,
local_conv_size=3,
layer_lr_decay=0.8,
)
if mobile_sam
else ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
)
)
sam = Sam(
image_encoder=image_encoder,
prompt_encoder=PromptEncoder(
@ -129,7 +130,7 @@ def _build_sam(encoder_embed_dim,
)
if checkpoint is not None:
checkpoint = attempt_download_asset(checkpoint)
with open(checkpoint, 'rb') as f:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
sam.load_state_dict(state_dict)
sam.eval()
@ -139,13 +140,14 @@ def _build_sam(encoder_embed_dim,
sam_model_map = {
'sam_h.pt': build_sam_vit_h,
'sam_l.pt': build_sam_vit_l,
'sam_b.pt': build_sam_vit_b,
'mobile_sam.pt': build_mobile_sam, }
"sam_h.pt": build_sam_vit_h,
"sam_l.pt": build_sam_vit_l,
"sam_b.pt": build_sam_vit_b,
"mobile_sam.pt": build_mobile_sam,
}
def build_sam(ckpt='sam_b.pt'):
def build_sam(ckpt="sam_b.pt"):
"""Build a SAM model specified by ckpt."""
model_builder = None
ckpt = str(ckpt) # to allow Path ckpt types
@ -154,6 +156,6 @@ def build_sam(ckpt='sam_b.pt'):
model_builder = sam_model_map.get(k)
if not model_builder:
raise FileNotFoundError(f'{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}')
raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}")
return model_builder(ckpt)

@ -32,7 +32,7 @@ class SAM(Model):
dataset.
"""
def __init__(self, model='sam_b.pt') -> None:
def __init__(self, model="sam_b.pt") -> None:
"""
Initializes the SAM model with a pre-trained model file.
@ -42,9 +42,9 @@ class SAM(Model):
Raises:
NotImplementedError: If the model file extension is not .pt or .pth.
"""
if model and Path(model).suffix not in ('.pt', '.pth'):
raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.')
super().__init__(model=model, task='segment')
if model and Path(model).suffix not in (".pt", ".pth"):
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
super().__init__(model=model, task="segment")
def _load(self, weights: str, task=None):
"""
@ -70,7 +70,7 @@ class SAM(Model):
Returns:
(list): The model predictions.
"""
overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
kwargs.update(overrides)
prompts = dict(bboxes=bboxes, points=points, labels=labels)
return super().predict(source, stream, prompts=prompts, **kwargs)
@ -112,4 +112,4 @@ class SAM(Model):
Returns:
(dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
"""
return {'segment': {'predictor': Predictor}}
return {"segment": {"predictor": Predictor}}

@ -64,8 +64,9 @@ class MaskDecoder(nn.Module):
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
activation(),
)
self.output_hypernetworks_mlps = nn.ModuleList([
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)])
self.output_hypernetworks_mlps = nn.ModuleList(
[MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
)
self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)
@ -132,13 +133,14 @@ class MaskDecoder(nn.Module):
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :]
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
hyper_in_list: List[torch.Tensor] = [
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)]
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
]
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)

@ -28,23 +28,23 @@ class ImageEncoderViT(nn.Module):
"""
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
Args:
@ -283,9 +283,9 @@ class PromptEncoder(nn.Module):
if masks is not None:
dense_embeddings = self._embed_masks(masks)
else:
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1,
1).expand(bs, -1, self.image_embedding_size[0],
self.image_embedding_size[1])
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)
return sparse_embeddings, dense_embeddings
@ -298,7 +298,7 @@ class PositionEmbeddingRandom(nn.Module):
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
self.register_buffer('positional_encoding_gaussian_matrix', scale * torch.randn((2, num_pos_feats)))
self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)))
# Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'
torch.use_deterministic_algorithms(False)
@ -425,14 +425,14 @@ class Attention(nn.Module):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (input_size is not None), 'Input size must be provided if using relative positional encoding.'
assert input_size is not None, "Input size must be provided if using relative positional encoding."
# Initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
@ -479,8 +479,9 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T
return windows, (Hp, Wp)
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int],
hw: Tuple[int, int]) -> torch.Tensor:
def window_unpartition(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
@ -523,7 +524,7 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode='linear',
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
@ -567,11 +568,12 @@ def add_decomposed_rel_pos(
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh)
rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
B, q_h * q_w, k_h * k_w)
B, q_h * q_w, k_h * k_w
)
return attn
@ -580,12 +582,12 @@ class PatchEmbed(nn.Module):
"""Image to Patch Embedding."""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Initialize PatchEmbed module.

@ -30,8 +30,9 @@ class Sam(nn.Module):
pixel_mean (List[float]): Mean pixel values for image normalization.
pixel_std (List[float]): Standard deviation values for image normalization.
"""
mask_threshold: float = 0.0
image_format: str = 'RGB'
image_format: str = "RGB"
def __init__(
self,
@ -39,7 +40,7 @@ class Sam(nn.Module):
prompt_encoder: PromptEncoder,
mask_decoder: MaskDecoder,
pixel_mean: List[float] = (123.675, 116.28, 103.53),
pixel_std: List[float] = (58.395, 57.12, 57.375)
pixel_std: List[float] = (58.395, 57.12, 57.375),
) -> None:
"""
Initialize the Sam class to predict object masks from an image and input prompts.
@ -60,5 +61,5 @@ class Sam(nn.Module):
self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder
self.mask_decoder = mask_decoder
self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False)
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)

@ -28,11 +28,11 @@ class Conv2d_BN(torch.nn.Sequential):
drop path.
"""
super().__init__()
self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
bn = torch.nn.BatchNorm2d(b)
torch.nn.init.constant_(bn.weight, bn_weight_init)
torch.nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
self.add_module("bn", bn)
class PatchEmbed(nn.Module):
@ -146,11 +146,11 @@ class ConvLayer(nn.Module):
input_resolution,
depth,
activation,
drop_path=0.,
drop_path=0.0,
downsample=None,
use_checkpoint=False,
out_dim=None,
conv_expand_ratio=4.,
conv_expand_ratio=4.0,
):
"""
Initializes the ConvLayer with the given dimensions and settings.
@ -173,18 +173,25 @@ class ConvLayer(nn.Module):
self.use_checkpoint = use_checkpoint
# Build blocks
self.blocks = nn.ModuleList([
MBConv(
dim,
dim,
conv_expand_ratio,
activation,
drop_path[i] if isinstance(drop_path, list) else drop_path,
) for i in range(depth)])
self.blocks = nn.ModuleList(
[
MBConv(
dim,
dim,
conv_expand_ratio,
activation,
drop_path[i] if isinstance(drop_path, list) else drop_path,
)
for i in range(depth)
]
)
# Patch merging layer
self.downsample = None if downsample is None else downsample(
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
self.downsample = (
None
if downsample is None
else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
)
def forward(self, x):
"""Processes the input through a series of convolutional layers and returns the activated output."""
@ -200,7 +207,7 @@ class Mlp(nn.Module):
This layer takes an input with in_features, applies layer normalization and two fully-connected layers.
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
"""Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc."""
super().__init__()
out_features = out_features or in_features
@ -232,12 +239,12 @@ class Attention(torch.nn.Module):
"""
def __init__(
self,
dim,
key_dim,
num_heads=8,
attn_ratio=4,
resolution=(14, 14),
self,
dim,
key_dim,
num_heads=8,
attn_ratio=4,
resolution=(14, 14),
):
"""
Initializes the Attention module.
@ -256,7 +263,7 @@ class Attention(torch.nn.Module):
assert isinstance(resolution, tuple) and len(resolution) == 2
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.scale = key_dim**-0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
@ -279,13 +286,13 @@ class Attention(torch.nn.Module):
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False)
self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False)
@torch.no_grad()
def train(self, mode=True):
"""Sets the module in training mode and handles attribute 'ab' based on the mode."""
super().train(mode)
if mode and hasattr(self, 'ab'):
if mode and hasattr(self, "ab"):
del self.ab
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
@ -306,8 +313,9 @@ class Attention(torch.nn.Module):
v = v.permute(0, 2, 1, 3)
self.ab = self.ab.to(self.attention_biases.device)
attn = ((q @ k.transpose(-2, -1)) * self.scale +
(self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab))
attn = (q @ k.transpose(-2, -1)) * self.scale + (
self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
return self.proj(x)
@ -322,9 +330,9 @@ class TinyViTBlock(nn.Module):
input_resolution,
num_heads,
window_size=7,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
mlp_ratio=4.0,
drop=0.0,
drop_path=0.0,
local_conv_size=3,
activation=nn.GELU,
):
@ -350,7 +358,7 @@ class TinyViTBlock(nn.Module):
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
assert window_size > 0, 'window_size must be greater than 0'
assert window_size > 0, "window_size must be greater than 0"
self.window_size = window_size
self.mlp_ratio = mlp_ratio
@ -358,7 +366,7 @@ class TinyViTBlock(nn.Module):
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path = nn.Identity()
assert dim % num_heads == 0, 'dim must be divisible by num_heads'
assert dim % num_heads == 0, "dim must be divisible by num_heads"
head_dim = dim // num_heads
window_resolution = (window_size, window_size)
@ -377,7 +385,7 @@ class TinyViTBlock(nn.Module):
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, 'input feature has wrong size'
assert L == H * W, "input feature has wrong size"
res_x = x
if H == self.window_size and W == self.window_size:
x = self.attn(x)
@ -394,8 +402,11 @@ class TinyViTBlock(nn.Module):
nH = pH // self.window_size
nW = pW // self.window_size
# Window partition
x = x.view(B, nH, self.window_size, nW, self.window_size,
C).transpose(2, 3).reshape(B * nH * nW, self.window_size * self.window_size, C)
x = (
x.view(B, nH, self.window_size, nW, self.window_size, C)
.transpose(2, 3)
.reshape(B * nH * nW, self.window_size * self.window_size, C)
)
x = self.attn(x)
# Window reverse
x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C)
@ -417,8 +428,10 @@ class TinyViTBlock(nn.Module):
"""Returns a formatted string representing the TinyViTBlock's parameters: dimension, input resolution, number of
attentions heads, window size, and MLP ratio.
"""
return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
f'window_size={self.window_size}, mlp_ratio={self.mlp_ratio}'
return (
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
)
class BasicLayer(nn.Module):
@ -431,9 +444,9 @@ class BasicLayer(nn.Module):
depth,
num_heads,
window_size,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
mlp_ratio=4.0,
drop=0.0,
drop_path=0.0,
downsample=None,
use_checkpoint=False,
local_conv_size=3,
@ -468,22 +481,29 @@ class BasicLayer(nn.Module):
self.use_checkpoint = use_checkpoint
# Build blocks
self.blocks = nn.ModuleList([
TinyViTBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
drop=drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
local_conv_size=local_conv_size,
activation=activation,
) for i in range(depth)])
self.blocks = nn.ModuleList(
[
TinyViTBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
drop=drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
local_conv_size=local_conv_size,
activation=activation,
)
for i in range(depth)
]
)
# Patch merging layer
self.downsample = None if downsample is None else downsample(
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
self.downsample = (
None
if downsample is None
else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
)
def forward(self, x):
"""Performs forward propagation on the input tensor and returns a normalized tensor."""
@ -493,7 +513,7 @@ class BasicLayer(nn.Module):
def extra_repr(self) -> str:
"""Returns a string representation of the extra_repr function with the layer's parameters."""
return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
class LayerNorm2d(nn.Module):
@ -549,8 +569,8 @@ class TinyViT(nn.Module):
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_sizes=[7, 7, 14, 7],
mlp_ratio=4.,
drop_rate=0.,
mlp_ratio=4.0,
drop_rate=0.0,
drop_path_rate=0.1,
use_checkpoint=False,
mbconv_expand_ratio=4.0,
@ -585,10 +605,9 @@ class TinyViT(nn.Module):
activation = nn.GELU
self.patch_embed = PatchEmbed(in_chans=in_chans,
embed_dim=embed_dims[0],
resolution=img_size,
activation=activation)
self.patch_embed = PatchEmbed(
in_chans=in_chans, embed_dim=embed_dims[0], resolution=img_size, activation=activation
)
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
@ -601,27 +620,30 @@ class TinyViT(nn.Module):
for i_layer in range(self.num_layers):
kwargs = dict(
dim=embed_dims[i_layer],
input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))),
input_resolution=(
patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
),
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
# patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
out_dim=embed_dims[min(i_layer + 1,
len(embed_dims) - 1)],
out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)],
activation=activation,
)
if i_layer == 0:
layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)
else:
layer = BasicLayer(num_heads=num_heads[i_layer],
window_size=window_sizes[i_layer],
mlp_ratio=self.mlp_ratio,
drop=drop_rate,
local_conv_size=local_conv_size,
**kwargs)
layer = BasicLayer(
num_heads=num_heads[i_layer],
window_size=window_sizes[i_layer],
mlp_ratio=self.mlp_ratio,
drop=drop_rate,
local_conv_size=local_conv_size,
**kwargs,
)
self.layers.append(layer)
# Classifier head
@ -680,7 +702,7 @@ class TinyViT(nn.Module):
def _check_lr_scale(m):
"""Checks if the learning rate scale attribute is present in module's parameters."""
for p in m.parameters():
assert hasattr(p, 'lr_scale'), p.param_name
assert hasattr(p, "lr_scale"), p.param_name
self.apply(_check_lr_scale)
@ -698,7 +720,7 @@ class TinyViT(nn.Module):
@torch.jit.ignore
def no_weight_decay_keywords(self):
"""Returns a dictionary of parameter names where weight decay should not be applied."""
return {'attention_biases'}
return {"attention_biases"}
def forward_features(self, x):
"""Runs the input through the model layers and returns the transformed output."""

@ -62,7 +62,8 @@ class TwoWayTransformer(nn.Module):
activation=activation,
attention_downsample_rate=attention_downsample_rate,
skip_first_layer_pe=(i == 0),
))
)
)
self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
self.norm_final_attn = nn.LayerNorm(embedding_dim)
@ -227,7 +228,7 @@ class Attention(nn.Module):
self.embedding_dim = embedding_dim
self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads
assert self.internal_dim % num_heads == 0, 'num_heads must divide embedding_dim.'
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)

@ -19,8 +19,17 @@ from ultralytics.engine.results import Results
from ultralytics.utils import DEFAULT_CFG, ops
from ultralytics.utils.torch_utils import select_device
from .amg import (batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score,
generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks)
from .amg import (
batch_iterator,
batched_mask_to_box,
build_all_layer_point_grids,
calculate_stability_score,
generate_crop_boxes,
is_box_near_crop_edge,
remove_small_regions,
uncrop_boxes_xyxy,
uncrop_masks,
)
from .build import build_sam
@ -58,7 +67,7 @@ class Predictor(BasePredictor):
"""
if overrides is None:
overrides = {}
overrides.update(dict(task='segment', mode='predict', imgsz=1024))
overrides.update(dict(task="segment", mode="predict", imgsz=1024))
super().__init__(cfg, overrides, _callbacks)
self.args.retina_masks = True
self.im = None
@ -107,7 +116,7 @@ class Predictor(BasePredictor):
Returns:
(List[np.ndarray]): List of transformed images.
"""
assert len(im) == 1, 'SAM model does not currently support batched inference'
assert len(im) == 1, "SAM model does not currently support batched inference"
letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
return [letterbox(image=x) for x in im]
@ -132,9 +141,9 @@ class Predictor(BasePredictor):
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
"""
# Override prompts if any stored in self.prompts
bboxes = self.prompts.pop('bboxes', bboxes)
points = self.prompts.pop('points', points)
masks = self.prompts.pop('masks', masks)
bboxes = self.prompts.pop("bboxes", bboxes)
points = self.prompts.pop("points", points)
masks = self.prompts.pop("masks", masks)
if all(i is None for i in [bboxes, points, masks]):
return self.generate(im, *args, **kwargs)
@ -199,18 +208,20 @@ class Predictor(BasePredictor):
# `d` could be 1 or 3 depends on `multimask_output`.
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
def generate(self,
im,
crop_n_layers=0,
crop_overlap_ratio=512 / 1500,
crop_downscale_factor=1,
point_grids=None,
points_stride=32,
points_batch_size=64,
conf_thres=0.88,
stability_score_thresh=0.95,
stability_score_offset=0.95,
crop_nms_thresh=0.7):
def generate(
self,
im,
crop_n_layers=0,
crop_overlap_ratio=512 / 1500,
crop_downscale_factor=1,
point_grids=None,
points_stride=32,
points_batch_size=64,
conf_thres=0.88,
stability_score_thresh=0.95,
stability_score_offset=0.95,
crop_nms_thresh=0.7,
):
"""
Perform image segmentation using the Segment Anything Model (SAM).
@ -248,19 +259,20 @@ class Predictor(BasePredictor):
area = torch.tensor(w * h, device=im.device)
points_scale = np.array([[w, h]]) # w, h
# Crop image and interpolate to input size
crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode='bilinear', align_corners=False)
crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False)
# (num_points, 2)
points_for_image = point_grids[layer_idx] * points_scale
crop_masks, crop_scores, crop_bboxes = [], [], []
for (points, ) in batch_iterator(points_batch_size, points_for_image):
for (points,) in batch_iterator(points_batch_size, points_for_image):
pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True)
# Interpolate predicted masks to input size
pred_mask = F.interpolate(pred_mask[None], (h, w), mode='bilinear', align_corners=False)[0]
pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0]
idx = pred_score > conf_thres
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
stability_score = calculate_stability_score(pred_mask, self.model.mask_threshold,
stability_score_offset)
stability_score = calculate_stability_score(
pred_mask, self.model.mask_threshold, stability_score_offset
)
idx = stability_score > stability_score_thresh
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
# Bool type is much more memory-efficient.
@ -404,7 +416,7 @@ class Predictor(BasePredictor):
model = build_sam(self.args.model)
self.setup_model(model)
self.setup_source(image)
assert len(self.dataset) == 1, '`set_image` only supports setting one image!'
assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
for batch in self.dataset:
im = self.preprocess(batch[1])
self.features = self.model.image_encoder(im)
@ -446,9 +458,9 @@ class Predictor(BasePredictor):
scores = []
for mask in masks:
mask = mask.cpu().numpy().astype(np.uint8)
mask, changed = remove_small_regions(mask, min_area, mode='holes')
mask, changed = remove_small_regions(mask, min_area, mode="holes")
unchanged = not changed
mask, changed = remove_small_regions(mask, min_area, mode='islands')
mask, changed = remove_small_regions(mask, min_area, mode="islands")
unchanged = unchanged and not changed
new_masks.append(torch.as_tensor(mask).unsqueeze(0))

@ -30,14 +30,9 @@ class DETRLoss(nn.Module):
device (torch.device): Device on which tensors are stored.
"""
def __init__(self,
nc=80,
loss_gain=None,
aux_loss=True,
use_fl=True,
use_vfl=False,
use_uni_match=False,
uni_match_ind=0):
def __init__(
self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0
):
"""
DETR loss function.
@ -52,9 +47,9 @@ class DETRLoss(nn.Module):
super().__init__()
if loss_gain is None:
loss_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'no_object': 0.1, 'mask': 1, 'dice': 1}
loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1}
self.nc = nc
self.matcher = HungarianMatcher(cost_gain={'class': 2, 'bbox': 5, 'giou': 2})
self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
self.loss_gain = loss_gain
self.aux_loss = aux_loss
self.fl = FocalLoss() if use_fl else None
@ -64,10 +59,10 @@ class DETRLoss(nn.Module):
self.uni_match_ind = uni_match_ind
self.device = None
def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=''):
def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""):
"""Computes the classification loss based on predictions, target values, and ground truth scores."""
# Logits: [b, query, num_classes], gt_class: list[[n, 1]]
name_class = f'loss_class{postfix}'
name_class = f"loss_class{postfix}"
bs, nq = pred_scores.shape[:2]
# one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
@ -82,28 +77,28 @@ class DETRLoss(nn.Module):
loss_cls = self.fl(pred_scores, one_hot.float())
loss_cls /= max(num_gts, 1) / nq
else:
loss_cls = nn.BCEWithLogitsLoss(reduction='none')(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
loss_cls = nn.BCEWithLogitsLoss(reduction="none")(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
return {name_class: loss_cls.squeeze() * self.loss_gain['class']}
return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=''):
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
"""Calculates and returns the bounding box loss and GIoU loss for the predicted and ground truth bounding
boxes.
"""
# Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
name_bbox = f'loss_bbox{postfix}'
name_giou = f'loss_giou{postfix}'
name_bbox = f"loss_bbox{postfix}"
name_giou = f"loss_giou{postfix}"
loss = {}
if len(gt_bboxes) == 0:
loss[name_bbox] = torch.tensor(0., device=self.device)
loss[name_giou] = torch.tensor(0., device=self.device)
loss[name_bbox] = torch.tensor(0.0, device=self.device)
loss[name_giou] = torch.tensor(0.0, device=self.device)
return loss
loss[name_bbox] = self.loss_gain['bbox'] * F.l1_loss(pred_bboxes, gt_bboxes, reduction='sum') / len(gt_bboxes)
loss[name_bbox] = self.loss_gain["bbox"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction="sum") / len(gt_bboxes)
loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
loss[name_giou] = self.loss_gain['giou'] * loss[name_giou]
loss[name_giou] = self.loss_gain["giou"] * loss[name_giou]
return {k: v.squeeze() for k, v in loss.items()}
# This function is for future RT-DETR Segment models
@ -137,50 +132,57 @@ class DETRLoss(nn.Module):
# loss = 1 - (numerator + 1) / (denominator + 1)
# return loss.sum() / num_gts
def _get_loss_aux(self,
pred_bboxes,
pred_scores,
gt_bboxes,
gt_cls,
gt_groups,
match_indices=None,
postfix='',
masks=None,
gt_mask=None):
def _get_loss_aux(
self,
pred_bboxes,
pred_scores,
gt_bboxes,
gt_cls,
gt_groups,
match_indices=None,
postfix="",
masks=None,
gt_mask=None,
):
"""Get auxiliary losses."""
# NOTE: loss class, bbox, giou, mask, dice
loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
if match_indices is None and self.use_uni_match:
match_indices = self.matcher(pred_bboxes[self.uni_match_ind],
pred_scores[self.uni_match_ind],
gt_bboxes,
gt_cls,
gt_groups,
masks=masks[self.uni_match_ind] if masks is not None else None,
gt_mask=gt_mask)
match_indices = self.matcher(
pred_bboxes[self.uni_match_ind],
pred_scores[self.uni_match_ind],
gt_bboxes,
gt_cls,
gt_groups,
masks=masks[self.uni_match_ind] if masks is not None else None,
gt_mask=gt_mask,
)
for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
aux_masks = masks[i] if masks is not None else None
loss_ = self._get_loss(aux_bboxes,
aux_scores,
gt_bboxes,
gt_cls,
gt_groups,
masks=aux_masks,
gt_mask=gt_mask,
postfix=postfix,
match_indices=match_indices)
loss[0] += loss_[f'loss_class{postfix}']
loss[1] += loss_[f'loss_bbox{postfix}']
loss[2] += loss_[f'loss_giou{postfix}']
loss_ = self._get_loss(
aux_bboxes,
aux_scores,
gt_bboxes,
gt_cls,
gt_groups,
masks=aux_masks,
gt_mask=gt_mask,
postfix=postfix,
match_indices=match_indices,
)
loss[0] += loss_[f"loss_class{postfix}"]
loss[1] += loss_[f"loss_bbox{postfix}"]
loss[2] += loss_[f"loss_giou{postfix}"]
# if masks is not None and gt_mask is not None:
# loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
# loss[3] += loss_[f'loss_mask{postfix}']
# loss[4] += loss_[f'loss_dice{postfix}']
loss = {
f'loss_class_aux{postfix}': loss[0],
f'loss_bbox_aux{postfix}': loss[1],
f'loss_giou_aux{postfix}': loss[2]}
f"loss_class_aux{postfix}": loss[0],
f"loss_bbox_aux{postfix}": loss[1],
f"loss_giou_aux{postfix}": loss[2],
}
# if masks is not None and gt_mask is not None:
# loss[f'loss_mask_aux{postfix}'] = loss[3]
# loss[f'loss_dice_aux{postfix}'] = loss[4]
@ -196,33 +198,37 @@ class DETRLoss(nn.Module):
def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
"""Assigns predicted bounding boxes to ground truth bounding boxes based on the match indices."""
pred_assigned = torch.cat([
t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
for t, (I, _) in zip(pred_bboxes, match_indices)])
gt_assigned = torch.cat([
t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
for t, (_, J) in zip(gt_bboxes, match_indices)])
pred_assigned = torch.cat(
[
t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
for t, (I, _) in zip(pred_bboxes, match_indices)
]
)
gt_assigned = torch.cat(
[
t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
for t, (_, J) in zip(gt_bboxes, match_indices)
]
)
return pred_assigned, gt_assigned
def _get_loss(self,
pred_bboxes,
pred_scores,
gt_bboxes,
gt_cls,
gt_groups,
masks=None,
gt_mask=None,
postfix='',
match_indices=None):
def _get_loss(
self,
pred_bboxes,
pred_scores,
gt_bboxes,
gt_cls,
gt_groups,
masks=None,
gt_mask=None,
postfix="",
match_indices=None,
):
"""Get losses."""
if match_indices is None:
match_indices = self.matcher(pred_bboxes,
pred_scores,
gt_bboxes,
gt_cls,
gt_groups,
masks=masks,
gt_mask=gt_mask)
match_indices = self.matcher(
pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask
)
idx, gt_idx = self._get_index(match_indices)
pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
@ -242,7 +248,7 @@ class DETRLoss(nn.Module):
# loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix))
return loss
def forward(self, pred_bboxes, pred_scores, batch, postfix='', **kwargs):
def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs):
"""
Args:
pred_bboxes (torch.Tensor): [l, b, query, 4]
@ -254,21 +260,19 @@ class DETRLoss(nn.Module):
postfix (str): postfix of loss name.
"""
self.device = pred_bboxes.device
match_indices = kwargs.get('match_indices', None)
gt_cls, gt_bboxes, gt_groups = batch['cls'], batch['bboxes'], batch['gt_groups']
match_indices = kwargs.get("match_indices", None)
gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"]
total_loss = self._get_loss(pred_bboxes[-1],
pred_scores[-1],
gt_bboxes,
gt_cls,
gt_groups,
postfix=postfix,
match_indices=match_indices)
total_loss = self._get_loss(
pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices
)
if self.aux_loss:
total_loss.update(
self._get_loss_aux(pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices,
postfix))
self._get_loss_aux(
pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix
)
)
return total_loss
@ -300,18 +304,18 @@ class RTDETRDetectionLoss(DETRLoss):
# Check for denoising metadata to compute denoising training loss
if dn_meta is not None:
dn_pos_idx, dn_num_group = dn_meta['dn_pos_idx'], dn_meta['dn_num_group']
assert len(batch['gt_groups']) == len(dn_pos_idx)
dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"]
assert len(batch["gt_groups"]) == len(dn_pos_idx)
# Get the match indices for denoising
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch['gt_groups'])
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"])
# Compute the denoising training loss
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix='_dn', match_indices=match_indices)
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices)
total_loss.update(dn_loss)
else:
# If no denoising metadata is provided, set denoising loss to zero
total_loss.update({f'{k}_dn': torch.tensor(0., device=self.device) for k in total_loss.keys()})
total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()})
return total_loss
@ -334,8 +338,8 @@ class RTDETRDetectionLoss(DETRLoss):
if num_gt > 0:
gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
gt_idx = gt_idx.repeat(dn_num_group)
assert len(dn_pos_idx[i]) == len(gt_idx), 'Expected the same length, '
f'but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.'
assert len(dn_pos_idx[i]) == len(gt_idx), "Expected the same length, "
f"but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively."
dn_match_indices.append((dn_pos_idx[i], gt_idx))
else:
dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))

@ -37,7 +37,7 @@ class HungarianMatcher(nn.Module):
"""
super().__init__()
if cost_gain is None:
cost_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'mask': 1, 'dice': 1}
cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
self.cost_gain = cost_gain
self.use_fl = use_fl
self.with_mask = with_mask
@ -86,7 +86,7 @@ class HungarianMatcher(nn.Module):
# Compute the classification cost
pred_scores = pred_scores[:, gt_cls]
if self.use_fl:
neg_cost_class = (1 - self.alpha) * (pred_scores ** self.gamma) * (-(1 - pred_scores + 1e-8).log())
neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log())
pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
cost_class = pos_cost_class - neg_cost_class
else:
@ -99,9 +99,11 @@ class HungarianMatcher(nn.Module):
cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
# Final cost matrix
C = self.cost_gain['class'] * cost_class + \
self.cost_gain['bbox'] * cost_bbox + \
self.cost_gain['giou'] * cost_giou
C = (
self.cost_gain["class"] * cost_class
+ self.cost_gain["bbox"] * cost_bbox
+ self.cost_gain["giou"] * cost_giou
)
# Compute the mask cost and dice cost
if self.with_mask:
C += self._cost_mask(bs, gt_groups, masks, gt_mask)
@ -111,10 +113,11 @@ class HungarianMatcher(nn.Module):
C = C.view(bs, nq, -1).cpu()
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
# (idx for queries, idx for gt)
return [(torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
for k, (i, j) in enumerate(indices)]
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) # (idx for queries, idx for gt)
return [
(torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
for k, (i, j) in enumerate(indices)
]
# This function is for future RT-DETR Segment models
# def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
@ -147,14 +150,9 @@ class HungarianMatcher(nn.Module):
# return C
def get_cdn_group(batch,
num_classes,
num_queries,
class_embed,
num_dn=100,
cls_noise_ratio=0.5,
box_noise_scale=1.0,
training=False):
def get_cdn_group(
batch, num_classes, num_queries, class_embed, num_dn=100, cls_noise_ratio=0.5, box_noise_scale=1.0, training=False
):
"""
Get contrastive denoising training group. This function creates a contrastive denoising training group with positive
and negative samples from the ground truths (gt). It applies noise to the class labels and bounding box coordinates,
@ -180,7 +178,7 @@ def get_cdn_group(batch,
if (not training) or num_dn <= 0:
return None, None, None, None
gt_groups = batch['gt_groups']
gt_groups = batch["gt_groups"]
total_num = sum(gt_groups)
max_nums = max(gt_groups)
if max_nums == 0:
@ -190,9 +188,9 @@ def get_cdn_group(batch,
num_group = 1 if num_group == 0 else num_group
# Pad gt to max_num of a batch
bs = len(gt_groups)
gt_cls = batch['cls'] # (bs*num, )
gt_bbox = batch['bboxes'] # bs*num, 4
b_idx = batch['batch_idx']
gt_cls = batch["cls"] # (bs*num, )
gt_bbox = batch["bboxes"] # bs*num, 4
b_idx = batch["batch_idx"]
# Each group has positive and negative queries.
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
@ -245,16 +243,21 @@ def get_cdn_group(batch,
# Reconstruct cannot see each other
for i in range(num_group):
if i == 0:
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
if i == num_group - 1:
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * i * 2] = True
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True
else:
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * 2 * i] = True
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True
dn_meta = {
'dn_pos_idx': [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
'dn_num_group': num_group,
'dn_num_split': [num_dn, num_queries]}
return padding_cls.to(class_embed.device), padding_bbox.to(class_embed.device), attn_mask.to(
class_embed.device), dn_meta
"dn_pos_idx": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
"dn_num_group": num_group,
"dn_num_split": [num_dn, num_queries],
}
return (
padding_cls.to(class_embed.device),
padding_bbox.to(class_embed.device),
attn_mask.to(class_embed.device),
dn_meta,
)

@ -4,4 +4,4 @@ from ultralytics.models.yolo import classify, detect, obb, pose, segment
from .model import YOLO
__all__ = 'classify', 'segment', 'detect', 'pose', 'obb', 'YOLO'
__all__ = "classify", "segment", "detect", "pose", "obb", "YOLO"

@ -4,4 +4,4 @@ from ultralytics.models.yolo.classify.predict import ClassificationPredictor
from ultralytics.models.yolo.classify.train import ClassificationTrainer
from ultralytics.models.yolo.classify.val import ClassificationValidator
__all__ = 'ClassificationPredictor', 'ClassificationTrainer', 'ClassificationValidator'
__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator"

@ -30,19 +30,21 @@ class ClassificationPredictor(BasePredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initializes ClassificationPredictor setting the task to 'classify'."""
super().__init__(cfg, overrides, _callbacks)
self.args.task = 'classify'
self._legacy_transform_name = 'ultralytics.yolo.data.augment.ToTensor'
self.args.task = "classify"
self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"
def preprocess(self, img):
"""Converts input image to model-compatible data type."""
if not isinstance(img, torch.Tensor):
is_legacy_transform = any(self._legacy_transform_name in str(transform)
for transform in self.transforms.transforms)
is_legacy_transform = any(
self._legacy_transform_name in str(transform) for transform in self.transforms.transforms
)
if is_legacy_transform: # to handle legacy transforms
img = torch.stack([self.transforms(im) for im in img], dim=0)
else:
img = torch.stack([self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img],
dim=0)
img = torch.stack(
[self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
)
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32

@ -33,23 +33,23 @@ class ClassificationTrainer(BaseTrainer):
"""Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
if overrides is None:
overrides = {}
overrides['task'] = 'classify'
if overrides.get('imgsz') is None:
overrides['imgsz'] = 224
overrides["task"] = "classify"
if overrides.get("imgsz") is None:
overrides["imgsz"] = 224
super().__init__(cfg, overrides, _callbacks)
def set_model_attributes(self):
"""Set the YOLO model's class names from the loaded dataset."""
self.model.names = self.data['names']
self.model.names = self.data["names"]
def get_model(self, cfg=None, weights=None, verbose=True):
"""Returns a modified PyTorch model configured for training YOLO."""
model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
for m in model.modules():
if not self.args.pretrained and hasattr(m, 'reset_parameters'):
if not self.args.pretrained and hasattr(m, "reset_parameters"):
m.reset_parameters()
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
m.p = self.args.dropout # set dropout
@ -64,32 +64,32 @@ class ClassificationTrainer(BaseTrainer):
model, ckpt = str(self.model), None
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
if model.endswith('.pt'):
self.model, ckpt = attempt_load_one_weight(model, device='cpu')
if model.endswith(".pt"):
self.model, ckpt = attempt_load_one_weight(model, device="cpu")
for p in self.model.parameters():
p.requires_grad = True # for training
elif model.split('.')[-1] in ('yaml', 'yml'):
elif model.split(".")[-1] in ("yaml", "yml"):
self.model = self.get_model(cfg=model)
elif model in torchvision.models.__dict__:
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if self.args.pretrained else None)
self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None)
else:
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
FileNotFoundError(f"ERROR: model={model} not found locally or online. Please check model name.")
ClassificationModel.reshape_outputs(self.model, self.data["nc"])
return ckpt
def build_dataset(self, img_path, mode='train', batch=None):
def build_dataset(self, img_path, mode="train", batch=None):
"""Creates a ClassificationDataset instance given an image path, and mode (train/test etc.)."""
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train', prefix=mode)
return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = self.build_dataset(dataset_path, mode)
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
# Attach inference transforms
if mode != 'train':
if mode != "train":
if is_parallel(self.model):
self.model.module.transforms = loader.dataset.torch_transforms
else:
@ -98,27 +98,32 @@ class ClassificationTrainer(BaseTrainer):
def preprocess_batch(self, batch):
"""Preprocesses a batch of images and classes."""
batch['img'] = batch['img'].to(self.device)
batch['cls'] = batch['cls'].to(self.device)
batch["img"] = batch["img"].to(self.device)
batch["cls"] = batch["cls"].to(self.device)
return batch
def progress_string(self):
"""Returns a formatted string showing training progress."""
return ('\n' + '%11s' * (4 + len(self.loss_names))) % \
('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
"Epoch",
"GPU_mem",
*self.loss_names,
"Instances",
"Size",
)
def get_validator(self):
"""Returns an instance of ClassificationValidator for validation."""
self.loss_names = ['loss']
self.loss_names = ["loss"]
return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks)
def label_loss_items(self, loss_items=None, prefix='train'):
def label_loss_items(self, loss_items=None, prefix="train"):
"""
Returns a loss dict with labelled training loss items tensor.
Not needed for classification but necessary for segmentation & detection
"""
keys = [f'{prefix}/{x}' for x in self.loss_names]
keys = [f"{prefix}/{x}" for x in self.loss_names]
if loss_items is None:
return keys
loss_items = [round(float(loss_items), 5)]
@ -134,19 +139,20 @@ class ClassificationTrainer(BaseTrainer):
if f.exists():
strip_optimizer(f) # strip optimizers
if f is self.best:
LOGGER.info(f'\nValidating {f}...')
LOGGER.info(f"\nValidating {f}...")
self.validator.args.data = self.args.data
self.validator.args.plots = self.args.plots
self.metrics = self.validator(model=f)
self.metrics.pop('fitness', None)
self.run_callbacks('on_fit_epoch_end')
self.metrics.pop("fitness", None)
self.run_callbacks("on_fit_epoch_end")
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
def plot_training_samples(self, batch, ni):
"""Plots training samples with their annotations."""
plot_images(
images=batch['img'],
batch_idx=torch.arange(len(batch['img'])),
cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
images=batch["img"],
batch_idx=torch.arange(len(batch["img"])),
cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models
fname=self.save_dir / f"train_batch{ni}.jpg",
on_plot=self.on_plot,
)

@ -31,43 +31,42 @@ class ClassificationValidator(BaseValidator):
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.targets = None
self.pred = None
self.args.task = 'classify'
self.args.task = "classify"
self.metrics = ClassifyMetrics()
def get_desc(self):
"""Returns a formatted string summarizing classification metrics."""
return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc')
return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
def init_metrics(self, model):
"""Initialize confusion matrix, class names, and top-1 and top-5 accuracy."""
self.names = model.names
self.nc = len(model.names)
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task='classify')
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task="classify")
self.pred = []
self.targets = []
def preprocess(self, batch):
"""Preprocesses input batch and returns it."""
batch['img'] = batch['img'].to(self.device, non_blocking=True)
batch['img'] = batch['img'].half() if self.args.half else batch['img'].float()
batch['cls'] = batch['cls'].to(self.device)
batch["img"] = batch["img"].to(self.device, non_blocking=True)
batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
batch["cls"] = batch["cls"].to(self.device)
return batch
def update_metrics(self, preds, batch):
"""Updates running metrics with model predictions and batch targets."""
n5 = min(len(self.names), 5)
self.pred.append(preds.argsort(1, descending=True)[:, :n5])
self.targets.append(batch['cls'])
self.targets.append(batch["cls"])
def finalize_metrics(self, *args, **kwargs):
"""Finalizes metrics of the model such as confusion_matrix and speed."""
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
if self.args.plots:
for normalize in True, False:
self.confusion_matrix.plot(save_dir=self.save_dir,
names=self.names.values(),
normalize=normalize,
on_plot=self.on_plot)
self.confusion_matrix.plot(
save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
)
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
self.metrics.save_dir = self.save_dir
@ -88,24 +87,27 @@ class ClassificationValidator(BaseValidator):
def print_results(self):
"""Prints evaluation metrics for YOLO object detection model."""
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format
LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
def plot_val_samples(self, batch, ni):
"""Plot validation image samples."""
plot_images(
images=batch['img'],
batch_idx=torch.arange(len(batch['img'])),
cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
images=batch["img"],
batch_idx=torch.arange(len(batch["img"])),
cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
names=self.names,
on_plot=self.on_plot)
on_plot=self.on_plot,
)
def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""
plot_images(batch['img'],
batch_idx=torch.arange(len(batch['img'])),
cls=torch.argmax(preds, dim=1),
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names,
on_plot=self.on_plot) # pred
plot_images(
batch["img"],
batch_idx=torch.arange(len(batch["img"])),
cls=torch.argmax(preds, dim=1),
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
names=self.names,
on_plot=self.on_plot,
) # pred

@ -4,4 +4,4 @@ from .predict import DetectionPredictor
from .train import DetectionTrainer
from .val import DetectionValidator
__all__ = 'DetectionPredictor', 'DetectionTrainer', 'DetectionValidator'
__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator"

@ -22,12 +22,14 @@ class DetectionPredictor(BasePredictor):
def postprocess(self, preds, img, orig_imgs):
"""Post-processes predictions and returns a list of Results objects."""
preds = ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
classes=self.args.classes)
preds = ops.non_max_suppression(
preds,
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
classes=self.args.classes,
)
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

@ -30,7 +30,7 @@ class DetectionTrainer(BaseTrainer):
```
"""
def build_dataset(self, img_path, mode='train', batch=None):
def build_dataset(self, img_path, mode="train", batch=None):
"""
Build YOLO Dataset.
@ -40,33 +40,37 @@ class DetectionTrainer(BaseTrainer):
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs)
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
"""Construct and return dataloader."""
assert mode in ['train', 'val']
assert mode in ["train", "val"]
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = self.build_dataset(dataset_path, mode, batch_size)
shuffle = mode == 'train'
if getattr(dataset, 'rect', False) and shuffle:
shuffle = mode == "train"
if getattr(dataset, "rect", False) and shuffle:
LOGGER.warning("WARNING ⚠ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
shuffle = False
workers = self.args.workers if mode == 'train' else self.args.workers * 2
workers = self.args.workers if mode == "train" else self.args.workers * 2
return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
def preprocess_batch(self, batch):
"""Preprocesses a batch of images by scaling and converting to float."""
batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
if self.args.multi_scale:
imgs = batch['img']
sz = (random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride) // self.stride *
self.stride) # size
imgs = batch["img"]
sz = (
random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride)
// self.stride
* self.stride
) # size
sf = sz / max(imgs.shape[2:]) # scale factor
if sf != 1:
ns = [math.ceil(x * sf / self.stride) * self.stride
for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
batch['img'] = imgs
ns = [
math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]
] # new shape (stretched to gs-multiple)
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
batch["img"] = imgs
return batch
def set_model_attributes(self):
@ -74,33 +78,32 @@ class DetectionTrainer(BaseTrainer):
# self.args.box *= 3 / nl # scale to layers
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
self.model.nc = self.data['nc'] # attach number of classes to model
self.model.names = self.data['names'] # attach class names to model
self.model.nc = self.data["nc"] # attach number of classes to model
self.model.names = self.data["names"] # attach class names to model
self.model.args = self.args # attach hyperparameters to model
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return a YOLO detection model."""
model = DetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def get_validator(self):
"""Returns a DetectionValidator for YOLO model validation."""
self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
return yolo.detect.DetectionValidator(self.test_loader,
save_dir=self.save_dir,
args=copy(self.args),
_callbacks=self.callbacks)
self.loss_names = "box_loss", "cls_loss", "dfl_loss"
return yolo.detect.DetectionValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
def label_loss_items(self, loss_items=None, prefix='train'):
def label_loss_items(self, loss_items=None, prefix="train"):
"""
Returns a loss dict with labelled training loss items tensor.
Not needed for classification but necessary for segmentation & detection
"""
keys = [f'{prefix}/{x}' for x in self.loss_names]
keys = [f"{prefix}/{x}" for x in self.loss_names]
if loss_items is not None:
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
return dict(zip(keys, loss_items))
@ -109,18 +112,25 @@ class DetectionTrainer(BaseTrainer):
def progress_string(self):
"""Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
return ('\n' + '%11s' *
(4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
"Epoch",
"GPU_mem",
*self.loss_names,
"Instances",
"Size",
)
def plot_training_samples(self, batch, ni):
"""Plots training samples with their annotations."""
plot_images(images=batch['img'],
batch_idx=batch['batch_idx'],
cls=batch['cls'].squeeze(-1),
bboxes=batch['bboxes'],
paths=batch['im_file'],
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
plot_images(
images=batch["img"],
batch_idx=batch["batch_idx"],
cls=batch["cls"].squeeze(-1),
bboxes=batch["bboxes"],
paths=batch["im_file"],
fname=self.save_dir / f"train_batch{ni}.jpg",
on_plot=self.on_plot,
)
def plot_metrics(self):
"""Plots metrics from a CSV file."""
@ -128,6 +138,6 @@ class DetectionTrainer(BaseTrainer):
def plot_training_labels(self):
"""Create a labeled training plot of the YOLO model."""
boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0)
cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0)
plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot)
boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)

@ -34,7 +34,7 @@ class DetectionValidator(BaseValidator):
self.nt_per_class = None
self.is_coco = False
self.class_map = None
self.args.task = 'detect'
self.args.task = "detect"
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
self.niou = self.iouv.numel()
@ -42,25 +42,30 @@ class DetectionValidator(BaseValidator):
def preprocess(self, batch):
"""Preprocesses batch of images for YOLO training."""
batch['img'] = batch['img'].to(self.device, non_blocking=True)
batch['img'] = (batch['img'].half() if self.args.half else batch['img'].float()) / 255
for k in ['batch_idx', 'cls', 'bboxes']:
batch["img"] = batch["img"].to(self.device, non_blocking=True)
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
for k in ["batch_idx", "cls", "bboxes"]:
batch[k] = batch[k].to(self.device)
if self.args.save_hybrid:
height, width = batch['img'].shape[2:]
nb = len(batch['img'])
bboxes = batch['bboxes'] * torch.tensor((width, height, width, height), device=self.device)
self.lb = [
torch.cat([batch['cls'][batch['batch_idx'] == i], bboxes[batch['batch_idx'] == i]], dim=-1)
for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling
height, width = batch["img"].shape[2:]
nb = len(batch["img"])
bboxes = batch["bboxes"] * torch.tensor((width, height, width, height), device=self.device)
self.lb = (
[
torch.cat([batch["cls"][batch["batch_idx"] == i], bboxes[batch["batch_idx"] == i]], dim=-1)
for i in range(nb)
]
if self.args.save_hybrid
else []
) # for autolabelling
return batch
def init_metrics(self, model):
"""Initialize evaluation metrics for YOLO."""
val = self.data.get(self.args.split, '') # validation path
self.is_coco = isinstance(val, str) and 'coco' in val and val.endswith(f'{os.sep}val2017.txt') # is COCO
val = self.data.get(self.args.split, "") # validation path
self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000))
self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO
self.names = model.names
@ -74,26 +79,28 @@ class DetectionValidator(BaseValidator):
def get_desc(self):
"""Return a formatted string summarizing class metrics of YOLO model."""
return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)')
return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
def postprocess(self, preds):
"""Apply Non-maximum suppression to prediction outputs."""
return ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=True,
agnostic=self.args.single_cls,
max_det=self.args.max_det)
return ops.non_max_suppression(
preds,
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=True,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
)
def _prepare_batch(self, si, batch):
"""Prepares a batch of images and annotations for validation."""
idx = batch['batch_idx'] == si
cls = batch['cls'][idx].squeeze(-1)
bbox = batch['bboxes'][idx]
ori_shape = batch['ori_shape'][si]
imgsz = batch['img'].shape[2:]
ratio_pad = batch['ratio_pad'][si]
idx = batch["batch_idx"] == si
cls = batch["cls"][idx].squeeze(-1)
bbox = batch["bboxes"][idx]
ori_shape = batch["ori_shape"][si]
imgsz = batch["img"].shape[2:]
ratio_pad = batch["ratio_pad"][si]
if len(cls):
bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels
@ -103,8 +110,9 @@ class DetectionValidator(BaseValidator):
def _prepare_pred(self, pred, pbatch):
"""Prepares a batch of images and annotations for validation."""
predn = pred.clone()
ops.scale_boxes(pbatch['imgsz'], predn[:, :4], pbatch['ori_shape'],
ratio_pad=pbatch['ratio_pad']) # native-space pred
ops.scale_boxes(
pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
) # native-space pred
return predn
def update_metrics(self, preds, batch):
@ -112,19 +120,21 @@ class DetectionValidator(BaseValidator):
for si, pred in enumerate(preds):
self.seen += 1
npr = len(pred)
stat = dict(conf=torch.zeros(0, device=self.device),
pred_cls=torch.zeros(0, device=self.device),
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device))
stat = dict(
conf=torch.zeros(0, device=self.device),
pred_cls=torch.zeros(0, device=self.device),
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
)
pbatch = self._prepare_batch(si, batch)
cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox')
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
nl = len(cls)
stat['target_cls'] = cls
stat["target_cls"] = cls
if npr == 0:
if nl:
for k in self.stats.keys():
self.stats[k].append(stat[k])
# TODO: obb has not supported confusion_matrix yet.
if self.args.plots and self.args.task != 'obb':
if self.args.plots and self.args.task != "obb":
self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
continue
@ -132,24 +142,24 @@ class DetectionValidator(BaseValidator):
if self.args.single_cls:
pred[:, 5] = 0
predn = self._prepare_pred(pred, pbatch)
stat['conf'] = predn[:, 4]
stat['pred_cls'] = predn[:, 5]
stat["conf"] = predn[:, 4]
stat["pred_cls"] = predn[:, 5]
# Evaluate
if nl:
stat['tp'] = self._process_batch(predn, bbox, cls)
stat["tp"] = self._process_batch(predn, bbox, cls)
# TODO: obb has not supported confusion_matrix yet.
if self.args.plots and self.args.task != 'obb':
if self.args.plots and self.args.task != "obb":
self.confusion_matrix.process_batch(predn, bbox, cls)
for k in self.stats.keys():
self.stats[k].append(stat[k])
# Save
if self.args.save_json:
self.pred_to_json(predn, batch['im_file'][si])
self.pred_to_json(predn, batch["im_file"][si])
if self.args.save_txt:
file = self.save_dir / 'labels' / f'{Path(batch["im_file"][si]).stem}.txt'
self.save_one_txt(predn, self.args.save_conf, pbatch['ori_shape'], file)
file = self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt'
self.save_one_txt(predn, self.args.save_conf, pbatch["ori_shape"], file)
def finalize_metrics(self, *args, **kwargs):
"""Set final values for metrics speed and confusion matrix."""
@ -159,19 +169,19 @@ class DetectionValidator(BaseValidator):
def get_stats(self):
"""Returns metrics statistics and results dictionary."""
stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
if len(stats) and stats['tp'].any():
if len(stats) and stats["tp"].any():
self.metrics.process(**stats)
self.nt_per_class = np.bincount(stats['target_cls'].astype(int),
minlength=self.nc) # number of targets per class
self.nt_per_class = np.bincount(
stats["target_cls"].astype(int), minlength=self.nc
) # number of targets per class
return self.metrics.results_dict
def print_results(self):
"""Prints training/validation set metrics per class."""
pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format
LOGGER.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
if self.nt_per_class.sum() == 0:
LOGGER.warning(
f'WARNING ⚠ no labels found in {self.args.task} set, can not compute metrics without labels')
LOGGER.warning(f"WARNING ⚠ no labels found in {self.args.task} set, can not compute metrics without labels")
# Print results per class
if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
@ -180,10 +190,9 @@ class DetectionValidator(BaseValidator):
if self.args.plots:
for normalize in True, False:
self.confusion_matrix.plot(save_dir=self.save_dir,
names=self.names.values(),
normalize=normalize,
on_plot=self.on_plot)
self.confusion_matrix.plot(
save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
)
def _process_batch(self, detections, gt_bboxes, gt_cls):
"""
@ -201,7 +210,7 @@ class DetectionValidator(BaseValidator):
iou = box_iou(gt_bboxes, detections[:, :4])
return self.match_predictions(detections[:, 5], gt_cls, iou)
def build_dataset(self, img_path, mode='val', batch=None):
def build_dataset(self, img_path, mode="val", batch=None):
"""
Build YOLO Dataset.
@ -214,28 +223,32 @@ class DetectionValidator(BaseValidator):
def get_dataloader(self, dataset_path, batch_size):
"""Construct and return dataloader."""
dataset = self.build_dataset(dataset_path, batch=batch_size, mode='val')
dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader
def plot_val_samples(self, batch, ni):
"""Plot validation image samples."""
plot_images(batch['img'],
batch['batch_idx'],
batch['cls'].squeeze(-1),
batch['bboxes'],
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names,
on_plot=self.on_plot)
plot_images(
batch["img"],
batch["batch_idx"],
batch["cls"].squeeze(-1),
batch["bboxes"],
paths=batch["im_file"],
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
names=self.names,
on_plot=self.on_plot,
)
def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""
plot_images(batch['img'],
*output_to_target(preds, max_det=self.args.max_det),
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names,
on_plot=self.on_plot) # pred
plot_images(
batch["img"],
*output_to_target(preds, max_det=self.args.max_det),
paths=batch["im_file"],
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
names=self.names,
on_plot=self.on_plot,
) # pred
def save_one_txt(self, predn, save_conf, shape, file):
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
@ -243,8 +256,8 @@ class DetectionValidator(BaseValidator):
for *xyxy, conf, cls in predn.tolist():
xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
with open(file, 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')
with open(file, "a") as f:
f.write(("%g " * len(line)).rstrip() % line + "\n")
def pred_to_json(self, predn, filename):
"""Serialize YOLO predictions to COCO json format."""
@ -253,28 +266,31 @@ class DetectionValidator(BaseValidator):
box = ops.xyxy2xywh(predn[:, :4]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
for p, b in zip(predn.tolist(), box.tolist()):
self.jdict.append({
'image_id': image_id,
'category_id': self.class_map[int(p[5])],
'bbox': [round(x, 3) for x in b],
'score': round(p[4], 5)})
self.jdict.append(
{
"image_id": image_id,
"category_id": self.class_map[int(p[5])],
"bbox": [round(x, 3) for x in b],
"score": round(p[4], 5),
}
)
def eval_json(self, stats):
"""Evaluates YOLO output in JSON format and returns performance statistics."""
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
pred_json = self.save_dir / 'predictions.json' # predictions
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
pred_json = self.save_dir / "predictions.json" # predictions
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements('pycocotools>=2.0.6')
check_requirements("pycocotools>=2.0.6")
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
assert x.is_file(), f'{x} file not found'
assert x.is_file(), f"{x} file not found"
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
eval = COCOeval(anno, pred, 'bbox')
eval = COCOeval(anno, pred, "bbox")
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
eval.evaluate()
@ -282,5 +298,5 @@ class DetectionValidator(BaseValidator):
eval.summarize()
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
except Exception as e:
LOGGER.warning(f'pycocotools unable to run: {e}')
LOGGER.warning(f"pycocotools unable to run: {e}")
return stats

@ -12,28 +12,34 @@ class YOLO(Model):
def task_map(self):
"""Map head to model, trainer, validator, and predictor classes."""
return {
'classify': {
'model': ClassificationModel,
'trainer': yolo.classify.ClassificationTrainer,
'validator': yolo.classify.ClassificationValidator,
'predictor': yolo.classify.ClassificationPredictor, },
'detect': {
'model': DetectionModel,
'trainer': yolo.detect.DetectionTrainer,
'validator': yolo.detect.DetectionValidator,
'predictor': yolo.detect.DetectionPredictor, },
'segment': {
'model': SegmentationModel,
'trainer': yolo.segment.SegmentationTrainer,
'validator': yolo.segment.SegmentationValidator,
'predictor': yolo.segment.SegmentationPredictor, },
'pose': {
'model': PoseModel,
'trainer': yolo.pose.PoseTrainer,
'validator': yolo.pose.PoseValidator,
'predictor': yolo.pose.PosePredictor, },
'obb': {
'model': OBBModel,
'trainer': yolo.obb.OBBTrainer,
'validator': yolo.obb.OBBValidator,
'predictor': yolo.obb.OBBPredictor, }, }
"classify": {
"model": ClassificationModel,
"trainer": yolo.classify.ClassificationTrainer,
"validator": yolo.classify.ClassificationValidator,
"predictor": yolo.classify.ClassificationPredictor,
},
"detect": {
"model": DetectionModel,
"trainer": yolo.detect.DetectionTrainer,
"validator": yolo.detect.DetectionValidator,
"predictor": yolo.detect.DetectionPredictor,
},
"segment": {
"model": SegmentationModel,
"trainer": yolo.segment.SegmentationTrainer,
"validator": yolo.segment.SegmentationValidator,
"predictor": yolo.segment.SegmentationPredictor,
},
"pose": {
"model": PoseModel,
"trainer": yolo.pose.PoseTrainer,
"validator": yolo.pose.PoseValidator,
"predictor": yolo.pose.PosePredictor,
},
"obb": {
"model": OBBModel,
"trainer": yolo.obb.OBBTrainer,
"validator": yolo.obb.OBBValidator,
"predictor": yolo.obb.OBBPredictor,
},
}

@ -4,4 +4,4 @@ from .predict import OBBPredictor
from .train import OBBTrainer
from .val import OBBValidator
__all__ = 'OBBPredictor', 'OBBTrainer', 'OBBValidator'
__all__ = "OBBPredictor", "OBBTrainer", "OBBValidator"

@ -25,26 +25,27 @@ class OBBPredictor(DetectionPredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initializes OBBPredictor with optional model and data configuration overrides."""
super().__init__(cfg, overrides, _callbacks)
self.args.task = 'obb'
self.args.task = "obb"
def postprocess(self, preds, img, orig_imgs):
"""Post-processes predictions and returns a list of Results objects."""
preds = ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nc=len(self.model.names),
classes=self.args.classes,
rotated=True)
preds = ops.non_max_suppression(
preds,
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nc=len(self.model.names),
classes=self.args.classes,
rotated=True,
)
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
for i, (pred, orig_img) in enumerate(zip(preds, orig_imgs)):
for i, (pred, orig_img, img_path) in enumerate(zip(preds, orig_imgs, self.batch[0])):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape, xywh=True)
img_path = self.batch[0][i]
# xywh, r, conf, cls
obb = torch.cat([pred[:, :4], pred[:, -1:], pred[:, 4:6]], dim=-1)
results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb))

@ -25,12 +25,12 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
"""Initialize a OBBTrainer object with given arguments."""
if overrides is None:
overrides = {}
overrides['task'] = 'obb'
overrides["task"] = "obb"
super().__init__(cfg, overrides, _callbacks)
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return OBBModel initialized with specified config and weights."""
model = OBBModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
model = OBBModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
@ -38,5 +38,5 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
def get_validator(self):
"""Return an instance of OBBValidator for validation of YOLO model."""
self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
self.loss_names = "box_loss", "cls_loss", "dfl_loss"
return yolo.obb.OBBValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))

@ -27,26 +27,28 @@ class OBBValidator(DetectionValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'obb'
self.args.task = "obb"
self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True, on_plot=self.on_plot)
def init_metrics(self, model):
"""Initialize evaluation metrics for YOLO."""
super().init_metrics(model)
val = self.data.get(self.args.split, '') # validation path
self.is_dota = isinstance(val, str) and 'DOTA' in val # is COCO
val = self.data.get(self.args.split, "") # validation path
self.is_dota = isinstance(val, str) and "DOTA" in val # is COCO
def postprocess(self, preds):
"""Apply Non-maximum suppression to prediction outputs."""
return ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
labels=self.lb,
nc=self.nc,
multi_label=True,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
rotated=True)
return ops.non_max_suppression(
preds,
self.args.conf,
self.args.iou,
labels=self.lb,
nc=self.nc,
multi_label=True,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
rotated=True,
)
def _process_batch(self, detections, gt_bboxes, gt_cls):
"""
@ -66,12 +68,12 @@ class OBBValidator(DetectionValidator):
def _prepare_batch(self, si, batch):
"""Prepares and returns a batch for OBB validation."""
idx = batch['batch_idx'] == si
cls = batch['cls'][idx].squeeze(-1)
bbox = batch['bboxes'][idx]
ori_shape = batch['ori_shape'][si]
imgsz = batch['img'].shape[2:]
ratio_pad = batch['ratio_pad'][si]
idx = batch["batch_idx"] == si
cls = batch["cls"][idx].squeeze(-1)
bbox = batch["bboxes"][idx]
ori_shape = batch["ori_shape"][si]
imgsz = batch["img"].shape[2:]
ratio_pad = batch["ratio_pad"][si]
if len(cls):
bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
@ -81,18 +83,21 @@ class OBBValidator(DetectionValidator):
def _prepare_pred(self, pred, pbatch):
"""Prepares and returns a batch for OBB validation with scaled and padded bounding boxes."""
predn = pred.clone()
ops.scale_boxes(pbatch['imgsz'], predn[:, :4], pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'],
xywh=True) # native-space pred
ops.scale_boxes(
pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
) # native-space pred
return predn
def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""
plot_images(batch['img'],
*output_to_rotated_target(preds, max_det=self.args.max_det),
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names,
on_plot=self.on_plot) # pred
plot_images(
batch["img"],
*output_to_rotated_target(preds, max_det=self.args.max_det),
paths=batch["im_file"],
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
names=self.names,
on_plot=self.on_plot,
) # pred
def pred_to_json(self, predn, filename):
"""Serialize YOLO predictions to COCO json format."""
@ -101,12 +106,15 @@ class OBBValidator(DetectionValidator):
rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())):
self.jdict.append({
'image_id': image_id,
'category_id': self.class_map[int(predn[i, 5].item())],
'score': round(predn[i, 4].item(), 5),
'rbox': [round(x, 3) for x in r],
'poly': [round(x, 3) for x in b]})
self.jdict.append(
{
"image_id": image_id,
"category_id": self.class_map[int(predn[i, 5].item())],
"score": round(predn[i, 4].item(), 5),
"rbox": [round(x, 3) for x in r],
"poly": [round(x, 3) for x in b],
}
)
def save_one_txt(self, predn, save_conf, shape, file):
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
@ -116,8 +124,8 @@ class OBBValidator(DetectionValidator):
xywha[:, :4] /= gn
xyxyxyxy = ops.xywhr2xyxyxyxy(xywha).view(-1).tolist() # normalized xywh
line = (cls, *xyxyxyxy, conf) if save_conf else (cls, *xyxyxyxy) # label format
with open(file, 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')
with open(file, "a") as f:
f.write(("%g " * len(line)).rstrip() % line + "\n")
def eval_json(self, stats):
"""Evaluates YOLO output in JSON format and returns performance statistics."""
@ -125,42 +133,43 @@ class OBBValidator(DetectionValidator):
import json
import re
from collections import defaultdict
pred_json = self.save_dir / 'predictions.json' # predictions
pred_txt = self.save_dir / 'predictions_txt' # predictions
pred_json = self.save_dir / "predictions.json" # predictions
pred_txt = self.save_dir / "predictions_txt" # predictions
pred_txt.mkdir(parents=True, exist_ok=True)
data = json.load(open(pred_json))
# Save split results
LOGGER.info(f'Saving predictions with DOTA format to {str(pred_txt)}...')
LOGGER.info(f"Saving predictions with DOTA format to {str(pred_txt)}...")
for d in data:
image_id = d['image_id']
score = d['score']
classname = self.names[d['category_id']].replace(' ', '-')
image_id = d["image_id"]
score = d["score"]
classname = self.names[d["category_id"]].replace(" ", "-")
lines = '{} {} {} {} {} {} {} {} {} {}\n'.format(
lines = "{} {} {} {} {} {} {} {} {} {}\n".format(
image_id,
score,
d['poly'][0],
d['poly'][1],
d['poly'][2],
d['poly'][3],
d['poly'][4],
d['poly'][5],
d['poly'][6],
d['poly'][7],
d["poly"][0],
d["poly"][1],
d["poly"][2],
d["poly"][3],
d["poly"][4],
d["poly"][5],
d["poly"][6],
d["poly"][7],
)
with open(str(pred_txt / f'Task1_{classname}') + '.txt', 'a') as f:
with open(str(pred_txt / f"Task1_{classname}") + ".txt", "a") as f:
f.writelines(lines)
# Save merged results, this could result slightly lower map than using official merging script,
# because of the probiou calculation.
pred_merged_txt = self.save_dir / 'predictions_merged_txt' # predictions
pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions
pred_merged_txt.mkdir(parents=True, exist_ok=True)
merged_results = defaultdict(list)
LOGGER.info(f'Saving merged predictions with DOTA format to {str(pred_merged_txt)}...')
LOGGER.info(f"Saving merged predictions with DOTA format to {str(pred_merged_txt)}...")
for d in data:
image_id = d['image_id'].split('__')[0]
pattern = re.compile(r'\d+___\d+')
x, y = (int(c) for c in re.findall(pattern, d['image_id'])[0].split('___'))
bbox, score, cls = d['rbox'], d['score'], d['category_id']
image_id = d["image_id"].split("__")[0]
pattern = re.compile(r"\d+___\d+")
x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
bbox, score, cls = d["rbox"], d["score"], d["category_id"]
bbox[0] += x
bbox[1] += y
bbox.extend([score, cls])
@ -178,11 +187,11 @@ class OBBValidator(DetectionValidator):
b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
classname = self.names[int(x[-1])].replace(' ', '-')
classname = self.names[int(x[-1])].replace(" ", "-")
poly = [round(i, 3) for i in x[:-2]]
score = round(x[-2], 3)
lines = '{} {} {} {} {} {} {} {} {} {}\n'.format(
lines = "{} {} {} {} {} {} {} {} {} {}\n".format(
image_id,
score,
poly[0],
@ -194,7 +203,7 @@ class OBBValidator(DetectionValidator):
poly[6],
poly[7],
)
with open(str(pred_merged_txt / f'Task1_{classname}') + '.txt', 'a') as f:
with open(str(pred_merged_txt / f"Task1_{classname}") + ".txt", "a") as f:
f.writelines(lines)
return stats

@ -4,4 +4,4 @@ from .predict import PosePredictor
from .train import PoseTrainer
from .val import PoseValidator
__all__ = 'PoseTrainer', 'PoseValidator', 'PosePredictor'
__all__ = "PoseTrainer", "PoseValidator", "PosePredictor"

@ -23,20 +23,24 @@ class PosePredictor(DetectionPredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initializes PosePredictor, sets task to 'pose' and logs a warning for using 'mps' as device."""
super().__init__(cfg, overrides, _callbacks)
self.args.task = 'pose'
if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
LOGGER.warning("WARNING ⚠ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
'See https://github.com/ultralytics/ultralytics/issues/4031.')
self.args.task = "pose"
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
LOGGER.warning(
"WARNING ⚠ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
"See https://github.com/ultralytics/ultralytics/issues/4031."
)
def postprocess(self, preds, img, orig_imgs):
"""Return detection results for a given input image or list of images."""
preds = ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
classes=self.args.classes,
nc=len(self.model.names))
preds = ops.non_max_suppression(
preds,
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
classes=self.args.classes,
nc=len(self.model.names),
)
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
@ -49,5 +53,6 @@ class PosePredictor(DetectionPredictor):
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
img_path = self.batch[0][i]
results.append(
Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts))
Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)
)
return results

@ -26,16 +26,18 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
"""Initialize a PoseTrainer object with specified configurations and overrides."""
if overrides is None:
overrides = {}
overrides['task'] = 'pose'
overrides["task"] = "pose"
super().__init__(cfg, overrides, _callbacks)
if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
LOGGER.warning("WARNING ⚠ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
'See https://github.com/ultralytics/ultralytics/issues/4031.')
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
LOGGER.warning(
"WARNING ⚠ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
"See https://github.com/ultralytics/ultralytics/issues/4031."
)
def get_model(self, cfg=None, weights=None, verbose=True):
"""Get pose estimation model with specified configuration and weights."""
model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose)
model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose)
if weights:
model.load(weights)
@ -44,32 +46,33 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
def set_model_attributes(self):
"""Sets keypoints shape attribute of PoseModel."""
super().set_model_attributes()
self.model.kpt_shape = self.data['kpt_shape']
self.model.kpt_shape = self.data["kpt_shape"]
def get_validator(self):
"""Returns an instance of the PoseValidator class for validation."""
self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss'
return yolo.pose.PoseValidator(self.test_loader,
save_dir=self.save_dir,
args=copy(self.args),
_callbacks=self.callbacks)
self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
return yolo.pose.PoseValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
def plot_training_samples(self, batch, ni):
"""Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
images = batch['img']
kpts = batch['keypoints']
cls = batch['cls'].squeeze(-1)
bboxes = batch['bboxes']
paths = batch['im_file']
batch_idx = batch['batch_idx']
plot_images(images,
batch_idx,
cls,
bboxes,
kpts=kpts,
paths=paths,
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
images = batch["img"]
kpts = batch["keypoints"]
cls = batch["cls"].squeeze(-1)
bboxes = batch["bboxes"]
paths = batch["im_file"]
batch_idx = batch["batch_idx"]
plot_images(
images,
batch_idx,
cls,
bboxes,
kpts=kpts,
paths=paths,
fname=self.save_dir / f"train_batch{ni}.jpg",
on_plot=self.on_plot,
)
def plot_metrics(self):
"""Plots training/val metrics."""

@ -31,38 +31,53 @@ class PoseValidator(DetectionValidator):
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.sigma = None
self.kpt_shape = None
self.args.task = 'pose'
self.args.task = "pose"
self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
LOGGER.warning("WARNING ⚠ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
'See https://github.com/ultralytics/ultralytics/issues/4031.')
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
LOGGER.warning(
"WARNING ⚠ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
"See https://github.com/ultralytics/ultralytics/issues/4031."
)
def preprocess(self, batch):
"""Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
batch = super().preprocess(batch)
batch['keypoints'] = batch['keypoints'].to(self.device).float()
batch["keypoints"] = batch["keypoints"].to(self.device).float()
return batch
def get_desc(self):
"""Returns description of evaluation metrics in string format."""
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Pose(P',
'R', 'mAP50', 'mAP50-95)')
return ("%22s" + "%11s" * 10) % (
"Class",
"Images",
"Instances",
"Box(P",
"R",
"mAP50",
"mAP50-95)",
"Pose(P",
"R",
"mAP50",
"mAP50-95)",
)
def postprocess(self, preds):
"""Apply non-maximum suppression and return detections with high confidence scores."""
return ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=True,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
nc=self.nc)
return ops.non_max_suppression(
preds,
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=True,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
nc=self.nc,
)
def init_metrics(self, model):
"""Initiate pose estimation metrics for YOLO model."""
super().init_metrics(model)
self.kpt_shape = self.data['kpt_shape']
self.kpt_shape = self.data["kpt_shape"]
is_pose = self.kpt_shape == [17, 3]
nkpt = self.kpt_shape[0]
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
@ -71,21 +86,21 @@ class PoseValidator(DetectionValidator):
def _prepare_batch(self, si, batch):
"""Prepares a batch for processing by converting keypoints to float and moving to device."""
pbatch = super()._prepare_batch(si, batch)
kpts = batch['keypoints'][batch['batch_idx'] == si]
h, w = pbatch['imgsz']
kpts = batch["keypoints"][batch["batch_idx"] == si]
h, w = pbatch["imgsz"]
kpts = kpts.clone()
kpts[..., 0] *= w
kpts[..., 1] *= h
kpts = ops.scale_coords(pbatch['imgsz'], kpts, pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'])
pbatch['kpts'] = kpts
kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
pbatch["kpts"] = kpts
return pbatch
def _prepare_pred(self, pred, pbatch):
"""Prepares and scales keypoints in a batch for pose processing."""
predn = super()._prepare_pred(pred, pbatch)
nk = pbatch['kpts'].shape[1]
nk = pbatch["kpts"].shape[1]
pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
ops.scale_coords(pbatch['imgsz'], pred_kpts, pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'])
ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
return predn, pred_kpts
def update_metrics(self, preds, batch):
@ -93,14 +108,16 @@ class PoseValidator(DetectionValidator):
for si, pred in enumerate(preds):
self.seen += 1
npr = len(pred)
stat = dict(conf=torch.zeros(0, device=self.device),
pred_cls=torch.zeros(0, device=self.device),
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device))
stat = dict(
conf=torch.zeros(0, device=self.device),
pred_cls=torch.zeros(0, device=self.device),
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
)
pbatch = self._prepare_batch(si, batch)
cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox')
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
nl = len(cls)
stat['target_cls'] = cls
stat["target_cls"] = cls
if npr == 0:
if nl:
for k in self.stats.keys():
@ -113,13 +130,13 @@ class PoseValidator(DetectionValidator):
if self.args.single_cls:
pred[:, 5] = 0
predn, pred_kpts = self._prepare_pred(pred, pbatch)
stat['conf'] = predn[:, 4]
stat['pred_cls'] = predn[:, 5]
stat["conf"] = predn[:, 4]
stat["pred_cls"] = predn[:, 5]
# Evaluate
if nl:
stat['tp'] = self._process_batch(predn, bbox, cls)
stat['tp_p'] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch['kpts'])
stat["tp"] = self._process_batch(predn, bbox, cls)
stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"])
if self.args.plots:
self.confusion_matrix.process_batch(predn, bbox, cls)
@ -128,7 +145,7 @@ class PoseValidator(DetectionValidator):
# Save
if self.args.save_json:
self.pred_to_json(predn, batch['im_file'][si])
self.pred_to_json(predn, batch["im_file"][si])
# if self.args.save_txt:
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
@ -159,26 +176,30 @@ class PoseValidator(DetectionValidator):
def plot_val_samples(self, batch, ni):
"""Plots and saves validation set samples with predicted bounding boxes and keypoints."""
plot_images(batch['img'],
batch['batch_idx'],
batch['cls'].squeeze(-1),
batch['bboxes'],
kpts=batch['keypoints'],
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names,
on_plot=self.on_plot)
plot_images(
batch["img"],
batch["batch_idx"],
batch["cls"].squeeze(-1),
batch["bboxes"],
kpts=batch["keypoints"],
paths=batch["im_file"],
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
names=self.names,
on_plot=self.on_plot,
)
def plot_predictions(self, batch, preds, ni):
"""Plots predictions for YOLO model."""
pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
plot_images(batch['img'],
*output_to_target(preds, max_det=self.args.max_det),
kpts=pred_kpts,
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names,
on_plot=self.on_plot) # pred
plot_images(
batch["img"],
*output_to_target(preds, max_det=self.args.max_det),
kpts=pred_kpts,
paths=batch["im_file"],
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
names=self.names,
on_plot=self.on_plot,
) # pred
def pred_to_json(self, predn, filename):
"""Converts YOLO predictions to COCO JSON format."""
@ -187,37 +208,41 @@ class PoseValidator(DetectionValidator):
box = ops.xyxy2xywh(predn[:, :4]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
for p, b in zip(predn.tolist(), box.tolist()):
self.jdict.append({
'image_id': image_id,
'category_id': self.class_map[int(p[5])],
'bbox': [round(x, 3) for x in b],
'keypoints': p[6:],
'score': round(p[4], 5)})
self.jdict.append(
{
"image_id": image_id,
"category_id": self.class_map[int(p[5])],
"bbox": [round(x, 3) for x in b],
"keypoints": p[6:],
"score": round(p[4], 5),
}
)
def eval_json(self, stats):
"""Evaluates object detection model using COCO JSON format."""
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / 'annotations/person_keypoints_val2017.json' # annotations
pred_json = self.save_dir / 'predictions.json' # predictions
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
pred_json = self.save_dir / "predictions.json" # predictions
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements('pycocotools>=2.0.6')
check_requirements("pycocotools>=2.0.6")
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
assert x.is_file(), f'{x} file not found'
assert x.is_file(), f"{x} file not found"
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'keypoints')]):
for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "keypoints")]):
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
eval.evaluate()
eval.accumulate()
eval.summarize()
idx = i * 4 + 2
stats[self.metrics.keys[idx + 1]], stats[
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[
:2
] # update mAP50-95 and mAP50
except Exception as e:
LOGGER.warning(f'pycocotools unable to run: {e}')
LOGGER.warning(f"pycocotools unable to run: {e}")
return stats

@ -4,4 +4,4 @@ from .predict import SegmentationPredictor
from .train import SegmentationTrainer
from .val import SegmentationValidator
__all__ = 'SegmentationPredictor', 'SegmentationTrainer', 'SegmentationValidator'
__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator"

@ -23,17 +23,19 @@ class SegmentationPredictor(DetectionPredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initializes the SegmentationPredictor with the provided configuration, overrides, and callbacks."""
super().__init__(cfg, overrides, _callbacks)
self.args.task = 'segment'
self.args.task = "segment"
def postprocess(self, preds, img, orig_imgs):
"""Applies non-max suppression and processes detections for each image in an input batch."""
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nc=len(self.model.names),
classes=self.args.classes)
p = ops.non_max_suppression(
preds[0],
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nc=len(self.model.names),
classes=self.args.classes,
)
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

@ -26,12 +26,12 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
"""Initialize a SegmentationTrainer object with given arguments."""
if overrides is None:
overrides = {}
overrides['task'] = 'segment'
overrides["task"] = "segment"
super().__init__(cfg, overrides, _callbacks)
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return SegmentationModel initialized with specified config and weights."""
model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
@ -39,22 +39,23 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
def get_validator(self):
"""Return an instance of SegmentationValidator for validation of YOLO model."""
self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
return yolo.segment.SegmentationValidator(self.test_loader,
save_dir=self.save_dir,
args=copy(self.args),
_callbacks=self.callbacks)
self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
return yolo.segment.SegmentationValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
def plot_training_samples(self, batch, ni):
"""Creates a plot of training sample images with labels and box coordinates."""
plot_images(batch['img'],
batch['batch_idx'],
batch['cls'].squeeze(-1),
batch['bboxes'],
masks=batch['masks'],
paths=batch['im_file'],
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
plot_images(
batch["img"],
batch["batch_idx"],
batch["cls"].squeeze(-1),
batch["bboxes"],
masks=batch["masks"],
paths=batch["im_file"],
fname=self.save_dir / f"train_batch{ni}.jpg",
on_plot=self.on_plot,
)
def plot_metrics(self):
"""Plots training/val metrics."""

@ -33,13 +33,13 @@ class SegmentationValidator(DetectionValidator):
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.plot_masks = None
self.process = None
self.args.task = 'segment'
self.args.task = "segment"
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
def preprocess(self, batch):
"""Preprocesses batch by converting masks to float and sending to device."""
batch = super().preprocess(batch)
batch['masks'] = batch['masks'].to(self.device).float()
batch["masks"] = batch["masks"].to(self.device).float()
return batch
def init_metrics(self, model):
@ -47,7 +47,7 @@ class SegmentationValidator(DetectionValidator):
super().init_metrics(model)
self.plot_masks = []
if self.args.save_json:
check_requirements('pycocotools>=2.0.6')
check_requirements("pycocotools>=2.0.6")
self.process = ops.process_mask_upsample # more accurate
else:
self.process = ops.process_mask # faster
@ -55,33 +55,46 @@ class SegmentationValidator(DetectionValidator):
def get_desc(self):
"""Return a formatted description of evaluation metrics."""
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P',
'R', 'mAP50', 'mAP50-95)')
return ("%22s" + "%11s" * 10) % (
"Class",
"Images",
"Instances",
"Box(P",
"R",
"mAP50",
"mAP50-95)",
"Mask(P",
"R",
"mAP50",
"mAP50-95)",
)
def postprocess(self, preds):
"""Post-processes YOLO predictions and returns output detections with proto."""
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=True,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
nc=self.nc)
p = ops.non_max_suppression(
preds[0],
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=True,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
nc=self.nc,
)
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
return p, proto
def _prepare_batch(self, si, batch):
"""Prepares a batch for training or inference by processing images and targets."""
prepared_batch = super()._prepare_batch(si, batch)
midx = [si] if self.args.overlap_mask else batch['batch_idx'] == si
prepared_batch['masks'] = batch['masks'][midx]
midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si
prepared_batch["masks"] = batch["masks"][midx]
return prepared_batch
def _prepare_pred(self, pred, pbatch, proto):
"""Prepares a batch for training or inference by processing images and targets."""
predn = super()._prepare_pred(pred, pbatch)
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch['imgsz'])
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch["imgsz"])
return predn, pred_masks
def update_metrics(self, preds, batch):
@ -89,14 +102,16 @@ class SegmentationValidator(DetectionValidator):
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
self.seen += 1
npr = len(pred)
stat = dict(conf=torch.zeros(0, device=self.device),
pred_cls=torch.zeros(0, device=self.device),
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device))
stat = dict(
conf=torch.zeros(0, device=self.device),
pred_cls=torch.zeros(0, device=self.device),
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
)
pbatch = self._prepare_batch(si, batch)
cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox')
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
nl = len(cls)
stat['target_cls'] = cls
stat["target_cls"] = cls
if npr == 0:
if nl:
for k in self.stats.keys():
@ -106,24 +121,20 @@ class SegmentationValidator(DetectionValidator):
continue
# Masks
gt_masks = pbatch.pop('masks')
gt_masks = pbatch.pop("masks")
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
predn, pred_masks = self._prepare_pred(pred, pbatch, proto)
stat['conf'] = predn[:, 4]
stat['pred_cls'] = predn[:, 5]
stat["conf"] = predn[:, 4]
stat["pred_cls"] = predn[:, 5]
# Evaluate
if nl:
stat['tp'] = self._process_batch(predn, bbox, cls)
stat['tp_m'] = self._process_batch(predn,
bbox,
cls,
pred_masks,
gt_masks,
self.args.overlap_mask,
masks=True)
stat["tp"] = self._process_batch(predn, bbox, cls)
stat["tp_m"] = self._process_batch(
predn, bbox, cls, pred_masks, gt_masks, self.args.overlap_mask, masks=True
)
if self.args.plots:
self.confusion_matrix.process_batch(predn, bbox, cls)
@ -136,10 +147,12 @@ class SegmentationValidator(DetectionValidator):
# Save
if self.args.save_json:
pred_masks = ops.scale_image(pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
pbatch['ori_shape'],
ratio_pad=batch['ratio_pad'][si])
self.pred_to_json(predn, batch['im_file'][si], pred_masks)
pred_masks = ops.scale_image(
pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
pbatch["ori_shape"],
ratio_pad=batch["ratio_pad"][si],
)
self.pred_to_json(predn, batch["im_file"][si], pred_masks)
# if self.args.save_txt:
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
@ -166,7 +179,7 @@ class SegmentationValidator(DetectionValidator):
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
if gt_masks.shape[1:] != pred_masks.shape[1:]:
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0]
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
gt_masks = gt_masks.gt_(0.5)
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
else: # boxes
@ -176,26 +189,29 @@ class SegmentationValidator(DetectionValidator):
def plot_val_samples(self, batch, ni):
"""Plots validation samples with bounding box labels."""
plot_images(batch['img'],
batch['batch_idx'],
batch['cls'].squeeze(-1),
batch['bboxes'],
masks=batch['masks'],
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names,
on_plot=self.on_plot)
plot_images(
batch["img"],
batch["batch_idx"],
batch["cls"].squeeze(-1),
batch["bboxes"],
masks=batch["masks"],
paths=batch["im_file"],
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
names=self.names,
on_plot=self.on_plot,
)
def plot_predictions(self, batch, preds, ni):
"""Plots batch predictions with masks and bounding boxes."""
plot_images(
batch['img'],
batch["img"],
*output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
paths=batch["im_file"],
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
names=self.names,
on_plot=self.on_plot) # pred
on_plot=self.on_plot,
) # pred
self.plot_masks.clear()
def pred_to_json(self, predn, filename, pred_masks):
@ -205,8 +221,8 @@ class SegmentationValidator(DetectionValidator):
def single_encode(x):
"""Encode predicted masks as RLE and append results to jdict."""
rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0]
rle['counts'] = rle['counts'].decode('utf-8')
rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
rle["counts"] = rle["counts"].decode("utf-8")
return rle
stem = Path(filename).stem
@ -217,37 +233,41 @@ class SegmentationValidator(DetectionValidator):
with ThreadPool(NUM_THREADS) as pool:
rles = pool.map(single_encode, pred_masks)
for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
self.jdict.append({
'image_id': image_id,
'category_id': self.class_map[int(p[5])],
'bbox': [round(x, 3) for x in b],
'score': round(p[4], 5),
'segmentation': rles[i]})
self.jdict.append(
{
"image_id": image_id,
"category_id": self.class_map[int(p[5])],
"bbox": [round(x, 3) for x in b],
"score": round(p[4], 5),
"segmentation": rles[i],
}
)
def eval_json(self, stats):
"""Return COCO-style object detection evaluation metrics."""
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
pred_json = self.save_dir / 'predictions.json' # predictions
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
pred_json = self.save_dir / "predictions.json" # predictions
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements('pycocotools>=2.0.6')
check_requirements("pycocotools>=2.0.6")
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
assert x.is_file(), f'{x} file not found'
assert x.is_file(), f"{x} file not found"
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]):
for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "segm")]):
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
eval.evaluate()
eval.accumulate()
eval.summarize()
idx = i * 4 + 2
stats[self.metrics.keys[idx + 1]], stats[
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[
:2
] # update mAP50-95 and mAP50
except Exception as e:
LOGGER.warning(f'pycocotools unable to run: {e}')
LOGGER.warning(f"pycocotools unable to run: {e}")
return stats

@ -1,9 +1,29 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .tasks import (BaseModel, ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
attempt_load_weights, guess_model_scale, guess_model_task, parse_model, torch_safe_load,
yaml_model_load)
from .tasks import (
BaseModel,
ClassificationModel,
DetectionModel,
SegmentationModel,
attempt_load_one_weight,
attempt_load_weights,
guess_model_scale,
guess_model_task,
parse_model,
torch_safe_load,
yaml_model_load,
)
__all__ = ('attempt_load_one_weight', 'attempt_load_weights', 'parse_model', 'yaml_model_load', 'guess_model_task',
'guess_model_scale', 'torch_safe_load', 'DetectionModel', 'SegmentationModel', 'ClassificationModel',
'BaseModel')
__all__ = (
"attempt_load_one_weight",
"attempt_load_weights",
"parse_model",
"yaml_model_load",
"guess_model_task",
"guess_model_scale",
"torch_safe_load",
"DetectionModel",
"SegmentationModel",
"ClassificationModel",
"BaseModel",
)

@ -32,10 +32,12 @@ def check_class_names(names):
names = {int(k): str(v) for k, v in names.items()}
n = len(names)
if max(names.keys()) >= n:
raise KeyError(f'{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices '
f'{min(names.keys())}-{max(names.keys())} defined in your dataset YAML.')
if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
names_map = yaml_load(ROOT / 'cfg/datasets/ImageNet.yaml')['map'] # human-readable names
raise KeyError(
f"{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices "
f"{min(names.keys())}-{max(names.keys())} defined in your dataset YAML."
)
if isinstance(names[0], str) and names[0].startswith("n0"): # imagenet class codes, i.e. 'n01440764'
names_map = yaml_load(ROOT / "cfg/datasets/ImageNet.yaml")["map"] # human-readable names
names = {k: names_map[v] for k, v in names.items()}
return names
@ -44,8 +46,8 @@ def default_class_names(data=None):
"""Applies default class names to an input YAML file or returns numerical class names."""
if data:
with contextlib.suppress(Exception):
return yaml_load(check_yaml(data))['names']
return {i: f'class{i}' for i in range(999)} # return default if above errors
return yaml_load(check_yaml(data))["names"]
return {i: f"class{i}" for i in range(999)} # return default if above errors
class AutoBackend(nn.Module):
@ -77,14 +79,16 @@ class AutoBackend(nn.Module):
"""
@torch.no_grad()
def __init__(self,
weights='yolov8n.pt',
device=torch.device('cpu'),
dnn=False,
data=None,
fp16=False,
fuse=True,
verbose=True):
def __init__(
self,
weights="yolov8n.pt",
device=torch.device("cpu"),
dnn=False,
data=None,
fp16=False,
fuse=True,
verbose=True,
):
"""
Initialize the AutoBackend for inference.
@ -100,17 +104,31 @@ class AutoBackend(nn.Module):
super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights)
nn_module = isinstance(weights, torch.nn.Module)
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, triton = \
self._model_type(w)
(
pt,
jit,
onnx,
xml,
engine,
coreml,
saved_model,
pb,
tflite,
edgetpu,
tfjs,
paddle,
ncnn,
triton,
) = self._model_type(w)
fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
stride = 32 # default stride
model, metadata = None, None
# Set device
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
cuda = torch.cuda.is_available() and device.type != "cpu" # use CUDA
if cuda and not any([nn_module, pt, jit, engine, onnx]): # GPU dataloader formats
device = torch.device('cpu')
device = torch.device("cpu")
cuda = False
# Download if not local
@ -121,77 +139,79 @@ class AutoBackend(nn.Module):
if nn_module: # in-memory PyTorch model
model = weights.to(device)
model = model.fuse(verbose=verbose) if fuse else model
if hasattr(model, 'kpt_shape'):
if hasattr(model, "kpt_shape"):
kpt_shape = model.kpt_shape # pose-only
stride = max(int(model.stride.max()), 32) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
names = model.module.names if hasattr(model, "module") else model.names # get class names
model.half() if fp16 else model.float()
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
pt = True
elif pt: # PyTorch
from ultralytics.nn.tasks import attempt_load_weights
model = attempt_load_weights(weights if isinstance(weights, list) else w,
device=device,
inplace=True,
fuse=fuse)
if hasattr(model, 'kpt_shape'):
model = attempt_load_weights(
weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse
)
if hasattr(model, "kpt_shape"):
kpt_shape = model.kpt_shape # pose-only
stride = max(int(model.stride.max()), 32) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
names = model.module.names if hasattr(model, "module") else model.names # get class names
model.half() if fp16 else model.float()
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
elif jit: # TorchScript
LOGGER.info(f'Loading {w} for TorchScript inference...')
extra_files = {'config.txt': ''} # model metadata
LOGGER.info(f"Loading {w} for TorchScript inference...")
extra_files = {"config.txt": ""} # model metadata
model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
model.half() if fp16 else model.float()
if extra_files['config.txt']: # load metadata dict
metadata = json.loads(extra_files['config.txt'], object_hook=lambda x: dict(x.items()))
if extra_files["config.txt"]: # load metadata dict
metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items()))
elif dnn: # ONNX OpenCV DNN
LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
check_requirements('opencv-python>=4.5.4')
LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...")
check_requirements("opencv-python>=4.5.4")
net = cv2.dnn.readNetFromONNX(w)
elif onnx: # ONNX Runtime
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
LOGGER.info(f"Loading {w} for ONNX Runtime inference...")
check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime"))
import onnxruntime
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if cuda else ["CPUExecutionProvider"]
session = onnxruntime.InferenceSession(w, providers=providers)
output_names = [x.name for x in session.get_outputs()]
metadata = session.get_modelmeta().custom_metadata_map # metadata
elif xml: # OpenVINO
LOGGER.info(f'Loading {w} for OpenVINO inference...')
check_requirements('openvino>=2023.0') # requires openvino-dev: https://pypi.org/project/openvino-dev/
LOGGER.info(f"Loading {w} for OpenVINO inference...")
check_requirements("openvino>=2023.0") # requires openvino-dev: https://pypi.org/project/openvino-dev/
from openvino.runtime import Core, Layout, get_batch # noqa
core = Core()
w = Path(w)
if not w.is_file(): # if not *.xml
w = next(w.glob('*.xml')) # get *.xml file from *_openvino_model dir
ov_model = core.read_model(model=str(w), weights=w.with_suffix('.bin'))
w = next(w.glob("*.xml")) # get *.xml file from *_openvino_model dir
ov_model = core.read_model(model=str(w), weights=w.with_suffix(".bin"))
if ov_model.get_parameters()[0].get_layout().empty:
ov_model.get_parameters()[0].set_layout(Layout('NCHW'))
ov_model.get_parameters()[0].set_layout(Layout("NCHW"))
batch_dim = get_batch(ov_model)
if batch_dim.is_static:
batch_size = batch_dim.get_length()
ov_compiled_model = core.compile_model(ov_model, device_name='AUTO') # AUTO selects best available device
metadata = w.parent / 'metadata.yaml'
ov_compiled_model = core.compile_model(ov_model, device_name="AUTO") # AUTO selects best available device
metadata = w.parent / "metadata.yaml"
elif engine: # TensorRT
LOGGER.info(f'Loading {w} for TensorRT inference...')
LOGGER.info(f"Loading {w} for TensorRT inference...")
try:
import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download
except ImportError:
if LINUX:
check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
check_requirements("nvidia-tensorrt", cmds="-U --index-url https://pypi.ngc.nvidia.com")
import tensorrt as trt # noqa
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
if device.type == 'cpu':
device = torch.device('cuda:0')
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
check_version(trt.__version__, "7.0.0", hard=True) # require tensorrt>=7.0.0
if device.type == "cpu":
device = torch.device("cuda:0")
Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr"))
logger = trt.Logger(trt.Logger.INFO)
# Read file
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
meta_len = int.from_bytes(f.read(4), byteorder='little') # read metadata length
metadata = json.loads(f.read(meta_len).decode('utf-8')) # read metadata
with open(w, "rb") as f, trt.Runtime(logger) as runtime:
meta_len = int.from_bytes(f.read(4), byteorder="little") # read metadata length
metadata = json.loads(f.read(meta_len).decode("utf-8")) # read metadata
model = runtime.deserialize_cuda_engine(f.read()) # read engine
context = model.create_execution_context()
bindings = OrderedDict()
@ -213,116 +233,124 @@ class AutoBackend(nn.Module):
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
batch_size = bindings["images"].shape[0] # if dynamic, this is instead max batch size
elif coreml: # CoreML
LOGGER.info(f'Loading {w} for CoreML inference...')
LOGGER.info(f"Loading {w} for CoreML inference...")
import coremltools as ct
model = ct.models.MLModel(w)
metadata = dict(model.user_defined_metadata)
elif saved_model: # TF SavedModel
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...")
import tensorflow as tf
keras = False # assume TF1 saved_model
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
metadata = Path(w) / 'metadata.yaml'
metadata = Path(w) / "metadata.yaml"
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...")
import tensorflow as tf
from ultralytics.engine.exporter import gd_outputs
def wrap_frozen_graph(gd, inputs, outputs):
"""Wrap frozen graphs for deployment."""
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # wrapped
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
ge = x.graph.as_graph_element
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
gd = tf.Graph().as_graph_def() # TF GraphDef
with open(w, 'rb') as f:
with open(w, "rb") as f:
gd.ParseFromString(f.read())
frozen_func = wrap_frozen_graph(gd, inputs='x:0', outputs=gd_outputs(gd))
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
from tflite_runtime.interpreter import Interpreter, load_delegate
except ImportError:
import tensorflow as tf
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate
if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
delegate = {
'Linux': 'libedgetpu.so.1',
'Darwin': 'libedgetpu.1.dylib',
'Windows': 'edgetpu.dll'}[platform.system()]
LOGGER.info(f"Loading {w} for TensorFlow Lite Edge TPU inference...")
delegate = {"Linux": "libedgetpu.so.1", "Darwin": "libedgetpu.1.dylib", "Windows": "edgetpu.dll"}[
platform.system()
]
interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
else: # TFLite
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
LOGGER.info(f"Loading {w} for TensorFlow Lite inference...")
interpreter = Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
# Load metadata
with contextlib.suppress(zipfile.BadZipFile):
with zipfile.ZipFile(w, 'r') as model:
with zipfile.ZipFile(w, "r") as model:
meta_file = model.namelist()[0]
metadata = ast.literal_eval(model.read(meta_file).decode('utf-8'))
metadata = ast.literal_eval(model.read(meta_file).decode("utf-8"))
elif tfjs: # TF.js
raise NotImplementedError('YOLOv8 TF.js inference is not currently supported.')
raise NotImplementedError("YOLOv8 TF.js inference is not currently supported.")
elif paddle: # PaddlePaddle
LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
LOGGER.info(f"Loading {w} for PaddlePaddle inference...")
check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle")
import paddle.inference as pdi # noqa
w = Path(w)
if not w.is_file(): # if not *.pdmodel
w = next(w.rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir
config = pdi.Config(str(w), str(w.with_suffix('.pdiparams')))
w = next(w.rglob("*.pdmodel")) # get *.pdmodel file from *_paddle_model dir
config = pdi.Config(str(w), str(w.with_suffix(".pdiparams")))
if cuda:
config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
predictor = pdi.create_predictor(config)
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
output_names = predictor.get_output_names()
metadata = w.parents[1] / 'metadata.yaml'
metadata = w.parents[1] / "metadata.yaml"
elif ncnn: # ncnn
LOGGER.info(f'Loading {w} for ncnn inference...')
check_requirements('git+https://github.com/Tencent/ncnn.git' if ARM64 else 'ncnn') # requires ncnn
LOGGER.info(f"Loading {w} for ncnn inference...")
check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires ncnn
import ncnn as pyncnn
net = pyncnn.Net()
net.opt.use_vulkan_compute = cuda
w = Path(w)
if not w.is_file(): # if not *.param
w = next(w.glob('*.param')) # get *.param file from *_ncnn_model dir
w = next(w.glob("*.param")) # get *.param file from *_ncnn_model dir
net.load_param(str(w))
net.load_model(str(w.with_suffix('.bin')))
metadata = w.parent / 'metadata.yaml'
net.load_model(str(w.with_suffix(".bin")))
metadata = w.parent / "metadata.yaml"
elif triton: # NVIDIA Triton Inference Server
check_requirements('tritonclient[all]')
check_requirements("tritonclient[all]")
from ultralytics.utils.triton import TritonRemoteModel
model = TritonRemoteModel(w)
else:
from ultralytics.engine.exporter import export_formats
raise TypeError(f"model='{w}' is not a supported model format. "
'See https://docs.ultralytics.com/modes/predict for help.'
f'\n\n{export_formats()}')
raise TypeError(
f"model='{w}' is not a supported model format. "
"See https://docs.ultralytics.com/modes/predict for help."
f"\n\n{export_formats()}"
)
# Load external metadata YAML
if isinstance(metadata, (str, Path)) and Path(metadata).exists():
metadata = yaml_load(metadata)
if metadata:
for k, v in metadata.items():
if k in ('stride', 'batch'):
if k in ("stride", "batch"):
metadata[k] = int(v)
elif k in ('imgsz', 'names', 'kpt_shape') and isinstance(v, str):
elif k in ("imgsz", "names", "kpt_shape") and isinstance(v, str):
metadata[k] = eval(v)
stride = metadata['stride']
task = metadata['task']
batch = metadata['batch']
imgsz = metadata['imgsz']
names = metadata['names']
kpt_shape = metadata.get('kpt_shape')
stride = metadata["stride"]
task = metadata["task"]
batch = metadata["batch"]
imgsz = metadata["imgsz"]
names = metadata["names"]
kpt_shape = metadata.get("kpt_shape")
elif not (pt or triton or nn_module):
LOGGER.warning(f"WARNING ⚠ Metadata not found for 'model={weights}'")
# Check names
if 'names' not in locals(): # names missing
if "names" not in locals(): # names missing
names = default_class_names(data)
names = check_class_names(names)
@ -367,26 +395,28 @@ class AutoBackend(nn.Module):
im = im.cpu().numpy() # FP32
y = list(self.ov_compiled_model(im).values())
elif self.engine: # TensorRT
if self.dynamic and im.shape != self.bindings['images'].shape:
i = self.model.get_binding_index('images')
if self.dynamic and im.shape != self.bindings["images"].shape:
i = self.model.get_binding_index("images")
self.context.set_binding_shape(i, im.shape) # reshape if dynamic
self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape)
for name in self.output_names:
i = self.model.get_binding_index(name)
self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
s = self.bindings['images'].shape
s = self.bindings["images"].shape
assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
self.binding_addrs['images'] = int(im.data_ptr())
self.binding_addrs["images"] = int(im.data_ptr())
self.context.execute_v2(list(self.binding_addrs.values()))
y = [self.bindings[x].data for x in sorted(self.output_names)]
elif self.coreml: # CoreML
im = im[0].cpu().numpy()
im_pil = Image.fromarray((im * 255).astype('uint8'))
im_pil = Image.fromarray((im * 255).astype("uint8"))
# im = im.resize((192, 320), Image.BILINEAR)
y = self.model.predict({'image': im_pil}) # coordinates are xywh normalized
if 'confidence' in y:
raise TypeError('Ultralytics only supports inference of non-pipelined CoreML models exported with '
f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export.")
y = self.model.predict({"image": im_pil}) # coordinates are xywh normalized
if "confidence" in y:
raise TypeError(
"Ultralytics only supports inference of non-pipelined CoreML models exported with "
f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export."
)
# TODO: CoreML NMS inference handling
# from ultralytics.utils.ops import xywh2xyxy
# box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
@ -425,20 +455,20 @@ class AutoBackend(nn.Module):
if len(y) == 2 and len(self.names) == 999: # segments and names not defined
ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes
nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400)
self.names = {i: f'class{i}' for i in range(nc)}
self.names = {i: f"class{i}" for i in range(nc)}
else: # Lite or Edge TPU
details = self.input_details[0]
integer = details['dtype'] in (np.int8, np.int16) # is TFLite quantized int8 or int16 model
integer = details["dtype"] in (np.int8, np.int16) # is TFLite quantized int8 or int16 model
if integer:
scale, zero_point = details['quantization']
im = (im / scale + zero_point).astype(details['dtype']) # de-scale
self.interpreter.set_tensor(details['index'], im)
scale, zero_point = details["quantization"]
im = (im / scale + zero_point).astype(details["dtype"]) # de-scale
self.interpreter.set_tensor(details["index"], im)
self.interpreter.invoke()
y = []
for output in self.output_details:
x = self.interpreter.get_tensor(output['index'])
x = self.interpreter.get_tensor(output["index"])
if integer:
scale, zero_point = output['quantization']
scale, zero_point = output["quantization"]
x = (x.astype(np.float32) - zero_point) * scale # re-scale
if x.ndim > 2: # if task is not classification
# Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695
@ -483,13 +513,13 @@ class AutoBackend(nn.Module):
(None): This method runs the forward pass and don't return any value
"""
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
if any(warmup_types) and (self.device.type != "cpu" or self.triton):
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
for _ in range(2 if self.jit else 1):
self.forward(im) # warmup
@staticmethod
def _model_type(p='path/to/model.pt'):
def _model_type(p="path/to/model.pt"):
"""
This function takes a path to a model file and returns the model type.
@ -499,18 +529,20 @@ class AutoBackend(nn.Module):
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
from ultralytics.engine.exporter import export_formats
sf = list(export_formats().Suffix) # export suffixes
if not is_url(p, check=False) and not isinstance(p, str):
check_suffix(p, sf) # checks
name = Path(p).name
types = [s in name for s in sf]
types[5] |= name.endswith('.mlmodel') # retain support for older Apple CoreML *.mlmodel formats
types[5] |= name.endswith(".mlmodel") # retain support for older Apple CoreML *.mlmodel formats
types[8] &= not types[9] # tflite &= not edgetpu
if any(types):
triton = False
else:
from urllib.parse import urlsplit
url = urlsplit(p)
triton = url.netloc and url.path and url.scheme in {'http', 'grpc'}
triton = url.netloc and url.path and url.scheme in {"http", "grpc"}
return types + [triton]

@ -17,18 +17,101 @@ Example:
```
"""
from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, GhostBottleneck,
HGBlock, HGStem, Proto, RepC3, ResNetLayer)
from .conv import (CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus,
GhostConv, LightConv, RepConv, SpatialAttention)
from .block import (
C1,
C2,
C3,
C3TR,
DFL,
SPP,
SPPF,
Bottleneck,
BottleneckCSP,
C2f,
C3Ghost,
C3x,
GhostBottleneck,
HGBlock,
HGStem,
Proto,
RepC3,
ResNetLayer,
)
from .conv import (
CBAM,
ChannelAttention,
Concat,
Conv,
Conv2,
ConvTranspose,
DWConv,
DWConvTranspose2d,
Focus,
GhostConv,
LightConv,
RepConv,
SpatialAttention,
)
from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment
from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d,
MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer)
from .transformer import (
AIFI,
MLP,
DeformableTransformerDecoder,
DeformableTransformerDecoderLayer,
LayerNorm2d,
MLPBlock,
MSDeformAttn,
TransformerBlock,
TransformerEncoderLayer,
TransformerLayer,
)
__all__ = ('Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus',
'GhostConv', 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer',
'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3',
'C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect',
'Segment', 'Pose', 'Classify', 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI',
'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP', 'ResNetLayer',
'OBB')
__all__ = (
"Conv",
"Conv2",
"LightConv",
"RepConv",
"DWConv",
"DWConvTranspose2d",
"ConvTranspose",
"Focus",
"GhostConv",
"ChannelAttention",
"SpatialAttention",
"CBAM",
"Concat",
"TransformerLayer",
"TransformerBlock",
"MLPBlock",
"LayerNorm2d",
"DFL",
"HGBlock",
"HGStem",
"SPP",
"SPPF",
"C1",
"C2",
"C3",
"C2f",
"C3x",
"C3TR",
"C3Ghost",
"GhostBottleneck",
"Bottleneck",
"BottleneckCSP",
"Proto",
"Detect",
"Segment",
"Pose",
"Classify",
"TransformerEncoderLayer",
"RepC3",
"RTDETRDecoder",
"AIFI",
"DeformableTransformerDecoder",
"DeformableTransformerDecoderLayer",
"MSDeformAttn",
"MLP",
"ResNetLayer",
"OBB",
)

@ -8,8 +8,26 @@ import torch.nn.functional as F
from .conv import Conv, DWConv, GhostConv, LightConv, RepConv
from .transformer import TransformerBlock
__all__ = ('DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 'C2f', 'C3x', 'C3TR', 'C3Ghost',
'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'RepC3', 'ResNetLayer')
__all__ = (
"DFL",
"HGBlock",
"HGStem",
"SPP",
"SPPF",
"C1",
"C2",
"C3",
"C2f",
"C3x",
"C3TR",
"C3Ghost",
"GhostBottleneck",
"Bottleneck",
"BottleneckCSP",
"Proto",
"RepC3",
"ResNetLayer",
)
class DFL(nn.Module):
@ -284,9 +302,11 @@ class GhostBottleneck(nn.Module):
self.conv = nn.Sequential(
GhostConv(c1, c_, 1, 1), # pw
DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1,
act=False)) if s == 2 else nn.Identity()
GhostConv(c_, c2, 1, 1, act=False), # pw-linear
)
self.shortcut = (
nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
)
def forward(self, x):
"""Applies skip connection and concatenation to input tensor."""
@ -359,8 +379,9 @@ class ResNetLayer(nn.Module):
self.is_first = is_first
if self.is_first:
self.layer = nn.Sequential(Conv(c1, c2, k=7, s=2, p=3, act=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
self.layer = nn.Sequential(
Conv(c1, c2, k=7, s=2, p=3, act=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
else:
blocks = [ResNetBlock(c1, c2, s, e=e)]
blocks.extend([ResNetBlock(e * c2, c2, 1, e=e) for _ in range(n - 1)])

@ -7,8 +7,21 @@ import numpy as np
import torch
import torch.nn as nn
__all__ = ('Conv', 'Conv2', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv',
'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'RepConv')
__all__ = (
"Conv",
"Conv2",
"LightConv",
"DWConv",
"DWConvTranspose2d",
"ConvTranspose",
"Focus",
"GhostConv",
"ChannelAttention",
"SpatialAttention",
"CBAM",
"Concat",
"RepConv",
)
def autopad(k, p=None, d=1): # kernel, padding, dilation
@ -22,6 +35,7 @@ def autopad(k, p=None, d=1): # kernel, padding, dilation
class Conv(nn.Module):
"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
@ -60,9 +74,9 @@ class Conv2(Conv):
"""Fuse parallel convolutions."""
w = torch.zeros_like(self.conv.weight.data)
i = [x // 2 for x in w.shape[2:]]
w[:, :, i[0]:i[0] + 1, i[1]:i[1] + 1] = self.cv2.weight.data.clone()
w[:, :, i[0] : i[0] + 1, i[1] : i[1] + 1] = self.cv2.weight.data.clone()
self.conv.weight.data += w
self.__delattr__('cv2')
self.__delattr__("cv2")
self.forward = self.forward_fuse
@ -102,6 +116,7 @@ class DWConvTranspose2d(nn.ConvTranspose2d):
class ConvTranspose(nn.Module):
"""Convolution transpose 2d layer."""
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
@ -164,6 +179,7 @@ class RepConv(nn.Module):
This module is used in RT-DETR.
Based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
"""
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
@ -214,7 +230,7 @@ class RepConv(nn.Module):
beta = branch.bn.bias
eps = branch.bn.eps
elif isinstance(branch, nn.BatchNorm2d):
if not hasattr(self, 'id_tensor'):
if not hasattr(self, "id_tensor"):
input_dim = self.c1 // self.g
kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)
for i in range(self.c1):
@ -232,29 +248,31 @@ class RepConv(nn.Module):
def fuse_convs(self):
"""Combines two convolution layers into a single layer and removes unused attributes from the class."""
if hasattr(self, 'conv'):
if hasattr(self, "conv"):
return
kernel, bias = self.get_equivalent_kernel_bias()
self.conv = nn.Conv2d(in_channels=self.conv1.conv.in_channels,
out_channels=self.conv1.conv.out_channels,
kernel_size=self.conv1.conv.kernel_size,
stride=self.conv1.conv.stride,
padding=self.conv1.conv.padding,
dilation=self.conv1.conv.dilation,
groups=self.conv1.conv.groups,
bias=True).requires_grad_(False)
self.conv = nn.Conv2d(
in_channels=self.conv1.conv.in_channels,
out_channels=self.conv1.conv.out_channels,
kernel_size=self.conv1.conv.kernel_size,
stride=self.conv1.conv.stride,
padding=self.conv1.conv.padding,
dilation=self.conv1.conv.dilation,
groups=self.conv1.conv.groups,
bias=True,
).requires_grad_(False)
self.conv.weight.data = kernel
self.conv.bias.data = bias
for para in self.parameters():
para.detach_()
self.__delattr__('conv1')
self.__delattr__('conv2')
if hasattr(self, 'nm'):
self.__delattr__('nm')
if hasattr(self, 'bn'):
self.__delattr__('bn')
if hasattr(self, 'id_tensor'):
self.__delattr__('id_tensor')
self.__delattr__("conv1")
self.__delattr__("conv2")
if hasattr(self, "nm"):
self.__delattr__("nm")
if hasattr(self, "bn"):
self.__delattr__("bn")
if hasattr(self, "id_tensor"):
self.__delattr__("id_tensor")
class ChannelAttention(nn.Module):
@ -278,7 +296,7 @@ class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
"""Initialize Spatial-attention module with kernel size argument."""
super().__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
assert kernel_size in (3, 7), "kernel size must be 3 or 7"
padding = 3 if kernel_size == 7 else 1
self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.act = nn.Sigmoid()

@ -14,11 +14,12 @@ from .conv import Conv
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
from .utils import bias_init_with_prob, linear_init_
__all__ = 'Detect', 'Segment', 'Pose', 'Classify', 'OBB', 'RTDETRDecoder'
__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"
class Detect(nn.Module):
"""YOLOv8 Detect head for detection models."""
dynamic = False # force grid reconstruction
export = False # export mode
shape = None
@ -35,7 +36,8 @@ class Detect(nn.Module):
self.stride = torch.zeros(self.nl) # strides computed during build
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
self.cv2 = nn.ModuleList(
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
)
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
@ -53,14 +55,14 @@ class Detect(nn.Module):
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV ops
box = x_cat[:, :self.reg_max * 4]
cls = x_cat[:, self.reg_max * 4:]
if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV ops
box = x_cat[:, : self.reg_max * 4]
cls = x_cat[:, self.reg_max * 4 :]
else:
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
dbox = self.decode_bboxes(box)
if self.export and self.format in ('tflite', 'edgetpu'):
if self.export and self.format in ("tflite", "edgetpu"):
# Precompute normalization factor to increase numerical stability
# See https://github.com/ultralytics/ultralytics/issues/7371
img_h = shape[2]
@ -79,7 +81,7 @@ class Detect(nn.Module):
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
a[-1].bias.data[:] = 1.0 # box
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
def decode_bboxes(self, bboxes):
"""Decode bounding boxes."""
@ -214,26 +216,28 @@ class RTDETRDecoder(nn.Module):
and class labels for objects in an image. It integrates features from multiple layers and runs through a series of
Transformer decoder layers to output the final predictions.
"""
export = False # export mode
def __init__(
self,
nc=80,
ch=(512, 1024, 2048),
hd=256, # hidden dim
nq=300, # num queries
ndp=4, # num decoder points
nh=8, # num head
ndl=6, # num decoder layers
d_ffn=1024, # dim of feedforward
dropout=0.,
act=nn.ReLU(),
eval_idx=-1,
# Training args
nd=100, # num denoising
label_noise_ratio=0.5,
box_noise_scale=1.0,
learnt_init_query=False):
self,
nc=80,
ch=(512, 1024, 2048),
hd=256, # hidden dim
nq=300, # num queries
ndp=4, # num decoder points
nh=8, # num head
ndl=6, # num decoder layers
d_ffn=1024, # dim of feedforward
dropout=0.0,
act=nn.ReLU(),
eval_idx=-1,
# Training args
nd=100, # num denoising
label_noise_ratio=0.5,
box_noise_scale=1.0,
learnt_init_query=False,
):
"""
Initializes the RTDETRDecoder module with the given parameters.
@ -302,28 +306,30 @@ class RTDETRDecoder(nn.Module):
feats, shapes = self._get_encoder_input(x)
# Prepare denoising training
dn_embed, dn_bbox, attn_mask, dn_meta = \
get_cdn_group(batch,
self.nc,
self.num_queries,
self.denoising_class_embed.weight,
self.num_denoising,
self.label_noise_ratio,
self.box_noise_scale,
self.training)
embed, refer_bbox, enc_bboxes, enc_scores = \
self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)
dn_embed, dn_bbox, attn_mask, dn_meta = get_cdn_group(
batch,
self.nc,
self.num_queries,
self.denoising_class_embed.weight,
self.num_denoising,
self.label_noise_ratio,
self.box_noise_scale,
self.training,
)
embed, refer_bbox, enc_bboxes, enc_scores = self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)
# Decoder
dec_bboxes, dec_scores = self.decoder(embed,
refer_bbox,
feats,
shapes,
self.dec_bbox_head,
self.dec_score_head,
self.query_pos_head,
attn_mask=attn_mask)
dec_bboxes, dec_scores = self.decoder(
embed,
refer_bbox,
feats,
shapes,
self.dec_bbox_head,
self.dec_score_head,
self.query_pos_head,
attn_mask=attn_mask,
)
x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
if self.training:
return x
@ -331,24 +337,24 @@ class RTDETRDecoder(nn.Module):
y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
return y if self.export else (y, x)
def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):
def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device="cpu", eps=1e-2):
"""Generates anchor bounding boxes for given shapes with specific grid size and validates them."""
anchors = []
for i, (h, w) in enumerate(shapes):
sy = torch.arange(end=h, dtype=dtype, device=device)
sx = torch.arange(end=w, dtype=dtype, device=device)
grid_y, grid_x = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx)
grid_y, grid_x = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)
valid_WH = torch.tensor([w, h], dtype=dtype, device=device)
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2)
wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0 ** i)
wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**i)
anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4)
anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4)
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1
anchors = torch.log(anchors / (1 - anchors))
anchors = anchors.masked_fill(~valid_mask, float('inf'))
anchors = anchors.masked_fill(~valid_mask, float("inf"))
return anchors, valid_mask
def _get_encoder_input(self, x):
@ -415,13 +421,13 @@ class RTDETRDecoder(nn.Module):
# NOTE: the weight initialization in `linear_init_` would cause NaN when training with custom datasets.
# linear_init_(self.enc_score_head)
constant_(self.enc_score_head.bias, bias_cls)
constant_(self.enc_bbox_head.layers[-1].weight, 0.)
constant_(self.enc_bbox_head.layers[-1].bias, 0.)
constant_(self.enc_bbox_head.layers[-1].weight, 0.0)
constant_(self.enc_bbox_head.layers[-1].bias, 0.0)
for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
# linear_init_(cls_)
constant_(cls_.bias, bias_cls)
constant_(reg_.layers[-1].weight, 0.)
constant_(reg_.layers[-1].bias, 0.)
constant_(reg_.layers[-1].weight, 0.0)
constant_(reg_.layers[-1].bias, 0.0)
linear_init_(self.enc_output[0])
xavier_uniform_(self.enc_output[0].weight)

@ -11,8 +11,18 @@ from torch.nn.init import constant_, xavier_uniform_
from .conv import Conv
from .utils import _get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch
__all__ = ('TransformerEncoderLayer', 'TransformerLayer', 'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'AIFI',
'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP')
__all__ = (
"TransformerEncoderLayer",
"TransformerLayer",
"TransformerBlock",
"MLPBlock",
"LayerNorm2d",
"AIFI",
"DeformableTransformerDecoder",
"DeformableTransformerDecoderLayer",
"MSDeformAttn",
"MLP",
)
class TransformerEncoderLayer(nn.Module):
@ -22,9 +32,11 @@ class TransformerEncoderLayer(nn.Module):
"""Initialize the TransformerEncoderLayer with specified parameters."""
super().__init__()
from ...utils.torch_utils import TORCH_1_9
if not TORCH_1_9:
raise ModuleNotFoundError(
'TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True).')
"TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True)."
)
self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)
# Implementation of Feedforward model
self.fc1 = nn.Linear(c1, cm)
@ -91,12 +103,11 @@ class AIFI(TransformerEncoderLayer):
"""Builds 2D sine-cosine position embedding."""
grid_w = torch.arange(int(w), dtype=torch.float32)
grid_h = torch.arange(int(h), dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij')
assert embed_dim % 4 == 0, \
'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1. / (temperature ** omega)
omega = 1.0 / (temperature**omega)
out_w = grid_w.flatten()[..., None] @ omega[None]
out_h = grid_h.flatten()[..., None] @ omega[None]
@ -213,10 +224,10 @@ class MSDeformAttn(nn.Module):
"""Initialize MSDeformAttn with the given parameters."""
super().__init__()
if d_model % n_heads != 0:
raise ValueError(f'd_model must be divisible by n_heads, but got {d_model} and {n_heads}')
raise ValueError(f"d_model must be divisible by n_heads, but got {d_model} and {n_heads}")
_d_per_head = d_model // n_heads
# Better to set _d_per_head to a power of 2 which is more efficient in a CUDA implementation
assert _d_per_head * n_heads == d_model, '`d_model` must be divisible by `n_heads`'
assert _d_per_head * n_heads == d_model, "`d_model` must be divisible by `n_heads`"
self.im2col_step = 64
@ -234,21 +245,24 @@ class MSDeformAttn(nn.Module):
def _reset_parameters(self):
"""Reset module parameters."""
constant_(self.sampling_offsets.weight.data, 0.)
constant_(self.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(
1, self.n_levels, self.n_points, 1)
grid_init = (
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
.view(self.n_heads, 1, 1, 2)
.repeat(1, self.n_levels, self.n_points, 1)
)
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.)
constant_(self.attention_weights.bias.data, 0.)
constant_(self.attention_weights.weight.data, 0.0)
constant_(self.attention_weights.bias.data, 0.0)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.)
constant_(self.value_proj.bias.data, 0.0)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
constant_(self.output_proj.bias.data, 0.0)
def forward(self, query, refer_bbox, value, value_shapes, value_mask=None):
"""
@ -288,7 +302,7 @@ class MSDeformAttn(nn.Module):
add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5
sampling_locations = refer_bbox[:, :, None, :, None, :2] + add
else:
raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {num_points}.')
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {num_points}.")
output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights)
return self.output_proj(output)
@ -301,7 +315,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py
"""
def __init__(self, d_model=256, n_heads=8, d_ffn=1024, dropout=0., act=nn.ReLU(), n_levels=4, n_points=4):
def __init__(self, d_model=256, n_heads=8, d_ffn=1024, dropout=0.0, act=nn.ReLU(), n_levels=4, n_points=4):
"""Initialize the DeformableTransformerDecoderLayer with the given parameters."""
super().__init__()
@ -339,14 +353,16 @@ class DeformableTransformerDecoderLayer(nn.Module):
# Self attention
q = k = self.with_pos_embed(embed, query_pos)
tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1),
attn_mask=attn_mask)[0].transpose(0, 1)
tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), attn_mask=attn_mask)[
0
].transpose(0, 1)
embed = embed + self.dropout1(tgt)
embed = self.norm1(embed)
# Cross attention
tgt = self.cross_attn(self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes,
padding_mask)
tgt = self.cross_attn(
self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes, padding_mask
)
embed = embed + self.dropout2(tgt)
embed = self.norm2(embed)
@ -370,16 +386,17 @@ class DeformableTransformerDecoder(nn.Module):
self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
def forward(
self,
embed, # decoder embeddings
refer_bbox, # anchor
feats, # image features
shapes, # feature shapes
bbox_head,
score_head,
pos_mlp,
attn_mask=None,
padding_mask=None):
self,
embed, # decoder embeddings
refer_bbox, # anchor
feats, # image features
shapes, # feature shapes
bbox_head,
score_head,
pos_mlp,
attn_mask=None,
padding_mask=None,
):
"""Perform the forward pass through the entire decoder."""
output = embed
dec_bboxes = []

@ -10,7 +10,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import uniform_
__all__ = 'multi_scale_deformable_attn_pytorch', 'inverse_sigmoid'
__all__ = "multi_scale_deformable_attn_pytorch", "inverse_sigmoid"
def _get_clones(module, n):
@ -27,7 +27,7 @@ def linear_init_(module):
"""Initialize the weights and biases of a linear module."""
bound = 1 / math.sqrt(module.weight.shape[0])
uniform_(module.weight, -bound, bound)
if hasattr(module, 'bias') and module.bias is not None:
if hasattr(module, "bias") and module.bias is not None:
uniform_(module.bias, -bound, bound)
@ -39,9 +39,12 @@ def inverse_sigmoid(x, eps=1e-5):
return torch.log(x1 / x2)
def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shapes: torch.Tensor,
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor) -> torch.Tensor:
def multi_scale_deformable_attn_pytorch(
value: torch.Tensor,
value_spatial_shapes: torch.Tensor,
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor,
) -> torch.Tensor:
"""
Multi-scale deformable attention.
@ -58,23 +61,25 @@ def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shape
# bs, H_*W_, num_heads*embed_dims ->
# bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_
value_l_ = (value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_))
value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
# bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
# bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_ = F.grid_sample(value_l_,
sampling_grid_l_,
mode='bilinear',
padding_mode='zeros',
align_corners=False)
sampling_value_l_ = F.grid_sample(
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
)
sampling_value_list.append(sampling_value_l_)
# (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(bs * num_heads, 1, num_queries,
num_levels * num_points)
output = ((torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(
bs, num_heads * embed_dims, num_queries))
attention_weights = attention_weights.transpose(1, 2).reshape(
bs * num_heads, 1, num_queries, num_levels * num_points
)
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(bs, num_heads * embed_dims, num_queries)
)
return output.transpose(1, 2).contiguous()

@ -7,16 +7,54 @@ from pathlib import Path
import torch
import torch.nn as nn
from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, OBB, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost,
C3x, Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv,
DWConvTranspose2d, Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3,
RepConv, ResNetLayer, RTDETRDecoder, Segment)
from ultralytics.nn.modules import (
AIFI,
C1,
C2,
C3,
C3TR,
OBB,
SPP,
SPPF,
Bottleneck,
BottleneckCSP,
C2f,
C3Ghost,
C3x,
Classify,
Concat,
Conv,
Conv2,
ConvTranspose,
Detect,
DWConv,
DWConvTranspose2d,
Focus,
GhostBottleneck,
GhostConv,
HGBlock,
HGStem,
Pose,
RepC3,
RepConv,
ResNetLayer,
RTDETRDecoder,
Segment,
)
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss
from ultralytics.utils.plotting import feature_visualization
from ultralytics.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, intersect_dicts,
make_divisible, model_info, scale_img, time_sync)
from ultralytics.utils.torch_utils import (
fuse_conv_and_bn,
fuse_deconv_and_bn,
initialize_weights,
intersect_dicts,
make_divisible,
model_info,
scale_img,
time_sync,
)
try:
import thop
@ -90,8 +128,10 @@ class BaseModel(nn.Module):
def _predict_augment(self, x):
"""Perform augmentations on input image x and return augmented inference."""
LOGGER.warning(f'WARNING ⚠ {self.__class__.__name__} does not support augmented inference yet. '
f'Reverting to single-scale inference instead.')
LOGGER.warning(
f"WARNING ⚠ {self.__class__.__name__} does not support augmented inference yet. "
f"Reverting to single-scale inference instead."
)
return self._predict_once(x)
def _profile_one_layer(self, m, x, dt):
@ -108,14 +148,14 @@ class BaseModel(nn.Module):
None
"""
c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # FLOPs
t = time_sync()
for _ in range(10):
m(x.copy() if c else x)
dt.append((time_sync() - t) * 100)
if m == self.model[0]:
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
LOGGER.info(f'{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}')
LOGGER.info(f"{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}")
if c:
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
@ -129,15 +169,15 @@ class BaseModel(nn.Module):
"""
if not self.is_fused():
for m in self.model.modules():
if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, 'bn'):
if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, "bn"):
if isinstance(m, Conv2):
m.fuse_convs()
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, 'bn') # remove batchnorm
delattr(m, "bn") # remove batchnorm
m.forward = m.forward_fuse # update forward
if isinstance(m, ConvTranspose) and hasattr(m, 'bn'):
if isinstance(m, ConvTranspose) and hasattr(m, "bn"):
m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
delattr(m, 'bn') # remove batchnorm
delattr(m, "bn") # remove batchnorm
m.forward = m.forward_fuse # update forward
if isinstance(m, RepConv):
m.fuse_convs()
@ -156,7 +196,7 @@ class BaseModel(nn.Module):
Returns:
(bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
"""
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
def info(self, detailed=False, verbose=True, imgsz=640):
@ -196,12 +236,12 @@ class BaseModel(nn.Module):
weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
verbose (bool, optional): Whether to log the transfer progress. Defaults to True.
"""
model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts
model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
csd = model.float().state_dict() # checkpoint state_dict as FP32
csd = intersect_dicts(csd, self.state_dict()) # intersect
self.load_state_dict(csd, strict=False) # load
if verbose:
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights")
def loss(self, batch, preds=None):
"""
@ -211,33 +251,33 @@ class BaseModel(nn.Module):
batch (dict): Batch to compute loss on
preds (torch.Tensor | List[torch.Tensor]): Predictions.
"""
if not hasattr(self, 'criterion'):
if not hasattr(self, "criterion"):
self.criterion = self.init_criterion()
preds = self.forward(batch['img']) if preds is None else preds
preds = self.forward(batch["img"]) if preds is None else preds
return self.criterion(preds, batch)
def init_criterion(self):
"""Initialize the loss criterion for the BaseModel."""
raise NotImplementedError('compute_loss() needs to be implemented by task heads')
raise NotImplementedError("compute_loss() needs to be implemented by task heads")
class DetectionModel(BaseModel):
"""YOLOv8 detection model."""
def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes
def __init__(self, cfg="yolov8n.yaml", ch=3, nc=None, verbose=True): # model, input channels, number of classes
"""Initialize the YOLOv8 detection model with the given config and parameters."""
super().__init__()
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
# Define model
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
if nc and nc != self.yaml['nc']:
ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
if nc and nc != self.yaml["nc"]:
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
self.yaml['nc'] = nc # override YAML value
self.yaml["nc"] = nc # override YAML value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
self.inplace = self.yaml.get('inplace', True)
self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
self.inplace = self.yaml.get("inplace", True)
# Build strides
m = self.model[-1] # Detect()
@ -255,7 +295,7 @@ class DetectionModel(BaseModel):
initialize_weights(self)
if verbose:
self.info()
LOGGER.info('')
LOGGER.info("")
def _predict_augment(self, x):
"""Perform augmentations on input image x and return augmented inference and train outputs."""
@ -285,9 +325,9 @@ class DetectionModel(BaseModel):
def _clip_augmented(self, y):
"""Clip YOLO augmented inference tails."""
nl = self.model[-1].nl # number of detection layers (P3-P5)
g = sum(4 ** x for x in range(nl)) # grid points
g = sum(4**x for x in range(nl)) # grid points
e = 1 # exclude layer count
i = (y[0].shape[-1] // g) * sum(4 ** x for x in range(e)) # indices
i = (y[0].shape[-1] // g) * sum(4**x for x in range(e)) # indices
y[0] = y[0][..., :-i] # large
i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
y[-1] = y[-1][..., i:] # small
@ -301,7 +341,7 @@ class DetectionModel(BaseModel):
class OBBModel(DetectionModel):
""""YOLOv8 Oriented Bounding Box (OBB) model."""
def __init__(self, cfg='yolov8n-obb.yaml', ch=3, nc=None, verbose=True):
def __init__(self, cfg="yolov8n-obb.yaml", ch=3, nc=None, verbose=True):
"""Initialize YOLOv8 OBB model with given config and parameters."""
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
@ -313,7 +353,7 @@ class OBBModel(DetectionModel):
class SegmentationModel(DetectionModel):
"""YOLOv8 segmentation model."""
def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True):
def __init__(self, cfg="yolov8n-seg.yaml", ch=3, nc=None, verbose=True):
"""Initialize YOLOv8 segmentation model with given config and parameters."""
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
@ -325,13 +365,13 @@ class SegmentationModel(DetectionModel):
class PoseModel(DetectionModel):
"""YOLOv8 pose model."""
def __init__(self, cfg='yolov8n-pose.yaml', ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
def __init__(self, cfg="yolov8n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
"""Initialize YOLOv8 Pose model."""
if not isinstance(cfg, dict):
cfg = yaml_model_load(cfg) # load model YAML
if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg['kpt_shape']):
if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]):
LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}")
cfg['kpt_shape'] = data_kpt_shape
cfg["kpt_shape"] = data_kpt_shape
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
def init_criterion(self):
@ -342,7 +382,7 @@ class PoseModel(DetectionModel):
class ClassificationModel(BaseModel):
"""YOLOv8 classification model."""
def __init__(self, cfg='yolov8n-cls.yaml', ch=3, nc=None, verbose=True):
def __init__(self, cfg="yolov8n-cls.yaml", ch=3, nc=None, verbose=True):
"""Init ClassificationModel with YAML, channels, number of classes, verbose flag."""
super().__init__()
self._from_yaml(cfg, ch, nc, verbose)
@ -352,21 +392,21 @@ class ClassificationModel(BaseModel):
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
# Define model
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
if nc and nc != self.yaml['nc']:
ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
if nc and nc != self.yaml["nc"]:
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
self.yaml['nc'] = nc # override YAML value
elif not nc and not self.yaml.get('nc', None):
raise ValueError('nc not specified. Must specify nc in model.yaml or function arguments.')
self.yaml["nc"] = nc # override YAML value
elif not nc and not self.yaml.get("nc", None):
raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.")
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
self.stride = torch.Tensor([1]) # no stride constraints
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
self.info()
@staticmethod
def reshape_outputs(model, nc):
"""Update a TorchVision classification model to class count 'n' if required."""
name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module
if isinstance(m, Classify): # YOLO Classify() head
if m.linear.out_features != nc:
m.linear = nn.Linear(m.linear.in_features, nc)
@ -409,7 +449,7 @@ class RTDETRDetectionModel(DetectionModel):
predict: Performs a forward pass through the network and returns the output.
"""
def __init__(self, cfg='rtdetr-l.yaml', ch=3, nc=None, verbose=True):
def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):
"""
Initialize the RTDETRDetectionModel.
@ -438,39 +478,39 @@ class RTDETRDetectionModel(DetectionModel):
Returns:
(tuple): A tuple containing the total loss and main three losses in a tensor.
"""
if not hasattr(self, 'criterion'):
if not hasattr(self, "criterion"):
self.criterion = self.init_criterion()
img = batch['img']
img = batch["img"]
# NOTE: preprocess gt_bbox and gt_labels to list.
bs = len(img)
batch_idx = batch['batch_idx']
batch_idx = batch["batch_idx"]
gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
targets = {
'cls': batch['cls'].to(img.device, dtype=torch.long).view(-1),
'bboxes': batch['bboxes'].to(device=img.device),
'batch_idx': batch_idx.to(img.device, dtype=torch.long).view(-1),
'gt_groups': gt_groups}
"cls": batch["cls"].to(img.device, dtype=torch.long).view(-1),
"bboxes": batch["bboxes"].to(device=img.device),
"batch_idx": batch_idx.to(img.device, dtype=torch.long).view(-1),
"gt_groups": gt_groups,
}
preds = self.predict(img, batch=targets) if preds is None else preds
dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]
if dn_meta is None:
dn_bboxes, dn_scores = None, None
else:
dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta['dn_num_split'], dim=2)
dn_scores, dec_scores = torch.split(dec_scores, dn_meta['dn_num_split'], dim=2)
dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta["dn_num_split"], dim=2)
dn_scores, dec_scores = torch.split(dec_scores, dn_meta["dn_num_split"], dim=2)
dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4)
dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])
loss = self.criterion((dec_bboxes, dec_scores),
targets,
dn_bboxes=dn_bboxes,
dn_scores=dn_scores,
dn_meta=dn_meta)
loss = self.criterion(
(dec_bboxes, dec_scores), targets, dn_bboxes=dn_bboxes, dn_scores=dn_scores, dn_meta=dn_meta
)
# NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']],
device=img.device)
return sum(loss.values()), torch.as_tensor(
[loss[k].detach() for k in ["loss_giou", "loss_class", "loss_bbox"]], device=img.device
)
def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
"""
@ -553,6 +593,7 @@ def temporary_modules(modules=None):
import importlib
import sys
try:
# Set modules in sys.modules under their old name
for old, new in modules.items():
@ -580,30 +621,38 @@ def torch_safe_load(weight):
"""
from ultralytics.utils.downloads import attempt_download_asset
check_suffix(file=weight, suffix='.pt')
check_suffix(file=weight, suffix=".pt")
file = attempt_download_asset(weight) # search online if missing locally
try:
with temporary_modules({
'ultralytics.yolo.utils': 'ultralytics.utils',
'ultralytics.yolo.v8': 'ultralytics.models.yolo',
'ultralytics.yolo.data': 'ultralytics.data'}): # for legacy 8.0 Classify and Pose models
return torch.load(file, map_location='cpu'), file # load
with temporary_modules(
{
"ultralytics.yolo.utils": "ultralytics.utils",
"ultralytics.yolo.v8": "ultralytics.models.yolo",
"ultralytics.yolo.data": "ultralytics.data",
}
): # for legacy 8.0 Classify and Pose models
return torch.load(file, map_location="cpu"), file # load
except ModuleNotFoundError as e: # e.name is missing module name
if e.name == 'models':
if e.name == "models":
raise TypeError(
emojis(f'ERROR ❌ {weight} appears to be an Ultralytics YOLOv5 model originally trained '
f'with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with '
f'YOLOv8 at https://github.com/ultralytics/ultralytics.'
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")) from e
LOGGER.warning(f"WARNING ⚠ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")
emojis(
f"ERROR ❌ {weight} appears to be an Ultralytics YOLOv5 model originally trained "
f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "
f"YOLOv8 at https://github.com/ultralytics/ultralytics."
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
)
) from e
LOGGER.warning(
f"WARNING ⚠ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
)
check_requirements(e.name) # install missing module
return torch.load(file, map_location='cpu'), file # load
return torch.load(file, map_location="cpu"), file # load
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
@ -612,25 +661,25 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
ensemble = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
ckpt, w = torch_safe_load(w) # load ckpt
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} if 'train_args' in ckpt else None # combined args
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None # combined args
model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
# Model compatibility updates
model.args = args # attach args to model
model.pt_path = w # attach *.pt file path to model
model.task = guess_model_task(model)
if not hasattr(model, 'stride'):
model.stride = torch.tensor([32.])
if not hasattr(model, "stride"):
model.stride = torch.tensor([32.0])
# Append
ensemble.append(model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval()) # model in eval mode
ensemble.append(model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()) # model in eval mode
# Module updates
for m in ensemble.modules():
t = type(m)
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
m.inplace = inplace
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
elif t is nn.Upsample and not hasattr(m, "recompute_scale_factor"):
m.recompute_scale_factor = None # torch 1.11.0 compatibility
# Return model
@ -638,35 +687,35 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
return ensemble[-1]
# Return ensemble
LOGGER.info(f'Ensemble created with {weights}\n')
for k in 'names', 'nc', 'yaml':
LOGGER.info(f"Ensemble created with {weights}\n")
for k in "names", "nc", "yaml":
setattr(ensemble, k, getattr(ensemble[0], k))
ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride
assert all(ensemble[0].nc == m.nc for m in ensemble), f'Models differ in class counts {[m.nc for m in ensemble]}'
assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}"
return ensemble
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
"""Loads a single model weights."""
ckpt, weight = torch_safe_load(weight) # load ckpt
args = {**DEFAULT_CFG_DICT, **(ckpt.get('train_args', {}))} # combine model and default args, preferring model args
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))} # combine model and default args, preferring model args
model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
# Model compatibility updates
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
model.pt_path = weight # attach *.pt file path to model
model.task = guess_model_task(model)
if not hasattr(model, 'stride'):
model.stride = torch.tensor([32.])
if not hasattr(model, "stride"):
model.stride = torch.tensor([32.0])
model = model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval() # model in eval mode
model = model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval() # model in eval mode
# Module updates
for m in model.modules():
t = type(m)
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
m.inplace = inplace
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
elif t is nn.Upsample and not hasattr(m, "recompute_scale_factor"):
m.recompute_scale_factor = None # torch 1.11.0 compatibility
# Return model and ckpt
@ -678,11 +727,11 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
import ast
# Args
max_channels = float('inf')
nc, act, scales = (d.get(x) for x in ('nc', 'activation', 'scales'))
depth, width, kpt_shape = (d.get(x, 1.0) for x in ('depth_multiple', 'width_multiple', 'kpt_shape'))
max_channels = float("inf")
nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
if scales:
scale = d.get('scale')
scale = d.get("scale")
if not scale:
scale = tuple(scales.keys())[0]
LOGGER.warning(f"WARNING ⚠ no model scale passed. Assuming scale='{scale}'.")
@ -697,16 +746,37 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
ch = [ch]
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m] # get module
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m] # get module
for j, a in enumerate(args):
if isinstance(a, str):
with contextlib.suppress(ValueError):
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3):
if m in (
Classify,
Conv,
ConvTranspose,
GhostConv,
Bottleneck,
GhostBottleneck,
SPP,
SPPF,
DWConv,
Focus,
BottleneckCSP,
C1,
C2,
C2f,
C3,
C3TR,
C3Ghost,
nn.ConvTranspose2d,
DWConvTranspose2d,
C3x,
RepC3,
):
c1, c2 = ch[f], args[0]
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
c2 = make_divisible(min(c2, max_channels) * width, 8)
@ -739,11 +809,11 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
c2 = ch[f]
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
t = str(m)[8:-2].replace('__main__.', '') # module type
t = str(m)[8:-2].replace("__main__.", "") # module type
m.np = sum(x.numel() for x in m_.parameters()) # number params
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
if verbose:
LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}') # print
LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}") # print
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
layers.append(m_)
if i == 0:
@ -757,16 +827,16 @@ def yaml_model_load(path):
import re
path = Path(path)
if path.stem in (f'yolov{d}{x}6' for x in 'nsmlx' for d in (5, 8)):
new_stem = re.sub(r'(\d+)([nslmx])6(.+)?$', r'\1\2-p6\3', path.stem)
LOGGER.warning(f'WARNING ⚠ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.')
if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
LOGGER.warning(f"WARNING ⚠ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.")
path = path.with_name(new_stem + path.suffix)
unified_path = re.sub(r'(\d+)([nslmx])(.+)?$', r'\1\3', str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
d = yaml_load(yaml_file) # model dict
d['scale'] = guess_model_scale(path)
d['yaml_file'] = str(path)
d["scale"] = guess_model_scale(path)
d["yaml_file"] = str(path)
return d
@ -784,8 +854,9 @@ def guess_model_scale(model_path):
"""
with contextlib.suppress(AttributeError):
import re
return re.search(r'yolov\d+([nslmx])', Path(model_path).stem).group(1) # n, s, m, l, or x
return ''
return re.search(r"yolov\d+([nslmx])", Path(model_path).stem).group(1) # n, s, m, l, or x
return ""
def guess_model_task(model):
@ -804,17 +875,17 @@ def guess_model_task(model):
def cfg2task(cfg):
"""Guess from YAML dictionary."""
m = cfg['head'][-1][-2].lower() # output module name
if m in ('classify', 'classifier', 'cls', 'fc'):
return 'classify'
if m == 'detect':
return 'detect'
if m == 'segment':
return 'segment'
if m == 'pose':
return 'pose'
if m == 'obb':
return 'obb'
m = cfg["head"][-1][-2].lower() # output module name
if m in ("classify", "classifier", "cls", "fc"):
return "classify"
if m == "detect":
return "detect"
if m == "segment":
return "segment"
if m == "pose":
return "pose"
if m == "obb":
return "obb"
# Guess from model cfg
if isinstance(model, dict):
@ -823,40 +894,42 @@ def guess_model_task(model):
# Guess from PyTorch model
if isinstance(model, nn.Module): # PyTorch model
for x in 'model.args', 'model.model.args', 'model.model.model.args':
for x in "model.args", "model.model.args", "model.model.model.args":
with contextlib.suppress(Exception):
return eval(x)['task']
for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
return eval(x)["task"]
for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
with contextlib.suppress(Exception):
return cfg2task(eval(x))
for m in model.modules():
if isinstance(m, Detect):
return 'detect'
return "detect"
elif isinstance(m, Segment):
return 'segment'
return "segment"
elif isinstance(m, Classify):
return 'classify'
return "classify"
elif isinstance(m, Pose):
return 'pose'
return "pose"
elif isinstance(m, OBB):
return 'obb'
return "obb"
# Guess from model filename
if isinstance(model, (str, Path)):
model = Path(model)
if '-seg' in model.stem or 'segment' in model.parts:
return 'segment'
elif '-cls' in model.stem or 'classify' in model.parts:
return 'classify'
elif '-pose' in model.stem or 'pose' in model.parts:
return 'pose'
elif '-obb' in model.stem or 'obb' in model.parts:
return 'obb'
elif 'detect' in model.parts:
return 'detect'
if "-seg" in model.stem or "segment" in model.parts:
return "segment"
elif "-cls" in model.stem or "classify" in model.parts:
return "classify"
elif "-pose" in model.stem or "pose" in model.parts:
return "pose"
elif "-obb" in model.stem or "obb" in model.parts:
return "obb"
elif "detect" in model.parts:
return "detect"
# Unable to determine task from model
LOGGER.warning("WARNING ⚠ Unable to automatically guess model task, assuming 'task=detect'. "
"Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'.")
return 'detect' # assume detect
LOGGER.warning(
"WARNING ⚠ Unable to automatically guess model task, assuming 'task=detect'. "
"Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'."
)
return "detect" # assume detect

@ -26,7 +26,7 @@ class AIGym:
self.angle = None
self.count = None
self.stage = None
self.pose_type = 'pushup'
self.pose_type = "pushup"
self.kpts_to_check = None
# Visual Information
@ -36,13 +36,15 @@ class AIGym:
# Check if environment support imshow
self.env_check = check_imshow(warn=True)
def set_args(self,
kpts_to_check,
line_thickness=2,
view_img=False,
pose_up_angle=145.0,
pose_down_angle=90.0,
pose_type='pullup'):
def set_args(
self,
kpts_to_check,
line_thickness=2,
view_img=False,
pose_up_angle=145.0,
pose_down_angle=90.0,
pose_type="pullup",
):
"""
Configures the AIGym line_thickness, save image and view image parameters
Args:
@ -72,65 +74,75 @@ class AIGym:
if frame_count == 1:
self.count = [0] * len(results[0])
self.angle = [0] * len(results[0])
self.stage = ['-' for _ in results[0]]
self.stage = ["-" for _ in results[0]]
self.keypoints = results[0].keypoints.data
self.annotator = Annotator(im0, line_width=2)
for ind, k in enumerate(reversed(self.keypoints)):
if self.pose_type == 'pushup' or self.pose_type == 'pullup':
self.angle[ind] = self.annotator.estimate_pose_angle(k[int(self.kpts_to_check[0])].cpu(),
k[int(self.kpts_to_check[1])].cpu(),
k[int(self.kpts_to_check[2])].cpu())
if self.pose_type == "pushup" or self.pose_type == "pullup":
self.angle[ind] = self.annotator.estimate_pose_angle(
k[int(self.kpts_to_check[0])].cpu(),
k[int(self.kpts_to_check[1])].cpu(),
k[int(self.kpts_to_check[2])].cpu(),
)
self.im0 = self.annotator.draw_specific_points(k, self.kpts_to_check, shape=(640, 640), radius=10)
if self.pose_type == 'abworkout':
self.angle[ind] = self.annotator.estimate_pose_angle(k[int(self.kpts_to_check[0])].cpu(),
k[int(self.kpts_to_check[1])].cpu(),
k[int(self.kpts_to_check[2])].cpu())
if self.pose_type == "abworkout":
self.angle[ind] = self.annotator.estimate_pose_angle(
k[int(self.kpts_to_check[0])].cpu(),
k[int(self.kpts_to_check[1])].cpu(),
k[int(self.kpts_to_check[2])].cpu(),
)
self.im0 = self.annotator.draw_specific_points(k, self.kpts_to_check, shape=(640, 640), radius=10)
if self.angle[ind] > self.poseup_angle:
self.stage[ind] = 'down'
if self.angle[ind] < self.posedown_angle and self.stage[ind] == 'down':
self.stage[ind] = 'up'
self.stage[ind] = "down"
if self.angle[ind] < self.posedown_angle and self.stage[ind] == "down":
self.stage[ind] = "up"
self.count[ind] += 1
self.annotator.plot_angle_and_count_and_stage(angle_text=self.angle[ind],
count_text=self.count[ind],
stage_text=self.stage[ind],
center_kpt=k[int(self.kpts_to_check[1])],
line_thickness=self.tf)
if self.pose_type == 'pushup':
self.annotator.plot_angle_and_count_and_stage(
angle_text=self.angle[ind],
count_text=self.count[ind],
stage_text=self.stage[ind],
center_kpt=k[int(self.kpts_to_check[1])],
line_thickness=self.tf,
)
if self.pose_type == "pushup":
if self.angle[ind] > self.poseup_angle:
self.stage[ind] = 'up'
if self.angle[ind] < self.posedown_angle and self.stage[ind] == 'up':
self.stage[ind] = 'down'
self.stage[ind] = "up"
if self.angle[ind] < self.posedown_angle and self.stage[ind] == "up":
self.stage[ind] = "down"
self.count[ind] += 1
self.annotator.plot_angle_and_count_and_stage(angle_text=self.angle[ind],
count_text=self.count[ind],
stage_text=self.stage[ind],
center_kpt=k[int(self.kpts_to_check[1])],
line_thickness=self.tf)
if self.pose_type == 'pullup':
self.annotator.plot_angle_and_count_and_stage(
angle_text=self.angle[ind],
count_text=self.count[ind],
stage_text=self.stage[ind],
center_kpt=k[int(self.kpts_to_check[1])],
line_thickness=self.tf,
)
if self.pose_type == "pullup":
if self.angle[ind] > self.poseup_angle:
self.stage[ind] = 'down'
if self.angle[ind] < self.posedown_angle and self.stage[ind] == 'down':
self.stage[ind] = 'up'
self.stage[ind] = "down"
if self.angle[ind] < self.posedown_angle and self.stage[ind] == "down":
self.stage[ind] = "up"
self.count[ind] += 1
self.annotator.plot_angle_and_count_and_stage(angle_text=self.angle[ind],
count_text=self.count[ind],
stage_text=self.stage[ind],
center_kpt=k[int(self.kpts_to_check[1])],
line_thickness=self.tf)
self.annotator.plot_angle_and_count_and_stage(
angle_text=self.angle[ind],
count_text=self.count[ind],
stage_text=self.stage[ind],
center_kpt=k[int(self.kpts_to_check[1])],
line_thickness=self.tf,
)
self.annotator.kpts(k, shape=(640, 640), radius=1, kpt_line=True)
if self.env_check and self.view_img:
cv2.imshow('Ultralytics YOLOv8 AI GYM', self.im0)
if cv2.waitKey(1) & 0xFF == ord('q'):
cv2.imshow("Ultralytics YOLOv8 AI GYM", self.im0)
if cv2.waitKey(1) & 0xFF == ord("q"):
return
return self.im0
if __name__ == '__main__':
if __name__ == "__main__":
AIGym()

@ -41,13 +41,15 @@ class DistanceCalculation:
# Check if environment support imshow
self.env_check = check_imshow(warn=True)
def set_args(self,
names,
pixels_per_meter=10,
view_img=False,
line_thickness=2,
line_color=(255, 255, 0),
centroid_color=(255, 0, 255)):
def set_args(
self,
names,
pixels_per_meter=10,
view_img=False,
line_thickness=2,
line_color=(255, 255, 0),
centroid_color=(255, 0, 255),
):
"""
Configures the distance calculation and display parameters.
@ -129,8 +131,9 @@ class DistanceCalculation:
distance (float): Distance between two centroids
"""
cv2.rectangle(self.im0, (15, 25), (280, 70), (255, 255, 255), -1)
cv2.putText(self.im0, f'Distance : {distance:.2f}m', (20, 55), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2,
cv2.LINE_AA)
cv2.putText(
self.im0, f"Distance : {distance:.2f}m", (20, 55), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2, cv2.LINE_AA
)
cv2.line(self.im0, self.centroids[0], self.centroids[1], self.line_color, 3)
cv2.circle(self.im0, self.centroids[0], 6, self.centroid_color, -1)
cv2.circle(self.im0, self.centroids[1], 6, self.centroid_color, -1)
@ -179,13 +182,13 @@ class DistanceCalculation:
def display_frames(self):
"""Display frame."""
cv2.namedWindow('Ultralytics Distance Estimation')
cv2.setMouseCallback('Ultralytics Distance Estimation', self.mouse_event_for_distance)
cv2.imshow('Ultralytics Distance Estimation', self.im0)
cv2.namedWindow("Ultralytics Distance Estimation")
cv2.setMouseCallback("Ultralytics Distance Estimation", self.mouse_event_for_distance)
cv2.imshow("Ultralytics Distance Estimation", self.im0)
if cv2.waitKey(1) & 0xFF == ord('q'):
if cv2.waitKey(1) & 0xFF == ord("q"):
return
if __name__ == '__main__':
if __name__ == "__main__":
DistanceCalculation()

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save