Python refactorings and simplifications (#7549)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Hassaan Farooq <103611273+hassaanfarooq01@users.noreply.github.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
pull/7329/head
Glenn Jocher 11 months ago committed by GitHub
parent 0da13831cf
commit f6309b8e70
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 8
      examples/YOLOv8-OpenCV-int8-tflite-Python/main.py
  2. 2
      ultralytics/data/annotator.py
  3. 3
      ultralytics/data/augment.py
  4. 1
      ultralytics/data/base.py
  5. 1
      ultralytics/data/build.py
  6. 1
      ultralytics/data/dataset.py
  7. 23
      ultralytics/data/explorer/explorer.py
  8. 6
      ultralytics/data/explorer/utils.py
  9. 11
      ultralytics/data/split_dota.py
  10. 2
      ultralytics/engine/model.py
  11. 2
      ultralytics/engine/predictor.py
  12. 5
      ultralytics/hub/session.py
  13. 1
      ultralytics/models/fastsam/model.py
  14. 1
      ultralytics/models/nas/model.py
  15. 1
      ultralytics/models/rtdetr/train.py
  16. 3
      ultralytics/models/rtdetr/val.py
  17. 1
      ultralytics/models/sam/build.py
  18. 1
      ultralytics/models/sam/model.py
  19. 1
      ultralytics/models/sam/predict.py
  20. 1
      ultralytics/models/utils/loss.py
  21. 3
      ultralytics/models/yolo/detect/val.py
  22. 42
      ultralytics/models/yolo/obb/val.py
  23. 1
      ultralytics/nn/modules/head.py
  24. 2
      ultralytics/nn/tasks.py
  25. 2
      ultralytics/solutions/ai_gym.py
  26. 12
      ultralytics/solutions/distance_calculation.py
  27. 37
      ultralytics/solutions/heatmap.py
  28. 59
      ultralytics/solutions/object_counter.py
  29. 43
      ultralytics/solutions/speed_estimation.py
  30. 1
      ultralytics/trackers/track.py
  31. 2
      ultralytics/trackers/utils/gmc.py
  32. 4
      ultralytics/trackers/utils/matching.py
  33. 1
      ultralytics/utils/benchmarks.py
  34. 1
      ultralytics/utils/callbacks/base.py
  35. 4
      ultralytics/utils/callbacks/neptune.py
  36. 4
      ultralytics/utils/checks.py
  37. 1
      ultralytics/utils/loss.py
  38. 7
      ultralytics/utils/ops.py
  39. 12
      ultralytics/utils/plotting.py
  40. 3
      ultralytics/utils/tal.py

@ -175,9 +175,7 @@ class Yolov8TFLite:
img = np.ascontiguousarray(image)
# n, h, w, c
image = img.astype(np.float32)
image_data = image / 255
# Return the preprocessed image data
return image_data
return image / 255
def postprocess(self, input_image, output):
"""
@ -194,7 +192,7 @@ class Yolov8TFLite:
boxes = []
scores = []
class_ids = []
for i, pred in enumerate(output):
for pred in output:
pred = np.transpose(pred)
for box in pred:
x, y, w, h = box[:4]
@ -221,7 +219,7 @@ class Yolov8TFLite:
box[3] = box[3] / gain
score = scores[i]
class_id = class_ids[i]
if scores[i] > 0.25:
if score > 0.25:
print(box, score, class_id)
# Draw the detection on the input image
self.draw_detections(input_image, box, score, class_id)

@ -41,7 +41,7 @@ def auto_annotate(data, det_model="yolov8x.pt", sam_model="sam_b.pt", device="",
sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device)
segments = sam_results[0].masks.xyn # noqa
with open(f"{str(Path(output_dir) / Path(result.path).stem)}.txt", "w") as f:
with open(f"{Path(output_dir) / Path(result.path).stem}.txt", "w") as f:
for i in range(len(segments)):
s = segments[i]
if len(s) == 0:

@ -15,7 +15,6 @@ from ultralytics.utils.instance import Instances
from ultralytics.utils.metrics import bbox_ioa
from ultralytics.utils.ops import segment2box, xyxyxyxy2xywhr
from ultralytics.utils.torch_utils import TORCHVISION_0_10, TORCHVISION_0_11, TORCHVISION_0_13
from .utils import polygons2masks, polygons2masks_overlap
DEFAULT_MEAN = (0.0, 0.0, 0.0)
@ -1028,7 +1027,7 @@ def classify_transforms(
if isinstance(size, (tuple, list)):
assert len(size) == 2
scale_size = tuple([math.floor(x / crop_fraction) for x in size])
scale_size = tuple(math.floor(x / crop_fraction) for x in size)
else:
scale_size = math.floor(size / crop_fraction)
scale_size = (scale_size, scale_size)

@ -15,7 +15,6 @@ import psutil
from torch.utils.data import Dataset
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM
from .utils import HELP_URL, IMG_FORMATS

@ -22,7 +22,6 @@ from ultralytics.data.loaders import (
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.utils import RANK, colorstr
from ultralytics.utils.checks import check_file
from .dataset import YOLODataset
from .utils import PIN_MEMORY

@ -12,7 +12,6 @@ from PIL import Image
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable
from ultralytics.utils.ops import resample_segments
from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms
from .base import BaseDataset
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label

@ -7,9 +7,9 @@ from typing import Any, List, Tuple, Union
import cv2
import numpy as np
import torch
from PIL import Image
from matplotlib import pyplot as plt
from pandas import DataFrame
from PIL import Image
from tqdm import tqdm
from ultralytics.data.augment import Format
@ -17,7 +17,6 @@ from ultralytics.data.dataset import YOLODataset
from ultralytics.data.utils import check_det_dataset
from ultralytics.models.yolo.model import YOLO
from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks
from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch
@ -188,10 +187,10 @@ class Explorer:
result = exp.sql_query(query)
```
"""
assert return_type in [
assert return_type in {
"pandas",
"arrow",
], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
}, f"Return type should be either `pandas` or `arrow`, but got {return_type}"
import duckdb
if self.table is None:
@ -208,10 +207,10 @@ class Explorer:
LOGGER.info(f"Running query: {query}")
rs = duckdb.sql(query)
if return_type == "pandas":
return rs.df()
elif return_type == "arrow":
if return_type == "arrow":
return rs.arrow()
elif return_type == "pandas":
return rs.df()
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
"""
@ -264,17 +263,17 @@ class Explorer:
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
```
"""
assert return_type in [
assert return_type in {
"pandas",
"arrow",
], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
}, f"Return type should be either `pandas` or `arrow`, but got {return_type}"
img = self._check_imgs_or_idxs(img, idx)
similar = self.query(img, limit=limit)
if return_type == "pandas":
return similar.to_pandas()
elif return_type == "arrow":
if return_type == "arrow":
return similar
elif return_type == "pandas":
return similar.to_pandas()
def plot_similar(
self,

@ -98,9 +98,9 @@ def plot_query_result(similar_set, plot_labels=True):
plot_kpts.append(kpt)
batch_idx.append(np.ones(len(np.array(bboxes[i], dtype=np.float32))) * i)
imgs = np.stack(imgs, axis=0)
masks = np.stack(plot_masks, axis=0) if len(plot_masks) > 0 else np.zeros(0, dtype=np.uint8)
kpts = np.concatenate(plot_kpts, axis=0) if len(plot_kpts) > 0 else np.zeros((0, 51), dtype=np.float32)
boxes = xyxy2xywh(np.concatenate(plot_boxes, axis=0)) if len(plot_boxes) > 0 else np.zeros(0, dtype=np.float32)
masks = np.stack(plot_masks, axis=0) if plot_masks else np.zeros(0, dtype=np.uint8)
kpts = np.concatenate(plot_kpts, axis=0) if plot_kpts else np.zeros((0, 51), dtype=np.float32)
boxes = xyxy2xywh(np.concatenate(plot_boxes, axis=0)) if plot_boxes else np.zeros(0, dtype=np.float32)
batch_idx = np.concatenate(batch_idx, axis=0)
cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)

@ -139,10 +139,9 @@ def get_window_obj(anno, windows, iof_thr=0.7):
label[:, 2::2] *= h
iofs = bbox_iof(label[:, 1:], windows)
# Unnormalized and misaligned coordinates
window_anns = [(label[iofs[:, i] >= iof_thr]) for i in range(len(windows))]
return [(label[iofs[:, i] >= iof_thr]) for i in range(len(windows))] # window_anns
else:
window_anns = [np.zeros((0, 9), dtype=np.float32) for _ in range(len(windows))]
return window_anns
return [np.zeros((0, 9), dtype=np.float32) for _ in range(len(windows))] # window_anns
def crop_and_save(anno, windows, window_objs, im_dir, lb_dir):
@ -170,7 +169,7 @@ def crop_and_save(anno, windows, window_objs, im_dir, lb_dir):
name = Path(anno["filepath"]).stem
for i, window in enumerate(windows):
x_start, y_start, x_stop, y_stop = window.tolist()
new_name = name + "__" + str(x_stop - x_start) + "__" + str(x_start) + "___" + str(y_start)
new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}"
patch_im = im[y_start:y_stop, x_start:x_stop]
ph, pw = patch_im.shape[:2]
@ -271,7 +270,7 @@ def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]):
save_dir.mkdir(parents=True, exist_ok=True)
im_dir = Path(os.path.join(data_root, "images/test"))
assert im_dir.exists(), f"Can't find {str(im_dir)}, please check your data root."
assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
im_files = glob(str(im_dir / "*"))
for im_file in tqdm(im_files, total=len(im_files), desc="test"):
w, h = exif_size(Image.open(im_file))
@ -280,7 +279,7 @@ def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]):
name = Path(im_file).stem
for window in windows:
x_start, y_start, x_stop, y_stop = window.tolist()
new_name = name + "__" + str(x_stop - x_start) + "__" + str(x_start) + "___" + str(y_start)
new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}"
patch_im = im[y_start:y_stop, x_start:x_stop]
cv2.imwrite(os.path.join(str(save_dir), f"{new_name}.jpg"), patch_im)

