`ultralytics 8.1.26` `LoadImagesAndVideos` batched inference (#8817)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
main
Glenn Jocher 1 year ago committed by GitHub
parent 1f9667fff2
commit 7451ca1f54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      .gitignore
  2. 2
      docs/en/reference/data/loaders.md
  3. 4
      docs/en/reference/utils/files.md
  4. 3
      tests/test_python.py
  5. 2
      ultralytics/__init__.py
  6. 2
      ultralytics/cfg/__init__.py
  7. 13
      ultralytics/data/build.py
  8. 127
      ultralytics/data/loaders.py
  9. 3
      ultralytics/engine/model.py
  10. 196
      ultralytics/engine/predictor.py
  11. 5
      ultralytics/trackers/track.py

2
.gitignore vendored

@ -29,7 +29,7 @@ MANIFEST
# PyInstaller # PyInstaller
# Usually these files are written by a python script from a template # Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it. # before PyInstaller builds the exe, so as to inject date/other info into it.
*.manifest *.manifest
*.spec *.spec

@ -23,7 +23,7 @@ keywords: Ultralytics, data loaders, LoadStreams, LoadImages, LoadTensor, YOLO,
<br><br> <br><br>
## ::: ultralytics.data.loaders.LoadImages ## ::: ultralytics.data.loaders.LoadImagesAndVideos
<br><br> <br><br>

@ -38,3 +38,7 @@ keywords: Ultralytics, utility functions, file operations, working directory, fi
## ::: ultralytics.utils.files.get_latest_run ## ::: ultralytics.utils.files.get_latest_run
<br><br> <br><br>
## ::: ultralytics.utils.files.update_models
<br><br>

@ -8,6 +8,7 @@ import cv2
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
import yaml
from PIL import Image from PIL import Image
from torchvision.transforms import ToTensor from torchvision.transforms import ToTensor
@ -169,8 +170,6 @@ def test_track_stream():
Note imgsz=160 required for tracking for higher confidence and better matches Note imgsz=160 required for tracking for higher confidence and better matches
""" """
import yaml
video_url = "https://ultralytics.com/assets/decelera_portrait_min.mov" video_url = "https://ultralytics.com/assets/decelera_portrait_min.mov"
model = YOLO(MODEL) model = YOLO(MODEL)
model.track(video_url, imgsz=160, tracker="bytetrack.yaml") model.track(video_url, imgsz=160, tracker="bytetrack.yaml")

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.25" __version__ = "8.1.26"
from ultralytics.data.explorer.explorer import Explorer from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

@ -396,7 +396,7 @@ def handle_yolo_settings(args: List[str]) -> None:
def handle_explorer(): def handle_explorer():
"""Open the Ultralytics Explorer GUI.""" """Open the Ultralytics Explorer GUI."""
checks.check_requirements("streamlit") checks.check_requirements("streamlit")
LOGGER.info(f"💡 Loading Explorer dashboard...") LOGGER.info("💡 Loading Explorer dashboard...")
subprocess.run(["streamlit", "run", ROOT / "data/explorer/gui/dash.py", "--server.maxMessageSize", "2048"]) subprocess.run(["streamlit", "run", ROOT / "data/explorer/gui/dash.py", "--server.maxMessageSize", "2048"])

@ -11,7 +11,7 @@ from torch.utils.data import dataloader, distributed
from ultralytics.data.loaders import ( from ultralytics.data.loaders import (
LOADERS, LOADERS,
LoadImages, LoadImagesAndVideos,
LoadPilAndNumpy, LoadPilAndNumpy,
LoadScreenshots, LoadScreenshots,
LoadStreams, LoadStreams,
@ -150,34 +150,35 @@ def check_source(source):
return source, webcam, screenshot, from_img, in_memory, tensor return source, webcam, screenshot, from_img, in_memory, tensor
def load_inference_source(source=None, vid_stride=1, buffer=False): def load_inference_source(source=None, batch=1, vid_stride=1, buffer=False):
""" """
Loads an inference source for object detection and applies necessary transformations. Loads an inference source for object detection and applies necessary transformations.
Args: Args:
source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference. source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference.
batch (int, optional): Batch size for dataloaders. Default is 1.
vid_stride (int, optional): The frame interval for video sources. Default is 1. vid_stride (int, optional): The frame interval for video sources. Default is 1.
buffer (bool, optional): Determined whether stream frames will be buffered. Default is False. buffer (bool, optional): Determined whether stream frames will be buffered. Default is False.
Returns: Returns:
dataset (Dataset): A dataset object for the specified input source. dataset (Dataset): A dataset object for the specified input source.
""" """
source, webcam, screenshot, from_img, in_memory, tensor = check_source(source) source, stream, screenshot, from_img, in_memory, tensor = check_source(source)
source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor) source_type = source.source_type if in_memory else SourceTypes(stream, screenshot, from_img, tensor)
# Dataloader # Dataloader
if tensor: if tensor:
dataset = LoadTensor(source) dataset = LoadTensor(source)
elif in_memory: elif in_memory:
dataset = source dataset = source
elif webcam: elif stream:
dataset = LoadStreams(source, vid_stride=vid_stride, buffer=buffer) dataset = LoadStreams(source, vid_stride=vid_stride, buffer=buffer)
elif screenshot: elif screenshot:
dataset = LoadScreenshots(source) dataset = LoadScreenshots(source)
elif from_img: elif from_img:
dataset = LoadPilAndNumpy(source) dataset = LoadPilAndNumpy(source)
else: else:
dataset = LoadImages(source, vid_stride=vid_stride) dataset = LoadImagesAndVideos(source, batch=batch, vid_stride=vid_stride)
# Attach source types to the dataset # Attach source types to the dataset
setattr(dataset, "source_type", source_type) setattr(dataset, "source_type", source_type)

@ -24,7 +24,7 @@ from ultralytics.utils.checks import check_requirements
class SourceTypes: class SourceTypes:
"""Class to represent various types of input sources for predictions.""" """Class to represent various types of input sources for predictions."""
webcam: bool = False stream: bool = False
screenshot: bool = False screenshot: bool = False
from_img: bool = False from_img: bool = False
tensor: bool = False tensor: bool = False
@ -32,9 +32,7 @@ class SourceTypes:
class LoadStreams: class LoadStreams:
""" """
Stream Loader for various types of video streams. Stream Loader for various types of video streams, Supports RTSP, RTMP, HTTP, and TCP streams.
Suitable for use with `yolo predict source='rtsp://example.com/media.mp4'`, supports RTSP, RTMP, HTTP, and TCP streams.
Attributes: Attributes:
sources (str): The source input paths or URLs for the video streams. sources (str): The source input paths or URLs for the video streams.
@ -57,6 +55,11 @@ class LoadStreams:
__iter__: Returns an iterator object for the class. __iter__: Returns an iterator object for the class.
__next__: Returns source paths, transformed, and original images for processing. __next__: Returns source paths, transformed, and original images for processing.
__len__: Return the length of the sources object. __len__: Return the length of the sources object.
Example:
```bash
yolo predict source='rtsp://example.com/media.mp4'
```
""" """
def __init__(self, sources="file.streams", vid_stride=1, buffer=False): def __init__(self, sources="file.streams", vid_stride=1, buffer=False):
@ -69,6 +72,7 @@ class LoadStreams:
sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources] sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
n = len(sources) n = len(sources)
self.bs = n
self.fps = [0] * n # frames per second self.fps = [0] * n # frames per second
self.frames = [0] * n self.frames = [0] * n
self.threads = [None] * n self.threads = [None] * n
@ -76,6 +80,8 @@ class LoadStreams:
self.imgs = [[] for _ in range(n)] # images self.imgs = [[] for _ in range(n)] # images
self.shape = [[] for _ in range(n)] # image shapes self.shape = [[] for _ in range(n)] # image shapes
self.sources = [ops.clean_str(x) for x in sources] # clean source names for later self.sources = [ops.clean_str(x) for x in sources] # clean source names for later
self.info = [""] * n
self.is_video = [True] * n
for i, s in enumerate(sources): # index, source for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream # Start thread to read frames from video stream
st = f"{i + 1}/{n}: {s}... " st = f"{i + 1}/{n}: {s}... "
@ -109,9 +115,6 @@ class LoadStreams:
self.threads[i].start() self.threads[i].start()
LOGGER.info("") # newline LOGGER.info("") # newline
# Check for common shapes
self.bs = self.__len__()
def update(self, i, cap, stream): def update(self, i, cap, stream):
"""Read stream `i` frames in daemon thread.""" """Read stream `i` frames in daemon thread."""
n, f = 0, self.frames[i] # frame number, frame array n, f = 0, self.frames[i] # frame number, frame array
@ -175,11 +178,11 @@ class LoadStreams:
images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8)) images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8))
x.clear() x.clear()
return self.sources, images, None, "" return self.sources, images, self.is_video, self.info
def __len__(self): def __len__(self):
"""Return the length of the sources object.""" """Return the length of the sources object."""
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years return self.bs # 1E12 frames = 32 streams at 30 FPS for 30 years
class LoadScreenshots: class LoadScreenshots:
@ -243,10 +246,10 @@ class LoadScreenshots:
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: " s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
self.frame += 1 self.frame += 1
return [str(self.screen)], [im0], None, s # screen, img, vid_cap, string return [str(self.screen)], [im0], [True], [s] # screen, img, is_video, string
class LoadImages: class LoadImagesAndVideos:
""" """
YOLOv8 image/video dataloader. YOLOv8 image/video dataloader.
@ -269,7 +272,7 @@ class LoadImages:
_new_video(path): Create a new cv2.VideoCapture object for a given video path. _new_video(path): Create a new cv2.VideoCapture object for a given video path.
""" """
def __init__(self, path, vid_stride=1): def __init__(self, path, batch=1, vid_stride=1):
"""Initialize the Dataloader and raise FileNotFoundError if file not found.""" """Initialize the Dataloader and raise FileNotFoundError if file not found."""
parent = None parent = None
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
@ -298,7 +301,7 @@ class LoadImages:
self.video_flag = [False] * ni + [True] * nv self.video_flag = [False] * ni + [True] * nv
self.mode = "image" self.mode = "image"
self.vid_stride = vid_stride # video frame-rate stride self.vid_stride = vid_stride # video frame-rate stride
self.bs = 1 self.bs = batch
if any(videos): if any(videos):
self._new_video(videos[0]) # new video self._new_video(videos[0]) # new video
else: else:
@ -315,49 +318,68 @@ class LoadImages:
return self return self
def __next__(self): def __next__(self):
"""Return next image, path and metadata from dataset.""" """Returns the next batch of images or video frames along with their paths and metadata."""
if self.count == self.nf: paths, imgs, is_video, info = [], [], [], []
raise StopIteration while len(imgs) < self.bs:
path = self.files[self.count] if self.count >= self.nf: # end of file list
if len(imgs) > 0:
if self.video_flag[self.count]: return paths, imgs, is_video, info # return last partial batch
# Read video else:
self.mode = "video"
for _ in range(self.vid_stride):
self.cap.grab()
success, im0 = self.cap.retrieve()
while not success:
self.count += 1
self.cap.release()
if self.count == self.nf: # last video
raise StopIteration raise StopIteration
path = self.files[self.count]
self._new_video(path)
success, im0 = self.cap.read()
self.frame += 1
# im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
s = f"video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: "
else: path = self.files[self.count]
# Read image if self.video_flag[self.count]:
self.count += 1 self.mode = "video"
im0 = cv2.imread(path) # BGR if not self.cap or not self.cap.isOpened():
if im0 is None: self._new_video(path)
raise FileNotFoundError(f"Image Not Found {path}")
s = f"image {self.count}/{self.nf} {path}: "
return [path], [im0], self.cap, s for _ in range(self.vid_stride):
success = self.cap.grab()
if not success:
break # end of video or failure
if success:
success, im0 = self.cap.retrieve()
if success:
self.frame += 1
paths.append(path)
imgs.append(im0)
is_video.append(True)
info.append(f"video {self.count + 1}/{self.nf} (frame {self.frame}/{self.frames}) {path}: ")
if self.frame == self.frames: # end of video
self.count += 1
self.cap.release()
else:
# Move to the next file if the current video ended or failed to open
self.count += 1
if self.cap:
self.cap.release()
if self.count < self.nf:
self._new_video(self.files[self.count])
else:
self.mode = "image"
im0 = cv2.imread(path) # BGR
if im0 is None:
raise FileNotFoundError(f"Image Not Found {path}")
paths.append(path)
imgs.append(im0)
is_video.append(False) # no capture object for images
info.append(f"image {self.count + 1}/{self.nf} {path}: ")
self.count += 1 # move to the next file
return paths, imgs, is_video, info
def _new_video(self, path): def _new_video(self, path):
"""Create a new video capture object.""" """Creates a new video capture object for the given path."""
self.frame = 0 self.frame = 0
self.cap = cv2.VideoCapture(path) self.cap = cv2.VideoCapture(path)
if not self.cap.isOpened():
raise FileNotFoundError(f"Failed to open video {path}")
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride) self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
def __len__(self): def __len__(self):
"""Returns the number of files in the object.""" """Returns the number of batches in the object."""
return self.nf # number of files return math.ceil(self.nf / self.bs) # number of files
class LoadPilAndNumpy: class LoadPilAndNumpy:
@ -373,7 +395,6 @@ class LoadPilAndNumpy:
im0 (list): List of images stored as Numpy arrays. im0 (list): List of images stored as Numpy arrays.
mode (str): Type of data being processed, defaults to 'image'. mode (str): Type of data being processed, defaults to 'image'.
bs (int): Batch size, equivalent to the length of `im0`. bs (int): Batch size, equivalent to the length of `im0`.
count (int): Counter for iteration, initialized at 0 during `__iter__()`.
Methods: Methods:
_single_check(im): Validate and format a single image to a Numpy array. _single_check(im): Validate and format a single image to a Numpy array.
@ -386,7 +407,6 @@ class LoadPilAndNumpy:
self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)] self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
self.im0 = [self._single_check(im) for im in im0] self.im0 = [self._single_check(im) for im in im0]
self.mode = "image" self.mode = "image"
# Generate fake paths
self.bs = len(self.im0) self.bs = len(self.im0)
@staticmethod @staticmethod
@ -409,7 +429,7 @@ class LoadPilAndNumpy:
if self.count == 1: # loop only once as it's batch inference if self.count == 1: # loop only once as it's batch inference
raise StopIteration raise StopIteration
self.count += 1 self.count += 1
return self.paths, self.im0, None, "" return self.paths, self.im0, [False] * self.bs, [""] * self.bs
def __iter__(self): def __iter__(self):
"""Enables iteration for class LoadPilAndNumpy.""" """Enables iteration for class LoadPilAndNumpy."""
@ -474,7 +494,7 @@ class LoadTensor:
if self.count == 1: if self.count == 1:
raise StopIteration raise StopIteration
self.count += 1 self.count += 1
return self.paths, self.im0, None, "" return self.paths, self.im0, [False] * self.bs, [""] * self.bs
def __len__(self): def __len__(self):
"""Returns the batch size.""" """Returns the batch size."""
@ -498,9 +518,6 @@ def autocast_list(source):
return files return files
LOADERS = LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots # tuple
def get_best_youtube_url(url, use_pafy=True): def get_best_youtube_url(url, use_pafy=True):
""" """
Retrieves the URL of the best quality MP4 video stream from a given YouTube video. Retrieves the URL of the best quality MP4 video stream from a given YouTube video.
@ -531,3 +548,7 @@ def get_best_youtube_url(url, use_pafy=True):
good_size = (f.get("width") or 0) >= 1920 or (f.get("height") or 0) >= 1080 good_size = (f.get("width") or 0) >= 1920 or (f.get("height") or 0) >= 1080
if good_size and f["vcodec"] != "none" and f["acodec"] == "none" and f["ext"] == "mp4": if good_size and f["vcodec"] != "none" and f["acodec"] == "none" and f["ext"] == "mp4":
return f.get("url") return f.get("url")
# Define constants
LOADERS = (LoadStreams, LoadPilAndNumpy, LoadImagesAndVideos, LoadScreenshots)

