refactor: Update image transformation and supported models logic in TorchVisionVideoClassifier and HuggingFaceVideoClassifier

action-recog
fcakyon 4 months ago
parent 16efb5426b
commit 327a37f4fd
  1. 57
      examples/YOLOv8-Action-Recognition/action_recognition.py
  2. 52
      ultralytics/solutions/action_recognition.py

@ -20,29 +20,30 @@ from ultralytics.utils.torch_utils import select_device
class TorchVisionVideoClassifier:
"""Classifies videos using pretrained TorchVision models; see https://pytorch.org/vision/stable/."""
from torchvision.models.video import (
MViT_V1_B_Weights,
MViT_V2_S_Weights,
R3D_18_Weights,
S3D_Weights,
Swin3D_B_Weights,
Swin3D_T_Weights,
mvit_v1_b,
mvit_v2_s,
r3d_18,
s3d,
swin3d_b,
swin3d_t,
)
model_name_to_model_and_weights = {
"s3d": (s3d, S3D_Weights.DEFAULT),
"r3d_18": (r3d_18, R3D_18_Weights.DEFAULT),
"swin3d_t": (swin3d_t, Swin3D_T_Weights.DEFAULT),
"swin3d_b": (swin3d_b, Swin3D_B_Weights.DEFAULT),
"mvit_v1_b": (mvit_v1_b, MViT_V1_B_Weights.DEFAULT),
"mvit_v2_s": (mvit_v2_s, MViT_V2_S_Weights.DEFAULT),
}
supports_r3d = check_requirements("torchvision>=0.8.1", install=False)
supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False)
supports_mvitv1b = supports_s3d = check_requirements("torchvision>=0.14.0", install=False)
supports_mvitv2s = supports_swin3dt = supports_swin3db = check_requirements("torchvision>=0.15.0", install=False)
model_name_to_model_and_weights = {}
if supports_r3d:
from torchvision.models.video import r3d_18, R3D_18_Weights
model_name_to_model_and_weights["r3d_18"] = (r3d_18, R3D_18_Weights.DEFAULT)
if supports_s3d:
from torchvision.models.video import s3d, S3D_Weights
model_name_to_model_and_weights["s3d"] = (s3d, S3D_Weights.DEFAULT)
if supports_swin3db:
from torchvision.models.video import swin3d_b, Swin3D_B_Weights
model_name_to_model_and_weights["swin3d_b"] = (swin3d_b, Swin3D_B_Weights.DEFAULT)
if supports_swin3dt:
from torchvision.models.video import swin3d_t, Swin3D_T_Weights
model_name_to_model_and_weights["swin3d_t"] = (swin3d_t, Swin3D_T_Weights.DEFAULT)
if supports_mvitv1b:
from torchvision.models.video import mvit_v1_b, MViT_V1_B_Weights
model_name_to_model_and_weights["mvit_v1_b"] = (mvit_v1_b, MViT_V1_B_Weights.DEFAULT)
if supports_mvitv2s:
from torchvision.models.video import mvit_v2_s, MViT_V2_S_Weights
model_name_to_model_and_weights["mvit_v2_s"] = (mvit_v2_s, MViT_V2_S_Weights.DEFAULT)
def __init__(self, model_name: str, device: str or torch.device = ""):
"""
@ -86,9 +87,7 @@ class TorchVisionVideoClassifier:
if input_size is None:
input_size = [224, 224]
supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False)
if supports_transforms_v2:
if self.supports_transforms_v2:
from torchvision.transforms import v2
transform = v2.Compose(
@ -152,6 +151,8 @@ class TorchVisionVideoClassifier:
class HuggingFaceVideoClassifier:
"""Zero-shot video classifier using Hugging Face models for various devices."""
supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False)
def __init__(
self,
labels: List[str],
@ -194,9 +195,7 @@ class HuggingFaceVideoClassifier:
if input_size is None:
input_size = [224, 224]
supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False)
if supports_transforms_v2:
if self.supports_transforms_v2:
from torchvision.transforms import v2
transform = v2.Compose(

@ -225,29 +225,31 @@ class ActionRecognition:
class TorchVisionVideoClassifier:
"""Classifies videos using pretrained TorchVision models; see https://pytorch.org/vision/stable/."""
from torchvision.models.video import (
MViT_V1_B_Weights,
MViT_V2_S_Weights,
R3D_18_Weights,
S3D_Weights,
Swin3D_B_Weights,
Swin3D_T_Weights,
mvit_v1_b,
mvit_v2_s,
r3d_18,
s3d,
swin3d_b,
swin3d_t,
)
model_name_to_model_and_weights = {
"s3d": (s3d, S3D_Weights.DEFAULT),
"r3d_18": (r3d_18, R3D_18_Weights.DEFAULT),
"swin3d_t": (swin3d_t, Swin3D_T_Weights.DEFAULT),
"swin3d_b": (swin3d_b, Swin3D_B_Weights.DEFAULT),
"mvit_v1_b": (mvit_v1_b, MViT_V1_B_Weights.DEFAULT),
"mvit_v2_s": (mvit_v2_s, MViT_V2_S_Weights.DEFAULT),
}
supports_r3d = check_requirements("torchvision>=0.8.1", install=False)
supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False)
supports_mvitv1b = supports_s3d = check_requirements("torchvision>=0.14.0", install=False)
supports_mvitv2s = supports_swin3dt = supports_swin3db = check_requirements("torchvision>=0.15.0", install=False)
model_name_to_model_and_weights = {}
if supports_r3d:
from torchvision.models.video import r3d_18, R3D_18_Weights
model_name_to_model_and_weights["r3d_18"] = (r3d_18, R3D_18_Weights.DEFAULT)
if supports_s3d:
from torchvision.models.video import s3d, S3D_Weights
model_name_to_model_and_weights["s3d"] = (s3d, S3D_Weights.DEFAULT)
if supports_swin3db:
from torchvision.models.video import swin3d_b, Swin3D_B_Weights
model_name_to_model_and_weights["swin3d_b"] = (swin3d_b, Swin3D_B_Weights.DEFAULT)
if supports_swin3dt:
from torchvision.models.video import swin3d_t, Swin3D_T_Weights
model_name_to_model_and_weights["swin3d_t"] = (swin3d_t, Swin3D_T_Weights.DEFAULT)
if supports_mvitv1b:
from torchvision.models.video import mvit_v1_b, MViT_V1_B_Weights
model_name_to_model_and_weights["mvit_v1_b"] = (mvit_v1_b, MViT_V1_B_Weights.DEFAULT)
if supports_mvitv2s:
from torchvision.models.video import mvit_v2_s, MViT_V2_S_Weights
model_name_to_model_and_weights["mvit_v2_s"] = (mvit_v2_s, MViT_V2_S_Weights.DEFAULT)
def __init__(self, model_name: str, device: str or torch.device = ""):
"""
@ -290,9 +292,7 @@ class TorchVisionVideoClassifier:
if input_size is None:
input_size = [224, 224]
supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False)
if supports_transforms_v2:
if self.supports_transforms_v2:
from torchvision.transforms import v2
transform = v2.Compose(

Loading…
Cancel
Save