@ -73,7 +73,7 @@ class Model(nn.Module):
self.metrics = None # validation/training metrics
self.session = None # HUB session
self.task = task # task type
model = str(model).strip() # strip spaces
self.model_name = model = str(model).strip() # strip spaces
# Check if Ultralytics HUB model from https://hub.ultralytics.com
if self.is_hub_model(model):

@ -210,7 +210,7 @@ class BasePredictor:
It uses always generator as outputs as not required by CLI mode.
"""
gen = self.stream_inference(source, model)
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
for _ in gen: # noqa, running CLI inference without accumulating any outputs (do not modify)
pass
def setup_source(self, source):

@ -70,6 +70,9 @@ class HUBTrainingSession:
def load_model(self, model_id):
# Initialize model
self.model = self.client.model(model_id)
if not self.model.data: # then model model does not exist
raise ValueError(emojis(f"❌ The specified HUB model does not exist")) # TODO: improve error handling
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
self._set_train_args()
@ -257,7 +260,7 @@ class HUBTrainingSession:
HTTPStatus.BAD_GATEWAY,
HTTPStatus.GATEWAY_TIMEOUT,
}
return True if status_code in retry_codes else False
return status_code in retry_codes
def _get_failure_message(self, response: requests.Response, retry: int, timeout: int):
"""

