respect aten planned overlap in inductor (#164569)

Now that we have a hop to add implicit deps - use those deps for comm/compute overlap.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164569
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
ghstack dependencies: #164568
This commit is contained in:
eellison 2025-10-03 16:18:33 -07:00 committed by PyTorch MergeBot
parent 4a39820e5e
commit 35f66b83f8
6 changed files with 154 additions and 13 deletions

View File

@ -12,7 +12,7 @@ import torch._dynamo.test_case
import torch.distributed._functional_collectives as _functional_collectives
from torch._C import FileCheck
from torch._dynamo.utils import counters, same
from torch._inductor.utils import run_and_get_triton_code
from torch._inductor.utils import run_and_get_code, run_and_get_triton_code
from torch.testing._internal.common_distributed import (
_dynamo_dist_per_rank_init,
at_least_x_gpu,
@ -67,6 +67,8 @@ def get_patches():
"reorder_for_compute_comm_overlap_passes": [],
"compile_threads": 1,
"force_disable_caches": True,
# Messes up existing test strings
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
}
@ -358,6 +360,8 @@ def get_bucket_patches(compute_multiplier=1.0):
"reorder_for_compute_comm_overlap_passes": [],
"compile_threads": 1,
"force_disable_caches": True,
# messes up test strings
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
}
@ -750,6 +754,85 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
correct = func(a, b, c, ranks=ranks)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches(2.0))
def test_bucketing_split_for_overlap_blocking_deps_inductor(self):
"""Test that 4 independent all-gathers split into 2+2 buckets for better overlap with compute."""
# check that ordering is preserved in inductor
def func(a, b, c, d, *, ranks):
# All 4 all-gathers are independent - COULD be bucketed together
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks)
ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks)
# First compute - can hide ag1 and ag2
e = a * 5 # Use a to avoid fusion
mm1 = torch.matmul(e, e.T)
# Force ag1/ag2 to complete before mm2 (but ag3/ag4 can still be deferred)
# Use first 8x8 elements to match mm1's shape
intermediate = ag1[:8, :8] + ag2[:8, :8]
# Second compute - depends on ag1/ag2 through intermediate, can hide ag3/ag4
mm2 = torch.matmul(mm1 + intermediate, c[:8])
# Use all results
result = (
ag1.sum() * 1.1
+ ag2.sum() * 1.2
+ ag3.sum() * 1.3
+ ag4.sum() * 1.4
+ mm1.sum()
+ mm2.sum()
)
return result
li = []
apply = functools.partial(apply_reordering_and_get_graph, out_li=li)
with (
_dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
),
torch._inductor.config.patch(
"test_configs.aten_fx_overlap_insert_overlap_deps", True
),
torch._inductor.config.patch(post_grad_custom_post_pass=apply),
):
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
compiled = torch.compile(func_c)
test_out, (code,) = run_and_get_code(compiled, a, b, c, d)
# Check that right deps are added
f = FileCheck()
for _ in range(2):
f.check("control_deps").check_same("all_gather").check_same(
"subgraph_mm"
)
f.check("control_deps").check_same("mm").check_same("subgraph_wait")
f.run(li[0])
f = FileCheck()
for _ in range(2):
f.check_count("all_gather_into_tensor_out.default(", 1, exactly=True)
f.check_count("extern_kernels.mm(", 1, exactly=True)
f.check_count("wait_tensor.default(", 1, exactly=True)
f.run(code)
correct = func(a, b, c, d, ranks=ranks)
self.assertTrue(same(test_out, correct))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -2029,6 +2029,9 @@ class test_configs:
# to be migrated when ready for use
aten_fx_overlap_scheduling = False
# insert ordering deps for overlap
aten_fx_overlap_insert_overlap_deps = True
# to be migrated when ready for use
aten_fx_overlap_preserving_bucketing = False

View File

