[PATCH] Back out "Move functional collectives implementation to python. (#98595) (#99168)

Summary:
Original commit changeset: ba36f8751adc

Original Phabricator Diff: D44788697

Test Plan: model loading is fine after reverting the diff

Reviewed By: zyan0, sayitmemory

Differential Revision: D44921259
---

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99168
Approved by: https://github.com/izaitsevfb
This commit is contained in:
Rodrigo Kumpera 2023-04-14 23:48:15 +00:00 committed by PyTorch MergeBot
parent 20019f7c56
commit a910045add
15 changed files with 165 additions and 85 deletions

View File

@ -0,0 +1,36 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Parallel.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#endif
namespace at {
namespace native {
// Dummy impls required by codegen infra, not used
// These should never get called
// Defer to python impls in torch/distributed/_functional_collectives.py and _meta_registrations.py
at::Tensor all_reduce(at::Tensor const& self, const c10::string_view reduceOp, const c10::string_view tag, c10::ArrayRef<int64_t> ranks, int64_t group_size) {
TORCH_INTERNAL_ASSERT(false);
}
at::Tensor all_gather_into_tensor(at::Tensor const& shard, const c10::string_view tag, c10::ArrayRef<int64_t> ranks, int64_t group_size) {
TORCH_INTERNAL_ASSERT(false);
}
at::Tensor reduce_scatter_tensor(at::Tensor const& input, const c10::string_view reduceOp, int64_t scatter_dim, const c10::string_view tag, c10::ArrayRef<int64_t> ranks, int64_t group_size) {
TORCH_INTERNAL_ASSERT(false);
}
at::Tensor wait_tensor(at::Tensor const& self) {
TORCH_INTERNAL_ASSERT(false);
}
} // namespace native
} // namespace at

View File

@ -14732,6 +14732,34 @@
CUDA: _fused_adamw_kernel_cuda_
autogen: _fused_adamw, _fused_adamw.out
# Collectives
- func: all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor
# This should be changed to distributed but it requires changes all over the place to work
python_module: nn
dispatch:
CompositeExplicitAutograd: all_reduce
variants: function
- func: all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor
# This should be changed to distributed but it requires changes all over the place to work
python_module: nn
dispatch:
CompositeExplicitAutograd: all_gather_into_tensor
variants: function
- func: reduce_scatter_tensor(Tensor input, str reduceOp, int scatter_dim, str tag, int[] ranks, int group_size) -> Tensor
# This should be changed to distributed but it requires changes all over the place to work
python_module: nn
dispatch:
CompositeExplicitAutograd: reduce_scatter_tensor
variants: function
- func: wait_tensor(Tensor self) -> Tensor
# This should be changed to distributed but it requires changes all over the place to work
python_module: nn
dispatch:
CompositeExplicitAutograd: wait_tensor
# This op is ONLY used by pytorch/XLA in functionalization, and should never show up in vanilla eager mode or in any pytorch tracing contexts.
- func: _propagate_xla_data(Tensor input, Tensor output) -> ()
variants: function

View File

@ -1237,6 +1237,7 @@ aten_native_source_non_codegen_list = [
"aten/src/ATen/native/Bucketization.cpp",
"aten/src/ATen/native/CPUBlas.cpp",
"aten/src/ATen/native/ChanelShuffle.cpp",
"aten/src/ATen/native/Collectives.cpp",
"aten/src/ATen/native/Col2Im.cpp",
"aten/src/ATen/native/PadNd.cpp",
"aten/src/ATen/native/Convolution.cpp",

View File

@ -553,7 +553,7 @@ class TraceTrainStepTest(DTensorTestBase):
[
n
for n in gm.graph.nodes
if n.target == torch.ops.c10d_functional.all_reduce.default
if n.target == torch.ops.aten.all_reduce.default
]
),
1,

View File

