`ultralytics 8.2.59` use `Results.save_txt` for validation (#14496)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/14306/merge v8.2.59
Laughing 4 months ago committed by GitHub
parent ebf7dcf5a8
commit bfcd85323d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 8
      docs/en/integrations/kaggle.md
  2. 1
      docs/mkdocs_github_authors.yaml
  3. 2
      ultralytics/__init__.py
  4. 22
      ultralytics/models/yolo/detect/val.py
  5. 20
      ultralytics/models/yolo/obb/val.py
  6. 22
      ultralytics/models/yolo/pose/val.py
  7. 38
      ultralytics/models/yolo/segment/val.py

@ -28,7 +28,7 @@ Once you sign in to your Kaggle account, you can click on the option to copy and
![Using kaggle for machine learning model training with a GPU](https://github.com/user-attachments/assets/264f01b9-207b-4a8d-be5d-7b51739e9726)
On the [official YOLOv8 Kaggle notebook page](https://www.kaggle.com/code/ultralytics/yolov8), if you click on the three dots in the upper right-hand corner, youll notice more options will pop up.
On the [official YOLOv8 Kaggle notebook page](https://www.kaggle.com/code/ultralytics/yolov8), if you click on the three dots in the upper right-hand corner, you'll notice more options will pop up.
![Overview of Options From the Official YOLOv8 Kaggle Notebook Page](https://github.com/user-attachments/assets/bca100a6-fae8-433d-8dfd-1ecf4cc4f691)
@ -49,7 +49,7 @@ These options include:
When working with Kaggle, you might come across some common issues. Here are some points to help you navigate the platform smoothly:
- **Access to GPUs**: In your Kaggle notebooks, you can activate a GPU at any time, with usage allowed for up to 30 hours per week. Kaggle provides the Nvidia Tesla P100 GPU with 16GB of memory and also offers the option of using a Nvidia GPU T4 x2. Powerful hardware accelerates your machine-learning tasks, making model training and inference much faster.
- **Kaggle Kernels**: Kaggle Kernels are free Jupyter notebook servers that can integrate GPUs, allowing you to perform machine learning operations on cloud computers. You dont have to rely on your own computer's CPU, avoiding overload and freeing up your local resources.
- **Kaggle Kernels**: Kaggle Kernels are free Jupyter notebook servers that can integrate GPUs, allowing you to perform machine learning operations on cloud computers. You don't have to rely on your own computer's CPU, avoiding overload and freeing up your local resources.
- **Kaggle Datasets**: Kaggle datasets are free to download. However, it's important to check the license for each dataset to understand any usage restrictions. Some datasets may have limitations on academic publications or commercial use. You can download datasets directly to your Kaggle notebook or anywhere else via the Kaggle API.
- **Saving and Committing Notebooks**: To save and commit a notebook on Kaggle, click "Save Version." This saves the current state of your notebook. Once the background kernel finishes generating the output files, you can access them from the Output tab on the main notebook page.
- **Collaboration**: Kaggle supports collaboration, but multiple users cannot edit a notebook simultaneously. Collaboration on Kaggle is asynchronous, meaning users can share and work on the same notebook at different times.
@ -57,7 +57,7 @@ When working with Kaggle, you might come across some common issues. Here are som
## Key Features of Kaggle
Next, lets understand the features Kaggle offers that make it an excellent platform for data science and machine learning enthusiasts. Here are some of the key highlights:
Next, let's understand the features Kaggle offers that make it an excellent platform for data science and machine learning enthusiasts. Here are some of the key highlights:
- **Datasets**: Kaggle hosts a massive collection of datasets on various topics. You can easily search and use these datasets in your projects, which is particularly handy for training and testing your YOLOv8 models.
- **Competitions**: Known for its exciting competitions, Kaggle allows data scientists and machine learning enthusiasts to solve real-world problems. Competing helps you improve your skills, learn new techniques, and gain recognition in the community.
@ -81,7 +81,7 @@ If you want to learn more about Kaggle, here are some helpful resources to guide
- [**Kaggle Learn**](https://www.kaggle.com/learn): Discover a variety of free, interactive tutorials on Kaggle Learn. These courses cover essential data science topics and provide hands-on experience to help you master new skills.
- [**Getting Started with Kaggle**](https://www.kaggle.com/code/alexisbcook/getting-started-with-kaggle): This comprehensive guide walks you through the basics of using Kaggle, from joining competitions to creating your first notebook. It's a great starting point for newcomers.
- [**Kaggle Medium Page**](https://medium.com/@kaggleteam): Explore tutorials, updates, and community contributions on Kaggle’s Medium page. It’s an excellent source for staying up-to-date with the latest trends and gaining deeper insights into data science.
- [**Kaggle Medium Page**](https://medium.com/@kaggleteam): Explore tutorials, updates, and community contributions on Kaggle's Medium page. It's an excellent source for staying up-to-date with the latest trends and gaining deeper insights into data science.
## Summary

@ -25,6 +25,7 @@ andrei.kochin@intel.com: andrei-kochin
ayush.chaurarsia@gmail.com: AyushExel
chr043416@gmail.com: RizwanMunawar
glenn.jocher@ultralytics.com: glenn-jocher
hnliu_2@stu.xidian.edu.cn: null
jpedrofonseca_94@hotmail.com: null
k-2feng@hotmail.com: null
lakshantha@ultralytics.com: lakshanthad

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.58"
__version__ = "8.2.59"
import os

@ -160,8 +160,12 @@ class DetectionValidator(BaseValidator):
if self.args.save_json:
self.pred_to_json(predn, batch["im_file"][si])
if self.args.save_txt:
file = self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt'
self.save_one_txt(predn, self.args.save_conf, pbatch["ori_shape"], file)
self.save_one_txt(
predn,
self.args.save_conf,
pbatch["ori_shape"],
self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt',
)
def finalize_metrics(self, *args, **kwargs):
"""Set final values for metrics speed and confusion matrix."""
@ -261,12 +265,14 @@ class DetectionValidator(BaseValidator):
def save_one_txt(self, predn, save_conf, shape, file):
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
for *xyxy, conf, cls in predn.tolist():
xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
with open(file, "a") as f:
f.write(("%g " * len(line)).rstrip() % line + "\n")
from ultralytics.engine.results import Results
Results(
np.zeros((shape[0], shape[1]), dtype=np.uint8),
path=None,
names=self.names,
boxes=predn[:, :6],
).save_txt(file, save_conf=save_conf)
def pred_to_json(self, predn, filename):
"""Serialize YOLO predictions to COCO json format."""

@ -130,13 +130,19 @@ class OBBValidator(DetectionValidator):
def save_one_txt(self, predn, save_conf, shape, file):
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
gn = torch.tensor(shape)[[1, 0]] # normalization gain whwh
for *xywh, conf, cls, angle in predn.tolist():
xywha = torch.tensor([*xywh, angle]).view(1, 5)
xyxyxyxy = (ops.xywhr2xyxyxyxy(xywha) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xyxyxyxy, conf) if save_conf else (cls, *xyxyxyxy) # label format
with open(file, "a") as f:
f.write(("%g " * len(line)).rstrip() % line + "\n")
import numpy as np
from ultralytics.engine.results import Results
rboxes = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
# xywh, r, conf, cls
obb = torch.cat([rboxes, predn[:, 4:6]], dim=-1)
Results(
np.zeros((shape[0], shape[1]), dtype=np.uint8),
path=None,
names=self.names,
obb=obb,
).save_txt(file, save_conf=save_conf)
def eval_json(self, stats):
"""Evaluates YOLO output in JSON format and returns performance statistics."""

@ -147,8 +147,14 @@ class PoseValidator(DetectionValidator):
# Save
if self.args.save_json:
self.pred_to_json(predn, batch["im_file"][si])
# if self.args.save_txt:
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
if self.args.save_txt:
self.save_one_txt(
predn,
pred_kpts,
self.args.save_conf,
pbatch["ori_shape"],
self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt',
)
def _process_batch(self, detections, gt_bboxes, gt_cls, pred_kpts=None, gt_kpts=None):
"""
@ -217,6 +223,18 @@ class PoseValidator(DetectionValidator):
on_plot=self.on_plot,
) # pred
def save_one_txt(self, predn, pred_kpts, save_conf, shape, file):
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
from ultralytics.engine.results import Results
Results(
np.zeros((shape[0], shape[1]), dtype=np.uint8),
path=None,
names=self.names,
boxes=predn[:, :6],
keypoints=pred_kpts,
).save_txt(file, save_conf=save_conf)
def pred_to_json(self, predn, filename):
"""Converts YOLO predictions to COCO JSON format."""
stem = Path(filename).stem

@ -48,9 +48,8 @@ class SegmentationValidator(DetectionValidator):
self.plot_masks = []
if self.args.save_json:
check_requirements("pycocotools>=2.0.6")
self.process = ops.process_mask_upsample # more accurate
else:
self.process = ops.process_mask # faster
# more accurate vs faster
self.process = ops.process_mask_upsample if self.args.save_json or self.args.save_txt else ops.process_mask
self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
def get_desc(self):
@ -148,14 +147,23 @@ class SegmentationValidator(DetectionValidator):
# Save
if self.args.save_json:
pred_masks = ops.scale_image(
pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
self.pred_to_json(
predn,
batch["im_file"][si],
ops.scale_image(
pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
pbatch["ori_shape"],
ratio_pad=batch["ratio_pad"][si],
),
)
if self.args.save_txt:
self.save_one_txt(
predn,
pred_masks,
self.args.save_conf,
pbatch["ori_shape"],
ratio_pad=batch["ratio_pad"][si],
self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt',
)
self.pred_to_json(predn, batch["im_file"][si], pred_masks)
# if self.args.save_txt:
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
def finalize_metrics(self, *args, **kwargs):
"""Sets speed and confusion matrix for evaluation metrics."""
@ -235,6 +243,18 @@ class SegmentationValidator(DetectionValidator):
) # pred
self.plot_masks.clear()
def save_one_txt(self, predn, pred_masks, save_conf, shape, file):
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
from ultralytics.engine.results import Results
Results(
np.zeros((shape[0], shape[1]), dtype=np.uint8),
path=None,
names=self.names,
boxes=predn[:, :6],
masks=pred_masks,
).save_txt(file, save_conf=save_conf)
def pred_to_json(self, predn, filename, pred_masks):
"""
Save one JSON result.

Loading…
Cancel
Save