Add `max_dim==2` argument to `check_imgsz()` (#789)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: andreaswimmer <53872150+andreaswimmer@users.noreply.github.com>
Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com>
pull/406/merge
Glenn Jocher 2 years ago committed by GitHub
parent 5a80ad98db
commit 0d182e80f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 24
      docker/Dockerfile
  2. 4
      docker/Dockerfile-cpu
  3. 2
      ultralytics/__init__.py
  4. 16
      ultralytics/nn/autobackend.py
  5. 30
      ultralytics/nn/tasks.py
  6. 8
      ultralytics/yolo/data/datasets/ImageNet.yaml
  7. 37
      ultralytics/yolo/engine/exporter.py
  8. 6
      ultralytics/yolo/engine/model.py
  9. 2
      ultralytics/yolo/engine/trainer.py
  10. 10
      ultralytics/yolo/utils/__init__.py
  11. 9
      ultralytics/yolo/utils/checks.py

@ -3,18 +3,18 @@
# Image is CUDA-optimized for YOLOv8 single/multi-GPU training and inference # Image is CUDA-optimized for YOLOv8 single/multi-GPU training and inference
# Start FROM NVIDIA PyTorch image https://ngc.nvidia.com/catalog/containers/nvidia:pytorch # Start FROM NVIDIA PyTorch image https://ngc.nvidia.com/catalog/containers/nvidia:pytorch
FROM nvcr.io/nvidia/pytorch:23.01-py3 # FROM docker.io/pytorch/pytorch:latest
FROM pytorch/pytorch:latest
# Downloads to user config dir # Downloads to user config dir
ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Arial.Unicode.ttf /root/.config/Ultralytics/ ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Arial.Unicode.ttf /root/.config/Ultralytics/
# Remove torch nightly and install torch stable
RUN rm -rf /opt/pytorch # remove 1.2GB dir
RUN pip uninstall -y torchtext pillow torch torchvision
RUN pip install --no-cache torch torchvision
# Install linux packages # Install linux packages
RUN apt update && apt install --no-install-recommends -y zip htop screen libgl1-mesa-glx ENV DEBIAN_FRONTEND noninteractive
RUN apt update
RUN TZ=Etc/UTC apt install -y tzdata
RUN apt install --no-install-recommends -y git zip curl htop libgl1-mesa-glx libglib2.0-0 libpython3-dev gnupg
# RUN alias python=python3
# Create working directory # Create working directory
RUN mkdir -p /usr/src/ultralytics RUN mkdir -p /usr/src/ultralytics
@ -25,12 +25,18 @@ WORKDIR /usr/src/ultralytics
RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics
# Install pip packages # Install pip packages
RUN python -m pip install --upgrade pip wheel COPY requirements.txt .
RUN pip install --no-cache ultralytics albumentations comet gsutil notebook RUN python3 -m pip install --upgrade pip wheel
RUN pip install --no-cache ultralytics albumentations comet gsutil notebook \
coremltools onnx onnx-simplifier onnxruntime openvino-dev>=2022.3
# tensorflow tensorflowjs \
# Set environment variables # Set environment variables
ENV OMP_NUM_THREADS=1 ENV OMP_NUM_THREADS=1
# Cleanup
ENV DEBIAN_FRONTEND teletype
# Usage Examples ------------------------------------------------------------------------------------------------------- # Usage Examples -------------------------------------------------------------------------------------------------------

@ -27,8 +27,8 @@ RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics
COPY requirements.txt . COPY requirements.txt .
RUN python3 -m pip install --upgrade pip wheel RUN python3 -m pip install --upgrade pip wheel
RUN pip install --no-cache ultralytics albumentations gsutil notebook \ RUN pip install --no-cache ultralytics albumentations gsutil notebook \
coremltools onnx onnx-simplifier onnxruntime tensorflow-cpu \ coremltools onnx onnx-simplifier onnxruntime openvino-dev>=2022.3 \
# openvino-dev>=2022.3 tensorflowjs \ # tensorflow-cpu tensorflowjs \
--extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://download.pytorch.org/whl/cpu
# Cleanup # Cleanup

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = "8.0.26" __version__ = "8.0.27"
from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import ops from ultralytics.yolo.utils import ops

@ -18,6 +18,16 @@ from ultralytics.yolo.utils.downloads import attempt_download_asset, is_url
from ultralytics.yolo.utils.ops import xywh2xyxy from ultralytics.yolo.utils.ops import xywh2xyxy
def check_class_names(names):
# Check class names. Map imagenet class codes to human-readable names if required. Convert lists to dicts.
if isinstance(names, list): # names is a list
names = dict(enumerate(names)) # convert to dict
if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
map = yaml_load(ROOT / 'yolo/data/datasets/ImageNet.yaml')['map'] # human-readable names
names = {k: map[v] for k, v in names.items()}
return names
class AutoBackend(nn.Module): class AutoBackend(nn.Module):
def __init__(self, weights='yolov8n.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True): def __init__(self, weights='yolov8n.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
@ -228,11 +238,7 @@ class AutoBackend(nn.Module):
# class names # class names
if 'names' not in locals(): # names missing if 'names' not in locals(): # names missing
names = yaml_load(data)['names'] if data else {i: f'class{i}' for i in range(999)} # assign default names = yaml_load(data)['names'] if data else {i: f'class{i}' for i in range(999)} # assign default
elif isinstance(names, list): # names is a list names = check_class_names(names)
names = dict(enumerate(names)) # convert to dict
if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
map = yaml_load(ROOT / 'yolo/data/datasets/ImageNet.yaml')['map'] # human-readable names
names = {k: map[v] for k, v in names.items()}
self.__dict__.update(locals()) # assign all variables to self self.__dict__.update(locals()) # assign all variables to self

@ -347,23 +347,24 @@ def torch_safe_load(weight):
def attempt_load_weights(weights, device=None, inplace=True, fuse=False): def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
model = Ensemble() ensemble = Ensemble()
for w in weights if isinstance(weights, list) else [weights]: for w in weights if isinstance(weights, list) else [weights]:
ckpt = torch_safe_load(w) # load ckpt ckpt = torch_safe_load(w) # load ckpt
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
# Model compatibility updates # Model compatibility updates
ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
ckpt.pt_path = weights # attach *.pt file path to model model.pt_path = weights # attach *.pt file path to model
if not hasattr(ckpt, 'stride'): model.task = guess_model_task(model)
ckpt.stride = torch.tensor([32.]) if not hasattr(model, 'stride'):
model.stride = torch.tensor([32.])
# Append # Append
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode ensemble.append(model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval()) # model in eval mode
# Module compatibility updates # Module compatibility updates
for m in model.modules(): for m in ensemble.modules():
t = type(m) t = type(m)
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment): if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
m.inplace = inplace # torch 1.7.0 compatibility m.inplace = inplace # torch 1.7.0 compatibility
@ -371,16 +372,16 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
m.recompute_scale_factor = None # torch 1.11.0 compatibility m.recompute_scale_factor = None # torch 1.11.0 compatibility
# Return model # Return model
if len(model) == 1: if len(ensemble) == 1:
return model[-1] return ensemble[-1]
# Return ensemble # Return ensemble
print(f'Ensemble created with {weights}\n') print(f'Ensemble created with {weights}\n')
for k in 'names', 'nc', 'yaml': for k in 'names', 'nc', 'yaml':
setattr(model, k, getattr(model[0], k)) setattr(ensemble, k, getattr(ensemble[0], k))
model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride
assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}' assert all(ensemble[0].nc == m.nc for m in ensemble), f'Models differ in class counts: {[m.nc for m in ensemble]}'
return model return ensemble
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False): def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
@ -392,6 +393,7 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
# Model compatibility updates # Model compatibility updates
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
model.pt_path = weight # attach *.pt file path to model model.pt_path = weight # attach *.pt file path to model
model.task = guess_model_task(model)
if not hasattr(model, 'stride'): if not hasattr(model, 'stride'):
model.stride = torch.tensor([32.]) model.stride = torch.tensor([32.])

