Merge branch 'main' into yolov9

pull/8571/head
Glenn Jocher 9 months ago committed by GitHub
commit 0b77c6e69d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 6
      .github/workflows/ci.yaml
  2. 32
      docs/en/models/yolo-world.md
  3. 4
      ultralytics/__init__.py
  4. 2
      ultralytics/cfg/models/v8/yolov8-world.yaml
  5. 10
      ultralytics/cfg/models/v8/yolov8-worldv2.yaml
  6. 13
      ultralytics/engine/exporter.py
  7. 14
      ultralytics/engine/model.py
  8. 2
      ultralytics/engine/trainer.py
  9. 6
      ultralytics/models/yolo/model.py
  10. 36
      ultralytics/utils/benchmarks.py
  11. 3
      ultralytics/utils/downloads.py
  12. 20
      ultralytics/utils/patches.py

@ -118,9 +118,9 @@ jobs:
run: | run: |
yolo checks yolo checks
pip list pip list
# - name: Benchmark DetectionModel - name: Benchmark World DetectionModel
# shell: bash shell: bash
# run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}.pt' imgsz=160 verbose=0.318 run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/yolov8s-worldv2.pt' imgsz=160 verbose=0.318
- name: Benchmark SegmentationModel - name: Benchmark SegmentationModel
shell: bash shell: bash
run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}-seg.pt' imgsz=160 verbose=0.281 run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}-seg.pt' imgsz=160 verbose=0.281

