From 64f247d6923e28108792cbeea16343606d8cd844 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 6 Feb 2023 21:57:10 +0400 Subject: [PATCH] `ultralytics 8.0.30` Docker, rect, data=*.zip updates (#832) Signed-off-by: dependabot[bot] Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/docker.yaml | 6 +++--- docker/Dockerfile-arm64 | 8 +++----- docs/cfg.md | 5 +++-- ultralytics/__init__.py | 2 +- ultralytics/nn/tasks.py | 5 +++-- ultralytics/yolo/cfg/default.yaml | 2 +- ultralytics/yolo/data/build.py | 4 ++-- ultralytics/yolo/data/dataset.py | 8 +++++--- ultralytics/yolo/data/utils.py | 5 +++-- ultralytics/yolo/engine/model.py | 1 + ultralytics/yolo/engine/trainer.py | 17 ++++++++++------- ultralytics/yolo/engine/validator.py | 2 ++ ultralytics/yolo/utils/__init__.py | 14 ++++++++------ ultralytics/yolo/utils/checks.py | 2 +- ultralytics/yolo/utils/downloads.py | 22 ++++++++++++++++++---- ultralytics/yolo/v8/detect/train.py | 4 ++-- ultralytics/yolo/v8/detect/val.py | 3 +-- 17 files changed, 67 insertions(+), 43 deletions(-) diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 4d32b1c76..1d1d012e4 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -29,7 +29,7 @@ jobs: password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Build and push arm64 image - uses: docker/build-push-action@v3 + uses: docker/build-push-action@v4 continue-on-error: true with: context: . @@ -39,7 +39,7 @@ jobs: tags: ultralytics/ultralytics:latest-arm64 - name: Build and push CPU image - uses: docker/build-push-action@v3 + uses: docker/build-push-action@v4 continue-on-error: true with: context: . @@ -48,7 +48,7 @@ jobs: tags: ultralytics/ultralytics:latest-cpu - name: Build and push GPU image - uses: docker/build-push-action@v3 + uses: docker/build-push-action@v4 continue-on-error: true with: context: . diff --git a/docker/Dockerfile-arm64 b/docker/Dockerfile-arm64 index 63ca653dc..ad4d9e1e1 100644 --- a/docker/Dockerfile-arm64 +++ b/docker/Dockerfile-arm64 @@ -26,11 +26,9 @@ RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics # Install pip packages COPY requirements.txt . RUN python3 -m pip install --upgrade pip wheel -RUN pip install --no-cache ultralytics gsutil notebook \ - tensorflow-aarch64 - # tensorflowjs \ - # onnx onnx-simplifier onnxruntime \ - # coremltools openvino-dev>=2022.3 \ +RUN pip install --no-cache ultralytics albumentations gsutil notebook \ + coremltools onnx onnx-simplifier onnxruntime openvino-dev>=2022.3 + # tensorflow-aarch64 tensorflowjs \ # Cleanup ENV DEBIAN_FRONTEND teletype diff --git a/docs/cfg.md b/docs/cfg.md index aed688362..9c75ea2da 100644 --- a/docs/cfg.md +++ b/docs/cfg.md @@ -108,6 +108,7 @@ task. | overlap_mask | True | masks should overlap during training (segment train only) | | mask_ratio | 4 | mask downsample ratio (segment train only) | | dropout | 0.0 | use dropout regularization (classify train only) | +| val | True | validate/test during training | ### Prediction @@ -148,7 +149,6 @@ validation dataset and to detect and prevent overfitting. | Key | Value | Description | |-------------|-------|-----------------------------------------------------------------------------| -| val | True | validate/test during training | | save_json | False | save results to JSON file | | save_hybrid | False | save hybrid version of labels (labels + additional predictions) | | conf | 0.001 | object confidence threshold for detection (default 0.25 predict, 0.001 val) | @@ -157,6 +157,7 @@ validation dataset and to detect and prevent overfitting. | half | True | use half precision (FP16) | | dnn | False | use OpenCV DNN for ONNX inference | | plots | False | show plots during training | +| rect | False | support rectangular evaluation | ### Export @@ -222,4 +223,4 @@ it easier to debug and optimize the training process. | name | 'exp' | experiment name. `exp` gets automatically incremented if not specified, i.e, `exp`, `exp2` ... | | exist_ok | False | whether to overwrite existing experiment | | plots | False | save plots during train/val | -| save | False | save train checkpoints and predict results | \ No newline at end of file +| save | False | save train checkpoints and predict results | diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index f6fb57208..948698753 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, GPL-3.0 license -__version__ = "8.0.29" +__version__ = "8.0.30" from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.utils import ops diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 8150b20df..4013b5b7c 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -338,9 +338,10 @@ def torch_safe_load(weight): if e.name == 'omegaconf': # e.name is missing module name LOGGER.warning(f"WARNING ⚠️ {weight} requires {e.name}, which is not in ultralytics requirements." f"\nAutoInstall will run now for {e.name} but this feature will be removed in the future." - f"\nRecommend fixes are to train a new model using updated ultraltyics package or to " + f"\nRecommend fixes are to train a new model using updated ultralytics package or to " f"download updated models from https://github.com/ultralytics/assets/releases/tag/v0.0.0") - check_requirements(e.name) # install missing module + if e.name != 'models': + check_requirements(e.name) # install missing module return torch.load(file, map_location='cpu') # load diff --git a/ultralytics/yolo/cfg/default.yaml b/ultralytics/yolo/cfg/default.yaml index 789f2b70e..5be02eb12 100644 --- a/ultralytics/yolo/cfg/default.yaml +++ b/ultralytics/yolo/cfg/default.yaml @@ -25,7 +25,7 @@ seed: 0 # random seed for reproducibility deterministic: True # whether to enable deterministic mode single_cls: False # train multi-class data as single-class image_weights: False # use weighted image selection for training -rect: False # support rectangular training +rect: False # support rectangular training if mode='train', support rectangular evaluation if mode='val' cos_lr: False # use cosine learning rate scheduler close_mosaic: 10 # disable mosaic augmentation for final 10 epochs resume: False # resume training from last checkpoint diff --git a/ultralytics/yolo/data/build.py b/ultralytics/yolo/data/build.py index a2e62faed..3448232e0 100644 --- a/ultralytics/yolo/data/build.py +++ b/ultralytics/yolo/data/build.py @@ -61,7 +61,7 @@ def seed_worker(worker_id): random.seed(worker_seed) -def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank=-1, mode="train"): +def build_dataloader(cfg, batch_size, img_path, stride=32, rect=False, label_path=None, rank=-1, mode="train"): assert mode in ["train", "val"] shuffle = mode == "train" if cfg.rect and shuffle: @@ -75,7 +75,7 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank batch_size=batch_size, augment=mode == "train", # augmentation hyp=cfg, # TODO: probably add a get_hyps_from_cfg function - rect=cfg.rect if mode == "train" else True, # rectangular batches + rect=cfg.rect or rect, # rectangular batches cache=cfg.cache or None, single_cls=cfg.single_cls or False, stride=int(stride), diff --git a/ultralytics/yolo/data/dataset.py b/ultralytics/yolo/data/dataset.py index 08320e882..e2152ef5d 100644 --- a/ultralytics/yolo/data/dataset.py +++ b/ultralytics/yolo/data/dataset.py @@ -113,13 +113,15 @@ class YOLODataset(BaseDataset): tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results if cache["msgs"]: LOGGER.info("\n".join(cache["msgs"])) # display warnings - assert nf > 0, f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}" + if nf == 0: # number of labels found + raise FileNotFoundError(f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}") # Read cache [cache.pop(k) for k in ("hash", "version", "msgs")] # remove items labels = cache["labels"] # Check if the dataset is all boxes or all segments + len_cls = sum(len(lb["cls"]) for lb in labels) len_boxes = sum(len(lb["bboxes"]) for lb in labels) len_segments = sum(len(lb["segments"]) for lb in labels) if len_segments and len_boxes != len_segments: @@ -129,8 +131,8 @@ class YOLODataset(BaseDataset): "To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.") for lb in labels: lb["segments"] = [] - nl = len(np.concatenate([label["cls"] for label in labels], 0)) # number of labels - assert nl > 0, f"{self.prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}" + if len_cls == 0: + raise ValueError(f"{self.prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}") return labels # TODO: use hyp config to set all these augmentations diff --git a/ultralytics/yolo/data/utils.py b/ultralytics/yolo/data/utils.py index f9ea01965..91d8fe022 100644 --- a/ultralytics/yolo/data/utils.py +++ b/ultralytics/yolo/data/utils.py @@ -192,7 +192,7 @@ def check_det_dataset(dataset, autodownload=True): # Download (optional) extract_dir = '' if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)): - download(data, dir=f'{DATASETS_DIR}/{Path(data).stem}', unzip=True, delete=False, curl=False, threads=1) + download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False, threads=1) data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml')) extract_dir, autodownload = data.parent, False @@ -211,7 +211,8 @@ def check_det_dataset(dataset, autodownload=True): data['nc'] = len(data['names']) # Resolve paths - path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.' + path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent) # dataset root + if not path.is_absolute(): path = (DATASETS_DIR / path).resolve() data['path'] = path # download scripts diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index 0b3a7ad19..5fba06d17 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -156,6 +156,7 @@ class YOLO: **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs """ overrides = self.overrides.copy() + overrides["rect"] = True # rect batches as default overrides.update(kwargs) overrides["mode"] = "val" args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index c2139208f..ce7c3449d 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -116,13 +116,16 @@ class BaseTrainer: # Model and Dataloaders. self.model = self.args.model - self.data = self.args.data - if self.data.endswith(".yaml"): - self.data = check_det_dataset(self.data) - elif self.args.task == 'classify': - self.data = check_cls_dataset(self.data) - else: - raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' not found ❌")) + try: + if self.args.task == 'classify': + self.data = check_cls_dataset(self.args.data) + elif self.args.data.endswith(".yaml") or self.args.task in ('detect', 'segment'): + self.data = check_det_dataset(self.args.data) + if 'yaml_file' in self.data: + self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage + except Exception as e: + raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' error ❌ {e}")) from e + self.trainset, self.testset = self.get_dataset(self.data) self.ema = None diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py index 71710782a..b3e8f5874 100644 --- a/ultralytics/yolo/engine/validator.py +++ b/ultralytics/yolo/engine/validator.py @@ -117,6 +117,8 @@ class BaseValidator: if self.device.type == 'cpu': self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading + if not pt: + self.args.rect = False self.dataloader = self.dataloader or \ self.get_dataloader(self.data.get("val") or self.data.set("test"), self.args.batch) diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py index 616e527ac..f8eea1e8e 100644 --- a/ultralytics/yolo/utils/__init__.py +++ b/ultralytics/yolo/utils/__init__.py @@ -491,6 +491,7 @@ def set_sentry(): ((is_pip_package() and not is_git_dir()) or (get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git" and get_git_branch() == "main")): + import hashlib import sentry_sdk # noqa from ultralytics import __version__ @@ -502,13 +503,14 @@ def set_sentry(): environment='production', # 'dev' or 'production' before_send=before_send, ignore_errors=[KeyboardInterrupt, FileNotFoundError]) + sentry_sdk.set_user({"id": SETTINGS['uuid']}) # Disable all sentry logging for logger in "sentry_sdk", "sentry_sdk.errors": logging.getLogger(logger).setLevel(logging.CRITICAL) -def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'): +def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.2'): """ Loads a global Ultralytics settings YAML file or creates one with default values if it does not exist. @@ -519,6 +521,7 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'): Returns: dict: Dictionary of settings key-value pairs. """ + import hashlib from ultralytics.yolo.utils.checks import check_version from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first @@ -530,7 +533,7 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'): 'weights_dir': str(root / 'weights'), # default weights directory. 'runs_dir': str(root / 'runs'), # default runs directory. 'sync': True, # sync analytics to help with YOLO development - 'uuid': uuid.getnode(), # device UUID to align analytics + 'uuid': hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(), # anonymized uuid hash 'settings_version': version} # Ultralytics settings version with torch_distributed_zero_first(RANK): @@ -544,10 +547,9 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'): and all(type(a) == type(b) for a, b in zip(settings.values(), defaults.values())) \ and check_version(settings['settings_version'], version) if not correct: - LOGGER.warning('WARNING ⚠️ Ultralytics settings reset to defaults. ' - '\nThis is normal and may be due to a recent ultralytics package update, ' - 'but may have overwritten previous settings. ' - f"\nYou may view and update settings directly in '{file}'") + LOGGER.warning('WARNING ⚠️ Ultralytics settings reset to defaults. This is normal and may be due to a ' + 'recent ultralytics package update, but may have overwritten previous settings. ' + f"\nView and update settings with 'yolo settings' or at '{file}'") settings = defaults # merge **defaults with **settings (prefer **settings) yaml_save(file, settings) # save updated defaults diff --git a/ultralytics/yolo/utils/checks.py b/ultralytics/yolo/utils/checks.py index 64b93a4e5..0bb8e6192 100644 --- a/ultralytics/yolo/utils/checks.py +++ b/ultralytics/yolo/utils/checks.py @@ -247,7 +247,7 @@ def check_file(file, suffix=''): if Path(file).is_file(): LOGGER.info(f'Found {url} locally at {file}') # file already exists else: - downloads.safe_download(url=url, file=file) + downloads.safe_download(url=url, file=file, unzip=False) return file else: # search files = [] diff --git a/ultralytics/yolo/utils/downloads.py b/ultralytics/yolo/utils/downloads.py index 622d1ae52..b7b6d7462 100644 --- a/ultralytics/yolo/utils/downloads.py +++ b/ultralytics/yolo/utils/downloads.py @@ -28,6 +28,19 @@ def is_url(url, check=True): return False +def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')): + """ + Unzip a *.zip file to path/, excluding files containing strings in exclude list + Replaces: ZipFile(file).extractall(path=path) + """ + if path is None: + path = Path(file).parent # default path + with ZipFile(file) as zipObj: + for f in zipObj.namelist(): # list all archived filenames in the zip + if all(x not in f for x in exclude): + zipObj.extract(f, path=path) + + def safe_download(url, file=None, dir=None, @@ -96,13 +109,14 @@ def safe_download(url, LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...') if unzip and f.exists() and f.suffix in {'.zip', '.tar', '.gz'}: - LOGGER.info(f'Unzipping {f}...') + unzip_dir = dir or f.parent # unzip to dir if provided else unzip in place + LOGGER.info(f'Unzipping {f} to {unzip_dir}...') if f.suffix == '.zip': - ZipFile(f).extractall(path=f.parent) # unzip + unzip_file(file=f, path=unzip_dir) # unzip elif f.suffix == '.tar': - subprocess.run(['tar', 'xf', f, '--directory', f.parent], check=True) # unzip + subprocess.run(['tar', 'xf', f, '--directory', unzip_dir], check=True) # unzip elif f.suffix == '.gz': - subprocess.run(['tar', 'xfz', f, '--directory', f.parent], check=True) # unzip + subprocess.run(['tar', 'xfz', f, '--directory', unzip_dir], check=True) # unzip if delete: f.unlink() # remove zip diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py index ed890a122..ce15ba451 100644 --- a/ultralytics/yolo/v8/detect/train.py +++ b/ultralytics/yolo/v8/detect/train.py @@ -33,14 +33,14 @@ class DetectionTrainer(BaseTrainer): augment=mode == "train", cache=self.args.cache, pad=0 if mode == "train" else 0.5, - rect=self.args.rect, + rect=self.args.rect or mode=="val", rank=rank, workers=self.args.workers, close_mosaic=self.args.close_mosaic != 0, prefix=colorstr(f'{mode}: '), shuffle=mode == "train", seed=self.args.seed)[0] if self.args.v5loader else \ - build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode)[0] + build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode, rect=mode=="val")[0] def preprocess_batch(self, batch): batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255 diff --git a/ultralytics/yolo/v8/detect/val.py b/ultralytics/yolo/v8/detect/val.py index 09826b111..bc2148dc6 100644 --- a/ultralytics/yolo/v8/detect/val.py +++ b/ultralytics/yolo/v8/detect/val.py @@ -22,7 +22,6 @@ class DetectionValidator(BaseValidator): def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None): super().__init__(dataloader, save_dir, pbar, logger, args) self.args.task = 'detect' - self.data_dict = yaml_load(check_file(self.args.data), append_filename=True) if self.args.data else None self.is_coco = False self.class_map = None self.metrics = DetMetrics(save_dir=self.save_dir) @@ -172,7 +171,7 @@ class DetectionValidator(BaseValidator): hyp=vars(self.args), cache=False, pad=0.5, - rect=True, + rect=self.args.rect, workers=self.args.workers, prefix=colorstr(f'{self.args.mode}: '), shuffle=False,