import os import ctypes import warnings import torch.cuda from torch.backends.cudnn import int_array lib = None __all__ = ['all_reduce', 'reduce', 'broadcast', 'all_gather', 'reduce_scatter'] def _libnccl(): global lib if lib is None: lib = ctypes.cdll.LoadLibrary(None) if hasattr(lib, 'ncclCommDestroy'): lib.ncclCommDestroy.restype = None else: lib = None return lib def is_available(tensors): devices = set() for tensor in tensors: if not tensor.is_contiguous(): return False if not tensor.is_cuda: return False device = tensor.get_device() if device in devices: return False devices.add(device) if _libnccl() is None: warnings.warn('NCCL library not found. Check your LD_LIBRARY_PATH') return False return True _communicators = {} # ncclDataType_t ncclChar = 0 ncclInt = 1 ncclHalf = 2 ncclFloat = 3 ncclDouble = 4 ncclInt64 = 5 ncclUint64 = 6 # ncclRedOp_t SUM = 0 PROD = 1 MAX = 2 MIN = 3 status_codes = { 0: "Success", 1: "Unhandled Cuda Error", 2: "System Error", 3: "Internal Error", 4: "Invalid Device Pointer", 5: "Invalid Rank", 6: "Unsupported Device Count", 7: "Device Not Found", 8: "Invalid Device Index", 9: "Lib Wrapper Not Set", 10: "Cuda Malloc Failed", 11: "Rank Mismatch", 12: "Invalid Argument", 13: "Invalid Type", 14: "Invalid Operation", } nccl_types = { 'torch.cuda.ByteTensor': ncclChar, 'torch.cuda.CharTensor': ncclChar, 'torch.cuda.IntTensor': ncclInt, 'torch.cuda.HalfTensor': ncclHalf, 'torch.cuda.FloatTensor': ncclFloat, 'torch.cuda.DoubleTensor': ncclDouble, 'torch.cuda.LongTensor': ncclInt64, } class NcclError(RuntimeError): def __init__(self, status): self.status = status msg = '{0} ({1})'.format(status_codes.get(status), status) super(NcclError, self).__init__(msg) class NcclComm(ctypes.c_void_p): pass class NcclCommList(object): def __init__(self, devices): self.devices = devices ptrs = (NcclComm * len(devices))() self._as_parameter_ = ptrs check_error(lib.ncclCommInitAll(self, len(devices), int_array(devices))) def __getitem__(self, i): return self._as_parameter_[i] def __del__(self): for i in range(len(self.devices)): lib.ncclCommDestroy(self[i]) def check_error(status): if status != 0: raise NcclError(status) def communicator(inputs, outputs=None): if _libnccl() is None: raise RuntimeError('Unable to load NCCL library') devices = [input.get_device() for input in inputs] if outputs is not None: for device, output in zip(devices, outputs): if output.get_device() != device: raise ValueError("inputs and outputs must be on the same devices") key = ','.join(str(d) for d in devices) if key not in _communicators: _communicators[key] = NcclCommList(devices) return _communicators[key] def cudaStream(): # TODO: return the current stream # ffi.C.THCState_getCurrentStream(cutorch.getState()) return None def all_reduce(inputs, outputs=None, op=SUM): if outputs is None: outputs = inputs _check_inputs(inputs, outputs) comm = communicator(inputs, outputs) count = inputs[0].numel() data_type = nccl_types[inputs[0].type()] with torch.cuda._free_mutex(): for i in range(len(inputs)): with torch.cuda.device(comm.devices[i]): check_error(lib.ncclAllReduce( ctypes.c_void_p(inputs[i].data_ptr()), ctypes.c_void_p(outputs[i].data_ptr()), count, data_type, op, comm[i], cudaStream())) def reduce(inputs, outputs=None, root=0, op=SUM): assert(root >= 0 and root < len(inputs)) if outputs is None: outputs = inputs _check_inputs(inputs, outputs) comm = communicator(inputs) count = inputs[0].numel() data_type = nccl_types[inputs[0].type()] with torch.cuda._free_mutex(): for i in range(len(inputs)): with torch.cuda.device(comm.devices[i]): check_error(lib.ncclReduce( ctypes.c_void_p(inputs[i].data_ptr()), ctypes.c_void_p(outputs[i].data_ptr()), count, data_type, op, root, comm[i], cudaStream())) def broadcast(inputs, root=0): assert(root >= 0 and root < len(inputs)) _check_inputs(inputs, inputs) comm = communicator(inputs) count = inputs[0].numel() data_type = nccl_types[inputs[0].type()] with torch.cuda._free_mutex(): for i in range(len(inputs)): with torch.cuda.device(comm.devices[i]): check_error(lib.ncclBcast( ctypes.c_void_p(inputs[i].data_ptr()), count, data_type, root, comm[i], cudaStream())) def all_gather(inputs, outputs): _check_inputs(inputs, outputs, len(inputs)) comm = communicator(inputs, outputs) count = inputs[0].numel() data_type = nccl_types[inputs[0].type()] with torch.cuda._free_mutex(): for i in range(len(inputs)): with torch.cuda.device(comm.devices[i]): check_error(lib.ncclAllGather( ctypes.c_void_p(inputs[i].data_ptr()), count, data_type, ctypes.c_void_p(outputs[i].data_ptr()), comm[i], cudaStream())) def reduce_scatter(inputs, outputs, op=SUM): _check_inputs(inputs, outputs, 1.0 / len(inputs)) comm = communicator(inputs, outputs) count = inputs[0].numel() // len(inputs) data_type = nccl_types[inputs[0].type()] with torch.cuda._free_mutex(): for i in range(len(inputs)): with torch.cuda.device(comm.devices[i]): check_error(lib.ncclReduceScatter( ctypes.c_void_p(inputs[i].data_ptr()), ctypes.c_void_p(outputs[i].data_ptr()), count, data_type, op, comm[i], cudaStream())) def _check_inputs(inputs, outputs=None, size_multiplier=1): devices = set() size = inputs[0].numel() if len(inputs) != len(outputs): raise ValueError('inputs and outputs must be the same length') for input, output in zip(inputs, outputs): if not input.is_cuda: raise TypeError('inputs must be CUDA inputs') if not input.is_contiguous(): raise ValueError('inputs must be contiguous') device = input.get_device() if device in devices: raise ValueError('inputs must be on unique devices') devices.add(device) if input.numel() != size: raise ValueError('inputs must be the same size') if not output.is_contiguous(): raise ValueError('outputs must be contiguous') if output.get_device() != device: raise ValueError('inputs and outputs must be on the same devices') if output.numel() != size * size_multiplier: raise ValueError(('incorrect output size; expected {0} but got {1}' .format(size * size_multiplier, output.numel())))