@ -423,7 +423,7 @@ class Model(nn.Module):
x in sys.argv for x in ("predict", "track", "mode=predict", "mode=track") x in sys.argv for x in ("predict", "track", "mode=predict", "mode=track")
) )
custom = {"conf": 0.25, "save": is_cli, "mode": "predict"} # method defaults custom = {"conf": 0.25, "batch": 1, "save": is_cli, "mode": "predict"} # method defaults
args = {**self.overrides, **custom, **kwargs} # highest priority args on the right args = {**self.overrides, **custom, **kwargs} # highest priority args on the right
prompts = args.pop("prompts", None) # for SAM-type models prompts = args.pop("prompts", None) # for SAM-type models
@ -474,6 +474,7 @@ class Model(nn.Module):
register_tracker(self, persist) register_tracker(self, persist)
kwargs["conf"] = kwargs.get("conf") or 0.1 # ByteTrack-based method needs low confidence predictions as input kwargs["conf"] = kwargs.get("conf") or 0.1 # ByteTrack-based method needs low confidence predictions as input
kwargs["batch"] = kwargs.get("batch") or 1 # batch-size 1 for tracking in videos
kwargs["mode"] = "track" kwargs["mode"] = "track"
return self.predict(source=source, stream=stream, **kwargs) return self.predict(source=source, stream=stream, **kwargs)

