Refactor Python code (#13448)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
pull/13462/head
Glenn Jocher 9 months ago committed by GitHub
parent 6a234f3639
commit 1b26838def
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 5
      ultralytics/data/augment.py
  2. 2
      ultralytics/data/build.py
  3. 7
      ultralytics/engine/exporter.py
  4. 1
      ultralytics/engine/model.py
  5. 4
      ultralytics/engine/predictor.py
  6. 2
      ultralytics/hub/session.py
  7. 2
      ultralytics/models/fastsam/prompt.py
  8. 2
      ultralytics/models/nas/val.py
  9. 16
      ultralytics/models/yolo/detect/val.py
  10. 9
      ultralytics/models/yolo/world/train_world.py
  11. 4
      ultralytics/nn/autobackend.py
  12. 6
      ultralytics/nn/modules/block.py
  13. 2
      ultralytics/solutions/ai_gym.py
  14. 2
      ultralytics/solutions/object_counter.py
  15. 3
      ultralytics/trackers/track.py
  16. 2
      ultralytics/trackers/utils/gmc.py
  17. 44
      ultralytics/utils/benchmarks.py
  18. 6
      ultralytics/utils/callbacks/mlflow.py
  19. 2
      ultralytics/utils/callbacks/wb.py
  20. 3
      ultralytics/utils/checks.py
  21. 4
      ultralytics/utils/plotting.py

@ -1114,10 +1114,7 @@ class RandomLoadText:
pos_labels = set(random.sample(pos_labels, k=self.max_samples)) pos_labels = set(random.sample(pos_labels, k=self.max_samples))
neg_samples = min(min(num_classes, self.max_samples) - len(pos_labels), random.randint(*self.neg_samples)) neg_samples = min(min(num_classes, self.max_samples) - len(pos_labels), random.randint(*self.neg_samples))
neg_labels = [] neg_labels = [i for i in range(num_classes) if i not in pos_labels]
for i in range(num_classes):
if i not in pos_labels:
neg_labels.append(i)
neg_labels = random.sample(neg_labels, k=neg_samples) neg_labels = random.sample(neg_labels, k=neg_samples)
sampled_labels = pos_labels + neg_labels sampled_labels = pos_labels + neg_labels

@ -21,7 +21,7 @@ from ultralytics.data.loaders import (
autocast_list, autocast_list,
) )
from ultralytics.data.utils import IMG_FORMATS, PIN_MEMORY, VID_FORMATS from ultralytics.data.utils import IMG_FORMATS, PIN_MEMORY, VID_FORMATS
from ultralytics.utils import LINUX, RANK, colorstr from ultralytics.utils import RANK, colorstr
from ultralytics.utils.checks import check_file from ultralytics.utils.checks import check_file

@ -209,9 +209,10 @@ class Exporter:
if self.args.optimize: if self.args.optimize:
assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False" assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False"
assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'" assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
if edgetpu and not LINUX: if edgetpu:
raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/") if not LINUX:
elif edgetpu and self.args.batch != 1: # see github.com/ultralytics/ultralytics/pull/13420 raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler")
elif self.args.batch != 1: # see github.com/ultralytics/ultralytics/pull/13420
LOGGER.warning("WARNING ⚠ Edge TPU export requires batch size 1, setting batch=1.") LOGGER.warning("WARNING ⚠ Edge TPU export requires batch size 1, setting batch=1.")
self.args.batch = 1 self.args.batch = 1
if isinstance(model, WorldModel): if isinstance(model, WorldModel):

@ -742,7 +742,6 @@ class Model(nn.Module):
if hasattr(self.model, "names"): if hasattr(self.model, "names"):
return check_class_names(self.model.names) return check_class_names(self.model.names)
else:
if not self.predictor: # export formats will not have predictor defined until predict() is called if not self.predictor: # export formats will not have predictor defined until predict() is called
self.predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks) self.predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks)
self.predictor.setup_model(model=self.model, verbose=False) self.predictor.setup_model(model=self.model, verbose=False)

