You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
135 lines
4.5 KiB
135 lines
4.5 KiB
"""Train the model""" |
|
|
|
import argparse |
|
import datetime |
|
import os |
|
|
|
import torch |
|
import torch.optim as optim |
|
from tqdm import tqdm |
|
# from apex import amp |
|
|
|
import dataset.data_loader as data_loader |
|
import model.net as net |
|
|
|
from common import utils |
|
from common.manager import Manager |
|
from evaluate import evaluate |
|
from loss.losses import compute_losses |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--model_dir', default='experiments/base_model', help="Directory containing params.json") |
|
parser.add_argument('--restore_file', |
|
default=None, |
|
help="Optional, name of the file in --model_dir containing weights to reload before \ |
|
training") # |
|
parser.add_argument('-ow', '--only_weights', action='store_true', help='Only use weights to load or load all train status.') |
|
|
|
|
|
def train(model, manager): |
|
|
|
# loss status initial |
|
manager.reset_loss_status() |
|
|
|
# set model to training mode |
|
torch.cuda.empty_cache() |
|
model.train() |
|
|
|
# Use tqdm for progress bar |
|
with tqdm(total=len(manager.dataloaders['train'])) as t: |
|
for i, data_batch in enumerate(manager.dataloaders['train']): |
|
|
|
# move to GPU if available |
|
data_batch = utils.tensor_gpu(data_batch) |
|
|
|
# compute model output and loss |
|
output_batch = model(data_batch) |
|
loss = compute_losses(output_batch, data_batch, manager.params) |
|
|
|
# update loss status and print current loss and average loss |
|
manager.update_loss_status(loss=loss, split="train") |
|
|
|
# clear previous gradients, compute gradients of all variables loss |
|
manager.optimizer.zero_grad() |
|
loss['total'].backward() |
|
|
|
# performs updates using calculated gradients |
|
manager.optimizer.step() |
|
# manager.logger.info("Loss/train: step {}: {}".format(manager.step, manager.loss_status['total'].val)) |
|
|
|
# update step: step += 1 |
|
manager.update_step() |
|
|
|
# infor print |
|
print_str = manager.print_train_info() |
|
|
|
t.set_description(desc=print_str) |
|
t.update() |
|
|
|
manager.scheduler.step() |
|
|
|
# update epoch: epoch += 1 |
|
manager.update_epoch() |
|
|
|
def train_and_evaluate(model, manager): |
|
|
|
# reload weights from restore_file if specified |
|
if args.restore_file is not None: |
|
manager.load_checkpoints() |
|
|
|
for epoch in range(manager.params.num_epochs): |
|
|
|
# compute number of batches in one epoch (one full pass over the training set) |
|
train(model, manager) |
|
|
|
# Save latest model, or best model weights accroding to the params.major_metric |
|
manager.check_best_save_last_checkpoints(latest_freq_val=999, latest_freq=1) |
|
|
|
if __name__ == '__main__': |
|
# Load the parameters from json file |
|
args = parser.parse_args() |
|
json_path = os.path.join(args.model_dir, 'params.json') |
|
assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path) |
|
params = utils.Params(json_path) |
|
|
|
# Update args into params |
|
params.update(vars(args)) |
|
|
|
# use GPU if available |
|
params.cuda = torch.cuda.is_available() |
|
|
|
# Set the random seed for reproducible experiments |
|
torch.manual_seed(230) |
|
if params.cuda: |
|
torch.cuda.manual_seed(230) |
|
|
|
# Set the logger |
|
logger = utils.set_logger(os.path.join(params.model_dir, 'train.log')) |
|
|
|
# Set the tensorboard writer |
|
log_dir = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") |
|
|
|
# Create the input data pipeline |
|
logger.info("Loading the train datasets from {}".format(params.train_data_dir)) |
|
|
|
# fetch dataloaders |
|
dataloaders = data_loader.fetch_dataloader(params) |
|
|
|
# Define the model and optimizer |
|
if params.cuda: |
|
model = net.fetch_net(params).cuda() |
|
optimizer = optimizer = optim.Adam(model.parameters(), lr=params.learning_rate) |
|
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=params.gamma) |
|
model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) |
|
else: |
|
model = net.fetch_net(params) |
|
optimizer = optimizer = optim.Adam(model.parameters(), lr=params.learning_rate) |
|
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=params.gamma) |
|
|
|
# initial status for checkpoint manager |
|
manager = Manager(model=model, optimizer=optimizer, scheduler=scheduler, params=params, dataloaders=dataloaders, writer=None, logger=logger) |
|
|
|
# Train the model |
|
logger.info("Starting training for {} epoch(s)".format(params.num_epochs)) |
|
|
|
train_and_evaluate(model, manager)
|
|
|