@ -73,9 +73,7 @@ class BasePredictor:
data (dict): Data configuration. data (dict): Data configuration.
device (torch.device): Device used for prediction. device (torch.device): Device used for prediction.
dataset (Dataset): Dataset used for prediction. dataset (Dataset): Dataset used for prediction.
vid_path (str): Path to video file. vid_writer (dict): Dictionary of {save_path: video_writer, ...} writer for saving video output.
vid_writer (cv2.VideoWriter): Video writer for saving video output.
data_path (str): Path to data.
""" """
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
@ -100,10 +98,11 @@ class BasePredictor:
self.imgsz = None self.imgsz = None
self.device = None self.device = None
self.dataset = None self.dataset = None
self.vid_path, self.vid_writer, self.vid_frame = None, None, None self.vid_writer = {} # dict of {save_path: video_writer, ...}
self.plotted_img = None self.plotted_img = None
self.data_path = None
self.source_type = None self.source_type = None
self.seen = 0
self.windows = []
self.batch = None self.batch = None
self.results = None self.results = None
self.transforms = None self.transforms = None
@ -155,44 +154,6 @@ class BasePredictor:
letterbox = LetterBox(self.imgsz, auto=same_shapes and self.model.pt, stride=self.model.stride) letterbox = LetterBox(self.imgsz, auto=same_shapes and self.model.pt, stride=self.model.stride)
return [letterbox(image=x) for x in im] return [letterbox(image=x) for x in im]
def write_results(self, idx, results, batch):
"""Write inference results to a file or directory."""
p, im, _ = batch
log_string = ""
if len(im.shape) == 3:
im = im[None] # expand for batch dim
if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
log_string += f"{idx}: "
frame = self.dataset.count
else:
frame = getattr(self.dataset, "frame", 0)
self.data_path = p
self.txt_path = str(self.save_dir / "labels" / p.stem) + ("" if self.dataset.mode == "image" else f"_{frame}")
log_string += "%gx%g " % im.shape[2:] # print string
result = results[idx]
log_string += result.verbose()
if self.args.save or self.args.show: # Add bbox to image
plot_args = {
"line_width": self.args.line_width,
"boxes": self.args.show_boxes,
"conf": self.args.show_conf,
"labels": self.args.show_labels,
}
if not self.args.retina_masks:
plot_args["im_gpu"] = im[idx]
self.plotted_img = result.plot(**plot_args)
# Write
if self.args.save_txt:
result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf)
if self.args.save_crop:
result.save_crop(
save_dir=self.save_dir / "crops",
file_name=self.data_path.stem + ("" if self.dataset.mode == "image" else f"_{frame}"),
)
return log_string
def postprocess(self, preds, img, orig_imgs): def postprocess(self, preds, img, orig_imgs):
"""Post-processes predictions for an image and returns them.""" """Post-processes predictions for an image and returns them."""
return preds return preds
@ -228,18 +189,20 @@ class BasePredictor:
else None else None
) )
self.dataset = load_inference_source( self.dataset = load_inference_source(
source=source, vid_stride=self.args.vid_stride, buffer=self.args.stream_buffer source=source,
batch=self.args.batch,
vid_stride=self.args.vid_stride,
buffer=self.args.stream_buffer,
) )
self.source_type = self.dataset.source_type self.source_type = self.dataset.source_type
if not getattr(self, "stream", True) and ( if not getattr(self, "stream", True) and (
self.dataset.mode == "stream" # streams self.source_type.stream
or len(self.dataset) > 1000 # images or self.source_type.screenshot
or len(self.dataset) > 1000 # many images
or any(getattr(self.dataset, "video_flag", [False])) or any(getattr(self.dataset, "video_flag", [False]))
): # videos ): # videos
LOGGER.warning(STREAM_WARNING) LOGGER.warning(STREAM_WARNING)
self.vid_path = [None] * self.dataset.bs self.vid_writer = {}
self.vid_writer = [None] * self.dataset.bs
self.vid_frame = [None] * self.dataset.bs
@smart_inference_mode() @smart_inference_mode()
def stream_inference(self, source=None, model=None, *args, **kwargs): def stream_inference(self, source=None, model=None, *args, **kwargs):
@ -271,10 +234,9 @@ class BasePredictor:
ops.Profile(device=self.device), ops.Profile(device=self.device),
) )
self.run_callbacks("on_predict_start") self.run_callbacks("on_predict_start")
for batch in self.dataset: for self.batch in self.dataset:
self.run_callbacks("on_predict_batch_start") self.run_callbacks("on_predict_batch_start")
self.batch = batch paths, im0s, is_video, s = self.batch
path, im0s, vid_cap, s = batch
# Preprocess # Preprocess
with profilers[0]: with profilers[0]:
@ -290,8 +252,8 @@ class BasePredictor:
# Postprocess # Postprocess
with profilers[2]: with profilers[2]:
self.results = self.postprocess(preds, im, im0s) self.results = self.postprocess(preds, im, im0s)
self.run_callbacks("on_predict_postprocess_end") self.run_callbacks("on_predict_postprocess_end")
# Visualize, save, write results # Visualize, save, write results
n = len(im0s) n = len(im0s)
for i in range(n): for i in range(n):
@ -301,41 +263,32 @@ class BasePredictor:
"inference": profilers[1].dt * 1e3 / n, "inference": profilers[1].dt * 1e3 / n,
"postprocess": profilers[2].dt * 1e3 / n, "postprocess": profilers[2].dt * 1e3 / n,
} }
p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy()
p = Path(p)
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show: if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
s += self.write_results(i, self.results, (p, im, im0)) s[i] += self.write_results(i, Path(paths[i]), im, is_video)
if self.args.save or self.args.save_txt:
self.results[i].save_dir = self.save_dir.__str__() # Print batch results
if self.args.show and self.plotted_img is not None: if self.args.verbose:
self.show(p) LOGGER.info("\n".join(s))
if self.args.save and self.plotted_img is not None:
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
self.run_callbacks("on_predict_batch_end") self.run_callbacks("on_predict_batch_end")
yield from self.results yield from self.results
# Print time (inference-only)
if self.args.verbose:
LOGGER.info(f"{s}{profilers[1].dt * 1E3:.1f}ms")
# Release assets # Release assets
if isinstance(self.vid_writer[-1], cv2.VideoWriter): for v in self.vid_writer.values():
self.vid_writer[-1].release() # release final video writer if isinstance(v, cv2.VideoWriter):
v.release()
# Print results # Print final results
if self.args.verbose and self.seen: if self.args.verbose and self.seen:
t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image
LOGGER.info( LOGGER.info(
f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape " f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape "
f"{(1, 3, *im.shape[2:])}" % t f"{(min(self.args.batch, self.seen), 3, *im.shape[2:])}" % t
) )
if self.args.save or self.args.save_txt or self.args.save_crop: if self.args.save or self.args.save_txt or self.args.save_crop:
nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else "" s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ""
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}") LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
self.run_callbacks("on_predict_end") self.run_callbacks("on_predict_end")
def setup_model(self, model, verbose=True): def setup_model(self, model, verbose=True):
@ -354,48 +307,81 @@ class BasePredictor:
self.args.half = self.model.fp16 # update half self.args.half = self.model.fp16 # update half
self.model.eval() self.model.eval()
def show(self, p): def write_results(self, i, p, im, is_video):
"""Display an image in a window using OpenCV imshow().""" """Write inference results to a file or directory."""
im0 = self.plotted_img string = "" # print string
if platform.system() == "Linux" and p not in self.windows: if len(im.shape) == 3:
self.windows.append(p) im = im[None] # expand for batch dim
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) if self.source_type.stream or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0]) string += f"{i}: "
cv2.imshow(str(p), im0) frame = self.dataset.count
cv2.waitKey(500 if self.batch[3].startswith("image") else 1) # 1 millisecond else:
frame = getattr(self.dataset, "frame", 0) - len(self.results) + i
self.txt_path = self.save_dir / "labels" / (p.stem + f"_{frame}" if is_video[i] else "")
string += "%gx%g " % im.shape[2:]
result = self.results[i]
result.save_dir = self.save_dir.__str__() # used in other locations
string += result.verbose() + f"{result.speed['inference']:.1f}ms"
# Add predictions to image
if self.args.save or self.args.show:
self.plotted_img = result.plot(
line_width=self.args.line_width,
boxes=self.args.show_boxes,
conf=self.args.show_conf,
labels=self.args.show_labels,
im_gpu=None if self.args.retina_masks else im[i],
)
# Save results
if self.args.save_txt:
result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf)
if self.args.save_crop:
result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem)
if self.args.show:
self.show(str(p), is_video[i])
if self.args.save:
self.save_predicted_images(str(self.save_dir / p.name), is_video[i], frame)
return string
def save_preds(self, vid_cap, idx, save_path): def save_predicted_images(self, save_path="", is_video=False, frame=0):
"""Save video predictions as mp4 at specified path.""" """Save video predictions as mp4 at specified path."""
im0 = self.plotted_img im = self.plotted_img
# Save imgs
if self.dataset.mode == "image": # Save videos and streams
cv2.imwrite(save_path, im0) if is_video:
else: # 'video' or 'stream'
frames_path = f'{save_path.split(".", 1)[0]}_frames/' frames_path = f'{save_path.split(".", 1)[0]}_frames/'
if self.vid_path[idx] != save_path: # new video if save_path not in self.vid_writer: # new video
self.vid_path[idx] = save_path
if self.args.save_frames: if self.args.save_frames:
Path(frames_path).mkdir(parents=True, exist_ok=True) Path(frames_path).mkdir(parents=True, exist_ok=True)
self.vid_frame[idx] = 0
if isinstance(self.vid_writer[idx], cv2.VideoWriter):
self.vid_writer[idx].release() # release previous video writer
if vid_cap: # video
fps = int(vid_cap.get(cv2.CAP_PROP_FPS)) # integer required, floats produce error in MP4 codec
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
else: # stream
fps, w, h = 30, im0.shape[1], im0.shape[0]
suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG") suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG")
self.vid_writer[idx] = cv2.VideoWriter( self.vid_writer[save_path] = cv2.VideoWriter(
str(Path(save_path).with_suffix(suffix)), cv2.VideoWriter_fourcc(*fourcc), fps, (w, h) filename=str(Path(save_path).with_suffix(suffix)),
fourcc=cv2.VideoWriter_fourcc(*fourcc),
fps=30, # integer required, floats produce error in MP4 codec
frameSize=(im.shape[1], im.shape[0]), # (width, height)
) )
# Write video
self.vid_writer[idx].write(im0)
# Write frame # Save video
self.vid_writer[save_path].write(im)
if self.args.save_frames: if self.args.save_frames:
cv2.imwrite(f"{frames_path}{self.vid_frame[idx]}.jpg", im0) cv2.imwrite(f"{frames_path}{frame}.jpg", im)
self.vid_frame[idx] += 1
# Save images
else:
cv2.imwrite(save_path, im)
def show(self, p="", is_video=False):
"""Display an image in a window using OpenCV imshow()."""
im = self.plotted_img
if platform.system() == "Linux" and p not in self.windows:
self.windows.append(p)
cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height)
cv2.imshow(p, im)
cv2.waitKey(1 if is_video else 500) # 1 millisecond
def run_callbacks(self, event: str): def run_callbacks(self, event: str):
"""Runs all registered callbacks for a specific event.""" """Runs all registered callbacks for a specific event."""

@ -39,6 +39,7 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30) tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30)
trackers.append(tracker) trackers.append(tracker)
predictor.trackers = trackers predictor.trackers = trackers
predictor.vid_path = [None] * predictor.dataset.bs # for determining when to reset tracker on new video
def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None: def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None:
@ -54,8 +55,10 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
is_obb = predictor.args.task == "obb" is_obb = predictor.args.task == "obb"
for i in range(bs): for i in range(bs):
if not persist and predictor.vid_path[i] != str(predictor.save_dir / Path(path[i]).name): # new video vid_path = predictor.save_dir / Path(path[i]).name
if not persist and predictor.vid_path[i] != vid_path: # new video
predictor.trackers[i].reset() predictor.trackers[i].reset()
predictor.vid_path[i] = vid_path
det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy() det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy()
if len(det) == 0: if len(det) == 0:

Loading…
Cancel
Save