@ -3,7 +3,6 @@
from pathlib import Path
from ultralytics.engine.model import Model
from .predict import FastSAMPredictor
from .val import FastSAMValidator

@ -17,7 +17,6 @@ import torch
from ultralytics.engine.model import Model
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
from .predict import NASPredictor
from .val import NASValidator

@ -7,7 +7,6 @@ import torch
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import RTDETRDetectionModel
from ultralytics.utils import RANK, colorstr
from .val import RTDETRDataset, RTDETRValidator

@ -122,8 +122,7 @@ class RTDETRValidator(DetectionValidator):
bbox = ops.xywh2xyxy(bbox) # target boxes
bbox[..., [0, 2]] *= ori_shape[1] # native-space pred
bbox[..., [1, 3]] *= ori_shape[0] # native-space pred
prepared_batch = dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad)
return prepared_batch
return dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad)
def _prepare_pred(self, pred, pbatch):
"""Prepares and returns a batch with transformed bounding boxes and class labels."""

@ -11,7 +11,6 @@ from functools import partial
import torch
from ultralytics.utils.downloads import attempt_download_asset
from .modules.decoders import MaskDecoder
from .modules.encoders import ImageEncoderViT, PromptEncoder
from .modules.sam import Sam

@ -18,7 +18,6 @@ from pathlib import Path
from ultralytics.engine.model import Model
from ultralytics.utils.torch_utils import model_info
from .build import build_sam
from .predict import Predictor