@ -1153,7 +1153,7 @@ map:
n02009229: little_blue_heron n02009229: little_blue_heron
n02009912: American_egret n02009912: American_egret
n02011460: bittern n02011460: bittern
n02012849: crane n02012849: crane_(bird)
n02013706: limpkin n02013706: limpkin
n02017213: European_gallinule n02017213: European_gallinule
n02018207: American_coot n02018207: American_coot
@ -1536,7 +1536,7 @@ map:
n03124043: cowboy_boot n03124043: cowboy_boot
n03124170: cowboy_hat n03124170: cowboy_hat
n03125729: cradle n03125729: cradle
n03126707: crane n03126707: crane_(machine)
n03127747: crash_helmet n03127747: crash_helmet
n03127925: crate n03127925: crate
n03131574: crib n03131574: crib
@ -1657,8 +1657,8 @@ map:
n03706229: magnetic_compass n03706229: magnetic_compass
n03709823: mailbag n03709823: mailbag
n03710193: mailbox n03710193: mailbox
n03710637: maillot n03710637: maillot_(tights)
n03710721: maillot n03710721: maillot_(tank_suit)
n03717622: manhole_cover n03717622: manhole_cover
n03720891: maraca n03720891: maraca
n03721384: marimba n03721384: marimba

@ -65,6 +65,7 @@ import pandas as pd
import torch import torch
import ultralytics import ultralytics
from ultralytics.nn.autobackend import check_class_names
from ultralytics.nn.modules import Detect, Segment from ultralytics.nn.modules import Detect, Segment
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, guess_model_task from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, guess_model_task
from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.cfg import get_cfg
@ -151,9 +152,12 @@ class Exporter:
assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic' assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic'
# Checks # Checks
model.names = check_class_names(model.names)
# if self.args.batch == model.args['batch_size']: # user has not modified training batch_size # if self.args.batch == model.args['batch_size']: # user has not modified training batch_size
self.args.batch = 1 self.args.batch = 1
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
if model.task == 'classify':
self.args.nms = self.args.agnostic_nms = False
if self.args.optimize: if self.args.optimize:
assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu' assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
@ -194,8 +198,14 @@ class Exporter:
self.model = model self.model = model
self.file = file self.file = file
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else (x.shape for x in y) self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else (x.shape for x in y)
self.metadata = {'stride': int(max(model.stride)), 'names': model.names} # model metadata
self.pretty_name = self.file.stem.replace('yolo', 'YOLO') self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
self.metadata = {
'description': f"Ultralytics {self.pretty_name} model trained on {self.model.args['data']}",
'author': 'Ultralytics',
'license': 'GPL-3.0 https://ultralytics.com/license',
'version': ultralytics.__version__,
'stride': int(max(model.stride)),
'names': model.names} # model metadata
# Exports # Exports
f = [''] * len(fmts) # exported filenames f = [''] * len(fmts) # exported filenames
@ -235,12 +245,11 @@ class Exporter:
# Finish # Finish
f = [str(x) for x in f if x] # filter out '' and None f = [str(x) for x in f if x] # filter out '' and None
if any(f): if any(f):
task = guess_model_task(model)
s = "-WARNING ⚠ not yet supported for YOLOv8 exported models" s = "-WARNING ⚠ not yet supported for YOLOv8 exported models"
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)' LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
f"\nResults saved to {colorstr('bold', file.parent.resolve())}" f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
f"\nPredict: yolo task={task} mode=predict model={f[-1]} {s}" f"\nPredict: yolo task={model.task} mode=predict model={f[-1]} {s}"
f"\nValidate: yolo task={task} mode=val model={f[-1]} {s}" f"\nValidate: yolo task={model.task} mode=val model={f[-1]} {s}"
f"\nVisualize: https://netron.app") f"\nVisualize: https://netron.app")
self.run_callbacks("on_export_end") self.run_callbacks("on_export_end")
@ -375,9 +384,13 @@ class Exporter:
LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...') LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
f = self.file.with_suffix('.mlmodel') f = self.file.with_suffix('.mlmodel')
task = self.model.task
model = iOSModel(self.model, self.im).eval() if self.args.nms else self.model model = iOSModel(self.model, self.im).eval() if self.args.nms else self.model
ts = torch.jit.trace(model, self.im, strict=False) # TorchScript model ts = torch.jit.trace(model, self.im, strict=False) # TorchScript model
ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=self.im.shape, scale=1 / 255, bias=[0, 0, 0])]) classifier_config = ct.ClassifierConfig(list(model.names.values())) if task == 'classify' else None
ct_model = ct.convert(ts,
inputs=[ct.ImageType('image', shape=self.im.shape, scale=1 / 255, bias=[0, 0, 0])],
classifier_config=classifier_config)
bits, mode = (8, 'kmeans_lut') if self.args.int8 else (16, 'linear') if self.args.half else (32, None) bits, mode = (8, 'kmeans_lut') if self.args.int8 else (16, 'linear') if self.args.half else (32, None)
if bits < 32: if bits < 32:
if MACOS: # quantization only supported on macOS if MACOS: # quantization only supported on macOS
@ -387,6 +400,10 @@ class Exporter:
if self.args.nms: if self.args.nms:
ct_model = self._pipeline_coreml(ct_model) ct_model = self._pipeline_coreml(ct_model)
ct_model.short_description = self.metadata['description']
ct_model.author = self.metadata['author']
ct_model.license = self.metadata['license']
ct_model.version = self.metadata['version']
ct_model.save(str(f)) ct_model.save(str(f))
return f, ct_model return f, ct_model
@ -687,8 +704,8 @@ class Exporter:
out0_shape = out[out0.name].shape out0_shape = out[out0.name].shape
out1_shape = out[out1.name].shape out1_shape = out[out1.name].shape
else: # linux and windows can not run model.predict(), get sizes from pytorch output y else: # linux and windows can not run model.predict(), get sizes from pytorch output y
out0_shape = self.output_shape[1], self.output_shape[2] - 5 # (3780, 80) out0_shape = self.output_shape[2], self.output_shape[1] - 4 # (3780, 80)
out1_shape = self.output_shape[1], 4 # (3780, 4) out1_shape = self.output_shape[2], 4 # (3780, 4)
# Checks # Checks
names = self.metadata['names'] names = self.metadata['names']
@ -714,7 +731,7 @@ class Exporter:
# flexible_shape_utils.update_image_size_range(spec, feature_name='image', size_range=r) # flexible_shape_utils.update_image_size_range(spec, feature_name='image', size_range=r)
# Print # Print
print(spec.description) # print(spec.description)
# Model from spec # Model from spec
model = ct.models.MLModel(spec) model = ct.models.MLModel(spec)
@ -771,10 +788,6 @@ class Exporter:
# Update metadata # Update metadata
pipeline.spec.specificationVersion = 5 pipeline.spec.specificationVersion = 5
pipeline.spec.description.metadata.versionString = f'Ultralytics YOLOv{ultralytics.__version__}'
pipeline.spec.description.metadata.shortDescription = f'Ultralytics {self.pretty_name} CoreML model'
pipeline.spec.description.metadata.author = 'Ultralytics (https://ultralytics.com)'
pipeline.spec.description.metadata.license = 'GPL-3.0 license (https://ultralytics.com/license)'
pipeline.spec.description.metadata.userDefined.update({ pipeline.spec.description.metadata.userDefined.update({
'IoU threshold': str(nms.iouThreshold), 'IoU threshold': str(nms.iouThreshold),
'Confidence threshold': str(nms.confidenceThreshold)}) 'Confidence threshold': str(nms.confidenceThreshold)})

