|
|
|
@ -232,24 +232,29 @@ class TorchVisionVideoClassifier: |
|
|
|
|
|
|
|
|
|
model_name_to_model_and_weights = {} |
|
|
|
|
if supports_r3d: |
|
|
|
|
from torchvision.models.video import r3d_18, R3D_18_Weights |
|
|
|
|
from torchvision.models.video import R3D_18_Weights, r3d_18 |
|
|
|
|
|
|
|
|
|
model_name_to_model_and_weights["r3d_18"] = (r3d_18, R3D_18_Weights.DEFAULT) |
|
|
|
|
if supports_s3d: |
|
|
|
|
from torchvision.models.video import s3d, S3D_Weights |
|
|
|
|
from torchvision.models.video import S3D_Weights, s3d |
|
|
|
|
|
|
|
|
|
model_name_to_model_and_weights["s3d"] = (s3d, S3D_Weights.DEFAULT) |
|
|
|
|
if supports_swin3db: |
|
|
|
|
from torchvision.models.video import swin3d_b, Swin3D_B_Weights |
|
|
|
|
from torchvision.models.video import Swin3D_B_Weights, swin3d_b |
|
|
|
|
|
|
|
|
|
model_name_to_model_and_weights["swin3d_b"] = (swin3d_b, Swin3D_B_Weights.DEFAULT) |
|
|
|
|
if supports_swin3dt: |
|
|
|
|
from torchvision.models.video import swin3d_t, Swin3D_T_Weights |
|
|
|
|
from torchvision.models.video import Swin3D_T_Weights, swin3d_t |
|
|
|
|
|
|
|
|
|
model_name_to_model_and_weights["swin3d_t"] = (swin3d_t, Swin3D_T_Weights.DEFAULT) |
|
|
|
|
if supports_mvitv1b: |
|
|
|
|
from torchvision.models.video import mvit_v1_b, MViT_V1_B_Weights |
|
|
|
|
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b |
|
|
|
|
|
|
|
|
|
model_name_to_model_and_weights["mvit_v1_b"] = (mvit_v1_b, MViT_V1_B_Weights.DEFAULT) |
|
|
|
|
if supports_mvitv2s: |
|
|
|
|
from torchvision.models.video import mvit_v2_s, MViT_V2_S_Weights |
|
|
|
|
model_name_to_model_and_weights["mvit_v2_s"] = (mvit_v2_s, MViT_V2_S_Weights.DEFAULT) |
|
|
|
|
from torchvision.models.video import MViT_V2_S_Weights, mvit_v2_s |
|
|
|
|
|
|
|
|
|
model_name_to_model_and_weights["mvit_v2_s"] = (mvit_v2_s, MViT_V2_S_Weights.DEFAULT) |
|
|
|
|
|
|
|
|
|
def __init__(self, model_name: str, device: str or torch.device = ""): |
|
|
|
|
""" |
|
|
|
|