@ -18,7 +18,6 @@ from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import DEFAULT_CFG, ops
from ultralytics.utils.torch_utils import select_device
from .amg import (
batch_iterator,
batched_mask_to_box,

@ -6,7 +6,6 @@ import torch.nn.functional as F
from ultralytics.utils.loss import FocalLoss, VarifocalLoss
from ultralytics.utils.metrics import bbox_iou
from .ops import HungarianMatcher

@ -104,8 +104,7 @@ class DetectionValidator(BaseValidator):
if len(cls):
bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels
prepared_batch = dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad)
return prepared_batch
return dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad)
def _prepare_pred(self, pred, pbatch):
"""Prepares a batch of images and annotations for validation."""

@ -77,8 +77,7 @@ class OBBValidator(DetectionValidator):
if len(cls):
bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
prepared_batch = dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad)
return prepared_batch
return dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad)
def _prepare_pred(self, pred, pbatch):
"""Prepares and returns a batch for OBB validation with scaled and padded bounding boxes."""
@ -139,32 +138,21 @@ class OBBValidator(DetectionValidator):
pred_txt.mkdir(parents=True, exist_ok=True)
data = json.load(open(pred_json))
# Save split results
LOGGER.info(f"Saving predictions with DOTA format to {str(pred_txt)}...")
LOGGER.info(f"Saving predictions with DOTA format to {pred_txt}...")
for d in data:
image_id = d["image_id"]
score = d["score"]
classname = self.names[d["category_id"]].replace(" ", "-")
p = d["poly"]
lines = "{} {} {} {} {} {} {} {} {} {}\n".format(
image_id,
score,
d["poly"][0],
d["poly"][1],
d["poly"][2],
d["poly"][3],
d["poly"][4],
d["poly"][5],
d["poly"][6],
d["poly"][7],
)
with open(str(pred_txt / f"Task1_{classname}") + ".txt", "a") as f:
f.writelines(lines)
with open(f'{pred_txt / f"Task1_{classname}"}.txt', "a") as f:
f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
# Save merged results, this could result slightly lower map than using official merging script,
# because of the probiou calculation.
pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions
pred_merged_txt.mkdir(parents=True, exist_ok=True)
merged_results = defaultdict(list)
LOGGER.info(f"Saving merged predictions with DOTA format to {str(pred_merged_txt)}...")
LOGGER.info(f"Saving merged predictions with DOTA format to {pred_merged_txt}...")
for d in data:
image_id = d["image_id"].split("__")[0]
pattern = re.compile(r"\d+___\d+")
@ -188,22 +176,10 @@ class OBBValidator(DetectionValidator):
b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
classname = self.names[int(x[-1])].replace(" ", "-")
poly = [round(i, 3) for i in x[:-2]]
p = [round(i, 3) for i in x[:-2]] # poly
score = round(x[-2], 3)
lines = "{} {} {} {} {} {} {} {} {} {}\n".format(
image_id,
score,
poly[0],
poly[1],
poly[2],
poly[3],
poly[4],
poly[5],
poly[6],
poly[7],
)
with open(str(pred_merged_txt / f"Task1_{classname}") + ".txt", "a") as f:
f.writelines(lines)
with open(f'{pred_merged_txt / f"Task1_{classname}"}.txt', "a") as f:
f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
return stats