@ -56,9 +56,9 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
x = torch.matmul(a, b)
y = torch.matmul(c, d)
z = torch.cat((x, y))
ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size)
ar = torch.ops.aten.all_reduce(z, "sum", tag, ranks, group_size)
g = torch.matmul(e, f)
ar = torch.ops.c10d_functional.wait_tensor(ar)
ar = torch.ops.aten.wait_tensor(ar)
out = torch.add(ar, g.repeat(2, 1))
return (out, )
@ -89,12 +89,12 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
x = torch.matmul(a, b)
y = torch.matmul(c, d)
z = torch.cat((x, y))
ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size)
ar = torch.ops.aten.all_reduce(z, "sum", tag, ranks, group_size)
return ar
def inductor_func(ar, e, f):
g = torch.matmul(e, f)
ar = torch.ops.c10d_functional.wait_tensor(ar)
ar = torch.ops.aten.wait_tensor(ar)
out = torch.add(ar, g.repeat(2, 1))
return (out, )
@ -128,12 +128,12 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
x = torch.matmul(a, b)
y = torch.matmul(c, d)
z = torch.cat((x, y))
ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size)
ar = torch.ops.aten.all_reduce(z, "sum", tag, ranks, group_size)
return ar
def eager_func(ar, e, f):
g = torch.matmul(e, f)
ar = torch.ops.c10d_functional.wait_tensor(ar)
ar = torch.ops.aten.wait_tensor(ar)
out = torch.add(ar, g.repeat(2, 1))
return (out, )
@ -166,8 +166,8 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
def example(a, b, *, tag, ranks, group_size):
c = torch.matmul(a, b)
ag = torch.ops.c10d_functional.all_gather_into_tensor(c, tag, ranks, group_size)
ag = torch.ops.c10d_functional.wait_tensor(ag)
ag = torch.ops.aten.all_gather_into_tensor(c, tag, ranks, group_size)
ag = torch.ops.aten.wait_tensor(ag)
return (ag, )
def compile(func, example_inputs):
@ -194,10 +194,10 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
def test_reduce_scatter_tensor_inductor(self):
def example(a, b, *, tag, ranks, group_size):
c = torch.matmul(a, b)
ag = torch.ops.c10d_functional.reduce_scatter_tensor(
c, "sum", tag, ranks, group_size
ag = torch.ops.aten.reduce_scatter_tensor(
c, "sum", 0, tag, ranks, group_size
)
ag = torch.ops.c10d_functional.wait_tensor(ag)
ag = torch.ops.aten.wait_tensor(ag)
return (ag,)
def compile(func, example_inputs):
@ -234,8 +234,8 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
torch._inductor.config.debug = True
def func(inp, *, tag, ranks, group_size):
ar = torch.ops.c10d_functional.all_reduce(inp, "sum", tag, ranks, group_size)
ar = torch.ops.c10d_functional.wait_tensor(ar)
ar = torch.ops.aten.all_reduce(inp, "sum", tag, ranks, group_size)
ar = torch.ops.aten.wait_tensor(ar)
return ar
inputs = torch.ones(4, 4, device="cuda")
@ -264,8 +264,8 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
def func(inp, *, tag, ranks, group_size):
x = inp + 1
ar = torch.ops.c10d_functional.all_reduce(x, "sum", tag, ranks, group_size)
ar = torch.ops.c10d_functional.wait_tensor(ar)
ar = torch.ops.aten.all_reduce(x, "sum", tag, ranks, group_size)
ar = torch.ops.aten.wait_tensor(ar)
# ensure other is not incorrectly aliasing ar's buffer
other = torch.ones_like(inp) + 22
return ar, other
@ -298,9 +298,9 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
def func(inp, *, tag, ranks, group_size):
x = inp + 1
ar = torch.ops.c10d_functional.all_reduce(x, "sum", tag, ranks, group_size)
ar = torch.ops.aten.all_reduce(x, "sum", tag, ranks, group_size)
y = x + 2
ar = torch.ops.c10d_functional.wait_tensor(ar)
ar = torch.ops.aten.wait_tensor(ar)
# ensure other is not incorrectly aliasing ar's buffer
other = torch.ones_like(inp) + 22
return ar, y, other
@ -327,7 +327,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
def test_dynamo_trace_allreduce(self):
def func(inp, *, tag, ranks, group_size):
ar = torch.ops.c10d_functional.all_reduce(inp, "sum", tag, ranks, group_size)
ar = torch.ops.aten.all_reduce(inp, "sum", tag, ranks, group_size)
return ar
inputs = torch.ones(4, 4, device="cuda")
@ -346,7 +346,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
However, I wanted to at least see if it was possible to support it as a design goal.
"""
def func(inp, *, tag, ranks, group_size):
ar = torch.ops.c10d_functional.all_reduce(inp, "sum", tag, ranks, group_size)
ar = torch.ops.aten.all_reduce(inp, "sum", tag, ranks, group_size)
return ar
input = torch.ones(4, 4, device="cuda", requires_grad=True)
@ -364,7 +364,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
def test_meta(self):
x = torch.rand((2, 3, 4), device="meta")
out = torch.ops.c10d_functional.all_reduce(x, "sum", **self.get_world_trs())
out = torch.ops.aten.all_reduce(x, "sum", **self.get_world_trs())
assert x.size() == out.size()
if __name__ == "__main__":

View File

@ -574,6 +574,8 @@ aten::affine_grid_generator
aten::affine_grid_generator.out
aten::alias_copy
aten::alias_copy.out
aten::all_gather_into_tensor
aten::all_reduce
aten::allclose
aten::angle
aten::angle.out
@ -1058,6 +1060,7 @@ aten::range.out
aten::range.out_
aten::range.step
aten::record_stream
aten::reduce_scatter_tensor
aten::reflection_pad1d
aten::reflection_pad1d.out
aten::reflection_pad1d_backward
@ -1328,6 +1331,7 @@ aten::view_copy
aten::view_copy.dtype
aten::view_copy.dtype_out
aten::view_copy.out
aten::wait_tensor
aten::zeros.names
aten::zeros.names_out
aten::zeros.out

View File

@ -357,12 +357,6 @@ ALLOW_LIST = [
("aten::_nested_view_from_buffer_copy.out", datetime.date(2023, 5, 1)),
("aten::_nested_view_from_buffer_copy", datetime.date(2023, 5, 1)),
("aten::_nested_view_from_buffer", datetime.date(2023, 5, 1)),
# These ops were moved to python under the c10d_functional namespace
("aten::wait_tensor", datetime.date(9999, 1, 30)),
("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)),
("aten::all_gather_into_tensor", datetime.date(9999, 1, 30)),
("aten::all_reduce", datetime.date(9999, 1, 30)),
]
ALLOW_LIST_COMPILED = [

View File

@ -3009,3 +3009,15 @@
- name: _foreach_maximum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
self: at::where(self[i] == scalars[i], grads[i] / 2, grads[i]).masked_fill_(self[i] < scalars[i], 0)
- name: all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor
self: non_differentiable
- name: all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor
shard: non_differentiable
- name: reduce_scatter_tensor(Tensor input, str reduceOp, int scatter_dim, str tag, int[] ranks, int group_size) -> Tensor
input: non_differentiable
- name: wait_tensor(Tensor self) -> Tensor
self: non_differentiable

View File

@ -552,6 +552,11 @@ DONT_ENFORCE_TENSOR_IMPL_USE_COUNT = {
"_nested_tensor_size",
"_nested_tensor_strides",
"_nested_tensor_storage_offsets",
# Functional collectives keep an internal ref through the Work object
"all_reduce",
"all_gather_into_tensor",
"reduce_scatter_tensor",
"wait_tensor",
}
DONT_ENFORCE_STORAGE_IMPL_USE_COUNT = {

View File

@ -4176,15 +4176,18 @@ class AllGatherIntoTensor(CollectiveKernel):
class ReduceScatterTensor(CollectiveKernel):
def __init__(self, layout, inputs, constant_args, reduce_op):
def __init__(self, layout, inputs, constant_args, reduce_op, scatter_dim):
super().__init__(layout, inputs, constant_args)
self.reduce_op = reduce_op
# TODO support dim
self.scatter_dim = scatter_dim
@classmethod
def create(
cls,
x: "TensorBox",
reduce_op: str,
scatter_dim: int,
tag: str,
ranks: List[int],
group_size: int,
@ -4194,7 +4197,7 @@ class ReduceScatterTensor(CollectiveKernel):
# is there a difference between literally using x.data.layout below, vs
# creating a new one that has the same properties?
new_size = x.get_size()
new_size[0] /= group_size
new_size[scatter_dim] /= group_size
new_layout = FlexibleLayout(x.get_device(), x.get_dtype(), new_size)
return ReduceScatterTensor(
@ -4202,6 +4205,7 @@ class ReduceScatterTensor(CollectiveKernel):
inputs=[x],
constant_args=[tag, ranks, group_size],
reduce_op=reduce_op,
scatter_dim=scatter_dim,
)
def codegen_collective(self, wrapper, output_name, input_names):

View File

@ -3894,28 +3894,28 @@ def _realize(x):
try:
import torch.distributed._functional_collectives
c10d_functional = torch.ops.c10d_functional
@register_lowering(c10d_functional.wait_tensor)
@register_lowering(aten.wait_tensor)
def wait(input):
return TensorBox.create(ir.Wait.create(input))
@register_lowering(c10d_functional.all_reduce)
@register_lowering(aten.all_reduce)
def allreduce(input, reduce_op, tag, ranks, group_size):
return TensorBox.create(
ir.AllReduce.create(input, reduce_op, tag, ranks, group_size)
)
@register_lowering(c10d_functional.all_gather_into_tensor)
@register_lowering(aten.all_gather_into_tensor)
def all_gather_into_tensor(shard, tag, ranks, group_size):
return TensorBox.create(
ir.AllGatherIntoTensor.create(shard, tag, ranks, group_size)
)
@register_lowering(c10d_functional.reduce_scatter_tensor)
def reduce_scatter_tensor(input, reduce_op, tag, ranks, group_size):
@register_lowering(aten.reduce_scatter_tensor)
def reduce_scatter_tensor(input, reduce_op, scatter_dim, tag, ranks, group_size):
return TensorBox.create(
ir.ReduceScatterTensor.create(input, reduce_op, tag, ranks, group_size)
ir.ReduceScatterTensor.create(
input, reduce_op, scatter_dim, tag, ranks, group_size
)
)
except ImportError:

View File

@ -3089,4 +3089,28 @@ def activate_meta():
_meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn)
@register_meta(aten.all_reduce)
def all_reduce_meta(self, reduceOp, tag, rankset, group_size):
return torch.empty_like(self)
@register_meta(aten.all_gather_into_tensor)
def all_gather_into_tensor_meta(shard, tag, rankset, group_size):
out_size = list(shard.size())
out_size[0] *= group_size
return shard.new_empty(out_size)
@register_meta(aten.reduce_scatter_tensor)
def reduce_scatter_tensor_meta(input, reduce_op, scatter_dim, tag, rankset, group_size):
out_size = list(input.size())
out_size[scatter_dim] //= group_size
return input.new_empty(out_size)
@register_meta(aten.wait_tensor)
def wait_tensor_meta(self):
return torch.empty_like(self)
activate_meta()

View File

@ -109,7 +109,7 @@ class AsyncCollectiveTensor(torch.Tensor):
Use it inside functional collective pytorch wrappers like the following:
def functional_collective(self, group, tag):
tag, rankset, group_size = _expand_group(group, tag)
tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size)
tensor = torch._C._nn.{collective}(self, tag, rankset, group_size)
res = AsyncCollectiveTensor(tensor)
_register_wrapper_tensor(res, tensor)
return res
@ -188,16 +188,18 @@ def _all_gather_into_tensor(shard, tag, ranks, group_size):
def _reduce_scatter_tensor(
input: torch.Tensor,
reduceOp: str,
scatter_dim: int,
tag: str,
ranks: List[int],
group_size: int,
):
# TODO add dim support?
assert scatter_dim == 0, "Only scatter_dim = 0 is supported for now."
group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size)
assert group is not None
op = _str_to_reduce_op(reduceOp)
out_size = list(input.size())
out_size[0] //= group_size
out_size[scatter_dim] //= group_size
out_tensor = input.new_empty(out_size)
work = dist.reduce_scatter_tensor(
out_tensor, input, op=op, group=group, async_op=True
@ -273,7 +275,7 @@ def wait_tensor(tensor):
Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA.
"""
return torch.ops.c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
return torch._C._nn.wait_tensor(tensor) # type: ignore[attr-defined]
def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""):
@ -294,7 +296,7 @@ def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str =
that information and perform collective algebraic optimization. Use other forms of input for that.
"""
tag, rankset, group_size = _expand_group(group, tag)
tensor = torch.ops.c10d_functional.all_reduce(self, reduceOp, tag, rankset, group_size) # type: ignore[attr-defined]
tensor = torch._C._nn.all_reduce(self, reduceOp, tag, rankset, group_size) # type: ignore[attr-defined]
return _maybe_wrap_tensor(tensor)
@ -323,7 +325,7 @@ def all_gather_tensor(
"""
assert self.is_contiguous()
tag, rankset, group_size = _expand_group(group, tag)
tensor = torch.ops.c10d_functional.all_gather_into_tensor(self, tag, rankset, group_size) # type: ignore[attr-defined]
tensor = torch._C._nn.all_gather_into_tensor(self, tag, rankset, group_size) # type: ignore[attr-defined]
res: torch.Tensor = AsyncCollectiveTensor(tensor)
_register_wrapper_tensor(res, tensor)
# TODO this should be done inside AsyncCollectiveTensor to delay the wait() call
@ -361,49 +363,26 @@ def reduce_scatter_tensor(
tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
self = torch.cat(tensor_list)
tensor = torch.ops.c10d_functional.reduce_scatter_tensor(self, reduceOp, tag, rankset, group_size) # type: ignore[attr-defined]
tensor = torch._C._nn.reduce_scatter_tensor(self, reduceOp, 0, tag, rankset, group_size) # type: ignore[attr-defined]
res = _maybe_wrap_tensor(tensor)
return res
c10_lib = torch.library.Library("c10d_functional", "DEF")
c10_lib_impl = torch.library.Library("c10d_functional", "IMPL")
# We now register meta kernels to deal with tracing
def _all_reduce_meta(self, *args):
return torch.empty_like(self)
def _wait_tensor_meta(self, *args):
return torch.empty_like(self)
def _all_gather_into_tensor_meta(shard, tag, rankset, group_size):
out_size = list(shard.size())
out_size[0] *= group_size
return shard.new_empty(out_size)
def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size):
out_size = list(input.size())
out_size[0] //= group_size
return input.new_empty(out_size)
c10_lib_cpu = torch.library.Library("aten", "IMPL", "CPU")
c10_lib_cuda = torch.library.Library("aten", "IMPL", "CUDA")
def _register_ops():
ops_defs = [
"all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
"wait_tensor(Tensor self) -> Tensor",
"all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor",
"reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
]
c10_lib_cpu.impl("all_reduce", _all_reduce)
c10_lib_cuda.impl("all_reduce", _all_reduce)
my_module = sys.modules[__name__]
for op_def in ops_defs:
op_name = op_def[0:op_def.index('(')]
backend_impl = getattr(my_module, f"_{op_name}")
meta_impl = getattr(my_module, f"_{op_name}_meta")
c10_lib.define(op_def)
c10_lib_impl.impl(op_name, backend_impl, "CompositeExplicitAutograd")
c10_lib_impl.impl(op_name, meta_impl, "Meta")
c10_lib_cpu.impl("wait_tensor", _wait_tensor)
c10_lib_cuda.impl("wait_tensor", _wait_tensor)
c10_lib_cpu.impl("all_gather_into_tensor", _all_gather_into_tensor)
c10_lib_cuda.impl("all_gather_into_tensor", _all_gather_into_tensor)
c10_lib_cpu.impl("reduce_scatter_tensor", _reduce_scatter_tensor)
c10_lib_cuda.impl("reduce_scatter_tensor", _reduce_scatter_tensor)
if sys.executable != 'torch_deploy':
_register_ops()

View File

@ -21,9 +21,6 @@ from functorch import make_fx
import torch
import torch.distributed as dist
# We need to import _functional_collectives to trigger op registration
import torch.distributed._functional_collectives
import torch.nn as nn
import torch.utils._pytree as pytree
@ -388,8 +385,8 @@ FOREACH_DECOMP_TABLE = {
DEDUP_TARGETS: Set[torch._ops.OpOverload] = {
torch.ops.c10d_functional.all_reduce.default,
torch.ops.c10d_functional.wait_tensor.default,
aten.all_reduce.default,
aten.wait_tensor.default,
}

View File

@ -323,10 +323,6 @@ class ProcessLocalGroup(dist.ProcessGroup):
ProcessLocalGroup._end_coll(coll, self)
return res
def _reduce_scatter_base(self, output_tensor, input_tensor, opts=AllgatherOptions()):
tensor_list = list(torch.chunk(input_tensor, self._world_size))
return self.reduce_scatter([output_tensor], [tensor_list], opts)
def __init__(self, rank, world_size):
super().__init__(rank, world_size)
self._rank = rank