Add device selection for `yolo_bbox2segment` (#17409)

pull/17411/head^2
Laughing 3 months ago committed by GitHub
parent 9a7b344fd0
commit b0aef79d36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 5
      ultralytics/data/converter.py

@ -577,7 +577,7 @@ def merge_multi_segment(segments):
return s return s
def yolo_bbox2segment(im_dir, save_dir=None, sam_model="sam_b.pt"): def yolo_bbox2segment(im_dir, save_dir=None, sam_model="sam_b.pt", device=None):
""" """
Converts existing object detection dataset (bounding boxes) to segmentation dataset or oriented bounding box (OBB) Converts existing object detection dataset (bounding boxes) to segmentation dataset or oriented bounding box (OBB)
in YOLO format. Generates segmentation data using SAM auto-annotator as needed. in YOLO format. Generates segmentation data using SAM auto-annotator as needed.
@ -587,6 +587,7 @@ def yolo_bbox2segment(im_dir, save_dir=None, sam_model="sam_b.pt"):
save_dir (str | Path): Path to save the generated labels, labels will be saved save_dir (str | Path): Path to save the generated labels, labels will be saved
into `labels-segment` in the same directory level of `im_dir` if save_dir is None. Default: None. into `labels-segment` in the same directory level of `im_dir` if save_dir is None. Default: None.
sam_model (str): Segmentation model to use for intermediate segmentation data; optional. sam_model (str): Segmentation model to use for intermediate segmentation data; optional.
device (int | str): The specific device to run SAM models. Default: None.
Notes: Notes:
The input directory structure assumed for dataset: The input directory structure assumed for dataset:
@ -621,7 +622,7 @@ def yolo_bbox2segment(im_dir, save_dir=None, sam_model="sam_b.pt"):
boxes[:, [0, 2]] *= w boxes[:, [0, 2]] *= w
boxes[:, [1, 3]] *= h boxes[:, [1, 3]] *= h
im = cv2.imread(label["im_file"]) im = cv2.imread(label["im_file"])
sam_results = sam_model(im, bboxes=xywh2xyxy(boxes), verbose=False, save=False) sam_results = sam_model(im, bboxes=xywh2xyxy(boxes), verbose=False, save=False, device=device)
label["segments"] = sam_results[0].masks.xyn label["segments"] = sam_results[0].masks.xyn
save_dir = Path(save_dir) if save_dir else Path(im_dir).parent / "labels-segment" save_dir = Path(save_dir) if save_dir else Path(im_dir).parent / "labels-segment"

Loading…
Cancel
Save