Source code for easycv.utils.dist_utils

# Copyright (c) Alibaba, Inc. and its affiliates.
import functools
import os
import pickle
from collections import OrderedDict
from contextlib import contextmanager

import torch
import torch.distributed as dist
from mmcv.parallel import data_parallel as mm_data_parallel
from mmcv.parallel import distributed as mm_distributed
from mmcv.runner.dist_utils import get_dist_info
from torch import nn
from torch.distributed import ReduceOp


[docs]def is_master(): rank, _ = get_dist_info() return rank == 0
[docs]def local_rank(): return int(os.environ.get('LOCAL_RANK', 0))
[docs]@contextmanager def dist_zero_exec(rank=local_rank()): if rank not in [-1, 0]: barrier() # execute the context after yield, then return here to continue yield if rank == 0: barrier()
[docs]def get_num_gpu_per_node(): """ get number of gpu per node """ rank, world_size = get_dist_info() if world_size == 1: return 1 local_rank = int(os.environ.get('LOCAL_RANK', '0')) local_rank_tensor = torch.tensor([local_rank], device='cuda') torch.distributed.all_reduce(local_rank_tensor, op=ReduceOp.MAX) num_gpus = local_rank_tensor.tolist()[0] + 1 return num_gpus
[docs]def barrier(): if torch.distributed.is_available() and torch.distributed.is_initialized(): torch.distributed.barrier()
[docs]def is_parallel(model): return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel, mm_data_parallel.MMDataParallel, mm_distributed.MMDistributedDataParallel)
# For YOLOX
[docs]def obj2tensor(pyobj, device='cuda'): """Serialize picklable python object to tensor.""" storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj)) return torch.ByteTensor(storage).to(device=device)
[docs]def tensor2obj(tensor): """Deserialize tensor to picklable python object.""" return pickle.loads(tensor.cpu().numpy().tobytes())
@functools.lru_cache() def _get_global_gloo_group(): """Return a process group based on gloo backend, containing all the ranks The result is cached.""" if dist.get_backend() == 'nccl': return dist.new_group(backend='gloo') else: return dist.group.WORLD
[docs]def all_reduce_dict(py_dict, op='sum', group=None, to_float=True): """Apply all reduce function for python dict object. The code is modified from https://github.com/Megvii- BaseDetection/YOLOX/blob/main/yolox/utils/allreduce_norm.py. NOTE: make sure that py_dict in different ranks has the same keys and the values should be in the same shape. Args: py_dict (dict): Dict to be applied all reduce op. op (str): Operator, could be 'sum' or 'mean'. Default: 'sum' group (:obj:`torch.distributed.group`, optional): Distributed group, Default: None. to_float (bool): Whether to convert all values of dict to float. Default: True. Returns: OrderedDict: reduced python dict object. """ _, world_size = get_dist_info() if world_size == 1: return py_dict if group is None: # TODO: May try not to use gloo in the future group = _get_global_gloo_group() if dist.get_world_size(group) == 1: return py_dict # all reduce logic across different devices. py_key = list(py_dict.keys()) py_key_tensor = obj2tensor(py_key) dist.broadcast(py_key_tensor, src=0) py_key = tensor2obj(py_key_tensor) tensor_shapes = [py_dict[k].shape for k in py_key] tensor_numels = [py_dict[k].numel() for k in py_key] if to_float: flatten_tensor = torch.cat( [py_dict[k].flatten().float() for k in py_key]) else: flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key]) dist.all_reduce(flatten_tensor, op=dist.ReduceOp.SUM) if op == 'mean': flatten_tensor /= world_size split_tensors = [ x.reshape(shape) for x, shape in zip( torch.split(flatten_tensor, tensor_numels), tensor_shapes) ] return OrderedDict({k: v for k, v in zip(py_key, split_tensors)})