@ -1,6 +1,7 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
from pathlib import Path from pathlib import Path
from typing import List
import sys import sys
from ultralytics import yolo # noqa from ultralytics import yolo # noqa
@ -9,7 +10,7 @@ from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, Segmentat
from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, callbacks, yaml_load from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, callbacks, yaml_load
from ultralytics.yolo.utils.checks import check_yaml from ultralytics.yolo.utils.checks import check_yaml, check_imgsz
from ultralytics.yolo.utils.torch_utils import smart_inference_mode from ultralytics.yolo.utils.torch_utils import smart_inference_mode
# Map head to model, trainer, validator, and predictor classes # Map head to model, trainer, validator, and predictor classes
@ -131,7 +132,7 @@ class YOLO:
Check the 'configuration' section in the documentation for all available options. Check the 'configuration' section in the documentation for all available options.
Returns: Returns:
(dict): The prediction results. (List[ultralytics.yolo.engine.results.Results]): The prediction results.
""" """
overrides = self.overrides.copy() overrides = self.overrides.copy()
overrides["conf"] = 0.25 overrides["conf"] = 0.25
@ -161,6 +162,7 @@ class YOLO:
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.data = data or args.data args.data = data or args.data
args.task = self.task args.task = self.task
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = self.ValidatorClass(args=args) validator = self.ValidatorClass(args=args)
validator(model=self.model) validator(model=self.model)

