Merge branch 'main' into uv-pytests

uv-pytests
Burhan 3 months ago committed by GitHub
commit 2ef57ef7ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      .github/workflows/docs.yml
  2. 18
      docs/en/guides/triton-inference-server.md
  3. 2
      docs/en/macros/augmentation-args.md
  4. 4
      docs/en/reference/utils/ops.md
  5. 3
      docs/en/solutions/index.md
  6. 2
      ultralytics/__init__.py
  7. 4
      ultralytics/engine/exporter.py
  8. 1
      ultralytics/engine/model.py
  9. 2
      ultralytics/models/sam/predict.py
  10. 3
      ultralytics/models/yolo/classify/predict.py
  11. 4
      ultralytics/models/yolo/classify/val.py
  12. 3
      ultralytics/nn/autobackend.py
  13. 7
      ultralytics/nn/modules/head.py
  14. 1
      ultralytics/utils/loss.py
  15. 15
      ultralytics/utils/ops.py
  16. 1
      ultralytics/utils/triton.py

@ -23,7 +23,7 @@ on:
inputs:
publish_docs:
description: "Publish live to https://docs.ultralytics.com"
default: "true"
default: true
type: boolean
jobs:

@ -48,6 +48,16 @@ from ultralytics import YOLO
# Load a model
model = YOLO("yolo11n.pt") # load an official model
# Retreive metadata during export
metadata = []
def export_cb(exporter):
metadata.append(exporter.metadata)
model.add_callback("on_export_end", export_cb)
# Export the model
onnx_file = model.export(format="onnx", dynamic=True)
```
@ -107,7 +117,13 @@ The Triton Model Repository is a storage location where Triton can access and lo
}
}
}
"""
parameters {
key: "metadata"
value: {
string_value: "%s"
}
}
""" % metadata[0]
with open(triton_model_path / "config.pbtxt", "w") as f:
f.write(data)

@ -13,7 +13,7 @@
| `bgr` | `float` | `0.0` | `0.0 - 1.0` | Flips the image channels from RGB to BGR with the specified probability, useful for increasing robustness to incorrect channel ordering. |
| `mosaic` | `float` | `1.0` | `0.0 - 1.0` | Combines four training images into one, simulating different scene compositions and object interactions. Highly effective for complex scene understanding. |
| `mixup` | `float` | `0.0` | `0.0 - 1.0` | Blends two images and their labels, creating a composite image. Enhances the model's ability to generalize by introducing label noise and visual variability. |
| `copy_paste` | `float` | `0.0` | `0.0 - 1.0` | Copies objects from one image and pastes them onto another, useful for increasing object instances and learning object occlusion. |
| `copy_paste` | `float` | `0.0` | `0.0 - 1.0` | Copies and pastes objects across images, useful for increasing object instances and learning object occlusion. Requires segmentation labels. |
| `copy_paste_mode` | `str` | `flip` | - | Copy-Paste augmentation method selection among the options of (`"flip"`, `"mixup"`). |
| `auto_augment` | `str` | `randaugment` | - | Automatically applies a predefined augmentation policy (`randaugment`, `autoaugment`, `augmix`), optimizing for classification tasks by diversifying the visual features. |
| `erasing` | `float` | `0.4` | `0.0 - 0.9` | Randomly erases a portion of the image during classification training, encouraging the model to focus on less obvious features for recognition. |

@ -129,4 +129,8 @@ keywords: Ultralytics, utility operations, non-max suppression, bounding box tra
## ::: ultralytics.utils.ops.clean_str
<br><br><hr><br>
## ::: ultralytics.utils.ops.empty_like
<br><br>

@ -29,7 +29,6 @@ Here's our curated list of Ultralytics solutions that can be used to create awes
- [Parking Management](../guides/parking-management.md) 🚀: Organize and direct vehicle flow in parking areas with YOLO11, optimizing space utilization and user experience.
- [Analytics](../guides/analytics.md) 📊: Conduct comprehensive data analysis to discover patterns and make informed decisions, leveraging YOLO11 for descriptive, predictive, and prescriptive analytics.
- [Live Inference with Streamlit](../guides/streamlit-live-inference.md) 🚀: Leverage the power of YOLO11 for real-time [object detection](https://www.ultralytics.com/glossary/object-detection) directly through your web browser with a user-friendly Streamlit interface.
- [Live Inference with Streamlit](../guides/streamlit-live-inference.md) 🚀: Leverage the power of YOLO11 for real-time [object detection](https://www.ultralytics.com/glossary/object-detection) directly through your web browser with a user-friendly Streamlit interface.
- [Track Objects in Zone](../guides/trackzone.md) 🎯 NEW: Learn how to track objects within specific zones of video frames using YOLO11 for precise and efficient monitoring.
## Solutions Usage
@ -39,7 +38,7 @@ Here's our curated list of Ultralytics solutions that can be used to create awes
`yolo SOLUTIONS SOLUTION_NAME ARGS`
- **SOLUTIONS** is a required keyword.
- **SOLUTION_NAME** (optional) is one of: `['count', 'heatmap', 'queue', 'speed', 'workout', 'analytics']`.
- **SOLUTION_NAME** (optional) is one of: `['count', 'heatmap', 'queue', 'speed', 'workout', 'analytics', 'trackzone']`.
- **ARGS** (optional) are custom `arg=value` pairs, such as `show_in=True`, to override default settings.
=== "CLI"

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.43"
__version__ = "8.3.47"
import os

@ -73,7 +73,7 @@ from ultralytics.data import build_dataloader
from ultralytics.data.dataset import YOLODataset
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.autobackend import check_class_names, default_class_names
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
from ultralytics.nn.modules import C2f, Classify, Detect, RTDETRDecoder
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel
from ultralytics.utils import (
ARM64,
@ -287,6 +287,8 @@ class Exporter:
model = FXModel(model)
for m in model.modules():
if isinstance(m, Classify):
m.export = True
if isinstance(m, (Detect, RTDETRDecoder)): # includes all Detect subclasses like Segment, Pose, OBB
m.dynamic = self.args.dynamic
m.export = True

@ -136,6 +136,7 @@ class Model(nn.Module):
# Check if Triton Server model
elif self.is_triton_model(model):
self.model_name = self.model = model
self.overrides["task"] = task or "detect" # set `task=detect` if not explicitly set
return
# Load or create new YOLO model

@ -1105,7 +1105,7 @@ class SAM2VideoPredictor(SAM2Predictor):
for obj_temp_output_dict in temp_output_dict_per_obj.values():
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
consolidated_frame_inds[storage_key].update(temp_frame_inds)
# consolidate the temprary output across all objects on this frame
# consolidate the temporary output across all objects on this frame
for frame_idx in temp_frame_inds:
consolidated_out = self._consolidate_temp_output_across_obj(
frame_idx, is_cond=is_cond, run_mem_encoder=True

@ -53,7 +53,8 @@ class ClassificationPredictor(BasePredictor):
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
preds = preds[0] if isinstance(preds, (list, tuple)) else preds
return [
Results(orig_img, path=img_path, names=self.model.names, probs=pred.softmax(0))
Results(orig_img, path=img_path, names=self.model.names, probs=pred)
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
]

@ -71,6 +71,10 @@ class ClassificationValidator(BaseValidator):
self.metrics.confusion_matrix = self.confusion_matrix
self.metrics.save_dir = self.save_dir
def postprocess(self, preds):
"""Preprocesses the classification predictions."""
return preds[0] if isinstance(preds, (list, tuple)) else preds
def get_stats(self):
"""Returns a dictionary of metrics obtained by processing targets and predictions."""
self.metrics.process(self.targets, self.pred)

@ -96,7 +96,7 @@ class AutoBackend(nn.Module):
Initialize the AutoBackend for inference.
Args:
weights (str): Path to the model weights file. Defaults to 'yolov8n.pt'.
weights (str | torch.nn.Module): Path to the model weights file or a module instance. Defaults to 'yolo11n.pt'.
device (torch.device): Device to run the model on. Defaults to CPU.
dnn (bool): Use OpenCV DNN module for ONNX inference. Defaults to False.
data (str | Path | optional): Path to the additional data.yaml file containing class names. Optional.
@ -462,6 +462,7 @@ class AutoBackend(nn.Module):
from ultralytics.utils.triton import TritonRemoteModel
model = TritonRemoteModel(w)
metadata = model.metadata
# Any other format (unsupported)
else:

@ -282,6 +282,8 @@ class Pose(Detect):
class Classify(nn.Module):
"""YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
export = False # export mode
def __init__(self, c1, c2, k=1, s=1, p=None, g=1):
"""Initializes YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape."""
super().__init__()
@ -296,7 +298,10 @@ class Classify(nn.Module):
if isinstance(x, list):
x = torch.cat(x, 1)
x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
return x
if self.training:
return x
y = x.softmax(1) # get final output
return y if self.export else (y, x)
class WorldDetect(Detect):

@ -604,6 +604,7 @@ class v8ClassificationLoss:
def __call__(self, preds, batch):
"""Compute the classification loss between predictions and true labels."""
preds = preds[1] if isinstance(preds, (list, tuple)) else preds
loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
loss_items = loss.detach()
return loss, loss_items

@ -400,7 +400,7 @@ def xyxy2xywh(x):
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
"""
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
y = empty_like(x) # faster than clone/copy
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
y[..., 2] = x[..., 2] - x[..., 0] # width
@ -420,7 +420,7 @@ def xywh2xyxy(x):
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
"""
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
y = empty_like(x) # faster than clone/copy
xy = x[..., :2] # centers
wh = x[..., 2:] / 2 # half width-height
y[..., :2] = xy - wh # top left xy
@ -443,7 +443,7 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
"""
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
y = empty_like(x) # faster than clone/copy
y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
@ -469,7 +469,7 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
if clip:
x = clip_boxes(x, (h - eps, w - eps))
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=float) # faster than clone/copy
y = empty_like(x) # faster than clone/copy
y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
@ -838,3 +838,10 @@ def clean_str(s):
(str): a string with special characters replaced by an underscore _
"""
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
def empty_like(x):
"""Creates empty torch.Tensor or np.ndarray with same shape as input and float32 dtype."""
return (
torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32)
)

@ -66,6 +66,7 @@ class TritonRemoteModel:
self.np_input_formats = [type_map[x] for x in self.input_formats]
self.input_names = [x["name"] for x in config["input"]]
self.output_names = [x["name"] for x in config["output"]]
self.metadata = eval(config.get("parameters", {}).get("metadata", {}).get("string_value", "None"))
def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
"""

Loading…
Cancel
Save