|
|
|
@ -22,7 +22,7 @@ class ActionRecognition: |
|
|
|
|
fp16=False, |
|
|
|
|
crop_margin_percentage=10, |
|
|
|
|
num_video_sequence_samples=8, |
|
|
|
|
skip_frame=2, |
|
|
|
|
vid_stride=2, |
|
|
|
|
video_cls_overlap_ratio=0.25, |
|
|
|
|
device="", |
|
|
|
|
): |
|
|
|
@ -35,7 +35,7 @@ class ActionRecognition: |
|
|
|
|
fp16 (bool, optional): Whether to use half-precision floating point. Defaults to False. |
|
|
|
|
crop_margin_percentage (int, optional): Percentage of margin to add around detected objects. Defaults to 10. |
|
|
|
|
num_video_sequence_samples (int, optional): Number of video frames to use for classification. Defaults to 8. |
|
|
|
|
skip_frame (int, optional): Number of frames to skip between detections. Defaults to 2. |
|
|
|
|
vid_stride (int, optional): Number of frames to skip between detections. Defaults to 2. |
|
|
|
|
video_cls_overlap_ratio (float, optional): Overlap ratio between video sequences. Defaults to 0.25. |
|
|
|
|
device (str or torch.device, optional): The device to run the model on. Defaults to "". |
|
|
|
|
""" |
|
|
|
@ -69,7 +69,7 @@ class ActionRecognition: |
|
|
|
|
# Properties with default values |
|
|
|
|
self.crop_margin_percentage = crop_margin_percentage |
|
|
|
|
self.num_video_sequence_samples = num_video_sequence_samples |
|
|
|
|
self.skip_frame = skip_frame |
|
|
|
|
self.vid_stride = vid_stride |
|
|
|
|
self.video_cls_overlap_ratio = video_cls_overlap_ratio |
|
|
|
|
|
|
|
|
|
def process_tracks(self, tracks): |
|
|
|
@ -169,14 +169,14 @@ class ActionRecognition: |
|
|
|
|
|
|
|
|
|
self.process_tracks(tracks) |
|
|
|
|
|
|
|
|
|
if self.frame_counter % self.skip_frame == 0: |
|
|
|
|
if self.frame_counter % self.vid_stride == 0: |
|
|
|
|
crops_to_infer = [] |
|
|
|
|
track_ids_to_infer = [] |
|
|
|
|
|
|
|
|
|
for box, track_id in zip(self.boxes, self.track_ids): |
|
|
|
|
if ( |
|
|
|
|
len(self.track_history[track_id]) == self.num_video_sequence_samples |
|
|
|
|
and self.frame_counter % self.skip_frame == 0 |
|
|
|
|
and self.frame_counter % self.vid_stride == 0 |
|
|
|
|
): |
|
|
|
|
crops = self.video_classifier.preprocess_crops_for_video_cls(self.track_history[track_id]) |
|
|
|
|
crops_to_infer.append(crops) |
|
|
|
@ -185,7 +185,7 @@ class ActionRecognition: |
|
|
|
|
if crops_to_infer and ( |
|
|
|
|
not pred_labels |
|
|
|
|
or self.frame_counter |
|
|
|
|
% int(self.num_video_sequence_samples * self.skip_frame * (1 - self.video_cls_overlap_ratio)) |
|
|
|
|
% int(self.num_video_sequence_samples * self.vid_stride * (1 - self.video_cls_overlap_ratio)) |
|
|
|
|
== 0 |
|
|
|
|
): |
|
|
|
|
crops_batch = torch.cat(crops_to_infer, dim=0) |
|
|
|
|