scipy_removal
Glenn Jocher 1 year ago
commit 6c206aefbb
  1. 4
      docs/reference/hub/__init__.md
  2. 4
      docs/reference/utils/metrics.md
  3. 4
      docs/reference/utils/ops.md
  4. 2
      tests/conftest.py
  5. 56
      tests/test_python.py
  6. 3
      ultralytics/__init__.py
  7. 27
      ultralytics/hub/__init__.py
  8. 24
      ultralytics/hub/auth.py
  9. 56
      ultralytics/models/sam/modules/sam.py
  10. 7
      ultralytics/nn/modules/conv.py
  11. 164
      ultralytics/trackers/utils/kalman_filter.py
  12. 11
      ultralytics/utils/downloads.py
  13. 11
      ultralytics/utils/metrics.py
  14. 201
      ultralytics/utils/ops.py

@ -17,10 +17,6 @@ keywords: Ultralytics, hub functions, model export, dataset check, reset model,
## ::: ultralytics.hub.logout
<br><br>
---
## ::: ultralytics.hub.start
<br><br>
---
## ::: ultralytics.hub.reset_model
<br><br>

@ -33,10 +33,6 @@ keywords: Ultralytics, YOLO, YOLOv3, YOLOv4, metrics, confusion matrix, detectio
## ::: ultralytics.utils.metrics.ClassifyMetrics
<br><br>
---
## ::: ultralytics.utils.metrics.box_area
<br><br>
---
## ::: ultralytics.utils.metrics.bbox_ioa
<br><br>

@ -57,10 +57,6 @@ keywords: Ultralytics YOLO, Utility Operations, segment2box, make_divisible, cli
## ::: ultralytics.utils.ops.xyxy2xywhn
<br><br>
---
## ::: ultralytics.utils.ops.xyn2xy
<br><br>
---
## ::: ultralytics.utils.ops.xywh2ltwh
<br><br>

@ -6,6 +6,7 @@ from pathlib import Path
import pytest
from ultralytics.utils import ROOT
from ultralytics.utils.torch_utils import init_seeds
TMP = (ROOT / '../tests/tmp').resolve() # temp directory for test files
@ -32,6 +33,7 @@ def pytest_sessionstart(session):
"""
Called after the 'Session' object has been created and before performing test collection.
"""
init_seeds()
shutil.rmtree(TMP, ignore_errors=True) # delete any existing tests/tmp directory
TMP.mkdir(parents=True, exist_ok=True) # create a new empty directory

@ -128,7 +128,7 @@ def test_track_stream():
def test_val():
model = YOLO(MODEL)
model.val(data='coco8.yaml', imgsz=32)
model.val(data='coco8.yaml', imgsz=32, save_hybrid=True)
def test_train_scratch():
@ -348,9 +348,20 @@ def test_utils_downloads():
def test_utils_ops():
from ultralytics.utils.ops import make_divisible
from ultralytics.utils.ops import (ltwh2xywh, ltwh2xyxy, make_divisible, xywh2ltwh, xywh2xyxy, xywhn2xyxy,
xywhr2xyxyxyxy, xyxy2ltwh, xyxy2xywh, xyxy2xywhn, xyxyxyxy2xywhr)
make_divisible(17, 8)
make_divisible(17, torch.tensor([8]))
boxes = torch.rand(10, 4) # xywh
torch.allclose(boxes, xyxy2xywh(xywh2xyxy(boxes)))
torch.allclose(boxes, xyxy2xywhn(xywhn2xyxy(boxes)))
torch.allclose(boxes, ltwh2xywh(xywh2ltwh(boxes)))
torch.allclose(boxes, xyxy2ltwh(ltwh2xyxy(boxes)))
boxes = torch.rand(10, 5) # xywhr for OBB
boxes[:, 4] = torch.randn(10) * 30
torch.allclose(boxes, xyxyxyxy2xywhr(xywhr2xyxyxyxy(boxes)), rtol=1e-3)
def test_utils_files():
@ -364,3 +375,42 @@ def test_utils_files():
path.mkdir(parents=True, exist_ok=True)
with spaces_in_path(path) as new_path:
print(new_path)
def test_nn_modules_conv():
from ultralytics.nn.modules.conv import CBAM, Conv2, ConvTranspose, DWConvTranspose2d, Focus
c1, c2 = 8, 16 # input and output channels
x = torch.zeros(4, c1, 10, 10) # BCHW
# Run all modules not otherwise covered in tests
DWConvTranspose2d(c1, c2)(x)
ConvTranspose(c1, c2)(x)
Focus(c1, c2)(x)
CBAM(c1)(x)
# Fuse ops
m = Conv2(c1, c2)
m.fuse_convs()
m(x)
def test_nn_modules_block():
from ultralytics.nn.modules.block import C1, C3TR, BottleneckCSP, C3Ghost, C3x
c1, c2 = 8, 16 # input and output channels
x = torch.zeros(4, c1, 10, 10) # BCHW
# Run all modules not otherwise covered in tests
C1(c1, c2)(x)
C3x(c1, c2)(x)
C3TR(c1, c2)(x)
C3Ghost(c1, c2)(x)
BottleneckCSP(c1, c2)(x)
def test_hub():
from ultralytics.hub import export_fmts_hub, logout
export_fmts_hub()
logout()

