diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 654496f1..e183b3e1 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -95,7 +95,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python-version: ['3.10'] + python-version: ['3.11'] model: [yolov8n] steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml new file mode 100644 index 00000000..48a56e3d --- /dev/null +++ b/.github/workflows/format.yml @@ -0,0 +1,25 @@ +# Ultralytics 🚀 - AGPL-3.0 license +# Ultralytics Actions https://github.com/ultralytics/actions +# This workflow automatically formats code and documentation in PRs to official Ultralytics standards + +name: Ultralytics Actions + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + format: + runs-on: ubuntu-latest + steps: + - name: Run Ultralytics Formatting + uses: ultralytics/actions@main + with: + token: ${{ secrets.GITHUB_TOKEN }} # automatically generated + python: true + docstrings: true + markdown: true + spelling: true + links: false diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 72cef360..07ae822f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,6 @@ repos: - id: check-case-conflict # - id: check-yaml - id: check-docstring-first - - id: double-quote-string-fixer - id: detect-private-key - repo: https://github.com/asottile/pyupgrade @@ -64,7 +63,7 @@ repos: - id: codespell exclude: 'docs/de|docs/fr|docs/pt|docs/es|docs/mkdocs_de.yml' args: - - --ignore-words-list=crate,nd,strack,dota,ane,segway,fo,gool,winn + - --ignore-words-list=crate,nd,ned,strack,dota,ane,segway,fo,gool,winn,commend,bloc,nam,afterall - repo: https://github.com/PyCQA/docformatter rev: v1.7.5 diff --git a/docs/build_docs.py b/docs/build_docs.py index 914f2fe8..9c9d75ed 100644 --- a/docs/build_docs.py +++ b/docs/build_docs.py @@ -30,45 +30,47 @@ import subprocess from pathlib import Path DOCS = Path(__file__).parent.resolve() -SITE = DOCS.parent / 'site' +SITE = DOCS.parent / "site" def build_docs(): """Build docs using mkdocs.""" if SITE.exists(): - print(f'Removing existing {SITE}') + print(f"Removing existing {SITE}") shutil.rmtree(SITE) # Build the main documentation - print(f'Building docs from {DOCS}') - subprocess.run(f'mkdocs build -f {DOCS}/mkdocs.yml', check=True, shell=True) + print(f"Building docs from {DOCS}") + subprocess.run(f"mkdocs build -f {DOCS}/mkdocs.yml", check=True, shell=True) # Build other localized documentations - for file in DOCS.glob('mkdocs_*.yml'): - print(f'Building MkDocs site with configuration file: {file}') - subprocess.run(f'mkdocs build -f {file}', check=True, shell=True) - print(f'Site built at {SITE}') + for file in DOCS.glob("mkdocs_*.yml"): + print(f"Building MkDocs site with configuration file: {file}") + subprocess.run(f"mkdocs build -f {file}", check=True, shell=True) + print(f"Site built at {SITE}") def update_html_links(): """Update href links in HTML files to remove '.md' and '/index.md', excluding links starting with 'https://'.""" - html_files = Path(SITE).rglob('*.html') + html_files = Path(SITE).rglob("*.html") total_updated_links = 0 for html_file in html_files: - with open(html_file, 'r+', encoding='utf-8') as file: + with open(html_file, "r+", encoding="utf-8") as file: content = file.read() # Find all links to be updated, excluding those starting with 'https://' links_to_update = re.findall(r'href="(?!https://)([^"]+?)(/index)?\.md"', content) # Update the content and count the number of links updated - updated_content, number_of_links_updated = re.subn(r'href="(?!https://)([^"]+?)(/index)?\.md"', - r'href="\1"', content) + updated_content, number_of_links_updated = re.subn( + r'href="(?!https://)([^"]+?)(/index)?\.md"', r'href="\1"', content + ) total_updated_links += number_of_links_updated # Special handling for '/index' links - updated_content, number_of_index_links_updated = re.subn(r'href="([^"]+)/index"', r'href="\1/"', - updated_content) + updated_content, number_of_index_links_updated = re.subn( + r'href="([^"]+)/index"', r'href="\1/"', updated_content + ) total_updated_links += number_of_index_links_updated # Write the updated content back to the file @@ -78,23 +80,23 @@ def update_html_links(): # Print updated links for this file for link in links_to_update: - print(f'Updated link in {html_file}: {link[0]}') + print(f"Updated link in {html_file}: {link[0]}") - print(f'Total number of links updated: {total_updated_links}') + print(f"Total number of links updated: {total_updated_links}") def update_page_title(file_path: Path, new_title: str): """Update the title of an HTML file.""" # Read the content of the file - with open(file_path, encoding='utf-8') as file: + with open(file_path, encoding="utf-8") as file: content = file.read() # Replace the existing title with the new title - updated_content = re.sub(r'.*?', f'{new_title}', content) + updated_content = re.sub(r".*?", f"{new_title}", content) # Write the updated content back to the file - with open(file_path, 'w', encoding='utf-8') as file: + with open(file_path, "w", encoding="utf-8") as file: file.write(updated_content) @@ -109,8 +111,8 @@ def main(): print('Serve site at http://localhost:8000 with "python -m http.server --directory site"') # Update titles - update_page_title(SITE / '404.html', new_title='Ultralytics Docs - Not Found') + update_page_title(SITE / "404.html", new_title="Ultralytics Docs - Not Found") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/docs/build_reference.py b/docs/build_reference.py index cb15d34a..65736ac0 100644 --- a/docs/build_reference.py +++ b/docs/build_reference.py @@ -14,14 +14,14 @@ from ultralytics.utils import ROOT NEW_YAML_DIR = ROOT.parent CODE_DIR = ROOT -REFERENCE_DIR = ROOT.parent / 'docs/en/reference' +REFERENCE_DIR = ROOT.parent / "docs/en/reference" def extract_classes_and_functions(filepath: Path) -> tuple: """Extracts class and function names from a given Python file.""" content = filepath.read_text() - class_pattern = r'(?:^|\n)class\s(\w+)(?:\(|:)' - func_pattern = r'(?:^|\n)def\s(\w+)\(' + class_pattern = r"(?:^|\n)class\s(\w+)(?:\(|:)" + func_pattern = r"(?:^|\n)def\s(\w+)\(" classes = re.findall(class_pattern, content) functions = re.findall(func_pattern, content) @@ -31,31 +31,31 @@ def extract_classes_and_functions(filepath: Path) -> tuple: def create_markdown(py_filepath: Path, module_path: str, classes: list, functions: list): """Creates a Markdown file containing the API reference for the given Python module.""" - md_filepath = py_filepath.with_suffix('.md') + md_filepath = py_filepath.with_suffix(".md") # Read existing content and keep header content between first two --- - header_content = '' + header_content = "" if md_filepath.exists(): existing_content = md_filepath.read_text() - header_parts = existing_content.split('---') + header_parts = existing_content.split("---") for part in header_parts: - if 'description:' in part or 'comments:' in part: - header_content += f'---{part}---\n\n' + if "description:" in part or "comments:" in part: + header_content += f"---{part}---\n\n" - module_name = module_path.replace('.__init__', '') - module_path = module_path.replace('.', '/') - url = f'https://github.com/ultralytics/ultralytics/blob/main/{module_path}.py' - edit = f'https://github.com/ultralytics/ultralytics/edit/main/{module_path}.py' + module_name = module_path.replace(".__init__", "") + module_path = module_path.replace(".", "/") + url = f"https://github.com/ultralytics/ultralytics/blob/main/{module_path}.py" + edit = f"https://github.com/ultralytics/ultralytics/edit/main/{module_path}.py" title_content = ( - f'# Reference for `{module_path}.py`\n\n' - f'!!! Note\n\n' - f' This file is available at [{url}]({url}). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request]({edit}) 🛠️. Thank you 🙏!\n\n' + f"# Reference for `{module_path}.py`\n\n" + f"!!! Note\n\n" + f" This file is available at [{url}]({url}). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request]({edit}) 🛠️. Thank you 🙏!\n\n" ) - md_content = ['

\n'] + [f'## ::: {module_name}.{class_name}\n\n

\n' for class_name in classes] - md_content.extend(f'## ::: {module_name}.{func_name}\n\n

\n' for func_name in functions) - md_content = header_content + title_content + '\n'.join(md_content) - if not md_content.endswith('\n'): - md_content += '\n' + md_content = ["

\n"] + [f"## ::: {module_name}.{class_name}\n\n

\n" for class_name in classes] + md_content.extend(f"## ::: {module_name}.{func_name}\n\n

\n" for func_name in functions) + md_content = header_content + title_content + "\n".join(md_content) + if not md_content.endswith("\n"): + md_content += "\n" md_filepath.parent.mkdir(parents=True, exist_ok=True) md_filepath.write_text(md_content) @@ -80,28 +80,28 @@ def create_nav_menu_yaml(nav_items: list): for item_str in nav_items: item = Path(item_str) parts = item.parts - current_level = nav_tree['reference'] + current_level = nav_tree["reference"] for part in parts[2:-1]: # skip the first two parts (docs and reference) and the last part (filename) current_level = current_level[part] - md_file_name = parts[-1].replace('.md', '') + md_file_name = parts[-1].replace(".md", "") current_level[md_file_name] = item nav_tree_sorted = sort_nested_dict(nav_tree) def _dict_to_yaml(d, level=0): """Converts a nested dictionary to a YAML-formatted string with indentation.""" - yaml_str = '' - indent = ' ' * level + yaml_str = "" + indent = " " * level for k, v in d.items(): if isinstance(v, dict): - yaml_str += f'{indent}- {k}:\n{_dict_to_yaml(v, level + 1)}' + yaml_str += f"{indent}- {k}:\n{_dict_to_yaml(v, level + 1)}" else: yaml_str += f"{indent}- {k}: {str(v).replace('docs/en/', '')}\n" return yaml_str # Print updated YAML reference section - print('Scan complete, new mkdocs.yaml reference section is:\n\n', _dict_to_yaml(nav_tree_sorted)) + print("Scan complete, new mkdocs.yaml reference section is:\n\n", _dict_to_yaml(nav_tree_sorted)) # Save new YAML reference section # (NEW_YAML_DIR / 'nav_menu_updated.yml').write_text(_dict_to_yaml(nav_tree_sorted)) @@ -111,7 +111,7 @@ def main(): """Main function to extract class and function names, create Markdown files, and generate a YAML navigation menu.""" nav_items = [] - for py_filepath in CODE_DIR.rglob('*.py'): + for py_filepath in CODE_DIR.rglob("*.py"): classes, functions = extract_classes_and_functions(py_filepath) if classes or functions: @@ -124,5 +124,5 @@ def main(): create_nav_menu_yaml(nav_items) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/docs/update_translations.py b/docs/update_translations.py index 9c27c700..8b64a4a4 100644 --- a/docs/update_translations.py +++ b/docs/update_translations.py @@ -22,69 +22,232 @@ class MarkdownLinkFixer: self.base_dir = Path(base_dir) self.update_links = update_links self.update_text = update_text - self.md_link_regex = re.compile(r'\[([^]]+)]\(([^:)]+)\.md\)') + self.md_link_regex = re.compile(r"\[([^]]+)]\(([^:)]+)\.md\)") @staticmethod def replace_front_matter(content, lang_dir): """Ensure front matter keywords remain in English.""" - english = ['comments', 'description', 'keywords'] + english = ["comments", "description", "keywords"] translations = { - 'zh': ['评论', '描述', '关键词'], # Mandarin Chinese (Simplified) warning, sometimes translates as 关键字 - 'es': ['comentarios', 'descripción', 'palabras clave'], # Spanish - 'ru': ['комментарии', 'описание', 'ключевые слова'], # Russian - 'pt': ['comentários', 'descrição', 'palavras-chave'], # Portuguese - 'fr': ['commentaires', 'description', 'mots-clés'], # French - 'de': ['kommentare', 'beschreibung', 'schlüsselwörter'], # German - 'ja': ['コメント', '説明', 'キーワード'], # Japanese - 'ko': ['댓글', '설명', '키워드'], # Korean - 'hi': ['टिप्पणियाँ', 'विवरण', 'कीवर्ड'], # Hindi - 'ar': ['التعليقات', 'الوصف', 'الكلمات الرئيسية'] # Arabic + "zh": ["评论", "描述", "关键词"], # Mandarin Chinese (Simplified) warning, sometimes translates as 关键字 + "es": ["comentarios", "descripción", "palabras clave"], # Spanish + "ru": ["комментарии", "описание", "ключевые слова"], # Russian + "pt": ["comentários", "descrição", "palavras-chave"], # Portuguese + "fr": ["commentaires", "description", "mots-clés"], # French + "de": ["kommentare", "beschreibung", "schlüsselwörter"], # German + "ja": ["コメント", "説明", "キーワード"], # Japanese + "ko": ["댓글", "설명", "키워드"], # Korean + "hi": ["टिप्पणियाँ", "विवरण", "कीवर्ड"], # Hindi + "ar": ["التعليقات", "الوصف", "الكلمات الرئيسية"], # Arabic } # front matter translations for comments, description, keyword for term, eng_key in zip(translations.get(lang_dir.stem, []), english): - content = re.sub(rf'{term} *[::].*', f'{eng_key}: true', content, flags=re.IGNORECASE) if \ - eng_key == 'comments' else re.sub(rf'{term} *[::] *', f'{eng_key}: ', content, flags=re.IGNORECASE) + content = ( + re.sub(rf"{term} *[::].*", f"{eng_key}: true", content, flags=re.IGNORECASE) + if eng_key == "comments" + else re.sub(rf"{term} *[::] *", f"{eng_key}: ", content, flags=re.IGNORECASE) + ) return content @staticmethod def replace_admonitions(content, lang_dir): """Ensure front matter keywords remain in English.""" english = [ - 'Note', 'Summary', 'Tip', 'Info', 'Success', 'Question', 'Warning', 'Failure', 'Danger', 'Bug', 'Example', - 'Quote', 'Abstract', 'Seealso', 'Admonition'] + "Note", + "Summary", + "Tip", + "Info", + "Success", + "Question", + "Warning", + "Failure", + "Danger", + "Bug", + "Example", + "Quote", + "Abstract", + "Seealso", + "Admonition", + ] translations = { - 'en': - english, - 'zh': ['笔记', '摘要', '提示', '信息', '成功', '问题', '警告', '失败', '危险', '故障', '示例', '引用', '摘要', '另见', '警告'], - 'es': [ - 'Nota', 'Resumen', 'Consejo', 'Información', 'Éxito', 'Pregunta', 'Advertencia', 'Fracaso', 'Peligro', - 'Error', 'Ejemplo', 'Cita', 'Abstracto', 'Véase También', 'Amonestación'], - 'ru': [ - 'Заметка', 'Сводка', 'Совет', 'Информация', 'Успех', 'Вопрос', 'Предупреждение', 'Неудача', 'Опасность', - 'Ошибка', 'Пример', 'Цитата', 'Абстракт', 'См. Также', 'Предостережение'], - 'pt': [ - 'Nota', 'Resumo', 'Dica', 'Informação', 'Sucesso', 'Questão', 'Aviso', 'Falha', 'Perigo', 'Bug', - 'Exemplo', 'Citação', 'Abstrato', 'Veja Também', 'Advertência'], - 'fr': [ - 'Note', 'Résumé', 'Conseil', 'Info', 'Succès', 'Question', 'Avertissement', 'Échec', 'Danger', 'Bug', - 'Exemple', 'Citation', 'Abstrait', 'Voir Aussi', 'Admonestation'], - 'de': [ - 'Hinweis', 'Zusammenfassung', 'Tipp', 'Info', 'Erfolg', 'Frage', 'Warnung', 'Ausfall', 'Gefahr', - 'Fehler', 'Beispiel', 'Zitat', 'Abstrakt', 'Siehe Auch', 'Ermahnung'], - 'ja': ['ノート', '要約', 'ヒント', '情報', '成功', '質問', '警告', '失敗', '危険', 'バグ', '例', '引用', '抄録', '参照', '訓告'], - 'ko': ['노트', '요약', '팁', '정보', '성공', '질문', '경고', '실패', '위험', '버그', '예제', '인용', '추상', '참조', '경고'], - 'hi': [ - 'नोट', 'सारांश', 'सुझाव', 'जानकारी', 'सफलता', 'प्रश्न', 'चेतावनी', 'विफलता', 'खतरा', 'बग', 'उदाहरण', - 'उद्धरण', 'सार', 'देखें भी', 'आगाही'], - 'ar': [ - 'ملاحظة', 'ملخص', 'نصيحة', 'معلومات', 'نجاح', 'سؤال', 'تحذير', 'فشل', 'خطر', 'عطل', 'مثال', 'اقتباس', - 'ملخص', 'انظر أيضاً', 'تحذير']} + "en": english, + "zh": [ + "笔记", + "摘要", + "提示", + "信息", + "成功", + "问题", + "警告", + "失败", + "危险", + "故障", + "示例", + "引用", + "摘要", + "另见", + "警告", + ], + "es": [ + "Nota", + "Resumen", + "Consejo", + "Información", + "Éxito", + "Pregunta", + "Advertencia", + "Fracaso", + "Peligro", + "Error", + "Ejemplo", + "Cita", + "Abstracto", + "Véase También", + "Amonestación", + ], + "ru": [ + "Заметка", + "Сводка", + "Совет", + "Информация", + "Успех", + "Вопрос", + "Предупреждение", + "Неудача", + "Опасность", + "Ошибка", + "Пример", + "Цитата", + "Абстракт", + "См. Также", + "Предостережение", + ], + "pt": [ + "Nota", + "Resumo", + "Dica", + "Informação", + "Sucesso", + "Questão", + "Aviso", + "Falha", + "Perigo", + "Bug", + "Exemplo", + "Citação", + "Abstrato", + "Veja Também", + "Advertência", + ], + "fr": [ + "Note", + "Résumé", + "Conseil", + "Info", + "Succès", + "Question", + "Avertissement", + "Échec", + "Danger", + "Bug", + "Exemple", + "Citation", + "Abstrait", + "Voir Aussi", + "Admonestation", + ], + "de": [ + "Hinweis", + "Zusammenfassung", + "Tipp", + "Info", + "Erfolg", + "Frage", + "Warnung", + "Ausfall", + "Gefahr", + "Fehler", + "Beispiel", + "Zitat", + "Abstrakt", + "Siehe Auch", + "Ermahnung", + ], + "ja": [ + "ノート", + "要約", + "ヒント", + "情報", + "成功", + "質問", + "警告", + "失敗", + "危険", + "バグ", + "例", + "引用", + "抄録", + "参照", + "訓告", + ], + "ko": [ + "노트", + "요약", + "팁", + "정보", + "성공", + "질문", + "경고", + "실패", + "위험", + "버그", + "예제", + "인용", + "추상", + "참조", + "경고", + ], + "hi": [ + "नोट", + "सारांश", + "सुझाव", + "जानकारी", + "सफलता", + "प्रश्न", + "चेतावनी", + "विफलता", + "खतरा", + "बग", + "उदाहरण", + "उद्धरण", + "सार", + "देखें भी", + "आगाही", + ], + "ar": [ + "ملاحظة", + "ملخص", + "نصيحة", + "معلومات", + "نجاح", + "سؤال", + "تحذير", + "فشل", + "خطر", + "عطل", + "مثال", + "اقتباس", + "ملخص", + "انظر أيضاً", + "تحذير", + ], + } for term, eng_key in zip(translations.get(lang_dir.stem, []), english): - if lang_dir.stem != 'en': - content = re.sub(rf'!!! *{eng_key} *\n', f'!!! {eng_key} "{term}"\n', content, flags=re.IGNORECASE) - content = re.sub(rf'!!! *{term} *\n', f'!!! {eng_key} "{term}"\n', content, flags=re.IGNORECASE) - content = re.sub(rf'!!! *{term}', f'!!! {eng_key}', content, flags=re.IGNORECASE) + if lang_dir.stem != "en": + content = re.sub(rf"!!! *{eng_key} *\n", f'!!! {eng_key} "{term}"\n', content, flags=re.IGNORECASE) + content = re.sub(rf"!!! *{term} *\n", f'!!! {eng_key} "{term}"\n', content, flags=re.IGNORECASE) + content = re.sub(rf"!!! *{term}", f"!!! {eng_key}", content, flags=re.IGNORECASE) content = re.sub(r'!!! *"', '!!! Example "', content, flags=re.IGNORECASE) return content @@ -92,30 +255,30 @@ class MarkdownLinkFixer: @staticmethod def update_iframe(content): """Update the 'allow' attribute of iframe if it does not contain the specific English permissions.""" - english = 'accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share' + english = "accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" pattern = re.compile(f'allow="(?!{re.escape(english)}).+?"') return pattern.sub(f'allow="{english}"', content) def link_replacer(self, match, parent_dir, lang_dir, use_abs_link=False): """Replace broken links with corresponding links in the /en/ directory.""" text, path = match.groups() - linked_path = (parent_dir / path).resolve().with_suffix('.md') + linked_path = (parent_dir / path).resolve().with_suffix(".md") if not linked_path.exists(): - en_linked_path = Path(str(linked_path).replace(str(lang_dir), str(lang_dir.parent / 'en'))) + en_linked_path = Path(str(linked_path).replace(str(lang_dir), str(lang_dir.parent / "en"))) if en_linked_path.exists(): if use_abs_link: # Use absolute links WARNING: BUGS, DO NOT USE docs_root_relative_path = en_linked_path.relative_to(lang_dir.parent) - updated_path = str(docs_root_relative_path).replace('en/', '/../') + updated_path = str(docs_root_relative_path).replace("en/", "/../") else: # Use relative links steps_up = len(parent_dir.relative_to(self.base_dir).parts) - updated_path = Path('../' * steps_up) / en_linked_path.relative_to(self.base_dir) - updated_path = str(updated_path).replace('/en/', '/') + updated_path = Path("../" * steps_up) / en_linked_path.relative_to(self.base_dir) + updated_path = str(updated_path).replace("/en/", "/") print(f"Redirecting link '[{text}]({path})' from {parent_dir} to {updated_path}") - return f'[{text}]({updated_path})' + return f"[{text}]({updated_path})" else: print(f"Warning: Broken link '[{text}]({path})' found in {parent_dir} does not exist in /docs/en/.") @@ -124,28 +287,30 @@ class MarkdownLinkFixer: @staticmethod def update_html_tags(content): """Updates HTML tags in docs.""" - alt_tag = 'MISSING' + alt_tag = "MISSING" # Remove closing slashes from self-closing HTML tags - pattern = re.compile(r'<([^>]+?)\s*/>') - content = re.sub(pattern, r'<\1>', content) + pattern = re.compile(r"<([^>]+?)\s*/>") + content = re.sub(pattern, r"<\1>", content) # Find all images without alt tags and add placeholder alt text - pattern = re.compile(r'!\[(.*?)\]\((.*?)\)') - content, num_replacements = re.subn(pattern, lambda match: f'![{match.group(1) or alt_tag}]({match.group(2)})', - content) + pattern = re.compile(r"!\[(.*?)\]\((.*?)\)") + content, num_replacements = re.subn( + pattern, lambda match: f"![{match.group(1) or alt_tag}]({match.group(2)})", content + ) # Add missing alt tags to HTML images pattern = re.compile(r']*src=["\'](.*?)["\'][^>]*>') - content, num_replacements = re.subn(pattern, lambda match: match.group(0).replace('>', f' alt="{alt_tag}">', 1), - content) + content, num_replacements = re.subn( + pattern, lambda match: match.group(0).replace(">", f' alt="{alt_tag}">', 1), content + ) return content def process_markdown_file(self, md_file_path, lang_dir): """Process each markdown file in the language directory.""" - print(f'Processing file: {md_file_path}') - with open(md_file_path, encoding='utf-8') as file: + print(f"Processing file: {md_file_path}") + with open(md_file_path, encoding="utf-8") as file: content = file.read() if self.update_links: @@ -157,23 +322,23 @@ class MarkdownLinkFixer: content = self.update_iframe(content) content = self.update_html_tags(content) - with open(md_file_path, 'w', encoding='utf-8') as file: + with open(md_file_path, "w", encoding="utf-8") as file: file.write(content) def process_language_directory(self, lang_dir): """Process each language-specific directory.""" - print(f'Processing language directory: {lang_dir}') - for md_file in lang_dir.rglob('*.md'): + print(f"Processing language directory: {lang_dir}") + for md_file in lang_dir.rglob("*.md"): self.process_markdown_file(md_file, lang_dir) def run(self): """Run the link fixing and front matter updating process for each language-specific directory.""" for subdir in self.base_dir.iterdir(): - if subdir.is_dir() and re.match(r'^\w\w$', subdir.name): + if subdir.is_dir() and re.match(r"^\w\w$", subdir.name): self.process_language_directory(subdir) -if __name__ == '__main__': +if __name__ == "__main__": # Set the path to your MkDocs 'docs' directory here docs_dir = str(Path(__file__).parent.resolve()) fixer = MarkdownLinkFixer(docs_dir, update_links=True, update_text=True) diff --git a/examples/YOLOv8-ONNXRuntime/main.py b/examples/YOLOv8-ONNXRuntime/main.py index ec768713..e3755a44 100644 --- a/examples/YOLOv8-ONNXRuntime/main.py +++ b/examples/YOLOv8-ONNXRuntime/main.py @@ -28,7 +28,7 @@ class YOLOv8: self.iou_thres = iou_thres # Load the class names from the COCO dataset - self.classes = yaml_load(check_yaml('coco128.yaml'))['names'] + self.classes = yaml_load(check_yaml("coco128.yaml"))["names"] # Generate a color palette for the classes self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3)) @@ -57,7 +57,7 @@ class YOLOv8: cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2) # Create the label text with class name and score - label = f'{self.classes[class_id]}: {score:.2f}' + label = f"{self.classes[class_id]}: {score:.2f}" # Calculate the dimensions of the label text (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) @@ -67,8 +67,9 @@ class YOLOv8: label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10 # Draw a filled rectangle as the background for the label text - cv2.rectangle(img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color, - cv2.FILLED) + cv2.rectangle( + img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color, cv2.FILLED + ) # Draw the label text on the image cv2.putText(img, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA) @@ -182,7 +183,7 @@ class YOLOv8: output_img: The output image with drawn detections. """ # Create an inference session using the ONNX model and specify execution providers - session = ort.InferenceSession(self.onnx_model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + session = ort.InferenceSession(self.onnx_model, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) # Get the model inputs model_inputs = session.get_inputs() @@ -202,17 +203,17 @@ class YOLOv8: return self.postprocess(self.img, outputs) # output image -if __name__ == '__main__': +if __name__ == "__main__": # Create an argument parser to handle command-line arguments parser = argparse.ArgumentParser() - parser.add_argument('--model', type=str, default='yolov8n.onnx', help='Input your ONNX model.') - parser.add_argument('--img', type=str, default=str(ASSETS / 'bus.jpg'), help='Path to input image.') - parser.add_argument('--conf-thres', type=float, default=0.5, help='Confidence threshold') - parser.add_argument('--iou-thres', type=float, default=0.5, help='NMS IoU threshold') + parser.add_argument("--model", type=str, default="yolov8n.onnx", help="Input your ONNX model.") + parser.add_argument("--img", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image.") + parser.add_argument("--conf-thres", type=float, default=0.5, help="Confidence threshold") + parser.add_argument("--iou-thres", type=float, default=0.5, help="NMS IoU threshold") args = parser.parse_args() # Check the requirements and select the appropriate backend (CPU or GPU) - check_requirements('onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime') + check_requirements("onnxruntime-gpu" if torch.cuda.is_available() else "onnxruntime") # Create an instance of the YOLOv8 class with the specified arguments detection = YOLOv8(args.model, args.img, args.conf_thres, args.iou_thres) @@ -221,8 +222,8 @@ if __name__ == '__main__': output_image = detection.main() # Display the output image in a window - cv2.namedWindow('Output', cv2.WINDOW_NORMAL) - cv2.imshow('Output', output_image) + cv2.namedWindow("Output", cv2.WINDOW_NORMAL) + cv2.imshow("Output", output_image) # Wait for a key press to exit cv2.waitKey(0) diff --git a/examples/YOLOv8-OpenCV-ONNX-Python/main.py b/examples/YOLOv8-OpenCV-ONNX-Python/main.py index 78b0b08e..c0564d15 100644 --- a/examples/YOLOv8-OpenCV-ONNX-Python/main.py +++ b/examples/YOLOv8-OpenCV-ONNX-Python/main.py @@ -6,7 +6,7 @@ import numpy as np from ultralytics.utils import ASSETS, yaml_load from ultralytics.utils.checks import check_yaml -CLASSES = yaml_load(check_yaml('coco128.yaml'))['names'] +CLASSES = yaml_load(check_yaml("coco128.yaml"))["names"] colors = np.random.uniform(0, 255, size=(len(CLASSES), 3)) @@ -23,7 +23,7 @@ def draw_bounding_box(img, class_id, confidence, x, y, x_plus_w, y_plus_h): x_plus_w (int): X-coordinate of the bottom-right corner of the bounding box. y_plus_h (int): Y-coordinate of the bottom-right corner of the bounding box. """ - label = f'{CLASSES[class_id]} ({confidence:.2f})' + label = f"{CLASSES[class_id]} ({confidence:.2f})" color = colors[class_id] cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2) cv2.putText(img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) @@ -76,8 +76,11 @@ def main(onnx_model, input_image): (minScore, maxScore, minClassLoc, (x, maxClassIndex)) = cv2.minMaxLoc(classes_scores) if maxScore >= 0.25: box = [ - outputs[0][i][0] - (0.5 * outputs[0][i][2]), outputs[0][i][1] - (0.5 * outputs[0][i][3]), - outputs[0][i][2], outputs[0][i][3]] + outputs[0][i][0] - (0.5 * outputs[0][i][2]), + outputs[0][i][1] - (0.5 * outputs[0][i][3]), + outputs[0][i][2], + outputs[0][i][3], + ] boxes.append(box) scores.append(maxScore) class_ids.append(maxClassIndex) @@ -92,26 +95,34 @@ def main(onnx_model, input_image): index = result_boxes[i] box = boxes[index] detection = { - 'class_id': class_ids[index], - 'class_name': CLASSES[class_ids[index]], - 'confidence': scores[index], - 'box': box, - 'scale': scale} + "class_id": class_ids[index], + "class_name": CLASSES[class_ids[index]], + "confidence": scores[index], + "box": box, + "scale": scale, + } detections.append(detection) - draw_bounding_box(original_image, class_ids[index], scores[index], round(box[0] * scale), round(box[1] * scale), - round((box[0] + box[2]) * scale), round((box[1] + box[3]) * scale)) + draw_bounding_box( + original_image, + class_ids[index], + scores[index], + round(box[0] * scale), + round(box[1] * scale), + round((box[0] + box[2]) * scale), + round((box[1] + box[3]) * scale), + ) # Display the image with bounding boxes - cv2.imshow('image', original_image) + cv2.imshow("image", original_image) cv2.waitKey(0) cv2.destroyAllWindows() return detections -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--model', default='yolov8n.onnx', help='Input your ONNX model.') - parser.add_argument('--img', default=str(ASSETS / 'bus.jpg'), help='Path to input image.') + parser.add_argument("--model", default="yolov8n.onnx", help="Input your ONNX model.") + parser.add_argument("--img", default=str(ASSETS / "bus.jpg"), help="Path to input image.") args = parser.parse_args() main(args.model, args.img) diff --git a/examples/YOLOv8-OpenCV-int8-tflite-Python/main.py b/examples/YOLOv8-OpenCV-int8-tflite-Python/main.py index 04b39e9e..9c23173c 100644 --- a/examples/YOLOv8-OpenCV-int8-tflite-Python/main.py +++ b/examples/YOLOv8-OpenCV-int8-tflite-Python/main.py @@ -13,14 +13,9 @@ img_height = 640 class LetterBox: - - def __init__(self, - new_shape=(img_width, img_height), - auto=False, - scaleFill=False, - scaleup=True, - center=True, - stride=32): + def __init__( + self, new_shape=(img_width, img_height), auto=False, scaleFill=False, scaleup=True, center=True, stride=32 + ): self.new_shape = new_shape self.auto = auto self.scaleFill = scaleFill @@ -33,9 +28,9 @@ class LetterBox: if labels is None: labels = {} - img = labels.get('img') if image is None else image + img = labels.get("img") if image is None else image shape = img.shape[:2] # current shape [height, width] - new_shape = labels.pop('rect_shape', self.new_shape) + new_shape = labels.pop("rect_shape", self.new_shape) if isinstance(new_shape, int): new_shape = (new_shape, new_shape) @@ -63,15 +58,16 @@ class LetterBox: img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1)) left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1)) - img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, - value=(114, 114, 114)) # add border - if labels.get('ratio_pad'): - labels['ratio_pad'] = (labels['ratio_pad'], (left, top)) # for evaluation + img = cv2.copyMakeBorder( + img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114) + ) # add border + if labels.get("ratio_pad"): + labels["ratio_pad"] = (labels["ratio_pad"], (left, top)) # for evaluation if len(labels): labels = self._update_labels(labels, ratio, dw, dh) - labels['img'] = img - labels['resized_shape'] = new_shape + labels["img"] = img + labels["resized_shape"] = new_shape return labels else: return img @@ -79,15 +75,14 @@ class LetterBox: def _update_labels(self, labels, ratio, padw, padh): """Update labels.""" - labels['instances'].convert_bbox(format='xyxy') - labels['instances'].denormalize(*labels['img'].shape[:2][::-1]) - labels['instances'].scale(*ratio) - labels['instances'].add_padding(padw, padh) + labels["instances"].convert_bbox(format="xyxy") + labels["instances"].denormalize(*labels["img"].shape[:2][::-1]) + labels["instances"].scale(*ratio) + labels["instances"].add_padding(padw, padh) return labels class Yolov8TFLite: - def __init__(self, tflite_model, input_image, confidence_thres, iou_thres): """ Initializes an instance of the Yolov8TFLite class. @@ -105,7 +100,7 @@ class Yolov8TFLite: self.iou_thres = iou_thres # Load the class names from the COCO dataset - self.classes = yaml_load(check_yaml('coco128.yaml'))['names'] + self.classes = yaml_load(check_yaml("coco128.yaml"))["names"] # Generate a color palette for the classes self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3)) @@ -134,7 +129,7 @@ class Yolov8TFLite: cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2) # Create the label text with class name and score - label = f'{self.classes[class_id]}: {score:.2f}' + label = f"{self.classes[class_id]}: {score:.2f}" # Calculate the dimensions of the label text (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) @@ -144,8 +139,13 @@ class Yolov8TFLite: label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10 # Draw a filled rectangle as the background for the label text - cv2.rectangle(img, (int(label_x), int(label_y - label_height)), - (int(label_x + label_width), int(label_y + label_height)), color, cv2.FILLED) + cv2.rectangle( + img, + (int(label_x), int(label_y - label_height)), + (int(label_x + label_width), int(label_y + label_height)), + color, + cv2.FILLED, + ) # Draw the label text on the image cv2.putText(img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA) @@ -161,7 +161,7 @@ class Yolov8TFLite: # Read the input image using OpenCV self.img = cv2.imread(self.input_image) - print('image befor', self.img) + print("image before", self.img) # Get the height and width of the input image self.img_height, self.img_width = self.img.shape[:2] @@ -209,8 +209,10 @@ class Yolov8TFLite: # Get the box, score, and class ID corresponding to the index box = boxes[i] gain = min(img_width / self.img_width, img_height / self.img_height) - pad = round((img_width - self.img_width * gain) / 2 - - 0.1), round((img_height - self.img_height * gain) / 2 - 0.1) + pad = ( + round((img_width - self.img_width * gain) / 2 - 0.1), + round((img_height - self.img_height * gain) / 2 - 0.1), + ) box[0] = (box[0] - pad[0]) / gain box[1] = (box[1] - pad[1]) / gain box[2] = box[2] / gain @@ -242,7 +244,7 @@ class Yolov8TFLite: output_details = interpreter.get_output_details() # Store the shape of the input for later use - input_shape = input_details[0]['shape'] + input_shape = input_details[0]["shape"] self.input_width = input_shape[1] self.input_height = input_shape[2] @@ -251,19 +253,19 @@ class Yolov8TFLite: img_data = img_data # img_data = img_data.cpu().numpy() # Set the input tensor to the interpreter - print(input_details[0]['index']) + print(input_details[0]["index"]) print(img_data.shape) img_data = img_data.transpose((0, 2, 3, 1)) - scale, zero_point = input_details[0]['quantization'] - interpreter.set_tensor(input_details[0]['index'], img_data) + scale, zero_point = input_details[0]["quantization"] + interpreter.set_tensor(input_details[0]["index"], img_data) # Run inference interpreter.invoke() # Get the output tensor from the interpreter - output = interpreter.get_tensor(output_details[0]['index']) - scale, zero_point = output_details[0]['quantization'] + output = interpreter.get_tensor(output_details[0]["index"]) + scale, zero_point = output_details[0]["quantization"] output = (output.astype(np.float32) - zero_point) * scale output[:, [0, 2]] *= img_width @@ -273,16 +275,15 @@ class Yolov8TFLite: return self.postprocess(self.img, output) -if __name__ == '__main__': +if __name__ == "__main__": # Create an argument parser to handle command-line arguments parser = argparse.ArgumentParser() - parser.add_argument('--model', - type=str, - default='yolov8n_full_integer_quant.tflite', - help='Input your TFLite model.') - parser.add_argument('--img', type=str, default=str(ASSETS / 'bus.jpg'), help='Path to input image.') - parser.add_argument('--conf-thres', type=float, default=0.5, help='Confidence threshold') - parser.add_argument('--iou-thres', type=float, default=0.5, help='NMS IoU threshold') + parser.add_argument( + "--model", type=str, default="yolov8n_full_integer_quant.tflite", help="Input your TFLite model." + ) + parser.add_argument("--img", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image.") + parser.add_argument("--conf-thres", type=float, default=0.5, help="Confidence threshold") + parser.add_argument("--iou-thres", type=float, default=0.5, help="NMS IoU threshold") args = parser.parse_args() # Create an instance of the Yolov8TFLite class with the specified arguments @@ -292,7 +293,7 @@ if __name__ == '__main__': output_image = detection.main() # Display the output image in a window - cv2.imshow('Output', output_image) + cv2.imshow("Output", output_image) # Wait for a key press to exit cv2.waitKey(0) diff --git a/examples/YOLOv8-Region-Counter/yolov8_region_counter.py b/examples/YOLOv8-Region-Counter/yolov8_region_counter.py index 5379fd3b..70bcfd5a 100644 --- a/examples/YOLOv8-Region-Counter/yolov8_region_counter.py +++ b/examples/YOLOv8-Region-Counter/yolov8_region_counter.py @@ -16,21 +16,22 @@ track_history = defaultdict(list) current_region = None counting_regions = [ { - 'name': 'YOLOv8 Polygon Region', - 'polygon': Polygon([(50, 80), (250, 20), (450, 80), (400, 350), (100, 350)]), # Polygon points - 'counts': 0, - 'dragging': False, - 'region_color': (255, 42, 4), # BGR Value - 'text_color': (255, 255, 255) # Region Text Color + "name": "YOLOv8 Polygon Region", + "polygon": Polygon([(50, 80), (250, 20), (450, 80), (400, 350), (100, 350)]), # Polygon points + "counts": 0, + "dragging": False, + "region_color": (255, 42, 4), # BGR Value + "text_color": (255, 255, 255), # Region Text Color }, { - 'name': 'YOLOv8 Rectangle Region', - 'polygon': Polygon([(200, 250), (440, 250), (440, 550), (200, 550)]), # Polygon points - 'counts': 0, - 'dragging': False, - 'region_color': (37, 255, 225), # BGR Value - 'text_color': (0, 0, 0), # Region Text Color - }, ] + "name": "YOLOv8 Rectangle Region", + "polygon": Polygon([(200, 250), (440, 250), (440, 550), (200, 550)]), # Polygon points + "counts": 0, + "dragging": False, + "region_color": (37, 255, 225), # BGR Value + "text_color": (0, 0, 0), # Region Text Color + }, +] def mouse_callback(event, x, y, flags, param): @@ -40,32 +41,33 @@ def mouse_callback(event, x, y, flags, param): # Mouse left button down event if event == cv2.EVENT_LBUTTONDOWN: for region in counting_regions: - if region['polygon'].contains(Point((x, y))): + if region["polygon"].contains(Point((x, y))): current_region = region - current_region['dragging'] = True - current_region['offset_x'] = x - current_region['offset_y'] = y + current_region["dragging"] = True + current_region["offset_x"] = x + current_region["offset_y"] = y # Mouse move event elif event == cv2.EVENT_MOUSEMOVE: - if current_region is not None and current_region['dragging']: - dx = x - current_region['offset_x'] - dy = y - current_region['offset_y'] - current_region['polygon'] = Polygon([ - (p[0] + dx, p[1] + dy) for p in current_region['polygon'].exterior.coords]) - current_region['offset_x'] = x - current_region['offset_y'] = y + if current_region is not None and current_region["dragging"]: + dx = x - current_region["offset_x"] + dy = y - current_region["offset_y"] + current_region["polygon"] = Polygon( + [(p[0] + dx, p[1] + dy) for p in current_region["polygon"].exterior.coords] + ) + current_region["offset_x"] = x + current_region["offset_y"] = y # Mouse left button up event elif event == cv2.EVENT_LBUTTONUP: - if current_region is not None and current_region['dragging']: - current_region['dragging'] = False + if current_region is not None and current_region["dragging"]: + current_region["dragging"] = False def run( - weights='yolov8n.pt', + weights="yolov8n.pt", source=None, - device='cpu', + device="cpu", view_img=False, save_img=False, exist_ok=False, @@ -100,8 +102,8 @@ def run( raise FileNotFoundError(f"Source path '{source}' does not exist.") # Setup Model - model = YOLO(f'{weights}') - model.to('cuda') if device == '0' else model.to('cpu') + model = YOLO(f"{weights}") + model.to("cuda") if device == "0" else model.to("cpu") # Extract classes names names = model.model.names @@ -109,12 +111,12 @@ def run( # Video setup videocapture = cv2.VideoCapture(source) frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4)) - fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*'mp4v') + fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*"mp4v") # Output setup - save_dir = increment_path(Path('ultralytics_rc_output') / 'exp', exist_ok) + save_dir = increment_path(Path("ultralytics_rc_output") / "exp", exist_ok) save_dir.mkdir(parents=True, exist_ok=True) - video_writer = cv2.VideoWriter(str(save_dir / f'{Path(source).stem}.mp4'), fourcc, fps, (frame_width, frame_height)) + video_writer = cv2.VideoWriter(str(save_dir / f"{Path(source).stem}.mp4"), fourcc, fps, (frame_width, frame_height)) # Iterate over video frames while videocapture.isOpened(): @@ -146,43 +148,48 @@ def run( # Check if detection inside region for region in counting_regions: - if region['polygon'].contains(Point((bbox_center[0], bbox_center[1]))): - region['counts'] += 1 + if region["polygon"].contains(Point((bbox_center[0], bbox_center[1]))): + region["counts"] += 1 # Draw regions (Polygons/Rectangles) for region in counting_regions: - region_label = str(region['counts']) - region_color = region['region_color'] - region_text_color = region['text_color'] + region_label = str(region["counts"]) + region_color = region["region_color"] + region_text_color = region["text_color"] - polygon_coords = np.array(region['polygon'].exterior.coords, dtype=np.int32) - centroid_x, centroid_y = int(region['polygon'].centroid.x), int(region['polygon'].centroid.y) + polygon_coords = np.array(region["polygon"].exterior.coords, dtype=np.int32) + centroid_x, centroid_y = int(region["polygon"].centroid.x), int(region["polygon"].centroid.y) - text_size, _ = cv2.getTextSize(region_label, - cv2.FONT_HERSHEY_SIMPLEX, - fontScale=0.7, - thickness=line_thickness) + text_size, _ = cv2.getTextSize( + region_label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.7, thickness=line_thickness + ) text_x = centroid_x - text_size[0] // 2 text_y = centroid_y + text_size[1] // 2 - cv2.rectangle(frame, (text_x - 5, text_y - text_size[1] - 5), (text_x + text_size[0] + 5, text_y + 5), - region_color, -1) - cv2.putText(frame, region_label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, region_text_color, - line_thickness) + cv2.rectangle( + frame, + (text_x - 5, text_y - text_size[1] - 5), + (text_x + text_size[0] + 5, text_y + 5), + region_color, + -1, + ) + cv2.putText( + frame, region_label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, region_text_color, line_thickness + ) cv2.polylines(frame, [polygon_coords], isClosed=True, color=region_color, thickness=region_thickness) if view_img: if vid_frame_count == 1: - cv2.namedWindow('Ultralytics YOLOv8 Region Counter Movable') - cv2.setMouseCallback('Ultralytics YOLOv8 Region Counter Movable', mouse_callback) - cv2.imshow('Ultralytics YOLOv8 Region Counter Movable', frame) + cv2.namedWindow("Ultralytics YOLOv8 Region Counter Movable") + cv2.setMouseCallback("Ultralytics YOLOv8 Region Counter Movable", mouse_callback) + cv2.imshow("Ultralytics YOLOv8 Region Counter Movable", frame) if save_img: video_writer.write(frame) for region in counting_regions: # Reinitialize count for each region - region['counts'] = 0 + region["counts"] = 0 - if cv2.waitKey(1) & 0xFF == ord('q'): + if cv2.waitKey(1) & 0xFF == ord("q"): break del vid_frame_count @@ -194,16 +201,16 @@ def run( def parse_opt(): """Parse command line arguments.""" parser = argparse.ArgumentParser() - parser.add_argument('--weights', type=str, default='yolov8n.pt', help='initial weights path') - parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') - parser.add_argument('--source', type=str, required=True, help='video file path') - parser.add_argument('--view-img', action='store_true', help='show results') - parser.add_argument('--save-img', action='store_true', help='save results') - parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') - parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3') - parser.add_argument('--line-thickness', type=int, default=2, help='bounding box thickness') - parser.add_argument('--track-thickness', type=int, default=2, help='Tracking line thickness') - parser.add_argument('--region-thickness', type=int, default=4, help='Region thickness') + parser.add_argument("--weights", type=str, default="yolov8n.pt", help="initial weights path") + parser.add_argument("--device", default="", help="cuda device, i.e. 0 or 0,1,2,3 or cpu") + parser.add_argument("--source", type=str, required=True, help="video file path") + parser.add_argument("--view-img", action="store_true", help="show results") + parser.add_argument("--save-img", action="store_true", help="save results") + parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment") + parser.add_argument("--classes", nargs="+", type=int, help="filter by class: --classes 0, or --classes 0 2 3") + parser.add_argument("--line-thickness", type=int, default=2, help="bounding box thickness") + parser.add_argument("--track-thickness", type=int, default=2, help="Tracking line thickness") + parser.add_argument("--region-thickness", type=int, default=4, help="Region thickness") return parser.parse_args() @@ -213,6 +220,6 @@ def main(opt): run(**vars(opt)) -if __name__ == '__main__': +if __name__ == "__main__": opt = parse_opt() main(opt) diff --git a/examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py b/examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py index 7ab84417..37d78114 100644 --- a/examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py +++ b/examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py @@ -9,7 +9,7 @@ from sahi.utils.yolov8 import download_yolov8s_model from ultralytics.utils.files import increment_path -def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False, exist_ok=False): +def run(weights="yolov8n.pt", source="test.mp4", view_img=False, save_img=False, exist_ok=False): """ Run object detection on a video using YOLOv8 and SAHI. @@ -25,41 +25,41 @@ def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False, if not Path(source).exists(): raise FileNotFoundError(f"Source path '{source}' does not exist.") - yolov8_model_path = f'models/{weights}' + yolov8_model_path = f"models/{weights}" download_yolov8s_model(yolov8_model_path) - detection_model = AutoDetectionModel.from_pretrained(model_type='yolov8', - model_path=yolov8_model_path, - confidence_threshold=0.3, - device='cpu') + detection_model = AutoDetectionModel.from_pretrained( + model_type="yolov8", model_path=yolov8_model_path, confidence_threshold=0.3, device="cpu" + ) # Video setup videocapture = cv2.VideoCapture(source) frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4)) - fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*'mp4v') + fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*"mp4v") # Output setup - save_dir = increment_path(Path('ultralytics_results_with_sahi') / 'exp', exist_ok) + save_dir = increment_path(Path("ultralytics_results_with_sahi") / "exp", exist_ok) save_dir.mkdir(parents=True, exist_ok=True) - video_writer = cv2.VideoWriter(str(save_dir / f'{Path(source).stem}.mp4'), fourcc, fps, (frame_width, frame_height)) + video_writer = cv2.VideoWriter(str(save_dir / f"{Path(source).stem}.mp4"), fourcc, fps, (frame_width, frame_height)) while videocapture.isOpened(): success, frame = videocapture.read() if not success: break - results = get_sliced_prediction(frame, - detection_model, - slice_height=512, - slice_width=512, - overlap_height_ratio=0.2, - overlap_width_ratio=0.2) + results = get_sliced_prediction( + frame, detection_model, slice_height=512, slice_width=512, overlap_height_ratio=0.2, overlap_width_ratio=0.2 + ) object_prediction_list = results.object_prediction_list boxes_list = [] clss_list = [] for ind, _ in enumerate(object_prediction_list): - boxes = object_prediction_list[ind].bbox.minx, object_prediction_list[ind].bbox.miny, \ - object_prediction_list[ind].bbox.maxx, object_prediction_list[ind].bbox.maxy + boxes = ( + object_prediction_list[ind].bbox.minx, + object_prediction_list[ind].bbox.miny, + object_prediction_list[ind].bbox.maxx, + object_prediction_list[ind].bbox.maxy, + ) clss = object_prediction_list[ind].category.name boxes_list.append(boxes) clss_list.append(clss) @@ -69,21 +69,19 @@ def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False, cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (56, 56, 255), 2) label = str(cls) t_size = cv2.getTextSize(label, 0, fontScale=0.6, thickness=1)[0] - cv2.rectangle(frame, (int(x1), int(y1) - t_size[1] - 3), (int(x1) + t_size[0], int(y1) + 3), (56, 56, 255), - -1) - cv2.putText(frame, - label, (int(x1), int(y1) - 2), - 0, - 0.6, [255, 255, 255], - thickness=1, - lineType=cv2.LINE_AA) + cv2.rectangle( + frame, (int(x1), int(y1) - t_size[1] - 3), (int(x1) + t_size[0], int(y1) + 3), (56, 56, 255), -1 + ) + cv2.putText( + frame, label, (int(x1), int(y1) - 2), 0, 0.6, [255, 255, 255], thickness=1, lineType=cv2.LINE_AA + ) if view_img: cv2.imshow(Path(source).stem, frame) if save_img: video_writer.write(frame) - if cv2.waitKey(1) & 0xFF == ord('q'): + if cv2.waitKey(1) & 0xFF == ord("q"): break video_writer.release() videocapture.release() @@ -93,11 +91,11 @@ def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False, def parse_opt(): """Parse command line arguments.""" parser = argparse.ArgumentParser() - parser.add_argument('--weights', type=str, default='yolov8n.pt', help='initial weights path') - parser.add_argument('--source', type=str, required=True, help='video file path') - parser.add_argument('--view-img', action='store_true', help='show results') - parser.add_argument('--save-img', action='store_true', help='save results') - parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') + parser.add_argument("--weights", type=str, default="yolov8n.pt", help="initial weights path") + parser.add_argument("--source", type=str, required=True, help="video file path") + parser.add_argument("--view-img", action="store_true", help="show results") + parser.add_argument("--save-img", action="store_true", help="save results") + parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment") return parser.parse_args() @@ -106,6 +104,6 @@ def main(opt): run(**vars(opt)) -if __name__ == '__main__': +if __name__ == "__main__": opt = parse_opt() main(opt) diff --git a/examples/YOLOv8-Segmentation-ONNXRuntime-Python/main.py b/examples/YOLOv8-Segmentation-ONNXRuntime-Python/main.py index b13eab35..7dd11dc3 100644 --- a/examples/YOLOv8-Segmentation-ONNXRuntime-Python/main.py +++ b/examples/YOLOv8-Segmentation-ONNXRuntime-Python/main.py @@ -21,18 +21,21 @@ class YOLOv8Seg: """ # Build Ort session - self.session = ort.InferenceSession(onnx_model, - providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] - if ort.get_device() == 'GPU' else ['CPUExecutionProvider']) + self.session = ort.InferenceSession( + onnx_model, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"] + if ort.get_device() == "GPU" + else ["CPUExecutionProvider"], + ) # Numpy dtype: support both FP32 and FP16 onnx model - self.ndtype = np.half if self.session.get_inputs()[0].type == 'tensor(float16)' else np.single + self.ndtype = np.half if self.session.get_inputs()[0].type == "tensor(float16)" else np.single # Get model width and height(YOLOv8-seg only has one input) self.model_height, self.model_width = [x.shape for x in self.session.get_inputs()][0][-2:] # Load COCO class names - self.classes = yaml_load(check_yaml('coco128.yaml'))['names'] + self.classes = yaml_load(check_yaml("coco128.yaml"))["names"] # Create color palette self.color_palette = Colors() @@ -60,14 +63,16 @@ class YOLOv8Seg: preds = self.session.run(None, {self.session.get_inputs()[0].name: im}) # Post-process - boxes, segments, masks = self.postprocess(preds, - im0=im0, - ratio=ratio, - pad_w=pad_w, - pad_h=pad_h, - conf_threshold=conf_threshold, - iou_threshold=iou_threshold, - nm=nm) + boxes, segments, masks = self.postprocess( + preds, + im0=im0, + ratio=ratio, + pad_w=pad_w, + pad_h=pad_h, + conf_threshold=conf_threshold, + iou_threshold=iou_threshold, + nm=nm, + ) return boxes, segments, masks def preprocess(self, img): @@ -98,7 +103,7 @@ class YOLOv8Seg: img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)) # Transforms: HWC to CHW -> BGR to RGB -> div(255) -> contiguous -> add axis(optional) - img = np.ascontiguousarray(np.einsum('HWC->CHW', img)[::-1], dtype=self.ndtype) / 255.0 + img = np.ascontiguousarray(np.einsum("HWC->CHW", img)[::-1], dtype=self.ndtype) / 255.0 img_process = img[None] if len(img.shape) == 3 else img return img_process, ratio, (pad_w, pad_h) @@ -124,7 +129,7 @@ class YOLOv8Seg: x, protos = preds[0], preds[1] # Two outputs: predictions and protos # Transpose the first output: (Batch_size, xywh_conf_cls_nm, Num_anchors) -> (Batch_size, Num_anchors, xywh_conf_cls_nm) - x = np.einsum('bcn->bnc', x) + x = np.einsum("bcn->bnc", x) # Predictions filtering by conf-threshold x = x[np.amax(x[..., 4:-nm], axis=-1) > conf_threshold] @@ -138,7 +143,6 @@ class YOLOv8Seg: # Decode and return if len(x) > 0: - # Bounding boxes format change: cxcywh -> xyxy x[..., [0, 1]] -= x[..., [2, 3]] / 2 x[..., [2, 3]] += x[..., [0, 1]] @@ -173,13 +177,13 @@ class YOLOv8Seg: segments (List): list of segment masks. """ segments = [] - for x in masks.astype('uint8'): + for x in masks.astype("uint8"): c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[0] # CHAIN_APPROX_SIMPLE if c: c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2) else: c = np.zeros((0, 2)) # no segments found - segments.append(c.astype('float32')) + segments.append(c.astype("float32")) return segments @staticmethod @@ -219,7 +223,7 @@ class YOLOv8Seg: masks = np.matmul(masks_in, protos.reshape((c, -1))).reshape((-1, mh, mw)).transpose(1, 2, 0) # HWN masks = np.ascontiguousarray(masks) masks = self.scale_mask(masks, im0_shape) # re-scale mask from P3 shape to original input image shape - masks = np.einsum('HWN -> NHW', masks) # HWN -> NHW + masks = np.einsum("HWN -> NHW", masks) # HWN -> NHW masks = self.crop_mask(masks, bboxes) return np.greater(masks, 0.5) @@ -250,8 +254,9 @@ class YOLOv8Seg: if len(masks.shape) < 2: raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}') masks = masks[top:bottom, left:right] - masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]), - interpolation=cv2.INTER_LINEAR) # INTER_CUBIC would be better + masks = cv2.resize( + masks, (im0_shape[1], im0_shape[0]), interpolation=cv2.INTER_LINEAR + ) # INTER_CUBIC would be better if len(masks.shape) == 2: masks = masks[:, :, None] return masks @@ -279,32 +284,46 @@ class YOLOv8Seg: cv2.fillPoly(im_canvas, np.int32([segment]), self.color_palette(int(cls_), bgr=True)) # draw bbox rectangle - cv2.rectangle(im, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), - self.color_palette(int(cls_), bgr=True), 1, cv2.LINE_AA) - cv2.putText(im, f'{self.classes[cls_]}: {conf:.3f}', (int(box[0]), int(box[1] - 9)), - cv2.FONT_HERSHEY_SIMPLEX, 0.7, self.color_palette(int(cls_), bgr=True), 2, cv2.LINE_AA) + cv2.rectangle( + im, + (int(box[0]), int(box[1])), + (int(box[2]), int(box[3])), + self.color_palette(int(cls_), bgr=True), + 1, + cv2.LINE_AA, + ) + cv2.putText( + im, + f"{self.classes[cls_]}: {conf:.3f}", + (int(box[0]), int(box[1] - 9)), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + self.color_palette(int(cls_), bgr=True), + 2, + cv2.LINE_AA, + ) # Mix image im = cv2.addWeighted(im_canvas, 0.3, im, 0.7, 0) # Show image if vis: - cv2.imshow('demo', im) + cv2.imshow("demo", im) cv2.waitKey(0) cv2.destroyAllWindows() # Save image if save: - cv2.imwrite('demo.jpg', im) + cv2.imwrite("demo.jpg", im) -if __name__ == '__main__': +if __name__ == "__main__": # Create an argument parser to handle command-line arguments parser = argparse.ArgumentParser() - parser.add_argument('--model', type=str, required=True, help='Path to ONNX model') - parser.add_argument('--source', type=str, default=str(ASSETS / 'bus.jpg'), help='Path to input image') - parser.add_argument('--conf', type=float, default=0.25, help='Confidence threshold') - parser.add_argument('--iou', type=float, default=0.45, help='NMS IoU threshold') + parser.add_argument("--model", type=str, required=True, help="Path to ONNX model") + parser.add_argument("--source", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image") + parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold") + parser.add_argument("--iou", type=float, default=0.45, help="NMS IoU threshold") args = parser.parse_args() # Build model diff --git a/pyproject.toml b/pyproject.toml index 7b734356..a994e2ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,9 +107,9 @@ export = [ ] explorer = [ - "lancedb", # vector search - "duckdb", # SQL queries, supports lancedb tables - "streamlit", # visualizing with GUI + "lancedb", # vector search + "duckdb", # SQL queries, supports lancedb tables + "streamlit", # visualizing with GUI ] # tensorflow>=2.4.1,<=2.13.1 # TF exports (-cpu, -aarch64, -macos) @@ -179,5 +179,5 @@ pre-summary-newline = true close-quotes-on-newline = true [tool.codespell] -ignore-words-list = "crate,nd,strack,dota,ane,segway,fo,gool,winn,commend" -skip = '*.csv,*venv*,docs/??/,docs/mkdocs_??.yml' +ignore-words-list = "crate,nd,ned,strack,dota,ane,segway,fo,gool,winn,commend,bloc,nam,afterall" +skip = "*.pt,*.pth,*.torchscript,*.onnx,*.tflite,*.pb,*.bin,*.param,*.mlmodel,*.engine,*.npy,*.data*,*.csv,*pnnx*,*venv*,__pycache__*,*.ico,*.jpg,*.png,*.mp4,*.mov,/runs,/.git,./docs/??/*.md,./docs/mkdocs_??.yml" diff --git a/tests/test_explorer.py b/tests/test_explorer.py index 23121c43..6a0995bc 100644 --- a/tests/test_explorer.py +++ b/tests/test_explorer.py @@ -3,6 +3,8 @@ import PIL from ultralytics import Explorer from ultralytics.utils import ASSETS +import PIL + def test_similarity(): """Test similarity calculations and SQL queries for correctness and response length.""" diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 4656c5a9..a8a23d32 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = '8.0.238' +__version__ = "8.0.239" from ultralytics.data.explorer.explorer import Explorer from ultralytics.models import RTDETR, SAM, YOLO @@ -10,4 +10,4 @@ from ultralytics.utils import SETTINGS as settings from ultralytics.utils.checks import check_yolo as checks from ultralytics.utils.downloads import download -__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'settings', 'Explorer' +__all__ = "__version__", "YOLO", "NAS", "SAM", "FastSAM", "RTDETR", "checks", "download", "settings", "Explorer" diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py index 2488aa6b..7d9aca7e 100644 --- a/ultralytics/cfg/__init__.py +++ b/ultralytics/cfg/__init__.py @@ -8,34 +8,53 @@ from pathlib import Path from types import SimpleNamespace from typing import Dict, List, Union -from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, RANK, ROOT, RUNS_DIR, - SETTINGS, SETTINGS_YAML, TESTS_RUNNING, IterableSimpleNamespace, __version__, checks, - colorstr, deprecation_warn, yaml_load, yaml_print) +from ultralytics.utils import ( + ASSETS, + DEFAULT_CFG, + DEFAULT_CFG_DICT, + DEFAULT_CFG_PATH, + LOGGER, + RANK, + ROOT, + RUNS_DIR, + SETTINGS, + SETTINGS_YAML, + TESTS_RUNNING, + IterableSimpleNamespace, + __version__, + checks, + colorstr, + deprecation_warn, + yaml_load, + yaml_print, +) # Define valid tasks and modes -MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark' -TASKS = 'detect', 'segment', 'classify', 'pose', 'obb' +MODES = "train", "val", "predict", "export", "track", "benchmark" +TASKS = "detect", "segment", "classify", "pose", "obb" TASK2DATA = { - 'detect': 'coco8.yaml', - 'segment': 'coco8-seg.yaml', - 'classify': 'imagenet10', - 'pose': 'coco8-pose.yaml', - 'obb': 'dota8-obb.yaml'} # not implemented yet + "detect": "coco8.yaml", + "segment": "coco8-seg.yaml", + "classify": "imagenet10", + "pose": "coco8-pose.yaml", + "obb": "dota8-obb.yaml", +} TASK2MODEL = { - 'detect': 'yolov8n.pt', - 'segment': 'yolov8n-seg.pt', - 'classify': 'yolov8n-cls.pt', - 'pose': 'yolov8n-pose.pt', - 'obb': 'yolov8n-obb.pt'} + "detect": "yolov8n.pt", + "segment": "yolov8n-seg.pt", + "classify": "yolov8n-cls.pt", + "pose": "yolov8n-pose.pt", + "obb": "yolov8n-obb.pt", +} TASK2METRIC = { - 'detect': 'metrics/mAP50-95(B)', - 'segment': 'metrics/mAP50-95(M)', - 'classify': 'metrics/accuracy_top1', - 'pose': 'metrics/mAP50-95(P)', - 'obb': 'metrics/mAP50-95(OBB)'} - -CLI_HELP_MSG = \ - f""" + "detect": "metrics/mAP50-95(B)", + "segment": "metrics/mAP50-95(M)", + "classify": "metrics/accuracy_top1", + "pose": "metrics/mAP50-95(P)", + "obb": "metrics/mAP50-95(OBB)", +} + +CLI_HELP_MSG = f""" Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax: yolo TASK MODE ARGS @@ -74,16 +93,83 @@ CLI_HELP_MSG = \ """ # Define keys for arg type checks -CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear', 'time' -CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr', - 'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud', - 'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou', 'fraction') # fraction floats 0.0 - 1.0 -CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride', - 'line_width', 'workspace', 'nbs', 'save_period') -CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', 'rect', 'cos_lr', 'overlap_mask', 'val', - 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf', 'save_crop', - 'save_frames', 'show_labels', 'show_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', - 'show_boxes', 'keras', 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'profile', 'multi_scale') +CFG_FLOAT_KEYS = "warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time" +CFG_FRACTION_KEYS = ( + "dropout", + "iou", + "lr0", + "lrf", + "momentum", + "weight_decay", + "warmup_momentum", + "warmup_bias_lr", + "label_smoothing", + "hsv_h", + "hsv_s", + "hsv_v", + "translate", + "scale", + "perspective", + "flipud", + "fliplr", + "mosaic", + "mixup", + "copy_paste", + "conf", + "iou", + "fraction", +) # fraction floats 0.0 - 1.0 +CFG_INT_KEYS = ( + "epochs", + "patience", + "batch", + "workers", + "seed", + "close_mosaic", + "mask_ratio", + "max_det", + "vid_stride", + "line_width", + "workspace", + "nbs", + "save_period", +) +CFG_BOOL_KEYS = ( + "save", + "exist_ok", + "verbose", + "deterministic", + "single_cls", + "rect", + "cos_lr", + "overlap_mask", + "val", + "save_json", + "save_hybrid", + "half", + "dnn", + "plots", + "show", + "save_txt", + "save_conf", + "save_crop", + "save_frames", + "show_labels", + "show_conf", + "visualize", + "augment", + "agnostic_nms", + "retina_masks", + "show_boxes", + "keras", + "optimize", + "int8", + "dynamic", + "simplify", + "nms", + "profile", + "multi_scale", +) def cfg2dict(cfg): @@ -119,38 +205,44 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove # Merge overrides if overrides: overrides = cfg2dict(overrides) - if 'save_dir' not in cfg: - overrides.pop('save_dir', None) # special override keys to ignore + if "save_dir" not in cfg: + overrides.pop("save_dir", None) # special override keys to ignore check_dict_alignment(cfg, overrides) cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides) # Special handling for numeric project/name - for k in 'project', 'name': + for k in "project", "name": if k in cfg and isinstance(cfg[k], (int, float)): cfg[k] = str(cfg[k]) - if cfg.get('name') == 'model': # assign model to 'name' arg - cfg['name'] = cfg.get('model', '').split('.')[0] + if cfg.get("name") == "model": # assign model to 'name' arg + cfg["name"] = cfg.get("model", "").split(".")[0] LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.") # Type and Value checks for k, v in cfg.items(): if v is not None: # None values may be from optional args if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)): - raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. " - f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')") + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" + ) elif k in CFG_FRACTION_KEYS: if not isinstance(v, (int, float)): - raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. " - f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')") + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" + ) if not (0.0 <= v <= 1.0): - raise ValueError(f"'{k}={v}' is an invalid value. " - f"Valid '{k}' values are between 0.0 and 1.0.") + raise ValueError(f"'{k}={v}' is an invalid value. " f"Valid '{k}' values are between 0.0 and 1.0.") elif k in CFG_INT_KEYS and not isinstance(v, int): - raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. " - f"'{k}' must be an int (i.e. '{k}=8')") + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')" + ) elif k in CFG_BOOL_KEYS and not isinstance(v, bool): - raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. " - f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')") + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')" + ) # Return instance return IterableSimpleNamespace(**cfg) @@ -159,13 +251,13 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove def get_save_dir(args, name=None): """Return save_dir as created from train/val/predict arguments.""" - if getattr(args, 'save_dir', None): + if getattr(args, "save_dir", None): save_dir = args.save_dir else: from ultralytics.utils.files import increment_path - project = args.project or (ROOT.parent / 'tests/tmp/runs' if TESTS_RUNNING else RUNS_DIR) / args.task - name = name or args.name or f'{args.mode}' + project = args.project or (ROOT.parent / "tests/tmp/runs" if TESTS_RUNNING else RUNS_DIR) / args.task + name = name or args.name or f"{args.mode}" save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True) return Path(save_dir) @@ -175,18 +267,18 @@ def _handle_deprecation(custom): """Hardcoded function to handle deprecated config keys.""" for key in custom.copy().keys(): - if key == 'boxes': - deprecation_warn(key, 'show_boxes') - custom['show_boxes'] = custom.pop('boxes') - if key == 'hide_labels': - deprecation_warn(key, 'show_labels') - custom['show_labels'] = custom.pop('hide_labels') == 'False' - if key == 'hide_conf': - deprecation_warn(key, 'show_conf') - custom['show_conf'] = custom.pop('hide_conf') == 'False' - if key == 'line_thickness': - deprecation_warn(key, 'line_width') - custom['line_width'] = custom.pop('line_thickness') + if key == "boxes": + deprecation_warn(key, "show_boxes") + custom["show_boxes"] = custom.pop("boxes") + if key == "hide_labels": + deprecation_warn(key, "show_labels") + custom["show_labels"] = custom.pop("hide_labels") == "False" + if key == "hide_conf": + deprecation_warn(key, "show_conf") + custom["show_conf"] = custom.pop("hide_conf") == "False" + if key == "line_thickness": + deprecation_warn(key, "line_width") + custom["line_width"] = custom.pop("line_thickness") return custom @@ -207,11 +299,11 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None): if mismatched: from difflib import get_close_matches - string = '' + string = "" for x in mismatched: matches = get_close_matches(x, base_keys) # key list - matches = [f'{k}={base[k]}' if base.get(k) is not None else k for k in matches] - match_str = f'Similar arguments are i.e. {matches}.' if matches else '' + matches = [f"{k}={base[k]}" if base.get(k) is not None else k for k in matches] + match_str = f"Similar arguments are i.e. {matches}." if matches else "" string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n" raise SyntaxError(string + CLI_HELP_MSG) from e @@ -229,13 +321,13 @@ def merge_equals_args(args: List[str]) -> List[str]: """ new_args = [] for i, arg in enumerate(args): - if arg == '=' and 0 < i < len(args) - 1: # merge ['arg', '=', 'val'] - new_args[-1] += f'={args[i + 1]}' + if arg == "=" and 0 < i < len(args) - 1: # merge ['arg', '=', 'val'] + new_args[-1] += f"={args[i + 1]}" del args[i + 1] - elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]: # merge ['arg=', 'val'] - new_args.append(f'{arg}{args[i + 1]}') + elif arg.endswith("=") and i < len(args) - 1 and "=" not in args[i + 1]: # merge ['arg=', 'val'] + new_args.append(f"{arg}{args[i + 1]}") del args[i + 1] - elif arg.startswith('=') and i > 0: # merge ['arg', '=val'] + elif arg.startswith("=") and i > 0: # merge ['arg', '=val'] new_args[-1] += arg else: new_args.append(arg) @@ -259,11 +351,11 @@ def handle_yolo_hub(args: List[str]) -> None: """ from ultralytics import hub - if args[0] == 'login': - key = args[1] if len(args) > 1 else '' + if args[0] == "login": + key = args[1] if len(args) > 1 else "" # Log in to Ultralytics HUB using the provided API key hub.login(key) - elif args[0] == 'logout': + elif args[0] == "logout": # Log out from Ultralytics HUB hub.logout() @@ -283,19 +375,19 @@ def handle_yolo_settings(args: List[str]) -> None: python my_script.py yolo settings reset ``` """ - url = 'https://docs.ultralytics.com/quickstart/#ultralytics-settings' # help URL + url = "https://docs.ultralytics.com/quickstart/#ultralytics-settings" # help URL try: if any(args): - if args[0] == 'reset': + if args[0] == "reset": SETTINGS_YAML.unlink() # delete the settings file SETTINGS.reset() # create new settings - LOGGER.info('Settings reset successfully') # inform the user that settings have been reset + LOGGER.info("Settings reset successfully") # inform the user that settings have been reset else: # save a new setting new = dict(parse_key_value_pair(a) for a in args) check_dict_alignment(SETTINGS, new) SETTINGS.update(new) - LOGGER.info(f'💡 Learn about settings at {url}') + LOGGER.info(f"💡 Learn about settings at {url}") yaml_print(SETTINGS_YAML) # print the current settings except Exception as e: LOGGER.warning(f"WARNING ⚠️ settings error: '{e}'. Please see {url} for help.") @@ -303,13 +395,13 @@ def handle_yolo_settings(args: List[str]) -> None: def handle_explorer(): """Open the Ultralytics Explorer GUI.""" - checks.check_requirements('streamlit') - subprocess.run(['streamlit', 'run', ROOT / 'data/explorer/gui/dash.py', '--server.maxMessageSize', '2048']) + checks.check_requirements("streamlit") + subprocess.run(["streamlit", "run", ROOT / "data/explorer/gui/dash.py", "--server.maxMessageSize", "2048"]) def parse_key_value_pair(pair): """Parse one 'key=value' pair and return key and value.""" - k, v = pair.split('=', 1) # split on first '=' sign + k, v = pair.split("=", 1) # split on first '=' sign k, v = k.strip(), v.strip() # remove spaces assert v, f"missing '{k}' value" return k, smart_value(v) @@ -318,11 +410,11 @@ def parse_key_value_pair(pair): def smart_value(v): """Convert a string to an underlying type such as int, float, bool, etc.""" v_lower = v.lower() - if v_lower == 'none': + if v_lower == "none": return None - elif v_lower == 'true': + elif v_lower == "true": return True - elif v_lower == 'false': + elif v_lower == "false": return False else: with contextlib.suppress(Exception): @@ -330,7 +422,7 @@ def smart_value(v): return v -def entrypoint(debug=''): +def entrypoint(debug=""): """ This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed to the package. @@ -345,139 +437,150 @@ def entrypoint(debug=''): It uses the package's default cfg and initializes it using the passed overrides. Then it calls the CLI function with the composed cfg """ - args = (debug.split(' ') if debug else sys.argv)[1:] + args = (debug.split(" ") if debug else sys.argv)[1:] if not args: # no arguments passed LOGGER.info(CLI_HELP_MSG) return special = { - 'help': lambda: LOGGER.info(CLI_HELP_MSG), - 'checks': checks.collect_system_info, - 'version': lambda: LOGGER.info(__version__), - 'settings': lambda: handle_yolo_settings(args[1:]), - 'cfg': lambda: yaml_print(DEFAULT_CFG_PATH), - 'hub': lambda: handle_yolo_hub(args[1:]), - 'login': lambda: handle_yolo_hub(args), - 'copy-cfg': copy_default_cfg, - 'explorer': lambda: handle_explorer()} + "help": lambda: LOGGER.info(CLI_HELP_MSG), + "checks": checks.collect_system_info, + "version": lambda: LOGGER.info(__version__), + "settings": lambda: handle_yolo_settings(args[1:]), + "cfg": lambda: yaml_print(DEFAULT_CFG_PATH), + "hub": lambda: handle_yolo_hub(args[1:]), + "login": lambda: handle_yolo_hub(args), + "copy-cfg": copy_default_cfg, + "explorer": lambda: handle_explorer(), + } full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special} # Define common misuses of special commands, i.e. -h, -help, --help special.update({k[0]: v for k, v in special.items()}) # singular - special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith('s')}) # singular - special = {**special, **{f'-{k}': v for k, v in special.items()}, **{f'--{k}': v for k, v in special.items()}} + special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith("s")}) # singular + special = {**special, **{f"-{k}": v for k, v in special.items()}, **{f"--{k}": v for k, v in special.items()}} overrides = {} # basic overrides, i.e. imgsz=320 for a in merge_equals_args(args): # merge spaces around '=' sign - if a.startswith('--'): + if a.startswith("--"): LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.") a = a[2:] - if a.endswith(','): + if a.endswith(","): LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.") a = a[:-1] - if '=' in a: + if "=" in a: try: k, v = parse_key_value_pair(a) - if k == 'cfg' and v is not None: # custom.yaml passed - LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}') - overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != 'cfg'} + if k == "cfg" and v is not None: # custom.yaml passed + LOGGER.info(f"Overriding {DEFAULT_CFG_PATH} with {v}") + overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != "cfg"} else: overrides[k] = v except (NameError, SyntaxError, ValueError, AssertionError) as e: - check_dict_alignment(full_args_dict, {a: ''}, e) + check_dict_alignment(full_args_dict, {a: ""}, e) elif a in TASKS: - overrides['task'] = a + overrides["task"] = a elif a in MODES: - overrides['mode'] = a + overrides["mode"] = a elif a.lower() in special: special[a.lower()]() return elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool): overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True elif a in DEFAULT_CFG_DICT: - raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign " - f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}") + raise SyntaxError( + f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign " + f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}" + ) else: - check_dict_alignment(full_args_dict, {a: ''}) + check_dict_alignment(full_args_dict, {a: ""}) # Check keys check_dict_alignment(full_args_dict, overrides) # Mode - mode = overrides.get('mode') + mode = overrides.get("mode") if mode is None: - mode = DEFAULT_CFG.mode or 'predict' + mode = DEFAULT_CFG.mode or "predict" LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.") elif mode not in MODES: raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}") # Task - task = overrides.pop('task', None) + task = overrides.pop("task", None) if task: if task not in TASKS: raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}") - if 'model' not in overrides: - overrides['model'] = TASK2MODEL[task] + if "model" not in overrides: + overrides["model"] = TASK2MODEL[task] # Model - model = overrides.pop('model', DEFAULT_CFG.model) + model = overrides.pop("model", DEFAULT_CFG.model) if model is None: - model = 'yolov8n.pt' + model = "yolov8n.pt" LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.") - overrides['model'] = model + overrides["model"] = model stem = Path(model).stem.lower() - if 'rtdetr' in stem: # guess architecture + if "rtdetr" in stem: # guess architecture from ultralytics import RTDETR + model = RTDETR(model) # no task argument - elif 'fastsam' in stem: + elif "fastsam" in stem: from ultralytics import FastSAM + model = FastSAM(model) - elif 'sam' in stem: + elif "sam" in stem: from ultralytics import SAM + model = SAM(model) else: from ultralytics import YOLO + model = YOLO(model, task=task) - if isinstance(overrides.get('pretrained'), str): - model.load(overrides['pretrained']) + if isinstance(overrides.get("pretrained"), str): + model.load(overrides["pretrained"]) # Task Update if task != model.task: if task: - LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. " - f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.") + LOGGER.warning( + f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. " + f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model." + ) task = model.task # Mode - if mode in ('predict', 'track') and 'source' not in overrides: - overrides['source'] = DEFAULT_CFG.source or ASSETS + if mode in ("predict", "track") and "source" not in overrides: + overrides["source"] = DEFAULT_CFG.source or ASSETS LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.") - elif mode in ('train', 'val'): - if 'data' not in overrides and 'resume' not in overrides: - overrides['data'] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data) + elif mode in ("train", "val"): + if "data" not in overrides and "resume" not in overrides: + overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data) LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.") - elif mode == 'export': - if 'format' not in overrides: - overrides['format'] = DEFAULT_CFG.format or 'torchscript' + elif mode == "export": + if "format" not in overrides: + overrides["format"] = DEFAULT_CFG.format or "torchscript" LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.") # Run command in python getattr(model, mode)(**overrides) # default args from model # Show help - LOGGER.info(f'💡 Learn more at https://docs.ultralytics.com/modes/{mode}') + LOGGER.info(f"💡 Learn more at https://docs.ultralytics.com/modes/{mode}") # Special modes -------------------------------------------------------------------------------------------------------- def copy_default_cfg(): """Copy and create a new default configuration file with '_copy' appended to its name.""" - new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml') + new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace(".yaml", "_copy.yaml") shutil.copy2(DEFAULT_CFG_PATH, new_file) - LOGGER.info(f'{DEFAULT_CFG_PATH} copied to {new_file}\n' - f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8") + LOGGER.info( + f"{DEFAULT_CFG_PATH} copied to {new_file}\n" + f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8" + ) -if __name__ == '__main__': +if __name__ == "__main__": # Example: entrypoint(debug='yolo predict model=yolov8n.pt') - entrypoint(debug='') + entrypoint(debug="") diff --git a/ultralytics/data/__init__.py b/ultralytics/data/__init__.py index 6fa7e845..9f91ce97 100644 --- a/ultralytics/data/__init__.py +++ b/ultralytics/data/__init__.py @@ -4,5 +4,12 @@ from .base import BaseDataset from .build import build_dataloader, build_yolo_dataset, load_inference_source from .dataset import ClassificationDataset, SemanticDataset, YOLODataset -__all__ = ('BaseDataset', 'ClassificationDataset', 'SemanticDataset', 'YOLODataset', 'build_yolo_dataset', - 'build_dataloader', 'load_inference_source') +__all__ = ( + "BaseDataset", + "ClassificationDataset", + "SemanticDataset", + "YOLODataset", + "build_yolo_dataset", + "build_dataloader", + "load_inference_source", +) diff --git a/ultralytics/data/annotator.py b/ultralytics/data/annotator.py index b4e08c76..cd0444f3 100644 --- a/ultralytics/data/annotator.py +++ b/ultralytics/data/annotator.py @@ -5,7 +5,7 @@ from pathlib import Path from ultralytics import SAM, YOLO -def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None): +def auto_annotate(data, det_model="yolov8x.pt", sam_model="sam_b.pt", device="", output_dir=None): """ Automatically annotates images using a YOLO object detection model and a SAM segmentation model. @@ -29,7 +29,7 @@ def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', data = Path(data) if not output_dir: - output_dir = data.parent / f'{data.stem}_auto_annotate_labels' + output_dir = data.parent / f"{data.stem}_auto_annotate_labels" Path(output_dir).mkdir(exist_ok=True, parents=True) det_results = det_model(data, stream=True, device=device) @@ -41,10 +41,10 @@ def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device) segments = sam_results[0].masks.xyn # noqa - with open(f'{str(Path(output_dir) / Path(result.path).stem)}.txt', 'w') as f: + with open(f"{str(Path(output_dir) / Path(result.path).stem)}.txt", "w") as f: for i in range(len(segments)): s = segments[i] if len(s) == 0: continue segment = map(str, segments[i].reshape(-1).tolist()) - f.write(f'{class_ids[i]} ' + ' '.join(segment) + '\n') + f.write(f"{class_ids[i]} " + " ".join(segment) + "\n") diff --git a/ultralytics/data/augment.py b/ultralytics/data/augment.py index a1cd1c3c..9dd5ef59 100644 --- a/ultralytics/data/augment.py +++ b/ultralytics/data/augment.py @@ -117,11 +117,11 @@ class BaseMixTransform: if self.pre_transform is not None: for i, data in enumerate(mix_labels): mix_labels[i] = self.pre_transform(data) - labels['mix_labels'] = mix_labels + labels["mix_labels"] = mix_labels # Mosaic or MixUp labels = self._mix_transform(labels) - labels.pop('mix_labels', None) + labels.pop("mix_labels", None) return labels def _mix_transform(self, labels): @@ -149,8 +149,8 @@ class Mosaic(BaseMixTransform): def __init__(self, dataset, imgsz=640, p=1.0, n=4): """Initializes the object with a dataset, image size, probability, and border.""" - assert 0 <= p <= 1.0, f'The probability should be in range [0, 1], but got {p}.' - assert n in (4, 9), 'grid must be equal to 4 or 9.' + assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}." + assert n in (4, 9), "grid must be equal to 4 or 9." super().__init__(dataset=dataset, p=p) self.dataset = dataset self.imgsz = imgsz @@ -166,20 +166,21 @@ class Mosaic(BaseMixTransform): def _mix_transform(self, labels): """Apply mixup transformation to the input image and labels.""" - assert labels.get('rect_shape', None) is None, 'rect and mosaic are mutually exclusive.' - assert len(labels.get('mix_labels', [])), 'There are no other images for mosaic augment.' - return self._mosaic3(labels) if self.n == 3 else self._mosaic4(labels) if self.n == 4 else self._mosaic9( - labels) # This code is modified for mosaic3 method. + assert labels.get("rect_shape", None) is None, "rect and mosaic are mutually exclusive." + assert len(labels.get("mix_labels", [])), "There are no other images for mosaic augment." + return ( + self._mosaic3(labels) if self.n == 3 else self._mosaic4(labels) if self.n == 4 else self._mosaic9(labels) + ) # This code is modified for mosaic3 method. def _mosaic3(self, labels): """Create a 1x3 image mosaic.""" mosaic_labels = [] s = self.imgsz for i in range(3): - labels_patch = labels if i == 0 else labels['mix_labels'][i - 1] + labels_patch = labels if i == 0 else labels["mix_labels"][i - 1] # Load image - img = labels_patch['img'] - h, w = labels_patch.pop('resized_shape') + img = labels_patch["img"] + h, w = labels_patch.pop("resized_shape") # Place img in img3 if i == 0: # center @@ -194,7 +195,7 @@ class Mosaic(BaseMixTransform): padw, padh = c[:2] x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords - img3[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img3[ymin:ymax, xmin:xmax] + img3[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img3[ymin:ymax, xmin:xmax] # hp, wp = h, w # height, width previous for next iteration # Labels assuming imgsz*2 mosaic size @@ -202,7 +203,7 @@ class Mosaic(BaseMixTransform): mosaic_labels.append(labels_patch) final_labels = self._cat_labels(mosaic_labels) - final_labels['img'] = img3[-self.border[0]:self.border[0], -self.border[1]:self.border[1]] + final_labels["img"] = img3[-self.border[0] : self.border[0], -self.border[1] : self.border[1]] return final_labels def _mosaic4(self, labels): @@ -211,10 +212,10 @@ class Mosaic(BaseMixTransform): s = self.imgsz yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y for i in range(4): - labels_patch = labels if i == 0 else labels['mix_labels'][i - 1] + labels_patch = labels if i == 0 else labels["mix_labels"][i - 1] # Load image - img = labels_patch['img'] - h, w = labels_patch.pop('resized_shape') + img = labels_patch["img"] + h, w = labels_patch.pop("resized_shape") # Place img in img4 if i == 0: # top left @@ -238,7 +239,7 @@ class Mosaic(BaseMixTransform): labels_patch = self._update_labels(labels_patch, padw, padh) mosaic_labels.append(labels_patch) final_labels = self._cat_labels(mosaic_labels) - final_labels['img'] = img4 + final_labels["img"] = img4 return final_labels def _mosaic9(self, labels): @@ -247,10 +248,10 @@ class Mosaic(BaseMixTransform): s = self.imgsz hp, wp = -1, -1 # height, width previous for i in range(9): - labels_patch = labels if i == 0 else labels['mix_labels'][i - 1] + labels_patch = labels if i == 0 else labels["mix_labels"][i - 1] # Load image - img = labels_patch['img'] - h, w = labels_patch.pop('resized_shape') + img = labels_patch["img"] + h, w = labels_patch.pop("resized_shape") # Place img in img9 if i == 0: # center @@ -278,7 +279,7 @@ class Mosaic(BaseMixTransform): x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords # Image - img9[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img9[ymin:ymax, xmin:xmax] + img9[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img9[ymin:ymax, xmin:xmax] hp, wp = h, w # height, width previous for next iteration # Labels assuming imgsz*2 mosaic size @@ -286,16 +287,16 @@ class Mosaic(BaseMixTransform): mosaic_labels.append(labels_patch) final_labels = self._cat_labels(mosaic_labels) - final_labels['img'] = img9[-self.border[0]:self.border[0], -self.border[1]:self.border[1]] + final_labels["img"] = img9[-self.border[0] : self.border[0], -self.border[1] : self.border[1]] return final_labels @staticmethod def _update_labels(labels, padw, padh): """Update labels.""" - nh, nw = labels['img'].shape[:2] - labels['instances'].convert_bbox(format='xyxy') - labels['instances'].denormalize(nw, nh) - labels['instances'].add_padding(padw, padh) + nh, nw = labels["img"].shape[:2] + labels["instances"].convert_bbox(format="xyxy") + labels["instances"].denormalize(nw, nh) + labels["instances"].add_padding(padw, padh) return labels def _cat_labels(self, mosaic_labels): @@ -306,18 +307,20 @@ class Mosaic(BaseMixTransform): instances = [] imgsz = self.imgsz * 2 # mosaic imgsz for labels in mosaic_labels: - cls.append(labels['cls']) - instances.append(labels['instances']) + cls.append(labels["cls"]) + instances.append(labels["instances"]) + # Final labels final_labels = { - 'im_file': mosaic_labels[0]['im_file'], - 'ori_shape': mosaic_labels[0]['ori_shape'], - 'resized_shape': (imgsz, imgsz), - 'cls': np.concatenate(cls, 0), - 'instances': Instances.concatenate(instances, axis=0), - 'mosaic_border': self.border} # final_labels - final_labels['instances'].clip(imgsz, imgsz) - good = final_labels['instances'].remove_zero_area_boxes() - final_labels['cls'] = final_labels['cls'][good] + "im_file": mosaic_labels[0]["im_file"], + "ori_shape": mosaic_labels[0]["ori_shape"], + "resized_shape": (imgsz, imgsz), + "cls": np.concatenate(cls, 0), + "instances": Instances.concatenate(instances, axis=0), + "mosaic_border": self.border, + } + final_labels["instances"].clip(imgsz, imgsz) + good = final_labels["instances"].remove_zero_area_boxes() + final_labels["cls"] = final_labels["cls"][good] return final_labels @@ -335,10 +338,10 @@ class MixUp(BaseMixTransform): def _mix_transform(self, labels): """Applies MixUp augmentation as per https://arxiv.org/pdf/1710.09412.pdf.""" r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0 - labels2 = labels['mix_labels'][0] - labels['img'] = (labels['img'] * r + labels2['img'] * (1 - r)).astype(np.uint8) - labels['instances'] = Instances.concatenate([labels['instances'], labels2['instances']], axis=0) - labels['cls'] = np.concatenate([labels['cls'], labels2['cls']], 0) + labels2 = labels["mix_labels"][0] + labels["img"] = (labels["img"] * r + labels2["img"] * (1 - r)).astype(np.uint8) + labels["instances"] = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0) + labels["cls"] = np.concatenate([labels["cls"], labels2["cls"]], 0) return labels @@ -366,14 +369,9 @@ class RandomPerspective: box_candidates(box1, box2): Filters out bounding boxes that don't meet certain criteria post-transformation. """ - def __init__(self, - degrees=0.0, - translate=0.1, - scale=0.5, - shear=0.0, - perspective=0.0, - border=(0, 0), - pre_transform=None): + def __init__( + self, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, border=(0, 0), pre_transform=None + ): """Initializes RandomPerspective object with transformation parameters.""" self.degrees = degrees @@ -519,18 +517,18 @@ class RandomPerspective: Args: labels (dict): a dict of `bboxes`, `segments`, `keypoints`. """ - if self.pre_transform and 'mosaic_border' not in labels: + if self.pre_transform and "mosaic_border" not in labels: labels = self.pre_transform(labels) - labels.pop('ratio_pad', None) # do not need ratio pad + labels.pop("ratio_pad", None) # do not need ratio pad - img = labels['img'] - cls = labels['cls'] - instances = labels.pop('instances') + img = labels["img"] + cls = labels["cls"] + instances = labels.pop("instances") # Make sure the coord formats are right - instances.convert_bbox(format='xyxy') + instances.convert_bbox(format="xyxy") instances.denormalize(*img.shape[:2][::-1]) - border = labels.pop('mosaic_border', self.border) + border = labels.pop("mosaic_border", self.border) self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h # M is affine matrix # Scale for func:`box_candidates` @@ -546,20 +544,20 @@ class RandomPerspective: if keypoints is not None: keypoints = self.apply_keypoints(keypoints, M) - new_instances = Instances(bboxes, segments, keypoints, bbox_format='xyxy', normalized=False) + new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False) # Clip new_instances.clip(*self.size) # Filter instances instances.scale(scale_w=scale, scale_h=scale, bbox_only=True) # Make the bboxes have the same scale with new_bboxes - i = self.box_candidates(box1=instances.bboxes.T, - box2=new_instances.bboxes.T, - area_thr=0.01 if len(segments) else 0.10) - labels['instances'] = new_instances[i] - labels['cls'] = cls[i] - labels['img'] = img - labels['resized_shape'] = img.shape[:2] + i = self.box_candidates( + box1=instances.bboxes.T, box2=new_instances.bboxes.T, area_thr=0.01 if len(segments) else 0.10 + ) + labels["instances"] = new_instances[i] + labels["cls"] = cls[i] + labels["img"] = img + labels["resized_shape"] = img.shape[:2] return labels def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): @@ -611,7 +609,7 @@ class RandomHSV: The modified image replaces the original image in the input 'labels' dict. """ - img = labels['img'] + img = labels["img"] if self.hgain or self.sgain or self.vgain: r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV)) @@ -634,7 +632,7 @@ class RandomFlip: Also updates any instances (bounding boxes, keypoints, etc.) accordingly. """ - def __init__(self, p=0.5, direction='horizontal', flip_idx=None) -> None: + def __init__(self, p=0.5, direction="horizontal", flip_idx=None) -> None: """ Initializes the RandomFlip class with probability and direction. @@ -644,7 +642,7 @@ class RandomFlip: Default is 'horizontal'. flip_idx (array-like, optional): Index mapping for flipping keypoints, if any. """ - assert direction in ['horizontal', 'vertical'], f'Support direction `horizontal` or `vertical`, got {direction}' + assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}" assert 0 <= p <= 1.0 self.p = p @@ -662,25 +660,25 @@ class RandomFlip: Returns: (dict): The same dict with the flipped image and updated instances under the 'img' and 'instances' keys. """ - img = labels['img'] - instances = labels.pop('instances') - instances.convert_bbox(format='xywh') + img = labels["img"] + instances = labels.pop("instances") + instances.convert_bbox(format="xywh") h, w = img.shape[:2] h = 1 if instances.normalized else h w = 1 if instances.normalized else w # Flip up-down - if self.direction == 'vertical' and random.random() < self.p: + if self.direction == "vertical" and random.random() < self.p: img = np.flipud(img) instances.flipud(h) - if self.direction == 'horizontal' and random.random() < self.p: + if self.direction == "horizontal" and random.random() < self.p: img = np.fliplr(img) instances.fliplr(w) # For keypoints if self.flip_idx is not None and instances.keypoints is not None: instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :]) - labels['img'] = np.ascontiguousarray(img) - labels['instances'] = instances + labels["img"] = np.ascontiguousarray(img) + labels["instances"] = instances return labels @@ -700,9 +698,9 @@ class LetterBox: """Return updated labels and image with added border.""" if labels is None: labels = {} - img = labels.get('img') if image is None else image + img = labels.get("img") if image is None else image shape = img.shape[:2] # current shape [height, width] - new_shape = labels.pop('rect_shape', self.new_shape) + new_shape = labels.pop("rect_shape", self.new_shape) if isinstance(new_shape, int): new_shape = (new_shape, new_shape) @@ -730,25 +728,26 @@ class LetterBox: img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1)) left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1)) - img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, - value=(114, 114, 114)) # add border - if labels.get('ratio_pad'): - labels['ratio_pad'] = (labels['ratio_pad'], (left, top)) # for evaluation + img = cv2.copyMakeBorder( + img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114) + ) # add border + if labels.get("ratio_pad"): + labels["ratio_pad"] = (labels["ratio_pad"], (left, top)) # for evaluation if len(labels): labels = self._update_labels(labels, ratio, dw, dh) - labels['img'] = img - labels['resized_shape'] = new_shape + labels["img"] = img + labels["resized_shape"] = new_shape return labels else: return img def _update_labels(self, labels, ratio, padw, padh): """Update labels.""" - labels['instances'].convert_bbox(format='xyxy') - labels['instances'].denormalize(*labels['img'].shape[:2][::-1]) - labels['instances'].scale(*ratio) - labels['instances'].add_padding(padw, padh) + labels["instances"].convert_bbox(format="xyxy") + labels["instances"].denormalize(*labels["img"].shape[:2][::-1]) + labels["instances"].scale(*ratio) + labels["instances"].add_padding(padw, padh) return labels @@ -785,11 +784,11 @@ class CopyPaste: 1. Instances are expected to have 'segments' as one of their attributes for this augmentation to work. 2. This method modifies the input dictionary 'labels' in place. """ - im = labels['img'] - cls = labels['cls'] + im = labels["img"] + cls = labels["cls"] h, w = im.shape[:2] - instances = labels.pop('instances') - instances.convert_bbox(format='xyxy') + instances = labels.pop("instances") + instances.convert_bbox(format="xyxy") instances.denormalize(w, h) if self.p and len(instances.segments): n = len(instances) @@ -812,9 +811,9 @@ class CopyPaste: i = cv2.flip(im_new, 1).astype(bool) im[i] = result[i] - labels['img'] = im - labels['cls'] = cls - labels['instances'] = instances + labels["img"] = im + labels["cls"] = cls + labels["instances"] = instances return labels @@ -831,12 +830,13 @@ class Albumentations: """Initialize the transform object for YOLO bbox formatted params.""" self.p = p self.transform = None - prefix = colorstr('albumentations: ') + prefix = colorstr("albumentations: ") try: import albumentations as A - check_version(A.__version__, '1.0.3', hard=True) # version requirement + check_version(A.__version__, "1.0.3", hard=True) # version requirement + # Transforms T = [ A.Blur(p=0.01), A.MedianBlur(p=0.01), @@ -844,31 +844,32 @@ class Albumentations: A.CLAHE(p=0.01), A.RandomBrightnessContrast(p=0.0), A.RandomGamma(p=0.0), - A.ImageCompression(quality_lower=75, p=0.0)] # transforms - self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'])) + A.ImageCompression(quality_lower=75, p=0.0), + ] + self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"])) - LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p)) + LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p)) except ImportError: # package not installed, skip pass except Exception as e: - LOGGER.info(f'{prefix}{e}') + LOGGER.info(f"{prefix}{e}") def __call__(self, labels): """Generates object detections and returns a dictionary with detection results.""" - im = labels['img'] - cls = labels['cls'] + im = labels["img"] + cls = labels["cls"] if len(cls): - labels['instances'].convert_bbox('xywh') - labels['instances'].normalize(*im.shape[:2][::-1]) - bboxes = labels['instances'].bboxes + labels["instances"].convert_bbox("xywh") + labels["instances"].normalize(*im.shape[:2][::-1]) + bboxes = labels["instances"].bboxes # TODO: add supports of segments and keypoints if self.transform and random.random() < self.p: new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed - if len(new['class_labels']) > 0: # skip update if no bbox in new im - labels['img'] = new['image'] - labels['cls'] = np.array(new['class_labels']) - bboxes = np.array(new['bboxes'], dtype=np.float32) - labels['instances'].update(bboxes=bboxes) + if len(new["class_labels"]) > 0: # skip update if no bbox in new im + labels["img"] = new["image"] + labels["cls"] = np.array(new["class_labels"]) + bboxes = np.array(new["bboxes"], dtype=np.float32) + labels["instances"].update(bboxes=bboxes) return labels @@ -888,15 +889,17 @@ class Format: batch_idx (bool): Keep batch indexes. Default is True. """ - def __init__(self, - bbox_format='xywh', - normalize=True, - return_mask=False, - return_keypoint=False, - return_obb=False, - mask_ratio=4, - mask_overlap=True, - batch_idx=True): + def __init__( + self, + bbox_format="xywh", + normalize=True, + return_mask=False, + return_keypoint=False, + return_obb=False, + mask_ratio=4, + mask_overlap=True, + batch_idx=True, + ): """Initializes the Format class with given parameters.""" self.bbox_format = bbox_format self.normalize = normalize @@ -909,10 +912,10 @@ class Format: def __call__(self, labels): """Return formatted image, classes, bounding boxes & keypoints to be used by 'collate_fn'.""" - img = labels.pop('img') + img = labels.pop("img") h, w = img.shape[:2] - cls = labels.pop('cls') - instances = labels.pop('instances') + cls = labels.pop("cls") + instances = labels.pop("instances") instances.convert_bbox(format=self.bbox_format) instances.denormalize(w, h) nl = len(instances) @@ -922,22 +925,24 @@ class Format: masks, instances, cls = self._format_segments(instances, cls, w, h) masks = torch.from_numpy(masks) else: - masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, - img.shape[1] // self.mask_ratio) - labels['masks'] = masks + masks = torch.zeros( + 1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio + ) + labels["masks"] = masks if self.normalize: instances.normalize(w, h) - labels['img'] = self._format_img(img) - labels['cls'] = torch.from_numpy(cls) if nl else torch.zeros(nl) - labels['bboxes'] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4)) + labels["img"] = self._format_img(img) + labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl) + labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4)) if self.return_keypoint: - labels['keypoints'] = torch.from_numpy(instances.keypoints) + labels["keypoints"] = torch.from_numpy(instances.keypoints) if self.return_obb: - labels['bboxes'] = xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len( - instances.segments) else torch.zeros((0, 5)) + labels["bboxes"] = ( + xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(instances.segments) else torch.zeros((0, 5)) + ) # Then we can use collate_fn if self.batch_idx: - labels['batch_idx'] = torch.zeros(nl) + labels["batch_idx"] = torch.zeros(nl) return labels def _format_img(self, img): @@ -964,33 +969,39 @@ class Format: def v8_transforms(dataset, imgsz, hyp, stretch=False): """Convert images to a size suitable for YOLOv8 training.""" - pre_transform = Compose([ - Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic), - CopyPaste(p=hyp.copy_paste), - RandomPerspective( - degrees=hyp.degrees, - translate=hyp.translate, - scale=hyp.scale, - shear=hyp.shear, - perspective=hyp.perspective, - pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)), - )]) - flip_idx = dataset.data.get('flip_idx', []) # for keypoints augmentation + pre_transform = Compose( + [ + Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic), + CopyPaste(p=hyp.copy_paste), + RandomPerspective( + degrees=hyp.degrees, + translate=hyp.translate, + scale=hyp.scale, + shear=hyp.shear, + perspective=hyp.perspective, + pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)), + ), + ] + ) + flip_idx = dataset.data.get("flip_idx", []) # for keypoints augmentation if dataset.use_keypoints: - kpt_shape = dataset.data.get('kpt_shape', None) + kpt_shape = dataset.data.get("kpt_shape", None) if len(flip_idx) == 0 and hyp.fliplr > 0.0: hyp.fliplr = 0.0 LOGGER.warning("WARNING ⚠️ No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'") elif flip_idx and (len(flip_idx) != kpt_shape[0]): - raise ValueError(f'data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}') + raise ValueError(f"data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}") - return Compose([ - pre_transform, - MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup), - Albumentations(p=1.0), - RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v), - RandomFlip(direction='vertical', p=hyp.flipud), - RandomFlip(direction='horizontal', p=hyp.fliplr, flip_idx=flip_idx)]) # transforms + return Compose( + [ + pre_transform, + MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup), + Albumentations(p=1.0), + RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v), + RandomFlip(direction="vertical", p=hyp.flipud), + RandomFlip(direction="horizontal", p=hyp.fliplr, flip_idx=flip_idx), + ] + ) # transforms # Classification augmentations ----------------------------------------------------------------------------------------- @@ -1031,10 +1042,13 @@ def classify_transforms( tfl = [T.Resize(scale_size)] tfl += [T.CenterCrop(size)] - tfl += [T.ToTensor(), T.Normalize( - mean=torch.tensor(mean), - std=torch.tensor(std), - )] + tfl += [ + T.ToTensor(), + T.Normalize( + mean=torch.tensor(mean), + std=torch.tensor(std), + ), + ] return T.Compose(tfl) @@ -1053,7 +1067,7 @@ def classify_augmentations( hsv_s=0.4, # image HSV-Saturation augmentation (fraction) hsv_v=0.4, # image HSV-Value augmentation (fraction) force_color_jitter=False, - erasing=0., + erasing=0.0, interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, ): """ @@ -1080,13 +1094,13 @@ def classify_augmentations( """ # Transforms to apply if albumentations not installed if not isinstance(size, int): - raise TypeError(f'classify_transforms() size {size} must be integer, not (list, tuple)') + raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)") scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range - ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range + ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)] - if hflip > 0.: + if hflip > 0.0: primary_tfl += [T.RandomHorizontalFlip(p=hflip)] - if vflip > 0.: + if vflip > 0.0: primary_tfl += [T.RandomVerticalFlip(p=vflip)] secondary_tfl = [] @@ -1097,27 +1111,29 @@ def classify_augmentations( # this allows override without breaking old hparm cfgs disable_color_jitter = not force_color_jitter - if auto_augment == 'randaugment': + if auto_augment == "randaugment": if TORCHVISION_0_11: secondary_tfl += [T.RandAugment(interpolation=interpolation)] else: LOGGER.warning('"auto_augment=randaugment" requires torchvision >= 0.11.0. Disabling it.') - elif auto_augment == 'augmix': + elif auto_augment == "augmix": if TORCHVISION_0_13: secondary_tfl += [T.AugMix(interpolation=interpolation)] else: LOGGER.warning('"auto_augment=augmix" requires torchvision >= 0.13.0. Disabling it.') - elif auto_augment == 'autoaugment': + elif auto_augment == "autoaugment": if TORCHVISION_0_10: secondary_tfl += [T.AutoAugment(interpolation=interpolation)] else: LOGGER.warning('"auto_augment=autoaugment" requires torchvision >= 0.10.0. Disabling it.') else: - raise ValueError(f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", ' - f'"augmix", "autoaugment" or None') + raise ValueError( + f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", ' + f'"augmix", "autoaugment" or None' + ) if not disable_color_jitter: secondary_tfl += [T.ColorJitter(brightness=hsv_v, contrast=hsv_v, saturation=hsv_s, hue=hsv_h)] @@ -1125,7 +1141,8 @@ def classify_augmentations( final_tfl = [ T.ToTensor(), T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), - T.RandomErasing(p=erasing, inplace=True)] + T.RandomErasing(p=erasing, inplace=True), + ] return T.Compose(primary_tfl + secondary_tfl + final_tfl) @@ -1177,7 +1194,7 @@ class ClassifyLetterBox: # Create padded image im_out = np.full((hs, ws, 3), 114, dtype=im.dtype) - im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR) + im_out[top : top + h, left : left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR) return im_out @@ -1205,7 +1222,7 @@ class CenterCrop: imh, imw = im.shape[:2] m = min(imh, imw) # min dimension top, left = (imh - m) // 2, (imw - m) // 2 - return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR) + return cv2.resize(im[top : top + m, left : left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR) # NOTE: keep this class for backward compatibility diff --git a/ultralytics/data/base.py b/ultralytics/data/base.py index 1df546b3..1a392e0c 100644 --- a/ultralytics/data/base.py +++ b/ultralytics/data/base.py @@ -47,20 +47,22 @@ class BaseDataset(Dataset): transforms (callable): Image transformation function. """ - def __init__(self, - img_path, - imgsz=640, - cache=False, - augment=True, - hyp=DEFAULT_CFG, - prefix='', - rect=False, - batch_size=16, - stride=32, - pad=0.5, - single_cls=False, - classes=None, - fraction=1.0): + def __init__( + self, + img_path, + imgsz=640, + cache=False, + augment=True, + hyp=DEFAULT_CFG, + prefix="", + rect=False, + batch_size=16, + stride=32, + pad=0.5, + single_cls=False, + classes=None, + fraction=1.0, + ): """Initialize BaseDataset with given configuration and options.""" super().__init__() self.img_path = img_path @@ -86,10 +88,10 @@ class BaseDataset(Dataset): self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0 # Cache images - if cache == 'ram' and not self.check_cache_ram(): + if cache == "ram" and not self.check_cache_ram(): cache = False self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni - self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files] + self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files] if cache: self.cache_images(cache) @@ -103,23 +105,23 @@ class BaseDataset(Dataset): for p in img_path if isinstance(img_path, list) else [img_path]: p = Path(p) # os-agnostic if p.is_dir(): # dir - f += glob.glob(str(p / '**' / '*.*'), recursive=True) + f += glob.glob(str(p / "**" / "*.*"), recursive=True) # F = list(p.rglob('*.*')) # pathlib elif p.is_file(): # file with open(p) as t: t = t.read().strip().splitlines() parent = str(p.parent) + os.sep - f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path + f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path # F += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib) else: - raise FileNotFoundError(f'{self.prefix}{p} does not exist') - im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS) + raise FileNotFoundError(f"{self.prefix}{p} does not exist") + im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS) # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib - assert im_files, f'{self.prefix}No images found in {img_path}' + assert im_files, f"{self.prefix}No images found in {img_path}" except Exception as e: - raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e + raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e if self.fraction < 1: - im_files = im_files[:round(len(im_files) * self.fraction)] + im_files = im_files[: round(len(im_files) * self.fraction)] return im_files def update_labels(self, include_class: Optional[list]): @@ -127,19 +129,19 @@ class BaseDataset(Dataset): include_class_array = np.array(include_class).reshape(1, -1) for i in range(len(self.labels)): if include_class is not None: - cls = self.labels[i]['cls'] - bboxes = self.labels[i]['bboxes'] - segments = self.labels[i]['segments'] - keypoints = self.labels[i]['keypoints'] + cls = self.labels[i]["cls"] + bboxes = self.labels[i]["bboxes"] + segments = self.labels[i]["segments"] + keypoints = self.labels[i]["keypoints"] j = (cls == include_class_array).any(1) - self.labels[i]['cls'] = cls[j] - self.labels[i]['bboxes'] = bboxes[j] + self.labels[i]["cls"] = cls[j] + self.labels[i]["bboxes"] = bboxes[j] if segments: - self.labels[i]['segments'] = [segments[si] for si, idx in enumerate(j) if idx] + self.labels[i]["segments"] = [segments[si] for si, idx in enumerate(j) if idx] if keypoints is not None: - self.labels[i]['keypoints'] = keypoints[j] + self.labels[i]["keypoints"] = keypoints[j] if self.single_cls: - self.labels[i]['cls'][:, 0] = 0 + self.labels[i]["cls"][:, 0] = 0 def load_image(self, i, rect_mode=True): """Loads 1 image from dataset index 'i', returns (im, resized hw).""" @@ -149,13 +151,13 @@ class BaseDataset(Dataset): try: im = np.load(fn) except Exception as e: - LOGGER.warning(f'{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}') + LOGGER.warning(f"{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}") Path(fn).unlink(missing_ok=True) im = cv2.imread(f) # BGR else: # read image im = cv2.imread(f) # BGR if im is None: - raise FileNotFoundError(f'Image Not Found {f}') + raise FileNotFoundError(f"Image Not Found {f}") h0, w0 = im.shape[:2] # orig hw if rect_mode: # resize long side to imgsz while maintaining aspect ratio @@ -181,17 +183,17 @@ class BaseDataset(Dataset): def cache_images(self, cache): """Cache images to memory or disk.""" b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes - fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image + fcn = self.cache_images_to_disk if cache == "disk" else self.load_image with ThreadPool(NUM_THREADS) as pool: results = pool.imap(fcn, range(self.ni)) pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0) for i, x in pbar: - if cache == 'disk': + if cache == "disk": b += self.npy_files[i].stat().st_size else: # 'ram' self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i) b += self.ims[i].nbytes - pbar.desc = f'{self.prefix}Caching images ({b / gb:.1f}GB {cache})' + pbar.desc = f"{self.prefix}Caching images ({b / gb:.1f}GB {cache})" pbar.close() def cache_images_to_disk(self, i): @@ -207,15 +209,17 @@ class BaseDataset(Dataset): for _ in range(n): im = cv2.imread(random.choice(self.im_files)) # sample image ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio - b += im.nbytes * ratio ** 2 + b += im.nbytes * ratio**2 mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM mem = psutil.virtual_memory() cache = mem_required < mem.available # to cache or not to cache, that is the question if not cache: - LOGGER.info(f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images ' - f'with {int(safety_margin * 100)}% safety margin but only ' - f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, ' - f"{'caching images ✅' if cache else 'not caching images ⚠️'}") + LOGGER.info( + f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images ' + f'with {int(safety_margin * 100)}% safety margin but only ' + f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, ' + f"{'caching images ✅' if cache else 'not caching images ⚠️'}" + ) return cache def set_rectangle(self): @@ -223,7 +227,7 @@ class BaseDataset(Dataset): bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index nb = bi[-1] + 1 # number of batches - s = np.array([x.pop('shape') for x in self.labels]) # hw + s = np.array([x.pop("shape") for x in self.labels]) # hw ar = s[:, 0] / s[:, 1] # aspect ratio irect = ar.argsort() self.im_files = [self.im_files[i] for i in irect] @@ -250,12 +254,14 @@ class BaseDataset(Dataset): def get_image_and_label(self, index): """Get and return label information from the dataset.""" label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948 - label.pop('shape', None) # shape is for rect, remove it - label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index) - label['ratio_pad'] = (label['resized_shape'][0] / label['ori_shape'][0], - label['resized_shape'][1] / label['ori_shape'][1]) # for evaluation + label.pop("shape", None) # shape is for rect, remove it + label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index) + label["ratio_pad"] = ( + label["resized_shape"][0] / label["ori_shape"][0], + label["resized_shape"][1] / label["ori_shape"][1], + ) # for evaluation if self.rect: - label['rect_shape'] = self.batch_shapes[self.batch[index]] + label["rect_shape"] = self.batch_shapes[self.batch[index]] return self.update_labels_info(label) def __len__(self): diff --git a/ultralytics/data/build.py b/ultralytics/data/build.py index 99fed0c4..0892a2d9 100644 --- a/ultralytics/data/build.py +++ b/ultralytics/data/build.py @@ -9,8 +9,16 @@ import torch from PIL import Image from torch.utils.data import dataloader, distributed -from ultralytics.data.loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams, LoadTensor, - SourceTypes, autocast_list) +from ultralytics.data.loaders import ( + LOADERS, + LoadImages, + LoadPilAndNumpy, + LoadScreenshots, + LoadStreams, + LoadTensor, + SourceTypes, + autocast_list, +) from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS from ultralytics.utils import RANK, colorstr from ultralytics.utils.checks import check_file @@ -29,7 +37,7 @@ class InfiniteDataLoader(dataloader.DataLoader): def __init__(self, *args, **kwargs): """Dataloader that infinitely recycles workers, inherits from DataLoader.""" super().__init__(*args, **kwargs) - object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) + object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler)) self.iterator = super().__iter__() def __len__(self): @@ -70,29 +78,30 @@ class _RepeatSampler: def seed_worker(worker_id): # noqa """Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader.""" - worker_seed = torch.initial_seed() % 2 ** 32 + worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) -def build_yolo_dataset(cfg, img_path, batch, data, mode='train', rect=False, stride=32): +def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32): """Build YOLO Dataset.""" return YOLODataset( img_path=img_path, imgsz=cfg.imgsz, batch_size=batch, - augment=mode == 'train', # augmentation + augment=mode == "train", # augmentation hyp=cfg, # TODO: probably add a get_hyps_from_cfg function rect=cfg.rect or rect, # rectangular batches cache=cfg.cache or None, single_cls=cfg.single_cls or False, stride=int(stride), - pad=0.0 if mode == 'train' else 0.5, - prefix=colorstr(f'{mode}: '), + pad=0.0 if mode == "train" else 0.5, + prefix=colorstr(f"{mode}: "), task=cfg.task, classes=cfg.classes, data=data, - fraction=cfg.fraction if mode == 'train' else 1.0) + fraction=cfg.fraction if mode == "train" else 1.0, + ) def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1): @@ -103,15 +112,17 @@ def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1): sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) generator = torch.Generator() generator.manual_seed(6148914691236517205 + RANK) - return InfiniteDataLoader(dataset=dataset, - batch_size=batch, - shuffle=shuffle and sampler is None, - num_workers=nw, - sampler=sampler, - pin_memory=PIN_MEMORY, - collate_fn=getattr(dataset, 'collate_fn', None), - worker_init_fn=seed_worker, - generator=generator) + return InfiniteDataLoader( + dataset=dataset, + batch_size=batch, + shuffle=shuffle and sampler is None, + num_workers=nw, + sampler=sampler, + pin_memory=PIN_MEMORY, + collate_fn=getattr(dataset, "collate_fn", None), + worker_init_fn=seed_worker, + generator=generator, + ) def check_source(source): @@ -120,9 +131,9 @@ def check_source(source): if isinstance(source, (str, int, Path)): # int for local usb camera source = str(source) is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) - is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://', 'tcp://')) - webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file) - screenshot = source.lower() == 'screen' + is_url = source.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")) + webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file) + screenshot = source.lower() == "screen" if is_url and is_file: source = check_file(source) # download elif isinstance(source, LOADERS): @@ -135,7 +146,7 @@ def check_source(source): elif isinstance(source, torch.Tensor): tensor = True else: - raise TypeError('Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict') + raise TypeError("Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict") return source, webcam, screenshot, from_img, in_memory, tensor @@ -171,6 +182,6 @@ def load_inference_source(source=None, imgsz=640, vid_stride=1, buffer=False): dataset = LoadImages(source, imgsz=imgsz, vid_stride=vid_stride) # Attach source types to the dataset - setattr(dataset, 'source_type', source_type) + setattr(dataset, "source_type", source_type) return dataset diff --git a/ultralytics/data/converter.py b/ultralytics/data/converter.py index 5714320f..1a875f3b 100644 --- a/ultralytics/data/converter.py +++ b/ultralytics/data/converter.py @@ -20,10 +20,98 @@ def coco91_to_coco80_class(): corresponding 91-index class ID. """ return [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, None, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, None, 24, 25, None, - None, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, None, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, - 51, 52, 53, 54, 55, 56, 57, 58, 59, None, 60, None, None, 61, None, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, - None, 73, 74, 75, 76, 77, 78, 79, None] + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + None, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + None, + 24, + 25, + None, + None, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + None, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + None, + 60, + None, + None, + 61, + None, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + None, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + None, + ] def coco80_to_coco91_class(): @@ -42,16 +130,96 @@ def coco80_to_coco91_class(): ``` """ return [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, - 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, - 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] - - -def convert_coco(labels_dir='../coco/annotations/', - save_dir='coco_converted/', - use_segments=False, - use_keypoints=False, - cls91to80=True): + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 27, + 28, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 67, + 70, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + ] + + +def convert_coco( + labels_dir="../coco/annotations/", + save_dir="coco_converted/", + use_segments=False, + use_keypoints=False, + cls91to80=True, +): """ Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models. @@ -75,76 +243,78 @@ def convert_coco(labels_dir='../coco/annotations/', # Create dataset directory save_dir = increment_path(save_dir) # increment if save directory already exists - for p in save_dir / 'labels', save_dir / 'images': + for p in save_dir / "labels", save_dir / "images": p.mkdir(parents=True, exist_ok=True) # make dir # Convert classes coco80 = coco91_to_coco80_class() # Import json - for json_file in sorted(Path(labels_dir).resolve().glob('*.json')): - fn = Path(save_dir) / 'labels' / json_file.stem.replace('instances_', '') # folder name + for json_file in sorted(Path(labels_dir).resolve().glob("*.json")): + fn = Path(save_dir) / "labels" / json_file.stem.replace("instances_", "") # folder name fn.mkdir(parents=True, exist_ok=True) with open(json_file) as f: data = json.load(f) # Create image dict - images = {f'{x["id"]:d}': x for x in data['images']} + images = {f'{x["id"]:d}': x for x in data["images"]} # Create image-annotations dict imgToAnns = defaultdict(list) - for ann in data['annotations']: - imgToAnns[ann['image_id']].append(ann) + for ann in data["annotations"]: + imgToAnns[ann["image_id"]].append(ann) # Write labels file - for img_id, anns in TQDM(imgToAnns.items(), desc=f'Annotations {json_file}'): - img = images[f'{img_id:d}'] - h, w, f = img['height'], img['width'], img['file_name'] + for img_id, anns in TQDM(imgToAnns.items(), desc=f"Annotations {json_file}"): + img = images[f"{img_id:d}"] + h, w, f = img["height"], img["width"], img["file_name"] bboxes = [] segments = [] keypoints = [] for ann in anns: - if ann['iscrowd']: + if ann["iscrowd"]: continue # The COCO box format is [top left x, top left y, width, height] - box = np.array(ann['bbox'], dtype=np.float64) + box = np.array(ann["bbox"], dtype=np.float64) box[:2] += box[2:] / 2 # xy top-left corner to center box[[0, 2]] /= w # normalize x box[[1, 3]] /= h # normalize y if box[2] <= 0 or box[3] <= 0: # if w <= 0 and h <= 0 continue - cls = coco80[ann['category_id'] - 1] if cls91to80 else ann['category_id'] - 1 # class + cls = coco80[ann["category_id"] - 1] if cls91to80 else ann["category_id"] - 1 # class box = [cls] + box.tolist() if box not in bboxes: bboxes.append(box) - if use_segments and ann.get('segmentation') is not None: - if len(ann['segmentation']) == 0: + if use_segments and ann.get("segmentation") is not None: + if len(ann["segmentation"]) == 0: segments.append([]) continue - elif len(ann['segmentation']) > 1: - s = merge_multi_segment(ann['segmentation']) + elif len(ann["segmentation"]) > 1: + s = merge_multi_segment(ann["segmentation"]) s = (np.concatenate(s, axis=0) / np.array([w, h])).reshape(-1).tolist() else: - s = [j for i in ann['segmentation'] for j in i] # all segments concatenated + s = [j for i in ann["segmentation"] for j in i] # all segments concatenated s = (np.array(s).reshape(-1, 2) / np.array([w, h])).reshape(-1).tolist() s = [cls] + s segments.append(s) - if use_keypoints and ann.get('keypoints') is not None: - keypoints.append(box + (np.array(ann['keypoints']).reshape(-1, 3) / - np.array([w, h, 1])).reshape(-1).tolist()) + if use_keypoints and ann.get("keypoints") is not None: + keypoints.append( + box + (np.array(ann["keypoints"]).reshape(-1, 3) / np.array([w, h, 1])).reshape(-1).tolist() + ) # Write - with open((fn / f).with_suffix('.txt'), 'a') as file: + with open((fn / f).with_suffix(".txt"), "a") as file: for i in range(len(bboxes)): if use_keypoints: - line = *(keypoints[i]), # cls, box, keypoints + line = (*(keypoints[i]),) # cls, box, keypoints else: - line = *(segments[i] - if use_segments and len(segments[i]) > 0 else bboxes[i]), # cls, box or segments - file.write(('%g ' * len(line)).rstrip() % line + '\n') + line = ( + *(segments[i] if use_segments and len(segments[i]) > 0 else bboxes[i]), + ) # cls, box or segments + file.write(("%g " * len(line)).rstrip() % line + "\n") - LOGGER.info(f'COCO data converted successfully.\nResults saved to {save_dir.resolve()}') + LOGGER.info(f"COCO data converted successfully.\nResults saved to {save_dir.resolve()}") def convert_dota_to_yolo_obb(dota_root_path: str): @@ -184,31 +354,32 @@ def convert_dota_to_yolo_obb(dota_root_path: str): # Class names to indices mapping class_mapping = { - 'plane': 0, - 'ship': 1, - 'storage-tank': 2, - 'baseball-diamond': 3, - 'tennis-court': 4, - 'basketball-court': 5, - 'ground-track-field': 6, - 'harbor': 7, - 'bridge': 8, - 'large-vehicle': 9, - 'small-vehicle': 10, - 'helicopter': 11, - 'roundabout': 12, - 'soccer-ball-field': 13, - 'swimming-pool': 14, - 'container-crane': 15, - 'airport': 16, - 'helipad': 17} + "plane": 0, + "ship": 1, + "storage-tank": 2, + "baseball-diamond": 3, + "tennis-court": 4, + "basketball-court": 5, + "ground-track-field": 6, + "harbor": 7, + "bridge": 8, + "large-vehicle": 9, + "small-vehicle": 10, + "helicopter": 11, + "roundabout": 12, + "soccer-ball-field": 13, + "swimming-pool": 14, + "container-crane": 15, + "airport": 16, + "helipad": 17, + } def convert_label(image_name, image_width, image_height, orig_label_dir, save_dir): """Converts a single image's DOTA annotation to YOLO OBB format and saves it to a specified directory.""" - orig_label_path = orig_label_dir / f'{image_name}.txt' - save_path = save_dir / f'{image_name}.txt' + orig_label_path = orig_label_dir / f"{image_name}.txt" + save_path = save_dir / f"{image_name}.txt" - with orig_label_path.open('r') as f, save_path.open('w') as g: + with orig_label_path.open("r") as f, save_path.open("w") as g: lines = f.readlines() for line in lines: parts = line.strip().split() @@ -218,20 +389,21 @@ def convert_dota_to_yolo_obb(dota_root_path: str): class_idx = class_mapping[class_name] coords = [float(p) for p in parts[:8]] normalized_coords = [ - coords[i] / image_width if i % 2 == 0 else coords[i] / image_height for i in range(8)] - formatted_coords = ['{:.6g}'.format(coord) for coord in normalized_coords] + coords[i] / image_width if i % 2 == 0 else coords[i] / image_height for i in range(8) + ] + formatted_coords = ["{:.6g}".format(coord) for coord in normalized_coords] g.write(f"{class_idx} {' '.join(formatted_coords)}\n") - for phase in ['train', 'val']: - image_dir = dota_root_path / 'images' / phase - orig_label_dir = dota_root_path / 'labels' / f'{phase}_original' - save_dir = dota_root_path / 'labels' / phase + for phase in ["train", "val"]: + image_dir = dota_root_path / "images" / phase + orig_label_dir = dota_root_path / "labels" / f"{phase}_original" + save_dir = dota_root_path / "labels" / phase save_dir.mkdir(parents=True, exist_ok=True) image_paths = list(image_dir.iterdir()) - for image_path in TQDM(image_paths, desc=f'Processing {phase} images'): - if image_path.suffix != '.png': + for image_path in TQDM(image_paths, desc=f"Processing {phase} images"): + if image_path.suffix != ".png": continue image_name_without_ext = image_path.stem img = cv2.imread(str(image_path)) @@ -293,7 +465,7 @@ def merge_multi_segment(segments): s.append(segments[i]) else: idx = [0, idx[1] - idx[0]] - s.append(segments[i][idx[0]:idx[1] + 1]) + s.append(segments[i][idx[0] : idx[1] + 1]) else: for i in range(len(idx_list) - 1, -1, -1): diff --git a/ultralytics/data/dataset.py b/ultralytics/data/dataset.py index ad0ba56e..0475508e 100644 --- a/ultralytics/data/dataset.py +++ b/ultralytics/data/dataset.py @@ -18,7 +18,7 @@ from .base import BaseDataset from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label # Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8 -DATASET_CACHE_VERSION = '1.0.3' +DATASET_CACHE_VERSION = "1.0.3" class YOLODataset(BaseDataset): @@ -33,16 +33,16 @@ class YOLODataset(BaseDataset): (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model. """ - def __init__(self, *args, data=None, task='detect', **kwargs): + def __init__(self, *args, data=None, task="detect", **kwargs): """Initializes the YOLODataset with optional configurations for segments and keypoints.""" - self.use_segments = task == 'segment' - self.use_keypoints = task == 'pose' - self.use_obb = task == 'obb' + self.use_segments = task == "segment" + self.use_keypoints = task == "pose" + self.use_obb = task == "obb" self.data = data - assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.' + assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints." super().__init__(*args, **kwargs) - def cache_labels(self, path=Path('./labels.cache')): + def cache_labels(self, path=Path("./labels.cache")): """ Cache dataset labels, check images and read shapes. @@ -51,19 +51,29 @@ class YOLODataset(BaseDataset): Returns: (dict): labels. """ - x = {'labels': []} + x = {"labels": []} nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages - desc = f'{self.prefix}Scanning {path.parent / path.stem}...' + desc = f"{self.prefix}Scanning {path.parent / path.stem}..." total = len(self.im_files) - nkpt, ndim = self.data.get('kpt_shape', (0, 0)) + nkpt, ndim = self.data.get("kpt_shape", (0, 0)) if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)): - raise ValueError("'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of " - "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'") + raise ValueError( + "'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of " + "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'" + ) with ThreadPool(NUM_THREADS) as pool: - results = pool.imap(func=verify_image_label, - iterable=zip(self.im_files, self.label_files, repeat(self.prefix), - repeat(self.use_keypoints), repeat(len(self.data['names'])), repeat(nkpt), - repeat(ndim))) + results = pool.imap( + func=verify_image_label, + iterable=zip( + self.im_files, + self.label_files, + repeat(self.prefix), + repeat(self.use_keypoints), + repeat(len(self.data["names"])), + repeat(nkpt), + repeat(ndim), + ), + ) pbar = TQDM(results, desc=desc, total=total) for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar: nm += nm_f @@ -71,7 +81,7 @@ class YOLODataset(BaseDataset): ne += ne_f nc += nc_f if im_file: - x['labels'].append( + x["labels"].append( dict( im_file=im_file, shape=shape, @@ -80,60 +90,63 @@ class YOLODataset(BaseDataset): segments=segments, keypoints=keypoint, normalized=True, - bbox_format='xywh')) + bbox_format="xywh", + ) + ) if msg: msgs.append(msg) - pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt' + pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt" pbar.close() if msgs: - LOGGER.info('\n'.join(msgs)) + LOGGER.info("\n".join(msgs)) if nf == 0: - LOGGER.warning(f'{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}') - x['hash'] = get_hash(self.label_files + self.im_files) - x['results'] = nf, nm, ne, nc, len(self.im_files) - x['msgs'] = msgs # warnings + LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}") + x["hash"] = get_hash(self.label_files + self.im_files) + x["results"] = nf, nm, ne, nc, len(self.im_files) + x["msgs"] = msgs # warnings save_dataset_cache_file(self.prefix, path, x) return x def get_labels(self): """Returns dictionary of labels for YOLO training.""" self.label_files = img2label_paths(self.im_files) - cache_path = Path(self.label_files[0]).parent.with_suffix('.cache') + cache_path = Path(self.label_files[0]).parent.with_suffix(".cache") try: cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file - assert cache['version'] == DATASET_CACHE_VERSION # matches current version - assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash + assert cache["version"] == DATASET_CACHE_VERSION # matches current version + assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash except (FileNotFoundError, AssertionError, AttributeError): cache, exists = self.cache_labels(cache_path), False # run cache ops # Display cache - nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total + nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total if exists and LOCAL_RANK in (-1, 0): - d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt' + d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt" TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results - if cache['msgs']: - LOGGER.info('\n'.join(cache['msgs'])) # display warnings + if cache["msgs"]: + LOGGER.info("\n".join(cache["msgs"])) # display warnings # Read cache - [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items - labels = cache['labels'] + [cache.pop(k) for k in ("hash", "version", "msgs")] # remove items + labels = cache["labels"] if not labels: - LOGGER.warning(f'WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}') - self.im_files = [lb['im_file'] for lb in labels] # update im_files + LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}") + self.im_files = [lb["im_file"] for lb in labels] # update im_files # Check if the dataset is all boxes or all segments - lengths = ((len(lb['cls']), len(lb['bboxes']), len(lb['segments'])) for lb in labels) + lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels) len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths)) if len_segments and len_boxes != len_segments: LOGGER.warning( - f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, ' - f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. ' - 'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.') + f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, " + f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. " + "To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset." + ) for lb in labels: - lb['segments'] = [] + lb["segments"] = [] if len_cls == 0: - LOGGER.warning(f'WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}') + LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}") return labels def build_transforms(self, hyp=None): @@ -145,14 +158,17 @@ class YOLODataset(BaseDataset): else: transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)]) transforms.append( - Format(bbox_format='xywh', - normalize=True, - return_mask=self.use_segments, - return_keypoint=self.use_keypoints, - return_obb=self.use_obb, - batch_idx=True, - mask_ratio=hyp.mask_ratio, - mask_overlap=hyp.overlap_mask)) + Format( + bbox_format="xywh", + normalize=True, + return_mask=self.use_segments, + return_keypoint=self.use_keypoints, + return_obb=self.use_obb, + batch_idx=True, + mask_ratio=hyp.mask_ratio, + mask_overlap=hyp.overlap_mask, + ) + ) return transforms def close_mosaic(self, hyp): @@ -166,11 +182,11 @@ class YOLODataset(BaseDataset): """Custom your label format here.""" # NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label # We can make it also support classification and semantic segmentation by add or remove some dict keys there. - bboxes = label.pop('bboxes') - segments = label.pop('segments', []) - keypoints = label.pop('keypoints', None) - bbox_format = label.pop('bbox_format') - normalized = label.pop('normalized') + bboxes = label.pop("bboxes") + segments = label.pop("segments", []) + keypoints = label.pop("keypoints", None) + bbox_format = label.pop("bbox_format") + normalized = label.pop("normalized") # NOTE: do NOT resample oriented boxes segment_resamples = 100 if self.use_obb else 1000 @@ -180,7 +196,7 @@ class YOLODataset(BaseDataset): segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0) else: segments = np.zeros((0, segment_resamples, 2), dtype=np.float32) - label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized) + label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized) return label @staticmethod @@ -191,15 +207,15 @@ class YOLODataset(BaseDataset): values = list(zip(*[list(b.values()) for b in batch])) for i, k in enumerate(keys): value = values[i] - if k == 'img': + if k == "img": value = torch.stack(value, 0) - if k in ['masks', 'keypoints', 'bboxes', 'cls', 'segments', 'obb']: + if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]: value = torch.cat(value, 0) new_batch[k] = value - new_batch['batch_idx'] = list(new_batch['batch_idx']) - for i in range(len(new_batch['batch_idx'])): - new_batch['batch_idx'][i] += i # add target image index for build_targets() - new_batch['batch_idx'] = torch.cat(new_batch['batch_idx'], 0) + new_batch["batch_idx"] = list(new_batch["batch_idx"]) + for i in range(len(new_batch["batch_idx"])): + new_batch["batch_idx"][i] += i # add target image index for build_targets() + new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0) return new_batch @@ -219,7 +235,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True. """ - def __init__(self, root, args, augment=False, cache=False, prefix=''): + def __init__(self, root, args, augment=False, cache=False, prefix=""): """ Initialize YOLO object with root, image size, augmentations, and cache settings. @@ -231,23 +247,28 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): """ super().__init__(root=root) if augment and args.fraction < 1.0: # reduce training fraction - self.samples = self.samples[:round(len(self.samples) * args.fraction)] - self.prefix = colorstr(f'{prefix}: ') if prefix else '' - self.cache_ram = cache is True or cache == 'ram' - self.cache_disk = cache == 'disk' + self.samples = self.samples[: round(len(self.samples) * args.fraction)] + self.prefix = colorstr(f"{prefix}: ") if prefix else "" + self.cache_ram = cache is True or cache == "ram" + self.cache_disk = cache == "disk" self.samples = self.verify_images() # filter out bad images - self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im + self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im scale = (1.0 - args.scale, 1.0) # (0.08, 1.0) - self.torch_transforms = classify_augmentations(size=args.imgsz, - scale=scale, - hflip=args.fliplr, - vflip=args.flipud, - erasing=args.erasing, - auto_augment=args.auto_augment, - hsv_h=args.hsv_h, - hsv_s=args.hsv_s, - hsv_v=args.hsv_v) if augment else classify_transforms( - size=args.imgsz, crop_fraction=args.crop_fraction) + self.torch_transforms = ( + classify_augmentations( + size=args.imgsz, + scale=scale, + hflip=args.fliplr, + vflip=args.flipud, + erasing=args.erasing, + auto_augment=args.auto_augment, + hsv_h=args.hsv_h, + hsv_s=args.hsv_s, + hsv_v=args.hsv_v, + ) + if augment + else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction) + ) def __getitem__(self, i): """Returns subset of data and targets corresponding to given indices.""" @@ -263,7 +284,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): # Convert NumPy array to PIL image im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB)) sample = self.torch_transforms(im) - return {'img': sample, 'cls': j} + return {"img": sample, "cls": j} def __len__(self) -> int: """Return the total number of samples in the dataset.""" @@ -271,19 +292,19 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): def verify_images(self): """Verify all images in dataset.""" - desc = f'{self.prefix}Scanning {self.root}...' - path = Path(self.root).with_suffix('.cache') # *.cache file path + desc = f"{self.prefix}Scanning {self.root}..." + path = Path(self.root).with_suffix(".cache") # *.cache file path with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError): cache = load_dataset_cache_file(path) # attempt to load a *.cache file - assert cache['version'] == DATASET_CACHE_VERSION # matches current version - assert cache['hash'] == get_hash([x[0] for x in self.samples]) # identical hash - nf, nc, n, samples = cache.pop('results') # found, missing, empty, corrupt, total + assert cache["version"] == DATASET_CACHE_VERSION # matches current version + assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash + nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total if LOCAL_RANK in (-1, 0): - d = f'{desc} {nf} images, {nc} corrupt' + d = f"{desc} {nf} images, {nc} corrupt" TQDM(None, desc=d, total=n, initial=n) - if cache['msgs']: - LOGGER.info('\n'.join(cache['msgs'])) # display warnings + if cache["msgs"]: + LOGGER.info("\n".join(cache["msgs"])) # display warnings return samples # Run scan if *.cache retrieval failed @@ -298,13 +319,13 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): msgs.append(msg) nf += nf_f nc += nc_f - pbar.desc = f'{desc} {nf} images, {nc} corrupt' + pbar.desc = f"{desc} {nf} images, {nc} corrupt" pbar.close() if msgs: - LOGGER.info('\n'.join(msgs)) - x['hash'] = get_hash([x[0] for x in self.samples]) - x['results'] = nf, nc, len(samples), samples - x['msgs'] = msgs # warnings + LOGGER.info("\n".join(msgs)) + x["hash"] = get_hash([x[0] for x in self.samples]) + x["results"] = nf, nc, len(samples), samples + x["msgs"] = msgs # warnings save_dataset_cache_file(self.prefix, path, x) return samples @@ -312,6 +333,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): def load_dataset_cache_file(path): """Load an Ultralytics *.cache dictionary from path.""" import gc + gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585 cache = np.load(str(path), allow_pickle=True).item() # load dict gc.enable() @@ -320,15 +342,15 @@ def load_dataset_cache_file(path): def save_dataset_cache_file(prefix, path, x): """Save an Ultralytics dataset *.cache dictionary x to path.""" - x['version'] = DATASET_CACHE_VERSION # add cache version + x["version"] = DATASET_CACHE_VERSION # add cache version if is_dir_writeable(path.parent): if path.exists(): path.unlink() # remove *.cache file if exists np.save(str(path), x) # save cache for next time - path.with_suffix('.cache.npy').rename(path) # remove .npy suffix - LOGGER.info(f'{prefix}New cache created: {path}') + path.with_suffix(".cache.npy").rename(path) # remove .npy suffix + LOGGER.info(f"{prefix}New cache created: {path}") else: - LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.') + LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.") # TODO: support semantic segmentation diff --git a/ultralytics/data/explorer/__init__.py b/ultralytics/data/explorer/__init__.py index f0344e2d..ce594dc1 100644 --- a/ultralytics/data/explorer/__init__.py +++ b/ultralytics/data/explorer/__init__.py @@ -2,4 +2,4 @@ from .utils import plot_query_result -__all__ = ['plot_query_result'] +__all__ = ["plot_query_result"] diff --git a/ultralytics/data/explorer/explorer.py b/ultralytics/data/explorer/explorer.py index 4a8595d6..5381a0a1 100644 --- a/ultralytics/data/explorer/explorer.py +++ b/ultralytics/data/explorer/explorer.py @@ -22,7 +22,6 @@ from .utils import get_sim_index_schema, get_table_schema, plot_query_result, pr class ExplorerDataset(YOLODataset): - def __init__(self, *args, data: dict = None, **kwargs) -> None: super().__init__(*args, data=data, **kwargs) @@ -35,7 +34,7 @@ class ExplorerDataset(YOLODataset): else: # read image im = cv2.imread(f) # BGR if im is None: - raise FileNotFoundError(f'Image Not Found {f}') + raise FileNotFoundError(f"Image Not Found {f}") h0, w0 = im.shape[:2] # orig hw return im, (h0, w0), im.shape[:2] @@ -44,7 +43,7 @@ class ExplorerDataset(YOLODataset): def build_transforms(self, hyp: IterableSimpleNamespace = None): """Creates transforms for dataset images without resizing.""" return Format( - bbox_format='xyxy', + bbox_format="xyxy", normalize=False, return_mask=self.use_segments, return_keypoint=self.use_keypoints, @@ -55,17 +54,16 @@ class ExplorerDataset(YOLODataset): class Explorer: - - def __init__(self, - data: Union[str, Path] = 'coco128.yaml', - model: str = 'yolov8n.pt', - uri: str = '~/ultralytics/explorer') -> None: - checks.check_requirements(['lancedb>=0.4.3', 'duckdb']) + def __init__( + self, data: Union[str, Path] = "coco128.yaml", model: str = "yolov8n.pt", uri: str = "~/ultralytics/explorer" + ) -> None: + checks.check_requirements(["lancedb>=0.4.3", "duckdb"]) import lancedb self.connection = lancedb.connect(uri) - self.table_name = Path(data).name.lower() + '_' + model.lower() - self.sim_idx_base_name = f'{self.table_name}_sim_idx'.lower( + self.table_name = Path(data).name.lower() + "_" + model.lower() + self.sim_idx_base_name = ( + f"{self.table_name}_sim_idx".lower() ) # Use this name and append thres and top_k to reuse the table self.model = YOLO(model) self.data = data # None @@ -74,7 +72,7 @@ class Explorer: self.table = None self.progress = 0 - def create_embeddings_table(self, force: bool = False, split: str = 'train') -> None: + def create_embeddings_table(self, force: bool = False, split: str = "train") -> None: """ Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it already exists. Pass force=True to overwrite the existing table. @@ -90,20 +88,20 @@ class Explorer: ``` """ if self.table is not None and not force: - LOGGER.info('Table already exists. Reusing it. Pass force=True to overwrite it.') + LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.") return if self.table_name in self.connection.table_names() and not force: - LOGGER.info(f'Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.') + LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.") self.table = self.connection.open_table(self.table_name) self.progress = 1 return if self.data is None: - raise ValueError('Data must be provided to create embeddings table') + raise ValueError("Data must be provided to create embeddings table") data_info = check_det_dataset(self.data) if split not in data_info: raise ValueError( - f'Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}' + f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}" ) choice_set = data_info[split] @@ -113,13 +111,16 @@ class Explorer: # Create the table schema batch = dataset[0] - vector_size = self.model.embed(batch['im_file'], verbose=False)[0].shape[0] - table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode='overwrite') + vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0] + table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite") table.add( - self._yield_batches(dataset, - data_info, - self.model, - exclude_keys=['img', 'ratio_pad', 'resized_shape', 'ori_shape', 'batch_idx'])) + self._yield_batches( + dataset, + data_info, + self.model, + exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"], + ) + ) self.table = table @@ -131,12 +132,12 @@ class Explorer: for k in exclude_keys: batch.pop(k, None) batch = sanitize_batch(batch, data_info) - batch['vector'] = model.embed(batch['im_file'], verbose=False)[0].detach().tolist() + batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist() yield [batch] - def query(self, - imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, - limit: int = 25) -> Any: # pyarrow.Table + def query( + self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25 + ) -> Any: # pyarrow.Table """ Query the table for similar images. Accepts a single image or a list of images. @@ -157,18 +158,18 @@ class Explorer: ``` """ if self.table is None: - raise ValueError('Table is not created. Please create the table first.') + raise ValueError("Table is not created. Please create the table first.") if isinstance(imgs, str): imgs = [imgs] - assert isinstance(imgs, list), f'img must be a string or a list of strings. Got {type(imgs)}' + assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}" embeds = self.model.embed(imgs) # Get avg if multiple images are passed (len > 1) embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy() return self.table.search(embeds).limit(limit).to_arrow() - def sql_query(self, - query: str, - return_type: str = 'pandas') -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table + def sql_query( + self, query: str, return_type: str = "pandas" + ) -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table """ Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown. @@ -187,27 +188,29 @@ class Explorer: result = exp.sql_query(query) ``` """ - assert return_type in ['pandas', - 'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}' + assert return_type in [ + "pandas", + "arrow", + ], f"Return type should be either `pandas` or `arrow`, but got {return_type}" import duckdb if self.table is None: - raise ValueError('Table is not created. Please create the table first.') + raise ValueError("Table is not created. Please create the table first.") # Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this. table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB - if not query.startswith('SELECT') and not query.startswith('WHERE'): + if not query.startswith("SELECT") and not query.startswith("WHERE"): raise ValueError( - f'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}' + f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}" ) - if query.startswith('WHERE'): + if query.startswith("WHERE"): query = f"SELECT * FROM 'table' {query}" - LOGGER.info(f'Running query: {query}') + LOGGER.info(f"Running query: {query}") rs = duckdb.sql(query) - if return_type == 'pandas': + if return_type == "pandas": return rs.df() - elif return_type == 'arrow': + elif return_type == "arrow": return rs.arrow() def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image: @@ -228,18 +231,20 @@ class Explorer: result = exp.plot_sql_query(query) ``` """ - result = self.sql_query(query, return_type='arrow') + result = self.sql_query(query, return_type="arrow") if len(result) == 0: - LOGGER.info('No results found.') + LOGGER.info("No results found.") return None img = plot_query_result(result, plot_labels=labels) return Image.fromarray(img) - def get_similar(self, - img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, - idx: Union[int, List[int]] = None, - limit: int = 25, - return_type: str = 'pandas') -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table + def get_similar( + self, + img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, + idx: Union[int, List[int]] = None, + limit: int = 25, + return_type: str = "pandas", + ) -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table """ Query the table for similar images. Accepts a single image or a list of images. @@ -259,21 +264,25 @@ class Explorer: similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg') ``` """ - assert return_type in ['pandas', - 'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}' + assert return_type in [ + "pandas", + "arrow", + ], f"Return type should be either `pandas` or `arrow`, but got {return_type}" img = self._check_imgs_or_idxs(img, idx) similar = self.query(img, limit=limit) - if return_type == 'pandas': + if return_type == "pandas": return similar.to_pandas() - elif return_type == 'arrow': + elif return_type == "arrow": return similar - def plot_similar(self, - img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, - idx: Union[int, List[int]] = None, - limit: int = 25, - labels: bool = True) -> Image.Image: + def plot_similar( + self, + img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, + idx: Union[int, List[int]] = None, + limit: int = 25, + labels: bool = True, + ) -> Image.Image: """ Plot the similar images. Accepts images or indexes. @@ -293,9 +302,9 @@ class Explorer: similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg') ``` """ - similar = self.get_similar(img, idx, limit, return_type='arrow') + similar = self.get_similar(img, idx, limit, return_type="arrow") if len(similar) == 0: - LOGGER.info('No results found.') + LOGGER.info("No results found.") return None img = plot_query_result(similar, plot_labels=labels) return Image.fromarray(img) @@ -323,34 +332,37 @@ class Explorer: ``` """ if self.table is None: - raise ValueError('Table is not created. Please create the table first.') - sim_idx_table_name = f'{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}'.lower() + raise ValueError("Table is not created. Please create the table first.") + sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower() if sim_idx_table_name in self.connection.table_names() and not force: - LOGGER.info('Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.') + LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.") return self.connection.open_table(sim_idx_table_name).to_pandas() if top_k and not (1.0 >= top_k >= 0.0): - raise ValueError(f'top_k must be between 0.0 and 1.0. Got {top_k}') + raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}") if max_dist < 0.0: - raise ValueError(f'max_dist must be greater than 0. Got {max_dist}') + raise ValueError(f"max_dist must be greater than 0. Got {max_dist}") top_k = int(top_k * len(self.table)) if top_k else len(self.table) top_k = max(top_k, 1) - features = self.table.to_lance().to_table(columns=['vector', 'im_file']).to_pydict() - im_files = features['im_file'] - embeddings = features['vector'] + features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict() + im_files = features["im_file"] + embeddings = features["vector"] - sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode='overwrite') + sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite") def _yield_sim_idx(): """Generates a dataframe with similarity indices and distances for images.""" for i in tqdm(range(len(embeddings))): - sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f'_distance <= {max_dist}') - yield [{ - 'idx': i, - 'im_file': im_files[i], - 'count': len(sim_idx), - 'sim_im_files': sim_idx['im_file'].tolist()}] + sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}") + yield [ + { + "idx": i, + "im_file": im_files[i], + "count": len(sim_idx), + "sim_im_files": sim_idx["im_file"].tolist(), + } + ] sim_table.add(_yield_sim_idx()) self.sim_index = sim_table @@ -381,7 +393,7 @@ class Explorer: ``` """ sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force) - sim_count = sim_idx['count'].tolist() + sim_count = sim_idx["count"].tolist() sim_count = np.array(sim_count) indices = np.arange(len(sim_count)) @@ -390,25 +402,26 @@ class Explorer: plt.bar(indices, sim_count) # Customize the plot (optional) - plt.xlabel('data idx') - plt.ylabel('Count') - plt.title('Similarity Count') + plt.xlabel("data idx") + plt.ylabel("Count") + plt.title("Similarity Count") buffer = BytesIO() - plt.savefig(buffer, format='png') + plt.savefig(buffer, format="png") buffer.seek(0) # Use Pillow to open the image from the buffer return Image.fromarray(np.array(Image.open(buffer))) - def _check_imgs_or_idxs(self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], - idx: Union[None, int, List[int]]) -> List[np.ndarray]: + def _check_imgs_or_idxs( + self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]] + ) -> List[np.ndarray]: if img is None and idx is None: - raise ValueError('Either img or idx must be provided.') + raise ValueError("Either img or idx must be provided.") if img is not None and idx is not None: - raise ValueError('Only one of img or idx must be provided.') + raise ValueError("Only one of img or idx must be provided.") if idx is not None: idx = idx if isinstance(idx, list) else [idx] - img = self.table.to_lance().take(idx, columns=['im_file']).to_pydict()['im_file'] + img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"] return img if isinstance(img, list) else [img] @@ -433,7 +446,7 @@ class Explorer: try: df = self.sql_query(result) except Exception as e: - LOGGER.error('AI generated query is not valid. Please try again with a different prompt') + LOGGER.error("AI generated query is not valid. Please try again with a different prompt") LOGGER.error(e) return None return df diff --git a/ultralytics/data/explorer/gui/dash.py b/ultralytics/data/explorer/gui/dash.py index d2b5ec21..1ef7fe42 100644 --- a/ultralytics/data/explorer/gui/dash.py +++ b/ultralytics/data/explorer/gui/dash.py @@ -9,100 +9,114 @@ from ultralytics import Explorer from ultralytics.utils import ROOT, SETTINGS from ultralytics.utils.checks import check_requirements -check_requirements(('streamlit>=1.29.0', 'streamlit-select>=0.2')) +check_requirements(("streamlit>=1.29.0", "streamlit-select>=0.2")) + import streamlit as st from streamlit_select import image_select def _get_explorer(): """Initializes and returns an instance of the Explorer class.""" - exp = Explorer(data=st.session_state.get('dataset'), model=st.session_state.get('model')) - thread = Thread(target=exp.create_embeddings_table, - kwargs={'force': st.session_state.get('force_recreate_embeddings')}) + exp = Explorer(data=st.session_state.get("dataset"), model=st.session_state.get("model")) + thread = Thread( + target=exp.create_embeddings_table, kwargs={"force": st.session_state.get("force_recreate_embeddings")} + ) thread.start() - progress_bar = st.progress(0, text='Creating embeddings table...') + progress_bar = st.progress(0, text="Creating embeddings table...") while exp.progress < 1: time.sleep(0.1) - progress_bar.progress(exp.progress, text=f'Progress: {exp.progress * 100}%') + progress_bar.progress(exp.progress, text=f"Progress: {exp.progress * 100}%") thread.join() - st.session_state['explorer'] = exp + st.session_state["explorer"] = exp progress_bar.empty() def init_explorer_form(): """Initializes an Explorer instance and creates embeddings table with progress tracking.""" - datasets = ROOT / 'cfg' / 'datasets' - ds = [d.name for d in datasets.glob('*.yaml')] + datasets = ROOT / "cfg" / "datasets" + ds = [d.name for d in datasets.glob("*.yaml")] models = [ - 'yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt', 'yolov8l.pt', 'yolov8x.pt', 'yolov8n-seg.pt', 'yolov8s-seg.pt', - 'yolov8m-seg.pt', 'yolov8l-seg.pt', 'yolov8x-seg.pt', 'yolov8n-pose.pt', 'yolov8s-pose.pt', 'yolov8m-pose.pt', - 'yolov8l-pose.pt', 'yolov8x-pose.pt'] - with st.form(key='explorer_init_form'): + "yolov8n.pt", + "yolov8s.pt", + "yolov8m.pt", + "yolov8l.pt", + "yolov8x.pt", + "yolov8n-seg.pt", + "yolov8s-seg.pt", + "yolov8m-seg.pt", + "yolov8l-seg.pt", + "yolov8x-seg.pt", + "yolov8n-pose.pt", + "yolov8s-pose.pt", + "yolov8m-pose.pt", + "yolov8l-pose.pt", + "yolov8x-pose.pt", + ] + with st.form(key="explorer_init_form"): col1, col2 = st.columns(2) with col1: - st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml')) + st.selectbox("Select dataset", ds, key="dataset", index=ds.index("coco128.yaml")) with col2: - st.selectbox('Select model', models, key='model') - st.checkbox('Force recreate embeddings', key='force_recreate_embeddings') + st.selectbox("Select model", models, key="model") + st.checkbox("Force recreate embeddings", key="force_recreate_embeddings") - st.form_submit_button('Explore', on_click=_get_explorer) + st.form_submit_button("Explore", on_click=_get_explorer) def query_form(): """Sets up a form in Streamlit to initialize Explorer with dataset and model selection.""" - with st.form('query_form'): + with st.form("query_form"): col1, col2 = st.columns([0.8, 0.2]) with col1: - st.text_input('Query', - "WHERE labels LIKE '%person%' AND labels LIKE '%dog%'", - label_visibility='collapsed', - key='query') + st.text_input( + "Query", + "WHERE labels LIKE '%person%' AND labels LIKE '%dog%'", + label_visibility="collapsed", + key="query", + ) with col2: - st.form_submit_button('Query', on_click=run_sql_query) + st.form_submit_button("Query", on_click=run_sql_query) def ai_query_form(): """Sets up a Streamlit form for user input to initialize Explorer with dataset and model selection.""" - with st.form('ai_query_form'): + with st.form("ai_query_form"): col1, col2 = st.columns([0.8, 0.2]) with col1: - st.text_input('Query', 'Show images with 1 person and 1 dog', label_visibility='collapsed', key='ai_query') + st.text_input("Query", "Show images with 1 person and 1 dog", label_visibility="collapsed", key="ai_query") with col2: - st.form_submit_button('Ask AI', on_click=run_ai_query) + st.form_submit_button("Ask AI", on_click=run_ai_query) def find_similar_imgs(imgs): """Initializes a Streamlit form for AI-based image querying with custom input.""" - exp = st.session_state['explorer'] - similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow') - paths = similar.to_pydict()['im_file'] - st.session_state['imgs'] = paths + exp = st.session_state["explorer"] + similar = exp.get_similar(img=imgs, limit=st.session_state.get("limit"), return_type="arrow") + paths = similar.to_pydict()["im_file"] + st.session_state["imgs"] = paths def similarity_form(selected_imgs): """Initializes a form for AI-based image querying with custom input in Streamlit.""" - st.write('Similarity Search') - with st.form('similarity_form'): + st.write("Similarity Search") + with st.form("similarity_form"): subcol1, subcol2 = st.columns([1, 1]) with subcol1: - st.number_input('limit', - min_value=None, - max_value=None, - value=25, - label_visibility='collapsed', - key='limit') + st.number_input( + "limit", min_value=None, max_value=None, value=25, label_visibility="collapsed", key="limit" + ) with subcol2: disabled = not len(selected_imgs) - st.write('Selected: ', len(selected_imgs)) + st.write("Selected: ", len(selected_imgs)) st.form_submit_button( - 'Search', + "Search", disabled=disabled, on_click=find_similar_imgs, - args=(selected_imgs, ), + args=(selected_imgs,), ) if disabled: - st.error('Select at least one image to search.') + st.error("Select at least one image to search.") # def persist_reset_form(): @@ -117,100 +131,108 @@ def similarity_form(selected_imgs): def run_sql_query(): """Executes an SQL query and returns the results.""" - st.session_state['error'] = None - query = st.session_state.get('query') + st.session_state["error"] = None + query = st.session_state.get("query") if query.rstrip().lstrip(): - exp = st.session_state['explorer'] - res = exp.sql_query(query, return_type='arrow') - st.session_state['imgs'] = res.to_pydict()['im_file'] + exp = st.session_state["explorer"] + res = exp.sql_query(query, return_type="arrow") + st.session_state["imgs"] = res.to_pydict()["im_file"] def run_ai_query(): """Execute SQL query and update session state with query results.""" - if not SETTINGS['openai_api_key']: + if not SETTINGS["openai_api_key"]: st.session_state[ - 'error'] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."' + "error" + ] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."' return - st.session_state['error'] = None - query = st.session_state.get('ai_query') + st.session_state["error"] = None + query = st.session_state.get("ai_query") if query.rstrip().lstrip(): - exp = st.session_state['explorer'] + exp = st.session_state["explorer"] res = exp.ask_ai(query) if not isinstance(res, pd.DataFrame) or res.empty: - st.session_state['error'] = 'No results found using AI generated query. Try another query or rerun it.' + st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it." return - st.session_state['imgs'] = res['im_file'].to_list() + st.session_state["imgs"] = res["im_file"].to_list() def reset_explorer(): """Resets the explorer to its initial state by clearing session variables.""" - st.session_state['explorer'] = None - st.session_state['imgs'] = None - st.session_state['error'] = None + st.session_state["explorer"] = None + st.session_state["imgs"] = None + st.session_state["error"] = None def utralytics_explorer_docs_callback(): """Resets the explorer to its initial state by clearing session variables.""" with st.container(border=True): - st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg', - width=100) + st.image( + "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg", + width=100, + ) st.markdown( "

This demo is built using Ultralytics Explorer API. Visit API docs to try examples & learn more

", unsafe_allow_html=True, - help=None) - st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/datasets/explorer/') + help=None, + ) + st.link_button("Ultrlaytics Explorer API", "https://docs.ultralytics.com/datasets/explorer/") def layout(): """Resets explorer session variables and provides documentation with a link to API docs.""" - st.set_page_config(layout='wide', initial_sidebar_state='collapsed') + st.set_page_config(layout="wide", initial_sidebar_state="collapsed") st.markdown("

Ultralytics Explorer Demo

", unsafe_allow_html=True) - if st.session_state.get('explorer') is None: + if st.session_state.get("explorer") is None: init_explorer_form() return - st.button(':arrow_backward: Select Dataset', on_click=reset_explorer) - exp = st.session_state.get('explorer') - col1, col2 = st.columns([0.75, 0.25], gap='small') + st.button(":arrow_backward: Select Dataset", on_click=reset_explorer) + exp = st.session_state.get("explorer") + col1, col2 = st.columns([0.75, 0.25], gap="small") imgs = [] - if st.session_state.get('error'): - st.error(st.session_state['error']) + if st.session_state.get("error"): + st.error(st.session_state["error"]) else: - imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file'] + imgs = st.session_state.get("imgs") or exp.table.to_lance().to_table(columns=["im_file"]).to_pydict()["im_file"] total_imgs, selected_imgs = len(imgs), [] with col1: subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5) with subcol1: - st.write('Max Images Displayed:') + st.write("Max Images Displayed:") with subcol2: - num = st.number_input('Max Images Displayed', - min_value=0, - max_value=total_imgs, - value=min(500, total_imgs), - key='num_imgs_displayed', - label_visibility='collapsed') + num = st.number_input( + "Max Images Displayed", + min_value=0, + max_value=total_imgs, + value=min(500, total_imgs), + key="num_imgs_displayed", + label_visibility="collapsed", + ) with subcol3: - st.write('Start Index:') + st.write("Start Index:") with subcol4: - start_idx = st.number_input('Start Index', - min_value=0, - max_value=total_imgs, - value=0, - key='start_index', - label_visibility='collapsed') + start_idx = st.number_input( + "Start Index", + min_value=0, + max_value=total_imgs, + value=0, + key="start_index", + label_visibility="collapsed", + ) with subcol5: - reset = st.button('Reset', use_container_width=False, key='reset') + reset = st.button("Reset", use_container_width=False, key="reset") if reset: - st.session_state['imgs'] = None + st.session_state["imgs"] = None st.experimental_rerun() query_form() ai_query_form() if total_imgs: - imgs_displayed = imgs[start_idx:start_idx + num] + imgs_displayed = imgs[start_idx : start_idx + num] selected_imgs = image_select( - f'Total samples: {total_imgs}', + f"Total samples: {total_imgs}", images=imgs_displayed, use_container_width=False, # indices=[i for i in range(num)] if select_all else None, @@ -222,5 +244,5 @@ def layout(): utralytics_explorer_docs_callback() -if __name__ == '__main__': +if __name__ == "__main__": layout() diff --git a/ultralytics/data/explorer/utils.py b/ultralytics/data/explorer/utils.py index 0064d362..c70722bc 100644 --- a/ultralytics/data/explorer/utils.py +++ b/ultralytics/data/explorer/utils.py @@ -46,14 +46,13 @@ def get_sim_index_schema(): def sanitize_batch(batch, dataset_info): """Sanitizes input batch for inference, ensuring correct format and dimensions.""" - batch['cls'] = batch['cls'].flatten().int().tolist() - box_cls_pair = sorted(zip(batch['bboxes'].tolist(), batch['cls']), key=lambda x: x[1]) - batch['bboxes'] = [box for box, _ in box_cls_pair] - batch['cls'] = [cls for _, cls in box_cls_pair] - batch['labels'] = [dataset_info['names'][i] for i in batch['cls']] - batch['masks'] = batch['masks'].tolist() if 'masks' in batch else [[[]]] - batch['keypoints'] = batch['keypoints'].tolist() if 'keypoints' in batch else [[[]]] - + batch["cls"] = batch["cls"].flatten().int().tolist() + box_cls_pair = sorted(zip(batch["bboxes"].tolist(), batch["cls"]), key=lambda x: x[1]) + batch["bboxes"] = [box for box, _ in box_cls_pair] + batch["cls"] = [cls for _, cls in box_cls_pair] + batch["labels"] = [dataset_info["names"][i] for i in batch["cls"]] + batch["masks"] = batch["masks"].tolist() if "masks" in batch else [[[]]] + batch["keypoints"] = batch["keypoints"].tolist() if "keypoints" in batch else [[[]]] return batch @@ -65,15 +64,16 @@ def plot_query_result(similar_set, plot_labels=True): similar_set (list): Pyarrow or pandas object containing the similar data points plot_labels (bool): Whether to plot labels or not """ - similar_set = similar_set.to_dict( - orient='list') if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict() + similar_set = ( + similar_set.to_dict(orient="list") if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict() + ) empty_masks = [[[]]] empty_boxes = [[]] - images = similar_set.get('im_file', []) - bboxes = similar_set.get('bboxes', []) if similar_set.get('bboxes') is not empty_boxes else [] - masks = similar_set.get('masks') if similar_set.get('masks')[0] != empty_masks else [] - kpts = similar_set.get('keypoints') if similar_set.get('keypoints')[0] != empty_masks else [] - cls = similar_set.get('cls', []) + images = similar_set.get("im_file", []) + bboxes = similar_set.get("bboxes", []) if similar_set.get("bboxes") is not empty_boxes else [] + masks = similar_set.get("masks") if similar_set.get("masks")[0] != empty_masks else [] + kpts = similar_set.get("keypoints") if similar_set.get("keypoints")[0] != empty_masks else [] + cls = similar_set.get("cls", []) plot_size = 640 imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], [] @@ -104,34 +104,26 @@ def plot_query_result(similar_set, plot_labels=True): batch_idx = np.concatenate(batch_idx, axis=0) cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0) - return plot_images(imgs, - batch_idx, - cls, - bboxes=boxes, - masks=masks, - kpts=kpts, - max_subplots=len(images), - save=False, - threaded=False) + return plot_images( + imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, max_subplots=len(images), save=False, threaded=False + ) def prompt_sql_query(query): """Plots images with optional labels from a similar data set.""" - check_requirements('openai>=1.6.1') + check_requirements("openai>=1.6.1") from openai import OpenAI - if not SETTINGS['openai_api_key']: - logger.warning('OpenAI API key not found in settings. Please enter your API key below.') - openai_api_key = getpass.getpass('OpenAI API key: ') - SETTINGS.update({'openai_api_key': openai_api_key}) - openai = OpenAI(api_key=SETTINGS['openai_api_key']) + if not SETTINGS["openai_api_key"]: + logger.warning("OpenAI API key not found in settings. Please enter your API key below.") + openai_api_key = getpass.getpass("OpenAI API key: ") + SETTINGS.update({"openai_api_key": openai_api_key}) + openai = OpenAI(api_key=SETTINGS["openai_api_key"]) messages = [ { - 'role': - 'system', - 'content': - ''' + "role": "system", + "content": """ You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on the following schema and a user request. You only need to output the format with fixed selection statement that selects everything from "'table'", like `SELECT * from 'table'` @@ -165,10 +157,10 @@ def prompt_sql_query(query): request - Get all data points that contain 2 or more people and at least one dog correct query- SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1; - '''}, - { - 'role': 'user', - 'content': f'{query}'}, ] + """, + }, + {"role": "user", "content": f"{query}"}, + ] - response = openai.chat.completions.create(model='gpt-3.5-turbo', messages=messages) + response = openai.chat.completions.create(model="gpt-3.5-turbo", messages=messages) return response.choices[0].message.content diff --git a/ultralytics/data/loaders.py b/ultralytics/data/loaders.py index 2545e9ab..a2cc2be7 100644 --- a/ultralytics/data/loaders.py +++ b/ultralytics/data/loaders.py @@ -23,6 +23,7 @@ from ultralytics.utils.checks import check_requirements @dataclass class SourceTypes: """Class to represent various types of input sources for predictions.""" + webcam: bool = False screenshot: bool = False from_img: bool = False @@ -59,12 +60,12 @@ class LoadStreams: __len__: Return the length of the sources object. """ - def __init__(self, sources='file.streams', imgsz=640, vid_stride=1, buffer=False): + def __init__(self, sources="file.streams", imgsz=640, vid_stride=1, buffer=False): """Initialize instance variables and check for consistent input stream shapes.""" torch.backends.cudnn.benchmark = True # faster for fixed-size inference self.buffer = buffer # buffer input streams self.running = True # running flag for Thread - self.mode = 'stream' + self.mode = "stream" self.imgsz = imgsz self.vid_stride = vid_stride # video frame-rate stride @@ -79,33 +80,36 @@ class LoadStreams: self.sources = [ops.clean_str(x) for x in sources] # clean source names for later for i, s in enumerate(sources): # index, source # Start thread to read frames from video stream - st = f'{i + 1}/{n}: {s}... ' - if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video + st = f"{i + 1}/{n}: {s}... " + if urlparse(s).hostname in ("www.youtube.com", "youtube.com", "youtu.be"): # if source is YouTube video # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4' s = get_best_youtube_url(s) s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam if s == 0 and (is_colab() or is_kaggle()): - raise NotImplementedError("'source=0' webcam not supported in Colab and Kaggle notebooks. " - "Try running 'source=0' in a local environment.") + raise NotImplementedError( + "'source=0' webcam not supported in Colab and Kaggle notebooks. " + "Try running 'source=0' in a local environment." + ) self.caps[i] = cv2.VideoCapture(s) # store video capture object if not self.caps[i].isOpened(): - raise ConnectionError(f'{st}Failed to open {s}') + raise ConnectionError(f"{st}Failed to open {s}") w = int(self.caps[i].get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(self.caps[i].get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = self.caps[i].get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan self.frames[i] = max(int(self.caps[i].get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float( - 'inf') # infinite stream fallback + "inf" + ) # infinite stream fallback self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback success, im = self.caps[i].read() # guarantee first frame if not success or im is None: - raise ConnectionError(f'{st}Failed to read images from {s}') + raise ConnectionError(f"{st}Failed to read images from {s}") self.imgs[i].append(im) self.shape[i] = im.shape self.threads[i] = Thread(target=self.update, args=([i, self.caps[i], s]), daemon=True) - LOGGER.info(f'{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)') + LOGGER.info(f"{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)") self.threads[i].start() - LOGGER.info('') # newline + LOGGER.info("") # newline # Check for common shapes self.bs = self.__len__() @@ -121,7 +125,7 @@ class LoadStreams: success, im = cap.retrieve() if not success: im = np.zeros(self.shape[i], dtype=np.uint8) - LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.') + LOGGER.warning("WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.") cap.open(stream) # re-open stream if signal was lost if self.buffer: self.imgs[i].append(im) @@ -140,7 +144,7 @@ class LoadStreams: try: cap.release() # release video capture except Exception as e: - LOGGER.warning(f'WARNING ⚠️ Could not release VideoCapture object: {e}') + LOGGER.warning(f"WARNING ⚠️ Could not release VideoCapture object: {e}") cv2.destroyAllWindows() def __iter__(self): @@ -154,16 +158,15 @@ class LoadStreams: images = [] for i, x in enumerate(self.imgs): - # Wait until a frame is available in each buffer while not x: - if not self.threads[i].is_alive() or cv2.waitKey(1) == ord('q'): # q to quit + if not self.threads[i].is_alive() or cv2.waitKey(1) == ord("q"): # q to quit self.close() raise StopIteration time.sleep(1 / min(self.fps)) x = self.imgs[i] if not x: - LOGGER.warning(f'WARNING ⚠️ Waiting for stream {i}') + LOGGER.warning(f"WARNING ⚠️ Waiting for stream {i}") # Get and remove the first frame from imgs buffer if self.buffer: @@ -174,7 +177,7 @@ class LoadStreams: images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8)) x.clear() - return self.sources, images, None, '' + return self.sources, images, None, "" def __len__(self): """Return the length of the sources object.""" @@ -209,7 +212,7 @@ class LoadScreenshots: def __init__(self, source, imgsz=640): """Source = [screen_number left top width height] (pixels).""" - check_requirements('mss') + check_requirements("mss") import mss # noqa source, *params = source.split() @@ -221,18 +224,18 @@ class LoadScreenshots: elif len(params) == 5: self.screen, left, top, width, height = (int(x) for x in params) self.imgsz = imgsz - self.mode = 'stream' + self.mode = "stream" self.frame = 0 self.sct = mss.mss() self.bs = 1 # Parse monitor shape monitor = self.sct.monitors[self.screen] - self.top = monitor['top'] if top is None else (monitor['top'] + top) - self.left = monitor['left'] if left is None else (monitor['left'] + left) - self.width = width or monitor['width'] - self.height = height or monitor['height'] - self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height} + self.top = monitor["top"] if top is None else (monitor["top"] + top) + self.left = monitor["left"] if left is None else (monitor["left"] + left) + self.width = width or monitor["width"] + self.height = height or monitor["height"] + self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height} def __iter__(self): """Returns an iterator of the object.""" @@ -241,7 +244,7 @@ class LoadScreenshots: def __next__(self): """mss screen capture: get raw pixels from the screen as np array.""" im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3] # BGRA to BGR - s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: ' + s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: " self.frame += 1 return [str(self.screen)], [im0], None, s # screen, img, vid_cap, string @@ -274,32 +277,32 @@ class LoadImages: def __init__(self, path, imgsz=640, vid_stride=1): """Initialize the Dataloader and raise FileNotFoundError if file not found.""" parent = None - if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line + if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line parent = Path(path).parent path = Path(path).read_text().splitlines() # list of sources files = [] for p in sorted(path) if isinstance(path, (list, tuple)) else [path]: a = str(Path(p).absolute()) # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912 - if '*' in a: + if "*" in a: files.extend(sorted(glob.glob(a, recursive=True))) # glob elif os.path.isdir(a): - files.extend(sorted(glob.glob(os.path.join(a, '*.*')))) # dir + files.extend(sorted(glob.glob(os.path.join(a, "*.*")))) # dir elif os.path.isfile(a): files.append(a) # files (absolute or relative to CWD) elif parent and (parent / p).is_file(): files.append(str((parent / p).absolute())) # files (relative to *.txt file parent) else: - raise FileNotFoundError(f'{p} does not exist') + raise FileNotFoundError(f"{p} does not exist") - images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS] - videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS] + images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS] + videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS] ni, nv = len(images), len(videos) self.imgsz = imgsz self.files = images + videos self.nf = ni + nv # number of files self.video_flag = [False] * ni + [True] * nv - self.mode = 'image' + self.mode = "image" self.vid_stride = vid_stride # video frame-rate stride self.bs = 1 if any(videos): @@ -307,8 +310,10 @@ class LoadImages: else: self.cap = None if self.nf == 0: - raise FileNotFoundError(f'No images or videos found in {p}. ' - f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}') + raise FileNotFoundError( + f"No images or videos found in {p}. " + f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}" + ) def __iter__(self): """Returns an iterator object for VideoStream or ImageFolder.""" @@ -323,7 +328,7 @@ class LoadImages: if self.video_flag[self.count]: # Read video - self.mode = 'video' + self.mode = "video" for _ in range(self.vid_stride): self.cap.grab() success, im0 = self.cap.retrieve() @@ -338,15 +343,15 @@ class LoadImages: self.frame += 1 # im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False - s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: ' + s = f"video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: " else: # Read image self.count += 1 im0 = cv2.imread(path) # BGR if im0 is None: - raise FileNotFoundError(f'Image Not Found {path}') - s = f'image {self.count}/{self.nf} {path}: ' + raise FileNotFoundError(f"Image Not Found {path}") + s = f"image {self.count}/{self.nf} {path}: " return [path], [im0], self.cap, s @@ -385,20 +390,20 @@ class LoadPilAndNumpy: """Initialize PIL and Numpy Dataloader.""" if not isinstance(im0, list): im0 = [im0] - self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)] + self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)] self.im0 = [self._single_check(im) for im in im0] self.imgsz = imgsz - self.mode = 'image' + self.mode = "image" # Generate fake paths self.bs = len(self.im0) @staticmethod def _single_check(im): """Validate and format an image to numpy array.""" - assert isinstance(im, (Image.Image, np.ndarray)), f'Expected PIL/np.ndarray image type, but got {type(im)}' + assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}" if isinstance(im, Image.Image): - if im.mode != 'RGB': - im = im.convert('RGB') + if im.mode != "RGB": + im = im.convert("RGB") im = np.asarray(im)[:, :, ::-1] im = np.ascontiguousarray(im) # contiguous return im @@ -412,7 +417,7 @@ class LoadPilAndNumpy: if self.count == 1: # loop only once as it's batch inference raise StopIteration self.count += 1 - return self.paths, self.im0, None, '' + return self.paths, self.im0, None, "" def __iter__(self): """Enables iteration for class LoadPilAndNumpy.""" @@ -441,14 +446,16 @@ class LoadTensor: """Initialize Tensor Dataloader.""" self.im0 = self._single_check(im0) self.bs = self.im0.shape[0] - self.mode = 'image' - self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)] + self.mode = "image" + self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)] @staticmethod def _single_check(im, stride=32): """Validate and format an image to torch.Tensor.""" - s = f'WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) ' \ - f'divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible.' + s = ( + f"WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) " + f"divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible." + ) if len(im.shape) != 4: if len(im.shape) != 3: raise ValueError(s) @@ -457,8 +464,10 @@ class LoadTensor: if im.shape[2] % stride or im.shape[3] % stride: raise ValueError(s) if im.max() > 1.0 + torch.finfo(im.dtype).eps: # torch.float32 eps is 1.2e-07 - LOGGER.warning(f'WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. ' - f'Dividing input by 255.') + LOGGER.warning( + f"WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. " + f"Dividing input by 255." + ) im = im.float() / 255.0 return im @@ -473,7 +482,7 @@ class LoadTensor: if self.count == 1: raise StopIteration self.count += 1 - return self.paths, self.im0, None, '' + return self.paths, self.im0, None, "" def __len__(self): """Returns the batch size.""" @@ -485,12 +494,14 @@ def autocast_list(source): files = [] for im in source: if isinstance(im, (str, Path)): # filename or uri - files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im)) + files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith("http") else im)) elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image files.append(im) else: - raise TypeError(f'type {type(im).__name__} is not a supported Ultralytics prediction source type. \n' - f'See https://docs.ultralytics.com/modes/predict for supported source types.') + raise TypeError( + f"type {type(im).__name__} is not a supported Ultralytics prediction source type. \n" + f"See https://docs.ultralytics.com/modes/predict for supported source types." + ) return files @@ -513,16 +524,18 @@ def get_best_youtube_url(url, use_pafy=True): (str): The URL of the best quality MP4 video stream, or None if no suitable stream is found. """ if use_pafy: - check_requirements(('pafy', 'youtube_dl==2020.12.2')) + check_requirements(("pafy", "youtube_dl==2020.12.2")) import pafy # noqa - return pafy.new(url).getbestvideo(preftype='mp4').url + + return pafy.new(url).getbestvideo(preftype="mp4").url else: - check_requirements('yt-dlp') + check_requirements("yt-dlp") import yt_dlp - with yt_dlp.YoutubeDL({'quiet': True}) as ydl: + + with yt_dlp.YoutubeDL({"quiet": True}) as ydl: info_dict = ydl.extract_info(url, download=False) # extract info - for f in reversed(info_dict.get('formats', [])): # reversed because best is usually last + for f in reversed(info_dict.get("formats", [])): # reversed because best is usually last # Find a format with video codec, no audio, *.mp4 extension at least 1920x1080 size - good_size = (f.get('width') or 0) >= 1920 or (f.get('height') or 0) >= 1080 - if good_size and f['vcodec'] != 'none' and f['acodec'] == 'none' and f['ext'] == 'mp4': - return f.get('url') + good_size = (f.get("width") or 0) >= 1920 or (f.get("height") or 0) >= 1080 + if good_size and f["vcodec"] != "none" and f["acodec"] == "none" and f["ext"] == "mp4": + return f.get("url") diff --git a/ultralytics/data/split_dota.py b/ultralytics/data/split_dota.py index aea97493..6e7169b7 100644 --- a/ultralytics/data/split_dota.py +++ b/ultralytics/data/split_dota.py @@ -14,7 +14,7 @@ from tqdm import tqdm from ultralytics.data.utils import exif_size, img2label_paths from ultralytics.utils.checks import check_requirements -check_requirements('shapely') +check_requirements("shapely") from shapely.geometry import Polygon @@ -54,7 +54,7 @@ def bbox_iof(polygon1, bbox2, eps=1e-6): return outputs -def load_yolo_dota(data_root, split='train'): +def load_yolo_dota(data_root, split="train"): """ Load DOTA dataset. @@ -72,10 +72,10 @@ def load_yolo_dota(data_root, split='train'): - train - val """ - assert split in ['train', 'val'] - im_dir = os.path.join(data_root, f'images/{split}') + assert split in ["train", "val"] + im_dir = os.path.join(data_root, f"images/{split}") assert Path(im_dir).exists(), f"Can't find {im_dir}, please check your data root." - im_files = glob(os.path.join(data_root, f'images/{split}/*')) + im_files = glob(os.path.join(data_root, f"images/{split}/*")) lb_files = img2label_paths(im_files) annos = [] for im_file, lb_file in zip(im_files, lb_files): @@ -100,7 +100,7 @@ def get_windows(im_size, crop_sizes=[1024], gaps=[200], im_rate_thr=0.6, eps=0.0 h, w = im_size windows = [] for crop_size, gap in zip(crop_sizes, gaps): - assert crop_size > gap, f'invaild crop_size gap pair [{crop_size} {gap}]' + assert crop_size > gap, f"invalid crop_size gap pair [{crop_size} {gap}]" step = crop_size - gap xn = 1 if w <= crop_size else ceil((w - crop_size) / step + 1) @@ -132,8 +132,8 @@ def get_windows(im_size, crop_sizes=[1024], gaps=[200], im_rate_thr=0.6, eps=0.0 def get_window_obj(anno, windows, iof_thr=0.7): """Get objects for each window.""" - h, w = anno['ori_size'] - label = anno['label'] + h, w = anno["ori_size"] + label = anno["label"] if len(label): label[:, 1::2] *= w label[:, 2::2] *= h @@ -166,15 +166,15 @@ def crop_and_save(anno, windows, window_objs, im_dir, lb_dir): - train - val """ - im = cv2.imread(anno['filepath']) - name = Path(anno['filepath']).stem + im = cv2.imread(anno["filepath"]) + name = Path(anno["filepath"]).stem for i, window in enumerate(windows): x_start, y_start, x_stop, y_stop = window.tolist() - new_name = name + '__' + str(x_stop - x_start) + '__' + str(x_start) + '___' + str(y_start) + new_name = name + "__" + str(x_stop - x_start) + "__" + str(x_start) + "___" + str(y_start) patch_im = im[y_start:y_stop, x_start:x_stop] ph, pw = patch_im.shape[:2] - cv2.imwrite(os.path.join(im_dir, f'{new_name}.jpg'), patch_im) + cv2.imwrite(os.path.join(im_dir, f"{new_name}.jpg"), patch_im) label = window_objs[i] if len(label) == 0: continue @@ -183,13 +183,13 @@ def crop_and_save(anno, windows, window_objs, im_dir, lb_dir): label[:, 1::2] /= pw label[:, 2::2] /= ph - with open(os.path.join(lb_dir, f'{new_name}.txt'), 'w') as f: + with open(os.path.join(lb_dir, f"{new_name}.txt"), "w") as f: for lb in label: - formatted_coords = ['{:.6g}'.format(coord) for coord in lb[1:]] + formatted_coords = ["{:.6g}".format(coord) for coord in lb[1:]] f.write(f"{int(lb[0])} {' '.join(formatted_coords)}\n") -def split_images_and_labels(data_root, save_dir, split='train', crop_sizes=[1024], gaps=[200]): +def split_images_and_labels(data_root, save_dir, split="train", crop_sizes=[1024], gaps=[200]): """ Split both images and labels. @@ -207,14 +207,14 @@ def split_images_and_labels(data_root, save_dir, split='train', crop_sizes=[1024 - labels - split """ - im_dir = Path(save_dir) / 'images' / split + im_dir = Path(save_dir) / "images" / split im_dir.mkdir(parents=True, exist_ok=True) - lb_dir = Path(save_dir) / 'labels' / split + lb_dir = Path(save_dir) / "labels" / split lb_dir.mkdir(parents=True, exist_ok=True) annos = load_yolo_dota(data_root, split=split) for anno in tqdm(annos, total=len(annos), desc=split): - windows = get_windows(anno['ori_size'], crop_sizes, gaps) + windows = get_windows(anno["ori_size"], crop_sizes, gaps) window_objs = get_window_obj(anno, windows) crop_and_save(anno, windows, window_objs, str(im_dir), str(lb_dir)) @@ -245,7 +245,7 @@ def split_trainval(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]): for r in rates: crop_sizes.append(int(crop_size / r)) gaps.append(int(gap / r)) - for split in ['train', 'val']: + for split in ["train", "val"]: split_images_and_labels(data_root, save_dir, split, crop_sizes, gaps) @@ -267,30 +267,30 @@ def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]): for r in rates: crop_sizes.append(int(crop_size / r)) gaps.append(int(gap / r)) - save_dir = Path(save_dir) / 'images' / 'test' + save_dir = Path(save_dir) / "images" / "test" save_dir.mkdir(parents=True, exist_ok=True) - im_dir = Path(os.path.join(data_root, 'images/test')) + im_dir = Path(os.path.join(data_root, "images/test")) assert im_dir.exists(), f"Can't find {str(im_dir)}, please check your data root." - im_files = glob(str(im_dir / '*')) - for im_file in tqdm(im_files, total=len(im_files), desc='test'): + im_files = glob(str(im_dir / "*")) + for im_file in tqdm(im_files, total=len(im_files), desc="test"): w, h = exif_size(Image.open(im_file)) windows = get_windows((h, w), crop_sizes=crop_sizes, gaps=gaps) im = cv2.imread(im_file) name = Path(im_file).stem for window in windows: x_start, y_start, x_stop, y_stop = window.tolist() - new_name = (name + '__' + str(x_stop - x_start) + '__' + str(x_start) + '___' + str(y_start)) + new_name = name + "__" + str(x_stop - x_start) + "__" + str(x_start) + "___" + str(y_start) patch_im = im[y_start:y_stop, x_start:x_stop] - cv2.imwrite(os.path.join(str(save_dir), f'{new_name}.jpg'), patch_im) + cv2.imwrite(os.path.join(str(save_dir), f"{new_name}.jpg"), patch_im) -if __name__ == '__main__': +if __name__ == "__main__": split_trainval( - data_root='DOTAv2', - save_dir='DOTAv2-split', + data_root="DOTAv2", + save_dir="DOTAv2-split", ) split_test( - data_root='DOTAv2', - save_dir='DOTAv2-split', + data_root="DOTAv2", + save_dir="DOTAv2-split", ) diff --git a/ultralytics/data/utils.py b/ultralytics/data/utils.py index c3447394..9214f8ce 100644 --- a/ultralytics/data/utils.py +++ b/ultralytics/data/utils.py @@ -17,36 +17,47 @@ import numpy as np from PIL import Image, ImageOps from ultralytics.nn.autobackend import check_class_names -from ultralytics.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_YAML, TQDM, clean_url, colorstr, - emojis, yaml_load, yaml_save) +from ultralytics.utils import ( + DATASETS_DIR, + LOGGER, + NUM_THREADS, + ROOT, + SETTINGS_YAML, + TQDM, + clean_url, + colorstr, + emojis, + yaml_load, + yaml_save, +) from ultralytics.utils.checks import check_file, check_font, is_ascii from ultralytics.utils.downloads import download, safe_download, unzip_file from ultralytics.utils.ops import segments2boxes -HELP_URL = 'See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance.' -IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # image suffixes -VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv', 'webm' # video suffixes -PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders +HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance." +IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # image suffixes +VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm" # video suffixes +PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders def img2label_paths(img_paths): """Define label paths as a function of image paths.""" - sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings - return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths] + sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings + return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths] def get_hash(paths): """Returns a single hash value of a list of paths (files or dirs).""" size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes h = hashlib.sha256(str(size).encode()) # hash sizes - h.update(''.join(paths).encode()) # hash paths + h.update("".join(paths).encode()) # hash paths return h.hexdigest() # return hash def exif_size(img: Image.Image): """Returns exif-corrected PIL size.""" s = img.size # (width, height) - if img.format == 'JPEG': # only support JPEG images + if img.format == "JPEG": # only support JPEG images with contextlib.suppress(Exception): exif = img.getexif() if exif: @@ -60,24 +71,24 @@ def verify_image(args): """Verify one image.""" (im_file, cls), prefix = args # Number (found, corrupt), message - nf, nc, msg = 0, 0, '' + nf, nc, msg = 0, 0, "" try: im = Image.open(im_file) im.verify() # PIL verify shape = exif_size(im) # image size shape = (shape[1], shape[0]) # hw - assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' - assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}' - if im.format.lower() in ('jpg', 'jpeg'): - with open(im_file, 'rb') as f: + assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" + assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}" + if im.format.lower() in ("jpg", "jpeg"): + with open(im_file, "rb") as f: f.seek(-2, 2) - if f.read() != b'\xff\xd9': # corrupt JPEG - ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100) - msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved' + if f.read() != b"\xff\xd9": # corrupt JPEG + ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100) + msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved" nf = 1 except Exception as e: nc = 1 - msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}' + msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}" return (im_file, cls), nf, nc, msg @@ -85,21 +96,21 @@ def verify_image_label(args): """Verify one image-label pair.""" im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args # Number (missing, found, empty, corrupt), message, segments, keypoints - nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, '', [], None + nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None try: # Verify images im = Image.open(im_file) im.verify() # PIL verify shape = exif_size(im) # image size shape = (shape[1], shape[0]) # hw - assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' - assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}' - if im.format.lower() in ('jpg', 'jpeg'): - with open(im_file, 'rb') as f: + assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" + assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}" + if im.format.lower() in ("jpg", "jpeg"): + with open(im_file, "rb") as f: f.seek(-2, 2) - if f.read() != b'\xff\xd9': # corrupt JPEG - ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100) - msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved' + if f.read() != b"\xff\xd9": # corrupt JPEG + ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100) + msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved" # Verify labels if os.path.isfile(lb_file): @@ -114,25 +125,26 @@ def verify_image_label(args): nl = len(lb) if nl: if keypoint: - assert lb.shape[1] == (5 + nkpt * ndim), f'labels require {(5 + nkpt * ndim)} columns each' + assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each" points = lb[:, 5:].reshape(-1, ndim)[:, :2] else: - assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected' + assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected" points = lb[:, 1:] - assert points.max() <= 1, f'non-normalized or out of bounds coordinates {points[points > 1]}' - assert lb.min() >= 0, f'negative label values {lb[lb < 0]}' + assert points.max() <= 1, f"non-normalized or out of bounds coordinates {points[points > 1]}" + assert lb.min() >= 0, f"negative label values {lb[lb < 0]}" # All labels max_cls = lb[:, 0].max() # max label count - assert max_cls <= num_cls, \ - f'Label class {int(max_cls)} exceeds dataset class count {num_cls}. ' \ - f'Possible class labels are 0-{num_cls - 1}' + assert max_cls <= num_cls, ( + f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. " + f"Possible class labels are 0-{num_cls - 1}" + ) _, i = np.unique(lb, axis=0, return_index=True) if len(i) < nl: # duplicate row check lb = lb[i] # remove duplicates if segments: segments = [segments[x] for x in i] - msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed' + msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed" else: ne = 1 # label empty lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32) @@ -148,7 +160,7 @@ def verify_image_label(args): return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg except Exception as e: nc = 1 - msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}' + msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}" return [None, None, None, None, None, nm, nf, ne, nc, msg] @@ -194,8 +206,10 @@ def polygons2masks(imgsz, polygons, color, downsample_ratio=1): def polygons2masks_overlap(imgsz, segments, downsample_ratio=1): """Return a (640, 640) overlap mask.""" - masks = np.zeros((imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio), - dtype=np.int32 if len(segments) > 255 else np.uint8) + masks = np.zeros( + (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio), + dtype=np.int32 if len(segments) > 255 else np.uint8, + ) areas = [] ms = [] for si in range(len(segments)): @@ -226,7 +240,7 @@ def find_dataset_yaml(path: Path) -> Path: Returns: (Path): The path of the found YAML file. """ - files = list(path.glob('*.yaml')) or list(path.rglob('*.yaml')) # try root level first and then recursive + files = list(path.glob("*.yaml")) or list(path.rglob("*.yaml")) # try root level first and then recursive assert files, f"No YAML file found in '{path.resolve()}'" if len(files) > 1: files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match @@ -253,7 +267,7 @@ def check_det_dataset(dataset, autodownload=True): file = check_file(dataset) # Download (optional) - extract_dir = '' + extract_dir = "" if zipfile.is_zipfile(file) or is_tarfile(file): new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False) file = find_dataset_yaml(DATASETS_DIR / new_dir) @@ -263,43 +277,44 @@ def check_det_dataset(dataset, autodownload=True): data = yaml_load(file, append_filename=True) # dictionary # Checks - for k in 'train', 'val': + for k in "train", "val": if k not in data: - if k != 'val' or 'validation' not in data: + if k != "val" or "validation" not in data: raise SyntaxError( - emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.")) + emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.") + ) LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.") - data['val'] = data.pop('validation') # replace 'validation' key with 'val' key - if 'names' not in data and 'nc' not in data: + data["val"] = data.pop("validation") # replace 'validation' key with 'val' key + if "names" not in data and "nc" not in data: raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs.")) - if 'names' in data and 'nc' in data and len(data['names']) != data['nc']: + if "names" in data and "nc" in data and len(data["names"]) != data["nc"]: raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match.")) - if 'names' not in data: - data['names'] = [f'class_{i}' for i in range(data['nc'])] + if "names" not in data: + data["names"] = [f"class_{i}" for i in range(data["nc"])] else: - data['nc'] = len(data['names']) + data["nc"] = len(data["names"]) - data['names'] = check_class_names(data['names']) + data["names"] = check_class_names(data["names"]) # Resolve paths - path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent) # dataset root + path = Path(extract_dir or data.get("path") or Path(data.get("yaml_file", "")).parent) # dataset root if not path.is_absolute(): path = (DATASETS_DIR / path).resolve() # Set paths - data['path'] = path # download scripts - for k in 'train', 'val', 'test': + data["path"] = path # download scripts + for k in "train", "val", "test": if data.get(k): # prepend path if isinstance(data[k], str): x = (path / data[k]).resolve() - if not x.exists() and data[k].startswith('../'): + if not x.exists() and data[k].startswith("../"): x = (path / data[k][3:]).resolve() data[k] = str(x) else: data[k] = [str((path / x).resolve()) for x in data[k]] # Parse YAML - val, s = (data.get(x) for x in ('val', 'download')) + val, s = (data.get(x) for x in ("val", "download")) if val: val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path if not all(x.exists() for x in val): @@ -312,22 +327,22 @@ def check_det_dataset(dataset, autodownload=True): raise FileNotFoundError(m) t = time.time() r = None # success - if s.startswith('http') and s.endswith('.zip'): # URL + if s.startswith("http") and s.endswith(".zip"): # URL safe_download(url=s, dir=DATASETS_DIR, delete=True) - elif s.startswith('bash '): # bash script - LOGGER.info(f'Running {s} ...') + elif s.startswith("bash "): # bash script + LOGGER.info(f"Running {s} ...") r = os.system(s) else: # python script - exec(s, {'yaml': data}) - dt = f'({round(time.time() - t, 1)}s)' - s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt} ❌' - LOGGER.info(f'Dataset download {s}\n') - check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts + exec(s, {"yaml": data}) + dt = f"({round(time.time() - t, 1)}s)" + s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌" + LOGGER.info(f"Dataset download {s}\n") + check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts return data # dictionary -def check_cls_dataset(dataset, split=''): +def check_cls_dataset(dataset, split=""): """ Checks a classification dataset such as Imagenet. @@ -348,54 +363,59 @@ def check_cls_dataset(dataset, split=''): """ # Download (optional if dataset=https://file.zip is passed directly) - if str(dataset).startswith(('http:/', 'https:/')): + if str(dataset).startswith(("http:/", "https:/")): dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False) dataset = Path(dataset) data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve() if not data_dir.is_dir(): - LOGGER.warning(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...') + LOGGER.warning(f"\nDataset not found ⚠️, missing path {data_dir}, attempting download...") t = time.time() - if str(dataset) == 'imagenet': + if str(dataset) == "imagenet": subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True) else: - url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip' + url = f"https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip" download(url, dir=data_dir.parent) s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n" LOGGER.info(s) - train_set = data_dir / 'train' - val_set = data_dir / 'val' if (data_dir / 'val').exists() else data_dir / 'validation' if \ - (data_dir / 'validation').exists() else None # data/test or data/val - test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test - if split == 'val' and not val_set: + train_set = data_dir / "train" + val_set = ( + data_dir / "val" + if (data_dir / "val").exists() + else data_dir / "validation" + if (data_dir / "validation").exists() + else None + ) # data/test or data/val + test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test + if split == "val" and not val_set: LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.") - elif split == 'test' and not test_set: + elif split == "test" and not test_set: LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.") - nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes - names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list + nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # number of classes + names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # class names list names = dict(enumerate(sorted(names))) # Print to console - for k, v in {'train': train_set, 'val': val_set, 'test': test_set}.items(): + for k, v in {"train": train_set, "val": val_set, "test": test_set}.items(): prefix = f'{colorstr(f"{k}:")} {v}...' if v is None: LOGGER.info(prefix) else: - files = [path for path in v.rglob('*.*') if path.suffix[1:].lower() in IMG_FORMATS] + files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS] nf = len(files) # number of files nd = len({file.parent for file in files}) # number of directories if nf == 0: - if k == 'train': + if k == "train": raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ ")) else: - LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found') + LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found") elif nd != nc: - LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}') + LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}") else: - LOGGER.info(f'{prefix} found {nf} images in {nd} classes ✅ ') + LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ") - return {'train': train_set, 'val': val_set, 'test': test_set, 'nc': nc, 'names': names} + return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names} class HUBDatasetStats: @@ -423,42 +443,43 @@ class HUBDatasetStats: ``` """ - def __init__(self, path='coco8.yaml', task='detect', autodownload=False): + def __init__(self, path="coco8.yaml", task="detect", autodownload=False): """Initialize class.""" path = Path(path).resolve() - LOGGER.info(f'Starting HUB dataset checks for {path}....') + LOGGER.info(f"Starting HUB dataset checks for {path}....") self.task = task # detect, segment, pose, classify - if self.task == 'classify': + if self.task == "classify": unzip_dir = unzip_file(path) data = check_cls_dataset(unzip_dir) - data['path'] = unzip_dir + data["path"] = unzip_dir else: # detect, segment, pose _, data_dir, yaml_path = self._unzip(Path(path)) try: # Load YAML with checks data = yaml_load(yaml_path) - data['path'] = '' # strip path since YAML should be in dataset root for all HUB datasets + data["path"] = "" # strip path since YAML should be in dataset root for all HUB datasets yaml_save(yaml_path, data) data = check_det_dataset(yaml_path, autodownload) # dict - data['path'] = data_dir # YAML path should be set to '' (relative) or parent (absolute) + data["path"] = data_dir # YAML path should be set to '' (relative) or parent (absolute) except Exception as e: - raise Exception('error/HUB/dataset_stats/init') from e + raise Exception("error/HUB/dataset_stats/init") from e self.hub_dir = Path(f'{data["path"]}-hub') - self.im_dir = self.hub_dir / 'images' + self.im_dir = self.hub_dir / "images" self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images - self.stats = {'nc': len(data['names']), 'names': list(data['names'].values())} # statistics dictionary + self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())} # statistics dictionary self.data = data @staticmethod def _unzip(path): """Unzip data.zip.""" - if not str(path).endswith('.zip'): # path is data.yaml + if not str(path).endswith(".zip"): # path is data.yaml return False, None, path unzip_dir = unzip_file(path, path=path.parent) - assert unzip_dir.is_dir(), f'Error unzipping {path}, {unzip_dir} not found. ' \ - f'path/to/abc.zip MUST unzip to path/to/abc/' + assert unzip_dir.is_dir(), ( + f"Error unzipping {path}, {unzip_dir} not found. " f"path/to/abc.zip MUST unzip to path/to/abc/" + ) return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path def _hub_ops(self, f): @@ -470,31 +491,31 @@ class HUBDatasetStats: def _round(labels): """Update labels to integer class and 4 decimal place floats.""" - if self.task == 'detect': - coordinates = labels['bboxes'] - elif self.task == 'segment': - coordinates = [x.flatten() for x in labels['segments']] - elif self.task == 'pose': - n = labels['keypoints'].shape[0] - coordinates = np.concatenate((labels['bboxes'], labels['keypoints'].reshape(n, -1)), 1) + if self.task == "detect": + coordinates = labels["bboxes"] + elif self.task == "segment": + coordinates = [x.flatten() for x in labels["segments"]] + elif self.task == "pose": + n = labels["keypoints"].shape[0] + coordinates = np.concatenate((labels["bboxes"], labels["keypoints"].reshape(n, -1)), 1) else: - raise ValueError('Undefined dataset task.') - zipped = zip(labels['cls'], coordinates) + raise ValueError("Undefined dataset task.") + zipped = zip(labels["cls"], coordinates) return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped] - for split in 'train', 'val', 'test': + for split in "train", "val", "test": self.stats[split] = None # predefine path = self.data.get(split) # Check split if path is None: # no split continue - files = [f for f in Path(path).rglob('*.*') if f.suffix[1:].lower() in IMG_FORMATS] # image files in split + files = [f for f in Path(path).rglob("*.*") if f.suffix[1:].lower() in IMG_FORMATS] # image files in split if not files: # no images continue # Get dataset statistics - if self.task == 'classify': + if self.task == "classify": from torchvision.datasets import ImageFolder dataset = ImageFolder(self.data[split]) @@ -504,38 +525,35 @@ class HUBDatasetStats: x[im[1]] += 1 self.stats[split] = { - 'instance_stats': { - 'total': len(dataset), - 'per_class': x.tolist()}, - 'image_stats': { - 'total': len(dataset), - 'unlabelled': 0, - 'per_class': x.tolist()}, - 'labels': [{ - Path(k).name: v} for k, v in dataset.imgs]} + "instance_stats": {"total": len(dataset), "per_class": x.tolist()}, + "image_stats": {"total": len(dataset), "unlabelled": 0, "per_class": x.tolist()}, + "labels": [{Path(k).name: v} for k, v in dataset.imgs], + } else: from ultralytics.data import YOLODataset dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task) - x = np.array([ - np.bincount(label['cls'].astype(int).flatten(), minlength=self.data['nc']) - for label in TQDM(dataset.labels, total=len(dataset), desc='Statistics')]) # shape(128x80) + x = np.array( + [ + np.bincount(label["cls"].astype(int).flatten(), minlength=self.data["nc"]) + for label in TQDM(dataset.labels, total=len(dataset), desc="Statistics") + ] + ) # shape(128x80) self.stats[split] = { - 'instance_stats': { - 'total': int(x.sum()), - 'per_class': x.sum(0).tolist()}, - 'image_stats': { - 'total': len(dataset), - 'unlabelled': int(np.all(x == 0, 1).sum()), - 'per_class': (x > 0).sum(0).tolist()}, - 'labels': [{ - Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)]} + "instance_stats": {"total": int(x.sum()), "per_class": x.sum(0).tolist()}, + "image_stats": { + "total": len(dataset), + "unlabelled": int(np.all(x == 0, 1).sum()), + "per_class": (x > 0).sum(0).tolist(), + }, + "labels": [{Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)], + } # Save, print and return if save: - stats_path = self.hub_dir / 'stats.json' - LOGGER.info(f'Saving {stats_path.resolve()}...') - with open(stats_path, 'w') as f: + stats_path = self.hub_dir / "stats.json" + LOGGER.info(f"Saving {stats_path.resolve()}...") + with open(stats_path, "w") as f: json.dump(self.stats, f) # save stats.json if verbose: LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False)) @@ -545,14 +563,14 @@ class HUBDatasetStats: """Compress images for Ultralytics HUB.""" from ultralytics.data import YOLODataset # ClassificationDataset - for split in 'train', 'val', 'test': + for split in "train", "val", "test": if self.data.get(split) is None: continue dataset = YOLODataset(img_path=self.data[split], data=self.data) with ThreadPool(NUM_THREADS) as pool: - for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f'{split} images'): + for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f"{split} images"): pass - LOGGER.info(f'Done. All images saved to {self.im_dir}') + LOGGER.info(f"Done. All images saved to {self.im_dir}") return self.im_dir @@ -583,9 +601,9 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50): r = max_dim / max(im.height, im.width) # ratio if r < 1.0: # image too large im = im.resize((int(im.width * r), int(im.height * r))) - im.save(f_new or f, 'JPEG', quality=quality, optimize=True) # save + im.save(f_new or f, "JPEG", quality=quality, optimize=True) # save except Exception as e: # use OpenCV - LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}') + LOGGER.info(f"WARNING ⚠️ HUB ops PIL failure {f}: {e}") im = cv2.imread(f) im_height, im_width = im.shape[:2] r = max_dim / max(im_height, im_width) # ratio @@ -594,7 +612,7 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50): cv2.imwrite(str(f_new or f), im) -def autosplit(path=DATASETS_DIR / 'coco8/images', weights=(0.9, 0.1, 0.0), annotated_only=False): +def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annotated_only=False): """ Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files. @@ -612,18 +630,18 @@ def autosplit(path=DATASETS_DIR / 'coco8/images', weights=(0.9, 0.1, 0.0), annot """ path = Path(path) # images dir - files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only + files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS) # image files only n = len(files) # number of files random.seed(0) # for reproducibility indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split - txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files + txt = ["autosplit_train.txt", "autosplit_val.txt", "autosplit_test.txt"] # 3 txt files for x in txt: if (path.parent / x).exists(): (path.parent / x).unlink() # remove existing - LOGGER.info(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only) + LOGGER.info(f"Autosplitting images from {path}" + ", using *.txt labeled images only" * annotated_only) for i, img in TQDM(zip(indices, files), total=n): if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label - with open(path.parent / txt[i], 'a') as f: - f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file + with open(path.parent / txt[i], "a") as f: + f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index 6819088f..e3d2a86a 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -67,8 +67,20 @@ from ultralytics.data.utils import check_det_dataset from ultralytics.nn.autobackend import check_class_names, default_class_names from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder from ultralytics.nn.tasks import DetectionModel, SegmentationModel -from ultralytics.utils import (ARM64, DEFAULT_CFG, LINUX, LOGGER, MACOS, ROOT, WINDOWS, __version__, callbacks, - colorstr, get_default_args, yaml_save) +from ultralytics.utils import ( + ARM64, + DEFAULT_CFG, + LINUX, + LOGGER, + MACOS, + ROOT, + WINDOWS, + __version__, + callbacks, + colorstr, + get_default_args, + yaml_save, +) from ultralytics.utils.checks import check_imgsz, check_is_path_safe, check_requirements, check_version from ultralytics.utils.downloads import attempt_download_asset, get_github_assets from ultralytics.utils.files import file_size, spaces_in_path @@ -79,21 +91,23 @@ from ultralytics.utils.torch_utils import get_latest_opset, select_device, smart def export_formats(): """YOLOv8 export formats.""" import pandas + x = [ - ['PyTorch', '-', '.pt', True, True], - ['TorchScript', 'torchscript', '.torchscript', True, True], - ['ONNX', 'onnx', '.onnx', True, True], - ['OpenVINO', 'openvino', '_openvino_model', True, False], - ['TensorRT', 'engine', '.engine', False, True], - ['CoreML', 'coreml', '.mlpackage', True, False], - ['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True], - ['TensorFlow GraphDef', 'pb', '.pb', True, True], - ['TensorFlow Lite', 'tflite', '.tflite', True, False], - ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', True, False], - ['TensorFlow.js', 'tfjs', '_web_model', True, False], - ['PaddlePaddle', 'paddle', '_paddle_model', True, True], - ['ncnn', 'ncnn', '_ncnn_model', True, True], ] - return pandas.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU']) + ["PyTorch", "-", ".pt", True, True], + ["TorchScript", "torchscript", ".torchscript", True, True], + ["ONNX", "onnx", ".onnx", True, True], + ["OpenVINO", "openvino", "_openvino_model", True, False], + ["TensorRT", "engine", ".engine", False, True], + ["CoreML", "coreml", ".mlpackage", True, False], + ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True], + ["TensorFlow GraphDef", "pb", ".pb", True, True], + ["TensorFlow Lite", "tflite", ".tflite", True, False], + ["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False], + ["TensorFlow.js", "tfjs", "_web_model", True, False], + ["PaddlePaddle", "paddle", "_paddle_model", True, True], + ["ncnn", "ncnn", "_ncnn_model", True, True], + ] + return pandas.DataFrame(x, columns=["Format", "Argument", "Suffix", "CPU", "GPU"]) def gd_outputs(gd): @@ -102,7 +116,7 @@ def gd_outputs(gd): for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef name_list.append(node.name) input_list.extend(node.input) - return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp')) + return sorted(f"{x}:0" for x in list(set(name_list) - set(input_list)) if not x.startswith("NoOp")) def try_export(inner_func): @@ -111,14 +125,14 @@ def try_export(inner_func): def outer_func(*args, **kwargs): """Export a model.""" - prefix = inner_args['prefix'] + prefix = inner_args["prefix"] try: with Profile() as dt: f, model = inner_func(*args, **kwargs) LOGGER.info(f"{prefix} export success ✅ {dt.t:.1f}s, saved as '{f}' ({file_size(f):.1f} MB)") return f, model except Exception as e: - LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}') + LOGGER.info(f"{prefix} export failure ❌ {dt.t:.1f}s: {e}") raise e return outer_func @@ -143,8 +157,8 @@ class Exporter: _callbacks (dict, optional): Dictionary of callback functions. Defaults to None. """ self.args = get_cfg(cfg, overrides) - if self.args.format.lower() in ('coreml', 'mlmodel'): # fix attempt for protobuf<3.20.x errors - os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' # must run before TensorBoard callback + if self.args.format.lower() in ("coreml", "mlmodel"): # fix attempt for protobuf<3.20.x errors + os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # must run before TensorBoard callback self.callbacks = _callbacks or callbacks.get_default_callbacks() callbacks.add_integration_callbacks(self) @@ -152,45 +166,46 @@ class Exporter: @smart_inference_mode() def __call__(self, model=None): """Returns list of exported files/dirs after running callbacks.""" - self.run_callbacks('on_export_start') + self.run_callbacks("on_export_start") t = time.time() fmt = self.args.format.lower() # to lowercase - if fmt in ('tensorrt', 'trt'): # 'engine' aliases - fmt = 'engine' - if fmt in ('mlmodel', 'mlpackage', 'mlprogram', 'apple', 'ios', 'coreml'): # 'coreml' aliases - fmt = 'coreml' - fmts = tuple(export_formats()['Argument'][1:]) # available export formats + if fmt in ("tensorrt", "trt"): # 'engine' aliases + fmt = "engine" + if fmt in ("mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"): # 'coreml' aliases + fmt = "coreml" + fmts = tuple(export_formats()["Argument"][1:]) # available export formats flags = [x == fmt for x in fmts] if sum(flags) != 1: raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}") jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn = flags # export booleans # Device - if fmt == 'engine' and self.args.device is None: - LOGGER.warning('WARNING ⚠️ TensorRT requires GPU export, automatically assigning device=0') - self.args.device = '0' - self.device = select_device('cpu' if self.args.device is None else self.args.device) + if fmt == "engine" and self.args.device is None: + LOGGER.warning("WARNING ⚠️ TensorRT requires GPU export, automatically assigning device=0") + self.args.device = "0" + self.device = select_device("cpu" if self.args.device is None else self.args.device) # Checks - if not hasattr(model, 'names'): + if not hasattr(model, "names"): model.names = default_class_names() model.names = check_class_names(model.names) - if self.args.half and onnx and self.device.type == 'cpu': - LOGGER.warning('WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0') + if self.args.half and onnx and self.device.type == "cpu": + LOGGER.warning("WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0") self.args.half = False - assert not self.args.dynamic, 'half=True not compatible with dynamic=True, i.e. use only one.' + assert not self.args.dynamic, "half=True not compatible with dynamic=True, i.e. use only one." self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size if self.args.optimize: assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False" - assert self.device.type == 'cpu', "optimize=True not compatible with cuda devices, i.e. use device='cpu'" + assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'" if edgetpu and not LINUX: - raise SystemError('Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/') + raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/") # Input im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device) file = Path( - getattr(model, 'pt_path', None) or getattr(model, 'yaml_file', None) or model.yaml.get('yaml_file', '')) - if file.suffix in {'.yaml', '.yml'}: + getattr(model, "pt_path", None) or getattr(model, "yaml_file", None) or model.yaml.get("yaml_file", "") + ) + if file.suffix in {".yaml", ".yml"}: file = Path(file.name) # Update model @@ -212,42 +227,48 @@ class Exporter: y = None for _ in range(2): y = model(im) # dry runs - if self.args.half and (engine or onnx) and self.device.type != 'cpu': + if self.args.half and (engine or onnx) and self.device.type != "cpu": im, model = im.half(), model.half() # to FP16 # Filter warnings - warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) # suppress TracerWarning - warnings.filterwarnings('ignore', category=UserWarning) # suppress shape prim::Constant missing ONNX warning - warnings.filterwarnings('ignore', category=DeprecationWarning) # suppress CoreML np.bool deprecation warning + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) # suppress TracerWarning + warnings.filterwarnings("ignore", category=UserWarning) # suppress shape prim::Constant missing ONNX warning + warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress CoreML np.bool deprecation warning # Assign self.im = im self.model = model self.file = file - self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple( - tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y) - self.pretty_name = Path(self.model.yaml.get('yaml_file', self.file)).stem.replace('yolo', 'YOLO') - data = model.args['data'] if hasattr(model, 'args') and isinstance(model.args, dict) else '' + self.output_shape = ( + tuple(y.shape) + if isinstance(y, torch.Tensor) + else tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y) + ) + self.pretty_name = Path(self.model.yaml.get("yaml_file", self.file)).stem.replace("yolo", "YOLO") + data = model.args["data"] if hasattr(model, "args") and isinstance(model.args, dict) else "" description = f'Ultralytics {self.pretty_name} model {f"trained on {data}" if data else ""}' self.metadata = { - 'description': description, - 'author': 'Ultralytics', - 'license': 'AGPL-3.0 https://ultralytics.com/license', - 'date': datetime.now().isoformat(), - 'version': __version__, - 'stride': int(max(model.stride)), - 'task': model.task, - 'batch': self.args.batch, - 'imgsz': self.imgsz, - 'names': model.names} # model metadata - if model.task == 'pose': - self.metadata['kpt_shape'] = model.model[-1].kpt_shape - - LOGGER.info(f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and " - f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)') + "description": description, + "author": "Ultralytics", + "license": "AGPL-3.0 https://ultralytics.com/license", + "date": datetime.now().isoformat(), + "version": __version__, + "stride": int(max(model.stride)), + "task": model.task, + "batch": self.args.batch, + "imgsz": self.imgsz, + "names": model.names, + } # model metadata + if model.task == "pose": + self.metadata["kpt_shape"] = model.model[-1].kpt_shape + + LOGGER.info( + f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and " + f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)' + ) # Exports - f = [''] * len(fmts) # exported filenames + f = [""] * len(fmts) # exported filenames if jit or ncnn: # TorchScript f[0], _ = self.export_torchscript() if engine: # TensorRT required before ONNX @@ -266,7 +287,7 @@ class Exporter: if tflite: f[7], _ = self.export_tflite(keras_model=keras_model, nms=False, agnostic_nms=self.args.agnostic_nms) if edgetpu: - f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f'{self.file.stem}_full_integer_quant.tflite') + f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f"{self.file.stem}_full_integer_quant.tflite") if tfjs: f[9], _ = self.export_tfjs() if paddle: # PaddlePaddle @@ -279,58 +300,65 @@ class Exporter: if any(f): f = str(Path(f[-1])) square = self.imgsz[0] == self.imgsz[1] - s = '' if square else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " \ - f"work. Use export 'imgsz={max(self.imgsz)}' if val is required." - imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '') - predict_data = f'data={data}' if model.task == 'segment' and fmt == 'pb' else '' - q = 'int8' if self.args.int8 else 'half' if self.args.half else '' # quantization - LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)' - f"\nResults saved to {colorstr('bold', file.parent.resolve())}" - f'\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}' - f'\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}' - f'\nVisualize: https://netron.app') - - self.run_callbacks('on_export_end') + s = ( + "" + if square + else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " + f"work. Use export 'imgsz={max(self.imgsz)}' if val is required." + ) + imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(" ", "") + predict_data = f"data={data}" if model.task == "segment" and fmt == "pb" else "" + q = "int8" if self.args.int8 else "half" if self.args.half else "" # quantization + LOGGER.info( + f'\nExport complete ({time.time() - t:.1f}s)' + f"\nResults saved to {colorstr('bold', file.parent.resolve())}" + f'\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}' + f'\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}' + f'\nVisualize: https://netron.app' + ) + + self.run_callbacks("on_export_end") return f # return list of exported files/dirs @try_export - def export_torchscript(self, prefix=colorstr('TorchScript:')): + def export_torchscript(self, prefix=colorstr("TorchScript:")): """YOLOv8 TorchScript model export.""" - LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...') - f = self.file.with_suffix('.torchscript') + LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...") + f = self.file.with_suffix(".torchscript") ts = torch.jit.trace(self.model, self.im, strict=False) - extra_files = {'config.txt': json.dumps(self.metadata)} # torch._C.ExtraFilesMap() + extra_files = {"config.txt": json.dumps(self.metadata)} # torch._C.ExtraFilesMap() if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html - LOGGER.info(f'{prefix} optimizing for mobile...') + LOGGER.info(f"{prefix} optimizing for mobile...") from torch.utils.mobile_optimizer import optimize_for_mobile + optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files) else: ts.save(str(f), _extra_files=extra_files) return f, None @try_export - def export_onnx(self, prefix=colorstr('ONNX:')): + def export_onnx(self, prefix=colorstr("ONNX:")): """YOLOv8 ONNX export.""" - requirements = ['onnx>=1.12.0'] + requirements = ["onnx>=1.12.0"] if self.args.simplify: - requirements += ['onnxsim>=0.4.33', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'] + requirements += ["onnxsim>=0.4.33", "onnxruntime-gpu" if torch.cuda.is_available() else "onnxruntime"] check_requirements(requirements) import onnx # noqa opset_version = self.args.opset or get_latest_opset() - LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...') - f = str(self.file.with_suffix('.onnx')) + LOGGER.info(f"\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...") + f = str(self.file.with_suffix(".onnx")) - output_names = ['output0', 'output1'] if isinstance(self.model, SegmentationModel) else ['output0'] + output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"] dynamic = self.args.dynamic if dynamic: - dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640) + dynamic = {"images": {0: "batch", 2: "height", 3: "width"}} # shape(1,3,640,640) if isinstance(self.model, SegmentationModel): - dynamic['output0'] = {0: 'batch', 2: 'anchors'} # shape(1, 116, 8400) - dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160) + dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 116, 8400) + dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160) elif isinstance(self.model, DetectionModel): - dynamic['output0'] = {0: 'batch', 2: 'anchors'} # shape(1, 84, 8400) + dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, 8400) torch.onnx.export( self.model.cpu() if dynamic else self.model, # dynamic=True only compatible with cpu @@ -339,9 +367,10 @@ class Exporter: verbose=False, opset_version=opset_version, do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False - input_names=['images'], + input_names=["images"], output_names=output_names, - dynamic_axes=dynamic or None) + dynamic_axes=dynamic or None, + ) # Checks model_onnx = onnx.load(f) # load onnx model @@ -352,12 +381,12 @@ class Exporter: try: import onnxsim - LOGGER.info(f'{prefix} simplifying with onnxsim {onnxsim.__version__}...') + LOGGER.info(f"{prefix} simplifying with onnxsim {onnxsim.__version__}...") # subprocess.run(f'onnxsim "{f}" "{f}"', shell=True) model_onnx, check = onnxsim.simplify(model_onnx) - assert check, 'Simplified ONNX model could not be validated' + assert check, "Simplified ONNX model could not be validated" except Exception as e: - LOGGER.info(f'{prefix} simplifier failure: {e}') + LOGGER.info(f"{prefix} simplifier failure: {e}") # Metadata for k, v in self.metadata.items(): @@ -368,58 +397,56 @@ class Exporter: return f, model_onnx @try_export - def export_openvino(self, prefix=colorstr('OpenVINO:')): + def export_openvino(self, prefix=colorstr("OpenVINO:")): """YOLOv8 OpenVINO export.""" - check_requirements('openvino-dev>=2023.0') # requires openvino-dev: https://pypi.org/project/openvino-dev/ + check_requirements("openvino-dev>=2023.0") # requires openvino-dev: https://pypi.org/project/openvino-dev/ import openvino.runtime as ov # noqa from openvino.tools import mo # noqa - LOGGER.info(f'\n{prefix} starting export with openvino {ov.__version__}...') - f = str(self.file).replace(self.file.suffix, f'_openvino_model{os.sep}') - fq = str(self.file).replace(self.file.suffix, f'_int8_openvino_model{os.sep}') - f_onnx = self.file.with_suffix('.onnx') - f_ov = str(Path(f) / self.file.with_suffix('.xml').name) - fq_ov = str(Path(fq) / self.file.with_suffix('.xml').name) + LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...") + f = str(self.file).replace(self.file.suffix, f"_openvino_model{os.sep}") + fq = str(self.file).replace(self.file.suffix, f"_int8_openvino_model{os.sep}") + f_onnx = self.file.with_suffix(".onnx") + f_ov = str(Path(f) / self.file.with_suffix(".xml").name) + fq_ov = str(Path(fq) / self.file.with_suffix(".xml").name) def serialize(ov_model, file): """Set RT info, serialize and save metadata YAML.""" - ov_model.set_rt_info('YOLOv8', ['model_info', 'model_type']) - ov_model.set_rt_info(True, ['model_info', 'reverse_input_channels']) - ov_model.set_rt_info(114, ['model_info', 'pad_value']) - ov_model.set_rt_info([255.0], ['model_info', 'scale_values']) - ov_model.set_rt_info(self.args.iou, ['model_info', 'iou_threshold']) - ov_model.set_rt_info([v.replace(' ', '_') for v in self.model.names.values()], ['model_info', 'labels']) - if self.model.task != 'classify': - ov_model.set_rt_info('fit_to_window_letterbox', ['model_info', 'resize_type']) + ov_model.set_rt_info("YOLOv8", ["model_info", "model_type"]) + ov_model.set_rt_info(True, ["model_info", "reverse_input_channels"]) + ov_model.set_rt_info(114, ["model_info", "pad_value"]) + ov_model.set_rt_info([255.0], ["model_info", "scale_values"]) + ov_model.set_rt_info(self.args.iou, ["model_info", "iou_threshold"]) + ov_model.set_rt_info([v.replace(" ", "_") for v in self.model.names.values()], ["model_info", "labels"]) + if self.model.task != "classify": + ov_model.set_rt_info("fit_to_window_letterbox", ["model_info", "resize_type"]) ov.serialize(ov_model, file) # save - yaml_save(Path(file).parent / 'metadata.yaml', self.metadata) # add metadata.yaml + yaml_save(Path(file).parent / "metadata.yaml", self.metadata) # add metadata.yaml - ov_model = mo.convert_model(f_onnx, - model_name=self.pretty_name, - framework='onnx', - compress_to_fp16=self.args.half) # export + ov_model = mo.convert_model( + f_onnx, model_name=self.pretty_name, framework="onnx", compress_to_fp16=self.args.half + ) # export if self.args.int8: assert self.args.data, "INT8 export requires a data argument for calibration, i.e. 'data=coco8.yaml'" - check_requirements('nncf>=2.5.0') + check_requirements("nncf>=2.5.0") import nncf def transform_fn(data_item): """Quantization transform function.""" - im = data_item['img'].numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0 + im = data_item["img"].numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0 return np.expand_dims(im, 0) if im.ndim == 3 else im # Generate calibration data for integer quantization LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'") data = check_det_dataset(self.args.data) - dataset = YOLODataset(data['val'], data=data, imgsz=self.imgsz[0], augment=False) + dataset = YOLODataset(data["val"], data=data, imgsz=self.imgsz[0], augment=False) quantization_dataset = nncf.Dataset(dataset, transform_fn) - ignored_scope = nncf.IgnoredScope(types=['Multiply', 'Subtract', 'Sigmoid']) # ignore operation - quantized_ov_model = nncf.quantize(ov_model, - quantization_dataset, - preset=nncf.QuantizationPreset.MIXED, - ignored_scope=ignored_scope) + ignored_scope = nncf.IgnoredScope(types=["Multiply", "Subtract", "Sigmoid"]) # ignore operation + quantized_ov_model = nncf.quantize( + ov_model, quantization_dataset, preset=nncf.QuantizationPreset.MIXED, ignored_scope=ignored_scope + ) serialize(quantized_ov_model, fq_ov) return fq, None @@ -427,48 +454,49 @@ class Exporter: return f, None @try_export - def export_paddle(self, prefix=colorstr('PaddlePaddle:')): + def export_paddle(self, prefix=colorstr("PaddlePaddle:")): """YOLOv8 Paddle export.""" - check_requirements(('paddlepaddle', 'x2paddle')) + check_requirements(("paddlepaddle", "x2paddle")) import x2paddle # noqa from x2paddle.convert import pytorch2paddle # noqa - LOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...') - f = str(self.file).replace(self.file.suffix, f'_paddle_model{os.sep}') + LOGGER.info(f"\n{prefix} starting export with X2Paddle {x2paddle.__version__}...") + f = str(self.file).replace(self.file.suffix, f"_paddle_model{os.sep}") - pytorch2paddle(module=self.model, save_dir=f, jit_type='trace', input_examples=[self.im]) # export - yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml + pytorch2paddle(module=self.model, save_dir=f, jit_type="trace", input_examples=[self.im]) # export + yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml return f, None @try_export - def export_ncnn(self, prefix=colorstr('ncnn:')): + def export_ncnn(self, prefix=colorstr("ncnn:")): """ YOLOv8 ncnn export using PNNX https://github.com/pnnx/pnnx. """ - check_requirements('git+https://github.com/Tencent/ncnn.git' if ARM64 else 'ncnn') # requires ncnn + check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires ncnn import ncnn # noqa - LOGGER.info(f'\n{prefix} starting export with ncnn {ncnn.__version__}...') - f = Path(str(self.file).replace(self.file.suffix, f'_ncnn_model{os.sep}')) - f_ts = self.file.with_suffix('.torchscript') + LOGGER.info(f"\n{prefix} starting export with ncnn {ncnn.__version__}...") + f = Path(str(self.file).replace(self.file.suffix, f"_ncnn_model{os.sep}")) + f_ts = self.file.with_suffix(".torchscript") - name = Path('pnnx.exe' if WINDOWS else 'pnnx') # PNNX filename + name = Path("pnnx.exe" if WINDOWS else "pnnx") # PNNX filename pnnx = name if name.is_file() else ROOT / name if not pnnx.is_file(): LOGGER.warning( - f'{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from ' - 'https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory ' - f'or in {ROOT}. See PNNX repo for full installation instructions.') - system = ['macos'] if MACOS else ['windows'] if WINDOWS else ['ubuntu', 'linux'] # operating system + f"{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from " + "https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory " + f"or in {ROOT}. See PNNX repo for full installation instructions." + ) + system = ["macos"] if MACOS else ["windows"] if WINDOWS else ["ubuntu", "linux"] # operating system try: - _, assets = get_github_assets(repo='pnnx/pnnx', retry=True) + _, assets = get_github_assets(repo="pnnx/pnnx", retry=True) url = [x for x in assets if any(s in x for s in system)][0] except Exception as e: - url = f'https://github.com/pnnx/pnnx/releases/download/20231127/pnnx-20231127-{system[0]}.zip' - LOGGER.warning(f'{prefix} WARNING ⚠️ PNNX GitHub assets not found: {e}, using default {url}') - asset = attempt_download_asset(url, repo='pnnx/pnnx', release='latest') + url = f"https://github.com/pnnx/pnnx/releases/download/20231127/pnnx-20231127-{system[0]}.zip" + LOGGER.warning(f"{prefix} WARNING ⚠️ PNNX GitHub assets not found: {e}, using default {url}") + asset = attempt_download_asset(url, repo="pnnx/pnnx", release="latest") if check_is_path_safe(Path.cwd(), asset): # avoid path traversal security vulnerability - unzip_dir = Path(asset).with_suffix('') + unzip_dir = Path(asset).with_suffix("") (unzip_dir / name).rename(pnnx) # move binary to ROOT shutil.rmtree(unzip_dir) # delete unzip dir Path(asset).unlink() # delete zip @@ -477,53 +505,56 @@ class Exporter: ncnn_args = [ f'ncnnparam={f / "model.ncnn.param"}', f'ncnnbin={f / "model.ncnn.bin"}', - f'ncnnpy={f / "model_ncnn.py"}', ] + f'ncnnpy={f / "model_ncnn.py"}', + ] pnnx_args = [ f'pnnxparam={f / "model.pnnx.param"}', f'pnnxbin={f / "model.pnnx.bin"}', f'pnnxpy={f / "model_pnnx.py"}', - f'pnnxonnx={f / "model.pnnx.onnx"}', ] + f'pnnxonnx={f / "model.pnnx.onnx"}', + ] cmd = [ str(pnnx), str(f_ts), *ncnn_args, *pnnx_args, - f'fp16={int(self.args.half)}', - f'device={self.device.type}', - f'inputshape="{[self.args.batch, 3, *self.imgsz]}"', ] + f"fp16={int(self.args.half)}", + f"device={self.device.type}", + f'inputshape="{[self.args.batch, 3, *self.imgsz]}"', + ] f.mkdir(exist_ok=True) # make ncnn_model directory LOGGER.info(f"{prefix} running '{' '.join(cmd)}'") subprocess.run(cmd, check=True) # Remove debug files - pnnx_files = [x.split('=')[-1] for x in pnnx_args] - for f_debug in ('debug.bin', 'debug.param', 'debug2.bin', 'debug2.param', *pnnx_files): + pnnx_files = [x.split("=")[-1] for x in pnnx_args] + for f_debug in ("debug.bin", "debug.param", "debug2.bin", "debug2.param", *pnnx_files): Path(f_debug).unlink(missing_ok=True) - yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml + yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml return str(f), None @try_export - def export_coreml(self, prefix=colorstr('CoreML:')): + def export_coreml(self, prefix=colorstr("CoreML:")): """YOLOv8 CoreML export.""" - mlmodel = self.args.format.lower() == 'mlmodel' # legacy *.mlmodel export format requested - check_requirements('coremltools>=6.0,<=6.2' if mlmodel else 'coremltools>=7.0') + mlmodel = self.args.format.lower() == "mlmodel" # legacy *.mlmodel export format requested + check_requirements("coremltools>=6.0,<=6.2" if mlmodel else "coremltools>=7.0") import coremltools as ct # noqa - LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...') - f = self.file.with_suffix('.mlmodel' if mlmodel else '.mlpackage') + LOGGER.info(f"\n{prefix} starting export with coremltools {ct.__version__}...") + f = self.file.with_suffix(".mlmodel" if mlmodel else ".mlpackage") if f.is_dir(): shutil.rmtree(f) bias = [0.0, 0.0, 0.0] scale = 1 / 255 classifier_config = None - if self.model.task == 'classify': + if self.model.task == "classify": classifier_config = ct.ClassifierConfig(list(self.model.names.values())) if self.args.nms else None model = self.model - elif self.model.task == 'detect': + elif self.model.task == "detect": model = IOSDetectModel(self.model, self.im) if self.args.nms else self.model else: if self.args.nms: @@ -532,69 +563,73 @@ class Exporter: model = self.model ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model - ct_model = ct.convert(ts, - inputs=[ct.ImageType('image', shape=self.im.shape, scale=scale, bias=bias)], - classifier_config=classifier_config, - convert_to='neuralnetwork' if mlmodel else 'mlprogram') - bits, mode = (8, 'kmeans') if self.args.int8 else (16, 'linear') if self.args.half else (32, None) + ct_model = ct.convert( + ts, + inputs=[ct.ImageType("image", shape=self.im.shape, scale=scale, bias=bias)], + classifier_config=classifier_config, + convert_to="neuralnetwork" if mlmodel else "mlprogram", + ) + bits, mode = (8, "kmeans") if self.args.int8 else (16, "linear") if self.args.half else (32, None) if bits < 32: - if 'kmeans' in mode: - check_requirements('scikit-learn') # scikit-learn package required for k-means quantization + if "kmeans" in mode: + check_requirements("scikit-learn") # scikit-learn package required for k-means quantization if mlmodel: ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode) elif bits == 8: # mlprogram already quantized to FP16 import coremltools.optimize.coreml as cto - op_config = cto.OpPalettizerConfig(mode='kmeans', nbits=bits, weight_threshold=512) + + op_config = cto.OpPalettizerConfig(mode="kmeans", nbits=bits, weight_threshold=512) config = cto.OptimizationConfig(global_config=op_config) ct_model = cto.palettize_weights(ct_model, config=config) - if self.args.nms and self.model.task == 'detect': + if self.args.nms and self.model.task == "detect": if mlmodel: import platform # coremltools<=6.2 NMS export requires Python<3.11 - check_version(platform.python_version(), '<3.11', name='Python ', hard=True) + check_version(platform.python_version(), "<3.11", name="Python ", hard=True) weights_dir = None else: ct_model.save(str(f)) # save otherwise weights_dir does not exist - weights_dir = str(f / 'Data/com.apple.CoreML/weights') + weights_dir = str(f / "Data/com.apple.CoreML/weights") ct_model = self._pipeline_coreml(ct_model, weights_dir=weights_dir) m = self.metadata # metadata dict - ct_model.short_description = m.pop('description') - ct_model.author = m.pop('author') - ct_model.license = m.pop('license') - ct_model.version = m.pop('version') + ct_model.short_description = m.pop("description") + ct_model.author = m.pop("author") + ct_model.license = m.pop("license") + ct_model.version = m.pop("version") ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()}) try: ct_model.save(str(f)) # save *.mlpackage except Exception as e: LOGGER.warning( - f'{prefix} WARNING ⚠️ CoreML export to *.mlpackage failed ({e}), reverting to *.mlmodel export. ' - f'Known coremltools Python 3.11 and Windows bugs https://github.com/apple/coremltools/issues/1928.') - f = f.with_suffix('.mlmodel') + f"{prefix} WARNING ⚠️ CoreML export to *.mlpackage failed ({e}), reverting to *.mlmodel export. " + f"Known coremltools Python 3.11 and Windows bugs https://github.com/apple/coremltools/issues/1928." + ) + f = f.with_suffix(".mlmodel") ct_model.save(str(f)) return f, ct_model @try_export - def export_engine(self, prefix=colorstr('TensorRT:')): + def export_engine(self, prefix=colorstr("TensorRT:")): """YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt.""" - assert self.im.device.type != 'cpu', "export running on CPU but must be on GPU, i.e. use 'device=0'" + assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'" f_onnx, _ = self.export_onnx() # run before trt import https://github.com/ultralytics/ultralytics/issues/7016 try: import tensorrt as trt # noqa except ImportError: if LINUX: - check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com') + check_requirements("nvidia-tensorrt", cmds="-U --index-url https://pypi.ngc.nvidia.com") import tensorrt as trt # noqa - check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0 + check_version(trt.__version__, "7.0.0", hard=True) # require tensorrt>=7.0.0 self.args.simplify = True - LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...') - assert Path(f_onnx).exists(), f'failed to export ONNX file: {f_onnx}' - f = self.file.with_suffix('.engine') # TensorRT engine file + LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...") + assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}" + f = self.file.with_suffix(".engine") # TensorRT engine file logger = trt.Logger(trt.Logger.INFO) if self.args.verbose: logger.min_severity = trt.Logger.Severity.VERBOSE @@ -604,11 +639,11 @@ class Exporter: config.max_workspace_size = self.args.workspace * 1 << 30 # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice - flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) network = builder.create_network(flag) parser = trt.OnnxParser(network, logger) if not parser.parse_from_file(f_onnx): - raise RuntimeError(f'failed to load ONNX file: {f_onnx}') + raise RuntimeError(f"failed to load ONNX file: {f_onnx}") inputs = [network.get_input(i) for i in range(network.num_inputs)] outputs = [network.get_output(i) for i in range(network.num_outputs)] @@ -627,7 +662,8 @@ class Exporter: config.add_optimization_profile(profile) LOGGER.info( - f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}') + f"{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}" + ) if builder.platform_has_fast_fp16 and self.args.half: config.set_flag(trt.BuilderFlag.FP16) @@ -635,10 +671,10 @@ class Exporter: torch.cuda.empty_cache() # Write file - with builder.build_engine(network, config) as engine, open(f, 'wb') as t: + with builder.build_engine(network, config) as engine, open(f, "wb") as t: # Metadata meta = json.dumps(self.metadata) - t.write(len(meta).to_bytes(4, byteorder='little', signed=True)) + t.write(len(meta).to_bytes(4, byteorder="little", signed=True)) t.write(meta.encode()) # Model t.write(engine.serialize()) @@ -646,7 +682,7 @@ class Exporter: return f, None @try_export - def export_saved_model(self, prefix=colorstr('TensorFlow SavedModel:')): + def export_saved_model(self, prefix=colorstr("TensorFlow SavedModel:")): """YOLOv8 TensorFlow SavedModel export.""" cuda = torch.cuda.is_available() try: @@ -655,44 +691,55 @@ class Exporter: check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}") import tensorflow as tf # noqa check_requirements( - ('onnx', 'onnx2tf>=1.15.4,<=1.17.5', 'sng4onnx>=1.0.1', 'onnxsim>=0.4.33', 'onnx_graphsurgeon>=0.3.26', - 'tflite_support', 'onnxruntime-gpu' if cuda else 'onnxruntime'), - cmds='--extra-index-url https://pypi.ngc.nvidia.com') # onnx_graphsurgeon only on NVIDIA - - LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') - check_version(tf.__version__, - '<=2.13.1', - name='tensorflow', - verbose=True, - msg='https://github.com/ultralytics/ultralytics/issues/5161') - f = Path(str(self.file).replace(self.file.suffix, '_saved_model')) + ( + "onnx", + "onnx2tf>=1.15.4,<=1.17.5", + "sng4onnx>=1.0.1", + "onnxsim>=0.4.33", + "onnx_graphsurgeon>=0.3.26", + "tflite_support", + "onnxruntime-gpu" if cuda else "onnxruntime", + ), + cmds="--extra-index-url https://pypi.ngc.nvidia.com", + ) # onnx_graphsurgeon only on NVIDIA + + LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...") + check_version( + tf.__version__, + "<=2.13.1", + name="tensorflow", + verbose=True, + msg="https://github.com/ultralytics/ultralytics/issues/5161", + ) + f = Path(str(self.file).replace(self.file.suffix, "_saved_model")) if f.is_dir(): import shutil + shutil.rmtree(f) # delete output folder # Pre-download calibration file to fix https://github.com/PINTO0309/onnx2tf/issues/545 - onnx2tf_file = Path('calibration_image_sample_data_20x128x128x3_float32.npy') + onnx2tf_file = Path("calibration_image_sample_data_20x128x128x3_float32.npy") if not onnx2tf_file.exists(): - attempt_download_asset(f'{onnx2tf_file}.zip', unzip=True, delete=True) + attempt_download_asset(f"{onnx2tf_file}.zip", unzip=True, delete=True) # Export to ONNX self.args.simplify = True f_onnx, _ = self.export_onnx() # Export to TF - tmp_file = f / 'tmp_tflite_int8_calibration_images.npy' # int8 calibration images file + tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file if self.args.int8: - verbosity = '--verbosity info' + verbosity = "--verbosity info" if self.args.data: # Generate calibration data for integer quantization LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'") data = check_det_dataset(self.args.data) - dataset = YOLODataset(data['val'], data=data, imgsz=self.imgsz[0], augment=False) + dataset = YOLODataset(data["val"], data=data, imgsz=self.imgsz[0], augment=False) images = [] for i, batch in enumerate(dataset): if i >= 100: # maximum number of calibration images break - im = batch['img'].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC + im = batch["img"].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC images.append(im) f.mkdir() images = torch.cat(images, 0).float() @@ -701,38 +748,38 @@ class Exporter: np.save(str(tmp_file), images.numpy()) # BHWC int8 = f'-oiqt -qt per-tensor -cind images "{tmp_file}" "[[[[0, 0, 0]]]]" "[[[[255, 255, 255]]]]"' else: - int8 = '-oiqt -qt per-tensor' + int8 = "-oiqt -qt per-tensor" else: - verbosity = '--non_verbose' - int8 = '' + verbosity = "--non_verbose" + int8 = "" cmd = f'onnx2tf -i "{f_onnx}" -o "{f}" -nuo {verbosity} {int8}'.strip() LOGGER.info(f"{prefix} running '{cmd}'") subprocess.run(cmd, shell=True) - yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml + yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml # Remove/rename TFLite models if self.args.int8: tmp_file.unlink(missing_ok=True) - for file in f.rglob('*_dynamic_range_quant.tflite'): - file.rename(file.with_name(file.stem.replace('_dynamic_range_quant', '_int8') + file.suffix)) - for file in f.rglob('*_integer_quant_with_int16_act.tflite'): + for file in f.rglob("*_dynamic_range_quant.tflite"): + file.rename(file.with_name(file.stem.replace("_dynamic_range_quant", "_int8") + file.suffix)) + for file in f.rglob("*_integer_quant_with_int16_act.tflite"): file.unlink() # delete extra fp16 activation TFLite files # Add TFLite metadata - for file in f.rglob('*.tflite'): - f.unlink() if 'quant_with_int16_act.tflite' in str(f) else self._add_tflite_metadata(file) + for file in f.rglob("*.tflite"): + f.unlink() if "quant_with_int16_act.tflite" in str(f) else self._add_tflite_metadata(file) return str(f), tf.saved_model.load(f, tags=None, options=None) # load saved_model as Keras model @try_export - def export_pb(self, keras_model, prefix=colorstr('TensorFlow GraphDef:')): + def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")): """YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow.""" import tensorflow as tf # noqa from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa - LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') - f = self.file.with_suffix('.pb') + LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...") + f = self.file.with_suffix(".pb") m = tf.function(lambda x: keras_model(x)) # full model m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)) @@ -742,40 +789,43 @@ class Exporter: return f, None @try_export - def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')): + def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr("TensorFlow Lite:")): """YOLOv8 TensorFlow Lite export.""" import tensorflow as tf # noqa - LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') - saved_model = Path(str(self.file).replace(self.file.suffix, '_saved_model')) + LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...") + saved_model = Path(str(self.file).replace(self.file.suffix, "_saved_model")) if self.args.int8: - f = saved_model / f'{self.file.stem}_int8.tflite' # fp32 in/out + f = saved_model / f"{self.file.stem}_int8.tflite" # fp32 in/out elif self.args.half: - f = saved_model / f'{self.file.stem}_float16.tflite' # fp32 in/out + f = saved_model / f"{self.file.stem}_float16.tflite" # fp32 in/out else: - f = saved_model / f'{self.file.stem}_float32.tflite' + f = saved_model / f"{self.file.stem}_float32.tflite" return str(f), None @try_export - def export_edgetpu(self, tflite_model='', prefix=colorstr('Edge TPU:')): + def export_edgetpu(self, tflite_model="", prefix=colorstr("Edge TPU:")): """YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/.""" - LOGGER.warning(f'{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185') + LOGGER.warning(f"{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185") - cmd = 'edgetpu_compiler --version' - help_url = 'https://coral.ai/docs/edgetpu/compiler/' - assert LINUX, f'export only supported on Linux. See {help_url}' + cmd = "edgetpu_compiler --version" + help_url = "https://coral.ai/docs/edgetpu/compiler/" + assert LINUX, f"export only supported on Linux. See {help_url}" if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0: - LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}') - sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system - for c in ('curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -', - 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | ' - 'sudo tee /etc/apt/sources.list.d/coral-edgetpu.list', 'sudo apt-get update', - 'sudo apt-get install edgetpu-compiler'): - subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True) + LOGGER.info(f"\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}") + sudo = subprocess.run("sudo --version >/dev/null", shell=True).returncode == 0 # sudo installed on system + for c in ( + "curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -", + 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | ' + "sudo tee /etc/apt/sources.list.d/coral-edgetpu.list", + "sudo apt-get update", + "sudo apt-get install edgetpu-compiler", + ): + subprocess.run(c if sudo else c.replace("sudo ", ""), shell=True, check=True) ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1] - LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...') - f = str(tflite_model).replace('.tflite', '_edgetpu.tflite') # Edge TPU model + LOGGER.info(f"\n{prefix} starting export with Edge TPU compiler {ver}...") + f = str(tflite_model).replace(".tflite", "_edgetpu.tflite") # Edge TPU model cmd = f'edgetpu_compiler -s -d -k 10 --out_dir "{Path(f).parent}" "{tflite_model}"' LOGGER.info(f"{prefix} running '{cmd}'") @@ -784,30 +834,30 @@ class Exporter: return f, None @try_export - def export_tfjs(self, prefix=colorstr('TensorFlow.js:')): + def export_tfjs(self, prefix=colorstr("TensorFlow.js:")): """YOLOv8 TensorFlow.js export.""" # JAX bug requiring install constraints in https://github.com/google/jax/issues/18978 - check_requirements(['jax<=0.4.21', 'jaxlib<=0.4.21', 'tensorflowjs']) + check_requirements(["jax<=0.4.21", "jaxlib<=0.4.21", "tensorflowjs"]) import tensorflow as tf import tensorflowjs as tfjs # noqa - LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...') - f = str(self.file).replace(self.file.suffix, '_web_model') # js dir - f_pb = str(self.file.with_suffix('.pb')) # *.pb path + LOGGER.info(f"\n{prefix} starting export with tensorflowjs {tfjs.__version__}...") + f = str(self.file).replace(self.file.suffix, "_web_model") # js dir + f_pb = str(self.file.with_suffix(".pb")) # *.pb path gd = tf.Graph().as_graph_def() # TF GraphDef - with open(f_pb, 'rb') as file: + with open(f_pb, "rb") as file: gd.ParseFromString(file.read()) - outputs = ','.join(gd_outputs(gd)) - LOGGER.info(f'\n{prefix} output node names: {outputs}') + outputs = ",".join(gd_outputs(gd)) + LOGGER.info(f"\n{prefix} output node names: {outputs}") - quantization = '--quantize_float16' if self.args.half else '--quantize_uint8' if self.args.int8 else '' + quantization = "--quantize_float16" if self.args.half else "--quantize_uint8" if self.args.int8 else "" with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_: # exporter can not handle spaces in path cmd = f'tensorflowjs_converter --input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"' LOGGER.info(f"{prefix} running '{cmd}'") subprocess.run(cmd, shell=True) - if ' ' in f: + if " " in f: LOGGER.warning(f"{prefix} WARNING ⚠️ your model may not work correctly with spaces in path '{f}'.") # f_json = Path(f) / 'model.json' # *.json path @@ -824,7 +874,7 @@ class Exporter: # f_json.read_text(), # ) # j.write(subst) - yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml + yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml return f, None def _add_tflite_metadata(self, file): @@ -835,14 +885,14 @@ class Exporter: # Create model info model_meta = _metadata_fb.ModelMetadataT() - model_meta.name = self.metadata['description'] - model_meta.version = self.metadata['version'] - model_meta.author = self.metadata['author'] - model_meta.license = self.metadata['license'] + model_meta.name = self.metadata["description"] + model_meta.version = self.metadata["version"] + model_meta.author = self.metadata["author"] + model_meta.license = self.metadata["license"] # Label file - tmp_file = Path(file).parent / 'temp_meta.txt' - with open(tmp_file, 'w') as f: + tmp_file = Path(file).parent / "temp_meta.txt" + with open(tmp_file, "w") as f: f.write(str(self.metadata)) label_file = _metadata_fb.AssociatedFileT() @@ -851,8 +901,8 @@ class Exporter: # Create input info input_meta = _metadata_fb.TensorMetadataT() - input_meta.name = 'image' - input_meta.description = 'Input image to be detected.' + input_meta.name = "image" + input_meta.description = "Input image to be detected." input_meta.content = _metadata_fb.ContentT() input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT() input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB @@ -860,19 +910,19 @@ class Exporter: # Create output info output1 = _metadata_fb.TensorMetadataT() - output1.name = 'output' - output1.description = 'Coordinates of detected objects, class labels, and confidence score' + output1.name = "output" + output1.description = "Coordinates of detected objects, class labels, and confidence score" output1.associatedFiles = [label_file] - if self.model.task == 'segment': + if self.model.task == "segment": output2 = _metadata_fb.TensorMetadataT() - output2.name = 'output' - output2.description = 'Mask protos' + output2.name = "output" + output2.description = "Mask protos" output2.associatedFiles = [label_file] # Create subgraph info subgraph = _metadata_fb.SubGraphMetadataT() subgraph.inputTensorMetadata = [input_meta] - subgraph.outputTensorMetadata = [output1, output2] if self.model.task == 'segment' else [output1] + subgraph.outputTensorMetadata = [output1, output2] if self.model.task == "segment" else [output1] model_meta.subgraphMetadata = [subgraph] b = flatbuffers.Builder(0) @@ -885,11 +935,11 @@ class Exporter: populator.populate() tmp_file.unlink() - def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr('CoreML Pipeline:')): + def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr("CoreML Pipeline:")): """YOLOv8 CoreML pipeline.""" import coremltools as ct # noqa - LOGGER.info(f'{prefix} starting pipeline with coremltools {ct.__version__}...') + LOGGER.info(f"{prefix} starting pipeline with coremltools {ct.__version__}...") _, _, h, w = list(self.im.shape) # BCHW # Output shapes @@ -897,8 +947,9 @@ class Exporter: out0, out1 = iter(spec.description.output) if MACOS: from PIL import Image - img = Image.new('RGB', (w, h)) # w=192, h=320 - out = model.predict({'image': img}) + + img = Image.new("RGB", (w, h)) # w=192, h=320 + out = model.predict({"image": img}) out0_shape = out[out0.name].shape # (3780, 80) out1_shape = out[out1.name].shape # (3780, 4) else: # linux and windows can not run model.predict(), get sizes from PyTorch model output y @@ -906,11 +957,11 @@ class Exporter: out1_shape = self.output_shape[2], 4 # (3780, 4) # Checks - names = self.metadata['names'] + names = self.metadata["names"] nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height _, nc = out0_shape # number of anchors, number of classes # _, nc = out0.type.multiArrayType.shape - assert len(names) == nc, f'{len(names)} names found for nc={nc}' # check + assert len(names) == nc, f"{len(names)} names found for nc={nc}" # check # Define output shapes (missing) out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80) @@ -944,8 +995,8 @@ class Exporter: nms_spec.description.output.add() nms_spec.description.output[i].ParseFromString(decoder_output) - nms_spec.description.output[0].name = 'confidence' - nms_spec.description.output[1].name = 'coordinates' + nms_spec.description.output[0].name = "confidence" + nms_spec.description.output[1].name = "coordinates" output_sizes = [nc, 4] for i in range(2): @@ -961,10 +1012,10 @@ class Exporter: nms = nms_spec.nonMaximumSuppression nms.confidenceInputFeatureName = out0.name # 1x507x80 nms.coordinatesInputFeatureName = out1.name # 1x507x4 - nms.confidenceOutputFeatureName = 'confidence' - nms.coordinatesOutputFeatureName = 'coordinates' - nms.iouThresholdInputFeatureName = 'iouThreshold' - nms.confidenceThresholdInputFeatureName = 'confidenceThreshold' + nms.confidenceOutputFeatureName = "confidence" + nms.coordinatesOutputFeatureName = "coordinates" + nms.iouThresholdInputFeatureName = "iouThreshold" + nms.confidenceThresholdInputFeatureName = "confidenceThreshold" nms.iouThreshold = 0.45 nms.confidenceThreshold = 0.25 nms.pickTop.perClass = True @@ -972,10 +1023,14 @@ class Exporter: nms_model = ct.models.MLModel(nms_spec) # 4. Pipeline models together - pipeline = ct.models.pipeline.Pipeline(input_features=[('image', ct.models.datatypes.Array(3, ny, nx)), - ('iouThreshold', ct.models.datatypes.Double()), - ('confidenceThreshold', ct.models.datatypes.Double())], - output_features=['confidence', 'coordinates']) + pipeline = ct.models.pipeline.Pipeline( + input_features=[ + ("image", ct.models.datatypes.Array(3, ny, nx)), + ("iouThreshold", ct.models.datatypes.Double()), + ("confidenceThreshold", ct.models.datatypes.Double()), + ], + output_features=["confidence", "coordinates"], + ) pipeline.add_model(model) pipeline.add_model(nms_model) @@ -986,19 +1041,20 @@ class Exporter: # Update metadata pipeline.spec.specificationVersion = 5 - pipeline.spec.description.metadata.userDefined.update({ - 'IoU threshold': str(nms.iouThreshold), - 'Confidence threshold': str(nms.confidenceThreshold)}) + pipeline.spec.description.metadata.userDefined.update( + {"IoU threshold": str(nms.iouThreshold), "Confidence threshold": str(nms.confidenceThreshold)} + ) # Save the model model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir) - model.input_description['image'] = 'Input image' - model.input_description['iouThreshold'] = f'(optional) IOU threshold override (default: {nms.iouThreshold})' - model.input_description['confidenceThreshold'] = \ - f'(optional) Confidence threshold override (default: {nms.confidenceThreshold})' - model.output_description['confidence'] = 'Boxes × Class confidence (see user-defined metadata "classes")' - model.output_description['coordinates'] = 'Boxes × [x, y, width, height] (relative to image size)' - LOGGER.info(f'{prefix} pipeline success') + model.input_description["image"] = "Input image" + model.input_description["iouThreshold"] = f"(optional) IOU threshold override (default: {nms.iouThreshold})" + model.input_description[ + "confidenceThreshold" + ] = f"(optional) Confidence threshold override (default: {nms.confidenceThreshold})" + model.output_description["confidence"] = 'Boxes × Class confidence (see user-defined metadata "classes")' + model.output_description["coordinates"] = "Boxes × [x, y, width, height] (relative to image size)" + LOGGER.info(f"{prefix} pipeline success") return model def add_callback(self, event: str, callback): diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index 66052f30..01c81682 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -53,7 +53,7 @@ class Model(nn.Module): list(ultralytics.engine.results.Results): The prediction results. """ - def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None: + def __init__(self, model: Union[str, Path] = "yolov8n.pt", task=None) -> None: """ Initializes the YOLO model. @@ -89,7 +89,7 @@ class Model(nn.Module): # Load or create new YOLO model model = checks.check_model_file_from_stem(model) # add suffix, i.e. yolov8n -> yolov8n.pt - if Path(model).suffix in ('.yaml', '.yml'): + if Path(model).suffix in (".yaml", ".yml"): self._new(model, task) else: self._load(model, task) @@ -112,16 +112,20 @@ class Model(nn.Module): def is_triton_model(model): """Is model a Triton Server URL string, i.e. :////""" from urllib.parse import urlsplit + url = urlsplit(model) - return url.netloc and url.path and url.scheme in {'http', 'grpc'} + return url.netloc and url.path and url.scheme in {"http", "grpc"} @staticmethod def is_hub_model(model): """Check if the provided model is a HUB model.""" - return any(( - model.startswith(f'{HUB_WEB_ROOT}/models/'), # i.e. https://hub.ultralytics.com/models/MODEL_ID - [len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID - len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID + return any( + ( + model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID + [len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODELID + len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), + ) + ) # MODELID def _new(self, cfg: str, task=None, model=None, verbose=True): """ @@ -136,9 +140,9 @@ class Model(nn.Module): cfg_dict = yaml_model_load(cfg) self.cfg = cfg self.task = task or guess_model_task(cfg_dict) - self.model = (model or self._smart_load('model'))(cfg_dict, verbose=verbose and RANK == -1) # build model - self.overrides['model'] = self.cfg - self.overrides['task'] = self.task + self.model = (model or self._smart_load("model"))(cfg_dict, verbose=verbose and RANK == -1) # build model + self.overrides["model"] = self.cfg + self.overrides["task"] = self.task # Below added to allow export from YAMLs self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args) @@ -153,9 +157,9 @@ class Model(nn.Module): task (str | None): model task """ suffix = Path(weights).suffix - if suffix == '.pt': + if suffix == ".pt": self.model, self.ckpt = attempt_load_one_weight(weights) - self.task = self.model.args['task'] + self.task = self.model.args["task"] self.overrides = self.model.args = self._reset_ckpt_args(self.model.args) self.ckpt_path = self.model.pt_path else: @@ -163,12 +167,12 @@ class Model(nn.Module): self.model, self.ckpt = weights, None self.task = task or guess_model_task(weights) self.ckpt_path = weights - self.overrides['model'] = weights - self.overrides['task'] = self.task + self.overrides["model"] = weights + self.overrides["task"] = self.task def _check_is_pytorch_model(self): """Raises TypeError is model is not a PyTorch model.""" - pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt' + pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt" pt_module = isinstance(self.model, nn.Module) if not (pt_module or pt_str): raise TypeError( @@ -176,19 +180,20 @@ class Model(nn.Module): f"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported " f"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, " f"i.e. 'yolo predict model=yolov8n.onnx'.\nTo run CUDA or MPS inference please pass the device " - f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'") + f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'" + ) def reset_weights(self): """Resets the model modules parameters to randomly initialized values, losing all training information.""" self._check_is_pytorch_model() for m in self.model.modules(): - if hasattr(m, 'reset_parameters'): + if hasattr(m, "reset_parameters"): m.reset_parameters() for p in self.model.parameters(): p.requires_grad = True return self - def load(self, weights='yolov8n.pt'): + def load(self, weights="yolov8n.pt"): """Transfers parameters with matching names and shapes from 'weights' to model.""" self._check_is_pytorch_model() if isinstance(weights, (str, Path)): @@ -226,8 +231,8 @@ class Model(nn.Module): Returns: (List[torch.Tensor]): A list of image embeddings. """ - if not kwargs.get('embed'): - kwargs['embed'] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed + if not kwargs.get("embed"): + kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed return self.predict(source, stream, **kwargs) def predict(self, source=None, stream=False, predictor=None, **kwargs): @@ -249,21 +254,22 @@ class Model(nn.Module): source = ASSETS LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") - is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any( - x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track')) + is_cli = (sys.argv[0].endswith("yolo") or sys.argv[0].endswith("ultralytics")) and any( + x in sys.argv for x in ("predict", "track", "mode=predict", "mode=track") + ) - custom = {'conf': 0.25, 'save': is_cli} # method defaults - args = {**self.overrides, **custom, **kwargs, 'mode': 'predict'} # highest priority args on the right - prompts = args.pop('prompts', None) # for SAM-type models + custom = {"conf": 0.25, "save": is_cli} # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "predict"} # highest priority args on the right + prompts = args.pop("prompts", None) # for SAM-type models if not self.predictor: - self.predictor = predictor or self._smart_load('predictor')(overrides=args, _callbacks=self.callbacks) + self.predictor = predictor or self._smart_load("predictor")(overrides=args, _callbacks=self.callbacks) self.predictor.setup_model(model=self.model, verbose=is_cli) else: # only update args if predictor is already setup self.predictor.args = get_cfg(self.predictor.args, args) - if 'project' in args or 'name' in args: + if "project" in args or "name" in args: self.predictor.save_dir = get_save_dir(self.predictor.args) - if prompts and hasattr(self.predictor, 'set_prompts'): # for SAM-type models + if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models self.predictor.set_prompts(prompts) return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) @@ -280,11 +286,12 @@ class Model(nn.Module): Returns: (List[ultralytics.engine.results.Results]): The tracking results. """ - if not hasattr(self.predictor, 'trackers'): + if not hasattr(self.predictor, "trackers"): from ultralytics.trackers import register_tracker + register_tracker(self, persist) - kwargs['conf'] = kwargs.get('conf') or 0.1 # ByteTrack-based method needs low confidence predictions as input - kwargs['mode'] = 'track' + kwargs["conf"] = kwargs.get("conf") or 0.1 # ByteTrack-based method needs low confidence predictions as input + kwargs["mode"] = "track" return self.predict(source=source, stream=stream, **kwargs) def val(self, validator=None, **kwargs): @@ -295,10 +302,10 @@ class Model(nn.Module): validator (BaseValidator): Customized validator. **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs """ - custom = {'rect': True} # method defaults - args = {**self.overrides, **custom, **kwargs, 'mode': 'val'} # highest priority args on the right + custom = {"rect": True} # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right - validator = (validator or self._smart_load('validator'))(args=args, _callbacks=self.callbacks) + validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks) validator(model=self.model) self.metrics = validator.metrics return validator.metrics @@ -313,16 +320,17 @@ class Model(nn.Module): self._check_is_pytorch_model() from ultralytics.utils.benchmarks import benchmark - custom = {'verbose': False} # method defaults - args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, 'mode': 'benchmark'} + custom = {"verbose": False} # method defaults + args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"} return benchmark( model=self, - data=kwargs.get('data'), # if no 'data' argument passed set data=None for default datasets - imgsz=args['imgsz'], - half=args['half'], - int8=args['int8'], - device=args['device'], - verbose=kwargs.get('verbose')) + data=kwargs.get("data"), # if no 'data' argument passed set data=None for default datasets + imgsz=args["imgsz"], + half=args["half"], + int8=args["int8"], + device=args["device"], + verbose=kwargs.get("verbose"), + ) def export(self, **kwargs): """ @@ -334,8 +342,8 @@ class Model(nn.Module): self._check_is_pytorch_model() from .exporter import Exporter - custom = {'imgsz': self.model.args['imgsz'], 'batch': 1, 'data': None, 'verbose': False} # method defaults - args = {**self.overrides, **custom, **kwargs, 'mode': 'export'} # highest priority args on the right + custom = {"imgsz": self.model.args["imgsz"], "batch": 1, "data": None, "verbose": False} # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model) def train(self, trainer=None, **kwargs): @@ -347,32 +355,32 @@ class Model(nn.Module): **kwargs (Any): Any number of arguments representing the training configuration. """ self._check_is_pytorch_model() - if hasattr(self.session, 'model') and self.session.model.id: # Ultralytics HUB session with loaded model + if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model if any(kwargs): - LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.') + LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.") kwargs = self.session.train_args # overwrite kwargs checks.check_pip_update_available() - overrides = yaml_load(checks.check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides - custom = {'data': DEFAULT_CFG_DICT['data'] or TASK2DATA[self.task]} # method defaults - args = {**overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right - if args.get('resume'): - args['resume'] = self.ckpt_path + overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides + custom = {"data": DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task]} # method defaults + args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right + if args.get("resume"): + args["resume"] = self.ckpt_path - self.trainer = (trainer or self._smart_load('trainer'))(overrides=args, _callbacks=self.callbacks) - if not args.get('resume'): # manually set model only if not resuming + self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks) + if not args.get("resume"): # manually set model only if not resuming self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) self.model = self.trainer.model - if SETTINGS['hub'] is True and not self.session: + if SETTINGS["hub"] is True and not self.session: # Create a model in HUB try: self.session = self._get_hub_session(self.model_name) if self.session: self.session.create_model(args) # Check model was created - if not getattr(self.session.model, 'id', None): + if not getattr(self.session.model, "id", None): self.session = None except PermissionError: # Ignore permission error @@ -385,7 +393,7 @@ class Model(nn.Module): ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last self.model, _ = attempt_load_one_weight(ckpt) self.overrides = self.model.args - self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP + self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP return self.metrics def tune(self, use_ray=False, iterations=10, *args, **kwargs): @@ -398,12 +406,13 @@ class Model(nn.Module): self._check_is_pytorch_model() if use_ray: from ultralytics.utils.tuner import run_ray_tune + return run_ray_tune(self, max_samples=iterations, *args, **kwargs) else: from .tuner import Tuner custom = {} # method defaults - args = {**self.overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right + args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations) def _apply(self, fn): @@ -411,13 +420,13 @@ class Model(nn.Module): self._check_is_pytorch_model() self = super()._apply(fn) # noqa self.predictor = None # reset predictor as device may have changed - self.overrides['device'] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0' + self.overrides["device"] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0' return self @property def names(self): """Returns class names of the loaded model.""" - return self.model.names if hasattr(self.model, 'names') else None + return self.model.names if hasattr(self.model, "names") else None @property def device(self): @@ -427,7 +436,7 @@ class Model(nn.Module): @property def transforms(self): """Returns transform of the loaded model.""" - return self.model.transforms if hasattr(self.model, 'transforms') else None + return self.model.transforms if hasattr(self.model, "transforms") else None def add_callback(self, event: str, func): """Add a callback.""" @@ -445,7 +454,7 @@ class Model(nn.Module): @staticmethod def _reset_ckpt_args(args): """Reset arguments when loading a PyTorch model.""" - include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model + include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model return {k: v for k, v in args.items() if k in include} # def __getattr__(self, attr): @@ -461,7 +470,8 @@ class Model(nn.Module): name = self.__class__.__name__ mode = inspect.stack()[1][3] # get the function name. raise NotImplementedError( - emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.")) from e + emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.") + ) from e @property def task_map(self): @@ -471,4 +481,4 @@ class Model(nn.Module): Returns: task_map (dict): The map of model task to mode classes. """ - raise NotImplementedError('Please provide task map for your model!') + raise NotImplementedError("Please provide task map for your model!") diff --git a/ultralytics/engine/predictor.py b/ultralytics/engine/predictor.py index f78de2fa..f6c029c4 100644 --- a/ultralytics/engine/predictor.py +++ b/ultralytics/engine/predictor.py @@ -132,8 +132,11 @@ class BasePredictor: def inference(self, im, *args, **kwargs): """Runs inference on a given image using the specified model and arguments.""" - visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem, - mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False + visualize = ( + increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True) + if self.args.visualize and (not self.source_type.tensor) + else False + ) return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs) def pre_transform(self, im): @@ -153,35 +156,38 @@ class BasePredictor: def write_results(self, idx, results, batch): """Write inference results to a file or directory.""" p, im, _ = batch - log_string = '' + log_string = "" if len(im.shape) == 3: im = im[None] # expand for batch dim if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1 - log_string += f'{idx}: ' + log_string += f"{idx}: " frame = self.dataset.count else: - frame = getattr(self.dataset, 'frame', 0) + frame = getattr(self.dataset, "frame", 0) self.data_path = p - self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') - log_string += '%gx%g ' % im.shape[2:] # print string + self.txt_path = str(self.save_dir / "labels" / p.stem) + ("" if self.dataset.mode == "image" else f"_{frame}") + log_string += "%gx%g " % im.shape[2:] # print string result = results[idx] log_string += result.verbose() if self.args.save or self.args.show: # Add bbox to image plot_args = { - 'line_width': self.args.line_width, - 'boxes': self.args.show_boxes, - 'conf': self.args.show_conf, - 'labels': self.args.show_labels} + "line_width": self.args.line_width, + "boxes": self.args.show_boxes, + "conf": self.args.show_conf, + "labels": self.args.show_labels, + } if not self.args.retina_masks: - plot_args['im_gpu'] = im[idx] + plot_args["im_gpu"] = im[idx] self.plotted_img = result.plot(**plot_args) # Write if self.args.save_txt: - result.save_txt(f'{self.txt_path}.txt', save_conf=self.args.save_conf) + result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf) if self.args.save_crop: - result.save_crop(save_dir=self.save_dir / 'crops', - file_name=self.data_path.stem + ('' if self.dataset.mode == 'image' else f'_{frame}')) + result.save_crop( + save_dir=self.save_dir / "crops", + file_name=self.data_path.stem + ("" if self.dataset.mode == "image" else f"_{frame}"), + ) return log_string @@ -210,17 +216,24 @@ class BasePredictor: def setup_source(self, source): """Sets up source and inference mode.""" self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size - self.transforms = getattr( - self.model.model, 'transforms', classify_transforms( - self.imgsz[0], crop_fraction=self.args.crop_fraction)) if self.args.task == 'classify' else None - self.dataset = load_inference_source(source=source, - imgsz=self.imgsz, - vid_stride=self.args.vid_stride, - buffer=self.args.stream_buffer) + self.transforms = ( + getattr( + self.model.model, + "transforms", + classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction), + ) + if self.args.task == "classify" + else None + ) + self.dataset = load_inference_source( + source=source, imgsz=self.imgsz, vid_stride=self.args.vid_stride, buffer=self.args.stream_buffer + ) self.source_type = self.dataset.source_type - if not getattr(self, 'stream', True) and (self.dataset.mode == 'stream' or # streams - len(self.dataset) > 1000 or # images - any(getattr(self.dataset, 'video_flag', [False]))): # videos + if not getattr(self, "stream", True) and ( + self.dataset.mode == "stream" # streams + or len(self.dataset) > 1000 # images + or any(getattr(self.dataset, "video_flag", [False])) + ): # videos LOGGER.warning(STREAM_WARNING) self.vid_path = [None] * self.dataset.bs self.vid_writer = [None] * self.dataset.bs @@ -230,7 +243,7 @@ class BasePredictor: def stream_inference(self, source=None, model=None, *args, **kwargs): """Streams real-time inference on camera feed and saves results to file.""" if self.args.verbose: - LOGGER.info('') + LOGGER.info("") # Setup model if not self.model: @@ -242,7 +255,7 @@ class BasePredictor: # Check if save_dir/ label file exists if self.args.save or self.args.save_txt: - (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) + (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) # Warmup model if not self.done_warmup: @@ -250,10 +263,10 @@ class BasePredictor: self.done_warmup = True self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile()) - self.run_callbacks('on_predict_start') + self.run_callbacks("on_predict_start") for batch in self.dataset: - self.run_callbacks('on_predict_batch_start') + self.run_callbacks("on_predict_batch_start") self.batch = batch path, im0s, vid_cap, s = batch @@ -272,15 +285,16 @@ class BasePredictor: with profilers[2]: self.results = self.postprocess(preds, im, im0s) - self.run_callbacks('on_predict_postprocess_end') + self.run_callbacks("on_predict_postprocess_end") # Visualize, save, write results n = len(im0s) for i in range(n): self.seen += 1 self.results[i].speed = { - 'preprocess': profilers[0].dt * 1E3 / n, - 'inference': profilers[1].dt * 1E3 / n, - 'postprocess': profilers[2].dt * 1E3 / n} + "preprocess": profilers[0].dt * 1e3 / n, + "inference": profilers[1].dt * 1e3 / n, + "postprocess": profilers[2].dt * 1e3 / n, + } p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy() p = Path(p) @@ -293,12 +307,12 @@ class BasePredictor: if self.args.save and self.plotted_img is not None: self.save_preds(vid_cap, i, str(self.save_dir / p.name)) - self.run_callbacks('on_predict_batch_end') + self.run_callbacks("on_predict_batch_end") yield from self.results # Print time (inference-only) if self.args.verbose: - LOGGER.info(f'{s}{profilers[1].dt * 1E3:.1f}ms') + LOGGER.info(f"{s}{profilers[1].dt * 1E3:.1f}ms") # Release assets if isinstance(self.vid_writer[-1], cv2.VideoWriter): @@ -306,25 +320,29 @@ class BasePredictor: # Print results if self.args.verbose and self.seen: - t = tuple(x.t / self.seen * 1E3 for x in profilers) # speeds per image - LOGGER.info(f'Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape ' - f'{(1, 3, *im.shape[2:])}' % t) + t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image + LOGGER.info( + f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape " + f"{(1, 3, *im.shape[2:])}" % t + ) if self.args.save or self.args.save_txt or self.args.save_crop: - nl = len(list(self.save_dir.glob('labels/*.txt'))) # number of labels - s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else '' + nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels + s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else "" LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}") - self.run_callbacks('on_predict_end') + self.run_callbacks("on_predict_end") def setup_model(self, model, verbose=True): """Initialize YOLO model with given parameters and set it to evaluation mode.""" - self.model = AutoBackend(model or self.args.model, - device=select_device(self.args.device, verbose=verbose), - dnn=self.args.dnn, - data=self.args.data, - fp16=self.args.half, - fuse=True, - verbose=verbose) + self.model = AutoBackend( + model or self.args.model, + device=select_device(self.args.device, verbose=verbose), + dnn=self.args.dnn, + data=self.args.data, + fp16=self.args.half, + fuse=True, + verbose=verbose, + ) self.device = self.model.device # update device self.args.half = self.model.fp16 # update half @@ -333,18 +351,18 @@ class BasePredictor: def show(self, p): """Display an image in a window using OpenCV imshow().""" im0 = self.plotted_img - if platform.system() == 'Linux' and p not in self.windows: + if platform.system() == "Linux" and p not in self.windows: self.windows.append(p) cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0]) cv2.imshow(str(p), im0) - cv2.waitKey(500 if self.batch[3].startswith('image') else 1) # 1 millisecond + cv2.waitKey(500 if self.batch[3].startswith("image") else 1) # 1 millisecond def save_preds(self, vid_cap, idx, save_path): """Save video predictions as mp4 at specified path.""" im0 = self.plotted_img # Save imgs - if self.dataset.mode == 'image': + if self.dataset.mode == "image": cv2.imwrite(save_path, im0) else: # 'video' or 'stream' frames_path = f'{save_path.split(".", 1)[0]}_frames/' @@ -361,15 +379,16 @@ class BasePredictor: h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) else: # stream fps, w, h = 30, im0.shape[1], im0.shape[0] - suffix, fourcc = ('.mp4', 'avc1') if MACOS else ('.avi', 'WMV2') if WINDOWS else ('.avi', 'MJPG') - self.vid_writer[idx] = cv2.VideoWriter(str(Path(save_path).with_suffix(suffix)), - cv2.VideoWriter_fourcc(*fourcc), fps, (w, h)) + suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG") + self.vid_writer[idx] = cv2.VideoWriter( + str(Path(save_path).with_suffix(suffix)), cv2.VideoWriter_fourcc(*fourcc), fps, (w, h) + ) # Write video self.vid_writer[idx].write(im0) # Write frame if self.args.save_frames: - cv2.imwrite(f'{frames_path}{self.vid_frame[idx]}.jpg', im0) + cv2.imwrite(f"{frames_path}{self.vid_frame[idx]}.jpg", im0) self.vid_frame[idx] += 1 def run_callbacks(self, event: str): diff --git a/ultralytics/engine/results.py b/ultralytics/engine/results.py index 21e11078..82d654ed 100644 --- a/ultralytics/engine/results.py +++ b/ultralytics/engine/results.py @@ -98,15 +98,15 @@ class Results(SimpleClass): self.probs = Probs(probs) if probs is not None else None self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None self.obb = OBB(obb, self.orig_shape) if obb is not None else None - self.speed = {'preprocess': None, 'inference': None, 'postprocess': None} # milliseconds per image + self.speed = {"preprocess": None, "inference": None, "postprocess": None} # milliseconds per image self.names = names self.path = path self.save_dir = None - self._keys = 'boxes', 'masks', 'probs', 'keypoints', 'obb' + self._keys = "boxes", "masks", "probs", "keypoints", "obb" def __getitem__(self, idx): """Return a Results object for the specified index.""" - return self._apply('__getitem__', idx) + return self._apply("__getitem__", idx) def __len__(self): """Return the number of detections in the Results object.""" @@ -146,19 +146,19 @@ class Results(SimpleClass): def cpu(self): """Return a copy of the Results object with all tensors on CPU memory.""" - return self._apply('cpu') + return self._apply("cpu") def numpy(self): """Return a copy of the Results object with all tensors as numpy arrays.""" - return self._apply('numpy') + return self._apply("numpy") def cuda(self): """Return a copy of the Results object with all tensors on GPU memory.""" - return self._apply('cuda') + return self._apply("cuda") def to(self, *args, **kwargs): """Return a copy of the Results object with tensors on the specified device and dtype.""" - return self._apply('to', *args, **kwargs) + return self._apply("to", *args, **kwargs) def new(self): """Return a new Results object with the same image, path, and names.""" @@ -169,7 +169,7 @@ class Results(SimpleClass): conf=True, line_width=None, font_size=None, - font='Arial.ttf', + font="Arial.ttf", pil=False, img=None, im_gpu=None, @@ -229,14 +229,20 @@ class Results(SimpleClass): font_size, font, pil or (pred_probs is not None and show_probs), # Classify tasks default to pil=True - example=names) + example=names, + ) # Plot Segment results if pred_masks and show_masks: if im_gpu is None: img = LetterBox(pred_masks.shape[1:])(image=annotator.result()) - im_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute( - 2, 0, 1).flip(0).contiguous() / 255 + im_gpu = ( + torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device) + .permute(2, 0, 1) + .flip(0) + .contiguous() + / 255 + ) idx = pred_boxes.cls if pred_boxes else range(len(pred_masks)) annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu) @@ -244,14 +250,14 @@ class Results(SimpleClass): if pred_boxes is not None and show_boxes: for d in reversed(pred_boxes): c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item()) - name = ('' if id is None else f'id:{id} ') + names[c] - label = (f'{name} {conf:.2f}' if conf else name) if labels else None + name = ("" if id is None else f"id:{id} ") + names[c] + label = (f"{name} {conf:.2f}" if conf else name) if labels else None box = d.xyxyxyxy.reshape(-1, 4, 2).squeeze() if is_obb else d.xyxy.squeeze() annotator.box_label(box, label, color=colors(c, True), rotated=is_obb) # Plot Classify results if pred_probs is not None and show_probs: - text = ',\n'.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5) + text = ",\n".join(f"{names[j] if names else j} {pred_probs.data[j]:.2f}" for j in pred_probs.top5) x = round(self.orig_shape[0] * 0.03) annotator.text([x, x], text, txt_color=(255, 255, 255)) # TODO: allow setting colors @@ -264,11 +270,11 @@ class Results(SimpleClass): def verbose(self): """Return log string for each task.""" - log_string = '' + log_string = "" probs = self.probs boxes = self.boxes if len(self) == 0: - return log_string if probs is not None else f'{log_string}(no detections), ' + return log_string if probs is not None else f"{log_string}(no detections), " if probs is not None: log_string += f"{', '.join(f'{self.names[j]} {probs.data[j]:.2f}' for j in probs.top5)}, " if boxes: @@ -293,7 +299,7 @@ class Results(SimpleClass): texts = [] if probs is not None: # Classify - [texts.append(f'{probs.data[j]:.2f} {self.names[j]}') for j in probs.top5] + [texts.append(f"{probs.data[j]:.2f} {self.names[j]}") for j in probs.top5] elif boxes: # Detect/segment/pose for j, d in enumerate(boxes): @@ -304,16 +310,16 @@ class Results(SimpleClass): line = (c, *seg) if kpts is not None: kpt = torch.cat((kpts[j].xyn, kpts[j].conf[..., None]), 2) if kpts[j].has_visible else kpts[j].xyn - line += (*kpt.reshape(-1).tolist(), ) - line += (conf, ) * save_conf + (() if id is None else (id, )) - texts.append(('%g ' * len(line)).rstrip() % line) + line += (*kpt.reshape(-1).tolist(),) + line += (conf,) * save_conf + (() if id is None else (id,)) + texts.append(("%g " * len(line)).rstrip() % line) if texts: Path(txt_file).parent.mkdir(parents=True, exist_ok=True) # make directory - with open(txt_file, 'a') as f: - f.writelines(text + '\n' for text in texts) + with open(txt_file, "a") as f: + f.writelines(text + "\n" for text in texts) - def save_crop(self, save_dir, file_name=Path('im.jpg')): + def save_crop(self, save_dir, file_name=Path("im.jpg")): """ Save cropped predictions to `save_dir/cls/file_name.jpg`. @@ -322,21 +328,23 @@ class Results(SimpleClass): file_name (str | pathlib.Path): File name. """ if self.probs is not None: - LOGGER.warning('WARNING ⚠️ Classify task do not support `save_crop`.') + LOGGER.warning("WARNING ⚠️ Classify task do not support `save_crop`.") return if self.obb is not None: - LOGGER.warning('WARNING ⚠️ OBB task do not support `save_crop`.') + LOGGER.warning("WARNING ⚠️ OBB task do not support `save_crop`.") return for d in self.boxes: - save_one_box(d.xyxy, - self.orig_img.copy(), - file=Path(save_dir) / self.names[int(d.cls)] / f'{Path(file_name)}.jpg', - BGR=True) + save_one_box( + d.xyxy, + self.orig_img.copy(), + file=Path(save_dir) / self.names[int(d.cls)] / f"{Path(file_name)}.jpg", + BGR=True, + ) def tojson(self, normalize=False): """Convert the object to JSON format.""" if self.probs is not None: - LOGGER.warning('Warning: Classify task do not support `tojson` yet.') + LOGGER.warning("Warning: Classify task do not support `tojson` yet.") return import json @@ -346,19 +354,19 @@ class Results(SimpleClass): data = self.boxes.data.cpu().tolist() h, w = self.orig_shape if normalize else (1, 1) for i, row in enumerate(data): # xyxy, track_id if tracking, conf, class_id - box = {'x1': row[0] / w, 'y1': row[1] / h, 'x2': row[2] / w, 'y2': row[3] / h} + box = {"x1": row[0] / w, "y1": row[1] / h, "x2": row[2] / w, "y2": row[3] / h} conf = row[-2] class_id = int(row[-1]) name = self.names[class_id] - result = {'name': name, 'class': class_id, 'confidence': conf, 'box': box} + result = {"name": name, "class": class_id, "confidence": conf, "box": box} if self.boxes.is_track: - result['track_id'] = int(row[-3]) # track ID + result["track_id"] = int(row[-3]) # track ID if self.masks: x, y = self.masks.xy[i][:, 0], self.masks.xy[i][:, 1] # numpy array - result['segments'] = {'x': (x / w).tolist(), 'y': (y / h).tolist()} + result["segments"] = {"x": (x / w).tolist(), "y": (y / h).tolist()} if self.keypoints is not None: x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1) # torch Tensor - result['keypoints'] = {'x': (x / w).tolist(), 'y': (y / h).tolist(), 'visible': visible.tolist()} + result["keypoints"] = {"x": (x / w).tolist(), "y": (y / h).tolist(), "visible": visible.tolist()} results.append(result) # Convert detections to JSON @@ -397,7 +405,7 @@ class Boxes(BaseTensor): if boxes.ndim == 1: boxes = boxes[None, :] n = boxes.shape[-1] - assert n in (6, 7), f'expected 6 or 7 values but got {n}' # xyxy, track_id, conf, cls + assert n in (6, 7), f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls super().__init__(boxes, orig_shape) self.is_track = n == 7 self.orig_shape = orig_shape @@ -474,7 +482,8 @@ class Masks(BaseTensor): """Return normalized segments.""" return [ ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True) - for x in ops.masks2segments(self.data)] + for x in ops.masks2segments(self.data) + ] @property @lru_cache(maxsize=1) @@ -482,7 +491,8 @@ class Masks(BaseTensor): """Return segments in pixel coordinates.""" return [ ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False) - for x in ops.masks2segments(self.data)] + for x in ops.masks2segments(self.data) + ] class Keypoints(BaseTensor): @@ -610,7 +620,7 @@ class OBB(BaseTensor): if boxes.ndim == 1: boxes = boxes[None, :] n = boxes.shape[-1] - assert n in (7, 8), f'expected 7 or 8 values but got {n}' # xywh, rotation, track_id, conf, cls + assert n in (7, 8), f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls super().__init__(boxes, orig_shape) self.is_track = n == 8 self.orig_shape = orig_shape diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 15eefb23..0fa98396 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -23,14 +23,31 @@ from torch import nn, optim from ultralytics.cfg import get_cfg, get_save_dir from ultralytics.data.utils import check_cls_dataset, check_det_dataset from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights -from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, TQDM, __version__, callbacks, clean_url, colorstr, emojis, - yaml_save) +from ultralytics.utils import ( + DEFAULT_CFG, + LOGGER, + RANK, + TQDM, + __version__, + callbacks, + clean_url, + colorstr, + emojis, + yaml_save, +) from ultralytics.utils.autobatch import check_train_batch_size from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command from ultralytics.utils.files import get_latest_run -from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device, - strip_optimizer) +from ultralytics.utils.torch_utils import ( + EarlyStopping, + ModelEMA, + de_parallel, + init_seeds, + one_cycle, + select_device, + strip_optimizer, +) class BaseTrainer: @@ -89,12 +106,12 @@ class BaseTrainer: # Dirs self.save_dir = get_save_dir(self.args) self.args.name = self.save_dir.name # update name for loggers - self.wdir = self.save_dir / 'weights' # weights dir + self.wdir = self.save_dir / "weights" # weights dir if RANK in (-1, 0): self.wdir.mkdir(parents=True, exist_ok=True) # make dir self.args.save_dir = str(self.save_dir) - yaml_save(self.save_dir / 'args.yaml', vars(self.args)) # save run args - self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths + yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args + self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths self.save_period = self.args.save_period self.batch_size = self.args.batch @@ -104,18 +121,18 @@ class BaseTrainer: print_args(vars(self.args)) # Device - if self.device.type in ('cpu', 'mps'): + if self.device.type in ("cpu", "mps"): self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading # Model and Dataset self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt try: - if self.args.task == 'classify': + if self.args.task == "classify": self.data = check_cls_dataset(self.args.data) - elif self.args.data.split('.')[-1] in ('yaml', 'yml') or self.args.task in ('detect', 'segment', 'pose'): + elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in ("detect", "segment", "pose"): self.data = check_det_dataset(self.args.data) - if 'yaml_file' in self.data: - self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage + if "yaml_file" in self.data: + self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage except Exception as e: raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e @@ -131,8 +148,8 @@ class BaseTrainer: self.fitness = None self.loss = None self.tloss = None - self.loss_names = ['Loss'] - self.csv = self.save_dir / 'results.csv' + self.loss_names = ["Loss"] + self.csv = self.save_dir / "results.csv" self.plot_idx = [0, 1, 2] # Callbacks @@ -156,7 +173,7 @@ class BaseTrainer: def train(self): """Allow device='', device=None on Multi-GPU systems to default to device=0.""" if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3' - world_size = len(self.args.device.split(',')) + world_size = len(self.args.device.split(",")) elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list) world_size = len(self.args.device) elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number @@ -165,14 +182,16 @@ class BaseTrainer: world_size = 0 # Run subprocess if DDP training, else train normally - if world_size > 1 and 'LOCAL_RANK' not in os.environ: + if world_size > 1 and "LOCAL_RANK" not in os.environ: # Argument checks if self.args.rect: LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'") self.args.rect = False if self.args.batch == -1: - LOGGER.warning("WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting " - "default 'batch=16'") + LOGGER.warning( + "WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting " + "default 'batch=16'" + ) self.args.batch = 16 # Command @@ -199,37 +218,45 @@ class BaseTrainer: def _setup_ddp(self, world_size): """Initializes and sets the DistributedDataParallel parameters for training.""" torch.cuda.set_device(RANK) - self.device = torch.device('cuda', RANK) + self.device = torch.device("cuda", RANK) # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}') - os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout + os.environ["NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout dist.init_process_group( - 'nccl' if dist.is_nccl_available() else 'gloo', + "nccl" if dist.is_nccl_available() else "gloo", timeout=timedelta(seconds=10800), # 3 hours rank=RANK, - world_size=world_size) + world_size=world_size, + ) def _setup_train(self, world_size): """Builds dataloaders and optimizer on correct rank process.""" # Model - self.run_callbacks('on_pretrain_routine_start') + self.run_callbacks("on_pretrain_routine_start") ckpt = self.setup_model() self.model = self.model.to(self.device) self.set_model_attributes() # Freeze layers - freeze_list = self.args.freeze if isinstance( - self.args.freeze, list) else range(self.args.freeze) if isinstance(self.args.freeze, int) else [] - always_freeze_names = ['.dfl'] # always freeze these layers - freeze_layer_names = [f'model.{x}.' for x in freeze_list] + always_freeze_names + freeze_list = ( + self.args.freeze + if isinstance(self.args.freeze, list) + else range(self.args.freeze) + if isinstance(self.args.freeze, int) + else [] + ) + always_freeze_names = [".dfl"] # always freeze these layers + freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names for k, v in self.model.named_parameters(): # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results) if any(x in k for x in freeze_layer_names): LOGGER.info(f"Freezing layer '{k}'") v.requires_grad = False elif not v.requires_grad: - LOGGER.info(f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. " - 'See ultralytics.engine.trainer for customization of frozen layers.') + LOGGER.info( + f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. " + "See ultralytics.engine.trainer for customization of frozen layers." + ) v.requires_grad = True # Check AMP @@ -246,7 +273,7 @@ class BaseTrainer: self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK]) # Check imgsz - gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride) + gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride) self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1) self.stride = gs # for multi-scale training @@ -256,15 +283,14 @@ class BaseTrainer: # Dataloaders batch_size = self.batch_size // max(world_size, 1) - self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train') + self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train") if RANK in (-1, 0): # NOTE: When training DOTA dataset, double batch size could get OOM cause some images got more than 2000 objects. - self.test_loader = self.get_dataloader(self.testset, - batch_size=batch_size if self.args.task == 'obb' else batch_size * 2, - rank=-1, - mode='val') + self.test_loader = self.get_dataloader( + self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val" + ) self.validator = self.get_validator() - metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val') + metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val") self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) self.ema = ModelEMA(self.model) if self.args.plots: @@ -274,18 +300,20 @@ class BaseTrainer: self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs - self.optimizer = self.build_optimizer(model=self.model, - name=self.args.optimizer, - lr=self.args.lr0, - momentum=self.args.momentum, - decay=weight_decay, - iterations=iterations) + self.optimizer = self.build_optimizer( + model=self.model, + name=self.args.optimizer, + lr=self.args.lr0, + momentum=self.args.momentum, + decay=weight_decay, + iterations=iterations, + ) # Scheduler self._setup_scheduler() self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False self.resume_training(ckpt) self.scheduler.last_epoch = self.start_epoch - 1 # do not move - self.run_callbacks('on_pretrain_routine_end') + self.run_callbacks("on_pretrain_routine_end") def _do_train(self, world_size=1): """Train completed, evaluate and plot if specified by arguments.""" @@ -299,19 +327,23 @@ class BaseTrainer: self.epoch_time = None self.epoch_time_start = time.time() self.train_time_start = time.time() - self.run_callbacks('on_train_start') - LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n' - f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n' - f"Logging results to {colorstr('bold', self.save_dir)}\n" - f'Starting training for ' - f'{self.args.time} hours...' if self.args.time else f'{self.epochs} epochs...') + self.run_callbacks("on_train_start") + LOGGER.info( + f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n' + f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n' + f"Logging results to {colorstr('bold', self.save_dir)}\n" + f'Starting training for ' + f'{self.args.time} hours...' + if self.args.time + else f"{self.epochs} epochs..." + ) if self.args.close_mosaic: base_idx = (self.epochs - self.args.close_mosaic) * nb self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2]) epoch = self.epochs # predefine for resume fully trained model edge cases for epoch in range(self.start_epoch, self.epochs): self.epoch = epoch - self.run_callbacks('on_train_epoch_start') + self.run_callbacks("on_train_epoch_start") self.model.train() if RANK != -1: self.train_loader.sampler.set_epoch(epoch) @@ -327,7 +359,7 @@ class BaseTrainer: self.tloss = None self.optimizer.zero_grad() for i, batch in pbar: - self.run_callbacks('on_train_batch_start') + self.run_callbacks("on_train_batch_start") # Warmup ni = i + nb * epoch if ni <= nw: @@ -335,10 +367,11 @@ class BaseTrainer: self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())) for j, x in enumerate(self.optimizer.param_groups): # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 - x['lr'] = np.interp( - ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)]) - if 'momentum' in x: - x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum]) + x["lr"] = np.interp( + ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)] + ) + if "momentum" in x: + x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum]) # Forward with torch.cuda.amp.autocast(self.amp): @@ -346,8 +379,9 @@ class BaseTrainer: self.loss, self.loss_items = self.model(batch) if RANK != -1: self.loss *= world_size - self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \ - else self.loss_items + self.tloss = ( + (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items + ) # Backward self.scaler.scale(self.loss).backward() @@ -368,24 +402,25 @@ class BaseTrainer: break # Log - mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB) + mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB) loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1 losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0) if RANK in (-1, 0): pbar.set_description( - ('%11s' * 2 + '%11.4g' * (2 + loss_len)) % - (f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1])) - self.run_callbacks('on_batch_end') + ("%11s" * 2 + "%11.4g" * (2 + loss_len)) + % (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]) + ) + self.run_callbacks("on_batch_end") if self.args.plots and ni in self.plot_idx: self.plot_training_samples(batch, ni) - self.run_callbacks('on_train_batch_end') + self.run_callbacks("on_train_batch_end") - self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers - self.run_callbacks('on_train_epoch_end') + self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers + self.run_callbacks("on_train_epoch_end") if RANK in (-1, 0): final_epoch = epoch + 1 == self.epochs - self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights']) + self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"]) # Validation if self.args.val or final_epoch or self.stopper.possible_stop or self.stop: @@ -398,14 +433,14 @@ class BaseTrainer: # Save model if self.args.save or final_epoch: self.save_model() - self.run_callbacks('on_model_save') + self.run_callbacks("on_model_save") # Scheduler t = time.time() self.epoch_time = t - self.epoch_time_start self.epoch_time_start = t with warnings.catch_warnings(): - warnings.simplefilter('ignore') # suppress 'Detected lr_scheduler.step() before optimizer.step()' + warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()' if self.args.time: mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1) self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time) @@ -413,7 +448,7 @@ class BaseTrainer: self.scheduler.last_epoch = self.epoch # do not move self.stop |= epoch >= self.epochs # stop if exceeded epochs self.scheduler.step() - self.run_callbacks('on_fit_epoch_end') + self.run_callbacks("on_fit_epoch_end") torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors # Early Stopping @@ -426,39 +461,43 @@ class BaseTrainer: if RANK in (-1, 0): # Do final val with best.pt - LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in ' - f'{(time.time() - self.train_time_start) / 3600:.3f} hours.') + LOGGER.info( + f"\n{epoch - self.start_epoch + 1} epochs completed in " + f"{(time.time() - self.train_time_start) / 3600:.3f} hours." + ) self.final_eval() if self.args.plots: self.plot_metrics() - self.run_callbacks('on_train_end') + self.run_callbacks("on_train_end") torch.cuda.empty_cache() - self.run_callbacks('teardown') + self.run_callbacks("teardown") def save_model(self): """Save model training checkpoints with additional metadata.""" import pandas as pd # scope for faster startup - metrics = {**self.metrics, **{'fitness': self.fitness}} - results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient='list').items()} + + metrics = {**self.metrics, **{"fitness": self.fitness}} + results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()} ckpt = { - 'epoch': self.epoch, - 'best_fitness': self.best_fitness, - 'model': deepcopy(de_parallel(self.model)).half(), - 'ema': deepcopy(self.ema.ema).half(), - 'updates': self.ema.updates, - 'optimizer': self.optimizer.state_dict(), - 'train_args': vars(self.args), # save as dict - 'train_metrics': metrics, - 'train_results': results, - 'date': datetime.now().isoformat(), - 'version': __version__} + "epoch": self.epoch, + "best_fitness": self.best_fitness, + "model": deepcopy(de_parallel(self.model)).half(), + "ema": deepcopy(self.ema.ema).half(), + "updates": self.ema.updates, + "optimizer": self.optimizer.state_dict(), + "train_args": vars(self.args), # save as dict + "train_metrics": metrics, + "train_results": results, + "date": datetime.now().isoformat(), + "version": __version__, + } # Save last and best torch.save(ckpt, self.last) if self.best_fitness == self.fitness: torch.save(ckpt, self.best) if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0): - torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt') + torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt") @staticmethod def get_dataset(data): @@ -467,7 +506,7 @@ class BaseTrainer: Returns None if data format is not recognized. """ - return data['train'], data.get('val') or data.get('test') + return data["train"], data.get("val") or data.get("test") def setup_model(self): """Load/create/download model for any task.""" @@ -476,9 +515,9 @@ class BaseTrainer: model, weights = self.model, None ckpt = None - if str(model).endswith('.pt'): + if str(model).endswith(".pt"): weights, ckpt = attempt_load_one_weight(model) - cfg = ckpt['model'].yaml + cfg = ckpt["model"].yaml else: cfg = model self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights) @@ -505,7 +544,7 @@ class BaseTrainer: The returned dict is expected to contain "fitness" key. """ metrics = self.validator(self) - fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found + fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found if not self.best_fitness or self.best_fitness < fitness: self.best_fitness = fitness return metrics, fitness @@ -516,24 +555,24 @@ class BaseTrainer: def get_validator(self): """Returns a NotImplementedError when the get_validator function is called.""" - raise NotImplementedError('get_validator function not implemented in trainer') + raise NotImplementedError("get_validator function not implemented in trainer") - def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'): + def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): """Returns dataloader derived from torch.data.Dataloader.""" - raise NotImplementedError('get_dataloader function not implemented in trainer') + raise NotImplementedError("get_dataloader function not implemented in trainer") - def build_dataset(self, img_path, mode='train', batch=None): + def build_dataset(self, img_path, mode="train", batch=None): """Build dataset.""" - raise NotImplementedError('build_dataset function not implemented in trainer') + raise NotImplementedError("build_dataset function not implemented in trainer") - def label_loss_items(self, loss_items=None, prefix='train'): + def label_loss_items(self, loss_items=None, prefix="train"): """Returns a loss dict with labelled training loss items tensor.""" # Not needed for classification but necessary for segmentation & detection - return {'loss': loss_items} if loss_items is not None else ['loss'] + return {"loss": loss_items} if loss_items is not None else ["loss"] def set_model_attributes(self): """To set or update model parameters before training.""" - self.model.names = self.data['names'] + self.model.names = self.data["names"] def build_targets(self, preds, targets): """Builds target tensors for training YOLO model.""" @@ -541,7 +580,7 @@ class BaseTrainer: def progress_string(self): """Returns a string describing training progress.""" - return '' + return "" # TODO: may need to put these following functions into callback def plot_training_samples(self, batch, ni): @@ -556,9 +595,9 @@ class BaseTrainer: """Saves training metrics to a CSV file.""" keys, vals = list(metrics.keys()), list(metrics.values()) n = len(metrics) + 1 # number of cols - s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header - with open(self.csv, 'a') as f: - f.write(s + ('%23.5g,' * n % tuple([self.epoch + 1] + vals)).rstrip(',') + '\n') + s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n") # header + with open(self.csv, "a") as f: + f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n") def plot_metrics(self): """Plot and display metrics visually.""" @@ -567,7 +606,7 @@ class BaseTrainer: def on_plot(self, name, data=None): """Registers plots (e.g. to be consumed in callbacks)""" path = Path(name) - self.plots[path] = {'data': data, 'timestamp': time.time()} + self.plots[path] = {"data": data, "timestamp": time.time()} def final_eval(self): """Performs final evaluation and validation for object detection YOLO model.""" @@ -575,11 +614,11 @@ class BaseTrainer: if f.exists(): strip_optimizer(f) # strip optimizers if f is self.best: - LOGGER.info(f'\nValidating {f}...') + LOGGER.info(f"\nValidating {f}...") self.validator.args.plots = self.args.plots self.metrics = self.validator(model=f) - self.metrics.pop('fitness', None) - self.run_callbacks('on_fit_epoch_end') + self.metrics.pop("fitness", None) + self.run_callbacks("on_fit_epoch_end") def check_resume(self, overrides): """Check if resume checkpoint exists and update arguments accordingly.""" @@ -591,19 +630,21 @@ class BaseTrainer: # Check that resume data YAML exists, otherwise strip to force re-download of dataset ckpt_args = attempt_load_weights(last).args - if not Path(ckpt_args['data']).exists(): - ckpt_args['data'] = self.args.data + if not Path(ckpt_args["data"]).exists(): + ckpt_args["data"] = self.args.data resume = True self.args = get_cfg(ckpt_args) self.args.model = str(last) # reinstate model - for k in 'imgsz', 'batch': # allow arg updates to reduce memory on resume if crashed due to CUDA OOM + for k in "imgsz", "batch": # allow arg updates to reduce memory on resume if crashed due to CUDA OOM if k in overrides: setattr(self.args, k, overrides[k]) except Exception as e: - raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, ' - "i.e. 'yolo train resume model=path/to/last.pt'") from e + raise FileNotFoundError( + "Resume checkpoint not found. Please pass a valid checkpoint to resume from, " + "i.e. 'yolo train resume model=path/to/last.pt'" + ) from e self.resume = resume def resume_training(self, ckpt): @@ -611,23 +652,26 @@ class BaseTrainer: if ckpt is None: return best_fitness = 0.0 - start_epoch = ckpt['epoch'] + 1 - if ckpt['optimizer'] is not None: - self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer - best_fitness = ckpt['best_fitness'] - if self.ema and ckpt.get('ema'): - self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA - self.ema.updates = ckpt['updates'] + start_epoch = ckpt["epoch"] + 1 + if ckpt["optimizer"] is not None: + self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer + best_fitness = ckpt["best_fitness"] + if self.ema and ckpt.get("ema"): + self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA + self.ema.updates = ckpt["updates"] if self.resume: - assert start_epoch > 0, \ - f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \ + assert start_epoch > 0, ( + f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n" f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'" + ) LOGGER.info( - f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs') + f"Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs" + ) if self.epochs < start_epoch: LOGGER.info( - f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.") - self.epochs += ckpt['epoch'] # finetune additional epochs + f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs." + ) + self.epochs += ckpt["epoch"] # finetune additional epochs self.best_fitness = best_fitness self.start_epoch = start_epoch if start_epoch > (self.epochs - self.args.close_mosaic): @@ -635,13 +679,13 @@ class BaseTrainer: def _close_dataloader_mosaic(self): """Update dataloaders to stop using mosaic augmentation.""" - if hasattr(self.train_loader.dataset, 'mosaic'): + if hasattr(self.train_loader.dataset, "mosaic"): self.train_loader.dataset.mosaic = False - if hasattr(self.train_loader.dataset, 'close_mosaic'): - LOGGER.info('Closing dataloader mosaic') + if hasattr(self.train_loader.dataset, "close_mosaic"): + LOGGER.info("Closing dataloader mosaic") self.train_loader.dataset.close_mosaic(hyp=self.args) - def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5): + def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5): """ Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum, weight decay, and number of iterations. @@ -661,41 +705,45 @@ class BaseTrainer: """ g = [], [], [] # optimizer parameter groups - bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d() - if name == 'auto': - LOGGER.info(f"{colorstr('optimizer:')} 'optimizer=auto' found, " - f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and " - f"determining best 'optimizer', 'lr0' and 'momentum' automatically... ") - nc = getattr(model, 'nc', 10) # number of classes + bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d() + if name == "auto": + LOGGER.info( + f"{colorstr('optimizer:')} 'optimizer=auto' found, " + f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and " + f"determining best 'optimizer', 'lr0' and 'momentum' automatically... " + ) + nc = getattr(model, "nc", 10) # number of classes lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places - name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9) + name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9) self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam for module_name, module in model.named_modules(): for param_name, param in module.named_parameters(recurse=False): - fullname = f'{module_name}.{param_name}' if module_name else param_name - if 'bias' in fullname: # bias (no decay) + fullname = f"{module_name}.{param_name}" if module_name else param_name + if "bias" in fullname: # bias (no decay) g[2].append(param) elif isinstance(module, bn): # weight (no decay) g[1].append(param) else: # weight (with decay) g[0].append(param) - if name in ('Adam', 'Adamax', 'AdamW', 'NAdam', 'RAdam'): + if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"): optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0) - elif name == 'RMSProp': + elif name == "RMSProp": optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum) - elif name == 'SGD': + elif name == "SGD": optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True) else: raise NotImplementedError( f"Optimizer '{name}' not found in list of available optimizers " - f'[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].' - 'To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.') + f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]." + "To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics." + ) - optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay - optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights) + optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay + optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights) LOGGER.info( f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups " - f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)') + f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)' + ) return optimizer diff --git a/ultralytics/engine/tuner.py b/ultralytics/engine/tuner.py index 5765bce1..80a554d4 100644 --- a/ultralytics/engine/tuner.py +++ b/ultralytics/engine/tuner.py @@ -73,40 +73,43 @@ class Tuner: Args: args (dict, optional): Configuration for hyperparameter evolution. """ - self.space = args.pop('space', None) or { # key: (min, max, gain(optional)) + self.space = args.pop("space", None) or { # key: (min, max, gain(optional)) # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']), - 'lr0': (1e-5, 1e-1), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3) - 'lrf': (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf) - 'momentum': (0.7, 0.98, 0.3), # SGD momentum/Adam beta1 - 'weight_decay': (0.0, 0.001), # optimizer weight decay 5e-4 - 'warmup_epochs': (0.0, 5.0), # warmup epochs (fractions ok) - 'warmup_momentum': (0.0, 0.95), # warmup initial momentum - 'box': (1.0, 20.0), # box loss gain - 'cls': (0.2, 4.0), # cls loss gain (scale with pixels) - 'dfl': (0.4, 6.0), # dfl loss gain - 'hsv_h': (0.0, 0.1), # image HSV-Hue augmentation (fraction) - 'hsv_s': (0.0, 0.9), # image HSV-Saturation augmentation (fraction) - 'hsv_v': (0.0, 0.9), # image HSV-Value augmentation (fraction) - 'degrees': (0.0, 45.0), # image rotation (+/- deg) - 'translate': (0.0, 0.9), # image translation (+/- fraction) - 'scale': (0.0, 0.95), # image scale (+/- gain) - 'shear': (0.0, 10.0), # image shear (+/- deg) - 'perspective': (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 - 'flipud': (0.0, 1.0), # image flip up-down (probability) - 'fliplr': (0.0, 1.0), # image flip left-right (probability) - 'mosaic': (0.0, 1.0), # image mixup (probability) - 'mixup': (0.0, 1.0), # image mixup (probability) - 'copy_paste': (0.0, 1.0)} # segment copy-paste (probability) + "lr0": (1e-5, 1e-1), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3) + "lrf": (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf) + "momentum": (0.7, 0.98, 0.3), # SGD momentum/Adam beta1 + "weight_decay": (0.0, 0.001), # optimizer weight decay 5e-4 + "warmup_epochs": (0.0, 5.0), # warmup epochs (fractions ok) + "warmup_momentum": (0.0, 0.95), # warmup initial momentum + "box": (1.0, 20.0), # box loss gain + "cls": (0.2, 4.0), # cls loss gain (scale with pixels) + "dfl": (0.4, 6.0), # dfl loss gain + "hsv_h": (0.0, 0.1), # image HSV-Hue augmentation (fraction) + "hsv_s": (0.0, 0.9), # image HSV-Saturation augmentation (fraction) + "hsv_v": (0.0, 0.9), # image HSV-Value augmentation (fraction) + "degrees": (0.0, 45.0), # image rotation (+/- deg) + "translate": (0.0, 0.9), # image translation (+/- fraction) + "scale": (0.0, 0.95), # image scale (+/- gain) + "shear": (0.0, 10.0), # image shear (+/- deg) + "perspective": (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 + "flipud": (0.0, 1.0), # image flip up-down (probability) + "fliplr": (0.0, 1.0), # image flip left-right (probability) + "mosaic": (0.0, 1.0), # image mixup (probability) + "mixup": (0.0, 1.0), # image mixup (probability) + "copy_paste": (0.0, 1.0), # segment copy-paste (probability) + } self.args = get_cfg(overrides=args) - self.tune_dir = get_save_dir(self.args, name='tune') - self.tune_csv = self.tune_dir / 'tune_results.csv' + self.tune_dir = get_save_dir(self.args, name="tune") + self.tune_csv = self.tune_dir / "tune_results.csv" self.callbacks = _callbacks or callbacks.get_default_callbacks() - self.prefix = colorstr('Tuner: ') + self.prefix = colorstr("Tuner: ") callbacks.add_integration_callbacks(self) - LOGGER.info(f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n" - f'{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning') + LOGGER.info( + f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n" + f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning" + ) - def _mutate(self, parent='single', n=5, mutation=0.8, sigma=0.2): + def _mutate(self, parent="single", n=5, mutation=0.8, sigma=0.2): """ Mutates the hyperparameters based on bounds and scaling factors specified in `self.space`. @@ -121,15 +124,15 @@ class Tuner: """ if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate # Select parent(s) - x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=',', skiprows=1) + x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1) fitness = x[:, 0] # first column n = min(n, len(x)) # number of previous results to consider x = x[np.argsort(-fitness)][:n] # top n mutations - w = x[:, 0] - x[:, 0].min() + 1E-6 # weights (sum > 0) - if parent == 'single' or len(x) == 1: + w = x[:, 0] - x[:, 0].min() + 1e-6 # weights (sum > 0) + if parent == "single" or len(x) == 1: # x = x[random.randint(0, n - 1)] # random selection x = x[random.choices(range(n), weights=w)[0]] # weighted selection - elif parent == 'weighted': + elif parent == "weighted": x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination # Mutate @@ -174,44 +177,44 @@ class Tuner: t0 = time.time() best_save_dir, best_metrics = None, None - (self.tune_dir / 'weights').mkdir(parents=True, exist_ok=True) + (self.tune_dir / "weights").mkdir(parents=True, exist_ok=True) for i in range(iterations): # Mutate hyperparameters mutated_hyp = self._mutate() - LOGGER.info(f'{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}') + LOGGER.info(f"{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}") metrics = {} train_args = {**vars(self.args), **mutated_hyp} save_dir = get_save_dir(get_cfg(train_args)) - weights_dir = save_dir / 'weights' - ckpt_file = weights_dir / ('best.pt' if (weights_dir / 'best.pt').exists() else 'last.pt') + weights_dir = save_dir / "weights" + ckpt_file = weights_dir / ("best.pt" if (weights_dir / "best.pt").exists() else "last.pt") try: # Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang) - cmd = ['yolo', 'train', *(f'{k}={v}' for k, v in train_args.items())] + cmd = ["yolo", "train", *(f"{k}={v}" for k, v in train_args.items())] return_code = subprocess.run(cmd, check=True).returncode - metrics = torch.load(ckpt_file)['train_metrics'] - assert return_code == 0, 'training failed' + metrics = torch.load(ckpt_file)["train_metrics"] + assert return_code == 0, "training failed" except Exception as e: - LOGGER.warning(f'WARNING ❌️ training failure for hyperparameter tuning iteration {i + 1}\n{e}') + LOGGER.warning(f"WARNING ❌️ training failure for hyperparameter tuning iteration {i + 1}\n{e}") # Save results and mutated_hyp to CSV - fitness = metrics.get('fitness', 0.0) + fitness = metrics.get("fitness", 0.0) log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()] - headers = '' if self.tune_csv.exists() else (','.join(['fitness'] + list(self.space.keys())) + '\n') - with open(self.tune_csv, 'a') as f: - f.write(headers + ','.join(map(str, log_row)) + '\n') + headers = "" if self.tune_csv.exists() else (",".join(["fitness"] + list(self.space.keys())) + "\n") + with open(self.tune_csv, "a") as f: + f.write(headers + ",".join(map(str, log_row)) + "\n") # Get best results - x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=',', skiprows=1) + x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1) fitness = x[:, 0] # first column best_idx = fitness.argmax() best_is_current = best_idx == i if best_is_current: best_save_dir = save_dir best_metrics = {k: round(v, 5) for k, v in metrics.items()} - for ckpt in weights_dir.glob('*.pt'): - shutil.copy2(ckpt, self.tune_dir / 'weights') + for ckpt in weights_dir.glob("*.pt"): + shutil.copy2(ckpt, self.tune_dir / "weights") elif cleanup: shutil.rmtree(ckpt_file.parent) # remove iteration weights/ dir to reduce storage space @@ -219,15 +222,19 @@ class Tuner: plot_tune_results(self.tune_csv) # Save and print tune results - header = (f'{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n' - f'{self.prefix}Results saved to {colorstr("bold", self.tune_dir)}\n' - f'{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n' - f'{self.prefix}Best fitness metrics are {best_metrics}\n' - f'{self.prefix}Best fitness model is {best_save_dir}\n' - f'{self.prefix}Best fitness hyperparameters are printed below.\n') - LOGGER.info('\n' + header) + header = ( + f'{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n' + f'{self.prefix}Results saved to {colorstr("bold", self.tune_dir)}\n' + f'{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n' + f'{self.prefix}Best fitness metrics are {best_metrics}\n' + f'{self.prefix}Best fitness model is {best_save_dir}\n' + f'{self.prefix}Best fitness hyperparameters are printed below.\n' + ) + LOGGER.info("\n" + header) data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())} - yaml_save(self.tune_dir / 'best_hyperparameters.yaml', - data=data, - header=remove_colorstr(header.replace(self.prefix, '# ')) + '\n') - yaml_print(self.tune_dir / 'best_hyperparameters.yaml') + yaml_save( + self.tune_dir / "best_hyperparameters.yaml", + data=data, + header=remove_colorstr(header.replace(self.prefix, "# ")) + "\n", + ) + yaml_print(self.tune_dir / "best_hyperparameters.yaml") diff --git a/ultralytics/engine/validator.py b/ultralytics/engine/validator.py index cf906d9e..5761113f 100644 --- a/ultralytics/engine/validator.py +++ b/ultralytics/engine/validator.py @@ -89,10 +89,10 @@ class BaseValidator: self.nc = None self.iouv = None self.jdict = None - self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0} + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} self.save_dir = save_dir or get_save_dir(self.args) - (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) + (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) if self.args.conf is None: self.args.conf = 0.001 # default conf=0.001 self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1) @@ -110,7 +110,7 @@ class BaseValidator: if self.training: self.device = trainer.device self.data = trainer.data - self.args.half = self.device.type != 'cpu' # force FP16 val during training + self.args.half = self.device.type != "cpu" # force FP16 val during training model = trainer.ema.ema or trainer.model model = model.half() if self.args.half else model.float() # self.model = model @@ -119,11 +119,13 @@ class BaseValidator: model.eval() else: callbacks.add_integration_callbacks(self) - model = AutoBackend(model or self.args.model, - device=select_device(self.args.device, self.args.batch), - dnn=self.args.dnn, - data=self.args.data, - fp16=self.args.half) + model = AutoBackend( + model or self.args.model, + device=select_device(self.args.device, self.args.batch), + dnn=self.args.dnn, + data=self.args.data, + fp16=self.args.half, + ) # self.model = model self.device = model.device # update device self.args.half = model.fp16 # update half @@ -133,16 +135,16 @@ class BaseValidator: self.args.batch = model.batch_size elif not pt and not jit: self.args.batch = 1 # export.py models default to batch-size 1 - LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models') + LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models") - if str(self.args.data).split('.')[-1] in ('yaml', 'yml'): + if str(self.args.data).split(".")[-1] in ("yaml", "yml"): self.data = check_det_dataset(self.args.data) - elif self.args.task == 'classify': + elif self.args.task == "classify": self.data = check_cls_dataset(self.args.data, split=self.args.split) else: raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌")) - if self.device.type in ('cpu', 'mps'): + if self.device.type in ("cpu", "mps"): self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading if not pt: self.args.rect = False @@ -152,13 +154,13 @@ class BaseValidator: model.eval() model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup - self.run_callbacks('on_val_start') + self.run_callbacks("on_val_start") dt = Profile(), Profile(), Profile(), Profile() bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader)) self.init_metrics(de_parallel(model)) self.jdict = [] # empty before each val for batch_i, batch in enumerate(bar): - self.run_callbacks('on_val_batch_start') + self.run_callbacks("on_val_batch_start") self.batch_i = batch_i # Preprocess with dt[0]: @@ -166,7 +168,7 @@ class BaseValidator: # Inference with dt[1]: - preds = model(batch['img'], augment=augment) + preds = model(batch["img"], augment=augment) # Loss with dt[2]: @@ -182,23 +184,25 @@ class BaseValidator: self.plot_val_samples(batch, batch_i) self.plot_predictions(batch, preds, batch_i) - self.run_callbacks('on_val_batch_end') + self.run_callbacks("on_val_batch_end") stats = self.get_stats() self.check_stats(stats) - self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt))) + self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt))) self.finalize_metrics() self.print_results() - self.run_callbacks('on_val_end') + self.run_callbacks("on_val_end") if self.training: model.float() - results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')} + results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")} return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats else: - LOGGER.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' % - tuple(self.speed.values())) + LOGGER.info( + "Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image" + % tuple(self.speed.values()) + ) if self.args.save_json and self.jdict: - with open(str(self.save_dir / 'predictions.json'), 'w') as f: - LOGGER.info(f'Saving {f.name}...') + with open(str(self.save_dir / "predictions.json"), "w") as f: + LOGGER.info(f"Saving {f.name}...") json.dump(self.jdict, f) # flatten and save stats = self.eval_json(stats) # update stats if self.args.plots or self.args.save_json: @@ -228,6 +232,7 @@ class BaseValidator: if use_scipy: # WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708 import scipy # scope import to avoid importing for all commands + cost_matrix = iou * (iou >= threshold) if cost_matrix.any(): labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True) @@ -257,11 +262,11 @@ class BaseValidator: def get_dataloader(self, dataset_path, batch_size): """Get data loader from dataset path and batch size.""" - raise NotImplementedError('get_dataloader function not implemented for this validator') + raise NotImplementedError("get_dataloader function not implemented for this validator") def build_dataset(self, img_path): """Build dataset.""" - raise NotImplementedError('build_dataset function not implemented in validator') + raise NotImplementedError("build_dataset function not implemented in validator") def preprocess(self, batch): """Preprocesses an input batch.""" @@ -306,7 +311,7 @@ class BaseValidator: def on_plot(self, name, data=None): """Registers plots (e.g. to be consumed in callbacks)""" - self.plots[Path(name)] = {'data': data, 'timestamp': time.time()} + self.plots[Path(name)] = {"data": data, "timestamp": time.time()} # TODO: may need to put these following functions into callback def plot_val_samples(self, batch, ni): diff --git a/ultralytics/hub/__init__.py b/ultralytics/hub/__init__.py index 13bc9246..df17da5e 100644 --- a/ultralytics/hub/__init__.py +++ b/ultralytics/hub/__init__.py @@ -21,10 +21,10 @@ def login(api_key: str = None, save=True) -> bool: Returns: bool: True if authentication is successful, False otherwise. """ - api_key_url = f'{HUB_WEB_ROOT}/settings?tab=api+keys' # Set the redirect URL - saved_key = SETTINGS.get('api_key') + api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys" # set the redirect URL + saved_key = SETTINGS.get("api_key") active_key = api_key or saved_key - credentials = {'api_key': active_key} if active_key and active_key != '' else None # Set credentials + credentials = {"api_key": active_key} if active_key and active_key != "" else None # set credentials client = HUBClient(credentials) # initialize HUBClient @@ -32,17 +32,18 @@ def login(api_key: str = None, save=True) -> bool: # Successfully authenticated with HUB if save and client.api_key != saved_key: - SETTINGS.update({'api_key': client.api_key}) # update settings with valid API key + SETTINGS.update({"api_key": client.api_key}) # update settings with valid API key # Set message based on whether key was provided or retrieved from settings - log_message = ('New authentication successful ✅' - if client.api_key == api_key or not credentials else 'Authenticated ✅') - LOGGER.info(f'{PREFIX}{log_message}') + log_message = ( + "New authentication successful ✅" if client.api_key == api_key or not credentials else "Authenticated ✅" + ) + LOGGER.info(f"{PREFIX}{log_message}") return True else: # Failed to authenticate with HUB - LOGGER.info(f'{PREFIX}Retrieve API key from {api_key_url}') + LOGGER.info(f"{PREFIX}Retrieve API key from {api_key_url}") return False @@ -57,50 +58,50 @@ def logout(): hub.logout() ``` """ - SETTINGS['api_key'] = '' + SETTINGS["api_key"] = "" SETTINGS.save() LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.") -def reset_model(model_id=''): +def reset_model(model_id=""): """Reset a trained model to an untrained state.""" - r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'modelId': model_id}, headers={'x-api-key': Auth().api_key}) + r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key}) if r.status_code == 200: - LOGGER.info(f'{PREFIX}Model reset successfully') + LOGGER.info(f"{PREFIX}Model reset successfully") return - LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}') + LOGGER.warning(f"{PREFIX}Model reset failure {r.status_code} {r.reason}") def export_fmts_hub(): """Returns a list of HUB-supported export formats.""" from ultralytics.engine.exporter import export_formats - return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml'] + return list(export_formats()["Argument"][1:]) + ["ultralytics_tflite", "ultralytics_coreml"] -def export_model(model_id='', format='torchscript'): + +def export_model(model_id="", format="torchscript"): """Export a model to all formats.""" assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}" - r = requests.post(f'{HUB_API_ROOT}/v1/models/{model_id}/export', - json={'format': format}, - headers={'x-api-key': Auth().api_key}) - assert r.status_code == 200, f'{PREFIX}{format} export failure {r.status_code} {r.reason}' - LOGGER.info(f'{PREFIX}{format} export started ✅') + r = requests.post( + f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key} + ) + assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}" + LOGGER.info(f"{PREFIX}{format} export started ✅") -def get_export(model_id='', format='torchscript'): +def get_export(model_id="", format="torchscript"): """Get an exported model dictionary with download URL.""" assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}" - r = requests.post(f'{HUB_API_ROOT}/get-export', - json={ - 'apiKey': Auth().api_key, - 'modelId': model_id, - 'format': format}, - headers={'x-api-key': Auth().api_key}) - assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}' + r = requests.post( + f"{HUB_API_ROOT}/get-export", + json={"apiKey": Auth().api_key, "modelId": model_id, "format": format}, + headers={"x-api-key": Auth().api_key}, + ) + assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}" return r.json() -def check_dataset(path='', task='detect'): +def check_dataset(path="", task="detect"): """ Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is uploaded to the HUB. Usage examples are given below. @@ -119,4 +120,4 @@ def check_dataset(path='', task='detect'): ``` """ HUBDatasetStats(path=path, task=task).get_json() - LOGGER.info(f'Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.') + LOGGER.info(f"Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.") diff --git a/ultralytics/hub/auth.py b/ultralytics/hub/auth.py index 202ca8c4..72aad109 100644 --- a/ultralytics/hub/auth.py +++ b/ultralytics/hub/auth.py @@ -6,7 +6,7 @@ from hub_sdk import HUB_API_ROOT, HUB_WEB_ROOT from ultralytics.hub.utils import PREFIX, request_with_credentials from ultralytics.utils import LOGGER, SETTINGS, emojis, is_colab -API_KEY_URL = f'{HUB_WEB_ROOT}/settings?tab=api+keys' +API_KEY_URL = f"{HUB_WEB_ROOT}/settings?tab=api+keys" class Auth: @@ -23,9 +23,10 @@ class Auth: api_key (str or bool): API key for authentication, initialized as False. model_key (bool): Placeholder for model key, initialized as False. """ + id_token = api_key = model_key = False - def __init__(self, api_key='', verbose=False): + def __init__(self, api_key="", verbose=False): """ Initialize the Auth class with an optional API key. @@ -33,18 +34,18 @@ class Auth: api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id """ # Split the input API key in case it contains a combined key_model and keep only the API key part - api_key = api_key.split('_')[0] + api_key = api_key.split("_")[0] # Set API key attribute as value passed or SETTINGS API key if none passed - self.api_key = api_key or SETTINGS.get('api_key', '') + self.api_key = api_key or SETTINGS.get("api_key", "") # If an API key is provided if self.api_key: # If the provided API key matches the API key in the SETTINGS - if self.api_key == SETTINGS.get('api_key'): + if self.api_key == SETTINGS.get("api_key"): # Log that the user is already logged in if verbose: - LOGGER.info(f'{PREFIX}Authenticated ✅') + LOGGER.info(f"{PREFIX}Authenticated ✅") return else: # Attempt to authenticate with the provided API key @@ -59,12 +60,12 @@ class Auth: # Update SETTINGS with the new API key after successful authentication if success: - SETTINGS.update({'api_key': self.api_key}) + SETTINGS.update({"api_key": self.api_key}) # Log that the new login was successful if verbose: - LOGGER.info(f'{PREFIX}New authentication successful ✅') + LOGGER.info(f"{PREFIX}New authentication successful ✅") elif verbose: - LOGGER.info(f'{PREFIX}Retrieve API key from {API_KEY_URL}') + LOGGER.info(f"{PREFIX}Retrieve API key from {API_KEY_URL}") def request_api_key(self, max_attempts=3): """ @@ -73,13 +74,14 @@ class Auth: Returns the model ID. """ import getpass + for attempts in range(max_attempts): - LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}') - input_key = getpass.getpass(f'Enter API key from {API_KEY_URL} ') - self.api_key = input_key.split('_')[0] # remove model id if present + LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}") + input_key = getpass.getpass(f"Enter API key from {API_KEY_URL} ") + self.api_key = input_key.split("_")[0] # remove model id if present if self.authenticate(): return True - raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌')) + raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌")) def authenticate(self) -> bool: """ @@ -90,14 +92,14 @@ class Auth: """ try: if header := self.get_auth_header(): - r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header) - if not r.json().get('success', False): - raise ConnectionError('Unable to authenticate.') + r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header) + if not r.json().get("success", False): + raise ConnectionError("Unable to authenticate.") return True - raise ConnectionError('User has not authenticated locally.') + raise ConnectionError("User has not authenticated locally.") except ConnectionError: self.id_token = self.api_key = False # reset invalid - LOGGER.warning(f'{PREFIX}Invalid API key ⚠️') + LOGGER.warning(f"{PREFIX}Invalid API key ⚠️") return False def auth_with_cookies(self) -> bool: @@ -111,12 +113,12 @@ class Auth: if not is_colab(): return False # Currently only works with Colab try: - authn = request_with_credentials(f'{HUB_API_ROOT}/v1/auth/auto') - if authn.get('success', False): - self.id_token = authn.get('data', {}).get('idToken', None) + authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto") + if authn.get("success", False): + self.id_token = authn.get("data", {}).get("idToken", None) self.authenticate() return True - raise ConnectionError('Unable to fetch browser authentication details.') + raise ConnectionError("Unable to fetch browser authentication details.") except ConnectionError: self.id_token = False # reset invalid return False @@ -129,7 +131,7 @@ class Auth: (dict): The authentication header if id_token or API key is set, None otherwise. """ if self.id_token: - return {'authorization': f'Bearer {self.id_token}'} + return {"authorization": f"Bearer {self.id_token}"} elif self.api_key: - return {'x-api-key': self.api_key} + return {"x-api-key": self.api_key} # else returns None diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py index dd3d01c0..150f4eb4 100644 --- a/ultralytics/hub/session.py +++ b/ultralytics/hub/session.py @@ -12,16 +12,13 @@ from ultralytics.hub.utils import HELP_MSG, PREFIX, TQDM from ultralytics.utils import LOGGER, SETTINGS, __version__, checks, emojis, is_colab from ultralytics.utils.errors import HUBModelError -AGENT_NAME = (f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local') +AGENT_NAME = f"python-{__version__}-colab" if is_colab() else f"python-{__version__}-local" class HUBTrainingSession: """ HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing. - Args: - url (str): Model identifier used to initialize the HUB training session. - Attributes: agent_id (str): Identifier for the instance communicating with the server. model_id (str): Identifier for the YOLO model being trained. @@ -40,17 +37,18 @@ class HUBTrainingSession: Initialize the HUBTrainingSession with the provided model identifier. Args: - url (str): Model identifier used to initialize the HUB training session. - It can be a URL string or a model key with specific format. + identifier (str): Model identifier used to initialize the HUB training session. + It can be a URL string or a model key with specific format. Raises: ValueError: If the provided model identifier is invalid. ConnectionError: If connecting with global API key is not supported. """ self.rate_limits = { - 'metrics': 3.0, - 'ckpt': 900.0, - 'heartbeat': 300.0, } # rate limits (seconds) + "metrics": 3.0, + "ckpt": 900.0, + "heartbeat": 300.0, + } # rate limits (seconds) self.metrics_queue = {} # holds metrics for each epoch until upload self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py @@ -58,8 +56,8 @@ class HUBTrainingSession: api_key, model_id, self.filename = self._parse_identifier(identifier) # Get credentials - active_key = api_key or SETTINGS.get('api_key') - credentials = {'api_key': active_key} if active_key else None # set credentials + active_key = api_key or SETTINGS.get("api_key") + credentials = {"api_key": active_key} if active_key else None # set credentials # Initialize client self.client = HUBClient(credentials) @@ -72,35 +70,37 @@ class HUBTrainingSession: def load_model(self, model_id): # Initialize model self.model = self.client.model(model_id) - self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}' + self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}" self._set_train_args() # Start heartbeats for HUB to monitor agent - self.model.start_heartbeat(self.rate_limits['heartbeat']) - LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀') + self.model.start_heartbeat(self.rate_limits["heartbeat"]) + LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀") def create_model(self, model_args): # Initialize model payload = { - 'config': { - 'batchSize': model_args.get('batch', -1), - 'epochs': model_args.get('epochs', 300), - 'imageSize': model_args.get('imgsz', 640), - 'patience': model_args.get('patience', 100), - 'device': model_args.get('device', ''), - 'cache': model_args.get('cache', 'ram'), }, - 'dataset': { - 'name': model_args.get('data')}, - 'lineage': { - 'architecture': { - 'name': self.filename.replace('.pt', '').replace('.yaml', ''), }, - 'parent': {}, }, - 'meta': { - 'name': self.filename}, } - - if self.filename.endswith('.pt'): - payload['lineage']['parent']['name'] = self.filename + "config": { + "batchSize": model_args.get("batch", -1), + "epochs": model_args.get("epochs", 300), + "imageSize": model_args.get("imgsz", 640), + "patience": model_args.get("patience", 100), + "device": model_args.get("device", ""), + "cache": model_args.get("cache", "ram"), + }, + "dataset": {"name": model_args.get("data")}, + "lineage": { + "architecture": { + "name": self.filename.replace(".pt", "").replace(".yaml", ""), + }, + "parent": {}, + }, + "meta": {"name": self.filename}, + } + + if self.filename.endswith(".pt"): + payload["lineage"]["parent"]["name"] = self.filename self.model.create_model(payload) @@ -109,12 +109,12 @@ class HUBTrainingSession: if not self.model.id: return - self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}' + self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}" # Start heartbeats for HUB to monitor agent - self.model.start_heartbeat(self.rate_limits['heartbeat']) + self.model.start_heartbeat(self.rate_limits["heartbeat"]) - LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀') + LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀") def _parse_identifier(self, identifier): """ @@ -125,13 +125,13 @@ class HUBTrainingSession: - An identifier containing an API key and a model ID separated by an underscore - An identifier that is solely a model ID of a fixed length - A local filename that ends with '.pt' or '.yaml' - + Args: identifier (str): The identifier string to be parsed. - + Returns: (tuple): A tuple containing the API key, model ID, and filename as applicable. - + Raises: HUBModelError: If the identifier format is not recognized. """ @@ -140,12 +140,12 @@ class HUBTrainingSession: api_key, model_id, filename = None, None, None # Check if identifier is a HUB URL - if identifier.startswith(f'{HUB_WEB_ROOT}/models/'): + if identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # Extract the model_id after the HUB_WEB_ROOT URL - model_id = identifier.split(f'{HUB_WEB_ROOT}/models/')[-1] + model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1] else: # Split the identifier based on underscores only if it's not a HUB URL - parts = identifier.split('_') + parts = identifier.split("_") # Check if identifier is in the format of API key and model ID if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20: @@ -154,43 +154,46 @@ class HUBTrainingSession: elif len(parts) == 1 and len(parts[0]) == 20: model_id = parts[0] # Check if identifier is a local filename - elif identifier.endswith('.pt') or identifier.endswith('.yaml'): + elif identifier.endswith(".pt") or identifier.endswith(".yaml"): filename = identifier else: raise HUBModelError( f"model='{identifier}' could not be parsed. Check format is correct. " - f'Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file.') + f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file." + ) return api_key, model_id, filename def _set_train_args(self, **kwargs): if self.model.is_trained(): # Model is already trained - raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀')) + raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀")) if self.model.is_resumable(): # Model has saved weights - self.train_args = {'data': self.model.get_dataset_url(), 'resume': True} - self.model_file = self.model.get_weights_url('last') + self.train_args = {"data": self.model.get_dataset_url(), "resume": True} + self.model_file = self.model.get_weights_url("last") else: # Model has no saved weights def get_train_args(config): return { - 'batch': config['batchSize'], - 'epochs': config['epochs'], - 'imgsz': config['imageSize'], - 'patience': config['patience'], - 'device': config['device'], - 'cache': config['cache'], - 'data': self.model.get_dataset_url(), } - - self.train_args = get_train_args(self.model.data.get('config')) + "batch": config["batchSize"], + "epochs": config["epochs"], + "imgsz": config["imageSize"], + "patience": config["patience"], + "device": config["device"], + "cache": config["cache"], + "data": self.model.get_dataset_url(), + } + + self.train_args = get_train_args(self.model.data.get("config")) # Set the model file as either a *.pt or *.yaml file - self.model_file = (self.model.get_weights_url('parent') - if self.model.is_pretrained() else self.model.get_architecture()) + self.model_file = ( + self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture() + ) - if not self.train_args.get('data'): - raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix + if not self.train_args.get("data"): + raise ValueError("Dataset may still be processing. Please wait a minute and try again.") # RF fix self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u self.model_id = self.model.id @@ -206,12 +209,11 @@ class HUBTrainingSession: *args, **kwargs, ): - def retry_request(): t0 = time.time() # Record the start time for the timeout for i in range(retry + 1): if (time.time() - t0) > timeout: - LOGGER.warning(f'{PREFIX}Timeout for request reached. {HELP_MSG}') + LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}") break # Timeout reached, exit loop response = request_func(*args, **kwargs) @@ -219,8 +221,8 @@ class HUBTrainingSession: self._show_upload_progress(progress_total, response) if response is None: - LOGGER.warning(f'{PREFIX}Received no response from the request. {HELP_MSG}') - time.sleep(2 ** i) # Exponential backoff before retrying + LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}") + time.sleep(2**i) # Exponential backoff before retrying continue # Skip further processing and retry if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES: @@ -231,13 +233,13 @@ class HUBTrainingSession: message = self._get_failure_message(response, retry, timeout) if verbose: - LOGGER.warning(f'{PREFIX}{message} {HELP_MSG} ({response.status_code})') + LOGGER.warning(f"{PREFIX}{message} {HELP_MSG} ({response.status_code})") if not self._should_retry(response.status_code): - LOGGER.warning(f'{PREFIX}Request failed. {HELP_MSG} ({response.status_code}') + LOGGER.warning(f"{PREFIX}Request failed. {HELP_MSG} ({response.status_code}") break # Not an error that should be retried, exit loop - time.sleep(2 ** i) # Exponential backoff for retries + time.sleep(2**i) # Exponential backoff for retries return response @@ -253,7 +255,8 @@ class HUBTrainingSession: retry_codes = { HTTPStatus.REQUEST_TIMEOUT, HTTPStatus.BAD_GATEWAY, - HTTPStatus.GATEWAY_TIMEOUT, } + HTTPStatus.GATEWAY_TIMEOUT, + } return True if status_code in retry_codes else False def _get_failure_message(self, response: requests.Response, retry: int, timeout: int): @@ -269,16 +272,18 @@ class HUBTrainingSession: str: The retry message. """ if self._should_retry(response.status_code): - return f'Retrying {retry}x for {timeout}s.' if retry else '' + return f"Retrying {retry}x for {timeout}s." if retry else "" elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit headers = response.headers - return (f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). " - f"Please retry after {headers['Retry-After']}s.") + return ( + f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). " + f"Please retry after {headers['Retry-After']}s." + ) else: try: - return response.json().get('message', 'No JSON message.') + return response.json().get("message", "No JSON message.") except AttributeError: - return 'Unable to read JSON.' + return "Unable to read JSON." def upload_metrics(self): """Upload model metrics to Ultralytics HUB.""" @@ -303,7 +308,7 @@ class HUBTrainingSession: final (bool): Indicates if the model is the final model after training. """ if Path(weights).is_file(): - progress_total = (Path(weights).stat().st_size if final else None) # Only show progress if final + progress_total = Path(weights).stat().st_size if final else None # Only show progress if final self.request_queue( self.model.upload_model, epoch=epoch, @@ -317,7 +322,7 @@ class HUBTrainingSession: progress_total=progress_total, ) else: - LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.') + LOGGER.warning(f"{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.") def _show_upload_progress(self, content_length: int, response: requests.Response) -> None: """ @@ -330,6 +335,6 @@ class HUBTrainingSession: Returns: (None) """ - with TQDM(total=content_length, unit='B', unit_scale=True, unit_divisor=1024) as pbar: + with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar: for data in response.iter_content(chunk_size=1024): pbar.update(len(data)) diff --git a/ultralytics/hub/utils.py b/ultralytics/hub/utils.py index 1277c63a..2cdd0350 100644 --- a/ultralytics/hub/utils.py +++ b/ultralytics/hub/utils.py @@ -9,12 +9,26 @@ from pathlib import Path import requests -from ultralytics.utils import (ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING, TQDM, TryExcept, __version__, - colorstr, get_git_origin_url, is_colab, is_git_dir, is_pip_package) +from ultralytics.utils import ( + ENVIRONMENT, + LOGGER, + ONLINE, + RANK, + SETTINGS, + TESTS_RUNNING, + TQDM, + TryExcept, + __version__, + colorstr, + get_git_origin_url, + is_colab, + is_git_dir, + is_pip_package, +) from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES -PREFIX = colorstr('Ultralytics HUB: ') -HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.' +PREFIX = colorstr("Ultralytics HUB: ") +HELP_MSG = "If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance." def request_with_credentials(url: str) -> any: @@ -31,11 +45,13 @@ def request_with_credentials(url: str) -> any: OSError: If the function is not run in a Google Colab environment. """ if not is_colab(): - raise OSError('request_with_credentials() must run in a Colab environment') + raise OSError("request_with_credentials() must run in a Colab environment") from google.colab import output # noqa from IPython import display # noqa + display.display( - display.Javascript(""" + display.Javascript( + """ window._hub_tmp = new Promise((resolve, reject) => { const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000) fetch("%s", { @@ -50,8 +66,11 @@ def request_with_credentials(url: str) -> any: reject(err); }); }); - """ % url)) - return output.eval_js('_hub_tmp') + """ + % url + ) + ) + return output.eval_js("_hub_tmp") def requests_with_progress(method, url, **kwargs): @@ -71,13 +90,13 @@ def requests_with_progress(method, url, **kwargs): content length. - If 'progress' is a number then progress bar will display assuming content length = progress. """ - progress = kwargs.pop('progress', False) + progress = kwargs.pop("progress", False) if not progress: return requests.request(method, url, **kwargs) response = requests.request(method, url, stream=True, **kwargs) - total = int(response.headers.get('content-length', 0) if isinstance(progress, bool) else progress) # total size + total = int(response.headers.get("content-length", 0) if isinstance(progress, bool) else progress) # total size try: - pbar = TQDM(total=total, unit='B', unit_scale=True, unit_divisor=1024) + pbar = TQDM(total=total, unit="B", unit_scale=True, unit_divisor=1024) for data in response.iter_content(chunk_size=1024): pbar.update(len(data)) pbar.close() @@ -118,25 +137,27 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos if r.status_code < 300: # return codes in the 2xx range are generally considered "good" or "successful" break try: - m = r.json().get('message', 'No JSON message.') + m = r.json().get("message", "No JSON message.") except AttributeError: - m = 'Unable to read JSON.' + m = "Unable to read JSON." if i == 0: if r.status_code in retry_codes: - m += f' Retrying {retry}x for {timeout}s.' if retry else '' + m += f" Retrying {retry}x for {timeout}s." if retry else "" elif r.status_code == 429: # rate limit h = r.headers # response headers - m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \ + m = ( + f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " f"Please retry after {h['Retry-After']}s." + ) if verbose: - LOGGER.warning(f'{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})') + LOGGER.warning(f"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})") if r.status_code not in retry_codes: return r - time.sleep(2 ** i) # exponential standoff + time.sleep(2**i) # exponential standoff return r args = method, url - kwargs['progress'] = progress + kwargs["progress"] = progress if thread: threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start() else: @@ -155,7 +176,7 @@ class Events: enabled (bool): A flag to enable or disable Events based on certain conditions. """ - url = 'https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw' + url = "https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw" def __init__(self): """Initializes the Events object with default values for events, rate_limit, and metadata.""" @@ -163,19 +184,21 @@ class Events: self.rate_limit = 60.0 # rate limit (seconds) self.t = 0.0 # rate limit timer (seconds) self.metadata = { - 'cli': Path(sys.argv[0]).name == 'yolo', - 'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other', - 'python': '.'.join(platform.python_version_tuple()[:2]), # i.e. 3.10 - 'version': __version__, - 'env': ENVIRONMENT, - 'session_id': round(random.random() * 1E15), - 'engagement_time_msec': 1000} - self.enabled = \ - SETTINGS['sync'] and \ - RANK in (-1, 0) and \ - not TESTS_RUNNING and \ - ONLINE and \ - (is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git') + "cli": Path(sys.argv[0]).name == "yolo", + "install": "git" if is_git_dir() else "pip" if is_pip_package() else "other", + "python": ".".join(platform.python_version_tuple()[:2]), # i.e. 3.10 + "version": __version__, + "env": ENVIRONMENT, + "session_id": round(random.random() * 1e15), + "engagement_time_msec": 1000, + } + self.enabled = ( + SETTINGS["sync"] + and RANK in (-1, 0) + and not TESTS_RUNNING + and ONLINE + and (is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git") + ) def __call__(self, cfg): """ @@ -191,11 +214,13 @@ class Events: # Attempt to add to events if len(self.events) < 25: # Events list limited to 25 events (drop any events past this) params = { - **self.metadata, 'task': cfg.task, - 'model': cfg.model if cfg.model in GITHUB_ASSETS_NAMES else 'custom'} - if cfg.mode == 'export': - params['format'] = cfg.format - self.events.append({'name': cfg.mode, 'params': params}) + **self.metadata, + "task": cfg.task, + "model": cfg.model if cfg.model in GITHUB_ASSETS_NAMES else "custom", + } + if cfg.mode == "export": + params["format"] = cfg.format + self.events.append({"name": cfg.mode, "params": params}) # Check rate limit t = time.time() @@ -204,10 +229,10 @@ class Events: return # Time is over rate limiter, send now - data = {'client_id': SETTINGS['uuid'], 'events': self.events} # SHA-256 anonymized UUID hash and events list + data = {"client_id": SETTINGS["uuid"], "events": self.events} # SHA-256 anonymized UUID hash and events list # POST equivalent to requests.post(self.url, json=data) - smart_request('post', self.url, json=data, retry=0, verbose=False) + smart_request("post", self.url, json=data, retry=0, verbose=False) # Reset events and rate limit timer self.events = [] diff --git a/ultralytics/models/__init__.py b/ultralytics/models/__init__.py index e96f893e..ef49faff 100644 --- a/ultralytics/models/__init__.py +++ b/ultralytics/models/__init__.py @@ -4,4 +4,4 @@ from .rtdetr import RTDETR from .sam import SAM from .yolo import YOLO -__all__ = 'YOLO', 'RTDETR', 'SAM' # allow simpler import +__all__ = "YOLO", "RTDETR", "SAM" # allow simpler import diff --git a/ultralytics/models/fastsam/__init__.py b/ultralytics/models/fastsam/__init__.py index 8f47772f..eabf5b9f 100644 --- a/ultralytics/models/fastsam/__init__.py +++ b/ultralytics/models/fastsam/__init__.py @@ -5,4 +5,4 @@ from .predict import FastSAMPredictor from .prompt import FastSAMPrompt from .val import FastSAMValidator -__all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMValidator' +__all__ = "FastSAMPredictor", "FastSAM", "FastSAMPrompt", "FastSAMValidator" diff --git a/ultralytics/models/fastsam/model.py b/ultralytics/models/fastsam/model.py index e6475faa..8904f8e9 100644 --- a/ultralytics/models/fastsam/model.py +++ b/ultralytics/models/fastsam/model.py @@ -21,14 +21,14 @@ class FastSAM(Model): ``` """ - def __init__(self, model='FastSAM-x.pt'): + def __init__(self, model="FastSAM-x.pt"): """Call the __init__ method of the parent class (YOLO) with the updated default model.""" - if str(model) == 'FastSAM.pt': - model = 'FastSAM-x.pt' - assert Path(model).suffix not in ('.yaml', '.yml'), 'FastSAM models only support pre-trained models.' - super().__init__(model=model, task='segment') + if str(model) == "FastSAM.pt": + model = "FastSAM-x.pt" + assert Path(model).suffix not in (".yaml", ".yml"), "FastSAM models only support pre-trained models." + super().__init__(model=model, task="segment") @property def task_map(self): """Returns a dictionary mapping segment task to corresponding predictor and validator classes.""" - return {'segment': {'predictor': FastSAMPredictor, 'validator': FastSAMValidator}} + return {"segment": {"predictor": FastSAMPredictor, "validator": FastSAMValidator}} diff --git a/ultralytics/models/fastsam/predict.py b/ultralytics/models/fastsam/predict.py index 4a3c2e9e..0ef18032 100644 --- a/ultralytics/models/fastsam/predict.py +++ b/ultralytics/models/fastsam/predict.py @@ -33,7 +33,7 @@ class FastSAMPredictor(DetectionPredictor): _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction. """ super().__init__(cfg, overrides, _callbacks) - self.args.task = 'segment' + self.args.task = "segment" def postprocess(self, preds, img, orig_imgs): """ @@ -55,7 +55,8 @@ class FastSAMPredictor(DetectionPredictor): agnostic=self.args.agnostic_nms, max_det=self.args.max_det, nc=1, # set to 1 class since SAM has no class predictions - classes=self.args.classes) + classes=self.args.classes, + ) full_box = torch.zeros(p[0].shape[1], device=p[0].device) full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0 full_box = full_box.view(1, -1) diff --git a/ultralytics/models/fastsam/prompt.py b/ultralytics/models/fastsam/prompt.py index 0f43441a..921d8322 100644 --- a/ultralytics/models/fastsam/prompt.py +++ b/ultralytics/models/fastsam/prompt.py @@ -23,7 +23,7 @@ class FastSAMPrompt: clip: CLIP model for linear assignment. """ - def __init__(self, source, results, device='cuda') -> None: + def __init__(self, source, results, device="cuda") -> None: """Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment.""" self.device = device self.results = results @@ -34,7 +34,8 @@ class FastSAMPrompt: import clip # for linear_assignment except ImportError: from ultralytics.utils.checks import check_requirements - check_requirements('git+https://github.com/openai/CLIP.git') + + check_requirements("git+https://github.com/openai/CLIP.git") import clip self.clip = clip @@ -46,11 +47,11 @@ class FastSAMPrompt: x1, y1, x2, y2 = bbox segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2] segmented_image = Image.fromarray(segmented_image_array) - black_image = Image.new('RGB', image.size, (255, 255, 255)) + black_image = Image.new("RGB", image.size, (255, 255, 255)) # transparency_mask = np.zeros_like((), dtype=np.uint8) transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8) transparency_mask[y1:y2, x1:x2] = 255 - transparency_mask_image = Image.fromarray(transparency_mask, mode='L') + transparency_mask_image = Image.fromarray(transparency_mask, mode="L") black_image.paste(segmented_image, mask=transparency_mask_image) return black_image @@ -65,11 +66,12 @@ class FastSAMPrompt: mask = result.masks.data[i] == 1.0 if torch.sum(mask) >= filter: annotation = { - 'id': i, - 'segmentation': mask.cpu().numpy(), - 'bbox': result.boxes.data[i], - 'score': result.boxes.conf[i]} - annotation['area'] = annotation['segmentation'].sum() + "id": i, + "segmentation": mask.cpu().numpy(), + "bbox": result.boxes.data[i], + "score": result.boxes.conf[i], + } + annotation["area"] = annotation["segmentation"].sum() annotations.append(annotation) return annotations @@ -91,16 +93,18 @@ class FastSAMPrompt: y2 = max(y2, y_t + h_t) return [x1, y1, x2, y2] - def plot(self, - annotations, - output, - bbox=None, - points=None, - point_label=None, - mask_random_color=True, - better_quality=True, - retina=False, - with_contours=True): + def plot( + self, + annotations, + output, + bbox=None, + points=None, + point_label=None, + mask_random_color=True, + better_quality=True, + retina=False, + with_contours=True, + ): """ Plots annotations, bounding boxes, and points on images and saves the output. @@ -139,15 +143,17 @@ class FastSAMPrompt: mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)) - self.fast_show_mask(masks, - plt.gca(), - random_color=mask_random_color, - bbox=bbox, - points=points, - pointlabel=point_label, - retinamask=retina, - target_height=original_h, - target_width=original_w) + self.fast_show_mask( + masks, + plt.gca(), + random_color=mask_random_color, + bbox=bbox, + points=points, + pointlabel=point_label, + retinamask=retina, + target_height=original_h, + target_width=original_w, + ) if with_contours: contour_all = [] @@ -166,10 +172,10 @@ class FastSAMPrompt: # Save the figure save_path = Path(output) / result_name save_path.parent.mkdir(exist_ok=True, parents=True) - plt.axis('off') - plt.savefig(save_path, bbox_inches='tight', pad_inches=0, transparent=True) + plt.axis("off") + plt.savefig(save_path, bbox_inches="tight", pad_inches=0, transparent=True) plt.close() - pbar.set_description(f'Saving {result_name} to {save_path}') + pbar.set_description(f"Saving {result_name} to {save_path}") @staticmethod def fast_show_mask( @@ -212,26 +218,26 @@ class FastSAMPrompt: mask_image = np.expand_dims(annotation, -1) * visual show = np.zeros((h, w, 4)) - h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing='ij') + h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing="ij") indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) show[h_indices, w_indices, :] = mask_image[indices] if bbox is not None: x1, y1, x2, y2 = bbox - ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1)) + ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1)) # Draw point if points is not None: plt.scatter( [point[0] for i, point in enumerate(points) if pointlabel[i] == 1], [point[1] for i, point in enumerate(points) if pointlabel[i] == 1], s=20, - c='y', + c="y", ) plt.scatter( [point[0] for i, point in enumerate(points) if pointlabel[i] == 0], [point[1] for i, point in enumerate(points) if pointlabel[i] == 0], s=20, - c='m', + c="m", ) if not retinamask: @@ -258,7 +264,7 @@ class FastSAMPrompt: image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB)) ori_w, ori_h = image.size annotations = format_results - mask_h, mask_w = annotations[0]['segmentation'].shape + mask_h, mask_w = annotations[0]["segmentation"].shape if ori_w != mask_w or ori_h != mask_h: image = image.resize((mask_w, mask_h)) cropped_boxes = [] @@ -266,19 +272,19 @@ class FastSAMPrompt: not_crop = [] filter_id = [] for _, mask in enumerate(annotations): - if np.sum(mask['segmentation']) <= 100: + if np.sum(mask["segmentation"]) <= 100: filter_id.append(_) continue - bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox - cropped_boxes.append(self._segment_image(image, bbox)) # 保存裁剪的图片 - cropped_images.append(bbox) # 保存裁剪的图片的bbox + bbox = self._get_bbox_from_mask(mask["segmentation"]) # bbox from mask + cropped_boxes.append(self._segment_image(image, bbox)) # save cropped image + cropped_images.append(bbox) # save cropped image bbox return cropped_boxes, cropped_images, not_crop, filter_id, annotations def box_prompt(self, bbox): """Modifies the bounding box properties and calculates IoU between masks and bounding box.""" if self.results[0].masks is not None: - assert (bbox[2] != 0 and bbox[3] != 0) + assert bbox[2] != 0 and bbox[3] != 0 if os.path.isdir(self.source): raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.") masks = self.results[0].masks.data @@ -290,7 +296,8 @@ class FastSAMPrompt: int(bbox[0] * w / target_width), int(bbox[1] * h / target_height), int(bbox[2] * w / target_width), - int(bbox[3] * h / target_height), ] + int(bbox[3] * h / target_height), + ] bbox[0] = max(round(bbox[0]), 0) bbox[1] = max(round(bbox[1]), 0) bbox[2] = min(round(bbox[2]), w) @@ -299,7 +306,7 @@ class FastSAMPrompt: # IoUs = torch.zeros(len(masks), dtype=torch.float32) bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) - masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2)) + masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2)) orig_masks_area = torch.sum(masks, dim=(1, 2)) union = bbox_area + orig_masks_area - masks_area @@ -316,13 +323,13 @@ class FastSAMPrompt: raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.") masks = self._format_results(self.results[0], 0) target_height, target_width = self.results[0].orig_shape - h = masks[0]['segmentation'].shape[0] - w = masks[0]['segmentation'].shape[1] + h = masks[0]["segmentation"].shape[0] + w = masks[0]["segmentation"].shape[1] if h != target_height or w != target_width: points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points] onemask = np.zeros((h, w)) for annotation in masks: - mask = annotation['segmentation'] if isinstance(annotation, dict) else annotation + mask = annotation["segmentation"] if isinstance(annotation, dict) else annotation for i, point in enumerate(points): if mask[point[1], point[0]] == 1 and pointlabel[i] == 1: onemask += mask @@ -337,12 +344,12 @@ class FastSAMPrompt: if self.results[0].masks is not None: format_results = self._format_results(self.results[0], 0) cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results) - clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device) + clip_model, preprocess = self.clip.load("ViT-B/32", device=self.device) scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device) max_idx = scores.argsort() max_idx = max_idx[-1] max_idx += sum(np.array(filter_id) <= int(max_idx)) - self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]['segmentation']])) + self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]["segmentation"]])) return self.results def everything_prompt(self): diff --git a/ultralytics/models/fastsam/val.py b/ultralytics/models/fastsam/val.py index 4e1e0b01..9014b27a 100644 --- a/ultralytics/models/fastsam/val.py +++ b/ultralytics/models/fastsam/val.py @@ -35,6 +35,6 @@ class FastSAMValidator(SegmentationValidator): Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors. """ super().__init__(dataloader, save_dir, pbar, args, _callbacks) - self.args.task = 'segment' + self.args.task = "segment" self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot) diff --git a/ultralytics/models/nas/__init__.py b/ultralytics/models/nas/__init__.py index eec3837d..b095a050 100644 --- a/ultralytics/models/nas/__init__.py +++ b/ultralytics/models/nas/__init__.py @@ -4,4 +4,4 @@ from .model import NAS from .predict import NASPredictor from .val import NASValidator -__all__ = 'NASPredictor', 'NASValidator', 'NAS' +__all__ = "NASPredictor", "NASValidator", "NAS" diff --git a/ultralytics/models/nas/model.py b/ultralytics/models/nas/model.py index 00d0b6ed..f990cb4e 100644 --- a/ultralytics/models/nas/model.py +++ b/ultralytics/models/nas/model.py @@ -44,20 +44,21 @@ class NAS(Model): YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files. """ - def __init__(self, model='yolo_nas_s.pt') -> None: + def __init__(self, model="yolo_nas_s.pt") -> None: """Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model.""" - assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.' - super().__init__(model, task='detect') + assert Path(model).suffix not in (".yaml", ".yml"), "YOLO-NAS models only support pre-trained models." + super().__init__(model, task="detect") @smart_inference_mode() def _load(self, weights: str, task: str): """Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided.""" import super_gradients + suffix = Path(weights).suffix - if suffix == '.pt': + if suffix == ".pt": self.model = torch.load(weights) - elif suffix == '': - self.model = super_gradients.training.models.get(weights, pretrained_weights='coco') + elif suffix == "": + self.model = super_gradients.training.models.get(weights, pretrained_weights="coco") # Standardize model self.model.fuse = lambda verbose=True: self.model self.model.stride = torch.tensor([32]) @@ -65,7 +66,7 @@ class NAS(Model): self.model.is_fused = lambda: False # for info() self.model.yaml = {} # for info() self.model.pt_path = weights # for export() - self.model.task = 'detect' # for export() + self.model.task = "detect" # for export() def info(self, detailed=False, verbose=True): """ @@ -80,4 +81,4 @@ class NAS(Model): @property def task_map(self): """Returns a dictionary mapping tasks to respective predictor and validator classes.""" - return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}} + return {"detect": {"predictor": NASPredictor, "validator": NASValidator}} diff --git a/ultralytics/models/nas/predict.py b/ultralytics/models/nas/predict.py index 0118527a..2e485462 100644 --- a/ultralytics/models/nas/predict.py +++ b/ultralytics/models/nas/predict.py @@ -39,12 +39,14 @@ class NASPredictor(BasePredictor): boxes = ops.xyxy2xywh(preds_in[0][0]) preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) - preds = ops.non_max_suppression(preds, - self.args.conf, - self.args.iou, - agnostic=self.args.agnostic_nms, - max_det=self.args.max_det, - classes=self.args.classes) + preds = ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + agnostic=self.args.agnostic_nms, + max_det=self.args.max_det, + classes=self.args.classes, + ) if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) diff --git a/ultralytics/models/nas/val.py b/ultralytics/models/nas/val.py index 41f60c19..a4a4f990 100644 --- a/ultralytics/models/nas/val.py +++ b/ultralytics/models/nas/val.py @@ -5,7 +5,7 @@ import torch from ultralytics.models.yolo.detect import DetectionValidator from ultralytics.utils import ops -__all__ = ['NASValidator'] +__all__ = ["NASValidator"] class NASValidator(DetectionValidator): @@ -38,11 +38,13 @@ class NASValidator(DetectionValidator): """Apply Non-maximum suppression to prediction outputs.""" boxes = ops.xyxy2xywh(preds_in[0][0]) preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) - return ops.non_max_suppression(preds, - self.args.conf, - self.args.iou, - labels=self.lb, - multi_label=False, - agnostic=self.args.single_cls, - max_det=self.args.max_det, - max_time_img=0.5) + return ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + labels=self.lb, + multi_label=False, + agnostic=self.args.single_cls, + max_det=self.args.max_det, + max_time_img=0.5, + ) diff --git a/ultralytics/models/rtdetr/__init__.py b/ultralytics/models/rtdetr/__init__.py index 4d121156..172c74b4 100644 --- a/ultralytics/models/rtdetr/__init__.py +++ b/ultralytics/models/rtdetr/__init__.py @@ -4,4 +4,4 @@ from .model import RTDETR from .predict import RTDETRPredictor from .val import RTDETRValidator -__all__ = 'RTDETRPredictor', 'RTDETRValidator', 'RTDETR' +__all__ = "RTDETRPredictor", "RTDETRValidator", "RTDETR" diff --git a/ultralytics/models/rtdetr/model.py b/ultralytics/models/rtdetr/model.py index 6e582a8b..b8def485 100644 --- a/ultralytics/models/rtdetr/model.py +++ b/ultralytics/models/rtdetr/model.py @@ -24,7 +24,7 @@ class RTDETR(Model): model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'. """ - def __init__(self, model='rtdetr-l.pt') -> None: + def __init__(self, model="rtdetr-l.pt") -> None: """ Initializes the RT-DETR model with the given pre-trained model file. Supports .pt and .yaml formats. @@ -34,9 +34,9 @@ class RTDETR(Model): Raises: NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'. """ - if model and model.split('.')[-1] not in ('pt', 'yaml', 'yml'): - raise NotImplementedError('RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.') - super().__init__(model=model, task='detect') + if model and model.split(".")[-1] not in ("pt", "yaml", "yml"): + raise NotImplementedError("RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.") + super().__init__(model=model, task="detect") @property def task_map(self) -> dict: @@ -47,8 +47,10 @@ class RTDETR(Model): dict: A dictionary mapping task names to Ultralytics task classes for the RT-DETR model. """ return { - 'detect': { - 'predictor': RTDETRPredictor, - 'validator': RTDETRValidator, - 'trainer': RTDETRTrainer, - 'model': RTDETRDetectionModel}} + "detect": { + "predictor": RTDETRPredictor, + "validator": RTDETRValidator, + "trainer": RTDETRTrainer, + "model": RTDETRDetectionModel, + } + } diff --git a/ultralytics/models/rtdetr/train.py b/ultralytics/models/rtdetr/train.py index 26b7ea68..973af649 100644 --- a/ultralytics/models/rtdetr/train.py +++ b/ultralytics/models/rtdetr/train.py @@ -43,12 +43,12 @@ class RTDETRTrainer(DetectionTrainer): Returns: (RTDETRDetectionModel): Initialized model. """ - model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) + model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) if weights: model.load(weights) return model - def build_dataset(self, img_path, mode='val', batch=None): + def build_dataset(self, img_path, mode="val", batch=None): """ Build and return an RT-DETR dataset for training or validation. @@ -60,15 +60,17 @@ class RTDETRTrainer(DetectionTrainer): Returns: (RTDETRDataset): Dataset object for the specific mode. """ - return RTDETRDataset(img_path=img_path, - imgsz=self.args.imgsz, - batch_size=batch, - augment=mode == 'train', - hyp=self.args, - rect=False, - cache=self.args.cache or None, - prefix=colorstr(f'{mode}: '), - data=self.data) + return RTDETRDataset( + img_path=img_path, + imgsz=self.args.imgsz, + batch_size=batch, + augment=mode == "train", + hyp=self.args, + rect=False, + cache=self.args.cache or None, + prefix=colorstr(f"{mode}: "), + data=self.data, + ) def get_validator(self): """ @@ -77,7 +79,7 @@ class RTDETRTrainer(DetectionTrainer): Returns: (RTDETRValidator): Validator object for model validation. """ - self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss' + self.loss_names = "giou_loss", "cls_loss", "l1_loss" return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) def preprocess_batch(self, batch): @@ -91,10 +93,10 @@ class RTDETRTrainer(DetectionTrainer): (dict): Preprocessed batch. """ batch = super().preprocess_batch(batch) - bs = len(batch['img']) - batch_idx = batch['batch_idx'] + bs = len(batch["img"]) + batch_idx = batch["batch_idx"] gt_bbox, gt_class = [], [] for i in range(bs): - gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device)) - gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long)) + gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device)) + gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long)) return batch diff --git a/ultralytics/models/rtdetr/val.py b/ultralytics/models/rtdetr/val.py index c52b5940..359edb5d 100644 --- a/ultralytics/models/rtdetr/val.py +++ b/ultralytics/models/rtdetr/val.py @@ -7,7 +7,7 @@ from ultralytics.data.augment import Compose, Format, v8_transforms from ultralytics.models.yolo.detect import DetectionValidator from ultralytics.utils import colorstr, ops -__all__ = 'RTDETRValidator', # tuple or list +__all__ = ("RTDETRValidator",) # tuple or list class RTDETRDataset(YOLODataset): @@ -37,13 +37,16 @@ class RTDETRDataset(YOLODataset): # transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)]) transforms = Compose([]) transforms.append( - Format(bbox_format='xywh', - normalize=True, - return_mask=self.use_segments, - return_keypoint=self.use_keypoints, - batch_idx=True, - mask_ratio=hyp.mask_ratio, - mask_overlap=hyp.overlap_mask)) + Format( + bbox_format="xywh", + normalize=True, + return_mask=self.use_segments, + return_keypoint=self.use_keypoints, + batch_idx=True, + mask_ratio=hyp.mask_ratio, + mask_overlap=hyp.overlap_mask, + ) + ) return transforms @@ -68,7 +71,7 @@ class RTDETRValidator(DetectionValidator): For further details on the attributes and methods, refer to the parent DetectionValidator class. """ - def build_dataset(self, img_path, mode='val', batch=None): + def build_dataset(self, img_path, mode="val", batch=None): """ Build an RTDETR Dataset. @@ -85,8 +88,9 @@ class RTDETRValidator(DetectionValidator): hyp=self.args, rect=False, # no rect cache=self.args.cache or None, - prefix=colorstr(f'{mode}: '), - data=self.data) + prefix=colorstr(f"{mode}: "), + data=self.data, + ) def postprocess(self, preds): """Apply Non-maximum suppression to prediction outputs.""" @@ -108,12 +112,12 @@ class RTDETRValidator(DetectionValidator): def _prepare_batch(self, si, batch): """Prepares a batch for training or inference by applying transformations.""" - idx = batch['batch_idx'] == si - cls = batch['cls'][idx].squeeze(-1) - bbox = batch['bboxes'][idx] - ori_shape = batch['ori_shape'][si] - imgsz = batch['img'].shape[2:] - ratio_pad = batch['ratio_pad'][si] + idx = batch["batch_idx"] == si + cls = batch["cls"][idx].squeeze(-1) + bbox = batch["bboxes"][idx] + ori_shape = batch["ori_shape"][si] + imgsz = batch["img"].shape[2:] + ratio_pad = batch["ratio_pad"][si] if len(cls): bbox = ops.xywh2xyxy(bbox) # target boxes bbox[..., [0, 2]] *= ori_shape[1] # native-space pred @@ -124,6 +128,6 @@ class RTDETRValidator(DetectionValidator): def _prepare_pred(self, pred, pbatch): """Prepares and returns a batch with transformed bounding boxes and class labels.""" predn = pred.clone() - predn[..., [0, 2]] *= pbatch['ori_shape'][1] / self.args.imgsz # native-space pred - predn[..., [1, 3]] *= pbatch['ori_shape'][0] / self.args.imgsz # native-space pred + predn[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred + predn[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred return predn.float() diff --git a/ultralytics/models/sam/__init__.py b/ultralytics/models/sam/__init__.py index abf2eef5..8701fcce 100644 --- a/ultralytics/models/sam/__init__.py +++ b/ultralytics/models/sam/__init__.py @@ -3,4 +3,4 @@ from .model import SAM from .predict import Predictor -__all__ = 'SAM', 'Predictor' # tuple or list +__all__ = "SAM", "Predictor" # tuple or list diff --git a/ultralytics/models/sam/amg.py b/ultralytics/models/sam/amg.py index d7751d6f..c4bb6d1b 100644 --- a/ultralytics/models/sam/amg.py +++ b/ultralytics/models/sam/amg.py @@ -8,10 +8,9 @@ import numpy as np import torch -def is_box_near_crop_edge(boxes: torch.Tensor, - crop_box: List[int], - orig_box: List[int], - atol: float = 20.0) -> torch.Tensor: +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: """Return a boolean tensor indicating if boxes are near the crop edge.""" crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) @@ -24,10 +23,10 @@ def is_box_near_crop_edge(boxes: torch.Tensor, def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: """Yield batches of data from the input arguments.""" - assert args and all(len(a) == len(args[0]) for a in args), 'Batched iteration must have same-size inputs.' + assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs." n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) for b in range(n_batches): - yield [arg[b * batch_size:(b + 1) * batch_size] for arg in args] + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor: @@ -39,9 +38,8 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh """ # One mask is always contained inside the other. # Save memory by preventing unnecessary cast to torch.int64 - intersections = ((masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, - dtype=torch.int32)) - unions = ((masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)) + intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) + unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) return intersections / unions @@ -56,11 +54,12 @@ def build_point_grid(n_per_side: int) -> np.ndarray: def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]: """Generate point grids for all crop layers.""" - return [build_point_grid(int(n_per_side / (scale_per_layer ** i))) for i in range(n_layers + 1)] + return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)] -def generate_crop_boxes(im_size: Tuple[int, ...], n_layers: int, - overlap_ratio: float) -> Tuple[List[List[int]], List[int]]: +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: """ Generates a list of crop boxes of different sizes. @@ -132,8 +131,8 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup """Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator.""" import cv2 # type: ignore - assert mode in {'holes', 'islands'} - correct_holes = mode == 'holes' + assert mode in {"holes", "islands"} + correct_holes = mode == "holes" working_mask = (correct_holes ^ mask).astype(np.uint8) n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) sizes = stats[:, -1][1:] # Row 0 is background label diff --git a/ultralytics/models/sam/build.py b/ultralytics/models/sam/build.py index c27f2d09..cb3a7c68 100644 --- a/ultralytics/models/sam/build.py +++ b/ultralytics/models/sam/build.py @@ -64,46 +64,47 @@ def build_mobile_sam(checkpoint=None): ) -def _build_sam(encoder_embed_dim, - encoder_depth, - encoder_num_heads, - encoder_global_attn_indexes, - checkpoint=None, - mobile_sam=False): +def _build_sam( + encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, mobile_sam=False +): """Builds the selected SAM model architecture.""" prompt_embed_dim = 256 image_size = 1024 vit_patch_size = 16 image_embedding_size = image_size // vit_patch_size - image_encoder = (TinyViT( - img_size=1024, - in_chans=3, - num_classes=1000, - embed_dims=encoder_embed_dim, - depths=encoder_depth, - num_heads=encoder_num_heads, - window_sizes=[7, 7, 14, 7], - mlp_ratio=4.0, - drop_rate=0.0, - drop_path_rate=0.0, - use_checkpoint=False, - mbconv_expand_ratio=4.0, - local_conv_size=3, - layer_lr_decay=0.8, - ) if mobile_sam else ImageEncoderViT( - depth=encoder_depth, - embed_dim=encoder_embed_dim, - img_size=image_size, - mlp_ratio=4, - norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), - num_heads=encoder_num_heads, - patch_size=vit_patch_size, - qkv_bias=True, - use_rel_pos=True, - global_attn_indexes=encoder_global_attn_indexes, - window_size=14, - out_chans=prompt_embed_dim, - )) + image_encoder = ( + TinyViT( + img_size=1024, + in_chans=3, + num_classes=1000, + embed_dims=encoder_embed_dim, + depths=encoder_depth, + num_heads=encoder_num_heads, + window_sizes=[7, 7, 14, 7], + mlp_ratio=4.0, + drop_rate=0.0, + drop_path_rate=0.0, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=0.8, + ) + if mobile_sam + else ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ) + ) sam = Sam( image_encoder=image_encoder, prompt_encoder=PromptEncoder( @@ -129,7 +130,7 @@ def _build_sam(encoder_embed_dim, ) if checkpoint is not None: checkpoint = attempt_download_asset(checkpoint) - with open(checkpoint, 'rb') as f: + with open(checkpoint, "rb") as f: state_dict = torch.load(f) sam.load_state_dict(state_dict) sam.eval() @@ -139,13 +140,14 @@ def _build_sam(encoder_embed_dim, sam_model_map = { - 'sam_h.pt': build_sam_vit_h, - 'sam_l.pt': build_sam_vit_l, - 'sam_b.pt': build_sam_vit_b, - 'mobile_sam.pt': build_mobile_sam, } + "sam_h.pt": build_sam_vit_h, + "sam_l.pt": build_sam_vit_l, + "sam_b.pt": build_sam_vit_b, + "mobile_sam.pt": build_mobile_sam, +} -def build_sam(ckpt='sam_b.pt'): +def build_sam(ckpt="sam_b.pt"): """Build a SAM model specified by ckpt.""" model_builder = None ckpt = str(ckpt) # to allow Path ckpt types @@ -154,6 +156,6 @@ def build_sam(ckpt='sam_b.pt'): model_builder = sam_model_map.get(k) if not model_builder: - raise FileNotFoundError(f'{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}') + raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}") return model_builder(ckpt) diff --git a/ultralytics/models/sam/model.py b/ultralytics/models/sam/model.py index 68acd22f..8ae9ebaa 100644 --- a/ultralytics/models/sam/model.py +++ b/ultralytics/models/sam/model.py @@ -32,7 +32,7 @@ class SAM(Model): dataset. """ - def __init__(self, model='sam_b.pt') -> None: + def __init__(self, model="sam_b.pt") -> None: """ Initializes the SAM model with a pre-trained model file. @@ -42,9 +42,9 @@ class SAM(Model): Raises: NotImplementedError: If the model file extension is not .pt or .pth. """ - if model and Path(model).suffix not in ('.pt', '.pth'): - raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.') - super().__init__(model=model, task='segment') + if model and Path(model).suffix not in (".pt", ".pth"): + raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.") + super().__init__(model=model, task="segment") def _load(self, weights: str, task=None): """ @@ -70,7 +70,7 @@ class SAM(Model): Returns: (list): The model predictions. """ - overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024) + overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024) kwargs.update(overrides) prompts = dict(bboxes=bboxes, points=points, labels=labels) return super().predict(source, stream, prompts=prompts, **kwargs) @@ -112,4 +112,4 @@ class SAM(Model): Returns: (dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'. """ - return {'segment': {'predictor': Predictor}} + return {"segment": {"predictor": Predictor}} diff --git a/ultralytics/models/sam/modules/decoders.py b/ultralytics/models/sam/modules/decoders.py index 4ad1d9f1..41e3af5d 100644 --- a/ultralytics/models/sam/modules/decoders.py +++ b/ultralytics/models/sam/modules/decoders.py @@ -64,8 +64,9 @@ class MaskDecoder(nn.Module): nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), activation(), ) - self.output_hypernetworks_mlps = nn.ModuleList([ - MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]) + self.output_hypernetworks_mlps = nn.ModuleList( + [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)] + ) self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth) @@ -132,13 +133,14 @@ class MaskDecoder(nn.Module): # Run the transformer hs, src = self.transformer(src, pos_src, tokens) iou_token_out = hs[:, 0, :] - mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] # Upscale mask embeddings and predict masks using the mask tokens src = src.transpose(1, 2).view(b, c, h, w) upscaled_embedding = self.output_upscaling(src) hyper_in_list: List[torch.Tensor] = [ - self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)] + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens) + ] hyper_in = torch.stack(hyper_in_list, dim=1) b, c, h, w = upscaled_embedding.shape masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) diff --git a/ultralytics/models/sam/modules/encoders.py b/ultralytics/models/sam/modules/encoders.py index f7771380..a43e115c 100644 --- a/ultralytics/models/sam/modules/encoders.py +++ b/ultralytics/models/sam/modules/encoders.py @@ -28,23 +28,23 @@ class ImageEncoderViT(nn.Module): """ def __init__( - self, - img_size: int = 1024, - patch_size: int = 16, - in_chans: int = 3, - embed_dim: int = 768, - depth: int = 12, - num_heads: int = 12, - mlp_ratio: float = 4.0, - out_chans: int = 256, - qkv_bias: bool = True, - norm_layer: Type[nn.Module] = nn.LayerNorm, - act_layer: Type[nn.Module] = nn.GELU, - use_abs_pos: bool = True, - use_rel_pos: bool = False, - rel_pos_zero_init: bool = True, - window_size: int = 0, - global_attn_indexes: Tuple[int, ...] = (), + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), ) -> None: """ Args: @@ -283,9 +283,9 @@ class PromptEncoder(nn.Module): if masks is not None: dense_embeddings = self._embed_masks(masks) else: - dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, - 1).expand(bs, -1, self.image_embedding_size[0], - self.image_embedding_size[1]) + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) return sparse_embeddings, dense_embeddings @@ -298,7 +298,7 @@ class PositionEmbeddingRandom(nn.Module): super().__init__() if scale is None or scale <= 0.0: scale = 1.0 - self.register_buffer('positional_encoding_gaussian_matrix', scale * torch.randn((2, num_pos_feats))) + self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats))) # Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation' torch.use_deterministic_algorithms(False) @@ -425,14 +425,14 @@ class Attention(nn.Module): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = head_dim ** -0.5 + self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim) self.use_rel_pos = use_rel_pos if self.use_rel_pos: - assert (input_size is not None), 'Input size must be provided if using relative positional encoding.' + assert input_size is not None, "Input size must be provided if using relative positional encoding." # Initialize relative positional embeddings self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) @@ -479,8 +479,9 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T return windows, (Hp, Wp) -def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], - hw: Tuple[int, int]) -> torch.Tensor: +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: """ Window unpartition into original sequences and removing padding. @@ -523,7 +524,7 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor rel_pos_resized = F.interpolate( rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, - mode='linear', + mode="linear", ) rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) else: @@ -567,11 +568,12 @@ def add_decomposed_rel_pos( B, _, dim = q.shape r_q = q.reshape(B, q_h, q_w, dim) - rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh) - rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view( - B, q_h * q_w, k_h * k_w) + B, q_h * q_w, k_h * k_w + ) return attn @@ -580,12 +582,12 @@ class PatchEmbed(nn.Module): """Image to Patch Embedding.""" def __init__( - self, - kernel_size: Tuple[int, int] = (16, 16), - stride: Tuple[int, int] = (16, 16), - padding: Tuple[int, int] = (0, 0), - in_chans: int = 3, - embed_dim: int = 768, + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, ) -> None: """ Initialize PatchEmbed module. diff --git a/ultralytics/models/sam/modules/sam.py b/ultralytics/models/sam/modules/sam.py index 4097a228..95d9bbe6 100644 --- a/ultralytics/models/sam/modules/sam.py +++ b/ultralytics/models/sam/modules/sam.py @@ -30,8 +30,9 @@ class Sam(nn.Module): pixel_mean (List[float]): Mean pixel values for image normalization. pixel_std (List[float]): Standard deviation values for image normalization. """ + mask_threshold: float = 0.0 - image_format: str = 'RGB' + image_format: str = "RGB" def __init__( self, @@ -39,7 +40,7 @@ class Sam(nn.Module): prompt_encoder: PromptEncoder, mask_decoder: MaskDecoder, pixel_mean: List[float] = (123.675, 116.28, 103.53), - pixel_std: List[float] = (58.395, 57.12, 57.375) + pixel_std: List[float] = (58.395, 57.12, 57.375), ) -> None: """ Initialize the Sam class to predict object masks from an image and input prompts. @@ -60,5 +61,5 @@ class Sam(nn.Module): self.image_encoder = image_encoder self.prompt_encoder = prompt_encoder self.mask_decoder = mask_decoder - self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False) - self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False) + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) diff --git a/ultralytics/models/sam/modules/tiny_encoder.py b/ultralytics/models/sam/modules/tiny_encoder.py index 9955a261..ac378680 100644 --- a/ultralytics/models/sam/modules/tiny_encoder.py +++ b/ultralytics/models/sam/modules/tiny_encoder.py @@ -28,11 +28,11 @@ class Conv2d_BN(torch.nn.Sequential): drop path. """ super().__init__() - self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) + self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) bn = torch.nn.BatchNorm2d(b) torch.nn.init.constant_(bn.weight, bn_weight_init) torch.nn.init.constant_(bn.bias, 0) - self.add_module('bn', bn) + self.add_module("bn", bn) class PatchEmbed(nn.Module): @@ -146,11 +146,11 @@ class ConvLayer(nn.Module): input_resolution, depth, activation, - drop_path=0., + drop_path=0.0, downsample=None, use_checkpoint=False, out_dim=None, - conv_expand_ratio=4., + conv_expand_ratio=4.0, ): """ Initializes the ConvLayer with the given dimensions and settings. @@ -173,18 +173,25 @@ class ConvLayer(nn.Module): self.use_checkpoint = use_checkpoint # Build blocks - self.blocks = nn.ModuleList([ - MBConv( - dim, - dim, - conv_expand_ratio, - activation, - drop_path[i] if isinstance(drop_path, list) else drop_path, - ) for i in range(depth)]) + self.blocks = nn.ModuleList( + [ + MBConv( + dim, + dim, + conv_expand_ratio, + activation, + drop_path[i] if isinstance(drop_path, list) else drop_path, + ) + for i in range(depth) + ] + ) # Patch merging layer - self.downsample = None if downsample is None else downsample( - input_resolution, dim=dim, out_dim=out_dim, activation=activation) + self.downsample = ( + None + if downsample is None + else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation) + ) def forward(self, x): """Processes the input through a series of convolutional layers and returns the activated output.""" @@ -200,7 +207,7 @@ class Mlp(nn.Module): This layer takes an input with in_features, applies layer normalization and two fully-connected layers. """ - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): """Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc.""" super().__init__() out_features = out_features or in_features @@ -232,12 +239,12 @@ class Attention(torch.nn.Module): """ def __init__( - self, - dim, - key_dim, - num_heads=8, - attn_ratio=4, - resolution=(14, 14), + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=(14, 14), ): """ Initializes the Attention module. @@ -256,7 +263,7 @@ class Attention(torch.nn.Module): assert isinstance(resolution, tuple) and len(resolution) == 2 self.num_heads = num_heads - self.scale = key_dim ** -0.5 + self.scale = key_dim**-0.5 self.key_dim = key_dim self.nh_kd = nh_kd = key_dim * num_heads self.d = int(attn_ratio * key_dim) @@ -279,13 +286,13 @@ class Attention(torch.nn.Module): attention_offsets[offset] = len(attention_offsets) idxs.append(attention_offsets[offset]) self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) - self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False) + self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False) @torch.no_grad() def train(self, mode=True): """Sets the module in training mode and handles attribute 'ab' based on the mode.""" super().train(mode) - if mode and hasattr(self, 'ab'): + if mode and hasattr(self, "ab"): del self.ab else: self.ab = self.attention_biases[:, self.attention_bias_idxs] @@ -306,8 +313,9 @@ class Attention(torch.nn.Module): v = v.permute(0, 2, 1, 3) self.ab = self.ab.to(self.attention_biases.device) - attn = ((q @ k.transpose(-2, -1)) * self.scale + - (self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab)) + attn = (q @ k.transpose(-2, -1)) * self.scale + ( + self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab + ) attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) return self.proj(x) @@ -322,9 +330,9 @@ class TinyViTBlock(nn.Module): input_resolution, num_heads, window_size=7, - mlp_ratio=4., - drop=0., - drop_path=0., + mlp_ratio=4.0, + drop=0.0, + drop_path=0.0, local_conv_size=3, activation=nn.GELU, ): @@ -350,7 +358,7 @@ class TinyViTBlock(nn.Module): self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads - assert window_size > 0, 'window_size must be greater than 0' + assert window_size > 0, "window_size must be greater than 0" self.window_size = window_size self.mlp_ratio = mlp_ratio @@ -358,7 +366,7 @@ class TinyViTBlock(nn.Module): # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = nn.Identity() - assert dim % num_heads == 0, 'dim must be divisible by num_heads' + assert dim % num_heads == 0, "dim must be divisible by num_heads" head_dim = dim // num_heads window_resolution = (window_size, window_size) @@ -377,7 +385,7 @@ class TinyViTBlock(nn.Module): """ H, W = self.input_resolution B, L, C = x.shape - assert L == H * W, 'input feature has wrong size' + assert L == H * W, "input feature has wrong size" res_x = x if H == self.window_size and W == self.window_size: x = self.attn(x) @@ -394,8 +402,11 @@ class TinyViTBlock(nn.Module): nH = pH // self.window_size nW = pW // self.window_size # Window partition - x = x.view(B, nH, self.window_size, nW, self.window_size, - C).transpose(2, 3).reshape(B * nH * nW, self.window_size * self.window_size, C) + x = ( + x.view(B, nH, self.window_size, nW, self.window_size, C) + .transpose(2, 3) + .reshape(B * nH * nW, self.window_size * self.window_size, C) + ) x = self.attn(x) # Window reverse x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C) @@ -417,8 +428,10 @@ class TinyViTBlock(nn.Module): """Returns a formatted string representing the TinyViTBlock's parameters: dimension, input resolution, number of attentions heads, window size, and MLP ratio. """ - return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \ - f'window_size={self.window_size}, mlp_ratio={self.mlp_ratio}' + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" + ) class BasicLayer(nn.Module): @@ -431,9 +444,9 @@ class BasicLayer(nn.Module): depth, num_heads, window_size, - mlp_ratio=4., - drop=0., - drop_path=0., + mlp_ratio=4.0, + drop=0.0, + drop_path=0.0, downsample=None, use_checkpoint=False, local_conv_size=3, @@ -468,22 +481,29 @@ class BasicLayer(nn.Module): self.use_checkpoint = use_checkpoint # Build blocks - self.blocks = nn.ModuleList([ - TinyViTBlock( - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - mlp_ratio=mlp_ratio, - drop=drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - local_conv_size=local_conv_size, - activation=activation, - ) for i in range(depth)]) + self.blocks = nn.ModuleList( + [ + TinyViTBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + local_conv_size=local_conv_size, + activation=activation, + ) + for i in range(depth) + ] + ) # Patch merging layer - self.downsample = None if downsample is None else downsample( - input_resolution, dim=dim, out_dim=out_dim, activation=activation) + self.downsample = ( + None + if downsample is None + else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation) + ) def forward(self, x): """Performs forward propagation on the input tensor and returns a normalized tensor.""" @@ -493,7 +513,7 @@ class BasicLayer(nn.Module): def extra_repr(self) -> str: """Returns a string representation of the extra_repr function with the layer's parameters.""" - return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}' + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" class LayerNorm2d(nn.Module): @@ -549,8 +569,8 @@ class TinyViT(nn.Module): depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_sizes=[7, 7, 14, 7], - mlp_ratio=4., - drop_rate=0., + mlp_ratio=4.0, + drop_rate=0.0, drop_path_rate=0.1, use_checkpoint=False, mbconv_expand_ratio=4.0, @@ -585,10 +605,9 @@ class TinyViT(nn.Module): activation = nn.GELU - self.patch_embed = PatchEmbed(in_chans=in_chans, - embed_dim=embed_dims[0], - resolution=img_size, - activation=activation) + self.patch_embed = PatchEmbed( + in_chans=in_chans, embed_dim=embed_dims[0], resolution=img_size, activation=activation + ) patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution @@ -601,27 +620,30 @@ class TinyViT(nn.Module): for i_layer in range(self.num_layers): kwargs = dict( dim=embed_dims[i_layer], - input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)), - patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))), + input_resolution=( + patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)), + patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)), + ), # input_resolution=(patches_resolution[0] // (2 ** i_layer), # patches_resolution[1] // (2 ** i_layer)), depth=depths[i_layer], - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, - out_dim=embed_dims[min(i_layer + 1, - len(embed_dims) - 1)], + out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)], activation=activation, ) if i_layer == 0: layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs) else: - layer = BasicLayer(num_heads=num_heads[i_layer], - window_size=window_sizes[i_layer], - mlp_ratio=self.mlp_ratio, - drop=drop_rate, - local_conv_size=local_conv_size, - **kwargs) + layer = BasicLayer( + num_heads=num_heads[i_layer], + window_size=window_sizes[i_layer], + mlp_ratio=self.mlp_ratio, + drop=drop_rate, + local_conv_size=local_conv_size, + **kwargs, + ) self.layers.append(layer) # Classifier head @@ -680,7 +702,7 @@ class TinyViT(nn.Module): def _check_lr_scale(m): """Checks if the learning rate scale attribute is present in module's parameters.""" for p in m.parameters(): - assert hasattr(p, 'lr_scale'), p.param_name + assert hasattr(p, "lr_scale"), p.param_name self.apply(_check_lr_scale) @@ -698,7 +720,7 @@ class TinyViT(nn.Module): @torch.jit.ignore def no_weight_decay_keywords(self): """Returns a dictionary of parameter names where weight decay should not be applied.""" - return {'attention_biases'} + return {"attention_biases"} def forward_features(self, x): """Runs the input through the model layers and returns the transformed output.""" diff --git a/ultralytics/models/sam/modules/transformer.py b/ultralytics/models/sam/modules/transformer.py index 5c06acd9..1ad07418 100644 --- a/ultralytics/models/sam/modules/transformer.py +++ b/ultralytics/models/sam/modules/transformer.py @@ -62,7 +62,8 @@ class TwoWayTransformer(nn.Module): activation=activation, attention_downsample_rate=attention_downsample_rate, skip_first_layer_pe=(i == 0), - )) + ) + ) self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) self.norm_final_attn = nn.LayerNorm(embedding_dim) @@ -227,7 +228,7 @@ class Attention(nn.Module): self.embedding_dim = embedding_dim self.internal_dim = embedding_dim // downsample_rate self.num_heads = num_heads - assert self.internal_dim % num_heads == 0, 'num_heads must divide embedding_dim.' + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." self.q_proj = nn.Linear(embedding_dim, self.internal_dim) self.k_proj = nn.Linear(embedding_dim, self.internal_dim) diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py index 94362ecc..540ed6c6 100644 --- a/ultralytics/models/sam/predict.py +++ b/ultralytics/models/sam/predict.py @@ -19,8 +19,17 @@ from ultralytics.engine.results import Results from ultralytics.utils import DEFAULT_CFG, ops from ultralytics.utils.torch_utils import select_device -from .amg import (batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score, - generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks) +from .amg import ( + batch_iterator, + batched_mask_to_box, + build_all_layer_point_grids, + calculate_stability_score, + generate_crop_boxes, + is_box_near_crop_edge, + remove_small_regions, + uncrop_boxes_xyxy, + uncrop_masks, +) from .build import build_sam @@ -58,7 +67,7 @@ class Predictor(BasePredictor): """ if overrides is None: overrides = {} - overrides.update(dict(task='segment', mode='predict', imgsz=1024)) + overrides.update(dict(task="segment", mode="predict", imgsz=1024)) super().__init__(cfg, overrides, _callbacks) self.args.retina_masks = True self.im = None @@ -107,7 +116,7 @@ class Predictor(BasePredictor): Returns: (List[np.ndarray]): List of transformed images. """ - assert len(im) == 1, 'SAM model does not currently support batched inference' + assert len(im) == 1, "SAM model does not currently support batched inference" letterbox = LetterBox(self.args.imgsz, auto=False, center=False) return [letterbox(image=x) for x in im] @@ -132,9 +141,9 @@ class Predictor(BasePredictor): - np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256. """ # Override prompts if any stored in self.prompts - bboxes = self.prompts.pop('bboxes', bboxes) - points = self.prompts.pop('points', points) - masks = self.prompts.pop('masks', masks) + bboxes = self.prompts.pop("bboxes", bboxes) + points = self.prompts.pop("points", points) + masks = self.prompts.pop("masks", masks) if all(i is None for i in [bboxes, points, masks]): return self.generate(im, *args, **kwargs) @@ -199,18 +208,20 @@ class Predictor(BasePredictor): # `d` could be 1 or 3 depends on `multimask_output`. return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) - def generate(self, - im, - crop_n_layers=0, - crop_overlap_ratio=512 / 1500, - crop_downscale_factor=1, - point_grids=None, - points_stride=32, - points_batch_size=64, - conf_thres=0.88, - stability_score_thresh=0.95, - stability_score_offset=0.95, - crop_nms_thresh=0.7): + def generate( + self, + im, + crop_n_layers=0, + crop_overlap_ratio=512 / 1500, + crop_downscale_factor=1, + point_grids=None, + points_stride=32, + points_batch_size=64, + conf_thres=0.88, + stability_score_thresh=0.95, + stability_score_offset=0.95, + crop_nms_thresh=0.7, + ): """ Perform image segmentation using the Segment Anything Model (SAM). @@ -248,19 +259,20 @@ class Predictor(BasePredictor): area = torch.tensor(w * h, device=im.device) points_scale = np.array([[w, h]]) # w, h # Crop image and interpolate to input size - crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode='bilinear', align_corners=False) + crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False) # (num_points, 2) points_for_image = point_grids[layer_idx] * points_scale crop_masks, crop_scores, crop_bboxes = [], [], [] - for (points, ) in batch_iterator(points_batch_size, points_for_image): + for (points,) in batch_iterator(points_batch_size, points_for_image): pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True) # Interpolate predicted masks to input size - pred_mask = F.interpolate(pred_mask[None], (h, w), mode='bilinear', align_corners=False)[0] + pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0] idx = pred_score > conf_thres pred_mask, pred_score = pred_mask[idx], pred_score[idx] - stability_score = calculate_stability_score(pred_mask, self.model.mask_threshold, - stability_score_offset) + stability_score = calculate_stability_score( + pred_mask, self.model.mask_threshold, stability_score_offset + ) idx = stability_score > stability_score_thresh pred_mask, pred_score = pred_mask[idx], pred_score[idx] # Bool type is much more memory-efficient. @@ -404,7 +416,7 @@ class Predictor(BasePredictor): model = build_sam(self.args.model) self.setup_model(model) self.setup_source(image) - assert len(self.dataset) == 1, '`set_image` only supports setting one image!' + assert len(self.dataset) == 1, "`set_image` only supports setting one image!" for batch in self.dataset: im = self.preprocess(batch[1]) self.features = self.model.image_encoder(im) @@ -446,9 +458,9 @@ class Predictor(BasePredictor): scores = [] for mask in masks: mask = mask.cpu().numpy().astype(np.uint8) - mask, changed = remove_small_regions(mask, min_area, mode='holes') + mask, changed = remove_small_regions(mask, min_area, mode="holes") unchanged = not changed - mask, changed = remove_small_regions(mask, min_area, mode='islands') + mask, changed = remove_small_regions(mask, min_area, mode="islands") unchanged = unchanged and not changed new_masks.append(torch.as_tensor(mask).unsqueeze(0)) diff --git a/ultralytics/models/utils/loss.py b/ultralytics/models/utils/loss.py index abb54958..1251cc96 100644 --- a/ultralytics/models/utils/loss.py +++ b/ultralytics/models/utils/loss.py @@ -30,14 +30,9 @@ class DETRLoss(nn.Module): device (torch.device): Device on which tensors are stored. """ - def __init__(self, - nc=80, - loss_gain=None, - aux_loss=True, - use_fl=True, - use_vfl=False, - use_uni_match=False, - uni_match_ind=0): + def __init__( + self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0 + ): """ DETR loss function. @@ -52,9 +47,9 @@ class DETRLoss(nn.Module): super().__init__() if loss_gain is None: - loss_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'no_object': 0.1, 'mask': 1, 'dice': 1} + loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1} self.nc = nc - self.matcher = HungarianMatcher(cost_gain={'class': 2, 'bbox': 5, 'giou': 2}) + self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2}) self.loss_gain = loss_gain self.aux_loss = aux_loss self.fl = FocalLoss() if use_fl else None @@ -64,10 +59,10 @@ class DETRLoss(nn.Module): self.uni_match_ind = uni_match_ind self.device = None - def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=''): + def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""): """Computes the classification loss based on predictions, target values, and ground truth scores.""" # Logits: [b, query, num_classes], gt_class: list[[n, 1]] - name_class = f'loss_class{postfix}' + name_class = f"loss_class{postfix}" bs, nq = pred_scores.shape[:2] # one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes) one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device) @@ -82,28 +77,28 @@ class DETRLoss(nn.Module): loss_cls = self.fl(pred_scores, one_hot.float()) loss_cls /= max(num_gts, 1) / nq else: - loss_cls = nn.BCEWithLogitsLoss(reduction='none')(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss + loss_cls = nn.BCEWithLogitsLoss(reduction="none")(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss - return {name_class: loss_cls.squeeze() * self.loss_gain['class']} + return {name_class: loss_cls.squeeze() * self.loss_gain["class"]} - def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=''): + def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""): """Calculates and returns the bounding box loss and GIoU loss for the predicted and ground truth bounding boxes. """ # Boxes: [b, query, 4], gt_bbox: list[[n, 4]] - name_bbox = f'loss_bbox{postfix}' - name_giou = f'loss_giou{postfix}' + name_bbox = f"loss_bbox{postfix}" + name_giou = f"loss_giou{postfix}" loss = {} if len(gt_bboxes) == 0: - loss[name_bbox] = torch.tensor(0., device=self.device) - loss[name_giou] = torch.tensor(0., device=self.device) + loss[name_bbox] = torch.tensor(0.0, device=self.device) + loss[name_giou] = torch.tensor(0.0, device=self.device) return loss - loss[name_bbox] = self.loss_gain['bbox'] * F.l1_loss(pred_bboxes, gt_bboxes, reduction='sum') / len(gt_bboxes) + loss[name_bbox] = self.loss_gain["bbox"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction="sum") / len(gt_bboxes) loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True) loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes) - loss[name_giou] = self.loss_gain['giou'] * loss[name_giou] + loss[name_giou] = self.loss_gain["giou"] * loss[name_giou] return {k: v.squeeze() for k, v in loss.items()} # This function is for future RT-DETR Segment models @@ -137,50 +132,57 @@ class DETRLoss(nn.Module): # loss = 1 - (numerator + 1) / (denominator + 1) # return loss.sum() / num_gts - def _get_loss_aux(self, - pred_bboxes, - pred_scores, - gt_bboxes, - gt_cls, - gt_groups, - match_indices=None, - postfix='', - masks=None, - gt_mask=None): + def _get_loss_aux( + self, + pred_bboxes, + pred_scores, + gt_bboxes, + gt_cls, + gt_groups, + match_indices=None, + postfix="", + masks=None, + gt_mask=None, + ): """Get auxiliary losses.""" # NOTE: loss class, bbox, giou, mask, dice loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device) if match_indices is None and self.use_uni_match: - match_indices = self.matcher(pred_bboxes[self.uni_match_ind], - pred_scores[self.uni_match_ind], - gt_bboxes, - gt_cls, - gt_groups, - masks=masks[self.uni_match_ind] if masks is not None else None, - gt_mask=gt_mask) + match_indices = self.matcher( + pred_bboxes[self.uni_match_ind], + pred_scores[self.uni_match_ind], + gt_bboxes, + gt_cls, + gt_groups, + masks=masks[self.uni_match_ind] if masks is not None else None, + gt_mask=gt_mask, + ) for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)): aux_masks = masks[i] if masks is not None else None - loss_ = self._get_loss(aux_bboxes, - aux_scores, - gt_bboxes, - gt_cls, - gt_groups, - masks=aux_masks, - gt_mask=gt_mask, - postfix=postfix, - match_indices=match_indices) - loss[0] += loss_[f'loss_class{postfix}'] - loss[1] += loss_[f'loss_bbox{postfix}'] - loss[2] += loss_[f'loss_giou{postfix}'] + loss_ = self._get_loss( + aux_bboxes, + aux_scores, + gt_bboxes, + gt_cls, + gt_groups, + masks=aux_masks, + gt_mask=gt_mask, + postfix=postfix, + match_indices=match_indices, + ) + loss[0] += loss_[f"loss_class{postfix}"] + loss[1] += loss_[f"loss_bbox{postfix}"] + loss[2] += loss_[f"loss_giou{postfix}"] # if masks is not None and gt_mask is not None: # loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix) # loss[3] += loss_[f'loss_mask{postfix}'] # loss[4] += loss_[f'loss_dice{postfix}'] loss = { - f'loss_class_aux{postfix}': loss[0], - f'loss_bbox_aux{postfix}': loss[1], - f'loss_giou_aux{postfix}': loss[2]} + f"loss_class_aux{postfix}": loss[0], + f"loss_bbox_aux{postfix}": loss[1], + f"loss_giou_aux{postfix}": loss[2], + } # if masks is not None and gt_mask is not None: # loss[f'loss_mask_aux{postfix}'] = loss[3] # loss[f'loss_dice_aux{postfix}'] = loss[4] @@ -196,33 +198,37 @@ class DETRLoss(nn.Module): def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices): """Assigns predicted bounding boxes to ground truth bounding boxes based on the match indices.""" - pred_assigned = torch.cat([ - t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device) - for t, (I, _) in zip(pred_bboxes, match_indices)]) - gt_assigned = torch.cat([ - t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device) - for t, (_, J) in zip(gt_bboxes, match_indices)]) + pred_assigned = torch.cat( + [ + t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device) + for t, (I, _) in zip(pred_bboxes, match_indices) + ] + ) + gt_assigned = torch.cat( + [ + t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device) + for t, (_, J) in zip(gt_bboxes, match_indices) + ] + ) return pred_assigned, gt_assigned - def _get_loss(self, - pred_bboxes, - pred_scores, - gt_bboxes, - gt_cls, - gt_groups, - masks=None, - gt_mask=None, - postfix='', - match_indices=None): + def _get_loss( + self, + pred_bboxes, + pred_scores, + gt_bboxes, + gt_cls, + gt_groups, + masks=None, + gt_mask=None, + postfix="", + match_indices=None, + ): """Get losses.""" if match_indices is None: - match_indices = self.matcher(pred_bboxes, - pred_scores, - gt_bboxes, - gt_cls, - gt_groups, - masks=masks, - gt_mask=gt_mask) + match_indices = self.matcher( + pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask + ) idx, gt_idx = self._get_index(match_indices) pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx] @@ -242,7 +248,7 @@ class DETRLoss(nn.Module): # loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix)) return loss - def forward(self, pred_bboxes, pred_scores, batch, postfix='', **kwargs): + def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs): """ Args: pred_bboxes (torch.Tensor): [l, b, query, 4] @@ -254,21 +260,19 @@ class DETRLoss(nn.Module): postfix (str): postfix of loss name. """ self.device = pred_bboxes.device - match_indices = kwargs.get('match_indices', None) - gt_cls, gt_bboxes, gt_groups = batch['cls'], batch['bboxes'], batch['gt_groups'] + match_indices = kwargs.get("match_indices", None) + gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"] - total_loss = self._get_loss(pred_bboxes[-1], - pred_scores[-1], - gt_bboxes, - gt_cls, - gt_groups, - postfix=postfix, - match_indices=match_indices) + total_loss = self._get_loss( + pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices + ) if self.aux_loss: total_loss.update( - self._get_loss_aux(pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, - postfix)) + self._get_loss_aux( + pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix + ) + ) return total_loss @@ -300,18 +304,18 @@ class RTDETRDetectionLoss(DETRLoss): # Check for denoising metadata to compute denoising training loss if dn_meta is not None: - dn_pos_idx, dn_num_group = dn_meta['dn_pos_idx'], dn_meta['dn_num_group'] - assert len(batch['gt_groups']) == len(dn_pos_idx) + dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"] + assert len(batch["gt_groups"]) == len(dn_pos_idx) # Get the match indices for denoising - match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch['gt_groups']) + match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"]) # Compute the denoising training loss - dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix='_dn', match_indices=match_indices) + dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices) total_loss.update(dn_loss) else: # If no denoising metadata is provided, set denoising loss to zero - total_loss.update({f'{k}_dn': torch.tensor(0., device=self.device) for k in total_loss.keys()}) + total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()}) return total_loss @@ -334,8 +338,8 @@ class RTDETRDetectionLoss(DETRLoss): if num_gt > 0: gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i] gt_idx = gt_idx.repeat(dn_num_group) - assert len(dn_pos_idx[i]) == len(gt_idx), 'Expected the same length, ' - f'but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.' + assert len(dn_pos_idx[i]) == len(gt_idx), "Expected the same length, " + f"but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively." dn_match_indices.append((dn_pos_idx[i], gt_idx)) else: dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long))) diff --git a/ultralytics/models/utils/ops.py b/ultralytics/models/utils/ops.py index 902756db..4f66feef 100644 --- a/ultralytics/models/utils/ops.py +++ b/ultralytics/models/utils/ops.py @@ -37,7 +37,7 @@ class HungarianMatcher(nn.Module): """ super().__init__() if cost_gain is None: - cost_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'mask': 1, 'dice': 1} + cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1} self.cost_gain = cost_gain self.use_fl = use_fl self.with_mask = with_mask @@ -86,7 +86,7 @@ class HungarianMatcher(nn.Module): # Compute the classification cost pred_scores = pred_scores[:, gt_cls] if self.use_fl: - neg_cost_class = (1 - self.alpha) * (pred_scores ** self.gamma) * (-(1 - pred_scores + 1e-8).log()) + neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log()) pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log()) cost_class = pos_cost_class - neg_cost_class else: @@ -99,9 +99,11 @@ class HungarianMatcher(nn.Module): cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1) # Final cost matrix - C = self.cost_gain['class'] * cost_class + \ - self.cost_gain['bbox'] * cost_bbox + \ - self.cost_gain['giou'] * cost_giou + C = ( + self.cost_gain["class"] * cost_class + + self.cost_gain["bbox"] * cost_bbox + + self.cost_gain["giou"] * cost_giou + ) # Compute the mask cost and dice cost if self.with_mask: C += self._cost_mask(bs, gt_groups, masks, gt_mask) @@ -111,10 +113,11 @@ class HungarianMatcher(nn.Module): C = C.view(bs, nq, -1).cpu() indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))] - gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) - # (idx for queries, idx for gt) - return [(torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k]) - for k, (i, j) in enumerate(indices)] + gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) # (idx for queries, idx for gt) + return [ + (torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k]) + for k, (i, j) in enumerate(indices) + ] # This function is for future RT-DETR Segment models # def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None): @@ -147,14 +150,9 @@ class HungarianMatcher(nn.Module): # return C -def get_cdn_group(batch, - num_classes, - num_queries, - class_embed, - num_dn=100, - cls_noise_ratio=0.5, - box_noise_scale=1.0, - training=False): +def get_cdn_group( + batch, num_classes, num_queries, class_embed, num_dn=100, cls_noise_ratio=0.5, box_noise_scale=1.0, training=False +): """ Get contrastive denoising training group. This function creates a contrastive denoising training group with positive and negative samples from the ground truths (gt). It applies noise to the class labels and bounding box coordinates, @@ -180,7 +178,7 @@ def get_cdn_group(batch, if (not training) or num_dn <= 0: return None, None, None, None - gt_groups = batch['gt_groups'] + gt_groups = batch["gt_groups"] total_num = sum(gt_groups) max_nums = max(gt_groups) if max_nums == 0: @@ -190,9 +188,9 @@ def get_cdn_group(batch, num_group = 1 if num_group == 0 else num_group # Pad gt to max_num of a batch bs = len(gt_groups) - gt_cls = batch['cls'] # (bs*num, ) - gt_bbox = batch['bboxes'] # bs*num, 4 - b_idx = batch['batch_idx'] + gt_cls = batch["cls"] # (bs*num, ) + gt_bbox = batch["bboxes"] # bs*num, 4 + b_idx = batch["batch_idx"] # Each group has positive and negative queries. dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, ) @@ -245,16 +243,21 @@ def get_cdn_group(batch, # Reconstruct cannot see each other for i in range(num_group): if i == 0: - attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True + attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True if i == num_group - 1: - attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * i * 2] = True + attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True else: - attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True - attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * 2 * i] = True + attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True + attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True dn_meta = { - 'dn_pos_idx': [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)], - 'dn_num_group': num_group, - 'dn_num_split': [num_dn, num_queries]} - - return padding_cls.to(class_embed.device), padding_bbox.to(class_embed.device), attn_mask.to( - class_embed.device), dn_meta + "dn_pos_idx": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)], + "dn_num_group": num_group, + "dn_num_split": [num_dn, num_queries], + } + + return ( + padding_cls.to(class_embed.device), + padding_bbox.to(class_embed.device), + attn_mask.to(class_embed.device), + dn_meta, + ) diff --git a/ultralytics/models/yolo/__init__.py b/ultralytics/models/yolo/__init__.py index 602307b8..230f53a8 100644 --- a/ultralytics/models/yolo/__init__.py +++ b/ultralytics/models/yolo/__init__.py @@ -4,4 +4,4 @@ from ultralytics.models.yolo import classify, detect, obb, pose, segment from .model import YOLO -__all__ = 'classify', 'segment', 'detect', 'pose', 'obb', 'YOLO' +__all__ = "classify", "segment", "detect", "pose", "obb", "YOLO" diff --git a/ultralytics/models/yolo/classify/__init__.py b/ultralytics/models/yolo/classify/__init__.py index 33d72e68..ca92f892 100644 --- a/ultralytics/models/yolo/classify/__init__.py +++ b/ultralytics/models/yolo/classify/__init__.py @@ -4,4 +4,4 @@ from ultralytics.models.yolo.classify.predict import ClassificationPredictor from ultralytics.models.yolo.classify.train import ClassificationTrainer from ultralytics.models.yolo.classify.val import ClassificationValidator -__all__ = 'ClassificationPredictor', 'ClassificationTrainer', 'ClassificationValidator' +__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator" diff --git a/ultralytics/models/yolo/classify/predict.py b/ultralytics/models/yolo/classify/predict.py index 9047d8df..853ef048 100644 --- a/ultralytics/models/yolo/classify/predict.py +++ b/ultralytics/models/yolo/classify/predict.py @@ -30,19 +30,21 @@ class ClassificationPredictor(BasePredictor): def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): """Initializes ClassificationPredictor setting the task to 'classify'.""" super().__init__(cfg, overrides, _callbacks) - self.args.task = 'classify' - self._legacy_transform_name = 'ultralytics.yolo.data.augment.ToTensor' + self.args.task = "classify" + self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor" def preprocess(self, img): """Converts input image to model-compatible data type.""" if not isinstance(img, torch.Tensor): - is_legacy_transform = any(self._legacy_transform_name in str(transform) - for transform in self.transforms.transforms) + is_legacy_transform = any( + self._legacy_transform_name in str(transform) for transform in self.transforms.transforms + ) if is_legacy_transform: # to handle legacy transforms img = torch.stack([self.transforms(im) for im in img], dim=0) else: - img = torch.stack([self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], - dim=0) + img = torch.stack( + [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0 + ) img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device) return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 diff --git a/ultralytics/models/yolo/classify/train.py b/ultralytics/models/yolo/classify/train.py index 3c23b2a0..a7bbe772 100644 --- a/ultralytics/models/yolo/classify/train.py +++ b/ultralytics/models/yolo/classify/train.py @@ -33,23 +33,23 @@ class ClassificationTrainer(BaseTrainer): """Initialize a ClassificationTrainer object with optional configuration overrides and callbacks.""" if overrides is None: overrides = {} - overrides['task'] = 'classify' - if overrides.get('imgsz') is None: - overrides['imgsz'] = 224 + overrides["task"] = "classify" + if overrides.get("imgsz") is None: + overrides["imgsz"] = 224 super().__init__(cfg, overrides, _callbacks) def set_model_attributes(self): """Set the YOLO model's class names from the loaded dataset.""" - self.model.names = self.data['names'] + self.model.names = self.data["names"] def get_model(self, cfg=None, weights=None, verbose=True): """Returns a modified PyTorch model configured for training YOLO.""" - model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) + model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) if weights: model.load(weights) for m in model.modules(): - if not self.args.pretrained and hasattr(m, 'reset_parameters'): + if not self.args.pretrained and hasattr(m, "reset_parameters"): m.reset_parameters() if isinstance(m, torch.nn.Dropout) and self.args.dropout: m.p = self.args.dropout # set dropout @@ -64,32 +64,32 @@ class ClassificationTrainer(BaseTrainer): model, ckpt = str(self.model), None # Load a YOLO model locally, from torchvision, or from Ultralytics assets - if model.endswith('.pt'): - self.model, ckpt = attempt_load_one_weight(model, device='cpu') + if model.endswith(".pt"): + self.model, ckpt = attempt_load_one_weight(model, device="cpu") for p in self.model.parameters(): p.requires_grad = True # for training - elif model.split('.')[-1] in ('yaml', 'yml'): + elif model.split(".")[-1] in ("yaml", "yml"): self.model = self.get_model(cfg=model) elif model in torchvision.models.__dict__: - self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if self.args.pretrained else None) + self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None) else: - FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.') - ClassificationModel.reshape_outputs(self.model, self.data['nc']) + FileNotFoundError(f"ERROR: model={model} not found locally or online. Please check model name.") + ClassificationModel.reshape_outputs(self.model, self.data["nc"]) return ckpt - def build_dataset(self, img_path, mode='train', batch=None): + def build_dataset(self, img_path, mode="train", batch=None): """Creates a ClassificationDataset instance given an image path, and mode (train/test etc.).""" - return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train', prefix=mode) + return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode) - def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'): + def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): """Returns PyTorch DataLoader with transforms to preprocess images for inference.""" with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP dataset = self.build_dataset(dataset_path, mode) loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank) # Attach inference transforms - if mode != 'train': + if mode != "train": if is_parallel(self.model): self.model.module.transforms = loader.dataset.torch_transforms else: @@ -98,27 +98,32 @@ class ClassificationTrainer(BaseTrainer): def preprocess_batch(self, batch): """Preprocesses a batch of images and classes.""" - batch['img'] = batch['img'].to(self.device) - batch['cls'] = batch['cls'].to(self.device) + batch["img"] = batch["img"].to(self.device) + batch["cls"] = batch["cls"].to(self.device) return batch def progress_string(self): """Returns a formatted string showing training progress.""" - return ('\n' + '%11s' * (4 + len(self.loss_names))) % \ - ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size') + return ("\n" + "%11s" * (4 + len(self.loss_names))) % ( + "Epoch", + "GPU_mem", + *self.loss_names, + "Instances", + "Size", + ) def get_validator(self): """Returns an instance of ClassificationValidator for validation.""" - self.loss_names = ['loss'] + self.loss_names = ["loss"] return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks) - def label_loss_items(self, loss_items=None, prefix='train'): + def label_loss_items(self, loss_items=None, prefix="train"): """ Returns a loss dict with labelled training loss items tensor. Not needed for classification but necessary for segmentation & detection """ - keys = [f'{prefix}/{x}' for x in self.loss_names] + keys = [f"{prefix}/{x}" for x in self.loss_names] if loss_items is None: return keys loss_items = [round(float(loss_items), 5)] @@ -134,19 +139,20 @@ class ClassificationTrainer(BaseTrainer): if f.exists(): strip_optimizer(f) # strip optimizers if f is self.best: - LOGGER.info(f'\nValidating {f}...') + LOGGER.info(f"\nValidating {f}...") self.validator.args.data = self.args.data self.validator.args.plots = self.args.plots self.metrics = self.validator(model=f) - self.metrics.pop('fitness', None) - self.run_callbacks('on_fit_epoch_end') + self.metrics.pop("fitness", None) + self.run_callbacks("on_fit_epoch_end") LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}") def plot_training_samples(self, batch, ni): """Plots training samples with their annotations.""" plot_images( - images=batch['img'], - batch_idx=torch.arange(len(batch['img'])), - cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models - fname=self.save_dir / f'train_batch{ni}.jpg', - on_plot=self.on_plot) + images=batch["img"], + batch_idx=torch.arange(len(batch["img"])), + cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models + fname=self.save_dir / f"train_batch{ni}.jpg", + on_plot=self.on_plot, + ) diff --git a/ultralytics/models/yolo/classify/val.py b/ultralytics/models/yolo/classify/val.py index 3ebf3808..de3cff2b 100644 --- a/ultralytics/models/yolo/classify/val.py +++ b/ultralytics/models/yolo/classify/val.py @@ -31,43 +31,42 @@ class ClassificationValidator(BaseValidator): super().__init__(dataloader, save_dir, pbar, args, _callbacks) self.targets = None self.pred = None - self.args.task = 'classify' + self.args.task = "classify" self.metrics = ClassifyMetrics() def get_desc(self): """Returns a formatted string summarizing classification metrics.""" - return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc') + return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc") def init_metrics(self, model): """Initialize confusion matrix, class names, and top-1 and top-5 accuracy.""" self.names = model.names self.nc = len(model.names) - self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task='classify') + self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task="classify") self.pred = [] self.targets = [] def preprocess(self, batch): """Preprocesses input batch and returns it.""" - batch['img'] = batch['img'].to(self.device, non_blocking=True) - batch['img'] = batch['img'].half() if self.args.half else batch['img'].float() - batch['cls'] = batch['cls'].to(self.device) + batch["img"] = batch["img"].to(self.device, non_blocking=True) + batch["img"] = batch["img"].half() if self.args.half else batch["img"].float() + batch["cls"] = batch["cls"].to(self.device) return batch def update_metrics(self, preds, batch): """Updates running metrics with model predictions and batch targets.""" n5 = min(len(self.names), 5) self.pred.append(preds.argsort(1, descending=True)[:, :n5]) - self.targets.append(batch['cls']) + self.targets.append(batch["cls"]) def finalize_metrics(self, *args, **kwargs): """Finalizes metrics of the model such as confusion_matrix and speed.""" self.confusion_matrix.process_cls_preds(self.pred, self.targets) if self.args.plots: for normalize in True, False: - self.confusion_matrix.plot(save_dir=self.save_dir, - names=self.names.values(), - normalize=normalize, - on_plot=self.on_plot) + self.confusion_matrix.plot( + save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot + ) self.metrics.speed = self.speed self.metrics.confusion_matrix = self.confusion_matrix self.metrics.save_dir = self.save_dir @@ -88,24 +87,27 @@ class ClassificationValidator(BaseValidator): def print_results(self): """Prints evaluation metrics for YOLO object detection model.""" - pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format - LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5)) + pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format + LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5)) def plot_val_samples(self, batch, ni): """Plot validation image samples.""" plot_images( - images=batch['img'], - batch_idx=torch.arange(len(batch['img'])), - cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models - fname=self.save_dir / f'val_batch{ni}_labels.jpg', + images=batch["img"], + batch_idx=torch.arange(len(batch["img"])), + cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models + fname=self.save_dir / f"val_batch{ni}_labels.jpg", names=self.names, - on_plot=self.on_plot) + on_plot=self.on_plot, + ) def plot_predictions(self, batch, preds, ni): """Plots predicted bounding boxes on input images and saves the result.""" - plot_images(batch['img'], - batch_idx=torch.arange(len(batch['img'])), - cls=torch.argmax(preds, dim=1), - fname=self.save_dir / f'val_batch{ni}_pred.jpg', - names=self.names, - on_plot=self.on_plot) # pred + plot_images( + batch["img"], + batch_idx=torch.arange(len(batch["img"])), + cls=torch.argmax(preds, dim=1), + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) # pred diff --git a/ultralytics/models/yolo/detect/__init__.py b/ultralytics/models/yolo/detect/__init__.py index 20fc0c48..5f3e62c1 100644 --- a/ultralytics/models/yolo/detect/__init__.py +++ b/ultralytics/models/yolo/detect/__init__.py @@ -4,4 +4,4 @@ from .predict import DetectionPredictor from .train import DetectionTrainer from .val import DetectionValidator -__all__ = 'DetectionPredictor', 'DetectionTrainer', 'DetectionValidator' +__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator" diff --git a/ultralytics/models/yolo/detect/predict.py b/ultralytics/models/yolo/detect/predict.py index 28cbd7ce..3a0c6287 100644 --- a/ultralytics/models/yolo/detect/predict.py +++ b/ultralytics/models/yolo/detect/predict.py @@ -22,12 +22,14 @@ class DetectionPredictor(BasePredictor): def postprocess(self, preds, img, orig_imgs): """Post-processes predictions and returns a list of Results objects.""" - preds = ops.non_max_suppression(preds, - self.args.conf, - self.args.iou, - agnostic=self.args.agnostic_nms, - max_det=self.args.max_det, - classes=self.args.classes) + preds = ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + agnostic=self.args.agnostic_nms, + max_det=self.args.max_det, + classes=self.args.classes, + ) if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) diff --git a/ultralytics/models/yolo/detect/train.py b/ultralytics/models/yolo/detect/train.py index fc656984..3326512b 100644 --- a/ultralytics/models/yolo/detect/train.py +++ b/ultralytics/models/yolo/detect/train.py @@ -30,7 +30,7 @@ class DetectionTrainer(BaseTrainer): ``` """ - def build_dataset(self, img_path, mode='train', batch=None): + def build_dataset(self, img_path, mode="train", batch=None): """ Build YOLO Dataset. @@ -40,33 +40,37 @@ class DetectionTrainer(BaseTrainer): batch (int, optional): Size of batches, this is for `rect`. Defaults to None. """ gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) - return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs) + return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs) - def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'): + def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): """Construct and return dataloader.""" - assert mode in ['train', 'val'] + assert mode in ["train", "val"] with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP dataset = self.build_dataset(dataset_path, mode, batch_size) - shuffle = mode == 'train' - if getattr(dataset, 'rect', False) and shuffle: + shuffle = mode == "train" + if getattr(dataset, "rect", False) and shuffle: LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False") shuffle = False - workers = self.args.workers if mode == 'train' else self.args.workers * 2 + workers = self.args.workers if mode == "train" else self.args.workers * 2 return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader def preprocess_batch(self, batch): """Preprocesses a batch of images by scaling and converting to float.""" - batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255 + batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255 if self.args.multi_scale: - imgs = batch['img'] - sz = (random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride) // self.stride * - self.stride) # size + imgs = batch["img"] + sz = ( + random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride) + // self.stride + * self.stride + ) # size sf = sz / max(imgs.shape[2:]) # scale factor if sf != 1: - ns = [math.ceil(x * sf / self.stride) * self.stride - for x in imgs.shape[2:]] # new shape (stretched to gs-multiple) - imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False) - batch['img'] = imgs + ns = [ + math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:] + ] # new shape (stretched to gs-multiple) + imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False) + batch["img"] = imgs return batch def set_model_attributes(self): @@ -74,33 +78,32 @@ class DetectionTrainer(BaseTrainer): # self.args.box *= 3 / nl # scale to layers # self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers - self.model.nc = self.data['nc'] # attach number of classes to model - self.model.names = self.data['names'] # attach class names to model + self.model.nc = self.data["nc"] # attach number of classes to model + self.model.names = self.data["names"] # attach class names to model self.model.args = self.args # attach hyperparameters to model # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc def get_model(self, cfg=None, weights=None, verbose=True): """Return a YOLO detection model.""" - model = DetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) + model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) if weights: model.load(weights) return model def get_validator(self): """Returns a DetectionValidator for YOLO model validation.""" - self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss' - return yolo.detect.DetectionValidator(self.test_loader, - save_dir=self.save_dir, - args=copy(self.args), - _callbacks=self.callbacks) + self.loss_names = "box_loss", "cls_loss", "dfl_loss" + return yolo.detect.DetectionValidator( + self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) - def label_loss_items(self, loss_items=None, prefix='train'): + def label_loss_items(self, loss_items=None, prefix="train"): """ Returns a loss dict with labelled training loss items tensor. Not needed for classification but necessary for segmentation & detection """ - keys = [f'{prefix}/{x}' for x in self.loss_names] + keys = [f"{prefix}/{x}" for x in self.loss_names] if loss_items is not None: loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats return dict(zip(keys, loss_items)) @@ -109,18 +112,25 @@ class DetectionTrainer(BaseTrainer): def progress_string(self): """Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size.""" - return ('\n' + '%11s' * - (4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size') + return ("\n" + "%11s" * (4 + len(self.loss_names))) % ( + "Epoch", + "GPU_mem", + *self.loss_names, + "Instances", + "Size", + ) def plot_training_samples(self, batch, ni): """Plots training samples with their annotations.""" - plot_images(images=batch['img'], - batch_idx=batch['batch_idx'], - cls=batch['cls'].squeeze(-1), - bboxes=batch['bboxes'], - paths=batch['im_file'], - fname=self.save_dir / f'train_batch{ni}.jpg', - on_plot=self.on_plot) + plot_images( + images=batch["img"], + batch_idx=batch["batch_idx"], + cls=batch["cls"].squeeze(-1), + bboxes=batch["bboxes"], + paths=batch["im_file"], + fname=self.save_dir / f"train_batch{ni}.jpg", + on_plot=self.on_plot, + ) def plot_metrics(self): """Plots metrics from a CSV file.""" @@ -128,6 +138,6 @@ class DetectionTrainer(BaseTrainer): def plot_training_labels(self): """Create a labeled training plot of the YOLO model.""" - boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0) - cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0) - plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot) + boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0) + cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0) + plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot) diff --git a/ultralytics/models/yolo/detect/val.py b/ultralytics/models/yolo/detect/val.py index 295c482c..7ac68e50 100644 --- a/ultralytics/models/yolo/detect/val.py +++ b/ultralytics/models/yolo/detect/val.py @@ -34,7 +34,7 @@ class DetectionValidator(BaseValidator): self.nt_per_class = None self.is_coco = False self.class_map = None - self.args.task = 'detect' + self.args.task = "detect" self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot) self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95 self.niou = self.iouv.numel() @@ -42,25 +42,30 @@ class DetectionValidator(BaseValidator): def preprocess(self, batch): """Preprocesses batch of images for YOLO training.""" - batch['img'] = batch['img'].to(self.device, non_blocking=True) - batch['img'] = (batch['img'].half() if self.args.half else batch['img'].float()) / 255 - for k in ['batch_idx', 'cls', 'bboxes']: + batch["img"] = batch["img"].to(self.device, non_blocking=True) + batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255 + for k in ["batch_idx", "cls", "bboxes"]: batch[k] = batch[k].to(self.device) if self.args.save_hybrid: - height, width = batch['img'].shape[2:] - nb = len(batch['img']) - bboxes = batch['bboxes'] * torch.tensor((width, height, width, height), device=self.device) - self.lb = [ - torch.cat([batch['cls'][batch['batch_idx'] == i], bboxes[batch['batch_idx'] == i]], dim=-1) - for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling + height, width = batch["img"].shape[2:] + nb = len(batch["img"]) + bboxes = batch["bboxes"] * torch.tensor((width, height, width, height), device=self.device) + self.lb = ( + [ + torch.cat([batch["cls"][batch["batch_idx"] == i], bboxes[batch["batch_idx"] == i]], dim=-1) + for i in range(nb) + ] + if self.args.save_hybrid + else [] + ) # for autolabelling return batch def init_metrics(self, model): """Initialize evaluation metrics for YOLO.""" - val = self.data.get(self.args.split, '') # validation path - self.is_coco = isinstance(val, str) and 'coco' in val and val.endswith(f'{os.sep}val2017.txt') # is COCO + val = self.data.get(self.args.split, "") # validation path + self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000)) self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO self.names = model.names @@ -74,26 +79,28 @@ class DetectionValidator(BaseValidator): def get_desc(self): """Return a formatted string summarizing class metrics of YOLO model.""" - return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)') + return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)") def postprocess(self, preds): """Apply Non-maximum suppression to prediction outputs.""" - return ops.non_max_suppression(preds, - self.args.conf, - self.args.iou, - labels=self.lb, - multi_label=True, - agnostic=self.args.single_cls, - max_det=self.args.max_det) + return ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + labels=self.lb, + multi_label=True, + agnostic=self.args.single_cls, + max_det=self.args.max_det, + ) def _prepare_batch(self, si, batch): """Prepares a batch of images and annotations for validation.""" - idx = batch['batch_idx'] == si - cls = batch['cls'][idx].squeeze(-1) - bbox = batch['bboxes'][idx] - ori_shape = batch['ori_shape'][si] - imgsz = batch['img'].shape[2:] - ratio_pad = batch['ratio_pad'][si] + idx = batch["batch_idx"] == si + cls = batch["cls"][idx].squeeze(-1) + bbox = batch["bboxes"][idx] + ori_shape = batch["ori_shape"][si] + imgsz = batch["img"].shape[2:] + ratio_pad = batch["ratio_pad"][si] if len(cls): bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels @@ -103,8 +110,9 @@ class DetectionValidator(BaseValidator): def _prepare_pred(self, pred, pbatch): """Prepares a batch of images and annotations for validation.""" predn = pred.clone() - ops.scale_boxes(pbatch['imgsz'], predn[:, :4], pbatch['ori_shape'], - ratio_pad=pbatch['ratio_pad']) # native-space pred + ops.scale_boxes( + pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"] + ) # native-space pred return predn def update_metrics(self, preds, batch): @@ -112,19 +120,21 @@ class DetectionValidator(BaseValidator): for si, pred in enumerate(preds): self.seen += 1 npr = len(pred) - stat = dict(conf=torch.zeros(0, device=self.device), - pred_cls=torch.zeros(0, device=self.device), - tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device)) + stat = dict( + conf=torch.zeros(0, device=self.device), + pred_cls=torch.zeros(0, device=self.device), + tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + ) pbatch = self._prepare_batch(si, batch) - cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox') + cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox") nl = len(cls) - stat['target_cls'] = cls + stat["target_cls"] = cls if npr == 0: if nl: for k in self.stats.keys(): self.stats[k].append(stat[k]) # TODO: obb has not supported confusion_matrix yet. - if self.args.plots and self.args.task != 'obb': + if self.args.plots and self.args.task != "obb": self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls) continue @@ -132,24 +142,24 @@ class DetectionValidator(BaseValidator): if self.args.single_cls: pred[:, 5] = 0 predn = self._prepare_pred(pred, pbatch) - stat['conf'] = predn[:, 4] - stat['pred_cls'] = predn[:, 5] + stat["conf"] = predn[:, 4] + stat["pred_cls"] = predn[:, 5] # Evaluate if nl: - stat['tp'] = self._process_batch(predn, bbox, cls) + stat["tp"] = self._process_batch(predn, bbox, cls) # TODO: obb has not supported confusion_matrix yet. - if self.args.plots and self.args.task != 'obb': + if self.args.plots and self.args.task != "obb": self.confusion_matrix.process_batch(predn, bbox, cls) for k in self.stats.keys(): self.stats[k].append(stat[k]) # Save if self.args.save_json: - self.pred_to_json(predn, batch['im_file'][si]) + self.pred_to_json(predn, batch["im_file"][si]) if self.args.save_txt: - file = self.save_dir / 'labels' / f'{Path(batch["im_file"][si]).stem}.txt' - self.save_one_txt(predn, self.args.save_conf, pbatch['ori_shape'], file) + file = self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt' + self.save_one_txt(predn, self.args.save_conf, pbatch["ori_shape"], file) def finalize_metrics(self, *args, **kwargs): """Set final values for metrics speed and confusion matrix.""" @@ -159,19 +169,19 @@ class DetectionValidator(BaseValidator): def get_stats(self): """Returns metrics statistics and results dictionary.""" stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy - if len(stats) and stats['tp'].any(): + if len(stats) and stats["tp"].any(): self.metrics.process(**stats) - self.nt_per_class = np.bincount(stats['target_cls'].astype(int), - minlength=self.nc) # number of targets per class + self.nt_per_class = np.bincount( + stats["target_cls"].astype(int), minlength=self.nc + ) # number of targets per class return self.metrics.results_dict def print_results(self): """Prints training/validation set metrics per class.""" - pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format - LOGGER.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results())) + pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format + LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results())) if self.nt_per_class.sum() == 0: - LOGGER.warning( - f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels') + LOGGER.warning(f"WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels") # Print results per class if self.args.verbose and not self.training and self.nc > 1 and len(self.stats): @@ -180,10 +190,9 @@ class DetectionValidator(BaseValidator): if self.args.plots: for normalize in True, False: - self.confusion_matrix.plot(save_dir=self.save_dir, - names=self.names.values(), - normalize=normalize, - on_plot=self.on_plot) + self.confusion_matrix.plot( + save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot + ) def _process_batch(self, detections, gt_bboxes, gt_cls): """ @@ -201,7 +210,7 @@ class DetectionValidator(BaseValidator): iou = box_iou(gt_bboxes, detections[:, :4]) return self.match_predictions(detections[:, 5], gt_cls, iou) - def build_dataset(self, img_path, mode='val', batch=None): + def build_dataset(self, img_path, mode="val", batch=None): """ Build YOLO Dataset. @@ -214,28 +223,32 @@ class DetectionValidator(BaseValidator): def get_dataloader(self, dataset_path, batch_size): """Construct and return dataloader.""" - dataset = self.build_dataset(dataset_path, batch=batch_size, mode='val') + dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val") return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader def plot_val_samples(self, batch, ni): """Plot validation image samples.""" - plot_images(batch['img'], - batch['batch_idx'], - batch['cls'].squeeze(-1), - batch['bboxes'], - paths=batch['im_file'], - fname=self.save_dir / f'val_batch{ni}_labels.jpg', - names=self.names, - on_plot=self.on_plot) + plot_images( + batch["img"], + batch["batch_idx"], + batch["cls"].squeeze(-1), + batch["bboxes"], + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_labels.jpg", + names=self.names, + on_plot=self.on_plot, + ) def plot_predictions(self, batch, preds, ni): """Plots predicted bounding boxes on input images and saves the result.""" - plot_images(batch['img'], - *output_to_target(preds, max_det=self.args.max_det), - paths=batch['im_file'], - fname=self.save_dir / f'val_batch{ni}_pred.jpg', - names=self.names, - on_plot=self.on_plot) # pred + plot_images( + batch["img"], + *output_to_target(preds, max_det=self.args.max_det), + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) # pred def save_one_txt(self, predn, save_conf, shape, file): """Save YOLO detections to a txt file in normalized coordinates in a specific format.""" @@ -243,8 +256,8 @@ class DetectionValidator(BaseValidator): for *xyxy, conf, cls in predn.tolist(): xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format - with open(file, 'a') as f: - f.write(('%g ' * len(line)).rstrip() % line + '\n') + with open(file, "a") as f: + f.write(("%g " * len(line)).rstrip() % line + "\n") def pred_to_json(self, predn, filename): """Serialize YOLO predictions to COCO json format.""" @@ -253,28 +266,31 @@ class DetectionValidator(BaseValidator): box = ops.xyxy2xywh(predn[:, :4]) # xywh box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner for p, b in zip(predn.tolist(), box.tolist()): - self.jdict.append({ - 'image_id': image_id, - 'category_id': self.class_map[int(p[5])], - 'bbox': [round(x, 3) for x in b], - 'score': round(p[4], 5)}) + self.jdict.append( + { + "image_id": image_id, + "category_id": self.class_map[int(p[5])], + "bbox": [round(x, 3) for x in b], + "score": round(p[4], 5), + } + ) def eval_json(self, stats): """Evaluates YOLO output in JSON format and returns performance statistics.""" if self.args.save_json and self.is_coco and len(self.jdict): - anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations - pred_json = self.save_dir / 'predictions.json' # predictions - LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...') + anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations + pred_json = self.save_dir / "predictions.json" # predictions + LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...") try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb - check_requirements('pycocotools>=2.0.6') + check_requirements("pycocotools>=2.0.6") from pycocotools.coco import COCO # noqa from pycocotools.cocoeval import COCOeval # noqa for x in anno_json, pred_json: - assert x.is_file(), f'{x} file not found' + assert x.is_file(), f"{x} file not found" anno = COCO(str(anno_json)) # init annotations api pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) - eval = COCOeval(anno, pred, 'bbox') + eval = COCOeval(anno, pred, "bbox") if self.is_coco: eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval eval.evaluate() @@ -282,5 +298,5 @@ class DetectionValidator(BaseValidator): eval.summarize() stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50 except Exception as e: - LOGGER.warning(f'pycocotools unable to run: {e}') + LOGGER.warning(f"pycocotools unable to run: {e}") return stats diff --git a/ultralytics/models/yolo/model.py b/ultralytics/models/yolo/model.py index eb2225d8..858e42c5 100644 --- a/ultralytics/models/yolo/model.py +++ b/ultralytics/models/yolo/model.py @@ -12,28 +12,34 @@ class YOLO(Model): def task_map(self): """Map head to model, trainer, validator, and predictor classes.""" return { - 'classify': { - 'model': ClassificationModel, - 'trainer': yolo.classify.ClassificationTrainer, - 'validator': yolo.classify.ClassificationValidator, - 'predictor': yolo.classify.ClassificationPredictor, }, - 'detect': { - 'model': DetectionModel, - 'trainer': yolo.detect.DetectionTrainer, - 'validator': yolo.detect.DetectionValidator, - 'predictor': yolo.detect.DetectionPredictor, }, - 'segment': { - 'model': SegmentationModel, - 'trainer': yolo.segment.SegmentationTrainer, - 'validator': yolo.segment.SegmentationValidator, - 'predictor': yolo.segment.SegmentationPredictor, }, - 'pose': { - 'model': PoseModel, - 'trainer': yolo.pose.PoseTrainer, - 'validator': yolo.pose.PoseValidator, - 'predictor': yolo.pose.PosePredictor, }, - 'obb': { - 'model': OBBModel, - 'trainer': yolo.obb.OBBTrainer, - 'validator': yolo.obb.OBBValidator, - 'predictor': yolo.obb.OBBPredictor, }, } + "classify": { + "model": ClassificationModel, + "trainer": yolo.classify.ClassificationTrainer, + "validator": yolo.classify.ClassificationValidator, + "predictor": yolo.classify.ClassificationPredictor, + }, + "detect": { + "model": DetectionModel, + "trainer": yolo.detect.DetectionTrainer, + "validator": yolo.detect.DetectionValidator, + "predictor": yolo.detect.DetectionPredictor, + }, + "segment": { + "model": SegmentationModel, + "trainer": yolo.segment.SegmentationTrainer, + "validator": yolo.segment.SegmentationValidator, + "predictor": yolo.segment.SegmentationPredictor, + }, + "pose": { + "model": PoseModel, + "trainer": yolo.pose.PoseTrainer, + "validator": yolo.pose.PoseValidator, + "predictor": yolo.pose.PosePredictor, + }, + "obb": { + "model": OBBModel, + "trainer": yolo.obb.OBBTrainer, + "validator": yolo.obb.OBBValidator, + "predictor": yolo.obb.OBBPredictor, + }, + } diff --git a/ultralytics/models/yolo/obb/__init__.py b/ultralytics/models/yolo/obb/__init__.py index 09f10481..f60349a7 100644 --- a/ultralytics/models/yolo/obb/__init__.py +++ b/ultralytics/models/yolo/obb/__init__.py @@ -4,4 +4,4 @@ from .predict import OBBPredictor from .train import OBBTrainer from .val import OBBValidator -__all__ = 'OBBPredictor', 'OBBTrainer', 'OBBValidator' +__all__ = "OBBPredictor", "OBBTrainer", "OBBValidator" diff --git a/ultralytics/models/yolo/obb/predict.py b/ultralytics/models/yolo/obb/predict.py index 662266f7..572082b8 100644 --- a/ultralytics/models/yolo/obb/predict.py +++ b/ultralytics/models/yolo/obb/predict.py @@ -25,26 +25,27 @@ class OBBPredictor(DetectionPredictor): def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): """Initializes OBBPredictor with optional model and data configuration overrides.""" super().__init__(cfg, overrides, _callbacks) - self.args.task = 'obb' + self.args.task = "obb" def postprocess(self, preds, img, orig_imgs): """Post-processes predictions and returns a list of Results objects.""" - preds = ops.non_max_suppression(preds, - self.args.conf, - self.args.iou, - agnostic=self.args.agnostic_nms, - max_det=self.args.max_det, - nc=len(self.model.names), - classes=self.args.classes, - rotated=True) + preds = ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + agnostic=self.args.agnostic_nms, + max_det=self.args.max_det, + nc=len(self.model.names), + classes=self.args.classes, + rotated=True, + ) if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) results = [] - for i, (pred, orig_img) in enumerate(zip(preds, orig_imgs)): + for i, (pred, orig_img, img_path) in enumerate(zip(preds, orig_imgs, self.batch[0])): pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape, xywh=True) - img_path = self.batch[0][i] # xywh, r, conf, cls obb = torch.cat([pred[:, :4], pred[:, -1:], pred[:, 4:6]], dim=-1) results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb)) diff --git a/ultralytics/models/yolo/obb/train.py b/ultralytics/models/yolo/obb/train.py index 0d1284a7..43ebaecd 100644 --- a/ultralytics/models/yolo/obb/train.py +++ b/ultralytics/models/yolo/obb/train.py @@ -25,12 +25,12 @@ class OBBTrainer(yolo.detect.DetectionTrainer): """Initialize a OBBTrainer object with given arguments.""" if overrides is None: overrides = {} - overrides['task'] = 'obb' + overrides["task"] = "obb" super().__init__(cfg, overrides, _callbacks) def get_model(self, cfg=None, weights=None, verbose=True): """Return OBBModel initialized with specified config and weights.""" - model = OBBModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1) + model = OBBModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1) if weights: model.load(weights) @@ -38,5 +38,5 @@ class OBBTrainer(yolo.detect.DetectionTrainer): def get_validator(self): """Return an instance of OBBValidator for validation of YOLO model.""" - self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss' + self.loss_names = "box_loss", "cls_loss", "dfl_loss" return yolo.obb.OBBValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) diff --git a/ultralytics/models/yolo/obb/val.py b/ultralytics/models/yolo/obb/val.py index a5c030a8..40dbb604 100644 --- a/ultralytics/models/yolo/obb/val.py +++ b/ultralytics/models/yolo/obb/val.py @@ -27,26 +27,28 @@ class OBBValidator(DetectionValidator): def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): """Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.""" super().__init__(dataloader, save_dir, pbar, args, _callbacks) - self.args.task = 'obb' + self.args.task = "obb" self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True, on_plot=self.on_plot) def init_metrics(self, model): """Initialize evaluation metrics for YOLO.""" super().init_metrics(model) - val = self.data.get(self.args.split, '') # validation path - self.is_dota = isinstance(val, str) and 'DOTA' in val # is COCO + val = self.data.get(self.args.split, "") # validation path + self.is_dota = isinstance(val, str) and "DOTA" in val # is COCO def postprocess(self, preds): """Apply Non-maximum suppression to prediction outputs.""" - return ops.non_max_suppression(preds, - self.args.conf, - self.args.iou, - labels=self.lb, - nc=self.nc, - multi_label=True, - agnostic=self.args.single_cls, - max_det=self.args.max_det, - rotated=True) + return ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + labels=self.lb, + nc=self.nc, + multi_label=True, + agnostic=self.args.single_cls, + max_det=self.args.max_det, + rotated=True, + ) def _process_batch(self, detections, gt_bboxes, gt_cls): """ @@ -66,12 +68,12 @@ class OBBValidator(DetectionValidator): def _prepare_batch(self, si, batch): """Prepares and returns a batch for OBB validation.""" - idx = batch['batch_idx'] == si - cls = batch['cls'][idx].squeeze(-1) - bbox = batch['bboxes'][idx] - ori_shape = batch['ori_shape'][si] - imgsz = batch['img'].shape[2:] - ratio_pad = batch['ratio_pad'][si] + idx = batch["batch_idx"] == si + cls = batch["cls"][idx].squeeze(-1) + bbox = batch["bboxes"][idx] + ori_shape = batch["ori_shape"][si] + imgsz = batch["img"].shape[2:] + ratio_pad = batch["ratio_pad"][si] if len(cls): bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels @@ -81,18 +83,21 @@ class OBBValidator(DetectionValidator): def _prepare_pred(self, pred, pbatch): """Prepares and returns a batch for OBB validation with scaled and padded bounding boxes.""" predn = pred.clone() - ops.scale_boxes(pbatch['imgsz'], predn[:, :4], pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'], - xywh=True) # native-space pred + ops.scale_boxes( + pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True + ) # native-space pred return predn def plot_predictions(self, batch, preds, ni): """Plots predicted bounding boxes on input images and saves the result.""" - plot_images(batch['img'], - *output_to_rotated_target(preds, max_det=self.args.max_det), - paths=batch['im_file'], - fname=self.save_dir / f'val_batch{ni}_pred.jpg', - names=self.names, - on_plot=self.on_plot) # pred + plot_images( + batch["img"], + *output_to_rotated_target(preds, max_det=self.args.max_det), + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) # pred def pred_to_json(self, predn, filename): """Serialize YOLO predictions to COCO json format.""" @@ -101,12 +106,15 @@ class OBBValidator(DetectionValidator): rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1) poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8) for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())): - self.jdict.append({ - 'image_id': image_id, - 'category_id': self.class_map[int(predn[i, 5].item())], - 'score': round(predn[i, 4].item(), 5), - 'rbox': [round(x, 3) for x in r], - 'poly': [round(x, 3) for x in b]}) + self.jdict.append( + { + "image_id": image_id, + "category_id": self.class_map[int(predn[i, 5].item())], + "score": round(predn[i, 4].item(), 5), + "rbox": [round(x, 3) for x in r], + "poly": [round(x, 3) for x in b], + } + ) def save_one_txt(self, predn, save_conf, shape, file): """Save YOLO detections to a txt file in normalized coordinates in a specific format.""" @@ -116,8 +124,8 @@ class OBBValidator(DetectionValidator): xywha[:, :4] /= gn xyxyxyxy = ops.xywhr2xyxyxyxy(xywha).view(-1).tolist() # normalized xywh line = (cls, *xyxyxyxy, conf) if save_conf else (cls, *xyxyxyxy) # label format - with open(file, 'a') as f: - f.write(('%g ' * len(line)).rstrip() % line + '\n') + with open(file, "a") as f: + f.write(("%g " * len(line)).rstrip() % line + "\n") def eval_json(self, stats): """Evaluates YOLO output in JSON format and returns performance statistics.""" @@ -125,42 +133,43 @@ class OBBValidator(DetectionValidator): import json import re from collections import defaultdict - pred_json = self.save_dir / 'predictions.json' # predictions - pred_txt = self.save_dir / 'predictions_txt' # predictions + + pred_json = self.save_dir / "predictions.json" # predictions + pred_txt = self.save_dir / "predictions_txt" # predictions pred_txt.mkdir(parents=True, exist_ok=True) data = json.load(open(pred_json)) # Save split results - LOGGER.info(f'Saving predictions with DOTA format to {str(pred_txt)}...') + LOGGER.info(f"Saving predictions with DOTA format to {str(pred_txt)}...") for d in data: - image_id = d['image_id'] - score = d['score'] - classname = self.names[d['category_id']].replace(' ', '-') + image_id = d["image_id"] + score = d["score"] + classname = self.names[d["category_id"]].replace(" ", "-") - lines = '{} {} {} {} {} {} {} {} {} {}\n'.format( + lines = "{} {} {} {} {} {} {} {} {} {}\n".format( image_id, score, - d['poly'][0], - d['poly'][1], - d['poly'][2], - d['poly'][3], - d['poly'][4], - d['poly'][5], - d['poly'][6], - d['poly'][7], + d["poly"][0], + d["poly"][1], + d["poly"][2], + d["poly"][3], + d["poly"][4], + d["poly"][5], + d["poly"][6], + d["poly"][7], ) - with open(str(pred_txt / f'Task1_{classname}') + '.txt', 'a') as f: + with open(str(pred_txt / f"Task1_{classname}") + ".txt", "a") as f: f.writelines(lines) # Save merged results, this could result slightly lower map than using official merging script, # because of the probiou calculation. - pred_merged_txt = self.save_dir / 'predictions_merged_txt' # predictions + pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions pred_merged_txt.mkdir(parents=True, exist_ok=True) merged_results = defaultdict(list) - LOGGER.info(f'Saving merged predictions with DOTA format to {str(pred_merged_txt)}...') + LOGGER.info(f"Saving merged predictions with DOTA format to {str(pred_merged_txt)}...") for d in data: - image_id = d['image_id'].split('__')[0] - pattern = re.compile(r'\d+___\d+') - x, y = (int(c) for c in re.findall(pattern, d['image_id'])[0].split('___')) - bbox, score, cls = d['rbox'], d['score'], d['category_id'] + image_id = d["image_id"].split("__")[0] + pattern = re.compile(r"\d+___\d+") + x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___")) + bbox, score, cls = d["rbox"], d["score"], d["category_id"] bbox[0] += x bbox[1] += y bbox.extend([score, cls]) @@ -178,11 +187,11 @@ class OBBValidator(DetectionValidator): b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8) for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist(): - classname = self.names[int(x[-1])].replace(' ', '-') + classname = self.names[int(x[-1])].replace(" ", "-") poly = [round(i, 3) for i in x[:-2]] score = round(x[-2], 3) - lines = '{} {} {} {} {} {} {} {} {} {}\n'.format( + lines = "{} {} {} {} {} {} {} {} {} {}\n".format( image_id, score, poly[0], @@ -194,7 +203,7 @@ class OBBValidator(DetectionValidator): poly[6], poly[7], ) - with open(str(pred_merged_txt / f'Task1_{classname}') + '.txt', 'a') as f: + with open(str(pred_merged_txt / f"Task1_{classname}") + ".txt", "a") as f: f.writelines(lines) return stats diff --git a/ultralytics/models/yolo/pose/__init__.py b/ultralytics/models/yolo/pose/__init__.py index 2a79f0f3..d5669430 100644 --- a/ultralytics/models/yolo/pose/__init__.py +++ b/ultralytics/models/yolo/pose/__init__.py @@ -4,4 +4,4 @@ from .predict import PosePredictor from .train import PoseTrainer from .val import PoseValidator -__all__ = 'PoseTrainer', 'PoseValidator', 'PosePredictor' +__all__ = "PoseTrainer", "PoseValidator", "PosePredictor" diff --git a/ultralytics/models/yolo/pose/predict.py b/ultralytics/models/yolo/pose/predict.py index d00cea02..7c55709f 100644 --- a/ultralytics/models/yolo/pose/predict.py +++ b/ultralytics/models/yolo/pose/predict.py @@ -23,20 +23,24 @@ class PosePredictor(DetectionPredictor): def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): """Initializes PosePredictor, sets task to 'pose' and logs a warning for using 'mps' as device.""" super().__init__(cfg, overrides, _callbacks) - self.args.task = 'pose' - if isinstance(self.args.device, str) and self.args.device.lower() == 'mps': - LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " - 'See https://github.com/ultralytics/ultralytics/issues/4031.') + self.args.task = "pose" + if isinstance(self.args.device, str) and self.args.device.lower() == "mps": + LOGGER.warning( + "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " + "See https://github.com/ultralytics/ultralytics/issues/4031." + ) def postprocess(self, preds, img, orig_imgs): """Return detection results for a given input image or list of images.""" - preds = ops.non_max_suppression(preds, - self.args.conf, - self.args.iou, - agnostic=self.args.agnostic_nms, - max_det=self.args.max_det, - classes=self.args.classes, - nc=len(self.model.names)) + preds = ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + agnostic=self.args.agnostic_nms, + max_det=self.args.max_det, + classes=self.args.classes, + nc=len(self.model.names), + ) if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) @@ -49,5 +53,6 @@ class PosePredictor(DetectionPredictor): pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape) img_path = self.batch[0][i] results.append( - Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)) + Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts) + ) return results diff --git a/ultralytics/models/yolo/pose/train.py b/ultralytics/models/yolo/pose/train.py index c9ccf52e..f5229e50 100644 --- a/ultralytics/models/yolo/pose/train.py +++ b/ultralytics/models/yolo/pose/train.py @@ -26,16 +26,18 @@ class PoseTrainer(yolo.detect.DetectionTrainer): """Initialize a PoseTrainer object with specified configurations and overrides.""" if overrides is None: overrides = {} - overrides['task'] = 'pose' + overrides["task"] = "pose" super().__init__(cfg, overrides, _callbacks) - if isinstance(self.args.device, str) and self.args.device.lower() == 'mps': - LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " - 'See https://github.com/ultralytics/ultralytics/issues/4031.') + if isinstance(self.args.device, str) and self.args.device.lower() == "mps": + LOGGER.warning( + "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " + "See https://github.com/ultralytics/ultralytics/issues/4031." + ) def get_model(self, cfg=None, weights=None, verbose=True): """Get pose estimation model with specified configuration and weights.""" - model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose) + model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose) if weights: model.load(weights) @@ -44,32 +46,33 @@ class PoseTrainer(yolo.detect.DetectionTrainer): def set_model_attributes(self): """Sets keypoints shape attribute of PoseModel.""" super().set_model_attributes() - self.model.kpt_shape = self.data['kpt_shape'] + self.model.kpt_shape = self.data["kpt_shape"] def get_validator(self): """Returns an instance of the PoseValidator class for validation.""" - self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss' - return yolo.pose.PoseValidator(self.test_loader, - save_dir=self.save_dir, - args=copy(self.args), - _callbacks=self.callbacks) + self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss" + return yolo.pose.PoseValidator( + self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) def plot_training_samples(self, batch, ni): """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.""" - images = batch['img'] - kpts = batch['keypoints'] - cls = batch['cls'].squeeze(-1) - bboxes = batch['bboxes'] - paths = batch['im_file'] - batch_idx = batch['batch_idx'] - plot_images(images, - batch_idx, - cls, - bboxes, - kpts=kpts, - paths=paths, - fname=self.save_dir / f'train_batch{ni}.jpg', - on_plot=self.on_plot) + images = batch["img"] + kpts = batch["keypoints"] + cls = batch["cls"].squeeze(-1) + bboxes = batch["bboxes"] + paths = batch["im_file"] + batch_idx = batch["batch_idx"] + plot_images( + images, + batch_idx, + cls, + bboxes, + kpts=kpts, + paths=paths, + fname=self.save_dir / f"train_batch{ni}.jpg", + on_plot=self.on_plot, + ) def plot_metrics(self): """Plots training/val metrics.""" diff --git a/ultralytics/models/yolo/pose/val.py b/ultralytics/models/yolo/pose/val.py index f7855bf2..84056863 100644 --- a/ultralytics/models/yolo/pose/val.py +++ b/ultralytics/models/yolo/pose/val.py @@ -31,38 +31,53 @@ class PoseValidator(DetectionValidator): super().__init__(dataloader, save_dir, pbar, args, _callbacks) self.sigma = None self.kpt_shape = None - self.args.task = 'pose' + self.args.task = "pose" self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot) - if isinstance(self.args.device, str) and self.args.device.lower() == 'mps': - LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " - 'See https://github.com/ultralytics/ultralytics/issues/4031.') + if isinstance(self.args.device, str) and self.args.device.lower() == "mps": + LOGGER.warning( + "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " + "See https://github.com/ultralytics/ultralytics/issues/4031." + ) def preprocess(self, batch): """Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device.""" batch = super().preprocess(batch) - batch['keypoints'] = batch['keypoints'].to(self.device).float() + batch["keypoints"] = batch["keypoints"].to(self.device).float() return batch def get_desc(self): """Returns description of evaluation metrics in string format.""" - return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Pose(P', - 'R', 'mAP50', 'mAP50-95)') + return ("%22s" + "%11s" * 10) % ( + "Class", + "Images", + "Instances", + "Box(P", + "R", + "mAP50", + "mAP50-95)", + "Pose(P", + "R", + "mAP50", + "mAP50-95)", + ) def postprocess(self, preds): """Apply non-maximum suppression and return detections with high confidence scores.""" - return ops.non_max_suppression(preds, - self.args.conf, - self.args.iou, - labels=self.lb, - multi_label=True, - agnostic=self.args.single_cls, - max_det=self.args.max_det, - nc=self.nc) + return ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + labels=self.lb, + multi_label=True, + agnostic=self.args.single_cls, + max_det=self.args.max_det, + nc=self.nc, + ) def init_metrics(self, model): """Initiate pose estimation metrics for YOLO model.""" super().init_metrics(model) - self.kpt_shape = self.data['kpt_shape'] + self.kpt_shape = self.data["kpt_shape"] is_pose = self.kpt_shape == [17, 3] nkpt = self.kpt_shape[0] self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt @@ -71,21 +86,21 @@ class PoseValidator(DetectionValidator): def _prepare_batch(self, si, batch): """Prepares a batch for processing by converting keypoints to float and moving to device.""" pbatch = super()._prepare_batch(si, batch) - kpts = batch['keypoints'][batch['batch_idx'] == si] - h, w = pbatch['imgsz'] + kpts = batch["keypoints"][batch["batch_idx"] == si] + h, w = pbatch["imgsz"] kpts = kpts.clone() kpts[..., 0] *= w kpts[..., 1] *= h - kpts = ops.scale_coords(pbatch['imgsz'], kpts, pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad']) - pbatch['kpts'] = kpts + kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]) + pbatch["kpts"] = kpts return pbatch def _prepare_pred(self, pred, pbatch): """Prepares and scales keypoints in a batch for pose processing.""" predn = super()._prepare_pred(pred, pbatch) - nk = pbatch['kpts'].shape[1] + nk = pbatch["kpts"].shape[1] pred_kpts = predn[:, 6:].view(len(predn), nk, -1) - ops.scale_coords(pbatch['imgsz'], pred_kpts, pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad']) + ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]) return predn, pred_kpts def update_metrics(self, preds, batch): @@ -93,14 +108,16 @@ class PoseValidator(DetectionValidator): for si, pred in enumerate(preds): self.seen += 1 npr = len(pred) - stat = dict(conf=torch.zeros(0, device=self.device), - pred_cls=torch.zeros(0, device=self.device), - tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), - tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device)) + stat = dict( + conf=torch.zeros(0, device=self.device), + pred_cls=torch.zeros(0, device=self.device), + tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + ) pbatch = self._prepare_batch(si, batch) - cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox') + cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox") nl = len(cls) - stat['target_cls'] = cls + stat["target_cls"] = cls if npr == 0: if nl: for k in self.stats.keys(): @@ -113,13 +130,13 @@ class PoseValidator(DetectionValidator): if self.args.single_cls: pred[:, 5] = 0 predn, pred_kpts = self._prepare_pred(pred, pbatch) - stat['conf'] = predn[:, 4] - stat['pred_cls'] = predn[:, 5] + stat["conf"] = predn[:, 4] + stat["pred_cls"] = predn[:, 5] # Evaluate if nl: - stat['tp'] = self._process_batch(predn, bbox, cls) - stat['tp_p'] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch['kpts']) + stat["tp"] = self._process_batch(predn, bbox, cls) + stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"]) if self.args.plots: self.confusion_matrix.process_batch(predn, bbox, cls) @@ -128,7 +145,7 @@ class PoseValidator(DetectionValidator): # Save if self.args.save_json: - self.pred_to_json(predn, batch['im_file'][si]) + self.pred_to_json(predn, batch["im_file"][si]) # if self.args.save_txt: # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') @@ -159,26 +176,30 @@ class PoseValidator(DetectionValidator): def plot_val_samples(self, batch, ni): """Plots and saves validation set samples with predicted bounding boxes and keypoints.""" - plot_images(batch['img'], - batch['batch_idx'], - batch['cls'].squeeze(-1), - batch['bboxes'], - kpts=batch['keypoints'], - paths=batch['im_file'], - fname=self.save_dir / f'val_batch{ni}_labels.jpg', - names=self.names, - on_plot=self.on_plot) + plot_images( + batch["img"], + batch["batch_idx"], + batch["cls"].squeeze(-1), + batch["bboxes"], + kpts=batch["keypoints"], + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_labels.jpg", + names=self.names, + on_plot=self.on_plot, + ) def plot_predictions(self, batch, preds, ni): """Plots predictions for YOLO model.""" pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0) - plot_images(batch['img'], - *output_to_target(preds, max_det=self.args.max_det), - kpts=pred_kpts, - paths=batch['im_file'], - fname=self.save_dir / f'val_batch{ni}_pred.jpg', - names=self.names, - on_plot=self.on_plot) # pred + plot_images( + batch["img"], + *output_to_target(preds, max_det=self.args.max_det), + kpts=pred_kpts, + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) # pred def pred_to_json(self, predn, filename): """Converts YOLO predictions to COCO JSON format.""" @@ -187,37 +208,41 @@ class PoseValidator(DetectionValidator): box = ops.xyxy2xywh(predn[:, :4]) # xywh box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner for p, b in zip(predn.tolist(), box.tolist()): - self.jdict.append({ - 'image_id': image_id, - 'category_id': self.class_map[int(p[5])], - 'bbox': [round(x, 3) for x in b], - 'keypoints': p[6:], - 'score': round(p[4], 5)}) + self.jdict.append( + { + "image_id": image_id, + "category_id": self.class_map[int(p[5])], + "bbox": [round(x, 3) for x in b], + "keypoints": p[6:], + "score": round(p[4], 5), + } + ) def eval_json(self, stats): """Evaluates object detection model using COCO JSON format.""" if self.args.save_json and self.is_coco and len(self.jdict): - anno_json = self.data['path'] / 'annotations/person_keypoints_val2017.json' # annotations - pred_json = self.save_dir / 'predictions.json' # predictions - LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...') + anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations + pred_json = self.save_dir / "predictions.json" # predictions + LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...") try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb - check_requirements('pycocotools>=2.0.6') + check_requirements("pycocotools>=2.0.6") from pycocotools.coco import COCO # noqa from pycocotools.cocoeval import COCOeval # noqa for x in anno_json, pred_json: - assert x.is_file(), f'{x} file not found' + assert x.is_file(), f"{x} file not found" anno = COCO(str(anno_json)) # init annotations api pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) - for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'keypoints')]): + for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "keypoints")]): if self.is_coco: eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval eval.evaluate() eval.accumulate() eval.summarize() idx = i * 4 + 2 - stats[self.metrics.keys[idx + 1]], stats[ - self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50 + stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[ + :2 + ] # update mAP50-95 and mAP50 except Exception as e: - LOGGER.warning(f'pycocotools unable to run: {e}') + LOGGER.warning(f"pycocotools unable to run: {e}") return stats diff --git a/ultralytics/models/yolo/segment/__init__.py b/ultralytics/models/yolo/segment/__init__.py index c84a570e..ec1ac799 100644 --- a/ultralytics/models/yolo/segment/__init__.py +++ b/ultralytics/models/yolo/segment/__init__.py @@ -4,4 +4,4 @@ from .predict import SegmentationPredictor from .train import SegmentationTrainer from .val import SegmentationValidator -__all__ = 'SegmentationPredictor', 'SegmentationTrainer', 'SegmentationValidator' +__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator" diff --git a/ultralytics/models/yolo/segment/predict.py b/ultralytics/models/yolo/segment/predict.py index ba44a482..8d8cd59a 100644 --- a/ultralytics/models/yolo/segment/predict.py +++ b/ultralytics/models/yolo/segment/predict.py @@ -23,17 +23,19 @@ class SegmentationPredictor(DetectionPredictor): def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): """Initializes the SegmentationPredictor with the provided configuration, overrides, and callbacks.""" super().__init__(cfg, overrides, _callbacks) - self.args.task = 'segment' + self.args.task = "segment" def postprocess(self, preds, img, orig_imgs): """Applies non-max suppression and processes detections for each image in an input batch.""" - p = ops.non_max_suppression(preds[0], - self.args.conf, - self.args.iou, - agnostic=self.args.agnostic_nms, - max_det=self.args.max_det, - nc=len(self.model.names), - classes=self.args.classes) + p = ops.non_max_suppression( + preds[0], + self.args.conf, + self.args.iou, + agnostic=self.args.agnostic_nms, + max_det=self.args.max_det, + nc=len(self.model.names), + classes=self.args.classes, + ) if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) diff --git a/ultralytics/models/yolo/segment/train.py b/ultralytics/models/yolo/segment/train.py index 1d1227da..126baf20 100644 --- a/ultralytics/models/yolo/segment/train.py +++ b/ultralytics/models/yolo/segment/train.py @@ -26,12 +26,12 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer): """Initialize a SegmentationTrainer object with given arguments.""" if overrides is None: overrides = {} - overrides['task'] = 'segment' + overrides["task"] = "segment" super().__init__(cfg, overrides, _callbacks) def get_model(self, cfg=None, weights=None, verbose=True): """Return SegmentationModel initialized with specified config and weights.""" - model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1) + model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1) if weights: model.load(weights) @@ -39,22 +39,23 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer): def get_validator(self): """Return an instance of SegmentationValidator for validation of YOLO model.""" - self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss' - return yolo.segment.SegmentationValidator(self.test_loader, - save_dir=self.save_dir, - args=copy(self.args), - _callbacks=self.callbacks) + self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss" + return yolo.segment.SegmentationValidator( + self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) def plot_training_samples(self, batch, ni): """Creates a plot of training sample images with labels and box coordinates.""" - plot_images(batch['img'], - batch['batch_idx'], - batch['cls'].squeeze(-1), - batch['bboxes'], - masks=batch['masks'], - paths=batch['im_file'], - fname=self.save_dir / f'train_batch{ni}.jpg', - on_plot=self.on_plot) + plot_images( + batch["img"], + batch["batch_idx"], + batch["cls"].squeeze(-1), + batch["bboxes"], + masks=batch["masks"], + paths=batch["im_file"], + fname=self.save_dir / f"train_batch{ni}.jpg", + on_plot=self.on_plot, + ) def plot_metrics(self): """Plots training/val metrics.""" diff --git a/ultralytics/models/yolo/segment/val.py b/ultralytics/models/yolo/segment/val.py index 10ac374f..72c04b2c 100644 --- a/ultralytics/models/yolo/segment/val.py +++ b/ultralytics/models/yolo/segment/val.py @@ -33,13 +33,13 @@ class SegmentationValidator(DetectionValidator): super().__init__(dataloader, save_dir, pbar, args, _callbacks) self.plot_masks = None self.process = None - self.args.task = 'segment' + self.args.task = "segment" self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot) def preprocess(self, batch): """Preprocesses batch by converting masks to float and sending to device.""" batch = super().preprocess(batch) - batch['masks'] = batch['masks'].to(self.device).float() + batch["masks"] = batch["masks"].to(self.device).float() return batch def init_metrics(self, model): @@ -47,7 +47,7 @@ class SegmentationValidator(DetectionValidator): super().init_metrics(model) self.plot_masks = [] if self.args.save_json: - check_requirements('pycocotools>=2.0.6') + check_requirements("pycocotools>=2.0.6") self.process = ops.process_mask_upsample # more accurate else: self.process = ops.process_mask # faster @@ -55,33 +55,46 @@ class SegmentationValidator(DetectionValidator): def get_desc(self): """Return a formatted description of evaluation metrics.""" - return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P', - 'R', 'mAP50', 'mAP50-95)') + return ("%22s" + "%11s" * 10) % ( + "Class", + "Images", + "Instances", + "Box(P", + "R", + "mAP50", + "mAP50-95)", + "Mask(P", + "R", + "mAP50", + "mAP50-95)", + ) def postprocess(self, preds): """Post-processes YOLO predictions and returns output detections with proto.""" - p = ops.non_max_suppression(preds[0], - self.args.conf, - self.args.iou, - labels=self.lb, - multi_label=True, - agnostic=self.args.single_cls, - max_det=self.args.max_det, - nc=self.nc) + p = ops.non_max_suppression( + preds[0], + self.args.conf, + self.args.iou, + labels=self.lb, + multi_label=True, + agnostic=self.args.single_cls, + max_det=self.args.max_det, + nc=self.nc, + ) proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported return p, proto def _prepare_batch(self, si, batch): """Prepares a batch for training or inference by processing images and targets.""" prepared_batch = super()._prepare_batch(si, batch) - midx = [si] if self.args.overlap_mask else batch['batch_idx'] == si - prepared_batch['masks'] = batch['masks'][midx] + midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si + prepared_batch["masks"] = batch["masks"][midx] return prepared_batch def _prepare_pred(self, pred, pbatch, proto): """Prepares a batch for training or inference by processing images and targets.""" predn = super()._prepare_pred(pred, pbatch) - pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch['imgsz']) + pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch["imgsz"]) return predn, pred_masks def update_metrics(self, preds, batch): @@ -89,14 +102,16 @@ class SegmentationValidator(DetectionValidator): for si, (pred, proto) in enumerate(zip(preds[0], preds[1])): self.seen += 1 npr = len(pred) - stat = dict(conf=torch.zeros(0, device=self.device), - pred_cls=torch.zeros(0, device=self.device), - tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), - tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device)) + stat = dict( + conf=torch.zeros(0, device=self.device), + pred_cls=torch.zeros(0, device=self.device), + tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + ) pbatch = self._prepare_batch(si, batch) - cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox') + cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox") nl = len(cls) - stat['target_cls'] = cls + stat["target_cls"] = cls if npr == 0: if nl: for k in self.stats.keys(): @@ -106,24 +121,20 @@ class SegmentationValidator(DetectionValidator): continue # Masks - gt_masks = pbatch.pop('masks') + gt_masks = pbatch.pop("masks") # Predictions if self.args.single_cls: pred[:, 5] = 0 predn, pred_masks = self._prepare_pred(pred, pbatch, proto) - stat['conf'] = predn[:, 4] - stat['pred_cls'] = predn[:, 5] + stat["conf"] = predn[:, 4] + stat["pred_cls"] = predn[:, 5] # Evaluate if nl: - stat['tp'] = self._process_batch(predn, bbox, cls) - stat['tp_m'] = self._process_batch(predn, - bbox, - cls, - pred_masks, - gt_masks, - self.args.overlap_mask, - masks=True) + stat["tp"] = self._process_batch(predn, bbox, cls) + stat["tp_m"] = self._process_batch( + predn, bbox, cls, pred_masks, gt_masks, self.args.overlap_mask, masks=True + ) if self.args.plots: self.confusion_matrix.process_batch(predn, bbox, cls) @@ -136,10 +147,12 @@ class SegmentationValidator(DetectionValidator): # Save if self.args.save_json: - pred_masks = ops.scale_image(pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), - pbatch['ori_shape'], - ratio_pad=batch['ratio_pad'][si]) - self.pred_to_json(predn, batch['im_file'][si], pred_masks) + pred_masks = ops.scale_image( + pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), + pbatch["ori_shape"], + ratio_pad=batch["ratio_pad"][si], + ) + self.pred_to_json(predn, batch["im_file"][si], pred_masks) # if self.args.save_txt: # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') @@ -166,7 +179,7 @@ class SegmentationValidator(DetectionValidator): gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640) gt_masks = torch.where(gt_masks == index, 1.0, 0.0) if gt_masks.shape[1:] != pred_masks.shape[1:]: - gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0] + gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0] gt_masks = gt_masks.gt_(0.5) iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1)) else: # boxes @@ -176,26 +189,29 @@ class SegmentationValidator(DetectionValidator): def plot_val_samples(self, batch, ni): """Plots validation samples with bounding box labels.""" - plot_images(batch['img'], - batch['batch_idx'], - batch['cls'].squeeze(-1), - batch['bboxes'], - masks=batch['masks'], - paths=batch['im_file'], - fname=self.save_dir / f'val_batch{ni}_labels.jpg', - names=self.names, - on_plot=self.on_plot) + plot_images( + batch["img"], + batch["batch_idx"], + batch["cls"].squeeze(-1), + batch["bboxes"], + masks=batch["masks"], + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_labels.jpg", + names=self.names, + on_plot=self.on_plot, + ) def plot_predictions(self, batch, preds, ni): """Plots batch predictions with masks and bounding boxes.""" plot_images( - batch['img'], + batch["img"], *output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks, - paths=batch['im_file'], - fname=self.save_dir / f'val_batch{ni}_pred.jpg', + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_pred.jpg", names=self.names, - on_plot=self.on_plot) # pred + on_plot=self.on_plot, + ) # pred self.plot_masks.clear() def pred_to_json(self, predn, filename, pred_masks): @@ -205,8 +221,8 @@ class SegmentationValidator(DetectionValidator): def single_encode(x): """Encode predicted masks as RLE and append results to jdict.""" - rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0] - rle['counts'] = rle['counts'].decode('utf-8') + rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0] + rle["counts"] = rle["counts"].decode("utf-8") return rle stem = Path(filename).stem @@ -217,37 +233,41 @@ class SegmentationValidator(DetectionValidator): with ThreadPool(NUM_THREADS) as pool: rles = pool.map(single_encode, pred_masks) for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())): - self.jdict.append({ - 'image_id': image_id, - 'category_id': self.class_map[int(p[5])], - 'bbox': [round(x, 3) for x in b], - 'score': round(p[4], 5), - 'segmentation': rles[i]}) + self.jdict.append( + { + "image_id": image_id, + "category_id": self.class_map[int(p[5])], + "bbox": [round(x, 3) for x in b], + "score": round(p[4], 5), + "segmentation": rles[i], + } + ) def eval_json(self, stats): """Return COCO-style object detection evaluation metrics.""" if self.args.save_json and self.is_coco and len(self.jdict): - anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations - pred_json = self.save_dir / 'predictions.json' # predictions - LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...') + anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations + pred_json = self.save_dir / "predictions.json" # predictions + LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...") try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb - check_requirements('pycocotools>=2.0.6') + check_requirements("pycocotools>=2.0.6") from pycocotools.coco import COCO # noqa from pycocotools.cocoeval import COCOeval # noqa for x in anno_json, pred_json: - assert x.is_file(), f'{x} file not found' + assert x.is_file(), f"{x} file not found" anno = COCO(str(anno_json)) # init annotations api pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) - for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]): + for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "segm")]): if self.is_coco: eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval eval.evaluate() eval.accumulate() eval.summarize() idx = i * 4 + 2 - stats[self.metrics.keys[idx + 1]], stats[ - self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50 + stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[ + :2 + ] # update mAP50-95 and mAP50 except Exception as e: - LOGGER.warning(f'pycocotools unable to run: {e}') + LOGGER.warning(f"pycocotools unable to run: {e}") return stats diff --git a/ultralytics/nn/__init__.py b/ultralytics/nn/__init__.py index 9889b7ef..6905d349 100644 --- a/ultralytics/nn/__init__.py +++ b/ultralytics/nn/__init__.py @@ -1,9 +1,29 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -from .tasks import (BaseModel, ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight, - attempt_load_weights, guess_model_scale, guess_model_task, parse_model, torch_safe_load, - yaml_model_load) +from .tasks import ( + BaseModel, + ClassificationModel, + DetectionModel, + SegmentationModel, + attempt_load_one_weight, + attempt_load_weights, + guess_model_scale, + guess_model_task, + parse_model, + torch_safe_load, + yaml_model_load, +) -__all__ = ('attempt_load_one_weight', 'attempt_load_weights', 'parse_model', 'yaml_model_load', 'guess_model_task', - 'guess_model_scale', 'torch_safe_load', 'DetectionModel', 'SegmentationModel', 'ClassificationModel', - 'BaseModel') +__all__ = ( + "attempt_load_one_weight", + "attempt_load_weights", + "parse_model", + "yaml_model_load", + "guess_model_task", + "guess_model_scale", + "torch_safe_load", + "DetectionModel", + "SegmentationModel", + "ClassificationModel", + "BaseModel", +) diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index a5c964ce..8f55c3ba 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -32,10 +32,12 @@ def check_class_names(names): names = {int(k): str(v) for k, v in names.items()} n = len(names) if max(names.keys()) >= n: - raise KeyError(f'{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices ' - f'{min(names.keys())}-{max(names.keys())} defined in your dataset YAML.') - if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764' - names_map = yaml_load(ROOT / 'cfg/datasets/ImageNet.yaml')['map'] # human-readable names + raise KeyError( + f"{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices " + f"{min(names.keys())}-{max(names.keys())} defined in your dataset YAML." + ) + if isinstance(names[0], str) and names[0].startswith("n0"): # imagenet class codes, i.e. 'n01440764' + names_map = yaml_load(ROOT / "cfg/datasets/ImageNet.yaml")["map"] # human-readable names names = {k: names_map[v] for k, v in names.items()} return names @@ -44,8 +46,8 @@ def default_class_names(data=None): """Applies default class names to an input YAML file or returns numerical class names.""" if data: with contextlib.suppress(Exception): - return yaml_load(check_yaml(data))['names'] - return {i: f'class{i}' for i in range(999)} # return default if above errors + return yaml_load(check_yaml(data))["names"] + return {i: f"class{i}" for i in range(999)} # return default if above errors class AutoBackend(nn.Module): @@ -77,14 +79,16 @@ class AutoBackend(nn.Module): """ @torch.no_grad() - def __init__(self, - weights='yolov8n.pt', - device=torch.device('cpu'), - dnn=False, - data=None, - fp16=False, - fuse=True, - verbose=True): + def __init__( + self, + weights="yolov8n.pt", + device=torch.device("cpu"), + dnn=False, + data=None, + fp16=False, + fuse=True, + verbose=True, + ): """ Initialize the AutoBackend for inference. @@ -100,17 +104,31 @@ class AutoBackend(nn.Module): super().__init__() w = str(weights[0] if isinstance(weights, list) else weights) nn_module = isinstance(weights, torch.nn.Module) - pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, triton = \ - self._model_type(w) + ( + pt, + jit, + onnx, + xml, + engine, + coreml, + saved_model, + pb, + tflite, + edgetpu, + tfjs, + paddle, + ncnn, + triton, + ) = self._model_type(w) fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16 nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH) stride = 32 # default stride model, metadata = None, None # Set device - cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA + cuda = torch.cuda.is_available() and device.type != "cpu" # use CUDA if cuda and not any([nn_module, pt, jit, engine, onnx]): # GPU dataloader formats - device = torch.device('cpu') + device = torch.device("cpu") cuda = False # Download if not local @@ -121,77 +139,79 @@ class AutoBackend(nn.Module): if nn_module: # in-memory PyTorch model model = weights.to(device) model = model.fuse(verbose=verbose) if fuse else model - if hasattr(model, 'kpt_shape'): + if hasattr(model, "kpt_shape"): kpt_shape = model.kpt_shape # pose-only stride = max(int(model.stride.max()), 32) # model stride - names = model.module.names if hasattr(model, 'module') else model.names # get class names + names = model.module.names if hasattr(model, "module") else model.names # get class names model.half() if fp16 else model.float() self.model = model # explicitly assign for to(), cpu(), cuda(), half() pt = True elif pt: # PyTorch from ultralytics.nn.tasks import attempt_load_weights - model = attempt_load_weights(weights if isinstance(weights, list) else w, - device=device, - inplace=True, - fuse=fuse) - if hasattr(model, 'kpt_shape'): + + model = attempt_load_weights( + weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse + ) + if hasattr(model, "kpt_shape"): kpt_shape = model.kpt_shape # pose-only stride = max(int(model.stride.max()), 32) # model stride - names = model.module.names if hasattr(model, 'module') else model.names # get class names + names = model.module.names if hasattr(model, "module") else model.names # get class names model.half() if fp16 else model.float() self.model = model # explicitly assign for to(), cpu(), cuda(), half() elif jit: # TorchScript - LOGGER.info(f'Loading {w} for TorchScript inference...') - extra_files = {'config.txt': ''} # model metadata + LOGGER.info(f"Loading {w} for TorchScript inference...") + extra_files = {"config.txt": ""} # model metadata model = torch.jit.load(w, _extra_files=extra_files, map_location=device) model.half() if fp16 else model.float() - if extra_files['config.txt']: # load metadata dict - metadata = json.loads(extra_files['config.txt'], object_hook=lambda x: dict(x.items())) + if extra_files["config.txt"]: # load metadata dict + metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items())) elif dnn: # ONNX OpenCV DNN - LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...') - check_requirements('opencv-python>=4.5.4') + LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...") + check_requirements("opencv-python>=4.5.4") net = cv2.dnn.readNetFromONNX(w) elif onnx: # ONNX Runtime - LOGGER.info(f'Loading {w} for ONNX Runtime inference...') - check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime')) + LOGGER.info(f"Loading {w} for ONNX Runtime inference...") + check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime")) import onnxruntime - providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider'] + + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if cuda else ["CPUExecutionProvider"] session = onnxruntime.InferenceSession(w, providers=providers) output_names = [x.name for x in session.get_outputs()] metadata = session.get_modelmeta().custom_metadata_map # metadata elif xml: # OpenVINO - LOGGER.info(f'Loading {w} for OpenVINO inference...') - check_requirements('openvino>=2023.0') # requires openvino-dev: https://pypi.org/project/openvino-dev/ + LOGGER.info(f"Loading {w} for OpenVINO inference...") + check_requirements("openvino>=2023.0") # requires openvino-dev: https://pypi.org/project/openvino-dev/ from openvino.runtime import Core, Layout, get_batch # noqa + core = Core() w = Path(w) if not w.is_file(): # if not *.xml - w = next(w.glob('*.xml')) # get *.xml file from *_openvino_model dir - ov_model = core.read_model(model=str(w), weights=w.with_suffix('.bin')) + w = next(w.glob("*.xml")) # get *.xml file from *_openvino_model dir + ov_model = core.read_model(model=str(w), weights=w.with_suffix(".bin")) if ov_model.get_parameters()[0].get_layout().empty: - ov_model.get_parameters()[0].set_layout(Layout('NCHW')) + ov_model.get_parameters()[0].set_layout(Layout("NCHW")) batch_dim = get_batch(ov_model) if batch_dim.is_static: batch_size = batch_dim.get_length() - ov_compiled_model = core.compile_model(ov_model, device_name='AUTO') # AUTO selects best available device - metadata = w.parent / 'metadata.yaml' + ov_compiled_model = core.compile_model(ov_model, device_name="AUTO") # AUTO selects best available device + metadata = w.parent / "metadata.yaml" elif engine: # TensorRT - LOGGER.info(f'Loading {w} for TensorRT inference...') + LOGGER.info(f"Loading {w} for TensorRT inference...") try: import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download except ImportError: if LINUX: - check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com') + check_requirements("nvidia-tensorrt", cmds="-U --index-url https://pypi.ngc.nvidia.com") import tensorrt as trt # noqa - check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0 - if device.type == 'cpu': - device = torch.device('cuda:0') - Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) + check_version(trt.__version__, "7.0.0", hard=True) # require tensorrt>=7.0.0 + if device.type == "cpu": + device = torch.device("cuda:0") + Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr")) logger = trt.Logger(trt.Logger.INFO) # Read file - with open(w, 'rb') as f, trt.Runtime(logger) as runtime: - meta_len = int.from_bytes(f.read(4), byteorder='little') # read metadata length - metadata = json.loads(f.read(meta_len).decode('utf-8')) # read metadata + with open(w, "rb") as f, trt.Runtime(logger) as runtime: + meta_len = int.from_bytes(f.read(4), byteorder="little") # read metadata length + metadata = json.loads(f.read(meta_len).decode("utf-8")) # read metadata model = runtime.deserialize_cuda_engine(f.read()) # read engine context = model.create_execution_context() bindings = OrderedDict() @@ -213,116 +233,124 @@ class AutoBackend(nn.Module): im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) - batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size + batch_size = bindings["images"].shape[0] # if dynamic, this is instead max batch size elif coreml: # CoreML - LOGGER.info(f'Loading {w} for CoreML inference...') + LOGGER.info(f"Loading {w} for CoreML inference...") import coremltools as ct + model = ct.models.MLModel(w) metadata = dict(model.user_defined_metadata) elif saved_model: # TF SavedModel - LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...') + LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...") import tensorflow as tf + keras = False # assume TF1 saved_model model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w) - metadata = Path(w) / 'metadata.yaml' + metadata = Path(w) / "metadata.yaml" elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt - LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...') + LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...") import tensorflow as tf from ultralytics.engine.exporter import gd_outputs def wrap_frozen_graph(gd, inputs, outputs): """Wrap frozen graphs for deployment.""" - x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # wrapped + x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped ge = x.graph.as_graph_element return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs)) gd = tf.Graph().as_graph_def() # TF GraphDef - with open(w, 'rb') as f: + with open(w, "rb") as f: gd.ParseFromString(f.read()) - frozen_func = wrap_frozen_graph(gd, inputs='x:0', outputs=gd_outputs(gd)) + frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd)) elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu from tflite_runtime.interpreter import Interpreter, load_delegate except ImportError: import tensorflow as tf + Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime - LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...') - delegate = { - 'Linux': 'libedgetpu.so.1', - 'Darwin': 'libedgetpu.1.dylib', - 'Windows': 'edgetpu.dll'}[platform.system()] + LOGGER.info(f"Loading {w} for TensorFlow Lite Edge TPU inference...") + delegate = {"Linux": "libedgetpu.so.1", "Darwin": "libedgetpu.1.dylib", "Windows": "edgetpu.dll"}[ + platform.system() + ] interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)]) else: # TFLite - LOGGER.info(f'Loading {w} for TensorFlow Lite inference...') + LOGGER.info(f"Loading {w} for TensorFlow Lite inference...") interpreter = Interpreter(model_path=w) # load TFLite model interpreter.allocate_tensors() # allocate input_details = interpreter.get_input_details() # inputs output_details = interpreter.get_output_details() # outputs # Load metadata with contextlib.suppress(zipfile.BadZipFile): - with zipfile.ZipFile(w, 'r') as model: + with zipfile.ZipFile(w, "r") as model: meta_file = model.namelist()[0] - metadata = ast.literal_eval(model.read(meta_file).decode('utf-8')) + metadata = ast.literal_eval(model.read(meta_file).decode("utf-8")) elif tfjs: # TF.js - raise NotImplementedError('YOLOv8 TF.js inference is not currently supported.') + raise NotImplementedError("YOLOv8 TF.js inference is not currently supported.") elif paddle: # PaddlePaddle - LOGGER.info(f'Loading {w} for PaddlePaddle inference...') - check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle') + LOGGER.info(f"Loading {w} for PaddlePaddle inference...") + check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle") import paddle.inference as pdi # noqa + w = Path(w) if not w.is_file(): # if not *.pdmodel - w = next(w.rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir - config = pdi.Config(str(w), str(w.with_suffix('.pdiparams'))) + w = next(w.rglob("*.pdmodel")) # get *.pdmodel file from *_paddle_model dir + config = pdi.Config(str(w), str(w.with_suffix(".pdiparams"))) if cuda: config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0) predictor = pdi.create_predictor(config) input_handle = predictor.get_input_handle(predictor.get_input_names()[0]) output_names = predictor.get_output_names() - metadata = w.parents[1] / 'metadata.yaml' + metadata = w.parents[1] / "metadata.yaml" elif ncnn: # ncnn - LOGGER.info(f'Loading {w} for ncnn inference...') - check_requirements('git+https://github.com/Tencent/ncnn.git' if ARM64 else 'ncnn') # requires ncnn + LOGGER.info(f"Loading {w} for ncnn inference...") + check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires ncnn import ncnn as pyncnn + net = pyncnn.Net() net.opt.use_vulkan_compute = cuda w = Path(w) if not w.is_file(): # if not *.param - w = next(w.glob('*.param')) # get *.param file from *_ncnn_model dir + w = next(w.glob("*.param")) # get *.param file from *_ncnn_model dir net.load_param(str(w)) - net.load_model(str(w.with_suffix('.bin'))) - metadata = w.parent / 'metadata.yaml' + net.load_model(str(w.with_suffix(".bin"))) + metadata = w.parent / "metadata.yaml" elif triton: # NVIDIA Triton Inference Server - check_requirements('tritonclient[all]') + check_requirements("tritonclient[all]") from ultralytics.utils.triton import TritonRemoteModel + model = TritonRemoteModel(w) else: from ultralytics.engine.exporter import export_formats - raise TypeError(f"model='{w}' is not a supported model format. " - 'See https://docs.ultralytics.com/modes/predict for help.' - f'\n\n{export_formats()}') + + raise TypeError( + f"model='{w}' is not a supported model format. " + "See https://docs.ultralytics.com/modes/predict for help." + f"\n\n{export_formats()}" + ) # Load external metadata YAML if isinstance(metadata, (str, Path)) and Path(metadata).exists(): metadata = yaml_load(metadata) if metadata: for k, v in metadata.items(): - if k in ('stride', 'batch'): + if k in ("stride", "batch"): metadata[k] = int(v) - elif k in ('imgsz', 'names', 'kpt_shape') and isinstance(v, str): + elif k in ("imgsz", "names", "kpt_shape") and isinstance(v, str): metadata[k] = eval(v) - stride = metadata['stride'] - task = metadata['task'] - batch = metadata['batch'] - imgsz = metadata['imgsz'] - names = metadata['names'] - kpt_shape = metadata.get('kpt_shape') + stride = metadata["stride"] + task = metadata["task"] + batch = metadata["batch"] + imgsz = metadata["imgsz"] + names = metadata["names"] + kpt_shape = metadata.get("kpt_shape") elif not (pt or triton or nn_module): LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'") # Check names - if 'names' not in locals(): # names missing + if "names" not in locals(): # names missing names = default_class_names(data) names = check_class_names(names) @@ -367,26 +395,28 @@ class AutoBackend(nn.Module): im = im.cpu().numpy() # FP32 y = list(self.ov_compiled_model(im).values()) elif self.engine: # TensorRT - if self.dynamic and im.shape != self.bindings['images'].shape: - i = self.model.get_binding_index('images') + if self.dynamic and im.shape != self.bindings["images"].shape: + i = self.model.get_binding_index("images") self.context.set_binding_shape(i, im.shape) # reshape if dynamic - self.bindings['images'] = self.bindings['images']._replace(shape=im.shape) + self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape) for name in self.output_names: i = self.model.get_binding_index(name) self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i))) - s = self.bindings['images'].shape + s = self.bindings["images"].shape assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}" - self.binding_addrs['images'] = int(im.data_ptr()) + self.binding_addrs["images"] = int(im.data_ptr()) self.context.execute_v2(list(self.binding_addrs.values())) y = [self.bindings[x].data for x in sorted(self.output_names)] elif self.coreml: # CoreML im = im[0].cpu().numpy() - im_pil = Image.fromarray((im * 255).astype('uint8')) + im_pil = Image.fromarray((im * 255).astype("uint8")) # im = im.resize((192, 320), Image.BILINEAR) - y = self.model.predict({'image': im_pil}) # coordinates are xywh normalized - if 'confidence' in y: - raise TypeError('Ultralytics only supports inference of non-pipelined CoreML models exported with ' - f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export.") + y = self.model.predict({"image": im_pil}) # coordinates are xywh normalized + if "confidence" in y: + raise TypeError( + "Ultralytics only supports inference of non-pipelined CoreML models exported with " + f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export." + ) # TODO: CoreML NMS inference handling # from ultralytics.utils.ops import xywh2xyxy # box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels @@ -425,20 +455,20 @@ class AutoBackend(nn.Module): if len(y) == 2 and len(self.names) == 999: # segments and names not defined ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400) - self.names = {i: f'class{i}' for i in range(nc)} + self.names = {i: f"class{i}" for i in range(nc)} else: # Lite or Edge TPU details = self.input_details[0] - integer = details['dtype'] in (np.int8, np.int16) # is TFLite quantized int8 or int16 model + integer = details["dtype"] in (np.int8, np.int16) # is TFLite quantized int8 or int16 model if integer: - scale, zero_point = details['quantization'] - im = (im / scale + zero_point).astype(details['dtype']) # de-scale - self.interpreter.set_tensor(details['index'], im) + scale, zero_point = details["quantization"] + im = (im / scale + zero_point).astype(details["dtype"]) # de-scale + self.interpreter.set_tensor(details["index"], im) self.interpreter.invoke() y = [] for output in self.output_details: - x = self.interpreter.get_tensor(output['index']) + x = self.interpreter.get_tensor(output["index"]) if integer: - scale, zero_point = output['quantization'] + scale, zero_point = output["quantization"] x = (x.astype(np.float32) - zero_point) * scale # re-scale if x.ndim > 2: # if task is not classification # Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695 @@ -483,13 +513,13 @@ class AutoBackend(nn.Module): (None): This method runs the forward pass and don't return any value """ warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module - if any(warmup_types) and (self.device.type != 'cpu' or self.triton): + if any(warmup_types) and (self.device.type != "cpu" or self.triton): im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input for _ in range(2 if self.jit else 1): self.forward(im) # warmup @staticmethod - def _model_type(p='path/to/model.pt'): + def _model_type(p="path/to/model.pt"): """ This function takes a path to a model file and returns the model type. @@ -499,18 +529,20 @@ class AutoBackend(nn.Module): # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx # types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle] from ultralytics.engine.exporter import export_formats + sf = list(export_formats().Suffix) # export suffixes if not is_url(p, check=False) and not isinstance(p, str): check_suffix(p, sf) # checks name = Path(p).name types = [s in name for s in sf] - types[5] |= name.endswith('.mlmodel') # retain support for older Apple CoreML *.mlmodel formats + types[5] |= name.endswith(".mlmodel") # retain support for older Apple CoreML *.mlmodel formats types[8] &= not types[9] # tflite &= not edgetpu if any(types): triton = False else: from urllib.parse import urlsplit + url = urlsplit(p) - triton = url.netloc and url.path and url.scheme in {'http', 'grpc'} + triton = url.netloc and url.path and url.scheme in {"http", "grpc"} return types + [triton] diff --git a/ultralytics/nn/modules/__init__.py b/ultralytics/nn/modules/__init__.py index bbc7a5b1..3c6739e5 100644 --- a/ultralytics/nn/modules/__init__.py +++ b/ultralytics/nn/modules/__init__.py @@ -17,18 +17,101 @@ Example: ``` """ -from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, GhostBottleneck, - HGBlock, HGStem, Proto, RepC3, ResNetLayer) -from .conv import (CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus, - GhostConv, LightConv, RepConv, SpatialAttention) +from .block import ( + C1, + C2, + C3, + C3TR, + DFL, + SPP, + SPPF, + Bottleneck, + BottleneckCSP, + C2f, + C3Ghost, + C3x, + GhostBottleneck, + HGBlock, + HGStem, + Proto, + RepC3, + ResNetLayer, +) +from .conv import ( + CBAM, + ChannelAttention, + Concat, + Conv, + Conv2, + ConvTranspose, + DWConv, + DWConvTranspose2d, + Focus, + GhostConv, + LightConv, + RepConv, + SpatialAttention, +) from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment -from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d, - MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer) +from .transformer import ( + AIFI, + MLP, + DeformableTransformerDecoder, + DeformableTransformerDecoderLayer, + LayerNorm2d, + MLPBlock, + MSDeformAttn, + TransformerBlock, + TransformerEncoderLayer, + TransformerLayer, +) -__all__ = ('Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', - 'GhostConv', 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer', - 'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', - 'C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect', - 'Segment', 'Pose', 'Classify', 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI', - 'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP', 'ResNetLayer', - 'OBB') +__all__ = ( + "Conv", + "Conv2", + "LightConv", + "RepConv", + "DWConv", + "DWConvTranspose2d", + "ConvTranspose", + "Focus", + "GhostConv", + "ChannelAttention", + "SpatialAttention", + "CBAM", + "Concat", + "TransformerLayer", + "TransformerBlock", + "MLPBlock", + "LayerNorm2d", + "DFL", + "HGBlock", + "HGStem", + "SPP", + "SPPF", + "C1", + "C2", + "C3", + "C2f", + "C3x", + "C3TR", + "C3Ghost", + "GhostBottleneck", + "Bottleneck", + "BottleneckCSP", + "Proto", + "Detect", + "Segment", + "Pose", + "Classify", + "TransformerEncoderLayer", + "RepC3", + "RTDETRDecoder", + "AIFI", + "DeformableTransformerDecoder", + "DeformableTransformerDecoderLayer", + "MSDeformAttn", + "MLP", + "ResNetLayer", + "OBB", +) diff --git a/ultralytics/nn/modules/block.py b/ultralytics/nn/modules/block.py index c754cdfc..5a6d12f7 100644 --- a/ultralytics/nn/modules/block.py +++ b/ultralytics/nn/modules/block.py @@ -8,8 +8,26 @@ import torch.nn.functional as F from .conv import Conv, DWConv, GhostConv, LightConv, RepConv from .transformer import TransformerBlock -__all__ = ('DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 'C2f', 'C3x', 'C3TR', 'C3Ghost', - 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'RepC3', 'ResNetLayer') +__all__ = ( + "DFL", + "HGBlock", + "HGStem", + "SPP", + "SPPF", + "C1", + "C2", + "C3", + "C2f", + "C3x", + "C3TR", + "C3Ghost", + "GhostBottleneck", + "Bottleneck", + "BottleneckCSP", + "Proto", + "RepC3", + "ResNetLayer", +) class DFL(nn.Module): @@ -284,9 +302,11 @@ class GhostBottleneck(nn.Module): self.conv = nn.Sequential( GhostConv(c1, c_, 1, 1), # pw DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw - GhostConv(c_, c2, 1, 1, act=False)) # pw-linear - self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1, - act=False)) if s == 2 else nn.Identity() + GhostConv(c_, c2, 1, 1, act=False), # pw-linear + ) + self.shortcut = ( + nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity() + ) def forward(self, x): """Applies skip connection and concatenation to input tensor.""" @@ -359,8 +379,9 @@ class ResNetLayer(nn.Module): self.is_first = is_first if self.is_first: - self.layer = nn.Sequential(Conv(c1, c2, k=7, s=2, p=3, act=True), - nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + self.layer = nn.Sequential( + Conv(c1, c2, k=7, s=2, p=3, act=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + ) else: blocks = [ResNetBlock(c1, c2, s, e=e)] blocks.extend([ResNetBlock(e * c2, c2, 1, e=e) for _ in range(n - 1)]) diff --git a/ultralytics/nn/modules/conv.py b/ultralytics/nn/modules/conv.py index 7fe615d9..399c4225 100644 --- a/ultralytics/nn/modules/conv.py +++ b/ultralytics/nn/modules/conv.py @@ -7,8 +7,21 @@ import numpy as np import torch import torch.nn as nn -__all__ = ('Conv', 'Conv2', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv', - 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'RepConv') +__all__ = ( + "Conv", + "Conv2", + "LightConv", + "DWConv", + "DWConvTranspose2d", + "ConvTranspose", + "Focus", + "GhostConv", + "ChannelAttention", + "SpatialAttention", + "CBAM", + "Concat", + "RepConv", +) def autopad(k, p=None, d=1): # kernel, padding, dilation @@ -22,6 +35,7 @@ def autopad(k, p=None, d=1): # kernel, padding, dilation class Conv(nn.Module): """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation).""" + default_act = nn.SiLU() # default activation def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True): @@ -60,9 +74,9 @@ class Conv2(Conv): """Fuse parallel convolutions.""" w = torch.zeros_like(self.conv.weight.data) i = [x // 2 for x in w.shape[2:]] - w[:, :, i[0]:i[0] + 1, i[1]:i[1] + 1] = self.cv2.weight.data.clone() + w[:, :, i[0] : i[0] + 1, i[1] : i[1] + 1] = self.cv2.weight.data.clone() self.conv.weight.data += w - self.__delattr__('cv2') + self.__delattr__("cv2") self.forward = self.forward_fuse @@ -102,6 +116,7 @@ class DWConvTranspose2d(nn.ConvTranspose2d): class ConvTranspose(nn.Module): """Convolution transpose 2d layer.""" + default_act = nn.SiLU() # default activation def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True): @@ -164,6 +179,7 @@ class RepConv(nn.Module): This module is used in RT-DETR. Based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py """ + default_act = nn.SiLU() # default activation def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False): @@ -214,7 +230,7 @@ class RepConv(nn.Module): beta = branch.bn.bias eps = branch.bn.eps elif isinstance(branch, nn.BatchNorm2d): - if not hasattr(self, 'id_tensor'): + if not hasattr(self, "id_tensor"): input_dim = self.c1 // self.g kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32) for i in range(self.c1): @@ -232,29 +248,31 @@ class RepConv(nn.Module): def fuse_convs(self): """Combines two convolution layers into a single layer and removes unused attributes from the class.""" - if hasattr(self, 'conv'): + if hasattr(self, "conv"): return kernel, bias = self.get_equivalent_kernel_bias() - self.conv = nn.Conv2d(in_channels=self.conv1.conv.in_channels, - out_channels=self.conv1.conv.out_channels, - kernel_size=self.conv1.conv.kernel_size, - stride=self.conv1.conv.stride, - padding=self.conv1.conv.padding, - dilation=self.conv1.conv.dilation, - groups=self.conv1.conv.groups, - bias=True).requires_grad_(False) + self.conv = nn.Conv2d( + in_channels=self.conv1.conv.in_channels, + out_channels=self.conv1.conv.out_channels, + kernel_size=self.conv1.conv.kernel_size, + stride=self.conv1.conv.stride, + padding=self.conv1.conv.padding, + dilation=self.conv1.conv.dilation, + groups=self.conv1.conv.groups, + bias=True, + ).requires_grad_(False) self.conv.weight.data = kernel self.conv.bias.data = bias for para in self.parameters(): para.detach_() - self.__delattr__('conv1') - self.__delattr__('conv2') - if hasattr(self, 'nm'): - self.__delattr__('nm') - if hasattr(self, 'bn'): - self.__delattr__('bn') - if hasattr(self, 'id_tensor'): - self.__delattr__('id_tensor') + self.__delattr__("conv1") + self.__delattr__("conv2") + if hasattr(self, "nm"): + self.__delattr__("nm") + if hasattr(self, "bn"): + self.__delattr__("bn") + if hasattr(self, "id_tensor"): + self.__delattr__("id_tensor") class ChannelAttention(nn.Module): @@ -278,7 +296,7 @@ class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): """Initialize Spatial-attention module with kernel size argument.""" super().__init__() - assert kernel_size in (3, 7), 'kernel size must be 3 or 7' + assert kernel_size in (3, 7), "kernel size must be 3 or 7" padding = 3 if kernel_size == 7 else 1 self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) self.act = nn.Sigmoid() diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index cf4cff2f..d5951baa 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -14,11 +14,12 @@ from .conv import Conv from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer from .utils import bias_init_with_prob, linear_init_ -__all__ = 'Detect', 'Segment', 'Pose', 'Classify', 'OBB', 'RTDETRDecoder' +__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder" class Detect(nn.Module): """YOLOv8 Detect head for detection models.""" + dynamic = False # force grid reconstruction export = False # export mode shape = None @@ -35,7 +36,8 @@ class Detect(nn.Module): self.stride = torch.zeros(self.nl) # strides computed during build c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels self.cv2 = nn.ModuleList( - nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch) + nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch + ) self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch) self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() @@ -53,14 +55,14 @@ class Detect(nn.Module): self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) self.shape = shape - if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV ops - box = x_cat[:, :self.reg_max * 4] - cls = x_cat[:, self.reg_max * 4:] + if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV ops + box = x_cat[:, : self.reg_max * 4] + cls = x_cat[:, self.reg_max * 4 :] else: box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) dbox = self.decode_bboxes(box) - if self.export and self.format in ('tflite', 'edgetpu'): + if self.export and self.format in ("tflite", "edgetpu"): # Precompute normalization factor to increase numerical stability # See https://github.com/ultralytics/ultralytics/issues/7371 img_h = shape[2] @@ -79,7 +81,7 @@ class Detect(nn.Module): # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency for a, b, s in zip(m.cv2, m.cv3, m.stride): # from a[-1].bias.data[:] = 1.0 # box - b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) + b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) def decode_bboxes(self, bboxes): """Decode bounding boxes.""" @@ -214,26 +216,28 @@ class RTDETRDecoder(nn.Module): and class labels for objects in an image. It integrates features from multiple layers and runs through a series of Transformer decoder layers to output the final predictions. """ + export = False # export mode def __init__( - self, - nc=80, - ch=(512, 1024, 2048), - hd=256, # hidden dim - nq=300, # num queries - ndp=4, # num decoder points - nh=8, # num head - ndl=6, # num decoder layers - d_ffn=1024, # dim of feedforward - dropout=0., - act=nn.ReLU(), - eval_idx=-1, - # Training args - nd=100, # num denoising - label_noise_ratio=0.5, - box_noise_scale=1.0, - learnt_init_query=False): + self, + nc=80, + ch=(512, 1024, 2048), + hd=256, # hidden dim + nq=300, # num queries + ndp=4, # num decoder points + nh=8, # num head + ndl=6, # num decoder layers + d_ffn=1024, # dim of feedforward + dropout=0.0, + act=nn.ReLU(), + eval_idx=-1, + # Training args + nd=100, # num denoising + label_noise_ratio=0.5, + box_noise_scale=1.0, + learnt_init_query=False, + ): """ Initializes the RTDETRDecoder module with the given parameters. @@ -302,28 +306,30 @@ class RTDETRDecoder(nn.Module): feats, shapes = self._get_encoder_input(x) # Prepare denoising training - dn_embed, dn_bbox, attn_mask, dn_meta = \ - get_cdn_group(batch, - self.nc, - self.num_queries, - self.denoising_class_embed.weight, - self.num_denoising, - self.label_noise_ratio, - self.box_noise_scale, - self.training) - - embed, refer_bbox, enc_bboxes, enc_scores = \ - self._get_decoder_input(feats, shapes, dn_embed, dn_bbox) + dn_embed, dn_bbox, attn_mask, dn_meta = get_cdn_group( + batch, + self.nc, + self.num_queries, + self.denoising_class_embed.weight, + self.num_denoising, + self.label_noise_ratio, + self.box_noise_scale, + self.training, + ) + + embed, refer_bbox, enc_bboxes, enc_scores = self._get_decoder_input(feats, shapes, dn_embed, dn_bbox) # Decoder - dec_bboxes, dec_scores = self.decoder(embed, - refer_bbox, - feats, - shapes, - self.dec_bbox_head, - self.dec_score_head, - self.query_pos_head, - attn_mask=attn_mask) + dec_bboxes, dec_scores = self.decoder( + embed, + refer_bbox, + feats, + shapes, + self.dec_bbox_head, + self.dec_score_head, + self.query_pos_head, + attn_mask=attn_mask, + ) x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta if self.training: return x @@ -331,24 +337,24 @@ class RTDETRDecoder(nn.Module): y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1) return y if self.export else (y, x) - def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2): + def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device="cpu", eps=1e-2): """Generates anchor bounding boxes for given shapes with specific grid size and validates them.""" anchors = [] for i, (h, w) in enumerate(shapes): sy = torch.arange(end=h, dtype=dtype, device=device) sx = torch.arange(end=w, dtype=dtype, device=device) - grid_y, grid_x = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx) + grid_y, grid_x = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx) grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2) valid_WH = torch.tensor([w, h], dtype=dtype, device=device) grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2) - wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0 ** i) + wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**i) anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4) anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4) valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1 anchors = torch.log(anchors / (1 - anchors)) - anchors = anchors.masked_fill(~valid_mask, float('inf')) + anchors = anchors.masked_fill(~valid_mask, float("inf")) return anchors, valid_mask def _get_encoder_input(self, x): @@ -415,13 +421,13 @@ class RTDETRDecoder(nn.Module): # NOTE: the weight initialization in `linear_init_` would cause NaN when training with custom datasets. # linear_init_(self.enc_score_head) constant_(self.enc_score_head.bias, bias_cls) - constant_(self.enc_bbox_head.layers[-1].weight, 0.) - constant_(self.enc_bbox_head.layers[-1].bias, 0.) + constant_(self.enc_bbox_head.layers[-1].weight, 0.0) + constant_(self.enc_bbox_head.layers[-1].bias, 0.0) for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head): # linear_init_(cls_) constant_(cls_.bias, bias_cls) - constant_(reg_.layers[-1].weight, 0.) - constant_(reg_.layers[-1].bias, 0.) + constant_(reg_.layers[-1].weight, 0.0) + constant_(reg_.layers[-1].bias, 0.0) linear_init_(self.enc_output[0]) xavier_uniform_(self.enc_output[0].weight) diff --git a/ultralytics/nn/modules/transformer.py b/ultralytics/nn/modules/transformer.py index 9fe9597f..465b7170 100644 --- a/ultralytics/nn/modules/transformer.py +++ b/ultralytics/nn/modules/transformer.py @@ -11,8 +11,18 @@ from torch.nn.init import constant_, xavier_uniform_ from .conv import Conv from .utils import _get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch -__all__ = ('TransformerEncoderLayer', 'TransformerLayer', 'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'AIFI', - 'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP') +__all__ = ( + "TransformerEncoderLayer", + "TransformerLayer", + "TransformerBlock", + "MLPBlock", + "LayerNorm2d", + "AIFI", + "DeformableTransformerDecoder", + "DeformableTransformerDecoderLayer", + "MSDeformAttn", + "MLP", +) class TransformerEncoderLayer(nn.Module): @@ -22,9 +32,11 @@ class TransformerEncoderLayer(nn.Module): """Initialize the TransformerEncoderLayer with specified parameters.""" super().__init__() from ...utils.torch_utils import TORCH_1_9 + if not TORCH_1_9: raise ModuleNotFoundError( - 'TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True).') + "TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True)." + ) self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True) # Implementation of Feedforward model self.fc1 = nn.Linear(c1, cm) @@ -91,12 +103,11 @@ class AIFI(TransformerEncoderLayer): """Builds 2D sine-cosine position embedding.""" grid_w = torch.arange(int(w), dtype=torch.float32) grid_h = torch.arange(int(h), dtype=torch.float32) - grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij') - assert embed_dim % 4 == 0, \ - 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' + grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") + assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding" pos_dim = embed_dim // 4 omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim - omega = 1. / (temperature ** omega) + omega = 1.0 / (temperature**omega) out_w = grid_w.flatten()[..., None] @ omega[None] out_h = grid_h.flatten()[..., None] @ omega[None] @@ -213,10 +224,10 @@ class MSDeformAttn(nn.Module): """Initialize MSDeformAttn with the given parameters.""" super().__init__() if d_model % n_heads != 0: - raise ValueError(f'd_model must be divisible by n_heads, but got {d_model} and {n_heads}') + raise ValueError(f"d_model must be divisible by n_heads, but got {d_model} and {n_heads}") _d_per_head = d_model // n_heads # Better to set _d_per_head to a power of 2 which is more efficient in a CUDA implementation - assert _d_per_head * n_heads == d_model, '`d_model` must be divisible by `n_heads`' + assert _d_per_head * n_heads == d_model, "`d_model` must be divisible by `n_heads`" self.im2col_step = 64 @@ -234,21 +245,24 @@ class MSDeformAttn(nn.Module): def _reset_parameters(self): """Reset module parameters.""" - constant_(self.sampling_offsets.weight.data, 0.) + constant_(self.sampling_offsets.weight.data, 0.0) thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) - grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat( - 1, self.n_levels, self.n_points, 1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.n_heads, 1, 1, 2) + .repeat(1, self.n_levels, self.n_points, 1) + ) for i in range(self.n_points): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - constant_(self.attention_weights.weight.data, 0.) - constant_(self.attention_weights.bias.data, 0.) + constant_(self.attention_weights.weight.data, 0.0) + constant_(self.attention_weights.bias.data, 0.0) xavier_uniform_(self.value_proj.weight.data) - constant_(self.value_proj.bias.data, 0.) + constant_(self.value_proj.bias.data, 0.0) xavier_uniform_(self.output_proj.weight.data) - constant_(self.output_proj.bias.data, 0.) + constant_(self.output_proj.bias.data, 0.0) def forward(self, query, refer_bbox, value, value_shapes, value_mask=None): """ @@ -288,7 +302,7 @@ class MSDeformAttn(nn.Module): add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5 sampling_locations = refer_bbox[:, :, None, :, None, :2] + add else: - raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {num_points}.') + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {num_points}.") output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights) return self.output_proj(output) @@ -301,7 +315,7 @@ class DeformableTransformerDecoderLayer(nn.Module): https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py """ - def __init__(self, d_model=256, n_heads=8, d_ffn=1024, dropout=0., act=nn.ReLU(), n_levels=4, n_points=4): + def __init__(self, d_model=256, n_heads=8, d_ffn=1024, dropout=0.0, act=nn.ReLU(), n_levels=4, n_points=4): """Initialize the DeformableTransformerDecoderLayer with the given parameters.""" super().__init__() @@ -339,14 +353,16 @@ class DeformableTransformerDecoderLayer(nn.Module): # Self attention q = k = self.with_pos_embed(embed, query_pos) - tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), - attn_mask=attn_mask)[0].transpose(0, 1) + tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), attn_mask=attn_mask)[ + 0 + ].transpose(0, 1) embed = embed + self.dropout1(tgt) embed = self.norm1(embed) # Cross attention - tgt = self.cross_attn(self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes, - padding_mask) + tgt = self.cross_attn( + self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes, padding_mask + ) embed = embed + self.dropout2(tgt) embed = self.norm2(embed) @@ -370,16 +386,17 @@ class DeformableTransformerDecoder(nn.Module): self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx def forward( - self, - embed, # decoder embeddings - refer_bbox, # anchor - feats, # image features - shapes, # feature shapes - bbox_head, - score_head, - pos_mlp, - attn_mask=None, - padding_mask=None): + self, + embed, # decoder embeddings + refer_bbox, # anchor + feats, # image features + shapes, # feature shapes + bbox_head, + score_head, + pos_mlp, + attn_mask=None, + padding_mask=None, + ): """Perform the forward pass through the entire decoder.""" output = embed dec_bboxes = [] diff --git a/ultralytics/nn/modules/utils.py b/ultralytics/nn/modules/utils.py index c7bec7af..2cb615a6 100644 --- a/ultralytics/nn/modules/utils.py +++ b/ultralytics/nn/modules/utils.py @@ -10,7 +10,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn.init import uniform_ -__all__ = 'multi_scale_deformable_attn_pytorch', 'inverse_sigmoid' +__all__ = "multi_scale_deformable_attn_pytorch", "inverse_sigmoid" def _get_clones(module, n): @@ -27,7 +27,7 @@ def linear_init_(module): """Initialize the weights and biases of a linear module.""" bound = 1 / math.sqrt(module.weight.shape[0]) uniform_(module.weight, -bound, bound) - if hasattr(module, 'bias') and module.bias is not None: + if hasattr(module, "bias") and module.bias is not None: uniform_(module.bias, -bound, bound) @@ -39,9 +39,12 @@ def inverse_sigmoid(x, eps=1e-5): return torch.log(x1 / x2) -def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shapes: torch.Tensor, - sampling_locations: torch.Tensor, - attention_weights: torch.Tensor) -> torch.Tensor: +def multi_scale_deformable_attn_pytorch( + value: torch.Tensor, + value_spatial_shapes: torch.Tensor, + sampling_locations: torch.Tensor, + attention_weights: torch.Tensor, +) -> torch.Tensor: """ Multi-scale deformable attention. @@ -58,23 +61,25 @@ def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shape # bs, H_*W_, num_heads*embed_dims -> # bs, num_heads*embed_dims, H_*W_ -> # bs*num_heads, embed_dims, H_, W_ - value_l_ = (value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)) + value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_) # bs, num_queries, num_heads, num_points, 2 -> # bs, num_heads, num_queries, num_points, 2 -> # bs*num_heads, num_queries, num_points, 2 sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1) # bs*num_heads, embed_dims, num_queries, num_points - sampling_value_l_ = F.grid_sample(value_l_, - sampling_grid_l_, - mode='bilinear', - padding_mode='zeros', - align_corners=False) + sampling_value_l_ = F.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) sampling_value_list.append(sampling_value_l_) # (bs, num_queries, num_heads, num_levels, num_points) -> # (bs, num_heads, num_queries, num_levels, num_points) -> # (bs, num_heads, 1, num_queries, num_levels*num_points) - attention_weights = attention_weights.transpose(1, 2).reshape(bs * num_heads, 1, num_queries, - num_levels * num_points) - output = ((torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view( - bs, num_heads * embed_dims, num_queries)) + attention_weights = attention_weights.transpose(1, 2).reshape( + bs * num_heads, 1, num_queries, num_levels * num_points + ) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(bs, num_heads * embed_dims, num_queries) + ) return output.transpose(1, 2).contiguous() diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 8261ed04..8efd8e33 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -7,16 +7,54 @@ from pathlib import Path import torch import torch.nn as nn -from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, OBB, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, - C3x, Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, - DWConvTranspose2d, Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, - RepConv, ResNetLayer, RTDETRDecoder, Segment) +from ultralytics.nn.modules import ( + AIFI, + C1, + C2, + C3, + C3TR, + OBB, + SPP, + SPPF, + Bottleneck, + BottleneckCSP, + C2f, + C3Ghost, + C3x, + Classify, + Concat, + Conv, + Conv2, + ConvTranspose, + Detect, + DWConv, + DWConvTranspose2d, + Focus, + GhostBottleneck, + GhostConv, + HGBlock, + HGStem, + Pose, + RepC3, + RepConv, + ResNetLayer, + RTDETRDecoder, + Segment, +) from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss from ultralytics.utils.plotting import feature_visualization -from ultralytics.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, intersect_dicts, - make_divisible, model_info, scale_img, time_sync) +from ultralytics.utils.torch_utils import ( + fuse_conv_and_bn, + fuse_deconv_and_bn, + initialize_weights, + intersect_dicts, + make_divisible, + model_info, + scale_img, + time_sync, +) try: import thop @@ -90,8 +128,10 @@ class BaseModel(nn.Module): def _predict_augment(self, x): """Perform augmentations on input image x and return augmented inference.""" - LOGGER.warning(f'WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. ' - f'Reverting to single-scale inference instead.') + LOGGER.warning( + f"WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. " + f"Reverting to single-scale inference instead." + ) return self._predict_once(x) def _profile_one_layer(self, m, x, dt): @@ -108,14 +148,14 @@ class BaseModel(nn.Module): None """ c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix - flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs + flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # FLOPs t = time_sync() for _ in range(10): m(x.copy() if c else x) dt.append((time_sync() - t) * 100) if m == self.model[0]: LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module") - LOGGER.info(f'{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}') + LOGGER.info(f"{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}") if c: LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total") @@ -129,15 +169,15 @@ class BaseModel(nn.Module): """ if not self.is_fused(): for m in self.model.modules(): - if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, 'bn'): + if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, "bn"): if isinstance(m, Conv2): m.fuse_convs() m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv - delattr(m, 'bn') # remove batchnorm + delattr(m, "bn") # remove batchnorm m.forward = m.forward_fuse # update forward - if isinstance(m, ConvTranspose) and hasattr(m, 'bn'): + if isinstance(m, ConvTranspose) and hasattr(m, "bn"): m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn) - delattr(m, 'bn') # remove batchnorm + delattr(m, "bn") # remove batchnorm m.forward = m.forward_fuse # update forward if isinstance(m, RepConv): m.fuse_convs() @@ -156,7 +196,7 @@ class BaseModel(nn.Module): Returns: (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise. """ - bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d() + bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d() return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model def info(self, detailed=False, verbose=True, imgsz=640): @@ -196,12 +236,12 @@ class BaseModel(nn.Module): weights (dict | torch.nn.Module): The pre-trained weights to be loaded. verbose (bool, optional): Whether to log the transfer progress. Defaults to True. """ - model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts + model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts csd = model.float().state_dict() # checkpoint state_dict as FP32 csd = intersect_dicts(csd, self.state_dict()) # intersect self.load_state_dict(csd, strict=False) # load if verbose: - LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights') + LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights") def loss(self, batch, preds=None): """ @@ -211,33 +251,33 @@ class BaseModel(nn.Module): batch (dict): Batch to compute loss on preds (torch.Tensor | List[torch.Tensor]): Predictions. """ - if not hasattr(self, 'criterion'): + if not hasattr(self, "criterion"): self.criterion = self.init_criterion() - preds = self.forward(batch['img']) if preds is None else preds + preds = self.forward(batch["img"]) if preds is None else preds return self.criterion(preds, batch) def init_criterion(self): """Initialize the loss criterion for the BaseModel.""" - raise NotImplementedError('compute_loss() needs to be implemented by task heads') + raise NotImplementedError("compute_loss() needs to be implemented by task heads") class DetectionModel(BaseModel): """YOLOv8 detection model.""" - def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes + def __init__(self, cfg="yolov8n.yaml", ch=3, nc=None, verbose=True): # model, input channels, number of classes """Initialize the YOLOv8 detection model with the given config and parameters.""" super().__init__() self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict # Define model - ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels - if nc and nc != self.yaml['nc']: + ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels + if nc and nc != self.yaml["nc"]: LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") - self.yaml['nc'] = nc # override YAML value + self.yaml["nc"] = nc # override YAML value self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist - self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict - self.inplace = self.yaml.get('inplace', True) + self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict + self.inplace = self.yaml.get("inplace", True) # Build strides m = self.model[-1] # Detect() @@ -255,7 +295,7 @@ class DetectionModel(BaseModel): initialize_weights(self) if verbose: self.info() - LOGGER.info('') + LOGGER.info("") def _predict_augment(self, x): """Perform augmentations on input image x and return augmented inference and train outputs.""" @@ -285,9 +325,9 @@ class DetectionModel(BaseModel): def _clip_augmented(self, y): """Clip YOLO augmented inference tails.""" nl = self.model[-1].nl # number of detection layers (P3-P5) - g = sum(4 ** x for x in range(nl)) # grid points + g = sum(4**x for x in range(nl)) # grid points e = 1 # exclude layer count - i = (y[0].shape[-1] // g) * sum(4 ** x for x in range(e)) # indices + i = (y[0].shape[-1] // g) * sum(4**x for x in range(e)) # indices y[0] = y[0][..., :-i] # large i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices y[-1] = y[-1][..., i:] # small @@ -301,7 +341,7 @@ class DetectionModel(BaseModel): class OBBModel(DetectionModel): """"YOLOv8 Oriented Bounding Box (OBB) model.""" - def __init__(self, cfg='yolov8n-obb.yaml', ch=3, nc=None, verbose=True): + def __init__(self, cfg="yolov8n-obb.yaml", ch=3, nc=None, verbose=True): """Initialize YOLOv8 OBB model with given config and parameters.""" super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) @@ -313,7 +353,7 @@ class OBBModel(DetectionModel): class SegmentationModel(DetectionModel): """YOLOv8 segmentation model.""" - def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True): + def __init__(self, cfg="yolov8n-seg.yaml", ch=3, nc=None, verbose=True): """Initialize YOLOv8 segmentation model with given config and parameters.""" super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) @@ -325,13 +365,13 @@ class SegmentationModel(DetectionModel): class PoseModel(DetectionModel): """YOLOv8 pose model.""" - def __init__(self, cfg='yolov8n-pose.yaml', ch=3, nc=None, data_kpt_shape=(None, None), verbose=True): + def __init__(self, cfg="yolov8n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True): """Initialize YOLOv8 Pose model.""" if not isinstance(cfg, dict): cfg = yaml_model_load(cfg) # load model YAML - if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg['kpt_shape']): + if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]): LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}") - cfg['kpt_shape'] = data_kpt_shape + cfg["kpt_shape"] = data_kpt_shape super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) def init_criterion(self): @@ -342,7 +382,7 @@ class PoseModel(DetectionModel): class ClassificationModel(BaseModel): """YOLOv8 classification model.""" - def __init__(self, cfg='yolov8n-cls.yaml', ch=3, nc=None, verbose=True): + def __init__(self, cfg="yolov8n-cls.yaml", ch=3, nc=None, verbose=True): """Init ClassificationModel with YAML, channels, number of classes, verbose flag.""" super().__init__() self._from_yaml(cfg, ch, nc, verbose) @@ -352,21 +392,21 @@ class ClassificationModel(BaseModel): self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict # Define model - ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels - if nc and nc != self.yaml['nc']: + ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels + if nc and nc != self.yaml["nc"]: LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") - self.yaml['nc'] = nc # override YAML value - elif not nc and not self.yaml.get('nc', None): - raise ValueError('nc not specified. Must specify nc in model.yaml or function arguments.') + self.yaml["nc"] = nc # override YAML value + elif not nc and not self.yaml.get("nc", None): + raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.") self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist self.stride = torch.Tensor([1]) # no stride constraints - self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict + self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict self.info() @staticmethod def reshape_outputs(model, nc): """Update a TorchVision classification model to class count 'n' if required.""" - name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module + name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module if isinstance(m, Classify): # YOLO Classify() head if m.linear.out_features != nc: m.linear = nn.Linear(m.linear.in_features, nc) @@ -409,7 +449,7 @@ class RTDETRDetectionModel(DetectionModel): predict: Performs a forward pass through the network and returns the output. """ - def __init__(self, cfg='rtdetr-l.yaml', ch=3, nc=None, verbose=True): + def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True): """ Initialize the RTDETRDetectionModel. @@ -438,39 +478,39 @@ class RTDETRDetectionModel(DetectionModel): Returns: (tuple): A tuple containing the total loss and main three losses in a tensor. """ - if not hasattr(self, 'criterion'): + if not hasattr(self, "criterion"): self.criterion = self.init_criterion() - img = batch['img'] + img = batch["img"] # NOTE: preprocess gt_bbox and gt_labels to list. bs = len(img) - batch_idx = batch['batch_idx'] + batch_idx = batch["batch_idx"] gt_groups = [(batch_idx == i).sum().item() for i in range(bs)] targets = { - 'cls': batch['cls'].to(img.device, dtype=torch.long).view(-1), - 'bboxes': batch['bboxes'].to(device=img.device), - 'batch_idx': batch_idx.to(img.device, dtype=torch.long).view(-1), - 'gt_groups': gt_groups} + "cls": batch["cls"].to(img.device, dtype=torch.long).view(-1), + "bboxes": batch["bboxes"].to(device=img.device), + "batch_idx": batch_idx.to(img.device, dtype=torch.long).view(-1), + "gt_groups": gt_groups, + } preds = self.predict(img, batch=targets) if preds is None else preds dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1] if dn_meta is None: dn_bboxes, dn_scores = None, None else: - dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta['dn_num_split'], dim=2) - dn_scores, dec_scores = torch.split(dec_scores, dn_meta['dn_num_split'], dim=2) + dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta["dn_num_split"], dim=2) + dn_scores, dec_scores = torch.split(dec_scores, dn_meta["dn_num_split"], dim=2) dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4) dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores]) - loss = self.criterion((dec_bboxes, dec_scores), - targets, - dn_bboxes=dn_bboxes, - dn_scores=dn_scores, - dn_meta=dn_meta) + loss = self.criterion( + (dec_bboxes, dec_scores), targets, dn_bboxes=dn_bboxes, dn_scores=dn_scores, dn_meta=dn_meta + ) # NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses. - return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']], - device=img.device) + return sum(loss.values()), torch.as_tensor( + [loss[k].detach() for k in ["loss_giou", "loss_class", "loss_bbox"]], device=img.device + ) def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None): """ @@ -553,6 +593,7 @@ def temporary_modules(modules=None): import importlib import sys + try: # Set modules in sys.modules under their old name for old, new in modules.items(): @@ -580,30 +621,38 @@ def torch_safe_load(weight): """ from ultralytics.utils.downloads import attempt_download_asset - check_suffix(file=weight, suffix='.pt') + check_suffix(file=weight, suffix=".pt") file = attempt_download_asset(weight) # search online if missing locally try: - with temporary_modules({ - 'ultralytics.yolo.utils': 'ultralytics.utils', - 'ultralytics.yolo.v8': 'ultralytics.models.yolo', - 'ultralytics.yolo.data': 'ultralytics.data'}): # for legacy 8.0 Classify and Pose models - return torch.load(file, map_location='cpu'), file # load + with temporary_modules( + { + "ultralytics.yolo.utils": "ultralytics.utils", + "ultralytics.yolo.v8": "ultralytics.models.yolo", + "ultralytics.yolo.data": "ultralytics.data", + } + ): # for legacy 8.0 Classify and Pose models + return torch.load(file, map_location="cpu"), file # load except ModuleNotFoundError as e: # e.name is missing module name - if e.name == 'models': + if e.name == "models": raise TypeError( - emojis(f'ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained ' - f'with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with ' - f'YOLOv8 at https://github.com/ultralytics/ultralytics.' - f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to " - f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")) from e - LOGGER.warning(f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements." - f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future." - f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to " - f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'") + emojis( + f"ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained " + f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with " + f"YOLOv8 at https://github.com/ultralytics/ultralytics." + f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to " + f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'" + ) + ) from e + LOGGER.warning( + f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements." + f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future." + f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to " + f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'" + ) check_requirements(e.name) # install missing module - return torch.load(file, map_location='cpu'), file # load + return torch.load(file, map_location="cpu"), file # load def attempt_load_weights(weights, device=None, inplace=True, fuse=False): @@ -612,25 +661,25 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False): ensemble = Ensemble() for w in weights if isinstance(weights, list) else [weights]: ckpt, w = torch_safe_load(w) # load ckpt - args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} if 'train_args' in ckpt else None # combined args - model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model + args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None # combined args + model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model # Model compatibility updates model.args = args # attach args to model model.pt_path = w # attach *.pt file path to model model.task = guess_model_task(model) - if not hasattr(model, 'stride'): - model.stride = torch.tensor([32.]) + if not hasattr(model, "stride"): + model.stride = torch.tensor([32.0]) # Append - ensemble.append(model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval()) # model in eval mode + ensemble.append(model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()) # model in eval mode # Module updates for m in ensemble.modules(): t = type(m) if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB): m.inplace = inplace - elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'): + elif t is nn.Upsample and not hasattr(m, "recompute_scale_factor"): m.recompute_scale_factor = None # torch 1.11.0 compatibility # Return model @@ -638,35 +687,35 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False): return ensemble[-1] # Return ensemble - LOGGER.info(f'Ensemble created with {weights}\n') - for k in 'names', 'nc', 'yaml': + LOGGER.info(f"Ensemble created with {weights}\n") + for k in "names", "nc", "yaml": setattr(ensemble, k, getattr(ensemble[0], k)) ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride - assert all(ensemble[0].nc == m.nc for m in ensemble), f'Models differ in class counts {[m.nc for m in ensemble]}' + assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}" return ensemble def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False): """Loads a single model weights.""" ckpt, weight = torch_safe_load(weight) # load ckpt - args = {**DEFAULT_CFG_DICT, **(ckpt.get('train_args', {}))} # combine model and default args, preferring model args - model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model + args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))} # combine model and default args, preferring model args + model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model # Model compatibility updates model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model model.pt_path = weight # attach *.pt file path to model model.task = guess_model_task(model) - if not hasattr(model, 'stride'): - model.stride = torch.tensor([32.]) + if not hasattr(model, "stride"): + model.stride = torch.tensor([32.0]) - model = model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval() # model in eval mode + model = model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval() # model in eval mode # Module updates for m in model.modules(): t = type(m) if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB): m.inplace = inplace - elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'): + elif t is nn.Upsample and not hasattr(m, "recompute_scale_factor"): m.recompute_scale_factor = None # torch 1.11.0 compatibility # Return model and ckpt @@ -678,11 +727,11 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) import ast # Args - max_channels = float('inf') - nc, act, scales = (d.get(x) for x in ('nc', 'activation', 'scales')) - depth, width, kpt_shape = (d.get(x, 1.0) for x in ('depth_multiple', 'width_multiple', 'kpt_shape')) + max_channels = float("inf") + nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales")) + depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape")) if scales: - scale = d.get('scale') + scale = d.get("scale") if not scale: scale = tuple(scales.keys())[0] LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.") @@ -697,16 +746,37 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}") ch = [ch] layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out - for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args - m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m] # get module + for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args + m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m] # get module for j, a in enumerate(args): if isinstance(a, str): with contextlib.suppress(ValueError): args[j] = locals()[a] if a in locals() else ast.literal_eval(a) n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain - if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus, - BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3): + if m in ( + Classify, + Conv, + ConvTranspose, + GhostConv, + Bottleneck, + GhostBottleneck, + SPP, + SPPF, + DWConv, + Focus, + BottleneckCSP, + C1, + C2, + C2f, + C3, + C3TR, + C3Ghost, + nn.ConvTranspose2d, + DWConvTranspose2d, + C3x, + RepC3, + ): c1, c2 = ch[f], args[0] if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output) c2 = make_divisible(min(c2, max_channels) * width, 8) @@ -739,11 +809,11 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) c2 = ch[f] m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module - t = str(m)[8:-2].replace('__main__.', '') # module type + t = str(m)[8:-2].replace("__main__.", "") # module type m.np = sum(x.numel() for x in m_.parameters()) # number params m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type if verbose: - LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}') # print + LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}") # print save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist layers.append(m_) if i == 0: @@ -757,16 +827,16 @@ def yaml_model_load(path): import re path = Path(path) - if path.stem in (f'yolov{d}{x}6' for x in 'nsmlx' for d in (5, 8)): - new_stem = re.sub(r'(\d+)([nslmx])6(.+)?$', r'\1\2-p6\3', path.stem) - LOGGER.warning(f'WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.') + if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)): + new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem) + LOGGER.warning(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.") path = path.with_name(new_stem + path.suffix) - unified_path = re.sub(r'(\d+)([nslmx])(.+)?$', r'\1\3', str(path)) # i.e. yolov8x.yaml -> yolov8.yaml + unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path) d = yaml_load(yaml_file) # model dict - d['scale'] = guess_model_scale(path) - d['yaml_file'] = str(path) + d["scale"] = guess_model_scale(path) + d["yaml_file"] = str(path) return d @@ -784,8 +854,9 @@ def guess_model_scale(model_path): """ with contextlib.suppress(AttributeError): import re - return re.search(r'yolov\d+([nslmx])', Path(model_path).stem).group(1) # n, s, m, l, or x - return '' + + return re.search(r"yolov\d+([nslmx])", Path(model_path).stem).group(1) # n, s, m, l, or x + return "" def guess_model_task(model): @@ -804,17 +875,17 @@ def guess_model_task(model): def cfg2task(cfg): """Guess from YAML dictionary.""" - m = cfg['head'][-1][-2].lower() # output module name - if m in ('classify', 'classifier', 'cls', 'fc'): - return 'classify' - if m == 'detect': - return 'detect' - if m == 'segment': - return 'segment' - if m == 'pose': - return 'pose' - if m == 'obb': - return 'obb' + m = cfg["head"][-1][-2].lower() # output module name + if m in ("classify", "classifier", "cls", "fc"): + return "classify" + if m == "detect": + return "detect" + if m == "segment": + return "segment" + if m == "pose": + return "pose" + if m == "obb": + return "obb" # Guess from model cfg if isinstance(model, dict): @@ -823,40 +894,42 @@ def guess_model_task(model): # Guess from PyTorch model if isinstance(model, nn.Module): # PyTorch model - for x in 'model.args', 'model.model.args', 'model.model.model.args': + for x in "model.args", "model.model.args", "model.model.model.args": with contextlib.suppress(Exception): - return eval(x)['task'] - for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml': + return eval(x)["task"] + for x in "model.yaml", "model.model.yaml", "model.model.model.yaml": with contextlib.suppress(Exception): return cfg2task(eval(x)) for m in model.modules(): if isinstance(m, Detect): - return 'detect' + return "detect" elif isinstance(m, Segment): - return 'segment' + return "segment" elif isinstance(m, Classify): - return 'classify' + return "classify" elif isinstance(m, Pose): - return 'pose' + return "pose" elif isinstance(m, OBB): - return 'obb' + return "obb" # Guess from model filename if isinstance(model, (str, Path)): model = Path(model) - if '-seg' in model.stem or 'segment' in model.parts: - return 'segment' - elif '-cls' in model.stem or 'classify' in model.parts: - return 'classify' - elif '-pose' in model.stem or 'pose' in model.parts: - return 'pose' - elif '-obb' in model.stem or 'obb' in model.parts: - return 'obb' - elif 'detect' in model.parts: - return 'detect' + if "-seg" in model.stem or "segment" in model.parts: + return "segment" + elif "-cls" in model.stem or "classify" in model.parts: + return "classify" + elif "-pose" in model.stem or "pose" in model.parts: + return "pose" + elif "-obb" in model.stem or "obb" in model.parts: + return "obb" + elif "detect" in model.parts: + return "detect" # Unable to determine task from model - LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. " - "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'.") - return 'detect' # assume detect + LOGGER.warning( + "WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. " + "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'." + ) + return "detect" # assume detect diff --git a/ultralytics/solutions/ai_gym.py b/ultralytics/solutions/ai_gym.py index df9b422a..0cd59620 100644 --- a/ultralytics/solutions/ai_gym.py +++ b/ultralytics/solutions/ai_gym.py @@ -26,7 +26,7 @@ class AIGym: self.angle = None self.count = None self.stage = None - self.pose_type = 'pushup' + self.pose_type = "pushup" self.kpts_to_check = None # Visual Information @@ -36,13 +36,15 @@ class AIGym: # Check if environment support imshow self.env_check = check_imshow(warn=True) - def set_args(self, - kpts_to_check, - line_thickness=2, - view_img=False, - pose_up_angle=145.0, - pose_down_angle=90.0, - pose_type='pullup'): + def set_args( + self, + kpts_to_check, + line_thickness=2, + view_img=False, + pose_up_angle=145.0, + pose_down_angle=90.0, + pose_type="pullup", + ): """ Configures the AIGym line_thickness, save image and view image parameters Args: @@ -72,65 +74,75 @@ class AIGym: if frame_count == 1: self.count = [0] * len(results[0]) self.angle = [0] * len(results[0]) - self.stage = ['-' for _ in results[0]] + self.stage = ["-" for _ in results[0]] self.keypoints = results[0].keypoints.data self.annotator = Annotator(im0, line_width=2) for ind, k in enumerate(reversed(self.keypoints)): - if self.pose_type == 'pushup' or self.pose_type == 'pullup': - self.angle[ind] = self.annotator.estimate_pose_angle(k[int(self.kpts_to_check[0])].cpu(), - k[int(self.kpts_to_check[1])].cpu(), - k[int(self.kpts_to_check[2])].cpu()) + if self.pose_type == "pushup" or self.pose_type == "pullup": + self.angle[ind] = self.annotator.estimate_pose_angle( + k[int(self.kpts_to_check[0])].cpu(), + k[int(self.kpts_to_check[1])].cpu(), + k[int(self.kpts_to_check[2])].cpu(), + ) self.im0 = self.annotator.draw_specific_points(k, self.kpts_to_check, shape=(640, 640), radius=10) - if self.pose_type == 'abworkout': - self.angle[ind] = self.annotator.estimate_pose_angle(k[int(self.kpts_to_check[0])].cpu(), - k[int(self.kpts_to_check[1])].cpu(), - k[int(self.kpts_to_check[2])].cpu()) + if self.pose_type == "abworkout": + self.angle[ind] = self.annotator.estimate_pose_angle( + k[int(self.kpts_to_check[0])].cpu(), + k[int(self.kpts_to_check[1])].cpu(), + k[int(self.kpts_to_check[2])].cpu(), + ) self.im0 = self.annotator.draw_specific_points(k, self.kpts_to_check, shape=(640, 640), radius=10) if self.angle[ind] > self.poseup_angle: - self.stage[ind] = 'down' - if self.angle[ind] < self.posedown_angle and self.stage[ind] == 'down': - self.stage[ind] = 'up' + self.stage[ind] = "down" + if self.angle[ind] < self.posedown_angle and self.stage[ind] == "down": + self.stage[ind] = "up" self.count[ind] += 1 - self.annotator.plot_angle_and_count_and_stage(angle_text=self.angle[ind], - count_text=self.count[ind], - stage_text=self.stage[ind], - center_kpt=k[int(self.kpts_to_check[1])], - line_thickness=self.tf) - - if self.pose_type == 'pushup': + self.annotator.plot_angle_and_count_and_stage( + angle_text=self.angle[ind], + count_text=self.count[ind], + stage_text=self.stage[ind], + center_kpt=k[int(self.kpts_to_check[1])], + line_thickness=self.tf, + ) + + if self.pose_type == "pushup": if self.angle[ind] > self.poseup_angle: - self.stage[ind] = 'up' - if self.angle[ind] < self.posedown_angle and self.stage[ind] == 'up': - self.stage[ind] = 'down' + self.stage[ind] = "up" + if self.angle[ind] < self.posedown_angle and self.stage[ind] == "up": + self.stage[ind] = "down" self.count[ind] += 1 - self.annotator.plot_angle_and_count_and_stage(angle_text=self.angle[ind], - count_text=self.count[ind], - stage_text=self.stage[ind], - center_kpt=k[int(self.kpts_to_check[1])], - line_thickness=self.tf) - if self.pose_type == 'pullup': + self.annotator.plot_angle_and_count_and_stage( + angle_text=self.angle[ind], + count_text=self.count[ind], + stage_text=self.stage[ind], + center_kpt=k[int(self.kpts_to_check[1])], + line_thickness=self.tf, + ) + if self.pose_type == "pullup": if self.angle[ind] > self.poseup_angle: - self.stage[ind] = 'down' - if self.angle[ind] < self.posedown_angle and self.stage[ind] == 'down': - self.stage[ind] = 'up' + self.stage[ind] = "down" + if self.angle[ind] < self.posedown_angle and self.stage[ind] == "down": + self.stage[ind] = "up" self.count[ind] += 1 - self.annotator.plot_angle_and_count_and_stage(angle_text=self.angle[ind], - count_text=self.count[ind], - stage_text=self.stage[ind], - center_kpt=k[int(self.kpts_to_check[1])], - line_thickness=self.tf) + self.annotator.plot_angle_and_count_and_stage( + angle_text=self.angle[ind], + count_text=self.count[ind], + stage_text=self.stage[ind], + center_kpt=k[int(self.kpts_to_check[1])], + line_thickness=self.tf, + ) self.annotator.kpts(k, shape=(640, 640), radius=1, kpt_line=True) if self.env_check and self.view_img: - cv2.imshow('Ultralytics YOLOv8 AI GYM', self.im0) - if cv2.waitKey(1) & 0xFF == ord('q'): + cv2.imshow("Ultralytics YOLOv8 AI GYM", self.im0) + if cv2.waitKey(1) & 0xFF == ord("q"): return return self.im0 -if __name__ == '__main__': +if __name__ == "__main__": AIGym() diff --git a/ultralytics/solutions/distance_calculation.py b/ultralytics/solutions/distance_calculation.py index 306a5bd2..1770a7c0 100644 --- a/ultralytics/solutions/distance_calculation.py +++ b/ultralytics/solutions/distance_calculation.py @@ -41,13 +41,15 @@ class DistanceCalculation: # Check if environment support imshow self.env_check = check_imshow(warn=True) - def set_args(self, - names, - pixels_per_meter=10, - view_img=False, - line_thickness=2, - line_color=(255, 255, 0), - centroid_color=(255, 0, 255)): + def set_args( + self, + names, + pixels_per_meter=10, + view_img=False, + line_thickness=2, + line_color=(255, 255, 0), + centroid_color=(255, 0, 255), + ): """ Configures the distance calculation and display parameters. @@ -129,8 +131,9 @@ class DistanceCalculation: distance (float): Distance between two centroids """ cv2.rectangle(self.im0, (15, 25), (280, 70), (255, 255, 255), -1) - cv2.putText(self.im0, f'Distance : {distance:.2f}m', (20, 55), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2, - cv2.LINE_AA) + cv2.putText( + self.im0, f"Distance : {distance:.2f}m", (20, 55), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2, cv2.LINE_AA + ) cv2.line(self.im0, self.centroids[0], self.centroids[1], self.line_color, 3) cv2.circle(self.im0, self.centroids[0], 6, self.centroid_color, -1) cv2.circle(self.im0, self.centroids[1], 6, self.centroid_color, -1) @@ -179,13 +182,13 @@ class DistanceCalculation: def display_frames(self): """Display frame.""" - cv2.namedWindow('Ultralytics Distance Estimation') - cv2.setMouseCallback('Ultralytics Distance Estimation', self.mouse_event_for_distance) - cv2.imshow('Ultralytics Distance Estimation', self.im0) + cv2.namedWindow("Ultralytics Distance Estimation") + cv2.setMouseCallback("Ultralytics Distance Estimation", self.mouse_event_for_distance) + cv2.imshow("Ultralytics Distance Estimation", self.im0) - if cv2.waitKey(1) & 0xFF == ord('q'): + if cv2.waitKey(1) & 0xFF == ord("q"): return -if __name__ == '__main__': +if __name__ == "__main__": DistanceCalculation() diff --git a/ultralytics/solutions/heatmap.py b/ultralytics/solutions/heatmap.py index 5d5b6d1f..c60eaa76 100644 --- a/ultralytics/solutions/heatmap.py +++ b/ultralytics/solutions/heatmap.py @@ -8,7 +8,7 @@ import numpy as np from ultralytics.utils.checks import check_imshow, check_requirements from ultralytics.utils.plotting import Annotator -check_requirements('shapely>=2.0.0') +check_requirements("shapely>=2.0.0") from shapely.geometry import LineString, Point, Polygon @@ -22,7 +22,7 @@ class Heatmap: # Visual information self.annotator = None self.view_img = False - self.shape = 'circle' + self.shape = "circle" # Image information self.imw = None @@ -63,23 +63,25 @@ class Heatmap: # Check if environment support imshow self.env_check = check_imshow(warn=True) - def set_args(self, - imw, - imh, - colormap=cv2.COLORMAP_JET, - heatmap_alpha=0.5, - view_img=False, - view_in_counts=True, - view_out_counts=True, - count_reg_pts=None, - count_txt_thickness=2, - count_txt_color=(0, 0, 0), - count_color=(255, 255, 255), - count_reg_color=(255, 0, 255), - region_thickness=5, - line_dist_thresh=15, - decay_factor=0.99, - shape='circle'): + def set_args( + self, + imw, + imh, + colormap=cv2.COLORMAP_JET, + heatmap_alpha=0.5, + view_img=False, + view_in_counts=True, + view_out_counts=True, + count_reg_pts=None, + count_txt_thickness=2, + count_txt_color=(0, 0, 0), + count_color=(255, 255, 255), + count_reg_color=(255, 0, 255), + region_thickness=5, + line_dist_thresh=15, + decay_factor=0.99, + shape="circle", + ): """ Configures the heatmap colormap, width, height and display parameters. @@ -111,20 +113,19 @@ class Heatmap: # Region and line selection if count_reg_pts is not None: - if len(count_reg_pts) == 2: - print('Line Counter Initiated.') + print("Line Counter Initiated.") self.count_reg_pts = count_reg_pts self.counting_region = LineString(count_reg_pts) elif len(count_reg_pts) == 4: - print('Region Counter Initiated.') + print("Region Counter Initiated.") self.count_reg_pts = count_reg_pts self.counting_region = Polygon(self.count_reg_pts) else: - print('Region or line points Invalid, 2 or 4 points supported') - print('Using Line Counter Now') + print("Region or line points Invalid, 2 or 4 points supported") + print("Using Line Counter Now") self.counting_region = Polygon([(20, 400), (1260, 400)]) # dummy points # Heatmap new frame @@ -140,10 +141,10 @@ class Heatmap: self.shape = shape # shape of heatmap, if not selected - if self.shape not in ['circle', 'rect']: + if self.shape not in ["circle", "rect"]: print("Unknown shape value provided, 'circle' & 'rect' supported") - print('Using Circular shape now') - self.shape = 'circle' + print("Using Circular shape now") + self.shape = "circle" def extract_results(self, tracks): """ @@ -177,27 +178,26 @@ class Heatmap: self.annotator = Annotator(self.im0, self.count_txt_thickness, None) if self.count_reg_pts is not None: - # Draw counting region if self.view_in_counts or self.view_out_counts: - self.annotator.draw_region(reg_pts=self.count_reg_pts, - color=self.region_color, - thickness=self.region_thickness) + self.annotator.draw_region( + reg_pts=self.count_reg_pts, color=self.region_color, thickness=self.region_thickness + ) for box, cls, track_id in zip(self.boxes, self.clss, self.track_ids): - - if self.shape == 'circle': + if self.shape == "circle": center = (int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2)) radius = min(int(box[2]) - int(box[0]), int(box[3]) - int(box[1])) // 2 - y, x = np.ogrid[0:self.heatmap.shape[0], 0:self.heatmap.shape[1]] - mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius ** 2 + y, x = np.ogrid[0 : self.heatmap.shape[0], 0 : self.heatmap.shape[1]] + mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius**2 - self.heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] += \ - (2 * mask[int(box[1]):int(box[3]), int(box[0]):int(box[2])]) + self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += ( + 2 * mask[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] + ) else: - self.heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] += 2 + self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += 2 # Store tracking hist track_line = self.track_history[track_id] @@ -226,26 +226,26 @@ class Heatmap: self.in_counts += 1 else: for box, cls in zip(self.boxes, self.clss): - - if self.shape == 'circle': + if self.shape == "circle": center = (int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2)) radius = min(int(box[2]) - int(box[0]), int(box[3]) - int(box[1])) // 2 - y, x = np.ogrid[0:self.heatmap.shape[0], 0:self.heatmap.shape[1]] - mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius ** 2 + y, x = np.ogrid[0 : self.heatmap.shape[0], 0 : self.heatmap.shape[1]] + mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius**2 - self.heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] += \ - (2 * mask[int(box[1]):int(box[3]), int(box[0]):int(box[2])]) + self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += ( + 2 * mask[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] + ) else: - self.heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] += 2 + self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += 2 # Normalize, apply colormap to heatmap and combine with original image heatmap_normalized = cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX) heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), self.colormap) - incount_label = 'In Count : ' + f'{self.in_counts}' - outcount_label = 'OutCount : ' + f'{self.out_counts}' + incount_label = "In Count : " + f"{self.in_counts}" + outcount_label = "OutCount : " + f"{self.out_counts}" # Display counts based on user choice counts_label = None @@ -256,13 +256,15 @@ class Heatmap: elif not self.view_out_counts: counts_label = incount_label else: - counts_label = incount_label + ' ' + outcount_label + counts_label = incount_label + " " + outcount_label if self.count_reg_pts is not None and counts_label is not None: - self.annotator.count_labels(counts=counts_label, - count_txt_size=self.count_txt_thickness, - txt_color=self.count_txt_color, - color=self.count_color) + self.annotator.count_labels( + counts=counts_label, + count_txt_size=self.count_txt_thickness, + txt_color=self.count_txt_color, + color=self.count_color, + ) self.im0 = cv2.addWeighted(self.im0, 1 - self.heatmap_alpha, heatmap_colored, self.heatmap_alpha, 0) @@ -273,11 +275,11 @@ class Heatmap: def display_frames(self): """Display frame.""" - cv2.imshow('Ultralytics Heatmap', self.im0) + cv2.imshow("Ultralytics Heatmap", self.im0) - if cv2.waitKey(1) & 0xFF == ord('q'): + if cv2.waitKey(1) & 0xFF == ord("q"): return -if __name__ == '__main__': +if __name__ == "__main__": Heatmap() diff --git a/ultralytics/solutions/object_counter.py b/ultralytics/solutions/object_counter.py index f817c798..1257553f 100644 --- a/ultralytics/solutions/object_counter.py +++ b/ultralytics/solutions/object_counter.py @@ -7,7 +7,7 @@ import cv2 from ultralytics.utils.checks import check_imshow, check_requirements from ultralytics.utils.plotting import Annotator, colors -check_requirements('shapely>=2.0.0') +check_requirements("shapely>=2.0.0") from shapely.geometry import LineString, Point, Polygon @@ -56,22 +56,24 @@ class ObjectCounter: # Check if environment support imshow self.env_check = check_imshow(warn=True) - def set_args(self, - classes_names, - reg_pts, - count_reg_color=(255, 0, 255), - line_thickness=2, - track_thickness=2, - view_img=False, - view_in_counts=True, - view_out_counts=True, - draw_tracks=False, - count_txt_thickness=2, - count_txt_color=(0, 0, 0), - count_color=(255, 255, 255), - track_color=(0, 255, 0), - region_thickness=5, - line_dist_thresh=15): + def set_args( + self, + classes_names, + reg_pts, + count_reg_color=(255, 0, 255), + line_thickness=2, + track_thickness=2, + view_img=False, + view_in_counts=True, + view_out_counts=True, + draw_tracks=False, + count_txt_thickness=2, + count_txt_color=(0, 0, 0), + count_color=(255, 255, 255), + track_color=(0, 255, 0), + region_thickness=5, + line_dist_thresh=15, + ): """ Configures the Counter's image, bounding box line thickness, and counting region points. @@ -101,16 +103,16 @@ class ObjectCounter: # Region and line selection if len(reg_pts) == 2: - print('Line Counter Initiated.') + print("Line Counter Initiated.") self.reg_pts = reg_pts self.counting_region = LineString(self.reg_pts) elif len(reg_pts) == 4: - print('Region Counter Initiated.') + print("Region Counter Initiated.") self.reg_pts = reg_pts self.counting_region = Polygon(self.reg_pts) else: - print('Invalid Region points provided, region_points can be 2 or 4') - print('Using Line Counter Now') + print("Invalid Region points provided, region_points can be 2 or 4") + print("Using Line Counter Now") self.counting_region = LineString(self.reg_pts) self.names = classes_names @@ -164,8 +166,9 @@ class ObjectCounter: # Extract tracks for box, track_id, cls in zip(boxes, track_ids, clss): - self.annotator.box_label(box, label=str(track_id) + ':' + self.names[cls], - color=colors(int(cls), True)) # Draw bounding box + self.annotator.box_label( + box, label=str(track_id) + ":" + self.names[cls], color=colors(int(cls), True) + ) # Draw bounding box # Draw Tracks track_line = self.track_history[track_id] @@ -175,9 +178,9 @@ class ObjectCounter: # Draw track trails if self.draw_tracks: - self.annotator.draw_centroid_and_tracks(track_line, - color=self.track_color, - track_thickness=self.track_thickness) + self.annotator.draw_centroid_and_tracks( + track_line, color=self.track_color, track_thickness=self.track_thickness + ) # Count objects if len(self.reg_pts) == 4: @@ -199,8 +202,8 @@ class ObjectCounter: else: self.in_counts += 1 - incount_label = 'In Count : ' + f'{self.in_counts}' - outcount_label = 'OutCount : ' + f'{self.out_counts}' + incount_label = "In Count : " + f"{self.in_counts}" + outcount_label = "OutCount : " + f"{self.out_counts}" # Display counts based on user choice counts_label = None @@ -211,24 +214,27 @@ class ObjectCounter: elif not self.view_out_counts: counts_label = incount_label else: - counts_label = incount_label + ' ' + outcount_label + counts_label = incount_label + " " + outcount_label if counts_label is not None: - self.annotator.count_labels(counts=counts_label, - count_txt_size=self.count_txt_thickness, - txt_color=self.count_txt_color, - color=self.count_color) + self.annotator.count_labels( + counts=counts_label, + count_txt_size=self.count_txt_thickness, + txt_color=self.count_txt_color, + color=self.count_color, + ) def display_frames(self): """Display frame.""" if self.env_check: - cv2.namedWindow('Ultralytics YOLOv8 Object Counter') + cv2.namedWindow("Ultralytics YOLOv8 Object Counter") if len(self.reg_pts) == 4: # only add mouse event If user drawn region - cv2.setMouseCallback('Ultralytics YOLOv8 Object Counter', self.mouse_event_for_region, - {'region_points': self.reg_pts}) - cv2.imshow('Ultralytics YOLOv8 Object Counter', self.im0) + cv2.setMouseCallback( + "Ultralytics YOLOv8 Object Counter", self.mouse_event_for_region, {"region_points": self.reg_pts} + ) + cv2.imshow("Ultralytics YOLOv8 Object Counter", self.im0) # Break Window - if cv2.waitKey(1) & 0xFF == ord('q'): + if cv2.waitKey(1) & 0xFF == ord("q"): return def start_counting(self, im0, tracks): @@ -254,5 +260,5 @@ class ObjectCounter: return self.im0 -if __name__ == '__main__': +if __name__ == "__main__": ObjectCounter() diff --git a/ultralytics/solutions/speed_estimation.py b/ultralytics/solutions/speed_estimation.py index 7260141f..55e209f9 100644 --- a/ultralytics/solutions/speed_estimation.py +++ b/ultralytics/solutions/speed_estimation.py @@ -66,7 +66,7 @@ class SpeedEstimator: spdl_dist_thresh (int): Euclidean distance threshold for speed line """ if reg_pts is None: - print('Region points not provided, using default values') + print("Region points not provided, using default values") else: self.reg_pts = reg_pts self.names = names @@ -114,8 +114,9 @@ class SpeedEstimator: cls (str): object class name track (list): tracking history for tracks path drawing """ - speed_label = str(int( - self.dist_data[track_id])) + 'km/ph' if track_id in self.dist_data else self.names[int(cls)] + speed_label = ( + str(int(self.dist_data[track_id])) + "km/ph" if track_id in self.dist_data else self.names[int(cls)] + ) bbox_color = colors(int(track_id)) if track_id in self.dist_data else (255, 0, 255) self.annotator.box_label(box, speed_label, bbox_color) @@ -132,19 +133,16 @@ class SpeedEstimator: """ if self.reg_pts[0][0] < track[-1][0] < self.reg_pts[1][0]: + if self.reg_pts[1][1] - self.spdl_dist_thresh < track[-1][1] < self.reg_pts[1][1] + self.spdl_dist_thresh: + direction = "known" - if (self.reg_pts[1][1] - self.spdl_dist_thresh < track[-1][1] < self.reg_pts[1][1] + self.spdl_dist_thresh): - direction = 'known' - - elif (self.reg_pts[0][1] - self.spdl_dist_thresh < track[-1][1] < - self.reg_pts[0][1] + self.spdl_dist_thresh): - direction = 'known' + elif self.reg_pts[0][1] - self.spdl_dist_thresh < track[-1][1] < self.reg_pts[0][1] + self.spdl_dist_thresh: + direction = "known" else: - direction = 'unknown' - - if self.trk_previous_times[trk_id] != 0 and direction != 'unknown': + direction = "unknown" + if self.trk_previous_times[trk_id] != 0 and direction != "unknown": if trk_id not in self.trk_idslist: self.trk_idslist.append(trk_id) @@ -178,7 +176,6 @@ class SpeedEstimator: self.annotator.draw_region(reg_pts=self.reg_pts, color=(255, 0, 0), thickness=self.region_thickness) for box, trk_id, cls in zip(self.boxes, self.trk_ids, self.clss): - track = self.store_track_info(trk_id, box) if trk_id not in self.trk_previous_times: @@ -194,10 +191,10 @@ class SpeedEstimator: def display_frames(self): """Display frame.""" - cv2.imshow('Ultralytics Speed Estimation', self.im0) - if cv2.waitKey(1) & 0xFF == ord('q'): + cv2.imshow("Ultralytics Speed Estimation", self.im0) + if cv2.waitKey(1) & 0xFF == ord("q"): return -if __name__ == '__main__': +if __name__ == "__main__": SpeedEstimator() diff --git a/ultralytics/trackers/__init__.py b/ultralytics/trackers/__init__.py index 46e178e4..bf51b8df 100644 --- a/ultralytics/trackers/__init__.py +++ b/ultralytics/trackers/__init__.py @@ -4,4 +4,4 @@ from .bot_sort import BOTSORT from .byte_tracker import BYTETracker from .track import register_tracker -__all__ = 'register_tracker', 'BOTSORT', 'BYTETracker' # allow simpler import +__all__ = "register_tracker", "BOTSORT", "BYTETracker" # allow simpler import diff --git a/ultralytics/trackers/bot_sort.py b/ultralytics/trackers/bot_sort.py index 778786b8..31d5e1be 100644 --- a/ultralytics/trackers/bot_sort.py +++ b/ultralytics/trackers/bot_sort.py @@ -39,6 +39,7 @@ class BOTrack(STrack): bo_track.predict() bo_track.update(new_track, frame_id) """ + shared_kalman = KalmanFilterXYWH() def __init__(self, tlwh, score, cls, feat=None, feat_history=50): @@ -176,7 +177,7 @@ class BOTSORT(BYTETracker): def get_dists(self, tracks, detections): """Get distances between tracks and detections using IoU and (optionally) ReID embeddings.""" dists = matching.iou_distance(tracks, detections) - dists_mask = (dists > self.proximity_thresh) + dists_mask = dists > self.proximity_thresh # TODO: mot20 # if not self.args.mot20: diff --git a/ultralytics/trackers/byte_tracker.py b/ultralytics/trackers/byte_tracker.py index 5331511e..619ea12c 100644 --- a/ultralytics/trackers/byte_tracker.py +++ b/ultralytics/trackers/byte_tracker.py @@ -112,8 +112,9 @@ class STrack(BaseTrack): def re_activate(self, new_track, frame_id, new_id=False): """Reactivates a previously lost track with a new detection.""" - self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance, - self.convert_coords(new_track.tlwh)) + self.mean, self.covariance = self.kalman_filter.update( + self.mean, self.covariance, self.convert_coords(new_track.tlwh) + ) self.tracklet_len = 0 self.state = TrackState.Tracked self.is_activated = True @@ -136,8 +137,9 @@ class STrack(BaseTrack): self.tracklet_len += 1 new_tlwh = new_track.tlwh - self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance, - self.convert_coords(new_tlwh)) + self.mean, self.covariance = self.kalman_filter.update( + self.mean, self.covariance, self.convert_coords(new_tlwh) + ) self.state = TrackState.Tracked self.is_activated = True @@ -192,7 +194,7 @@ class STrack(BaseTrack): def __repr__(self): """Return a string representation of the BYTETracker object with start and end frames and track ID.""" - return f'OT_{self.track_id}_({self.start_frame}-{self.end_frame})' + return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})" class BYTETracker: @@ -275,7 +277,7 @@ class BYTETracker: strack_pool = self.joint_stracks(tracked_stracks, self.lost_stracks) # Predict the current location with KF self.multi_predict(strack_pool) - if hasattr(self, 'gmc') and img is not None: + if hasattr(self, "gmc") and img is not None: warp = self.gmc.apply(img, dets) STrack.multi_gmc(strack_pool, warp) STrack.multi_gmc(unconfirmed, warp) @@ -349,7 +351,8 @@ class BYTETracker: self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum return np.asarray( [x.tlbr.tolist() + [x.track_id, x.score, x.cls, x.idx] for x in self.tracked_stracks if x.is_activated], - dtype=np.float32) + dtype=np.float32, + ) def get_kalmanfilter(self): """Returns a Kalman filter object for tracking bounding boxes.""" diff --git a/ultralytics/trackers/track.py b/ultralytics/trackers/track.py index 2ad4f5d7..c0dd4e45 100644 --- a/ultralytics/trackers/track.py +++ b/ultralytics/trackers/track.py @@ -10,7 +10,7 @@ from .bot_sort import BOTSORT from .byte_tracker import BYTETracker # A mapping of tracker types to corresponding tracker classes -TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT} +TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT} def on_predict_start(predictor: object, persist: bool = False) -> None: @@ -24,15 +24,15 @@ def on_predict_start(predictor: object, persist: bool = False) -> None: Raises: AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'. """ - if predictor.args.task == 'obb': - raise NotImplementedError('ERROR ❌ OBB task does not support track mode!') - if hasattr(predictor, 'trackers') and persist: + if predictor.args.task == "obb": + raise NotImplementedError("ERROR ❌ OBB task does not support track mode!") + if hasattr(predictor, "trackers") and persist: return tracker = check_yaml(predictor.args.tracker) cfg = IterableSimpleNamespace(**yaml_load(tracker)) - if cfg.tracker_type not in ['bytetrack', 'botsort']: + if cfg.tracker_type not in ["bytetrack", "botsort"]: raise AssertionError(f"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'") trackers = [] @@ -76,5 +76,5 @@ def register_tracker(model: object, persist: bool) -> None: model (object): The model object to register tracking callbacks for. persist (bool): Whether to persist the trackers if they already exist. """ - model.add_callback('on_predict_start', partial(on_predict_start, persist=persist)) - model.add_callback('on_predict_postprocess_end', partial(on_predict_postprocess_end, persist=persist)) + model.add_callback("on_predict_start", partial(on_predict_start, persist=persist)) + model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist)) diff --git a/ultralytics/trackers/utils/gmc.py b/ultralytics/trackers/utils/gmc.py index 1e801b4a..07726d73 100644 --- a/ultralytics/trackers/utils/gmc.py +++ b/ultralytics/trackers/utils/gmc.py @@ -33,7 +33,7 @@ class GMC: applySparseOptFlow(self, raw_frame, detections=None): Applies the Sparse Optical Flow method to a raw frame. """ - def __init__(self, method: str = 'sparseOptFlow', downscale: int = 2) -> None: + def __init__(self, method: str = "sparseOptFlow", downscale: int = 2) -> None: """ Initialize a video tracker with specified parameters. @@ -46,34 +46,31 @@ class GMC: self.method = method self.downscale = max(1, int(downscale)) - if self.method == 'orb': + if self.method == "orb": self.detector = cv2.FastFeatureDetector_create(20) self.extractor = cv2.ORB_create() self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING) - elif self.method == 'sift': + elif self.method == "sift": self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20) self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20) self.matcher = cv2.BFMatcher(cv2.NORM_L2) - elif self.method == 'ecc': + elif self.method == "ecc": number_of_iterations = 5000 termination_eps = 1e-6 self.warp_mode = cv2.MOTION_EUCLIDEAN self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps) - elif self.method == 'sparseOptFlow': - self.feature_params = dict(maxCorners=1000, - qualityLevel=0.01, - minDistance=1, - blockSize=3, - useHarrisDetector=False, - k=0.04) + elif self.method == "sparseOptFlow": + self.feature_params = dict( + maxCorners=1000, qualityLevel=0.01, minDistance=1, blockSize=3, useHarrisDetector=False, k=0.04 + ) - elif self.method in ['none', 'None', None]: + elif self.method in ["none", "None", None]: self.method = None else: - raise ValueError(f'Error: Unknown GMC method:{method}') + raise ValueError(f"Error: Unknown GMC method:{method}") self.prevFrame = None self.prevKeyPoints = None @@ -97,11 +94,11 @@ class GMC: array([[1, 2, 3], [4, 5, 6]]) """ - if self.method in ['orb', 'sift']: + if self.method in ["orb", "sift"]: return self.applyFeatures(raw_frame, detections) - elif self.method == 'ecc': + elif self.method == "ecc": return self.applyEcc(raw_frame, detections) - elif self.method == 'sparseOptFlow': + elif self.method == "sparseOptFlow": return self.applySparseOptFlow(raw_frame, detections) else: return np.eye(2, 3) @@ -149,7 +146,7 @@ class GMC: try: (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1) except Exception as e: - LOGGER.warning(f'WARNING: find transform failed. Set warp as identity {e}') + LOGGER.warning(f"WARNING: find transform failed. Set warp as identity {e}") return H @@ -182,11 +179,11 @@ class GMC: # Find the keypoints mask = np.zeros_like(frame) - mask[int(0.02 * height):int(0.98 * height), int(0.02 * width):int(0.98 * width)] = 255 + mask[int(0.02 * height) : int(0.98 * height), int(0.02 * width) : int(0.98 * width)] = 255 if detections is not None: for det in detections: tlbr = (det[:4] / self.downscale).astype(np.int_) - mask[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2]] = 0 + mask[tlbr[1] : tlbr[3], tlbr[0] : tlbr[2]] = 0 keypoints = self.detector.detect(frame, mask) @@ -228,11 +225,14 @@ class GMC: prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt currKeyPointLocation = keypoints[m.trainIdx].pt - spatialDistance = (prevKeyPointLocation[0] - currKeyPointLocation[0], - prevKeyPointLocation[1] - currKeyPointLocation[1]) + spatialDistance = ( + prevKeyPointLocation[0] - currKeyPointLocation[0], + prevKeyPointLocation[1] - currKeyPointLocation[1], + ) - if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and \ - (np.abs(spatialDistance[1]) < maxSpatialDistance[1]): + if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and ( + np.abs(spatialDistance[1]) < maxSpatialDistance[1] + ): spatialDistances.append(spatialDistance) matches.append(m) @@ -283,7 +283,7 @@ class GMC: H[0, 2] *= self.downscale H[1, 2] *= self.downscale else: - LOGGER.warning('WARNING: not enough matching points') + LOGGER.warning("WARNING: not enough matching points") # Store to next iteration self.prevFrame = frame.copy() @@ -350,7 +350,7 @@ class GMC: H[0, 2] *= self.downscale H[1, 2] *= self.downscale else: - LOGGER.warning('WARNING: not enough matching points') + LOGGER.warning("WARNING: not enough matching points") self.prevFrame = frame.copy() self.prevKeyPoints = copy.copy(keypoints) diff --git a/ultralytics/trackers/utils/kalman_filter.py b/ultralytics/trackers/utils/kalman_filter.py index b4fb91fc..4ae68be2 100644 --- a/ultralytics/trackers/utils/kalman_filter.py +++ b/ultralytics/trackers/utils/kalman_filter.py @@ -17,7 +17,7 @@ class KalmanFilterXYAH: def __init__(self): """Initialize Kalman filter model matrices with motion and observation uncertainty weights.""" - ndim, dt = 4, 1. + ndim, dt = 4, 1.0 # Create Kalman filter model matrices self._motion_mat = np.eye(2 * ndim, 2 * ndim) @@ -27,8 +27,8 @@ class KalmanFilterXYAH: # Motion and observation uncertainty are chosen relative to the current state estimate. These weights control # the amount of uncertainty in the model. - self._std_weight_position = 1. / 20 - self._std_weight_velocity = 1. / 160 + self._std_weight_position = 1.0 / 20 + self._std_weight_velocity = 1.0 / 160 def initiate(self, measurement: np.ndarray) -> tuple: """ @@ -47,9 +47,15 @@ class KalmanFilterXYAH: mean = np.r_[mean_pos, mean_vel] std = [ - 2 * self._std_weight_position * measurement[3], 2 * self._std_weight_position * measurement[3], 1e-2, - 2 * self._std_weight_position * measurement[3], 10 * self._std_weight_velocity * measurement[3], - 10 * self._std_weight_velocity * measurement[3], 1e-5, 10 * self._std_weight_velocity * measurement[3]] + 2 * self._std_weight_position * measurement[3], + 2 * self._std_weight_position * measurement[3], + 1e-2, + 2 * self._std_weight_position * measurement[3], + 10 * self._std_weight_velocity * measurement[3], + 10 * self._std_weight_velocity * measurement[3], + 1e-5, + 10 * self._std_weight_velocity * measurement[3], + ] covariance = np.diag(np.square(std)) return mean, covariance @@ -66,11 +72,17 @@ class KalmanFilterXYAH: velocities are initialized to 0 mean. """ std_pos = [ - self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-2, - self._std_weight_position * mean[3]] + self._std_weight_position * mean[3], + self._std_weight_position * mean[3], + 1e-2, + self._std_weight_position * mean[3], + ] std_vel = [ - self._std_weight_velocity * mean[3], self._std_weight_velocity * mean[3], 1e-5, - self._std_weight_velocity * mean[3]] + self._std_weight_velocity * mean[3], + self._std_weight_velocity * mean[3], + 1e-5, + self._std_weight_velocity * mean[3], + ] motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) mean = np.dot(mean, self._motion_mat.T) @@ -90,8 +102,11 @@ class KalmanFilterXYAH: (tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate. """ std = [ - self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-1, - self._std_weight_position * mean[3]] + self._std_weight_position * mean[3], + self._std_weight_position * mean[3], + 1e-1, + self._std_weight_position * mean[3], + ] innovation_cov = np.diag(np.square(std)) mean = np.dot(self._update_mat, mean) @@ -111,11 +126,17 @@ class KalmanFilterXYAH: velocities are initialized to 0 mean. """ std_pos = [ - self._std_weight_position * mean[:, 3], self._std_weight_position * mean[:, 3], - 1e-2 * np.ones_like(mean[:, 3]), self._std_weight_position * mean[:, 3]] + self._std_weight_position * mean[:, 3], + self._std_weight_position * mean[:, 3], + 1e-2 * np.ones_like(mean[:, 3]), + self._std_weight_position * mean[:, 3], + ] std_vel = [ - self._std_weight_velocity * mean[:, 3], self._std_weight_velocity * mean[:, 3], - 1e-5 * np.ones_like(mean[:, 3]), self._std_weight_velocity * mean[:, 3]] + self._std_weight_velocity * mean[:, 3], + self._std_weight_velocity * mean[:, 3], + 1e-5 * np.ones_like(mean[:, 3]), + self._std_weight_velocity * mean[:, 3], + ] sqr = np.square(np.r_[std_pos, std_vel]).T motion_cov = [np.diag(sqr[i]) for i in range(len(mean))] @@ -143,21 +164,23 @@ class KalmanFilterXYAH: projected_mean, projected_cov = self.project(mean, covariance) chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False) - kalman_gain = scipy.linalg.cho_solve((chol_factor, lower), - np.dot(covariance, self._update_mat.T).T, - check_finite=False).T + kalman_gain = scipy.linalg.cho_solve( + (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, check_finite=False + ).T innovation = measurement - projected_mean new_mean = mean + np.dot(innovation, kalman_gain.T) new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T)) return new_mean, new_covariance - def gating_distance(self, - mean: np.ndarray, - covariance: np.ndarray, - measurements: np.ndarray, - only_position: bool = False, - metric: str = 'maha') -> np.ndarray: + def gating_distance( + self, + mean: np.ndarray, + covariance: np.ndarray, + measurements: np.ndarray, + only_position: bool = False, + metric: str = "maha", + ) -> np.ndarray: """ Compute gating distance between state distribution and measurements. A suitable distance threshold can be obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of freedom, @@ -183,14 +206,14 @@ class KalmanFilterXYAH: measurements = measurements[:, :2] d = measurements - mean - if metric == 'gaussian': + if metric == "gaussian": return np.sum(d * d, axis=1) - elif metric == 'maha': + elif metric == "maha": cholesky_factor = np.linalg.cholesky(covariance) z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True) return np.sum(z * z, axis=0) # square maha else: - raise ValueError('Invalid distance metric') + raise ValueError("Invalid distance metric") class KalmanFilterXYWH(KalmanFilterXYAH): @@ -220,10 +243,15 @@ class KalmanFilterXYWH(KalmanFilterXYAH): mean = np.r_[mean_pos, mean_vel] std = [ - 2 * self._std_weight_position * measurement[2], 2 * self._std_weight_position * measurement[3], - 2 * self._std_weight_position * measurement[2], 2 * self._std_weight_position * measurement[3], - 10 * self._std_weight_velocity * measurement[2], 10 * self._std_weight_velocity * measurement[3], - 10 * self._std_weight_velocity * measurement[2], 10 * self._std_weight_velocity * measurement[3]] + 2 * self._std_weight_position * measurement[2], + 2 * self._std_weight_position * measurement[3], + 2 * self._std_weight_position * measurement[2], + 2 * self._std_weight_position * measurement[3], + 10 * self._std_weight_velocity * measurement[2], + 10 * self._std_weight_velocity * measurement[3], + 10 * self._std_weight_velocity * measurement[2], + 10 * self._std_weight_velocity * measurement[3], + ] covariance = np.diag(np.square(std)) return mean, covariance @@ -240,11 +268,17 @@ class KalmanFilterXYWH(KalmanFilterXYAH): velocities are initialized to 0 mean. """ std_pos = [ - self._std_weight_position * mean[2], self._std_weight_position * mean[3], - self._std_weight_position * mean[2], self._std_weight_position * mean[3]] + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + ] std_vel = [ - self._std_weight_velocity * mean[2], self._std_weight_velocity * mean[3], - self._std_weight_velocity * mean[2], self._std_weight_velocity * mean[3]] + self._std_weight_velocity * mean[2], + self._std_weight_velocity * mean[3], + self._std_weight_velocity * mean[2], + self._std_weight_velocity * mean[3], + ] motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) mean = np.dot(mean, self._motion_mat.T) @@ -264,8 +298,11 @@ class KalmanFilterXYWH(KalmanFilterXYAH): (tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate. """ std = [ - self._std_weight_position * mean[2], self._std_weight_position * mean[3], - self._std_weight_position * mean[2], self._std_weight_position * mean[3]] + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + ] innovation_cov = np.diag(np.square(std)) mean = np.dot(self._update_mat, mean) @@ -285,11 +322,17 @@ class KalmanFilterXYWH(KalmanFilterXYAH): velocities are initialized to 0 mean. """ std_pos = [ - self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3], - self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3]] + self._std_weight_position * mean[:, 2], + self._std_weight_position * mean[:, 3], + self._std_weight_position * mean[:, 2], + self._std_weight_position * mean[:, 3], + ] std_vel = [ - self._std_weight_velocity * mean[:, 2], self._std_weight_velocity * mean[:, 3], - self._std_weight_velocity * mean[:, 2], self._std_weight_velocity * mean[:, 3]] + self._std_weight_velocity * mean[:, 2], + self._std_weight_velocity * mean[:, 3], + self._std_weight_velocity * mean[:, 2], + self._std_weight_velocity * mean[:, 3], + ] sqr = np.square(np.r_[std_pos, std_vel]).T motion_cov = [np.diag(sqr[i]) for i in range(len(mean))] diff --git a/ultralytics/trackers/utils/matching.py b/ultralytics/trackers/utils/matching.py index b2df4bd1..cf59755c 100644 --- a/ultralytics/trackers/utils/matching.py +++ b/ultralytics/trackers/utils/matching.py @@ -13,7 +13,7 @@ try: except (ImportError, AssertionError, AttributeError): from ultralytics.utils.checks import check_requirements - check_requirements('lapx>=0.5.2') # update to lap package from https://github.com/rathaROG/lapx + check_requirements("lapx>=0.5.2") # update to lap package from https://github.com/rathaROG/lapx import lap @@ -70,8 +70,9 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray: (np.ndarray): Cost matrix computed based on IoU. """ - if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) \ - or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)): + if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) or ( + len(btracks) > 0 and isinstance(btracks[0], np.ndarray) + ): atlbrs = atracks btlbrs = btracks else: @@ -80,13 +81,13 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray: ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32) if len(atlbrs) and len(btlbrs): - ious = bbox_ioa(np.ascontiguousarray(atlbrs, dtype=np.float32), - np.ascontiguousarray(btlbrs, dtype=np.float32), - iou=True) + ious = bbox_ioa( + np.ascontiguousarray(atlbrs, dtype=np.float32), np.ascontiguousarray(btlbrs, dtype=np.float32), iou=True + ) return 1 - ious # cost matrix -def embedding_distance(tracks: list, detections: list, metric: str = 'cosine') -> np.ndarray: +def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -> np.ndarray: """ Compute distance between tracks and detections based on embeddings. diff --git a/ultralytics/utils/__init__.py b/ultralytics/utils/__init__.py index 67bc7a2d..07641c65 100644 --- a/ultralytics/utils/__init__.py +++ b/ultralytics/utils/__init__.py @@ -25,23 +25,22 @@ from tqdm import tqdm as tqdm_original from ultralytics import __version__ # PyTorch Multi-GPU DDP Constants -RANK = int(os.getenv('RANK', -1)) -LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html +RANK = int(os.getenv("RANK", -1)) +LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html # Other Constants FILE = Path(__file__).resolve() ROOT = FILE.parents[1] # YOLO -ASSETS = ROOT / 'assets' # default images -DEFAULT_CFG_PATH = ROOT / 'cfg/default.yaml' +ASSETS = ROOT / "assets" # default images +DEFAULT_CFG_PATH = ROOT / "cfg/default.yaml" NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads -AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode -VERBOSE = str(os.getenv('YOLO_VERBOSE', True)).lower() == 'true' # global verbose mode -TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' if VERBOSE else None # tqdm bar format -LOGGING_NAME = 'ultralytics' -MACOS, LINUX, WINDOWS = (platform.system() == x for x in ['Darwin', 'Linux', 'Windows']) # environment booleans -ARM64 = platform.machine() in ('arm64', 'aarch64') # ARM64 booleans -HELP_MSG = \ - """ +AUTOINSTALL = str(os.getenv("YOLO_AUTOINSTALL", True)).lower() == "true" # global auto-install mode +VERBOSE = str(os.getenv("YOLO_VERBOSE", True)).lower() == "true" # global verbose mode +TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar format +LOGGING_NAME = "ultralytics" +MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans +ARM64 = platform.machine() in ("arm64", "aarch64") # ARM64 booleans +HELP_MSG = """ Usage examples for running YOLOv8: 1. Install the ultralytics package: @@ -99,12 +98,12 @@ HELP_MSG = \ """ # Settings -torch.set_printoptions(linewidth=320, precision=4, profile='default') -np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5 +torch.set_printoptions(linewidth=320, precision=4, profile="default") +np.set_printoptions(linewidth=320, formatter={"float_kind": "{:11.5g}".format}) # format short g, %precision=5 cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader) -os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads -os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # suppress verbose TF compiler warnings in Colab +os.environ["NUMEXPR_MAX_THREADS"] = str(NUM_THREADS) # NumExpr max threads +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # for deterministic training +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # suppress verbose TF compiler warnings in Colab class TQDM(tqdm_original): @@ -119,8 +118,8 @@ class TQDM(tqdm_original): def __init__(self, *args, **kwargs): """Initialize custom Ultralytics tqdm class with different default arguments.""" # Set new default values (these can still be overridden when calling TQDM) - kwargs['disable'] = not VERBOSE or kwargs.get('disable', False) # logical 'and' with default value if passed - kwargs.setdefault('bar_format', TQDM_BAR_FORMAT) # override default value if passed + kwargs["disable"] = not VERBOSE or kwargs.get("disable", False) # logical 'and' with default value if passed + kwargs.setdefault("bar_format", TQDM_BAR_FORMAT) # override default value if passed super().__init__(*args, **kwargs) @@ -134,14 +133,14 @@ class SimpleClass: attr = [] for a in dir(self): v = getattr(self, a) - if not callable(v) and not a.startswith('_'): + if not callable(v) and not a.startswith("_"): if isinstance(v, SimpleClass): # Display only the module and class name for subclasses - s = f'{a}: {v.__module__}.{v.__class__.__name__} object' + s = f"{a}: {v.__module__}.{v.__class__.__name__} object" else: - s = f'{a}: {repr(v)}' + s = f"{a}: {repr(v)}" attr.append(s) - return f'{self.__module__}.{self.__class__.__name__} object with attributes:\n\n' + '\n'.join(attr) + return f"{self.__module__}.{self.__class__.__name__} object with attributes:\n\n" + "\n".join(attr) def __repr__(self): """Return a machine-readable string representation of the object.""" @@ -164,24 +163,26 @@ class IterableSimpleNamespace(SimpleNamespace): def __str__(self): """Return a human-readable string representation of the object.""" - return '\n'.join(f'{k}={v}' for k, v in vars(self).items()) + return "\n".join(f"{k}={v}" for k, v in vars(self).items()) def __getattr__(self, attr): """Custom attribute access error message with helpful information.""" name = self.__class__.__name__ - raise AttributeError(f""" + raise AttributeError( + f""" '{name}' object has no attribute '{attr}'. This may be caused by a modified or out of date ultralytics 'default.yaml' file.\nPlease update your code with 'pip install -U ultralytics' and if necessary replace {DEFAULT_CFG_PATH} with the latest version from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml - """) + """ + ) def get(self, key, default=None): """Return the value of the specified key if it exists; otherwise, return the default value.""" return getattr(self, key, default) -def plt_settings(rcparams=None, backend='Agg'): +def plt_settings(rcparams=None, backend="Agg"): """ Decorator to temporarily set rc parameters and the backend for a plotting function. @@ -199,7 +200,7 @@ def plt_settings(rcparams=None, backend='Agg'): """ if rcparams is None: - rcparams = {'font.size': 11} + rcparams = {"font.size": 11} def decorator(func): """Decorator to apply temporary rc parameters and backend to a function.""" @@ -208,14 +209,14 @@ def plt_settings(rcparams=None, backend='Agg'): """Sets rc parameters and backend, calls the original function, and restores the settings.""" original_backend = plt.get_backend() if backend != original_backend: - plt.close('all') # auto-close()ing of figures upon backend switching is deprecated since 3.8 + plt.close("all") # auto-close()ing of figures upon backend switching is deprecated since 3.8 plt.switch_backend(backend) with plt.rc_context(rcparams): result = func(*args, **kwargs) if backend != original_backend: - plt.close('all') + plt.close("all") plt.switch_backend(original_backend) return result @@ -229,26 +230,26 @@ def set_logging(name=LOGGING_NAME, verbose=True): level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR # rank in world for Multi-GPU trainings # Configure the console (stdout) encoding to UTF-8 - formatter = logging.Formatter('%(message)s') # Default formatter - if WINDOWS and sys.stdout.encoding != 'utf-8': + formatter = logging.Formatter("%(message)s") # Default formatter + if WINDOWS and sys.stdout.encoding != "utf-8": try: - if hasattr(sys.stdout, 'reconfigure'): - sys.stdout.reconfigure(encoding='utf-8') - elif hasattr(sys.stdout, 'buffer'): + if hasattr(sys.stdout, "reconfigure"): + sys.stdout.reconfigure(encoding="utf-8") + elif hasattr(sys.stdout, "buffer"): import io - sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') + + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") else: - sys.stdout.encoding = 'utf-8' + sys.stdout.encoding = "utf-8" except Exception as e: - print(f'Creating custom formatter for non UTF-8 environments due to {e}') + print(f"Creating custom formatter for non UTF-8 environments due to {e}") class CustomFormatter(logging.Formatter): - def format(self, record): """Sets up logging with UTF-8 encoding and configurable verbosity.""" return emojis(super().format(record)) - formatter = CustomFormatter('%(message)s') # Use CustomFormatter to eliminate UTF-8 output as last recourse + formatter = CustomFormatter("%(message)s") # Use CustomFormatter to eliminate UTF-8 output as last recourse # Create and configure the StreamHandler stream_handler = logging.StreamHandler(sys.stdout) @@ -264,13 +265,13 @@ def set_logging(name=LOGGING_NAME, verbose=True): # Set logger LOGGER = set_logging(LOGGING_NAME, verbose=VERBOSE) # define globally (used in train.py, val.py, predict.py, etc.) -for logger in 'sentry_sdk', 'urllib3.connectionpool': +for logger in "sentry_sdk", "urllib3.connectionpool": logging.getLogger(logger).setLevel(logging.CRITICAL + 1) -def emojis(string=''): +def emojis(string=""): """Return platform-dependent emoji-safe version of string.""" - return string.encode().decode('ascii', 'ignore') if WINDOWS else string + return string.encode().decode("ascii", "ignore") if WINDOWS else string class ThreadingLocked: @@ -310,7 +311,7 @@ class ThreadingLocked: return decorated -def yaml_save(file='data.yaml', data=None, header=''): +def yaml_save(file="data.yaml", data=None, header=""): """ Save YAML data to a file. @@ -336,13 +337,13 @@ def yaml_save(file='data.yaml', data=None, header=''): data[k] = str(v) # Dump data to file in YAML format - with open(file, 'w', errors='ignore', encoding='utf-8') as f: + with open(file, "w", errors="ignore", encoding="utf-8") as f: if header: f.write(header) yaml.safe_dump(data, f, sort_keys=False, allow_unicode=True) -def yaml_load(file='data.yaml', append_filename=False): +def yaml_load(file="data.yaml", append_filename=False): """ Load YAML data from a file. @@ -353,18 +354,18 @@ def yaml_load(file='data.yaml', append_filename=False): Returns: (dict): YAML data and file name. """ - assert Path(file).suffix in ('.yaml', '.yml'), f'Attempting to load non-YAML file {file} with yaml_load()' - with open(file, errors='ignore', encoding='utf-8') as f: + assert Path(file).suffix in (".yaml", ".yml"), f"Attempting to load non-YAML file {file} with yaml_load()" + with open(file, errors="ignore", encoding="utf-8") as f: s = f.read() # string # Remove special characters if not s.isprintable(): - s = re.sub(r'[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+', '', s) + s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+", "", s) # Add YAML filename to dict and return data = yaml.safe_load(s) or {} # always return a dict (yaml.safe_load() may return None for empty files) if append_filename: - data['yaml_file'] = str(file) + data["yaml_file"] = str(file) return data @@ -386,7 +387,7 @@ def yaml_print(yaml_file: Union[str, Path, dict]) -> None: # Default configuration DEFAULT_CFG_DICT = yaml_load(DEFAULT_CFG_PATH) for k, v in DEFAULT_CFG_DICT.items(): - if isinstance(v, str) and v.lower() == 'none': + if isinstance(v, str) and v.lower() == "none": DEFAULT_CFG_DICT[k] = None DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys() DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT) @@ -400,8 +401,8 @@ def is_ubuntu() -> bool: (bool): True if OS is Ubuntu, False otherwise. """ with contextlib.suppress(FileNotFoundError): - with open('/etc/os-release') as f: - return 'ID=ubuntu' in f.read() + with open("/etc/os-release") as f: + return "ID=ubuntu" in f.read() return False @@ -412,7 +413,7 @@ def is_colab(): Returns: (bool): True if running inside a Colab notebook, False otherwise. """ - return 'COLAB_RELEASE_TAG' in os.environ or 'COLAB_BACKEND_VERSION' in os.environ + return "COLAB_RELEASE_TAG" in os.environ or "COLAB_BACKEND_VERSION" in os.environ def is_kaggle(): @@ -422,7 +423,7 @@ def is_kaggle(): Returns: (bool): True if running inside a Kaggle kernel, False otherwise. """ - return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com' + return os.environ.get("PWD") == "/kaggle/working" and os.environ.get("KAGGLE_URL_BASE") == "https://www.kaggle.com" def is_jupyter(): @@ -434,6 +435,7 @@ def is_jupyter(): """ with contextlib.suppress(Exception): from IPython import get_ipython + return get_ipython() is not None return False @@ -445,10 +447,10 @@ def is_docker() -> bool: Returns: (bool): True if the script is running inside a Docker container, False otherwise. """ - file = Path('/proc/self/cgroup') + file = Path("/proc/self/cgroup") if file.exists(): with open(file) as f: - return 'docker' in f.read() + return "docker" in f.read() else: return False @@ -462,7 +464,7 @@ def is_online() -> bool: """ import socket - for host in '1.1.1.1', '8.8.8.8', '223.5.5.5': # Cloudflare, Google, AliDNS: + for host in "1.1.1.1", "8.8.8.8", "223.5.5.5": # Cloudflare, Google, AliDNS: try: test_connection = socket.create_connection(address=(host, 53), timeout=2) except (socket.timeout, socket.gaierror, OSError): @@ -516,7 +518,7 @@ def is_pytest_running(): Returns: (bool): True if pytest is running, False otherwise. """ - return ('PYTEST_CURRENT_TEST' in os.environ) or ('pytest' in sys.modules) or ('pytest' in Path(sys.argv[0]).stem) + return ("PYTEST_CURRENT_TEST" in os.environ) or ("pytest" in sys.modules) or ("pytest" in Path(sys.argv[0]).stem) def is_github_action_running() -> bool: @@ -526,7 +528,7 @@ def is_github_action_running() -> bool: Returns: (bool): True if the current environment is a GitHub Actions runner, False otherwise. """ - return 'GITHUB_ACTIONS' in os.environ and 'GITHUB_WORKFLOW' in os.environ and 'RUNNER_OS' in os.environ + return "GITHUB_ACTIONS" in os.environ and "GITHUB_WORKFLOW" in os.environ and "RUNNER_OS" in os.environ def is_git_dir(): @@ -549,7 +551,7 @@ def get_git_dir(): (Path | None): Git root directory if found or None if not found. """ for d in Path(__file__).parents: - if (d / '.git').is_dir(): + if (d / ".git").is_dir(): return d @@ -562,7 +564,7 @@ def get_git_origin_url(): """ if is_git_dir(): with contextlib.suppress(subprocess.CalledProcessError): - origin = subprocess.check_output(['git', 'config', '--get', 'remote.origin.url']) + origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"]) return origin.decode().strip() @@ -575,7 +577,7 @@ def get_git_branch(): """ if is_git_dir(): with contextlib.suppress(subprocess.CalledProcessError): - origin = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]) return origin.decode().strip() @@ -602,11 +604,11 @@ def get_ubuntu_version(): """ if is_ubuntu(): with contextlib.suppress(FileNotFoundError, AttributeError): - with open('/etc/os-release') as f: + with open("/etc/os-release") as f: return re.search(r'VERSION_ID="(\d+\.\d+)"', f.read())[1] -def get_user_config_dir(sub_dir='Ultralytics'): +def get_user_config_dir(sub_dir="Ultralytics"): """ Get the user config directory. @@ -618,19 +620,21 @@ def get_user_config_dir(sub_dir='Ultralytics'): """ # Return the appropriate config directory for each operating system if WINDOWS: - path = Path.home() / 'AppData' / 'Roaming' / sub_dir + path = Path.home() / "AppData" / "Roaming" / sub_dir elif MACOS: # macOS - path = Path.home() / 'Library' / 'Application Support' / sub_dir + path = Path.home() / "Library" / "Application Support" / sub_dir elif LINUX: - path = Path.home() / '.config' / sub_dir + path = Path.home() / ".config" / sub_dir else: - raise ValueError(f'Unsupported operating system: {platform.system()}') + raise ValueError(f"Unsupported operating system: {platform.system()}") # GCP and AWS lambda fix, only /tmp is writeable if not is_dir_writeable(path.parent): - LOGGER.warning(f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD." - 'Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path.') - path = Path('/tmp') / sub_dir if is_dir_writeable('/tmp') else Path().cwd() / sub_dir + LOGGER.warning( + f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD." + "Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path." + ) + path = Path("/tmp") / sub_dir if is_dir_writeable("/tmp") else Path().cwd() / sub_dir # Create the subdirectory if it does not exist path.mkdir(parents=True, exist_ok=True) @@ -638,8 +642,8 @@ def get_user_config_dir(sub_dir='Ultralytics'): return path -USER_CONFIG_DIR = Path(os.getenv('YOLO_CONFIG_DIR') or get_user_config_dir()) # Ultralytics settings dir -SETTINGS_YAML = USER_CONFIG_DIR / 'settings.yaml' +USER_CONFIG_DIR = Path(os.getenv("YOLO_CONFIG_DIR") or get_user_config_dir()) # Ultralytics settings dir +SETTINGS_YAML = USER_CONFIG_DIR / "settings.yaml" def colorstr(*input): @@ -670,28 +674,29 @@ def colorstr(*input): >>> colorstr('blue', 'bold', 'hello world') >>> '\033[34m\033[1mhello world\033[0m' """ - *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string + *args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string colors = { - 'black': '\033[30m', # basic colors - 'red': '\033[31m', - 'green': '\033[32m', - 'yellow': '\033[33m', - 'blue': '\033[34m', - 'magenta': '\033[35m', - 'cyan': '\033[36m', - 'white': '\033[37m', - 'bright_black': '\033[90m', # bright colors - 'bright_red': '\033[91m', - 'bright_green': '\033[92m', - 'bright_yellow': '\033[93m', - 'bright_blue': '\033[94m', - 'bright_magenta': '\033[95m', - 'bright_cyan': '\033[96m', - 'bright_white': '\033[97m', - 'end': '\033[0m', # misc - 'bold': '\033[1m', - 'underline': '\033[4m'} - return ''.join(colors[x] for x in args) + f'{string}' + colors['end'] + "black": "\033[30m", # basic colors + "red": "\033[31m", + "green": "\033[32m", + "yellow": "\033[33m", + "blue": "\033[34m", + "magenta": "\033[35m", + "cyan": "\033[36m", + "white": "\033[37m", + "bright_black": "\033[90m", # bright colors + "bright_red": "\033[91m", + "bright_green": "\033[92m", + "bright_yellow": "\033[93m", + "bright_blue": "\033[94m", + "bright_magenta": "\033[95m", + "bright_cyan": "\033[96m", + "bright_white": "\033[97m", + "end": "\033[0m", # misc + "bold": "\033[1m", + "underline": "\033[4m", + } + return "".join(colors[x] for x in args) + f"{string}" + colors["end"] def remove_colorstr(input_string): @@ -708,8 +713,8 @@ def remove_colorstr(input_string): >>> remove_colorstr(colorstr('blue', 'bold', 'hello world')) >>> 'hello world' """ - ansi_escape = re.compile(r'\x1B\[[0-9;]*[A-Za-z]') - return ansi_escape.sub('', input_string) + ansi_escape = re.compile(r"\x1B\[[0-9;]*[A-Za-z]") + return ansi_escape.sub("", input_string) class TryExcept(contextlib.ContextDecorator): @@ -719,7 +724,7 @@ class TryExcept(contextlib.ContextDecorator): Use as @TryExcept() decorator or 'with TryExcept():' context manager. """ - def __init__(self, msg='', verbose=True): + def __init__(self, msg="", verbose=True): """Initialize TryExcept class with optional message and verbosity settings.""" self.msg = msg self.verbose = verbose @@ -744,7 +749,7 @@ def threaded(func): def wrapper(*args, **kwargs): """Multi-threads a given function based on 'threaded' kwarg and returns the thread or function result.""" - if kwargs.pop('threaded', True): # run in thread + if kwargs.pop("threaded", True): # run in thread thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True) thread.start() return thread @@ -786,27 +791,28 @@ def set_sentry(): Returns: dict: The modified event or None if the event should not be sent to Sentry. """ - if 'exc_info' in hint: - exc_type, exc_value, tb = hint['exc_info'] - if exc_type in (KeyboardInterrupt, FileNotFoundError) \ - or 'out of memory' in str(exc_value): + if "exc_info" in hint: + exc_type, exc_value, tb = hint["exc_info"] + if exc_type in (KeyboardInterrupt, FileNotFoundError) or "out of memory" in str(exc_value): return None # do not send event - event['tags'] = { - 'sys_argv': sys.argv[0], - 'sys_argv_name': Path(sys.argv[0]).name, - 'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other', - 'os': ENVIRONMENT} + event["tags"] = { + "sys_argv": sys.argv[0], + "sys_argv_name": Path(sys.argv[0]).name, + "install": "git" if is_git_dir() else "pip" if is_pip_package() else "other", + "os": ENVIRONMENT, + } return event - if SETTINGS['sync'] and \ - RANK in (-1, 0) and \ - Path(sys.argv[0]).name == 'yolo' and \ - not TESTS_RUNNING and \ - ONLINE and \ - is_pip_package() and \ - not is_git_dir(): - + if ( + SETTINGS["sync"] + and RANK in (-1, 0) + and Path(sys.argv[0]).name == "yolo" + and not TESTS_RUNNING + and ONLINE + and is_pip_package() + and not is_git_dir() + ): # If sentry_sdk package is not installed then return and do not use Sentry try: import sentry_sdk # noqa @@ -814,14 +820,15 @@ def set_sentry(): return sentry_sdk.init( - dsn='https://5ff1556b71594bfea135ff0203a0d290@o4504521589325824.ingest.sentry.io/4504521592406016', + dsn="https://5ff1556b71594bfea135ff0203a0d290@o4504521589325824.ingest.sentry.io/4504521592406016", debug=False, traces_sample_rate=1.0, release=__version__, - environment='production', # 'dev' or 'production' + environment="production", # 'dev' or 'production' before_send=before_send, - ignore_errors=[KeyboardInterrupt, FileNotFoundError]) - sentry_sdk.set_user({'id': SETTINGS['uuid']}) # SHA-256 anonymized UUID hash + ignore_errors=[KeyboardInterrupt, FileNotFoundError], + ) + sentry_sdk.set_user({"id": SETTINGS["uuid"]}) # SHA-256 anonymized UUID hash class SettingsManager(dict): @@ -833,7 +840,7 @@ class SettingsManager(dict): version (str): Settings version. In case of local version mismatch, new default settings will be saved. """ - def __init__(self, file=SETTINGS_YAML, version='0.0.4'): + def __init__(self, file=SETTINGS_YAML, version="0.0.4"): """Initialize the SettingsManager with default settings, load and validate current settings from the YAML file. """ @@ -850,23 +857,24 @@ class SettingsManager(dict): self.file = Path(file) self.version = version self.defaults = { - 'settings_version': version, - 'datasets_dir': str(datasets_root / 'datasets'), - 'weights_dir': str(root / 'weights'), - 'runs_dir': str(root / 'runs'), - 'uuid': hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(), - 'sync': True, - 'api_key': '', - 'openai_api_key': '', - 'clearml': True, # integrations - 'comet': True, - 'dvc': True, - 'hub': True, - 'mlflow': True, - 'neptune': True, - 'raytune': True, - 'tensorboard': True, - 'wandb': True} + "settings_version": version, + "datasets_dir": str(datasets_root / "datasets"), + "weights_dir": str(root / "weights"), + "runs_dir": str(root / "runs"), + "uuid": hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(), + "sync": True, + "api_key": "", + "openai_api_key": "", + "clearml": True, # integrations + "comet": True, + "dvc": True, + "hub": True, + "mlflow": True, + "neptune": True, + "raytune": True, + "tensorboard": True, + "wandb": True, + } super().__init__(copy.deepcopy(self.defaults)) @@ -877,13 +885,14 @@ class SettingsManager(dict): self.load() correct_keys = self.keys() == self.defaults.keys() correct_types = all(type(a) is type(b) for a, b in zip(self.values(), self.defaults.values())) - correct_version = check_version(self['settings_version'], self.version) + correct_version = check_version(self["settings_version"], self.version) if not (correct_keys and correct_types and correct_version): LOGGER.warning( - 'WARNING ⚠️ Ultralytics settings reset to default values. This may be due to a possible problem ' - 'with your settings or a recent ultralytics package update. ' + "WARNING ⚠️ Ultralytics settings reset to default values. This may be due to a possible problem " + "with your settings or a recent ultralytics package update. " f"\nView settings with 'yolo settings' or at '{self.file}'" - "\nUpdate settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'.") + "\nUpdate settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'." + ) self.reset() def load(self): @@ -910,14 +919,16 @@ def deprecation_warn(arg, new_arg, version=None): """Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument.""" if not version: version = float(__version__[:3]) + 0.2 # deprecate after 2nd major release - LOGGER.warning(f"WARNING ⚠️ '{arg}' is deprecated and will be removed in 'ultralytics {version}' in the future. " - f"Please use '{new_arg}' instead.") + LOGGER.warning( + f"WARNING ⚠️ '{arg}' is deprecated and will be removed in 'ultralytics {version}' in the future. " + f"Please use '{new_arg}' instead." + ) def clean_url(url): """Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt.""" - url = Path(url).as_posix().replace(':/', '://') # Pathlib turns :// -> :/, as_posix() for Windows - return urllib.parse.unquote(url).split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth + url = Path(url).as_posix().replace(":/", "://") # Pathlib turns :// -> :/, as_posix() for Windows + return urllib.parse.unquote(url).split("?")[0] # '%2F' to '/', split https://url.com/file.txt?auth def url2file(url): @@ -928,13 +939,22 @@ def url2file(url): # Run below code on utils init ------------------------------------------------------------------------------------ # Check first-install steps -PREFIX = colorstr('Ultralytics: ') +PREFIX = colorstr("Ultralytics: ") SETTINGS = SettingsManager() # initialize settings -DATASETS_DIR = Path(SETTINGS['datasets_dir']) # global datasets directory -WEIGHTS_DIR = Path(SETTINGS['weights_dir']) # global weights directory -RUNS_DIR = Path(SETTINGS['runs_dir']) # global runs directory -ENVIRONMENT = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \ - 'Docker' if is_docker() else platform.system() +DATASETS_DIR = Path(SETTINGS["datasets_dir"]) # global datasets directory +WEIGHTS_DIR = Path(SETTINGS["weights_dir"]) # global weights directory +RUNS_DIR = Path(SETTINGS["runs_dir"]) # global runs directory +ENVIRONMENT = ( + "Colab" + if is_colab() + else "Kaggle" + if is_kaggle() + else "Jupyter" + if is_jupyter() + else "Docker" + if is_docker() + else platform.system() +) TESTS_RUNNING = is_pytest_running() or is_github_action_running() set_sentry() diff --git a/ultralytics/utils/autobatch.py b/ultralytics/utils/autobatch.py index 172a4c12..daea14ed 100644 --- a/ultralytics/utils/autobatch.py +++ b/ultralytics/utils/autobatch.py @@ -42,14 +42,14 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch): """ # Check device - prefix = colorstr('AutoBatch: ') - LOGGER.info(f'{prefix}Computing optimal batch size for imgsz={imgsz}') + prefix = colorstr("AutoBatch: ") + LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz}") device = next(model.parameters()).device # get model device - if device.type == 'cpu': - LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}') + if device.type == "cpu": + LOGGER.info(f"{prefix}CUDA not detected, using default CPU batch-size {batch_size}") return batch_size if torch.backends.cudnn.benchmark: - LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}') + LOGGER.info(f"{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}") return batch_size # Inspect CUDA memory @@ -60,7 +60,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch): r = torch.cuda.memory_reserved(device) / gb # GiB reserved a = torch.cuda.memory_allocated(device) / gb # GiB allocated f = t - (r + a) # GiB free - LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free') + LOGGER.info(f"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free") # Profile batch sizes batch_sizes = [1, 2, 4, 8, 16] @@ -70,7 +70,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch): # Fit a solution y = [x[2] for x in results if x] # memory [2] - p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit + p = np.polyfit(batch_sizes[: len(y)], y, deg=1) # first degree polynomial fit b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size) if None in results: # some sizes failed i = results.index(None) # first fail index @@ -78,11 +78,11 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch): b = batch_sizes[max(i - 1, 0)] # select prior safe point if b < 1 or b > 1024: # b outside of safe range b = batch_size - LOGGER.info(f'{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.') + LOGGER.info(f"{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.") fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted - LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅') + LOGGER.info(f"{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅") return b except Exception as e: - LOGGER.warning(f'{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.') + LOGGER.warning(f"{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.") return batch_size diff --git a/ultralytics/utils/benchmarks.py b/ultralytics/utils/benchmarks.py index 4842ff5b..66621f53 100644 --- a/ultralytics/utils/benchmarks.py +++ b/ultralytics/utils/benchmarks.py @@ -42,13 +42,9 @@ from ultralytics.utils.files import file_size from ultralytics.utils.torch_utils import select_device -def benchmark(model=WEIGHTS_DIR / 'yolov8n.pt', - data=None, - imgsz=160, - half=False, - int8=False, - device='cpu', - verbose=False): +def benchmark( + model=WEIGHTS_DIR / "yolov8n.pt", data=None, imgsz=160, half=False, int8=False, device="cpu", verbose=False +): """ Benchmark a YOLO model across different formats for speed and accuracy. @@ -76,6 +72,7 @@ def benchmark(model=WEIGHTS_DIR / 'yolov8n.pt', """ import pandas as pd + pd.options.display.max_columns = 10 pd.options.display.width = 120 device = select_device(device, verbose=False) @@ -85,67 +82,62 @@ def benchmark(model=WEIGHTS_DIR / 'yolov8n.pt', y = [] t0 = time.time() for i, (name, format, suffix, cpu, gpu) in export_formats().iterrows(): # index, (name, format, suffix, CPU, GPU) - emoji, filename = '❌', None # export defaults + emoji, filename = "❌", None # export defaults try: - assert i != 9 or LINUX, 'Edge TPU export only supported on Linux' + assert i != 9 or LINUX, "Edge TPU export only supported on Linux" if i == 10: - assert MACOS or LINUX, 'TF.js export only supported on macOS and Linux' + assert MACOS or LINUX, "TF.js export only supported on macOS and Linux" elif i == 11: - assert sys.version_info < (3, 11), 'PaddlePaddle export only supported on Python<=3.10' - if 'cpu' in device.type: - assert cpu, 'inference not supported on CPU' - if 'cuda' in device.type: - assert gpu, 'inference not supported on GPU' + assert sys.version_info < (3, 11), "PaddlePaddle export only supported on Python<=3.10" + if "cpu" in device.type: + assert cpu, "inference not supported on CPU" + if "cuda" in device.type: + assert gpu, "inference not supported on GPU" # Export - if format == '-': + if format == "-": filename = model.ckpt_path or model.cfg exported_model = model # PyTorch format else: filename = model.export(imgsz=imgsz, format=format, half=half, int8=int8, device=device, verbose=False) exported_model = YOLO(filename, task=model.task) - assert suffix in str(filename), 'export failed' - emoji = '❎' # indicates export succeeded + assert suffix in str(filename), "export failed" + emoji = "❎" # indicates export succeeded # Predict - assert model.task != 'pose' or i != 7, 'GraphDef Pose inference is not supported' - assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported - assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML - exported_model.predict(ASSETS / 'bus.jpg', imgsz=imgsz, device=device, half=half) + assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported" + assert i not in (9, 10), "inference not supported" # Edge TPU and TF.js are unsupported + assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML + exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half) # Validate data = data or TASK2DATA[model.task] # task to dataset, i.e. coco8.yaml for task=detect key = TASK2METRIC[model.task] # task to metric, i.e. metrics/mAP50-95(B) for task=detect - results = exported_model.val(data=data, - batch=1, - imgsz=imgsz, - plots=False, - device=device, - half=half, - int8=int8, - verbose=False) - metric, speed = results.results_dict[key], results.speed['inference'] - y.append([name, '✅', round(file_size(filename), 1), round(metric, 4), round(speed, 2)]) + results = exported_model.val( + data=data, batch=1, imgsz=imgsz, plots=False, device=device, half=half, int8=int8, verbose=False + ) + metric, speed = results.results_dict[key], results.speed["inference"] + y.append([name, "✅", round(file_size(filename), 1), round(metric, 4), round(speed, 2)]) except Exception as e: if verbose: - assert type(e) is AssertionError, f'Benchmark failure for {name}: {e}' - LOGGER.warning(f'ERROR ❌️ Benchmark failure for {name}: {e}') + assert type(e) is AssertionError, f"Benchmark failure for {name}: {e}" + LOGGER.warning(f"ERROR ❌️ Benchmark failure for {name}: {e}") y.append([name, emoji, round(file_size(filename), 1), None, None]) # mAP, t_inference # Print results check_yolo(device=device) # print system info - df = pd.DataFrame(y, columns=['Format', 'Status❔', 'Size (MB)', key, 'Inference time (ms/im)']) + df = pd.DataFrame(y, columns=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)"]) name = Path(model.ckpt_path).name - s = f'\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n' + s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n" LOGGER.info(s) - with open('benchmarks.log', 'a', errors='ignore', encoding='utf-8') as f: + with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f: f.write(s) if verbose and isinstance(verbose, float): metrics = df[key].array # values to compare to floor floor = verbose # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n - assert all(x > floor for x in metrics if pd.notna(x)), f'Benchmark failure: metric(s) < floor {floor}' + assert all(x > floor for x in metrics if pd.notna(x)), f"Benchmark failure: metric(s) < floor {floor}" return df @@ -175,15 +167,17 @@ class ProfileModels: ``` """ - def __init__(self, - paths: list, - num_timed_runs=100, - num_warmup_runs=10, - min_time=60, - imgsz=640, - half=True, - trt=True, - device=None): + def __init__( + self, + paths: list, + num_timed_runs=100, + num_warmup_runs=10, + min_time=60, + imgsz=640, + half=True, + trt=True, + device=None, + ): """ Initialize the ProfileModels class for profiling models. @@ -204,37 +198,32 @@ class ProfileModels: self.imgsz = imgsz self.half = half self.trt = trt # run TensorRT profiling - self.device = device or torch.device(0 if torch.cuda.is_available() else 'cpu') + self.device = device or torch.device(0 if torch.cuda.is_available() else "cpu") def profile(self): """Logs the benchmarking results of a model, checks metrics against floor and returns the results.""" files = self.get_files() if not files: - print('No matching *.pt or *.onnx files found.') + print("No matching *.pt or *.onnx files found.") return table_rows = [] output = [] for file in files: - engine_file = file.with_suffix('.engine') - if file.suffix in ('.pt', '.yaml', '.yml'): + engine_file = file.with_suffix(".engine") + if file.suffix in (".pt", ".yaml", ".yml"): model = YOLO(str(file)) model.fuse() # to report correct params and GFLOPs in model.info() model_info = model.info() - if self.trt and self.device.type != 'cpu' and not engine_file.is_file(): - engine_file = model.export(format='engine', - half=self.half, - imgsz=self.imgsz, - device=self.device, - verbose=False) - onnx_file = model.export(format='onnx', - half=self.half, - imgsz=self.imgsz, - simplify=True, - device=self.device, - verbose=False) - elif file.suffix == '.onnx': + if self.trt and self.device.type != "cpu" and not engine_file.is_file(): + engine_file = model.export( + format="engine", half=self.half, imgsz=self.imgsz, device=self.device, verbose=False + ) + onnx_file = model.export( + format="onnx", half=self.half, imgsz=self.imgsz, simplify=True, device=self.device, verbose=False + ) + elif file.suffix == ".onnx": model_info = self.get_onnx_model_info(file) onnx_file = file else: @@ -254,14 +243,14 @@ class ProfileModels: for path in self.paths: path = Path(path) if path.is_dir(): - extensions = ['*.pt', '*.onnx', '*.yaml'] + extensions = ["*.pt", "*.onnx", "*.yaml"] files.extend([file for ext in extensions for file in glob.glob(str(path / ext))]) - elif path.suffix in {'.pt', '.yaml', '.yml'}: # add non-existing + elif path.suffix in {".pt", ".yaml", ".yml"}: # add non-existing files.append(str(path)) else: files.extend(glob.glob(str(path))) - print(f'Profiling: {sorted(files)}') + print(f"Profiling: {sorted(files)}") return [Path(file) for file in sorted(files)] def get_onnx_model_info(self, onnx_file: str): @@ -306,7 +295,7 @@ class ProfileModels: run_times = [] for _ in TQDM(range(num_runs), desc=engine_file): results = model(input_data, imgsz=self.imgsz, verbose=False) - run_times.append(results[0].speed['inference']) # Convert to milliseconds + run_times.append(results[0].speed["inference"]) # Convert to milliseconds run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping return np.mean(run_times), np.std(run_times) @@ -315,31 +304,31 @@ class ProfileModels: """Profiles an ONNX model by executing it multiple times and returns the mean and standard deviation of run times. """ - check_requirements('onnxruntime') + check_requirements("onnxruntime") import onnxruntime as ort # Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider' sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.intra_op_num_threads = 8 # Limit the number of threads - sess = ort.InferenceSession(onnx_file, sess_options, providers=['CPUExecutionProvider']) + sess = ort.InferenceSession(onnx_file, sess_options, providers=["CPUExecutionProvider"]) input_tensor = sess.get_inputs()[0] input_type = input_tensor.type # Mapping ONNX datatype to numpy datatype - if 'float16' in input_type: + if "float16" in input_type: input_dtype = np.float16 - elif 'float' in input_type: + elif "float" in input_type: input_dtype = np.float32 - elif 'double' in input_type: + elif "double" in input_type: input_dtype = np.float64 - elif 'int64' in input_type: + elif "int64" in input_type: input_dtype = np.int64 - elif 'int32' in input_type: + elif "int32" in input_type: input_dtype = np.int32 else: - raise ValueError(f'Unsupported ONNX datatype {input_type}') + raise ValueError(f"Unsupported ONNX datatype {input_type}") input_data = np.random.rand(*input_tensor.shape).astype(input_dtype) input_name = input_tensor.name @@ -369,25 +358,26 @@ class ProfileModels: def generate_table_row(self, model_name, t_onnx, t_engine, model_info): """Generates a formatted string for a table row that includes model performance and metric details.""" layers, params, gradients, flops = model_info - return f'| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |' + return f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |" def generate_results_dict(self, model_name, t_onnx, t_engine, model_info): """Generates a dictionary of model details including name, parameters, GFLOPS and speed metrics.""" layers, params, gradients, flops = model_info return { - 'model/name': model_name, - 'model/parameters': params, - 'model/GFLOPs': round(flops, 3), - 'model/speed_ONNX(ms)': round(t_onnx[0], 3), - 'model/speed_TensorRT(ms)': round(t_engine[0], 3)} + "model/name": model_name, + "model/parameters": params, + "model/GFLOPs": round(flops, 3), + "model/speed_ONNX(ms)": round(t_onnx[0], 3), + "model/speed_TensorRT(ms)": round(t_engine[0], 3), + } def print_table(self, table_rows): """Formats and prints a comparison table for different models with given statistics and performance data.""" - gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'GPU' - header = f'| Model | size
(pixels) | mAPval
50-95 | Speed
CPU ONNX
(ms) | Speed
{gpu} TensorRT
(ms) | params
(M) | FLOPs
(B) |' - separator = '|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|' + gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU" + header = f"| Model | size
(pixels) | mAPval
50-95 | Speed
CPU ONNX
(ms) | Speed
{gpu} TensorRT
(ms) | params
(M) | FLOPs
(B) |" + separator = "|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|" - print(f'\n\n{header}') + print(f"\n\n{header}") print(separator) for row in table_rows: print(row) diff --git a/ultralytics/utils/callbacks/__init__.py b/ultralytics/utils/callbacks/__init__.py index 8ad4ad6e..116babe9 100644 --- a/ultralytics/utils/callbacks/__init__.py +++ b/ultralytics/utils/callbacks/__init__.py @@ -2,4 +2,4 @@ from .base import add_integration_callbacks, default_callbacks, get_default_callbacks -__all__ = 'add_integration_callbacks', 'default_callbacks', 'get_default_callbacks' +__all__ = "add_integration_callbacks", "default_callbacks", "get_default_callbacks" diff --git a/ultralytics/utils/callbacks/base.py b/ultralytics/utils/callbacks/base.py index 211ae5bf..e1e9b42b 100644 --- a/ultralytics/utils/callbacks/base.py +++ b/ultralytics/utils/callbacks/base.py @@ -143,37 +143,35 @@ def on_export_end(exporter): default_callbacks = { # Run in trainer - 'on_pretrain_routine_start': [on_pretrain_routine_start], - 'on_pretrain_routine_end': [on_pretrain_routine_end], - 'on_train_start': [on_train_start], - 'on_train_epoch_start': [on_train_epoch_start], - 'on_train_batch_start': [on_train_batch_start], - 'optimizer_step': [optimizer_step], - 'on_before_zero_grad': [on_before_zero_grad], - 'on_train_batch_end': [on_train_batch_end], - 'on_train_epoch_end': [on_train_epoch_end], - 'on_fit_epoch_end': [on_fit_epoch_end], # fit = train + val - 'on_model_save': [on_model_save], - 'on_train_end': [on_train_end], - 'on_params_update': [on_params_update], - 'teardown': [teardown], - + "on_pretrain_routine_start": [on_pretrain_routine_start], + "on_pretrain_routine_end": [on_pretrain_routine_end], + "on_train_start": [on_train_start], + "on_train_epoch_start": [on_train_epoch_start], + "on_train_batch_start": [on_train_batch_start], + "optimizer_step": [optimizer_step], + "on_before_zero_grad": [on_before_zero_grad], + "on_train_batch_end": [on_train_batch_end], + "on_train_epoch_end": [on_train_epoch_end], + "on_fit_epoch_end": [on_fit_epoch_end], # fit = train + val + "on_model_save": [on_model_save], + "on_train_end": [on_train_end], + "on_params_update": [on_params_update], + "teardown": [teardown], # Run in validator - 'on_val_start': [on_val_start], - 'on_val_batch_start': [on_val_batch_start], - 'on_val_batch_end': [on_val_batch_end], - 'on_val_end': [on_val_end], - + "on_val_start": [on_val_start], + "on_val_batch_start": [on_val_batch_start], + "on_val_batch_end": [on_val_batch_end], + "on_val_end": [on_val_end], # Run in predictor - 'on_predict_start': [on_predict_start], - 'on_predict_batch_start': [on_predict_batch_start], - 'on_predict_postprocess_end': [on_predict_postprocess_end], - 'on_predict_batch_end': [on_predict_batch_end], - 'on_predict_end': [on_predict_end], - + "on_predict_start": [on_predict_start], + "on_predict_batch_start": [on_predict_batch_start], + "on_predict_postprocess_end": [on_predict_postprocess_end], + "on_predict_batch_end": [on_predict_batch_end], + "on_predict_end": [on_predict_end], # Run in exporter - 'on_export_start': [on_export_start], - 'on_export_end': [on_export_end]} + "on_export_start": [on_export_start], + "on_export_end": [on_export_end], +} def get_default_callbacks(): @@ -197,10 +195,11 @@ def add_integration_callbacks(instance): # Load HUB callbacks from .hub import callbacks as hub_cb + callbacks_list = [hub_cb] # Load training callbacks - if 'Trainer' in instance.__class__.__name__: + if "Trainer" in instance.__class__.__name__: from .clearml import callbacks as clear_cb from .comet import callbacks as comet_cb from .dvc import callbacks as dvc_cb @@ -209,6 +208,7 @@ def add_integration_callbacks(instance): from .raytune import callbacks as tune_cb from .tensorboard import callbacks as tb_cb from .wb import callbacks as wb_cb + callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb]) # Add the callbacks to the callbacks dictionary diff --git a/ultralytics/utils/callbacks/clearml.py b/ultralytics/utils/callbacks/clearml.py index dc0b2716..a030fc5e 100644 --- a/ultralytics/utils/callbacks/clearml.py +++ b/ultralytics/utils/callbacks/clearml.py @@ -4,19 +4,19 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING try: assert not TESTS_RUNNING # do not log pytest - assert SETTINGS['clearml'] is True # verify integration is enabled + assert SETTINGS["clearml"] is True # verify integration is enabled import clearml from clearml import Task from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO from clearml.binding.matplotlib_bind import PatchedMatplotlib - assert hasattr(clearml, '__version__') # verify package is not directory + assert hasattr(clearml, "__version__") # verify package is not directory except (ImportError, AssertionError): clearml = None -def _log_debug_samples(files, title='Debug Samples') -> None: +def _log_debug_samples(files, title="Debug Samples") -> None: """ Log files (images) as debug samples in the ClearML task. @@ -29,12 +29,11 @@ def _log_debug_samples(files, title='Debug Samples') -> None: if task := Task.current_task(): for f in files: if f.exists(): - it = re.search(r'_batch(\d+)', f.name) + it = re.search(r"_batch(\d+)", f.name) iteration = int(it.groups()[0]) if it else 0 - task.get_logger().report_image(title=title, - series=f.name.replace(it.group(), ''), - local_path=str(f), - iteration=iteration) + task.get_logger().report_image( + title=title, series=f.name.replace(it.group(), ""), local_path=str(f), iteration=iteration + ) def _log_plot(title, plot_path) -> None: @@ -50,13 +49,12 @@ def _log_plot(title, plot_path) -> None: img = mpimg.imread(plot_path) fig = plt.figure() - ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks + ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks ax.imshow(img) - Task.current_task().get_logger().report_matplotlib_figure(title=title, - series='', - figure=fig, - report_interactive=False) + Task.current_task().get_logger().report_matplotlib_figure( + title=title, series="", figure=fig, report_interactive=False + ) def on_pretrain_routine_start(trainer): @@ -68,19 +66,21 @@ def on_pretrain_routine_start(trainer): PatchPyTorchModelIO.update_current_task(None) PatchedMatplotlib.update_current_task(None) else: - task = Task.init(project_name=trainer.args.project or 'YOLOv8', - task_name=trainer.args.name, - tags=['YOLOv8'], - output_uri=True, - reuse_last_task_id=False, - auto_connect_frameworks={ - 'pytorch': False, - 'matplotlib': False}) - LOGGER.warning('ClearML Initialized a new task. If you want to run remotely, ' - 'please add clearml-init and connect your arguments before initializing YOLO.') - task.connect(vars(trainer.args), name='General') + task = Task.init( + project_name=trainer.args.project or "YOLOv8", + task_name=trainer.args.name, + tags=["YOLOv8"], + output_uri=True, + reuse_last_task_id=False, + auto_connect_frameworks={"pytorch": False, "matplotlib": False}, + ) + LOGGER.warning( + "ClearML Initialized a new task. If you want to run remotely, " + "please add clearml-init and connect your arguments before initializing YOLO." + ) + task.connect(vars(trainer.args), name="General") except Exception as e: - LOGGER.warning(f'WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}') + LOGGER.warning(f"WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}") def on_train_epoch_end(trainer): @@ -88,26 +88,26 @@ def on_train_epoch_end(trainer): if task := Task.current_task(): # Log debug samples if trainer.epoch == 1: - _log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic') + _log_debug_samples(sorted(trainer.save_dir.glob("train_batch*.jpg")), "Mosaic") # Report the current training progress - for k, v in trainer.label_loss_items(trainer.tloss, prefix='train').items(): - task.get_logger().report_scalar('train', k, v, iteration=trainer.epoch) + for k, v in trainer.label_loss_items(trainer.tloss, prefix="train").items(): + task.get_logger().report_scalar("train", k, v, iteration=trainer.epoch) for k, v in trainer.lr.items(): - task.get_logger().report_scalar('lr', k, v, iteration=trainer.epoch) + task.get_logger().report_scalar("lr", k, v, iteration=trainer.epoch) def on_fit_epoch_end(trainer): """Reports model information to logger at the end of an epoch.""" if task := Task.current_task(): # You should have access to the validation bboxes under jdict - task.get_logger().report_scalar(title='Epoch Time', - series='Epoch Time', - value=trainer.epoch_time, - iteration=trainer.epoch) + task.get_logger().report_scalar( + title="Epoch Time", series="Epoch Time", value=trainer.epoch_time, iteration=trainer.epoch + ) for k, v in trainer.metrics.items(): - task.get_logger().report_scalar('val', k, v, iteration=trainer.epoch) + task.get_logger().report_scalar("val", k, v, iteration=trainer.epoch) if trainer.epoch == 0: from ultralytics.utils.torch_utils import model_info_for_loggers + for k, v in model_info_for_loggers(trainer).items(): task.get_logger().report_single_value(k, v) @@ -116,7 +116,7 @@ def on_val_end(validator): """Logs validation results including labels and predictions.""" if Task.current_task(): # Log val_labels and val_pred - _log_debug_samples(sorted(validator.save_dir.glob('val*.jpg')), 'Validation') + _log_debug_samples(sorted(validator.save_dir.glob("val*.jpg")), "Validation") def on_train_end(trainer): @@ -124,8 +124,11 @@ def on_train_end(trainer): if task := Task.current_task(): # Log final results, CM matrix + PR plots files = [ - 'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png', - *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))] + "results.png", + "confusion_matrix.png", + "confusion_matrix_normalized.png", + *(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")), + ] files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter for f in files: _log_plot(title=f.stem, plot_path=f) @@ -136,9 +139,14 @@ def on_train_end(trainer): task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False) -callbacks = { - 'on_pretrain_routine_start': on_pretrain_routine_start, - 'on_train_epoch_end': on_train_epoch_end, - 'on_fit_epoch_end': on_fit_epoch_end, - 'on_val_end': on_val_end, - 'on_train_end': on_train_end} if clearml else {} +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_train_epoch_end": on_train_epoch_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_val_end": on_val_end, + "on_train_end": on_train_end, + } + if clearml + else {} +) diff --git a/ultralytics/utils/callbacks/comet.py b/ultralytics/utils/callbacks/comet.py index e8016f4e..2ed60577 100644 --- a/ultralytics/utils/callbacks/comet.py +++ b/ultralytics/utils/callbacks/comet.py @@ -4,20 +4,20 @@ from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops try: assert not TESTS_RUNNING # do not log pytest - assert SETTINGS['comet'] is True # verify integration is enabled + assert SETTINGS["comet"] is True # verify integration is enabled import comet_ml - assert hasattr(comet_ml, '__version__') # verify package is not directory + assert hasattr(comet_ml, "__version__") # verify package is not directory import os from pathlib import Path # Ensures certain logging functions only run for supported tasks - COMET_SUPPORTED_TASKS = ['detect'] + COMET_SUPPORTED_TASKS = ["detect"] # Names of plots created by YOLOv8 that are logged to Comet - EVALUATION_PLOT_NAMES = 'F1_curve', 'P_curve', 'R_curve', 'PR_curve', 'confusion_matrix' - LABEL_PLOT_NAMES = 'labels', 'labels_correlogram' + EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve", "confusion_matrix" + LABEL_PLOT_NAMES = "labels", "labels_correlogram" _comet_image_prediction_count = 0 @@ -27,43 +27,43 @@ except (ImportError, AssertionError): def _get_comet_mode(): """Returns the mode of comet set in the environment variables, defaults to 'online' if not set.""" - return os.getenv('COMET_MODE', 'online') + return os.getenv("COMET_MODE", "online") def _get_comet_model_name(): """Returns the model name for Comet from the environment variable 'COMET_MODEL_NAME' or defaults to 'YOLOv8'.""" - return os.getenv('COMET_MODEL_NAME', 'YOLOv8') + return os.getenv("COMET_MODEL_NAME", "YOLOv8") def _get_eval_batch_logging_interval(): """Get the evaluation batch logging interval from environment variable or use default value 1.""" - return int(os.getenv('COMET_EVAL_BATCH_LOGGING_INTERVAL', 1)) + return int(os.getenv("COMET_EVAL_BATCH_LOGGING_INTERVAL", 1)) def _get_max_image_predictions_to_log(): """Get the maximum number of image predictions to log from the environment variables.""" - return int(os.getenv('COMET_MAX_IMAGE_PREDICTIONS', 100)) + return int(os.getenv("COMET_MAX_IMAGE_PREDICTIONS", 100)) def _scale_confidence_score(score): """Scales the given confidence score by a factor specified in an environment variable.""" - scale = float(os.getenv('COMET_MAX_CONFIDENCE_SCORE', 100.0)) + scale = float(os.getenv("COMET_MAX_CONFIDENCE_SCORE", 100.0)) return score * scale def _should_log_confusion_matrix(): """Determines if the confusion matrix should be logged based on the environment variable settings.""" - return os.getenv('COMET_EVAL_LOG_CONFUSION_MATRIX', 'false').lower() == 'true' + return os.getenv("COMET_EVAL_LOG_CONFUSION_MATRIX", "false").lower() == "true" def _should_log_image_predictions(): """Determines whether to log image predictions based on a specified environment variable.""" - return os.getenv('COMET_EVAL_LOG_IMAGE_PREDICTIONS', 'true').lower() == 'true' + return os.getenv("COMET_EVAL_LOG_IMAGE_PREDICTIONS", "true").lower() == "true" def _get_experiment_type(mode, project_name): """Return an experiment based on mode and project name.""" - if mode == 'offline': + if mode == "offline": return comet_ml.OfflineExperiment(project_name=project_name) return comet_ml.Experiment(project_name=project_name) @@ -75,18 +75,21 @@ def _create_experiment(args): return try: comet_mode = _get_comet_mode() - _project_name = os.getenv('COMET_PROJECT_NAME', args.project) + _project_name = os.getenv("COMET_PROJECT_NAME", args.project) experiment = _get_experiment_type(comet_mode, _project_name) experiment.log_parameters(vars(args)) - experiment.log_others({ - 'eval_batch_logging_interval': _get_eval_batch_logging_interval(), - 'log_confusion_matrix_on_eval': _should_log_confusion_matrix(), - 'log_image_predictions': _should_log_image_predictions(), - 'max_image_predictions': _get_max_image_predictions_to_log(), }) - experiment.log_other('Created from', 'yolov8') + experiment.log_others( + { + "eval_batch_logging_interval": _get_eval_batch_logging_interval(), + "log_confusion_matrix_on_eval": _should_log_confusion_matrix(), + "log_image_predictions": _should_log_image_predictions(), + "max_image_predictions": _get_max_image_predictions_to_log(), + } + ) + experiment.log_other("Created from", "yolov8") except Exception as e: - LOGGER.warning(f'WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}') + LOGGER.warning(f"WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}") def _fetch_trainer_metadata(trainer): @@ -134,29 +137,32 @@ def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, origin def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None): """Format ground truth annotations for detection.""" - indices = batch['batch_idx'] == img_idx - bboxes = batch['bboxes'][indices] + indices = batch["batch_idx"] == img_idx + bboxes = batch["bboxes"][indices] if len(bboxes) == 0: - LOGGER.debug(f'COMET WARNING: Image: {image_path} has no bounding boxes labels') + LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes labels") return None - cls_labels = batch['cls'][indices].squeeze(1).tolist() + cls_labels = batch["cls"][indices].squeeze(1).tolist() if class_name_map: cls_labels = [str(class_name_map[label]) for label in cls_labels] - original_image_shape = batch['ori_shape'][img_idx] - resized_image_shape = batch['resized_shape'][img_idx] - ratio_pad = batch['ratio_pad'][img_idx] + original_image_shape = batch["ori_shape"][img_idx] + resized_image_shape = batch["resized_shape"][img_idx] + ratio_pad = batch["ratio_pad"][img_idx] data = [] for box, label in zip(bboxes, cls_labels): box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad) - data.append({ - 'boxes': [box], - 'label': f'gt_{label}', - 'score': _scale_confidence_score(1.0), }) + data.append( + { + "boxes": [box], + "label": f"gt_{label}", + "score": _scale_confidence_score(1.0), + } + ) - return {'name': 'ground_truth', 'data': data} + return {"name": "ground_truth", "data": data} def _format_prediction_annotations_for_detection(image_path, metadata, class_label_map=None): @@ -166,31 +172,34 @@ def _format_prediction_annotations_for_detection(image_path, metadata, class_lab predictions = metadata.get(image_id) if not predictions: - LOGGER.debug(f'COMET WARNING: Image: {image_path} has no bounding boxes predictions') + LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes predictions") return None data = [] for prediction in predictions: - boxes = prediction['bbox'] - score = _scale_confidence_score(prediction['score']) - cls_label = prediction['category_id'] + boxes = prediction["bbox"] + score = _scale_confidence_score(prediction["score"]) + cls_label = prediction["category_id"] if class_label_map: cls_label = str(class_label_map[cls_label]) - data.append({'boxes': [boxes], 'label': cls_label, 'score': score}) + data.append({"boxes": [boxes], "label": cls_label, "score": score}) - return {'name': 'prediction', 'data': data} + return {"name": "prediction", "data": data} def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map): """Join the ground truth and prediction annotations if they exist.""" - ground_truth_annotations = _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, - class_label_map) - prediction_annotations = _format_prediction_annotations_for_detection(image_path, prediction_metadata_map, - class_label_map) + ground_truth_annotations = _format_ground_truth_annotations_for_detection( + img_idx, image_path, batch, class_label_map + ) + prediction_annotations = _format_prediction_annotations_for_detection( + image_path, prediction_metadata_map, class_label_map + ) annotations = [ - annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None] + annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None + ] return [annotations] if annotations else None @@ -198,8 +207,8 @@ def _create_prediction_metadata_map(model_predictions): """Create metadata map for model predictions by groupings them based on image ID.""" pred_metadata_map = {} for prediction in model_predictions: - pred_metadata_map.setdefault(prediction['image_id'], []) - pred_metadata_map[prediction['image_id']].append(prediction) + pred_metadata_map.setdefault(prediction["image_id"], []) + pred_metadata_map[prediction["image_id"]].append(prediction) return pred_metadata_map @@ -207,7 +216,7 @@ def _create_prediction_metadata_map(model_predictions): def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch): """Log the confusion matrix to Comet experiment.""" conf_mat = trainer.validator.confusion_matrix.matrix - names = list(trainer.data['names'].values()) + ['background'] + names = list(trainer.data["names"].values()) + ["background"] experiment.log_confusion_matrix( matrix=conf_mat, labels=names, @@ -251,7 +260,7 @@ def _log_image_predictions(experiment, validator, curr_step): if (batch_idx + 1) % batch_logging_interval != 0: continue - image_paths = batch['im_file'] + image_paths = batch["im_file"] for img_idx, image_path in enumerate(image_paths): if _comet_image_prediction_count >= max_image_predictions: return @@ -275,10 +284,10 @@ def _log_image_predictions(experiment, validator, curr_step): def _log_plots(experiment, trainer): """Logs evaluation plots and label plots for the experiment.""" - plot_filenames = [trainer.save_dir / f'{plots}.png' for plots in EVALUATION_PLOT_NAMES] + plot_filenames = [trainer.save_dir / f"{plots}.png" for plots in EVALUATION_PLOT_NAMES] _log_images(experiment, plot_filenames, None) - label_plot_filenames = [trainer.save_dir / f'{labels}.jpg' for labels in LABEL_PLOT_NAMES] + label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES] _log_images(experiment, label_plot_filenames, None) @@ -288,7 +297,7 @@ def _log_model(experiment, trainer): experiment.log_model( model_name, file_or_folder=str(trainer.best), - file_name='best.pt', + file_name="best.pt", overwrite=True, ) @@ -296,7 +305,7 @@ def _log_model(experiment, trainer): def on_pretrain_routine_start(trainer): """Creates or resumes a CometML experiment at the start of a YOLO pre-training routine.""" experiment = comet_ml.get_global_experiment() - is_alive = getattr(experiment, 'alive', False) + is_alive = getattr(experiment, "alive", False) if not experiment or not is_alive: _create_experiment(trainer.args) @@ -308,17 +317,17 @@ def on_train_epoch_end(trainer): return metadata = _fetch_trainer_metadata(trainer) - curr_epoch = metadata['curr_epoch'] - curr_step = metadata['curr_step'] + curr_epoch = metadata["curr_epoch"] + curr_step = metadata["curr_step"] experiment.log_metrics( - trainer.label_loss_items(trainer.tloss, prefix='train'), + trainer.label_loss_items(trainer.tloss, prefix="train"), step=curr_step, epoch=curr_epoch, ) if curr_epoch == 1: - _log_images(experiment, trainer.save_dir.glob('train_batch*.jpg'), curr_step) + _log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step) def on_fit_epoch_end(trainer): @@ -328,14 +337,15 @@ def on_fit_epoch_end(trainer): return metadata = _fetch_trainer_metadata(trainer) - curr_epoch = metadata['curr_epoch'] - curr_step = metadata['curr_step'] - save_assets = metadata['save_assets'] + curr_epoch = metadata["curr_epoch"] + curr_step = metadata["curr_step"] + save_assets = metadata["save_assets"] experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch) experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch) if curr_epoch == 1: from ultralytics.utils.torch_utils import model_info_for_loggers + experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch) if not save_assets: @@ -355,8 +365,8 @@ def on_train_end(trainer): return metadata = _fetch_trainer_metadata(trainer) - curr_epoch = metadata['curr_epoch'] - curr_step = metadata['curr_step'] + curr_epoch = metadata["curr_epoch"] + curr_step = metadata["curr_step"] plots = trainer.args.plots _log_model(experiment, trainer) @@ -371,8 +381,13 @@ def on_train_end(trainer): _comet_image_prediction_count = 0 -callbacks = { - 'on_pretrain_routine_start': on_pretrain_routine_start, - 'on_train_epoch_end': on_train_epoch_end, - 'on_fit_epoch_end': on_fit_epoch_end, - 'on_train_end': on_train_end} if comet_ml else {} +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_train_epoch_end": on_train_epoch_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_train_end": on_train_end, + } + if comet_ml + else {} +) diff --git a/ultralytics/utils/callbacks/dvc.py b/ultralytics/utils/callbacks/dvc.py index 7fa05c6b..ab51dc52 100644 --- a/ultralytics/utils/callbacks/dvc.py +++ b/ultralytics/utils/callbacks/dvc.py @@ -4,9 +4,10 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks try: assert not TESTS_RUNNING # do not log pytest - assert SETTINGS['dvc'] is True # verify integration is enabled + assert SETTINGS["dvc"] is True # verify integration is enabled import dvclive - assert checks.check_version('dvclive', '2.11.0', verbose=True) + + assert checks.check_version("dvclive", "2.11.0", verbose=True) import os import re @@ -24,24 +25,24 @@ except (ImportError, AssertionError, TypeError): dvclive = None -def _log_images(path, prefix=''): +def _log_images(path, prefix=""): """Logs images at specified path with an optional prefix using DVCLive.""" if live: name = path.name # Group images by batch to enable sliders in UI - if m := re.search(r'_batch(\d+)', name): + if m := re.search(r"_batch(\d+)", name): ni = m[1] - new_stem = re.sub(r'_batch(\d+)', '_batch', path.stem) + new_stem = re.sub(r"_batch(\d+)", "_batch", path.stem) name = (Path(new_stem) / ni).with_suffix(path.suffix) live.log_image(os.path.join(prefix, name), path) -def _log_plots(plots, prefix=''): +def _log_plots(plots, prefix=""): """Logs plot images for training progress if they have not been previously processed.""" for name, params in plots.items(): - timestamp = params['timestamp'] + timestamp = params["timestamp"] if _processed_plots.get(name) != timestamp: _log_images(name, prefix) _processed_plots[name] = timestamp @@ -53,15 +54,15 @@ def _log_confusion_matrix(validator): preds = [] matrix = validator.confusion_matrix.matrix names = list(validator.names.values()) - if validator.confusion_matrix.task == 'detect': - names += ['background'] + if validator.confusion_matrix.task == "detect": + names += ["background"] for ti, pred in enumerate(matrix.T.astype(int)): for pi, num in enumerate(pred): targets.extend([names[ti]] * num) preds.extend([names[pi]] * num) - live.log_sklearn_plot('confusion_matrix', targets, preds, name='cf.json', normalized=True) + live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True) def on_pretrain_routine_start(trainer): @@ -71,12 +72,12 @@ def on_pretrain_routine_start(trainer): live = dvclive.Live(save_dvc_exp=True, cache_images=True) LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).") except Exception as e: - LOGGER.warning(f'WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}') + LOGGER.warning(f"WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}") def on_pretrain_routine_end(trainer): """Logs plots related to the training process at the end of the pretraining routine.""" - _log_plots(trainer.plots, 'train') + _log_plots(trainer.plots, "train") def on_train_start(trainer): @@ -95,17 +96,18 @@ def on_fit_epoch_end(trainer): """Logs training metrics and model info, and advances to next step on the end of each fit epoch.""" global _training_epoch if live and _training_epoch: - all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr} + all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr} for metric, value in all_metrics.items(): live.log_metric(metric, value) if trainer.epoch == 0: from ultralytics.utils.torch_utils import model_info_for_loggers + for metric, value in model_info_for_loggers(trainer).items(): live.log_metric(metric, value, plot=False) - _log_plots(trainer.plots, 'train') - _log_plots(trainer.validator.plots, 'val') + _log_plots(trainer.plots, "train") + _log_plots(trainer.validator.plots, "val") live.next_step() _training_epoch = False @@ -115,24 +117,29 @@ def on_train_end(trainer): """Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active.""" if live: # At the end log the best metrics. It runs validator on the best model internally. - all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr} + all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr} for metric, value in all_metrics.items(): live.log_metric(metric, value, plot=False) - _log_plots(trainer.plots, 'val') - _log_plots(trainer.validator.plots, 'val') + _log_plots(trainer.plots, "val") + _log_plots(trainer.validator.plots, "val") _log_confusion_matrix(trainer.validator) if trainer.best.exists(): - live.log_artifact(trainer.best, copy=True, type='model') + live.log_artifact(trainer.best, copy=True, type="model") live.end() -callbacks = { - 'on_pretrain_routine_start': on_pretrain_routine_start, - 'on_pretrain_routine_end': on_pretrain_routine_end, - 'on_train_start': on_train_start, - 'on_train_epoch_start': on_train_epoch_start, - 'on_fit_epoch_end': on_fit_epoch_end, - 'on_train_end': on_train_end} if dvclive else {} +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_pretrain_routine_end": on_pretrain_routine_end, + "on_train_start": on_train_start, + "on_train_epoch_start": on_train_epoch_start, + "on_fit_epoch_end": on_fit_epoch_end, + "on_train_end": on_train_end, + } + if dvclive + else {} +) diff --git a/ultralytics/utils/callbacks/hub.py b/ultralytics/utils/callbacks/hub.py index f3a3f353..60c74161 100644 --- a/ultralytics/utils/callbacks/hub.py +++ b/ultralytics/utils/callbacks/hub.py @@ -11,60 +11,62 @@ from ultralytics.utils import LOGGER, SETTINGS def on_pretrain_routine_end(trainer): """Logs info before starting timer for upload rate limit.""" - session = getattr(trainer, 'hub_session', None) + session = getattr(trainer, "hub_session", None) if session: # Start timer for upload rate limit session.timers = { - 'metrics': time(), - 'ckpt': time(), } # start timer on session.rate_limit + "metrics": time(), + "ckpt": time(), + } # start timer on session.rate_limit def on_fit_epoch_end(trainer): """Uploads training progress metrics at the end of each epoch.""" - session = getattr(trainer, 'hub_session', None) + session = getattr(trainer, "hub_session", None) if session: # Upload metrics after val end all_plots = { - **trainer.label_loss_items(trainer.tloss, prefix='train'), - **trainer.metrics, } + **trainer.label_loss_items(trainer.tloss, prefix="train"), + **trainer.metrics, + } if trainer.epoch == 0: from ultralytics.utils.torch_utils import model_info_for_loggers + all_plots = {**all_plots, **model_info_for_loggers(trainer)} session.metrics_queue[trainer.epoch] = json.dumps(all_plots) - if time() - session.timers['metrics'] > session.rate_limits['metrics']: + if time() - session.timers["metrics"] > session.rate_limits["metrics"]: session.upload_metrics() - session.timers['metrics'] = time() # reset timer + session.timers["metrics"] = time() # reset timer session.metrics_queue = {} # reset queue def on_model_save(trainer): """Saves checkpoints to Ultralytics HUB with rate limiting.""" - session = getattr(trainer, 'hub_session', None) + session = getattr(trainer, "hub_session", None) if session: # Upload checkpoints with rate limiting is_best = trainer.best_fitness == trainer.fitness - if time() - session.timers['ckpt'] > session.rate_limits['ckpt']: - LOGGER.info(f'{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_file}') + if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]: + LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_file}") session.upload_model(trainer.epoch, trainer.last, is_best) - session.timers['ckpt'] = time() # reset timer + session.timers["ckpt"] = time() # reset timer def on_train_end(trainer): """Upload final model and metrics to Ultralytics HUB at the end of training.""" - session = getattr(trainer, 'hub_session', None) + session = getattr(trainer, "hub_session", None) if session: # Upload final model and metrics with exponential standoff - LOGGER.info(f'{PREFIX}Syncing final model...') + LOGGER.info(f"{PREFIX}Syncing final model...") session.upload_model( trainer.epoch, trainer.best, - map=trainer.metrics.get('metrics/mAP50-95(B)', 0), + map=trainer.metrics.get("metrics/mAP50-95(B)", 0), final=True, ) session.alive = False # stop heartbeats - LOGGER.info(f'{PREFIX}Done ✅\n' - f'{PREFIX}View model at {session.model_url} 🚀') + LOGGER.info(f"{PREFIX}Done ✅\n" f"{PREFIX}View model at {session.model_url} 🚀") def on_train_start(trainer): @@ -87,12 +89,17 @@ def on_export_start(exporter): events(exporter.args) -callbacks = ({ - 'on_pretrain_routine_end': on_pretrain_routine_end, - 'on_fit_epoch_end': on_fit_epoch_end, - 'on_model_save': on_model_save, - 'on_train_end': on_train_end, - 'on_train_start': on_train_start, - 'on_val_start': on_val_start, - 'on_predict_start': on_predict_start, - 'on_export_start': on_export_start, } if SETTINGS['hub'] is True else {}) # verify enabled +callbacks = ( + { + "on_pretrain_routine_end": on_pretrain_routine_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_model_save": on_model_save, + "on_train_end": on_train_end, + "on_train_start": on_train_start, + "on_val_start": on_val_start, + "on_predict_start": on_predict_start, + "on_export_start": on_export_start, + } + if SETTINGS["hub"] is True + else {} +) # verify enabled diff --git a/ultralytics/utils/callbacks/mlflow.py b/ultralytics/utils/callbacks/mlflow.py index 91dcff5e..9676fe85 100644 --- a/ultralytics/utils/callbacks/mlflow.py +++ b/ultralytics/utils/callbacks/mlflow.py @@ -26,15 +26,15 @@ from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorst try: import os - assert not TESTS_RUNNING or 'test_mlflow' in os.environ.get('PYTEST_CURRENT_TEST', '') # do not log pytest - assert SETTINGS['mlflow'] is True # verify integration is enabled + assert not TESTS_RUNNING or "test_mlflow" in os.environ.get("PYTEST_CURRENT_TEST", "") # do not log pytest + assert SETTINGS["mlflow"] is True # verify integration is enabled import mlflow - assert hasattr(mlflow, '__version__') # verify package is not directory + assert hasattr(mlflow, "__version__") # verify package is not directory from pathlib import Path - PREFIX = colorstr('MLflow: ') - SANITIZE = lambda x: {k.replace('(', '').replace(')', ''): float(v) for k, v in x.items()} + PREFIX = colorstr("MLflow: ") + SANITIZE = lambda x: {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()} except (ImportError, AssertionError): mlflow = None @@ -61,33 +61,33 @@ def on_pretrain_routine_end(trainer): """ global mlflow - uri = os.environ.get('MLFLOW_TRACKING_URI') or str(RUNS_DIR / 'mlflow') - LOGGER.debug(f'{PREFIX} tracking uri: {uri}') + uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow") + LOGGER.debug(f"{PREFIX} tracking uri: {uri}") mlflow.set_tracking_uri(uri) # Set experiment and run names - experiment_name = os.environ.get('MLFLOW_EXPERIMENT_NAME') or trainer.args.project or '/Shared/YOLOv8' - run_name = os.environ.get('MLFLOW_RUN') or trainer.args.name + experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/YOLOv8" + run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name mlflow.set_experiment(experiment_name) mlflow.autolog() try: active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name) - LOGGER.info(f'{PREFIX}logging run_id({active_run.info.run_id}) to {uri}') + LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}") if Path(uri).is_dir(): LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'") LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'") mlflow.log_params(dict(trainer.args)) except Exception as e: - LOGGER.warning(f'{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n' - f'{PREFIX}WARNING ⚠️ Not tracking this run') + LOGGER.warning(f"{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n" f"{PREFIX}WARNING ⚠️ Not tracking this run") def on_train_epoch_end(trainer): """Log training metrics at the end of each train epoch to MLflow.""" if mlflow: - mlflow.log_metrics(metrics=SANITIZE(trainer.label_loss_items(trainer.tloss, prefix='train')), - step=trainer.epoch) + mlflow.log_metrics( + metrics=SANITIZE(trainer.label_loss_items(trainer.tloss, prefix="train")), step=trainer.epoch + ) mlflow.log_metrics(metrics=SANITIZE(trainer.lr), step=trainer.epoch) @@ -101,16 +101,23 @@ def on_train_end(trainer): """Log model artifacts at the end of the training.""" if mlflow: mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt - for f in trainer.save_dir.glob('*'): # log all other files in save_dir - if f.suffix in {'.png', '.jpg', '.csv', '.pt', '.yaml'}: + for f in trainer.save_dir.glob("*"): # log all other files in save_dir + if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}: mlflow.log_artifact(str(f)) mlflow.end_run() - LOGGER.info(f'{PREFIX}results logged to {mlflow.get_tracking_uri()}\n' - f"{PREFIX}disable with 'yolo settings mlflow=False'") - - -callbacks = { - 'on_pretrain_routine_end': on_pretrain_routine_end, - 'on_fit_epoch_end': on_fit_epoch_end, - 'on_train_end': on_train_end} if mlflow else {} + LOGGER.info( + f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n" + f"{PREFIX}disable with 'yolo settings mlflow=False'" + ) + + +callbacks = ( + { + "on_pretrain_routine_end": on_pretrain_routine_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_train_end": on_train_end, + } + if mlflow + else {} +) diff --git a/ultralytics/utils/callbacks/neptune.py b/ultralytics/utils/callbacks/neptune.py index 088e3f8e..60c85371 100644 --- a/ultralytics/utils/callbacks/neptune.py +++ b/ultralytics/utils/callbacks/neptune.py @@ -4,11 +4,11 @@ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING try: assert not TESTS_RUNNING # do not log pytest - assert SETTINGS['neptune'] is True # verify integration is enabled + assert SETTINGS["neptune"] is True # verify integration is enabled import neptune from neptune.types import File - assert hasattr(neptune, '__version__') + assert hasattr(neptune, "__version__") run = None # NeptuneAI experiment logger instance @@ -23,11 +23,11 @@ def _log_scalars(scalars, step=0): run[k].append(value=v, step=step) -def _log_images(imgs_dict, group=''): +def _log_images(imgs_dict, group=""): """Log scalars to the NeptuneAI experiment logger.""" if run: for k, v in imgs_dict.items(): - run[f'{group}/{k}'].upload(File(v)) + run[f"{group}/{k}"].upload(File(v)) def _log_plot(title, plot_path): @@ -43,34 +43,35 @@ def _log_plot(title, plot_path): img = mpimg.imread(plot_path) fig = plt.figure() - ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks + ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks ax.imshow(img) - run[f'Plots/{title}'].upload(fig) + run[f"Plots/{title}"].upload(fig) def on_pretrain_routine_start(trainer): """Callback function called before the training routine starts.""" try: global run - run = neptune.init_run(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, tags=['YOLOv8']) - run['Configuration/Hyperparameters'] = {k: '' if v is None else v for k, v in vars(trainer.args).items()} + run = neptune.init_run(project=trainer.args.project or "YOLOv8", name=trainer.args.name, tags=["YOLOv8"]) + run["Configuration/Hyperparameters"] = {k: "" if v is None else v for k, v in vars(trainer.args).items()} except Exception as e: - LOGGER.warning(f'WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}') + LOGGER.warning(f"WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}") def on_train_epoch_end(trainer): """Callback function called at end of each training epoch.""" - _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1) + _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1) _log_scalars(trainer.lr, trainer.epoch + 1) if trainer.epoch == 1: - _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic') + _log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic") def on_fit_epoch_end(trainer): """Callback function called at end of each fit (train+val) epoch.""" if run and trainer.epoch == 0: from ultralytics.utils.torch_utils import model_info_for_loggers - run['Configuration/Model'] = model_info_for_loggers(trainer) + + run["Configuration/Model"] = model_info_for_loggers(trainer) _log_scalars(trainer.metrics, trainer.epoch + 1) @@ -78,7 +79,7 @@ def on_val_end(validator): """Callback function called at end of each validation.""" if run: # Log val_labels and val_pred - _log_images({f.stem: str(f) for f in validator.save_dir.glob('val*.jpg')}, 'Validation') + _log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation") def on_train_end(trainer): @@ -86,19 +87,28 @@ def on_train_end(trainer): if run: # Log final results, CM matrix + PR plots files = [ - 'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png', - *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))] + "results.png", + "confusion_matrix.png", + "confusion_matrix_normalized.png", + *(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")), + ] files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter for f in files: _log_plot(title=f.stem, plot_path=f) # Log the final model - run[f'weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}'].upload(File(str( - trainer.best))) - - -callbacks = { - 'on_pretrain_routine_start': on_pretrain_routine_start, - 'on_train_epoch_end': on_train_epoch_end, - 'on_fit_epoch_end': on_fit_epoch_end, - 'on_val_end': on_val_end, - 'on_train_end': on_train_end} if neptune else {} + run[f"weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}"].upload( + File(str(trainer.best)) + ) + + +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_train_epoch_end": on_train_epoch_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_val_end": on_val_end, + "on_train_end": on_train_end, + } + if neptune + else {} +) diff --git a/ultralytics/utils/callbacks/raytune.py b/ultralytics/utils/callbacks/raytune.py index 417b3314..f2694554 100644 --- a/ultralytics/utils/callbacks/raytune.py +++ b/ultralytics/utils/callbacks/raytune.py @@ -3,7 +3,7 @@ from ultralytics.utils import SETTINGS try: - assert SETTINGS['raytune'] is True # verify integration is enabled + assert SETTINGS["raytune"] is True # verify integration is enabled import ray from ray import tune from ray.air import session @@ -16,9 +16,14 @@ def on_fit_epoch_end(trainer): """Sends training metrics to Ray Tune at end of each epoch.""" if ray.tune.is_session_enabled(): metrics = trainer.metrics - metrics['epoch'] = trainer.epoch + metrics["epoch"] = trainer.epoch session.report(metrics) -callbacks = { - 'on_fit_epoch_end': on_fit_epoch_end, } if tune else {} +callbacks = ( + { + "on_fit_epoch_end": on_fit_epoch_end, + } + if tune + else {} +) diff --git a/ultralytics/utils/callbacks/tensorboard.py b/ultralytics/utils/callbacks/tensorboard.py index 86a50667..0a39d094 100644 --- a/ultralytics/utils/callbacks/tensorboard.py +++ b/ultralytics/utils/callbacks/tensorboard.py @@ -7,7 +7,7 @@ try: from torch.utils.tensorboard import SummaryWriter assert not TESTS_RUNNING # do not log pytest - assert SETTINGS['tensorboard'] is True # verify integration is enabled + assert SETTINGS["tensorboard"] is True # verify integration is enabled WRITER = None # TensorBoard SummaryWriter instance except (ImportError, AssertionError, TypeError): @@ -34,10 +34,10 @@ def _log_tensorboard_graph(trainer): p = next(trainer.model.parameters()) # for device, type im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty) with warnings.catch_warnings(): - warnings.simplefilter('ignore', category=UserWarning) # suppress jit trace warning + warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), []) except Exception as e: - LOGGER.warning(f'WARNING ⚠️ TensorBoard graph visualization failure {e}') + LOGGER.warning(f"WARNING ⚠️ TensorBoard graph visualization failure {e}") def on_pretrain_routine_start(trainer): @@ -46,10 +46,10 @@ def on_pretrain_routine_start(trainer): try: global WRITER WRITER = SummaryWriter(str(trainer.save_dir)) - prefix = colorstr('TensorBoard: ') + prefix = colorstr("TensorBoard: ") LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/") except Exception as e: - LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}') + LOGGER.warning(f"WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}") def on_train_start(trainer): @@ -60,7 +60,7 @@ def on_train_start(trainer): def on_train_epoch_end(trainer): """Logs scalar statistics at the end of a training epoch.""" - _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1) + _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1) _log_scalars(trainer.lr, trainer.epoch + 1) @@ -69,8 +69,13 @@ def on_fit_epoch_end(trainer): _log_scalars(trainer.metrics, trainer.epoch + 1) -callbacks = { - 'on_pretrain_routine_start': on_pretrain_routine_start, - 'on_train_start': on_train_start, - 'on_fit_epoch_end': on_fit_epoch_end, - 'on_train_epoch_end': on_train_epoch_end} if SummaryWriter else {} +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_train_start": on_train_start, + "on_fit_epoch_end": on_fit_epoch_end, + "on_train_epoch_end": on_train_epoch_end, + } + if SummaryWriter + else {} +) diff --git a/ultralytics/utils/callbacks/wb.py b/ultralytics/utils/callbacks/wb.py index 88f9bd7c..7f0f57ac 100644 --- a/ultralytics/utils/callbacks/wb.py +++ b/ultralytics/utils/callbacks/wb.py @@ -5,10 +5,10 @@ from ultralytics.utils.torch_utils import model_info_for_loggers try: assert not TESTS_RUNNING # do not log pytest - assert SETTINGS['wandb'] is True # verify integration is enabled + assert SETTINGS["wandb"] is True # verify integration is enabled import wandb as wb - assert hasattr(wb, '__version__') # verify package is not directory + assert hasattr(wb, "__version__") # verify package is not directory import numpy as np import pandas as pd @@ -19,7 +19,7 @@ except (ImportError, AssertionError): wb = None -def _custom_table(x, y, classes, title='Precision Recall Curve', x_title='Recall', y_title='Precision'): +def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"): """ Create and log a custom metric visualization to wandb.plot.pr_curve. @@ -37,24 +37,25 @@ def _custom_table(x, y, classes, title='Precision Recall Curve', x_title='Recall Returns: (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization. """ - df = pd.DataFrame({'class': classes, 'y': y, 'x': x}).round(3) - fields = {'x': 'x', 'y': 'y', 'class': 'class'} - string_fields = {'title': title, 'x-axis-title': x_title, 'y-axis-title': y_title} - return wb.plot_table('wandb/area-under-curve/v0', - wb.Table(dataframe=df), - fields=fields, - string_fields=string_fields) - - -def _plot_curve(x, - y, - names=None, - id='precision-recall', - title='Precision Recall Curve', - x_title='Recall', - y_title='Precision', - num_x=100, - only_mean=False): + df = pd.DataFrame({"class": classes, "y": y, "x": x}).round(3) + fields = {"x": "x", "y": "y", "class": "class"} + string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title} + return wb.plot_table( + "wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields + ) + + +def _plot_curve( + x, + y, + names=None, + id="precision-recall", + title="Precision Recall Curve", + x_title="Recall", + y_title="Precision", + num_x=100, + only_mean=False, +): """ Log a metric curve visualization. @@ -88,7 +89,7 @@ def _plot_curve(x, table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title]) wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)}) else: - classes = ['mean'] * len(x_log) + classes = ["mean"] * len(x_log) for i, yi in enumerate(y): x_log.extend(x_new) # add new x y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x @@ -99,7 +100,7 @@ def _plot_curve(x, def _log_plots(plots, step): """Logs plots from the input dictionary if they haven't been logged already at the specified step.""" for name, params in plots.items(): - timestamp = params['timestamp'] + timestamp = params["timestamp"] if _processed_plots.get(name) != timestamp: wb.run.log({name.stem: wb.Image(str(name))}, step=step) _processed_plots[name] = timestamp @@ -107,7 +108,7 @@ def _log_plots(plots, step): def on_pretrain_routine_start(trainer): """Initiate and start project if module is present.""" - wb.run or wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars(trainer.args)) + wb.run or wb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=vars(trainer.args)) def on_fit_epoch_end(trainer): @@ -121,7 +122,7 @@ def on_fit_epoch_end(trainer): def on_train_epoch_end(trainer): """Log metrics and save images at the end of each training epoch.""" - wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1) + wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1) wb.run.log(trainer.lr, step=trainer.epoch + 1) if trainer.epoch == 1: _log_plots(trainer.plots, step=trainer.epoch + 1) @@ -131,17 +132,17 @@ def on_train_end(trainer): """Save the best model as an artifact at end of training.""" _log_plots(trainer.validator.plots, step=trainer.epoch + 1) _log_plots(trainer.plots, step=trainer.epoch + 1) - art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model') + art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model") if trainer.best.exists(): art.add_file(trainer.best) - wb.run.log_artifact(art, aliases=['best']) + wb.run.log_artifact(art, aliases=["best"]) for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results): x, y, x_title, y_title = curve_values _plot_curve( x, y, names=list(trainer.validator.metrics.names.values()), - id=f'curves/{curve_name}', + id=f"curves/{curve_name}", title=curve_name, x_title=x_title, y_title=y_title, @@ -149,8 +150,13 @@ def on_train_end(trainer): wb.run.finish() # required or run continues on dashboard -callbacks = { - 'on_pretrain_routine_start': on_pretrain_routine_start, - 'on_train_epoch_end': on_train_epoch_end, - 'on_fit_epoch_end': on_fit_epoch_end, - 'on_train_end': on_train_end} if wb else {} +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_train_epoch_end": on_train_epoch_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_train_end": on_train_end, + } + if wb + else {} +) diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py index ed804fff..ba25fd46 100644 --- a/ultralytics/utils/checks.py +++ b/ultralytics/utils/checks.py @@ -21,12 +21,33 @@ import requests import torch from matplotlib import font_manager -from ultralytics.utils import (ASSETS, AUTOINSTALL, LINUX, LOGGER, ONLINE, ROOT, USER_CONFIG_DIR, SimpleNamespace, - ThreadingLocked, TryExcept, clean_url, colorstr, downloads, emojis, is_colab, is_docker, - is_github_action_running, is_jupyter, is_kaggle, is_online, is_pip_package, url2file) - - -def parse_requirements(file_path=ROOT.parent / 'requirements.txt', package=''): +from ultralytics.utils import ( + ASSETS, + AUTOINSTALL, + LINUX, + LOGGER, + ONLINE, + ROOT, + USER_CONFIG_DIR, + SimpleNamespace, + ThreadingLocked, + TryExcept, + clean_url, + colorstr, + downloads, + emojis, + is_colab, + is_docker, + is_github_action_running, + is_jupyter, + is_kaggle, + is_online, + is_pip_package, + url2file, +) + + +def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""): """ Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'. @@ -46,23 +67,23 @@ def parse_requirements(file_path=ROOT.parent / 'requirements.txt', package=''): """ if package: - requires = [x for x in metadata.distribution(package).requires if 'extra == ' not in x] + requires = [x for x in metadata.distribution(package).requires if "extra == " not in x] else: requires = Path(file_path).read_text().splitlines() requirements = [] for line in requires: line = line.strip() - if line and not line.startswith('#'): - line = line.split('#')[0].strip() # ignore inline comments - match = re.match(r'([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?', line) + if line and not line.startswith("#"): + line = line.split("#")[0].strip() # ignore inline comments + match = re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line) if match: - requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else '')) + requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else "")) return requirements -def parse_version(version='0.0.0') -> tuple: +def parse_version(version="0.0.0") -> tuple: """ Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This function replaces deprecated 'pkg_resources.parse_version(v)'. @@ -74,9 +95,9 @@ def parse_version(version='0.0.0') -> tuple: (tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1) """ try: - return tuple(map(int, re.findall(r'\d+', version)[:3])) # '2.0.1+cpu' -> (2, 0, 1) + return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1) except Exception as e: - LOGGER.warning(f'WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}') + LOGGER.warning(f"WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}") return 0, 0, 0 @@ -121,15 +142,19 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0): elif isinstance(imgsz, (list, tuple)): imgsz = list(imgsz) else: - raise TypeError(f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. " - f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'") + raise TypeError( + f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. " + f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'" + ) # Apply max_dim if len(imgsz) > max_dim: - msg = "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " \ - "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'" + msg = ( + "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " + "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'" + ) if max_dim != 1: - raise ValueError(f'imgsz={imgsz} is not a valid image size. {msg}') + raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}") LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}") imgsz = [max(imgsz)] # Make image size a multiple of the stride @@ -137,7 +162,7 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0): # Print warning message if image size was updated if sz != imgsz: - LOGGER.warning(f'WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}') + LOGGER.warning(f"WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}") # Add missing dimensions if necessary sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz @@ -145,12 +170,14 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0): return sz -def check_version(current: str = '0.0.0', - required: str = '0.0.0', - name: str = 'version', - hard: bool = False, - verbose: bool = False, - msg: str = '') -> bool: +def check_version( + current: str = "0.0.0", + required: str = "0.0.0", + name: str = "version", + hard: bool = False, + verbose: bool = False, + msg: str = "", +) -> bool: """ Check current version against the required version or range. @@ -181,7 +208,7 @@ def check_version(current: str = '0.0.0', ``` """ if not current: # if current is '' or None - LOGGER.warning(f'WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.') + LOGGER.warning(f"WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.") return True elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics' try: @@ -189,34 +216,34 @@ def check_version(current: str = '0.0.0', current = metadata.version(current) # get version string from package name except metadata.PackageNotFoundError: if hard: - raise ModuleNotFoundError(emojis(f'WARNING ⚠️ {current} package is required but not installed')) + raise ModuleNotFoundError(emojis(f"WARNING ⚠️ {current} package is required but not installed")) else: return False if not required: # if required is '' or None return True - op = '' - version = '' + op = "" + version = "" result = True c = parse_version(current) # '1.2.3' -> (1, 2, 3) - for r in required.strip(',').split(','): - op, version = re.match(r'([^0-9]*)([\d.]+)', r).groups() # split '>=22.04' -> ('>=', '22.04') + for r in required.strip(",").split(","): + op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04') v = parse_version(version) # '1.2.3' -> (1, 2, 3) - if op == '==' and c != v: + if op == "==" and c != v: result = False - elif op == '!=' and c == v: + elif op == "!=" and c == v: result = False - elif op in ('>=', '') and not (c >= v): # if no constraint passed assume '>=required' + elif op in (">=", "") and not (c >= v): # if no constraint passed assume '>=required' result = False - elif op == '<=' and not (c <= v): + elif op == "<=" and not (c <= v): result = False - elif op == '>' and not (c > v): + elif op == ">" and not (c > v): result = False - elif op == '<' and not (c < v): + elif op == "<" and not (c < v): result = False if not result: - warning = f'WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}' + warning = f"WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}" if hard: raise ModuleNotFoundError(emojis(warning)) # assert version requirements met if verbose: @@ -224,7 +251,7 @@ def check_version(current: str = '0.0.0', return result -def check_latest_pypi_version(package_name='ultralytics'): +def check_latest_pypi_version(package_name="ultralytics"): """ Returns the latest version of a PyPI package without downloading or installing it. @@ -236,9 +263,9 @@ def check_latest_pypi_version(package_name='ultralytics'): """ with contextlib.suppress(Exception): requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning - response = requests.get(f'https://pypi.org/pypi/{package_name}/json', timeout=3) + response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3) if response.status_code == 200: - return response.json()['info']['version'] + return response.json()["info"]["version"] def check_pip_update_available(): @@ -251,16 +278,19 @@ def check_pip_update_available(): if ONLINE and is_pip_package(): with contextlib.suppress(Exception): from ultralytics import __version__ + latest = check_latest_pypi_version() - if check_version(__version__, f'<{latest}'): # check if current version is < latest version - LOGGER.info(f'New https://pypi.org/project/ultralytics/{latest} available 😃 ' - f"Update with 'pip install -U ultralytics'") + if check_version(__version__, f"<{latest}"): # check if current version is < latest version + LOGGER.info( + f"New https://pypi.org/project/ultralytics/{latest} available 😃 " + f"Update with 'pip install -U ultralytics'" + ) return True return False @ThreadingLocked() -def check_font(font='Arial.ttf'): +def check_font(font="Arial.ttf"): """ Find font locally or download to user's configuration directory if it does not already exist. @@ -283,13 +313,13 @@ def check_font(font='Arial.ttf'): return matches[0] # Download to USER_CONFIG_DIR if missing - url = f'https://ultralytics.com/assets/{name}' + url = f"https://ultralytics.com/assets/{name}" if downloads.is_url(url): downloads.safe_download(url=url, file=file) return file -def check_python(minimum: str = '3.8.0') -> bool: +def check_python(minimum: str = "3.8.0") -> bool: """ Check current python version against the required minimum version. @@ -299,11 +329,11 @@ def check_python(minimum: str = '3.8.0') -> bool: Returns: None """ - return check_version(platform.python_version(), minimum, name='Python ', hard=True) + return check_version(platform.python_version(), minimum, name="Python ", hard=True) @TryExcept() -def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=(), install=True, cmds=''): +def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""): """ Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed. @@ -329,41 +359,42 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=() ``` """ - prefix = colorstr('red', 'bold', 'requirements:') + prefix = colorstr("red", "bold", "requirements:") check_python() # check python version check_torchvision() # check torch-torchvision compatibility if isinstance(requirements, Path): # requirements.txt file file = requirements.resolve() - assert file.exists(), f'{prefix} {file} not found, check failed.' - requirements = [f'{x.name}{x.specifier}' for x in parse_requirements(file) if x.name not in exclude] + assert file.exists(), f"{prefix} {file} not found, check failed." + requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude] elif isinstance(requirements, str): requirements = [requirements] pkgs = [] for r in requirements: - r_stripped = r.split('/')[-1].replace('.git', '') # replace git+https://org/repo.git -> 'repo' - match = re.match(r'([a-zA-Z0-9-_]+)([<>!=~]+.*)?', r_stripped) - name, required = match[1], match[2].strip() if match[2] else '' + r_stripped = r.split("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo' + match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped) + name, required = match[1], match[2].strip() if match[2] else "" try: assert check_version(metadata.version(name), required) # exception if requirements not met except (AssertionError, metadata.PackageNotFoundError): pkgs.append(r) - s = ' '.join(f'"{x}"' for x in pkgs) # console string + s = " ".join(f'"{x}"' for x in pkgs) # console string if s: if install and AUTOINSTALL: # check environment variable n = len(pkgs) # number of packages updates LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...") try: t = time.time() - assert is_online(), 'AutoUpdate skipped (offline)' - LOGGER.info(subprocess.check_output(f'pip install --no-cache {s} {cmds}', shell=True).decode()) + assert is_online(), "AutoUpdate skipped (offline)" + LOGGER.info(subprocess.check_output(f"pip install --no-cache {s} {cmds}", shell=True).decode()) dt = time.time() - t LOGGER.info( f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n" - f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n") + f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n" + ) except Exception as e: - LOGGER.warning(f'{prefix} ❌ {e}') + LOGGER.warning(f"{prefix} ❌ {e}") return False else: return False @@ -386,76 +417,82 @@ def check_torchvision(): import torchvision # Compatibility table - compatibility_table = {'2.0': ['0.15'], '1.13': ['0.14'], '1.12': ['0.13']} + compatibility_table = {"2.0": ["0.15"], "1.13": ["0.14"], "1.12": ["0.13"]} # Extract only the major and minor versions - v_torch = '.'.join(torch.__version__.split('+')[0].split('.')[:2]) - v_torchvision = '.'.join(torchvision.__version__.split('+')[0].split('.')[:2]) + v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2]) + v_torchvision = ".".join(torchvision.__version__.split("+")[0].split(".")[:2]) if v_torch in compatibility_table: compatible_versions = compatibility_table[v_torch] if all(v_torchvision != v for v in compatible_versions): - print(f'WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n' - f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or " - "'pip install -U torch torchvision' to update both.\n" - 'For a full compatibility table see https://github.com/pytorch/vision#installation') + print( + f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n" + f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or " + "'pip install -U torch torchvision' to update both.\n" + "For a full compatibility table see https://github.com/pytorch/vision#installation" + ) -def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''): +def check_suffix(file="yolov8n.pt", suffix=".pt", msg=""): """Check file(s) for acceptable suffix.""" if file and suffix: if isinstance(suffix, str): - suffix = (suffix, ) + suffix = (suffix,) for f in file if isinstance(file, (list, tuple)) else [file]: s = Path(f).suffix.lower().strip() # file suffix if len(s): - assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}, not {s}' + assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}" def check_yolov5u_filename(file: str, verbose: bool = True): """Replace legacy YOLOv5 filenames with updated YOLOv5u filenames.""" - if 'yolov3' in file or 'yolov5' in file: - if 'u.yaml' in file: - file = file.replace('u.yaml', '.yaml') # i.e. yolov5nu.yaml -> yolov5n.yaml - elif '.pt' in file and 'u' not in file: + if "yolov3" in file or "yolov5" in file: + if "u.yaml" in file: + file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml + elif ".pt" in file and "u" not in file: original_file = file - file = re.sub(r'(.*yolov5([nsmlx]))\.pt', '\\1u.pt', file) # i.e. yolov5n.pt -> yolov5nu.pt - file = re.sub(r'(.*yolov5([nsmlx])6)\.pt', '\\1u.pt', file) # i.e. yolov5n6.pt -> yolov5n6u.pt - file = re.sub(r'(.*yolov3(|-tiny|-spp))\.pt', '\\1u.pt', file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt + file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt + file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt + file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt if file != original_file and verbose: LOGGER.info( f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are " - f'trained with https://github.com/ultralytics/ultralytics and feature improved performance vs ' - f'standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n') + f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs " + f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n" + ) return file -def check_model_file_from_stem(model='yolov8n'): +def check_model_file_from_stem(model="yolov8n"): """Return a model filename from a valid model stem.""" if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS: - return Path(model).with_suffix('.pt') # add suffix, i.e. yolov8n -> yolov8n.pt + return Path(model).with_suffix(".pt") # add suffix, i.e. yolov8n -> yolov8n.pt else: return model -def check_file(file, suffix='', download=True, hard=True): +def check_file(file, suffix="", download=True, hard=True): """Search/download file (if necessary) and return path.""" check_suffix(file, suffix) # optional file = str(file).strip() # convert to string and strip spaces file = check_yolov5u_filename(file) # yolov5n -> yolov5nu - if (not file or ('://' not in file and Path(file).exists()) or # '://' check required in Windows Python<3.10 - file.lower().startswith('grpc://')): # file exists or gRPC Triton images + if ( + not file + or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10 + or file.lower().startswith("grpc://") + ): # file exists or gRPC Triton images return file - elif download and file.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://', 'tcp://')): # download + elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download url = file # warning: Pathlib turns :// -> :/ file = url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth if Path(file).exists(): - LOGGER.info(f'Found {clean_url(url)} locally at {file}') # file already exists + LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists else: downloads.safe_download(url=url, file=file, unzip=False) return file else: # search - files = glob.glob(str(ROOT / 'cfg' / '**' / file), recursive=True) # find file + files = glob.glob(str(ROOT / "cfg" / "**" / file), recursive=True) # find file if not files and hard: raise FileNotFoundError(f"'{file}' does not exist") elif len(files) > 1 and hard: @@ -463,7 +500,7 @@ def check_file(file, suffix='', download=True, hard=True): return files[0] if len(files) else [] # return file -def check_yaml(file, suffix=('.yaml', '.yml'), hard=True): +def check_yaml(file, suffix=(".yaml", ".yml"), hard=True): """Search/download YAML file (if necessary) and return path, checking suffix.""" return check_file(file, suffix, hard=hard) @@ -482,51 +519,52 @@ def check_is_path_safe(basedir, path): base_dir_resolved = Path(basedir).resolve() path_resolved = Path(path).resolve() - return path_resolved.is_file() and path_resolved.parts[:len(base_dir_resolved.parts)] == base_dir_resolved.parts + return path_resolved.is_file() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts def check_imshow(warn=False): """Check if environment supports image displays.""" try: if LINUX: - assert 'DISPLAY' in os.environ and not is_docker() and not is_colab() and not is_kaggle() - cv2.imshow('test', np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image + assert "DISPLAY" in os.environ and not is_docker() and not is_colab() and not is_kaggle() + cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image cv2.waitKey(1) cv2.destroyAllWindows() cv2.waitKey(1) return True except Exception as e: if warn: - LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}') + LOGGER.warning(f"WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}") return False -def check_yolo(verbose=True, device=''): +def check_yolo(verbose=True, device=""): """Return a human-readable YOLO software and hardware summary.""" import psutil from ultralytics.utils.torch_utils import select_device if is_jupyter(): - if check_requirements('wandb', install=False): - os.system('pip uninstall -y wandb') # uninstall wandb: unwanted account creation prompt with infinite hang + if check_requirements("wandb", install=False): + os.system("pip uninstall -y wandb") # uninstall wandb: unwanted account creation prompt with infinite hang if is_colab(): - shutil.rmtree('sample_data', ignore_errors=True) # remove colab /sample_data directory + shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory if verbose: # System info gib = 1 << 30 # bytes per GiB ram = psutil.virtual_memory().total - total, used, free = shutil.disk_usage('/') - s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)' + total, used, free = shutil.disk_usage("/") + s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)" with contextlib.suppress(Exception): # clear display if ipython is installed from IPython import display + display.clear_output() else: - s = '' + s = "" select_device(device=device, newline=False) - LOGGER.info(f'Setup complete ✅ {s}') + LOGGER.info(f"Setup complete ✅ {s}") def collect_system_info(): @@ -537,32 +575,36 @@ def collect_system_info(): from ultralytics.utils import ENVIRONMENT, is_git_dir from ultralytics.utils.torch_utils import get_cpu_info - ram_info = psutil.virtual_memory().total / (1024 ** 3) # Convert bytes to GB + ram_info = psutil.virtual_memory().total / (1024**3) # Convert bytes to GB check_yolo() - LOGGER.info(f"\n{'OS':<20}{platform.platform()}\n" - f"{'Environment':<20}{ENVIRONMENT}\n" - f"{'Python':<20}{sys.version.split()[0]}\n" - f"{'Install':<20}{'git' if is_git_dir() else 'pip' if is_pip_package() else 'other'}\n" - f"{'RAM':<20}{ram_info:.2f} GB\n" - f"{'CPU':<20}{get_cpu_info()}\n" - f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n") - - for r in parse_requirements(package='ultralytics'): + LOGGER.info( + f"\n{'OS':<20}{platform.platform()}\n" + f"{'Environment':<20}{ENVIRONMENT}\n" + f"{'Python':<20}{sys.version.split()[0]}\n" + f"{'Install':<20}{'git' if is_git_dir() else 'pip' if is_pip_package() else 'other'}\n" + f"{'RAM':<20}{ram_info:.2f} GB\n" + f"{'CPU':<20}{get_cpu_info()}\n" + f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n" + ) + + for r in parse_requirements(package="ultralytics"): try: current = metadata.version(r.name) - is_met = '✅ ' if check_version(current, str(r.specifier), hard=True) else '❌ ' + is_met = "✅ " if check_version(current, str(r.specifier), hard=True) else "❌ " except metadata.PackageNotFoundError: - current = '(not installed)' - is_met = '❌ ' - LOGGER.info(f'{r.name:<20}{is_met}{current}{r.specifier}') + current = "(not installed)" + is_met = "❌ " + LOGGER.info(f"{r.name:<20}{is_met}{current}{r.specifier}") if is_github_action_running(): - LOGGER.info(f"\nRUNNER_OS: {os.getenv('RUNNER_OS')}\n" - f"GITHUB_EVENT_NAME: {os.getenv('GITHUB_EVENT_NAME')}\n" - f"GITHUB_WORKFLOW: {os.getenv('GITHUB_WORKFLOW')}\n" - f"GITHUB_ACTOR: {os.getenv('GITHUB_ACTOR')}\n" - f"GITHUB_REPOSITORY: {os.getenv('GITHUB_REPOSITORY')}\n" - f"GITHUB_REPOSITORY_OWNER: {os.getenv('GITHUB_REPOSITORY_OWNER')}\n") + LOGGER.info( + f"\nRUNNER_OS: {os.getenv('RUNNER_OS')}\n" + f"GITHUB_EVENT_NAME: {os.getenv('GITHUB_EVENT_NAME')}\n" + f"GITHUB_WORKFLOW: {os.getenv('GITHUB_WORKFLOW')}\n" + f"GITHUB_ACTOR: {os.getenv('GITHUB_ACTOR')}\n" + f"GITHUB_REPOSITORY: {os.getenv('GITHUB_REPOSITORY')}\n" + f"GITHUB_REPOSITORY_OWNER: {os.getenv('GITHUB_REPOSITORY_OWNER')}\n" + ) def check_amp(model): @@ -587,7 +629,7 @@ def check_amp(model): (bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False. """ device = next(model.parameters()).device # get model device - if device.type in ('cpu', 'mps'): + if device.type in ("cpu", "mps"): return False # AMP only used on CUDA devices def amp_allclose(m, im): @@ -598,22 +640,27 @@ def check_amp(model): del m return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance - im = ASSETS / 'bus.jpg' # image to check - prefix = colorstr('AMP: ') - LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...') + im = ASSETS / "bus.jpg" # image to check + prefix = colorstr("AMP: ") + LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...") warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False." try: from ultralytics import YOLO - assert amp_allclose(YOLO('yolov8n.pt'), im) - LOGGER.info(f'{prefix}checks passed ✅') + + assert amp_allclose(YOLO("yolov8n.pt"), im) + LOGGER.info(f"{prefix}checks passed ✅") except ConnectionError: - LOGGER.warning(f'{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}') + LOGGER.warning(f"{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}") except (AttributeError, ModuleNotFoundError): - LOGGER.warning(f'{prefix}checks skipped ⚠️. ' - f'Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}') + LOGGER.warning( + f"{prefix}checks skipped ⚠️. " + f"Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}" + ) except AssertionError: - LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to ' - f'NaN losses or zero-mAP results, so AMP will be disabled during training.') + LOGGER.warning( + f"{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to " + f"NaN losses or zero-mAP results, so AMP will be disabled during training." + ) return False return True @@ -621,8 +668,8 @@ def check_amp(model): def git_describe(path=ROOT): # path must be a directory """Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe.""" with contextlib.suppress(Exception): - return subprocess.check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1] - return '' + return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1] + return "" def print_args(args: Optional[dict] = None, show_file=True, show_func=False): @@ -630,7 +677,7 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False): def strip_auth(v): """Clean longer Ultralytics HUB URLs by stripping potential authentication information.""" - return clean_url(v) if (isinstance(v, str) and v.startswith('http') and len(v) > 100) else v + return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v x = inspect.currentframe().f_back # previous frame file, _, func, _, _ = inspect.getframeinfo(x) @@ -638,11 +685,11 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False): args, _, _, frm = inspect.getargvalues(x) args = {k: v for k, v in frm.items() if k in args} try: - file = Path(file).resolve().relative_to(ROOT).with_suffix('') + file = Path(file).resolve().relative_to(ROOT).with_suffix("") except ValueError: file = Path(file).stem - s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '') - LOGGER.info(colorstr(s) + ', '.join(f'{k}={strip_auth(v)}' for k, v in args.items())) + s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "") + LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in args.items())) def cuda_device_count() -> int: @@ -654,11 +701,12 @@ def cuda_device_count() -> int: """ try: # Run the nvidia-smi command and capture its output - output = subprocess.check_output(['nvidia-smi', '--query-gpu=count', '--format=csv,noheader,nounits'], - encoding='utf-8') + output = subprocess.check_output( + ["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8" + ) # Take the first line and strip any leading/trailing white space - first_line = output.strip().split('\n')[0] + first_line = output.strip().split("\n")[0] return int(first_line) except (subprocess.CalledProcessError, FileNotFoundError, ValueError): diff --git a/ultralytics/utils/dist.py b/ultralytics/utils/dist.py index b07204e9..b669e52f 100644 --- a/ultralytics/utils/dist.py +++ b/ultralytics/utils/dist.py @@ -18,13 +18,13 @@ def find_free_network_port() -> int: `MASTER_PORT` environment variable. """ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('127.0.0.1', 0)) + s.bind(("127.0.0.1", 0)) return s.getsockname()[1] # port def generate_ddp_file(trainer): """Generates a DDP file and returns its file name.""" - module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1) + module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1) content = f""" # Ultralytics Multi-GPU training temp file (should be automatically deleted after use) @@ -39,13 +39,15 @@ if __name__ == "__main__": trainer = {name}(cfg=cfg, overrides=overrides) results = trainer.train() """ - (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True) - with tempfile.NamedTemporaryFile(prefix='_temp_', - suffix=f'{id(trainer)}.py', - mode='w+', - encoding='utf-8', - dir=USER_CONFIG_DIR / 'DDP', - delete=False) as file: + (USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True) + with tempfile.NamedTemporaryFile( + prefix="_temp_", + suffix=f"{id(trainer)}.py", + mode="w+", + encoding="utf-8", + dir=USER_CONFIG_DIR / "DDP", + delete=False, + ) as file: file.write(content) return file.name @@ -53,16 +55,17 @@ if __name__ == "__main__": def generate_ddp_command(world_size, trainer): """Generates and returns command for distributed training.""" import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218 + if not trainer.resume: shutil.rmtree(trainer.save_dir) # remove the save_dir file = generate_ddp_file(trainer) - dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch' + dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch" port = find_free_network_port() - cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file] + cmd = [sys.executable, "-m", dist_cmd, "--nproc_per_node", f"{world_size}", "--master_port", f"{port}", file] return cmd, file def ddp_cleanup(trainer, file): """Delete temp file if created.""" - if f'{id(trainer)}.py' in file: # if temp_file suffix in file + if f"{id(trainer)}.py" in file: # if temp_file suffix in file os.remove(file) diff --git a/ultralytics/utils/downloads.py b/ultralytics/utils/downloads.py index 363e9a0d..fe8b8af0 100644 --- a/ultralytics/utils/downloads.py +++ b/ultralytics/utils/downloads.py @@ -15,15 +15,17 @@ import torch from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file # Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets -GITHUB_ASSETS_REPO = 'ultralytics/assets' -GITHUB_ASSETS_NAMES = [f'yolov8{k}{suffix}.pt' for k in 'nsmlx' for suffix in ('', '-cls', '-seg', '-pose', '-obb')] + \ - [f'yolov5{k}{resolution}u.pt' for k in 'nsmlx' for resolution in ('', '6')] + \ - [f'yolov3{k}u.pt' for k in ('', '-spp', '-tiny')] + \ - [f'yolo_nas_{k}.pt' for k in 'sml'] + \ - [f'sam_{k}.pt' for k in 'bl'] + \ - [f'FastSAM-{k}.pt' for k in 'sx'] + \ - [f'rtdetr-{k}.pt' for k in 'lx'] + \ - ['mobile_sam.pt'] +GITHUB_ASSETS_REPO = "ultralytics/assets" +GITHUB_ASSETS_NAMES = ( + [f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")] + + [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")] + + [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")] + + [f"yolo_nas_{k}.pt" for k in "sml"] + + [f"sam_{k}.pt" for k in "bl"] + + [f"FastSAM-{k}.pt" for k in "sx"] + + [f"rtdetr-{k}.pt" for k in "lx"] + + ["mobile_sam.pt"] +) GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES] @@ -56,7 +58,7 @@ def is_url(url, check=True): return False -def delete_dsstore(path, files_to_delete=('.DS_Store', '__MACOSX')): +def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")): """ Deletes all ".DS_store" files under a specified directory. @@ -77,12 +79,12 @@ def delete_dsstore(path, files_to_delete=('.DS_Store', '__MACOSX')): """ for file in files_to_delete: matches = list(Path(path).rglob(file)) - LOGGER.info(f'Deleting {file} files: {matches}') + LOGGER.info(f"Deleting {file} files: {matches}") for f in matches: f.unlink() -def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), progress=True): +def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True): """ Zips the contents of a directory, excluding files containing strings in the exclude list. The resulting zip file is named after the directory and placed alongside it. @@ -111,17 +113,17 @@ def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), p raise FileNotFoundError(f"Directory '{directory}' does not exist.") # Unzip with progress bar - files_to_zip = [f for f in directory.rglob('*') if f.is_file() and all(x not in f.name for x in exclude)] - zip_file = directory.with_suffix('.zip') + files_to_zip = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)] + zip_file = directory.with_suffix(".zip") compression = ZIP_DEFLATED if compress else ZIP_STORED - with ZipFile(zip_file, 'w', compression) as f: - for file in TQDM(files_to_zip, desc=f'Zipping {directory} to {zip_file}...', unit='file', disable=not progress): + with ZipFile(zip_file, "w", compression) as f: + for file in TQDM(files_to_zip, desc=f"Zipping {directory} to {zip_file}...", unit="file", disable=not progress): f.write(file, file.relative_to(directory)) return zip_file # return path to zip file -def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=False, progress=True): +def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True): """ Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list. @@ -161,7 +163,7 @@ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=Fals files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)] top_level_dirs = {Path(f).parts[0] for f in files} - if len(top_level_dirs) > 1 or (len(files) > 1 and not files[0].endswith('/')): + if len(top_level_dirs) > 1 or (len(files) > 1 and not files[0].endswith("/")): # Zip has multiple files at top level path = extract_path = Path(path) / Path(file).stem # i.e. ../datasets/coco8 else: @@ -172,20 +174,20 @@ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=Fals # Check if destination directory already exists and contains files if path.exists() and any(path.iterdir()) and not exist_ok: # If it exists and is not empty, return the path without unzipping - LOGGER.warning(f'WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.') + LOGGER.warning(f"WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.") return path - for f in TQDM(files, desc=f'Unzipping {file} to {Path(path).resolve()}...', unit='file', disable=not progress): + for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="file", disable=not progress): # Ensure the file is within the extract_path to avoid path traversal security vulnerability - if '..' in Path(f).parts: - LOGGER.warning(f'Potentially insecure file path: {f}, skipping extraction.') + if ".." in Path(f).parts: + LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.") continue zipObj.extract(f, extract_path) return path # return unzip dir -def check_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=1.5, hard=True): +def check_disk_space(url="https://ultralytics.com/assets/coco128.zip", sf=1.5, hard=True): """ Check if there is sufficient disk space to download and store a file. @@ -199,20 +201,23 @@ def check_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=1.5, h """ try: r = requests.head(url) # response - assert r.status_code < 400, f'URL error for {url}: {r.status_code} {r.reason}' # check response + assert r.status_code < 400, f"URL error for {url}: {r.status_code} {r.reason}" # check response except Exception: return True # requests issue, default to True # Check file size gib = 1 << 30 # bytes per GiB - data = int(r.headers.get('Content-Length', 0)) / gib # file size (GB) + data = int(r.headers.get("Content-Length", 0)) / gib # file size (GB) total, used, free = (x / gib for x in shutil.disk_usage(Path.cwd())) # bytes + if data * sf < free: return True # sufficient space # Insufficient space - text = (f'WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, ' - f'Please free {data * sf - free:.1f} GB additional disk space and try again.') + text = ( + f"WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, " + f"Please free {data * sf - free:.1f} GB additional disk space and try again." + ) if hard: raise MemoryError(text) LOGGER.warning(text) @@ -238,36 +243,41 @@ def get_google_drive_file_info(link): url, filename = get_google_drive_file_info(link) ``` """ - file_id = link.split('/d/')[1].split('/view')[0] - drive_url = f'https://drive.google.com/uc?export=download&id={file_id}' + file_id = link.split("/d/")[1].split("/view")[0] + drive_url = f"https://drive.google.com/uc?export=download&id={file_id}" filename = None # Start session with requests.Session() as session: response = session.get(drive_url, stream=True) - if 'quota exceeded' in str(response.content.lower()): + if "quota exceeded" in str(response.content.lower()): raise ConnectionError( - emojis(f'❌ Google Drive file download quota exceeded. ' - f'Please try again later or download this file manually at {link}.')) + emojis( + f"❌ Google Drive file download quota exceeded. " + f"Please try again later or download this file manually at {link}." + ) + ) for k, v in response.cookies.items(): - if k.startswith('download_warning'): - drive_url += f'&confirm={v}' # v is token - cd = response.headers.get('content-disposition') + if k.startswith("download_warning"): + drive_url += f"&confirm={v}" # v is token + cd = response.headers.get("content-disposition") if cd: filename = re.findall('filename="(.+)"', cd)[0] return drive_url, filename -def safe_download(url, - file=None, - dir=None, - unzip=True, - delete=False, - curl=False, - retry=3, - min_bytes=1E0, - exist_ok=False, - progress=True): +def safe_download( + url, + file=None, + dir=None, + unzip=True, + delete=False, + curl=False, + retry=3, + min_bytes=1e0, + exist_ok=False, + progress=True, +): """ Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file. @@ -294,36 +304,38 @@ def safe_download(url, path = safe_download(link) ``` """ - gdrive = url.startswith('https://drive.google.com/') # check if the URL is a Google Drive link + gdrive = url.startswith("https://drive.google.com/") # check if the URL is a Google Drive link if gdrive: url, file = get_google_drive_file_info(url) - f = Path(dir or '.') / (file or url2file(url)) # URL converted to filename - if '://' not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10) + f = Path(dir or ".") / (file or url2file(url)) # URL converted to filename + if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10) f = Path(url) # filename elif not f.is_file(): # URL and file do not exist desc = f"Downloading {url if gdrive else clean_url(url)} to '{f}'" - LOGGER.info(f'{desc}...') + LOGGER.info(f"{desc}...") f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing check_disk_space(url) for i in range(retry + 1): try: if curl or i > 0: # curl download with retry, continue - s = 'sS' * (not progress) # silent - r = subprocess.run(['curl', '-#', f'-{s}L', url, '-o', f, '--retry', '3', '-C', '-']).returncode - assert r == 0, f'Curl return value {r}' + s = "sS" * (not progress) # silent + r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode + assert r == 0, f"Curl return value {r}" else: # urllib download - method = 'torch' - if method == 'torch': + method = "torch" + if method == "torch": torch.hub.download_url_to_file(url, f, progress=progress) else: - with request.urlopen(url) as response, TQDM(total=int(response.getheader('Content-Length', 0)), - desc=desc, - disable=not progress, - unit='B', - unit_scale=True, - unit_divisor=1024) as pbar: - with open(f, 'wb') as f_opened: + with request.urlopen(url) as response, TQDM( + total=int(response.getheader("Content-Length", 0)), + desc=desc, + disable=not progress, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as pbar: + with open(f, "wb") as f_opened: for data in response: f_opened.write(data) pbar.update(len(data)) @@ -334,26 +346,26 @@ def safe_download(url, f.unlink() # remove partial downloads except Exception as e: if i == 0 and not is_online(): - raise ConnectionError(emojis(f'❌ Download failure for {url}. Environment is not online.')) from e + raise ConnectionError(emojis(f"❌ Download failure for {url}. Environment is not online.")) from e elif i >= retry: - raise ConnectionError(emojis(f'❌ Download failure for {url}. Retry limit reached.')) from e - LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...') + raise ConnectionError(emojis(f"❌ Download failure for {url}. Retry limit reached.")) from e + LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {url}...") - if unzip and f.exists() and f.suffix in ('', '.zip', '.tar', '.gz'): + if unzip and f.exists() and f.suffix in ("", ".zip", ".tar", ".gz"): from zipfile import is_zipfile unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place if is_zipfile(f): unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip - elif f.suffix in ('.tar', '.gz'): - LOGGER.info(f'Unzipping {f} to {unzip_dir}...') - subprocess.run(['tar', 'xf' if f.suffix == '.tar' else 'xfz', f, '--directory', unzip_dir], check=True) + elif f.suffix in (".tar", ".gz"): + LOGGER.info(f"Unzipping {f} to {unzip_dir}...") + subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True) if delete: f.unlink() # remove zip return unzip_dir -def get_github_assets(repo='ultralytics/assets', version='latest', retry=False): +def get_github_assets(repo="ultralytics/assets", version="latest", retry=False): """ Retrieve the specified version's tag and assets from a GitHub repository. If the version is not specified, the function fetches the latest release assets. @@ -372,20 +384,20 @@ def get_github_assets(repo='ultralytics/assets', version='latest', retry=False): ``` """ - if version != 'latest': - version = f'tags/{version}' # i.e. tags/v6.2 - url = f'https://api.github.com/repos/{repo}/releases/{version}' + if version != "latest": + version = f"tags/{version}" # i.e. tags/v6.2 + url = f"https://api.github.com/repos/{repo}/releases/{version}" r = requests.get(url) # github api - if r.status_code != 200 and r.reason != 'rate limit exceeded' and retry: # failed and not 403 rate limit exceeded + if r.status_code != 200 and r.reason != "rate limit exceeded" and retry: # failed and not 403 rate limit exceeded r = requests.get(url) # try again if r.status_code != 200: - LOGGER.warning(f'⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}') - return '', [] + LOGGER.warning(f"⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}") + return "", [] data = r.json() - return data['tag_name'], [x['name'] for x in data['assets']] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...] + return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...] -def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0', **kwargs): +def attempt_download_asset(file, repo="ultralytics/assets", release="v0.0.0", **kwargs): """ Attempt to download a file from GitHub release assets if it is not found locally. The function checks for the file locally first, then tries to download it from the specified GitHub repository release. @@ -409,32 +421,32 @@ def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0', ** # YOLOv3/5u updates file = str(file) file = checks.check_yolov5u_filename(file) - file = Path(file.strip().replace("'", '')) + file = Path(file.strip().replace("'", "")) if file.exists(): return str(file) - elif (SETTINGS['weights_dir'] / file).exists(): - return str(SETTINGS['weights_dir'] / file) + elif (SETTINGS["weights_dir"] / file).exists(): + return str(SETTINGS["weights_dir"] / file) else: # URL specified name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc. - download_url = f'https://github.com/{repo}/releases/download' - if str(file).startswith(('http:/', 'https:/')): # download - url = str(file).replace(':/', '://') # Pathlib turns :// -> :/ + download_url = f"https://github.com/{repo}/releases/download" + if str(file).startswith(("http:/", "https:/")): # download + url = str(file).replace(":/", "://") # Pathlib turns :// -> :/ file = url2file(name) # parse authentication https://url.com/file.txt?auth... if Path(file).is_file(): - LOGGER.info(f'Found {clean_url(url)} locally at {file}') # file already exists + LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists else: - safe_download(url=url, file=file, min_bytes=1E5, **kwargs) + safe_download(url=url, file=file, min_bytes=1e5, **kwargs) elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES: - safe_download(url=f'{download_url}/{release}/{name}', file=file, min_bytes=1E5, **kwargs) + safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs) else: tag, assets = get_github_assets(repo, release) if not assets: tag, assets = get_github_assets(repo) # latest release if name in assets: - safe_download(url=f'{download_url}/{tag}/{name}', file=file, min_bytes=1E5, **kwargs) + safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs) return str(file) @@ -464,14 +476,18 @@ def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads= if threads > 1: with ThreadPool(threads) as pool: pool.map( - lambda x: safe_download(url=x[0], - dir=x[1], - unzip=unzip, - delete=delete, - curl=curl, - retry=retry, - exist_ok=exist_ok, - progress=threads <= 1), zip(url, repeat(dir))) + lambda x: safe_download( + url=x[0], + dir=x[1], + unzip=unzip, + delete=delete, + curl=curl, + retry=retry, + exist_ok=exist_ok, + progress=threads <= 1, + ), + zip(url, repeat(dir)), + ) pool.close() pool.join() else: diff --git a/ultralytics/utils/errors.py b/ultralytics/utils/errors.py index 745ca0a4..86aee1d9 100644 --- a/ultralytics/utils/errors.py +++ b/ultralytics/utils/errors.py @@ -17,6 +17,6 @@ class HUBModelError(Exception): The message is automatically processed through the 'emojis' function from the 'ultralytics.utils' package. """ - def __init__(self, message='Model not found. Please check model URL and try again.'): + def __init__(self, message="Model not found. Please check model URL and try again."): """Create an exception for when a model is not found.""" super().__init__(emojis(message)) diff --git a/ultralytics/utils/files.py b/ultralytics/utils/files.py index 9fa1488f..ae8e90c2 100644 --- a/ultralytics/utils/files.py +++ b/ultralytics/utils/files.py @@ -50,13 +50,13 @@ def spaces_in_path(path): """ # If path has spaces, replace them with underscores - if ' ' in str(path): + if " " in str(path): string = isinstance(path, str) # input type path = Path(path) # Create a temporary directory and construct the new path with tempfile.TemporaryDirectory() as tmp_dir: - tmp_path = Path(tmp_dir) / path.name.replace(' ', '_') + tmp_path = Path(tmp_dir) / path.name.replace(" ", "_") # Copy file/directory if path.is_dir(): @@ -82,7 +82,7 @@ def spaces_in_path(path): yield path -def increment_path(path, exist_ok=False, sep='', mkdir=False): +def increment_path(path, exist_ok=False, sep="", mkdir=False): """ Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. @@ -102,11 +102,11 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False): """ path = Path(path) # os-agnostic if path.exists() and not exist_ok: - path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '') + path, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "") # Method 1 for n in range(2, 9999): - p = f'{path}{sep}{n}{suffix}' # increment path + p = f"{path}{sep}{n}{suffix}" # increment path if not os.path.exists(p): break path = Path(p) @@ -119,14 +119,14 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False): def file_age(path=__file__): """Return days since last file update.""" - dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta + dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta return dt.days # + dt.seconds / 86400 # fractional days def file_date(path=__file__): """Return human-readable file modification date, i.e. '2021-3-26'.""" t = datetime.fromtimestamp(Path(path).stat().st_mtime) - return f'{t.year}-{t.month}-{t.day}' + return f"{t.year}-{t.month}-{t.day}" def file_size(path): @@ -137,11 +137,11 @@ def file_size(path): if path.is_file(): return path.stat().st_size / mb elif path.is_dir(): - return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb + return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / mb return 0.0 -def get_latest_run(search_dir='.'): +def get_latest_run(search_dir="."): """Return path to most recent 'last.pt' in /runs (i.e. to --resume from).""" - last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) - return max(last_list, key=os.path.getctime) if last_list else '' + last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True) + return max(last_list, key=os.path.getctime) if last_list else "" diff --git a/ultralytics/utils/instance.py b/ultralytics/utils/instance.py index a1d85aaf..4e9ef2c6 100644 --- a/ultralytics/utils/instance.py +++ b/ultralytics/utils/instance.py @@ -26,9 +26,9 @@ to_4tuple = _ntuple(4) # `xyxy` means left top and right bottom # `xywh` means center x, center y and width, height(YOLO format) # `ltwh` means left top and width, height(COCO format) -_formats = ['xyxy', 'xywh', 'ltwh'] +_formats = ["xyxy", "xywh", "ltwh"] -__all__ = 'Bboxes', # tuple or list +__all__ = ("Bboxes",) # tuple or list class Bboxes: @@ -46,9 +46,9 @@ class Bboxes: This class does not handle normalization or denormalization of bounding boxes. """ - def __init__(self, bboxes, format='xyxy') -> None: + def __init__(self, bboxes, format="xyxy") -> None: """Initializes the Bboxes class with bounding box data in a specified format.""" - assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}' + assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}" bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes assert bboxes.ndim == 2 assert bboxes.shape[1] == 4 @@ -58,21 +58,21 @@ class Bboxes: def convert(self, format): """Converts bounding box format from one type to another.""" - assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}' + assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}" if self.format == format: return - elif self.format == 'xyxy': - func = xyxy2xywh if format == 'xywh' else xyxy2ltwh - elif self.format == 'xywh': - func = xywh2xyxy if format == 'xyxy' else xywh2ltwh + elif self.format == "xyxy": + func = xyxy2xywh if format == "xywh" else xyxy2ltwh + elif self.format == "xywh": + func = xywh2xyxy if format == "xyxy" else xywh2ltwh else: - func = ltwh2xyxy if format == 'xyxy' else ltwh2xywh + func = ltwh2xyxy if format == "xyxy" else ltwh2xywh self.bboxes = func(self.bboxes) self.format = format def areas(self): """Return box areas.""" - self.convert('xyxy') + self.convert("xyxy") return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1]) # def denormalize(self, w, h): @@ -124,7 +124,7 @@ class Bboxes: return len(self.bboxes) @classmethod - def concatenate(cls, boxes_list: List['Bboxes'], axis=0) -> 'Bboxes': + def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes": """ Concatenate a list of Bboxes objects into a single Bboxes object. @@ -148,7 +148,7 @@ class Bboxes: return boxes_list[0] return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis)) - def __getitem__(self, index) -> 'Bboxes': + def __getitem__(self, index) -> "Bboxes": """ Retrieve a specific bounding box or a set of bounding boxes using indexing. @@ -169,7 +169,7 @@ class Bboxes: if isinstance(index, int): return Bboxes(self.bboxes[index].view(1, -1)) b = self.bboxes[index] - assert b.ndim == 2, f'Indexing on Bboxes with {index} failed to return a matrix!' + assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!" return Bboxes(b) @@ -205,7 +205,7 @@ class Instances: This class does not perform input validation, and it assumes the inputs are well-formed. """ - def __init__(self, bboxes, segments=None, keypoints=None, bbox_format='xywh', normalized=True) -> None: + def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None: """ Args: bboxes (ndarray): bboxes with shape [N, 4]. @@ -263,7 +263,7 @@ class Instances: def add_padding(self, padw, padh): """Handle rect and mosaic situation.""" - assert not self.normalized, 'you should add padding with absolute coordinates.' + assert not self.normalized, "you should add padding with absolute coordinates." self._bboxes.add(offset=(padw, padh, padw, padh)) self.segments[..., 0] += padw self.segments[..., 1] += padh @@ -271,7 +271,7 @@ class Instances: self.keypoints[..., 0] += padw self.keypoints[..., 1] += padh - def __getitem__(self, index) -> 'Instances': + def __getitem__(self, index) -> "Instances": """ Retrieve a specific instance or a set of instances using indexing. @@ -301,7 +301,7 @@ class Instances: def flipud(self, h): """Flips the coordinates of bounding boxes, segments, and keypoints vertically.""" - if self._bboxes.format == 'xyxy': + if self._bboxes.format == "xyxy": y1 = self.bboxes[:, 1].copy() y2 = self.bboxes[:, 3].copy() self.bboxes[:, 1] = h - y2 @@ -314,7 +314,7 @@ class Instances: def fliplr(self, w): """Reverses the order of the bounding boxes and segments horizontally.""" - if self._bboxes.format == 'xyxy': + if self._bboxes.format == "xyxy": x1 = self.bboxes[:, 0].copy() x2 = self.bboxes[:, 2].copy() self.bboxes[:, 0] = w - x2 @@ -328,10 +328,10 @@ class Instances: def clip(self, w, h): """Clips bounding boxes, segments, and keypoints values to stay within image boundaries.""" ori_format = self._bboxes.format - self.convert_bbox(format='xyxy') + self.convert_bbox(format="xyxy") self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w) self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h) - if ori_format != 'xyxy': + if ori_format != "xyxy": self.convert_bbox(format=ori_format) self.segments[..., 0] = self.segments[..., 0].clip(0, w) self.segments[..., 1] = self.segments[..., 1].clip(0, h) @@ -367,7 +367,7 @@ class Instances: return len(self.bboxes) @classmethod - def concatenate(cls, instances_list: List['Instances'], axis=0) -> 'Instances': + def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances": """ Concatenates a list of Instances objects into a single Instances object. diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py index 38d633bf..900338dc 100644 --- a/ultralytics/utils/loss.py +++ b/ultralytics/utils/loss.py @@ -28,22 +28,27 @@ class VarifocalLoss(nn.Module): """Computes varfocal loss.""" weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label with torch.cuda.amp.autocast(enabled=False): - loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') * - weight).mean(1).sum() + loss = ( + (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight) + .mean(1) + .sum() + ) return loss class FocalLoss(nn.Module): """Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).""" - def __init__(self, ): + def __init__( + self, + ): """Initializer for FocalLoss class with no parameters.""" super().__init__() @staticmethod def forward(pred, label, gamma=1.5, alpha=0.25): """Calculates and updates confusion matrix for object detection/classification tasks.""" - loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none') + loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none") # p_t = torch.exp(-loss) # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability @@ -91,8 +96,10 @@ class BboxLoss(nn.Module): tr = tl + 1 # target right wl = tr - target # weight left wr = 1 - wl # weight right - return (F.cross_entropy(pred_dist, tl.view(-1), reduction='none').view(tl.shape) * wl + - F.cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape) * wr).mean(-1, keepdim=True) + return ( + F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl + + F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr + ).mean(-1, keepdim=True) class RotatedBboxLoss(BboxLoss): @@ -145,7 +152,7 @@ class v8DetectionLoss: h = model.args # hyperparameters m = model.model[-1] # Detect() module - self.bce = nn.BCEWithLogitsLoss(reduction='none') + self.bce = nn.BCEWithLogitsLoss(reduction="none") self.hyp = h self.stride = m.stride # model strides self.nc = m.nc # number of classes @@ -190,7 +197,8 @@ class v8DetectionLoss: loss = torch.zeros(3, device=self.device) # box, cls, dfl feats = preds[1] if isinstance(preds, tuple) else preds pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( - (self.reg_max * 4, self.nc), 1) + (self.reg_max * 4, self.nc), 1 + ) pred_scores = pred_scores.permute(0, 2, 1).contiguous() pred_distri = pred_distri.permute(0, 2, 1).contiguous() @@ -201,7 +209,7 @@ class v8DetectionLoss: anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) # Targets - targets = torch.cat((batch['batch_idx'].view(-1, 1), batch['cls'].view(-1, 1), batch['bboxes']), 1) + targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) @@ -210,8 +218,13 @@ class v8DetectionLoss: pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) _, target_bboxes, target_scores, fg_mask, _ = self.assigner( - pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), - anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt) + pred_scores.detach().sigmoid(), + (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), + anchor_points * stride_tensor, + gt_labels, + gt_bboxes, + mask_gt, + ) target_scores_sum = max(target_scores.sum(), 1) @@ -222,8 +235,9 @@ class v8DetectionLoss: # Bbox loss if fg_mask.sum(): target_bboxes /= stride_tensor - loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, - target_scores_sum, fg_mask) + loss[0], loss[2] = self.bbox_loss( + pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask + ) loss[0] *= self.hyp.box # box gain loss[1] *= self.hyp.cls # cls gain @@ -246,7 +260,8 @@ class v8SegmentationLoss(v8DetectionLoss): feats, pred_masks, proto = preds if len(preds) == 3 else preds[1] batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( - (self.reg_max * 4, self.nc), 1) + (self.reg_max * 4, self.nc), 1 + ) # B, grids, .. pred_scores = pred_scores.permute(0, 2, 1).contiguous() @@ -259,24 +274,31 @@ class v8SegmentationLoss(v8DetectionLoss): # Targets try: - batch_idx = batch['batch_idx'].view(-1, 1) - targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1) + batch_idx = batch["batch_idx"].view(-1, 1) + targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1) targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) except RuntimeError as e: - raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n' - "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, " - "i.e. 'yolo train model=yolov8n-seg.pt data=coco8.yaml'.\nVerify your dataset is a " - "correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' " - 'as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help.') from e + raise TypeError( + "ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n" + "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, " + "i.e. 'yolo train model=yolov8n-seg.pt data=coco8.yaml'.\nVerify your dataset is a " + "correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' " + "as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help." + ) from e # Pboxes pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner( - pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), - anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt) + pred_scores.detach().sigmoid(), + (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), + anchor_points * stride_tensor, + gt_labels, + gt_bboxes, + mask_gt, + ) target_scores_sum = max(target_scores.sum(), 1) @@ -286,15 +308,23 @@ class v8SegmentationLoss(v8DetectionLoss): if fg_mask.sum(): # Bbox loss - loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor, - target_scores, target_scores_sum, fg_mask) + loss[0], loss[3] = self.bbox_loss( + pred_distri, + pred_bboxes, + anchor_points, + target_bboxes / stride_tensor, + target_scores, + target_scores_sum, + fg_mask, + ) # Masks loss - masks = batch['masks'].to(self.device).float() + masks = batch["masks"].to(self.device).float() if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample - masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0] + masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0] - loss[1] = self.calculate_segmentation_loss(fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, - pred_masks, imgsz, self.overlap) + loss[1] = self.calculate_segmentation_loss( + fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap + ) # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove else: @@ -308,8 +338,9 @@ class v8SegmentationLoss(v8DetectionLoss): return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) @staticmethod - def single_mask_loss(gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, - area: torch.Tensor) -> torch.Tensor: + def single_mask_loss( + gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor + ) -> torch.Tensor: """ Compute the instance segmentation loss for a single image. @@ -327,8 +358,8 @@ class v8SegmentationLoss(v8DetectionLoss): The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the predicted masks from the prototype masks and predicted mask coefficients. """ - pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80) - loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none') + pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80) + loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none") return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum() def calculate_segmentation_loss( @@ -387,8 +418,9 @@ class v8SegmentationLoss(v8DetectionLoss): else: gt_mask = masks[batch_idx.view(-1) == i][mask_idx] - loss += self.single_mask_loss(gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], - marea_i[fg_mask_i]) + loss += self.single_mask_loss( + gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i] + ) # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove else: @@ -415,7 +447,8 @@ class v8PoseLoss(v8DetectionLoss): loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1] pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( - (self.reg_max * 4, self.nc), 1) + (self.reg_max * 4, self.nc), 1 + ) # B, grids, .. pred_scores = pred_scores.permute(0, 2, 1).contiguous() @@ -428,8 +461,8 @@ class v8PoseLoss(v8DetectionLoss): # Targets batch_size = pred_scores.shape[0] - batch_idx = batch['batch_idx'].view(-1, 1) - targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1) + batch_idx = batch["batch_idx"].view(-1, 1) + targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1) targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) @@ -439,8 +472,13 @@ class v8PoseLoss(v8DetectionLoss): pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3) _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner( - pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), - anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt) + pred_scores.detach().sigmoid(), + (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), + anchor_points * stride_tensor, + gt_labels, + gt_bboxes, + mask_gt, + ) target_scores_sum = max(target_scores.sum(), 1) @@ -451,14 +489,16 @@ class v8PoseLoss(v8DetectionLoss): # Bbox loss if fg_mask.sum(): target_bboxes /= stride_tensor - loss[0], loss[4] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, - target_scores_sum, fg_mask) - keypoints = batch['keypoints'].to(self.device).float().clone() + loss[0], loss[4] = self.bbox_loss( + pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask + ) + keypoints = batch["keypoints"].to(self.device).float().clone() keypoints[..., 0] *= imgsz[1] keypoints[..., 1] *= imgsz[0] - loss[1], loss[2] = self.calculate_keypoints_loss(fg_mask, target_gt_idx, keypoints, batch_idx, - stride_tensor, target_bboxes, pred_kpts) + loss[1], loss[2] = self.calculate_keypoints_loss( + fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts + ) loss[0] *= self.hyp.box # box gain loss[1] *= self.hyp.pose # pose gain @@ -477,8 +517,9 @@ class v8PoseLoss(v8DetectionLoss): y[..., 1] += anchor_points[:, [1]] - 0.5 return y - def calculate_keypoints_loss(self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, - pred_kpts): + def calculate_keypoints_loss( + self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts + ): """ Calculate the keypoints loss for the model. @@ -507,21 +548,23 @@ class v8PoseLoss(v8DetectionLoss): max_kpts = torch.unique(batch_idx, return_counts=True)[1].max() # Create a tensor to hold batched keypoints - batched_keypoints = torch.zeros((batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), - device=keypoints.device) + batched_keypoints = torch.zeros( + (batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device + ) # TODO: any idea how to vectorize this? # Fill batched_keypoints with keypoints based on batch_idx for i in range(batch_size): keypoints_i = keypoints[batch_idx == i] - batched_keypoints[i, :keypoints_i.shape[0]] = keypoints_i + batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i # Expand dimensions of target_gt_idx to match the shape of batched_keypoints target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1) # Use target_gt_idx_expanded to select keypoints from batched_keypoints selected_keypoints = batched_keypoints.gather( - 1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])) + 1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2]) + ) # Divide coordinates by stride selected_keypoints /= stride_tensor.view(1, -1, 1, 1) @@ -547,13 +590,12 @@ class v8ClassificationLoss: def __call__(self, preds, batch): """Compute the classification loss between predictions and true labels.""" - loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='mean') + loss = torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean") loss_items = loss.detach() return loss, loss_items class v8OBBLoss(v8DetectionLoss): - def __init__(self, model): # model must be de-paralleled super().__init__(model) self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0) @@ -583,7 +625,8 @@ class v8OBBLoss(v8DetectionLoss): feats, pred_angle = preds if isinstance(preds[0], list) else preds[1] batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( - (self.reg_max * 4, self.nc), 1) + (self.reg_max * 4, self.nc), 1 + ) # b, grids, .. pred_scores = pred_scores.permute(0, 2, 1).contiguous() @@ -596,19 +639,21 @@ class v8OBBLoss(v8DetectionLoss): # targets try: - batch_idx = batch['batch_idx'].view(-1, 1) - targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes'].view(-1, 5)), 1) + batch_idx = batch["batch_idx"].view(-1, 1) + targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1) rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item() targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) except RuntimeError as e: - raise TypeError('ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n' - "This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, " - "i.e. 'yolo train model=yolov8n-obb.pt data=coco8.yaml'.\nVerify your dataset is a " - "correctly formatted 'OBB' dataset using 'data=coco8-obb.yaml' " - 'as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help.') from e + raise TypeError( + "ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n" + "This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, " + "i.e. 'yolo train model=yolov8n-obb.pt data=coco8.yaml'.\nVerify your dataset is a " + "correctly formatted 'OBB' dataset using 'data=coco8-obb.yaml' " + "as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help." + ) from e # Pboxes pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4) @@ -616,10 +661,14 @@ class v8OBBLoss(v8DetectionLoss): bboxes_for_assigner = pred_bboxes.clone().detach() # Only the first four elements need to be scaled bboxes_for_assigner[..., :4] *= stride_tensor - _, target_bboxes, target_scores, fg_mask, _ = self.assigner(pred_scores.detach().sigmoid(), - bboxes_for_assigner.type(gt_bboxes.dtype), - anchor_points * stride_tensor, gt_labels, gt_bboxes, - mask_gt) + _, target_bboxes, target_scores, fg_mask, _ = self.assigner( + pred_scores.detach().sigmoid(), + bboxes_for_assigner.type(gt_bboxes.dtype), + anchor_points * stride_tensor, + gt_labels, + gt_bboxes, + mask_gt, + ) target_scores_sum = max(target_scores.sum(), 1) @@ -630,8 +679,9 @@ class v8OBBLoss(v8DetectionLoss): # Bbox loss if fg_mask.sum(): target_bboxes[..., :4] /= stride_tensor - loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, - target_scores_sum, fg_mask) + loss[0], loss[2] = self.bbox_loss( + pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask + ) else: loss[0] += (pred_angle * 0).sum() diff --git a/ultralytics/utils/metrics.py b/ultralytics/utils/metrics.py index 8d60b372..676ef737 100644 --- a/ultralytics/utils/metrics.py +++ b/ultralytics/utils/metrics.py @@ -11,7 +11,10 @@ import torch from ultralytics.utils import LOGGER, SimpleClass, TryExcept, plt_settings -OKS_SIGMA = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0 +OKS_SIGMA = ( + np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89]) + / 10.0 +) def bbox_ioa(box1, box2, iou=False, eps=1e-7): @@ -33,8 +36,9 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7): b2_x1, b2_y1, b2_x2, b2_y2 = box2.T # Intersection area - inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * \ - (np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)).clip(0) + inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * ( + np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1) + ).clip(0) # Box2 area area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) @@ -99,8 +103,9 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7 w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps # Intersection area - inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * \ - (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp_(0) + inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * ( + b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1) + ).clamp_(0) # Union Area union = w1 * h1 + w2 * h2 - inter + eps @@ -111,10 +116,10 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7 cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 - c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared + c2 = cw**2 + ch**2 + eps # convex diagonal squared rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2 if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 - v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2) + v = (4 / math.pi**2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2) with torch.no_grad(): alpha = v / (v - iou + (1 + eps)) return iou - (rho2 / c2 + v * alpha) # CIoU @@ -202,12 +207,19 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7): a1, b1, c1 = _get_covariance_matrix(obb1) a2, b2, c2 = _get_covariance_matrix(obb2) - t1 = (((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2))) / - ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.25 + t1 = ( + ((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2))) + / ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps) + ) * 0.25 t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.5 - t3 = torch.log(((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2))) / - (4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) * - (a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps) + eps) * 0.5 + t3 = ( + torch.log( + ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2))) + / (4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) * (a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps) + + eps + ) + * 0.5 + ) bd = t1 + t2 + t3 bd = torch.clamp(bd, eps, 100.0) hd = torch.sqrt(1.0 - torch.exp(-bd) + eps) @@ -215,7 +227,7 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7): if CIoU: # only include the wh aspect ratio part w1, h1 = obb1[..., 2:4].split(1, dim=-1) w2, h2 = obb2[..., 2:4].split(1, dim=-1) - v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2) + v = (4 / math.pi**2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2) with torch.no_grad(): alpha = v / (v - iou + (1 + eps)) return iou - v * alpha # CIoU @@ -239,12 +251,19 @@ def batch_probiou(obb1, obb2, eps=1e-7): a1, b1, c1 = _get_covariance_matrix(obb1) a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2)) - t1 = (((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2))) / - ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.25 + t1 = ( + ((a1 + a2) * (torch.pow(y1 - y2, 2)) + (b1 + b2) * (torch.pow(x1 - x2, 2))) + / ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps) + ) * 0.25 t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2)) + eps)) * 0.5 - t3 = torch.log(((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2))) / - (4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) * - (a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps) + eps) * 0.5 + t3 = ( + torch.log( + ((a1 + a2) * (b1 + b2) - (torch.pow(c1 + c2, 2))) + / (4 * torch.sqrt((a1 * b1 - torch.pow(c1, 2)).clamp_(0) * (a2 * b2 - torch.pow(c2, 2)).clamp_(0)) + eps) + + eps + ) + * 0.5 + ) bd = t1 + t2 + t3 bd = torch.clamp(bd, eps, 100.0) hd = torch.sqrt(1.0 - torch.exp(-bd) + eps) @@ -279,10 +298,10 @@ class ConfusionMatrix: iou_thres (float): The Intersection over Union threshold. """ - def __init__(self, nc, conf=0.25, iou_thres=0.45, task='detect'): + def __init__(self, nc, conf=0.25, iou_thres=0.45, task="detect"): """Initialize attributes for the YOLO model.""" self.task = task - self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == 'detect' else np.zeros((nc, nc)) + self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc)) self.nc = nc # number of classes self.conf = 0.25 if conf in (None, 0.001) else conf # apply 0.25 if default val conf is passed self.iou_thres = iou_thres @@ -361,11 +380,11 @@ class ConfusionMatrix: tp = self.matrix.diagonal() # true positives fp = self.matrix.sum(1) - tp # false positives # fn = self.matrix.sum(0) - tp # false negatives (missed detections) - return (tp[:-1], fp[:-1]) if self.task == 'detect' else (tp, fp) # remove background class if task=detect + return (tp[:-1], fp[:-1]) if self.task == "detect" else (tp, fp) # remove background class if task=detect - @TryExcept('WARNING ⚠️ ConfusionMatrix plot failure') + @TryExcept("WARNING ⚠️ ConfusionMatrix plot failure") @plt_settings() - def plot(self, normalize=True, save_dir='', names=(), on_plot=None): + def plot(self, normalize=True, save_dir="", names=(), on_plot=None): """ Plot the confusion matrix using seaborn and save it to a file. @@ -377,30 +396,31 @@ class ConfusionMatrix: """ import seaborn as sn - array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns + array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns array[array < 0.005] = np.nan # don't annotate (would appear as 0.00) fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True) nc, nn = self.nc, len(names) # number of classes, names sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels - ticklabels = (list(names) + ['background']) if labels else 'auto' + ticklabels = (list(names) + ["background"]) if labels else "auto" with warnings.catch_warnings(): - warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered - sn.heatmap(array, - ax=ax, - annot=nc < 30, - annot_kws={ - 'size': 8}, - cmap='Blues', - fmt='.2f' if normalize else '.0f', - square=True, - vmin=0.0, - xticklabels=ticklabels, - yticklabels=ticklabels).set_facecolor((1, 1, 1)) - title = 'Confusion Matrix' + ' Normalized' * normalize - ax.set_xlabel('True') - ax.set_ylabel('Predicted') + warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered + sn.heatmap( + array, + ax=ax, + annot=nc < 30, + annot_kws={"size": 8}, + cmap="Blues", + fmt=".2f" if normalize else ".0f", + square=True, + vmin=0.0, + xticklabels=ticklabels, + yticklabels=ticklabels, + ).set_facecolor((1, 1, 1)) + title = "Confusion Matrix" + " Normalized" * normalize + ax.set_xlabel("True") + ax.set_ylabel("Predicted") ax.set_title(title) plot_fname = Path(save_dir) / f'{title.lower().replace(" ", "_")}.png' fig.savefig(plot_fname, dpi=250) @@ -411,7 +431,7 @@ class ConfusionMatrix: def print(self): """Print the confusion matrix to the console.""" for i in range(self.nc + 1): - LOGGER.info(' '.join(map(str, self.matrix[i]))) + LOGGER.info(" ".join(map(str, self.matrix[i]))) def smooth(y, f=0.05): @@ -419,28 +439,28 @@ def smooth(y, f=0.05): nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd) p = np.ones(nf // 2) # ones padding yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded - return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed + return np.convolve(yp, np.ones(nf) / nf, mode="valid") # y-smoothed @plt_settings() -def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=(), on_plot=None): +def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=(), on_plot=None): """Plots a precision-recall curve.""" fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) py = np.stack(py, axis=1) if 0 < len(names) < 21: # display per-class legend if < 21 classes for i, y in enumerate(py.T): - ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision) + ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision) else: - ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision) + ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision) - ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean()) - ax.set_xlabel('Recall') - ax.set_ylabel('Precision') + ax.plot(px, py.mean(1), linewidth=3, color="blue", label="all classes %.3f mAP@0.5" % ap[:, 0].mean()) + ax.set_xlabel("Recall") + ax.set_ylabel("Precision") ax.set_xlim(0, 1) ax.set_ylim(0, 1) - ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left') - ax.set_title('Precision-Recall Curve') + ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") + ax.set_title("Precision-Recall Curve") fig.savefig(save_dir, dpi=250) plt.close(fig) if on_plot: @@ -448,24 +468,24 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=(), on_plot=N @plt_settings() -def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric', on_plot=None): +def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names=(), xlabel="Confidence", ylabel="Metric", on_plot=None): """Plots a metric-confidence curve.""" fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) if 0 < len(names) < 21: # display per-class legend if < 21 classes for i, y in enumerate(py): - ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric) + ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric) else: - ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric) + ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric) y = smooth(py.mean(0), 0.05) - ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}') + ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.2f} at {px[y.argmax()]:.3f}") ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_xlim(0, 1) ax.set_ylim(0, 1) - ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left') - ax.set_title(f'{ylabel}-Confidence Curve') + ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") + ax.set_title(f"{ylabel}-Confidence Curve") fig.savefig(save_dir, dpi=250) plt.close(fig) if on_plot: @@ -494,8 +514,8 @@ def compute_ap(recall, precision): mpre = np.flip(np.maximum.accumulate(np.flip(mpre))) # Integrate area under curve - method = 'interp' # methods: 'continuous', 'interp' - if method == 'interp': + method = "interp" # methods: 'continuous', 'interp' + if method == "interp": x = np.linspace(0, 1, 101) # 101-point interp (COCO) ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate else: # 'continuous' @@ -505,16 +525,9 @@ def compute_ap(recall, precision): return ap, mpre, mrec -def ap_per_class(tp, - conf, - pred_cls, - target_cls, - plot=False, - on_plot=None, - save_dir=Path(), - names=(), - eps=1e-16, - prefix=''): +def ap_per_class( + tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names=(), eps=1e-16, prefix="" +): """ Computes the average precision per class for object detection evaluation. @@ -591,10 +604,10 @@ def ap_per_class(tp, names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data names = dict(enumerate(names)) # to dict if plot: - plot_pr_curve(x, prec_values, ap, save_dir / f'{prefix}PR_curve.png', names, on_plot=on_plot) - plot_mc_curve(x, f1_curve, save_dir / f'{prefix}F1_curve.png', names, ylabel='F1', on_plot=on_plot) - plot_mc_curve(x, p_curve, save_dir / f'{prefix}P_curve.png', names, ylabel='Precision', on_plot=on_plot) - plot_mc_curve(x, r_curve, save_dir / f'{prefix}R_curve.png', names, ylabel='Recall', on_plot=on_plot) + plot_pr_curve(x, prec_values, ap, save_dir / f"{prefix}PR_curve.png", names, on_plot=on_plot) + plot_mc_curve(x, f1_curve, save_dir / f"{prefix}F1_curve.png", names, ylabel="F1", on_plot=on_plot) + plot_mc_curve(x, p_curve, save_dir / f"{prefix}P_curve.png", names, ylabel="Precision", on_plot=on_plot) + plot_mc_curve(x, r_curve, save_dir / f"{prefix}R_curve.png", names, ylabel="Recall", on_plot=on_plot) i = smooth(f1_curve.mean(0), 0.1).argmax() # max F1 index p, r, f1 = p_curve[:, i], r_curve[:, i], f1_curve[:, i] # max-F1 precision, recall, F1 values @@ -746,8 +759,18 @@ class Metric(SimpleClass): Updates the class attributes `self.p`, `self.r`, `self.f1`, `self.all_ap`, and `self.ap_class_index` based on the values provided in the `results` tuple. """ - (self.p, self.r, self.f1, self.all_ap, self.ap_class_index, self.p_curve, self.r_curve, self.f1_curve, self.px, - self.prec_values) = results + ( + self.p, + self.r, + self.f1, + self.all_ap, + self.ap_class_index, + self.p_curve, + self.r_curve, + self.f1_curve, + self.px, + self.prec_values, + ) = results @property def curves(self): @@ -757,8 +780,12 @@ class Metric(SimpleClass): @property def curves_results(self): """Returns a list of curves for accessing specific metrics curves.""" - return [[self.px, self.prec_values, 'Recall', 'Precision'], [self.px, self.f1_curve, 'Confidence', 'F1'], - [self.px, self.p_curve, 'Confidence', 'Precision'], [self.px, self.r_curve, 'Confidence', 'Recall']] + return [ + [self.px, self.prec_values, "Recall", "Precision"], + [self.px, self.f1_curve, "Confidence", "F1"], + [self.px, self.p_curve, "Confidence", "Precision"], + [self.px, self.r_curve, "Confidence", "Recall"], + ] class DetMetrics(SimpleClass): @@ -793,33 +820,35 @@ class DetMetrics(SimpleClass): curves_results: TODO """ - def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None: + def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None: """Initialize a DetMetrics instance with a save directory, plot flag, callback function, and class names.""" self.save_dir = save_dir self.plot = plot self.on_plot = on_plot self.names = names self.box = Metric() - self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0} - self.task = 'detect' + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + self.task = "detect" def process(self, tp, conf, pred_cls, target_cls): """Process predicted results for object detection and update metrics.""" - results = ap_per_class(tp, - conf, - pred_cls, - target_cls, - plot=self.plot, - save_dir=self.save_dir, - names=self.names, - on_plot=self.on_plot)[2:] + results = ap_per_class( + tp, + conf, + pred_cls, + target_cls, + plot=self.plot, + save_dir=self.save_dir, + names=self.names, + on_plot=self.on_plot, + )[2:] self.box.nc = len(self.names) self.box.update(results) @property def keys(self): """Returns a list of keys for accessing specific metrics.""" - return ['metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)'] + return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"] def mean_results(self): """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95.""" @@ -847,12 +876,12 @@ class DetMetrics(SimpleClass): @property def results_dict(self): """Returns dictionary of computed performance metrics and statistics.""" - return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness])) + return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness])) @property def curves(self): """Returns a list of curves for accessing specific metrics curves.""" - return ['Precision-Recall(B)', 'F1-Confidence(B)', 'Precision-Confidence(B)', 'Recall-Confidence(B)'] + return ["Precision-Recall(B)", "F1-Confidence(B)", "Precision-Confidence(B)", "Recall-Confidence(B)"] @property def curves_results(self): @@ -889,7 +918,7 @@ class SegmentMetrics(SimpleClass): results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score. """ - def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None: + def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None: """Initialize a SegmentMetrics instance with a save directory, plot flag, callback function, and class names.""" self.save_dir = save_dir self.plot = plot @@ -897,8 +926,8 @@ class SegmentMetrics(SimpleClass): self.names = names self.box = Metric() self.seg = Metric() - self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0} - self.task = 'segment' + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + self.task = "segment" def process(self, tp, tp_m, conf, pred_cls, target_cls): """ @@ -912,26 +941,30 @@ class SegmentMetrics(SimpleClass): target_cls (list): List of target classes. """ - results_mask = ap_per_class(tp_m, - conf, - pred_cls, - target_cls, - plot=self.plot, - on_plot=self.on_plot, - save_dir=self.save_dir, - names=self.names, - prefix='Mask')[2:] + results_mask = ap_per_class( + tp_m, + conf, + pred_cls, + target_cls, + plot=self.plot, + on_plot=self.on_plot, + save_dir=self.save_dir, + names=self.names, + prefix="Mask", + )[2:] self.seg.nc = len(self.names) self.seg.update(results_mask) - results_box = ap_per_class(tp, - conf, - pred_cls, - target_cls, - plot=self.plot, - on_plot=self.on_plot, - save_dir=self.save_dir, - names=self.names, - prefix='Box')[2:] + results_box = ap_per_class( + tp, + conf, + pred_cls, + target_cls, + plot=self.plot, + on_plot=self.on_plot, + save_dir=self.save_dir, + names=self.names, + prefix="Box", + )[2:] self.box.nc = len(self.names) self.box.update(results_box) @@ -939,8 +972,15 @@ class SegmentMetrics(SimpleClass): def keys(self): """Returns a list of keys for accessing metrics.""" return [ - 'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)', - 'metrics/precision(M)', 'metrics/recall(M)', 'metrics/mAP50(M)', 'metrics/mAP50-95(M)'] + "metrics/precision(B)", + "metrics/recall(B)", + "metrics/mAP50(B)", + "metrics/mAP50-95(B)", + "metrics/precision(M)", + "metrics/recall(M)", + "metrics/mAP50(M)", + "metrics/mAP50-95(M)", + ] def mean_results(self): """Return the mean metrics for bounding box and segmentation results.""" @@ -968,14 +1008,21 @@ class SegmentMetrics(SimpleClass): @property def results_dict(self): """Returns results of object detection model for evaluation.""" - return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness])) + return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness])) @property def curves(self): """Returns a list of curves for accessing specific metrics curves.""" return [ - 'Precision-Recall(B)', 'F1-Confidence(B)', 'Precision-Confidence(B)', 'Recall-Confidence(B)', - 'Precision-Recall(M)', 'F1-Confidence(M)', 'Precision-Confidence(M)', 'Recall-Confidence(M)'] + "Precision-Recall(B)", + "F1-Confidence(B)", + "Precision-Confidence(B)", + "Recall-Confidence(B)", + "Precision-Recall(M)", + "F1-Confidence(M)", + "Precision-Confidence(M)", + "Recall-Confidence(M)", + ] @property def curves_results(self): @@ -1012,7 +1059,7 @@ class PoseMetrics(SegmentMetrics): results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score. """ - def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None: + def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None: """Initialize the PoseMetrics class with directory path, class names, and plotting options.""" super().__init__(save_dir, plot, names) self.save_dir = save_dir @@ -1021,8 +1068,8 @@ class PoseMetrics(SegmentMetrics): self.names = names self.box = Metric() self.pose = Metric() - self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0} - self.task = 'pose' + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + self.task = "pose" def process(self, tp, tp_p, conf, pred_cls, target_cls): """ @@ -1036,26 +1083,30 @@ class PoseMetrics(SegmentMetrics): target_cls (list): List of target classes. """ - results_pose = ap_per_class(tp_p, - conf, - pred_cls, - target_cls, - plot=self.plot, - on_plot=self.on_plot, - save_dir=self.save_dir, - names=self.names, - prefix='Pose')[2:] + results_pose = ap_per_class( + tp_p, + conf, + pred_cls, + target_cls, + plot=self.plot, + on_plot=self.on_plot, + save_dir=self.save_dir, + names=self.names, + prefix="Pose", + )[2:] self.pose.nc = len(self.names) self.pose.update(results_pose) - results_box = ap_per_class(tp, - conf, - pred_cls, - target_cls, - plot=self.plot, - on_plot=self.on_plot, - save_dir=self.save_dir, - names=self.names, - prefix='Box')[2:] + results_box = ap_per_class( + tp, + conf, + pred_cls, + target_cls, + plot=self.plot, + on_plot=self.on_plot, + save_dir=self.save_dir, + names=self.names, + prefix="Box", + )[2:] self.box.nc = len(self.names) self.box.update(results_box) @@ -1063,8 +1114,15 @@ class PoseMetrics(SegmentMetrics): def keys(self): """Returns list of evaluation metric keys.""" return [ - 'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)', - 'metrics/precision(P)', 'metrics/recall(P)', 'metrics/mAP50(P)', 'metrics/mAP50-95(P)'] + "metrics/precision(B)", + "metrics/recall(B)", + "metrics/mAP50(B)", + "metrics/mAP50-95(B)", + "metrics/precision(P)", + "metrics/recall(P)", + "metrics/mAP50(P)", + "metrics/mAP50-95(P)", + ] def mean_results(self): """Return the mean results of box and pose.""" @@ -1088,8 +1146,15 @@ class PoseMetrics(SegmentMetrics): def curves(self): """Returns a list of curves for accessing specific metrics curves.""" return [ - 'Precision-Recall(B)', 'F1-Confidence(B)', 'Precision-Confidence(B)', 'Recall-Confidence(B)', - 'Precision-Recall(P)', 'F1-Confidence(P)', 'Precision-Confidence(P)', 'Recall-Confidence(P)'] + "Precision-Recall(B)", + "F1-Confidence(B)", + "Precision-Confidence(B)", + "Recall-Confidence(B)", + "Precision-Recall(P)", + "F1-Confidence(P)", + "Precision-Confidence(P)", + "Recall-Confidence(P)", + ] @property def curves_results(self): @@ -1119,8 +1184,8 @@ class ClassifyMetrics(SimpleClass): """Initialize a ClassifyMetrics instance.""" self.top1 = 0 self.top5 = 0 - self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0} - self.task = 'classify' + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + self.task = "classify" def process(self, targets, pred): """Target classes and predicted classes.""" @@ -1137,12 +1202,12 @@ class ClassifyMetrics(SimpleClass): @property def results_dict(self): """Returns a dictionary with model's performance metrics and fitness score.""" - return dict(zip(self.keys + ['fitness'], [self.top1, self.top5, self.fitness])) + return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness])) @property def keys(self): """Returns a list of keys for the results_dict property.""" - return ['metrics/accuracy_top1', 'metrics/accuracy_top5'] + return ["metrics/accuracy_top1", "metrics/accuracy_top5"] @property def curves(self): @@ -1156,32 +1221,33 @@ class ClassifyMetrics(SimpleClass): class OBBMetrics(SimpleClass): - - def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None: + def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None: self.save_dir = save_dir self.plot = plot self.on_plot = on_plot self.names = names self.box = Metric() - self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0} + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} def process(self, tp, conf, pred_cls, target_cls): """Process predicted results for object detection and update metrics.""" - results = ap_per_class(tp, - conf, - pred_cls, - target_cls, - plot=self.plot, - save_dir=self.save_dir, - names=self.names, - on_plot=self.on_plot)[2:] + results = ap_per_class( + tp, + conf, + pred_cls, + target_cls, + plot=self.plot, + save_dir=self.save_dir, + names=self.names, + on_plot=self.on_plot, + )[2:] self.box.nc = len(self.names) self.box.update(results) @property def keys(self): """Returns a list of keys for accessing specific metrics.""" - return ['metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)'] + return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"] def mean_results(self): """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95.""" @@ -1209,7 +1275,7 @@ class OBBMetrics(SimpleClass): @property def results_dict(self): """Returns dictionary of computed performance metrics and statistics.""" - return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness])) + return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness])) @property def curves(self): diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py index d589436d..8a86239c 100644 --- a/ultralytics/utils/ops.py +++ b/ultralytics/utils/ops.py @@ -52,7 +52,7 @@ class Profile(contextlib.ContextDecorator): def __str__(self): """Returns a human-readable string representing the accumulated elapsed time in the profiler.""" - return f'Elapsed time is {self.t} s' + return f"Elapsed time is {self.t} s" def time(self): """Get current time.""" @@ -76,9 +76,13 @@ def segment2box(segment, width=640, height=640): # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy) x, y = segment.T # segment xy inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height) - x, y, = x[inside], y[inside] - return np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) if any(x) else np.zeros( - 4, dtype=segment.dtype) # xyxy + x = x[inside] + y = y[inside] + return ( + np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) + if any(x) + else np.zeros(4, dtype=segment.dtype) + ) # xyxy def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False): @@ -101,8 +105,10 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xyw """ if ratio_pad is None: # calculate from img0_shape gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new - pad = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), round( - (img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1) # wh padding + pad = ( + round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), + round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1), + ) # wh padding else: gain = ratio_pad[0][0] pad = ratio_pad[1] @@ -145,7 +151,7 @@ def nms_rotated(boxes, scores, threshold=0.45): Returns: """ if len(boxes) == 0: - return np.empty((0, ), dtype=np.int8) + return np.empty((0,), dtype=np.int8) sorted_idx = torch.argsort(scores, descending=True) boxes = boxes[sorted_idx] ious = batch_probiou(boxes, boxes).triu_(diagonal=1) @@ -199,8 +205,8 @@ def non_max_suppression( """ # Checks - assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0' - assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0' + assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0" + assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0" if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out) prediction = prediction[0] # select only inference output @@ -284,7 +290,7 @@ def non_max_suppression( output[xi] = x[i] if (time.time() - t) > time_limit: - LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded') + LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded") break # time limit exceeded return output @@ -378,7 +384,7 @@ def xyxy2xywh(x): Returns: y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format. """ - assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}' + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center @@ -398,7 +404,7 @@ def xywh2xyxy(x): Returns: y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format. """ - assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}' + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy dw = x[..., 2] / 2 # half-width dh = x[..., 3] / 2 # half-height @@ -423,7 +429,7 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box. """ - assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}' + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y @@ -449,7 +455,7 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0): """ if clip: x = clip_boxes(x, (h - eps, w - eps)) - assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}' + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center @@ -526,8 +532,11 @@ def xyxyxyxy2xywhr(corners): # especially some objects are cut off by augmentations in dataloader. (x, y), (w, h), angle = cv2.minAreaRect(pts) rboxes.append([x, y, w, h, angle / 180 * np.pi]) - rboxes = torch.tensor(rboxes, device=corners.device, dtype=corners.dtype) if is_torch else np.asarray( - rboxes, dtype=points.dtype) + rboxes = ( + torch.tensor(rboxes, device=corners.device, dtype=corners.dtype) + if is_torch + else np.asarray(rboxes, dtype=points.dtype) + ) return rboxes @@ -546,7 +555,7 @@ def xywhr2xyxyxyxy(center): cos, sin = (np.cos, np.sin) if is_numpy else (torch.cos, torch.sin) ctr = center[..., :2] - w, h, angle = (center[..., i:i + 1] for i in range(2, 5)) + w, h, angle = (center[..., i : i + 1] for i in range(2, 5)) cos_value, sin_value = cos(angle), sin(angle) vec1 = [w / 2 * cos_value, w / 2 * sin_value] vec2 = [-h / 2 * sin_value, h / 2 * cos_value] @@ -607,8 +616,9 @@ def resample_segments(segments, n=1000): s = np.concatenate((s, s[0:1, :]), axis=0) x = np.linspace(0, len(s) - 1, n) xp = np.arange(len(s)) - segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], - dtype=np.float32).reshape(2, -1).T # segment xy + segments[i] = ( + np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T + ) # segment xy return segments @@ -647,7 +657,7 @@ def process_mask_upsample(protos, masks_in, bboxes, shape): """ c, mh, mw = protos.shape # CHW masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) - masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW + masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW masks = crop_mask(masks, bboxes) # CHW return masks.gt_(0.5) @@ -680,7 +690,7 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False): masks = crop_mask(masks, downsampled_bboxes) # CHW if upsample: - masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW + masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW return masks.gt_(0.5) @@ -724,7 +734,7 @@ def scale_masks(masks, shape, padding=True): bottom, right = (int(round(mh - pad[1] + 0.1)), int(round(mw - pad[0] + 0.1))) masks = masks[..., top:bottom, left:right] - masks = F.interpolate(masks, shape, mode='bilinear', align_corners=False) # NCHW + masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False) # NCHW return masks @@ -763,7 +773,7 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False return coords -def masks2segments(masks, strategy='largest'): +def masks2segments(masks, strategy="largest"): """ It takes a list of masks(n,h,w) and returns a list of segments(n,xy) @@ -775,16 +785,16 @@ def masks2segments(masks, strategy='largest'): segments (List): list of segment masks """ segments = [] - for x in masks.int().cpu().numpy().astype('uint8'): + for x in masks.int().cpu().numpy().astype("uint8"): c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] if c: - if strategy == 'concat': # concatenate all segments + if strategy == "concat": # concatenate all segments c = np.concatenate([x.reshape(-1, 2) for x in c]) - elif strategy == 'largest': # select largest segment + elif strategy == "largest": # select largest segment c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2) else: c = np.zeros((0, 2)) # no segments found - segments.append(c.astype('float32')) + segments.append(c.astype("float32")) return segments @@ -811,4 +821,4 @@ def clean_str(s): Returns: (str): a string with special characters replaced by an underscore _ """ - return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s) + return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s) diff --git a/ultralytics/utils/patches.py b/ultralytics/utils/patches.py index 541cf45a..f9c5bb1b 100644 --- a/ultralytics/utils/patches.py +++ b/ultralytics/utils/patches.py @@ -52,7 +52,7 @@ def imshow(winname: str, mat: np.ndarray): winname (str): Name of the window. mat (np.ndarray): Image to be shown. """ - _imshow(winname.encode('unicode_escape').decode(), mat) + _imshow(winname.encode("unicode_escape").decode(), mat) # PyTorch functions ---------------------------------------------------------------------------------------------------- @@ -72,6 +72,6 @@ def torch_save(*args, **kwargs): except ImportError: import pickle - if 'pickle_module' not in kwargs: - kwargs['pickle_module'] = pickle # noqa + if "pickle_module" not in kwargs: + kwargs["pickle_module"] = pickle # noqa return _torch_save(*args, **kwargs) diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py index 46fca19f..280a3e28 100644 --- a/ultralytics/utils/plotting.py +++ b/ultralytics/utils/plotting.py @@ -33,15 +33,55 @@ class Colors: def __init__(self): """Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values().""" - hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB', - '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7') - self.palette = [self.hex2rgb(f'#{c}') for c in hexs] + hexs = ( + "FF3838", + "FF9D97", + "FF701F", + "FFB21D", + "CFD231", + "48F90A", + "92CC17", + "3DDB86", + "1A9334", + "00D4BB", + "2C99A8", + "00C2FF", + "344593", + "6473FF", + "0018EC", + "8438FF", + "520085", + "CB38FF", + "FF95C8", + "FF37C7", + ) + self.palette = [self.hex2rgb(f"#{c}") for c in hexs] self.n = len(self.palette) - self.pose_palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], [230, 230, 0], [255, 153, 255], - [153, 204, 255], [255, 102, 255], [255, 51, 255], [102, 178, 255], [51, 153, 255], - [255, 153, 153], [255, 102, 102], [255, 51, 51], [153, 255, 153], [102, 255, 102], - [51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0], [255, 255, 255]], - dtype=np.uint8) + self.pose_palette = np.array( + [ + [255, 128, 0], + [255, 153, 51], + [255, 178, 102], + [230, 230, 0], + [255, 153, 255], + [153, 204, 255], + [255, 102, 255], + [255, 51, 255], + [102, 178, 255], + [51, 153, 255], + [255, 153, 153], + [255, 102, 102], + [255, 51, 51], + [153, 255, 153], + [102, 255, 102], + [51, 255, 51], + [0, 255, 0], + [0, 0, 255], + [255, 0, 0], + [255, 255, 255], + ], + dtype=np.uint8, + ) def __call__(self, i, bgr=False): """Converts hex color codes to RGB values.""" @@ -51,7 +91,7 @@ class Colors: @staticmethod def hex2rgb(h): """Converts hex color codes to RGB values (i.e. default PIL order).""" - return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4)) + return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4)) colors = Colors() # create instance for 'from utils.plots import colors' @@ -71,9 +111,9 @@ class Annotator: kpt_color (List[int]): Color palette for keypoints. """ - def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'): + def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"): """Initialize the Annotator class with image and line width along with color palette for keypoints and limbs.""" - assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.' + assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images." non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic self.pil = pil or non_ascii self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width @@ -81,26 +121,45 @@ class Annotator: self.im = im if isinstance(im, Image.Image) else Image.fromarray(im) self.draw = ImageDraw.Draw(self.im) try: - font = check_font('Arial.Unicode.ttf' if non_ascii else font) + font = check_font("Arial.Unicode.ttf" if non_ascii else font) size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12) self.font = ImageFont.truetype(str(font), size) except Exception: self.font = ImageFont.load_default() # Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string) - if check_version(pil_version, '9.2.0'): + if check_version(pil_version, "9.2.0"): self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height else: # use cv2 self.im = im if im.flags.writeable else im.copy() self.tf = max(self.lw - 1, 1) # font thickness self.sf = self.lw / 3 # font scale # Pose - self.skeleton = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], [6, 7], [6, 8], [7, 9], - [8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]] + self.skeleton = [ + [16, 14], + [14, 12], + [17, 15], + [15, 13], + [12, 13], + [6, 12], + [7, 13], + [6, 7], + [6, 8], + [7, 9], + [8, 10], + [9, 11], + [2, 3], + [1, 2], + [1, 3], + [2, 4], + [3, 5], + [4, 6], + [5, 7], + ] self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]] self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]] - def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False): + def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False): """Add one xyxy box to image with label.""" if isinstance(box, torch.Tensor): box = box.tolist() @@ -134,13 +193,16 @@ class Annotator: outside = p1[1] - h >= 3 p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled - cv2.putText(self.im, - label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), - 0, - self.sf, - txt_color, - thickness=self.tf, - lineType=cv2.LINE_AA) + cv2.putText( + self.im, + label, + (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), + 0, + self.sf, + txt_color, + thickness=self.tf, + lineType=cv2.LINE_AA, + ) def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False): """ @@ -171,7 +233,7 @@ class Annotator: im_gpu = im_gpu.flip(dims=[0]) # flip channel im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3) im_gpu = im_gpu * inv_alpha_masks[-1] + mcs - im_mask = (im_gpu * 255) + im_mask = im_gpu * 255 im_mask_np = im_mask.byte().cpu().numpy() self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape) if self.pil: @@ -230,9 +292,9 @@ class Annotator: """Add rectangle to image (PIL-only).""" self.draw.rectangle(xy, fill, outline, width) - def text(self, xy, text, txt_color=(255, 255, 255), anchor='top', box_style=False): + def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_style=False): """Adds text to an image using PIL or cv2.""" - if anchor == 'bottom': # start y from font bottom + if anchor == "bottom": # start y from font bottom w, h = self.font.getsize(text) # text width, height xy[1] += 1 - h if self.pil: @@ -241,8 +303,8 @@ class Annotator: self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color) # Using `txt_color` for background and draw fg with white color txt_color = (255, 255, 255) - if '\n' in text: - lines = text.split('\n') + if "\n" in text: + lines = text.split("\n") _, h = self.font.getsize(text) for line in lines: self.draw.text(xy, line, fill=txt_color, font=self.font) @@ -314,15 +376,12 @@ class Annotator: text_y = t_size_in[1] # Create a rounded rectangle for in_count - cv2.rectangle(self.im, (text_x - 5, text_y - 5), (text_x + text_width + 7, text_y + t_size_in[1] + 7), color, - -1) - cv2.putText(self.im, - str(counts), (text_x, text_y + t_size_in[1]), - 0, - tl / 2, - txt_color, - self.tf, - lineType=cv2.LINE_AA) + cv2.rectangle( + self.im, (text_x - 5, text_y - 5), (text_x + text_width + 7, text_y + t_size_in[1] + 7), color, -1 + ) + cv2.putText( + self.im, str(counts), (text_x, text_y + t_size_in[1]), 0, tl / 2, txt_color, self.tf, lineType=cv2.LINE_AA + ) @staticmethod def estimate_pose_angle(a, b, c): @@ -375,7 +434,7 @@ class Annotator: center_kpt (int): centroid pose index for workout monitoring line_thickness (int): thickness for text display """ - angle_text, count_text, stage_text = (f' {angle_text:.2f}', 'Steps : ' + f'{count_text}', f' {stage_text}') + angle_text, count_text, stage_text = (f" {angle_text:.2f}", "Steps : " + f"{count_text}", f" {stage_text}") font_scale = 0.6 + (line_thickness / 10.0) # Draw angle @@ -383,21 +442,37 @@ class Annotator: angle_text_position = (int(center_kpt[0]), int(center_kpt[1])) angle_background_position = (angle_text_position[0], angle_text_position[1] - angle_text_height - 5) angle_background_size = (angle_text_width + 2 * 5, angle_text_height + 2 * 5 + (line_thickness * 2)) - cv2.rectangle(self.im, angle_background_position, (angle_background_position[0] + angle_background_size[0], - angle_background_position[1] + angle_background_size[1]), - (255, 255, 255), -1) + cv2.rectangle( + self.im, + angle_background_position, + ( + angle_background_position[0] + angle_background_size[0], + angle_background_position[1] + angle_background_size[1], + ), + (255, 255, 255), + -1, + ) cv2.putText(self.im, angle_text, angle_text_position, 0, font_scale, (0, 0, 0), line_thickness) # Draw Counts (count_text_width, count_text_height), _ = cv2.getTextSize(count_text, 0, font_scale, line_thickness) count_text_position = (angle_text_position[0], angle_text_position[1] + angle_text_height + 20) - count_background_position = (angle_background_position[0], - angle_background_position[1] + angle_background_size[1] + 5) + count_background_position = ( + angle_background_position[0], + angle_background_position[1] + angle_background_size[1] + 5, + ) count_background_size = (count_text_width + 10, count_text_height + 10 + (line_thickness * 2)) - cv2.rectangle(self.im, count_background_position, (count_background_position[0] + count_background_size[0], - count_background_position[1] + count_background_size[1]), - (255, 255, 255), -1) + cv2.rectangle( + self.im, + count_background_position, + ( + count_background_position[0] + count_background_size[0], + count_background_position[1] + count_background_size[1], + ), + (255, 255, 255), + -1, + ) cv2.putText(self.im, count_text, count_text_position, 0, font_scale, (0, 0, 0), line_thickness) # Draw Stage @@ -406,9 +481,16 @@ class Annotator: stage_background_position = (stage_text_position[0], stage_text_position[1] - stage_text_height - 5) stage_background_size = (stage_text_width + 10, stage_text_height + 10) - cv2.rectangle(self.im, stage_background_position, (stage_background_position[0] + stage_background_size[0], - stage_background_position[1] + stage_background_size[1]), - (255, 255, 255), -1) + cv2.rectangle( + self.im, + stage_background_position, + ( + stage_background_position[0] + stage_background_size[0], + stage_background_position[1] + stage_background_size[1], + ), + (255, 255, 255), + -1, + ) cv2.putText(self.im, stage_text, stage_text_position, 0, font_scale, (0, 0, 0), line_thickness) def seg_bbox(self, mask, mask_color=(255, 0, 255), det_label=None, track_label=None): @@ -423,14 +505,20 @@ class Annotator: """ cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2) - label = f'Track ID: {track_label}' if track_label else det_label + label = f"Track ID: {track_label}" if track_label else det_label text_size, _ = cv2.getTextSize(label, 0, 0.7, 1) - cv2.rectangle(self.im, (int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10), - (int(mask[0][0]) + text_size[0] // 2 + 5, int(mask[0][1] + 5)), mask_color, -1) + cv2.rectangle( + self.im, + (int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10), + (int(mask[0][0]) + text_size[0] // 2 + 5, int(mask[0][1] + 5)), + mask_color, + -1, + ) - cv2.putText(self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1]) - 5), 0, 0.7, (255, 255, 255), - 2) + cv2.putText( + self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1]) - 5), 0, 0.7, (255, 255, 255), 2 + ) def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, 255), thickness=2, pins_radius=10): """ @@ -452,24 +540,24 @@ class Annotator: @TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395 @plt_settings() -def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None): +def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None): """Plot training labels including class histograms and box statistics.""" import pandas as pd import seaborn as sn # Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings - warnings.filterwarnings('ignore', category=UserWarning, message='The figure layout has changed to tight') - warnings.filterwarnings('ignore', category=FutureWarning) + warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight") + warnings.filterwarnings("ignore", category=FutureWarning) # Plot dataset labels LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ") nc = int(cls.max() + 1) # number of classes boxes = boxes[:1000000] # limit to 1M boxes - x = pd.DataFrame(boxes, columns=['x', 'y', 'width', 'height']) + x = pd.DataFrame(boxes, columns=["x", "y", "width", "height"]) # Seaborn correlogram - sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9)) - plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200) + sn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9)) + plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200) plt.close() # Matplotlib labels @@ -477,14 +565,14 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None): y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) for i in range(nc): y[2].patches[i].set_color([x / 255 for x in colors(i)]) - ax[0].set_ylabel('instances') + ax[0].set_ylabel("instances") if 0 < len(names) < 30: ax[0].set_xticks(range(len(names))) ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10) else: - ax[0].set_xlabel('classes') - sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9) - sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9) + ax[0].set_xlabel("classes") + sn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9) + sn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9) # Rectangles boxes[:, 0:2] = 0.5 # center @@ -493,20 +581,20 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None): for cls, box in zip(cls[:500], boxes[:500]): ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot ax[1].imshow(img) - ax[1].axis('off') + ax[1].axis("off") for a in [0, 1, 2, 3]: - for s in ['top', 'right', 'left', 'bottom']: + for s in ["top", "right", "left", "bottom"]: ax[a].spines[s].set_visible(False) - fname = save_dir / 'labels.jpg' + fname = save_dir / "labels.jpg" plt.savefig(fname, dpi=200) plt.close() if on_plot: on_plot(fname) -def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True): +def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True): """ Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop. @@ -545,29 +633,31 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad xyxy = ops.xywh2xyxy(b).long() xyxy = ops.clip_boxes(xyxy, im.shape) - crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)] + crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR else -1)] if save: file.parent.mkdir(parents=True, exist_ok=True) # make directory - f = str(increment_path(file).with_suffix('.jpg')) + f = str(increment_path(file).with_suffix(".jpg")) # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB return crop @threaded -def plot_images(images, - batch_idx, - cls, - bboxes=np.zeros(0, dtype=np.float32), - confs=None, - masks=np.zeros(0, dtype=np.uint8), - kpts=np.zeros((0, 51), dtype=np.float32), - paths=None, - fname='images.jpg', - names=None, - on_plot=None, - max_subplots=16, - save=True): +def plot_images( + images, + batch_idx, + cls, + bboxes=np.zeros(0, dtype=np.float32), + confs=None, + masks=np.zeros(0, dtype=np.uint8), + kpts=np.zeros((0, 51), dtype=np.float32), + paths=None, + fname="images.jpg", + names=None, + on_plot=None, + max_subplots=16, + save=True, +): """Plot image grid with labels.""" if isinstance(images, torch.Tensor): images = images.cpu().float().numpy() @@ -585,7 +675,7 @@ def plot_images(images, max_size = 1920 # max image size bs, _, h, w = images.shape # batch size, _, height, width bs = min(bs, max_subplots) # limit plot images - ns = np.ceil(bs ** 0.5) # number of subplots (square) + ns = np.ceil(bs**0.5) # number of subplots (square) if np.max(images[0]) <= 1: images *= 255 # de-normalise (optional) @@ -593,7 +683,7 @@ def plot_images(images, mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init for i in range(bs): x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin - mosaic[y:y + h, x:x + w, :] = images[i].transpose(1, 2, 0) + mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0) # Resize (optional) scale = max_size / ns / max(h, w) @@ -612,7 +702,7 @@ def plot_images(images, annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames if len(cls) > 0: idx = batch_idx == i - classes = cls[idx].astype('int') + classes = cls[idx].astype("int") labels = confs is None if len(bboxes): @@ -633,14 +723,14 @@ def plot_images(images, color = colors(c) c = names.get(c, c) if names else c if labels or conf[j] > 0.25: # 0.25 conf thresh - label = f'{c}' if labels else f'{c} {conf[j]:.1f}' + label = f"{c}" if labels else f"{c} {conf[j]:.1f}" annotator.box_label(box, label, color=color, rotated=is_obb) elif len(classes): for c in classes: color = colors(c) c = names.get(c, c) if names else c - annotator.text((x, y), f'{c}', txt_color=color, box_style=True) + annotator.text((x, y), f"{c}", txt_color=color, box_style=True) # Plot keypoints if len(kpts): @@ -680,7 +770,9 @@ def plot_images(images, else: mask = image_masks[j].astype(bool) with contextlib.suppress(Exception): - im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6 + im[y : y + h, x : x + w, :][mask] = ( + im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6 + ) annotator.fromarray(im) if save: annotator.im.save(fname) # save @@ -691,7 +783,7 @@ def plot_images(images, @plt_settings() -def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False, on_plot=None): +def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None): """ Plot training results from a results CSV file. The function supports various types of data including segmentation, pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located. @@ -714,6 +806,7 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, """ import pandas as pd from scipy.ndimage import gaussian_filter1d + save_dir = Path(file).parent if file else Path(dir) if classify: fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True) @@ -728,32 +821,32 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True) index = [1, 2, 3, 4, 5, 8, 9, 10, 6, 7] ax = ax.ravel() - files = list(save_dir.glob('results*.csv')) - assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.' + files = list(save_dir.glob("results*.csv")) + assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot." for f in files: try: data = pd.read_csv(f) s = [x.strip() for x in data.columns] x = data.values[:, 0] for i, j in enumerate(index): - y = data.values[:, j].astype('float') + y = data.values[:, j].astype("float") # y[y == 0] = np.nan # don't show zero values - ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8) # actual results - ax[i].plot(x, gaussian_filter1d(y, sigma=3), ':', label='smooth', linewidth=2) # smoothing line + ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results + ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line ax[i].set_title(s[j], fontsize=12) # if j in [8, 9, 10]: # share train and val loss y axes # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5]) except Exception as e: - LOGGER.warning(f'WARNING: Plotting error for {f}: {e}') + LOGGER.warning(f"WARNING: Plotting error for {f}: {e}") ax[1].legend() - fname = save_dir / 'results.png' + fname = save_dir / "results.png" fig.savefig(fname, dpi=200) plt.close() if on_plot: on_plot(fname) -def plt_color_scatter(v, f, bins=20, cmap='viridis', alpha=0.8, edgecolors='none'): +def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"): """ Plots a scatter plot with points colored based on a 2D histogram. @@ -774,14 +867,18 @@ def plt_color_scatter(v, f, bins=20, cmap='viridis', alpha=0.8, edgecolors='none # Calculate 2D histogram and corresponding colors hist, xedges, yedges = np.histogram2d(v, f, bins=bins) colors = [ - hist[min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1), - min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1)] for i in range(len(v))] + hist[ + min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1), + min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1), + ] + for i in range(len(v)) + ] # Scatter plot plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors) -def plot_tune_results(csv_file='tune_results.csv'): +def plot_tune_results(csv_file="tune_results.csv"): """ Plot the evolution results stored in an 'tune_results.csv' file. The function generates a scatter plot for each key in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots. @@ -810,33 +907,33 @@ def plot_tune_results(csv_file='tune_results.csv'): v = x[:, i + num_metrics_columns] mu = v[j] # best single result plt.subplot(n, n, i + 1) - plt_color_scatter(v, fitness, cmap='viridis', alpha=.8, edgecolors='none') - plt.plot(mu, fitness.max(), 'k+', markersize=15) - plt.title(f'{k} = {mu:.3g}', fontdict={'size': 9}) # limit to 40 characters - plt.tick_params(axis='both', labelsize=8) # Set axis label size to 8 + plt_color_scatter(v, fitness, cmap="viridis", alpha=0.8, edgecolors="none") + plt.plot(mu, fitness.max(), "k+", markersize=15) + plt.title(f"{k} = {mu:.3g}", fontdict={"size": 9}) # limit to 40 characters + plt.tick_params(axis="both", labelsize=8) # Set axis label size to 8 if i % n != 0: plt.yticks([]) - file = csv_file.with_name('tune_scatter_plots.png') # filename + file = csv_file.with_name("tune_scatter_plots.png") # filename plt.savefig(file, dpi=200) plt.close() - LOGGER.info(f'Saved {file}') + LOGGER.info(f"Saved {file}") # Fitness vs iteration x = range(1, len(fitness) + 1) plt.figure(figsize=(10, 6), tight_layout=True) - plt.plot(x, fitness, marker='o', linestyle='none', label='fitness') - plt.plot(x, gaussian_filter1d(fitness, sigma=3), ':', label='smoothed', linewidth=2) # smoothing line - plt.title('Fitness vs Iteration') - plt.xlabel('Iteration') - plt.ylabel('Fitness') + plt.plot(x, fitness, marker="o", linestyle="none", label="fitness") + plt.plot(x, gaussian_filter1d(fitness, sigma=3), ":", label="smoothed", linewidth=2) # smoothing line + plt.title("Fitness vs Iteration") + plt.xlabel("Iteration") + plt.ylabel("Fitness") plt.grid(True) plt.legend() - file = csv_file.with_name('tune_fitness.png') # filename + file = csv_file.with_name("tune_fitness.png") # filename plt.savefig(file, dpi=200) plt.close() - LOGGER.info(f'Saved {file}') + LOGGER.info(f"Saved {file}") def output_to_target(output, max_det=300): @@ -861,7 +958,7 @@ def output_to_rotated_target(output, max_det=300): return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1] -def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')): +def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")): """ Visualize feature maps of a given model module during inference. @@ -872,7 +969,7 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec n (int, optional): Maximum number of feature maps to plot. Defaults to 32. save_dir (Path, optional): Directory to save results. Defaults to Path('runs/detect/exp'). """ - for m in ['Detect', 'Pose', 'Segment']: + for m in ["Detect", "Pose", "Segment"]: if m in module_type: return batch, channels, height, width = x.shape # batch, channels, height, width @@ -886,9 +983,9 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec plt.subplots_adjust(wspace=0.05, hspace=0.05) for i in range(n): ax[i].imshow(blocks[i].squeeze()) # cmap='gray' - ax[i].axis('off') + ax[i].axis("off") - LOGGER.info(f'Saving {f}... ({n}/{channels})') - plt.savefig(f, dpi=300, bbox_inches='tight') + LOGGER.info(f"Saving {f}... ({n}/{channels})") + plt.savefig(f, dpi=300, bbox_inches="tight") plt.close() - np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save + np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy()) # npy save diff --git a/ultralytics/utils/tal.py b/ultralytics/utils/tal.py index 97406cd7..b72255d4 100644 --- a/ultralytics/utils/tal.py +++ b/ultralytics/utils/tal.py @@ -7,7 +7,7 @@ from .checks import check_version from .metrics import bbox_iou, probiou from .ops import xywhr2xyxyxyxy -TORCH_1_10 = check_version(torch.__version__, '1.10.0') +TORCH_1_10 = check_version(torch.__version__, "1.10.0") class TaskAlignedAssigner(nn.Module): @@ -61,12 +61,17 @@ class TaskAlignedAssigner(nn.Module): if self.n_max_boxes == 0: device = gt_bboxes.device - return (torch.full_like(pd_scores[..., 0], self.bg_idx).to(device), torch.zeros_like(pd_bboxes).to(device), - torch.zeros_like(pd_scores).to(device), torch.zeros_like(pd_scores[..., 0]).to(device), - torch.zeros_like(pd_scores[..., 0]).to(device)) - - mask_pos, align_metric, overlaps = self.get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, - mask_gt) + return ( + torch.full_like(pd_scores[..., 0], self.bg_idx).to(device), + torch.zeros_like(pd_bboxes).to(device), + torch.zeros_like(pd_scores).to(device), + torch.zeros_like(pd_scores[..., 0]).to(device), + torch.zeros_like(pd_scores[..., 0]).to(device), + ) + + mask_pos, align_metric, overlaps = self.get_pos_mask( + pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt + ) target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes) @@ -148,7 +153,7 @@ class TaskAlignedAssigner(nn.Module): ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device) for k in range(self.topk): # Expand topk_idxs for each value of k and add 1 at the specified positions - count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones) + count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones) # count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device)) # Filter invalid bboxes count_tensor.masked_fill_(count_tensor > 1, 0) @@ -192,9 +197,11 @@ class TaskAlignedAssigner(nn.Module): target_labels.clamp_(0) # 10x faster than F.one_hot() - target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes), - dtype=torch.int64, - device=target_labels.device) # (b, h*w, 80) + target_scores = torch.zeros( + (target_labels.shape[0], target_labels.shape[1], self.num_classes), + dtype=torch.int64, + device=target_labels.device, + ) # (b, h*w, 80) target_scores.scatter_(2, target_labels.unsqueeze(-1), 1) fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80) @@ -252,7 +259,6 @@ class TaskAlignedAssigner(nn.Module): class RotatedTaskAlignedAssigner(TaskAlignedAssigner): - def iou_calculation(self, gt_bboxes, pd_bboxes): """Iou calculation for rotated bounding boxes.""" return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0) @@ -295,7 +301,7 @@ def make_anchors(feats, strides, grid_cell_offset=0.5): _, _, h, w = feats[i].shape sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y - sy, sx = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx) + sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx) anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2)) stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) return torch.cat(anchor_points), torch.cat(stride_tensor) diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index 57139d8f..bdd7e335 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -25,11 +25,11 @@ try: except ImportError: thop = None -TORCH_1_9 = check_version(torch.__version__, '1.9.0') -TORCH_2_0 = check_version(torch.__version__, '2.0.0') -TORCHVISION_0_10 = check_version(torchvision.__version__, '0.10.0') -TORCHVISION_0_11 = check_version(torchvision.__version__, '0.11.0') -TORCHVISION_0_13 = check_version(torchvision.__version__, '0.13.0') +TORCH_1_9 = check_version(torch.__version__, "1.9.0") +TORCH_2_0 = check_version(torch.__version__, "2.0.0") +TORCHVISION_0_10 = check_version(torchvision.__version__, "0.10.0") +TORCHVISION_0_11 = check_version(torchvision.__version__, "0.11.0") +TORCHVISION_0_13 = check_version(torchvision.__version__, "0.13.0") @contextmanager @@ -60,13 +60,13 @@ def get_cpu_info(): """Return a string with system CPU information, i.e. 'Apple M2'.""" import cpuinfo # pip install py-cpuinfo - k = 'brand_raw', 'hardware_raw', 'arch_string_raw' # info keys sorted by preference (not all keys always available) + k = "brand_raw", "hardware_raw", "arch_string_raw" # info keys sorted by preference (not all keys always available) info = cpuinfo.get_cpu_info() # info dict - string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], 'unknown') - return string.replace('(R)', '').replace('CPU ', '').replace('@ ', '') + string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown") + return string.replace("(R)", "").replace("CPU ", "").replace("@ ", "") -def select_device(device='', batch=0, newline=False, verbose=True): +def select_device(device="", batch=0, newline=False, verbose=True): """ Selects the appropriate PyTorch device based on the provided arguments. @@ -103,49 +103,57 @@ def select_device(device='', batch=0, newline=False, verbose=True): if isinstance(device, torch.device): return device - s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} ' + s = f"Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} " device = str(device).lower() - for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ': - device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1' - cpu = device == 'cpu' - mps = device in ('mps', 'mps:0') # Apple Metal Performance Shaders (MPS) + for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ": + device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1' + cpu = device == "cpu" + mps = device in ("mps", "mps:0") # Apple Metal Performance Shaders (MPS) if cpu or mps: - os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False elif device: # non-cpu device requested - if device == 'cuda': - device = '0' - visible = os.environ.get('CUDA_VISIBLE_DEVICES', None) - os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available() - if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))): + if device == "cuda": + device = "0" + visible = os.environ.get("CUDA_VISIBLE_DEVICES", None) + os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available() + if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(",", ""))): LOGGER.info(s) - install = 'See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no ' \ - 'CUDA devices are seen by torch.\n' if torch.cuda.device_count() == 0 else '' - raise ValueError(f"Invalid CUDA 'device={device}' requested." - f" Use 'device=cpu' or pass valid CUDA device(s) if available," - f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n" - f'\ntorch.cuda.is_available(): {torch.cuda.is_available()}' - f'\ntorch.cuda.device_count(): {torch.cuda.device_count()}' - f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n" - f'{install}') + install = ( + "See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no " + "CUDA devices are seen by torch.\n" + if torch.cuda.device_count() == 0 + else "" + ) + raise ValueError( + f"Invalid CUDA 'device={device}' requested." + f" Use 'device=cpu' or pass valid CUDA device(s) if available," + f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n" + f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}" + f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}" + f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n" + f"{install}" + ) if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available - devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7 + devices = device.split(",") if device else "0" # range(torch.cuda.device_count()) # i.e. 0,1,6,7 n = len(devices) # device count if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count - raise ValueError(f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or " - f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}.") - space = ' ' * (len(s) + 1) + raise ValueError( + f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or " + f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}." + ) + space = " " * (len(s) + 1) for i, d in enumerate(devices): p = torch.cuda.get_device_properties(i) s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB - arg = 'cuda:0' + arg = "cuda:0" elif mps and TORCH_2_0 and torch.backends.mps.is_available(): # Prefer MPS if available - s += f'MPS ({get_cpu_info()})\n' - arg = 'mps' + s += f"MPS ({get_cpu_info()})\n" + arg = "mps" else: # revert to CPU - s += f'CPU ({get_cpu_info()})\n' - arg = 'cpu' + s += f"CPU ({get_cpu_info()})\n" + arg = "cpu" if verbose: LOGGER.info(s if newline else s.rstrip()) @@ -161,14 +169,20 @@ def time_sync(): def fuse_conv_and_bn(conv, bn): """Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/.""" - fusedconv = nn.Conv2d(conv.in_channels, - conv.out_channels, - kernel_size=conv.kernel_size, - stride=conv.stride, - padding=conv.padding, - dilation=conv.dilation, - groups=conv.groups, - bias=True).requires_grad_(False).to(conv.weight.device) + fusedconv = ( + nn.Conv2d( + conv.in_channels, + conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + dilation=conv.dilation, + groups=conv.groups, + bias=True, + ) + .requires_grad_(False) + .to(conv.weight.device) + ) # Prepare filters w_conv = conv.weight.clone().view(conv.out_channels, -1) @@ -185,15 +199,21 @@ def fuse_conv_and_bn(conv, bn): def fuse_deconv_and_bn(deconv, bn): """Fuse ConvTranspose2d() and BatchNorm2d() layers.""" - fuseddconv = nn.ConvTranspose2d(deconv.in_channels, - deconv.out_channels, - kernel_size=deconv.kernel_size, - stride=deconv.stride, - padding=deconv.padding, - output_padding=deconv.output_padding, - dilation=deconv.dilation, - groups=deconv.groups, - bias=True).requires_grad_(False).to(deconv.weight.device) + fuseddconv = ( + nn.ConvTranspose2d( + deconv.in_channels, + deconv.out_channels, + kernel_size=deconv.kernel_size, + stride=deconv.stride, + padding=deconv.padding, + output_padding=deconv.output_padding, + dilation=deconv.dilation, + groups=deconv.groups, + bias=True, + ) + .requires_grad_(False) + .to(deconv.weight.device) + ) # Prepare filters w_deconv = deconv.weight.clone().view(deconv.out_channels, -1) @@ -221,18 +241,21 @@ def model_info(model, detailed=False, verbose=True, imgsz=640): n_l = len(list(model.modules())) # number of layers if detailed: LOGGER.info( - f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}") + f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}" + ) for i, (name, p) in enumerate(model.named_parameters()): - name = name.replace('module_list.', '') - LOGGER.info('%5g %40s %9s %12g %20s %10.3g %10.3g %10s' % - (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype)) + name = name.replace("module_list.", "") + LOGGER.info( + "%5g %40s %9s %12g %20s %10.3g %10.3g %10s" + % (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype) + ) flops = get_flops(model, imgsz) - fused = ' (fused)' if getattr(model, 'is_fused', lambda: False)() else '' - fs = f', {flops:.1f} GFLOPs' if flops else '' - yaml_file = getattr(model, 'yaml_file', '') or getattr(model, 'yaml', {}).get('yaml_file', '') - model_name = Path(yaml_file).stem.replace('yolo', 'YOLO') or 'Model' - LOGGER.info(f'{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}') + fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else "" + fs = f", {flops:.1f} GFLOPs" if flops else "" + yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "") + model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model" + LOGGER.info(f"{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}") return n_l, n_p, n_g, flops @@ -262,13 +285,15 @@ def model_info_for_loggers(trainer): """ if trainer.args.profile: # profile ONNX and TensorRT times from ultralytics.utils.benchmarks import ProfileModels + results = ProfileModels([trainer.last], device=trainer.device).profile()[0] - results.pop('model/name') + results.pop("model/name") else: # only return PyTorch times from most recent validation results = { - 'model/parameters': get_num_params(trainer.model), - 'model/GFLOPs': round(get_flops(trainer.model), 3)} - results['model/speed_PyTorch(ms)'] = round(trainer.validator.speed['inference'], 3) + "model/parameters": get_num_params(trainer.model), + "model/GFLOPs": round(get_flops(trainer.model), 3), + } + results["model/speed_PyTorch(ms)"] = round(trainer.validator.speed["inference"], 3) return results @@ -284,14 +309,14 @@ def get_flops(model, imgsz=640): imgsz = [imgsz, imgsz] # expand if int/float try: # Use stride size for input tensor - stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride + stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format - flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 # stride GFLOPs + flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs except Exception: # Use actual image size for input tensor (i.e. required for RTDETR models) im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format - return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 # imgsz GFLOPs + return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs except Exception: return 0.0 @@ -301,11 +326,11 @@ def get_flops_with_torch_profiler(model, imgsz=640): if TORCH_2_0: model = de_parallel(model) p = next(model.parameters()) - stride = (max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32) * 2 # max stride + stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format with torch.profiler.profile(with_flops=True) as prof: model(im) - flops = sum(x.flops for x in prof.key_averages()) / 1E9 + flops = sum(x.flops for x in prof.key_averages()) / 1e9 imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs return flops @@ -333,7 +358,7 @@ def scale_img(img, ratio=1.0, same_shape=False, gs=32): return img h, w = img.shape[2:] s = (int(h * ratio), int(w * ratio)) # new size - img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize + img = F.interpolate(img, size=s, mode="bilinear", align_corners=False) # resize if not same_shape: # pad/crop img h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w)) return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean @@ -349,7 +374,7 @@ def make_divisible(x, divisor): def copy_attr(a, b, include=(), exclude=()): """Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes.""" for k, v in b.__dict__.items(): - if (len(include) and k not in include) or k.startswith('_') or k in exclude: + if (len(include) and k not in include) or k.startswith("_") or k in exclude: continue else: setattr(a, k, v) @@ -357,7 +382,7 @@ def copy_attr(a, b, include=(), exclude=()): def get_latest_opset(): """Return second-most (for maturity) recently supported ONNX opset by this version of torch.""" - return max(int(k[14:]) for k in vars(torch.onnx) if 'symbolic_opset' in k) - 1 # opset + return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1 # opset def intersect_dicts(da, db, exclude=()): @@ -392,10 +417,10 @@ def init_seeds(seed=0, deterministic=False): if TORCH_2_0: torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible torch.backends.cudnn.deterministic = True - os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' - os.environ['PYTHONHASHSEED'] = str(seed) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + os.environ["PYTHONHASHSEED"] = str(seed) else: - LOGGER.warning('WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.') + LOGGER.warning("WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.") else: torch.use_deterministic_algorithms(False) torch.backends.cudnn.deterministic = False @@ -430,13 +455,13 @@ class ModelEMA: v += (1 - d) * msd[k].detach() # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}' - def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): + def update_attr(self, model, include=(), exclude=("process_group", "reducer")): """Updates attributes and saves stripped model with optimizer removed.""" if self.enabled: copy_attr(self.ema, model, include, exclude) -def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None: +def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None: """ Strip optimizer from 'f' to finalize training, optionally save as 's'. @@ -456,26 +481,26 @@ def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None: strip_optimizer(f) ``` """ - x = torch.load(f, map_location=torch.device('cpu')) - if 'model' not in x: - LOGGER.info(f'Skipping {f}, not a valid Ultralytics model.') + x = torch.load(f, map_location=torch.device("cpu")) + if "model" not in x: + LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.") return - if hasattr(x['model'], 'args'): - x['model'].args = dict(x['model'].args) # convert from IterableSimpleNamespace to dict - args = {**DEFAULT_CFG_DICT, **x['train_args']} if 'train_args' in x else None # combine args - if x.get('ema'): - x['model'] = x['ema'] # replace model with ema - for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys + if hasattr(x["model"], "args"): + x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict + args = {**DEFAULT_CFG_DICT, **x["train_args"]} if "train_args" in x else None # combine args + if x.get("ema"): + x["model"] = x["ema"] # replace model with ema + for k in "optimizer", "best_fitness", "ema", "updates": # keys x[k] = None - x['epoch'] = -1 - x['model'].half() # to FP16 - for p in x['model'].parameters(): + x["epoch"] = -1 + x["model"].half() # to FP16 + for p in x["model"].parameters(): p.requires_grad = False - x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys + x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys # x['model'].args = x['train_args'] torch.save(x, s or f) - mb = os.path.getsize(s or f) / 1E6 # file size + mb = os.path.getsize(s or f) / 1e6 # file size LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB") @@ -496,18 +521,20 @@ def profile(input, ops, n=10, device=None): results = [] if not isinstance(device, torch.device): device = select_device(device) - LOGGER.info(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}" - f"{'input':>24s}{'output':>24s}") + LOGGER.info( + f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}" + f"{'input':>24s}{'output':>24s}" + ) for x in input if isinstance(input, list) else [input]: x = x.to(device) x.requires_grad = True for m in ops if isinstance(ops, list) else [ops]: - m = m.to(device) if hasattr(m, 'to') else m # device - m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m + m = m.to(device) if hasattr(m, "to") else m # device + m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward try: - flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1E9 * 2 if thop else 0 # GFLOPs + flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs except Exception: flops = 0 @@ -521,13 +548,13 @@ def profile(input, ops, n=10, device=None): t[2] = time_sync() except Exception: # no backward method # print(e) # for debug - t[2] = float('nan') + t[2] = float("nan") tf += (t[1] - t[0]) * 1000 / n # ms per op forward tb += (t[2] - t[1]) * 1000 / n # ms per op backward - mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB) - s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes + mem = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 # (GB) + s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters - LOGGER.info(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}') + LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}") results.append([p, flops, mem, tf, tb, s_in, s_out]) except Exception as e: LOGGER.info(e) @@ -548,7 +575,7 @@ class EarlyStopping: """ self.best_fitness = 0.0 # i.e. mAP self.best_epoch = 0 - self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop + self.patience = patience or float("inf") # epochs to wait after fitness stops improving to stop self.possible_stop = False # possible stop may occur next epoch def __call__(self, epoch, fitness): @@ -572,8 +599,10 @@ class EarlyStopping: self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch stop = delta >= self.patience # stop training if patience exceeded if stop: - LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. ' - f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n' - f'To update EarlyStopping(patience={self.patience}) pass a new patience value, ' - f'i.e. `patience=300` or use `patience=0` to disable EarlyStopping.') + LOGGER.info( + f"Stopping training early as no improvement observed in last {self.patience} epochs. " + f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n" + f"To update EarlyStopping(patience={self.patience}) pass a new patience value, " + f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping." + ) return stop diff --git a/ultralytics/utils/triton.py b/ultralytics/utils/triton.py index b79b04b4..98c4be20 100644 --- a/ultralytics/utils/triton.py +++ b/ultralytics/utils/triton.py @@ -22,7 +22,7 @@ class TritonRemoteModel: output_names (List[str]): The names of the model outputs. """ - def __init__(self, url: str, endpoint: str = '', scheme: str = ''): + def __init__(self, url: str, endpoint: str = "", scheme: str = ""): """ Initialize the TritonRemoteModel. @@ -36,7 +36,7 @@ class TritonRemoteModel: """ if not endpoint and not scheme: # Parse all args from URL string splits = urlsplit(url) - endpoint = splits.path.strip('/').split('/')[0] + endpoint = splits.path.strip("/").split("/")[0] scheme = splits.scheme url = splits.netloc @@ -44,26 +44,28 @@ class TritonRemoteModel: self.url = url # Choose the Triton client based on the communication scheme - if scheme == 'http': + if scheme == "http": import tritonclient.http as client # noqa + self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False) config = self.triton_client.get_model_config(endpoint) else: import tritonclient.grpc as client # noqa + self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False) - config = self.triton_client.get_model_config(endpoint, as_json=True)['config'] + config = self.triton_client.get_model_config(endpoint, as_json=True)["config"] # Sort output names alphabetically, i.e. 'output0', 'output1', etc. - config['output'] = sorted(config['output'], key=lambda x: x.get('name')) + config["output"] = sorted(config["output"], key=lambda x: x.get("name")) # Define model attributes - type_map = {'TYPE_FP32': np.float32, 'TYPE_FP16': np.float16, 'TYPE_UINT8': np.uint8} + type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8} self.InferRequestedOutput = client.InferRequestedOutput self.InferInput = client.InferInput - self.input_formats = [x['data_type'] for x in config['input']] + self.input_formats = [x["data_type"] for x in config["input"]] self.np_input_formats = [type_map[x] for x in self.input_formats] - self.input_names = [x['name'] for x in config['input']] - self.output_names = [x['name'] for x in config['output']] + self.input_names = [x["name"] for x in config["input"]] + self.output_names = [x["name"] for x in config["output"]] def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]: """ @@ -80,7 +82,7 @@ class TritonRemoteModel: for i, x in enumerate(inputs): if x.dtype != self.np_input_formats[i]: x = x.astype(self.np_input_formats[i]) - infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace('TYPE_', '')) + infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", "")) infer_input.set_data_from_numpy(x) infer_inputs.append(infer_input) diff --git a/ultralytics/utils/tuner.py b/ultralytics/utils/tuner.py index a06f813d..db7d4907 100644 --- a/ultralytics/utils/tuner.py +++ b/ultralytics/utils/tuner.py @@ -6,12 +6,9 @@ from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_save_dir from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS -def run_ray_tune(model, - space: dict = None, - grace_period: int = 10, - gpu_per_trial: int = None, - max_samples: int = 10, - **train_args): +def run_ray_tune( + model, space: dict = None, grace_period: int = 10, gpu_per_trial: int = None, max_samples: int = 10, **train_args +): """ Runs hyperparameter tuning using Ray Tune. @@ -38,12 +35,12 @@ def run_ray_tune(model, ``` """ - LOGGER.info('💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune') + LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune") if train_args is None: train_args = {} try: - subprocess.run('pip install ray[tune]'.split(), check=True) + subprocess.run("pip install ray[tune]".split(), check=True) import ray from ray import tune @@ -56,33 +53,34 @@ def run_ray_tune(model, try: import wandb - assert hasattr(wandb, '__version__') + assert hasattr(wandb, "__version__") except (ImportError, AssertionError): wandb = False default_space = { # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']), - 'lr0': tune.uniform(1e-5, 1e-1), - 'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) - 'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1 - 'weight_decay': tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4 - 'warmup_epochs': tune.uniform(0.0, 5.0), # warmup epochs (fractions ok) - 'warmup_momentum': tune.uniform(0.0, 0.95), # warmup initial momentum - 'box': tune.uniform(0.02, 0.2), # box loss gain - 'cls': tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels) - 'hsv_h': tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction) - 'hsv_s': tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction) - 'hsv_v': tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction) - 'degrees': tune.uniform(0.0, 45.0), # image rotation (+/- deg) - 'translate': tune.uniform(0.0, 0.9), # image translation (+/- fraction) - 'scale': tune.uniform(0.0, 0.9), # image scale (+/- gain) - 'shear': tune.uniform(0.0, 10.0), # image shear (+/- deg) - 'perspective': tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 - 'flipud': tune.uniform(0.0, 1.0), # image flip up-down (probability) - 'fliplr': tune.uniform(0.0, 1.0), # image flip left-right (probability) - 'mosaic': tune.uniform(0.0, 1.0), # image mixup (probability) - 'mixup': tune.uniform(0.0, 1.0), # image mixup (probability) - 'copy_paste': tune.uniform(0.0, 1.0)} # segment copy-paste (probability) + "lr0": tune.uniform(1e-5, 1e-1), + "lrf": tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) + "momentum": tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1 + "weight_decay": tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4 + "warmup_epochs": tune.uniform(0.0, 5.0), # warmup epochs (fractions ok) + "warmup_momentum": tune.uniform(0.0, 0.95), # warmup initial momentum + "box": tune.uniform(0.02, 0.2), # box loss gain + "cls": tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels) + "hsv_h": tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction) + "hsv_s": tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction) + "hsv_v": tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction) + "degrees": tune.uniform(0.0, 45.0), # image rotation (+/- deg) + "translate": tune.uniform(0.0, 0.9), # image translation (+/- fraction) + "scale": tune.uniform(0.0, 0.9), # image scale (+/- gain) + "shear": tune.uniform(0.0, 10.0), # image shear (+/- deg) + "perspective": tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 + "flipud": tune.uniform(0.0, 1.0), # image flip up-down (probability) + "fliplr": tune.uniform(0.0, 1.0), # image flip left-right (probability) + "mosaic": tune.uniform(0.0, 1.0), # image mixup (probability) + "mixup": tune.uniform(0.0, 1.0), # image mixup (probability) + "copy_paste": tune.uniform(0.0, 1.0), # segment copy-paste (probability) + } # Put the model in ray store task = model.task @@ -107,35 +105,39 @@ def run_ray_tune(model, # Get search space if not space: space = default_space - LOGGER.warning('WARNING ⚠️ search space not provided, using default search space.') + LOGGER.warning("WARNING ⚠️ search space not provided, using default search space.") # Get dataset - data = train_args.get('data', TASK2DATA[task]) - space['data'] = data - if 'data' not in train_args: + data = train_args.get("data", TASK2DATA[task]) + space["data"] = data + if "data" not in train_args: LOGGER.warning(f'WARNING ⚠️ data not provided, using default "data={data}".') # Define the trainable function with allocated resources - trainable_with_resources = tune.with_resources(_tune, {'cpu': NUM_THREADS, 'gpu': gpu_per_trial or 0}) + trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0}) # Define the ASHA scheduler for hyperparameter search - asha_scheduler = ASHAScheduler(time_attr='epoch', - metric=TASK2METRIC[task], - mode='max', - max_t=train_args.get('epochs') or DEFAULT_CFG_DICT['epochs'] or 100, - grace_period=grace_period, - reduction_factor=3) + asha_scheduler = ASHAScheduler( + time_attr="epoch", + metric=TASK2METRIC[task], + mode="max", + max_t=train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100, + grace_period=grace_period, + reduction_factor=3, + ) # Define the callbacks for the hyperparameter search - tuner_callbacks = [WandbLoggerCallback(project='YOLOv8-tune')] if wandb else [] + tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else [] # Create the Ray Tune hyperparameter search tuner - tune_dir = get_save_dir(DEFAULT_CFG, name='tune').resolve() # must be absolute dir + tune_dir = get_save_dir(DEFAULT_CFG, name="tune").resolve() # must be absolute dir tune_dir.mkdir(parents=True, exist_ok=True) - tuner = tune.Tuner(trainable_with_resources, - param_space=space, - tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples), - run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir)) + tuner = tune.Tuner( + trainable_with_resources, + param_space=space, + tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples), + run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir), + ) # Run the hyperparameter search tuner.fit()