chore: Refactor TorchVisionVideoClassifier initialization

Refactor the `__init__` method of the `TorchVisionVideoClassifier` class to improve code readability and maintainability. Move the initialization of `supports_r3d`, `supports_transforms_v2`, `supports_mvitv1b`, `supports_s3d`, `supports_mvitv2s`, and `supports_swin3dt` to the beginning of the method for better organization. Also, move the initialization of `model_name_to_model_and_weights` to a separate block for clarity. This refactoring enhances the overall structure of the code and makes it easier to understand and maintain.
action-recog
fcakyon 7 months ago
parent e559bc0119
commit db1e5f6c1b
  1. 64
      ultralytics/solutions/action_recognition.py

@ -225,37 +225,6 @@ class ActionRecognition:
class TorchVisionVideoClassifier:
"""Classifies videos using pretrained TorchVision models; see https://pytorch.org/vision/stable/."""
supports_r3d = check_requirements("torchvision>=0.8.1", install=False)
supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False)
supports_mvitv1b = supports_s3d = check_requirements("torchvision>=0.14.0", install=False)
supports_mvitv2s = supports_swin3dt = supports_swin3db = check_requirements("torchvision>=0.15.0", install=False)
model_name_to_model_and_weights = {}
if supports_r3d:
from torchvision.models.video import R3D_18_Weights, r3d_18
model_name_to_model_and_weights["r3d_18"] = (r3d_18, R3D_18_Weights.DEFAULT)
if supports_s3d:
from torchvision.models.video import S3D_Weights, s3d
model_name_to_model_and_weights["s3d"] = (s3d, S3D_Weights.DEFAULT)
if supports_swin3db:
from torchvision.models.video import Swin3D_B_Weights, swin3d_b
model_name_to_model_and_weights["swin3d_b"] = (swin3d_b, Swin3D_B_Weights.DEFAULT)
if supports_swin3dt:
from torchvision.models.video import Swin3D_T_Weights, swin3d_t
model_name_to_model_and_weights["swin3d_t"] = (swin3d_t, Swin3D_T_Weights.DEFAULT)
if supports_mvitv1b:
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
model_name_to_model_and_weights["mvit_v1_b"] = (mvit_v1_b, MViT_V1_B_Weights.DEFAULT)
if supports_mvitv2s:
from torchvision.models.video import MViT_V2_S_Weights, mvit_v2_s
model_name_to_model_and_weights["mvit_v2_s"] = (mvit_v2_s, MViT_V2_S_Weights.DEFAULT)
def __init__(self, model_name: str, device: str or torch.device = ""):
"""
Initialize the VideoClassifier with the specified model name and device.
@ -267,6 +236,39 @@ class TorchVisionVideoClassifier:
Raises:
ValueError: If an invalid model name is provided.
"""
supports_r3d = check_requirements("torchvision>=0.8.1", install=False)
self.supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False)
supports_mvitv1b = supports_s3d = check_requirements("torchvision>=0.14.0", install=False)
supports_mvitv2s = supports_swin3dt = supports_swin3db = check_requirements("torchvision>=0.15.0", install=False)
model_name_to_model_and_weights = {}
if supports_r3d:
from torchvision.models.video import R3D_18_Weights, r3d_18
model_name_to_model_and_weights["r3d_18"] = (r3d_18, R3D_18_Weights.DEFAULT)
if supports_s3d:
from torchvision.models.video import S3D_Weights, s3d
model_name_to_model_and_weights["s3d"] = (s3d, S3D_Weights.DEFAULT)
if supports_swin3db:
from torchvision.models.video import Swin3D_B_Weights, swin3d_b
model_name_to_model_and_weights["swin3d_b"] = (swin3d_b, Swin3D_B_Weights.DEFAULT)
if supports_swin3dt:
from torchvision.models.video import Swin3D_T_Weights, swin3d_t
model_name_to_model_and_weights["swin3d_t"] = (swin3d_t, Swin3D_T_Weights.DEFAULT)
if supports_mvitv1b:
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
model_name_to_model_and_weights["mvit_v1_b"] = (mvit_v1_b, MViT_V1_B_Weights.DEFAULT)
if supports_mvitv2s:
from torchvision.models.video import MViT_V2_S_Weights, mvit_v2_s
model_name_to_model_and_weights["mvit_v2_s"] = (mvit_v2_s, MViT_V2_S_Weights.DEFAULT)
self.model_name_to_model_and_weights = model_name_to_model_and_weights
if model_name not in self.model_name_to_model_and_weights:
raise ValueError(f"Invalid model name '{model_name}'. Available models: {self.available_model_names()}")
model, self.weights = self.model_name_to_model_and_weights[model_name]

Loading…
Cancel
Save