From e371e81aa07802306a3368ec69e0b3c8c40ef844 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Tue, 10 Jan 2023 22:59:11 +0530 Subject: [PATCH] Webcam inference fix (#202) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher --- ultralytics/yolo/engine/predictor.py | 8 ++++---- ultralytics/yolo/v8/classify/predict.py | 4 +++- ultralytics/yolo/v8/detect/predict.py | 2 +- ultralytics/yolo/v8/segment/predict.py | 4 +++- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index 15c2ff1433..c3e93cf163 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -57,7 +57,6 @@ class BasePredictor: dataset (Dataset): Dataset used for prediction. vid_path (str): Path to video file. vid_writer (cv2.VideoWriter): Video writer for saving video output. - show (bool): Whether to view image output. annotator (Annotator): Annotator used for prediction. data_path (str): Path to data. """ @@ -88,7 +87,6 @@ class BasePredictor: self.device = None self.dataset = None self.vid_path, self.vid_writer = None, None - self.show = None self.annotator = None self.data_path = None self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks @@ -108,7 +106,7 @@ class BasePredictor: def setup(self, source=None, model=None): # source - source = str(source or self.args.source) + source = str(source if source is not None else self.args.source) is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://')) webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file) @@ -126,8 +124,10 @@ class BasePredictor: # Dataloader bs = 1 # batch_size + if self.args.show: + self.args.show = check_imshow(warn=True) if webcam: - self.show = check_imshow(warn=True) + self.args.show = check_imshow(warn=True) self.dataset = LoadStreams(source, imgsz=imgsz, stride=stride, auto=pt, vid_stride=self.args.vid_stride) bs = len(self.dataset) elif screenshot: diff --git a/ultralytics/yolo/v8/classify/predict.py b/ultralytics/yolo/v8/classify/predict.py index d22939038f..3ef879b824 100644 --- a/ultralytics/yolo/v8/classify/predict.py +++ b/ultralytics/yolo/v8/classify/predict.py @@ -4,7 +4,7 @@ import hydra import torch from ultralytics.yolo.engine.predictor import BasePredictor -from ultralytics.yolo.utils import DEFAULT_CONFIG +from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT from ultralytics.yolo.utils.checks import check_imgsz from ultralytics.yolo.utils.plotting import Annotator @@ -59,6 +59,8 @@ class ClassificationPredictor(BasePredictor): def predict(cfg): cfg.model = cfg.model or "squeezenet1_0" cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size + cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" + predictor = ClassificationPredictor(cfg) predictor() diff --git a/ultralytics/yolo/v8/detect/predict.py b/ultralytics/yolo/v8/detect/predict.py index 8b9ed3ac57..918bf9bada 100644 --- a/ultralytics/yolo/v8/detect/predict.py +++ b/ultralytics/yolo/v8/detect/predict.py @@ -87,7 +87,7 @@ class DetectionPredictor(BasePredictor): def predict(cfg): cfg.model = cfg.model or "yolov8n.pt" cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size - cfg.source = cfg.source or ROOT / "assets" + cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" predictor = DetectionPredictor(cfg) predictor() diff --git a/ultralytics/yolo/v8/segment/predict.py b/ultralytics/yolo/v8/segment/predict.py index 3d89d7fde6..4a6a133733 100644 --- a/ultralytics/yolo/v8/segment/predict.py +++ b/ultralytics/yolo/v8/segment/predict.py @@ -3,7 +3,7 @@ import hydra import torch -from ultralytics.yolo.utils import DEFAULT_CONFIG, ops +from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, ops from ultralytics.yolo.utils.checks import check_imgsz from ultralytics.yolo.utils.plotting import colors, save_one_box @@ -103,6 +103,8 @@ class SegmentationPredictor(DetectionPredictor): def predict(cfg): cfg.model = cfg.model or "yolov8n-seg.pt" cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size + cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" + predictor = SegmentationPredictor(cfg) predictor()