@ -86,6 +86,10 @@ class OverlapPreservingBucketer:
from torch._dynamo.graph_deduplication import _stable_topological_sort
_stable_topological_sort(self.graph, additional_deps)
# After topological sort, preserve dependencies using effect tokens
self._preserve_dependencies_with_tokens(additional_deps)
self.graph.lint()
def _find_buckets(
@ -254,3 +258,19 @@ class OverlapPreservingBucketer:
overlap_deps[new_wait].add(info.hiding_node)
return overlap_deps
def _preserve_dependencies_with_tokens(
self, additional_deps: dict[fx.Node, OrderedSet[fx.Node]]
) -> None:
"""
Preserve dependencies using effect tokens and with_effects higher-order op.
Uses the standalone token_dependencies utility for consistent behavior
across different overlap scheduling approaches.
"""
from torch._inductor.fx_passes.control_dependencies import (
preserve_node_ordering,
)
if torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps:
preserve_node_ordering(self.graph, additional_deps)

View File

@ -378,10 +378,40 @@ class OverlapScheduler:
self._handle_other(node)
self._reorder_graph()
if torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing:
self._bucket_collectives()
elif torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps:
# If not bucketing, add effect tokens to preserve hiding dependencies
self._add_effect_tokens_for_overlap()
return self.gm
def _add_effect_tokens_for_overlap(self) -> None:
"""
Add effect tokens to preserve hiding dependency relationships when not bucketing.
This ensures that communication-compute overlap is preserved through effect tokens
when overlap preserving bucketing is not enabled.
"""
from torch._inductor.fx_passes.control_dependencies import (
preserve_node_ordering,
)
# Collect hiding dependencies: hiding_node -> collective_start, wait -> hiding_node
additional_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
for start_node, info in self.collective_info.items():
if info.hiding_node and not info.is_exposed:
# Compute depends on collective start (compute must wait for collective to start)
additional_deps[info.hiding_node].add(start_node)
# Wait depends on compute (wait must wait for compute to finish)
additional_deps[info.wait_node].add(info.hiding_node)
# Apply effect tokens to preserve these dependencies
if additional_deps:
preserve_node_ordering(self.graph, additional_deps)
def _handle_other(self, node: fx.Node) -> None:
self._schedule(node)

View File

@ -7255,29 +7255,34 @@ def control_deps_op_lowering(additional_deps, subgraph_fn, *args):
output = None
operation_len = len(V.graph.operations)
assert len(subgraph_fn.graph_module.graph.find_nodes(op="placeholder")) == len(args)
for i, node in enumerate(subgraph_fn.graph_module.graph.nodes):
if node.op == "placeholder":
assert node not in V.graph.env
V.graph.env[node] = args[i]
continue
elif node.op == "output":
args, kwargs = V.graph.fetch_args_kwargs_from_env(node)
output = torch.fx.Interpreter.output(V.graph, node, args, kwargs)
else:
assert node not in V.graph.env
V.graph.env[node] = V.graph.run_node(node)
assert output is not None and additional_deps
output_list = output if isinstance(output, (list, tuple)) else [output]
for out in output_list:
if not isinstance(out, IRNode):
continue
# need to realize in order to add the dep
out.realize()
out_name = out.get_name()
# some operators, like wait_tensor, just return their input,
# so its more robust to add dep to the operation itself,
# otherwise you can have a cycle of
# a = coll
# b = control_deps(a, mm, ...)
# c = control_deps(b, wait, ...)
# if c == a, then you have a cycle.
for op in V.graph.operations[operation_len:]:
for dep_name in dep_names:
V.graph.additional_buffer_deps[out_name].add(dep_name)
op_name = op.operation_name
assert op_name is not None
V.graph.additional_buffer_deps[op_name].add(dep_name)
return output

View File

@ -2684,9 +2684,9 @@ class Scheduler:
)
add_user(other_name, node, is_weak=True)
for add_dep in V.graph.additional_buffer_deps[buf.get_name()]:
add_user(add_dep, node, is_weak=True)
node.add_fake_dep(WeakDep(add_dep, node.get_name()))
for add_dep in V.graph.additional_buffer_deps[node.get_name()]:
add_user(add_dep, node, is_weak=True)
node.add_fake_dep(WeakDep(add_dep, node.get_name()))
# add normal non-mutation dependencies
for read in node.read_writes.reads: