|
|
|
@ -1,7 +1,7 @@ |
|
|
|
|
# Ultralytics YOLO 🚀, AGPL-3.0 license |
|
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
|
|
|
from typing import List, Tuple |
|
|
|
|
from typing import List, Optional, Tuple |
|
|
|
|
|
|
|
|
|
import cv2 |
|
|
|
|
import numpy as np |
|
|
|
@ -10,6 +10,7 @@ import torch |
|
|
|
|
from ultralytics.utils.checks import check_imshow, check_requirements |
|
|
|
|
from ultralytics.utils.plotting import Annotator |
|
|
|
|
from ultralytics.utils.torch_utils import select_device |
|
|
|
|
from ultralytics.engine.results import Results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ActionRecognition: |
|
|
|
@ -18,13 +19,13 @@ class ActionRecognition: |
|
|
|
|
def __init__( |
|
|
|
|
self, |
|
|
|
|
video_classifier_model="microsoft/xclip-base-patch32", |
|
|
|
|
labels=None, |
|
|
|
|
fp16=False, |
|
|
|
|
crop_margin_percentage=10, |
|
|
|
|
num_video_sequence_samples=8, |
|
|
|
|
vid_stride=2, |
|
|
|
|
video_cls_overlap_ratio=0.25, |
|
|
|
|
device="", |
|
|
|
|
labels: Optional[List[str]] = None, |
|
|
|
|
fp16: bool = False, |
|
|
|
|
crop_margin_percentage: int = 10, |
|
|
|
|
num_video_sequence_samples: int = 8, |
|
|
|
|
vid_stride: int = 2, |
|
|
|
|
video_cls_overlap_ratio: float = 0.25, |
|
|
|
|
device: str or torch.device = "", |
|
|
|
|
): |
|
|
|
|
""" |
|
|
|
|
Initializes the ActionRecognition with the given parameters. |
|
|
|
@ -72,12 +73,12 @@ class ActionRecognition: |
|
|
|
|
self.vid_stride = vid_stride |
|
|
|
|
self.video_cls_overlap_ratio = video_cls_overlap_ratio |
|
|
|
|
|
|
|
|
|
def process_tracks(self, tracks): |
|
|
|
|
def process_tracks(self, tracks: List[Results]): |
|
|
|
|
""" |
|
|
|
|
Extracts results from the provided tracking data and stores track information. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
tracks (list): List of tracks obtained from the object tracking process. |
|
|
|
|
tracks (List[Results]): List of tracks obtained from the object tracking process. |
|
|
|
|
""" |
|
|
|
|
self.boxes = tracks[0].boxes.xyxy.cpu().numpy() |
|
|
|
|
self.track_ids = tracks[0].boxes.id.cpu().numpy() |
|
|
|
@ -106,19 +107,7 @@ class ActionRecognition: |
|
|
|
|
if cv2.waitKey(1) & 0xFF == ord("q"): |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
def predict_action(self, sequences): |
|
|
|
|
""" |
|
|
|
|
Perform inference on the given sequences. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
sequences (torch.Tensor): The input sequences for the model. Batched video frames with shape (B, T, H, W, C). |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
(torch.Tensor): The model's output. |
|
|
|
|
""" |
|
|
|
|
return self.video_classifier(sequences) |
|
|
|
|
|
|
|
|
|
def postprocess(self, outputs): |
|
|
|
|
def postprocess(self, outputs: torch.Tensor) -> Tuple[List[List[str]], List[List[float]]]: |
|
|
|
|
""" |
|
|
|
|
Postprocess the model's batch output. |
|
|
|
|
|
|
|
|
@ -137,21 +126,21 @@ class ActionRecognition: |
|
|
|
|
probs = logits_per_video.softmax(dim=-1) |
|
|
|
|
|
|
|
|
|
for prob in probs: |
|
|
|
|
top2_indices = prob.topk(2).indices.tolist() |
|
|
|
|
top2_labels = [self.labels[idx] for idx in top2_indices] |
|
|
|
|
top2_confs = prob[top2_indices].tolist() |
|
|
|
|
pred_labels.append(top2_labels) |
|
|
|
|
pred_confs.append(top2_confs) |
|
|
|
|
top3_indices = prob.topk(3).indices.tolist() |
|
|
|
|
top3_labels = [self.labels[idx] for idx in top3_indices] |
|
|
|
|
top3_confs = prob[top3_indices].tolist() |
|
|
|
|
pred_labels.append(top3_labels) |
|
|
|
|
pred_confs.append(top3_confs) |
|
|
|
|
|
|
|
|
|
return pred_labels, pred_confs |
|
|
|
|
|
|
|
|
|
def recognize_actions(self, im0, tracks): |
|
|
|
|
def recognize_actions(self, im0: np.ndarray, tracks: List[Results]) -> np.ndarray: |
|
|
|
|
""" |
|
|
|
|
Recognizes actions based on tracking data. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
im0 (ndarray): Image. |
|
|
|
|
tracks (list): List of tracks obtained from the object tracking process. |
|
|
|
|
tracks (List[Results]): List of tracks obtained from the object tracking process. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
(ndarray): The image with annotated boxes and tracks. |
|
|
|
@ -188,8 +177,8 @@ class ActionRecognition: |
|
|
|
|
% int(self.num_video_sequence_samples * self.vid_stride * (1 - self.video_cls_overlap_ratio)) |
|
|
|
|
== 0 |
|
|
|
|
): |
|
|
|
|
crops_batch = torch.cat(crops_to_infer, dim=0) |
|
|
|
|
output_batch = self.predict_action(crops_batch) |
|
|
|
|
crops_batch = torch.cat(crops_to_infer, dim=0) # crops_batch shape: (B, T, H, W, C) |
|
|
|
|
output_batch = self.video_classifier(crops_batch) |
|
|
|
|
pred_labels, pred_confs = self.postprocess(output_batch) |
|
|
|
|
|
|
|
|
|
if track_ids_to_infer and crops_to_infer: |
|
|
|
@ -466,8 +455,18 @@ class HuggingFaceVideoClassifier: |
|
|
|
|
return pred_labels, pred_confs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def crop_and_pad(frame, box, margin_percent): |
|
|
|
|
"""Crop box with margin and take square crop from frame.""" |
|
|
|
|
def crop_and_pad(frame, box, margin_percent: int = 10) -> np.ndarray: |
|
|
|
|
""" |
|
|
|
|
Crop box with margin and take square crop from frame. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
frame (ndarray): The input frame. |
|
|
|
|
box (list): The bounding box coordinates. |
|
|
|
|
margin_percent (int, optional): The percentage [0-100] of margin to add around the detected object. Defaults to 10. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
ndarray: The cropped and resized frame. |
|
|
|
|
""" |
|
|
|
|
x1, y1, x2, y2 = map(int, box) |
|
|
|
|
w, h = x2 - x1, y2 - y1 |
|
|
|
|
|
|
|
|
@ -507,7 +506,7 @@ if __name__ == "__main__": |
|
|
|
|
if not success: |
|
|
|
|
break |
|
|
|
|
# Perform object tracking |
|
|
|
|
tracks = model.track(frame, persist=True, classes=[0]) |
|
|
|
|
tracks: List[Results] = model.track(frame, persist=True, classes=[0]) |
|
|
|
|
# Perform action recognition |
|
|
|
|
annotated_frame = action_recognition.recognize_actions(frame, tracks) |
|
|
|
|
# Display the frame |
|
|
|
|