@ -319,13 +319,13 @@ class BasePredictor:
frame = self.dataset.count frame = self.dataset.count
else: else:
match = re.search(r"frame (\d+)/", s[i]) match = re.search(r"frame (\d+)/", s[i])
frame = int(match.group(1)) if match else None # 0 if frame undetermined frame = int(match[1]) if match else None # 0 if frame undetermined
self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}")) self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}"))
string += "%gx%g " % im.shape[2:] string += "%gx%g " % im.shape[2:]
result = self.results[i] result = self.results[i]
result.save_dir = self.save_dir.__str__() # used in other locations result.save_dir = self.save_dir.__str__() # used in other locations
string += result.verbose() + f"{result.speed['inference']:.1f}ms" string += f"{result.verbose()}{result.speed['inference']:.1f}ms"
# Add predictions to image # Add predictions to image
if self.args.save or self.args.show: if self.args.save or self.args.show:

@ -368,5 +368,5 @@ class HUBTrainingSession:
Returns: Returns:
None None
""" """
for data in response.iter_content(chunk_size=1024): for _ in response.iter_content(chunk_size=1024):
pass # Do nothing with data chunks pass # Do nothing with data chunks

@ -25,7 +25,7 @@ class FastSAMPrompt:
def __init__(self, source, results, device="cuda") -> None: def __init__(self, source, results, device="cuda") -> None:
"""Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment.""" """Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment."""
if isinstance(source, (str, Path)) and os.path.isdir(source): if isinstance(source, (str, Path)) and os.path.isdir(source):
raise ValueError(f"FastSAM only accepts image paths and PIL Image sources, not directories.") raise ValueError("FastSAM only accepts image paths and PIL Image sources, not directories.")
self.device = device self.device = device
self.results = results self.results = results
self.source = source self.source = source

@ -17,7 +17,7 @@ class NASValidator(DetectionValidator):
ultimately producing the final detections. ultimately producing the final detections.
Attributes: Attributes:
args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU thresholds. args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU.
lb (torch.Tensor): Optional tensor for multilabel NMS. lb (torch.Tensor): Optional tensor for multilabel NMS.
Example: Example:

@ -300,22 +300,22 @@ class DetectionValidator(BaseValidator):
anno = COCO(str(anno_json)) # init annotations api anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
eval = COCOeval(anno, pred, "bbox") val = COCOeval(anno, pred, "bbox")
else: else:
from lvis import LVIS, LVISEval from lvis import LVIS, LVISEval
anno = LVIS(str(anno_json)) # init annotations api anno = LVIS(str(anno_json)) # init annotations api
pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path) pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path)
eval = LVISEval(anno, pred, "bbox") val = LVISEval(anno, pred, "bbox")
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval val.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
eval.evaluate() val.evaluate()
eval.accumulate() val.accumulate()
eval.summarize() val.summarize()
if self.is_lvis: if self.is_lvis:
eval.print_results() # explicitly call print_results val.print_results() # explicitly call print_results
# update mAP50-95 and mAP50 # update mAP50-95 and mAP50
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = ( stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
eval.stats[:2] if self.is_coco else [eval.results["AP50"], eval.results["AP"]] val.stats[:2] if self.is_coco else [val.results["AP50"], val.results["AP"]]
) )
except Exception as e: except Exception as e:
LOGGER.warning(f"{pkg} unable to run: {e}") LOGGER.warning(f"{pkg} unable to run: {e}")

@ -54,7 +54,8 @@ class WorldTrainerFromScratch(WorldTrainer):
batch (int, optional): Size of batches, this is for `rect`. Defaults to None. batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
""" """
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
if mode == "train": if mode != "train":
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
dataset = [ dataset = [
build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True) build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True)
if isinstance(im_path, str) if isinstance(im_path, str)
@ -62,8 +63,6 @@ class WorldTrainerFromScratch(WorldTrainer):
for im_path in img_path for im_path in img_path
] ]
return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0] return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0]
else:
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
def get_dataset(self): def get_dataset(self):
""" """
@ -71,7 +70,7 @@ class WorldTrainerFromScratch(WorldTrainer):
Returns None if data format is not recognized. Returns None if data format is not recognized.
""" """
final_data = dict() final_data = {}
data_yaml = self.args.data data_yaml = self.args.data
assert data_yaml.get("train", False) # object365.yaml assert data_yaml.get("train", False) # object365.yaml
assert data_yaml.get("val", False) # lvis.yaml assert data_yaml.get("val", False) # lvis.yaml
@ -88,7 +87,7 @@ class WorldTrainerFromScratch(WorldTrainer):
grounding_data = data_yaml[s].get("grounding_data") grounding_data = data_yaml[s].get("grounding_data")
if grounding_data is None: if grounding_data is None:
continue continue
grounding_data = [grounding_data] if not isinstance(grounding_data, list) else grounding_data grounding_data = grounding_data if isinstance(grounding_data, list) else [grounding_data]
for g in grounding_data: for g in grounding_data:
assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}" assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
final_data[s] += grounding_data final_data[s] += grounding_data