@ -8,7 +8,6 @@ import torch.nn as nn
from torch.nn.init import constant_, xavier_uniform_
from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
from .block import DFL, Proto
from .conv import Conv
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer

@ -339,7 +339,7 @@ class DetectionModel(BaseModel):
class OBBModel(DetectionModel):
""""YOLOv8 Oriented Bounding Box (OBB) model."""
"""YOLOv8 Oriented Bounding Box (OBB) model."""
def __init__(self, cfg="yolov8n-obb.yaml", ch=3, nc=None, verbose=True):
"""Initialize YOLOv8 OBB model with given config and parameters."""

@ -79,7 +79,7 @@ class AIGym:
self.annotator = Annotator(im0, line_width=2)
for ind, k in enumerate(reversed(self.keypoints)):
if self.pose_type == "pushup" or self.pose_type == "pullup":
if self.pose_type in ["pushup", "pullup"]:
self.angle[ind] = self.annotator.estimate_pose_angle(
k[int(self.kpts_to_check[0])].cpu(),
k[int(self.kpts_to_check[1])].cpu(),

@ -86,10 +86,9 @@ class DistanceCalculation:
self.left_mouse_count += 1
if self.left_mouse_count <= 2:
for box, track_id in zip(self.boxes, self.trk_ids):
if box[0] < x < box[2] and box[1] < y < box[3]:
if track_id not in self.selected_boxes:
self.selected_boxes[track_id] = []
self.selected_boxes[track_id] = box
if box[0] < x < box[2] and box[1] < y < box[3] and track_id not in self.selected_boxes:
self.selected_boxes[track_id] = []
self.selected_boxes[track_id] = box
if event == cv2.EVENT_RBUTTONDOWN:
self.selected_boxes = {}
@ -149,10 +148,7 @@ class DistanceCalculation:
if tracks[0].boxes.id is None:
if self.view_img:
self.display_frames()
return
else:
return
return
self.extract_tracks(tracks)
self.annotator = Annotator(self.im0, line_width=2)

@ -169,10 +169,7 @@ class Heatmap:
if tracks[0].boxes.id is None:
if self.view_img and self.env_check:
self.display_frames()
return
else:
return
return
self.heatmap *= self.decay_factor # decay factor
self.extract_results(tracks)
self.annotator = Annotator(self.im0, self.count_txt_thickness, None)
@ -207,23 +204,21 @@ class Heatmap:
# Count objects
if len(self.count_reg_pts) == 4:
if self.counting_region.contains(Point(track_line[-1])):
if track_id not in self.counting_list:
self.counting_list.append(track_id)
if box[0] < self.counting_region.centroid.x:
self.out_counts += 1
else:
self.in_counts += 1
if self.counting_region.contains(Point(track_line[-1])) and track_id not in self.counting_list:
self.counting_list.append(track_id)
if box[0] < self.counting_region.centroid.x:
self.out_counts += 1
else:
self.in_counts += 1
elif len(self.count_reg_pts) == 2:
distance = Point(track_line[-1]).distance(self.counting_region)
if distance < self.line_dist_thresh:
if track_id not in self.counting_list:
self.counting_list.append(track_id)
if box[0] < self.counting_region.centroid.x:
self.out_counts += 1
else:
self.in_counts += 1
if distance < self.line_dist_thresh and track_id not in self.counting_list:
self.counting_list.append(track_id)
if box[0] < self.counting_region.centroid.x:
self.out_counts += 1
else:
self.in_counts += 1
else:
for box, cls in zip(self.boxes, self.clss):
if self.shape == "circle":
@ -244,8 +239,8 @@ class Heatmap:
heatmap_normalized = cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX)
heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), self.colormap)
incount_label = "In Count : " + f"{self.in_counts}"
outcount_label = "OutCount : " + f"{self.out_counts}"
incount_label = f"In Count : {self.in_counts}"
outcount_label = f"OutCount : {self.out_counts}"
# Display counts based on user choice
counts_label = None
@ -256,7 +251,7 @@ class Heatmap:
elif not self.view_out_counts:
counts_label = incount_label
else:
counts_label = incount_label + " " + outcount_label
counts_label = f"{incount_label} {outcount_label}"
if self.count_reg_pts is not None and counts_label is not None:
self.annotator.count_labels(

@ -139,11 +139,14 @@ class ObjectCounter:
# global is_drawing, selected_point
if event == cv2.EVENT_LBUTTONDOWN:
for i, point in enumerate(self.reg_pts):
if isinstance(point, (tuple, list)) and len(point) >= 2:
if abs(x - point[0]) < 10 and abs(y - point[1]) < 10:
self.selected_point = i
self.is_drawing = True
break
if (
isinstance(point, (tuple, list))
and len(point) >= 2
and (abs(x - point[0]) < 10 and abs(y - point[1]) < 10)
):
self.selected_point = i
self.is_drawing = True
break
elif event == cv2.EVENT_MOUSEMOVE:
if self.is_drawing and self.selected_point is not None:
@ -166,9 +169,8 @@ class ObjectCounter:
# Extract tracks
for box, track_id, cls in zip(boxes, track_ids, clss):
self.annotator.box_label(
box, label=str(track_id) + ":" + self.names[cls], color=colors(int(cls), True)
) # Draw bounding box
# Draw bounding box
self.annotator.box_label(box, label=f"{track_id}:{self.names[cls]}", color=colors(int(cls), True))
# Draw Tracks
track_line = self.track_history[track_id]
@ -186,28 +188,29 @@ class ObjectCounter:
# Count objects
if len(self.reg_pts) == 4:
if prev_position is not None:
if self.counting_region.contains(Point(track_line[-1])):
if track_id not in self.counting_list:
self.counting_list.append(track_id)
if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
self.in_counts += 1
else:
self.out_counts += 1
if (
prev_position is not None
and self.counting_region.contains(Point(track_line[-1]))
and track_id not in self.counting_list
):
self.counting_list.append(track_id)
if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
self.in_counts += 1
else:
self.out_counts += 1
elif len(self.reg_pts) == 2:
if prev_position is not None:
distance = Point(track_line[-1]).distance(self.counting_region)
if distance < self.line_dist_thresh:
if track_id not in self.counting_list:
self.counting_list.append(track_id)
if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
self.in_counts += 1
else:
self.out_counts += 1
if distance < self.line_dist_thresh and track_id not in self.counting_list:
self.counting_list.append(track_id)
if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
self.in_counts += 1
else:
self.out_counts += 1
incount_label = "In Count : " + f"{self.in_counts}"
outcount_label = "OutCount : " + f"{self.out_counts}"
incount_label = f"In Count : {self.in_counts}"
outcount_label = f"OutCount : {self.out_counts}"
# Display counts based on user choice
counts_label = None
@ -218,7 +221,7 @@ class ObjectCounter:
elif not self.view_out_counts:
counts_label = incount_label
else:
counts_label = incount_label + " " + outcount_label
counts_label = f"{incount_label} {outcount_label}"
if counts_label is not None:
self.annotator.count_labels(
@ -254,9 +257,7 @@ class ObjectCounter:
if tracks[0].boxes.id is None:
if self.view_img:
self.display_frames()
return
else:
return
return
self.extract_and_process_tracks(tracks)
if self.view_img:

@ -114,9 +114,7 @@ class SpeedEstimator:
cls (str): object class name
track (list): tracking history for tracks path drawing
"""
speed_label = (
str(int(self.dist_data[track_id])) + "km/ph" if track_id in self.dist_data else self.names[int(cls)]
)
speed_label = f"{int(self.dist_data[track_id])}km/ph" if track_id in self.dist_data else self.names[int(cls)]
bbox_color = colors(int(track_id)) if track_id in self.dist_data else (255, 0, 255)
self.annotator.box_label(box, speed_label, bbox_color)
@ -132,28 +130,28 @@ class SpeedEstimator:
track (list): tracking history for tracks path drawing
"""
if self.reg_pts[0][0] < track[-1][0] < self.reg_pts[1][0]:
if self.reg_pts[1][1] - self.spdl_dist_thresh < track[-1][1] < self.reg_pts[1][1] + self.spdl_dist_thresh:
direction = "known"
if not self.reg_pts[0][0] < track[-1][0] < self.reg_pts[1][0]:
return
if self.reg_pts[1][1] - self.spdl_dist_thresh < track[-1][1] < self.reg_pts[1][1] + self.spdl_dist_thresh:
direction = "known"
elif self.reg_pts[0][1] - self.spdl_dist_thresh < track[-1][1] < self.reg_pts[0][1] + self.spdl_dist_thresh:
direction = "known"
elif self.reg_pts[0][1] - self.spdl_dist_thresh < track[-1][1] < self.reg_pts[0][1] + self.spdl_dist_thresh:
direction = "known"
else:
direction = "unknown"
else:
direction = "unknown"
if self.trk_previous_times[trk_id] != 0 and direction != "unknown":
if trk_id not in self.trk_idslist:
self.trk_idslist.append(trk_id)
if self.trk_previous_times[trk_id] != 0 and direction != "unknown" and trk_id not in self.trk_idslist:
self.trk_idslist.append(trk_id)
time_difference = time() - self.trk_previous_times[trk_id]
if time_difference > 0:
dist_difference = np.abs(track[-1][1] - self.trk_previous_points[trk_id][1])
speed = dist_difference / time_difference
self.dist_data[trk_id] = speed
time_difference = time() - self.trk_previous_times[trk_id]
if time_difference > 0:
dist_difference = np.abs(track[-1][1] - self.trk_previous_points[trk_id][1])
speed = dist_difference / time_difference
self.dist_data[trk_id] = speed
self.trk_previous_times[trk_id] = time()
self.trk_previous_points[trk_id] = track[-1]
self.trk_previous_times[trk_id] = time()
self.trk_previous_points[trk_id] = track[-1]
def estimate_speed(self, im0, tracks):
"""
@ -166,10 +164,7 @@ class SpeedEstimator:
if tracks[0].boxes.id is None:
if self.view_img and self.env_check:
self.display_frames()
return
else:
return
return
self.extract_tracks(tracks)
self.annotator = Annotator(self.im0, line_width=2)

@ -7,7 +7,6 @@ import torch
from ultralytics.utils import IterableSimpleNamespace, yaml_load
from ultralytics.utils.checks import check_yaml
from .bot_sort import BOTSORT
from .byte_tracker import BYTETracker

@ -67,7 +67,7 @@ class GMC:
maxCorners=1000, qualityLevel=0.01, minDistance=1, blockSize=3, useHarrisDetector=False, k=0.04
)
elif self.method in ["none", "None", None]:
elif self.method in {"none", "None", None}:
self.method = None
else:
raise ValueError(f"Error: Unknown GMC method:{method}")

@ -70,9 +70,7 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray:
(np.ndarray): Cost matrix computed based on IoU.
"""
if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) or (
len(btracks) > 0 and isinstance(btracks[0], np.ndarray)
):
if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray):
atlbrs = atracks
btlbrs = btracks
else:

