@ -225,37 +225,6 @@ class ActionRecognition:
class TorchVisionVideoClassifier :
""" Classifies videos using pretrained TorchVision models; see https://pytorch.org/vision/stable/. """
supports_r3d = check_requirements ( " torchvision>=0.8.1 " , install = False )
supports_transforms_v2 = check_requirements ( " torchvision>=0.16.0 " , install = False )
supports_mvitv1b = supports_s3d = check_requirements ( " torchvision>=0.14.0 " , install = False )
supports_mvitv2s = supports_swin3dt = supports_swin3db = check_requirements ( " torchvision>=0.15.0 " , install = False )
model_name_to_model_and_weights = { }
if supports_r3d :
from torchvision . models . video import R3D_18_Weights , r3d_18
model_name_to_model_and_weights [ " r3d_18 " ] = ( r3d_18 , R3D_18_Weights . DEFAULT )
if supports_s3d :
from torchvision . models . video import S3D_Weights , s3d
model_name_to_model_and_weights [ " s3d " ] = ( s3d , S3D_Weights . DEFAULT )
if supports_swin3db :
from torchvision . models . video import Swin3D_B_Weights , swin3d_b
model_name_to_model_and_weights [ " swin3d_b " ] = ( swin3d_b , Swin3D_B_Weights . DEFAULT )
if supports_swin3dt :
from torchvision . models . video import Swin3D_T_Weights , swin3d_t
model_name_to_model_and_weights [ " swin3d_t " ] = ( swin3d_t , Swin3D_T_Weights . DEFAULT )
if supports_mvitv1b :
from torchvision . models . video import MViT_V1_B_Weights , mvit_v1_b
model_name_to_model_and_weights [ " mvit_v1_b " ] = ( mvit_v1_b , MViT_V1_B_Weights . DEFAULT )
if supports_mvitv2s :
from torchvision . models . video import MViT_V2_S_Weights , mvit_v2_s
model_name_to_model_and_weights [ " mvit_v2_s " ] = ( mvit_v2_s , MViT_V2_S_Weights . DEFAULT )
def __init__ ( self , model_name : str , device : str or torch . device = " " ) :
"""
Initialize the VideoClassifier with the specified model name and device .
@ -267,6 +236,39 @@ class TorchVisionVideoClassifier:
Raises :
ValueError : If an invalid model name is provided .
"""
supports_r3d = check_requirements ( " torchvision>=0.8.1 " , install = False )
self . supports_transforms_v2 = check_requirements ( " torchvision>=0.16.0 " , install = False )
supports_mvitv1b = supports_s3d = check_requirements ( " torchvision>=0.14.0 " , install = False )
supports_mvitv2s = supports_swin3dt = supports_swin3db = check_requirements ( " torchvision>=0.15.0 " , install = False )
model_name_to_model_and_weights = { }
if supports_r3d :
from torchvision . models . video import R3D_18_Weights , r3d_18
model_name_to_model_and_weights [ " r3d_18 " ] = ( r3d_18 , R3D_18_Weights . DEFAULT )
if supports_s3d :
from torchvision . models . video import S3D_Weights , s3d
model_name_to_model_and_weights [ " s3d " ] = ( s3d , S3D_Weights . DEFAULT )
if supports_swin3db :
from torchvision . models . video import Swin3D_B_Weights , swin3d_b
model_name_to_model_and_weights [ " swin3d_b " ] = ( swin3d_b , Swin3D_B_Weights . DEFAULT )
if supports_swin3dt :
from torchvision . models . video import Swin3D_T_Weights , swin3d_t
model_name_to_model_and_weights [ " swin3d_t " ] = ( swin3d_t , Swin3D_T_Weights . DEFAULT )
if supports_mvitv1b :
from torchvision . models . video import MViT_V1_B_Weights , mvit_v1_b
model_name_to_model_and_weights [ " mvit_v1_b " ] = ( mvit_v1_b , MViT_V1_B_Weights . DEFAULT )
if supports_mvitv2s :
from torchvision . models . video import MViT_V2_S_Weights , mvit_v2_s
model_name_to_model_and_weights [ " mvit_v2_s " ] = ( mvit_v2_s , MViT_V2_S_Weights . DEFAULT )
self . model_name_to_model_and_weights = model_name_to_model_and_weights
if model_name not in self . model_name_to_model_and_weights :
raise ValueError ( f " Invalid model name ' { model_name } ' . Available models: { self . available_model_names ( ) } " )
model , self . weights = self . model_name_to_model_and_weights [ model_name ]