You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
136 lines
4.5 KiB
136 lines
4.5 KiB
3 years ago
|
"""Train the model"""
|
||
|
|
||
|
import argparse
|
||
|
import datetime
|
||
|
import os
|
||
|
|
||
|
import torch
|
||
|
import torch.optim as optim
|
||
|
from tqdm import tqdm
|
||
|
# from apex import amp
|
||
|
|
||
|
import dataset.data_loader as data_loader
|
||
|
import model.net as net
|
||
|
|
||
|
from common import utils
|
||
|
from common.manager import Manager
|
||
|
from evaluate import evaluate
|
||
|
from loss.losses import compute_losses
|
||
|
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('--model_dir', default='experiments/base_model', help="Directory containing params.json")
|
||
|
parser.add_argument('--restore_file',
|
||
|
default=None,
|
||
|
help="Optional, name of the file in --model_dir containing weights to reload before \
|
||
|
training") #
|
||
|
parser.add_argument('-ow', '--only_weights', action='store_true', help='Only use weights to load or load all train status.')
|
||
|
|
||
|
|
||
|
def train(model, manager):
|
||
|
|
||
|
# loss status initial
|
||
|
manager.reset_loss_status()
|
||
|
|
||
|
# set model to training mode
|
||
|
torch.cuda.empty_cache()
|
||
|
model.train()
|
||
|
|
||
|
# Use tqdm for progress bar
|
||
|
with tqdm(total=len(manager.dataloaders['train'])) as t:
|
||
|
for i, data_batch in enumerate(manager.dataloaders['train']):
|
||
|
|
||
|
# move to GPU if available
|
||
|
data_batch = utils.tensor_gpu(data_batch)
|
||
|
|
||
|
# compute model output and loss
|
||
|
output_batch = model(data_batch)
|
||
|
loss = compute_losses(output_batch, data_batch, manager.params)
|
||
|
|
||
|
# update loss status and print current loss and average loss
|
||
|
manager.update_loss_status(loss=loss, split="train")
|
||
|
|
||
|
# clear previous gradients, compute gradients of all variables loss
|
||
|
manager.optimizer.zero_grad()
|
||
|
loss['total'].backward()
|
||
|
|
||
|
# performs updates using calculated gradients
|
||
|
manager.optimizer.step()
|
||
|
# manager.logger.info("Loss/train: step {}: {}".format(manager.step, manager.loss_status['total'].val))
|
||
|
|
||
|
# update step: step += 1
|
||
|
manager.update_step()
|
||
|
|
||
|
# infor print
|
||
|
print_str = manager.print_train_info()
|
||
|
|
||
|
t.set_description(desc=print_str)
|
||
|
t.update()
|
||
|
|
||
|
manager.scheduler.step()
|
||
|
|
||
|
# update epoch: epoch += 1
|
||
|
manager.update_epoch()
|
||
|
|
||
|
def train_and_evaluate(model, manager):
|
||
|
|
||
|
# reload weights from restore_file if specified
|
||
|
if args.restore_file is not None:
|
||
|
manager.load_checkpoints()
|
||
|
|
||
|
for epoch in range(manager.params.num_epochs):
|
||
|
|
||
|
# compute number of batches in one epoch (one full pass over the training set)
|
||
|
train(model, manager)
|
||
|
|
||
|
# Save latest model, or best model weights accroding to the params.major_metric
|
||
|
manager.check_best_save_last_checkpoints(latest_freq_val=999, latest_freq=1)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
# Load the parameters from json file
|
||
|
args = parser.parse_args()
|
||
|
json_path = os.path.join(args.model_dir, 'params.json')
|
||
|
assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path)
|
||
|
params = utils.Params(json_path)
|
||
|
|
||
|
# Update args into params
|
||
|
params.update(vars(args))
|
||
|
|
||
|
# use GPU if available
|
||
|
params.cuda = torch.cuda.is_available()
|
||
|
|
||
|
# Set the random seed for reproducible experiments
|
||
|
torch.manual_seed(230)
|
||
|
if params.cuda:
|
||
|
torch.cuda.manual_seed(230)
|
||
|
|
||
|
# Set the logger
|
||
|
logger = utils.set_logger(os.path.join(params.model_dir, 'train.log'))
|
||
|
|
||
|
# Set the tensorboard writer
|
||
|
log_dir = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||
|
|
||
|
# Create the input data pipeline
|
||
|
logger.info("Loading the train datasets from {}".format(params.train_data_dir))
|
||
|
|
||
|
# fetch dataloaders
|
||
|
dataloaders = data_loader.fetch_dataloader(params)
|
||
|
|
||
|
# Define the model and optimizer
|
||
|
if params.cuda:
|
||
|
model = net.fetch_net(params).cuda()
|
||
|
optimizer = optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)
|
||
|
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=params.gamma)
|
||
|
model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
|
||
|
else:
|
||
|
model = net.fetch_net(params)
|
||
|
optimizer = optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)
|
||
|
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=params.gamma)
|
||
|
|
||
|
# initial status for checkpoint manager
|
||
|
manager = Manager(model=model, optimizer=optimizer, scheduler=scheduler, params=params, dataloaders=dataloaders, writer=None, logger=logger)
|
||
|
|
||
|
# Train the model
|
||
|
logger.info("Starting training for {} epoch(s)".format(params.num_epochs))
|
||
|
|
||
|
train_and_evaluate(model, manager)
|