@ -320,10 +320,8 @@ class AutoBackend(nn.Module):
with open(w, "rb") as f: with open(w, "rb") as f:
gd.ParseFromString(f.read()) gd.ParseFromString(f.read())
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd)) frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
try: # attempt to retrieve metadata from SavedModel file potentially alongside GraphDef file with contextlib.suppress(StopIteration): # find metadata in SavedModel alongside GraphDef
metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml")) metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml"))
except StopIteration:
pass # no metadata file found
# TFLite or TFLite Edge TPU # TFLite or TFLite Edge TPU
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python

@ -666,8 +666,7 @@ class CBLinear(nn.Module):
def forward(self, x): def forward(self, x):
"""Forward pass through CBLinear layer.""" """Forward pass through CBLinear layer."""
outs = self.conv(x).split(self.c2s, dim=1) return self.conv(x).split(self.c2s, dim=1)
return outs
class CBFuse(nn.Module): class CBFuse(nn.Module):
@ -682,5 +681,4 @@ class CBFuse(nn.Module):
"""Forward pass through CBFuse layer.""" """Forward pass through CBFuse layer."""
target_size = xs[-1].shape[2:] target_size = xs[-1].shape[2:]
res = [F.interpolate(x[self.idx[i]], size=target_size, mode="nearest") for i, x in enumerate(xs[:-1])] res = [F.interpolate(x[self.idx[i]], size=target_size, mode="nearest") for i, x in enumerate(xs[:-1])]
out = torch.sum(torch.stack(res + xs[-1:]), dim=0) return torch.sum(torch.stack(res + xs[-1:]), dim=0)
return out

@ -93,7 +93,7 @@ class AIGym:
self.stage[ind] = "up" self.stage[ind] = "up"
self.count[ind] += 1 self.count[ind] += 1
elif self.pose_type == "pushup" or self.pose_type == "squat": elif self.pose_type in {"pushup", "squat"}:
if self.angle[ind] > self.poseup_angle: if self.angle[ind] > self.poseup_angle:
self.stage[ind] = "up" self.stage[ind] = "up"
if self.angle[ind] < self.posedown_angle and self.stage[ind] == "up": if self.angle[ind] < self.posedown_angle and self.stage[ind] == "up":

@ -172,7 +172,7 @@ class ObjectCounter:
if self.draw_tracks: if self.draw_tracks:
self.annotator.draw_centroid_and_tracks( self.annotator.draw_centroid_and_tracks(
track_line, track_line,
color=self.track_color if self.track_color else colors(int(track_id), True), color=self.track_color or colors(int(track_id), True),
track_thickness=self.track_thickness, track_thickness=self.track_thickness,
) )

@ -73,8 +73,7 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
idx = tracks[:, -1].astype(int) idx = tracks[:, -1].astype(int)
predictor.results[i] = predictor.results[i][idx] predictor.results[i] = predictor.results[i][idx]
update_args = dict() update_args = {"obb" if is_obb else "boxes": torch.as_tensor(tracks[:, :-1])}
update_args["obb" if is_obb else "boxes"] = torch.as_tensor(tracks[:, :-1])
predictor.results[i].update(**update_args) predictor.results[i].update(**update_args)

@ -44,7 +44,7 @@ class GMC:
super().__init__() super().__init__()
self.method = method self.method = method
self.downscale = max(1, int(downscale)) self.downscale = max(1, downscale)
if self.method == "orb": if self.method == "orb":
self.detector = cv2.FastFeatureDetector_create(20) self.detector = cv2.FastFeatureDetector_create(20)

