diff --git a/ultralytics/solutions/action_recognition.py b/ultralytics/solutions/action_recognition.py index ead986966b..0483606a2a 100644 --- a/ultralytics/solutions/action_recognition.py +++ b/ultralytics/solutions/action_recognition.py @@ -21,8 +21,6 @@ class ActionRecognition: def __init__( self, video_classifier_model="microsoft/xclip-base-patch32", - labels: Optional[List[str]] = None, - fp16: bool = False, crop_margin_percentage: int = 10, num_video_sequence_samples: int = 8, vid_stride: int = 2, @@ -35,7 +33,6 @@ class ActionRecognition: Args: video_classifier_model (str): Name or path of the video classifier model. Defaults to "microsoft/xclip-base-patch32". labels (List[str], optional): List of labels for zero-shot classification. Defaults to predefined list. - fp16 (bool, optional): Whether to use half-precision floating point. Defaults to False. crop_margin_percentage (int, optional): Percentage of margin to add around detected objects. Defaults to 10. num_video_sequence_samples (int, optional): Number of video frames to use for classification. Defaults to 8. vid_stride (int, optional): Number of frames to skip between detections. Defaults to 2. @@ -47,22 +44,22 @@ class ActionRecognition: if labels is not None else ["walking", "running", "brushing teeth", "looking into phone", "weight lifting", "cooking", "sitting"] ) - self.fp16 = fp16 + self.device = select_device(device) + self.fp16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported() and 'cuda' in self.device # Check if environment supports imshow self.env_check = check_imshow(warn=True) self.window_name = "Ultralytics YOLOv8 Action Recognition" if video_classifier_model in TorchVisionVideoClassifier.available_model_names(): - print("'fp16' is not supported for TorchVisionVideoClassifier. Setting fp16 to False.") print( "'labels' is not used for TorchVisionVideoClassifier. Ignoring the provided labels and using Kinetics-400 labels." ) self.video_classifier = TorchVisionVideoClassifier(video_classifier_model, device=self.device) else: self.video_classifier = HuggingFaceVideoClassifier( - self.labels, model_name=video_classifier_model, device=self.device, fp16=fp16 + self.labels, model_name=video_classifier_model, device=self.device, fp16= self.fp16 ) self.track_history = defaultdict(list)