@ -202,7 +202,7 @@ class BaseTrainer:
self.model = DDP(self.model, device_ids=[rank]) self.model = DDP(self.model, device_ids=[rank])
# Check imgsz # Check imgsz
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride) gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs) self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
# Batch size # Batch size
if self.batch_size == -1: if self.batch_size == -1:
if RANK == -1: # single-GPU only, estimate best batch size if RANK == -1: # single-GPU only, estimate best batch size

@ -467,6 +467,13 @@ def set_sentry():
""" """
def before_send(event, hint): def before_send(event, hint):
if 'exc_info' in hint:
exc_type, exc_value, tb = hint['exc_info']
if exc_type in (KeyboardInterrupt, FileNotFoundError) \
or 'out of memory' in str(exc_value) \
or not sys.argv[0].endswith('yolo'):
return None # do not send event
env = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \ env = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \
'Docker' if is_docker() else platform.system() 'Docker' if is_docker() else platform.system()
event['tags'] = { event['tags'] = {
@ -477,6 +484,7 @@ def set_sentry():
return event return event
if SETTINGS['sync'] and \ if SETTINGS['sync'] and \
RANK in {-1, 0} and \
not is_pytest_running() and \ not is_pytest_running() and \
not is_github_actions_ci() and \ not is_github_actions_ci() and \
((is_pip_package() and not is_git_dir()) or ((is_pip_package() and not is_git_dir()) or
@ -491,7 +499,7 @@ def set_sentry():
release=ultralytics.__version__, release=ultralytics.__version__,
environment='production', # 'dev' or 'production' environment='production', # 'dev' or 'production'
before_send=before_send, before_send=before_send,
ignore_errors=[KeyboardInterrupt]) ignore_errors=[KeyboardInterrupt, FileNotFoundError])
# Disable all sentry logging # Disable all sentry logging
for logger in "sentry_sdk", "sentry_sdk.errors": for logger in "sentry_sdk", "sentry_sdk.errors":

@ -40,7 +40,7 @@ def is_ascii(s) -> bool:
return all(ord(c) < 128 for c in s) return all(ord(c) < 128 for c in s)
def check_imgsz(imgsz, stride=32, min_dim=1, floor=0): def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
""" """
Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the
stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value. stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.
@ -66,6 +66,13 @@ def check_imgsz(imgsz, stride=32, min_dim=1, floor=0):
raise TypeError(f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. " raise TypeError(f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'") f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'")
# Apply max_dim
if max_dim == 1:
LOGGER.warning(f"WARNING ⚠ 'train' and 'val' imgsz types must be integer, updating to 'imgsz={max(imgsz)}'. "
f"'predict' and 'export' imgsz may be list or integer, "
f"i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'")
imgsz = [max(imgsz)]
# Make image size a multiple of the stride # Make image size a multiple of the stride
sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz] sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]

Loading…
Cancel
Save