You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
100 lines
2.5 KiB
100 lines
2.5 KiB
# Copyright (c) ByteDance, Inc. and its affiliates. |
|
# All rights reserved. |
|
# |
|
# This source code is licensed under the license found in the |
|
# LICENSE file in the root directory of this source tree. |
|
|
|
import os |
|
from typing import List |
|
from typing import Union |
|
|
|
import torch |
|
import torch.distributed as tdist |
|
import torch.multiprocessing as mp |
|
from torch.distributed import barrier as __barrier |
|
|
|
barrier = __barrier |
|
__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu' |
|
__initialized = False |
|
|
|
|
|
def initialized(): |
|
return __initialized |
|
|
|
|
|
def initialize(backend='nccl'): |
|
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29 |
|
if mp.get_start_method(allow_none=True) is None: |
|
mp.set_start_method('spawn') |
|
global_rank, num_gpus = int(os.environ.get('RANK', 'error')), torch.cuda.device_count() |
|
local_rank = global_rank % num_gpus |
|
torch.cuda.set_device(local_rank) |
|
tdist.init_process_group(backend=backend) |
|
|
|
global __rank, __local_rank, __world_size, __device, __initialized |
|
__local_rank = local_rank |
|
__rank, __world_size = tdist.get_rank(), tdist.get_world_size() |
|
__device = torch.empty(1).cuda().device |
|
__initialized = True |
|
|
|
assert tdist.is_initialized(), 'torch.distributed is not initialized!' |
|
|
|
|
|
def get_rank(): |
|
return __rank |
|
|
|
|
|
def get_local_rank(): |
|
return __local_rank |
|
|
|
|
|
def get_world_size(): |
|
return __world_size |
|
|
|
|
|
def get_device(): |
|
return __device |
|
|
|
|
|
def is_master(): |
|
return __rank == 0 |
|
|
|
|
|
def is_local_master(): |
|
return __local_rank == 0 |
|
|
|
|
|
def parallelize(net, syncbn=False): |
|
if syncbn: |
|
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) |
|
net = net.cuda() |
|
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[get_local_rank()], find_unused_parameters=False, broadcast_buffers=False) |
|
return net |
|
|
|
|
|
def allreduce(t: torch.Tensor) -> None: |
|
if not t.is_cuda: |
|
cu = t.detach().cuda() |
|
tdist.all_reduce(cu) |
|
t.copy_(cu.cpu()) |
|
else: |
|
tdist.all_reduce(t) |
|
|
|
|
|
def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: |
|
if not t.is_cuda: |
|
t = t.cuda() |
|
ls = [torch.empty_like(t) for _ in range(__world_size)] |
|
tdist.all_gather(ls, t) |
|
if cat: |
|
ls = torch.cat(ls, dim=0) |
|
return ls |
|
|
|
|
|
def broadcast(t: torch.Tensor, src_rank) -> None: |
|
if not t.is_cuda: |
|
cu = t.detach().cuda() |
|
tdist.broadcast(cu, src=src_rank) |
|
t.copy_(cu.cpu()) |
|
else: |
|
tdist.broadcast(t, src=src_rank)
|
|
|