Merge branch 'main' into action-recog

action-recog
Ultralytics Assistant 4 months ago committed by GitHub
commit a44cb6ed33
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 6
      docs/en/models/yolov10.md
  2. 4
      docs/en/reference/utils/patches.md
  3. 4
      docs/en/reference/utils/torch_utils.md
  4. 2
      ultralytics/__init__.py
  5. 11
      ultralytics/data/augment.py
  6. 2
      ultralytics/engine/exporter.py
  7. 10
      ultralytics/engine/trainer.py
  8. 15
      ultralytics/models/nas/model.py
  9. 2
      ultralytics/models/utils/ops.py
  10. 13
      ultralytics/nn/autobackend.py
  11. 2
      ultralytics/solutions/streamlit_inference.py
  12. 3
      ultralytics/utils/__init__.py
  13. 4
      ultralytics/utils/autobatch.py
  14. 6
      ultralytics/utils/benchmarks.py
  15. 4
      ultralytics/utils/checks.py
  16. 3
      ultralytics/utils/loss.py
  17. 30
      ultralytics/utils/patches.py
  18. 31
      ultralytics/utils/torch_utils.py

@ -198,9 +198,9 @@ Due to the new operations introduced with YOLOv10, not all export formats provid
| [OpenVINO](../integrations/openvino.md) | ✅ | | [OpenVINO](../integrations/openvino.md) | ✅ |
| [TensorRT](../integrations/tensorrt.md) | ✅ | | [TensorRT](../integrations/tensorrt.md) | ✅ |
| [CoreML](../integrations/coreml.md) | ❌ | | [CoreML](../integrations/coreml.md) | ❌ |
| [TF SavedModel](../integrations/tf-savedmodel.md) | | | [TF SavedModel](../integrations/tf-savedmodel.md) | |
| [TF GraphDef](../integrations/tf-graphdef.md) | | | [TF GraphDef](../integrations/tf-graphdef.md) | |
| [TF Lite](../integrations/tflite.md) | | | [TF Lite](../integrations/tflite.md) | |
| [TF Edge TPU](../integrations/edge-tpu.md) | ❌ | | [TF Edge TPU](../integrations/edge-tpu.md) | ❌ |
| [TF.js](../integrations/tfjs.md) | ❌ | | [TF.js](../integrations/tfjs.md) | ❌ |
| [PaddlePaddle](../integrations/paddlepaddle.md) | ❌ | | [PaddlePaddle](../integrations/paddlepaddle.md) | ❌ |

@ -23,6 +23,10 @@ keywords: Ultralytics, utils, patches, imread, imwrite, imshow, torch_save, Open
<br><br><hr><br> <br><br><hr><br>
## ::: ultralytics.utils.patches.torch_load
<br><br><hr><br>
## ::: ultralytics.utils.patches.torch_save ## ::: ultralytics.utils.patches.torch_save
<br><br> <br><br>

@ -27,6 +27,10 @@ keywords: Ultralytics, torch utils, model optimization, device selection, infere
<br><br><hr><br> <br><br><hr><br>
## ::: ultralytics.utils.torch_utils.autocast
<br><br><hr><br>
## ::: ultralytics.utils.torch_utils.get_cpu_info ## ::: ultralytics.utils.torch_utils.get_cpu_info
<br><br><hr><br> <br><br><hr><br>

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.63" __version__ = "8.2.64"
import os import os

