diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index 796a08968a..43f5d4cfeb 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -39,7 +39,8 @@ on:
jobs:
HUB:
- if: github.repository == 'ultralytics/ultralytics' && (github.event_name == 'schedule' || github.event_name == 'push' || (github.event_name == 'workflow_dispatch' && github.event.inputs.hub == 'true'))
+ # if: github.repository == 'ultralytics/ultralytics' && (github.event_name == 'schedule' || github.event_name == 'push' || (github.event_name == 'workflow_dispatch' && github.event.inputs.hub == 'true'))
+ if: github.repository == 'ultralytics/ultralytics' && 'workflow_dispatch' && github.event.inputs.hub == 'true'
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml
index 8d9f749e19..c299bc5bfd 100644
--- a/.github/workflows/docker.yaml
+++ b/.github/workflows/docker.yaml
@@ -84,11 +84,8 @@ jobs:
outputs:
new_release: ${{ steps.check_tag.outputs.new_release }}
steps:
- - name: Cleanup disk
- # Free up to 30GB of disk space per https://github.com/ultralytics/ultralytics/pull/15848
- uses: jlumbroso/free-disk-space@v1.3.1
- with:
- tool-cache: true
+ - name: Cleanup disk space
+ uses: ultralytics/actions/cleanup-disk@main
- name: Checkout repo
uses: actions/checkout@v4
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
index dd8503541e..991e0edd99 100644
--- a/.github/workflows/stale.yml
+++ b/.github/workflows/stale.yml
@@ -5,6 +5,10 @@ on:
schedule:
- cron: "0 0 * * *" # Runs at 00:00 UTC every day
+permissions:
+ pull-requests: write
+ issues: write
+
jobs:
stale:
runs-on: ubuntu-latest
diff --git a/README.md b/README.md
index 39fd7bacaf..51f13230ed 100644
--- a/README.md
+++ b/README.md
@@ -17,6 +17,7 @@
+
@@ -26,7 +27,7 @@ We hope that the resources here will help you get the most out of YOLO. Please b
To request an Enterprise License please complete the form at [Ultralytics Licensing](https://www.ultralytics.com/license).
-
+
@@ -16,7 +16,7 @@ This comprehensive guide provides a detailed walkthrough for deploying Ultralyti
allowfullscreen>
- Watch: How to Setup NVIDIA Jetson with Ultralytics YOLOv8
+ Watch: How to Setup NVIDIA Jetson with Ultralytics YOLO11
diff --git a/docs/en/models/yolov5.md b/docs/en/models/yolov5.md
index 8ff1c36ec0..91c562a44e 100644
--- a/docs/en/models/yolov5.md
+++ b/docs/en/models/yolov5.md
@@ -4,7 +4,11 @@ description: Explore YOLOv5u, an advanced object detection model with optimized
keywords: YOLOv5, YOLOv5u, object detection, Ultralytics, anchor-free, pre-trained models, accuracy, speed, real-time detection
---
-# YOLOv5
+# Ultralytics YOLOv5
+
+!!! tip "Ultralytics YOLOv5 Publication"
+
+ Ultralytics has not published a formal research paper for YOLOv5 due to the rapidly evolving nature of the models. We focus on advancing the technology and making it easier to use, rather than producing static documentation. For the most up-to-date information on YOLO architecture, features, and usage, please refer to our [GitHub repository](https://github.com/ultralytics/ultralytics) and [documentation](https://docs.ultralytics.com).
## Overview
diff --git a/docs/en/models/yolov8.md b/docs/en/models/yolov8.md
index 036cd305a1..c8e4397d15 100644
--- a/docs/en/models/yolov8.md
+++ b/docs/en/models/yolov8.md
@@ -6,6 +6,10 @@ keywords: YOLOv8, real-time object detection, YOLO series, Ultralytics, computer
# Ultralytics YOLOv8
+!!! tip "Ultralytics YOLOv8 Publication"
+
+ Ultralytics has not published a formal research paper for YOLOv8 due to the rapidly evolving nature of the models. We focus on advancing the technology and making it easier to use, rather than producing static documentation. For the most up-to-date information on YOLO architecture, features, and usage, please refer to our [GitHub repository](https://github.com/ultralytics/ultralytics) and [documentation](https://docs.ultralytics.com).
+
## Overview
YOLOv8 is the latest iteration in the YOLO series of real-time object detectors, offering cutting-edge performance in terms of accuracy and speed. Building upon the advancements of previous YOLO versions, YOLOv8 introduces new features and optimizations that make it an ideal choice for various [object detection](https://www.ultralytics.com/glossary/object-detection) tasks in a wide range of applications.
diff --git a/docs/en/modes/export.md b/docs/en/modes/export.md
index 4be5bd5b90..776d826445 100644
--- a/docs/en/modes/export.md
+++ b/docs/en/modes/export.md
@@ -136,13 +136,13 @@ INT8 quantization is an excellent way to compress the model and speed up inferen
from ultralytics import YOLO
model = YOLO("yolo11n.pt") # Load a model
- model.export(format="onnx", int8=True)
+ model.export(format="engine", int8=True)
```
=== "CLI"
```bash
- yolo export model=yolo11n.pt format=onnx int8=True # export model with INT8 quantization
+ yolo export model=yolo11n.pt format=engine int8=True # export TensorRT model with INT8 quantization
```
INT8 quantization can be applied to various formats, such as TensorRT and CoreML. More details can be found in the [Export section](../modes/export.md).
diff --git a/docs/en/modes/predict.md b/docs/en/modes/predict.md
index a298294d59..4c69aa52d4 100644
--- a/docs/en/modes/predict.md
+++ b/docs/en/modes/predict.md
@@ -665,7 +665,7 @@ For more details see the [`Probs` class documentation](../reference/engine/resul
model = YOLO("yolo11n-obb.pt")
# Run inference on an image
- results = model("bus.jpg") # results list
+ results = model("boats.jpg") # results list
# View results
for r in results:
diff --git a/docs/en/tasks/obb.md b/docs/en/tasks/obb.md
index 35e659ed47..621ffc783d 100644
--- a/docs/en/tasks/obb.md
+++ b/docs/en/tasks/obb.md
@@ -141,14 +141,14 @@ Use a trained YOLO11n-obb model to run predictions on images.
model = YOLO("path/to/best.pt") # load a custom model
# Predict with the model
- results = model("https://ultralytics.com/images/bus.jpg") # predict on an image
+ results = model("https://ultralytics.com/images/boats.jpg") # predict on an image
```
=== "CLI"
```bash
- yolo obb predict model=yolo11n-obb.pt source='https://ultralytics.com/images/bus.jpg' # predict with official model
- yolo obb predict model=path/to/best.pt source='https://ultralytics.com/images/bus.jpg' # predict with custom model
+ yolo obb predict model=yolo11n-obb.pt source='https://ultralytics.com/images/boats.jpg' # predict with official model
+ yolo obb predict model=path/to/best.pt source='https://ultralytics.com/images/boats.jpg' # predict with custom model
```
diff --git a/examples/heatmaps.ipynb b/examples/heatmaps.ipynb index 11ffdc9058..c674ad4800 100644 --- a/examples/heatmaps.ipynb +++ b/examples/heatmaps.ipynb @@ -96,10 +96,7 @@ "source": [ "import cv2\n", "\n", - "from ultralytics import YOLO, solutions\n", - "\n", - "# Load YOLO model\n", - "model = YOLO(\"yolo11n.pt\")\n", + "from ultralytics import solutions\n", "\n", "# Open video file\n", "cap = cv2.VideoCapture(\"path/to/video/file.mp4\")\n", @@ -113,10 +110,9 @@ "\n", "# Initialize heatmap object\n", "heatmap_obj = solutions.Heatmap(\n", - " colormap=cv2.COLORMAP_PARULA,\n", - " view_img=True,\n", - " shape=\"circle\",\n", - " names=model.names,\n", + " colormap=cv2.COLORMAP_PARULA, # Color of the heatmap\n", + " show=True, # Display the image during processing\n", + " model=yolo11n.pt, # Ultralytics YOLO11 model file\n", ")\n", "\n", "while cap.isOpened():\n", @@ -125,11 +121,8 @@ " print(\"Video frame is empty or video processing has been successfully completed.\")\n", " break\n", "\n", - " # Perform tracking on the current frame\n", - " tracks = model.track(im0, persist=True, show=False)\n", - "\n", " # Generate heatmap on the frame\n", - " im0 = heatmap_obj.generate_heatmap(im0, tracks)\n", + " im0 = heatmap_obj.generate_heatmap(im0)\n", "\n", " # Write the frame to the output video\n", " video_writer.write(im0)\n", diff --git a/examples/object_counting.ipynb b/examples/object_counting.ipynb index 572f1033a1..50168f262e 100644 --- a/examples/object_counting.ipynb +++ b/examples/object_counting.ipynb @@ -104,10 +104,7 @@ "source": [ "import cv2\n", "\n", - "from ultralytics import YOLO, solutions\n", - "\n", - "# Load the pre-trained YOLO11 model\n", - "model = YOLO(\"yolo11n.pt\")\n", + "from ultralytics import solutions\n", "\n", "# Open the video file\n", "cap = cv2.VideoCapture(\"path/to/video/file.mp4\")\n", @@ -119,19 +116,15 @@ "# Define points for a line or region of interest in the video frame\n", "line_points = [(20, 400), (1080, 400)] # Line coordinates\n", "\n", - "# Specify classes to count, for example: person (0) and car (2)\n", - "classes_to_count = [0, 2] # Class IDs for person and car\n", - "\n", "# Initialize the video writer to save the output video\n", "video_writer = cv2.VideoWriter(\"object_counting_output.avi\", cv2.VideoWriter_fourcc(*\"mp4v\"), fps, (w, h))\n", "\n", "# Initialize the Object Counter with visualization options and other parameters\n", "counter = solutions.ObjectCounter(\n", - " view_img=True, # Display the image during processing\n", - " reg_pts=line_points, # Region of interest points\n", - " names=model.names, # Class names from the YOLO model\n", - " draw_tracks=True, # Draw tracking lines for objects\n", - " line_thickness=2, # Thickness of the lines drawn\n", + " show=True, # Display the image during processing\n", + " region=line_points, # Region of interest points\n", + " model=yolo11n.pt, # Ultralytics YOLO11 model file\n", + " line_width=2, # Thickness of the lines and bounding boxes\n", ")\n", "\n", "# Process video frames in a loop\n", @@ -141,11 +134,8 @@ " print(\"Video frame is empty or video processing has been successfully completed.\")\n", " break\n", "\n", - " # Perform object tracking on the current frame, filtering by specified classes\n", - " tracks = model.track(im0, persist=True, show=False, classes=classes_to_count)\n", - "\n", " # Use the Object Counter to count objects in the frame and get the annotated image\n", - " im0 = counter.start_counting(im0, tracks)\n", + " im0 = counter.count(im0)\n", "\n", " # Write the annotated frame to the output video\n", " video_writer.write(im0)\n", diff --git a/examples/tutorial.ipynb b/examples/tutorial.ipynb index 98c659b864..75dd455e9a 100644 --- a/examples/tutorial.ipynb +++ b/examples/tutorial.ipynb @@ -583,7 +583,7 @@ "\n", "model = YOLO('yolo11n-obb.pt') # load a pretrained YOLO OBB model\n", "model.train(data='dota8.yaml', epochs=3) # train the model\n", - "model('https://ultralytics.com/images/bus.jpg') # predict on an image" + "model('https://ultralytics.com/images/boats.jpg') # predict on an image" ], "metadata": { "id": "IJNKClOOB5YS" diff --git a/mkdocs.yml b/mkdocs.yml index c8d151b96e..2ea041f331 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -274,7 +274,7 @@ nav: - VisDrone: datasets/detect/visdrone.md - VOC: datasets/detect/voc.md - xView: datasets/detect/xview.md - - Roboflow 100: datasets/detect/roboflow-100.md + - RF100: datasets/detect/roboflow-100.md - Brain-tumor: datasets/detect/brain-tumor.md - African-wildlife: datasets/detect/african-wildlife.md - Signature: datasets/detect/signature.md diff --git a/pyproject.toml b/pyproject.toml index f6cb23204a..2545739bab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,8 @@ classifiers = [ # Required dependencies ------------------------------------------------------------------------------------------------ dependencies = [ - "numpy>=1.23.0", # temporary patch for compat errors https://github.com/ultralytics/yolov5/actions/runs/9538130424/job/26286956354 + "numpy>=1.23.0", + "numpy<2.0.0; sys_platform == 'darwin'", # macOS OpenVINO errors https://github.com/ultralytics/ultralytics/pull/17221 "matplotlib>=3.3.0", "opencv-python>=4.6.0", "pillow>=7.1.2", diff --git a/tests/test_cuda.py b/tests/test_cuda.py index 89f8c39b25..4fd1a7aee3 100644 --- a/tests/test_cuda.py +++ b/tests/test_cuda.py @@ -116,7 +116,7 @@ def test_predict_sam(): from ultralytics.models.sam import Predictor as SAMPredictor # Load a model - model = SAM(WEIGHTS_DIR / "sam_b.pt") + model = SAM(WEIGHTS_DIR / "sam2.1_b.pt") # Display model information (optional) model.info() diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 0887cf9050..72a9396473 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.3.18" +__version__ = "8.3.24" import os diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py index 153ab27e38..0af93a37d3 100644 --- a/ultralytics/cfg/__init__.py +++ b/ultralytics/cfg/__init__.py @@ -787,7 +787,7 @@ def entrypoint(debug=""): from ultralytics import FastSAM model = FastSAM(model) - elif "sam_" in stem or "sam2_" in stem: + elif "sam_" in stem or "sam2_" in stem or "sam2.1_" in stem: from ultralytics import SAM model = SAM(model) @@ -809,7 +809,9 @@ def entrypoint(debug=""): # Mode if mode in {"predict", "track"} and "source" not in overrides: - overrides["source"] = DEFAULT_CFG.source or ASSETS + overrides["source"] = ( + "https://ultralytics.com/images/boats.jpg" if task == "obb" else DEFAULT_CFG.source or ASSETS + ) LOGGER.warning(f"WARNING ⚠️ 'source' argument is missing. Using default 'source={overrides['source']}'.") elif mode in {"train", "val"}: if "data" not in overrides and "resume" not in overrides: diff --git a/ultralytics/cfg/solutions/default.yaml b/ultralytics/cfg/solutions/default.yaml index a353fd2a21..69e430b8c3 100644 --- a/ultralytics/cfg/solutions/default.yaml +++ b/ultralytics/cfg/solutions/default.yaml @@ -1,18 +1,19 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license - # Configuration for Ultralytics Solutions -model: "yolo11n.pt" # The Ultralytics YOLO11 model to be used (e.g., yolo11n.pt for YOLO11 nano version and yolov8n.pt for YOLOv8 nano version) - +# Object counting settings region: # Object counting, queue or speed estimation region points. Default region points are [(20, 400), (1080, 404), (1080, 360), (20, 360)] -line_width: 2 # Width of the annotator used to draw regions on the image/video frames + bounding boxes and tracks drawing. Default value is 2. -show: True # Flag to control whether to display output image or not, you can set this as False i.e. when deploying it on some embedded devices. show_in: True # Flag to display objects moving *into* the defined region show_out: True # Flag to display objects moving *out of* the defined region -classes: # To count specific classes. i.e, if you want to detect, track and count the person with COCO model, you can use classes=0, Default its None + +# Heatmaps settings +colormap: # Colormap for heatmap, Only OPENCV supported colormaps can be used. By default COLORMAP_PARULA will be used for visualization. + +# Workouts monitoring settings up_angle: 145.0 # Workouts up_angle for counts, 145.0 is default value. You can adjust it for different workouts, based on position of keypoints. down_angle: 90 # Workouts down_angle for counts, 90 is default value. You can change it for different workouts, based on position of keypoints. kpts: [6, 8, 10] # Keypoints for workouts monitoring, i.e. If you want to consider keypoints for pushups that have mostly values of [6, 8, 10]. -colormap: # Colormap for heatmap, Only OPENCV supported colormaps can be used. By default COLORMAP_PARULA will be used for visualization. + +# Analytics settings analytics_type: "line" # Analytics type i.e "line", "pie", "bar" or "area" charts. By default, "line" analytics will be used for processing. json_file: # parking system regions file path. diff --git a/ultralytics/data/converter.py b/ultralytics/data/converter.py index fe1aac10ae..fa5821418a 100644 --- a/ultralytics/data/converter.py +++ b/ultralytics/data/converter.py @@ -632,9 +632,10 @@ def yolo_bbox2segment(im_dir, save_dir=None, sam_model="sam_b.pt"): txt_file = save_dir / lb_name cls = label["cls"] for i, s in enumerate(label["segments"]): + if len(s) == 0: + continue line = (int(cls[i]), *s.reshape(-1)) texts.append(("%g " * len(line)).rstrip() % line) - if texts: with open(txt_file, "a") as f: f.writelines(text + "\n" for text in texts) LOGGER.info(f"Generated segment labels saved in {save_dir}") diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index 27dc1bfb0d..6174ab7add 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -213,9 +213,13 @@ class Exporter: LOGGER.warning("WARNING ⚠️ Sony MCT only supports int8 export, setting int8=True.") self.args.int8 = True # Device + dla = None if fmt == "engine" and self.args.device is None: LOGGER.warning("WARNING ⚠️ TensorRT requires GPU export, automatically assigning device=0") self.args.device = "0" + if fmt == "engine" and "dla" in str(self.args.device): # convert int/list to str first + dla = self.args.device.split(":")[-1] + assert dla in {"0", "1"}, f"Expected self.args.device='dla:0' or 'dla:1, but got {self.args.device}." self.device = select_device("cpu" if self.args.device is None else self.args.device) # Checks if not hasattr(model, "names"): @@ -349,7 +353,7 @@ class Exporter: if jit or ncnn: # TorchScript f[0], _ = self.export_torchscript() if engine: # TensorRT required before ONNX - f[1], _ = self.export_engine() + f[1], _ = self.export_engine(dla=dla) if onnx: # ONNX f[2], _ = self.export_onnx() if xml: # OpenVINO @@ -495,6 +499,7 @@ class Exporter: @try_export def export_openvino(self, prefix=colorstr("OpenVINO:")): """YOLO OpenVINO export.""" + # WARNING: numpy>=2.0.0 issue with OpenVINO on macOS https://github.com/ultralytics/ultralytics/pull/17221 check_requirements(f'openvino{"<=2024.0.0" if ARM64 else ">=2024.0.0"}') # fix OpenVINO issue on ARM64 import openvino as ov @@ -724,7 +729,7 @@ class Exporter: return f, ct_model @try_export - def export_engine(self, prefix=colorstr("TensorRT:")): + def export_engine(self, dla=None, prefix=colorstr("TensorRT:")): """YOLO TensorRT export https://developer.nvidia.com/tensorrt.""" assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'" f_onnx, _ = self.export_onnx() # run before TRT import https://github.com/ultralytics/ultralytics/issues/7016 @@ -733,10 +738,10 @@ class Exporter: import tensorrt as trt # noqa except ImportError: if LINUX: - check_requirements("tensorrt>7.0.0,<=10.1.0") + check_requirements("tensorrt>7.0.0,!=10.1.0") import tensorrt as trt # noqa check_version(trt.__version__, ">=7.0.0", hard=True) - check_version(trt.__version__, "<=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239") + check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239") # Setup and checks LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...") @@ -759,6 +764,20 @@ class Exporter: network = builder.create_network(flag) half = builder.platform_has_fast_fp16 and self.args.half int8 = builder.platform_has_fast_int8 and self.args.int8 + + # Optionally switch to DLA if enabled + if dla is not None: + if not IS_JETSON: + raise ValueError("DLA is only available on NVIDIA Jetson devices") + LOGGER.info(f"{prefix} enabling DLA on core {dla}...") + if not self.args.half and not self.args.int8: + raise ValueError( + "DLA requires either 'half=True' (FP16) or 'int8=True' (INT8) to be enabled. Please enable one of them and try again." + ) + config.default_device_type = trt.DeviceType.DLA + config.DLA_core = int(dla) + config.set_flag(trt.BuilderFlag.GPU_FALLBACK) + # Read ONNX file parser = trt.OnnxParser(network, logger) if not parser.parse_from_file(f_onnx): @@ -913,8 +932,10 @@ class Exporter: tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file if self.args.data: f.mkdir() - images = [batch["img"].permute(0, 2, 3, 1) for batch in self.get_int8_calibration_dataloader(prefix)] - images = torch.cat(images, 0).float() + images = [batch["img"] for batch in self.get_int8_calibration_dataloader(prefix)] + images = torch.nn.functional.interpolate(torch.cat(images, 0).float(), size=self.imgsz).permute( + 0, 2, 3, 1 + ) np.save(str(tmp_file), images.numpy().astype(np.float32)) # BHWC np_data = [["images", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]] diff --git a/ultralytics/models/sam/build.py b/ultralytics/models/sam/build.py index e110531244..cee5133a09 100644 --- a/ultralytics/models/sam/build.py +++ b/ultralytics/models/sam/build.py @@ -263,6 +263,7 @@ def _build_sam2( memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer()) memory_encoder = MemoryEncoder(out_dim=64) + is_sam2_1 = checkpoint is not None and "sam2.1" in checkpoint sam2 = SAM2Model( image_encoder=image_encoder, memory_attention=memory_attention, @@ -288,6 +289,9 @@ def _build_sam2( multimask_max_pt_num=1, use_mlp_for_obj_ptr_proj=True, compile_image_encoder=False, + no_obj_embed_spatial=is_sam2_1, + proj_tpos_enc_in_obj_ptrs=is_sam2_1, + use_signed_tpos_enc_to_obj_ptrs=is_sam2_1, sam_mask_decoder_extra_args=dict( dynamic_multimask_via_stability=True, dynamic_multimask_stability_delta=0.05, @@ -313,6 +317,10 @@ sam_model_map = { "sam2_s.pt": build_sam2_s, "sam2_b.pt": build_sam2_b, "sam2_l.pt": build_sam2_l, + "sam2.1_t.pt": build_sam2_t, + "sam2.1_s.pt": build_sam2_s, + "sam2.1_b.pt": build_sam2_b, + "sam2.1_l.pt": build_sam2_l, } diff --git a/ultralytics/models/sam/modules/sam.py b/ultralytics/models/sam/modules/sam.py index 2728b0b481..562314b2b9 100644 --- a/ultralytics/models/sam/modules/sam.py +++ b/ultralytics/models/sam/modules/sam.py @@ -161,18 +161,19 @@ class SAM2Model(torch.nn.Module): use_multimask_token_for_obj_ptr: bool = False, iou_prediction_use_sigmoid=False, memory_temporal_stride_for_eval=1, - add_all_frames_to_correct_as_cond=False, non_overlap_masks_for_mem_enc=False, use_obj_ptrs_in_encoder=False, max_obj_ptrs_in_encoder=16, add_tpos_enc_to_obj_ptrs=True, proj_tpos_enc_in_obj_ptrs=False, + use_signed_tpos_enc_to_obj_ptrs=False, only_obj_ptrs_in_the_past_for_eval=False, pred_obj_scores: bool = False, pred_obj_scores_mlp: bool = False, fixed_no_obj_ptr: bool = False, soft_no_obj_ptr: bool = False, use_mlp_for_obj_ptr_proj: bool = False, + no_obj_embed_spatial: bool = False, sam_mask_decoder_extra_args=None, compile_image_encoder: bool = False, ): @@ -205,8 +206,6 @@ class SAM2Model(torch.nn.Module): use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers. iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1]. memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation. - add_all_frames_to_correct_as_cond (bool): Whether to append frames with correction clicks to conditioning - frame list. non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in memory encoder during evaluation. use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder. @@ -216,6 +215,9 @@ class SAM2Model(torch.nn.Module): the encoder. proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional encoding in object pointers. + use_signed_tpos_enc_to_obj_ptrs (bool): whether to use signed distance (instead of unsigned absolute distance) + in the temporal positional encoding in the object pointers, only relevant when both `use_obj_ptrs_in_encoder=True` + and `add_tpos_enc_to_obj_ptrs=True`. only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past during evaluation. pred_obj_scores (bool): Whether to predict if there is an object in the frame. @@ -223,6 +225,7 @@ class SAM2Model(torch.nn.Module): fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present. soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation. use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection. + no_obj_embed_spatial (bool): Whether add no obj embedding to spatial frames. sam_mask_decoder_extra_args (Dict | None): Extra arguments for constructing the SAM mask decoder. compile_image_encoder (bool): Whether to compile the image encoder for faster inference. @@ -253,6 +256,7 @@ class SAM2Model(torch.nn.Module): if proj_tpos_enc_in_obj_ptrs: assert add_tpos_enc_to_obj_ptrs # these options need to be used together self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs + self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval # Part 2: memory attention to condition current frame's visual features @@ -309,9 +313,12 @@ class SAM2Model(torch.nn.Module): self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) trunc_normal_(self.no_obj_ptr, std=0.02) self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj + self.no_obj_embed_spatial = None + if no_obj_embed_spatial: + self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim)) + trunc_normal_(self.no_obj_embed_spatial, std=0.02) self._build_sam_heads() - self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond self.max_cond_frames_in_attn = max_cond_frames_in_attn # Model compilation @@ -533,8 +540,6 @@ class SAM2Model(torch.nn.Module): if self.pred_obj_scores: # Allow *soft* no obj ptr, unlike for masks if self.soft_no_obj_ptr: - # Only hard possible with gt - assert not self.teacher_force_obj_scores_for_mem lambda_is_obj_appearing = object_score_logits.sigmoid() else: lambda_is_obj_appearing = is_obj_appearing.float() @@ -647,6 +652,7 @@ class SAM2Model(torch.nn.Module): if self.num_maskmem == 0: # Disable memory and skip fusion return current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) num_obj_ptr_tokens = 0 + tpos_sign_mul = -1 if track_in_reverse else 1 # Step 1: condition the visual features of the current frame on previous memories if not is_init_cond_frame: # Retrieve the memories encoded with the maskmem backbone @@ -664,7 +670,7 @@ class SAM2Model(torch.nn.Module): # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 # We also allow taking the memory frame non-consecutively (with r>1), in which case # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame. - r = self.memory_temporal_stride_for_eval + r = 1 if self.training else self.memory_temporal_stride_for_eval for t_pos in range(1, self.num_maskmem): t_rel = self.num_maskmem - t_pos # how many frames before current frame if t_rel == 1: @@ -718,7 +724,14 @@ class SAM2Model(torch.nn.Module): ptr_cond_outputs = selected_cond_outputs pos_and_ptrs = [ # Temporal pos encoding contains how far away each pointer is from current frame - (abs(frame_idx - t), out["obj_ptr"]) + ( + ( + (frame_idx - t) * tpos_sign_mul + if self.use_signed_tpos_enc_to_obj_ptrs + else abs(frame_idx - t) + ), + out["obj_ptr"], + ) for t, out in ptr_cond_outputs.items() ] # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame @@ -787,6 +800,7 @@ class SAM2Model(torch.nn.Module): current_vision_feats, feat_sizes, pred_masks_high_res, + object_score_logits, is_mask_from_pts, ): """Encodes frame features and masks into a new memory representation for video segmentation.""" @@ -819,10 +833,17 @@ class SAM2Model(torch.nn.Module): ) maskmem_features = maskmem_out["vision_features"] maskmem_pos_enc = maskmem_out["vision_pos_enc"] + # add a no-object embedding to the spatial memory to indicate that the frame + # is predicted to be occluded (i.e. no object is appearing in the frame) + if self.no_obj_embed_spatial is not None: + is_obj_appearing = (object_score_logits > 0).float() + maskmem_features += (1 - is_obj_appearing[..., None, None]) * self.no_obj_embed_spatial[ + ..., None, None + ].expand(*maskmem_features.shape) return maskmem_features, maskmem_pos_enc - def track_step( + def _track_step( self, frame_idx, is_init_cond_frame, @@ -833,15 +854,7 @@ class SAM2Model(torch.nn.Module): mask_inputs, output_dict, num_frames, - track_in_reverse=False, # tracking in reverse time order (for demo usage) - # Whether to run the memory encoder on the predicted masks. Sometimes we might want - # to skip the memory encoder with `run_mem_encoder=False`. For example, - # in demo we might call `track_step` multiple times for each user click, - # and only encode the memory when the user finalizes their clicks. And in ablation - # settings like SAM training on static images, we don't need the memory encoder. - run_mem_encoder=True, - # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). - prev_sam_mask_logits=None, + prev_sam_mask_logits, ): """Performs a single tracking step, updating object masks and memory features based on current frame inputs.""" current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} @@ -861,7 +874,7 @@ class SAM2Model(torch.nn.Module): sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs) else: # fused the visual feature with previous memory features in the memory bank - pix_feat_with_mem = self._prepare_memory_conditioned_features( + pix_feat = self._prepare_memory_conditioned_features( frame_idx=frame_idx, is_init_cond_frame=is_init_cond_frame, current_vision_feats=current_vision_feats[-1:], @@ -880,12 +893,78 @@ class SAM2Model(torch.nn.Module): mask_inputs = prev_sam_mask_logits multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) sam_outputs = self._forward_sam_heads( - backbone_features=pix_feat_with_mem, + backbone_features=pix_feat, point_inputs=point_inputs, mask_inputs=mask_inputs, high_res_features=high_res_features, multimask_output=multimask_output, ) + return current_out, sam_outputs, high_res_features, pix_feat + + def _encode_memory_in_output( + self, + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ): + """Finally run the memory encoder on the predicted mask to encode, it into a new memory feature (that can be + used in future frames). + """ + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + ): + """Performs a single tracking step, updating object masks and memory features based on current frame inputs.""" + current_out, sam_outputs, _, _ = self._track_step( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ) + ( _, _, @@ -893,28 +972,28 @@ class SAM2Model(torch.nn.Module): low_res_masks, high_res_masks, obj_ptr, - _, + object_score_logits, ) = sam_outputs current_out["pred_masks"] = low_res_masks current_out["pred_masks_high_res"] = high_res_masks current_out["obj_ptr"] = obj_ptr + if not self.training: + # Only add this in inference (to avoid unused param in activation checkpointing; + # it's mainly used in the demo to encode spatial memories w/ consolidated masks) + current_out["object_score_logits"] = object_score_logits # Finally run the memory encoder on the predicted mask to encode # it into a new memory feature (that can be used in future frames) - if run_mem_encoder and self.num_maskmem > 0: - high_res_masks_for_mem_enc = high_res_masks - maskmem_features, maskmem_pos_enc = self._encode_new_memory( - current_vision_feats=current_vision_feats, - feat_sizes=feat_sizes, - pred_masks_high_res=high_res_masks_for_mem_enc, - is_mask_from_pts=(point_inputs is not None), - ) - current_out["maskmem_features"] = maskmem_features - current_out["maskmem_pos_enc"] = maskmem_pos_enc - else: - current_out["maskmem_features"] = None - current_out["maskmem_pos_enc"] = None + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ) return current_out diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py index 4002e092b6..a83159080f 100644 --- a/ultralytics/models/sam/predict.py +++ b/ultralytics/models/sam/predict.py @@ -478,7 +478,7 @@ class Predictor(BasePredictor): results = [] for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]): if len(masks) == 0: - masks = None + masks, pred_bboxes = None, torch.zeros((0, 6), device=pred_masks.device) else: masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0] masks = masks > self.model.mask_threshold # to bool diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index 75cb7e5949..b6b8fcbb68 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -224,10 +224,10 @@ class AutoBackend(nn.Module): import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download except ImportError: if LINUX: - check_requirements("tensorrt>7.0.0,<=10.1.0") + check_requirements("tensorrt>7.0.0,!=10.1.0") import tensorrt as trt # noqa check_version(trt.__version__, ">=7.0.0", hard=True) - check_version(trt.__version__, "<=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239") + check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239") if device.type == "cpu": device = torch.device("cuda:0") Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr")) @@ -343,6 +343,7 @@ class AutoBackend(nn.Module): model_path=w, experimental_delegates=[load_delegate(delegate, options={"device": device})], ) + device = "cpu" # Required, otherwise PyTorch will try to use the wrong device else: # TFLite LOGGER.info(f"Loading {w} for TensorFlow Lite inference...") interpreter = Interpreter(model_path=w) # load TFLite model diff --git a/ultralytics/solutions/parking_management.py b/ultralytics/solutions/parking_management.py index fa815938ab..a62de99524 100644 --- a/ultralytics/solutions/parking_management.py +++ b/ultralytics/solutions/parking_management.py @@ -168,7 +168,6 @@ class ParkingManagement(BaseSolution): Examples: >>> from ultralytics.solutions import ParkingManagement >>> parking_manager = ParkingManagement(model="yolov8n.pt", json_file="parking_regions.json") - >>> results = parking_manager(source="parking_lot_video.mp4") >>> print(f"Occupied spaces: {parking_manager.pr_info['Occupancy']}") >>> print(f"Available spaces: {parking_manager.pr_info['Available']}") """ diff --git a/ultralytics/solutions/solutions.py b/ultralytics/solutions/solutions.py index 1af0c0ba09..e43aba6441 100644 --- a/ultralytics/solutions/solutions.py +++ b/ultralytics/solutions/solutions.py @@ -1,16 +1,13 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license from collections import defaultdict -from pathlib import Path import cv2 from ultralytics import YOLO -from ultralytics.utils import LOGGER, yaml_load +from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_SOL_DICT, LOGGER from ultralytics.utils.checks import check_imshow, check_requirements -DEFAULT_SOL_CFG_PATH = Path(__file__).resolve().parents[1] / "cfg/solutions/default.yaml" - class BaseSolution: """ @@ -55,15 +52,18 @@ class BaseSolution: self.Point = Point # Load config and update with args - self.CFG = yaml_load(DEFAULT_SOL_CFG_PATH) - self.CFG.update(kwargs) - LOGGER.info(f"Ultralytics Solutions: ✅ {self.CFG}") + DEFAULT_SOL_DICT.update(kwargs) + DEFAULT_CFG_DICT.update(kwargs) + self.CFG = {**DEFAULT_SOL_DICT, **DEFAULT_CFG_DICT} + LOGGER.info(f"Ultralytics Solutions: ✅ {DEFAULT_SOL_DICT}") self.region = self.CFG["region"] # Store region data for other classes usage - self.line_width = self.CFG["line_width"] # Store line_width for usage + self.line_width = ( + self.CFG["line_width"] if self.CFG["line_width"] is not None else 2 + ) # Store line_width for usage # Load Model and store classes names - self.model = YOLO(self.CFG["model"]) + self.model = YOLO(self.CFG["model"] if self.CFG["model"] else "yolov8n.pt") self.names = self.model.names # Initialize environment and region setup diff --git a/ultralytics/utils/__init__.py b/ultralytics/utils/__init__.py index 05a4f464b7..d9cd96e3c4 100644 --- a/ultralytics/utils/__init__.py +++ b/ultralytics/utils/__init__.py @@ -38,6 +38,7 @@ FILE = Path(__file__).resolve() ROOT = FILE.parents[1] # YOLO ASSETS = ROOT / "assets" # default images DEFAULT_CFG_PATH = ROOT / "cfg/default.yaml" +DEFAULT_SOL_CFG_PATH = ROOT / "cfg/solutions/default.yaml" # Ultralytics solutions yaml path NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLO multiprocessing threads AUTOINSTALL = str(os.getenv("YOLO_AUTOINSTALL", True)).lower() == "true" # global auto-install mode VERBOSE = str(os.getenv("YOLO_VERBOSE", True)).lower() == "true" # global verbose mode @@ -508,6 +509,7 @@ def yaml_print(yaml_file: Union[str, Path, dict]) -> None: # Default configuration DEFAULT_CFG_DICT = yaml_load(DEFAULT_CFG_PATH) +DEFAULT_SOL_DICT = yaml_load(DEFAULT_SOL_CFG_PATH) # Ultralytics solutions configuration for k, v in DEFAULT_CFG_DICT.items(): if isinstance(v, str) and v.lower() == "none": DEFAULT_CFG_DICT[k] = None @@ -566,12 +568,16 @@ def is_kaggle(): def is_jupyter(): """ - Check if the current script is running inside a Jupyter Notebook. Verified on Colab, Jupyterlab, Kaggle, Paperspace. + Check if the current script is running inside a Jupyter Notebook. Returns: (bool): True if running inside a Jupyter Notebook, False otherwise. + + Note: + - Only works on Colab and Kaggle, other environments like Jupyterlab and Paperspace are not reliably detectable. + - "get_ipython" in globals() method suffers false positives when IPython package installed manually. """ - return "get_ipython" in globals() + return IS_COLAB or IS_KAGGLE def is_docker() -> bool: @@ -799,10 +805,10 @@ def get_user_config_dir(sub_dir="Ultralytics"): PROC_DEVICE_MODEL = read_device_model() # is_jetson() and is_raspberrypi() depend on this constant ONLINE = is_online() IS_COLAB = is_colab() +IS_KAGGLE = is_kaggle() IS_DOCKER = is_docker() IS_JETSON = is_jetson() IS_JUPYTER = is_jupyter() -IS_KAGGLE = is_kaggle() IS_PIP_PACKAGE = is_pip_package() IS_RASPBERRYPI = is_raspberrypi() GIT_DIR = get_git_dir() @@ -1193,7 +1199,7 @@ class SettingsManager(JSONDict): "neptune": True, # Neptune integration "raytune": True, # Ray Tune integration "tensorboard": True, # TensorBoard logging - "wandb": True, # Weights & Biases logging + "wandb": False, # Weights & Biases logging "vscode_msg": True, # VSCode messaging } diff --git a/ultralytics/utils/callbacks/comet.py b/ultralytics/utils/callbacks/comet.py index 3a217c3f25..3fae97f917 100644 --- a/ultralytics/utils/callbacks/comet.py +++ b/ultralytics/utils/callbacks/comet.py @@ -1,6 +1,7 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops +from ultralytics.utils.metrics import ClassifyMetrics, DetMetrics, OBBMetrics, PoseMetrics, SegmentMetrics try: assert not TESTS_RUNNING # do not log pytest @@ -16,8 +17,11 @@ try: COMET_SUPPORTED_TASKS = ["detect"] # Names of plots created by Ultralytics that are logged to Comet - EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve", "confusion_matrix" + CONFUSION_MATRIX_PLOT_NAMES = "confusion_matrix", "confusion_matrix_normalized" + EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve" LABEL_PLOT_NAMES = "labels", "labels_correlogram" + SEGMENT_METRICS_PLOT_PREFIX = "Box", "Mask" + POSE_METRICS_PLOT_PREFIX = "Box", "Pose" _comet_image_prediction_count = 0 @@ -86,7 +90,7 @@ def _create_experiment(args): "max_image_predictions": _get_max_image_predictions_to_log(), } ) - experiment.log_other("Created from", "yolov8") + experiment.log_other("Created from", "ultralytics") except Exception as e: LOGGER.warning(f"WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}") @@ -274,11 +278,31 @@ def _log_image_predictions(experiment, validator, curr_step): def _log_plots(experiment, trainer): """Logs evaluation plots and label plots for the experiment.""" - plot_filenames = [trainer.save_dir / f"{plots}.png" for plots in EVALUATION_PLOT_NAMES] - _log_images(experiment, plot_filenames, None) - - label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES] - _log_images(experiment, label_plot_filenames, None) + plot_filenames = None + if isinstance(trainer.validator.metrics, SegmentMetrics) and trainer.validator.metrics.task == "segment": + plot_filenames = [ + trainer.save_dir / f"{prefix}{plots}.png" + for plots in EVALUATION_PLOT_NAMES + for prefix in SEGMENT_METRICS_PLOT_PREFIX + ] + elif isinstance(trainer.validator.metrics, PoseMetrics): + plot_filenames = [ + trainer.save_dir / f"{prefix}{plots}.png" + for plots in EVALUATION_PLOT_NAMES + for prefix in POSE_METRICS_PLOT_PREFIX + ] + elif isinstance(trainer.validator.metrics, DetMetrics) or isinstance(trainer.validator.metrics, OBBMetrics): + plot_filenames = [trainer.save_dir / f"{plots}.png" for plots in EVALUATION_PLOT_NAMES] + + if plot_filenames is not None: + _log_images(experiment, plot_filenames, None) + + confusion_matrix_filenames = [trainer.save_dir / f"{plots}.png" for plots in CONFUSION_MATRIX_PLOT_NAMES] + _log_images(experiment, confusion_matrix_filenames, None) + + if not isinstance(trainer.validator.metrics, ClassifyMetrics): + label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES] + _log_images(experiment, label_plot_filenames, None) def _log_model(experiment, trainer): @@ -307,9 +331,6 @@ def on_train_epoch_end(trainer): experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=curr_step, epoch=curr_epoch) - if curr_epoch == 1: - _log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step) - def on_fit_epoch_end(trainer): """Logs model assets at the end of each epoch.""" @@ -356,6 +377,8 @@ def on_train_end(trainer): _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch) _log_image_predictions(experiment, trainer.validator, curr_step) + _log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step) + _log_images(experiment, trainer.save_dir.glob("val_batch*.jpg"), curr_step) experiment.end() global _comet_image_prediction_count diff --git a/ultralytics/utils/callbacks/wb.py b/ultralytics/utils/callbacks/wb.py index 7b6d00cfc3..b82b8d85ec 100644 --- a/ultralytics/utils/callbacks/wb.py +++ b/ultralytics/utils/callbacks/wb.py @@ -137,17 +137,19 @@ def on_train_end(trainer): if trainer.best.exists(): art.add_file(trainer.best) wb.run.log_artifact(art, aliases=["best"]) - for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results): - x, y, x_title, y_title = curve_values - _plot_curve( - x, - y, - names=list(trainer.validator.metrics.names.values()), - id=f"curves/{curve_name}", - title=curve_name, - x_title=x_title, - y_title=y_title, - ) + # Check if we actually have plots to save + if trainer.args.plots: + for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results): + x, y, x_title, y_title = curve_values + _plot_curve( + x, + y, + names=list(trainer.validator.metrics.names.values()), + id=f"curves/{curve_name}", + title=curve_name, + x_title=x_title, + y_title=y_title, + ) wb.run.finish() # required or run continues on dashboard diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py index c483e31366..9591d3dea2 100644 --- a/ultralytics/utils/checks.py +++ b/ultralytics/utils/checks.py @@ -335,7 +335,7 @@ def check_font(font="Arial.ttf"): return file -def check_python(minimum: str = "3.8.0", hard: bool = True, verbose: bool = True) -> bool: +def check_python(minimum: str = "3.8.0", hard: bool = True, verbose: bool = False) -> bool: """ Check current python version against the required minimum version. @@ -688,7 +688,7 @@ def check_amp(model): im = ASSETS / "bus.jpg" # image to check prefix = colorstr("AMP: ") - LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks with YOLO11n...") + LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks...") warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False." try: from ultralytics import YOLO @@ -696,11 +696,13 @@ def check_amp(model): assert amp_allclose(YOLO("yolo11n.pt"), im) LOGGER.info(f"{prefix}checks passed ✅") except ConnectionError: - LOGGER.warning(f"{prefix}checks skipped ⚠️, offline and unable to download YOLO11n. {warning_msg}") + LOGGER.warning( + f"{prefix}checks skipped ⚠️. " f"Offline and unable to download YOLO11n for AMP checks. {warning_msg}" + ) except (AttributeError, ModuleNotFoundError): LOGGER.warning( f"{prefix}checks skipped ⚠️. " - f"Unable to load YOLO11n due to possible Ultralytics package modifications. {warning_msg}" + f"Unable to load YOLO11n for AMP checks due to possible Ultralytics package modifications. {warning_msg}" ) except AssertionError: LOGGER.warning(