|
|
|
@ -30,13 +30,13 @@ from collections import defaultdict |
|
|
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
import cv2 |
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
from ultralytics.nn.autobackend import AutoBackend |
|
|
|
|
from ultralytics.yolo.cfg import get_cfg |
|
|
|
|
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams |
|
|
|
|
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS |
|
|
|
|
from ultralytics.yolo.data import load_inference_source |
|
|
|
|
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops |
|
|
|
|
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow |
|
|
|
|
from ultralytics.yolo.utils.checks import check_imgsz, check_imshow |
|
|
|
|
from ultralytics.yolo.utils.files import increment_path |
|
|
|
|
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode |
|
|
|
|
|
|
|
|
@ -76,6 +76,8 @@ class BasePredictor: |
|
|
|
|
if self.args.conf is None: |
|
|
|
|
self.args.conf = 0.25 # default conf=0.25 |
|
|
|
|
self.done_warmup = False |
|
|
|
|
if self.args.show: |
|
|
|
|
self.args.show = check_imshow(warn=True) |
|
|
|
|
|
|
|
|
|
# Usable if setup is done |
|
|
|
|
self.model = None |
|
|
|
@ -88,6 +90,7 @@ class BasePredictor: |
|
|
|
|
self.vid_path, self.vid_writer = None, None |
|
|
|
|
self.annotator = None |
|
|
|
|
self.data_path = None |
|
|
|
|
self.source_type = None |
|
|
|
|
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks |
|
|
|
|
callbacks.add_integration_callbacks(self) |
|
|
|
|
|
|
|
|
@ -103,53 +106,6 @@ class BasePredictor: |
|
|
|
|
def postprocess(self, preds, img, orig_img, classes=None): |
|
|
|
|
return preds |
|
|
|
|
|
|
|
|
|
def setup_source(self, source=None): |
|
|
|
|
if not self.model: |
|
|
|
|
raise Exception("setup model before setting up source!") |
|
|
|
|
# source |
|
|
|
|
source, webcam, screenshot, from_img = self.check_source(source) |
|
|
|
|
# model |
|
|
|
|
stride, pt = self.model.stride, self.model.pt |
|
|
|
|
imgsz = check_imgsz(self.args.imgsz, stride=stride, min_dim=2) # check image size |
|
|
|
|
|
|
|
|
|
# Dataloader |
|
|
|
|
bs = 1 # batch_size |
|
|
|
|
if webcam: |
|
|
|
|
self.args.show = check_imshow(warn=True) |
|
|
|
|
self.dataset = LoadStreams(source, |
|
|
|
|
imgsz=imgsz, |
|
|
|
|
stride=stride, |
|
|
|
|
auto=pt, |
|
|
|
|
transforms=getattr(self.model.model, 'transforms', None), |
|
|
|
|
vid_stride=self.args.vid_stride) |
|
|
|
|
bs = len(self.dataset) |
|
|
|
|
elif screenshot: |
|
|
|
|
self.dataset = LoadScreenshots(source, |
|
|
|
|
imgsz=imgsz, |
|
|
|
|
stride=stride, |
|
|
|
|
auto=pt, |
|
|
|
|
transforms=getattr(self.model.model, 'transforms', None)) |
|
|
|
|
elif from_img: |
|
|
|
|
self.dataset = LoadPilAndNumpy(source, |
|
|
|
|
imgsz=imgsz, |
|
|
|
|
stride=stride, |
|
|
|
|
auto=pt, |
|
|
|
|
transforms=getattr(self.model.model, 'transforms', None)) |
|
|
|
|
else: |
|
|
|
|
self.dataset = LoadImages(source, |
|
|
|
|
imgsz=imgsz, |
|
|
|
|
stride=stride, |
|
|
|
|
auto=pt, |
|
|
|
|
transforms=getattr(self.model.model, 'transforms', None), |
|
|
|
|
vid_stride=self.args.vid_stride) |
|
|
|
|
self.vid_path, self.vid_writer = [None] * bs, [None] * bs |
|
|
|
|
|
|
|
|
|
self.webcam = webcam |
|
|
|
|
self.screenshot = screenshot |
|
|
|
|
self.from_img = from_img |
|
|
|
|
self.imgsz = imgsz |
|
|
|
|
self.bs = bs |
|
|
|
|
|
|
|
|
|
@smart_inference_mode() |
|
|
|
|
def __call__(self, source=None, model=None, stream=False): |
|
|
|
|
if stream: |
|
|
|
@ -163,14 +119,29 @@ class BasePredictor: |
|
|
|
|
for _ in gen: # running CLI inference without accumulating any outputs (do not modify) |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
def setup_source(self, source): |
|
|
|
|
if not self.model: |
|
|
|
|
raise Exception("Model not initialized!") |
|
|
|
|
|
|
|
|
|
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size |
|
|
|
|
self.dataset = load_inference_source(source=source, |
|
|
|
|
transforms=getattr(self.model.model, 'transforms', None), |
|
|
|
|
imgsz=self.imgsz, |
|
|
|
|
vid_stride=self.args.vid_stride, |
|
|
|
|
stride=self.model.stride, |
|
|
|
|
auto=self.model.pt) |
|
|
|
|
self.source_type = self.dataset.source_type |
|
|
|
|
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs |
|
|
|
|
|
|
|
|
|
def stream_inference(self, source=None, model=None): |
|
|
|
|
self.run_callbacks("on_predict_start") |
|
|
|
|
|
|
|
|
|
# setup model |
|
|
|
|
if not self.model: |
|
|
|
|
self.setup_model(model) |
|
|
|
|
# setup source. Run every time predict is called |
|
|
|
|
self.setup_source(source) |
|
|
|
|
# setup source every time predict is called |
|
|
|
|
self.setup_source(source if source is not None else self.args.source) |
|
|
|
|
|
|
|
|
|
# check if save_dir/ label file exists |
|
|
|
|
if self.args.save or self.args.save_txt: |
|
|
|
|
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
@ -198,7 +169,7 @@ class BasePredictor: |
|
|
|
|
with self.dt[2]: |
|
|
|
|
self.results = self.postprocess(preds, im, im0s, self.classes) |
|
|
|
|
for i in range(len(im)): |
|
|
|
|
p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s) |
|
|
|
|
p, im0 = (path[i], im0s[i]) if self.source_type.webcam or self.source_type.from_img else (path, im0s) |
|
|
|
|
p = Path(p) |
|
|
|
|
|
|
|
|
|
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show: |
|
|
|
@ -237,21 +208,6 @@ class BasePredictor: |
|
|
|
|
self.device = device |
|
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
|
def check_source(self, source): |
|
|
|
|
source = source if source is not None else self.args.source |
|
|
|
|
webcam, screenshot, from_img = False, False, False |
|
|
|
|
if isinstance(source, (str, int, Path)): # int for local usb carame |
|
|
|
|
source = str(source) |
|
|
|
|
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) |
|
|
|
|
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://')) |
|
|
|
|
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file) |
|
|
|
|
screenshot = source.lower().startswith('screen') |
|
|
|
|
if is_url and is_file: |
|
|
|
|
source = check_file(source) # download |
|
|
|
|
else: |
|
|
|
|
from_img = True |
|
|
|
|
return source, webcam, screenshot, from_img |
|
|
|
|
|
|
|
|
|
def show(self, p): |
|
|
|
|
im0 = self.annotator.result() |
|
|
|
|
if platform.system() == 'Linux' and p not in self.windows: |
|
|
|
|