diff --git a/ultralytics/utils/callbacks/comet.py b/ultralytics/utils/callbacks/comet.py index 65c138bce1..224a6e0f15 100644 --- a/ultralytics/utils/callbacks/comet.py +++ b/ultralytics/utils/callbacks/comet.py @@ -177,7 +177,7 @@ def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, c return {"name": "ground_truth", "data": data} -def _format_prediction_annotations_for_detection(image_path, metadata, class_label_map=None): +def _format_prediction_annotations_for_detection(image_path, metadata, class_label_map=None, class_map=None): """Format YOLO predictions for object detection visualization.""" stem = image_path.stem image_id = int(stem) if stem.isnumeric() else stem @@ -187,26 +187,32 @@ def _format_prediction_annotations_for_detection(image_path, metadata, class_lab LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes predictions") return None + label_index_offset = 0 + if class_map is not None: + # offset to align indices of class labels (starting from zero) + # with prediction's category ID indices (can start from one) + label_index_offset = sorted(class_map)[0] + data = [] for prediction in predictions: boxes = prediction["bbox"] score = _scale_confidence_score(prediction["score"]) cls_label = prediction["category_id"] if class_label_map: - cls_label = str(class_label_map[cls_label]) + cls_label = str(class_label_map[cls_label - label_index_offset]) data.append({"boxes": [boxes], "label": cls_label, "score": score}) return {"name": "prediction", "data": data} -def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map): +def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map, class_map): """Join the ground truth and prediction annotations if they exist.""" ground_truth_annotations = _format_ground_truth_annotations_for_detection( img_idx, image_path, batch, class_label_map ) prediction_annotations = _format_prediction_annotations_for_detection( - image_path, prediction_metadata_map, class_label_map + image_path, prediction_metadata_map, class_label_map, class_map ) annotations = [ @@ -260,6 +266,7 @@ def _log_image_predictions(experiment, validator, curr_step): predictions_metadata_map = _create_prediction_metadata_map(jdict) dataloader = validator.dataloader class_label_map = validator.names + class_map = getattr(validator, "class_map", None) batch_logging_interval = _get_eval_batch_logging_interval() max_image_predictions = _get_max_image_predictions_to_log() @@ -280,6 +287,7 @@ def _log_image_predictions(experiment, validator, curr_step): batch, predictions_metadata_map, class_label_map, + class_map=class_map, ) _log_images( experiment,