Update metrics names (#85)

pull/89/head
Glenn Jocher 2 years ago committed by GitHub
parent 6432afc5f9
commit 248d54ca03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      ultralytics/yolo/data/dataset.py
  2. 11
      ultralytics/yolo/engine/trainer.py
  3. 19
      ultralytics/yolo/utils/callbacks/clearml.py
  4. 3
      ultralytics/yolo/utils/configs/__init__.py
  5. 10
      ultralytics/yolo/utils/metrics.py
  6. 6
      ultralytics/yolo/v8/detect/train.py
  7. 2
      ultralytics/yolo/v8/detect/val.py
  8. 5
      ultralytics/yolo/v8/segment/train.py
  9. 8
      ultralytics/yolo/v8/segment/val.py

@ -100,7 +100,7 @@ class YOLODataset(BaseDataset):
self.label_files = img2label_paths(self.im_files)
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
try:
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
assert cache["version"] == self.cache_version # matches current version
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
except Exception:

@ -82,6 +82,7 @@ class BaseTrainer:
self.fitness = None
self.loss = None
self.tloss = None
self.loss_names = None
self.csv = self.save_dir / 'results.csv'
for callback, func in callbacks.default_callbacks.items():
@ -106,7 +107,7 @@ class BaseTrainer:
def train(self):
world_size = torch.cuda.device_count()
if world_size > 1 and not ("LOCAL_RANK" in os.environ):
if world_size > 1 and "LOCAL_RANK" not in os.environ:
command = generate_ddp_command(world_size, self)
subprocess.Popen(command)
ddp_cleanup(command, self)
@ -154,11 +155,9 @@ class BaseTrainer:
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode="train")
if rank in {0, -1}:
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
validator = self.get_validator()
# init metric, for plot_results
metric_keys = validator.metric_keys + self.label_loss_items(prefix="val")
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
self.validator = validator
self.validator = self.get_validator()
# metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
# self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
self.ema = ModelEMA(self.model)
def _do_train(self, rank=-1, world_size=1):

@ -24,29 +24,22 @@ def before_train(trainer):
output_uri=True,
reuse_last_task_id=False,
auto_connect_frameworks={'pytorch': False})
task.connect(trainer.args, name='parameters')
task.connect(dict(trainer.args), name='General')
def on_batch_end(trainer):
train_loss = trainer.tloss
_log_scalers(trainer.label_loss_items(train_loss), "train", trainer.epoch)
_log_scalers(trainer.label_loss_items(trainer.tloss, prefix="train"), "train", trainer.epoch)
def on_val_end(trainer):
metrics = trainer.metrics
val_losses = trainer.validator.loss
val_loss_dict = trainer.label_loss_items(val_losses)
_log_scalers(val_loss_dict, "val", trainer.epoch)
_log_scalers(metrics, "metrics", trainer.epoch)
_log_scalers(trainer.label_loss_items(trainer.validator.loss, prefix="val"), "val", trainer.epoch)
_log_scalers({k: v for k, v in trainer.metrics.items() if k.startswith("metrics")}, "metrics", trainer.epoch)
if trainer.epoch == 0:
infer_speed = trainer.validator.speed[1]
model_info = {
"inference_speed": infer_speed,
"inference_speed": trainer.validator.speed[1],
"flops@640": get_flops(trainer.model),
"params": get_num_params(trainer.model)}
_log_scalers(model_info, "model")
Task.current_task().connect(model_info, 'Model')
def on_train_end(trainer):

@ -6,10 +6,11 @@ from omegaconf import DictConfig, OmegaConf
from ultralytics.yolo.utils.configs.hydra_patch import check_config_mismatch
def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict] = {}):
def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict]):
"""
Accepts yaml file name or DictConfig containing experiment configuration.
Returns training args namespace
:param overrides: Overrides str or Dict
:param config: Optional file name or DictConfig object
"""
if isinstance(config, (str, Path)):

