diff --git a/docs/en/guides/workouts-monitoring.md b/docs/en/guides/workouts-monitoring.md index f0c32cf44..02e8209d9 100644 --- a/docs/en/guides/workouts-monitoring.md +++ b/docs/en/guides/workouts-monitoring.md @@ -125,5 +125,6 @@ Monitoring workouts through pose estimation with [Ultralytics YOLOv8](https://gi | `visualize` | `bool` | `False` | visualize model features | | `augment` | `bool` | `False` | apply image augmentation to prediction sources | | `agnostic_nms` | `bool` | `False` | class-agnostic NMS | +| `classes` | `list[int]` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] | | `retina_masks` | `bool` | `False` | use high-resolution segmentation masks | -| `classes` | `None or list` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] | +| `embed` | `list[int]` | `None` | return feature vectors/embeddings from given layers | diff --git a/docs/en/modes/predict.md b/docs/en/modes/predict.md index 446d36edf..6c4561fa2 100644 --- a/docs/en/modes/predict.md +++ b/docs/en/modes/predict.md @@ -355,8 +355,9 @@ Inference arguments: | `visualize` | `bool` | `False` | visualize model features | | `augment` | `bool` | `False` | apply image augmentation to prediction sources | | `agnostic_nms` | `bool` | `False` | class-agnostic NMS | +| `classes` | `list[int]` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] | | `retina_masks` | `bool` | `False` | use high-resolution segmentation masks | -| `classes` | `None or list` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] | +| `embed` | `list[int]` | `None` | return feature vectors/embeddings from given layers | Visualization arguments: diff --git a/docs/en/reference/nn/autobackend.md b/docs/en/reference/nn/autobackend.md index 462789e50..3e8c2f7a2 100644 --- a/docs/en/reference/nn/autobackend.md +++ b/docs/en/reference/nn/autobackend.md @@ -18,3 +18,7 @@ keywords: Ultralytics, AutoBackend, check_class_names, YOLO, YOLO models, optimi ## ::: ultralytics.nn.autobackend.check_class_names

+ +## ::: ultralytics.nn.autobackend.default_class_names + +

diff --git a/docs/en/usage/cfg.md b/docs/en/usage/cfg.md index ce0b1a23c..2e822aeeb 100644 --- a/docs/en/usage/cfg.md +++ b/docs/en/usage/cfg.md @@ -156,8 +156,9 @@ Inference arguments: | `visualize` | `bool` | `False` | visualize model features | | `augment` | `bool` | `False` | apply image augmentation to prediction sources | | `agnostic_nms` | `bool` | `False` | class-agnostic NMS | +| `classes` | `list[int]` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] | | `retina_masks` | `bool` | `False` | use high-resolution segmentation masks | -| `classes` | `None or list` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] | +| `embed` | `list[int]` | `None` | return feature vectors/embeddings from given layers | Visualization arguments: diff --git a/tests/test_python.py b/tests/test_python.py index 741ad5d08..710974d42 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -511,3 +511,13 @@ def test_model_tune(): """Tune YOLO model for performance.""" YOLO('yolov8n-pose.pt').tune(data='coco8-pose.yaml', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu') YOLO('yolov8n-cls.pt').tune(data='imagenet10', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu') + + +def test_model_embeddings(): + """Test YOLO model embeddings.""" + model_detect = YOLO(MODEL) + model_segment = YOLO(WEIGHTS_DIR / 'yolov8n-seg.pt') + + for batch in [SOURCE], [SOURCE, SOURCE]: # test batch size 1 and 2 + assert len(model_detect.embed(source=batch, imgsz=32)) == len(batch) + assert len(model_segment.embed(source=batch, imgsz=32)) == len(batch) diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 5862abb8b..fa5418361 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = '8.0.228' +__version__ = '8.0.229' from ultralytics.models import RTDETR, SAM, YOLO from ultralytics.models.fastsam import FastSAM diff --git a/ultralytics/cfg/default.yaml b/ultralytics/cfg/default.yaml index b3499853a..f6edad234 100644 --- a/ultralytics/cfg/default.yaml +++ b/ultralytics/cfg/default.yaml @@ -61,6 +61,7 @@ augment: False # (bool) apply image augmentation to prediction sources agnostic_nms: False # (bool) class-agnostic NMS classes: # (int | list[int], optional) filter results by class, i.e. classes=0, or classes=[0,2,3] retina_masks: False # (bool) use high-resolution segmentation masks +embed: # (list[int], optional) return feature vectors/embeddings from given layers # Visualize settings --------------------------------------------------------------------------------------------------- show: False # (bool) show predicted images and videos if environment allows diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index af3a85077..374c872d3 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -94,7 +94,7 @@ class Model(nn.Module): self._load(model, task) def __call__(self, source=None, stream=False, **kwargs): - """Calls the 'predict' function with given arguments to perform object detection.""" + """Calls the predict() method with given arguments to perform object detection.""" return self.predict(source, stream, **kwargs) @staticmethod @@ -201,6 +201,24 @@ class Model(nn.Module): self._check_is_pytorch_model() self.model.fuse() + def embed(self, source=None, stream=False, **kwargs): + """ + Calls the predict() method and returns image embeddings. + + Args: + source (str | int | PIL | np.ndarray): The source of the image to make predictions on. + Accepts all source types accepted by the YOLO model. + stream (bool): Whether to stream the predictions or not. Defaults to False. + **kwargs : Additional keyword arguments passed to the predictor. + Check the 'configuration' section in the documentation for all available options. + + Returns: + (List[torch.Tensor]): A list of image embeddings. + """ + if not kwargs.get('embed'): + kwargs['embed'] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed + return self.predict(source, stream, **kwargs) + def predict(self, source=None, stream=False, predictor=None, **kwargs): """ Perform prediction using the YOLO model. diff --git a/ultralytics/engine/predictor.py b/ultralytics/engine/predictor.py index 5e04b1793..cda599441 100644 --- a/ultralytics/engine/predictor.py +++ b/ultralytics/engine/predictor.py @@ -134,7 +134,7 @@ class BasePredictor: """Runs inference on a given image using the specified model and arguments.""" visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False - return self.model(im, augment=self.args.augment, visualize=visualize) + return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs) def pre_transform(self, im): """ @@ -263,6 +263,9 @@ class BasePredictor: # Inference with profilers[1]: preds = self.inference(im, *args, **kwargs) + if self.args.embed: + yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors + continue # Postprocess with profilers[2]: diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index 09f982cae..a5c964ced 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -333,7 +333,7 @@ class AutoBackend(nn.Module): self.__dict__.update(locals()) # assign all variables to self - def forward(self, im, augment=False, visualize=False): + def forward(self, im, augment=False, visualize=False, embed=None): """ Runs inference on the YOLOv8 MultiBackend model. @@ -341,6 +341,7 @@ class AutoBackend(nn.Module): im (torch.Tensor): The image tensor to perform inference on. augment (bool): whether to perform data augmentation during inference, defaults to False visualize (bool): whether to visualize the output predictions, defaults to False + embed (list, optional): A list of feature vectors/embeddings to return. Returns: (tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True) @@ -352,7 +353,7 @@ class AutoBackend(nn.Module): im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3) if self.pt or self.nn_module: # PyTorch - y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im) + y = self.model(im, augment=augment, visualize=visualize, embed=embed) elif self.jit: # TorchScript y = self.model(im) elif self.dnn: # ONNX OpenCV DNN diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index a834f73dd..c7856ac16 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -41,7 +41,7 @@ class BaseModel(nn.Module): return self.loss(x, *args, **kwargs) return self.predict(x, *args, **kwargs) - def predict(self, x, profile=False, visualize=False, augment=False): + def predict(self, x, profile=False, visualize=False, augment=False, embed=None): """ Perform a forward pass through the network. @@ -50,15 +50,16 @@ class BaseModel(nn.Module): profile (bool): Print the computation time of each layer if True, defaults to False. visualize (bool): Save the feature maps of the model if True, defaults to False. augment (bool): Augment image during prediction, defaults to False. + embed (list, optional): A list of feature vectors/embeddings to return. Returns: (torch.Tensor): The last output of the model. """ if augment: return self._predict_augment(x) - return self._predict_once(x, profile, visualize) + return self._predict_once(x, profile, visualize, embed) - def _predict_once(self, x, profile=False, visualize=False): + def _predict_once(self, x, profile=False, visualize=False, embed=None): """ Perform a forward pass through the network. @@ -66,11 +67,12 @@ class BaseModel(nn.Module): x (torch.Tensor): The input tensor to the model. profile (bool): Print the computation time of each layer if True, defaults to False. visualize (bool): Save the feature maps of the model if True, defaults to False. + embed (list, optional): A list of feature vectors/embeddings to return. Returns: (torch.Tensor): The last output of the model. """ - y, dt = [], [] # outputs + y, dt, embeddings = [], [], [] # outputs for m in self.model: if m.f != -1: # if not from previous layer x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers @@ -80,6 +82,10 @@ class BaseModel(nn.Module): y.append(x if m.i in self.save else None) # save output if visualize: feature_visualization(x, m.type, m.i, save_dir=visualize) + if embed and m.i in embed: + embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten + if m.i == max(embed): + return torch.unbind(torch.cat(embeddings, 1), dim=0) return x def _predict_augment(self, x): @@ -454,7 +460,7 @@ class RTDETRDetectionModel(DetectionModel): return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']], device=img.device) - def predict(self, x, profile=False, visualize=False, batch=None, augment=False): + def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None): """ Perform a forward pass through the model. @@ -464,11 +470,12 @@ class RTDETRDetectionModel(DetectionModel): visualize (bool, optional): If True, save feature maps for visualization. Defaults to False. batch (dict, optional): Ground truth data for evaluation. Defaults to None. augment (bool, optional): If True, perform data augmentation during inference. Defaults to False. + embed (list, optional): A list of feature vectors/embeddings to return. Returns: (torch.Tensor): Model's output tensor. """ - y, dt = [], [] # outputs + y, dt, embeddings = [], [], [] # outputs for m in self.model[:-1]: # except the head part if m.f != -1: # if not from previous layer x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers @@ -478,6 +485,10 @@ class RTDETRDetectionModel(DetectionModel): y.append(x if m.i in self.save else None) # save output if visualize: feature_visualization(x, m.type, m.i, save_dir=visualize) + if embed and m.i in embed: + embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten + if m.i == max(embed): + return torch.unbind(torch.cat(embeddings, 1), dim=0) head = self.model[-1] x = head([y[j] for j in head.f], batch) # head inference return x