pytorch/torch/cuda/nccl.py
Wes Bland 9c331be919 [pytorch] Remove dot if no suffix (#113273)
Summary: Add the suffix to the version string shouldn't happen if there is no suffix.

Test Plan:
```
/data/users/wbland/fbsource/buck-out/v2/gen/fbcode/param_bench/train/comms/pt/comms.par \
--backend nccl --device cuda --collective all_gather \
--master-ip <snip> --log INFO --b 256 --e 1K \
--num-coll-per-iteration 10 --mode comms--num_iters 5 --w 1 --z 1
...
I1108 07:58:33.852557 2344130 ProcessGroupNCCL.cpp:990] [Rank 0] ProcessGroupNCCL initialization options: NCCL version: 2.17.1, NCCL_ASYNC_ERROR_HANDLING: 3, NCCL_DESYNC_DEBUG: 0, NCCL_ENABLE_TIMING: 0, NCCL_BLOCKING_WAIT: 0, TIMEOUT(ms): 600000, USE_HIGH_PRIORITY_STREAM: 0, TORCH_DISTRIBUTED_DEBUG: OFF, NCCL_DEBUG: OFF, ID=139992854228992
...
```

Differential Revision: D51116095

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113273
Approved by: https://github.com/kwen2501
2023-11-12 15:41:27 +00:00

138 lines
4.0 KiB
Python

import collections
import warnings
from typing import Optional, Sequence, Union
import torch.cuda
__all__ = ["all_reduce", "reduce", "broadcast", "all_gather", "reduce_scatter"]
SUM = 0 # ncclRedOp_t
def is_available(tensors):
if not hasattr(torch._C, "_nccl_all_reduce"):
warnings.warn("PyTorch is not compiled with NCCL support")
return False
devices = set()
for tensor in tensors:
if tensor.is_sparse:
return False
if not tensor.is_contiguous():
return False
if not tensor.is_cuda:
return False
device = tensor.get_device()
if device in devices:
return False
devices.add(device)
return True
def version():
ver = torch._C._nccl_version()
major = ver >> 32
minor = (ver >> 16) & 65535
patch = ver & 65535
suffix = torch._C._nccl_version_suffix().decode("utf-8")
if suffix == "":
return (major, minor, patch)
else:
return (major, minor, patch, suffix)
def unique_id():
return torch._C._nccl_unique_id()
def init_rank(num_ranks, uid, rank):
return torch._C._nccl_init_rank(num_ranks, uid, rank)
def _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None:
if not isinstance(inputs, collections.abc.Container) or isinstance(
inputs, torch.Tensor
):
raise TypeError("Inputs should be a collection of tensors")
def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None):
_check_sequence_type(inputs)
if outputs is None:
outputs = inputs
_check_sequence_type(outputs)
torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms)
# `output` used to be `outputs`, taking in a list of tensors. So we have two
# arguments for BC reasons.
def reduce(
inputs: Sequence[torch.Tensor],
output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None,
root: int = 0,
op: int = SUM,
streams: Optional[Sequence[torch.cuda.Stream]] = None,
comms=None,
*,
outputs: Optional[Sequence[torch.Tensor]] = None,
) -> None:
_check_sequence_type(inputs)
_output: torch.Tensor
if outputs is not None:
if output is not None:
raise ValueError(
"'output' and 'outputs' can not be both specified. 'outputs' is deprecated in "
"favor of 'output', taking in a single output tensor. The signature of reduce is: "
"reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None)."
)
else:
warnings.warn(
"nccl.reduce with an output tensor list is deprecated. "
"Please specify a single output tensor with argument 'output' instead instead."
)
_output = outputs[root]
elif not isinstance(output, torch.Tensor) and isinstance(
output, collections.abc.Sequence
):
# User called old API with positional arguments of list of output tensors.
warnings.warn(
"nccl.reduce with an output tensor list is deprecated. "
"Please specify a single output tensor."
)
_output = output[root]
else:
_output = inputs[root] if output is None else output
torch._C._nccl_reduce(inputs, _output, root, op, streams, comms)
def broadcast(
inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None
) -> None:
_check_sequence_type(inputs)
torch._C._nccl_broadcast(inputs, root, streams, comms)
def all_gather(
inputs: Sequence[torch.Tensor],
outputs: Sequence[torch.Tensor],
streams=None,
comms=None,
) -> None:
_check_sequence_type(inputs)
_check_sequence_type(outputs)
torch._C._nccl_all_gather(inputs, outputs, streams, comms)
def reduce_scatter(
inputs: Sequence[torch.Tensor],
outputs: Sequence[torch.Tensor],
op: int = SUM,
streams=None,
comms=None,
) -> None:
_check_sequence_type(inputs)
_check_sequence_type(outputs)
torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms)