Support prediction of list of sources, in-memory dataset and other improvements (#685)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
main
Ayush Chaurasia 2 years ago committed by GitHub
parent a5410ed79e
commit 0609561549
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 38
      tests/test_python.py
  2. 2
      ultralytics/yolo/data/__init__.py
  3. 67
      ultralytics/yolo/data/build.py
  4. 35
      ultralytics/yolo/data/dataloaders/stream_loaders.py
  5. 7
      ultralytics/yolo/engine/model.py
  6. 92
      ultralytics/yolo/engine/predictor.py
  7. 2
      ultralytics/yolo/v8/classify/predict.py
  8. 2
      ultralytics/yolo/v8/detect/predict.py
  9. 2
      ultralytics/yolo/v8/segment/predict.py

@ -3,10 +3,12 @@
from pathlib import Path
import cv2
import numpy as np
import torch
from PIL import Image
from ultralytics import YOLO
from ultralytics.yolo.data.build import load_inference_source
from ultralytics.yolo.utils import ROOT, SETTINGS
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
@ -40,6 +42,7 @@ def test_predict_dir():
def test_predict_img():
model = YOLO(MODEL)
img = Image.open(str(SOURCE))
output = model(source=img, save=True, verbose=True) # PIL
@ -54,6 +57,16 @@ def test_predict_img():
tens = torch.zeros(320, 640, 3)
output = model(tens.numpy())
assert len(output) == 1, "predict test failed"
# test multiple source
imgs = [
SOURCE, # filename
Path(SOURCE), # Path
'https://ultralytics.com/images/zidane.jpg', # URI
cv2.imread(str(SOURCE)), # OpenCV
Image.open(SOURCE), # PIL
np.zeros((320, 640, 3))] # numpy
output = model(imgs)
assert len(output) == 6, "predict test failed!"
def test_val():
@ -129,3 +142,28 @@ def test_workflow():
model.val()
model.predict(SOURCE)
model.export(format="onnx", opset=12) # export a model to ONNX format
def test_predict_callback_and_setup():
def on_predict_batch_end(predictor):
# results -> List[batch_size]
path, _, im0s, _, _ = predictor.batch
# print('on_predict_batch_end', im0s[0].shape)
bs = [predictor.bs for i in range(0, len(path))]
predictor.results = zip(predictor.results, im0s, bs)
model = YOLO("yolov8n.pt")
model.add_callback("on_predict_batch_end", on_predict_batch_end)
dataset = load_inference_source(source=SOURCE, transforms=model.transforms)
bs = dataset.bs # access predictor properties
results = model.predict(dataset, stream=True) # source already setup
for _, (result, im0, bs) in enumerate(results):
print('test_callback', im0.shape)
print('test_callback', bs)
boxes = result.boxes # Boxes object for bbox outputs
print(boxes)
test_predict_img()

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
from .base import BaseDataset
from .build import build_classification_dataloader, build_dataloader
from .build import build_classification_dataloader, build_dataloader, load_inference_source
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
from .dataset_wrappers import MixAndRectDataset

@ -2,11 +2,18 @@
import os
import random
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, dataloader, distributed
from ultralytics.yolo.data.dataloaders.stream_loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots,
LoadStreams, SourceTypes, autocast_list)
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.utils.checks import check_file
from ..utils import LOGGER, colorstr
from ..utils.torch_utils import torch_distributed_zero_first
from .dataset import ClassificationDataset, YOLODataset
@ -123,3 +130,63 @@ def build_classification_dataloader(path,
pin_memory=PIN_MEMORY,
worker_init_fn=seed_worker,
generator=generator) # or DataLoader(persistent_workers=True)
def check_source(source):
webcam, screenshot, from_img, in_memory = False, False, False, False
if isinstance(source, (str, int, Path)): # int for local usb carame
source = str(source)
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
screenshot = source.lower().startswith('screen')
if is_url and is_file:
source = check_file(source) # download
elif isinstance(source, tuple(LOADERS)):
in_memory = True
elif isinstance(source, (list, tuple)):
source = autocast_list(source) # convert all list elements to PIL or np arrays
from_img = True
elif isinstance(source, ((Image.Image, np.ndarray))):
from_img = True
else:
raise Exception(
"Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict")
return source, webcam, screenshot, from_img, in_memory
def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1, stride=32, auto=True):
"""
TODO: docs
"""
# source
source, webcam, screenshot, from_img, in_memory = check_source(source)
source_type = SourceTypes(webcam, screenshot, from_img) if not in_memory else source.source_type
# Dataloader
if in_memory:
dataset = source
elif webcam:
dataset = LoadStreams(source,
imgsz=imgsz,
stride=stride,
auto=auto,
transforms=transforms,
vid_stride=vid_stride)
elif screenshot:
dataset = LoadScreenshots(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms)
elif from_img:
dataset = LoadPilAndNumpy(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms)
else:
dataset = LoadImages(source,
imgsz=imgsz,
stride=stride,
auto=auto,
transforms=transforms,
vid_stride=vid_stride)
setattr(dataset, 'source_type', source_type) # attach source types
return dataset

