|
|
|
@ -6,18 +6,26 @@ from typing import List, Tuple |
|
|
|
|
import cv2 |
|
|
|
|
import numpy as np |
|
|
|
|
import torch |
|
|
|
|
from ultralytics.utils.checks import check_imshow |
|
|
|
|
|
|
|
|
|
from ultralytics.utils.checks import check_imshow |
|
|
|
|
from ultralytics.utils.plotting import Annotator |
|
|
|
|
from ultralytics.utils.torch_utils import select_device |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ActionRecognition: |
|
|
|
|
"""A class to recognize actions in a real-time video stream based on object tracks.""" |
|
|
|
|
|
|
|
|
|
def __init__(self, video_classifier_model="microsoft/xclip-base-patch32", labels=None, fp16=False, |
|
|
|
|
crop_margin_percentage=10, num_video_sequence_samples=8, skip_frame=2, video_cls_overlap_ratio=0.25, device=""): |
|
|
|
|
def __init__( |
|
|
|
|
self, |
|
|
|
|
video_classifier_model="microsoft/xclip-base-patch32", |
|
|
|
|
labels=None, |
|
|
|
|
fp16=False, |
|
|
|
|
crop_margin_percentage=10, |
|
|
|
|
num_video_sequence_samples=8, |
|
|
|
|
skip_frame=2, |
|
|
|
|
video_cls_overlap_ratio=0.25, |
|
|
|
|
device="", |
|
|
|
|
): |
|
|
|
|
""" |
|
|
|
|
Initializes the ActionRecognition with the given parameters. |
|
|
|
|
|
|
|
|
@ -31,9 +39,11 @@ class ActionRecognition: |
|
|
|
|
video_cls_overlap_ratio (float, optional): Overlap ratio between video sequences. Defaults to 0.25. |
|
|
|
|
device (str or torch.device, optional): The device to run the model on. Defaults to "". |
|
|
|
|
""" |
|
|
|
|
self.labels = labels if labels is not None else [ |
|
|
|
|
"walking", "running", "brushing teeth", "looking into phone", "weight lifting", "cooking", "sitting" |
|
|
|
|
] |
|
|
|
|
self.labels = ( |
|
|
|
|
labels |
|
|
|
|
if labels is not None |
|
|
|
|
else ["walking", "running", "brushing teeth", "looking into phone", "weight lifting", "cooking", "sitting"] |
|
|
|
|
) |
|
|
|
|
self.fp16 = fp16 |
|
|
|
|
self.device = select_device(device) |
|
|
|
|
|
|
|
|
@ -43,7 +53,9 @@ class ActionRecognition: |
|
|
|
|
|
|
|
|
|
if video_classifier_model in TorchVisionVideoClassifier.available_model_names(): |
|
|
|
|
print("'fp16' is not supported for TorchVisionVideoClassifier. Setting fp16 to False.") |
|
|
|
|
print("'labels' is not used for TorchVisionVideoClassifier. Ignoring the provided labels and using Kinetics-400 labels.") |
|
|
|
|
print( |
|
|
|
|
"'labels' is not used for TorchVisionVideoClassifier. Ignoring the provided labels and using Kinetics-400 labels." |
|
|
|
|
) |
|
|
|
|
self.video_classifier = TorchVisionVideoClassifier(video_classifier_model, device=self.device) |
|
|
|
|
else: |
|
|
|
|
self.video_classifier = HuggingFaceVideoClassifier( |
|
|
|
@ -164,14 +176,19 @@ class ActionRecognition: |
|
|
|
|
track_ids_to_infer = [] |
|
|
|
|
|
|
|
|
|
for box, track_id in zip(self.boxes, self.track_ids): |
|
|
|
|
if len(self.track_history[track_id]) == self.num_video_sequence_samples and self.frame_counter % self.skip_frame == 0: |
|
|
|
|
if ( |
|
|
|
|
len(self.track_history[track_id]) == self.num_video_sequence_samples |
|
|
|
|
and self.frame_counter % self.skip_frame == 0 |
|
|
|
|
): |
|
|
|
|
crops = self.video_classifier.preprocess_crops_for_video_cls(self.track_history[track_id]) |
|
|
|
|
crops_to_infer.append(crops) |
|
|
|
|
track_ids_to_infer.append(track_id) |
|
|
|
|
|
|
|
|
|
if crops_to_infer and ( |
|
|
|
|
not pred_labels |
|
|
|
|
or self.frame_counter % int(self.num_video_sequence_samples * self.skip_frame * (1 - self.video_cls_overlap_ratio)) == 0 |
|
|
|
|
or self.frame_counter |
|
|
|
|
% int(self.num_video_sequence_samples * self.skip_frame * (1 - self.video_cls_overlap_ratio)) |
|
|
|
|
== 0 |
|
|
|
|
): |
|
|
|
|
crops_batch = torch.cat(crops_to_infer, dim=0) |
|
|
|
|
output_batch = self.predict_action(crops_batch) |
|
|
|
|