# Copyright (c) ByteDance, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os from typing import List from typing import Union import torch import torch.distributed as tdist import torch.multiprocessing as mp from torch.distributed import barrier as __barrier barrier = __barrier __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu' __initialized = False def initialized(): return __initialized def initialize(backend='nccl'): # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29 if mp.get_start_method(allow_none=True) is None: mp.set_start_method('spawn') global_rank, num_gpus = int(os.environ.get('RANK', 'error')), torch.cuda.device_count() local_rank = global_rank % num_gpus torch.cuda.set_device(local_rank) tdist.init_process_group(backend=backend) global __rank, __local_rank, __world_size, __device, __initialized __local_rank = local_rank __rank, __world_size = tdist.get_rank(), tdist.get_world_size() __device = torch.empty(1).cuda().device __initialized = True assert tdist.is_initialized(), 'torch.distributed is not initialized!' def get_rank(): return __rank def get_local_rank(): return __local_rank def get_world_size(): return __world_size def get_device(): return __device def is_master(): return __rank == 0 def is_local_master(): return __local_rank == 0 def parallelize(net, syncbn=False): if syncbn: net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) net = net.cuda() net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[get_local_rank()], find_unused_parameters=False, broadcast_buffers=False) return net def allreduce(t: torch.Tensor) -> None: if not t.is_cuda: cu = t.detach().cuda() tdist.all_reduce(cu) t.copy_(cu.cpu()) else: tdist.all_reduce(t) def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: if not t.is_cuda: t = t.cuda() ls = [torch.empty_like(t) for _ in range(__world_size)] tdist.all_gather(ls, t) if cat: ls = torch.cat(ls, dim=0) return ls def broadcast(t: torch.Tensor, src_rank) -> None: if not t.is_cuda: cu = t.detach().cuda() tdist.broadcast(cu, src=src_rank) t.copy_(cu.cpu()) else: tdist.broadcast(t, src=src_rank)