[partitioner] always ban compiler-driven recompute of collectives by default (#147561)

This should fix the hang in https://fb.workplace.com/groups/1075192433118967/permalink/1603268720311333/

The argument here is that:

(1) in general, it is not safe for the partitioner to sometimes choose to recompute collectives in the backward. Why? If we are running a distributed job, where many ranks are compiling at the same time, we need every rank to make a consistent decision about which collectives are recomputed for backward. If we let each compiler instance make its own choice without any cross-rank communication, they can make different choices and cause NCCL hangs (see the link above)

(2) later on, we'll want an `spmd_mode` flag that causes the compiler to issue collectives and communicate info across ranks. Once we have such a config, then turning it on should make it safe for the partitioner to potentially choose to recompute collectives (and agree on the binary "recompute-or-save" choice across all ranks)

(3) even without an `spmd_mode`, users can override this choice by using `torch.utils.checkpoint()` in their user code. User checkpointing generally always overrides the partitioner, and this should be safe because we expect the user to apply checkpointing consistently across ranks

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147561
Approved by: https://github.com/zou3519
This commit is contained in:
Brian Hirsh 2025-03-12 11:04:38 -07:00 committed by PyTorch MergeBot
parent 420a9be743
commit 3646d4dbc8
3 changed files with 238 additions and 2 deletions

View File

@ -2,6 +2,7 @@
import datetime
import functools
import unittest
from collections import defaultdict
from unittest.mock import patch
import torch
@ -32,6 +33,7 @@ from torch.testing._internal.common_utils import (
skipIfRocm,
)
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils._python_dispatch import TorchDispatchMode
def _tolist_with_constrain_as_size(tensor):
@ -42,6 +44,7 @@ def _tolist_with_constrain_as_size(tensor):
@requires_nccl()
@instantiate_parametrized_tests
class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
"""
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
@ -550,6 +553,192 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
inductor_out = compiled_fn(*inputs, **trs)
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
# The goal of this test is that when `unsafe_allow_recompute_of_collectives=False`,
# The partitioner will *never* recompute collectives in the backward, even
# if the activation_memory_budget partitioner is being used,
# unless there is a manual user checkpoint() region (which we know makes it safe
# to recompute the collective, since we assume that the user applied the AC
# region consistently across all ranks)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
@patch.object(torch._functorch.config, "activation_memory_budget", 0.01)
@parametrize("override_with_ac", [False, True])
def test_all_to_all_recompute_is_always_banned(self, override_with_ac):
@torch.library.custom_op("custom_ns::foo", mutates_args=())
def foo(x: torch.Tensor) -> torch.Tensor:
return x + 1
@foo.register_fake
def _(x):
return torch.empty_like(x)
def setup_context(ctx, inputs, output):
ctx.save_for_backward(inputs[0])
return
def backward(ctx, grad):
(x,) = ctx.saved_tensors
return grad * x
foo.register_autograd(backward, setup_context=setup_context)
class AllToAllSingle(torch.autograd.Function):
@staticmethod
def forward(
ctx,
input: torch.Tensor,
output_split_sizes,
input_split_sizes,
tag,
ranks,
group_size: int,
) -> torch.Tensor:
ctx.output_split_sizes = input_split_sizes
ctx.input_split_sizes = output_split_sizes
ctx.group_size = group_size
a2a = torch.ops._c10d_functional.all_to_all_single.default(
input,
output_split_sizes,
input_split_sizes,
"0",
)
a2a = torch.ops.c10d_functional.wait_tensor(a2a)
return a2a
@staticmethod
def backward(ctx, grad):
grad = torch.ops._c10d_functional.all_to_all_single.default(
grad,
ctx.output_split_sizes,
ctx.input_split_sizes,
"0",
)
return (
torch.ops.c10d_functional.wait_tensor(grad),
None,
None,
None,
None,
None,
)
def alltoall_autograd(
inp,
output_split_sizes,
input_split_sizes,
tag,
ranks,
group_size,
):
out = AllToAllSingle.apply(
inp, output_split_sizes, input_split_sizes, tag, ranks, group_size
)
return out
# simple mode to track how many collective ops we saw in the backward
class TrackingMode(TorchDispatchMode):
def __init__(self):
super().__init__()
self.ops_counter = defaultdict(int)
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
rs = func(*args, **kwargs)
self.ops_counter[func] += 1
return rs
def example(
inp,
input_split_sizes_tensor,
output_split_sizes_tensor,
*,
tag,
ranks,
group_size,
):
input_split_sizes = _tolist_with_constrain_as_size(input_split_sizes_tensor)
output_split_sizes = _tolist_with_constrain_as_size(
output_split_sizes_tensor
)
a2a = torch.ops.custom_ns.alltoall_autograd.default(
inp,
output_split_sizes,
input_split_sizes,
tag,
ranks,
group_size,
)
return torch.ops.custom_ns.foo(a2a)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size
), torch._dynamo.config.patch(
dynamic_shapes=True,
capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True,
), torch.library._scoped_library(
"custom_ns", "FRAGMENT"
) as lib:
lib.define(
"alltoall_autograd(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor" # noqa: B950
)
lib.impl("alltoall_autograd", alltoall_autograd, "Autograd")
lib.impl("alltoall_autograd", alltoall_autograd, "Meta")
row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2
input_split_sizes_tensor = torch.tensor(
[(i + 1) * (self.rank + 1) for i in range(self.world_size)],
dtype=torch.int64,
)
output_split_sizes_tensor = torch.tensor(
[(i + 1) * (self.rank + 1) for i in range(self.world_size)],
dtype=torch.int64,
)
inputs = (
torch.ones(int(row), 5, device="cuda", requires_grad=True)
* (self.rank + 1),
input_split_sizes_tensor,
output_split_sizes_tensor,
)
trs = self.get_world_trs()
compiled_fn = torch.compile(example, fullgraph=True, dynamic=True)
if override_with_ac:
def compiled_fn_wrapper(*args):
return example(*inputs, **trs)
out = torch.utils.checkpoint.checkpoint(
compiled_fn_wrapper, *inputs, use_reentrant=False
)
else:
out = compiled_fn(*inputs, **trs)
# track how many all_to_alls we saw in the backward
with TrackingMode() as m:
out.sum().backward()
if override_with_ac:
# We wrapped our test in AC, which overrides the partitioner decision
# of never recomputing collectives.
# So we should properly see the all2all be recomputed in the backward
self.assertEqual(
m.ops_counter[torch.ops._c10d_functional.all_to_all_single.default],
2,
)
else:
# there is 1 all2all in the fw, and 1 all2all in the backward.
# notably: even though activation_memory_budget == 0 ("recompute_everything"),
# we are still choosing *not* to recompute the all2all from the fw
self.assertEqual(
m.ops_counter[torch.ops._c10d_functional.all_to_all_single.default],
1,
)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_all_to_all_single_inductor_split_sizes_none(self):

View File

@ -223,6 +223,30 @@ graphsafe_rng_functionalization = True
# Used for tests
strict_autograd_cache = False
# Note [Recomputing collectives in the partitioner]
# The purpose of this config is as follows:
# - We have many passes in the compiler (min-cut partitioning, DCE, etc)
# which can reorder or ,delete duplicate nodes in the graph
# - If any of these passes reorder/delete/duplicate a collective
# in a setting where the compiler is being run independently on multiple
# ranks, we run the risk that the compiler will make a different decison on
# different ranks, resulting in a NCCL hang when using torch.compile
# To handle this, we will (by default) ensure that collectives are not modified
# by the compiler.
#
# A few examples:
# - don't dead-code-eliminate collectives
# (in case they are dead on rank i but not rank j)
# - don't recompute collectives in partitioning
# (in case we recompute on rank i but not rank j)
#
# Today this flag **must** be set to false, but eventually
# we want the option to set it to true.
# In order to potentially optimize collectives, we'll need the compiler
# to broadcast information across ranks at compile time to ensure
# that any decisions on collectives are made consistently.
unsafe_allow_optimization_of_collectives = False
# See Note [AOTAutograd Tangent Subclassness for mutated inputs]
# TODO(ivankobzarev): Remove this config, being able to deduce it compile time.
disable_guess_zero_tangent_for_mutated_input_subclass = False

View File

@ -920,6 +920,21 @@ def functionalize_rng_ops(
return fw_module, bw_module
def force_save_collectives(joint_module: fx.GraphModule) -> None:
"""
By default, the partitioner is not allowed to recompute collectives
unless they come from a user-annotated AC region.
See Note [Recomputing collectives in the partitioner]
"""
for node in joint_module.graph.nodes:
if (
isinstance(node.target, torch._ops.OpOverload)
and node.target.namespace == "_c10d_functional"
and not must_recompute(node)
):
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
"""
If there are two consecutive checkpointed blocks with no operator in
@ -1134,7 +1149,14 @@ def solve_min_cut(
if op_types.is_view(node):
return False
if node in dont_ban:
return False
# collectives are *always* banned from recompute, overriding `dont_ban`
# (in particular, the activation memory budget logic is not allowed to recompute collectives)
is_collective = (
isinstance(node.target, torch._ops.OpOverload)
and node.target.namespace == "_c10d_functional"
)
if config.unsafe_allow_optimization_of_collectives or not is_collective:
return False
# This bans recomputation of the node unless we've been forced not to by
# user annotation
if must_recompute(node):
@ -1142,7 +1164,6 @@ def solve_min_cut(
if "val" in node.meta and isinstance(node.meta["val"], torch.SymFloat):
return False
banned_nodes.add(node)
# A node will only ever be recomputed if there is a path from an
# ancestor of this node to the backwards path through this node that
@ -1926,6 +1947,8 @@ def min_cut_rematerialization_partition(
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
if graph_has_recomputable_ops:
joint_module = cleanup_recompute_tags(joint_module)
if not config.unsafe_allow_optimization_of_collectives:
force_save_collectives(joint_module)
def classify_nodes(joint_module):
name_to_node = get_name_to_node(joint_module.graph)