mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68786 To enable the auto grad for the sharded linear, we find we need to make some changes to the current nn function api (c10d api with auto grad enabled). So we made the following several changes: 1. Add a new api `reduce_scatter` since we need it in the rowwise sharding. 2. Modify the `all_to_all` api to make sure it consistent with the ones in distributed_c10d.py. 3. Found the cpp input params of `reduce_scatter` is missing input param, added more unit test to cover these cases. 4. Sync the NN test from gloo to nccl. ghstack-source-id: 144860208 Test Plan: CI + Unit Test Reviewed By: pritamdamania87 Differential Revision: D32569674 fbshipit-source-id: 9bd613f91bbf7a39eede0af32a5a5db0f2ade43b
368 lines
11 KiB
Python
368 lines
11 KiB
Python
import torch
|
|
import torch.distributed as dist
|
|
from torch.autograd import Function
|
|
|
|
|
|
def broadcast(tensor, src, group=dist.group.WORLD):
|
|
"""
|
|
Broadcasts the tensor to the whole group.
|
|
|
|
``tensor`` must have the same number of elements in all processes
|
|
participating in the collective.
|
|
|
|
Arguments:
|
|
tensor (Tensor): Data to be sent if ``src`` is the rank of current
|
|
process.
|
|
src (int): Source rank.
|
|
group (ProcessGroup, optional): The process group to work on.
|
|
|
|
Returns:
|
|
Tensor: Received tensor from the broadcast op.
|
|
|
|
"""
|
|
return _Broadcast.apply(src, group, tensor)
|
|
|
|
|
|
def gather(tensor, dst=0, group=dist.group.WORLD):
|
|
"""
|
|
Gathers a list of tensors in a single process.
|
|
|
|
Arguments:
|
|
tensor (Tensor): Input tensor.
|
|
dst (int, optional): Destination rank (default is 0).
|
|
group (ProcessGroup, optional): The process group to work on.
|
|
|
|
Returns:
|
|
tuple[Tensor]: List of appropriately-sized tensors with the gathered data.
|
|
"""
|
|
return _Gather.apply(dst, group, tensor)
|
|
|
|
|
|
def scatter(tensors, src=0, group=dist.group.WORLD):
|
|
"""
|
|
Scatters a list of tensors to all processes in a group.
|
|
|
|
Each process will receive exactly one tensor and store its data in the
|
|
``tensor`` argument.
|
|
|
|
Arguments:
|
|
tensors (list[Tensor]): List of tensors to scatter on the source rank.
|
|
Receivers must pass ``None`.
|
|
src (int, optional): Source rank (default is 0).
|
|
group (ProcessGroup, optional): The process group to work on.
|
|
|
|
Returns:
|
|
Tensor: Output tensor from the scatter operation.
|
|
|
|
"""
|
|
return _Scatter.apply(src, group, *tensors)
|
|
|
|
|
|
def reduce(tensor, dst, op=dist.ReduceOp.SUM, group=dist.group.WORLD):
|
|
"""
|
|
Reduces the tensor data across all machines.
|
|
|
|
Only the process with rank ``dst`` is going to receive the final result.
|
|
|
|
Arguments:
|
|
tensor (Tensor): Input of the collective.
|
|
dst (int): Destination rank.
|
|
op (optional): One of the values from
|
|
``torch.distributed.ReduceOp``
|
|
enum. Specifies an operation used for element-wise reductions.
|
|
group (ProcessGroup, optional): The process group to work on.
|
|
|
|
Returns:
|
|
Tensor: Output of the collective.
|
|
|
|
"""
|
|
return _Reduce.apply(dst, op, group, tensor)
|
|
|
|
|
|
def reduce_scatter(output, input_list, op=dist.ReduceOp.SUM, group=dist.group.WORLD):
|
|
"""
|
|
Reduces, then scatters a list of tensors to all processes in a group.
|
|
|
|
Arguments:
|
|
output (Tensor): Output tensor.
|
|
input_list (list[Tensor]): List of tensors to reduce and scatter.
|
|
op (optional): One of the values from
|
|
``torch.distributed.ReduceOp``
|
|
enum. Specifies an operation used for element-wise reductions.
|
|
group (ProcessGroup, optional): The process group to work on.
|
|
|
|
Returns:
|
|
Tensor: Output of the collective.
|
|
|
|
"""
|
|
return _Reduce_Scatter.apply(op, group, output, *input_list)
|
|
|
|
|
|
def all_gather(tensor, group=dist.group.WORLD):
|
|
"""
|
|
Gathers tensors from the whole group in a list.
|
|
|
|
Arguments:
|
|
tensor (Tensor): Tensor to be broadcast from current process.
|
|
group (ProcessGroup, optional): The process group to work on.
|
|
|
|
Returns:
|
|
tuple([Tensor]): Output of the collective.
|
|
|
|
"""
|
|
return _AllGather.apply(group, tensor)
|
|
|
|
|
|
def all_to_all(output_tensor_list, input_tensor_list, group=dist.group.WORLD):
|
|
"""
|
|
Each process scatters list of input tensors to all processes in a group and
|
|
return gathered list of tensors in output list.
|
|
|
|
Arguments:
|
|
out_tensor_list (list[Tensor]): list of tensors to gather one per rank.
|
|
input_tensor_list (list[Tensor]): List of tensors to scatter one per rank.
|
|
group (ProcessGroup, optional): The process group to work on.
|
|
|
|
Returns:
|
|
tuple([Tensor]): Output of the collective.
|
|
|
|
"""
|
|
return _AlltoAll.apply(group, output_tensor_list, *input_tensor_list)
|
|
|
|
|
|
def all_to_all_single(
|
|
output,
|
|
input,
|
|
output_split_sizes=None,
|
|
input_split_sizes=None,
|
|
group=dist.group.WORLD,
|
|
):
|
|
"""
|
|
Each process splits input tensor and then scatters the split list
|
|
to all processes in a group. Then concatenate the received tensors from all
|
|
the processes in the group and return single output tensor.
|
|
|
|
Arguments:
|
|
output (Tensor): Gathered cancatenated output tensor.
|
|
input (Tensor): Input tensor to scatter.
|
|
output_split_sizes: (list[Int], optional): Output split sizes for dim 0
|
|
if specified None or empty, dim 0 of ``output`` tensor must divide
|
|
equally by ``world_size``.
|
|
input_split_sizes: (list[Int], optional): Input split sizes for dim 0
|
|
if specified None or empty, dim 0 of ``input`` tensor must divide
|
|
equally by ``world_size``.
|
|
|
|
Returns:
|
|
Tensor: Output of the collective.
|
|
|
|
"""
|
|
return _AlltoAllSingle.apply(
|
|
group, output, output_split_sizes, input_split_sizes, input
|
|
)
|
|
|
|
|
|
def all_reduce(tensor, op=dist.ReduceOp.SUM, group=dist.group.WORLD):
|
|
"""
|
|
Reduces the tensor data across all machines in such a way that all get
|
|
the final result.
|
|
|
|
After the call the returned tensor is going to be bitwise
|
|
identical in all processes.
|
|
|
|
Arguments:
|
|
tensor (Tensor): Input of the collective.
|
|
op (optional): One of the values from
|
|
``torch.distributed.ReduceOp``
|
|
enum. Specifies an operation used for element-wise reductions.
|
|
group (ProcessGroup, optional): The process group to work on.
|
|
|
|
Returns:
|
|
Tensor: Output of the collective
|
|
|
|
"""
|
|
return _AllReduce.apply(op, group, tensor)
|
|
|
|
|
|
class _Broadcast(Function):
|
|
@staticmethod
|
|
def forward(ctx, src, group, tensor):
|
|
ctx.src = src
|
|
ctx.group = group
|
|
ctx.rank = dist.get_rank()
|
|
# torch.distributed makes all the calls in place
|
|
# we allocate new tensors to avoid this
|
|
tensor = tensor.clone()
|
|
dist.broadcast(tensor, src, group=group)
|
|
return tensor
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
gx = _Reduce.apply(ctx.src, dist.ReduceOp.SUM, ctx.group, grad_output)
|
|
if ctx.src != ctx.rank:
|
|
gx.zero_()
|
|
return (None, None, gx)
|
|
|
|
|
|
class _Gather(Function):
|
|
@staticmethod
|
|
def forward(ctx, dst, group, tensor):
|
|
ctx.dst = dst
|
|
ctx.group = group
|
|
# Need to create a list of tensors here to do the
|
|
# aggregation, get it from the group size
|
|
# tensor should be correctly sized for the method
|
|
# gathering
|
|
tensor_list = [
|
|
torch.zeros_like(tensor) for i in range(dist.get_world_size(group=group))
|
|
]
|
|
if dist.get_rank(group=group) == dst:
|
|
dist.gather(tensor, tensor_list, dst, group=group)
|
|
else:
|
|
dist.gather(tensor, None, dst, group=group)
|
|
return tuple(tensor_list)
|
|
|
|
@staticmethod
|
|
def backward(ctx, *grad_outputs):
|
|
return (None, None) + (_Scatter.apply(ctx.dst, ctx.group, *grad_outputs),)
|
|
|
|
|
|
class _Scatter(Function):
|
|
@staticmethod
|
|
def forward(ctx, src, group, *tensors):
|
|
ctx.src = src
|
|
ctx.group = group
|
|
assert all(t.size() == tensors[0].size() for t in tensors)
|
|
output = torch.zeros_like(tensors[0])
|
|
if dist.get_rank(group=group) == src:
|
|
dist.scatter(output, list(tensors), src, group=group)
|
|
else:
|
|
dist.scatter(output, None, src, group=group)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return (None, None) + _Gather.apply(ctx.src, ctx.group, grad_output)
|
|
|
|
|
|
class _Reduce(Function):
|
|
@staticmethod
|
|
def forward(ctx, src, op, group, tensor):
|
|
ctx.src = src
|
|
ctx.group = group
|
|
tensor = tensor.clone()
|
|
dist.reduce(tensor, src, op=op, group=group)
|
|
return tensor
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return (None, None, None) + (_Broadcast.apply(ctx.src, ctx.group, grad_output),)
|
|
|
|
|
|
class _Reduce_Scatter(Function):
|
|
@staticmethod
|
|
def forward(ctx, op, group, tensor, *input_tensor_list):
|
|
ctx.group = group
|
|
dist.reduce_scatter(tensor, list(input_tensor_list), op=op, group=group)
|
|
return tensor
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return (None, None, None) + _AllGather.apply(
|
|
ctx.group, grad_output.contiguous()
|
|
)
|
|
|
|
|
|
class _AllGather(Function):
|
|
@staticmethod
|
|
def forward(ctx, group, tensor):
|
|
ctx.group = group
|
|
out_tensor_list = [
|
|
torch.empty_like(tensor) for i in range(dist.get_world_size(group=group))
|
|
]
|
|
dist.all_gather(out_tensor_list, tensor, group=group)
|
|
return tuple(out_tensor_list)
|
|
|
|
@staticmethod
|
|
def backward(ctx, *grad_outputs):
|
|
tensor_list = [torch.empty_like(tensor) for tensor in grad_outputs]
|
|
gxs = _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs)
|
|
gx = torch.sum(torch.stack(gxs), dim=0)
|
|
return (None, gx)
|
|
|
|
|
|
class _AlltoAll(Function):
|
|
@staticmethod
|
|
def forward(ctx, group, out_tensor_list, *tensors):
|
|
ctx.group = group
|
|
ctx.input_tensor_size_list = [
|
|
tensors[i].size() for i in range(dist.get_world_size(group=group))
|
|
]
|
|
my_rank = dist.get_rank(group=group)
|
|
# Implement it on means of scatter/gather, send/recv async operations have issues
|
|
if dist.get_backend(group=group) is dist.Backend.GLOO:
|
|
for i in range(dist.get_world_size(group=group)):
|
|
to_send = None
|
|
if i == my_rank:
|
|
to_send = list(tensors)
|
|
dist.scatter(out_tensor_list[i], to_send, i, group=group)
|
|
else:
|
|
dist.all_to_all(
|
|
out_tensor_list,
|
|
list(tensors),
|
|
group=group,
|
|
)
|
|
return tuple(out_tensor_list)
|
|
|
|
@staticmethod
|
|
def backward(ctx, *grad_outputs):
|
|
tensor_list = [
|
|
torch.empty(size, device=grad_outputs[0].device)
|
|
for size in ctx.input_tensor_size_list
|
|
]
|
|
grad_outputs = tuple(tensor.contiguous() for tensor in grad_outputs)
|
|
return (None, None) + _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs)
|
|
|
|
|
|
class _AlltoAllSingle(Function):
|
|
@staticmethod
|
|
def forward(ctx, group, output, output_split_sizes, input_split_sizes, input):
|
|
ctx.group = group
|
|
ctx.input_size = input.size()
|
|
ctx.output_split_sizes = input_split_sizes
|
|
ctx.input_split_sizes = output_split_sizes
|
|
dist.all_to_all_single(
|
|
output,
|
|
input,
|
|
output_split_sizes=output_split_sizes,
|
|
input_split_sizes=input_split_sizes,
|
|
group=group,
|
|
)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
tensor = torch.empty(ctx.input_size, device=grad_output.device)
|
|
return (None, None, None, None) + (
|
|
_AlltoAllSingle.apply(
|
|
ctx.group,
|
|
tensor,
|
|
ctx.output_split_sizes,
|
|
ctx.input_split_sizes,
|
|
grad_output.contiguous(),
|
|
),
|
|
)
|
|
|
|
|
|
class _AllReduce(Function):
|
|
@staticmethod
|
|
def forward(ctx, op, group, tensor):
|
|
ctx.group = group
|
|
ctx.op = op
|
|
tensor = tensor.clone()
|
|
dist.all_reduce(tensor, op=op, group=group)
|
|
return tensor
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),)
|