diff --git a/docs/en/modes/predict.md b/docs/en/modes/predict.md index fc3397822..d13bd5eca 100644 --- a/docs/en/modes/predict.md +++ b/docs/en/modes/predict.md @@ -720,6 +720,7 @@ The `plot()` method supports various arguments to customize the output: | `show` | `bool` | Display the annotated image directly using the default image viewer. | `False` | | `save` | `bool` | Save the annotated image to a file specified by `filename`. | `False` | | `filename` | `str` | Path and name of the file to save the annotated image if `save` is `True`. | `None` | +| `color_mode` | `str` | Specify the color mode, e.g., 'instance' or 'class'. | `'class'` | ## Thread-Safe Inference diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 51f34eb25..04f4fc96d 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.2.76" +__version__ = "8.2.77" import os diff --git a/ultralytics/engine/results.py b/ultralytics/engine/results.py index ced02f0ab..34a5d3e08 100644 --- a/ultralytics/engine/results.py +++ b/ultralytics/engine/results.py @@ -460,6 +460,7 @@ class Results(SimpleClass): show=False, save=False, filename=None, + color_mode="class", ): """ Plots detection results on an input RGB image. @@ -481,6 +482,7 @@ class Results(SimpleClass): show (bool): Whether to display the annotated image. save (bool): Whether to save the annotated image. filename (str | None): Filename to save image if save is True. + color_mode (bool): Specify the color mode, e.g., 'instance' or 'class'. Default to 'class'. Returns: (np.ndarray): Annotated image as a numpy array. @@ -491,6 +493,7 @@ class Results(SimpleClass): ... im = result.plot() ... im.show() """ + assert color_mode in {"instance", "class"}, f"Expected color_mode='instance' or 'class', not {color_mode}." if img is None and isinstance(self.orig_img, torch.Tensor): img = (self.orig_img[0].detach().permute(1, 2, 0).contiguous() * 255).to(torch.uint8).cpu().numpy() @@ -519,17 +522,22 @@ class Results(SimpleClass): .contiguous() / 255 ) - idx = pred_boxes.cls if pred_boxes else range(len(pred_masks)) + idx = pred_boxes.cls if pred_boxes and color_mode == "class" else reversed(range(len(pred_masks))) annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu) # Plot Detect results if pred_boxes is not None and show_boxes: - for d in reversed(pred_boxes): + for i, d in enumerate(reversed(pred_boxes)): c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item()) name = ("" if id is None else f"id:{id} ") + names[c] label = (f"{name} {conf:.2f}" if conf else name) if labels else None box = d.xyxyxyxy.reshape(-1, 4, 2).squeeze() if is_obb else d.xyxy.squeeze() - annotator.box_label(box, label, color=colors(c, True), rotated=is_obb) + annotator.box_label( + box, + label, + color=colors(i if color_mode == "instance" else c, True), + rotated=is_obb, + ) # Plot Classify results if pred_probs is not None and show_probs: @@ -539,8 +547,14 @@ class Results(SimpleClass): # Plot Pose results if self.keypoints is not None: - for k in reversed(self.keypoints.data): - annotator.kpts(k, self.orig_shape, radius=kpt_radius, kpt_line=kpt_line) + for i, k in enumerate(reversed(self.keypoints.data)): + annotator.kpts( + k, + self.orig_shape, + radius=kpt_radius, + kpt_line=kpt_line, + kpt_color=colors(i, True) if color_mode == "instance" else None, + ) # Show results if show: diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py index 879a6ebd2..aaca19dc3 100644 --- a/ultralytics/utils/plotting.py +++ b/ultralytics/utils/plotting.py @@ -369,7 +369,7 @@ class Annotator: # Convert im back to PIL and update draw self.fromarray(self.im) - def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True, conf_thres=0.25): + def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True, conf_thres=0.25, kpt_color=None): """ Plot keypoints on the image. @@ -379,6 +379,7 @@ class Annotator: radius (int, optional): Radius of the drawn keypoints. Default is 5. kpt_line (bool, optional): If True, the function will draw lines connecting keypoints for human pose. Default is True. + kpt_color (tuple, optional): The color of the keypoints (B, G, R). Note: `kpt_line=True` currently only supports human pose plotting. @@ -391,7 +392,7 @@ class Annotator: is_pose = nkpt == 17 and ndim in {2, 3} kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plotting for i, k in enumerate(kpts): - color_k = [int(x) for x in self.kpt_color[i]] if is_pose else colors(i) + color_k = kpt_color or (self.kpt_color[i].tolist() if is_pose else colors(i)) x_coord, y_coord = k[0], k[1] if x_coord % shape[1] != 0 and y_coord % shape[0] != 0: if len(k) == 3: @@ -414,7 +415,14 @@ class Annotator: continue if pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0: continue - cv2.line(self.im, pos1, pos2, [int(x) for x in self.limb_color[i]], thickness=2, lineType=cv2.LINE_AA) + cv2.line( + self.im, + pos1, + pos2, + kpt_color or self.limb_color[i].tolist(), + thickness=2, + lineType=cv2.LINE_AA, + ) if self.pil: # Convert im back to PIL and update draw self.fromarray(self.im)