@ -2322,7 +2322,7 @@ def classify_transforms(
size=224, size=224,
mean=DEFAULT_MEAN, mean=DEFAULT_MEAN,
std=DEFAULT_STD, std=DEFAULT_STD,
interpolation=Image.BILINEAR, interpolation="BILINEAR",
crop_fraction: float = DEFAULT_CROP_FRACTION, crop_fraction: float = DEFAULT_CROP_FRACTION,
): ):
""" """
@ -2337,7 +2337,7 @@ def classify_transforms(
tuple, it defines (height, width). tuple, it defines (height, width).
mean (tuple): Mean values for each RGB channel used in normalization. mean (tuple): Mean values for each RGB channel used in normalization.
std (tuple): Standard deviation values for each RGB channel used in normalization. std (tuple): Standard deviation values for each RGB channel used in normalization.
interpolation (int): Interpolation method for resizing. interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'.
crop_fraction (float): Fraction of the image to be cropped. crop_fraction (float): Fraction of the image to be cropped.
Returns: Returns:
@ -2360,7 +2360,7 @@ def classify_transforms(
# Aspect ratio is preserved, crops center within image, no borders are added, image is lost # Aspect ratio is preserved, crops center within image, no borders are added, image is lost
if scale_size[0] == scale_size[1]: if scale_size[0] == scale_size[1]:
# Simple case, use torchvision built-in Resize with the shortest edge mode (scalar size arg) # Simple case, use torchvision built-in Resize with the shortest edge mode (scalar size arg)
tfl = [T.Resize(scale_size[0], interpolation=interpolation)] tfl = [T.Resize(scale_size[0], interpolation=getattr(T.InterpolationMode, interpolation))]
else: else:
# Resize the shortest edge to matching target dim for non-square target # Resize the shortest edge to matching target dim for non-square target
tfl = [T.Resize(scale_size)] tfl = [T.Resize(scale_size)]
@ -2389,7 +2389,7 @@ def classify_augmentations(
hsv_v=0.4, # image HSV-Value augmentation (fraction) hsv_v=0.4, # image HSV-Value augmentation (fraction)
force_color_jitter=False, force_color_jitter=False,
erasing=0.0, erasing=0.0,
interpolation=Image.BILINEAR, interpolation="BILINEAR",
): ):
""" """
Creates a composition of image augmentation transforms for classification tasks. Creates a composition of image augmentation transforms for classification tasks.
@ -2411,7 +2411,7 @@ def classify_augmentations(
hsv_v (float): Image HSV-Value augmentation factor. hsv_v (float): Image HSV-Value augmentation factor.
force_color_jitter (bool): Whether to apply color jitter even if auto augment is enabled. force_color_jitter (bool): Whether to apply color jitter even if auto augment is enabled.
erasing (float): Probability of random erasing. erasing (float): Probability of random erasing.
interpolation (int): Interpolation method. interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'.
Returns: Returns:
(torchvision.transforms.Compose): A composition of image augmentation transforms. (torchvision.transforms.Compose): A composition of image augmentation transforms.
@ -2427,6 +2427,7 @@ def classify_augmentations(
raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)") raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range
interpolation = getattr(T.InterpolationMode, interpolation)
primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)] primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)]
if hflip > 0.0: if hflip > 0.0:
primary_tfl.append(T.RandomHorizontalFlip(p=hflip)) primary_tfl.append(T.RandomHorizontalFlip(p=hflip))

@ -885,6 +885,8 @@ class Exporter:
output_integer_quantized_tflite=self.args.int8, output_integer_quantized_tflite=self.args.int8,
quant_type="per-tensor", # "per-tensor" (faster) or "per-channel" (slower but more accurate) quant_type="per-tensor", # "per-tensor" (faster) or "per-channel" (slower but more accurate)
custom_input_op_name_np_data_path=np_data, custom_input_op_name_np_data_path=np_data,
disable_group_convolution=True, # for end-to-end model compatibility
enable_batchmatmul_unfold=True, # for end-to-end model compatibility
) )
yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml

@ -41,8 +41,10 @@ from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_m
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.utils.files import get_latest_run from ultralytics.utils.files import get_latest_run
from ultralytics.utils.torch_utils import ( from ultralytics.utils.torch_utils import (
TORCH_1_13,
EarlyStopping, EarlyStopping,
ModelEMA, ModelEMA,
autocast,
convert_optimizer_state_dict_to_fp16, convert_optimizer_state_dict_to_fp16,
init_seeds, init_seeds,
one_cycle, one_cycle,
@ -264,7 +266,11 @@ class BaseTrainer:
if RANK > -1 and world_size > 1: # DDP if RANK > -1 and world_size > 1: # DDP
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None) dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
self.amp = bool(self.amp) # as boolean self.amp = bool(self.amp) # as boolean
self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp) self.scaler = (
torch.amp.GradScaler("cuda", enabled=self.amp)
if TORCH_1_13
else torch.cuda.amp.GradScaler(enabled=self.amp)
)
if world_size > 1: if world_size > 1:
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True) self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
@ -376,7 +382,7 @@ class BaseTrainer:
x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum]) x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
# Forward # Forward
with torch.cuda.amp.autocast(self.amp): with autocast(self.amp):
batch = self.preprocess_batch(batch) batch = self.preprocess_batch(batch)
self.loss, self.loss_items = self.model(batch) self.loss, self.loss_items = self.model(batch)
if RANK != -1: if RANK != -1:

