Tests and docstrings improvements (#4475)

pull/4482/head
Glenn Jocher 1 year ago committed by GitHub
parent c659c0fa7b
commit 615ddc9d97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      docs/modes/benchmark.md
  2. 2
      docs/usage/python.md
  3. 4
      docs/yolov5/tutorials/clearml_logging_integration.md
  4. 2
      docs/yolov5/tutorials/tips_for_best_training_results.md
  5. 17
      tests/test_python.py
  6. 11
      ultralytics/data/loaders.py
  7. 8
      ultralytics/data/utils.py
  8. 2
      ultralytics/engine/model.py
  9. 4
      ultralytics/models/sam/modules/decoders.py
  10. 40
      ultralytics/models/sam/modules/encoders.py
  11. 28
      ultralytics/models/sam/modules/sam.py
  12. 89
      ultralytics/models/sam/modules/tiny_encoder.py
  13. 25
      ultralytics/models/sam/modules/transformer.py
  14. 4
      ultralytics/models/yolo/segment/val.py
  15. 9
      ultralytics/nn/modules/transformer.py
  16. 6
      ultralytics/nn/tasks.py
  17. 4
      ultralytics/utils/__init__.py
  18. 21
      ultralytics/utils/callbacks/clearml.py
  19. 2
      ultralytics/utils/callbacks/dvc.py
  20. 11
      ultralytics/utils/downloads.py
  21. 2
      ultralytics/utils/metrics.py
  22. 2
      ultralytics/utils/ops.py

@ -46,7 +46,7 @@ the benchmarks to their specific needs and compare the performance of different
| Key | Value | Description |
|-----------|---------|-----------------------------------------------------------------------|
| `model` | `None` | path to model file, i.e. yolov8n.pt, yolov8n.yaml |
| `data` | `None` | path to yaml referencing the benchmarking dataset (under `val` label) |
| `data` | `None` | path to YAML referencing the benchmarking dataset (under `val` label) |
| `imgsz` | `640` | image size as scalar or (h, w) list, i.e. (640, 480) |
| `half` | `False` | FP16 quantization |
| `int8` | `False` | INT8 quantization |

@ -93,7 +93,7 @@ of the model to improve its performance.
from ultralytics import YOLO
model = YOLO("model.pt")
# It'll use the data yaml file in model.pt if you don't set data.
# It'll use the data YAML file in model.pt if you don't set data.
model.val()
# or you can set the data you want to val
model.val(data='coco128.yaml')

@ -107,7 +107,7 @@ Versioning your data separately from your code is generally a good idea and make
### Prepare Your Dataset
The YOLOv5 repository supports a number of different datasets by using yaml files containing their information. By default datasets are downloaded to the `../datasets` folder in relation to the repository root folder. So if you downloaded the `coco128` dataset using the link in the yaml or with the scripts provided by yolov5, you get this folder structure:
The YOLOv5 repository supports a number of different datasets by using YAML files containing their information. By default datasets are downloaded to the `../datasets` folder in relation to the repository root folder. So if you downloaded the `coco128` dataset using the link in the YAML or with the scripts provided by yolov5, you get this folder structure:
```
..
@ -122,7 +122,7 @@ The YOLOv5 repository supports a number of different datasets by using yaml file
But this can be any dataset you wish. Feel free to use your own, as long as you keep to this folder structure.
Next, ⚠**copy the corresponding yaml file to the root of the dataset folder**⚠. This yaml files contains the information ClearML will need to properly use the dataset. You can make this yourself too, of course, just follow the structure of the example yamls.
Next, ⚠**copy the corresponding YAML file to the root of the dataset folder**⚠. This YAML files contains the information ClearML will need to properly use the dataset. You can make this yourself too, of course, just follow the structure of the example YAMLs.
Basically we need the following keys: `path`, `train`, `test`, `val`, `nc`, `names`.

@ -41,7 +41,7 @@ python train.py --data custom.yaml --weights yolov5s.pt
custom_pretrained.pt
```
- **Start from Scratch.** Recommended for large datasets (i.e. [COCO](https://github.com/ultralytics/yolov5/blob/master/data/coco.yaml), [Objects365](https://github.com/ultralytics/yolov5/blob/master/data/Objects365.yaml), [OIv6](https://storage.googleapis.com/openimages/web/index.html)). Pass the model architecture yaml you are interested in, along with an empty `--weights ''` argument:
- **Start from Scratch.** Recommended for large datasets (i.e. [COCO](https://github.com/ultralytics/yolov5/blob/master/data/coco.yaml), [Objects365](https://github.com/ultralytics/yolov5/blob/master/data/Objects365.yaml), [OIv6](https://storage.googleapis.com/openimages/web/index.html)). Pass the model architecture YAML you are interested in, along with an empty `--weights ''` argument:
```bash
python train.py --data custom.yaml --weights '' --cfg yolov5s.yaml

@ -35,28 +35,19 @@ def test_model_methods():
model = model.reset_weights()
model = model.load(MODEL)
model.to('cpu')
model.fuse()
_ = model.names
_ = model.device
def test_model_fuse():
model = YOLO(MODEL)
model.fuse()
def test_predict_dir():
model = YOLO(MODEL)
model(source=ASSETS, imgsz=32)
def test_predict_txt():
# Write a list of sources to a txt file
# Write a list of sources (file, dir, glob, recursive glob) to a txt file
txt_file = TMP / 'sources.txt'
with open(txt_file, 'w') as f:
for x in [ASSETS / 'bus.jpg', ASSETS / 'zidane.jpg']:
for x in [ASSETS / 'bus.jpg', ASSETS, ASSETS / '*', ASSETS / '**/*.jpg']:
f.write(f'{x}\n')
model = YOLO(MODEL)
model(source=txt_file, imgsz=640)
model(source=txt_file, imgsz=32)
def test_predict_img():

@ -16,7 +16,7 @@ import torch
from PIL import Image
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.utils import ASSETS, LOGGER, is_colab, is_kaggle, ops
from ultralytics.utils import LOGGER, is_colab, is_kaggle, ops
from ultralytics.utils.checks import check_requirements
@ -167,7 +167,7 @@ class LoadScreenshots:
def __next__(self):
"""mss screen capture: get raw pixels from the screen as np array."""
im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3] # BGRA to BGR
s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: '
self.frame += 1
@ -400,10 +400,3 @@ def get_best_youtube_url(url, use_pafy=False):
good_size = (f.get('width') or 0) >= 1920 or (f.get('height') or 0) >= 1080
if good_size and f['vcodec'] != 'none' and f['acodec'] == 'none' and f['ext'] == 'mp4':
return f.get('url')
if __name__ == '__main__':
img = cv2.imread(str(ASSETS / 'bus.jpg'))
dataset = LoadPilAndNumpy(im0=img)
for d in dataset:
print(d[0])

@ -204,7 +204,7 @@ def check_det_dataset(dataset, autodownload=True):
data = next((DATASETS_DIR / new_dir).rglob('*.yaml'))
extract_dir, autodownload = data.parent, False
# Read yaml (optional)
# Read YAML (optional)
if isinstance(data, (str, Path)):
data = yaml_load(data, append_filename=True) # dictionary
@ -244,7 +244,7 @@ def check_det_dataset(dataset, autodownload=True):
else:
data[k] = [str((path / x).resolve()) for x in data[k]]
# Parse yaml
# Parse YAML
train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
if val:
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
@ -321,12 +321,12 @@ def check_cls_dataset(dataset: str, split=''):
# Print to console
for k, v in {'train': train_set, 'val': val_set, 'test': test_set}.items():
if v is None:
LOGGER.info(colorstr(k) + f': {v}')
LOGGER.info(f'{colorstr(k)}: {v}')
else:
files = [path for path in v.rglob('*.*') if path.suffix[1:].lower() in IMG_FORMATS]
nf = len(files) # number of files
nd = len({file.parent for file in files}) # number of directories
LOGGER.info(colorstr(k) + f': {v}... found {nf} images in {nd} classes ✅ ') # keep trailing space
LOGGER.info(f'{colorstr(k)}: {v}... found {nf} images in {nd} classes ✅ ') # keep trailing space
return {'train': train_set, 'val': val_set or test_set, 'test': test_set or val_set, 'nc': nc, 'names': names}

@ -122,7 +122,7 @@ class Model:
self.model = model(cfg_dict, verbose=verbose and RANK == -1) # build model
self.overrides['model'] = self.cfg
# Below added to allow export from yamls
# Below added to allow export from YAMLs
args = {**DEFAULT_CFG_DICT, **self.overrides} # combine model and default args, preferring model args
self.model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
self.model.task = self.task

@ -24,7 +24,7 @@ class MaskDecoder(nn.Module):
"""
Predicts masks given an image and prompt embeddings, using a transformer architecture.
Arguments:
Args:
transformer_dim (int): the channel dimension of the transformer module
transformer (nn.Module): the transformer used to predict masks
num_multimask_outputs (int): the number of masks to predict when disambiguating masks
@ -65,7 +65,7 @@ class MaskDecoder(nn.Module):
"""
Predict masks given image and prompt embeddings.
Arguments:
Args:
image_embeddings (torch.Tensor): the embeddings from the image encoder
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes

@ -103,13 +103,9 @@ class ImageEncoderViT(nn.Module):
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
x = self.neck(x.permute(0, 3, 1, 2))
return x
return self.neck(x.permute(0, 3, 1, 2))
class PromptEncoder(nn.Module):
@ -125,7 +121,7 @@ class PromptEncoder(nn.Module):
"""
Encodes prompts for input to SAM's mask decoder.
Arguments:
Args:
embed_dim (int): The prompts' embedding dimension
image_embedding_size (tuple(int, int)): The spatial size of the
image embedding, as (H, W).
@ -165,8 +161,7 @@ class PromptEncoder(nn.Module):
applied to a dense set of points the shape of the image encoding.
Returns:
torch.Tensor: Positional encoding with shape
1x(embed_dim)x(embedding_h)x(embedding_w)
torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
"""
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
@ -231,21 +226,17 @@ class PromptEncoder(nn.Module):
masks: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense
embeddings.
Embeds different types of prompts, returning both sparse and dense embeddings.
Arguments:
points (tuple(torch.Tensor, torch.Tensor), None): point coordinates
and labels to embed.
Args:
points (tuple(torch.Tensor, torch.Tensor), None): point coordinates and labels to embed.
boxes (torch.Tensor, None): boxes to embed
masks (torch.Tensor, None): masks to embed
Returns:
torch.Tensor: sparse embeddings for the points and boxes, with shape
BxNx(embed_dim), where N is determined by the number of input points
and boxes.
torch.Tensor: dense embeddings for the masks, in the shape
Bx(embed_dim)x(embed_H)x(embed_W)
torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined
by the number of input points and boxes.
torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W)
"""
bs = self._get_batch_size(points, boxes, masks)
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
@ -372,9 +363,7 @@ class Block(nn.Module):
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
return x + self.mlp(self.norm2(x))
class Attention(nn.Module):
@ -427,9 +416,7 @@ class Attention(nn.Module):
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
return self.proj(x)
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
@ -577,7 +564,4 @@ class PatchEmbed(nn.Module):
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C

@ -29,7 +29,7 @@ class Sam(nn.Module):
"""
SAM predicts object masks from an image and input prompts.
Arguments:
Args:
image_encoder (ImageEncoderViT): The backbone used to encode the
image into image embeddings that allow for efficient mask prediction.
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
@ -60,14 +60,12 @@ class Sam(nn.Module):
multimask_output: bool,
) -> List[Dict[str, torch.Tensor]]:
"""
Predicts masks end-to-end from provided images and prompts.
If prompts are not known in advance, using SamPredictor is
recommended over calling the model directly.
Arguments:
batched_input (list(dict)): A list over input images, each a
dictionary with the following keys. A prompt key can be
excluded if it is not present.
Predicts masks end-to-end from provided images and prompts. If prompts are not known in advance, using
SamPredictor is recommended over calling the model directly.
Args:
batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt
key can be excluded if it is not present.
'image': The image as a torch tensor in 3xHxW format,
already transformed for input to the model.
'original_size': (tuple(int, int)) The original size of
@ -81,12 +79,11 @@ class Sam(nn.Module):
Already transformed to the input frame of the model.
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
in the form Bx1xHxW.
multimask_output (bool): Whether the model should predict multiple
disambiguating masks, or return a single mask.
multimask_output (bool): Whether the model should predict multiple disambiguating masks, or return a single
mask.
Returns:
(list(dict)): A list over input images, where each element is
as dictionary with the following keys.
(list(dict)): A list over input images, where each element is as dictionary with the following keys.
'masks': (torch.Tensor) Batched binary mask predictions,
with shape BxCxHxW, where B is the number of input prompts,
C is determined by multimask_output, and (H, W) is the
@ -139,7 +136,7 @@ class Sam(nn.Module):
"""
Remove padding and upscale masks to the original image size.
Arguments:
Args:
masks (torch.Tensor): Batched masks from the mask_decoder,
in BxCxHxW format.
input_size (tuple(int, int)): The size of the image input to the
@ -158,8 +155,7 @@ class Sam(nn.Module):
align_corners=False,
)
masks = masks[..., :input_size[0], :input_size[1]]
masks = F.interpolate(masks, original_size, mode='bilinear', align_corners=False)
return masks
return F.interpolate(masks, original_size, mode='bilinear', align_corners=False)
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""

@ -35,8 +35,7 @@ class Conv2d_BN(torch.nn.Sequential):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / \
(bn.running_var + bn.eps)**0.5
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
m = torch.nn.Conv2d(w.size(1) * self.c.groups,
w.size(0),
w.shape[2:],
@ -72,8 +71,7 @@ class PatchEmbed(nn.Module):
super().__init__()
img_size: Tuple[int, int] = to_2tuple(resolution)
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
self.num_patches = self.patches_resolution[0] * \
self.patches_resolution[1]
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
n = embed_dim
@ -110,21 +108,14 @@ class MBConv(nn.Module):
def forward(self, x):
shortcut = x
x = self.conv1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.act2(x)
x = self.conv3(x)
x = self.drop_path(x)
x += shortcut
x = self.act3(x)
return x
return self.act3(x)
class PatchMerging(nn.Module):
@ -137,9 +128,7 @@ class PatchMerging(nn.Module):
self.out_dim = out_dim
self.act = activation()
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
stride_c = 2
if (out_dim == 320 or out_dim == 448 or out_dim == 576):
stride_c = 1
stride_c = 1 if out_dim in [320, 448, 576] else 2
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
@ -156,8 +145,7 @@ class PatchMerging(nn.Module):
x = self.conv2(x)
x = self.act(x)
x = self.conv3(x)
x = x.flatten(2).transpose(1, 2)
return x
return x.flatten(2).transpose(1, 2)
class ConvLayer(nn.Module):
@ -174,7 +162,6 @@ class ConvLayer(nn.Module):
out_dim=None,
conv_expand_ratio=4.,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
@ -192,20 +179,13 @@ class ConvLayer(nn.Module):
) for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
else:
self.downsample = None
self.downsample = None if downsample is None else downsample(
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
return x if self.downsample is None else self.downsample(x)
class Mlp(nn.Module):
@ -222,13 +202,11 @@ class Mlp(nn.Module):
def forward(self, x):
x = self.norm(x)
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
return self.drop(x)
class Attention(torch.nn.Module):
@ -297,12 +275,12 @@ class Attention(torch.nn.Module):
(self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab))
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
x = self.proj(x)
return x
return self.proj(x)
class TinyViTBlock(nn.Module):
r""" TinyViT Block.
"""
TinyViT Block.
Args:
dim (int): Number of input channels.
@ -312,8 +290,7 @@ class TinyViTBlock(nn.Module):
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
drop (float, optional): Dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
local_conv_size (int): the kernel size of the convolution between
Attention and MLP. Default: 3
local_conv_size (int): the kernel size of the convolution between Attention and MLP. Default: 3
activation (torch.nn): the activation function. Default: nn.GELU
"""
@ -391,8 +368,7 @@ class TinyViTBlock(nn.Module):
x = self.local_conv(x)
x = x.view(B, C, L).transpose(1, 2)
x = x + self.drop_path(self.mlp(x))
return x
return x + self.drop_path(self.mlp(x))
def extra_repr(self) -> str:
return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
@ -400,7 +376,8 @@ class TinyViTBlock(nn.Module):
class BasicLayer(nn.Module):
""" A basic TinyViT layer for one stage.
"""
A basic TinyViT layer for one stage.
Args:
dim (int): Number of input channels.
@ -434,7 +411,6 @@ class BasicLayer(nn.Module):
activation=nn.GELU,
out_dim=None,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
@ -456,20 +432,13 @@ class BasicLayer(nn.Module):
) for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
else:
self.downsample = None
self.downsample = None if downsample is None else downsample(
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
return x if self.downsample is None else self.downsample(x)
def extra_repr(self) -> str:
return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
@ -487,8 +456,7 @@ class LayerNorm2d(nn.Module):
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
return self.weight[:, None, None] * x + self.bias[:, None, None]
class TinyViT(nn.Module):
@ -548,10 +516,7 @@ class TinyViT(nn.Module):
activation=activation,
)
if i_layer == 0:
layer = ConvLayer(
conv_expand_ratio=mbconv_expand_ratio,
**kwargs,
)
layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)
else:
layer = BasicLayer(num_heads=num_heads[i_layer],
window_size=window_sizes[i_layer],
@ -622,7 +587,7 @@ class TinyViT(nn.Module):
if isinstance(m, nn.Linear):
# NOTE: This initialization is needed only for training.
# trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
@ -645,9 +610,7 @@ class TinyViT(nn.Module):
B, _, C = x.size()
x = x.view(B, 64, 64, C)
x = x.permute(0, 3, 1, 2)
x = self.neck(x)
return x
return self.neck(x)
def forward(self, x):
x = self.forward_features(x)
return x
return self.forward_features(x)

@ -61,16 +61,14 @@ class TwoWayTransformer(nn.Module):
) -> Tuple[Tensor, Tensor]:
"""
Args:
image_embedding (torch.Tensor): image to attend to. Should be shape
B x embedding_dim x h x w for any h and w.
image_pe (torch.Tensor): the positional encoding to add to the image. Must
have the same shape as image_embedding.
image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w.
image_pe (torch.Tensor): the positional encoding to add to the image. Must have same shape as image_embedding.
point_embedding (torch.Tensor): the embedding to add to the query points.
Must have shape B x N_points x embedding_dim for any N_points.
Returns:
torch.Tensor: the processed point_embedding
torch.Tensor: the processed image_embedding
(torch.Tensor): the processed point_embedding
(torch.Tensor): the processed image_embedding
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
bs, c, h, w = image_embedding.shape
@ -112,12 +110,11 @@ class TwoWayAttentionBlock(nn.Module):
skip_first_layer_pe: bool = False,
) -> None:
"""
A transformer block with four layers: (1) self-attention of sparse
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
block on sparse inputs, and (4) cross attention of dense inputs to sparse
A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse
inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse
inputs.
Arguments:
Args:
embedding_dim (int): the channel dimension of the embeddings
num_heads (int): the number of heads in the attention layers
mlp_dim (int): the hidden dimension of the mlp block
@ -175,8 +172,8 @@ class TwoWayAttentionBlock(nn.Module):
class Attention(nn.Module):
"""
An attention layer that allows for downscaling the size of the embedding
after projection to queries, keys, and values.
An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
values.
"""
def __init__(
@ -230,6 +227,4 @@ class Attention(nn.Module):
# Get output
out = attn @ v
out = self._recombine_heads(out)
out = self.out_proj(out)
return out
return self.out_proj(out)

@ -145,9 +145,11 @@ class SegmentationValidator(DetectionValidator):
def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
"""
Return correct prediction matrix
Arguments:
Args:
detections (array[N, 6]), x1, y1, x2, y2, conf, class
labels (array[M, 5]), class, x1, y1, x2, y2
Returns:
correct (array[N, 10]), for 10 IoU levels
"""

@ -97,8 +97,7 @@ class AIFI(TransformerEncoderLayer):
out_w = grid_w.flatten()[..., None] @ omega[None]
out_h = grid_h.flatten()[..., None] @ omega[None]
return torch.concat([torch.sin(out_w), torch.cos(out_w),
torch.sin(out_h), torch.cos(out_h)], axis=1)[None, :, :]
return torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], 1)[None]
class TransformerLayer(nn.Module):
@ -170,9 +169,11 @@ class MLP(nn.Module):
return x
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
class LayerNorm2d(nn.Module):
"""
LayerNorm2d module from https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py
https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119
"""
def __init__(self, num_channels, eps=1e-6):
super().__init__()

@ -229,7 +229,7 @@ class DetectionModel(BaseModel):
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
if nc and nc != self.yaml['nc']:
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
self.yaml['nc'] = nc # override yaml value
self.yaml['nc'] = nc # override YAML value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
self.inplace = self.yaml.get('inplace', True)
@ -329,7 +329,7 @@ class ClassificationModel(BaseModel):
ch=3,
nc=None,
cutoff=10,
verbose=True): # yaml, model, channels, number of classes, cutoff index, verbose flag
verbose=True): # YAML, model, channels, number of classes, cutoff index, verbose flag
super().__init__()
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
@ -357,7 +357,7 @@ class ClassificationModel(BaseModel):
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
if nc and nc != self.yaml['nc']:
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
self.yaml['nc'] = nc # override yaml value
self.yaml['nc'] = nc # override YAML value
elif not nc and not self.yaml.get('nc', None):
raise ValueError('nc not specified. Must specify nc in model.yaml or function arguments.')
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist

@ -341,10 +341,10 @@ def yaml_load(file='data.yaml', append_filename=False):
def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
"""
Pretty prints a yaml file or a yaml-formatted dictionary.
Pretty prints a YAML file or a YAML-formatted dictionary.
Args:
yaml_file: The file path of the yaml file or a yaml-formatted dictionary.
yaml_file: The file path of the YAML file or a YAML-formatted dictionary.
Returns:
None

@ -29,8 +29,7 @@ def _log_debug_samples(files, title='Debug Samples') -> None:
files (list): A list of file paths in PosixPath format.
title (str): A title that groups together images with the same values.
"""
task = Task.current_task()
if task:
if task := Task.current_task():
for f in files:
if f.exists():
it = re.search(r'_batch(\d+)', f.name)
@ -63,8 +62,7 @@ def _log_plot(title, plot_path) -> None:
def on_pretrain_routine_start(trainer):
"""Runs at start of pretraining routine; initializes and connects/ logs task to ClearML."""
try:
task = Task.current_task()
if task:
if task := Task.current_task():
# Make sure the automatic pytorch and matplotlib bindings are disabled!
# We are logging these plots and model files manually in the integration
PatchPyTorchModelIO.update_current_task(None)
@ -86,21 +84,19 @@ def on_pretrain_routine_start(trainer):
def on_train_epoch_end(trainer):
task = Task.current_task()
if task:
"""Logs debug samples for the first epoch of YOLO training."""
"""Logs debug samples for the first epoch of YOLO training and report current training progress."""
if task := Task.current_task():
# Log debug samples
if trainer.epoch == 1:
_log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic')
"""Report the current training progress."""
# Report the current training progress
for k, v in trainer.validator.metrics.results_dict.items():
task.get_logger().report_scalar('train', k, v, iteration=trainer.epoch)
def on_fit_epoch_end(trainer):
"""Reports model information to logger at the end of an epoch."""
task = Task.current_task()
if task:
if task := Task.current_task():
# You should have access to the validation bboxes under jdict
task.get_logger().report_scalar(title='Epoch Time',
series='Epoch Time',
@ -120,8 +116,7 @@ def on_val_end(validator):
def on_train_end(trainer):
"""Logs final model and its name on training completion."""
task = Task.current_task()
if task:
if task := Task.current_task():
# Log final results, CM matrix + PR plots
files = [
'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',

@ -40,7 +40,7 @@ def _log_images(path, prefix=''):
# Group images by batch to enable sliders in UI
if m := re.search(r'_batch(\d+)', name):
ni = m.group(1)
ni = m[1]
new_stem = re.sub(r'_batch(\d+)', '_batch', path.stem)
name = (Path(new_stem) / ni).with_suffix(path.suffix)

@ -93,7 +93,7 @@ def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), p
raise FileNotFoundError(f"Directory '{directory}' does not exist.")
# Unzip with progress bar
files_to_zip = [f for f in directory.rglob('*') if f.is_file() and not any(x in f.name for x in exclude)]
files_to_zip = [f for f in directory.rglob('*') if f.is_file() and all(x not in f.name for x in exclude)]
zip_file = directory.with_suffix('.zip')
compression = ZIP_DEFLATED if compress else ZIP_STORED
with ZipFile(zip_file, 'w', compression) as f:
@ -185,11 +185,9 @@ def check_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=1.5, h
f'Please free {data * sf - free:.1f} GB additional disk space and try again.')
if hard:
raise MemoryError(text)
else:
LOGGER.warning(text)
return False
LOGGER.warning(text)
return False
# Pass if error
return True
@ -332,6 +330,9 @@ def get_github_assets(repo='ultralytics/assets', version='latest', retry=False):
r = requests.get(url) # github api
if r.status_code != 200 and retry:
r = requests.get(url) # try again
if r.status_code != 200:
LOGGER.warning(f' GitHub assets check failure for {url}: {r.status_code} {r.reason}')
return '', []
data = r.json()
return data['tag_name'], [x['name'] for x in data['assets']] # tag, assets

@ -382,7 +382,7 @@ def compute_ap(recall, precision):
"""
Compute the average precision (AP) given the recall and precision curves.
Arguments:
Args:
recall (list): The recall curve.
precision (list): The precision curve.

@ -140,7 +140,7 @@ def non_max_suppression(
"""
Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
Arguments:
Args:
prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
containing the predicted boxes, classes, and masks. The tensor should be in the format
output by a model, such as YOLO.

Loading…
Cancel
Save