@ -4,14 +4,16 @@ import glob
import math
import os
import time
from dataclasses import dataclass
from pathlib import Path
from threading import Thread
from urllib.parse import urlparse
import cv2
import numpy as np
import requests
import torch
from PIL import Image
from PIL import Image, ImageOps
from ultralytics.yolo.data.augment import LetterBox
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
@ -19,6 +21,13 @@ from ultralytics.yolo.utils import LOGGER, ROOT, is_colab, is_kaggle, ops
from ultralytics.yolo.utils.checks import check_requirements
@dataclass
class SourceTypes:
webcam: bool = False
screenshot: bool = False
from_img: bool = False
class LoadStreams:
# YOLOv8 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
def __init__(self, sources='file.streams', imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
@ -63,6 +72,8 @@ class LoadStreams:
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
self.auto = auto and self.rect
self.transforms = transforms # optional
self.bs = self.__len__()
if not self.rect:
LOGGER.warning('WARNING ⚠ Stream shapes differ. For optimal performance supply similarly-shaped streams.')
@ -128,6 +139,7 @@ class LoadScreenshots:
self.mode = 'stream'
self.frame = 0
self.sct = mss.mss()
self.bs = 1
# Parse monitor shape
monitor = self.sct.monitors[self.screen]
@ -185,6 +197,7 @@ class LoadImages:
self.auto = auto
self.transforms = transforms # optional
self.vid_stride = vid_stride # video frame-rate stride
self.bs = 1
if any(videos):
self.orientation = None # rotation degrees
self._new_video(videos[0]) # new video
@ -276,6 +289,7 @@ class LoadPilAndNumpy:
self.mode = 'image'
# generate fake paths
self.paths = [f"image{i}.jpg" for i in range(len(self.im0))]
self.bs = 1
@staticmethod
def _single_check(im):
@ -311,6 +325,25 @@ class LoadPilAndNumpy:
return self
def autocast_list(source):
"""
Merges a list of source of different types into a list of numpy arrays or PIL images
"""
files = []
for _, im in enumerate(source):
if isinstance(im, (str, Path)): # filename or uri
files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im))
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
files.append(im)
else:
raise Exception(
"Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict")
return files
LOADERS = [LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots]
if __name__ == "__main__":
img = cv2.imread(str(ROOT / "assets/bus.jpg"))
dataset = LoadPilAndNumpy(im0=img)

@ -233,6 +233,13 @@ class YOLO:
"""
return self.model.names
@property
def transforms(self):
"""
Returns transform of the loaded model.
"""
return self.model.transforms if hasattr(self.model, 'transforms') else None
@staticmethod
def add_callback(event: str, func):
"""

