You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

136 lines
4.5 KiB

3 years ago
"""Train the model"""
import argparse
import datetime
import os
import torch
import torch.optim as optim
from tqdm import tqdm
# from apex import amp
import dataset.data_loader as data_loader
import model.net as net
from common import utils
from common.manager import Manager
from evaluate import evaluate
from loss.losses import compute_losses
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', default='experiments/base_model', help="Directory containing params.json")
parser.add_argument('--restore_file',
default=None,
help="Optional, name of the file in --model_dir containing weights to reload before \
training") #
parser.add_argument('-ow', '--only_weights', action='store_true', help='Only use weights to load or load all train status.')
def train(model, manager):
# loss status initial
manager.reset_loss_status()
# set model to training mode
torch.cuda.empty_cache()
model.train()
# Use tqdm for progress bar
with tqdm(total=len(manager.dataloaders['train'])) as t:
for i, data_batch in enumerate(manager.dataloaders['train']):
# move to GPU if available
data_batch = utils.tensor_gpu(data_batch)
# compute model output and loss
output_batch = model(data_batch)
loss = compute_losses(output_batch, data_batch, manager.params)
# update loss status and print current loss and average loss
manager.update_loss_status(loss=loss, split="train")
# clear previous gradients, compute gradients of all variables loss
manager.optimizer.zero_grad()
loss['total'].backward()
# performs updates using calculated gradients
manager.optimizer.step()
# manager.logger.info("Loss/train: step {}: {}".format(manager.step, manager.loss_status['total'].val))
# update step: step += 1
manager.update_step()
# infor print
print_str = manager.print_train_info()
t.set_description(desc=print_str)
t.update()
manager.scheduler.step()
# update epoch: epoch += 1
manager.update_epoch()
def train_and_evaluate(model, manager):
# reload weights from restore_file if specified
if args.restore_file is not None:
manager.load_checkpoints()
for epoch in range(manager.params.num_epochs):
# compute number of batches in one epoch (one full pass over the training set)
train(model, manager)
# Save latest model, or best model weights accroding to the params.major_metric
manager.check_best_save_last_checkpoints(latest_freq_val=999, latest_freq=1)
if __name__ == '__main__':
# Load the parameters from json file
args = parser.parse_args()
json_path = os.path.join(args.model_dir, 'params.json')
assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path)
params = utils.Params(json_path)
# Update args into params
params.update(vars(args))
# use GPU if available
params.cuda = torch.cuda.is_available()
# Set the random seed for reproducible experiments
torch.manual_seed(230)
if params.cuda:
torch.cuda.manual_seed(230)
# Set the logger
logger = utils.set_logger(os.path.join(params.model_dir, 'train.log'))
# Set the tensorboard writer
log_dir = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# Create the input data pipeline
logger.info("Loading the train datasets from {}".format(params.train_data_dir))
# fetch dataloaders
dataloaders = data_loader.fetch_dataloader(params)
# Define the model and optimizer
if params.cuda:
model = net.fetch_net(params).cuda()
optimizer = optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=params.gamma)
model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
else:
model = net.fetch_net(params)
optimizer = optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=params.gamma)
# initial status for checkpoint manager
manager = Manager(model=model, optimizer=optimizer, scheduler=scheduler, params=params, dataloaders=dataloaders, writer=None, logger=logger)
# Train the model
logger.info("Starting training for {} epoch(s)".format(params.num_epochs))
train_and_evaluate(model, manager)