diff --git a/docs/en/guides/workouts-monitoring.md b/docs/en/guides/workouts-monitoring.md
index f0c32cf44..02e8209d9 100644
--- a/docs/en/guides/workouts-monitoring.md
+++ b/docs/en/guides/workouts-monitoring.md
@@ -125,5 +125,6 @@ Monitoring workouts through pose estimation with [Ultralytics YOLOv8](https://gi
| `visualize` | `bool` | `False` | visualize model features |
| `augment` | `bool` | `False` | apply image augmentation to prediction sources |
| `agnostic_nms` | `bool` | `False` | class-agnostic NMS |
+| `classes` | `list[int]` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |
| `retina_masks` | `bool` | `False` | use high-resolution segmentation masks |
-| `classes` | `None or list` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |
+| `embed` | `list[int]` | `None` | return feature vectors/embeddings from given layers |
diff --git a/docs/en/modes/predict.md b/docs/en/modes/predict.md
index 446d36edf..6c4561fa2 100644
--- a/docs/en/modes/predict.md
+++ b/docs/en/modes/predict.md
@@ -355,8 +355,9 @@ Inference arguments:
| `visualize` | `bool` | `False` | visualize model features |
| `augment` | `bool` | `False` | apply image augmentation to prediction sources |
| `agnostic_nms` | `bool` | `False` | class-agnostic NMS |
+| `classes` | `list[int]` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |
| `retina_masks` | `bool` | `False` | use high-resolution segmentation masks |
-| `classes` | `None or list` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |
+| `embed` | `list[int]` | `None` | return feature vectors/embeddings from given layers |
Visualization arguments:
diff --git a/docs/en/reference/nn/autobackend.md b/docs/en/reference/nn/autobackend.md
index 462789e50..3e8c2f7a2 100644
--- a/docs/en/reference/nn/autobackend.md
+++ b/docs/en/reference/nn/autobackend.md
@@ -18,3 +18,7 @@ keywords: Ultralytics, AutoBackend, check_class_names, YOLO, YOLO models, optimi
## ::: ultralytics.nn.autobackend.check_class_names
+
+## ::: ultralytics.nn.autobackend.default_class_names
+
+
diff --git a/docs/en/usage/cfg.md b/docs/en/usage/cfg.md
index ce0b1a23c..2e822aeeb 100644
--- a/docs/en/usage/cfg.md
+++ b/docs/en/usage/cfg.md
@@ -156,8 +156,9 @@ Inference arguments:
| `visualize` | `bool` | `False` | visualize model features |
| `augment` | `bool` | `False` | apply image augmentation to prediction sources |
| `agnostic_nms` | `bool` | `False` | class-agnostic NMS |
+| `classes` | `list[int]` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |
| `retina_masks` | `bool` | `False` | use high-resolution segmentation masks |
-| `classes` | `None or list` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |
+| `embed` | `list[int]` | `None` | return feature vectors/embeddings from given layers |
Visualization arguments:
diff --git a/tests/test_python.py b/tests/test_python.py
index 741ad5d08..710974d42 100644
--- a/tests/test_python.py
+++ b/tests/test_python.py
@@ -511,3 +511,13 @@ def test_model_tune():
"""Tune YOLO model for performance."""
YOLO('yolov8n-pose.pt').tune(data='coco8-pose.yaml', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu')
YOLO('yolov8n-cls.pt').tune(data='imagenet10', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu')
+
+
+def test_model_embeddings():
+ """Test YOLO model embeddings."""
+ model_detect = YOLO(MODEL)
+ model_segment = YOLO(WEIGHTS_DIR / 'yolov8n-seg.pt')
+
+ for batch in [SOURCE], [SOURCE, SOURCE]: # test batch size 1 and 2
+ assert len(model_detect.embed(source=batch, imgsz=32)) == len(batch)
+ assert len(model_segment.embed(source=batch, imgsz=32)) == len(batch)
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index 5862abb8b..fa5418361 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-__version__ = '8.0.228'
+__version__ = '8.0.229'
from ultralytics.models import RTDETR, SAM, YOLO
from ultralytics.models.fastsam import FastSAM
diff --git a/ultralytics/cfg/default.yaml b/ultralytics/cfg/default.yaml
index b3499853a..f6edad234 100644
--- a/ultralytics/cfg/default.yaml
+++ b/ultralytics/cfg/default.yaml
@@ -61,6 +61,7 @@ augment: False # (bool) apply image augmentation to prediction sources
agnostic_nms: False # (bool) class-agnostic NMS
classes: # (int | list[int], optional) filter results by class, i.e. classes=0, or classes=[0,2,3]
retina_masks: False # (bool) use high-resolution segmentation masks
+embed: # (list[int], optional) return feature vectors/embeddings from given layers
# Visualize settings ---------------------------------------------------------------------------------------------------
show: False # (bool) show predicted images and videos if environment allows
diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py
index af3a85077..374c872d3 100644
--- a/ultralytics/engine/model.py
+++ b/ultralytics/engine/model.py
@@ -94,7 +94,7 @@ class Model(nn.Module):
self._load(model, task)
def __call__(self, source=None, stream=False, **kwargs):
- """Calls the 'predict' function with given arguments to perform object detection."""
+ """Calls the predict() method with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
@staticmethod
@@ -201,6 +201,24 @@ class Model(nn.Module):
self._check_is_pytorch_model()
self.model.fuse()
+ def embed(self, source=None, stream=False, **kwargs):
+ """
+ Calls the predict() method and returns image embeddings.
+
+ Args:
+ source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
+ Accepts all source types accepted by the YOLO model.
+ stream (bool): Whether to stream the predictions or not. Defaults to False.
+ **kwargs : Additional keyword arguments passed to the predictor.
+ Check the 'configuration' section in the documentation for all available options.
+
+ Returns:
+ (List[torch.Tensor]): A list of image embeddings.
+ """
+ if not kwargs.get('embed'):
+ kwargs['embed'] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
+ return self.predict(source, stream, **kwargs)
+
def predict(self, source=None, stream=False, predictor=None, **kwargs):
"""
Perform prediction using the YOLO model.
diff --git a/ultralytics/engine/predictor.py b/ultralytics/engine/predictor.py
index 5e04b1793..cda599441 100644
--- a/ultralytics/engine/predictor.py
+++ b/ultralytics/engine/predictor.py
@@ -134,7 +134,7 @@ class BasePredictor:
"""Runs inference on a given image using the specified model and arguments."""
visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
- return self.model(im, augment=self.args.augment, visualize=visualize)
+ return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
def pre_transform(self, im):
"""
@@ -263,6 +263,9 @@ class BasePredictor:
# Inference
with profilers[1]:
preds = self.inference(im, *args, **kwargs)
+ if self.args.embed:
+ yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors
+ continue
# Postprocess
with profilers[2]:
diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py
index 09f982cae..a5c964ced 100644
--- a/ultralytics/nn/autobackend.py
+++ b/ultralytics/nn/autobackend.py
@@ -333,7 +333,7 @@ class AutoBackend(nn.Module):
self.__dict__.update(locals()) # assign all variables to self
- def forward(self, im, augment=False, visualize=False):
+ def forward(self, im, augment=False, visualize=False, embed=None):
"""
Runs inference on the YOLOv8 MultiBackend model.
@@ -341,6 +341,7 @@ class AutoBackend(nn.Module):
im (torch.Tensor): The image tensor to perform inference on.
augment (bool): whether to perform data augmentation during inference, defaults to False
visualize (bool): whether to visualize the output predictions, defaults to False
+ embed (list, optional): A list of feature vectors/embeddings to return.
Returns:
(tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True)
@@ -352,7 +353,7 @@ class AutoBackend(nn.Module):
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
if self.pt or self.nn_module: # PyTorch
- y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
+ y = self.model(im, augment=augment, visualize=visualize, embed=embed)
elif self.jit: # TorchScript
y = self.model(im)
elif self.dnn: # ONNX OpenCV DNN
diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py
index a834f73dd..c7856ac16 100644
--- a/ultralytics/nn/tasks.py
+++ b/ultralytics/nn/tasks.py
@@ -41,7 +41,7 @@ class BaseModel(nn.Module):
return self.loss(x, *args, **kwargs)
return self.predict(x, *args, **kwargs)
- def predict(self, x, profile=False, visualize=False, augment=False):
+ def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
"""
Perform a forward pass through the network.
@@ -50,15 +50,16 @@ class BaseModel(nn.Module):
profile (bool): Print the computation time of each layer if True, defaults to False.
visualize (bool): Save the feature maps of the model if True, defaults to False.
augment (bool): Augment image during prediction, defaults to False.
+ embed (list, optional): A list of feature vectors/embeddings to return.
Returns:
(torch.Tensor): The last output of the model.
"""
if augment:
return self._predict_augment(x)
- return self._predict_once(x, profile, visualize)
+ return self._predict_once(x, profile, visualize, embed)
- def _predict_once(self, x, profile=False, visualize=False):
+ def _predict_once(self, x, profile=False, visualize=False, embed=None):
"""
Perform a forward pass through the network.
@@ -66,11 +67,12 @@ class BaseModel(nn.Module):
x (torch.Tensor): The input tensor to the model.
profile (bool): Print the computation time of each layer if True, defaults to False.
visualize (bool): Save the feature maps of the model if True, defaults to False.
+ embed (list, optional): A list of feature vectors/embeddings to return.
Returns:
(torch.Tensor): The last output of the model.
"""
- y, dt = [], [] # outputs
+ y, dt, embeddings = [], [], [] # outputs
for m in self.model:
if m.f != -1: # if not from previous layer
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
@@ -80,6 +82,10 @@ class BaseModel(nn.Module):
y.append(x if m.i in self.save else None) # save output
if visualize:
feature_visualization(x, m.type, m.i, save_dir=visualize)
+ if embed and m.i in embed:
+ embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
+ if m.i == max(embed):
+ return torch.unbind(torch.cat(embeddings, 1), dim=0)
return x
def _predict_augment(self, x):
@@ -454,7 +460,7 @@ class RTDETRDetectionModel(DetectionModel):
return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']],
device=img.device)
- def predict(self, x, profile=False, visualize=False, batch=None, augment=False):
+ def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
"""
Perform a forward pass through the model.
@@ -464,11 +470,12 @@ class RTDETRDetectionModel(DetectionModel):
visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
batch (dict, optional): Ground truth data for evaluation. Defaults to None.
augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
+ embed (list, optional): A list of feature vectors/embeddings to return.
Returns:
(torch.Tensor): Model's output tensor.
"""
- y, dt = [], [] # outputs
+ y, dt, embeddings = [], [], [] # outputs
for m in self.model[:-1]: # except the head part
if m.f != -1: # if not from previous layer
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
@@ -478,6 +485,10 @@ class RTDETRDetectionModel(DetectionModel):
y.append(x if m.i in self.save else None) # save output
if visualize:
feature_visualization(x, m.type, m.i, save_dir=visualize)
+ if embed and m.i in embed:
+ embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
+ if m.i == max(embed):
+ return torch.unbind(torch.cat(embeddings, 1), dim=0)
head = self.model[-1]
x = head([y[j] for j in head.f], batch) # head inference
return x