@ -2,7 +2,6 @@
__version__ = '8.0.159'
from ultralytics.hub import start
from ultralytics.models import RTDETR, SAM, YOLO
from ultralytics.models.fastsam import FastSAM
from ultralytics.models.nas import NAS
@ -10,4 +9,4 @@ from ultralytics.utils import SETTINGS as settings
from ultralytics.utils.checks import check_yolo as checks
from ultralytics.utils.downloads import download
__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'start', 'settings' # allow simpler import
__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'settings' # allow simpler import

@ -5,7 +5,7 @@ import requests
from ultralytics.data.utils import HUBDatasetStats
from ultralytics.hub.auth import Auth
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
from ultralytics.utils import LOGGER, SETTINGS, USER_CONFIG_DIR, yaml_save
from ultralytics.utils import LOGGER, SETTINGS
def login(api_key=''):
@ -37,29 +37,10 @@ def logout():
```
"""
SETTINGS['api_key'] = ''
yaml_save(USER_CONFIG_DIR / 'settings.yaml', SETTINGS)
SETTINGS.save()
LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")
def start(key=''):
"""
Start training models with Ultralytics HUB (DEPRECATED).
Args:
key (str, optional): A string containing either the API key and model ID combination (apikey_modelid),
or the full model URL (https://hub.ultralytics.com/models/apikey_modelid).
"""
api_key, model_id = key.split('_')
LOGGER.warning(f"""
WARNING ultralytics.start() is deprecated after 8.0.60. Updated usage to train Ultralytics HUB models is:
from ultralytics import YOLO, hub
hub.login('{api_key}')
model = YOLO('{HUB_WEB_ROOT}/models/{model_id}')
model.train()""")
def reset_model(model_id=''):
"""Reset a trained model to an untrained state."""
r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'apiKey': Auth().api_key, 'modelId': model_id})
@ -117,7 +98,3 @@ def check_dataset(path='', task='detect'):
"""
HUBDatasetStats(path=path, task=task).get_json()
LOGGER.info(f'Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.')
if __name__ == '__main__':
start()

@ -73,8 +73,7 @@ class Auth:
bool: True if authentication is successful, False otherwise.
"""
try:
header = self.get_auth_header()
if header:
if header := self.get_auth_header():
r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header)
if not r.json().get('success', False):
raise ConnectionError('Unable to authenticate.')
@ -117,23 +116,4 @@ class Auth:
return {'authorization': f'Bearer {self.id_token}'}
elif self.api_key:
return {'x-api-key': self.api_key}
else:
return None
def get_state(self) -> bool:
"""
Get the authentication state.
Returns:
bool: True if either id_token or API key is set, False otherwise.
"""
return self.id_token or self.api_key
def set_api_key(self, key: str):
"""
Set the API key for authentication.
Args:
key (str): The API key string.
"""
self.api_key = key
# else returns None

@ -30,11 +30,10 @@ class Sam(nn.Module):
SAM predicts object masks from an image and input prompts.
Args:
image_encoder (ImageEncoderViT): The backbone used to encode the
image into image embeddings that allow for efficient mask prediction.
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for
efficient mask prediction.
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
and encoded prompts.
mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
pixel_std (list(float)): Std values for normalizing pixels in the input image.
"""
@ -65,34 +64,25 @@ class Sam(nn.Module):
Args:
batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt
key can be excluded if it is not present.
'image': The image as a torch tensor in 3xHxW format,
already transformed for input to the model.
'original_size': (tuple(int, int)) The original size of
the image before transformation, as (H, W).
'point_coords': (torch.Tensor) Batched point prompts for
this image, with shape BxNx2. Already transformed to the
input frame of the model.
'point_labels': (torch.Tensor) Batched labels for point prompts,
with shape BxN.
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
Already transformed to the input frame of the model.
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
in the form Bx1xHxW.
key can be excluded if it is not present.
'image': The image as a torch tensor in 3xHxW format, already transformed for input to the model.
'original_size': (tuple(int, int)) The original size of the image before transformation, as (H, W).
'point_coords': (torch.Tensor) Batched point prompts for this image, with shape BxNx2. Already
transformed to the input frame of the model.
'point_labels': (torch.Tensor) Batched labels for point prompts, with shape BxN.
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. Already transformed to the input frame of
the model.
'mask_inputs': (torch.Tensor) Batched mask inputs to the model, in the form Bx1xHxW.
multimask_output (bool): Whether the model should predict multiple disambiguating masks, or return a single
mask.
Returns:
(list(dict)): A list over input images, where each element is as dictionary with the following keys.
'masks': (torch.Tensor) Batched binary mask predictions,
with shape BxCxHxW, where B is the number of input prompts,
C is determined by multimask_output, and (H, W) is the
original size of the image.
'iou_predictions': (torch.Tensor) The model's predictions
of mask quality, in shape BxC.
'low_res_logits': (torch.Tensor) Low resolution logits with
shape BxCxHxW, where H=W=256. Can be passed as mask input
to subsequent iterations of prediction.
'masks': (torch.Tensor) Batched binary mask predictions, with shape BxCxHxW, where B is the number of
input prompts, C is determined by multimask_output, and (H, W) is the original size of the image.
'iou_predictions': (torch.Tensor) The model's predictions of mask quality, in shape BxC.
'low_res_logits': (torch.Tensor) Low resolution logits with shape BxCxHxW, where H=W=256. Can be passed
as mask input to subsequent iterations of prediction.
"""
input_images = torch.stack([self.preprocess(x['image']) for x in batched_input], dim=0)
image_embeddings = self.image_encoder(input_images)
@ -137,16 +127,12 @@ class Sam(nn.Module):
Remove padding and upscale masks to the original image size.
Args:
masks (torch.Tensor): Batched masks from the mask_decoder,
in BxCxHxW format.
input_size (tuple(int, int)): The size of the image input to the
model, in (H, W) format. Used to remove padding.
original_size (tuple(int, int)): The original size of the image
before resizing for input to the model, in (H, W) format.
masks (torch.Tensor): Batched masks from the mask_decoder, in BxCxHxW format.
input_size (tuple(int, int)): The size of the model input image, in (H, W) format. Used to remove padding.
original_size (tuple(int, int)): The original image size before resizing for input to the model, in (H, W).
Returns:
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
is given by original_size.
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W) is given by original_size.
"""
masks = F.interpolate(
masks,

@ -9,7 +9,7 @@ import numpy as np
import torch
import torch.nn as nn
__all__ = ('Conv', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv',
__all__ = ('Conv', 'Conv2', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv',
'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'RepConv')
@ -54,6 +54,10 @@ class Conv2(Conv):
"""Apply convolution, batch normalization and activation to input tensor."""
return self.act(self.bn(self.conv(x) + self.cv2(x)))
def forward_fuse(self, x):
"""Apply fused convolution, batch normalization and activation to input tensor."""
return self.act(self.bn(self.conv(x)))
def fuse_convs(self):
"""Fuse parallel convolutions."""
w = torch.zeros_like(self.conv.weight.data)
@ -61,6 +65,7 @@ class Conv2(Conv):
w[:, :, i[0]:i[0] + 1, i[1]:i[1] + 1] = self.cv2.weight.data.clone()
self.conv.weight.data += w
self.__delattr__('cv2')
self.forward = self.forward_fuse
class LightConv(nn.Module):

@ -6,20 +6,13 @@ import scipy.linalg
class KalmanFilterXYAH:
"""
For bytetrack
A simple Kalman filter for tracking bounding boxes in image space.
For bytetrack. A simple Kalman filter for tracking bounding boxes in image space.
The 8-dimensional state space
x, y, a, h, vx, vy, va, vh
contains the bounding box center position (x, y), aspect ratio a, height h,
and their respective velocities.
Object motion follows a constant velocity model. The bounding box location
(x, y, a, h) is taken as direct observation of the state space (linear
observation model).
The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y),
aspect ratio a, height h, and their respective velocities.
Object motion follows a constant velocity model. The bounding box location (x, y, a, h) is taken as direct
observation of the state space (linear observation model).
"""
def __init__(self):
@ -32,14 +25,14 @@ class KalmanFilterXYAH:
self._motion_mat[i, ndim + i] = dt
self._update_mat = np.eye(ndim, 2 * ndim)
# Motion and observation uncertainty are chosen relative to the current
# state estimate. These weights control the amount of uncertainty in
# the model. This is a bit hacky.
# Motion and observation uncertainty are chosen relative to the current state estimate. These weights control
# the amount of uncertainty in the model. This is a bit hacky.
self._std_weight_position = 1. / 20
self._std_weight_velocity = 1. / 160
def initiate(self, measurement):
"""Create track from unassociated measurement.
"""
Create track from unassociated measurement.
Parameters
----------
@ -53,7 +46,6 @@ class KalmanFilterXYAH:
Returns the mean vector (8 dimensional) and covariance matrix (8x8
dimensional) of the new track. Unobserved velocities are initialized
to 0 mean.
"""
mean_pos = measurement
mean_vel = np.zeros_like(mean_pos)
@ -67,23 +59,21 @@ class KalmanFilterXYAH:
return mean, covariance
def predict(self, mean, covariance):
"""Run Kalman filter prediction step.
"""
Run Kalman filter prediction step.
Parameters
----------
mean : ndarray
The 8 dimensional mean vector of the object state at the previous
time step.
The 8 dimensional mean vector of the object state at the previous time step.
covariance : ndarray
The 8x8 dimensional covariance matrix of the object state at the
previous time step.
The 8x8 dimensional covariance matrix of the object state at the previous time step.
Returns
-------
(ndarray, ndarray)
Returns the mean vector and covariance matrix of the predicted
state. Unobserved velocities are initialized to 0 mean.
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-2,
@ -100,7 +90,8 @@ class KalmanFilterXYAH:
return mean, covariance
def project(self, mean, covariance):
"""Project state distribution to measurement space.
"""
Project state distribution to measurement space.
Parameters
----------
@ -112,9 +103,7 @@ class KalmanFilterXYAH:
Returns
-------
(ndarray, ndarray)
Returns the projected mean and covariance matrix of the given state
estimate.
Returns the projected mean and covariance matrix of the given state estimate.
"""
std = [
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-1,
@ -126,20 +115,21 @@ class KalmanFilterXYAH:
return mean, covariance + innovation_cov
def multi_predict(self, mean, covariance):
"""Run Kalman filter prediction step (Vectorized version).
"""
Run Kalman filter prediction step (Vectorized version).
Parameters
----------
mean : ndarray
The Nx8 dimensional mean matrix of the object states at the previous
time step.
The Nx8 dimensional mean matrix of the object states at the previous time step.
covariance : ndarray
The Nx8x8 dimensional covariance matrix of the object states at the
previous time step.
The Nx8x8 dimensional covariance matrix of the object states at the previous time step.
Returns
-------
(ndarray, ndarray)
Returns the mean vector and covariance matrix of the predicted
state. Unobserved velocities are initialized to 0 mean.
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[:, 3], self._std_weight_position * mean[:, 3],
@ -159,7 +149,8 @@ class KalmanFilterXYAH:
return mean, covariance
def update(self, mean, covariance, measurement):
"""Run Kalman filter correction step.
"""
Run Kalman filter correction step.
Parameters
----------
@ -168,14 +159,13 @@ class KalmanFilterXYAH:
covariance : ndarray
The state's covariance matrix (8x8 dimensional).
measurement : ndarray
The 4 dimensional measurement vector (x, y, a, h), where (x, y)
is the center position, a is the aspect ratio, and h is the height of the bounding box.
The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center position, a the aspect
ratio, and h the height of the bounding box.
Returns
-------
(ndarray, ndarray)
Returns the measurement-corrected state distribution.
"""
projected_mean, projected_cov = self.project(mean, covariance)
@ -195,10 +185,11 @@ class KalmanFilterXYAH:
return new_mean, new_covariance
def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'):
"""Compute gating distance between state distribution and measurements.
A suitable distance threshold can be obtained from `chi2inv95`. If
`only_position` is False, the chi-square distribution has 4 degrees of
"""
Compute gating distance between state distribution and measurements. A suitable distance threshold can be
obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of
freedom, otherwise 2.
Parameters
----------
mean : ndarray
@ -206,18 +197,16 @@ class KalmanFilterXYAH:
covariance : ndarray
Covariance of the state distribution (8x8 dimensional).
measurements : ndarray
An Nx4 dimensional matrix of N measurements, each in
format (x, y, a, h) where (x, y) is the bounding box center
position, a the aspect ratio, and h the height.
An Nx4 dimensional matrix of N measurements, each in format (x, y, a, h) where (x, y) is the bounding box
center position, a the aspect ratio, and h the height.
only_position : Optional[bool]
If True, distance computation is done with respect to the bounding
box center position only.
If True, distance computation is done with respect to the bounding box center position only.
Returns
-------
ndarray
Returns an array of length N, where the i-th element contains the
squared Mahalanobis distance between (mean, covariance) and
`measurements[i]`.
Returns an array of length N, where the i-th element contains the squared Mahalanobis distance between
(mean, covariance) and `measurements[i]`.
"""
mean, covariance = self.project(mean, covariance)
if only_position:
@ -237,38 +226,29 @@ class KalmanFilterXYAH:
class KalmanFilterXYWH(KalmanFilterXYAH):
"""
For BoT-SORT
A simple Kalman filter for tracking bounding boxes in image space.
For BoT-SORT. A simple Kalman filter for tracking bounding boxes in image space.
The 8-dimensional state space
x, y, w, h, vx, vy, vw, vh
contains the bounding box center position (x, y), width w, height h,
and their respective velocities.
Object motion follows a constant velocity model. The bounding box location
(x, y, w, h) is taken as direct observation of the state space (linear
observation model).
The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y),
width w, height h, and their respective velocities.
Object motion follows a constant velocity model. The bounding box location (x, y, w, h) is taken as direct
observation of the state space (linear observation model).
"""
def initiate(self, measurement):
"""Create track from unassociated measurement.
"""
Create track from unassociated measurement.
Parameters
----------
measurement : ndarray
Bounding box coordinates (x, y, w, h) with center position (x, y),
width w, and height h.
Bounding box coordinates (x, y, w, h) with center position (x, y), width w, and height h.
Returns
-------
(ndarray, ndarray)
Returns the mean vector (8 dimensional) and covariance matrix (8x8
dimensional) of the new track. Unobserved velocities are initialized
to 0 mean.
Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of the new track.
Unobserved velocities are initialized to 0 mean.
"""
mean_pos = measurement
mean_vel = np.zeros_like(mean_pos)
@ -283,23 +263,21 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
return mean, covariance
def predict(self, mean, covariance):
"""Run Kalman filter prediction step.
"""
Run Kalman filter prediction step.
Parameters
----------
mean : ndarray
The 8 dimensional mean vector of the object state at the previous
time step.
The 8 dimensional mean vector of the object state at the previous time step.
covariance : ndarray
The 8x8 dimensional covariance matrix of the object state at the
previous time step.
The 8x8 dimensional covariance matrix of the object state at the previous time step.
Returns
-------
(ndarray, ndarray)
Returns the mean vector and covariance matrix of the predicted
state. Unobserved velocities are initialized to 0 mean.
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[2], self._std_weight_position * mean[3],
@ -315,7 +293,8 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
return mean, covariance
def project(self, mean, covariance):
"""Project state distribution to measurement space.
"""
Project state distribution to measurement space.
Parameters
----------
@ -327,9 +306,7 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
Returns
-------
(ndarray, ndarray)
Returns the projected mean and covariance matrix of the given state
estimate.
Returns the projected mean and covariance matrix of the given state estimate.
"""
std = [
self._std_weight_position * mean[2], self._std_weight_position * mean[3],
@ -341,20 +318,21 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
return mean, covariance + innovation_cov
def multi_predict(self, mean, covariance):
"""Run Kalman filter prediction step (Vectorized version).
"""
Run Kalman filter prediction step (Vectorized version).
Parameters
----------
mean : ndarray
The Nx8 dimensional mean matrix of the object states at the previous
time step.
The Nx8 dimensional mean matrix of the object states at the previous time step.
covariance : ndarray
The Nx8x8 dimensional covariance matrix of the object states at the
previous time step.
The Nx8x8 dimensional covariance matrix of the object states at the previous time step.
Returns
-------
(ndarray, ndarray)
Returns the mean vector and covariance matrix of the predicted
state. Unobserved velocities are initialized to 0 mean.
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3],
@ -374,7 +352,8 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
return mean, covariance
def update(self, mean, covariance, measurement):
"""Run Kalman filter correction step.
"""
Run Kalman filter correction step.
Parameters
----------
@ -383,13 +362,12 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
covariance : ndarray
The state's covariance matrix (8x8 dimensional).
measurement : ndarray
The 4 dimensional measurement vector (x, y, w, h), where (x, y) is the center position, w is the width, and
h is the height of the bounding box.
The 4 dimensional measurement vector (x, y, w, h), where (x, y) is the center position, w the width,
and h the height of the bounding box.
Returns
-------
(ndarray, ndarray)
Returns the measurement-corrected state distribution.
"""
return super().update(mean, covariance, measurement)

@ -212,21 +212,18 @@ def get_google_drive_file_info(link):
"""
file_id = link.split('/d/')[1].split('/view')[0]
drive_url = f'https://drive.google.com/uc?export=download&id={file_id}'
filename = None
# Start session
filename = None
with requests.Session() as session:
response = session.get(drive_url, stream=True)
if 'quota exceeded' in str(response.content.lower()):
raise ConnectionError(
emojis(f'❌ Google Drive file download quota exceeded. '
f'Please try again later or download this file manually at {link}.'))
token = None
for key, value in response.cookies.items():
if key.startswith('download_warning'):
token = value
if token:
drive_url = f'https://drive.google.com/uc?export=download&confirm={token}&id={file_id}'
for k, v in response.cookies.items():
if k.startswith('download_warning'):
drive_url += f'&confirm={v}' # v is token
cd = response.headers.get('content-disposition')
if cd:
filename = re.findall('filename="(.+)"', cd)[0]

@ -15,12 +15,6 @@ from ultralytics.utils import LOGGER, SimpleClass, TryExcept, plt_settings
OKS_SIGMA = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0
# Boxes
def box_area(box):
"""Return box area, where box shape is xyxy(4,n)."""
return (box[2] - box[0]) * (box[3] - box[1])
def bbox_ioa(box1, box2, iou=False, eps=1e-7):
"""
Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format.
@ -869,11 +863,6 @@ class PoseMetrics(SegmentMetrics):
self.pose = Metric()
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
def __getattr__(self, attr):
"""Raises an AttributeError if an invalid attribute is accessed."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def process(self, tp_b, tp_p, conf, pred_cls, target_cls):
"""
Processes the detection and pose metrics over the given set of predictions.

@ -13,8 +13,6 @@ import torchvision
from ultralytics.utils import LOGGER
from .metrics import box_iou
class Profile(contextlib.ContextDecorator):
"""
@ -32,23 +30,17 @@ class Profile(contextlib.ContextDecorator):
self.cuda = torch.cuda.is_available()
def __enter__(self):
"""
Start timing.
"""
"""Start timing."""
self.start = self.time()
return self
def __exit__(self, type, value, traceback): # noqa
"""
Stop timing.
"""
"""Stop timing."""
self.dt = self.time() - self.start # delta-time
self.t += self.dt # accumulate dt
def time(self):
"""
Get current time.
"""
"""Get current time."""
if self.cuda:
torch.cuda.synchronize()
return time.time()
@ -56,15 +48,15 @@ class Profile(contextlib.ContextDecorator):
def segment2box(segment, width=640, height=640):
"""
Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy).
Args:
segment (torch.Tensor): the segment label
width (int): the width of the image. Defaults to 640
height (int): The height of the image. Defaults to 640
segment (torch.Tensor): the segment label
width (int): the width of the image. Defaults to 640
height (int): The height of the image. Defaults to 640
Returns:
(np.ndarray): the minimum and maximum x and y values of the segment.
(np.ndarray): the minimum and maximum x and y values of the segment.
"""
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
x, y = segment.T # segment xy
@ -80,16 +72,16 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True):
(img1_shape) to the shape of a different image (img0_shape).
Args:
img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
img0_shape (tuple): the shape of the target image, in the format of (height, width).
ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
calculated based on the size difference between the two images.
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
rescaling.
img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
img0_shape (tuple): the shape of the target image, in the format of (height, width).
ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
calculated based on the size difference between the two images.
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
rescaling.
Returns:
boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
"""
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
@ -186,9 +178,7 @@ def non_max_suppression(
# Settings
# min_wh = 2 # (pixels) minimum box width and height
time_limit = 0.5 + max_time_img * bs # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS
prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
@ -226,10 +216,6 @@ def non_max_suppression(
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# Apply finite constraint
# if not torch.isfinite(x).all():
# x = x[torch.isfinite(x).all(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
@ -242,13 +228,18 @@ def non_max_suppression(
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
i = i[:max_det] # limit detections
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
if redundant:
i = i[iou.sum(1) > 1] # require redundancy
# # Experimental
# merge = False # use merge-NMS
# if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
# from .metrics import box_iou
# iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
# weights = iou * scores[None] # box weights
# x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
# redundant = True # require redundant detections
# if redundant:
# i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
if mps:
@ -262,8 +253,7 @@ def non_max_suppression(
def clip_boxes(boxes, shape):
"""
It takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the
shape
Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape.
Args:
boxes (torch.Tensor): the bounding boxes to clip
@ -303,12 +293,12 @@ def scale_image(masks, im0_shape, ratio_pad=None):
Takes a mask, and resizes it to the original image size
Args:
masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3].
im0_shape (tuple): the original image shape
ratio_pad (tuple): the ratio of the padding to the original image.
masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3].
im0_shape (tuple): the original image shape
ratio_pad (tuple): the ratio of the padding to the original image.
Returns:
masks (torch.Tensor): The masks that are being returned.
masks (torch.Tensor): The masks that are being returned.
"""
# Rescale coordinates (xyxy) from im1_shape to im0_shape
im1_shape = masks.shape
@ -340,6 +330,7 @@ def xyxy2xywh(x):
Args:
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
Returns:
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
"""
@ -359,6 +350,7 @@ def xywh2xyxy(x):
Args:
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x, y, width, height) format.
Returns:
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
"""
@ -407,6 +399,7 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
h (int): The height of the image. Defaults to 640
clip (bool): If True, the boxes will be clipped to the image boundaries. Defaults to False
eps (float): The minimum value of the box's width and height. Defaults to 0.0
Returns:
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
"""
@ -421,31 +414,13 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
return y
def xyn2xy(x, w=640, h=640, padw=0, padh=0):
"""
Convert normalized coordinates to pixel coordinates of shape (n,2)
Args:
x (np.ndarray | torch.Tensor): The input tensor of normalized bounding box coordinates
w (int): The width of the image. Defaults to 640
h (int): The height of the image. Defaults to 640
padw (int): The width of the padding. Defaults to 0
padh (int): The height of the padding. Defaults to 0
Returns:
y (np.ndarray | torch.Tensor): The x and y coordinates of the top left corner of the bounding box
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 0] = w * x[..., 0] + padw # top left x
y[..., 1] = h * x[..., 1] + padh # top left y
return y
def xywh2ltwh(x):
"""
Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates.
Args:
x (np.ndarray | torch.Tensor): The input tensor with the bounding box coordinates in the xywh format
Returns:
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format
"""
@ -460,9 +435,10 @@ def xyxy2ltwh(x):
Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right
Args:
x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
Returns:
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format.
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format.
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 2] = x[..., 2] - x[..., 0] # width
@ -475,7 +451,10 @@ def ltwh2xywh(x):
Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center
Args:
x (torch.Tensor): the input tensor
x (torch.Tensor): the input tensor
Returns:
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xywh format.
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x
@ -493,14 +472,8 @@ def xyxyxyxy2xywhr(corners):
Returns:
(numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5).
"""
if isinstance(corners, torch.Tensor):
is_numpy = False
atan2 = torch.atan2
sqrt = torch.sqrt
else:
is_numpy = True
atan2 = np.arctan2
sqrt = np.sqrt
is_numpy = isinstance(corners, np.ndarray)
atan2, sqrt = (np.arctan2, np.sqrt) if is_numpy else (torch.atan2, torch.sqrt)
x1, y1, x2, y2, x3, y3, x4, y4 = corners.T
cx = (x1 + x3) / 2
@ -527,14 +500,8 @@ def xywhr2xyxyxyxy(center):
Returns:
(numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 8).
"""
if isinstance(center, torch.Tensor):
is_numpy = False
cos = torch.cos
sin = torch.sin
else:
is_numpy = True
cos = np.cos
sin = np.sin
is_numpy = isinstance(center, np.ndarray)
cos, sin = (np.cos, np.sin) if is_numpy else (torch.cos, torch.sin)
cx, cy, w, h, rotation = center.T
rotation *= math.pi / 180.0 # degrees to radians
@ -567,10 +534,10 @@ def ltwh2xyxy(x):
It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
Args:
x (np.ndarray | torch.Tensor): the input image
x (np.ndarray | torch.Tensor): the input image
Returns:
y (np.ndarray | torch.Tensor): the xyxy coordinates of the bounding boxes.
y (np.ndarray | torch.Tensor): the xyxy coordinates of the bounding boxes.
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 2] = x[..., 2] + x[..., 0] # width
@ -583,10 +550,10 @@ def segments2boxes(segments):
It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
Args:
segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates
segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates
Returns:
(np.ndarray): the xywh coordinates of the bounding boxes.
(np.ndarray): the xywh coordinates of the bounding boxes.
"""
boxes = []
for s in segments:
@ -600,11 +567,11 @@ def resample_segments(segments, n=1000):
Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each.
Args:
segments (list): a list of (n,2) arrays, where n is the number of points in the segment.
n (int): number of points to resample the segment to. Defaults to 1000
segments (list): a list of (n,2) arrays, where n is the number of points in the segment.
n (int): number of points to resample the segment to. Defaults to 1000
Returns:
segments (list): the resampled segments.
segments (list): the resampled segments.
"""
for i, s in enumerate(segments):
s = np.concatenate((s, s[0:1, :]), axis=0)
@ -617,14 +584,14 @@ def resample_segments(segments, n=1000):
def crop_mask(masks, boxes):
"""
It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box
It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box.
Args:
masks (torch.Tensor): [n, h, w] tensor of masks
boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form
masks (torch.Tensor): [n, h, w] tensor of masks
boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form
Returns:
(torch.Tensor): The masks are being cropped to the bounding box.
(torch.Tensor): The masks are being cropped to the bounding box.
"""
n, h, w = masks.shape
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
@ -636,17 +603,17 @@ def crop_mask(masks, boxes):
def process_mask_upsample(protos, masks_in, bboxes, shape):
"""
It takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher
Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher
quality but is slower.
Args:
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
shape (tuple): the size of the input image (h,w)
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
shape (tuple): the size of the input image (h,w)
Returns:
(torch.Tensor): The upsampled masks.
(torch.Tensor): The upsampled masks.
"""
c, mh, mw = protos.shape # CHW
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
@ -692,13 +659,13 @@ def process_mask_native(protos, masks_in, bboxes, shape):
It takes the output of the mask head, and crops it after upsampling to the bounding boxes.
Args:
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
shape (tuple): the size of the input image (h,w)
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
shape (tuple): the size of the input image (h,w)
Returns:
masks (torch.Tensor): The returned masks with dimensions [h, w, n]
masks (torch.Tensor): The returned masks with dimensions [h, w, n]
"""
c, mh, mw = protos.shape # CHW
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
@ -733,19 +700,19 @@ def scale_masks(masks, shape, padding=True):
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True):
"""
Rescale segment coordinates (xyxy) from img1_shape to img0_shape
Rescale segment coordinates (xy) from img1_shape to img0_shape
Args:
img1_shape (tuple): The shape of the image that the coords are from.
coords (torch.Tensor): the coords to be scaled
img0_shape (tuple): the shape of the image that the segmentation is being applied to
ratio_pad (tuple): the ratio of the image size to the padded image size.
normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
rescaling.
img1_shape (tuple): The shape of the image that the coords are from.
coords (torch.Tensor): the coords to be scaled of shape n,2.
img0_shape (tuple): the shape of the image that the segmentation is being applied to.
ratio_pad (tuple): the ratio of the image size to the padded image size.
normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False.
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
rescaling.
Returns:
coords (torch.Tensor): the segmented image.
coords (torch.Tensor): The scaled coordinates.
"""
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
@ -771,11 +738,11 @@ def masks2segments(masks, strategy='largest'):
It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
Args:
masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160)
strategy (str): 'concat' or 'largest'. Defaults to largest
masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160)
strategy (str): 'concat' or 'largest'. Defaults to largest
Returns:
segments (List): list of segment masks
segments (List): list of segment masks
"""
segments = []
for x in masks.int().cpu().numpy().astype('uint8'):
@ -796,9 +763,9 @@ def clean_str(s):
Cleans a string by replacing special characters with underscore _
Args:
s (str): a string needing special characters replaced
s (str): a string needing special characters replaced
Returns:
(str): a string with special characters replaced by an underscore _
(str): a string with special characters replaced by an underscore _
"""
return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s)

Loading…
Cancel
Save