From ea527507fecb72cbdb1881dbd8d46c246efce90b Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 1 Apr 2024 00:16:52 +0200 Subject: [PATCH] `ultralytics 8.1.40` search in Python sets `{}` for speed (#9450) Signed-off-by: Glenn Jocher --- tests/test_python.py | 2 +- ultralytics/__init__.py | 2 +- ultralytics/cfg/__init__.py | 6 ++--- ultralytics/data/augment.py | 4 ++-- ultralytics/data/base.py | 4 ++-- ultralytics/data/converter.py | 4 ++-- ultralytics/data/dataset.py | 8 +++---- ultralytics/data/loaders.py | 19 ++++++++------- ultralytics/data/split_dota.py | 2 +- ultralytics/data/utils.py | 15 ++++++------ ultralytics/engine/exporter.py | 6 ++--- ultralytics/engine/model.py | 4 ++-- ultralytics/engine/results.py | 4 ++-- ultralytics/engine/trainer.py | 24 +++++++++---------- ultralytics/engine/validator.py | 4 ++-- ultralytics/hub/utils.py | 2 +- ultralytics/models/fastsam/model.py | 2 +- ultralytics/models/nas/model.py | 2 +- ultralytics/models/sam/model.py | 2 +- .../models/sam/modules/tiny_encoder.py | 2 +- ultralytics/models/yolo/classify/train.py | 2 +- ultralytics/models/yolo/detect/train.py | 2 +- ultralytics/models/yolo/world/train.py | 2 +- ultralytics/nn/autobackend.py | 10 ++++---- ultralytics/nn/modules/conv.py | 2 +- ultralytics/nn/modules/head.py | 8 +++---- ultralytics/nn/tasks.py | 2 +- ultralytics/solutions/ai_gym.py | 2 +- ultralytics/solutions/heatmap.py | 2 +- ultralytics/trackers/byte_tracker.py | 2 +- ultralytics/trackers/track.py | 2 +- ultralytics/trackers/utils/gmc.py | 2 +- ultralytics/utils/__init__.py | 8 +++---- ultralytics/utils/benchmarks.py | 4 ++-- ultralytics/utils/callbacks/comet.py | 2 +- ultralytics/utils/callbacks/mlflow.py | 2 +- ultralytics/utils/checks.py | 4 ++-- ultralytics/utils/downloads.py | 4 ++-- ultralytics/utils/metrics.py | 2 +- ultralytics/utils/plotting.py | 2 +- ultralytics/utils/torch_utils.py | 6 ++--- 41 files changed, 97 insertions(+), 93 deletions(-) diff --git a/tests/test_python.py b/tests/test_python.py index fcda953899..f2bc886ba9 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -351,7 +351,7 @@ def test_labels_and_crops(): crop_dirs = [p for p in (save_path / "crops").iterdir()] crop_files = [f for p in crop_dirs for f in p.glob("*")] # Crop directories match detections - assert all([r.names.get(c) in [d.name for d in crop_dirs] for c in cls_idxs]) + assert all([r.names.get(c) in {d.name for d in crop_dirs} for c in cls_idxs]) # Same number of crops as detections assert len([f for f in crop_files if im_name in f.name]) == len(r.boxes.data) diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 07232b4414..8c16411d82 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.1.39" +__version__ = "8.1.40" from ultralytics.data.explorer.explorer import Explorer from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py index 45f95485c4..b907a8eb13 100644 --- a/ultralytics/cfg/__init__.py +++ b/ultralytics/cfg/__init__.py @@ -272,7 +272,7 @@ def get_save_dir(args, name=None): project = args.project or (ROOT.parent / "tests/tmp/runs" if TESTS_RUNNING else RUNS_DIR) / args.task name = name or args.name or f"{args.mode}" - save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True) + save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in {-1, 0} else True) return Path(save_dir) @@ -566,10 +566,10 @@ def entrypoint(debug=""): task = model.task # Mode - if mode in ("predict", "track") and "source" not in overrides: + if mode in {"predict", "track"} and "source" not in overrides: overrides["source"] = DEFAULT_CFG.source or ASSETS LOGGER.warning(f"WARNING ⚠️ 'source' argument is missing. Using default 'source={overrides['source']}'.") - elif mode in ("train", "val"): + elif mode in {"train", "val"}: if "data" not in overrides and "resume" not in overrides: overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data) LOGGER.warning(f"WARNING ⚠️ 'data' argument is missing. Using default 'data={overrides['data']}'.") diff --git a/ultralytics/data/augment.py b/ultralytics/data/augment.py index c72fa077e1..d07bb4673e 100644 --- a/ultralytics/data/augment.py +++ b/ultralytics/data/augment.py @@ -191,7 +191,7 @@ class Mosaic(BaseMixTransform): def __init__(self, dataset, imgsz=640, p=1.0, n=4): """Initializes the object with a dataset, image size, probability, and border.""" assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}." - assert n in (4, 9), "grid must be equal to 4 or 9." + assert n in {4, 9}, "grid must be equal to 4 or 9." super().__init__(dataset=dataset, p=p) self.dataset = dataset self.imgsz = imgsz @@ -685,7 +685,7 @@ class RandomFlip: Default is 'horizontal'. flip_idx (array-like, optional): Index mapping for flipping keypoints, if any. """ - assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}" + assert direction in {"horizontal", "vertical"}, f"Support direction `horizontal` or `vertical`, got {direction}" assert 0 <= p <= 1.0 self.p = p diff --git a/ultralytics/data/base.py b/ultralytics/data/base.py index 62ac869c57..7aa3928a77 100644 --- a/ultralytics/data/base.py +++ b/ultralytics/data/base.py @@ -15,7 +15,7 @@ import psutil from torch.utils.data import Dataset from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM -from .utils import HELP_URL, IMG_FORMATS +from .utils import HELP_URL, FORMATS_HELP_MSG, IMG_FORMATS class BaseDataset(Dataset): @@ -118,7 +118,7 @@ class BaseDataset(Dataset): raise FileNotFoundError(f"{self.prefix}{p} does not exist") im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS) # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib - assert im_files, f"{self.prefix}No images found in {img_path}" + assert im_files, f"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}" except Exception as e: raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e if self.fraction < 1: diff --git a/ultralytics/data/converter.py b/ultralytics/data/converter.py index 62370f8143..ebfb8af66d 100644 --- a/ultralytics/data/converter.py +++ b/ultralytics/data/converter.py @@ -481,7 +481,7 @@ def merge_multi_segment(segments): segments[i] = np.roll(segments[i], -idx[0], axis=0) segments[i] = np.concatenate([segments[i], segments[i][:1]]) # Deal with the first segment and the last one - if i in [0, len(idx_list) - 1]: + if i in {0, len(idx_list) - 1}: s.append(segments[i]) else: idx = [0, idx[1] - idx[0]] @@ -489,7 +489,7 @@ def merge_multi_segment(segments): else: for i in range(len(idx_list) - 1, -1, -1): - if i not in [0, len(idx_list) - 1]: + if i not in {0, len(idx_list) - 1}: idx = idx_list[i] nidx = abs(idx[1] - idx[0]) s.append(segments[i][nidx:]) diff --git a/ultralytics/data/dataset.py b/ultralytics/data/dataset.py index 7acf7689a2..8a46f7b770 100644 --- a/ultralytics/data/dataset.py +++ b/ultralytics/data/dataset.py @@ -77,7 +77,7 @@ class YOLODataset(BaseDataset): desc = f"{self.prefix}Scanning {path.parent / path.stem}..." total = len(self.im_files) nkpt, ndim = self.data.get("kpt_shape", (0, 0)) - if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)): + if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}): raise ValueError( "'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of " "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'" @@ -142,7 +142,7 @@ class YOLODataset(BaseDataset): # Display cache nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total - if exists and LOCAL_RANK in (-1, 0): + if exists and LOCAL_RANK in {-1, 0}: d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt" TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results if cache["msgs"]: @@ -235,7 +235,7 @@ class YOLODataset(BaseDataset): value = values[i] if k == "img": value = torch.stack(value, 0) - if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]: + if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}: value = torch.cat(value, 0) new_batch[k] = value new_batch["batch_idx"] = list(new_batch["batch_idx"]) @@ -334,7 +334,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): assert cache["version"] == DATASET_CACHE_VERSION # matches current version assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total - if LOCAL_RANK in (-1, 0): + if LOCAL_RANK in {-1, 0}: d = f"{desc} {nf} images, {nc} corrupt" TQDM(None, desc=d, total=n, initial=n) if cache["msgs"]: diff --git a/ultralytics/data/loaders.py b/ultralytics/data/loaders.py index 4b89770c74..95116b5cc1 100644 --- a/ultralytics/data/loaders.py +++ b/ultralytics/data/loaders.py @@ -15,7 +15,7 @@ import requests import torch from PIL import Image -from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS +from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS, FORMATS_HELP_MSG from ultralytics.utils import LOGGER, is_colab, is_kaggle, ops from ultralytics.utils.checks import check_requirements @@ -83,7 +83,7 @@ class LoadStreams: for i, s in enumerate(sources): # index, source # Start thread to read frames from video stream st = f"{i + 1}/{n}: {s}... " - if urlparse(s).hostname in ("www.youtube.com", "youtube.com", "youtu.be"): # if source is YouTube video + if urlparse(s).hostname in {"www.youtube.com", "youtube.com", "youtu.be"}: # if source is YouTube video # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4' s = get_best_youtube_url(s) s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam @@ -291,8 +291,14 @@ class LoadImagesAndVideos: else: raise FileNotFoundError(f"{p} does not exist") - images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS] - videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS] + # Define files as images or videos + images, videos = [], [] + for f in files: + suffix = f.split(".")[-1].lower() # Get file extension without the dot and lowercase + if suffix in IMG_FORMATS: + images.append(f) + elif suffix in VID_FORMATS: + videos.append(f) ni, nv = len(images), len(videos) self.files = images + videos @@ -307,10 +313,7 @@ class LoadImagesAndVideos: else: self.cap = None if self.nf == 0: - raise FileNotFoundError( - f"No images or videos found in {p}. " - f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}" - ) + raise FileNotFoundError(f"No images or videos found in {p}. {FORMATS_HELP_MSG}") def __iter__(self): """Returns an iterator object for VideoStream or ImageFolder.""" diff --git a/ultralytics/data/split_dota.py b/ultralytics/data/split_dota.py index 8a5469b853..f0a85d91fd 100644 --- a/ultralytics/data/split_dota.py +++ b/ultralytics/data/split_dota.py @@ -71,7 +71,7 @@ def load_yolo_dota(data_root, split="train"): - train - val """ - assert split in ["train", "val"] + assert split in {"train", "val"}, f"Split must be 'train' or 'val', not {split}." im_dir = Path(data_root) / "images" / split assert im_dir.exists(), f"Can't find {im_dir}, please check your data root." im_files = glob(str(Path(data_root) / "images" / split / "*")) diff --git a/ultralytics/data/utils.py b/ultralytics/data/utils.py index 1ad926c9d5..c3853db631 100644 --- a/ultralytics/data/utils.py +++ b/ultralytics/data/utils.py @@ -39,6 +39,7 @@ HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatt IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"} # image suffixes VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders +FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}" def img2label_paths(img_paths): @@ -63,7 +64,7 @@ def exif_size(img: Image.Image): exif = img.getexif() if exif: rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274 - if rotation in [6, 8]: # rotation 270 or 90 + if rotation in {6, 8}: # rotation 270 or 90 s = s[1], s[0] return s @@ -79,8 +80,8 @@ def verify_image(args): shape = exif_size(im) # image size shape = (shape[1], shape[0]) # hw assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" - assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}" - if im.format.lower() in ("jpg", "jpeg"): + assert im.format.lower() in IMG_FORMATS, f"Invalid image format {im.format}. {FORMATS_HELP_MSG}" + if im.format.lower() in {"jpg", "jpeg"}: with open(im_file, "rb") as f: f.seek(-2, 2) if f.read() != b"\xff\xd9": # corrupt JPEG @@ -105,8 +106,8 @@ def verify_image_label(args): shape = exif_size(im) # image size shape = (shape[1], shape[0]) # hw assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" - assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}" - if im.format.lower() in ("jpg", "jpeg"): + assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}. {FORMATS_HELP_MSG}" + if im.format.lower() in {"jpg", "jpeg"}: with open(im_file, "rb") as f: f.seek(-2, 2) if f.read() != b"\xff\xd9": # corrupt JPEG @@ -336,7 +337,7 @@ def check_det_dataset(dataset, autodownload=True): else: # python script exec(s, {"yaml": data}) dt = f"({round(time.time() - t, 1)}s)" - s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌" + s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in {0, None} else f"failure {dt} ❌" LOGGER.info(f"Dataset download {s}\n") check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts @@ -366,7 +367,7 @@ def check_cls_dataset(dataset, split=""): # Download (optional if dataset=https://file.zip is passed directly) if str(dataset).startswith(("http:/", "https:/")): dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False) - elif Path(dataset).suffix in (".zip", ".tar", ".gz"): + elif Path(dataset).suffix in {".zip", ".tar", ".gz"}: file = check_file(dataset) dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False) diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index 9c32e0ac5d..85feec547c 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -159,7 +159,7 @@ class Exporter: _callbacks (dict, optional): Dictionary of callback functions. Defaults to None. """ self.args = get_cfg(cfg, overrides) - if self.args.format.lower() in ("coreml", "mlmodel"): # fix attempt for protobuf<3.20.x errors + if self.args.format.lower() in {"coreml", "mlmodel"}: # fix attempt for protobuf<3.20.x errors os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # must run before TensorBoard callback self.callbacks = _callbacks or callbacks.get_default_callbacks() @@ -171,9 +171,9 @@ class Exporter: self.run_callbacks("on_export_start") t = time.time() fmt = self.args.format.lower() # to lowercase - if fmt in ("tensorrt", "trt"): # 'engine' aliases + if fmt in {"tensorrt", "trt"}: # 'engine' aliases fmt = "engine" - if fmt in ("mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"): # 'coreml' aliases + if fmt in {"mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"}: # 'coreml' aliases fmt = "coreml" fmts = tuple(export_formats()["Argument"][1:]) # available export formats flags = [x == fmt for x in fmts] diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index ed03547cf1..cae5023b20 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -145,7 +145,7 @@ class Model(nn.Module): return # Load or create new YOLO model - if Path(model).suffix in (".yaml", ".yml"): + if Path(model).suffix in {".yaml", ".yml"}: self._new(model, task=task, verbose=verbose) else: self._load(model, task=task) @@ -666,7 +666,7 @@ class Model(nn.Module): self.trainer.hub_session = self.session # attach optional HUB session self.trainer.train() # Update model and cfg after training - if RANK in (-1, 0): + if RANK in {-1, 0}: ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last self.model, _ = attempt_load_one_weight(ckpt) self.overrides = self.model.args diff --git a/ultralytics/engine/results.py b/ultralytics/engine/results.py index 85849c34d0..ba6f213746 100644 --- a/ultralytics/engine/results.py +++ b/ultralytics/engine/results.py @@ -470,7 +470,7 @@ class Boxes(BaseTensor): if boxes.ndim == 1: boxes = boxes[None, :] n = boxes.shape[-1] - assert n in (6, 7), f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls + assert n in {6, 7}, f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls super().__init__(boxes, orig_shape) self.is_track = n == 7 self.orig_shape = orig_shape @@ -687,7 +687,7 @@ class OBB(BaseTensor): if boxes.ndim == 1: boxes = boxes[None, :] n = boxes.shape[-1] - assert n in (7, 8), f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls + assert n in {7, 8}, f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls super().__init__(boxes, orig_shape) self.is_track = n == 8 self.orig_shape = orig_shape diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index f344b157bb..96931fabb0 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -107,7 +107,7 @@ class BaseTrainer: self.save_dir = get_save_dir(self.args) self.args.name = self.save_dir.name # update name for loggers self.wdir = self.save_dir / "weights" # weights dir - if RANK in (-1, 0): + if RANK in {-1, 0}: self.wdir.mkdir(parents=True, exist_ok=True) # make dir self.args.save_dir = str(self.save_dir) yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args @@ -121,7 +121,7 @@ class BaseTrainer: print_args(vars(self.args)) # Device - if self.device.type in ("cpu", "mps"): + if self.device.type in {"cpu", "mps"}: self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading # Model and Dataset @@ -144,7 +144,7 @@ class BaseTrainer: # Callbacks self.callbacks = _callbacks or callbacks.get_default_callbacks() - if RANK in (-1, 0): + if RANK in {-1, 0}: callbacks.add_integration_callbacks(self) def add_callback(self, event: str, callback): @@ -251,7 +251,7 @@ class BaseTrainer: # Check AMP self.amp = torch.tensor(self.args.amp).to(self.device) # True or False - if self.amp and RANK in (-1, 0): # Single-GPU and DDP + if self.amp and RANK in {-1, 0}: # Single-GPU and DDP callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them self.amp = torch.tensor(check_amp(self.model), device=self.device) callbacks.default_callbacks = callbacks_backup # restore callbacks @@ -274,7 +274,7 @@ class BaseTrainer: # Dataloaders batch_size = self.batch_size // max(world_size, 1) self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train") - if RANK in (-1, 0): + if RANK in {-1, 0}: # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects. self.test_loader = self.get_dataloader( self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val" @@ -340,7 +340,7 @@ class BaseTrainer: self._close_dataloader_mosaic() self.train_loader.reset() - if RANK in (-1, 0): + if RANK in {-1, 0}: LOGGER.info(self.progress_string()) pbar = TQDM(enumerate(self.train_loader), total=nb) self.tloss = None @@ -392,7 +392,7 @@ class BaseTrainer: mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB) loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1 losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0) - if RANK in (-1, 0): + if RANK in {-1, 0}: pbar.set_description( ("%11s" * 2 + "%11.4g" * (2 + loss_len)) % (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]) @@ -405,7 +405,7 @@ class BaseTrainer: self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers self.run_callbacks("on_train_epoch_end") - if RANK in (-1, 0): + if RANK in {-1, 0}: final_epoch = epoch + 1 >= self.epochs self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"]) @@ -447,7 +447,7 @@ class BaseTrainer: break # must break all DDP ranks epoch += 1 - if RANK in (-1, 0): + if RANK in {-1, 0}: # Do final val with best.pt LOGGER.info( f"\n{epoch - self.start_epoch + 1} epochs completed in " @@ -503,12 +503,12 @@ class BaseTrainer: try: if self.args.task == "classify": data = check_cls_dataset(self.args.data) - elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in ( + elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in { "detect", "segment", "pose", "obb", - ): + }: data = check_det_dataset(self.args.data) if "yaml_file" in data: self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage @@ -740,7 +740,7 @@ class BaseTrainer: else: # weight (with decay) g[0].append(param) - if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"): + if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}: optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0) elif name == "RMSProp": optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum) diff --git a/ultralytics/engine/validator.py b/ultralytics/engine/validator.py index 17666e3859..9e7b6c16f8 100644 --- a/ultralytics/engine/validator.py +++ b/ultralytics/engine/validator.py @@ -139,14 +139,14 @@ class BaseValidator: self.args.batch = 1 # export.py models default to batch-size 1 LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models") - if str(self.args.data).split(".")[-1] in ("yaml", "yml"): + if str(self.args.data).split(".")[-1] in {"yaml", "yml"}: self.data = check_det_dataset(self.args.data) elif self.args.task == "classify": self.data = check_cls_dataset(self.args.data, split=self.args.split) else: raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌")) - if self.device.type in ("cpu", "mps"): + if self.device.type in {"cpu", "mps"}: self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading if not pt: self.args.rect = False diff --git a/ultralytics/hub/utils.py b/ultralytics/hub/utils.py index 1f3de94c3c..e760d4c333 100644 --- a/ultralytics/hub/utils.py +++ b/ultralytics/hub/utils.py @@ -198,7 +198,7 @@ class Events: } self.enabled = ( SETTINGS["sync"] - and RANK in (-1, 0) + and RANK in {-1, 0} and not TESTS_RUNNING and ONLINE and (is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git") diff --git a/ultralytics/models/fastsam/model.py b/ultralytics/models/fastsam/model.py index c01e66b7f7..226b43d193 100644 --- a/ultralytics/models/fastsam/model.py +++ b/ultralytics/models/fastsam/model.py @@ -24,7 +24,7 @@ class FastSAM(Model): """Call the __init__ method of the parent class (YOLO) with the updated default model.""" if str(model) == "FastSAM.pt": model = "FastSAM-x.pt" - assert Path(model).suffix not in (".yaml", ".yml"), "FastSAM models only support pre-trained models." + assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models." super().__init__(model=model, task="segment") @property diff --git a/ultralytics/models/nas/model.py b/ultralytics/models/nas/model.py index 7997e96bd1..c94e2d81c7 100644 --- a/ultralytics/models/nas/model.py +++ b/ultralytics/models/nas/model.py @@ -45,7 +45,7 @@ class NAS(Model): def __init__(self, model="yolo_nas_s.pt") -> None: """Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model.""" - assert Path(model).suffix not in (".yaml", ".yml"), "YOLO-NAS models only support pre-trained models." + assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models." super().__init__(model, task="detect") @smart_inference_mode() diff --git a/ultralytics/models/sam/model.py b/ultralytics/models/sam/model.py index cb12bc7ed7..b931da8332 100644 --- a/ultralytics/models/sam/model.py +++ b/ultralytics/models/sam/model.py @@ -41,7 +41,7 @@ class SAM(Model): Raises: NotImplementedError: If the model file extension is not .pt or .pth. """ - if model and Path(model).suffix not in (".pt", ".pth"): + if model and Path(model).suffix not in {".pt", ".pth"}: raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.") super().__init__(model=model, task="segment") diff --git a/ultralytics/models/sam/modules/tiny_encoder.py b/ultralytics/models/sam/modules/tiny_encoder.py index 98f5ac04a4..66ef10d605 100644 --- a/ultralytics/models/sam/modules/tiny_encoder.py +++ b/ultralytics/models/sam/modules/tiny_encoder.py @@ -112,7 +112,7 @@ class PatchMerging(nn.Module): self.out_dim = out_dim self.act = activation() self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0) - stride_c = 1 if out_dim in [320, 448, 576] else 2 + stride_c = 1 if out_dim in {320, 448, 576} else 2 self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim) self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0) diff --git a/ultralytics/models/yolo/classify/train.py b/ultralytics/models/yolo/classify/train.py index 42c65542fc..7a1f84888e 100644 --- a/ultralytics/models/yolo/classify/train.py +++ b/ultralytics/models/yolo/classify/train.py @@ -68,7 +68,7 @@ class ClassificationTrainer(BaseTrainer): self.model, ckpt = attempt_load_one_weight(model, device="cpu") for p in self.model.parameters(): p.requires_grad = True # for training - elif model.split(".")[-1] in ("yaml", "yml"): + elif model.split(".")[-1] in {"yaml", "yml"}: self.model = self.get_model(cfg=model) elif model in torchvision.models.__dict__: self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None) diff --git a/ultralytics/models/yolo/detect/train.py b/ultralytics/models/yolo/detect/train.py index 3326512bde..65b7efff27 100644 --- a/ultralytics/models/yolo/detect/train.py +++ b/ultralytics/models/yolo/detect/train.py @@ -44,7 +44,7 @@ class DetectionTrainer(BaseTrainer): def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): """Construct and return dataloader.""" - assert mode in ["train", "val"] + assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}." with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP dataset = self.build_dataset(dataset_path, mode, batch_size) shuffle = mode == "train" diff --git a/ultralytics/models/yolo/world/train.py b/ultralytics/models/yolo/world/train.py index 6f51d44304..12d9b43c44 100644 --- a/ultralytics/models/yolo/world/train.py +++ b/ultralytics/models/yolo/world/train.py @@ -11,7 +11,7 @@ from ultralytics.utils.torch_utils import de_parallel def on_pretrain_routine_end(trainer): """Callback.""" - if RANK in (-1, 0): + if RANK in {-1, 0}: # NOTE: for evaluation names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())] de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False) diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index 93c5fa078d..0b473bf325 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -374,9 +374,9 @@ class AutoBackend(nn.Module): metadata = yaml_load(metadata) if metadata: for k, v in metadata.items(): - if k in ("stride", "batch"): + if k in {"stride", "batch"}: metadata[k] = int(v) - elif k in ("imgsz", "names", "kpt_shape") and isinstance(v, str): + elif k in {"imgsz", "names", "kpt_shape"} and isinstance(v, str): metadata[k] = eval(v) stride = metadata["stride"] task = metadata["task"] @@ -531,8 +531,8 @@ class AutoBackend(nn.Module): self.names = {i: f"class{i}" for i in range(nc)} else: # Lite or Edge TPU details = self.input_details[0] - integer = details["dtype"] in (np.int8, np.int16) # is TFLite quantized int8 or int16 model - if integer: + is_int = details["dtype"] in {np.int8, np.int16} # is TFLite quantized int8 or int16 model + if is_int: scale, zero_point = details["quantization"] im = (im / scale + zero_point).astype(details["dtype"]) # de-scale self.interpreter.set_tensor(details["index"], im) @@ -540,7 +540,7 @@ class AutoBackend(nn.Module): y = [] for output in self.output_details: x = self.interpreter.get_tensor(output["index"]) - if integer: + if is_int: scale, zero_point = output["quantization"] x = (x.astype(np.float32) - zero_point) * scale # re-scale if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well diff --git a/ultralytics/nn/modules/conv.py b/ultralytics/nn/modules/conv.py index 399c42255c..6b51813ed6 100644 --- a/ultralytics/nn/modules/conv.py +++ b/ultralytics/nn/modules/conv.py @@ -296,7 +296,7 @@ class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): """Initialize Spatial-attention module with kernel size argument.""" super().__init__() - assert kernel_size in (3, 7), "kernel size must be 3 or 7" + assert kernel_size in {3, 7}, "kernel size must be 3 or 7" padding = 3 if kernel_size == 7 else 1 self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) self.act = nn.Sigmoid() diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 13b4c7f44a..28e2f5d675 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -54,13 +54,13 @@ class Detect(nn.Module): self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) self.shape = shape - if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV ops + if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops box = x_cat[:, : self.reg_max * 4] cls = x_cat[:, self.reg_max * 4 :] else: box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) - if self.export and self.format in ("tflite", "edgetpu"): + if self.export and self.format in {"tflite", "edgetpu"}: # Precompute normalization factor to increase numerical stability # See https://github.com/ultralytics/ultralytics/issues/7371 grid_h = shape[2] @@ -230,13 +230,13 @@ class WorldDetect(Detect): self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) self.shape = shape - if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV ops + if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops box = x_cat[:, : self.reg_max * 4] cls = x_cat[:, self.reg_max * 4 :] else: box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) - if self.export and self.format in ("tflite", "edgetpu"): + if self.export and self.format in {"tflite", "edgetpu"}: # Precompute normalization factor to increase numerical stability # See https://github.com/ultralytics/ultralytics/issues/7371 grid_h = shape[2] diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 9b746d7a6e..7e4adf054c 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -896,7 +896,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) ) # num heads args = [c1, c2, *args[1:]] - if m in (BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3): + if m in {BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3}: args.insert(2, n) # number of repeats n = 1 elif m is AIFI: diff --git a/ultralytics/solutions/ai_gym.py b/ultralytics/solutions/ai_gym.py index b78cf598d8..495250ed1c 100644 --- a/ultralytics/solutions/ai_gym.py +++ b/ultralytics/solutions/ai_gym.py @@ -81,7 +81,7 @@ class AIGym: self.annotator = Annotator(im0, line_width=2) for ind, k in enumerate(reversed(self.keypoints)): - if self.pose_type in ["pushup", "pullup"]: + if self.pose_type in {"pushup", "pullup"}: self.angle[ind] = self.annotator.estimate_pose_angle( k[int(self.kpts_to_check[0])].cpu(), k[int(self.kpts_to_check[1])].cpu(), diff --git a/ultralytics/solutions/heatmap.py b/ultralytics/solutions/heatmap.py index 6524abb06d..807c6bb72b 100644 --- a/ultralytics/solutions/heatmap.py +++ b/ultralytics/solutions/heatmap.py @@ -153,7 +153,7 @@ class Heatmap: self.cls_txtdisplay_gap = cls_txtdisplay_gap # shape of heatmap, if not selected - if self.shape not in ["circle", "rect"]: + if self.shape not in {"circle", "rect"}: print("Unknown shape value provided, 'circle' & 'rect' supported") print("Using Circular shape now") self.shape = "circle" diff --git a/ultralytics/trackers/byte_tracker.py b/ultralytics/trackers/byte_tracker.py index 01cbca9751..d733fde63d 100644 --- a/ultralytics/trackers/byte_tracker.py +++ b/ultralytics/trackers/byte_tracker.py @@ -47,7 +47,7 @@ class STrack(BaseTrack): """Initialize new STrack instance.""" super().__init__() # xywh+idx or xywha+idx - assert len(xywh) in [5, 6], f"expected 5 or 6 values but got {len(xywh)}" + assert len(xywh) in {5, 6}, f"expected 5 or 6 values but got {len(xywh)}" self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32) self.kalman_filter = None self.mean, self.covariance = None, None diff --git a/ultralytics/trackers/track.py b/ultralytics/trackers/track.py index 7146a40178..d8d1d0ef04 100644 --- a/ultralytics/trackers/track.py +++ b/ultralytics/trackers/track.py @@ -31,7 +31,7 @@ def on_predict_start(predictor: object, persist: bool = False) -> None: tracker = check_yaml(predictor.args.tracker) cfg = IterableSimpleNamespace(**yaml_load(tracker)) - if cfg.tracker_type not in ["bytetrack", "botsort"]: + if cfg.tracker_type not in {"bytetrack", "botsort"}: raise AssertionError(f"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'") trackers = [] diff --git a/ultralytics/trackers/utils/gmc.py b/ultralytics/trackers/utils/gmc.py index 36ea6a2a8f..052bf94c8d 100644 --- a/ultralytics/trackers/utils/gmc.py +++ b/ultralytics/trackers/utils/gmc.py @@ -94,7 +94,7 @@ class GMC: array([[1, 2, 3], [4, 5, 6]]) """ - if self.method in ["orb", "sift"]: + if self.method in {"orb", "sift"}: return self.applyFeatures(raw_frame, detections) elif self.method == "ecc": return self.applyEcc(raw_frame) diff --git a/ultralytics/utils/__init__.py b/ultralytics/utils/__init__.py index fdc2b95b8c..c005e3c8f3 100644 --- a/ultralytics/utils/__init__.py +++ b/ultralytics/utils/__init__.py @@ -41,7 +41,7 @@ VERBOSE = str(os.getenv("YOLO_VERBOSE", True)).lower() == "true" # global verbo TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar format LOGGING_NAME = "ultralytics" MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans -ARM64 = platform.machine() in ("arm64", "aarch64") # ARM64 booleans +ARM64 = platform.machine() in {"arm64", "aarch64"} # ARM64 booleans HELP_MSG = """ Usage examples for running YOLOv8: @@ -359,7 +359,7 @@ def yaml_load(file="data.yaml", append_filename=False): Returns: (dict): YAML data and file name. """ - assert Path(file).suffix in (".yaml", ".yml"), f"Attempting to load non-YAML file {file} with yaml_load()" + assert Path(file).suffix in {".yaml", ".yml"}, f"Attempting to load non-YAML file {file} with yaml_load()" with open(file, errors="ignore", encoding="utf-8") as f: s = f.read() # string @@ -866,7 +866,7 @@ def set_sentry(): """ if "exc_info" in hint: exc_type, exc_value, tb = hint["exc_info"] - if exc_type in (KeyboardInterrupt, FileNotFoundError) or "out of memory" in str(exc_value): + if exc_type in {KeyboardInterrupt, FileNotFoundError} or "out of memory" in str(exc_value): return None # do not send event event["tags"] = { @@ -879,7 +879,7 @@ def set_sentry(): if ( SETTINGS["sync"] - and RANK in (-1, 0) + and RANK in {-1, 0} and Path(ARGV[0]).name == "yolo" and not TESTS_RUNNING and ONLINE diff --git a/ultralytics/utils/benchmarks.py b/ultralytics/utils/benchmarks.py index 3bc63510e7..d48b300eda 100644 --- a/ultralytics/utils/benchmarks.py +++ b/ultralytics/utils/benchmarks.py @@ -115,7 +115,7 @@ def benchmark( # Predict assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported" - assert i not in (9, 10), "inference not supported" # Edge TPU and TF.js are unsupported + assert i not in {9, 10}, "inference not supported" # Edge TPU and TF.js are unsupported assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half) @@ -220,7 +220,7 @@ class ProfileModels: output = [] for file in files: engine_file = file.with_suffix(".engine") - if file.suffix in (".pt", ".yaml", ".yml"): + if file.suffix in {".pt", ".yaml", ".yml"}: model = YOLO(str(file)) model.fuse() # to report correct params and GFLOPs in model.info() model_info = model.info() diff --git a/ultralytics/utils/callbacks/comet.py b/ultralytics/utils/callbacks/comet.py index 1c5f585da0..518860c5f9 100644 --- a/ultralytics/utils/callbacks/comet.py +++ b/ultralytics/utils/callbacks/comet.py @@ -71,7 +71,7 @@ def _get_experiment_type(mode, project_name): def _create_experiment(args): """Ensures that the experiment object is only created in a single process during distributed training.""" - if RANK not in (-1, 0): + if RANK not in {-1, 0}: return try: comet_mode = _get_comet_mode() diff --git a/ultralytics/utils/callbacks/mlflow.py b/ultralytics/utils/callbacks/mlflow.py index e5546200f3..b1dc101121 100644 --- a/ultralytics/utils/callbacks/mlflow.py +++ b/ultralytics/utils/callbacks/mlflow.py @@ -108,7 +108,7 @@ def on_train_end(trainer): for f in trainer.save_dir.glob("*"): # log all other files in save_dir if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}: mlflow.log_artifact(str(f)) - keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() in ("true") + keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() == "true" if keep_run_active: LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()") else: diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py index ae6dbce58d..8d3a69b3db 100644 --- a/ultralytics/utils/checks.py +++ b/ultralytics/utils/checks.py @@ -237,7 +237,7 @@ def check_version( result = False elif op == "!=" and c == v: result = False - elif op in (">=", "") and not (c >= v): # if no constraint passed assume '>=required' + elif op in {">=", ""} and not (c >= v): # if no constraint passed assume '>=required' result = False elif op == "<=" and not (c <= v): result = False @@ -632,7 +632,7 @@ def check_amp(model): (bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False. """ device = next(model.parameters()).device # get model device - if device.type in ("cpu", "mps"): + if device.type in {"cpu", "mps"}: return False # AMP only used on CUDA devices def amp_allclose(m, im): diff --git a/ultralytics/utils/downloads.py b/ultralytics/utils/downloads.py index 6191ade233..f24bd3b144 100644 --- a/ultralytics/utils/downloads.py +++ b/ultralytics/utils/downloads.py @@ -356,13 +356,13 @@ def safe_download( raise ConnectionError(emojis(f"❌ Download failure for {url}. Retry limit reached.")) from e LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {url}...") - if unzip and f.exists() and f.suffix in ("", ".zip", ".tar", ".gz"): + if unzip and f.exists() and f.suffix in {"", ".zip", ".tar", ".gz"}: from zipfile import is_zipfile unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place if is_zipfile(f): unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip - elif f.suffix in (".tar", ".gz"): + elif f.suffix in {".tar", ".gz"}: LOGGER.info(f"Unzipping {f} to {unzip_dir}...") subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True) if delete: diff --git a/ultralytics/utils/metrics.py b/ultralytics/utils/metrics.py index b49ee8b38a..22dfd78745 100644 --- a/ultralytics/utils/metrics.py +++ b/ultralytics/utils/metrics.py @@ -298,7 +298,7 @@ class ConfusionMatrix: self.task = task self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc)) self.nc = nc # number of classes - self.conf = 0.25 if conf in (None, 0.001) else conf # apply 0.25 if default val conf is passed + self.conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passed self.iou_thres = iou_thres def process_cls_preds(self, preds, targets): diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py index 04d305bcef..818464d0c5 100644 --- a/ultralytics/utils/plotting.py +++ b/ultralytics/utils/plotting.py @@ -904,7 +904,7 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line ax[i].set_title(s[j], fontsize=12) - # if j in [8, 9, 10]: # share train and val loss y axes + # if j in {8, 9, 10}: # share train and val loss y axes # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5]) except Exception as e: LOGGER.warning(f"WARNING: Plotting error for {f}: {e}") diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index 77449b0440..d5d91e1396 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -37,7 +37,7 @@ TORCHVISION_0_13 = check_version(torchvision.__version__, "0.13.0") def torch_distributed_zero_first(local_rank: int): """Decorator to make all processes in distributed training wait for each local_master to do something.""" initialized = torch.distributed.is_available() and torch.distributed.is_initialized() - if initialized and local_rank not in (-1, 0): + if initialized and local_rank not in {-1, 0}: dist.barrier(device_ids=[local_rank]) yield if initialized and local_rank == 0: @@ -109,7 +109,7 @@ def select_device(device="", batch=0, newline=False, verbose=True): for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ": device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1' cpu = device == "cpu" - mps = device in ("mps", "mps:0") # Apple Metal Performance Shaders (MPS) + mps = device in {"mps", "mps:0"} # Apple Metal Performance Shaders (MPS) if cpu or mps: os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False elif device: # non-cpu device requested @@ -347,7 +347,7 @@ def initialize_weights(model): elif t is nn.BatchNorm2d: m.eps = 1e-3 m.momentum = 0.03 - elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: + elif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}: m.inplace = True