Fix `TORCHVISION_0_18` for `allow_empty=True` (#14415)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
pull/14417/head
Glenn Jocher 4 months ago committed by GitHub
parent 21ca235681
commit 157b0251a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 4
      ultralytics/data/dataset.py
  2. 1
      ultralytics/utils/torch_utils.py

@ -15,7 +15,7 @@ from torch.utils.data import ConcatDataset
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
from ultralytics.utils.ops import resample_segments from ultralytics.utils.ops import resample_segments
from ultralytics.utils.torch_utils import TORCH_1_13 from ultralytics.utils.torch_utils import TORCHVISION_0_18
from .augment import ( from .augment import (
Compose, Compose,
@ -417,7 +417,7 @@ class ClassificationDataset:
import torchvision # scope for faster 'import ultralytics' import torchvision # scope for faster 'import ultralytics'
# Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import # Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import
if TORCH_1_13: # 'allow_empty' argument first introduced in torch 1.13 if TORCHVISION_0_18: # 'allow_empty' argument first introduced in torchvision 0.18
self.base = torchvision.datasets.ImageFolder(root=root, allow_empty=True) self.base = torchvision.datasets.ImageFolder(root=root, allow_empty=True)
else: else:
self.base = torchvision.datasets.ImageFolder(root=root) self.base = torchvision.datasets.ImageFolder(root=root)

@ -40,6 +40,7 @@ TORCH_2_0 = check_version(torch.__version__, "2.0.0")
TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0") TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")
TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0") TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")
TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0") TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")
TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, "0.18.0")
@contextmanager @contextmanager

Loading…
Cancel
Save