You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
143 lines
3.4 KiB
143 lines
3.4 KiB
2 years ago
|
# Copyright (c) ByteDance, Inc. and its affiliates.
|
||
|
# All rights reserved.
|
||
|
#
|
||
|
# This source code is licensed under the license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
import functools
|
||
|
import os
|
||
|
from typing import List
|
||
|
from typing import Union
|
||
|
|
||
|
import torch
|
||
|
import torch.distributed as tdist
|
||
|
import torch.multiprocessing as mp
|
||
|
|
||
|
__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu'
|
||
|
__initialized = False
|
||
|
|
||
|
|
||
|
def initialized():
|
||
|
return __initialized
|
||
|
|
||
|
|
||
|
def initialize(backend='nccl'):
|
||
|
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
|
||
|
if mp.get_start_method(allow_none=True) is None:
|
||
|
mp.set_start_method('spawn')
|
||
|
global_rank, num_gpus = int(os.environ.get('RANK', 'error')), torch.cuda.device_count()
|
||
|
local_rank = global_rank % num_gpus
|
||
|
torch.cuda.set_device(local_rank)
|
||
|
tdist.init_process_group(backend=backend) # 不要 init_method='env://'
|
||
|
|
||
|
global __rank, __local_rank, __world_size, __device, __initialized
|
||
|
__local_rank = local_rank
|
||
|
__rank, __world_size = tdist.get_rank(), tdist.get_world_size()
|
||
|
__device = torch.empty(1).cuda().device
|
||
|
__initialized = True
|
||
|
|
||
|
assert tdist.is_initialized(), 'torch.distributed is not initialized!'
|
||
|
|
||
|
|
||
|
def get_rank():
|
||
|
return __rank
|
||
|
|
||
|
|
||
|
def get_local_rank():
|
||
|
return __local_rank
|
||
|
|
||
|
|
||
|
def get_world_size():
|
||
|
return __world_size
|
||
|
|
||
|
|
||
|
def get_device():
|
||
|
return __device
|
||
|
|
||
|
|
||
|
def is_master():
|
||
|
return __rank == 0
|
||
|
|
||
|
|
||
|
def is_local_master():
|
||
|
return __local_rank == 0
|
||
|
|
||
|
|
||
|
def parallelize(net, syncbn=False):
|
||
|
if syncbn:
|
||
|
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
|
||
|
net = net.cuda()
|
||
|
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
|
||
|
return net
|
||
|
|
||
|
|
||
|
def new_group(ranks: List[int]):
|
||
|
return tdist.new_group(ranks=ranks)
|
||
|
|
||
|
|
||
|
def barrier():
|
||
|
tdist.barrier()
|
||
|
|
||
|
|
||
|
def allreduce(t: torch.Tensor) -> None:
|
||
|
if not t.is_cuda:
|
||
|
cu = t.detach().cuda()
|
||
|
tdist.all_reduce(cu)
|
||
|
t.copy_(cu.cpu())
|
||
|
else:
|
||
|
tdist.all_reduce(t)
|
||
|
|
||
|
|
||
|
def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
|
||
|
if not t.is_cuda:
|
||
|
t = t.cuda()
|
||
|
ls = [torch.empty_like(t) for _ in range(__world_size)]
|
||
|
tdist.all_gather(ls, t)
|
||
|
if cat:
|
||
|
ls = torch.cat(ls, dim=0)
|
||
|
return ls
|
||
|
|
||
|
|
||
|
def broadcast(t: torch.Tensor, src_rank) -> None:
|
||
|
if not t.is_cuda:
|
||
|
cu = t.detach().cuda()
|
||
|
tdist.broadcast(cu, src=src_rank)
|
||
|
t.copy_(cu.cpu())
|
||
|
else:
|
||
|
tdist.broadcast(t, src=src_rank)
|
||
|
|
||
|
|
||
|
def dist_fmt_vals(val, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]:
|
||
|
ts = torch.zeros(__world_size)
|
||
|
ts[__rank] = val
|
||
|
allreduce(ts)
|
||
|
if fmt is None:
|
||
|
return ts
|
||
|
return [fmt % v for v in ts.cpu().numpy().tolist()]
|
||
|
|
||
|
|
||
|
def master_only(func):
|
||
|
@functools.wraps(func)
|
||
|
def wrapper(*args, **kwargs):
|
||
|
force = kwargs.pop('force', False)
|
||
|
if force or is_master():
|
||
|
ret = func(*args, **kwargs)
|
||
|
else:
|
||
|
ret = None
|
||
|
barrier()
|
||
|
return ret
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
def local_master_only(func):
|
||
|
@functools.wraps(func)
|
||
|
def wrapper(*args, **kwargs):
|
||
|
force = kwargs.pop('force', False)
|
||
|
if force or is_local_master():
|
||
|
ret = func(*args, **kwargs)
|
||
|
else:
|
||
|
ret = None
|
||
|
barrier()
|
||
|
return ret
|
||
|
return wrapper
|