Fx collectives bucketing: add bucket all_reduce (#165351)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165351
Approved by: https://github.com/eellison
This commit is contained in:
IvanKobzarev 2025-10-16 03:51:46 -07:00 committed by PyTorch MergeBot
parent f06e669f6c
commit 9272437cde
2 changed files with 171 additions and 0 deletions

View File

@ -1743,6 +1743,67 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
correct = f(*inputs, **self.get_world_trs())
assert same(out, correct), f"{out} va {correct}"
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not SM80OrLater, "bfloat16")
@parametrize("bucket_mode", ["all"])
def test_all_reduce_bucket(self, bucket_mode):
def func(x, w, ar_0, ar_1, tag, ranks, group_size):
y = torch.mm(x, w)
group_name = (
torch.distributed.distributed_c10d._get_default_group().group_name
)
ar_0_out = torch.ops._c10d_functional.all_reduce.default(
ar_0, "sum", group_name
)
ar_1_out = torch.ops._c10d_functional.all_reduce.default(
ar_1, "sum", group_name
)
ar_0_w = torch.ops.c10d_functional.wait_tensor(ar_0_out)
ar_1_w = torch.ops.c10d_functional.wait_tensor(ar_1_out)
return y, ar_0_w, ar_1_w
f = func
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
ar_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
ar_1 = torch.ones(384, 256, device="cuda", dtype=torch.float32)
inputs = [x, w, ar_0, ar_1]
f(*inputs, **self.get_world_trs())
def _pass(g):
from torch._inductor.fx_passes.bucketing import bucket_all_reduce
bucket_all_reduce(g.owning_module, lambda _: 2000)
torch._inductor.config.post_grad_custom_post_pass = _pass
with torch._inductor.config.patch(
{
"reorder_for_compute_comm_overlap": False,
}
):
compiled = torch.compile(f)
compiled(*inputs, **self.get_world_trs())
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
# NOTE: The first return value should be the output of the first wait_tensor.
# We want to make sure no unnecessary copy is made.
(
FileCheck()
.check_count(
"torch.ops._c10d_functional.all_reduce_.default(",
count=1,
exactly=True,
)
.run(code)
)
out = compiled(*inputs, **self.get_world_trs())
correct = f(*inputs, **self.get_world_trs())
assert same(out, correct), f"{out} va {correct}"
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not SM80OrLater, "bfloat16")
@parametrize("bucket_mode", ["all", "all_custom_ops"])

View File

@ -34,11 +34,21 @@ def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
return (group_name, reduce_op, dtype)
def _ar_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
_, reduce_op, group_name = node.args
dtype = node.meta["val"].dtype
assert isinstance(group_name, str)
assert isinstance(reduce_op, str)
return (group_name, reduce_op, dtype)
def bucket_key(node: torch.fx.Node) -> object | None:
if is_all_gather_into_tensor(node):
return _ag_group_key(node)
elif is_reduce_scatter_tensor(node):
return _rs_group_key(node)
elif is_all_reduce_tensor(node):
return _ar_group_key(node)
else:
return None
@ -111,6 +121,13 @@ def is_wait_tensor(node: torch.fx.Node) -> bool:
)
def is_all_reduce_tensor(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.all_reduce.default
)
def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool:
return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]) # type: ignore[arg-type]
@ -293,6 +310,38 @@ def bucket_reduce_scatter_by_mb(
)
def bucket_all_reduce_by_mb(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
) -> list[list[torch.fx.Node]]:
return greedy_bucket_collective_by_mb(
gm,
bucket_cap_mb_by_bucket_idx,
is_all_reduce_tensor,
_ar_group_key,
filter_wait_node,
)
def bucket_all_reduce(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
mode: str | None = None,
) -> None:
if bucket_cap_mb_by_bucket_idx is None:
from torch._inductor.fx_passes.bucketing import (
bucket_cap_mb_by_bucket_idx_default,
)
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
ar_buckets = bucket_all_reduce_by_mb(gm, bucket_cap_mb_by_bucket_idx)
if len(ar_buckets) == 0:
return
for bucket in ar_buckets:
merge_all_reduce_bucket(gm.graph, bucket, mode)
@torch.library.custom_op("bucketing::_pre_bucket_reduce_scatter", mutates_args={})
def _pre_bucket_reduce_scatter(
rs_ins: list[torch.Tensor],
@ -364,6 +413,24 @@ def reduce_scatter_merge_fn_to_trace(
return new_outs
def all_reduce_merge_fn_to_trace(
ar_ins: list[torch.Tensor],
group_name: str,
reduce_op: str,
reduce_dtype: torch.dtype, # type: ignore[name-defined]
device: torch.device, # type: ignore[name-defined]
) -> list[torch.Tensor]: # type: ignore[no-untyped-def]
ar_ins_flattened = [x.view(-1) for x in ar_ins]
new_ar_in = torch.cat(ar_ins_flattened)
new_ar_out = torch.ops.c10d_functional.wait_tensor(
torch.ops._c10d_functional.all_reduce.default(new_ar_in, reduce_op, group_name)
)
split_sizes = [x.numel() for x in ar_ins]
new_outs_flat = new_ar_out.split(split_sizes)
new_outs = [x.view(ar_in.shape) for x, ar_in in zip(new_outs_flat, ar_ins)]
return new_outs
@torch.library.custom_op("bucketing::_pre_bucket_all_gather", mutates_args={})
def _pre_bucket_all_gather(
ag_ins: list[torch.Tensor],
@ -713,6 +780,49 @@ def merge_reduce_scatter_bucket(
)
def merge_all_reduce_bucket(
g: torch.fx.Graph,
ar_nodes: list[torch.fx.Node],
mode: str | None = None,
insert_before: torch.fx.Node | None = None,
wait_insertion_point: torch.fx.Node | None = None,
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
ar0 = ar_nodes[0]
ar0_val = ar0.meta["val"]
_, reduce_op, group_name = ar0.args
reduce_dtype = ar0_val.dtype
device = ar0_val.device
for n in ar_nodes:
ar_val = n.meta["val"]
assert (
n.args[1] == reduce_op
and n.args[2] == group_name
and ar_val.device == device
and ar_val.dtype == reduce_dtype
)
ar_merge_fn = all_reduce_merge_fn_to_trace
def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]:
return (
pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
group_name,
reduce_op,
reduce_dtype,
device,
)
return process_collective_bucket(
g,
ar_nodes,
ar_merge_fn,
create_trace_args,
insert_before=insert_before,
wait_insertion_point=wait_insertion_point,
)
def merge_all_gather_bucket(
g: torch.fx.Graph,
ag_nodes: list[torch.fx.Node],