You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
273 lines
9.0 KiB
273 lines
9.0 KiB
# Copyright (c) ByteDance, Inc. and its affiliates. |
|
# All rights reserved. |
|
# |
|
# This source code is licensed under the license found in the |
|
# LICENSE file in the root directory of this source tree. |
|
|
|
import datetime |
|
import functools |
|
import os |
|
import subprocess |
|
import sys |
|
import time |
|
from collections import defaultdict, deque |
|
from typing import Iterator |
|
|
|
import numpy as np |
|
import pytz |
|
import torch |
|
import torch.distributed as tdist |
|
|
|
import dist |
|
|
|
os_system = functools.partial(subprocess.call, shell=True) |
|
os_system_get_stdout = lambda cmd: subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8') |
|
def os_system_get_stdout_stderr(cmd): |
|
sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
|
return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8') |
|
|
|
|
|
def is_pow2n(x): |
|
return x > 0 and (x & (x - 1) == 0) |
|
|
|
|
|
def time_str(): |
|
return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]') |
|
|
|
|
|
def init_distributed_environ(exp_dir): |
|
dist.initialize() |
|
dist.barrier() |
|
|
|
import torch.backends.cudnn as cudnn |
|
cudnn.benchmark = True |
|
cudnn.deterministic = False |
|
|
|
_set_print_only_on_master_proc(is_master=dist.is_local_master()) |
|
if dist.is_local_master() and len(exp_dir): |
|
sys.stdout, sys.stderr = _SyncPrintToFile(exp_dir, stdout=True), _SyncPrintToFile(exp_dir, stdout=False) |
|
|
|
|
|
def save_checkpoint(fname, args, epoch, performance_desc, model_without_ddp_state, optimizer_state): |
|
checkpoint_path = os.path.join(args.exp_dir, fname) |
|
if dist.is_local_master(): |
|
to_save = { |
|
'args': str(args), |
|
'arch': args.model, |
|
'epoch': epoch, |
|
'performance_desc': performance_desc, |
|
'module': model_without_ddp_state, |
|
'optimizer': optimizer_state, |
|
} |
|
torch.save(to_save, checkpoint_path) |
|
dist.barrier() |
|
|
|
|
|
def load_checkpoint(fname, model_without_ddp, optimizer): |
|
print(f'[try to resume from file `{fname}`]') |
|
checkpoint = torch.load(fname, map_location='cpu') |
|
|
|
next_ep, performance_desc = checkpoint['epoch'] + 1, checkpoint['performance_desc'] |
|
missing, unexpected = model_without_ddp.load_state_dict(checkpoint['module'], strict=False) |
|
print(f'[load_checkpoint] missing_keys={missing}') |
|
print(f'[load_checkpoint] unexpected_keys={unexpected}') |
|
print(f'[load_checkpoint] next_ep={next_ep}, performance_desc={performance_desc}') |
|
|
|
if 'optimizer' in checkpoint: |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
|
|
return next_ep, performance_desc |
|
|
|
|
|
class SmoothedValue(object): |
|
"""Track a series of values and provide access to smoothed values over a |
|
window or the global series average. |
|
""" |
|
|
|
def __init__(self, window_size=20, fmt=None): |
|
if fmt is None: |
|
fmt = "{median:.4f} ({global_avg:.4f})" |
|
self.deque = deque(maxlen=window_size) |
|
self.total = 0.0 |
|
self.count = 0 |
|
self.fmt = fmt |
|
|
|
def update(self, value, n=1): |
|
self.deque.append(value) |
|
self.count += n |
|
self.total += value * n |
|
|
|
def synchronize_between_processes(self): |
|
""" |
|
Warning: does not synchronize the deque! |
|
""" |
|
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') |
|
tdist.barrier() |
|
tdist.all_reduce(t) |
|
t = t.tolist() |
|
self.count = int(t[0]) |
|
self.total = t[1] |
|
|
|
@property |
|
def median(self): |
|
d = torch.tensor(list(self.deque)) |
|
return d.median().item() |
|
|
|
@property |
|
def avg(self): |
|
d = torch.tensor(list(self.deque), dtype=torch.float32) |
|
return d.mean().item() |
|
|
|
@property |
|
def global_avg(self): |
|
return self.total / self.count |
|
|
|
@property |
|
def max(self): |
|
return max(self.deque) |
|
|
|
@property |
|
def value(self): |
|
return self.deque[-1] |
|
|
|
def time_preds(self, counts): |
|
remain_secs = counts * self.median |
|
remain_time = datetime.timedelta(seconds=round(remain_secs)) |
|
finish_time = time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs)) |
|
return remain_secs, str(remain_time), finish_time |
|
|
|
def __str__(self): |
|
return self.fmt.format( |
|
median=self.median, |
|
avg=self.avg, |
|
global_avg=self.global_avg, |
|
max=self.max, |
|
value=self.value) |
|
|
|
|
|
class MetricLogger(object): |
|
def __init__(self, delimiter="\t"): |
|
self.meters = defaultdict(SmoothedValue) |
|
self.delimiter = delimiter |
|
|
|
def update(self, **kwargs): |
|
for k, v in kwargs.items(): |
|
if v is None: |
|
continue |
|
if isinstance(v, torch.Tensor): |
|
v = v.item() |
|
assert isinstance(v, (float, int)) |
|
self.meters[k].update(v) |
|
|
|
def __getattr__(self, attr): |
|
if attr in self.meters: |
|
return self.meters[attr] |
|
if attr in self.__dict__: |
|
return self.__dict__[attr] |
|
raise AttributeError("'{}' object has no attribute '{}'".format( |
|
type(self).__name__, attr)) |
|
|
|
def __str__(self): |
|
loss_str = [] |
|
for name, meter in self.meters.items(): |
|
loss_str.append( |
|
"{}: {}".format(name, str(meter)) |
|
) |
|
return self.delimiter.join(loss_str) |
|
|
|
def synchronize_between_processes(self): |
|
for meter in self.meters.values(): |
|
meter.synchronize_between_processes() |
|
|
|
def add_meter(self, name, meter): |
|
self.meters[name] = meter |
|
|
|
def log_every(self, max_iters, itrt, print_freq, header=None): |
|
print_iters = set(np.linspace(0, max_iters-1, print_freq, dtype=int).tolist()) |
|
if not header: |
|
header = '' |
|
start_time = time.time() |
|
end = time.time() |
|
self.iter_time = SmoothedValue(fmt='{avg:.4f}') |
|
self.data_time = SmoothedValue(fmt='{avg:.4f}') |
|
space_fmt = ':' + str(len(str(max_iters))) + 'd' |
|
log_msg = [ |
|
header, |
|
'[{0' + space_fmt + '}/{1}]', |
|
'eta: {eta}', |
|
'{meters}', |
|
'time: {time}', |
|
'data: {data}' |
|
] |
|
log_msg = self.delimiter.join(log_msg) |
|
|
|
if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'): |
|
for i in range(max_iters): |
|
obj = next(itrt) |
|
self.data_time.update(time.time() - end) |
|
yield obj |
|
self.iter_time.update(time.time() - end) |
|
if i in print_iters: |
|
eta_seconds = self.iter_time.global_avg * (max_iters - i) |
|
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) |
|
print(log_msg.format( |
|
i, max_iters, eta=eta_string, |
|
meters=str(self), |
|
time=str(self.iter_time), data=str(self.data_time))) |
|
end = time.time() |
|
else: |
|
for i, obj in enumerate(itrt): |
|
self.data_time.update(time.time() - end) |
|
yield obj |
|
self.iter_time.update(time.time() - end) |
|
if i in print_iters: |
|
eta_seconds = self.iter_time.global_avg * (max_iters - i) |
|
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) |
|
print(log_msg.format( |
|
i, max_iters, eta=eta_string, |
|
meters=str(self), |
|
time=str(self.iter_time), data=str(self.data_time))) |
|
end = time.time() |
|
|
|
total_time = time.time() - start_time |
|
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
print('{} Total time: {} ({:.3f} s / it)'.format( |
|
header, total_time_str, total_time / max_iters)) |
|
|
|
|
|
def _set_print_only_on_master_proc(is_master): |
|
import builtins as __builtin__ |
|
|
|
builtin_print = __builtin__.print |
|
|
|
def prt(msg, *args, **kwargs): |
|
force = kwargs.pop('force', False) |
|
clean = kwargs.pop('clean', False) |
|
deeper = kwargs.pop('deeper', False) |
|
if is_master or force: |
|
if not clean: |
|
f_back = sys._getframe().f_back |
|
if deeper and f_back.f_back is not None: |
|
f_back = f_back.f_back |
|
file_desc = f'{f_back.f_code.co_filename:24s}'[-24:] |
|
msg = f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=> {msg}' |
|
builtin_print(msg, *args, **kwargs) |
|
|
|
__builtin__.print = prt |
|
|
|
|
|
class _SyncPrintToFile(object): |
|
def __init__(self, exp_dir, stdout=True): |
|
self.terminal = sys.stdout if stdout else sys.stderr |
|
fname = os.path.join(exp_dir, 'stdout.txt' if stdout else 'stderr.txt') |
|
self.log = open(fname, 'w') |
|
self.log.flush() |
|
|
|
def write(self, message): |
|
self.terminal.write(message) |
|
self.log.write(message) |
|
self.log.flush() |
|
|
|
def flush(self): |
|
self.terminal.flush() |
|
self.log.flush()
|
|
|