[inductor][comms] skip reorder_for_locality for wait nodes (#150074)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150074
Approved by: https://github.com/eellison, https://github.com/bdhirsh
ghstack dependencies: #150258
This commit is contained in:
Simon Fan 2025-04-15 11:43:42 -07:00 committed by PyTorch MergeBot
parent 159d8a14a6
commit c7400d0026
2 changed files with 69 additions and 0 deletions

View File

@ -379,6 +379,59 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(force_disable_caches=True)
def test_inductor_default_comms_ordering(self):
pg_info = self.get_world_trs()
tag = pg_info["tag"]
ranks = pg_info["ranks"]
group_size = pg_info["group_size"]
g1 = torch.ones(10, 10, device="cuda")
g2 = torch.ones(11, 11, device="cuda")
g3 = torch.ones(12, 12, device="cuda")
def assert_pass(graph):
# all_reduces need to remain in order!
self.assertExpectedInline(
graph,
"""\
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%all_reduce : [num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce.default](args = (%arg0_1, avg, 0), kwargs = {})
%all_reduce_1 : [num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce.default](args = (%arg1_1, avg, 0), kwargs = {})
%all_reduce_2 : [num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce.default](args = (%arg2_1, avg, 0), kwargs = {})
%wait_tensor : [num_users=1] = call_function[target=torch.ops._c10d_functional.wait_tensor.default](args = (%all_reduce_2,), kwargs = {})
%wait_tensor_1 : [num_users=1] = call_function[target=torch.ops._c10d_functional.wait_tensor.default](args = (%all_reduce_1,), kwargs = {})
%wait_tensor_2 : [num_users=1] = call_function[target=torch.ops._c10d_functional.wait_tensor.default](args = (%all_reduce,), kwargs = {})
return (wait_tensor, wait_tensor_1, wait_tensor_2)""", # noqa: B950
)
torch._inductor.config.post_grad_custom_post_pass = assert_pass
@torch.compile
def fn(g1, g2, g3):
handle1 = torch.ops.c10d_functional.all_reduce(
g1, "avg", tag, ranks, group_size
)
handle2 = torch.ops.c10d_functional.all_reduce(
g2, "avg", tag, ranks, group_size
)
handle3 = torch.ops.c10d_functional.all_reduce(
g3, "avg", tag, ranks, group_size
)
# wait on them in a different order
grad3 = torch.ops._c10d_functional.wait_tensor.default(handle3)
grad2 = torch.ops._c10d_functional.wait_tensor.default(handle2)
grad1 = torch.ops._c10d_functional.wait_tensor.default(handle1)
return grad3, grad2, grad1
with _dynamo_dist_per_rank_init(self.rank, self.world_size, fake_pg=True):
fn(g1, g2, g3)
def test_nccl_heuristics(self):
assert len(baseLat) == len(NCCL_ALGO)
assert all(len(x) == len(NCCL_PROTO) for x in baseLat)

View File

@ -485,6 +485,21 @@ def lazy_init():
def reorder_for_locality(graph: torch.fx.Graph):
if torch.distributed.is_available():
def check():
# This is a wait node, and `other_node`` is some collective node.
# Eager semantics allow waits to be issued in a different order than
# the collectives. Reordering this wait node might reorder collectives
# which cause hangs. Once we have SPMD mode, we can safely reorder them.
# However, increasing the locality between a collective and its wait node
# is generally worse for performance.
return node.target != torch.ops._c10d_functional.wait_tensor.default
else:
def check():
return True
def visit(other_node):
if (
other_node.op == "call_function"
@ -492,6 +507,7 @@ def reorder_for_locality(graph: torch.fx.Graph):
and all((n in seen_nodes) for n in other_node.users)
and get_mutation_region_id(graph, node)
== get_mutation_region_id(graph, other_node)
and check()
):
# move node's producers right before it
node.prepend(other_node)