pytorch/torch/distributed/nn/functional.py
Junjie Wang 7c2489bdae [PyTorch][Distributed] Enable Reduce Scatter and modify all_to_all for sharded linear with more test cases. (#68786)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68786

To enable the auto grad for the sharded linear, we find we need to make some changes to the current nn function api (c10d api with auto grad enabled). So we made the following several changes:

1. Add a new api `reduce_scatter` since we need it in the rowwise sharding.
2. Modify the `all_to_all` api to make sure it consistent with the ones in distributed_c10d.py.
3. Found the cpp input params of `reduce_scatter` is missing input param, added more unit test to cover these cases.
4. Sync the NN test from gloo to nccl.
ghstack-source-id: 144860208

Test Plan: CI + Unit Test

Reviewed By: pritamdamania87

Differential Revision: D32569674

fbshipit-source-id: 9bd613f91bbf7a39eede0af32a5a5db0f2ade43b
2021-12-06 13:38:58 -08:00

368 lines
11 KiB
Python

import torch
import torch.distributed as dist
from torch.autograd import Function
def broadcast(tensor, src, group=dist.group.WORLD):
"""
Broadcasts the tensor to the whole group.
``tensor`` must have the same number of elements in all processes
participating in the collective.
Arguments:
tensor (Tensor): Data to be sent if ``src`` is the rank of current
process.
src (int): Source rank.
group (ProcessGroup, optional): The process group to work on.
Returns:
Tensor: Received tensor from the broadcast op.
"""
return _Broadcast.apply(src, group, tensor)
def gather(tensor, dst=0, group=dist.group.WORLD):
"""
Gathers a list of tensors in a single process.
Arguments:
tensor (Tensor): Input tensor.
dst (int, optional): Destination rank (default is 0).
group (ProcessGroup, optional): The process group to work on.
Returns:
tuple[Tensor]: List of appropriately-sized tensors with the gathered data.
"""
return _Gather.apply(dst, group, tensor)
def scatter(tensors, src=0, group=dist.group.WORLD):
"""
Scatters a list of tensors to all processes in a group.
Each process will receive exactly one tensor and store its data in the
``tensor`` argument.
Arguments:
tensors (list[Tensor]): List of tensors to scatter on the source rank.
Receivers must pass ``None`.
src (int, optional): Source rank (default is 0).
group (ProcessGroup, optional): The process group to work on.
Returns:
Tensor: Output tensor from the scatter operation.
"""
return _Scatter.apply(src, group, *tensors)
def reduce(tensor, dst, op=dist.ReduceOp.SUM, group=dist.group.WORLD):
"""
Reduces the tensor data across all machines.
Only the process with rank ``dst`` is going to receive the final result.
Arguments:
tensor (Tensor): Input of the collective.
dst (int): Destination rank.
op (optional): One of the values from
``torch.distributed.ReduceOp``
enum. Specifies an operation used for element-wise reductions.
group (ProcessGroup, optional): The process group to work on.
Returns:
Tensor: Output of the collective.
"""
return _Reduce.apply(dst, op, group, tensor)
def reduce_scatter(output, input_list, op=dist.ReduceOp.SUM, group=dist.group.WORLD):
"""
Reduces, then scatters a list of tensors to all processes in a group.
Arguments:
output (Tensor): Output tensor.
input_list (list[Tensor]): List of tensors to reduce and scatter.
op (optional): One of the values from
``torch.distributed.ReduceOp``
enum. Specifies an operation used for element-wise reductions.
group (ProcessGroup, optional): The process group to work on.
Returns:
Tensor: Output of the collective.
"""
return _Reduce_Scatter.apply(op, group, output, *input_list)
def all_gather(tensor, group=dist.group.WORLD):
"""
Gathers tensors from the whole group in a list.
Arguments:
tensor (Tensor): Tensor to be broadcast from current process.
group (ProcessGroup, optional): The process group to work on.
Returns:
tuple([Tensor]): Output of the collective.
"""
return _AllGather.apply(group, tensor)
def all_to_all(output_tensor_list, input_tensor_list, group=dist.group.WORLD):
"""
Each process scatters list of input tensors to all processes in a group and
return gathered list of tensors in output list.
Arguments:
out_tensor_list (list[Tensor]): list of tensors to gather one per rank.
input_tensor_list (list[Tensor]): List of tensors to scatter one per rank.
group (ProcessGroup, optional): The process group to work on.
Returns:
tuple([Tensor]): Output of the collective.
"""
return _AlltoAll.apply(group, output_tensor_list, *input_tensor_list)
def all_to_all_single(
output,
input,
output_split_sizes=None,
input_split_sizes=None,
group=dist.group.WORLD,
):
"""
Each process splits input tensor and then scatters the split list
to all processes in a group. Then concatenate the received tensors from all
the processes in the group and return single output tensor.
Arguments:
output (Tensor): Gathered cancatenated output tensor.
input (Tensor): Input tensor to scatter.
output_split_sizes: (list[Int], optional): Output split sizes for dim 0
if specified None or empty, dim 0 of ``output`` tensor must divide
equally by ``world_size``.
input_split_sizes: (list[Int], optional): Input split sizes for dim 0
if specified None or empty, dim 0 of ``input`` tensor must divide
equally by ``world_size``.
Returns:
Tensor: Output of the collective.
"""
return _AlltoAllSingle.apply(
group, output, output_split_sizes, input_split_sizes, input
)
def all_reduce(tensor, op=dist.ReduceOp.SUM, group=dist.group.WORLD):
"""
Reduces the tensor data across all machines in such a way that all get
the final result.
After the call the returned tensor is going to be bitwise
identical in all processes.
Arguments:
tensor (Tensor): Input of the collective.
op (optional): One of the values from
``torch.distributed.ReduceOp``
enum. Specifies an operation used for element-wise reductions.
group (ProcessGroup, optional): The process group to work on.
Returns:
Tensor: Output of the collective
"""
return _AllReduce.apply(op, group, tensor)
class _Broadcast(Function):
@staticmethod
def forward(ctx, src, group, tensor):
ctx.src = src
ctx.group = group
ctx.rank = dist.get_rank()
# torch.distributed makes all the calls in place
# we allocate new tensors to avoid this
tensor = tensor.clone()
dist.broadcast(tensor, src, group=group)
return tensor
@staticmethod
def backward(ctx, grad_output):
gx = _Reduce.apply(ctx.src, dist.ReduceOp.SUM, ctx.group, grad_output)
if ctx.src != ctx.rank:
gx.zero_()
return (None, None, gx)
class _Gather(Function):
@staticmethod
def forward(ctx, dst, group, tensor):
ctx.dst = dst
ctx.group = group
# Need to create a list of tensors here to do the
# aggregation, get it from the group size
# tensor should be correctly sized for the method
# gathering
tensor_list = [
torch.zeros_like(tensor) for i in range(dist.get_world_size(group=group))
]
if dist.get_rank(group=group) == dst:
dist.gather(tensor, tensor_list, dst, group=group)
else:
dist.gather(tensor, None, dst, group=group)
return tuple(tensor_list)
@staticmethod
def backward(ctx, *grad_outputs):
return (None, None) + (_Scatter.apply(ctx.dst, ctx.group, *grad_outputs),)
class _Scatter(Function):
@staticmethod
def forward(ctx, src, group, *tensors):
ctx.src = src
ctx.group = group
assert all(t.size() == tensors[0].size() for t in tensors)
output = torch.zeros_like(tensors[0])
if dist.get_rank(group=group) == src:
dist.scatter(output, list(tensors), src, group=group)
else:
dist.scatter(output, None, src, group=group)
return output
@staticmethod
def backward(ctx, grad_output):
return (None, None) + _Gather.apply(ctx.src, ctx.group, grad_output)
class _Reduce(Function):
@staticmethod
def forward(ctx, src, op, group, tensor):
ctx.src = src
ctx.group = group
tensor = tensor.clone()
dist.reduce(tensor, src, op=op, group=group)
return tensor
@staticmethod
def backward(ctx, grad_output):
return (None, None, None) + (_Broadcast.apply(ctx.src, ctx.group, grad_output),)
class _Reduce_Scatter(Function):
@staticmethod
def forward(ctx, op, group, tensor, *input_tensor_list):
ctx.group = group
dist.reduce_scatter(tensor, list(input_tensor_list), op=op, group=group)
return tensor
@staticmethod
def backward(ctx, grad_output):
return (None, None, None) + _AllGather.apply(
ctx.group, grad_output.contiguous()
)
class _AllGather(Function):
@staticmethod
def forward(ctx, group, tensor):
ctx.group = group
out_tensor_list = [
torch.empty_like(tensor) for i in range(dist.get_world_size(group=group))
]
dist.all_gather(out_tensor_list, tensor, group=group)
return tuple(out_tensor_list)
@staticmethod
def backward(ctx, *grad_outputs):
tensor_list = [torch.empty_like(tensor) for tensor in grad_outputs]
gxs = _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs)
gx = torch.sum(torch.stack(gxs), dim=0)
return (None, gx)
class _AlltoAll(Function):
@staticmethod
def forward(ctx, group, out_tensor_list, *tensors):
ctx.group = group
ctx.input_tensor_size_list = [
tensors[i].size() for i in range(dist.get_world_size(group=group))
]
my_rank = dist.get_rank(group=group)
# Implement it on means of scatter/gather, send/recv async operations have issues
if dist.get_backend(group=group) is dist.Backend.GLOO:
for i in range(dist.get_world_size(group=group)):
to_send = None
if i == my_rank:
to_send = list(tensors)
dist.scatter(out_tensor_list[i], to_send, i, group=group)
else:
dist.all_to_all(
out_tensor_list,
list(tensors),
group=group,
)
return tuple(out_tensor_list)
@staticmethod
def backward(ctx, *grad_outputs):
tensor_list = [
torch.empty(size, device=grad_outputs[0].device)
for size in ctx.input_tensor_size_list
]
grad_outputs = tuple(tensor.contiguous() for tensor in grad_outputs)
return (None, None) + _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs)
class _AlltoAllSingle(Function):
@staticmethod
def forward(ctx, group, output, output_split_sizes, input_split_sizes, input):
ctx.group = group
ctx.input_size = input.size()
ctx.output_split_sizes = input_split_sizes
ctx.input_split_sizes = output_split_sizes
dist.all_to_all_single(
output,
input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
)
return output
@staticmethod
def backward(ctx, grad_output):
tensor = torch.empty(ctx.input_size, device=grad_output.device)
return (None, None, None, None) + (
_AlltoAllSingle.apply(
ctx.group,
tensor,
ctx.output_split_sizes,
ctx.input_split_sizes,
grad_output.contiguous(),
),
)
class _AllReduce(Function):
@staticmethod
def forward(ctx, op, group, tensor):
ctx.group = group
ctx.op = op
tensor = tensor.clone()
dist.all_reduce(tensor, op=op, group=group)
return tensor
@staticmethod
def backward(ctx, grad_output):
return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),)