From 327a37f4fddf25e4ac88cc4bc29d91845f15d458 Mon Sep 17 00:00:00 2001 From: fcakyon Date: Mon, 22 Jul 2024 13:56:20 +0300 Subject: [PATCH] refactor: Update image transformation and supported models logic in TorchVisionVideoClassifier and HuggingFaceVideoClassifier --- .../action_recognition.py | 57 +++++++++---------- ultralytics/solutions/action_recognition.py | 52 ++++++++--------- 2 files changed, 54 insertions(+), 55 deletions(-) diff --git a/examples/YOLOv8-Action-Recognition/action_recognition.py b/examples/YOLOv8-Action-Recognition/action_recognition.py index d10063c00d..1d6d7106f6 100644 --- a/examples/YOLOv8-Action-Recognition/action_recognition.py +++ b/examples/YOLOv8-Action-Recognition/action_recognition.py @@ -20,29 +20,30 @@ from ultralytics.utils.torch_utils import select_device class TorchVisionVideoClassifier: """Classifies videos using pretrained TorchVision models; see https://pytorch.org/vision/stable/.""" - from torchvision.models.video import ( - MViT_V1_B_Weights, - MViT_V2_S_Weights, - R3D_18_Weights, - S3D_Weights, - Swin3D_B_Weights, - Swin3D_T_Weights, - mvit_v1_b, - mvit_v2_s, - r3d_18, - s3d, - swin3d_b, - swin3d_t, - ) - - model_name_to_model_and_weights = { - "s3d": (s3d, S3D_Weights.DEFAULT), - "r3d_18": (r3d_18, R3D_18_Weights.DEFAULT), - "swin3d_t": (swin3d_t, Swin3D_T_Weights.DEFAULT), - "swin3d_b": (swin3d_b, Swin3D_B_Weights.DEFAULT), - "mvit_v1_b": (mvit_v1_b, MViT_V1_B_Weights.DEFAULT), - "mvit_v2_s": (mvit_v2_s, MViT_V2_S_Weights.DEFAULT), - } + supports_r3d = check_requirements("torchvision>=0.8.1", install=False) + supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False) + supports_mvitv1b = supports_s3d = check_requirements("torchvision>=0.14.0", install=False) + supports_mvitv2s = supports_swin3dt = supports_swin3db = check_requirements("torchvision>=0.15.0", install=False) + + model_name_to_model_and_weights = {} + if supports_r3d: + from torchvision.models.video import r3d_18, R3D_18_Weights + model_name_to_model_and_weights["r3d_18"] = (r3d_18, R3D_18_Weights.DEFAULT) + if supports_s3d: + from torchvision.models.video import s3d, S3D_Weights + model_name_to_model_and_weights["s3d"] = (s3d, S3D_Weights.DEFAULT) + if supports_swin3db: + from torchvision.models.video import swin3d_b, Swin3D_B_Weights + model_name_to_model_and_weights["swin3d_b"] = (swin3d_b, Swin3D_B_Weights.DEFAULT) + if supports_swin3dt: + from torchvision.models.video import swin3d_t, Swin3D_T_Weights + model_name_to_model_and_weights["swin3d_t"] = (swin3d_t, Swin3D_T_Weights.DEFAULT) + if supports_mvitv1b: + from torchvision.models.video import mvit_v1_b, MViT_V1_B_Weights + model_name_to_model_and_weights["mvit_v1_b"] = (mvit_v1_b, MViT_V1_B_Weights.DEFAULT) + if supports_mvitv2s: + from torchvision.models.video import mvit_v2_s, MViT_V2_S_Weights + model_name_to_model_and_weights["mvit_v2_s"] = (mvit_v2_s, MViT_V2_S_Weights.DEFAULT) def __init__(self, model_name: str, device: str or torch.device = ""): """ @@ -86,9 +87,7 @@ class TorchVisionVideoClassifier: if input_size is None: input_size = [224, 224] - supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False) - - if supports_transforms_v2: + if self.supports_transforms_v2: from torchvision.transforms import v2 transform = v2.Compose( @@ -152,6 +151,8 @@ class TorchVisionVideoClassifier: class HuggingFaceVideoClassifier: """Zero-shot video classifier using Hugging Face models for various devices.""" + supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False) + def __init__( self, labels: List[str], @@ -194,9 +195,7 @@ class HuggingFaceVideoClassifier: if input_size is None: input_size = [224, 224] - supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False) - - if supports_transforms_v2: + if self.supports_transforms_v2: from torchvision.transforms import v2 transform = v2.Compose( diff --git a/ultralytics/solutions/action_recognition.py b/ultralytics/solutions/action_recognition.py index 8de3de0428..ae9b8a2170 100644 --- a/ultralytics/solutions/action_recognition.py +++ b/ultralytics/solutions/action_recognition.py @@ -225,29 +225,31 @@ class ActionRecognition: class TorchVisionVideoClassifier: """Classifies videos using pretrained TorchVision models; see https://pytorch.org/vision/stable/.""" - from torchvision.models.video import ( - MViT_V1_B_Weights, - MViT_V2_S_Weights, - R3D_18_Weights, - S3D_Weights, - Swin3D_B_Weights, - Swin3D_T_Weights, - mvit_v1_b, - mvit_v2_s, - r3d_18, - s3d, - swin3d_b, - swin3d_t, - ) - - model_name_to_model_and_weights = { - "s3d": (s3d, S3D_Weights.DEFAULT), - "r3d_18": (r3d_18, R3D_18_Weights.DEFAULT), - "swin3d_t": (swin3d_t, Swin3D_T_Weights.DEFAULT), - "swin3d_b": (swin3d_b, Swin3D_B_Weights.DEFAULT), - "mvit_v1_b": (mvit_v1_b, MViT_V1_B_Weights.DEFAULT), - "mvit_v2_s": (mvit_v2_s, MViT_V2_S_Weights.DEFAULT), - } + supports_r3d = check_requirements("torchvision>=0.8.1", install=False) + supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False) + supports_mvitv1b = supports_s3d = check_requirements("torchvision>=0.14.0", install=False) + supports_mvitv2s = supports_swin3dt = supports_swin3db = check_requirements("torchvision>=0.15.0", install=False) + + model_name_to_model_and_weights = {} + if supports_r3d: + from torchvision.models.video import r3d_18, R3D_18_Weights + model_name_to_model_and_weights["r3d_18"] = (r3d_18, R3D_18_Weights.DEFAULT) + if supports_s3d: + from torchvision.models.video import s3d, S3D_Weights + model_name_to_model_and_weights["s3d"] = (s3d, S3D_Weights.DEFAULT) + if supports_swin3db: + from torchvision.models.video import swin3d_b, Swin3D_B_Weights + model_name_to_model_and_weights["swin3d_b"] = (swin3d_b, Swin3D_B_Weights.DEFAULT) + if supports_swin3dt: + from torchvision.models.video import swin3d_t, Swin3D_T_Weights + model_name_to_model_and_weights["swin3d_t"] = (swin3d_t, Swin3D_T_Weights.DEFAULT) + if supports_mvitv1b: + from torchvision.models.video import mvit_v1_b, MViT_V1_B_Weights + model_name_to_model_and_weights["mvit_v1_b"] = (mvit_v1_b, MViT_V1_B_Weights.DEFAULT) + if supports_mvitv2s: + from torchvision.models.video import mvit_v2_s, MViT_V2_S_Weights + model_name_to_model_and_weights["mvit_v2_s"] = (mvit_v2_s, MViT_V2_S_Weights.DEFAULT) + def __init__(self, model_name: str, device: str or torch.device = ""): """ @@ -290,9 +292,7 @@ class TorchVisionVideoClassifier: if input_size is None: input_size = [224, 224] - supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False) - - if supports_transforms_v2: + if self.supports_transforms_v2: from torchvision.transforms import v2 transform = v2.Compose(