mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Original commit changeset: ba36f8751adc Original Phabricator Diff: D44788697 Test Plan: model loading is fine after reverting the diff Reviewed By: zyan0, sayitmemory Differential Revision: D44921259 --- Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/99168 Approved by: https://github.com/izaitsevfb
This commit is contained in:
parent
20019f7c56
commit
a910045add
36
aten/src/ATen/native/Collectives.cpp
Normal file
36
aten/src/ATen/native/Collectives.cpp
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Parallel.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
// Dummy impls required by codegen infra, not used
|
||||
// These should never get called
|
||||
// Defer to python impls in torch/distributed/_functional_collectives.py and _meta_registrations.py
|
||||
|
||||
at::Tensor all_reduce(at::Tensor const& self, const c10::string_view reduceOp, const c10::string_view tag, c10::ArrayRef<int64_t> ranks, int64_t group_size) {
|
||||
TORCH_INTERNAL_ASSERT(false);
|
||||
}
|
||||
|
||||
at::Tensor all_gather_into_tensor(at::Tensor const& shard, const c10::string_view tag, c10::ArrayRef<int64_t> ranks, int64_t group_size) {
|
||||
TORCH_INTERNAL_ASSERT(false);
|
||||
}
|
||||
|
||||
at::Tensor reduce_scatter_tensor(at::Tensor const& input, const c10::string_view reduceOp, int64_t scatter_dim, const c10::string_view tag, c10::ArrayRef<int64_t> ranks, int64_t group_size) {
|
||||
TORCH_INTERNAL_ASSERT(false);
|
||||
}
|
||||
|
||||
at::Tensor wait_tensor(at::Tensor const& self) {
|
||||
TORCH_INTERNAL_ASSERT(false);
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
@ -14732,6 +14732,34 @@
|
|||
CUDA: _fused_adamw_kernel_cuda_
|
||||
autogen: _fused_adamw, _fused_adamw.out
|
||||
|
||||
# Collectives
|
||||
- func: all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor
|
||||
# This should be changed to distributed but it requires changes all over the place to work
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: all_reduce
|
||||
variants: function
|
||||
|
||||
- func: all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor
|
||||
# This should be changed to distributed but it requires changes all over the place to work
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: all_gather_into_tensor
|
||||
variants: function
|
||||
|
||||
- func: reduce_scatter_tensor(Tensor input, str reduceOp, int scatter_dim, str tag, int[] ranks, int group_size) -> Tensor
|
||||
# This should be changed to distributed but it requires changes all over the place to work
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: reduce_scatter_tensor
|
||||
variants: function
|
||||
|
||||
- func: wait_tensor(Tensor self) -> Tensor
|
||||
# This should be changed to distributed but it requires changes all over the place to work
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: wait_tensor
|
||||
|
||||
# This op is ONLY used by pytorch/XLA in functionalization, and should never show up in vanilla eager mode or in any pytorch tracing contexts.
|
||||
- func: _propagate_xla_data(Tensor input, Tensor output) -> ()
|
||||
variants: function
|
||||
|
|
|
|||
|
|
@ -1237,6 +1237,7 @@ aten_native_source_non_codegen_list = [
|
|||
"aten/src/ATen/native/Bucketization.cpp",
|
||||
"aten/src/ATen/native/CPUBlas.cpp",
|
||||
"aten/src/ATen/native/ChanelShuffle.cpp",
|
||||
"aten/src/ATen/native/Collectives.cpp",
|
||||
"aten/src/ATen/native/Col2Im.cpp",
|
||||
"aten/src/ATen/native/PadNd.cpp",
|
||||
"aten/src/ATen/native/Convolution.cpp",
|
||||
|
|
|
|||
|
|
@ -553,7 +553,7 @@ class TraceTrainStepTest(DTensorTestBase):
|
|||
[
|
||||
n
|
||||
for n in gm.graph.nodes
|
||||
if n.target == torch.ops.c10d_functional.all_reduce.default
|
||||
if n.target == torch.ops.aten.all_reduce.default
|
||||
]
|
||||
),
|
||||
1,
|
||||
|
|
|
|||
|
|
@ -56,9 +56,9 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
x = torch.matmul(a, b)
|
||||
y = torch.matmul(c, d)
|
||||
z = torch.cat((x, y))
|
||||
ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size)
|
||||
ar = torch.ops.aten.all_reduce(z, "sum", tag, ranks, group_size)
|
||||
g = torch.matmul(e, f)
|
||||
ar = torch.ops.c10d_functional.wait_tensor(ar)
|
||||
ar = torch.ops.aten.wait_tensor(ar)
|
||||
out = torch.add(ar, g.repeat(2, 1))
|
||||
return (out, )
|
||||
|
||||
|
|
@ -89,12 +89,12 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
x = torch.matmul(a, b)
|
||||
y = torch.matmul(c, d)
|
||||
z = torch.cat((x, y))
|
||||
ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size)
|
||||
ar = torch.ops.aten.all_reduce(z, "sum", tag, ranks, group_size)
|
||||
return ar
|
||||
|
||||
def inductor_func(ar, e, f):
|
||||
g = torch.matmul(e, f)
|
||||
ar = torch.ops.c10d_functional.wait_tensor(ar)
|
||||
ar = torch.ops.aten.wait_tensor(ar)
|
||||
out = torch.add(ar, g.repeat(2, 1))
|
||||
return (out, )
|
||||
|
||||
|
|
@ -128,12 +128,12 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
x = torch.matmul(a, b)
|
||||
y = torch.matmul(c, d)
|
||||
z = torch.cat((x, y))
|
||||
ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size)
|
||||
ar = torch.ops.aten.all_reduce(z, "sum", tag, ranks, group_size)
|
||||
return ar
|
||||
|
||||
def eager_func(ar, e, f):
|
||||
g = torch.matmul(e, f)
|
||||
ar = torch.ops.c10d_functional.wait_tensor(ar)
|
||||
ar = torch.ops.aten.wait_tensor(ar)
|
||||
out = torch.add(ar, g.repeat(2, 1))
|
||||
return (out, )
|
||||
|
||||
|
|
@ -166,8 +166,8 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
|
||||
def example(a, b, *, tag, ranks, group_size):
|
||||
c = torch.matmul(a, b)
|
||||
ag = torch.ops.c10d_functional.all_gather_into_tensor(c, tag, ranks, group_size)
|
||||
ag = torch.ops.c10d_functional.wait_tensor(ag)
|
||||
ag = torch.ops.aten.all_gather_into_tensor(c, tag, ranks, group_size)
|
||||
ag = torch.ops.aten.wait_tensor(ag)
|
||||
return (ag, )
|
||||
|
||||
def compile(func, example_inputs):
|
||||
|
|
@ -194,10 +194,10 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
def test_reduce_scatter_tensor_inductor(self):
|
||||
def example(a, b, *, tag, ranks, group_size):
|
||||
c = torch.matmul(a, b)
|
||||
ag = torch.ops.c10d_functional.reduce_scatter_tensor(
|
||||
c, "sum", tag, ranks, group_size
|
||||
ag = torch.ops.aten.reduce_scatter_tensor(
|
||||
c, "sum", 0, tag, ranks, group_size
|
||||
)
|
||||
ag = torch.ops.c10d_functional.wait_tensor(ag)
|
||||
ag = torch.ops.aten.wait_tensor(ag)
|
||||
return (ag,)
|
||||
|
||||
def compile(func, example_inputs):
|
||||
|
|
@ -234,8 +234,8 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||
torch._inductor.config.debug = True
|
||||
|
||||
def func(inp, *, tag, ranks, group_size):
|
||||
ar = torch.ops.c10d_functional.all_reduce(inp, "sum", tag, ranks, group_size)
|
||||
ar = torch.ops.c10d_functional.wait_tensor(ar)
|
||||
ar = torch.ops.aten.all_reduce(inp, "sum", tag, ranks, group_size)
|
||||
ar = torch.ops.aten.wait_tensor(ar)
|
||||
return ar
|
||||
|
||||
inputs = torch.ones(4, 4, device="cuda")
|
||||
|
|
@ -264,8 +264,8 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||
|
||||
def func(inp, *, tag, ranks, group_size):
|
||||
x = inp + 1
|
||||
ar = torch.ops.c10d_functional.all_reduce(x, "sum", tag, ranks, group_size)
|
||||
ar = torch.ops.c10d_functional.wait_tensor(ar)
|
||||
ar = torch.ops.aten.all_reduce(x, "sum", tag, ranks, group_size)
|
||||
ar = torch.ops.aten.wait_tensor(ar)
|
||||
# ensure other is not incorrectly aliasing ar's buffer
|
||||
other = torch.ones_like(inp) + 22
|
||||
return ar, other
|
||||
|
|
@ -298,9 +298,9 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||
|
||||
def func(inp, *, tag, ranks, group_size):
|
||||
x = inp + 1
|
||||
ar = torch.ops.c10d_functional.all_reduce(x, "sum", tag, ranks, group_size)
|
||||
ar = torch.ops.aten.all_reduce(x, "sum", tag, ranks, group_size)
|
||||
y = x + 2
|
||||
ar = torch.ops.c10d_functional.wait_tensor(ar)
|
||||
ar = torch.ops.aten.wait_tensor(ar)
|
||||
# ensure other is not incorrectly aliasing ar's buffer
|
||||
other = torch.ones_like(inp) + 22
|
||||
return ar, y, other
|
||||
|
|
@ -327,7 +327,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||
|
||||
def test_dynamo_trace_allreduce(self):
|
||||
def func(inp, *, tag, ranks, group_size):
|
||||
ar = torch.ops.c10d_functional.all_reduce(inp, "sum", tag, ranks, group_size)
|
||||
ar = torch.ops.aten.all_reduce(inp, "sum", tag, ranks, group_size)
|
||||
return ar
|
||||
|
||||
inputs = torch.ones(4, 4, device="cuda")
|
||||
|
|
@ -346,7 +346,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||
However, I wanted to at least see if it was possible to support it as a design goal.
|
||||
"""
|
||||
def func(inp, *, tag, ranks, group_size):
|
||||
ar = torch.ops.c10d_functional.all_reduce(inp, "sum", tag, ranks, group_size)
|
||||
ar = torch.ops.aten.all_reduce(inp, "sum", tag, ranks, group_size)
|
||||
return ar
|
||||
|
||||
input = torch.ones(4, 4, device="cuda", requires_grad=True)
|
||||
|
|
@ -364,7 +364,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||
|
||||
def test_meta(self):
|
||||
x = torch.rand((2, 3, 4), device="meta")
|
||||
out = torch.ops.c10d_functional.all_reduce(x, "sum", **self.get_world_trs())
|
||||
out = torch.ops.aten.all_reduce(x, "sum", **self.get_world_trs())
|
||||
assert x.size() == out.size()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -574,6 +574,8 @@ aten::affine_grid_generator
|
|||
aten::affine_grid_generator.out
|
||||
aten::alias_copy
|
||||
aten::alias_copy.out
|
||||
aten::all_gather_into_tensor
|
||||
aten::all_reduce
|
||||
aten::allclose
|
||||
aten::angle
|
||||
aten::angle.out
|
||||
|
|
@ -1058,6 +1060,7 @@ aten::range.out
|
|||
aten::range.out_
|
||||
aten::range.step
|
||||
aten::record_stream
|
||||
aten::reduce_scatter_tensor
|
||||
aten::reflection_pad1d
|
||||
aten::reflection_pad1d.out
|
||||
aten::reflection_pad1d_backward
|
||||
|
|
@ -1328,6 +1331,7 @@ aten::view_copy
|
|||
aten::view_copy.dtype
|
||||
aten::view_copy.dtype_out
|
||||
aten::view_copy.out
|
||||
aten::wait_tensor
|
||||
aten::zeros.names
|
||||
aten::zeros.names_out
|
||||
aten::zeros.out
|
||||
|
|
|
|||
|
|
@ -357,12 +357,6 @@ ALLOW_LIST = [
|
|||
("aten::_nested_view_from_buffer_copy.out", datetime.date(2023, 5, 1)),
|
||||
("aten::_nested_view_from_buffer_copy", datetime.date(2023, 5, 1)),
|
||||
("aten::_nested_view_from_buffer", datetime.date(2023, 5, 1)),
|
||||
# These ops were moved to python under the c10d_functional namespace
|
||||
("aten::wait_tensor", datetime.date(9999, 1, 30)),
|
||||
("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)),
|
||||
("aten::all_gather_into_tensor", datetime.date(9999, 1, 30)),
|
||||
("aten::all_reduce", datetime.date(9999, 1, 30)),
|
||||
|
||||
]
|
||||
|
||||
ALLOW_LIST_COMPILED = [
|
||||
|
|
|
|||
|
|
@ -3009,3 +3009,15 @@
|
|||
|
||||
- name: _foreach_maximum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
|
||||
self: at::where(self[i] == scalars[i], grads[i] / 2, grads[i]).masked_fill_(self[i] < scalars[i], 0)
|
||||
|
||||
- name: all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor
|
||||
self: non_differentiable
|
||||
|
||||
- name: all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor
|
||||
shard: non_differentiable
|
||||
|
||||
- name: reduce_scatter_tensor(Tensor input, str reduceOp, int scatter_dim, str tag, int[] ranks, int group_size) -> Tensor
|
||||
input: non_differentiable
|
||||
|
||||
- name: wait_tensor(Tensor self) -> Tensor
|
||||
self: non_differentiable
|
||||
|
|
|
|||
|
|
@ -552,6 +552,11 @@ DONT_ENFORCE_TENSOR_IMPL_USE_COUNT = {
|
|||
"_nested_tensor_size",
|
||||
"_nested_tensor_strides",
|
||||
"_nested_tensor_storage_offsets",
|
||||
# Functional collectives keep an internal ref through the Work object
|
||||
"all_reduce",
|
||||
"all_gather_into_tensor",
|
||||
"reduce_scatter_tensor",
|
||||
"wait_tensor",
|
||||
}
|
||||
|
||||
DONT_ENFORCE_STORAGE_IMPL_USE_COUNT = {
|
||||
|
|
|
|||
|
|
@ -4176,15 +4176,18 @@ class AllGatherIntoTensor(CollectiveKernel):
|
|||
|
||||
|
||||
class ReduceScatterTensor(CollectiveKernel):
|
||||
def __init__(self, layout, inputs, constant_args, reduce_op):
|
||||
def __init__(self, layout, inputs, constant_args, reduce_op, scatter_dim):
|
||||
super().__init__(layout, inputs, constant_args)
|
||||
self.reduce_op = reduce_op
|
||||
# TODO support dim
|
||||
self.scatter_dim = scatter_dim
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
x: "TensorBox",
|
||||
reduce_op: str,
|
||||
scatter_dim: int,
|
||||
tag: str,
|
||||
ranks: List[int],
|
||||
group_size: int,
|
||||
|
|
@ -4194,7 +4197,7 @@ class ReduceScatterTensor(CollectiveKernel):
|
|||
# is there a difference between literally using x.data.layout below, vs
|
||||
# creating a new one that has the same properties?
|
||||
new_size = x.get_size()
|
||||
new_size[0] /= group_size
|
||||
new_size[scatter_dim] /= group_size
|
||||
new_layout = FlexibleLayout(x.get_device(), x.get_dtype(), new_size)
|
||||
|
||||
return ReduceScatterTensor(
|
||||
|
|
@ -4202,6 +4205,7 @@ class ReduceScatterTensor(CollectiveKernel):
|
|||
inputs=[x],
|
||||
constant_args=[tag, ranks, group_size],
|
||||
reduce_op=reduce_op,
|
||||
scatter_dim=scatter_dim,
|
||||
)
|
||||
|
||||
def codegen_collective(self, wrapper, output_name, input_names):
|
||||
|
|
|
|||
|
|
@ -3894,28 +3894,28 @@ def _realize(x):
|
|||
try:
|
||||
import torch.distributed._functional_collectives
|
||||
|
||||
c10d_functional = torch.ops.c10d_functional
|
||||
|
||||
@register_lowering(c10d_functional.wait_tensor)
|
||||
@register_lowering(aten.wait_tensor)
|
||||
def wait(input):
|
||||
return TensorBox.create(ir.Wait.create(input))
|
||||
|
||||
@register_lowering(c10d_functional.all_reduce)
|
||||
@register_lowering(aten.all_reduce)
|
||||
def allreduce(input, reduce_op, tag, ranks, group_size):
|
||||
return TensorBox.create(
|
||||
ir.AllReduce.create(input, reduce_op, tag, ranks, group_size)
|
||||
)
|
||||
|
||||
@register_lowering(c10d_functional.all_gather_into_tensor)
|
||||
@register_lowering(aten.all_gather_into_tensor)
|
||||
def all_gather_into_tensor(shard, tag, ranks, group_size):
|
||||
return TensorBox.create(
|
||||
ir.AllGatherIntoTensor.create(shard, tag, ranks, group_size)
|
||||
)
|
||||
|
||||
@register_lowering(c10d_functional.reduce_scatter_tensor)
|
||||
def reduce_scatter_tensor(input, reduce_op, tag, ranks, group_size):
|
||||
@register_lowering(aten.reduce_scatter_tensor)
|
||||
def reduce_scatter_tensor(input, reduce_op, scatter_dim, tag, ranks, group_size):
|
||||
return TensorBox.create(
|
||||
ir.ReduceScatterTensor.create(input, reduce_op, tag, ranks, group_size)
|
||||
ir.ReduceScatterTensor.create(
|
||||
input, reduce_op, scatter_dim, tag, ranks, group_size
|
||||
)
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
|
|
|
|||
|
|
@ -3089,4 +3089,28 @@ def activate_meta():
|
|||
_meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn)
|
||||
|
||||
|
||||
@register_meta(aten.all_reduce)
|
||||
def all_reduce_meta(self, reduceOp, tag, rankset, group_size):
|
||||
return torch.empty_like(self)
|
||||
|
||||
|
||||
@register_meta(aten.all_gather_into_tensor)
|
||||
def all_gather_into_tensor_meta(shard, tag, rankset, group_size):
|
||||
out_size = list(shard.size())
|
||||
out_size[0] *= group_size
|
||||
return shard.new_empty(out_size)
|
||||
|
||||
|
||||
@register_meta(aten.reduce_scatter_tensor)
|
||||
def reduce_scatter_tensor_meta(input, reduce_op, scatter_dim, tag, rankset, group_size):
|
||||
out_size = list(input.size())
|
||||
out_size[scatter_dim] //= group_size
|
||||
return input.new_empty(out_size)
|
||||
|
||||
|
||||
@register_meta(aten.wait_tensor)
|
||||
def wait_tensor_meta(self):
|
||||
return torch.empty_like(self)
|
||||
|
||||
|
||||
activate_meta()
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ class AsyncCollectiveTensor(torch.Tensor):
|
|||
Use it inside functional collective pytorch wrappers like the following:
|
||||
def functional_collective(self, group, tag):
|
||||
tag, rankset, group_size = _expand_group(group, tag)
|
||||
tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size)
|
||||
tensor = torch._C._nn.{collective}(self, tag, rankset, group_size)
|
||||
res = AsyncCollectiveTensor(tensor)
|
||||
_register_wrapper_tensor(res, tensor)
|
||||
return res
|
||||
|
|
@ -188,16 +188,18 @@ def _all_gather_into_tensor(shard, tag, ranks, group_size):
|
|||
def _reduce_scatter_tensor(
|
||||
input: torch.Tensor,
|
||||
reduceOp: str,
|
||||
scatter_dim: int,
|
||||
tag: str,
|
||||
ranks: List[int],
|
||||
group_size: int,
|
||||
):
|
||||
# TODO add dim support?
|
||||
assert scatter_dim == 0, "Only scatter_dim = 0 is supported for now."
|
||||
group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size)
|
||||
assert group is not None
|
||||
op = _str_to_reduce_op(reduceOp)
|
||||
out_size = list(input.size())
|
||||
out_size[0] //= group_size
|
||||
out_size[scatter_dim] //= group_size
|
||||
out_tensor = input.new_empty(out_size)
|
||||
work = dist.reduce_scatter_tensor(
|
||||
out_tensor, input, op=op, group=group, async_op=True
|
||||
|
|
@ -273,7 +275,7 @@ def wait_tensor(tensor):
|
|||
|
||||
Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA.
|
||||
"""
|
||||
return torch.ops.c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
|
||||
return torch._C._nn.wait_tensor(tensor) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""):
|
||||
|
|
@ -294,7 +296,7 @@ def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str =
|
|||
that information and perform collective algebraic optimization. Use other forms of input for that.
|
||||
"""
|
||||
tag, rankset, group_size = _expand_group(group, tag)
|
||||
tensor = torch.ops.c10d_functional.all_reduce(self, reduceOp, tag, rankset, group_size) # type: ignore[attr-defined]
|
||||
tensor = torch._C._nn.all_reduce(self, reduceOp, tag, rankset, group_size) # type: ignore[attr-defined]
|
||||
return _maybe_wrap_tensor(tensor)
|
||||
|
||||
|
||||
|
|
@ -323,7 +325,7 @@ def all_gather_tensor(
|
|||
"""
|
||||
assert self.is_contiguous()
|
||||
tag, rankset, group_size = _expand_group(group, tag)
|
||||
tensor = torch.ops.c10d_functional.all_gather_into_tensor(self, tag, rankset, group_size) # type: ignore[attr-defined]
|
||||
tensor = torch._C._nn.all_gather_into_tensor(self, tag, rankset, group_size) # type: ignore[attr-defined]
|
||||
res: torch.Tensor = AsyncCollectiveTensor(tensor)
|
||||
_register_wrapper_tensor(res, tensor)
|
||||
# TODO this should be done inside AsyncCollectiveTensor to delay the wait() call
|
||||
|
|
@ -361,49 +363,26 @@ def reduce_scatter_tensor(
|
|||
tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
|
||||
self = torch.cat(tensor_list)
|
||||
|
||||
tensor = torch.ops.c10d_functional.reduce_scatter_tensor(self, reduceOp, tag, rankset, group_size) # type: ignore[attr-defined]
|
||||
tensor = torch._C._nn.reduce_scatter_tensor(self, reduceOp, 0, tag, rankset, group_size) # type: ignore[attr-defined]
|
||||
res = _maybe_wrap_tensor(tensor)
|
||||
return res
|
||||
|
||||
c10_lib = torch.library.Library("c10d_functional", "DEF")
|
||||
c10_lib_impl = torch.library.Library("c10d_functional", "IMPL")
|
||||
|
||||
# We now register meta kernels to deal with tracing
|
||||
def _all_reduce_meta(self, *args):
|
||||
return torch.empty_like(self)
|
||||
|
||||
def _wait_tensor_meta(self, *args):
|
||||
return torch.empty_like(self)
|
||||
|
||||
def _all_gather_into_tensor_meta(shard, tag, rankset, group_size):
|
||||
out_size = list(shard.size())
|
||||
out_size[0] *= group_size
|
||||
return shard.new_empty(out_size)
|
||||
|
||||
def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size):
|
||||
out_size = list(input.size())
|
||||
out_size[0] //= group_size
|
||||
return input.new_empty(out_size)
|
||||
|
||||
c10_lib_cpu = torch.library.Library("aten", "IMPL", "CPU")
|
||||
c10_lib_cuda = torch.library.Library("aten", "IMPL", "CUDA")
|
||||
|
||||
def _register_ops():
|
||||
ops_defs = [
|
||||
"all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
|
||||
"wait_tensor(Tensor self) -> Tensor",
|
||||
"all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor",
|
||||
"reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
|
||||
]
|
||||
c10_lib_cpu.impl("all_reduce", _all_reduce)
|
||||
c10_lib_cuda.impl("all_reduce", _all_reduce)
|
||||
|
||||
my_module = sys.modules[__name__]
|
||||
for op_def in ops_defs:
|
||||
op_name = op_def[0:op_def.index('(')]
|
||||
backend_impl = getattr(my_module, f"_{op_name}")
|
||||
meta_impl = getattr(my_module, f"_{op_name}_meta")
|
||||
c10_lib.define(op_def)
|
||||
c10_lib_impl.impl(op_name, backend_impl, "CompositeExplicitAutograd")
|
||||
c10_lib_impl.impl(op_name, meta_impl, "Meta")
|
||||
c10_lib_cpu.impl("wait_tensor", _wait_tensor)
|
||||
c10_lib_cuda.impl("wait_tensor", _wait_tensor)
|
||||
|
||||
c10_lib_cpu.impl("all_gather_into_tensor", _all_gather_into_tensor)
|
||||
c10_lib_cuda.impl("all_gather_into_tensor", _all_gather_into_tensor)
|
||||
|
||||
c10_lib_cpu.impl("reduce_scatter_tensor", _reduce_scatter_tensor)
|
||||
c10_lib_cuda.impl("reduce_scatter_tensor", _reduce_scatter_tensor)
|
||||
|
||||
if sys.executable != 'torch_deploy':
|
||||
_register_ops()
|
||||
|
|
|
|||
|
|
@ -21,9 +21,6 @@ from functorch import make_fx
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
# We need to import _functional_collectives to trigger op registration
|
||||
import torch.distributed._functional_collectives
|
||||
import torch.nn as nn
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
|
|
@ -388,8 +385,8 @@ FOREACH_DECOMP_TABLE = {
|
|||
|
||||
|
||||
DEDUP_TARGETS: Set[torch._ops.OpOverload] = {
|
||||
torch.ops.c10d_functional.all_reduce.default,
|
||||
torch.ops.c10d_functional.wait_tensor.default,
|
||||
aten.all_reduce.default,
|
||||
aten.wait_tensor.default,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -323,10 +323,6 @@ class ProcessLocalGroup(dist.ProcessGroup):
|
|||
ProcessLocalGroup._end_coll(coll, self)
|
||||
return res
|
||||
|
||||
def _reduce_scatter_base(self, output_tensor, input_tensor, opts=AllgatherOptions()):
|
||||
tensor_list = list(torch.chunk(input_tensor, self._world_size))
|
||||
return self.reduce_scatter([output_tensor], [tensor_list], opts)
|
||||
|
||||
def __init__(self, rank, world_size):
|
||||
super().__init__(rank, world_size)
|
||||
self._rank = rank
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user