@ -26,7 +26,6 @@ ncnn | `ncnn` | yolov8n_ncnn_model/
import glob
import platform
import sys
import time
from pathlib import Path

@ -4,6 +4,7 @@
from collections import defaultdict
from copy import deepcopy
# Trainer callbacks ----------------------------------------------------------------------------------------------------

@ -96,9 +96,7 @@ def on_train_end(trainer):
for f in files:
_log_plot(title=f.stem, plot_path=f)
# Log the final model
run[f"weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}"].upload(
File(str(trainer.best))
)
run[f"weights/{trainer.args.name or trainer.args.task}/{trainer.best.name}"].upload(File(str(trainer.best)))
callbacks = (

@ -214,9 +214,9 @@ def check_version(
try:
name = current # assigned package name to 'name' arg
current = metadata.version(current) # get version string from package name
except metadata.PackageNotFoundError:
except metadata.PackageNotFoundError as e:
if hard:
raise ModuleNotFoundError(emojis(f"WARNING ⚠ {current} package is required but not installed"))
raise ModuleNotFoundError(emojis(f"WARNING ⚠ {current} package is required but not installed")) from e
else:
return False

@ -7,7 +7,6 @@ import torch.nn.functional as F
from ultralytics.utils.metrics import OKS_SIGMA
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
from .metrics import bbox_iou, probiou
from .tal import bbox2dist

@ -40,7 +40,7 @@ class Profile(contextlib.ContextDecorator):
"""
self.t = t
self.device = device
self.cuda = True if (device and str(device)[:4] == "cuda") else False
self.cuda = bool(device and str(device).startswith("cuda"))
def __enter__(self):
"""Start timing."""
@ -534,12 +534,11 @@ def xyxyxyxy2xywhr(corners):
# especially some objects are cut off by augmentations in dataloader.
(x, y), (w, h), angle = cv2.minAreaRect(pts)
rboxes.append([x, y, w, h, angle / 180 * np.pi])
rboxes = (
return (
torch.tensor(rboxes, device=corners.device, dtype=corners.dtype)
if is_torch
else np.asarray(rboxes, dtype=points.dtype)
)
return rboxes
) # rboxes
def xywhr2xyxyxyxy(center):

@ -13,7 +13,6 @@ from PIL import Image, ImageDraw, ImageFont
from PIL import __version__ as pil_version
from ultralytics.utils import LOGGER, TryExcept, ops, plt_settings, threaded
from .checks import check_font, check_version, is_ascii
from .files import increment_path
@ -433,7 +432,7 @@ class Annotator:
center_kpt (int): centroid pose index for workout monitoring
line_thickness (int): thickness for text display
"""
angle_text, count_text, stage_text = (f" {angle_text:.2f}", "Steps : " + f"{count_text}", f" {stage_text}")
angle_text, count_text, stage_text = (f" {angle_text:.2f}", f"Steps : {count_text}", f" {stage_text}")
font_scale = 0.6 + (line_thickness / 10.0)
# Draw angle
@ -773,12 +772,11 @@ def plot_images(
im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
)
annotator.fromarray(im)
if save:
annotator.im.save(fname) # save
if on_plot:
on_plot(fname)
else:
if not save:
return np.asarray(annotator.im)
annotator.im.save(fname) # save
if on_plot:
on_plot(fname)
@plt_settings()

@ -288,8 +288,7 @@ class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
norm_ad = (ad * ad).sum(dim=-1)
ap_dot_ab = (ap * ab).sum(dim=-1)
ap_dot_ad = (ap * ad).sum(dim=-1)
is_in_box = (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad)
return is_in_box
return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad) # is_in_box
def make_anchors(feats, strides, grid_cell_offset=0.5):

Loading…
Cancel
Save