mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fx collectives bucketing: add bucket all_reduce (#165351)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165351 Approved by: https://github.com/eellison
This commit is contained in:
parent
f06e669f6c
commit
9272437cde
|
|
@ -1743,6 +1743,67 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||
correct = f(*inputs, **self.get_world_trs())
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
||||
@parametrize("bucket_mode", ["all"])
|
||||
def test_all_reduce_bucket(self, bucket_mode):
|
||||
def func(x, w, ar_0, ar_1, tag, ranks, group_size):
|
||||
y = torch.mm(x, w)
|
||||
|
||||
group_name = (
|
||||
torch.distributed.distributed_c10d._get_default_group().group_name
|
||||
)
|
||||
ar_0_out = torch.ops._c10d_functional.all_reduce.default(
|
||||
ar_0, "sum", group_name
|
||||
)
|
||||
ar_1_out = torch.ops._c10d_functional.all_reduce.default(
|
||||
ar_1, "sum", group_name
|
||||
)
|
||||
|
||||
ar_0_w = torch.ops.c10d_functional.wait_tensor(ar_0_out)
|
||||
ar_1_w = torch.ops.c10d_functional.wait_tensor(ar_1_out)
|
||||
|
||||
return y, ar_0_w, ar_1_w
|
||||
|
||||
f = func
|
||||
|
||||
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
|
||||
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
ar_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
ar_1 = torch.ones(384, 256, device="cuda", dtype=torch.float32)
|
||||
inputs = [x, w, ar_0, ar_1]
|
||||
f(*inputs, **self.get_world_trs())
|
||||
|
||||
def _pass(g):
|
||||
from torch._inductor.fx_passes.bucketing import bucket_all_reduce
|
||||
|
||||
bucket_all_reduce(g.owning_module, lambda _: 2000)
|
||||
|
||||
torch._inductor.config.post_grad_custom_post_pass = _pass
|
||||
|
||||
with torch._inductor.config.patch(
|
||||
{
|
||||
"reorder_for_compute_comm_overlap": False,
|
||||
}
|
||||
):
|
||||
compiled = torch.compile(f)
|
||||
compiled(*inputs, **self.get_world_trs())
|
||||
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
|
||||
# NOTE: The first return value should be the output of the first wait_tensor.
|
||||
# We want to make sure no unnecessary copy is made.
|
||||
(
|
||||
FileCheck()
|
||||
.check_count(
|
||||
"torch.ops._c10d_functional.all_reduce_.default(",
|
||||
count=1,
|
||||
exactly=True,
|
||||
)
|
||||
.run(code)
|
||||
)
|
||||
out = compiled(*inputs, **self.get_world_trs())
|
||||
correct = f(*inputs, **self.get_world_trs())
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
||||
@parametrize("bucket_mode", ["all", "all_custom_ops"])
|
||||
|
|
|
|||
|
|
@ -34,11 +34,21 @@ def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
|
|||
return (group_name, reduce_op, dtype)
|
||||
|
||||
|
||||
def _ar_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
|
||||
_, reduce_op, group_name = node.args
|
||||
dtype = node.meta["val"].dtype
|
||||
assert isinstance(group_name, str)
|
||||
assert isinstance(reduce_op, str)
|
||||
return (group_name, reduce_op, dtype)
|
||||
|
||||
|
||||
def bucket_key(node: torch.fx.Node) -> object | None:
|
||||
if is_all_gather_into_tensor(node):
|
||||
return _ag_group_key(node)
|
||||
elif is_reduce_scatter_tensor(node):
|
||||
return _rs_group_key(node)
|
||||
elif is_all_reduce_tensor(node):
|
||||
return _ar_group_key(node)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
|
@ -111,6 +121,13 @@ def is_wait_tensor(node: torch.fx.Node) -> bool:
|
|||
)
|
||||
|
||||
|
||||
def is_all_reduce_tensor(node: torch.fx.Node) -> bool:
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops._c10d_functional.all_reduce.default
|
||||
)
|
||||
|
||||
|
||||
def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool:
|
||||
return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]) # type: ignore[arg-type]
|
||||
|
||||
|
|
@ -293,6 +310,38 @@ def bucket_reduce_scatter_by_mb(
|
|||
)
|
||||
|
||||
|
||||
def bucket_all_reduce_by_mb(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
|
||||
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
|
||||
) -> list[list[torch.fx.Node]]:
|
||||
return greedy_bucket_collective_by_mb(
|
||||
gm,
|
||||
bucket_cap_mb_by_bucket_idx,
|
||||
is_all_reduce_tensor,
|
||||
_ar_group_key,
|
||||
filter_wait_node,
|
||||
)
|
||||
|
||||
|
||||
def bucket_all_reduce(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
|
||||
mode: str | None = None,
|
||||
) -> None:
|
||||
if bucket_cap_mb_by_bucket_idx is None:
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
bucket_cap_mb_by_bucket_idx_default,
|
||||
)
|
||||
|
||||
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
|
||||
ar_buckets = bucket_all_reduce_by_mb(gm, bucket_cap_mb_by_bucket_idx)
|
||||
if len(ar_buckets) == 0:
|
||||
return
|
||||
for bucket in ar_buckets:
|
||||
merge_all_reduce_bucket(gm.graph, bucket, mode)
|
||||
|
||||
|
||||
@torch.library.custom_op("bucketing::_pre_bucket_reduce_scatter", mutates_args={})
|
||||
def _pre_bucket_reduce_scatter(
|
||||
rs_ins: list[torch.Tensor],
|
||||
|
|
@ -364,6 +413,24 @@ def reduce_scatter_merge_fn_to_trace(
|
|||
return new_outs
|
||||
|
||||
|
||||
def all_reduce_merge_fn_to_trace(
|
||||
ar_ins: list[torch.Tensor],
|
||||
group_name: str,
|
||||
reduce_op: str,
|
||||
reduce_dtype: torch.dtype, # type: ignore[name-defined]
|
||||
device: torch.device, # type: ignore[name-defined]
|
||||
) -> list[torch.Tensor]: # type: ignore[no-untyped-def]
|
||||
ar_ins_flattened = [x.view(-1) for x in ar_ins]
|
||||
new_ar_in = torch.cat(ar_ins_flattened)
|
||||
new_ar_out = torch.ops.c10d_functional.wait_tensor(
|
||||
torch.ops._c10d_functional.all_reduce.default(new_ar_in, reduce_op, group_name)
|
||||
)
|
||||
split_sizes = [x.numel() for x in ar_ins]
|
||||
new_outs_flat = new_ar_out.split(split_sizes)
|
||||
new_outs = [x.view(ar_in.shape) for x, ar_in in zip(new_outs_flat, ar_ins)]
|
||||
return new_outs
|
||||
|
||||
|
||||
@torch.library.custom_op("bucketing::_pre_bucket_all_gather", mutates_args={})
|
||||
def _pre_bucket_all_gather(
|
||||
ag_ins: list[torch.Tensor],
|
||||
|
|
@ -713,6 +780,49 @@ def merge_reduce_scatter_bucket(
|
|||
)
|
||||
|
||||
|
||||
def merge_all_reduce_bucket(
|
||||
g: torch.fx.Graph,
|
||||
ar_nodes: list[torch.fx.Node],
|
||||
mode: str | None = None,
|
||||
insert_before: torch.fx.Node | None = None,
|
||||
wait_insertion_point: torch.fx.Node | None = None,
|
||||
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
||||
ar0 = ar_nodes[0]
|
||||
ar0_val = ar0.meta["val"]
|
||||
_, reduce_op, group_name = ar0.args
|
||||
reduce_dtype = ar0_val.dtype
|
||||
device = ar0_val.device
|
||||
|
||||
for n in ar_nodes:
|
||||
ar_val = n.meta["val"]
|
||||
assert (
|
||||
n.args[1] == reduce_op
|
||||
and n.args[2] == group_name
|
||||
and ar_val.device == device
|
||||
and ar_val.dtype == reduce_dtype
|
||||
)
|
||||
|
||||
ar_merge_fn = all_reduce_merge_fn_to_trace
|
||||
|
||||
def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]:
|
||||
return (
|
||||
pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
|
||||
group_name,
|
||||
reduce_op,
|
||||
reduce_dtype,
|
||||
device,
|
||||
)
|
||||
|
||||
return process_collective_bucket(
|
||||
g,
|
||||
ar_nodes,
|
||||
ar_merge_fn,
|
||||
create_trace_args,
|
||||
insert_before=insert_before,
|
||||
wait_insertion_point=wait_insertion_point,
|
||||
)
|
||||
|
||||
|
||||
def merge_all_gather_bucket(
|
||||
g: torch.fx.Graph,
|
||||
ag_nodes: list[torch.fx.Node],
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user