@ -514,7 +514,7 @@ class DetMetrics:
@property
def keys(self):
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP_0.5(B)", "metrics/mAP_0.5:0.95(B)"]
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
def mean_results(self):
return self.metric.mean_results()
@ -567,12 +567,12 @@ class SegmentMetrics:
return [
"metrics/precision(B)",
"metrics/recall(B)",
"metrics/mAP_0.5(B)",
"metrics/mAP_0.5:0.95(B)", # metrics
"metrics/mAP50(B)",
"metrics/mAP50-95(B)", # metrics
"metrics/precision(M)",
"metrics/recall(M)",
"metrics/mAP_0.5(M)",
"metrics/mAP_0.5:0.95(M)"]
"metrics/mAP50(M)",
"metrics/mAP50-95(M)"]
def mean_results(self):
return self.metric_box.mean_results() + self.metric_mask.mean_results()

@ -46,6 +46,7 @@ class DetectionTrainer(BaseTrainer):
return model
def get_validator(self):
self.loss_names = 'box_loss', 'obj_loss', 'cls_loss'
return v8.detect.DetectionValidator(self.test_loader,
save_dir=self.save_dir,
logger=self.console,
@ -190,15 +191,14 @@ class DetectionTrainer(BaseTrainer):
loss = lbox + lobj + lcls
return loss * bs, torch.cat((lbox, lobj, lcls)).detach()
# TODO: improve from API users perspective
def label_loss_items(self, loss_items=None, prefix="train"):
# We should just use named tensors here in future
keys = [f"{prefix}/lbox", f"{prefix}/lobj", f"{prefix}/lcls"]
keys = [f"{prefix}/{x}" for x in self.loss_names]
return dict(zip(keys, loss_items)) if loss_items is not None else keys
def progress_string(self):
return ('\n' + '%11s' * 6) % \
('Epoch', 'GPU_mem', 'box_loss', 'obj_loss', 'cls_loss', 'Size')
('Epoch', 'GPU_mem', *self.loss_names, 'Size')
def plot_training_samples(self, batch, ni):
images = batch["img"]

@ -173,7 +173,7 @@ class DetectionValidator(BaseValidator):
# TODO: align with train loss metrics
@property
def metric_keys(self):
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP_0.5(B)", "metrics/mAP_0.5:0.95(B)"]
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
def plot_val_samples(self, batch, ni):
images = batch["img"]

@ -29,6 +29,7 @@ class SegmentationTrainer(DetectionTrainer):
return model
def get_validator(self):
self.loss_names = 'box_loss', 'seg_loss', 'obj_loss', 'cls_loss'
return v8.segment.SegmentationValidator(self.test_loader,
save_dir=self.save_dir,
logger=self.console,
@ -212,12 +213,12 @@ class SegmentationTrainer(DetectionTrainer):
def label_loss_items(self, loss_items=None, prefix="train"):
# We should just use named tensors here in future
keys = [f"{prefix}/lbox", f"{prefix}/lseg", f"{prefix}/lobj", f"{prefix}/lcls"]
keys = [f"{prefix}/{x}" for x in self.loss_names]
return dict(zip(keys, loss_items)) if loss_items is not None else keys
def progress_string(self):
return ('\n' + '%11s' * 7) % \
('Epoch', 'GPU_mem', 'box_loss', 'seg_loss', 'obj_loss', 'cls_loss', 'Size')
('Epoch', 'GPU_mem', *self.loss_names, 'Size')
def plot_training_samples(self, batch, ni):
images = batch["img"]

@ -178,12 +178,12 @@ class SegmentationValidator(DetectionValidator):
return [
"metrics/precision(B)",
"metrics/recall(B)",
"metrics/mAP_0.5(B)",
"metrics/mAP_0.5:0.95(B)", # metrics
"metrics/mAP50(B)",
"metrics/mAP50-95(B)", # metrics
"metrics/precision(M)",
"metrics/recall(M)",
"metrics/mAP_0.5(M)",
"metrics/mAP_0.5:0.95(M)",]
"metrics/mAP50(M)",
"metrics/mAP50-95(M)",]
def plot_val_samples(self, batch, ni):
images = batch["img"]

Loading…
Cancel
Save