@ -208,9 +208,10 @@ class RF100Benchmark:
return self.ds_names, self.ds_cfg_list return self.ds_names, self.ds_cfg_list
def fix_yaml(self, path): @staticmethod
def fix_yaml(path):
""" """
Function to fix yaml train and val path. Function to fix YAML train and val path.
Args: Args:
path (str): YAML file path. path (str): YAML file path.
@ -245,32 +246,19 @@ class RF100Benchmark:
entries = line.split(" ") entries = line.split(" ")
entries = list(filter(lambda val: val != "", entries)) entries = list(filter(lambda val: val != "", entries))
entries = [e.strip("\n") for e in entries] entries = [e.strip("\n") for e in entries]
start_class = False eval_lines.extend(
for e in entries: {
if e == "all": "class": entries[0],
if "(AP)" not in entries: "images": entries[1],
if "(AR)" not in entries: "targets": entries[2],
# parse all "precision": entries[3],
eval = {} "recall": entries[4],
eval["class"] = entries[0] "map50": entries[5],
eval["images"] = entries[1] "map95": entries[6],
eval["targets"] = entries[2] }
eval["precision"] = entries[3] for e in entries
eval["recall"] = entries[4] if e in class_names or (e == "all" and "(AP)" not in entries and "(AR)" not in entries)
eval["map50"] = entries[5] )
eval["map95"] = entries[6]
eval_lines.append(eval)
if e in class_names:
eval = {}
eval["class"] = entries[0]
eval["images"] = entries[1]
eval["targets"] = entries[2]
eval["precision"] = entries[3]
eval["recall"] = entries[4]
eval["map50"] = entries[5]
eval["map95"] = entries[6]
eval_lines.append(eval)
map_val = 0.0 map_val = 0.0
if len(eval_lines) > 1: if len(eval_lines) > 1:
print("There's more dicts") print("There's more dicts")

@ -103,7 +103,8 @@ def on_fit_epoch_end(trainer):
def on_train_end(trainer): def on_train_end(trainer):
"""Log model artifacts at the end of the training.""" """Log model artifacts at the end of the training."""
if mlflow: if not mlflow:
return
mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt
for f in trainer.save_dir.glob("*"): # log all other files in save_dir for f in trainer.save_dir.glob("*"): # log all other files in save_dir
if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}: if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
@ -116,8 +117,7 @@ def on_train_end(trainer):
LOGGER.debug(f"{PREFIX}mlflow run ended") LOGGER.debug(f"{PREFIX}mlflow run ended")
LOGGER.info( LOGGER.info(
f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n" f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n{PREFIX}disable with 'yolo settings mlflow=False'"
f"{PREFIX}disable with 'yolo settings mlflow=False'"
) )

@ -19,7 +19,7 @@ def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall
""" """
Create and log a custom metric visualization to wandb.plot.pr_curve. Create and log a custom metric visualization to wandb.plot.pr_curve.
This function crafts a custom metric visualization that mimics the behavior of wandb's default precision-recall This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall
curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
different classes. different classes.

@ -434,10 +434,9 @@ def check_torchvision():
# Extract only the major and minor versions # Extract only the major and minor versions
v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2]) v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2])
if v_torch in compatibility_table: if v_torch in compatibility_table:
compatible_versions = compatibility_table[v_torch] compatible_versions = compatibility_table[v_torch]
v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2])
if all(v_torchvision != v for v in compatible_versions): if all(v_torchvision != v for v in compatible_versions):
print( print(
f"WARNING ⚠ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n" f"WARNING ⚠ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n"

@ -493,7 +493,7 @@ class Annotator:
angle = 360 - angle angle = 360 - angle
return angle return angle
def draw_specific_points(self, keypoints, indices=[2, 5, 7], shape=(640, 640), radius=2, conf_thres=0.25): def draw_specific_points(self, keypoints, indices=None, shape=(640, 640), radius=2, conf_thres=0.25):
""" """
Draw specific keypoints for gym steps counting. Draw specific keypoints for gym steps counting.
@ -503,6 +503,8 @@ class Annotator:
shape (tuple): imgsz for model inference shape (tuple): imgsz for model inference
radius (int): Keypoint radius value radius (int): Keypoint radius value
""" """
if indices is None:
indices = [2, 5, 7]
for i, k in enumerate(keypoints): for i, k in enumerate(keypoints):
if i in indices: if i in indices:
x_coord, y_coord = k[0], k[1] x_coord, y_coord = k[0], k[1]

Loading…
Cancel
Save