mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor][comms] skip reorder_for_locality for wait nodes (#150074)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150074 Approved by: https://github.com/eellison, https://github.com/bdhirsh ghstack dependencies: #150258
This commit is contained in:
parent
159d8a14a6
commit
c7400d0026
|
|
@ -379,6 +379,59 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
correct = func(inputs, **self.get_world_trs())
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(force_disable_caches=True)
|
||||
def test_inductor_default_comms_ordering(self):
|
||||
pg_info = self.get_world_trs()
|
||||
tag = pg_info["tag"]
|
||||
ranks = pg_info["ranks"]
|
||||
group_size = pg_info["group_size"]
|
||||
|
||||
g1 = torch.ones(10, 10, device="cuda")
|
||||
g2 = torch.ones(11, 11, device="cuda")
|
||||
g3 = torch.ones(12, 12, device="cuda")
|
||||
|
||||
def assert_pass(graph):
|
||||
# all_reduces need to remain in order!
|
||||
self.assertExpectedInline(
|
||||
graph,
|
||||
"""\
|
||||
graph():
|
||||
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
|
||||
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
|
||||
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
|
||||
%all_reduce : [num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce.default](args = (%arg0_1, avg, 0), kwargs = {})
|
||||
%all_reduce_1 : [num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce.default](args = (%arg1_1, avg, 0), kwargs = {})
|
||||
%all_reduce_2 : [num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce.default](args = (%arg2_1, avg, 0), kwargs = {})
|
||||
%wait_tensor : [num_users=1] = call_function[target=torch.ops._c10d_functional.wait_tensor.default](args = (%all_reduce_2,), kwargs = {})
|
||||
%wait_tensor_1 : [num_users=1] = call_function[target=torch.ops._c10d_functional.wait_tensor.default](args = (%all_reduce_1,), kwargs = {})
|
||||
%wait_tensor_2 : [num_users=1] = call_function[target=torch.ops._c10d_functional.wait_tensor.default](args = (%all_reduce,), kwargs = {})
|
||||
return (wait_tensor, wait_tensor_1, wait_tensor_2)""", # noqa: B950
|
||||
)
|
||||
|
||||
torch._inductor.config.post_grad_custom_post_pass = assert_pass
|
||||
|
||||
@torch.compile
|
||||
def fn(g1, g2, g3):
|
||||
handle1 = torch.ops.c10d_functional.all_reduce(
|
||||
g1, "avg", tag, ranks, group_size
|
||||
)
|
||||
handle2 = torch.ops.c10d_functional.all_reduce(
|
||||
g2, "avg", tag, ranks, group_size
|
||||
)
|
||||
handle3 = torch.ops.c10d_functional.all_reduce(
|
||||
g3, "avg", tag, ranks, group_size
|
||||
)
|
||||
|
||||
# wait on them in a different order
|
||||
grad3 = torch.ops._c10d_functional.wait_tensor.default(handle3)
|
||||
grad2 = torch.ops._c10d_functional.wait_tensor.default(handle2)
|
||||
grad1 = torch.ops._c10d_functional.wait_tensor.default(handle1)
|
||||
return grad3, grad2, grad1
|
||||
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size, fake_pg=True):
|
||||
fn(g1, g2, g3)
|
||||
|
||||
def test_nccl_heuristics(self):
|
||||
assert len(baseLat) == len(NCCL_ALGO)
|
||||
assert all(len(x) == len(NCCL_PROTO) for x in baseLat)
|
||||
|
|
|
|||
|
|
@ -485,6 +485,21 @@ def lazy_init():
|
|||
|
||||
|
||||
def reorder_for_locality(graph: torch.fx.Graph):
|
||||
if torch.distributed.is_available():
|
||||
|
||||
def check():
|
||||
# This is a wait node, and `other_node`` is some collective node.
|
||||
# Eager semantics allow waits to be issued in a different order than
|
||||
# the collectives. Reordering this wait node might reorder collectives
|
||||
# which cause hangs. Once we have SPMD mode, we can safely reorder them.
|
||||
# However, increasing the locality between a collective and its wait node
|
||||
# is generally worse for performance.
|
||||
return node.target != torch.ops._c10d_functional.wait_tensor.default
|
||||
else:
|
||||
|
||||
def check():
|
||||
return True
|
||||
|
||||
def visit(other_node):
|
||||
if (
|
||||
other_node.op == "call_function"
|
||||
|
|
@ -492,6 +507,7 @@ def reorder_for_locality(graph: torch.fx.Graph):
|
|||
and all((n in seen_nodes) for n in other_node.users)
|
||||
and get_mutation_region_id(graph, node)
|
||||
== get_mutation_region_id(graph, other_node)
|
||||
and check()
|
||||
):
|
||||
# move node's producers right before it
|
||||
node.prepend(other_node)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user