diff --git a/docs/en/models/yolov10.md b/docs/en/models/yolov10.md index 06057b542b..4646eb5d45 100644 --- a/docs/en/models/yolov10.md +++ b/docs/en/models/yolov10.md @@ -198,9 +198,9 @@ Due to the new operations introduced with YOLOv10, not all export formats provid | [OpenVINO](../integrations/openvino.md) | ✅ | | [TensorRT](../integrations/tensorrt.md) | ✅ | | [CoreML](../integrations/coreml.md) | ❌ | -| [TF SavedModel](../integrations/tf-savedmodel.md) | ❌ | -| [TF GraphDef](../integrations/tf-graphdef.md) | ❌ | -| [TF Lite](../integrations/tflite.md) | ❌ | +| [TF SavedModel](../integrations/tf-savedmodel.md) | ✅ | +| [TF GraphDef](../integrations/tf-graphdef.md) | ✅ | +| [TF Lite](../integrations/tflite.md) | ✅ | | [TF Edge TPU](../integrations/edge-tpu.md) | ❌ | | [TF.js](../integrations/tfjs.md) | ❌ | | [PaddlePaddle](../integrations/paddlepaddle.md) | ❌ | diff --git a/docs/en/reference/utils/patches.md b/docs/en/reference/utils/patches.md index 444a274237..50422d8c82 100644 --- a/docs/en/reference/utils/patches.md +++ b/docs/en/reference/utils/patches.md @@ -23,6 +23,10 @@ keywords: Ultralytics, utils, patches, imread, imwrite, imshow, torch_save, Open



+## ::: ultralytics.utils.patches.torch_load + +



+ ## ::: ultralytics.utils.patches.torch_save

diff --git a/docs/en/reference/utils/torch_utils.md b/docs/en/reference/utils/torch_utils.md index 6a48fec741..dd4c364d98 100644 --- a/docs/en/reference/utils/torch_utils.md +++ b/docs/en/reference/utils/torch_utils.md @@ -27,6 +27,10 @@ keywords: Ultralytics, torch utils, model optimization, device selection, infere



+## ::: ultralytics.utils.torch_utils.autocast + +



+ ## ::: ultralytics.utils.torch_utils.get_cpu_info



diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 362d4eead9..8a0415c929 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.2.63" +__version__ = "8.2.64" import os diff --git a/ultralytics/data/augment.py b/ultralytics/data/augment.py index 90cdd23133..cd084f3e69 100644 --- a/ultralytics/data/augment.py +++ b/ultralytics/data/augment.py @@ -2322,7 +2322,7 @@ def classify_transforms( size=224, mean=DEFAULT_MEAN, std=DEFAULT_STD, - interpolation=Image.BILINEAR, + interpolation="BILINEAR", crop_fraction: float = DEFAULT_CROP_FRACTION, ): """ @@ -2337,7 +2337,7 @@ def classify_transforms( tuple, it defines (height, width). mean (tuple): Mean values for each RGB channel used in normalization. std (tuple): Standard deviation values for each RGB channel used in normalization. - interpolation (int): Interpolation method for resizing. + interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'. crop_fraction (float): Fraction of the image to be cropped. Returns: @@ -2360,7 +2360,7 @@ def classify_transforms( # Aspect ratio is preserved, crops center within image, no borders are added, image is lost if scale_size[0] == scale_size[1]: # Simple case, use torchvision built-in Resize with the shortest edge mode (scalar size arg) - tfl = [T.Resize(scale_size[0], interpolation=interpolation)] + tfl = [T.Resize(scale_size[0], interpolation=getattr(T.InterpolationMode, interpolation))] else: # Resize the shortest edge to matching target dim for non-square target tfl = [T.Resize(scale_size)] @@ -2389,7 +2389,7 @@ def classify_augmentations( hsv_v=0.4, # image HSV-Value augmentation (fraction) force_color_jitter=False, erasing=0.0, - interpolation=Image.BILINEAR, + interpolation="BILINEAR", ): """ Creates a composition of image augmentation transforms for classification tasks. @@ -2411,7 +2411,7 @@ def classify_augmentations( hsv_v (float): Image HSV-Value augmentation factor. force_color_jitter (bool): Whether to apply color jitter even if auto augment is enabled. erasing (float): Probability of random erasing. - interpolation (int): Interpolation method. + interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'. Returns: (torchvision.transforms.Compose): A composition of image augmentation transforms. @@ -2427,6 +2427,7 @@ def classify_augmentations( raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)") scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range + interpolation = getattr(T.InterpolationMode, interpolation) primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)] if hflip > 0.0: primary_tfl.append(T.RandomHorizontalFlip(p=hflip)) diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index 59eb1ae8fd..6e64fe05d8 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -885,6 +885,8 @@ class Exporter: output_integer_quantized_tflite=self.args.int8, quant_type="per-tensor", # "per-tensor" (faster) or "per-channel" (slower but more accurate) custom_input_op_name_np_data_path=np_data, + disable_group_convolution=True, # for end-to-end model compatibility + enable_batchmatmul_unfold=True, # for end-to-end model compatibility ) yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 3fb3e0b852..4415ba94eb 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -41,8 +41,10 @@ from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_m from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command from ultralytics.utils.files import get_latest_run from ultralytics.utils.torch_utils import ( + TORCH_1_13, EarlyStopping, ModelEMA, + autocast, convert_optimizer_state_dict_to_fp16, init_seeds, one_cycle, @@ -264,7 +266,11 @@ class BaseTrainer: if RANK > -1 and world_size > 1: # DDP dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None) self.amp = bool(self.amp) # as boolean - self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp) + self.scaler = ( + torch.amp.GradScaler("cuda", enabled=self.amp) + if TORCH_1_13 + else torch.cuda.amp.GradScaler(enabled=self.amp) + ) if world_size > 1: self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True) @@ -376,7 +382,7 @@ class BaseTrainer: x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum]) # Forward - with torch.cuda.amp.autocast(self.amp): + with autocast(self.amp): batch = self.preprocess_batch(batch) self.loss, self.loss_items = self.model(batch) if RANK != -1: diff --git a/ultralytics/models/nas/model.py b/ultralytics/models/nas/model.py index fd444f1389..90446c585e 100644 --- a/ultralytics/models/nas/model.py +++ b/ultralytics/models/nas/model.py @@ -17,7 +17,7 @@ import torch from ultralytics.engine.model import Model from ultralytics.utils.downloads import attempt_download_asset -from ultralytics.utils.torch_utils import model_info, smart_inference_mode +from ultralytics.utils.torch_utils import model_info from .predict import NASPredictor from .val import NASValidator @@ -50,16 +50,25 @@ class NAS(Model): assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models." super().__init__(model, task="detect") - @smart_inference_mode() - def _load(self, weights: str, task: str): + def _load(self, weights: str, task=None) -> None: """Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided.""" import super_gradients suffix = Path(weights).suffix if suffix == ".pt": self.model = torch.load(attempt_download_asset(weights)) + elif suffix == "": self.model = super_gradients.training.models.get(weights, pretrained_weights="coco") + + # Override the forward method to ignore additional arguments + def new_forward(x, *args, **kwargs): + """Ignore additional __call__ arguments.""" + return self.model._original_forward(x) + + self.model._original_forward = self.model.forward + self.model.forward = new_forward + # Standardize model self.model.fuse = lambda verbose=True: self.model self.model.stride = torch.tensor([32]) diff --git a/ultralytics/models/utils/ops.py b/ultralytics/models/utils/ops.py index 4f66feef65..64d10e36bc 100644 --- a/ultralytics/models/utils/ops.py +++ b/ultralytics/models/utils/ops.py @@ -133,7 +133,7 @@ class HungarianMatcher(nn.Module): # sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0]) # tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2]) # - # with torch.cuda.amp.autocast(False): + # with torch.amp.autocast("cuda", enabled=False): # # binary cross entropy cost # pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none') # neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none') diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index fcb80ee6e1..cde35a5776 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -587,14 +587,21 @@ class AutoBackend(nn.Module): if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well # Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695 # xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models - x[:, [0, 2]] *= w - x[:, [1, 3]] *= h + if x.shape[-1] == 6: # end-to-end model + x[:, :, [0, 2]] *= w + x[:, :, [1, 3]] *= h + else: + x[:, [0, 2]] *= w + x[:, [1, 3]] *= h y.append(x) # TF segment fixes: export is reversed vs ONNX export and protos are transposed if len(y) == 2: # segment with (det, proto) output order reversed if len(y[1].shape) != 4: y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32) - y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160) + if y[1].shape[-1] == 6: # end-to-end model + y = [y[1]] + else: + y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160) y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y] # for x in y: diff --git a/ultralytics/solutions/streamlit_inference.py b/ultralytics/solutions/streamlit_inference.py index 99916552f7..85394350da 100644 --- a/ultralytics/solutions/streamlit_inference.py +++ b/ultralytics/solutions/streamlit_inference.py @@ -69,7 +69,7 @@ def inference(model=None): # Add dropdown menu for model selection available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolov8")] if model: - available_models.insert(0, model) + available_models.insert(0, model.split(".pt")[0]) # insert model without suffix as *.pt is added later selected_model = st.sidebar.selectbox("Model", available_models) with st.spinner("Model is downloading..."): diff --git a/ultralytics/utils/__init__.py b/ultralytics/utils/__init__.py index b97a0fc420..39f6ad2b33 100644 --- a/ultralytics/utils/__init__.py +++ b/ultralytics/utils/__init__.py @@ -1066,8 +1066,9 @@ TESTS_RUNNING = is_pytest_running() or is_github_action_running() set_sentry() # Apply monkey patches -from ultralytics.utils.patches import imread, imshow, imwrite, torch_save +from ultralytics.utils.patches import imread, imshow, imwrite, torch_load, torch_save +torch.load = torch_load torch.save = torch_save if WINDOWS: # Apply cv2 patches for non-ASCII and non-UTF characters in image paths diff --git a/ultralytics/utils/autobatch.py b/ultralytics/utils/autobatch.py index 2f695df82a..784210c574 100644 --- a/ultralytics/utils/autobatch.py +++ b/ultralytics/utils/autobatch.py @@ -7,7 +7,7 @@ import numpy as np import torch from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr -from ultralytics.utils.torch_utils import profile +from ultralytics.utils.torch_utils import autocast, profile def check_train_batch_size(model, imgsz=640, amp=True, batch=-1): @@ -23,7 +23,7 @@ def check_train_batch_size(model, imgsz=640, amp=True, batch=-1): (int): Optimal batch size computed using the autobatch() function. """ - with torch.cuda.amp.autocast(amp): + with autocast(enabled=amp): return autobatch(deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6) diff --git a/ultralytics/utils/benchmarks.py b/ultralytics/utils/benchmarks.py index a6af7afd07..c36163e99e 100644 --- a/ultralytics/utils/benchmarks.py +++ b/ultralytics/utils/benchmarks.py @@ -100,9 +100,11 @@ def benchmark( assert not is_end2end, "End-to-end models not supported by CoreML and TF.js yet" if i in {3, 5}: # CoreML and OpenVINO assert not IS_PYTHON_3_12, "CoreML and OpenVINO not supported on Python 3.12" - if i in {6, 7, 8, 9, 10}: # All TF formats + if i in {6, 7, 8}: # TF SavedModel, TF GraphDef, and TFLite assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet" - assert not is_end2end, "End-to-end models not supported by onnx2tf yet" + if i in {9, 10}: # TF EdgeTPU and TF.js + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet" + assert not is_end2end, "End-to-end models not supported by TF EdgeTPU and TF.js yet" if i in {11}: # Paddle assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet" assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet" diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py index dfd7922839..d94e157fb6 100644 --- a/ultralytics/utils/checks.py +++ b/ultralytics/utils/checks.py @@ -641,6 +641,8 @@ def check_amp(model): Returns: (bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False. """ + from ultralytics.utils.torch_utils import autocast + device = next(model.parameters()).device # get model device if device.type in {"cpu", "mps"}: return False # AMP only used on CUDA devices @@ -648,7 +650,7 @@ def check_amp(model): def amp_allclose(m, im): """All close FP32 vs AMP results.""" a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference - with torch.cuda.amp.autocast(True): + with autocast(enabled=True): b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference del m return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py index 3c3d3b71e2..15bf92f9da 100644 --- a/ultralytics/utils/loss.py +++ b/ultralytics/utils/loss.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from ultralytics.utils.metrics import OKS_SIGMA from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors +from ultralytics.utils.torch_utils import autocast from .metrics import bbox_iou, probiou from .tal import bbox2dist @@ -27,7 +28,7 @@ class VarifocalLoss(nn.Module): def forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0): """Computes varfocal loss.""" weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label - with torch.cuda.amp.autocast(enabled=False): + with autocast(enabled=False): loss = ( (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight) .mean(1) diff --git a/ultralytics/utils/patches.py b/ultralytics/utils/patches.py index d438407114..d918e0efea 100644 --- a/ultralytics/utils/patches.py +++ b/ultralytics/utils/patches.py @@ -57,7 +57,33 @@ def imshow(winname: str, mat: np.ndarray): # PyTorch functions ---------------------------------------------------------------------------------------------------- -_torch_save = torch.save # copy to avoid recursion errors +_torch_load = torch.load # copy to avoid recursion errors +_torch_save = torch.save + + +def torch_load(*args, **kwargs): + """ + Load a PyTorch model with updated arguments to avoid warnings. + + This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings. + + Args: + *args (Any): Variable length argument list to pass to torch.load. + **kwargs (Any): Arbitrary keyword arguments to pass to torch.load. + + Returns: + (Any): The loaded PyTorch object. + + Note: + For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False' + if the argument is not provided, to avoid deprecation warnings. + """ + from ultralytics.utils.torch_utils import TORCH_1_13 + + if TORCH_1_13 and "weights_only" not in kwargs: + kwargs["weights_only"] = False + + return _torch_load(*args, **kwargs) def torch_save(*args, use_dill=True, **kwargs): @@ -68,7 +94,7 @@ def torch_save(*args, use_dill=True, **kwargs): Args: *args (tuple): Positional arguments to pass to torch.save. use_dill (bool): Whether to try using dill for serialization if available. Defaults to True. - **kwargs (any): Keyword arguments to pass to torch.save. + **kwargs (Any): Keyword arguments to pass to torch.save. """ try: assert use_dill diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index 21973d7e29..fcecd14816 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -68,6 +68,37 @@ def smart_inference_mode(): return decorate +def autocast(enabled: bool, device: str = "cuda"): + """ + Get the appropriate autocast context manager based on PyTorch version and AMP setting. + + This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both + older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions. + + Args: + enabled (bool): Whether to enable automatic mixed precision. + device (str, optional): The device to use for autocast. Defaults to 'cuda'. + + Returns: + (torch.amp.autocast): The appropriate autocast context manager. + + Note: + - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`. + - For older versions, it uses `torch.cuda.autocast`. + + Example: + ```python + with autocast(amp=True): + # Your mixed precision operations here + pass + ``` + """ + if TORCH_1_13: + return torch.amp.autocast(device, enabled=enabled) + else: + return torch.cuda.amp.autocast(enabled) + + def get_cpu_info(): """Return a string with system CPU information, i.e. 'Apple M2'.""" import cpuinfo # pip install py-cpuinfo