@ -17,7 +17,7 @@ import torch
from ultralytics.engine.model import Model from ultralytics.engine.model import Model
from ultralytics.utils.downloads import attempt_download_asset from ultralytics.utils.downloads import attempt_download_asset
from ultralytics.utils.torch_utils import model_info, smart_inference_mode from ultralytics.utils.torch_utils import model_info
from .predict import NASPredictor from .predict import NASPredictor
from .val import NASValidator from .val import NASValidator
@ -50,16 +50,25 @@ class NAS(Model):
assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models." assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."
super().__init__(model, task="detect") super().__init__(model, task="detect")
@smart_inference_mode() def _load(self, weights: str, task=None) -> None:
def _load(self, weights: str, task: str):
"""Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided.""" """Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
import super_gradients import super_gradients
suffix = Path(weights).suffix suffix = Path(weights).suffix
if suffix == ".pt": if suffix == ".pt":
self.model = torch.load(attempt_download_asset(weights)) self.model = torch.load(attempt_download_asset(weights))
elif suffix == "": elif suffix == "":
self.model = super_gradients.training.models.get(weights, pretrained_weights="coco") self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
# Override the forward method to ignore additional arguments
def new_forward(x, *args, **kwargs):
"""Ignore additional __call__ arguments."""
return self.model._original_forward(x)
self.model._original_forward = self.model.forward
self.model.forward = new_forward
# Standardize model # Standardize model
self.model.fuse = lambda verbose=True: self.model self.model.fuse = lambda verbose=True: self.model
self.model.stride = torch.tensor([32]) self.model.stride = torch.tensor([32])

@ -133,7 +133,7 @@ class HungarianMatcher(nn.Module):
# sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0]) # sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
# tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2]) # tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
# #
# with torch.cuda.amp.autocast(False): # with torch.amp.autocast("cuda", enabled=False):
# # binary cross entropy cost # # binary cross entropy cost
# pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none') # pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
# neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none') # neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')

@ -587,14 +587,21 @@ class AutoBackend(nn.Module):
if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well
# Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695 # Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695
# xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models # xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models
x[:, [0, 2]] *= w if x.shape[-1] == 6: # end-to-end model
x[:, [1, 3]] *= h x[:, :, [0, 2]] *= w
x[:, :, [1, 3]] *= h
else:
x[:, [0, 2]] *= w
x[:, [1, 3]] *= h
y.append(x) y.append(x)
# TF segment fixes: export is reversed vs ONNX export and protos are transposed # TF segment fixes: export is reversed vs ONNX export and protos are transposed
if len(y) == 2: # segment with (det, proto) output order reversed if len(y) == 2: # segment with (det, proto) output order reversed
if len(y[1].shape) != 4: if len(y[1].shape) != 4:
y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32) y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32)
y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160) if y[1].shape[-1] == 6: # end-to-end model
y = [y[1]]
else:
y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160)
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y] y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
# for x in y: # for x in y:

@ -69,7 +69,7 @@ def inference(model=None):
# Add dropdown menu for model selection # Add dropdown menu for model selection
available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolov8")] available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolov8")]
if model: if model:
available_models.insert(0, model) available_models.insert(0, model.split(".pt")[0]) # insert model without suffix as *.pt is added later
selected_model = st.sidebar.selectbox("Model", available_models) selected_model = st.sidebar.selectbox("Model", available_models)
with st.spinner("Model is downloading..."): with st.spinner("Model is downloading..."):

@ -1066,8 +1066,9 @@ TESTS_RUNNING = is_pytest_running() or is_github_action_running()
set_sentry() set_sentry()
# Apply monkey patches # Apply monkey patches
from ultralytics.utils.patches import imread, imshow, imwrite, torch_save from ultralytics.utils.patches import imread, imshow, imwrite, torch_load, torch_save
torch.load = torch_load
torch.save = torch_save torch.save = torch_save
if WINDOWS: if WINDOWS:
# Apply cv2 patches for non-ASCII and non-UTF characters in image paths # Apply cv2 patches for non-ASCII and non-UTF characters in image paths

@ -7,7 +7,7 @@ import numpy as np
import torch import torch
from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr
from ultralytics.utils.torch_utils import profile from ultralytics.utils.torch_utils import autocast, profile
def check_train_batch_size(model, imgsz=640, amp=True, batch=-1): def check_train_batch_size(model, imgsz=640, amp=True, batch=-1):
@ -23,7 +23,7 @@ def check_train_batch_size(model, imgsz=640, amp=True, batch=-1):
(int): Optimal batch size computed using the autobatch() function. (int): Optimal batch size computed using the autobatch() function.
""" """
with torch.cuda.amp.autocast(amp): with autocast(enabled=amp):
return autobatch(deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6) return autobatch(deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6)

@ -100,9 +100,11 @@ def benchmark(
assert not is_end2end, "End-to-end models not supported by CoreML and TF.js yet" assert not is_end2end, "End-to-end models not supported by CoreML and TF.js yet"
if i in {3, 5}: # CoreML and OpenVINO if i in {3, 5}: # CoreML and OpenVINO
assert not IS_PYTHON_3_12, "CoreML and OpenVINO not supported on Python 3.12" assert not IS_PYTHON_3_12, "CoreML and OpenVINO not supported on Python 3.12"
if i in {6, 7, 8, 9, 10}: # All TF formats if i in {6, 7, 8}: # TF SavedModel, TF GraphDef, and TFLite
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet" assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
assert not is_end2end, "End-to-end models not supported by onnx2tf yet" if i in {9, 10}: # TF EdgeTPU and TF.js
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
assert not is_end2end, "End-to-end models not supported by TF EdgeTPU and TF.js yet"
if i in {11}: # Paddle if i in {11}: # Paddle
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet" assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet" assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet"

@ -641,6 +641,8 @@ def check_amp(model):
Returns: Returns:
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False. (bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
""" """
from ultralytics.utils.torch_utils import autocast
device = next(model.parameters()).device # get model device device = next(model.parameters()).device # get model device
if device.type in {"cpu", "mps"}: if device.type in {"cpu", "mps"}:
return False # AMP only used on CUDA devices return False # AMP only used on CUDA devices
@ -648,7 +650,7 @@ def check_amp(model):
def amp_allclose(m, im): def amp_allclose(m, im):
"""All close FP32 vs AMP results.""" """All close FP32 vs AMP results."""
a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference
with torch.cuda.amp.autocast(True): with autocast(enabled=True):
b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference
del m del m
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance

@ -7,6 +7,7 @@ import torch.nn.functional as F
from ultralytics.utils.metrics import OKS_SIGMA from ultralytics.utils.metrics import OKS_SIGMA
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
from ultralytics.utils.torch_utils import autocast
from .metrics import bbox_iou, probiou from .metrics import bbox_iou, probiou
from .tal import bbox2dist from .tal import bbox2dist
@ -27,7 +28,7 @@ class VarifocalLoss(nn.Module):
def forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0): def forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0):
"""Computes varfocal loss.""" """Computes varfocal loss."""
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
with torch.cuda.amp.autocast(enabled=False): with autocast(enabled=False):
loss = ( loss = (
(F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight) (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
.mean(1) .mean(1)

@ -57,7 +57,33 @@ def imshow(winname: str, mat: np.ndarray):
# PyTorch functions ---------------------------------------------------------------------------------------------------- # PyTorch functions ----------------------------------------------------------------------------------------------------
_torch_save = torch.save # copy to avoid recursion errors _torch_load = torch.load # copy to avoid recursion errors
_torch_save = torch.save
def torch_load(*args, **kwargs):
"""
Load a PyTorch model with updated arguments to avoid warnings.
This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.
Args:
*args (Any): Variable length argument list to pass to torch.load.
**kwargs (Any): Arbitrary keyword arguments to pass to torch.load.
Returns:
(Any): The loaded PyTorch object.
Note:
For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False'
if the argument is not provided, to avoid deprecation warnings.
"""
from ultralytics.utils.torch_utils import TORCH_1_13
if TORCH_1_13 and "weights_only" not in kwargs:
kwargs["weights_only"] = False
return _torch_load(*args, **kwargs)
def torch_save(*args, use_dill=True, **kwargs): def torch_save(*args, use_dill=True, **kwargs):
@ -68,7 +94,7 @@ def torch_save(*args, use_dill=True, **kwargs):
Args: Args:
*args (tuple): Positional arguments to pass to torch.save. *args (tuple): Positional arguments to pass to torch.save.
use_dill (bool): Whether to try using dill for serialization if available. Defaults to True. use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.
**kwargs (any): Keyword arguments to pass to torch.save. **kwargs (Any): Keyword arguments to pass to torch.save.
""" """
try: try:
assert use_dill assert use_dill

@ -68,6 +68,37 @@ def smart_inference_mode():
return decorate return decorate
def autocast(enabled: bool, device: str = "cuda"):
"""
Get the appropriate autocast context manager based on PyTorch version and AMP setting.
This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both
older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.
Args:
enabled (bool): Whether to enable automatic mixed precision.
device (str, optional): The device to use for autocast. Defaults to 'cuda'.
Returns:
(torch.amp.autocast): The appropriate autocast context manager.
Note:
- For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
- For older versions, it uses `torch.cuda.autocast`.
Example:
```python
with autocast(amp=True):
# Your mixed precision operations here
pass
```
"""
if TORCH_1_13:
return torch.amp.autocast(device, enabled=enabled)
else:
return torch.cuda.amp.autocast(enabled)
def get_cpu_info(): def get_cpu_info():
"""Return a string with system CPU information, i.e. 'Apple M2'.""" """Return a string with system CPU information, i.e. 'Apple M2'."""
import cpuinfo # pip install py-cpuinfo import cpuinfo # pip install py-cpuinfo

Loading…
Cancel
Save