Fix Windows non-UTF source filenames (#4524)

Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
pull/4491/head^2
Glenn Jocher 1 year ago committed by GitHub
parent a7419617a6
commit 1db9afc2e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      .github/workflows/links.yml
  2. 21
      docs/quickstart.md
  3. 4
      docs/reference/data/converter.md
  4. 8
      docs/reference/data/dataset.md
  5. 4
      docs/reference/data/utils.md
  6. 38
      tests/test_python.py
  7. 2
      ultralytics/data/augment.py
  8. 40
      ultralytics/data/converter.py
  9. 4
      ultralytics/data/utils.py
  10. 18
      ultralytics/engine/exporter.py
  11. 3
      ultralytics/engine/predictor.py
  12. 1
      ultralytics/nn/tasks.py
  13. 9
      ultralytics/utils/__init__.py
  14. 10
      ultralytics/utils/benchmarks.py
  15. 12
      ultralytics/utils/downloads.py
  16. 46
      ultralytics/utils/patches.py

@ -28,7 +28,7 @@ jobs:
timeout_minutes: 5
retry_wait_seconds: 60
max_attempts: 3
command: lychee --accept 429,999 --exclude-loopback --exclude 'https?://(www\.)?(linkedin\.com|twitter\.com|instagram\.com)' --exclude-path '**/ci.yaml' --exclude-mail --github-token ${{ secrets.GITHUB_TOKEN }} './**/*.md' './**/*.html'
command: lychee --accept 429,999 --exclude-loopback --exclude 'https?://(www\.)?(linkedin\.com|twitter\.com|instagram\.com|kaggle\.com)' --exclude-path '**/ci.yaml' --exclude-mail --github-token ${{ secrets.GITHUB_TOKEN }} './**/*.md' './**/*.html'
- name: Test Markdown, HTML, YAML, Python and Notebook links with retry
if: github.event_name == 'workflow_dispatch'
@ -37,4 +37,4 @@ jobs:
timeout_minutes: 5
retry_wait_seconds: 60
max_attempts: 3
command: lychee --accept 429,999 --exclude-loopback --exclude 'https?://(www\.)?(linkedin\.com|twitter\.com|instagram\.com|url\.com)' --exclude-path '**/ci.yaml' --exclude-mail --github-token ${{ secrets.GITHUB_TOKEN }} './**/*.md' './**/*.html' './**/*.yml' './**/*.yaml' './**/*.py' './**/*.ipynb'
command: lychee --accept 429,999 --exclude-loopback --exclude 'https?://(www\.)?(linkedin\.com|twitter\.com|instagram\.com|kaggle\.com|url\.com)' --exclude-path '**/ci.yaml' --exclude-mail --github-token ${{ secrets.GITHUB_TOKEN }} './**/*.md' './**/*.html' './**/*.yml' './**/*.yaml' './**/*.py' './**/*.ipynb'

@ -16,10 +16,18 @@ Ultralytics provides various installation methods including pip, conda, and Dock
[![PyPI version](https://badge.fury.io/py/ultralytics.svg)](https://badge.fury.io/py/ultralytics) [![Downloads](https://static.pepy.tech/badge/ultralytics)](https://pepy.tech/project/ultralytics)
```bash
# Install the ultralytics package using pip
# Install the ultralytics package from PyPI
pip install ultralytics
```
You can also install the `ultralytics` package directly from the GitHub repository. This might be useful if you want the latest development version. Make sure to have the Git command-line tool installed on your system. The `@main` command installs the `main` branch and may be modified to another branch, i.e. `@my-branch`, or removed alltogether to default to `main` branch.
```bash
# Install the ultralytics package from GitHub
pip install git+https://github.com/ultralytics/ultralytics.git@main
```
=== "Conda install"
Conda is an alternative package manager to pip which may also be used for installation. Visit Anaconda for more details at [https://anaconda.org/conda-forge/ultralytics](https://anaconda.org/conda-forge/ultralytics). Ultralytics feedstock repository for updating the conda package is at [https://github.com/conda-forge/ultralytics-feedstock/](https://github.com/conda-forge/ultralytics-feedstock/).
@ -53,10 +61,19 @@ Ultralytics provides various installation methods including pip, conda, and Dock
```
=== "Docker"
Utilize Docker to execute the `ultralytics` package in an isolated container. By employing the official `ultralytics` image from [Docker Hub](https://hub.docker.com/r/ultralytics/ultralytics), you can avoid local installation. Below are the commands to get the latest image and execute it:
Utilize Docker to effortlessly execute the `ultralytics` package in an isolated container, ensuring consistent and smooth performance across various environments. By choosing one of the official `ultralytics` images from [Docker Hub](https://hub.docker.com/r/ultralytics/ultralytics), you not only avoid the complexity of local installation but also benefit from access to a verified working environment. Ultralytics offers 5 main supported Docker images, each designed to provide high compatibility and efficiency for different platforms and use cases:
<a href="https://hub.docker.com/r/ultralytics/ultralytics"><img src="https://img.shields.io/docker/pulls/ultralytics/ultralytics?logo=docker" alt="Docker Pulls"></a>
- **Dockerfile:** GPU image recommended for training.
- **Dockerfile-arm64:** Optimized for ARM64 architecture, allowing deployment on devices like Raspberry Pi and other ARM64-based platforms.
- **Dockerfile-cpu:** Ubuntu-based CPU-only version suitable for inference and environments without GPUs.
- **Dockerfile-jetson:** Tailored for NVIDIA Jetson devices, integrating GPU support optimized for these platforms.
- **Dockerfile-python:** Minimal image with just Python and necessary dependencies, ideal for lightweight applications and development.
Below are the commands to get the latest image and execute it:
```bash
# Set image name as a variable
t=ultralytics/ultralytics:latest

@ -25,10 +25,6 @@ keywords: Ultralytics, Data Converter, coco91_to_coco80_class, merge_multi_segme
## ::: ultralytics.data.converter.convert_dota_to_yolo_obb
<br><br>
---
## ::: ultralytics.data.converter.rle2polygon
<br><br>
---
## ::: ultralytics.data.converter.min_index
<br><br>

@ -20,3 +20,11 @@ keywords: Ultralytics, YOLO, YOLODataset, SemanticDataset, data handling, data m
---
## ::: ultralytics.data.dataset.SemanticDataset
<br><br>
---
## ::: ultralytics.data.dataset.load_dataset_cache_file
<br><br>
---
## ::: ultralytics.data.dataset.save_dataset_cache_file
<br><br>

@ -25,6 +25,10 @@ keywords: Ultralytics, data utils, YOLO, img2label_paths, exif_size, polygon2mas
## ::: ultralytics.data.utils.exif_size
<br><br>
---
## ::: ultralytics.data.utils.verify_image
<br><br>
---
## ::: ultralytics.data.utils.verify_image_label
<br><br>

@ -13,7 +13,7 @@ from torchvision.transforms import ToTensor
from ultralytics import RTDETR, YOLO
from ultralytics.data.build import load_inference_source
from ultralytics.utils import ASSETS, DEFAULT_CFG, LINUX, ONLINE, ROOT, SETTINGS
from ultralytics.utils import ASSETS, DEFAULT_CFG, LINUX, ONLINE, ROOT, SETTINGS, WINDOWS
from ultralytics.utils.downloads import download
from ultralytics.utils.torch_utils import TORCH_1_9
@ -26,18 +26,24 @@ TMP = (ROOT / '../tests/tmp').resolve() # temp directory for test files
def test_model_forward():
model = YOLO(CFG)
model(SOURCE, imgsz=32, augment=True)
model(source=None, imgsz=32, augment=True) # also test no source and augment
def test_model_methods():
model = YOLO(MODEL)
# Model methods
model.info(verbose=True, detailed=True)
model = model.reset_weights()
model = model.load(MODEL)
model.to('cpu')
model.fuse()
# Model properties
_ = model.names
_ = model.device
_ = model.transforms
_ = model.task_map
def test_predict_txt():
@ -88,12 +94,13 @@ def test_predict_img():
def test_predict_grey_and_4ch():
# Convert SOURCE to greyscale and 4-ch
im = Image.open(SOURCE)
stem = SOURCE.parent / SOURCE.stem
directory = TMP / 'im4'
directory.mkdir(parents=True, exist_ok=True)
source_greyscale = Path(f'{stem}_greyscale.jpg')
source_rgba = Path(f'{stem}_4ch.png')
source_non_utf = Path(f'{stem}_veículo.jpg')
source_spaces = Path(f'{stem} with spaces.jpg')
source_greyscale = directory / 'greyscale.jpg'
source_rgba = directory / '4ch.png'
source_non_utf = directory / 'non_UTF_测试文件_tést_image.jpg'
source_spaces = directory / 'image with spaces.jpg'
im.convert('L').save(source_greyscale) # greyscale
im.convert('RGBA').save(source_rgba) # 4-ch PNG with alpha
@ -116,7 +123,7 @@ def test_track_stream():
import yaml
model = YOLO(MODEL)
model.predict('https://youtu.be/G17sBkb38XQ', imgsz=96)
model.predict('https://youtu.be/G17sBkb38XQ', imgsz=96, save=True)
model.track('https://ultralytics.com/assets/decelera_portrait_min.mov', imgsz=160, tracker='bytetrack.yaml')
model.track('https://ultralytics.com/assets/decelera_portrait_min.mov', imgsz=160, tracker='botsort.yaml')
@ -150,7 +157,7 @@ def test_train_pretrained():
def test_export_torchscript():
model = YOLO(MODEL)
f = model.export(format='torchscript')
f = model.export(format='torchscript', optimize=True)
YOLO(f)(SOURCE) # exported model inference
@ -166,11 +173,12 @@ def test_export_openvino():
YOLO(f)(SOURCE) # exported model inference
def test_export_coreml(): # sourcery skip: move-assign
model = YOLO(MODEL)
model.export(format='coreml', nms=True)
# if MACOS:
# YOLO(f)(SOURCE) # model prediction only supported on macOS
def test_export_coreml():
if not WINDOWS: # RuntimeError: BlobWriter not loaded with coremltools 7.0 on windows
model = YOLO(MODEL)
model.export(format='coreml', nms=True)
# if MACOS:
# YOLO(f)(SOURCE) # model prediction only supported on macOS
def test_export_tflite(enabled=False):
@ -196,7 +204,7 @@ def test_export_paddle(enabled=False):
model.export(format='paddle')
def test_export_ncnn(enabled=False):
def test_export_ncnn():
model = YOLO(MODEL)
f = model.export(format='ncnn')
YOLO(f)(SOURCE) # exported model inference

@ -634,7 +634,7 @@ class CopyPaste:
result = cv2.flip(im, 1) # augment segments (flip left-right)
i = cv2.flip(im_new, 1).astype(bool)
im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
im[i] = result[i]
labels['img'] = im
labels['cls'] = cls

@ -9,8 +9,6 @@ import cv2
import numpy as np
from tqdm import tqdm
from ultralytics.utils.checks import check_requirements
def coco91_to_coco80_class():
"""Converts 91-index COCO class IDs to 80-index COCO class IDs.
@ -18,7 +16,6 @@ def coco91_to_coco80_class():
Returns:
(list): A list of 91 class IDs where the index represents the 80-index class ID and the value is the
corresponding 91-index class ID.
"""
return [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, None, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, None, 24, 25, None,
@ -119,9 +116,7 @@ def convert_coco(labels_dir='../coco/annotations/', use_segments=False, use_keyp
if len(ann['segmentation']) == 0:
segments.append([])
continue
if isinstance(ann['segmentation'], dict):
ann['segmentation'] = rle2polygon(ann['segmentation'])
if len(ann['segmentation']) > 1:
elif len(ann['segmentation']) > 1:
s = merge_multi_segment(ann['segmentation'])
s = (np.concatenate(s, axis=0) / np.array([w, h])).reshape(-1).tolist()
else:
@ -131,9 +126,8 @@ def convert_coco(labels_dir='../coco/annotations/', use_segments=False, use_keyp
if s not in segments:
segments.append(s)
if use_keypoints and ann.get('keypoints') is not None:
k = (np.array(ann['keypoints']).reshape(-1, 3) / np.array([w, h, 1])).reshape(-1).tolist()
k = box + k
keypoints.append(k)
keypoints.append(box + (np.array(ann['keypoints']).reshape(-1, 3) /
np.array([w, h, 1])).reshape(-1).tolist())
# Write
with open((fn / f).with_suffix('.txt'), 'a') as file:
@ -237,34 +231,6 @@ def convert_dota_to_yolo_obb(dota_root_path: str):
convert_label(image_name_without_ext, w, h, orig_label_dir, save_dir)
def rle2polygon(segmentation):
"""
Convert Run-Length Encoding (RLE) mask to polygon coordinates.
Args:
segmentation (dict, list): RLE mask representation of the object segmentation.
Returns:
(list): A list of lists representing the polygon coordinates for each contour.
Note:
Requires the 'pycocotools' package to be installed.
"""
check_requirements('pycocotools')
from pycocotools import mask
m = mask.decode(segmentation)
m[m > 0] = 255
contours, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS)
polygons = []
for contour in contours:
epsilon = 0.001 * cv2.arcLength(contour, True)
contour_approx = cv2.approxPolyDP(contour, epsilon, True)
polygon = contour_approx.flatten().tolist()
polygons.append(polygon)
return polygons
def min_index(arr1, arr2):
"""
Find a pair of indexes with the shortest distance between two arrays of 2D points.

@ -144,9 +144,7 @@ def verify_image_label(args):
if keypoint:
keypoints = lb[:, 5:].reshape(-1, nkpt, ndim)
if ndim == 2:
kpt_mask = np.ones(keypoints.shape[:2], dtype=np.float32)
kpt_mask = np.where(keypoints[..., 0] < 0, 0.0, kpt_mask)
kpt_mask = np.where(keypoints[..., 1] < 0, 0.0, kpt_mask)
kpt_mask = np.where((keypoints[..., 0] < 0) | (keypoints[..., 1] < 0), 0.0, 1.0).astype(np.float32)
keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3)
lb = lb[:, :5]
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg

@ -127,7 +127,7 @@ class Exporter:
Attributes:
args (SimpleNamespace): Configuration for the exporter.
save_dir (Path): Directory to save results.
callbacks (list, optional): List of callback functions. Defaults to None.
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
@ -189,7 +189,7 @@ class Exporter:
model.eval()
model.float()
model = model.fuse()
for k, m in model.named_modules():
for m in model.modules():
if isinstance(m, (Detect, RTDETRDecoder)): # Segment and Pose use Detect base class
m.dynamic = self.args.dynamic
m.export = True
@ -427,7 +427,7 @@ class Exporter:
system = 'macos' if MACOS else 'ubuntu' if LINUX else 'windows' # operating system
asset = [x for x in assets if system in x][0] if assets else \
f'https://github.com/pnnx/pnnx/releases/download/20230816/pnnx-20230816-{system}.zip' # fallback
attempt_download_asset(asset, repo='pnnx/pnnx', release='latest')
asset = attempt_download_asset(asset, repo='pnnx/pnnx', release='latest')
unzip_dir = Path(asset).with_suffix('')
pnnx = ROOT / pnnx_filename # new location
(unzip_dir / pnnx_filename).rename(pnnx) # move binary to ROOT
@ -502,9 +502,9 @@ class Exporter:
check_requirements('scikit-learn') # scikit-learn package required for k-means quantization
if mlmodel:
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
else:
elif bits == 8: # mlprogram already quantized to FP16
import coremltools.optimize.coreml as cto
op_config = cto.OpPalettizerConfig(mode=mode, nbits=bits, weight_threshold=512)
op_config = cto.OpPalettizerConfig(mode='kmeans', nbits=bits, weight_threshold=512)
config = cto.OptimizationConfig(global_config=op_config)
ct_model = cto.palettize_weights(ct_model, config=config)
if self.args.nms and self.model.task == 'detect':
@ -839,7 +839,7 @@ class Exporter:
import coremltools as ct # noqa
LOGGER.info(f'{prefix} starting pipeline with coremltools {ct.__version__}...')
batch_size, ch, h, w = list(self.im.shape) # BCHW
_, _, h, w = list(self.im.shape) # BCHW
# Output shapes
spec = model.get_spec()
@ -857,8 +857,8 @@ class Exporter:
# Checks
names = self.metadata['names']
nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height
na, nc = out0_shape
# na, nc = out0.type.multiArrayType.shape # number anchors, classes
_, nc = out0_shape # number of anchors, number of classes
# _, nc = out0.type.multiArrayType.shape
assert len(names) == nc, f'{len(names)} names found for nc={nc}' # check
# Define output shapes (missing)
@ -968,7 +968,7 @@ class IOSDetectModel(torch.nn.Module):
def __init__(self, model, im):
"""Initialize the IOSDetectModel class with a YOLO model and example image."""
super().__init__()
b, c, h, w = im.shape # batch, channel, height, width
_, _, h, w = im.shape # batch, channel, height, width
self.model = model
self.nc = len(model.names) # number of classes
if w == h:

@ -343,8 +343,7 @@ class BasePredictor:
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
else: # stream
fps, w, h = 30, im0.shape[1], im0.shape[0]
suffix = '.mp4' if MACOS else '.avi' if WINDOWS else '.avi'
fourcc = 'avc1' if MACOS else 'WMV2' if WINDOWS else 'MJPG'
suffix, fourcc = ('.mp4', 'avc1') if MACOS else ('.avi', 'WMV2') if WINDOWS else ('.avi', 'MJPG')
save_path = str(Path(save_path).with_suffix(suffix))
self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
self.vid_writer[idx].write(im0)

@ -261,7 +261,6 @@ class DetectionModel(BaseModel):
for si, fi in zip(s, f):
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
yi = super().predict(xi)[0] # forward
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
yi = self._descale_pred(yi, fi, si, img_size)
y.append(yi)
y = self._clip_augmented(y) # clip augmented tails

@ -852,9 +852,10 @@ ENVIRONMENT = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter'
TESTS_RUNNING = is_pytest_running() or is_github_actions_ci()
set_sentry()
# Apply monkey patches if the script is being run from within the parent directory of the script's location
from .patches import imread, imshow, imwrite
# Apply monkey patches
from .patches import imread, imshow, imwrite, torch_save
# torch.save = torch_save
if Path(inspect.stack()[0].filename).parent.parent.as_posix() in inspect.stack()[-1].filename:
torch.save = torch_save
if WINDOWS:
# Apply cv2 patches for non-ASCII and non-UTF characters in image paths
cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow

@ -240,7 +240,7 @@ class ProfileModels:
if path.is_dir():
extensions = ['*.pt', '*.onnx', '*.yaml']
files.extend([file for ext in extensions for file in glob.glob(str(path / ext))])
elif path.suffix in ('.pt', '.yaml', '.yml'): # add non-existing
elif path.suffix in {'.pt', '.yaml', '.yml'}: # add non-existing
files.append(str(path))
else:
files.extend(glob.glob(str(path)))
@ -262,7 +262,7 @@ class ProfileModels:
data = clipped_data
return data
def profile_tensorrt_model(self, engine_file: str):
def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-7):
if not self.trt or not Path(engine_file).is_file():
return 0.0, 0.0
@ -279,7 +279,7 @@ class ProfileModels:
elapsed = time.time() - start_time
# Compute number of runs as higher of min_time or num_timed_runs
num_runs = max(round(self.min_time / elapsed * self.num_warmup_runs), self.num_timed_runs * 50)
num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs * 50)
# Timed runs
run_times = []
@ -290,7 +290,7 @@ class ProfileModels:
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping
return np.mean(run_times), np.std(run_times)
def profile_onnx_model(self, onnx_file: str):
def profile_onnx_model(self, onnx_file: str, eps: float = 1e-7):
check_requirements('onnxruntime')
import onnxruntime as ort
@ -330,7 +330,7 @@ class ProfileModels:
elapsed = time.time() - start_time
# Compute number of runs as higher of min_time or num_timed_runs
num_runs = max(round(self.min_time / elapsed * self.num_warmup_runs), self.num_timed_runs)
num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs)
# Timed runs
run_times = []

@ -101,7 +101,11 @@ def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), p
zip_file = directory.with_suffix('.zip')
compression = ZIP_DEFLATED if compress else ZIP_STORED
with ZipFile(zip_file, 'w', compression) as f:
for file in tqdm(files_to_zip, desc=f'Zipping {directory} to {zip_file}...', unit='file', disable=not progress):
for file in tqdm(files_to_zip,
desc=f'Zipping {directory} to {zip_file}...',
unit='file',
disable=not progress,
bar_format=TQDM_BAR_FORMAT):
f.write(file, file.relative_to(directory))
return zip_file # return path to zip file
@ -159,7 +163,11 @@ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=Fals
LOGGER.info(f'Skipping {file} unzip (already unzipped)')
return path
for f in tqdm(files, desc=f'Unzipping {file} to {Path(path).resolve()}...', unit='file', disable=not progress):
for f in tqdm(files,
desc=f'Unzipping {file} to {Path(path).resolve()}...',
unit='file',
disable=not progress,
bar_format=TQDM_BAR_FORMAT):
zipObj.extract(f, path=extract_path)
return path # return unzip dir

@ -13,20 +13,45 @@ import torch
_imshow = cv2.imshow # copy to avoid recursion errors
def imread(filename, flags=cv2.IMREAD_COLOR):
def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
"""Read an image from a file.
Args:
filename (str): Path to the file to read.
flags (int, optional): Flag that can take values of cv2.IMREAD_*. Defaults to cv2.IMREAD_COLOR.
Returns:
(np.ndarray): The read image.
"""
return cv2.imdecode(np.fromfile(filename, np.uint8), flags)
def imwrite(filename, img):
def imwrite(filename: str, img: np.ndarray, params=None):
"""Write an image to a file.
Args:
filename (str): Path to the file to write.
img (np.ndarray): Image to write.
params (list of ints, optional): Additional parameters. See OpenCV documentation.
Returns:
(bool): True if the file was written, False otherwise.
"""
try:
cv2.imencode(Path(filename).suffix, img)[1].tofile(filename)
cv2.imencode(Path(filename).suffix, img, params)[1].tofile(filename)
return True
except Exception:
return False
def imshow(path, im):
_imshow(path.encode('unicode_escape').decode(), im)
def imshow(winname: str, mat: np.ndarray):
"""Displays an image in the specified window.
Args:
winname (str): Name of the window.
mat (np.ndarray): Image to be shown.
"""
_imshow(winname.encode('unicode_escape').decode(), mat)
# PyTorch functions ----------------------------------------------------------------------------------------------------
@ -34,12 +59,17 @@ _torch_save = torch.save # copy to avoid recursion errors
def torch_save(*args, **kwargs):
"""Use dill (if exists) to serialize the lambda functions where pickle does not do this."""
"""Use dill (if exists) to serialize the lambda functions where pickle does not do this.
Args:
*args (tuple): Positional arguments to pass to torch.save.
**kwargs (dict): Keyword arguments to pass to torch.save.
"""
try:
import dill as pickle
import dill as pickle # noqa
except ImportError:
import pickle
if 'pickle_module' not in kwargs:
kwargs['pickle_module'] = pickle
kwargs['pickle_module'] = pickle # noqa
return _torch_save(*args, **kwargs)

Loading…
Cancel
Save