@ -36,21 +36,29 @@ This section details the models available with their specific pre-trained weight
All the YOLOv8-World weights have been directly migrated from the official [YOLO-World](https://github.com/AILab-CVC/YOLO-World) repository, highlighting their excellent contributions. All the YOLOv8-World weights have been directly migrated from the official [YOLO-World](https://github.com/AILab-CVC/YOLO-World) repository, highlighting their excellent contributions.
| Model Type | Pre-trained Weights | Tasks Supported | Inference | Validation | Training | Export | | Model Type | Pre-trained Weights | Tasks Supported | Inference | Validation | Training | Export |
|---------------|-----------------------------------------------------------------------------------------------------|----------------------------------------|-----------|------------|----------|--------| |-----------------|-------------------------------------------------------------------------------------------------------|----------------------------------------|-----------|------------|----------|--------|
| YOLOv8s-world | [yolov8s-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8s-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ | | YOLOv8s-world | [yolov8s-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8s-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ |
| YOLOv8m-world | [yolov8m-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8m-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ | | YOLOv8s-worldv2 | [yolov8s-worldv2.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8s-worldv2.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ✅ |
| YOLOv8l-world | [yolov8l-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8l-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ | | YOLOv8m-world | [yolov8m-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8m-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ |
| YOLOv8x-world | [yolov8x-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8x-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ | | YOLOv8m-worldv2 | [yolov8m-worldv2.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8m-worldv2.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ✅ |
| YOLOv8l-world | [yolov8l-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8l-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ |
| YOLOv8l-worldv2 | [yolov8l-worldv2.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8l-worldv2.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ✅ |
| YOLOv8x-world | [yolov8x-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8x-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ |
| YOLOv8x-worldv2 | [yolov8x-worldv2.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8x-worldv2.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ✅ |
## Zero-shot Transfer on COCO Dataset ## Zero-shot Transfer on COCO Dataset
| Model Type | mAP | mAP50 | mAP75 | | Model Type | mAP | mAP50 | mAP75 |
|---------------|------|-------|-------| |-----------------|------|-------|-------|
| yolov8s-world | 37.4 | 52.0 | 40.6 | | yolov8s-world | 37.4 | 52.0 | 40.6 |
| yolov8m-world | 42.0 | 57.0 | 45.6 | | yolov8s-worldv2 | 37.7 | 52.2 | 41.0 |
| yolov8l-world | 45.7 | 61.3 | 49.8 | | yolov8m-world | 42.0 | 57.0 | 45.6 |
| yolov8x-world | 47.0 | 63.0 | 51.2 | | yolov8m-worldv2 | 43.0 | 58.4 | 46.8 |
| yolov8l-world | 45.7 | 61.3 | 49.8 |
| yolov8l-worldv2 | 45.8 | 61.3 | 49.8 |
| yolov8x-world | 47.0 | 63.0 | 51.2 |
| yolov8x-worldv2 | 47.1 | 62.8 | 51.4 |
## Usage Examples ## Usage Examples

@ -1,7 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.22"
from ultralytics.data.explorer.explorer import Explorer from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
from ultralytics.models.fastsam import FastSAM from ultralytics.models.fastsam import FastSAM
@ -10,6 +8,7 @@ from ultralytics.utils import ASSETS, SETTINGS as settings
from ultralytics.utils.checks import check_yolo as checks from ultralytics.utils.checks import check_yolo as checks
from ultralytics.utils.downloads import download from ultralytics.utils.downloads import download
__all__ = ( __all__ = (
"__version__", "__version__",
"ASSETS", "ASSETS",
@ -24,3 +23,4 @@ __all__ = (
"settings", "settings",
"Explorer", "Explorer",
) )
__version__ = "8.1.22"

@ -1,5 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect # YOLOv8-World object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/tasks/detect
# Parameters # Parameters
nc: 80 # number of classes nc: 80 # number of classes

@ -1,5 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect # YOLOv8-World-v2 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/tasks/detect
# Parameters # Parameters
nc: 80 # number of classes nc: 80 # number of classes
@ -29,18 +29,18 @@ backbone:
head: head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]] - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4 - [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 2, C2fAttn, [512, 256, 8]] # 12 - [-1, 3, C2fAttn, [512, 256, 8]] # 12
- [-1, 1, nn.Upsample, [None, 2, "nearest"]] - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3 - [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 2, C2fAttn, [256, 128, 4]] # 15 (P3/8-small) - [-1, 3, C2fAttn, [256, 128, 4]] # 15 (P3/8-small)
- [15, 1, Conv, [256, 3, 2]] - [15, 1, Conv, [256, 3, 2]]
- [[-1, 12], 1, Concat, [1]] # cat head P4 - [[-1, 12], 1, Concat, [1]] # cat head P4
- [-1, 2, C2fAttn, [512, 256, 8]] # 18 (P4/16-medium) - [-1, 3, C2fAttn, [512, 256, 8]] # 18 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]] - [-1, 1, Conv, [512, 3, 2]]
- [[-1, 9], 1, Concat, [1]] # cat head P5 - [[-1, 9], 1, Concat, [1]] # cat head P5
- [-1, 2, C2fAttn, [1024, 512, 16]] # 21 (P5/32-large) - [-1, 3, C2fAttn, [1024, 512, 16]] # 21 (P5/32-large)
- [[15, 18, 21], 1, WorldDetect, [nc, 512, True]] # Detect(P3, P4, P5) - [[15, 18, 21], 1, WorldDetect, [nc, 512, True]] # Detect(P3, P4, P5)

@ -68,7 +68,7 @@ from ultralytics.data.dataset import YOLODataset
from ultralytics.data.utils import check_det_dataset from ultralytics.data.utils import check_det_dataset
from ultralytics.nn.autobackend import check_class_names, default_class_names from ultralytics.nn.autobackend import check_class_names, default_class_names
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
from ultralytics.nn.tasks import DetectionModel, SegmentationModel from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel
from ultralytics.utils import ( from ultralytics.utils import (
ARM64, ARM64,
DEFAULT_CFG, DEFAULT_CFG,
@ -201,6 +201,14 @@ class Exporter:
assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'" assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
if edgetpu and not LINUX: if edgetpu and not LINUX:
raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/") raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/")
print(type(model))
if isinstance(model, WorldModel):
LOGGER.warning(
"WARNING ⚠ YOLOWorld (original version) export is not supported to any format.\n"
"WARNING ⚠ YOLOWorldv2 models (i.e. 'yolov8s-worldv2.pt') only support export to "
"(torchscript, onnx, openvino, engine, coreml) formats. "
"See https://docs.ultralytics.com/models/yolo-world for details."
)
# Input # Input
im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device) im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
@ -252,9 +260,10 @@ class Exporter:
self.metadata = { self.metadata = {
"description": description, "description": description,
"author": "Ultralytics", "author": "Ultralytics",
"license": "AGPL-3.0 https://ultralytics.com/license",
"date": datetime.now().isoformat(), "date": datetime.now().isoformat(),
"version": __version__, "version": __version__,
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
"docs": "https://docs.ultralytics.com",
"stride": int(max(model.stride)), "stride": int(max(model.stride)),
"task": model.task, "task": model.task,
"batch": self.args.batch, "batch": self.args.batch,

@ -295,7 +295,7 @@ class Model(nn.Module):
self.model.load(weights) self.model.load(weights)
return self return self
def save(self, filename: Union[str, Path] = "saved_model.pt") -> None: def save(self, filename: Union[str, Path] = "saved_model.pt", use_dill=True) -> None:
""" """
Saves the current model state to a file. Saves the current model state to a file.
@ -303,12 +303,22 @@ class Model(nn.Module):
Args: Args:
filename (str | Path): The name of the file to save the model to. Defaults to 'saved_model.pt'. filename (str | Path): The name of the file to save the model to. Defaults to 'saved_model.pt'.
use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.
Raises: Raises:
AssertionError: If the model is not a PyTorch model. AssertionError: If the model is not a PyTorch model.
""" """
self._check_is_pytorch_model() self._check_is_pytorch_model()
torch.save(self.ckpt, filename) from ultralytics import __version__
from datetime import datetime
updates = {
"date": datetime.now().isoformat(),
"version": __version__,
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
"docs": "https://docs.ultralytics.com",
}
torch.save({**self.ckpt, **updates}, filename, use_dill=use_dill)
def info(self, detailed: bool = False, verbose: bool = True): def info(self, detailed: bool = False, verbose: bool = True):
""" """

@ -488,6 +488,8 @@ class BaseTrainer:
"train_results": results, "train_results": results,
"date": datetime.now().isoformat(), "date": datetime.now().isoformat(),
"version": __version__, "version": __version__,
"license": "AGPL-3.0 (https://ultralytics.com/license)",
"docs": "https://docs.ultralytics.com",
} }
# Save last and best # Save last and best

@ -13,8 +13,8 @@ class YOLO(Model):
def __init__(self, model="yolov8n.pt", task=None, verbose=False): def __init__(self, model="yolov8n.pt", task=None, verbose=False):
"""Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'.""" """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'."""
stem = Path(model).stem # filename stem without suffix, i.e. "yolov8n" model = Path(model)
if "-world" in stem: if "-world" in model.stem and model.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
new_instance = YOLOWorld(model) new_instance = YOLOWorld(model)
self.__class__ = type(new_instance) self.__class__ = type(new_instance)
self.__dict__ = new_instance.__dict__ self.__dict__ = new_instance.__dict__
@ -67,7 +67,7 @@ class YOLOWorld(Model):
Initializes the YOLOv8-World model with the given pre-trained model file. Supports *.pt and *.yaml formats. Initializes the YOLOv8-World model with the given pre-trained model file. Supports *.pt and *.yaml formats.
Args: Args:
model (str): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'. model (str | Path): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'.
""" """
super().__init__(model=model, task="detect") super().__init__(model=model, task="detect")

@ -32,7 +32,7 @@ from pathlib import Path
import numpy as np import numpy as np
import torch.cuda import torch.cuda
from ultralytics import YOLO from ultralytics import YOLO, YOLOWorld
from ultralytics.cfg import TASK2DATA, TASK2METRIC from ultralytics.cfg import TASK2DATA, TASK2METRIC
from ultralytics.engine.exporter import export_formats from ultralytics.engine.exporter import export_formats
from ultralytics.utils import ASSETS, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR from ultralytics.utils import ASSETS, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR
@ -84,14 +84,20 @@ def benchmark(
emoji, filename = "", None # export defaults emoji, filename = "", None # export defaults
try: try:
# Checks # Checks
if i == 9: if i == 9: # Edge TPU
assert LINUX, "Edge TPU export only supported on Linux" assert LINUX, "Edge TPU export only supported on Linux"
elif i == 7: elif i == 7: # TF GraphDef
assert model.task != "obb", "TensorFlow GraphDef not supported for OBB task" assert model.task != "obb", "TensorFlow GraphDef not supported for OBB task"
elif i in {5, 10}: # CoreML and TF.js elif i in {5, 10}: # CoreML and TF.js
assert MACOS or LINUX, "export only supported on macOS and Linux" assert MACOS or LINUX, "export only supported on macOS and Linux"
if i in {3, 5}: # CoreML and OpenVINO if i in {3, 5}: # CoreML and OpenVINO
assert not IS_PYTHON_3_12, "CoreML and OpenVINO not supported on Python 3.12" assert not IS_PYTHON_3_12, "CoreML and OpenVINO not supported on Python 3.12"
if i in {6, 7, 8, 9, 10}: # All TF formats
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
if i in {11}: # Paddle
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
if i in {12}: # NCNN
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
if "cpu" in device.type: if "cpu" in device.type:
assert cpu, "inference not supported on CPU" assert cpu, "inference not supported on CPU"
if "cuda" in device.type: if "cuda" in device.type:
@ -261,7 +267,8 @@ class ProfileModels:
""" """
return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops) return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops)
def iterative_sigma_clipping(self, data, sigma=2, max_iters=3): @staticmethod
def iterative_sigma_clipping(data, sigma=2, max_iters=3):
"""Applies an iterative sigma clipping algorithm to the given data times number of iterations.""" """Applies an iterative sigma clipping algorithm to the given data times number of iterations."""
data = np.array(data) data = np.array(data)
for _ in range(max_iters): for _ in range(max_iters):
@ -359,9 +366,13 @@ class ProfileModels:
def generate_table_row(self, model_name, t_onnx, t_engine, model_info): def generate_table_row(self, model_name, t_onnx, t_engine, model_info):
"""Generates a formatted string for a table row that includes model performance and metric details.""" """Generates a formatted string for a table row that includes model performance and metric details."""
layers, params, gradients, flops = model_info layers, params, gradients, flops = model_info
return f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |" return (
f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± "
f"{t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |"
)
def generate_results_dict(self, model_name, t_onnx, t_engine, model_info): @staticmethod
def generate_results_dict(model_name, t_onnx, t_engine, model_info):
"""Generates a dictionary of model details including name, parameters, GFLOPS and speed metrics.""" """Generates a dictionary of model details including name, parameters, GFLOPS and speed metrics."""
layers, params, gradients, flops = model_info layers, params, gradients, flops = model_info
return { return {
@ -372,11 +383,18 @@ class ProfileModels:
"model/speed_TensorRT(ms)": round(t_engine[0], 3), "model/speed_TensorRT(ms)": round(t_engine[0], 3),
} }
def print_table(self, table_rows): @staticmethod
def print_table(table_rows):
"""Formats and prints a comparison table for different models with given statistics and performance data.""" """Formats and prints a comparison table for different models with given statistics and performance data."""
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU" gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU"
header = f"| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |" header = (
separator = "|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|" f"| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | "
f"Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |"
)
separator = (
"|-------------|---------------------|--------------------|------------------------------|"
"-----------------------------------|------------------|-----------------|"
)
print(f"\n\n{header}") print(f"\n\n{header}")
print(separator) print(separator)

@ -20,7 +20,8 @@ GITHUB_ASSETS_NAMES = (
[f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")] [f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")]
+ [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")] + [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")]
+ [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")] + [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")]
+ [f"yolov8{k}-world.pt" for k in "sml"] + [f"yolov8{k}-world.pt" for k in "smlx"]
+ [f"yolov8{k}-worldv2.pt" for k in "smlx"]
+ [f"yolov9{k}.pt" for k in "ce"] + [f"yolov9{k}.pt" for k in "ce"]
+ [f"yolo_nas_{k}.pt" for k in "sml"] + [f"yolo_nas_{k}.pt" for k in "sml"]
+ [f"sam_{k}.pt" for k in "bl"] + [f"sam_{k}.pt" for k in "bl"]

@ -60,27 +60,29 @@ def imshow(winname: str, mat: np.ndarray):
_torch_save = torch.save # copy to avoid recursion errors _torch_save = torch.save # copy to avoid recursion errors
def torch_save(*args, **kwargs): def torch_save(*args, use_dill=True, **kwargs):
""" """
Use dill (if exists) to serialize the lambda functions where pickle does not do this. Also adds 3 retries with Optionally use dill to serialize lambda functions where pickle does not, adding robustness with 3 retries and
exponential standoff in case of save failure to improve robustness to transient issues. exponential standoff in case of save failure.
Args: Args:
*args (tuple): Positional arguments to pass to torch.save. *args (tuple): Positional arguments to pass to torch.save.
use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.
**kwargs (dict): Keyword arguments to pass to torch.save. **kwargs (dict): Keyword arguments to pass to torch.save.
""" """
try: try:
import dill as pickle # noqa assert use_dill
except ImportError: import dill as pickle
except (AssertionError, ImportError):
import pickle import pickle
if "pickle_module" not in kwargs: if "pickle_module" not in kwargs:
kwargs["pickle_module"] = pickle # noqa kwargs["pickle_module"] = pickle
for i in range(4): # 3 retries for i in range(4): # 3 retries
try: try:
return _torch_save(*args, **kwargs) return _torch_save(*args, **kwargs)
except RuntimeError: # unable to save, possibly waiting for device to flush or anti-virus to finish scanning except RuntimeError as e: # unable to save, possibly waiting for device to flush or antivirus scan
if i == 3: if i == 3:
raise raise e
time.sleep((2**i) / 2) # exponential standoff 0.5s, 1.0s, 2.0s time.sleep((2**i) / 2) # exponential standoff: 0.5s, 1.0s, 2.0s

Loading…
Cancel
Save