mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[partitioner] always ban compiler-driven recompute of collectives by default (#147561)
This should fix the hang in https://fb.workplace.com/groups/1075192433118967/permalink/1603268720311333/ The argument here is that: (1) in general, it is not safe for the partitioner to sometimes choose to recompute collectives in the backward. Why? If we are running a distributed job, where many ranks are compiling at the same time, we need every rank to make a consistent decision about which collectives are recomputed for backward. If we let each compiler instance make its own choice without any cross-rank communication, they can make different choices and cause NCCL hangs (see the link above) (2) later on, we'll want an `spmd_mode` flag that causes the compiler to issue collectives and communicate info across ranks. Once we have such a config, then turning it on should make it safe for the partitioner to potentially choose to recompute collectives (and agree on the binary "recompute-or-save" choice across all ranks) (3) even without an `spmd_mode`, users can override this choice by using `torch.utils.checkpoint()` in their user code. User checkpointing generally always overrides the partitioner, and this should be safe because we expect the user to apply checkpointing consistently across ranks Pull Request resolved: https://github.com/pytorch/pytorch/pull/147561 Approved by: https://github.com/zou3519
This commit is contained in:
parent
420a9be743
commit
3646d4dbc8
|
|
@ -2,6 +2,7 @@
|
|||
import datetime
|
||||
import functools
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
|
@ -32,6 +33,7 @@ from torch.testing._internal.common_utils import (
|
|||
skipIfRocm,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
def _tolist_with_constrain_as_size(tensor):
|
||||
|
|
@ -42,6 +44,7 @@ def _tolist_with_constrain_as_size(tensor):
|
|||
|
||||
|
||||
@requires_nccl()
|
||||
@instantiate_parametrized_tests
|
||||
class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
"""
|
||||
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
|
||||
|
|
@ -550,6 +553,192 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
inductor_out = compiled_fn(*inputs, **trs)
|
||||
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
|
||||
|
||||
# The goal of this test is that when `unsafe_allow_recompute_of_collectives=False`,
|
||||
# The partitioner will *never* recompute collectives in the backward, even
|
||||
# if the activation_memory_budget partitioner is being used,
|
||||
# unless there is a manual user checkpoint() region (which we know makes it safe
|
||||
# to recompute the collective, since we assume that the user applied the AC
|
||||
# region consistently across all ranks)
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
||||
@patch.object(torch._functorch.config, "activation_memory_budget", 0.01)
|
||||
@parametrize("override_with_ac", [False, True])
|
||||
def test_all_to_all_recompute_is_always_banned(self, override_with_ac):
|
||||
@torch.library.custom_op("custom_ns::foo", mutates_args=())
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
return x + 1
|
||||
|
||||
@foo.register_fake
|
||||
def _(x):
|
||||
return torch.empty_like(x)
|
||||
|
||||
def setup_context(ctx, inputs, output):
|
||||
ctx.save_for_backward(inputs[0])
|
||||
return
|
||||
|
||||
def backward(ctx, grad):
|
||||
(x,) = ctx.saved_tensors
|
||||
return grad * x
|
||||
|
||||
foo.register_autograd(backward, setup_context=setup_context)
|
||||
|
||||
class AllToAllSingle(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
input: torch.Tensor,
|
||||
output_split_sizes,
|
||||
input_split_sizes,
|
||||
tag,
|
||||
ranks,
|
||||
group_size: int,
|
||||
) -> torch.Tensor:
|
||||
ctx.output_split_sizes = input_split_sizes
|
||||
ctx.input_split_sizes = output_split_sizes
|
||||
ctx.group_size = group_size
|
||||
a2a = torch.ops._c10d_functional.all_to_all_single.default(
|
||||
input,
|
||||
output_split_sizes,
|
||||
input_split_sizes,
|
||||
"0",
|
||||
)
|
||||
a2a = torch.ops.c10d_functional.wait_tensor(a2a)
|
||||
return a2a
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
grad = torch.ops._c10d_functional.all_to_all_single.default(
|
||||
grad,
|
||||
ctx.output_split_sizes,
|
||||
ctx.input_split_sizes,
|
||||
"0",
|
||||
)
|
||||
|
||||
return (
|
||||
torch.ops.c10d_functional.wait_tensor(grad),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
def alltoall_autograd(
|
||||
inp,
|
||||
output_split_sizes,
|
||||
input_split_sizes,
|
||||
tag,
|
||||
ranks,
|
||||
group_size,
|
||||
):
|
||||
out = AllToAllSingle.apply(
|
||||
inp, output_split_sizes, input_split_sizes, tag, ranks, group_size
|
||||
)
|
||||
return out
|
||||
|
||||
# simple mode to track how many collective ops we saw in the backward
|
||||
class TrackingMode(TorchDispatchMode):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ops_counter = defaultdict(int)
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
rs = func(*args, **kwargs)
|
||||
self.ops_counter[func] += 1
|
||||
return rs
|
||||
|
||||
def example(
|
||||
inp,
|
||||
input_split_sizes_tensor,
|
||||
output_split_sizes_tensor,
|
||||
*,
|
||||
tag,
|
||||
ranks,
|
||||
group_size,
|
||||
):
|
||||
input_split_sizes = _tolist_with_constrain_as_size(input_split_sizes_tensor)
|
||||
output_split_sizes = _tolist_with_constrain_as_size(
|
||||
output_split_sizes_tensor
|
||||
)
|
||||
a2a = torch.ops.custom_ns.alltoall_autograd.default(
|
||||
inp,
|
||||
output_split_sizes,
|
||||
input_split_sizes,
|
||||
tag,
|
||||
ranks,
|
||||
group_size,
|
||||
)
|
||||
|
||||
return torch.ops.custom_ns.foo(a2a)
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank, self.world_size
|
||||
), torch._dynamo.config.patch(
|
||||
dynamic_shapes=True,
|
||||
capture_dynamic_output_shape_ops=True,
|
||||
capture_scalar_outputs=True,
|
||||
), torch.library._scoped_library(
|
||||
"custom_ns", "FRAGMENT"
|
||||
) as lib:
|
||||
lib.define(
|
||||
"alltoall_autograd(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor" # noqa: B950
|
||||
)
|
||||
lib.impl("alltoall_autograd", alltoall_autograd, "Autograd")
|
||||
lib.impl("alltoall_autograd", alltoall_autograd, "Meta")
|
||||
|
||||
row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2
|
||||
input_split_sizes_tensor = torch.tensor(
|
||||
[(i + 1) * (self.rank + 1) for i in range(self.world_size)],
|
||||
dtype=torch.int64,
|
||||
)
|
||||
output_split_sizes_tensor = torch.tensor(
|
||||
[(i + 1) * (self.rank + 1) for i in range(self.world_size)],
|
||||
dtype=torch.int64,
|
||||
)
|
||||
inputs = (
|
||||
torch.ones(int(row), 5, device="cuda", requires_grad=True)
|
||||
* (self.rank + 1),
|
||||
input_split_sizes_tensor,
|
||||
output_split_sizes_tensor,
|
||||
)
|
||||
trs = self.get_world_trs()
|
||||
|
||||
compiled_fn = torch.compile(example, fullgraph=True, dynamic=True)
|
||||
|
||||
if override_with_ac:
|
||||
|
||||
def compiled_fn_wrapper(*args):
|
||||
return example(*inputs, **trs)
|
||||
|
||||
out = torch.utils.checkpoint.checkpoint(
|
||||
compiled_fn_wrapper, *inputs, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
out = compiled_fn(*inputs, **trs)
|
||||
|
||||
# track how many all_to_alls we saw in the backward
|
||||
with TrackingMode() as m:
|
||||
out.sum().backward()
|
||||
if override_with_ac:
|
||||
# We wrapped our test in AC, which overrides the partitioner decision
|
||||
# of never recomputing collectives.
|
||||
# So we should properly see the all2all be recomputed in the backward
|
||||
self.assertEqual(
|
||||
m.ops_counter[torch.ops._c10d_functional.all_to_all_single.default],
|
||||
2,
|
||||
)
|
||||
else:
|
||||
# there is 1 all2all in the fw, and 1 all2all in the backward.
|
||||
# notably: even though activation_memory_budget == 0 ("recompute_everything"),
|
||||
# we are still choosing *not* to recompute the all2all from the fw
|
||||
self.assertEqual(
|
||||
m.ops_counter[torch.ops._c10d_functional.all_to_all_single.default],
|
||||
1,
|
||||
)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_all_to_all_single_inductor_split_sizes_none(self):
|
||||
|
|
|
|||
|
|
@ -223,6 +223,30 @@ graphsafe_rng_functionalization = True
|
|||
# Used for tests
|
||||
strict_autograd_cache = False
|
||||
|
||||
# Note [Recomputing collectives in the partitioner]
|
||||
# The purpose of this config is as follows:
|
||||
# - We have many passes in the compiler (min-cut partitioning, DCE, etc)
|
||||
# which can reorder or ,delete duplicate nodes in the graph
|
||||
# - If any of these passes reorder/delete/duplicate a collective
|
||||
# in a setting where the compiler is being run independently on multiple
|
||||
# ranks, we run the risk that the compiler will make a different decison on
|
||||
# different ranks, resulting in a NCCL hang when using torch.compile
|
||||
# To handle this, we will (by default) ensure that collectives are not modified
|
||||
# by the compiler.
|
||||
#
|
||||
# A few examples:
|
||||
# - don't dead-code-eliminate collectives
|
||||
# (in case they are dead on rank i but not rank j)
|
||||
# - don't recompute collectives in partitioning
|
||||
# (in case we recompute on rank i but not rank j)
|
||||
#
|
||||
# Today this flag **must** be set to false, but eventually
|
||||
# we want the option to set it to true.
|
||||
# In order to potentially optimize collectives, we'll need the compiler
|
||||
# to broadcast information across ranks at compile time to ensure
|
||||
# that any decisions on collectives are made consistently.
|
||||
unsafe_allow_optimization_of_collectives = False
|
||||
|
||||
# See Note [AOTAutograd Tangent Subclassness for mutated inputs]
|
||||
# TODO(ivankobzarev): Remove this config, being able to deduce it compile time.
|
||||
disable_guess_zero_tangent_for_mutated_input_subclass = False
|
||||
|
|
|
|||
|
|
@ -920,6 +920,21 @@ def functionalize_rng_ops(
|
|||
return fw_module, bw_module
|
||||
|
||||
|
||||
def force_save_collectives(joint_module: fx.GraphModule) -> None:
|
||||
"""
|
||||
By default, the partitioner is not allowed to recompute collectives
|
||||
unless they come from a user-annotated AC region.
|
||||
See Note [Recomputing collectives in the partitioner]
|
||||
"""
|
||||
for node in joint_module.graph.nodes:
|
||||
if (
|
||||
isinstance(node.target, torch._ops.OpOverload)
|
||||
and node.target.namespace == "_c10d_functional"
|
||||
and not must_recompute(node)
|
||||
):
|
||||
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
|
||||
|
||||
|
||||
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
|
||||
"""
|
||||
If there are two consecutive checkpointed blocks with no operator in
|
||||
|
|
@ -1134,7 +1149,14 @@ def solve_min_cut(
|
|||
if op_types.is_view(node):
|
||||
return False
|
||||
if node in dont_ban:
|
||||
return False
|
||||
# collectives are *always* banned from recompute, overriding `dont_ban`
|
||||
# (in particular, the activation memory budget logic is not allowed to recompute collectives)
|
||||
is_collective = (
|
||||
isinstance(node.target, torch._ops.OpOverload)
|
||||
and node.target.namespace == "_c10d_functional"
|
||||
)
|
||||
if config.unsafe_allow_optimization_of_collectives or not is_collective:
|
||||
return False
|
||||
# This bans recomputation of the node unless we've been forced not to by
|
||||
# user annotation
|
||||
if must_recompute(node):
|
||||
|
|
@ -1142,7 +1164,6 @@ def solve_min_cut(
|
|||
|
||||
if "val" in node.meta and isinstance(node.meta["val"], torch.SymFloat):
|
||||
return False
|
||||
|
||||
banned_nodes.add(node)
|
||||
# A node will only ever be recomputed if there is a path from an
|
||||
# ancestor of this node to the backwards path through this node that
|
||||
|
|
@ -1926,6 +1947,8 @@ def min_cut_rematerialization_partition(
|
|||
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
|
||||
if graph_has_recomputable_ops:
|
||||
joint_module = cleanup_recompute_tags(joint_module)
|
||||
if not config.unsafe_allow_optimization_of_collectives:
|
||||
force_save_collectives(joint_module)
|
||||
|
||||
def classify_nodes(joint_module):
|
||||
name_to_node = get_name_to_node(joint_module.graph)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user