`ultralytics 8.3.10` Apple iPhone HEIC support (#16853)

Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
pull/16858/head v8.3.10
Glenn Jocher 4 months ago committed by GitHub
parent 1e5e612f83
commit 02d5c290e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 5
      docs/en/modes/predict.md
  2. 2
      ultralytics/__init__.py
  3. 240
      ultralytics/data/loaders.py
  4. 2
      ultralytics/data/utils.py
  5. 2
      ultralytics/engine/predictor.py
  6. 15
      ultralytics/utils/checks.py

@ -408,6 +408,10 @@ YOLO11 supports various image and video formats, as specified in [ultralytics/da
The below table contains valid Ultralytics image formats.
!!! note
HEIC images are supported for inference only, not for training.
| Image Suffixes | Example Predict Command | Reference |
| -------------- | -------------------------------- | -------------------------------------------------------------------------- |
| `.bmp` | `yolo predict source=image.bmp` | [Microsoft BMP File Format](https://en.wikipedia.org/wiki/BMP_file_format) |
@ -420,6 +424,7 @@ The below table contains valid Ultralytics image formats.
| `.tiff` | `yolo predict source=image.tiff` | [Tag Image File Format](https://en.wikipedia.org/wiki/TIFF) |
| `.webp` | `yolo predict source=image.webp` | [WebP](https://en.wikipedia.org/wiki/WebP) |
| `.pfm` | `yolo predict source=image.pfm` | [Portable FloatMap](https://en.wikipedia.org/wiki/Netpbm#File_formats) |
| `.HEIC` | `yolo predict source=image.HEIC` | [High Efficiency Image Format](https://en.wikipedia.org/wiki/HEIF) |
### Videos

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.9"
__version__ = "8.3.10"
import os

@ -18,11 +18,29 @@ from PIL import Image
from ultralytics.data.utils import FORMATS_HELP_MSG, IMG_FORMATS, VID_FORMATS
from ultralytics.utils import IS_COLAB, IS_KAGGLE, LOGGER, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.patches import imread
@dataclass
class SourceTypes:
"""Class to represent various types of input sources for predictions."""
"""
Class to represent various types of input sources for predictions.
This class uses dataclass to define boolean flags for different types of input sources that can be used for
making predictions with YOLO models.
Attributes:
stream (bool): Flag indicating if the input source is a video stream.
screenshot (bool): Flag indicating if the input source is a screenshot.
from_img (bool): Flag indicating if the input source is an image file.
Examples:
>>> source_types = SourceTypes(stream=True, screenshot=False, from_img=False)
>>> print(source_types.stream)
True
>>> print(source_types.from_img)
False
"""
stream: bool = False
screenshot: bool = False
@ -32,38 +50,47 @@ class SourceTypes:
class LoadStreams:
"""
Stream Loader for various types of video streams, Supports RTSP, RTMP, HTTP, and TCP streams.
Stream Loader for various types of video streams.
Supports RTSP, RTMP, HTTP, and TCP streams. This class handles the loading and processing of multiple video
streams simultaneously, making it suitable for real-time video analysis tasks.
Attributes:
sources (str): The source input paths or URLs for the video streams.
vid_stride (int): Video frame-rate stride, defaults to 1.
buffer (bool): Whether to buffer input streams, defaults to False.
sources (List[str]): The source input paths or URLs for the video streams.
vid_stride (int): Video frame-rate stride.
buffer (bool): Whether to buffer input streams.
running (bool): Flag to indicate if the streaming thread is running.
mode (str): Set to 'stream' indicating real-time capture.
imgs (list): List of image frames for each stream.
fps (list): List of FPS for each stream.
frames (list): List of total frames for each stream.
threads (list): List of threads for each stream.
shape (list): List of shapes for each stream.
caps (list): List of cv2.VideoCapture objects for each stream.
imgs (List[List[np.ndarray]]): List of image frames for each stream.
fps (List[float]): List of FPS for each stream.
frames (List[int]): List of total frames for each stream.
threads (List[Thread]): List of threads for each stream.
shape (List[Tuple[int, int, int]]): List of shapes for each stream.
caps (List[cv2.VideoCapture]): List of cv2.VideoCapture objects for each stream.
bs (int): Batch size for processing.
Methods:
__init__: Initialize the stream loader.
update: Read stream frames in daemon thread.
close: Close stream loader and release resources.
__iter__: Returns an iterator object for the class.
__next__: Returns source paths, transformed, and original images for processing.
__len__: Return the length of the sources object.
Example:
```bash
yolo predict source='rtsp://example.com/media.mp4'
```
Examples:
>>> stream_loader = LoadStreams("rtsp://example.com/stream1.mp4")
>>> for sources, imgs, _ in stream_loader:
... # Process the images
... pass
>>> stream_loader.close()
Notes:
- The class uses threading to efficiently load frames from multiple streams simultaneously.
- It automatically handles YouTube links, converting them to the best available stream URL.
- The class implements a buffer system to manage frame storage and retrieval.
"""
def __init__(self, sources="file.streams", vid_stride=1, buffer=False):
"""Initialize instance variables and check for consistent input stream shapes."""
"""Initialize stream loader for multiple video sources, supporting various stream types."""
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
self.buffer = buffer # buffer input streams
self.running = True # running flag for Thread
@ -114,7 +141,7 @@ class LoadStreams:
LOGGER.info("") # newline
def update(self, i, cap, stream):
"""Read stream `i` frames in daemon thread."""
"""Read stream frames in daemon thread and update image buffer."""
n, f = 0, self.frames[i] # frame number, frame array
while self.running and cap.isOpened() and n < (f - 1):
if len(self.imgs[i]) < 30: # keep a <=30-image buffer
@ -134,7 +161,7 @@ class LoadStreams:
time.sleep(0.01) # wait until the buffer is empty
def close(self):
"""Close stream loader and release resources."""
"""Terminates stream loader, stops threads, and releases video capture resources."""
self.running = False # stop flag for Thread
for thread in self.threads:
if thread.is_alive():
@ -152,7 +179,7 @@ class LoadStreams:
return self
def __next__(self):
"""Returns source paths, transformed and original images for processing."""
"""Returns the next batch of frames from multiple video streams for processing."""
self.count += 1
images = []
@ -179,16 +206,16 @@ class LoadStreams:
return self.sources, images, [""] * self.bs
def __len__(self):
"""Return the length of the sources object."""
"""Return the number of video streams in the LoadStreams object."""
return self.bs # 1E12 frames = 32 streams at 30 FPS for 30 years
class LoadScreenshots:
"""
YOLOv8 screenshot dataloader.
Ultralytics screenshot dataloader for capturing and processing screen images.
This class manages the loading of screenshot images for processing with YOLOv8.
Suitable for use with `yolo predict source=screen`.
This class manages the loading of screenshot images for processing with YOLO. It is suitable for use with
`yolo predict source=screen`.
Attributes:
source (str): The source input indicating which screen to capture.
@ -201,15 +228,21 @@ class LoadScreenshots:
frame (int): Counter for captured frames.
sct (mss.mss): Screen capture object from `mss` library.
bs (int): Batch size, set to 1.
monitor (dict): Monitor configuration details.
fps (int): Frames per second, set to 30.
monitor (Dict[str, int]): Monitor configuration details.
Methods:
__iter__: Returns an iterator object.
__next__: Captures the next screenshot and returns it.
Examples:
>>> loader = LoadScreenshots("0 100 100 640 480") # screen 0, top-left (100,100), 640x480
>>> for source, im, im0s, vid_cap, s in loader:
... print(f"Captured frame: {im.shape}")
"""
def __init__(self, source):
"""Source = [screen_number left top width height] (pixels)."""
"""Initialize screenshot capture with specified screen and region parameters."""
check_requirements("mss")
import mss # noqa
@ -236,11 +269,11 @@ class LoadScreenshots:
self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
def __iter__(self):
"""Returns an iterator of the object."""
"""Yields the next screenshot image from the specified screen or region for processing."""
return self
def __next__(self):
"""Screen capture with 'mss' to get raw pixels from the screen as np array."""
"""Captures and returns the next screenshot as a numpy array using the mss library."""
im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3] # BGRA to BGR
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
@ -250,29 +283,45 @@ class LoadScreenshots:
class LoadImagesAndVideos:
"""
YOLOv8 image/video dataloader.
A class for loading and processing images and videos for YOLO object detection.
This class manages the loading and pre-processing of image and video data for YOLOv8. It supports loading from
various formats, including single image files, video files, and lists of image and video paths.
This class manages the loading and pre-processing of image and video data from various sources, including
single image files, video files, and lists of image and video paths.
Attributes:
files (list): List of image and video file paths.
files (List[str]): List of image and video file paths.
nf (int): Total number of files (images and videos).
video_flag (list): Flags indicating whether a file is a video (True) or an image (False).
video_flag (List[bool]): Flags indicating whether a file is a video (True) or an image (False).
mode (str): Current mode, 'image' or 'video'.
vid_stride (int): Stride for video frame-rate, defaults to 1.
bs (int): Batch size, set to 1 for this class.
vid_stride (int): Stride for video frame-rate.
bs (int): Batch size.
cap (cv2.VideoCapture): Video capture object for OpenCV.
frame (int): Frame counter for video.
frames (int): Total number of frames in the video.
count (int): Counter for iteration, initialized at 0 during `__iter__()`.
count (int): Counter for iteration, initialized at 0 during __iter__().
ni (int): Number of images.
Methods:
_new_video(path): Create a new cv2.VideoCapture object for a given video path.
__init__: Initialize the LoadImagesAndVideos object.
__iter__: Returns an iterator object for VideoStream or ImageFolder.
__next__: Returns the next batch of images or video frames along with their paths and metadata.
_new_video: Creates a new video capture object for the given path.
__len__: Returns the number of batches in the object.
Examples:
>>> loader = LoadImagesAndVideos("path/to/data", batch=32, vid_stride=1)
>>> for paths, imgs, info in loader:
... # Process batch of images or video frames
... pass
Notes:
- Supports various image formats including HEIC.
- Handles both local files and directories.
- Can read from a text file containing paths to images and videos.
"""
def __init__(self, path, batch=1, vid_stride=1):
"""Initialize the Dataloader and raise FileNotFoundError if file not found."""
"""Initialize dataloader for images and videos, supporting various input formats."""
parent = None
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
parent = Path(path).parent
@ -316,12 +365,12 @@ class LoadImagesAndVideos:
raise FileNotFoundError(f"No images or videos found in {p}. {FORMATS_HELP_MSG}")
def __iter__(self):
"""Returns an iterator object for VideoStream or ImageFolder."""
"""Iterates through image/video files, yielding source paths, images, and metadata."""
self.count = 0
return self
def __next__(self):
"""Returns the next batch of images or video frames along with their paths and metadata."""
"""Returns the next batch of images or video frames with their paths and metadata."""
paths, imgs, info = [], [], []
while len(imgs) < self.bs:
if self.count >= self.nf: # end of file list
@ -336,6 +385,7 @@ class LoadImagesAndVideos:
if not self.cap or not self.cap.isOpened():
self._new_video(path)
success = False
for _ in range(self.vid_stride):
success = self.cap.grab()
if not success:
@ -359,8 +409,19 @@ class LoadImagesAndVideos:
if self.count < self.nf:
self._new_video(self.files[self.count])
else:
# Handle image files (including HEIC)
self.mode = "image"
im0 = cv2.imread(path) # BGR
if path.split(".")[-1].lower() == "heic":
# Load HEIC image using Pillow with pillow-heif
check_requirements("pillow-heif")
from pillow_heif import register_heif_opener
register_heif_opener() # Register HEIF opener with Pillow
with Image.open(path) as img:
im0 = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) # convert image to BGR nparray
else:
im0 = imread(path) # BGR
if im0 is None:
LOGGER.warning(f"WARNING ⚠ Image Read Error {path}")
else:
@ -374,7 +435,7 @@ class LoadImagesAndVideos:
return paths, imgs, info
def _new_video(self, path):
"""Creates a new video capture object for the given path."""
"""Creates a new video capture object for the given path and initializes video-related attributes."""
self.frame = 0
self.cap = cv2.VideoCapture(path)
self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
@ -383,30 +444,39 @@ class LoadImagesAndVideos:
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
def __len__(self):
"""Returns the number of batches in the object."""
return math.ceil(self.nf / self.bs) # number of files
"""Returns the number of files (images and videos) in the dataset."""
return math.ceil(self.nf / self.bs) # number of batches
class LoadPilAndNumpy:
"""
Load images from PIL and Numpy arrays for batch processing.
This class is designed to manage loading and pre-processing of image data from both PIL and Numpy formats.
It performs basic validation and format conversion to ensure that the images are in the required format for
downstream processing.
This class manages loading and pre-processing of image data from both PIL and Numpy formats. It performs basic
validation and format conversion to ensure that the images are in the required format for downstream processing.
Attributes:
paths (list): List of image paths or autogenerated filenames.
im0 (list): List of images stored as Numpy arrays.
mode (str): Type of data being processed, defaults to 'image'.
paths (List[str]): List of image paths or autogenerated filenames.
im0 (List[np.ndarray]): List of images stored as Numpy arrays.
mode (str): Type of data being processed, set to 'image'.
bs (int): Batch size, equivalent to the length of `im0`.
Methods:
_single_check(im): Validate and format a single image to a Numpy array.
_single_check: Validate and format a single image to a Numpy array.
Examples:
>>> from PIL import Image
>>> import numpy as np
>>> pil_img = Image.new("RGB", (100, 100))
>>> np_img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
>>> loader = LoadPilAndNumpy([pil_img, np_img])
>>> paths, images, _ = next(iter(loader))
>>> print(f"Loaded {len(images)} images")
Loaded 2 images
"""
def __init__(self, im0):
"""Initialize PIL and Numpy Dataloader."""
"""Initializes a loader for PIL and Numpy images, converting inputs to a standardized format."""
if not isinstance(im0, list):
im0 = [im0]
# use `image{i}.jpg` when Image.filename returns an empty path.
@ -417,7 +487,7 @@ class LoadPilAndNumpy:
@staticmethod
def _single_check(im):
"""Validate and format an image to numpy array."""
"""Validate and format an image to numpy array, ensuring RGB order and contiguous memory."""
assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}"
if isinstance(im, Image.Image):
if im.mode != "RGB":
@ -427,41 +497,48 @@ class LoadPilAndNumpy:
return im
def __len__(self):
"""Returns the length of the 'im0' attribute."""
"""Returns the length of the 'im0' attribute, representing the number of loaded images."""
return len(self.im0)
def __next__(self):
"""Returns batch paths, images, processed images, None, ''."""
"""Returns the next batch of images, paths, and metadata for processing."""
if self.count == 1: # loop only once as it's batch inference
raise StopIteration
self.count += 1
return self.paths, self.im0, [""] * self.bs
def __iter__(self):
"""Enables iteration for class LoadPilAndNumpy."""
"""Iterates through PIL/numpy images, yielding paths, raw images, and metadata for processing."""
self.count = 0
return self
class LoadTensor:
"""
Load images from torch.Tensor data.
A class for loading and processing tensor data for object detection tasks.
This class manages the loading and pre-processing of image data from PyTorch tensors for further processing.
This class handles the loading and pre-processing of image data from PyTorch tensors, preparing them for
further processing in object detection pipelines.
Attributes:
im0 (torch.Tensor): The input tensor containing the image(s).
im0 (torch.Tensor): The input tensor containing the image(s) with shape (B, C, H, W).
bs (int): Batch size, inferred from the shape of `im0`.
mode (str): Current mode, set to 'image'.
paths (list): List of image paths or filenames.
count (int): Counter for iteration, initialized at 0 during `__iter__()`.
mode (str): Current processing mode, set to 'image'.
paths (List[str]): List of image paths or auto-generated filenames.
Methods:
_single_check(im, stride): Validate and possibly modify the input tensor.
_single_check: Validates and formats an input tensor.
Examples:
>>> import torch
>>> tensor = torch.rand(1, 3, 640, 640)
>>> loader = LoadTensor(tensor)
>>> paths, images, info = next(iter(loader))
>>> print(f"Processed {len(images)} images")
"""
def __init__(self, im0) -> None:
"""Initialize Tensor Dataloader."""
"""Initialize LoadTensor object for processing torch.Tensor image data."""
self.im0 = self._single_check(im0)
self.bs = self.im0.shape[0]
self.mode = "image"
@ -469,7 +546,7 @@ class LoadTensor:
@staticmethod
def _single_check(im, stride=32):
"""Validate and format an image to torch.Tensor."""
"""Validates and formats a single image tensor, ensuring correct shape and normalization."""
s = (
f"WARNING ⚠ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) "
f"divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible."
@ -491,24 +568,24 @@ class LoadTensor:
return im
def __iter__(self):
"""Returns an iterator object."""
"""Yields an iterator object for iterating through tensor image data."""
self.count = 0
return self
def __next__(self):
"""Return next item in the iterator."""
"""Yields the next batch of tensor images and metadata for processing."""
if self.count == 1:
raise StopIteration
self.count += 1
return self.paths, self.im0, [""] * self.bs
def __len__(self):
"""Returns the batch size."""
"""Returns the batch size of the tensor input."""
return self.bs
def autocast_list(source):
"""Merges a list of source of different types into a list of numpy arrays or PIL images."""
"""Merges a list of sources into a list of numpy arrays or PIL images for Ultralytics prediction."""
files = []
for im in source:
if isinstance(im, (str, Path)): # filename or uri
@ -528,21 +605,24 @@ def get_best_youtube_url(url, method="pytube"):
"""
Retrieves the URL of the best quality MP4 video stream from a given YouTube video.
This function uses the specified method to extract the video info from YouTube. It supports the following methods:
- "pytube": Uses the pytube library to fetch the video streams.
- "pafy": Uses the pafy library to fetch the video streams.
- "yt-dlp": Uses the yt-dlp library to fetch the video streams.
The function then finds the highest quality MP4 format that has a video codec but no audio codec, and returns the
URL of this video stream.
Args:
url (str): The URL of the YouTube video.
method (str): The method to use for extracting video info. Default is "pytube". Other options are "pafy" and
"yt-dlp".
method (str): The method to use for extracting video info. Options are "pytube", "pafy", and "yt-dlp".
Defaults to "pytube".
Returns:
(str): The URL of the best quality MP4 video stream, or None if no suitable stream is found.
(str | None): The URL of the best quality MP4 video stream, or None if no suitable stream is found.
Examples:
>>> url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ"
>>> best_url = get_best_youtube_url(url)
>>> print(best_url)
https://rr4---sn-q4flrnek.googlevideo.com/videoplayback?expire=...
Notes:
- Requires additional libraries based on the chosen method: pytubefix, pafy, or yt-dlp.
- The function prioritizes streams with at least 1080p resolution when available.
- For the "yt-dlp" method, it looks for formats with video codec, no audio, and *.mp4 extension.
"""
if method == "pytube":
# Switched from pytube to pytubefix to resolve https://github.com/pytube/pytube/issues/1954

@ -35,7 +35,7 @@ from ultralytics.utils.downloads import download, safe_download, unzip_file
from ultralytics.utils.ops import segments2boxes
HELP_URL = "See https://docs.ultralytics.com/datasets for dataset formatting guidance."
IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"} # image suffixes
IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm", "heic"} # image suffixes
VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"

@ -381,7 +381,7 @@ class BasePredictor:
# Save images
else:
cv2.imwrite(save_path, im)
cv2.imwrite(str(Path(save_path).with_suffix(".jpg")), im) # save to JPG for best support
def show(self, p=""):
"""Display an image in a window using the OpenCV imshow function."""

@ -238,12 +238,14 @@ def check_version(
c = parse_version(current) # '1.2.3' -> (1, 2, 3)
for r in required.strip(",").split(","):
op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04')
if not op:
op = ">=" # assume >= if no op passed
v = parse_version(version) # '1.2.3' -> (1, 2, 3)
if op == "==" and c != v:
result = False
elif op == "!=" and c == v:
result = False
elif op in {">=", ""} and not (c >= v): # if no constraint passed assume '>=required'
elif op == ">=" and not (c >= v):
result = False
elif op == "<=" and not (c <= v):
result = False
@ -333,18 +335,19 @@ def check_font(font="Arial.ttf"):
return file
def check_python(minimum: str = "3.8.0", hard: bool = True) -> bool:
def check_python(minimum: str = "3.8.0", hard: bool = True, verbose: bool = True) -> bool:
"""
Check current python version against the required minimum version.
Args:
minimum (str): Required minimum version of python.
hard (bool, optional): If True, raise an AssertionError if the requirement is not met.
verbose (bool, optional): If True, print warning message if requirement is not met.
Returns:
(bool): Whether the installed Python version meets the minimum constraints.
"""
return check_version(PYTHON_VERSION, minimum, name="Python", hard=hard)
return check_version(PYTHON_VERSION, minimum, name="Python", hard=hard, verbose=verbose)
@TryExcept()
@ -374,8 +377,6 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
```
"""
prefix = colorstr("red", "bold", "requirements:")
check_python() # check python version
check_torchvision() # check torch-torchvision compatibility
if isinstance(requirements, Path): # requirements.txt file
file = requirements.resolve()
assert file.exists(), f"{prefix} {file} not found, check failed."
@ -770,6 +771,8 @@ def cuda_is_available() -> bool:
return cuda_device_count() > 0
# Define constants
# Run checks and define constants
check_python("3.8", hard=False, verbose=True) # check python version
check_torchvision() # check torch-torchvision compatibility
IS_PYTHON_MINIMUM_3_10 = check_python("3.10", hard=False)
IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12")

Loading…
Cancel
Save