You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

135 lines
4.5 KiB

"""Train the model"""
import argparse
import datetime
import os
import torch
import torch.optim as optim
from tqdm import tqdm
# from apex import amp
import dataset.data_loader as data_loader
import model.net as net
from common import utils
from common.manager import Manager
from evaluate import evaluate
from loss.losses import compute_losses
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', default='experiments/base_model', help="Directory containing params.json")
parser.add_argument('--restore_file',
default=None,
help="Optional, name of the file in --model_dir containing weights to reload before \
training") #
parser.add_argument('-ow', '--only_weights', action='store_true', help='Only use weights to load or load all train status.')
def train(model, manager):
# loss status initial
manager.reset_loss_status()
# set model to training mode
torch.cuda.empty_cache()
model.train()
# Use tqdm for progress bar
with tqdm(total=len(manager.dataloaders['train'])) as t:
for i, data_batch in enumerate(manager.dataloaders['train']):
# move to GPU if available
data_batch = utils.tensor_gpu(data_batch)
# compute model output and loss
output_batch = model(data_batch)
loss = compute_losses(output_batch, data_batch, manager.params)
# update loss status and print current loss and average loss
manager.update_loss_status(loss=loss, split="train")
# clear previous gradients, compute gradients of all variables loss
manager.optimizer.zero_grad()
loss['total'].backward()
# performs updates using calculated gradients
manager.optimizer.step()
# manager.logger.info("Loss/train: step {}: {}".format(manager.step, manager.loss_status['total'].val))
# update step: step += 1
manager.update_step()
# infor print
print_str = manager.print_train_info()
t.set_description(desc=print_str)
t.update()
manager.scheduler.step()
# update epoch: epoch += 1
manager.update_epoch()
def train_and_evaluate(model, manager):
# reload weights from restore_file if specified
if args.restore_file is not None:
manager.load_checkpoints()
for epoch in range(manager.params.num_epochs):
# compute number of batches in one epoch (one full pass over the training set)
train(model, manager)
# Save latest model, or best model weights accroding to the params.major_metric
manager.check_best_save_last_checkpoints(latest_freq_val=999, latest_freq=1)
if __name__ == '__main__':
# Load the parameters from json file
args = parser.parse_args()
json_path = os.path.join(args.model_dir, 'params.json')
assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path)
params = utils.Params(json_path)
# Update args into params
params.update(vars(args))
# use GPU if available
params.cuda = torch.cuda.is_available()
# Set the random seed for reproducible experiments
torch.manual_seed(230)
if params.cuda:
torch.cuda.manual_seed(230)
# Set the logger
logger = utils.set_logger(os.path.join(params.model_dir, 'train.log'))
# Set the tensorboard writer
log_dir = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# Create the input data pipeline
logger.info("Loading the train datasets from {}".format(params.train_data_dir))
# fetch dataloaders
dataloaders = data_loader.fetch_dataloader(params)
# Define the model and optimizer
if params.cuda:
model = net.fetch_net(params).cuda()
optimizer = optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=params.gamma)
model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
else:
model = net.fetch_net(params)
optimizer = optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=params.gamma)
# initial status for checkpoint manager
manager = Manager(model=model, optimizer=optimizer, scheduler=scheduler, params=params, dataloaders=dataloaders, writer=None, logger=logger)
# Train the model
logger.info("Starting training for {} epoch(s)".format(params.num_epochs))
train_and_evaluate(model, manager)