|
|
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import time
|
|
|
|
from collections import defaultdict
|
|
|
|
from typing import List, Optional, Tuple
|
|
|
|
from urllib.parse import urlparse
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
from transformers import AutoModel, AutoProcessor
|
|
|
|
|
|
|
|
from ultralytics import YOLO
|
|
|
|
from ultralytics.data.loaders import get_best_youtube_url
|
|
|
|
from ultralytics.utils.plotting import Annotator
|
|
|
|
from ultralytics.utils.torch_utils import select_device
|
|
|
|
|
|
|
|
|
|
|
|
class TorchVisionVideoClassifier:
|
|
|
|
from torchvision.models.video import (
|
|
|
|
MViT_V1_B_Weights,
|
|
|
|
MViT_V2_S_Weights,
|
|
|
|
R3D_18_Weights,
|
|
|
|
S3D_Weights,
|
|
|
|
Swin3D_B_Weights,
|
|
|
|
Swin3D_T_Weights,
|
|
|
|
mvit_v1_b,
|
|
|
|
mvit_v2_s,
|
|
|
|
r3d_18,
|
|
|
|
s3d,
|
|
|
|
swin3d_b,
|
|
|
|
swin3d_t,
|
|
|
|
)
|
|
|
|
|
|
|
|
model_name_to_model_and_weights = {
|
|
|
|
"s3d": (s3d, S3D_Weights.DEFAULT),
|
|
|
|
"r3d_18": (r3d_18, R3D_18_Weights.DEFAULT),
|
|
|
|
"swin3d_t": (swin3d_t, Swin3D_T_Weights.DEFAULT),
|
|
|
|
"swin3d_b": (swin3d_b, Swin3D_B_Weights.DEFAULT),
|
|
|
|
"mvit_v1_b": (mvit_v1_b, MViT_V1_B_Weights.DEFAULT),
|
|
|
|
"mvit_v2_s": (mvit_v2_s, MViT_V2_S_Weights.DEFAULT),
|
|
|
|
}
|
|
|
|
|
|
|
|
def __init__(self, model_name: str, device: str or torch.device = ""):
|
|
|
|
"""
|
|
|
|
Initialize the VideoClassifier with the specified model name and device.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model_name (str): The name of the model to use.
|
|
|
|
device (str or torch.device, optional): The device to run the model on. Defaults to "".
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
ValueError: If an invalid model name is provided.
|
|
|
|
"""
|
|
|
|
if model_name not in self.model_name_to_model_and_weights:
|
|
|
|
raise ValueError(f"Invalid model name '{model_name}'. Available models: {self.available_model_names()}")
|
|
|
|
model, self.weights = self.model_name_to_model_and_weights[model_name]
|
|
|
|
self.device = select_device(device)
|
|
|
|
self.model = model(weights=self.weights).to(self.device).eval()
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def available_model_names() -> List[str]:
|
|
|
|
"""
|
|
|
|
Get the list of available model names.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
list: List of available model names.
|
|
|
|
"""
|
|
|
|
return list(TorchVisionVideoClassifier.model_name_to_model_and_weights.keys())
|
|
|
|
|
|
|
|
def preprocess_crops_for_video_cls(self, crops: List[np.ndarray], input_size: list = None) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Preprocess a list of crops for video classification.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
crops (List[np.ndarray]): List of crops to preprocess. Each crop should have dimensions (H, W, C)
|
|
|
|
input_size (tuple, optional): The target input size for the model. Defaults to (224, 224).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
torch.Tensor: Preprocessed crops as a tensor with dimensions (1, T, C, H, W).
|
|
|
|
"""
|
|
|
|
if input_size is None:
|
|
|
|
input_size = [224, 224]
|
|
|
|
from torchvision.transforms import v2
|
|
|
|
|
|
|
|
transform = v2.Compose(
|
|
|
|
[
|
|
|
|
v2.ToDtype(torch.float32, scale=True),
|
|
|
|
v2.Resize(input_size, antialias=True),
|
|
|
|
v2.Normalize(mean=self.weights.transforms().mean, std=self.weights.transforms().std),
|
|
|
|
]
|
|
|
|
)
|
|
|
|
|
|
|
|
processed_crops = [transform(torch.from_numpy(crop).permute(2, 0, 1)) for crop in crops]
|
|
|
|
return torch.stack(processed_crops).unsqueeze(0).permute(0, 2, 1, 3, 4).to(self.device)
|
|
|
|
|
|
|
|
def __call__(self, sequences: torch.Tensor):
|
|
|
|
"""
|
|
|
|
Perform inference on the given sequences.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
sequences (torch.Tensor): The input sequences for the model. The expected input dimensions are
|
|
|
|
(B, T, C, H, W) for batched video frames or (T, C, H, W) for single video frames.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
torch.Tensor: The model's output.
|
|
|
|
"""
|
|
|
|
with torch.inference_mode():
|
|
|
|
return self.model(sequences)
|
|
|
|
|
|
|
|
def postprocess(self, outputs: torch.Tensor) -> Tuple[List[str], List[float]]:
|
|
|
|
"""
|
|
|
|
Postprocess the model's batch output.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
outputs (torch.Tensor): The model's output.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[str]: The predicted labels.
|
|
|
|
List[float]: The predicted confidences.
|
|
|
|
"""
|
|
|
|
pred_labels = []
|
|
|
|
pred_confs = []
|
|
|
|
for output in outputs:
|
|
|
|
pred_class = output.argmax(0).item()
|
|
|
|
pred_label = self.weights.meta["categories"][pred_class]
|
|
|
|
pred_labels.append(pred_label)
|
|
|
|
pred_conf = output.softmax(0)[pred_class].item()
|
|
|
|
pred_confs.append(pred_conf)
|
|
|
|
|
|
|
|
return pred_labels, pred_confs
|
|
|
|
|
|
|
|
|
|
|
|
class HuggingFaceVideoClassifier:
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
labels: List[str],
|
|
|
|
model_name: str = "microsoft/xclip-base-patch16-zero-shot",
|
|
|
|
device: str or torch.device = "",
|
|
|
|
fp16: bool = False,
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Initialize the HuggingFaceVideoClassifier with the specified model name.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
labels (List[str]): List of labels for zero-shot classification.
|
|
|
|
model_name (str): The name of the model to use. Defaults to "microsoft/xclip-base-patch16-zero-shot".
|
|
|
|
device (str or torch.device, optional): The device to run the model on. Defaults to "".
|
|
|
|
fp16 (bool, optional): Whether to use FP16 for inference. Defaults to False.
|
|
|
|
"""
|
|
|
|
self.fp16 = fp16
|
|
|
|
self.labels = labels
|
|
|
|
self.device = select_device(device)
|
|
|
|
self.processor = AutoProcessor.from_pretrained(model_name)
|
|
|
|
model = AutoModel.from_pretrained(model_name).to(self.device)
|
|
|
|
if fp16:
|
|
|
|
model = model.half()
|
|
|
|
self.model = model.eval()
|
|
|
|
|
|
|
|
def preprocess_crops_for_video_cls(self, crops: List[np.ndarray], input_size: list = None) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Preprocess a list of crops for video classification.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
crops (List[np.ndarray]): List of crops to preprocess. Each crop should have dimensions (H, W, C)
|
|
|
|
input_size (tuple, optional): The target input size for the model. Defaults to (224, 224).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
torch.Tensor: Preprocessed crops as a tensor (1, T, C, H, W).
|
|
|
|
"""
|
|
|
|
if input_size is None:
|
|
|
|
input_size = [224, 224]
|
|
|
|
from torchvision import transforms
|
|
|
|
|
|
|
|
transform = transforms.Compose(
|
|
|
|
[
|
|
|
|
transforms.Lambda(lambda x: x.float() / 255.0),
|
|
|
|
transforms.Resize(input_size),
|
|
|
|
transforms.Normalize(
|
|
|
|
mean=self.processor.image_processor.image_mean, std=self.processor.image_processor.image_std
|
|
|
|
),
|
|
|
|
]
|
|
|
|
)
|
|
|
|
|
|
|
|
processed_crops = [transform(torch.from_numpy(crop).permute(2, 0, 1)) for crop in crops] # (T, C, H, W)
|
|
|
|
output = torch.stack(processed_crops).unsqueeze(0).to(self.device) # (1, T, C, H, W)
|
|
|
|
if self.fp16:
|
|
|
|
output = output.half()
|
|
|
|
return output
|
|
|
|
|
|
|
|
def __call__(self, sequences: torch.Tensor) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Perform inference on the given sequences.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
sequences (torch.Tensor): The input sequences for the model. Batched video frames with shape (B, T, H, W, C).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
torch.Tensor: The model's output.
|
|
|
|
"""
|
|
|
|
|
|
|
|
input_ids = self.processor(text=self.labels, return_tensors="pt", padding=True)["input_ids"].to(self.device)
|
|
|
|
|
|
|
|
inputs = {"pixel_values": sequences, "input_ids": input_ids}
|
|
|
|
|
|
|
|
with torch.inference_mode():
|
|
|
|
outputs = self.model(**inputs)
|
|
|
|
|
|
|
|
return outputs.logits_per_video
|
|
|
|
|
|
|
|
def postprocess(self, outputs: torch.Tensor) -> Tuple[List[List[str]], List[List[float]]]:
|
|
|
|
"""
|
|
|
|
Postprocess the model's batch output.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
outputs (torch.Tensor): The model's output.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[List[str]]: The predicted top3 labels.
|
|
|
|
List[List[float]]: The predicted top3 confidences.
|
|
|
|
"""
|
|
|
|
pred_labels = []
|
|
|
|
pred_confs = []
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
logits_per_video = outputs # Assuming outputs is already the logits tensor
|
|
|
|
probs = logits_per_video.softmax(dim=-1) # Use softmax to convert logits to probabilities
|
|
|
|
|
|
|
|
for prob in probs:
|
|
|
|
top2_indices = prob.topk(2).indices.tolist()
|
|
|
|
top2_labels = [self.labels[idx] for idx in top2_indices]
|
|
|
|
top2_confs = prob[top2_indices].tolist()
|
|
|
|
pred_labels.append(top2_labels)
|
|
|
|
pred_confs.append(top2_confs)
|
|
|
|
|
|
|
|
return pred_labels, pred_confs
|
|
|
|
|
|
|
|
|
|
|
|
def crop_and_pad(frame, box, margin_percent):
|
|
|
|
"""Crop box with margin and take square crop from frame."""
|
|
|
|
x1, y1, x2, y2 = map(int, box)
|
|
|
|
w, h = x2 - x1, y2 - y1
|
|
|
|
|
|
|
|
# Add margin
|
|
|
|
margin_x, margin_y = int(w * margin_percent / 100), int(h * margin_percent / 100)
|
|
|
|
x1, y1 = max(0, x1 - margin_x), max(0, y1 - margin_y)
|
|
|
|
x2, y2 = min(frame.shape[1], x2 + margin_x), min(frame.shape[0], y2 + margin_y)
|
|
|
|
|
|
|
|
# Take square crop from frame
|
|
|
|
size = max(y2 - y1, x2 - x1)
|
|
|
|
center_y, center_x = (y1 + y2) // 2, (x1 + x2) // 2
|
|
|
|
half_size = size // 2
|
|
|
|
square_crop = frame[
|
|
|
|
max(0, center_y - half_size) : min(frame.shape[0], center_y + half_size),
|
|
|
|
max(0, center_x - half_size) : min(frame.shape[1], center_x + half_size),
|
|
|
|
]
|
|
|
|
|
|
|
|
return cv2.resize(square_crop, (224, 224), interpolation=cv2.INTER_LINEAR)
|
|
|
|
|
|
|
|
|
|
|
|
def run(
|
|
|
|
weights: str = "yolov8n.pt",
|
|
|
|
device: str = "",
|
|
|
|
source: str = "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
|
|
|
output_path: Optional[str] = None,
|
|
|
|
crop_margin_percentage: int = 10,
|
|
|
|
num_video_sequence_samples: int = 8,
|
|
|
|
skip_frame: int = 2,
|
|
|
|
video_cls_overlap_ratio: float = 0.25,
|
|
|
|
fp16: bool = False,
|
|
|
|
video_classifier_model: str = "microsoft/xclip-base-patch32",
|
|
|
|
labels: List[str] = None,
|
|
|
|
) -> None:
|
|
|
|
"""
|
|
|
|
Run action recognition on a video source using YOLO for object detection and a video classifier.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
weights (str): Path to the YOLO model weights. Defaults to "yolov8n.pt".
|
|
|
|
device (str): Device to run the model on. Use 'cuda' for NVIDIA GPU, 'mps' for Apple Silicon, or 'cpu'. Defaults to auto-detection.
|
|
|
|
source (str): Path to mp4 video file or YouTube URL. Defaults to a sample YouTube video.
|
|
|
|
output_path (Optional[str], optional): Path to save the output video. Defaults to None.
|
|
|
|
crop_margin_percentage (int, optional): Percentage of margin to add around detected objects. Defaults to 10.
|
|
|
|
num_video_sequence_samples (int, optional): Number of video frames to use for classification. Defaults to 8.
|
|
|
|
skip_frame (int, optional): Number of frames to skip between detections. Defaults to 4.
|
|
|
|
video_cls_overlap_ratio (float, optional): Overlap ratio between video sequences. Defaults to 0.25.
|
|
|
|
fp16 (bool, optional): Whether to use half-precision floating point. Defaults to False.
|
|
|
|
video_classifier_model (str, optional): Name or path of the video classifier model. Defaults to "microsoft/xclip-base-patch32".
|
|
|
|
labels (List[str], optional): List of labels for zero-shot classification. Defaults to predefined list.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
None</edit>
|
|
|
|
"""
|
|
|
|
if labels is None:
|
|
|
|
labels = [
|
|
|
|
"walking",
|
|
|
|
"running",
|
|
|
|
"brushing teeth",
|
|
|
|
"looking into phone",
|
|
|
|
"weight lifting",
|
|
|
|
"cooking",
|
|
|
|
"sitting",
|
|
|
|
]
|
|
|
|
# Initialize models and device
|
|
|
|
device = select_device(device)
|
|
|
|
yolo_model = YOLO(weights).to(device)
|
|
|
|
if video_classifier_model in TorchVisionVideoClassifier.available_model_names():
|
|
|
|
print("'fp16' is not supported for TorchVisionVideoClassifier. Setting fp16 to False.")
|
|
|
|
print(
|
|
|
|
"'labels' is not used for TorchVisionVideoClassifier. Ignoring the provided labels and using Kinetics-400 labels."
|
|
|
|
)
|
|
|
|
video_classifier = TorchVisionVideoClassifier(video_classifier_model, device=device)
|
|
|
|
else:
|
|
|
|
video_classifier = HuggingFaceVideoClassifier(
|
|
|
|
labels, model_name=video_classifier_model, device=device, fp16=fp16
|
|
|
|
)
|
|
|
|
|
|
|
|
# Initialize video capture
|
|
|
|
if source.startswith("http") and urlparse(source).hostname in {"www.youtube.com", "youtube.com", "youtu.be"}:
|
|
|
|
source = get_best_youtube_url(source)
|
|
|
|
elif not source.endswith(".mp4"):
|
|
|
|
raise ValueError("Invalid source. Supported sources are YouTube URLs and MP4 files.")
|
|
|
|
cap = cv2.VideoCapture(source)
|
|
|
|
|
|
|
|
# Get video properties
|
|
|
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
|
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
|
|
|
|
|
|
# Initialize VideoWriter
|
|
|
|
if output_path is not None:
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
|
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
|
|
|
|
|
|
|
|
# Initialize track history
|
|
|
|
track_history = defaultdict(list)
|
|
|
|
frame_counter = 0
|
|
|
|
|
|
|
|
track_ids_to_infer = []
|
|
|
|
crops_to_infer = []
|
|
|
|
pred_labels = []
|
|
|
|
pred_confs = []
|
|
|
|
|
|
|
|
while cap.isOpened():
|
|
|
|
success, frame = cap.read()
|
|
|
|
if not success:
|
|
|
|
break
|
|
|
|
|
|
|
|
frame_counter += 1
|
|
|
|
|
|
|
|
# Run YOLO tracking
|
|
|
|
results = yolo_model.track(frame, persist=True, classes=[0]) # Track only person class
|
|
|
|
|
|
|
|
if results[0].boxes.id is not None:
|
|
|
|
boxes = results[0].boxes.xyxy.cpu().numpy()
|
|
|
|
track_ids = results[0].boxes.id.cpu().numpy()
|
|
|
|
|
|
|
|
# Visualize prediction
|
|
|
|
annotator = Annotator(frame, line_width=3, font_size=10, pil=False)
|
|
|
|
|
|
|
|
if frame_counter % skip_frame == 0:
|
|
|
|
crops_to_infer = []
|
|
|
|
track_ids_to_infer = []
|
|
|
|
|
|
|
|
for box, track_id in zip(boxes, track_ids):
|
|
|
|
if frame_counter % skip_frame == 0:
|
|
|
|
crop = crop_and_pad(frame, box, crop_margin_percentage)
|
|
|
|
track_history[track_id].append(crop)
|
|
|
|
|
|
|
|
if len(track_history[track_id]) > num_video_sequence_samples:
|
|
|
|
track_history[track_id].pop(0)
|
|
|
|
|
|
|
|
if len(track_history[track_id]) == num_video_sequence_samples and frame_counter % skip_frame == 0:
|
|
|
|
start_time = time.time()
|
|
|
|
crops = video_classifier.preprocess_crops_for_video_cls(track_history[track_id])
|
|
|
|
end_time = time.time()
|
|
|
|
preprocess_time = end_time - start_time
|
|
|
|
print(f"video cls preprocess time: {preprocess_time:.4f} seconds")
|
|
|
|
crops_to_infer.append(crops)
|
|
|
|
track_ids_to_infer.append(track_id)
|
|
|
|
|
|
|
|
if crops_to_infer and (
|
|
|
|
not pred_labels
|
|
|
|
or frame_counter % int(num_video_sequence_samples * skip_frame * (1 - video_cls_overlap_ratio)) == 0
|
|
|
|
):
|
|
|
|
crops_batch = torch.cat(crops_to_infer, dim=0)
|
|
|
|
|
|
|
|
start_inference_time = time.time()
|
|
|
|
output_batch = video_classifier(crops_batch)
|
|
|
|
end_inference_time = time.time()
|
|
|
|
inference_time = end_inference_time - start_inference_time
|
|
|
|
print(f"video cls inference time: {inference_time:.4f} seconds")
|
|
|
|
|
|
|
|
pred_labels, pred_confs = video_classifier.postprocess(output_batch)
|
|
|
|
|
|
|
|
if track_ids_to_infer and crops_to_infer:
|
|
|
|
for box, track_id, pred_label, pred_conf in zip(boxes, track_ids_to_infer, pred_labels, pred_confs):
|
|
|
|
top2_preds = sorted(zip(pred_label, pred_conf), key=lambda x: x[1], reverse=True)
|
|
|
|
label_text = " | ".join([f"{label} ({conf:.2f})" for label, conf in top2_preds])
|
|
|
|
annotator.box_label(box, label_text, color=(0, 0, 255))
|
|
|
|
|
|
|
|
# Write the annotated frame to the output video
|
|
|
|
if output_path is not None:
|
|
|
|
out.write(frame)
|
|
|
|
|
|
|
|
# Display the annotated frame
|
|
|
|
cv2.imshow("YOLOv8 Tracking with S3D Classification", frame)
|
|
|
|
|
|
|
|
if cv2.waitKey(1) & 0xFF == ord("q"):
|
|
|
|
break
|
|
|
|
|
|
|
|
cap.release()
|
|
|
|
if output_path is not None:
|
|
|
|
out.release()
|
|
|
|
cv2.destroyAllWindows()
|
|
|
|
|
|
|
|
|
|
|
|
def parse_opt():
|
|
|
|
"""Parse command line arguments."""
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--weights", type=str, default="yolov8n.pt", help="ultralytics detector model path")
|
|
|
|
parser.add_argument("--device", default="", help='cuda device, i.e. 0 or 0,1,2,3 or cpu/mps, "" for auto-detection')
|
|
|
|
parser.add_argument(
|
|
|
|
"--source",
|
|
|
|
type=str,
|
|
|
|
default="https://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
|
|
|
help="video file path or youtube URL",
|
|
|
|
)
|
|
|
|
parser.add_argument("--output-path", type=str, default="output_video.mp4", help="output video file path")
|
|
|
|
parser.add_argument(
|
|
|
|
"--crop-margin-percentage", type=int, default=10, help="percentage of margin to add around detected objects"
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--num-video-sequence-samples", type=int, default=8, help="number of video frames to use for classification"
|
|
|
|
)
|
|
|
|
parser.add_argument("--skip-frame", type=int, default=2, help="number of frames to skip between detections")
|
|
|
|
parser.add_argument(
|
|
|
|
"--video-cls-overlap-ratio", type=float, default=0.25, help="overlap ratio between video sequences"
|
|
|
|
)
|
|
|
|
parser.add_argument("--fp16", action="store_true", help="use FP16 for inference")
|
|
|
|
parser.add_argument(
|
|
|
|
"--video-classifier-model", type=str, default="microsoft/xclip-base-patch32", help="video classifier model name"
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--labels",
|
|
|
|
nargs="+",
|
|
|
|
type=str,
|
|
|
|
default=["dancing", "singing a song"],
|
|
|
|
help="labels for zero-shot video classification",
|
|
|
|
)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
def main(opt):
|
|
|
|
"""Main function."""
|
|
|
|
run(**vars(opt))
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
opt = parse_opt()
|
|
|
|
main(opt)
|