YOLOv8 architecture updates from R&D branch (#88)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/91/head
parent
5fbea25f0b
commit
ebd3cfb2fd
23 changed files with 722 additions and 572 deletions
@ -0,0 +1,22 @@ |
||||
#!/bin/bash |
||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license |
||||
# Download latest models from https://github.com/ultralytics/yolov5/releases |
||||
# Example usage: bash data/scripts/download_weights.sh |
||||
# parent |
||||
# └── yolov5 |
||||
# ├── yolov5s.pt ← downloads here |
||||
# ├── yolov5m.pt |
||||
# └── ... |
||||
|
||||
python - <<EOF |
||||
from utils.downloads import attempt_download |
||||
|
||||
p5 = list('nsmlx') # P5 models |
||||
p6 = [f'{x}6' for x in p5] # P6 models |
||||
cls = [f'{x}-cls' for x in p5] # classification models |
||||
seg = [f'{x}-seg' for x in p5] # classification models |
||||
|
||||
for x in p5 + p6 + cls + seg: |
||||
attempt_download(f'weights/yolov5{x}.pt') |
||||
|
||||
EOF |
@ -0,0 +1,60 @@ |
||||
#!/bin/bash |
||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license |
||||
# Download COCO 2017 dataset http://cocodataset.org |
||||
# Example usage: bash data/scripts/get_coco.sh |
||||
# parent |
||||
# ├── yolov5 |
||||
# └── datasets |
||||
# └── coco ← downloads here |
||||
|
||||
# Arguments (optional) Usage: bash data/scripts/get_coco.sh --train --val --test --segments |
||||
if [ "$#" -gt 0 ]; then |
||||
for opt in "$@"; do |
||||
case "${opt}" in |
||||
--train) train=true ;; |
||||
--val) val=true ;; |
||||
--test) test=true ;; |
||||
--segments) segments=true ;; |
||||
--sama) sama=true ;; |
||||
esac |
||||
done |
||||
else |
||||
train=true |
||||
val=true |
||||
test=false |
||||
segments=false |
||||
sama=false |
||||
fi |
||||
|
||||
# Download/unzip labels |
||||
d='../datasets' # unzip directory |
||||
url=https://github.com/ultralytics/yolov5/releases/download/v1.0/ |
||||
if [ "$segments" == "true" ]; then |
||||
f='coco2017labels-segments.zip' # 169 MB |
||||
elif [ "$sama" == "true" ]; then |
||||
f='coco2017labels-segments-sama.zip' # 199 MB https://www.sama.com/sama-coco-dataset/ |
||||
else |
||||
f='coco2017labels.zip' # 46 MB |
||||
fi |
||||
echo 'Downloading' $url$f ' ...' |
||||
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f & |
||||
|
||||
# Download/unzip images |
||||
d='../datasets/coco/images' # unzip directory |
||||
url=http://images.cocodataset.org/zips/ |
||||
if [ "$train" == "true" ]; then |
||||
f='train2017.zip' # 19G, 118k images |
||||
echo 'Downloading' $url$f '...' |
||||
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f & |
||||
fi |
||||
if [ "$val" == "true" ]; then |
||||
f='val2017.zip' # 1G, 5k images |
||||
echo 'Downloading' $url$f '...' |
||||
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f & |
||||
fi |
||||
if [ "$test" == "true" ]; then |
||||
f='test2017.zip' # 7G, 41k images (optional) |
||||
echo 'Downloading' $url$f '...' |
||||
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f & |
||||
fi |
||||
wait # finish background tasks |
@ -0,0 +1,17 @@ |
||||
#!/bin/bash |
||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license |
||||
# Download COCO128 dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017) |
||||
# Example usage: bash data/scripts/get_coco128.sh |
||||
# parent |
||||
# ├── yolov5 |
||||
# └── datasets |
||||
# └── coco128 ← downloads here |
||||
|
||||
# Download/unzip images and labels |
||||
d='../datasets' # unzip directory |
||||
url=https://github.com/ultralytics/yolov5/releases/download/v1.0/ |
||||
f='coco128.zip' # or 'coco128-segments.zip', 68 MB |
||||
echo 'Downloading' $url$f ' ...' |
||||
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f & |
||||
|
||||
wait # finish background tasks |
@ -0,0 +1,51 @@ |
||||
#!/bin/bash |
||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license |
||||
# Download ILSVRC2012 ImageNet dataset https://image-net.org |
||||
# Example usage: bash data/scripts/get_imagenet.sh |
||||
# parent |
||||
# ├── yolov5 |
||||
# └── datasets |
||||
# └── imagenet ← downloads here |
||||
|
||||
# Arguments (optional) Usage: bash data/scripts/get_imagenet.sh --train --val |
||||
if [ "$#" -gt 0 ]; then |
||||
for opt in "$@"; do |
||||
case "${opt}" in |
||||
--train) train=true ;; |
||||
--val) val=true ;; |
||||
esac |
||||
done |
||||
else |
||||
train=true |
||||
val=true |
||||
fi |
||||
|
||||
# Make dir |
||||
d='../datasets/imagenet' # unzip directory |
||||
mkdir -p $d && cd $d |
||||
|
||||
# Download/unzip train |
||||
if [ "$train" == "true" ]; then |
||||
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar # download 138G, 1281167 images |
||||
mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train |
||||
tar -xf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar |
||||
find . -name "*.tar" | while read NAME; do |
||||
mkdir -p "${NAME%.tar}" |
||||
tar -xf "${NAME}" -C "${NAME%.tar}" |
||||
rm -f "${NAME}" |
||||
done |
||||
cd .. |
||||
fi |
||||
|
||||
# Download/unzip val |
||||
if [ "$val" == "true" ]; then |
||||
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar # download 6.3G, 50000 images |
||||
mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xf ILSVRC2012_img_val.tar |
||||
wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash # move into subdirs |
||||
fi |
||||
|
||||
# Delete corrupted image (optional: PNG under JPEG name that may cause dataloaders to fail) |
||||
# rm train/n04266014/n04266014_10835.JPEG |
||||
|
||||
# TFRecords (optional) |
||||
# wget https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_lsvrc_2015_synsets.txt |
@ -0,0 +1,53 @@ |
||||
import torch |
||||
import torch.nn as nn |
||||
import torch.nn.functional as F |
||||
|
||||
from .metrics import bbox_iou |
||||
from .tal import bbox2dist |
||||
|
||||
|
||||
class VarifocalLoss(nn.Module): |
||||
# Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367 |
||||
def __init__(self): |
||||
super().__init__() |
||||
|
||||
def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0): |
||||
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label |
||||
with torch.cuda.amp.autocast(enabled=False): |
||||
loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * |
||||
weight).sum() |
||||
return loss |
||||
|
||||
|
||||
class BboxLoss(nn.Module): |
||||
|
||||
def __init__(self, reg_max, use_dfl=False): |
||||
super().__init__() |
||||
self.reg_max = reg_max |
||||
self.use_dfl = use_dfl |
||||
|
||||
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask): |
||||
# IoU loss |
||||
weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1) |
||||
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True) |
||||
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum |
||||
|
||||
# DFL loss |
||||
if self.use_dfl: |
||||
target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max) |
||||
loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight |
||||
loss_dfl = loss_dfl.sum() / target_scores_sum |
||||
else: |
||||
loss_dfl = torch.tensor(0.0).to(pred_dist.device) |
||||
|
||||
return loss_iou, loss_dfl |
||||
|
||||
@staticmethod |
||||
def _df_loss(pred_dist, target): |
||||
# Return sum of left and right DFL losses |
||||
tl = target.long() # target left |
||||
tr = tl + 1 # target right |
||||
wl = tr - target # weight left |
||||
wr = 1 - wl # weight right |
||||
return (F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl + |
||||
F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr).mean(-1, keepdim=True) |
@ -0,0 +1,211 @@ |
||||
import torch |
||||
import torch.nn as nn |
||||
import torch.nn.functional as F |
||||
|
||||
from .checks import check_version |
||||
from .metrics import bbox_iou |
||||
|
||||
TORCH_1_10 = check_version(torch.__version__, '1.10.0') |
||||
|
||||
|
||||
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9): |
||||
"""select the positive anchor center in gt |
||||
|
||||
Args: |
||||
xy_centers (Tensor): shape(h*w, 4) |
||||
gt_bboxes (Tensor): shape(b, n_boxes, 4) |
||||
Return: |
||||
(Tensor): shape(b, n_boxes, h*w) |
||||
""" |
||||
n_anchors = xy_centers.shape[0] |
||||
bs, n_boxes, _ = gt_bboxes.shape |
||||
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom |
||||
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1) |
||||
# return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype) |
||||
return bbox_deltas.amin(3).gt_(eps) |
||||
|
||||
|
||||
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes): |
||||
"""if an anchor box is assigned to multiple gts, |
||||
the one with the highest iou will be selected. |
||||
|
||||
Args: |
||||
mask_pos (Tensor): shape(b, n_max_boxes, h*w) |
||||
overlaps (Tensor): shape(b, n_max_boxes, h*w) |
||||
Return: |
||||
target_gt_idx (Tensor): shape(b, h*w) |
||||
fg_mask (Tensor): shape(b, h*w) |
||||
mask_pos (Tensor): shape(b, n_max_boxes, h*w) |
||||
""" |
||||
# (b, n_max_boxes, h*w) -> (b, h*w) |
||||
fg_mask = mask_pos.sum(-2) |
||||
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes |
||||
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).repeat([1, n_max_boxes, 1]) # (b, n_max_boxes, h*w) |
||||
max_overlaps_idx = overlaps.argmax(1) # (b, h*w) |
||||
is_max_overlaps = F.one_hot(max_overlaps_idx, n_max_boxes) # (b, h*w, n_max_boxes) |
||||
is_max_overlaps = is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype) # (b, n_max_boxes, h*w) |
||||
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos) # (b, n_max_boxes, h*w) |
||||
fg_mask = mask_pos.sum(-2) |
||||
# find each grid serve which gt(index) |
||||
target_gt_idx = mask_pos.argmax(-2) # (b, h*w) |
||||
return target_gt_idx, fg_mask, mask_pos |
||||
|
||||
|
||||
class TaskAlignedAssigner(nn.Module): |
||||
|
||||
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9): |
||||
super().__init__() |
||||
self.topk = topk |
||||
self.num_classes = num_classes |
||||
self.bg_idx = num_classes |
||||
self.alpha = alpha |
||||
self.beta = beta |
||||
self.eps = eps |
||||
|
||||
@torch.no_grad() |
||||
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt): |
||||
"""This code referenced to |
||||
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py |
||||
|
||||
Args: |
||||
pd_scores (Tensor): shape(bs, num_total_anchors, num_classes) |
||||
pd_bboxes (Tensor): shape(bs, num_total_anchors, 4) |
||||
anc_points (Tensor): shape(num_total_anchors, 2) |
||||
gt_labels (Tensor): shape(bs, n_max_boxes, 1) |
||||
gt_bboxes (Tensor): shape(bs, n_max_boxes, 4) |
||||
mask_gt (Tensor): shape(bs, n_max_boxes, 1) |
||||
Returns: |
||||
target_labels (Tensor): shape(bs, num_total_anchors) |
||||
target_bboxes (Tensor): shape(bs, num_total_anchors, 4) |
||||
target_scores (Tensor): shape(bs, num_total_anchors, num_classes) |
||||
fg_mask (Tensor): shape(bs, num_total_anchors) |
||||
""" |
||||
self.bs = pd_scores.size(0) |
||||
self.n_max_boxes = gt_bboxes.size(1) |
||||
|
||||
if self.n_max_boxes == 0: |
||||
device = gt_bboxes.device |
||||
return (torch.full_like(pd_scores[..., 0], self.bg_idx).to(device), torch.zeros_like(pd_bboxes).to(device), |
||||
torch.zeros_like(pd_scores).to(device), torch.zeros_like(pd_scores[..., 0]).to(device)) |
||||
|
||||
mask_pos, align_metric, overlaps = self.get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, |
||||
mask_gt) |
||||
|
||||
target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes) |
||||
|
||||
# assigned target |
||||
target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask) |
||||
|
||||
# normalize |
||||
align_metric *= mask_pos |
||||
pos_align_metrics = align_metric.amax(axis=-1, keepdim=True) # b, max_num_obj |
||||
pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True) # b, max_num_obj |
||||
norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1) |
||||
target_scores = target_scores * norm_align_metric |
||||
|
||||
return target_labels, target_bboxes, target_scores, fg_mask.bool() |
||||
|
||||
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt): |
||||
# get anchor_align metric, (b, max_num_obj, h*w) |
||||
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes) |
||||
# get in_gts mask, (b, max_num_obj, h*w) |
||||
mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes) |
||||
# get topk_metric mask, (b, max_num_obj, h*w) |
||||
mask_topk = self.select_topk_candidates(align_metric * mask_in_gts, |
||||
topk_mask=mask_gt.repeat([1, 1, self.topk]).bool()) |
||||
# merge all mask to a final mask, (b, max_num_obj, h*w) |
||||
mask_pos = mask_topk * mask_in_gts * mask_gt |
||||
|
||||
return mask_pos, align_metric, overlaps |
||||
|
||||
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes): |
||||
ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj |
||||
ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes) # b, max_num_obj |
||||
ind[1] = gt_labels.long().squeeze(-1) # b, max_num_obj |
||||
# get the scores of each grid for each gt cls |
||||
bbox_scores = pd_scores[ind[0], :, ind[1]] # b, max_num_obj, h*w |
||||
|
||||
overlaps = bbox_iou(gt_bboxes.unsqueeze(2), pd_bboxes.unsqueeze(1), xywh=False, CIoU=True).squeeze(3).clamp(0) |
||||
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta) |
||||
return align_metric, overlaps |
||||
|
||||
def select_topk_candidates(self, metrics, largest=True, topk_mask=None): |
||||
""" |
||||
Args: |
||||
metrics: (b, max_num_obj, h*w). |
||||
topk_mask: (b, max_num_obj, topk) or None |
||||
""" |
||||
|
||||
num_anchors = metrics.shape[-1] # h*w |
||||
# (b, max_num_obj, topk) |
||||
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest) |
||||
if topk_mask is None: |
||||
topk_mask = (topk_metrics.max(-1, keepdim=True) > self.eps).tile([1, 1, self.topk]) |
||||
# (b, max_num_obj, topk) |
||||
topk_idxs = torch.where(topk_mask, topk_idxs, 0) |
||||
# (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w) |
||||
is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2) |
||||
# filter invalid bboxes |
||||
# assigned topk should be unique, this is for dealing with empty labels |
||||
# since empty labels will generate index `0` through `F.one_hot` |
||||
# NOTE: but what if the topk_idxs include `0`? |
||||
is_in_topk = torch.where(is_in_topk > 1, 0, is_in_topk) |
||||
return is_in_topk.to(metrics.dtype) |
||||
|
||||
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask): |
||||
""" |
||||
Args: |
||||
gt_labels: (b, max_num_obj, 1) |
||||
gt_bboxes: (b, max_num_obj, 4) |
||||
target_gt_idx: (b, h*w) |
||||
fg_mask: (b, h*w) |
||||
""" |
||||
|
||||
# assigned target labels, (b, 1) |
||||
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None] |
||||
target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w) |
||||
target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w) |
||||
|
||||
# assigned target boxes, (b, max_num_obj, 4) -> (b, h*w) |
||||
target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx] |
||||
|
||||
# assigned target scores |
||||
target_labels.clamp(0) |
||||
target_scores = F.one_hot(target_labels, self.num_classes) # (b, h*w, 80) |
||||
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80) |
||||
target_scores = torch.where(fg_scores_mask > 0, target_scores, 0) |
||||
|
||||
return target_labels, target_bboxes, target_scores |
||||
|
||||
|
||||
def make_anchors(feats, strides, grid_cell_offset=0.5): |
||||
"""Generate anchors from features.""" |
||||
anchor_points, stride_tensor = [], [] |
||||
assert feats is not None |
||||
dtype, device = feats[0].dtype, feats[0].device |
||||
for i, stride in enumerate(strides): |
||||
_, _, h, w = feats[i].shape |
||||
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x |
||||
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y |
||||
sy, sx = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx) |
||||
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2)) |
||||
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) |
||||
return torch.cat(anchor_points), torch.cat(stride_tensor) |
||||
|
||||
|
||||
def dist2bbox(distance, anchor_points, xywh=True, dim=-1): |
||||
"""Transform distance(ltrb) to box(xywh or xyxy).""" |
||||
lt, rb = torch.split(distance, 2, dim) |
||||
x1y1 = anchor_points - lt |
||||
x2y2 = anchor_points + rb |
||||
if xywh: |
||||
c_xy = (x1y1 + x2y2) / 2 |
||||
wh = x2y2 - x1y1 |
||||
return torch.cat((c_xy, wh), dim) # xywh bbox |
||||
return torch.cat((x1y1, x2y2), dim) # xyxy bbox |
||||
|
||||
|
||||
def bbox2dist(anchor_points, bbox, reg_max): |
||||
"""Transform bbox(xyxy) to dist(ltrb).""" |
||||
x1y1, x2y2 = torch.split(bbox, 2, -1) |
||||
return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp(0, reg_max - 0.01) # dist (lt, rb) |
@ -1,48 +0,0 @@ |
||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license |
||||
|
||||
# Parameters |
||||
nc: 80 # number of classes |
||||
depth_multiple: 0.33 # model depth multiple |
||||
width_multiple: 0.25 # layer channel multiple |
||||
anchors: |
||||
- [10,13, 16,30, 33,23] # P3/8 |
||||
- [30,61, 62,45, 59,119] # P4/16 |
||||
- [116,90, 156,198, 373,326] # P5/32 |
||||
|
||||
# YOLOv5 v6.0 backbone |
||||
backbone: |
||||
# [from, number, module, args] |
||||
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 |
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4 |
||||
[-1, 3, C3, [128]], |
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8 |
||||
[-1, 6, C3, [256]], |
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16 |
||||
[-1, 9, C3, [512]], |
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 |
||||
[-1, 3, C3, [1024]], |
||||
[-1, 1, SPPF, [1024, 5]], # 9 |
||||
] |
||||
|
||||
# YOLOv5 v6.0 head |
||||
head: |
||||
[[-1, 1, Conv, [512, 1, 1]], |
||||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], |
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4 |
||||
[-1, 3, C3, [512, False]], # 13 |
||||
|
||||
[-1, 1, Conv, [256, 1, 1]], |
||||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], |
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3 |
||||
[-1, 3, C3, [256, False]], # 17 (P3/8-small) |
||||
|
||||
[-1, 1, Conv, [256, 3, 2]], |
||||
[[-1, 14], 1, Concat, [1]], # cat head P4 |
||||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium) |
||||
|
||||
[-1, 1, Conv, [512, 3, 2]], |
||||
[[-1, 10], 1, Concat, [1]], # cat head P5 |
||||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large) |
||||
|
||||
[[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5) |
||||
] |
@ -1,48 +0,0 @@ |
||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license |
||||
|
||||
# Parameters |
||||
nc: 80 # number of classes |
||||
depth_multiple: 0.33 # model depth multiple |
||||
width_multiple: 0.25 # layer channel multiple |
||||
anchors: |
||||
- [10,13, 16,30, 33,23] # P3/8 |
||||
- [30,61, 62,45, 59,119] # P4/16 |
||||
- [116,90, 156,198, 373,326] # P5/32 |
||||
|
||||
# YOLOv5 v6.0 backbone |
||||
backbone: |
||||
# [from, number, module, args] |
||||
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 |
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4 |
||||
[-1, 3, C3, [128]], |
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8 |
||||
[-1, 6, C3, [256]], |
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16 |
||||
[-1, 9, C3, [512]], |
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 |
||||
[-1, 3, C3, [1024]], |
||||
[-1, 1, SPPF, [1024, 5]], # 9 |
||||
] |
||||
|
||||
# YOLOv5 v6.0 head |
||||
head: |
||||
[[-1, 1, Conv, [512, 1, 1]], |
||||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], |
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4 |
||||
[-1, 3, C3, [512, False]], # 13 |
||||
|
||||
[-1, 1, Conv, [256, 1, 1]], |
||||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], |
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3 |
||||
[-1, 3, C3, [256, False]], # 17 (P3/8-small) |
||||
|
||||
[-1, 1, Conv, [256, 3, 2]], |
||||
[[-1, 14], 1, Concat, [1]], # cat head P4 |
||||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium) |
||||
|
||||
[-1, 1, Conv, [512, 3, 2]], |
||||
[[-1, 10], 1, Concat, [1]], # cat head P5 |
||||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large) |
||||
|
||||
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) |
||||
] |
@ -0,0 +1,43 @@ |
||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license |
||||
|
||||
# Parameters |
||||
nc: 80 # number of classes |
||||
depth_multiple: 0.33 # model depth multiple |
||||
width_multiple: 0.25 # layer channel multiple |
||||
anchors: [[16,19], [55,65], [178,192]] |
||||
|
||||
# YOLOv8n v0.0 backbone |
||||
backbone: |
||||
# [from, number, module, args] |
||||
[[-1, 1, Conv, [64, 3, 2]], # 0-P1/2 |
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4 |
||||
[-1, 3, C2f, [128, True]], |
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8 |
||||
[-1, 6, C2f, [256, True]], |
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16 |
||||
[-1, 6, C2f, [512, True]], |
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 |
||||
[-1, 3, C2f, [1024, True]], |
||||
[-1, 1, SPPF, [1024, 5]], # 9 |
||||
] |
||||
|
||||
# YOLOv8n v0.0 head |
||||
head: |
||||
[[-1, 1, nn.Upsample, [None, 2, 'nearest']], |
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4 |
||||
[-1, 3, C2f, [512]], # 13 |
||||
|
||||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], |
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3 |
||||
[-1, 3, C2f, [256]], # 17 (P3/8-small) |
||||
|
||||
[-1, 1, Conv, [256, 3, 2]], |
||||
[[-1, 12], 1, Concat, [1]], # cat head P4 |
||||
[-1, 3, C2f, [512]], # 20 (P4/16-medium) |
||||
|
||||
[-1, 1, Conv, [512, 3, 2]], |
||||
[[-1, 9], 1, Concat, [1]], # cat head P5 |
||||
[-1, 3, C2f, [1024]], # 23 (P5/32-large) |
||||
|
||||
[[15, 18, 21], 1, Segment, [nc, 32, 256]], # Detect(P3, P4, P5) |
||||
] |
@ -0,0 +1,42 @@ |
||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license |
||||
|
||||
# Parameters |
||||
nc: 80 # number of classes |
||||
depth_multiple: 0.33 # model depth multiple |
||||
width_multiple: 0.25 # layer channel multiple |
||||
|
||||
# YOLOv8.0n backbone |
||||
backbone: |
||||
# [from, number, module, args] |
||||
[[-1, 1, Conv, [64, 3, 2]], # 0-P1/2 |
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4 |
||||
[-1, 3, C2f, [128, True]], |
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8 |
||||
[-1, 6, C2f, [256, True]], |
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16 |
||||
[-1, 6, C2f, [512, True]], |
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 |
||||
[-1, 3, C2f, [1024, True]], |
||||
[-1, 1, SPPF, [1024, 5]], # 9 |
||||
] |
||||
|
||||
# YOLOv8.0n head |
||||
head: |
||||
[[-1, 1, nn.Upsample, [None, 2, 'nearest']], |
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4 |
||||
[-1, 3, C2f, [512]], # 13 |
||||
|
||||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], |
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3 |
||||
[-1, 3, C2f, [256]], # 17 (P3/8-small) |
||||
|
||||
[-1, 1, Conv, [256, 3, 2]], |
||||
[[-1, 12], 1, Concat, [1]], # cat head P4 |
||||
[-1, 3, C2f, [512]], # 20 (P4/16-medium) |
||||
|
||||
[-1, 1, Conv, [512, 3, 2]], |
||||
[[-1, 9], 1, Concat, [1]], # cat head P5 |
||||
[-1, 3, C2f, [1024]], # 23 (P5/32-large) |
||||
|
||||
[[15, 18, 21], 1, Detect, [nc]], # Detect(P3, P4, P5) |
||||
] |
Loading…
Reference in new issue