refactor: Update image transformation and supported models logic in TorchVisionVideoClassifier and HuggingFaceVideoClassifier

action-recog
fcakyon 4 months ago
parent 16efb5426b
commit 327a37f4fd
  1. 57
      examples/YOLOv8-Action-Recognition/action_recognition.py
  2. 52
      ultralytics/solutions/action_recognition.py

@ -20,29 +20,30 @@ from ultralytics.utils.torch_utils import select_device
class TorchVisionVideoClassifier: class TorchVisionVideoClassifier:
"""Classifies videos using pretrained TorchVision models; see https://pytorch.org/vision/stable/.""" """Classifies videos using pretrained TorchVision models; see https://pytorch.org/vision/stable/."""
from torchvision.models.video import ( supports_r3d = check_requirements("torchvision>=0.8.1", install=False)
MViT_V1_B_Weights, supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False)
MViT_V2_S_Weights, supports_mvitv1b = supports_s3d = check_requirements("torchvision>=0.14.0", install=False)
R3D_18_Weights, supports_mvitv2s = supports_swin3dt = supports_swin3db = check_requirements("torchvision>=0.15.0", install=False)
S3D_Weights,
Swin3D_B_Weights, model_name_to_model_and_weights = {}
Swin3D_T_Weights, if supports_r3d:
mvit_v1_b, from torchvision.models.video import r3d_18, R3D_18_Weights
mvit_v2_s, model_name_to_model_and_weights["r3d_18"] = (r3d_18, R3D_18_Weights.DEFAULT)
r3d_18, if supports_s3d:
s3d, from torchvision.models.video import s3d, S3D_Weights
swin3d_b, model_name_to_model_and_weights["s3d"] = (s3d, S3D_Weights.DEFAULT)
swin3d_t, if supports_swin3db:
) from torchvision.models.video import swin3d_b, Swin3D_B_Weights
model_name_to_model_and_weights["swin3d_b"] = (swin3d_b, Swin3D_B_Weights.DEFAULT)
model_name_to_model_and_weights = { if supports_swin3dt:
"s3d": (s3d, S3D_Weights.DEFAULT), from torchvision.models.video import swin3d_t, Swin3D_T_Weights
"r3d_18": (r3d_18, R3D_18_Weights.DEFAULT), model_name_to_model_and_weights["swin3d_t"] = (swin3d_t, Swin3D_T_Weights.DEFAULT)
"swin3d_t": (swin3d_t, Swin3D_T_Weights.DEFAULT), if supports_mvitv1b:
"swin3d_b": (swin3d_b, Swin3D_B_Weights.DEFAULT), from torchvision.models.video import mvit_v1_b, MViT_V1_B_Weights
"mvit_v1_b": (mvit_v1_b, MViT_V1_B_Weights.DEFAULT), model_name_to_model_and_weights["mvit_v1_b"] = (mvit_v1_b, MViT_V1_B_Weights.DEFAULT)
"mvit_v2_s": (mvit_v2_s, MViT_V2_S_Weights.DEFAULT), if supports_mvitv2s:
} from torchvision.models.video import mvit_v2_s, MViT_V2_S_Weights
model_name_to_model_and_weights["mvit_v2_s"] = (mvit_v2_s, MViT_V2_S_Weights.DEFAULT)
def __init__(self, model_name: str, device: str or torch.device = ""): def __init__(self, model_name: str, device: str or torch.device = ""):
""" """
@ -86,9 +87,7 @@ class TorchVisionVideoClassifier:
if input_size is None: if input_size is None:
input_size = [224, 224] input_size = [224, 224]
supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False) if self.supports_transforms_v2:
if supports_transforms_v2:
from torchvision.transforms import v2 from torchvision.transforms import v2
transform = v2.Compose( transform = v2.Compose(
@ -152,6 +151,8 @@ class TorchVisionVideoClassifier:
class HuggingFaceVideoClassifier: class HuggingFaceVideoClassifier:
"""Zero-shot video classifier using Hugging Face models for various devices.""" """Zero-shot video classifier using Hugging Face models for various devices."""
supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False)
def __init__( def __init__(
self, self,
labels: List[str], labels: List[str],
@ -194,9 +195,7 @@ class HuggingFaceVideoClassifier:
if input_size is None: if input_size is None:
input_size = [224, 224] input_size = [224, 224]
supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False) if self.supports_transforms_v2:
if supports_transforms_v2:
from torchvision.transforms import v2 from torchvision.transforms import v2
transform = v2.Compose( transform = v2.Compose(

@ -225,29 +225,31 @@ class ActionRecognition:
class TorchVisionVideoClassifier: class TorchVisionVideoClassifier:
"""Classifies videos using pretrained TorchVision models; see https://pytorch.org/vision/stable/.""" """Classifies videos using pretrained TorchVision models; see https://pytorch.org/vision/stable/."""
from torchvision.models.video import ( supports_r3d = check_requirements("torchvision>=0.8.1", install=False)
MViT_V1_B_Weights, supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False)
MViT_V2_S_Weights, supports_mvitv1b = supports_s3d = check_requirements("torchvision>=0.14.0", install=False)
R3D_18_Weights, supports_mvitv2s = supports_swin3dt = supports_swin3db = check_requirements("torchvision>=0.15.0", install=False)
S3D_Weights,
Swin3D_B_Weights, model_name_to_model_and_weights = {}
Swin3D_T_Weights, if supports_r3d:
mvit_v1_b, from torchvision.models.video import r3d_18, R3D_18_Weights
mvit_v2_s, model_name_to_model_and_weights["r3d_18"] = (r3d_18, R3D_18_Weights.DEFAULT)
r3d_18, if supports_s3d:
s3d, from torchvision.models.video import s3d, S3D_Weights
swin3d_b, model_name_to_model_and_weights["s3d"] = (s3d, S3D_Weights.DEFAULT)
swin3d_t, if supports_swin3db:
) from torchvision.models.video import swin3d_b, Swin3D_B_Weights
model_name_to_model_and_weights["swin3d_b"] = (swin3d_b, Swin3D_B_Weights.DEFAULT)
model_name_to_model_and_weights = { if supports_swin3dt:
"s3d": (s3d, S3D_Weights.DEFAULT), from torchvision.models.video import swin3d_t, Swin3D_T_Weights
"r3d_18": (r3d_18, R3D_18_Weights.DEFAULT), model_name_to_model_and_weights["swin3d_t"] = (swin3d_t, Swin3D_T_Weights.DEFAULT)
"swin3d_t": (swin3d_t, Swin3D_T_Weights.DEFAULT), if supports_mvitv1b:
"swin3d_b": (swin3d_b, Swin3D_B_Weights.DEFAULT), from torchvision.models.video import mvit_v1_b, MViT_V1_B_Weights
"mvit_v1_b": (mvit_v1_b, MViT_V1_B_Weights.DEFAULT), model_name_to_model_and_weights["mvit_v1_b"] = (mvit_v1_b, MViT_V1_B_Weights.DEFAULT)
"mvit_v2_s": (mvit_v2_s, MViT_V2_S_Weights.DEFAULT), if supports_mvitv2s:
} from torchvision.models.video import mvit_v2_s, MViT_V2_S_Weights
model_name_to_model_and_weights["mvit_v2_s"] = (mvit_v2_s, MViT_V2_S_Weights.DEFAULT)
def __init__(self, model_name: str, device: str or torch.device = ""): def __init__(self, model_name: str, device: str or torch.device = ""):
""" """
@ -290,9 +292,7 @@ class TorchVisionVideoClassifier:
if input_size is None: if input_size is None:
input_size = [224, 224] input_size = [224, 224]
supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False) if self.supports_transforms_v2:
if supports_transforms_v2:
from torchvision.transforms import v2 from torchvision.transforms import v2
transform = v2.Compose( transform = v2.Compose(

Loading…
Cancel
Save