refactor: Update image transformation logic in TorchVisionVideoClassifier and HuggingFaceVideoClassifier

action-recog
fcakyon 4 months ago
parent b5b82d93fe
commit 16efb5426b
  1. 74
      examples/YOLOv8-Action-Recognition/action_recognition.py
  2. 70
      ultralytics/solutions/action_recognition.py

@ -9,10 +9,10 @@ from urllib.parse import urlparse
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
from transformers import AutoModel, AutoProcessor
from ultralytics import YOLO from ultralytics import YOLO
from ultralytics.data.loaders import get_best_youtube_url from ultralytics.data.loaders import get_best_youtube_url
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.plotting import Annotator from ultralytics.utils.plotting import Annotator
from ultralytics.utils.torch_utils import select_device from ultralytics.utils.torch_utils import select_device
@ -82,17 +82,32 @@ class TorchVisionVideoClassifier:
Returns: Returns:
torch.Tensor: Preprocessed crops as a tensor with dimensions (1, T, C, H, W). torch.Tensor: Preprocessed crops as a tensor with dimensions (1, T, C, H, W).
""" """
if input_size is None: if input_size is None:
input_size = [224, 224] input_size = [224, 224]
from torchvision.transforms import v2
supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False)
transform = v2.Compose(
[ if supports_transforms_v2:
v2.ToDtype(torch.float32, scale=True), from torchvision.transforms import v2
v2.Resize(input_size, antialias=True),
v2.Normalize(mean=self.weights.transforms().mean, std=self.weights.transforms().std), transform = v2.Compose(
] [
) v2.ToDtype(torch.float32, scale=True),
v2.Resize(input_size, antialias=True),
v2.Normalize(mean=self.weights.transforms().mean, std=self.weights.transforms().std),
]
)
else:
from torchvision.transforms import transforms
transform = transforms.Compose(
[
transforms.Lambda(lambda x: x.float() / 255.0),
transforms.Resize(input_size),
transforms.Normalize(mean=self.weights.transforms().mean, std=self.weights.transforms().std),
]
)
processed_crops = [transform(torch.from_numpy(crop).permute(2, 0, 1)) for crop in crops] processed_crops = [transform(torch.from_numpy(crop).permute(2, 0, 1)) for crop in crops]
return torch.stack(processed_crops).unsqueeze(0).permute(0, 2, 1, 3, 4).to(self.device) return torch.stack(processed_crops).unsqueeze(0).permute(0, 2, 1, 3, 4).to(self.device)
@ -153,6 +168,9 @@ class HuggingFaceVideoClassifier:
device (str or torch.device, optional): The device to run the model on. Defaults to "". device (str or torch.device, optional): The device to run the model on. Defaults to "".
fp16 (bool, optional): Whether to use FP16 for inference. Defaults to False. fp16 (bool, optional): Whether to use FP16 for inference. Defaults to False.
""" """
check_requirements("transformers")
from transformers import AutoModel, AutoProcessor
self.fp16 = fp16 self.fp16 = fp16
self.labels = labels self.labels = labels
self.device = select_device(device) self.device = select_device(device)
@ -175,17 +193,31 @@ class HuggingFaceVideoClassifier:
""" """
if input_size is None: if input_size is None:
input_size = [224, 224] input_size = [224, 224]
from torchvision import transforms
supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False)
transform = transforms.Compose(
[ if supports_transforms_v2:
transforms.Lambda(lambda x: x.float() / 255.0), from torchvision.transforms import v2
transforms.Resize(input_size),
transforms.Normalize( transform = v2.Compose(
mean=self.processor.image_processor.image_mean, std=self.processor.image_processor.image_std [
), v2.ToDtype(torch.float32, scale=True),
] v2.Resize(input_size, antialias=True),
) v2.Normalize(mean=self.weights.transforms().mean, std=self.weights.transforms().std),
]
)
else:
from torchvision import transforms
transform = transforms.Compose(
[
transforms.Lambda(lambda x: x.float() / 255.0),
transforms.Resize(input_size),
transforms.Normalize(
mean=self.processor.image_processor.image_mean, std=self.processor.image_processor.image_std
),
]
)
processed_crops = [transform(torch.from_numpy(crop).permute(2, 0, 1)) for crop in crops] # (T, C, H, W) processed_crops = [transform(torch.from_numpy(crop).permute(2, 0, 1)) for crop in crops] # (T, C, H, W)
output = torch.stack(processed_crops).unsqueeze(0).to(self.device) # (1, T, C, H, W) output = torch.stack(processed_crops).unsqueeze(0).to(self.device) # (1, T, C, H, W)

@ -289,15 +289,31 @@ class TorchVisionVideoClassifier:
""" """
if input_size is None: if input_size is None:
input_size = [224, 224] input_size = [224, 224]
from torchvision.transforms import v2
supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False)
transform = v2.Compose(
[ if supports_transforms_v2:
v2.ToDtype(torch.float32, scale=True), from torchvision.transforms import v2
v2.Resize(input_size, antialias=True),
v2.Normalize(mean=self.weights.transforms().mean, std=self.weights.transforms().std), transform = v2.Compose(
] [
) v2.ToDtype(torch.float32, scale=True),
v2.Resize(input_size, antialias=True),
v2.Normalize(mean=self.weights.transforms().mean, std=self.weights.transforms().std),
]
)
else:
from torchvision import transforms
transform = transforms.Compose(
[
transforms.Lambda(lambda x: x.float() / 255.0),
transforms.Resize(input_size),
transforms.Normalize(
mean=self.processor.image_processor.image_mean, std=self.processor.image_processor.image_std
),
]
)
processed_crops = [transform(torch.from_numpy(crop).permute(2, 0, 1)) for crop in crops] processed_crops = [transform(torch.from_numpy(crop).permute(2, 0, 1)) for crop in crops]
return torch.stack(processed_crops).unsqueeze(0).permute(0, 2, 1, 3, 4).to(self.device) return torch.stack(processed_crops).unsqueeze(0).permute(0, 2, 1, 3, 4).to(self.device)
@ -383,17 +399,31 @@ class HuggingFaceVideoClassifier:
""" """
if input_size is None: if input_size is None:
input_size = [224, 224] input_size = [224, 224]
from torchvision import transforms
supports_transforms_v2 = check_requirements("torchvision>=0.16.0", install=False)
transform = transforms.Compose(
[ if supports_transforms_v2:
transforms.Lambda(lambda x: x.float() / 255.0), from torchvision.transforms import v2
transforms.Resize(input_size),
transforms.Normalize( transform = v2.Compose(
mean=self.processor.image_processor.image_mean, std=self.processor.image_processor.image_std [
), v2.ToDtype(torch.float32, scale=True),
] v2.Resize(input_size, antialias=True),
) v2.Normalize(mean=self.weights.transforms().mean, std=self.weights.transforms().std),
]
)
else:
from torchvision import transforms
transform = transforms.Compose(
[
transforms.Lambda(lambda x: x.float() / 255.0),
transforms.Resize(input_size),
transforms.Normalize(
mean=self.processor.image_processor.image_mean, std=self.processor.image_processor.image_std
),
]
)
processed_crops = [transform(torch.from_numpy(crop).permute(2, 0, 1)) for crop in crops] # (T, C, H, W) processed_crops = [transform(torch.from_numpy(crop).permute(2, 0, 1)) for crop in crops] # (T, C, H, W)
output = torch.stack(processed_crops).unsqueeze(0).to(self.device) # (1, T, C, H, W) output = torch.stack(processed_crops).unsqueeze(0).to(self.device) # (1, T, C, H, W)

Loading…
Cancel
Save