|
|
|
@ -21,8 +21,6 @@ class ActionRecognition: |
|
|
|
|
def __init__( |
|
|
|
|
self, |
|
|
|
|
video_classifier_model="microsoft/xclip-base-patch32", |
|
|
|
|
labels: Optional[List[str]] = None, |
|
|
|
|
fp16: bool = False, |
|
|
|
|
crop_margin_percentage: int = 10, |
|
|
|
|
num_video_sequence_samples: int = 8, |
|
|
|
|
vid_stride: int = 2, |
|
|
|
@ -35,7 +33,6 @@ class ActionRecognition: |
|
|
|
|
Args: |
|
|
|
|
video_classifier_model (str): Name or path of the video classifier model. Defaults to "microsoft/xclip-base-patch32". |
|
|
|
|
labels (List[str], optional): List of labels for zero-shot classification. Defaults to predefined list. |
|
|
|
|
fp16 (bool, optional): Whether to use half-precision floating point. Defaults to False. |
|
|
|
|
crop_margin_percentage (int, optional): Percentage of margin to add around detected objects. Defaults to 10. |
|
|
|
|
num_video_sequence_samples (int, optional): Number of video frames to use for classification. Defaults to 8. |
|
|
|
|
vid_stride (int, optional): Number of frames to skip between detections. Defaults to 2. |
|
|
|
@ -47,22 +44,22 @@ class ActionRecognition: |
|
|
|
|
if labels is not None |
|
|
|
|
else ["walking", "running", "brushing teeth", "looking into phone", "weight lifting", "cooking", "sitting"] |
|
|
|
|
) |
|
|
|
|
self.fp16 = fp16 |
|
|
|
|
|
|
|
|
|
self.device = select_device(device) |
|
|
|
|
self.fp16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported() and 'cuda' in self.device |
|
|
|
|
|
|
|
|
|
# Check if environment supports imshow |
|
|
|
|
self.env_check = check_imshow(warn=True) |
|
|
|
|
self.window_name = "Ultralytics YOLOv8 Action Recognition" |
|
|
|
|
|
|
|
|
|
if video_classifier_model in TorchVisionVideoClassifier.available_model_names(): |
|
|
|
|
print("'fp16' is not supported for TorchVisionVideoClassifier. Setting fp16 to False.") |
|
|
|
|
print( |
|
|
|
|
"'labels' is not used for TorchVisionVideoClassifier. Ignoring the provided labels and using Kinetics-400 labels." |
|
|
|
|
) |
|
|
|
|
self.video_classifier = TorchVisionVideoClassifier(video_classifier_model, device=self.device) |
|
|
|
|
else: |
|
|
|
|
self.video_classifier = HuggingFaceVideoClassifier( |
|
|
|
|
self.labels, model_name=video_classifier_model, device=self.device, fp16=fp16 |
|
|
|
|
self.labels, model_name=video_classifier_model, device=self.device, fp16= self.fp16 |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
self.track_history = defaultdict(list) |
|
|
|
|