mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
respect aten planned overlap in inductor (#164569)
Now that we have a hop to add implicit deps - use those deps for comm/compute overlap. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164569 Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev ghstack dependencies: #164568
This commit is contained in:
parent
4a39820e5e
commit
35f66b83f8
|
|
@ -12,7 +12,7 @@ import torch._dynamo.test_case
|
|||
import torch.distributed._functional_collectives as _functional_collectives
|
||||
from torch._C import FileCheck
|
||||
from torch._dynamo.utils import counters, same
|
||||
from torch._inductor.utils import run_and_get_triton_code
|
||||
from torch._inductor.utils import run_and_get_code, run_and_get_triton_code
|
||||
from torch.testing._internal.common_distributed import (
|
||||
_dynamo_dist_per_rank_init,
|
||||
at_least_x_gpu,
|
||||
|
|
@ -67,6 +67,8 @@ def get_patches():
|
|||
"reorder_for_compute_comm_overlap_passes": [],
|
||||
"compile_threads": 1,
|
||||
"force_disable_caches": True,
|
||||
# Messes up existing test strings
|
||||
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -358,6 +360,8 @@ def get_bucket_patches(compute_multiplier=1.0):
|
|||
"reorder_for_compute_comm_overlap_passes": [],
|
||||
"compile_threads": 1,
|
||||
"force_disable_caches": True,
|
||||
# messes up test strings
|
||||
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -750,6 +754,85 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
|
|||
correct = func(a, b, c, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches(2.0))
|
||||
def test_bucketing_split_for_overlap_blocking_deps_inductor(self):
|
||||
"""Test that 4 independent all-gathers split into 2+2 buckets for better overlap with compute."""
|
||||
|
||||
# check that ordering is preserved in inductor
|
||||
|
||||
def func(a, b, c, d, *, ranks):
|
||||
# All 4 all-gathers are independent - COULD be bucketed together
|
||||
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
|
||||
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
|
||||
ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks)
|
||||
ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks)
|
||||
|
||||
# First compute - can hide ag1 and ag2
|
||||
e = a * 5 # Use a to avoid fusion
|
||||
mm1 = torch.matmul(e, e.T)
|
||||
|
||||
# Force ag1/ag2 to complete before mm2 (but ag3/ag4 can still be deferred)
|
||||
# Use first 8x8 elements to match mm1's shape
|
||||
intermediate = ag1[:8, :8] + ag2[:8, :8]
|
||||
|
||||
# Second compute - depends on ag1/ag2 through intermediate, can hide ag3/ag4
|
||||
mm2 = torch.matmul(mm1 + intermediate, c[:8])
|
||||
|
||||
# Use all results
|
||||
result = (
|
||||
ag1.sum() * 1.1
|
||||
+ ag2.sum() * 1.2
|
||||
+ ag3.sum() * 1.3
|
||||
+ ag4.sum() * 1.4
|
||||
+ mm1.sum()
|
||||
+ mm2.sum()
|
||||
)
|
||||
return result
|
||||
|
||||
li = []
|
||||
apply = functools.partial(apply_reordering_and_get_graph, out_li=li)
|
||||
with (
|
||||
_dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
),
|
||||
torch._inductor.config.patch(
|
||||
"test_configs.aten_fx_overlap_insert_overlap_deps", True
|
||||
),
|
||||
torch._inductor.config.patch(post_grad_custom_post_pass=apply),
|
||||
):
|
||||
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
|
||||
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
|
||||
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
|
||||
d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
compiled = torch.compile(func_c)
|
||||
test_out, (code,) = run_and_get_code(compiled, a, b, c, d)
|
||||
|
||||
# Check that right deps are added
|
||||
f = FileCheck()
|
||||
for _ in range(2):
|
||||
f.check("control_deps").check_same("all_gather").check_same(
|
||||
"subgraph_mm"
|
||||
)
|
||||
f.check("control_deps").check_same("mm").check_same("subgraph_wait")
|
||||
f.run(li[0])
|
||||
|
||||
f = FileCheck()
|
||||
for _ in range(2):
|
||||
f.check_count("all_gather_into_tensor_out.default(", 1, exactly=True)
|
||||
f.check_count("extern_kernels.mm(", 1, exactly=True)
|
||||
f.check_count("wait_tensor.default(", 1, exactly=True)
|
||||
f.run(code)
|
||||
|
||||
correct = func(a, b, c, d, ranks=ranks)
|
||||
self.assertTrue(same(test_out, correct))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -2029,6 +2029,9 @@ class test_configs:
|
|||
# to be migrated when ready for use
|
||||
aten_fx_overlap_scheduling = False
|
||||
|
||||
# insert ordering deps for overlap
|
||||
aten_fx_overlap_insert_overlap_deps = True
|
||||
|
||||
# to be migrated when ready for use
|
||||
aten_fx_overlap_preserving_bucketing = False
|
||||
|
||||
|
|
|
|||
|
|
@ -86,6 +86,10 @@ class OverlapPreservingBucketer:
|
|||
from torch._dynamo.graph_deduplication import _stable_topological_sort
|
||||
|
||||
_stable_topological_sort(self.graph, additional_deps)
|
||||
|
||||
# After topological sort, preserve dependencies using effect tokens
|
||||
self._preserve_dependencies_with_tokens(additional_deps)
|
||||
|
||||
self.graph.lint()
|
||||
|
||||
def _find_buckets(
|
||||
|
|
@ -254,3 +258,19 @@ class OverlapPreservingBucketer:
|
|||
overlap_deps[new_wait].add(info.hiding_node)
|
||||
|
||||
return overlap_deps
|
||||
|
||||
def _preserve_dependencies_with_tokens(
|
||||
self, additional_deps: dict[fx.Node, OrderedSet[fx.Node]]
|
||||
) -> None:
|
||||
"""
|
||||
Preserve dependencies using effect tokens and with_effects higher-order op.
|
||||
|
||||
Uses the standalone token_dependencies utility for consistent behavior
|
||||
across different overlap scheduling approaches.
|
||||
"""
|
||||
from torch._inductor.fx_passes.control_dependencies import (
|
||||
preserve_node_ordering,
|
||||
)
|
||||
|
||||
if torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps:
|
||||
preserve_node_ordering(self.graph, additional_deps)
|
||||
|
|
|
|||
|
|
@ -378,10 +378,40 @@ class OverlapScheduler:
|
|||
self._handle_other(node)
|
||||
|
||||
self._reorder_graph()
|
||||
|
||||
if torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing:
|
||||
self._bucket_collectives()
|
||||
elif torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps:
|
||||
# If not bucketing, add effect tokens to preserve hiding dependencies
|
||||
self._add_effect_tokens_for_overlap()
|
||||
|
||||
return self.gm
|
||||
|
||||
def _add_effect_tokens_for_overlap(self) -> None:
|
||||
"""
|
||||
Add effect tokens to preserve hiding dependency relationships when not bucketing.
|
||||
|
||||
This ensures that communication-compute overlap is preserved through effect tokens
|
||||
when overlap preserving bucketing is not enabled.
|
||||
"""
|
||||
from torch._inductor.fx_passes.control_dependencies import (
|
||||
preserve_node_ordering,
|
||||
)
|
||||
|
||||
# Collect hiding dependencies: hiding_node -> collective_start, wait -> hiding_node
|
||||
additional_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
|
||||
|
||||
for start_node, info in self.collective_info.items():
|
||||
if info.hiding_node and not info.is_exposed:
|
||||
# Compute depends on collective start (compute must wait for collective to start)
|
||||
additional_deps[info.hiding_node].add(start_node)
|
||||
# Wait depends on compute (wait must wait for compute to finish)
|
||||
additional_deps[info.wait_node].add(info.hiding_node)
|
||||
|
||||
# Apply effect tokens to preserve these dependencies
|
||||
if additional_deps:
|
||||
preserve_node_ordering(self.graph, additional_deps)
|
||||
|
||||
def _handle_other(self, node: fx.Node) -> None:
|
||||
self._schedule(node)
|
||||
|
||||
|
|
|
|||
|
|
@ -7255,29 +7255,34 @@ def control_deps_op_lowering(additional_deps, subgraph_fn, *args):
|
|||
|
||||
output = None
|
||||
|
||||
operation_len = len(V.graph.operations)
|
||||
assert len(subgraph_fn.graph_module.graph.find_nodes(op="placeholder")) == len(args)
|
||||
for i, node in enumerate(subgraph_fn.graph_module.graph.nodes):
|
||||
if node.op == "placeholder":
|
||||
assert node not in V.graph.env
|
||||
V.graph.env[node] = args[i]
|
||||
continue
|
||||
elif node.op == "output":
|
||||
args, kwargs = V.graph.fetch_args_kwargs_from_env(node)
|
||||
output = torch.fx.Interpreter.output(V.graph, node, args, kwargs)
|
||||
else:
|
||||
assert node not in V.graph.env
|
||||
V.graph.env[node] = V.graph.run_node(node)
|
||||
|
||||
assert output is not None and additional_deps
|
||||
output_list = output if isinstance(output, (list, tuple)) else [output]
|
||||
|
||||
for out in output_list:
|
||||
if not isinstance(out, IRNode):
|
||||
continue
|
||||
|
||||
# need to realize in order to add the dep
|
||||
out.realize()
|
||||
out_name = out.get_name()
|
||||
# some operators, like wait_tensor, just return their input,
|
||||
# so its more robust to add dep to the operation itself,
|
||||
# otherwise you can have a cycle of
|
||||
# a = coll
|
||||
# b = control_deps(a, mm, ...)
|
||||
# c = control_deps(b, wait, ...)
|
||||
# if c == a, then you have a cycle.
|
||||
for op in V.graph.operations[operation_len:]:
|
||||
for dep_name in dep_names:
|
||||
V.graph.additional_buffer_deps[out_name].add(dep_name)
|
||||
op_name = op.operation_name
|
||||
assert op_name is not None
|
||||
V.graph.additional_buffer_deps[op_name].add(dep_name)
|
||||
|
||||
return output
|
||||
|
||||
|
|
|
|||
|
|
@ -2684,9 +2684,9 @@ class Scheduler:
|
|||
)
|
||||
add_user(other_name, node, is_weak=True)
|
||||
|
||||
for add_dep in V.graph.additional_buffer_deps[buf.get_name()]:
|
||||
add_user(add_dep, node, is_weak=True)
|
||||
node.add_fake_dep(WeakDep(add_dep, node.get_name()))
|
||||
for add_dep in V.graph.additional_buffer_deps[node.get_name()]:
|
||||
add_user(add_dep, node, is_weak=True)
|
||||
node.add_fake_dep(WeakDep(add_dep, node.get_name()))
|
||||
|
||||
# add normal non-mutation dependencies
|
||||
for read in node.read_writes.reads:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user