tons of typehinting

action-recog
fcakyon 4 months ago
parent d5231e8eaf
commit 1fde585f47
  1. 69
      ultralytics/solutions/action_recognition.py

@ -1,7 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from collections import defaultdict
from typing import List, Tuple
from typing import List, Optional, Tuple
import cv2
import numpy as np
@ -10,6 +10,7 @@ import torch
from ultralytics.utils.checks import check_imshow, check_requirements
from ultralytics.utils.plotting import Annotator
from ultralytics.utils.torch_utils import select_device
from ultralytics.engine.results import Results
class ActionRecognition:
@ -18,13 +19,13 @@ class ActionRecognition:
def __init__(
self,
video_classifier_model="microsoft/xclip-base-patch32",
labels=None,
fp16=False,
crop_margin_percentage=10,
num_video_sequence_samples=8,
vid_stride=2,
video_cls_overlap_ratio=0.25,
device="",
labels: Optional[List[str]] = None,
fp16: bool = False,
crop_margin_percentage: int = 10,
num_video_sequence_samples: int = 8,
vid_stride: int = 2,
video_cls_overlap_ratio: float = 0.25,
device: str or torch.device = "",
):
"""
Initializes the ActionRecognition with the given parameters.
@ -72,12 +73,12 @@ class ActionRecognition:
self.vid_stride = vid_stride
self.video_cls_overlap_ratio = video_cls_overlap_ratio
def process_tracks(self, tracks):
def process_tracks(self, tracks: List[Results]):
"""
Extracts results from the provided tracking data and stores track information.
Args:
tracks (list): List of tracks obtained from the object tracking process.
tracks (List[Results]): List of tracks obtained from the object tracking process.
"""
self.boxes = tracks[0].boxes.xyxy.cpu().numpy()
self.track_ids = tracks[0].boxes.id.cpu().numpy()
@ -106,19 +107,7 @@ class ActionRecognition:
if cv2.waitKey(1) & 0xFF == ord("q"):
return
def predict_action(self, sequences):
"""
Perform inference on the given sequences.
Args:
sequences (torch.Tensor): The input sequences for the model. Batched video frames with shape (B, T, H, W, C).
Returns:
(torch.Tensor): The model's output.
"""
return self.video_classifier(sequences)
def postprocess(self, outputs):
def postprocess(self, outputs: torch.Tensor) -> Tuple[List[List[str]], List[List[float]]]:
"""
Postprocess the model's batch output.
@ -137,21 +126,21 @@ class ActionRecognition:
probs = logits_per_video.softmax(dim=-1)
for prob in probs:
top2_indices = prob.topk(2).indices.tolist()
top2_labels = [self.labels[idx] for idx in top2_indices]
top2_confs = prob[top2_indices].tolist()
pred_labels.append(top2_labels)
pred_confs.append(top2_confs)
top3_indices = prob.topk(3).indices.tolist()
top3_labels = [self.labels[idx] for idx in top3_indices]
top3_confs = prob[top3_indices].tolist()
pred_labels.append(top3_labels)
pred_confs.append(top3_confs)
return pred_labels, pred_confs
def recognize_actions(self, im0, tracks):
def recognize_actions(self, im0: np.ndarray, tracks: List[Results]) -> np.ndarray:
"""
Recognizes actions based on tracking data.
Args:
im0 (ndarray): Image.
tracks (list): List of tracks obtained from the object tracking process.
tracks (List[Results]): List of tracks obtained from the object tracking process.
Returns:
(ndarray): The image with annotated boxes and tracks.
@ -188,8 +177,8 @@ class ActionRecognition:
% int(self.num_video_sequence_samples * self.vid_stride * (1 - self.video_cls_overlap_ratio))
== 0
):
crops_batch = torch.cat(crops_to_infer, dim=0)
output_batch = self.predict_action(crops_batch)
crops_batch = torch.cat(crops_to_infer, dim=0) # crops_batch shape: (B, T, H, W, C)
output_batch = self.video_classifier(crops_batch)
pred_labels, pred_confs = self.postprocess(output_batch)
if track_ids_to_infer and crops_to_infer:
@ -466,8 +455,18 @@ class HuggingFaceVideoClassifier:
return pred_labels, pred_confs
def crop_and_pad(frame, box, margin_percent):
"""Crop box with margin and take square crop from frame."""
def crop_and_pad(frame, box, margin_percent: int = 10) -> np.ndarray:
"""
Crop box with margin and take square crop from frame.
Args:
frame (ndarray): The input frame.
box (list): The bounding box coordinates.
margin_percent (int, optional): The percentage [0-100] of margin to add around the detected object. Defaults to 10.
Returns:
ndarray: The cropped and resized frame.
"""
x1, y1, x2, y2 = map(int, box)
w, h = x2 - x1, y2 - y1
@ -507,7 +506,7 @@ if __name__ == "__main__":
if not success:
break
# Perform object tracking
tracks = model.track(frame, persist=True, classes=[0])
tracks: List[Results] = model.track(frame, persist=True, classes=[0])
# Perform action recognition
annotated_frame = action_recognition.recognize_actions(frame, tracks)
# Display the frame

Loading…
Cancel
Save