|
|
|
@ -54,6 +54,22 @@ class BaseModel(nn.Module): |
|
|
|
|
visualize (bool): Save the feature maps of the model if True, defaults to False. |
|
|
|
|
augment (bool): Augment image during prediction, defaults to False. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
(torch.Tensor): The last output of the model. |
|
|
|
|
""" |
|
|
|
|
if augment: |
|
|
|
|
return self._predict_augment(x) |
|
|
|
|
return self._predict_once(x, profile, visualize) |
|
|
|
|
|
|
|
|
|
def _predict_once(self, x, profile=False, visualize=False): |
|
|
|
|
""" |
|
|
|
|
Perform a forward pass through the network. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
x (torch.Tensor): The input tensor to the model. |
|
|
|
|
profile (bool): Print the computation time of each layer if True, defaults to False. |
|
|
|
|
visualize (bool): Save the feature maps of the model if True, defaults to False. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
(torch.Tensor): The last output of the model. |
|
|
|
|
""" |
|
|
|
@ -69,6 +85,13 @@ class BaseModel(nn.Module): |
|
|
|
|
feature_visualization(x, m.type, m.i, save_dir=visualize) |
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
def _predict_augment(self, x): |
|
|
|
|
"""Perform augmentations on input image x and return augmented inference.""" |
|
|
|
|
LOGGER.warning( |
|
|
|
|
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.' |
|
|
|
|
) |
|
|
|
|
return self._predict_once(x) |
|
|
|
|
|
|
|
|
|
def _profile_one_layer(self, m, x, dt): |
|
|
|
|
""" |
|
|
|
|
Profile the computation time and FLOPs of a single layer of the model on a given input. |
|
|
|
@ -225,13 +248,7 @@ class DetectionModel(BaseModel): |
|
|
|
|
self.info() |
|
|
|
|
LOGGER.info('') |
|
|
|
|
|
|
|
|
|
def predict(self, x, augment=False, profile=False, visualize=False): |
|
|
|
|
"""Run forward pass on input image(s) with optional augmentation and profiling.""" |
|
|
|
|
if augment: |
|
|
|
|
return self._forward_augment(x) # augmented inference, None |
|
|
|
|
return super().predict(x, profile=profile, visualize=visualize) # single-scale inference, train |
|
|
|
|
|
|
|
|
|
def _forward_augment(self, x): |
|
|
|
|
def _predict_augment(self, x): |
|
|
|
|
"""Perform augmentations on input image x and return augmented inference and train outputs.""" |
|
|
|
|
img_size = x.shape[-2:] # height, width |
|
|
|
|
s = [1, 0.83, 0.67] # scales |
|
|
|
@ -279,13 +296,16 @@ class SegmentationModel(DetectionModel): |
|
|
|
|
"""Initialize YOLOv8 segmentation model with given config and parameters.""" |
|
|
|
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) |
|
|
|
|
|
|
|
|
|
def _forward_augment(self, x): |
|
|
|
|
"""Undocumented function.""" |
|
|
|
|
raise NotImplementedError(emojis('WARNING ⚠️ SegmentationModel has not supported augment inference yet!')) |
|
|
|
|
|
|
|
|
|
def init_criterion(self): |
|
|
|
|
return v8SegmentationLoss(self) |
|
|
|
|
|
|
|
|
|
def _predict_augment(self, x): |
|
|
|
|
"""Perform augmentations on input image x and return augmented inference.""" |
|
|
|
|
LOGGER.warning( |
|
|
|
|
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.' |
|
|
|
|
) |
|
|
|
|
return self._predict_once(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PoseModel(DetectionModel): |
|
|
|
|
"""YOLOv8 pose model.""" |
|
|
|
@ -302,9 +322,12 @@ class PoseModel(DetectionModel): |
|
|
|
|
def init_criterion(self): |
|
|
|
|
return v8PoseLoss(self) |
|
|
|
|
|
|
|
|
|
def _forward_augment(self, x): |
|
|
|
|
"""Undocumented function.""" |
|
|
|
|
raise NotImplementedError(emojis('WARNING ⚠️ PoseModel has not supported augment inference yet!')) |
|
|
|
|
def _predict_augment(self, x): |
|
|
|
|
"""Perform augmentations on input image x and return augmented inference.""" |
|
|
|
|
LOGGER.warning( |
|
|
|
|
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.' |
|
|
|
|
) |
|
|
|
|
return self._predict_once(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ClassificationModel(BaseModel): |
|
|
|
@ -448,10 +471,6 @@ class RTDETRDetectionModel(DetectionModel): |
|
|
|
|
x = head([y[j] for j in head.f], batch) # head inference |
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
def _forward_augment(self, x): |
|
|
|
|
"""Undocumented function.""" |
|
|
|
|
raise NotImplementedError(emojis('WARNING ⚠️ RTDETRModel has not supported augment inference yet!')) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Ensemble(nn.ModuleList): |
|
|
|
|
"""Ensemble of models.""" |
|
|
|
|