@ -30,13 +30,13 @@ from collections import defaultdict
from pathlib import Path
import cv2
import torch
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.data import load_inference_source
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
from ultralytics.yolo.utils.checks import check_imgsz, check_imshow
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
@ -76,6 +76,8 @@ class BasePredictor:
if self.args.conf is None:
self.args.conf = 0.25 # default conf=0.25
self.done_warmup = False
if self.args.show:
self.args.show = check_imshow(warn=True)
# Usable if setup is done
self.model = None
@ -88,6 +90,7 @@ class BasePredictor:
self.vid_path, self.vid_writer = None, None
self.annotator = None
self.data_path = None
self.source_type = None
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
callbacks.add_integration_callbacks(self)
@ -103,53 +106,6 @@ class BasePredictor:
def postprocess(self, preds, img, orig_img, classes=None):
return preds
def setup_source(self, source=None):
if not self.model:
raise Exception("setup model before setting up source!")
# source
source, webcam, screenshot, from_img = self.check_source(source)
# model
stride, pt = self.model.stride, self.model.pt
imgsz = check_imgsz(self.args.imgsz, stride=stride, min_dim=2) # check image size
# Dataloader
bs = 1 # batch_size
if webcam:
self.args.show = check_imshow(warn=True)
self.dataset = LoadStreams(source,
imgsz=imgsz,
stride=stride,
auto=pt,
transforms=getattr(self.model.model, 'transforms', None),
vid_stride=self.args.vid_stride)
bs = len(self.dataset)
elif screenshot:
self.dataset = LoadScreenshots(source,
imgsz=imgsz,
stride=stride,
auto=pt,
transforms=getattr(self.model.model, 'transforms', None))
elif from_img:
self.dataset = LoadPilAndNumpy(source,
imgsz=imgsz,
stride=stride,
auto=pt,
transforms=getattr(self.model.model, 'transforms', None))
else:
self.dataset = LoadImages(source,
imgsz=imgsz,
stride=stride,
auto=pt,
transforms=getattr(self.model.model, 'transforms', None),
vid_stride=self.args.vid_stride)
self.vid_path, self.vid_writer = [None] * bs, [None] * bs
self.webcam = webcam
self.screenshot = screenshot
self.from_img = from_img
self.imgsz = imgsz
self.bs = bs
@smart_inference_mode()
def __call__(self, source=None, model=None, stream=False):
if stream:
@ -163,14 +119,29 @@ class BasePredictor:
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
pass
def setup_source(self, source):
if not self.model:
raise Exception("Model not initialized!")
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
self.dataset = load_inference_source(source=source,
transforms=getattr(self.model.model, 'transforms', None),
imgsz=self.imgsz,
vid_stride=self.args.vid_stride,
stride=self.model.stride,
auto=self.model.pt)
self.source_type = self.dataset.source_type
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
def stream_inference(self, source=None, model=None):
self.run_callbacks("on_predict_start")
# setup model
if not self.model:
self.setup_model(model)
# setup source. Run every time predict is called
self.setup_source(source)
# setup source every time predict is called
self.setup_source(source if source is not None else self.args.source)
# check if save_dir/ label file exists
if self.args.save or self.args.save_txt:
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
@ -198,7 +169,7 @@ class BasePredictor:
with self.dt[2]:
self.results = self.postprocess(preds, im, im0s, self.classes)
for i in range(len(im)):
p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s)
p, im0 = (path[i], im0s[i]) if self.source_type.webcam or self.source_type.from_img else (path, im0s)
p = Path(p)
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
@ -237,21 +208,6 @@ class BasePredictor:
self.device = device
self.model.eval()
def check_source(self, source):
source = source if source is not None else self.args.source
webcam, screenshot, from_img = False, False, False
if isinstance(source, (str, int, Path)): # int for local usb carame
source = str(source)
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
screenshot = source.lower().startswith('screen')
if is_url and is_file:
source = check_file(source) # download
else:
from_img = True
return source, webcam, screenshot, from_img
def show(self, p):
im0 = self.annotator.result()
if platform.system() == 'Linux' and p not in self.windows:

@ -33,7 +33,7 @@ class ClassificationPredictor(BasePredictor):
im = im[None] # expand for batch dim
self.seen += 1
im0 = im0.copy()
if self.webcam or self.from_img: # batch_size >= 1
if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
log_string += f'{idx}: '
frame = self.dataset.count
else:

@ -42,7 +42,7 @@ class DetectionPredictor(BasePredictor):
im = im[None] # expand for batch dim
self.seen += 1
im0 = im0.copy()
if self.webcam or self.from_img: # batch_size >= 1
if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
log_string += f'{idx}: '
frame = self.dataset.count
else:

@ -43,7 +43,7 @@ class SegmentationPredictor(DetectionPredictor):
if len(im.shape) == 3:
im = im[None] # expand for batch dim
self.seen += 1
if self.webcam or self.from_img: # batch_size >= 1
if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
log_string += f'{idx}: '
frame = self.dataset.count
else:

Loading…
Cancel
Save