You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

256 lines
11 KiB

import json
import logging
import os
import pickle
import shutil
import time
from collections import defaultdict
import numpy as np
import torch
from termcolor import colored
from common import utils
class Manager():
def __init__(self, model, optimizer, scheduler, params, dataloaders, writer, logger):
# params status
self.params = params
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.dataloaders = dataloaders
self.writer = writer
self.logger = logger
self.epoch = 0
self.epoch_val = 0
self.step = 0
self.best_val_score = 100
self.best_test_score = 100
self.cur_val_score = 0
self.cur_test_score = 0
# train status
self.train_status = defaultdict(utils.AverageMeter)
# val status
self.val_status = defaultdict(utils.AverageMeter)
# test status
self.test_status = defaultdict(utils.AverageMeter)
# model status
self.loss_status = defaultdict(utils.AverageMeter)
def update_step(self):
self.step += 1
def update_epoch(self):
self.epoch += 1
self.epoch_val = 0
def update_epoch_val(self):
self.epoch_val += 1
def update_loss_status(self, loss, split, batch_size=None):
# loss: dict: ["loss_name": value, ...]
if split == "train":
for k, v in loss.items():
self.loss_status[k].update(val=v.item(), num=batch_size if batch_size is not None else self.params.train_batch_size)
elif split == "val":
for k, v in loss.items():
self.loss_status[k].update(val=v.item(), num=batch_size if batch_size is not None else self.params.eval_batch_size)
elif split == "test":
for k, v in loss.items():
self.loss_status[k].update(val=v.item(), num=batch_size if batch_size is not None else self.params.eval_batch_size)
else:
raise ValueError("Wrong eval type: {}".format(split))
def update_metric_status(self, metrics, split, batch_size=None):
# metric: dict: ["metric_name": value, ...]
if split == "val":
for k, v in metrics.items():
self.val_status[k].update(val=v.item(), num=batch_size if batch_size is not None else self.params.eval_batch_size)
self.cur_val_score = self.val_status[self.params.major_metric].avg
elif split == "test":
for k, v in metrics.items():
self.test_status[k].update(val=v.item(), num=batch_size if batch_size is not None else self.params.eval_batch_size)
self.cur_test_score = self.test_status[self.params.major_metric].avg
else:
raise ValueError("Wrong eval type: {}".format(split))
def reset_loss_status(self):
for k, v in self.loss_status.items():
self.loss_status[k].reset()
def reset_metric_status(self, split):
if split == "val":
for k, v in self.val_status.items():
self.val_status[k].reset()
elif split == "test":
for k, v in self.test_status.items():
self.test_status[k].reset()
else:
raise ValueError("Wrong eval type: {}".format(split))
def print_train_info(self):
exp_name = self.params.model_dir.split('/')[-1]
# print_str = "{} Epoch: {:4d}, lr={:.4f} ".format(exp_name, self.epoch, self.scheduler.get_last_lr()[0]) pytorch version>=1.4.0
print_str = "{} Epoch: {:4d}, lr={:.4f} ".format(exp_name, self.epoch, self.scheduler.get_lr()[0])
print_str += "total loss: %.4f(%.4f) " % (self.loss_status['total'].val, self.loss_status['total'].avg)
print_str += "photo_loss_l1: %.4f(%.4f) " % (self.loss_status['photo_loss_l1'].val, self.loss_status['photo_loss_l1'].avg)
print_str += "fea_loss_l1: %.4f(%.4f)" % (self.loss_status['fea_loss_l1'].val, self.loss_status['fea_loss_l1'].avg)
return print_str
def print_metrics(self, split, title="Eval", color="red"):
if split == "val":
metric_status = self.val_status
elif split == "test":
metric_status = self.test_status
else:
raise ValueError("Wrong eval type: {}".format(split))
print_str = " | ".join("{}: {:.4f}".format(k, v.avg) for k, v in metric_status.items())
self.logger.info(colored("{} Results: {}".format(title, print_str), color, attrs=["bold"]))
def check_best_save_last_checkpoints(self, latest_freq_val=5, latest_freq=5):
state = {
"state_dict": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"scheduler": self.scheduler.state_dict(),
"step": self.step,
"epoch": self.epoch,
}
if "val" in self.dataloaders:
state["best_val_score"] = self.best_val_score
if "test" in self.dataloaders:
state["best_test_score"] = self.best_test_score
# save latest checkpoint
if self.epoch % latest_freq == 0 or self.epoch % latest_freq_val==0:
latest_ckpt_name = os.path.join(self.params.model_dir, "model_latest.pth")
if self.params.save_mode == "local":
torch.save(state, latest_ckpt_name)
else:
raise NotImplementedError
self.logger.info("Saved latest checkpoint to: {}".format(latest_ckpt_name))
# save val latest metrics, and check if val is best checkpoints
if "val" in self.dataloaders:
val_latest_metrics_name = os.path.join(self.params.model_dir, "val_metrics_latest.json")
utils.save_dict_to_json(self.val_status, val_latest_metrics_name)
is_best = self.cur_val_score < self.best_val_score
if is_best:
# save metrics
self.best_val_score = self.cur_val_score
best_metrics_name = os.path.join(self.params.model_dir, "val_metrics_best.json")
val_status_save = self.val_status.copy()
val_status_save.update(epoch=self.epoch, epoch_val=self.epoch_val, step=self.step)
utils.save_dict_to_json(val_status_save, best_metrics_name)
self.logger.info("Current is val best, score={:.4f}".format(self.best_val_score))
# save checkpoint
best_ckpt_name = os.path.join(self.params.model_dir, "val_model_best.pth")
if self.params.save_mode == "local":
torch.save(state, best_ckpt_name)
self.logger.info("Saved val best checkpoint to: {}".format(best_ckpt_name))
# save test latest metrics, and check if test is best checkpoints
# if self.dataloaders["test"] is not None:
if "test" in self.dataloaders:
test_latest_metrics_name = os.path.join(self.params.model_dir, "test_metrics_latest.json")
utils.save_dict_to_json(self.test_status, test_latest_metrics_name)
is_best = self.cur_test_score < self.best_test_score
if is_best:
# save metrics
self.best_test_score = self.cur_test_score
best_metrics_name = os.path.join(self.params.model_dir, "test_metrics_best.json")
test_status_save = self.test_status.copy()
test_status_save.update(epoch=self.epoch, epoch_val=self.epoch_val, step=self.step)
utils.save_dict_to_json(test_status_save, best_metrics_name)
self.logger.info("Current is test best, score={:.4f}".format(self.best_test_score))
# save checkpoint
best_ckpt_name = os.path.join(self.params.model_dir, "test_model_best.pth")
if self.params.save_mode == "local":
torch.save(state, best_ckpt_name)
self.logger.info("Saved test best checkpoint to: {}".format(best_ckpt_name))
def load_checkpoints(self):
if self.params.save_mode == "local":
if self.params.cuda:
state = torch.load(self.params.restore_file)
else:
state = torch.load(self.params.restore_file, map_location=torch.device('cpu'))
ckpt_component = []
if "state_dict" in state and self.model is not None:
try:
self.model.load_state_dict(state["state_dict"])
except:
print("Using custom loading net")
net_dict = self.model.state_dict()
if "module" not in list(state["state_dict"].keys())[0]:
state_dict = {"module." + k: v for k, v in state["state_dict"].items() if "module." + k in net_dict.keys()}
else:
state_dict = {k.replace("module.",""): v for k, v in state["state_dict"].items() if k.replace("module.","") in net_dict.keys()}
net_dict.update(state_dict)
self.model.load_state_dict(net_dict, strict=False)
ckpt_component.append("net")
if not self.params.only_weights:
if "optimizer" in state and self.optimizer is not None:
try:
self.optimizer.load_state_dict(state["optimizer"])
except:
print("Using custom loading optimizer")
optimizer_dict = self.optimizer.state_dict()
state_dict = {k: v for k, v in state["optimizer"].items() if k in optimizer_dict.keys()}
optimizer_dict.update(state_dict)
self.optimizer.load_state_dict(optimizer_dict)
ckpt_component.append("opt")
if "scheduler" in state and self.train_status["scheduler"] is not None:
try:
self.scheduler.load_state_dict(state["scheduler"])
except:
print("Using custom loading scheduler")
scheduler_dict = self.scheduler.state_dict()
state_dict = {k: v for k, v in state["scheduler"].items() if k in scheduler_dict.keys()}
scheduler_dict.update(state_dict)
self.scheduler.load_state_dict(scheduler_dict)
ckpt_component.append("sch")
if "step" in state:
self.train_status["step"] = state["step"] + 1
ckpt_component.append("step")
if "epoch" in state:
self.train_status["epoch"] = state["epoch"] + 1
ckpt_component.append("epoch")
if "best_val_score" in state:
self.best_val_score = state["best_val_score"]
ckpt_component.append("best val score: {:.3g}".format(self.best_val_score))
if "best_test_score" in state:
self.best_test_score = state["best_test_score"]
ckpt_component.append("best test score: {:.3g}".format(self.best_test_score))
ckpt_component = ", ".join(i for i in ckpt_component)
self.logger.info("Loaded models from: {}".format(self.params.restore_file))
self.logger.info("Ckpt load: {}".format(ckpt_component))