|
|
@ -20,29 +20,30 @@ from ultralytics.utils.torch_utils import select_device |
|
|
|
class TorchVisionVideoClassifier: |
|
|
|
class TorchVisionVideoClassifier: |
|
|
|
"""Classifies videos using pretrained TorchVision models; see https://pytorch.org/vision/stable/.""" |
|
|
|
"""Classifies videos using pretrained TorchVision models; see https://pytorch.org/vision/stable/.""" |
|
|
|
|
|
|
|
|
|
|
|
from torchvision.models.video import ( |
|
|
|
supports_r3d = check_requirements("torchvision>=0.8.1", install=False) |
|
|
|
MViT_V1_B_Weights, |
|
|
|
supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False) |
|
|
|
MViT_V2_S_Weights, |
|
|
|
supports_mvitv1b = supports_s3d = check_requirements("torchvision>=0.14.0", install=False) |
|
|
|
R3D_18_Weights, |
|
|
|
supports_mvitv2s = supports_swin3dt = supports_swin3db = check_requirements("torchvision>=0.15.0", install=False) |
|
|
|
S3D_Weights, |
|
|
|
|
|
|
|
Swin3D_B_Weights, |
|
|
|
model_name_to_model_and_weights = {} |
|
|
|
Swin3D_T_Weights, |
|
|
|
if supports_r3d: |
|
|
|
mvit_v1_b, |
|
|
|
from torchvision.models.video import r3d_18, R3D_18_Weights |
|
|
|
mvit_v2_s, |
|
|
|
model_name_to_model_and_weights["r3d_18"] = (r3d_18, R3D_18_Weights.DEFAULT) |
|
|
|
r3d_18, |
|
|
|
if supports_s3d: |
|
|
|
s3d, |
|
|
|
from torchvision.models.video import s3d, S3D_Weights |
|
|
|
swin3d_b, |
|
|
|
model_name_to_model_and_weights["s3d"] = (s3d, S3D_Weights.DEFAULT) |
|
|
|
swin3d_t, |
|
|
|
if supports_swin3db: |
|
|
|
) |
|
|
|
from torchvision.models.video import swin3d_b, Swin3D_B_Weights |
|
|
|
|
|
|
|
model_name_to_model_and_weights["swin3d_b"] = (swin3d_b, Swin3D_B_Weights.DEFAULT) |
|
|
|
model_name_to_model_and_weights = { |
|
|
|
if supports_swin3dt: |
|
|
|
"s3d": (s3d, S3D_Weights.DEFAULT), |
|
|
|
from torchvision.models.video import swin3d_t, Swin3D_T_Weights |
|
|
|
"r3d_18": (r3d_18, R3D_18_Weights.DEFAULT), |
|
|
|
model_name_to_model_and_weights["swin3d_t"] = (swin3d_t, Swin3D_T_Weights.DEFAULT) |
|
|
|
"swin3d_t": (swin3d_t, Swin3D_T_Weights.DEFAULT), |
|
|
|
if supports_mvitv1b: |
|
|
|
"swin3d_b": (swin3d_b, Swin3D_B_Weights.DEFAULT), |
|
|
|
from torchvision.models.video import mvit_v1_b, MViT_V1_B_Weights |
|
|
|
"mvit_v1_b": (mvit_v1_b, MViT_V1_B_Weights.DEFAULT), |
|
|
|
model_name_to_model_and_weights["mvit_v1_b"] = (mvit_v1_b, MViT_V1_B_Weights.DEFAULT) |
|
|
|
"mvit_v2_s": (mvit_v2_s, MViT_V2_S_Weights.DEFAULT), |
|
|
|
if supports_mvitv2s: |
|
|
|
} |
|
|
|
from torchvision.models.video import mvit_v2_s, MViT_V2_S_Weights |
|
|
|
|
|
|
|
model_name_to_model_and_weights["mvit_v2_s"] = (mvit_v2_s, MViT_V2_S_Weights.DEFAULT) |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, model_name: str, device: str or torch.device = ""): |
|
|
|
def __init__(self, model_name: str, device: str or torch.device = ""): |
|
|
|
""" |
|
|
|
""" |
|
|
@ -86,9 +87,7 @@ class TorchVisionVideoClassifier: |
|
|
|
if input_size is None: |
|
|
|
if input_size is None: |
|
|
|
input_size = [224, 224] |
|
|
|
input_size = [224, 224] |
|
|
|
|
|
|
|
|
|
|
|
supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False) |
|
|
|
if self.supports_transforms_v2: |
|
|
|
|
|
|
|
|
|
|
|
if supports_transforms_v2: |
|
|
|
|
|
|
|
from torchvision.transforms import v2 |
|
|
|
from torchvision.transforms import v2 |
|
|
|
|
|
|
|
|
|
|
|
transform = v2.Compose( |
|
|
|
transform = v2.Compose( |
|
|
@ -152,6 +151,8 @@ class TorchVisionVideoClassifier: |
|
|
|
class HuggingFaceVideoClassifier: |
|
|
|
class HuggingFaceVideoClassifier: |
|
|
|
"""Zero-shot video classifier using Hugging Face models for various devices.""" |
|
|
|
"""Zero-shot video classifier using Hugging Face models for various devices.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False) |
|
|
|
|
|
|
|
|
|
|
|
def __init__( |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
self, |
|
|
|
labels: List[str], |
|
|
|
labels: List[str], |
|
|
@ -194,9 +195,7 @@ class HuggingFaceVideoClassifier: |
|
|
|
if input_size is None: |
|
|
|
if input_size is None: |
|
|
|
input_size = [224, 224] |
|
|
|
input_size = [224, 224] |
|
|
|
|
|
|
|
|
|
|
|
supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False) |
|
|
|
if self.supports_transforms_v2: |
|
|
|
|
|
|
|
|
|
|
|
if supports_transforms_v2: |
|
|
|
|
|
|
|
from torchvision.transforms import v2 |
|
|
|
from torchvision.transforms import v2 |
|
|
|
|
|
|
|
|
|
|
|
transform = v2.Compose( |
|
|
|
transform = v2.Compose( |
|
|
|