Add Classification model YAML support (#154)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
main
Ayush Chaurasia 2 years ago committed by GitHub
parent 0e5a7ae623
commit 07eab49c3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      .github/workflows/ci.yaml
  2. 28
      tests/test_cli.py
  3. 2
      tests/test_engine.py
  4. 2
      tests/test_python.py
  5. 8
      ultralytics/nn/modules.py
  6. 34
      ultralytics/nn/tasks.py
  7. 25
      ultralytics/yolo/data/utils.py
  8. 13
      ultralytics/yolo/engine/model.py
  9. 41
      ultralytics/yolo/v8/classify/train.py
  10. 23
      ultralytics/yolo/v8/models/cls/yolov8l-cls.yaml
  11. 23
      ultralytics/yolo/v8/models/cls/yolov8m-cls.yaml
  12. 23
      ultralytics/yolo/v8/models/cls/yolov8n-cls.yaml
  13. 23
      ultralytics/yolo/v8/models/cls/yolov8s-cls.yaml
  14. 23
      ultralytics/yolo/v8/models/cls/yolov8x-cls.yaml

@ -100,5 +100,5 @@ jobs:
- name: Test classification
shell: bash # for Windows compatibility
run: |
yolo task=classify mode=train model=resnet18 data=mnist160 epochs=1 imgsz=32
yolo task=classify mode=train model=yolov8n-cls.yaml data=mnist160 epochs=1 imgsz=32
yolo task=classify mode=val model=runs/classify/train/weights/last.pt data=mnist160

@ -3,8 +3,8 @@ from pathlib import Path
from ultralytics.yolo.utils import ROOT, SETTINGS
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
CFG = 'yolov8n.yaml'
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
CFG = 'yolov8n'
def test_checks():
@ -12,25 +12,25 @@ def test_checks():
# Train checks ---------------------------------------------------------------------------------------------------------
def test_train_detect():
os.system(f'yolo mode=train task=detect model={MODEL} data=coco128.yaml imgsz=32 epochs=1')
def test_train_det():
os.system(f'yolo mode=train task=detect model={CFG}.yaml data=coco128.yaml imgsz=32 epochs=1')
def test_train_segment():
os.system('yolo mode=train task=segment model=yolov8n-seg.yaml data=coco128-seg.yaml imgsz=32 epochs=1')
def test_train_seg():
os.system(f'yolo mode=train task=segment model={CFG}-seg.yaml data=coco128-seg.yaml imgsz=32 epochs=1')
def test_train_classify():
pass
def test_train_cls():
os.system(f'yolo mode=train task=classify model={CFG}-cls.yaml data=imagenette160 imgsz=32 epochs=1')
# Val checks -----------------------------------------------------------------------------------------------------------
def test_val_detect():
os.system(f'yolo mode=val task=detect model={MODEL} data=coco128.yaml imgsz=32 epochs=1')
os.system(f'yolo mode=val task=detect model={MODEL}.pt data=coco128.yaml imgsz=32 epochs=1')
def test_val_segment():
pass
os.system(f'yolo mode=val task=segment model={MODEL}-seg.pt data=coco128-seg.yaml imgsz=32 epochs=1')
def test_val_classify():
@ -39,11 +39,11 @@ def test_val_classify():
# Predict checks -------------------------------------------------------------------------------------------------------
def test_predict_detect():
os.system(f"yolo mode=predict model={MODEL} source={ROOT / 'assets'}")
os.system(f"yolo mode=predict model={MODEL}.pt source={ROOT / 'assets'}")
def test_predict_segment():
pass
os.system(f"yolo mode=predict model={MODEL}-seg.pt source={ROOT / 'assets'}")
def test_predict_classify():
@ -52,11 +52,11 @@ def test_predict_classify():
# Export checks --------------------------------------------------------------------------------------------------------
def test_export_detect_torchscript():
os.system(f'yolo mode=export model={MODEL} format=torchscript')
os.system(f'yolo mode=export model={MODEL}.pt format=torchscript')
def test_export_segment_torchscript():
pass
os.system(f'yolo mode=export model={MODEL}-seg.pt format=torchscript')
def test_export_classify_torchscript():

@ -71,7 +71,7 @@ def test_segment():
def test_classify():
overrides = {
"data": "imagenette160",
"model": "squeezenet1_0",
"model": "yolov8n-cls.yaml",
"imgsz": 32,
"epochs": 1,
"batch": 64,

@ -3,8 +3,8 @@ from pathlib import Path
from ultralytics import YOLO
from ultralytics.yolo.utils import ROOT, SETTINGS
CFG = 'yolov8n.yaml'
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
CFG = 'yolov8n.yaml'
SOURCE = ROOT / 'assets/bus.jpg'

@ -662,12 +662,10 @@ class Segment(Detect):
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
def forward(self, x):
p = self.proto(x[0])
p = self.proto(x[0]) # mask protos
bs = p.shape[0] # batch size
mc = [] # mask coefficient
for i in range(self.nl):
mc.append(self.cv4[i](x[i]))
mc = torch.cat([mi.view(p.shape[0], self.nm, -1) for mi in mc], 2)
mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
x = self.detect(self, x)
if self.training:
return x, mc, p

@ -1,11 +1,9 @@
import contextlib
from copy import deepcopy
from pathlib import Path
import thop
import torch
import torch.nn as nn
import torchvision
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
@ -226,9 +224,15 @@ class SegmentationModel(DetectionModel):
class ClassificationModel(BaseModel):
# YOLOv5 classification model
def __init__(self, cfg=None, model=None, nc=1000, cutoff=10): # yaml, model, number of classes, cutoff index
def __init__(self,
cfg=None,
model=None,
ch=3,
nc=1000,
cutoff=10,
verbose=True): # yaml, model, number of classes, cutoff index
super().__init__()
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg)
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
def _from_detection_model(self, model, nc=1000, cutoff=10):
# Create a YOLOv5 classification model from a YOLOv5 detection model
@ -246,9 +250,15 @@ class ClassificationModel(BaseModel):
self.save = []
self.nc = nc
def _from_yaml(self, cfg):
# TODO: Create a YOLOv5 classification model from a *.yaml file
self.model = None
def _from_yaml(self, cfg, ch, nc, verbose):
self.yaml = cfg if isinstance(cfg, dict) else yaml_load(check_yaml(cfg), append_filename=True) # cfg dict
# Define model
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
if nc and nc != self.yaml['nc']:
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
self.yaml['nc'] = nc # override yaml value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch], verbose=verbose) # model, savelist
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
def load(self, weights):
model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
@ -351,7 +361,7 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
# Parse a YOLOv5 model.yaml dictionary
# Parse a YOLO model.yaml dictionary
if verbose:
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
nc, gd, gw, act = d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
@ -359,7 +369,6 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
if verbose:
LOGGER.info(f"{colorstr('activation:')} {act}") # print
no = nc + 4 # number of outputs = classes + box
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
@ -370,10 +379,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in {
Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus, BottleneckCSP,
C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
c1, c2 = ch[f], args[0]
if c2 != no: # if not output
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
c2 = make_divisible(c2 * gw, 8)
args = [c1, c2, *args[1:]]
@ -384,7 +393,6 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
args = [ch[f]]
elif m is Concat:
c2 = sum(ch[x] for x in f)
# TODO: channel, gw, gd
elif m in {Detect, Segment}:
args.append([ch[x] for x in f])
if m is Segment:

@ -255,12 +255,28 @@ def check_dataset_yaml(data, autodownload=True):
def check_dataset(dataset: str):
data = Path.cwd() / "datasets" / dataset
data_dir = data if data.is_dir() else (Path.cwd() / data)
"""
Check a classification dataset such as Imagenet.
Copy code
This function takes a `dataset` name as input and returns a dictionary containing information about the dataset.
If the dataset is not found, it attempts to download the dataset from the internet and save it to the local file system.
Args:
dataset (str): Name of the dataset.
Returns:
data (dict): A dictionary containing the following keys and values:
'train': Path object for the directory containing the training set of the dataset
'val': Path object for the directory containing the validation set of the dataset
'nc': Number of classes in the dataset
'names': List of class names in the dataset
"""
data_dir = (Path.cwd() / "datasets" / dataset).resolve()
if not data_dir.is_dir():
LOGGER.info(f'\nDataset not found ⚠, missing path {data_dir}, attempting download...')
t = time.time()
if str(data) == 'imagenet':
if dataset == 'imagenet':
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
else:
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
@ -271,5 +287,4 @@ def check_dataset(dataset: str):
test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
names = [name for name in os.listdir(data_dir / 'train') if os.path.isdir(data_dir / 'train' / name)]
data = {"train": train_set, "val": test_set, "nc": nc, "names": names}
return data
return {"train": train_set, "val": test_set, "nc": nc, "names": names}

@ -103,13 +103,9 @@ class YOLO:
Args:
verbose (bool): Controls verbosity.
"""
if not self.model:
LOGGER.info("model not initialized!")
self.model.info(verbose=verbose)
def fuse(self):
if not self.model:
LOGGER.info("model not initialized!")
self.model.fuse()
@smart_inference_mode()
@ -139,9 +135,6 @@ class YOLO:
data (str): The dataset to validate on. Accepts all formats accepted by yolo
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
"""
if not self.model:
raise ModuleNotFoundError("model not initialized!")
overrides = self.overrides.copy()
overrides.update(kwargs)
overrides["mode"] = "val"
@ -177,8 +170,6 @@ class YOLO:
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
You can pass all arguments as a yaml file in `cfg`. Other args are ignored if `cfg` file is passed
"""
if not self.model:
raise AttributeError("model not initialized. Use .new() or .load()")
overrides = self.overrides.copy()
overrides.update(kwargs)
if kwargs.get("cfg"):
@ -193,10 +184,8 @@ class YOLO:
self.trainer = self.TrainerClass(overrides=overrides)
if not overrides.get("resume"): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None,
cfg=self.model.yaml if self.task != "classify" else None)
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
self.trainer.train()
def to(self, device):

@ -1,5 +1,3 @@
from pathlib import Path
import hydra
import torch
import torchvision
@ -13,7 +11,9 @@ from ultralytics.yolo.utils import DEFAULT_CONFIG
class ClassificationTrainer(BaseTrainer):
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
if overrides is None:
overrides = {}
overrides["task"] = "classify"
super().__init__(config, overrides)
@ -25,6 +25,10 @@ class ClassificationTrainer(BaseTrainer):
if weights:
model.load(weights)
# Update defaults
if self.args.imgsz == 640:
self.args.imgsz = 224
return model
def setup_model(self):
@ -36,22 +40,17 @@ class ClassificationTrainer(BaseTrainer):
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
model = self.model
pretrained = False
model = str(self.model)
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
if model.endswith(".pt"):
model = model.split(".")[0]
pretrained = True
else:
self.model = attempt_load_weights(model, device='cpu')
elif model.endswith(".yaml"):
self.model = self.get_model(cfg=model)
# order: check local file -> torchvision assets -> ultralytics asset
if Path(f"{model}.pt").is_file(): # local file
self.model = attempt_load_weights(f"{model}.pt", device='cpu')
elif model in torchvision.models.__dict__:
pretrained = True
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
else:
self.model = attempt_load_weights(f"{model}.pt", device='cpu')
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
return # dont return ckpt. Classification doesn't support resume
@ -66,6 +65,10 @@ class ClassificationTrainer(BaseTrainer):
batch["cls"] = batch["cls"].to(self.device)
return batch
def progress_string(self):
return ('\n' + '%11s' *
(4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
def get_validator(self):
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir, logger=self.console)
@ -73,9 +76,6 @@ class ClassificationTrainer(BaseTrainer):
loss = torch.nn.functional.cross_entropy(preds, batch["cls"])
return loss, loss
def check_resume(self):
pass
def resume_training(self, ckpt):
pass
@ -85,10 +85,13 @@ class ClassificationTrainer(BaseTrainer):
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def train(cfg):
cfg.model = cfg.model or "resnet18"
cfg.model = cfg.model or "yolov8n-cls.yaml" # or "resnet18"
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
trainer = ClassificationTrainer(cfg)
trainer.train()
# trainer = ClassificationTrainer(cfg)
# trainer.train()
from ultralytics import YOLO
model = YOLO(cfg.model)
model.train(**cfg)
if __name__ == "__main__":

@ -0,0 +1,23 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 1000 # number of classes
depth_multiple: 1.00 # scales module repeats
width_multiple: 1.00 # scales convolution channels
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
# YOLOv8.0n head
head:
- [-1, 1, Classify, [nc]]

@ -0,0 +1,23 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 1000 # number of classes
depth_multiple: 0.67 # scales module repeats
width_multiple: 0.75 # scales convolution channels
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
# YOLOv8.0n head
head:
- [-1, 1, Classify, [nc]]

@ -0,0 +1,23 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 1000 # number of classes
depth_multiple: 0.33 # scales module repeats
width_multiple: 0.25 # scales convolution channels
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
# YOLOv8.0n head
head:
- [-1, 1, Classify, [nc]]

@ -0,0 +1,23 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 1000 # number of classes
depth_multiple: 0.33 # scales module repeats
width_multiple: 0.50 # scales convolution channels
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
# YOLOv8.0n head
head:
- [-1, 1, Classify, [nc]]

@ -0,0 +1,23 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 1000 # number of classes
depth_multiple: 1.00 # scales module repeats
width_multiple: 1.25 # scales convolution channels
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
# YOLOv8.0n head
head:
- [-1, 1, Classify, [nc]]
Loading…
Cancel
Save