mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[partitioner] always ban compiler-driven recompute of collectives by default (#147561)
This should fix the hang in https://fb.workplace.com/groups/1075192433118967/permalink/1603268720311333/ The argument here is that: (1) in general, it is not safe for the partitioner to sometimes choose to recompute collectives in the backward. Why? If we are running a distributed job, where many ranks are compiling at the same time, we need every rank to make a consistent decision about which collectives are recomputed for backward. If we let each compiler instance make its own choice without any cross-rank communication, they can make different choices and cause NCCL hangs (see the link above) (2) later on, we'll want an `spmd_mode` flag that causes the compiler to issue collectives and communicate info across ranks. Once we have such a config, then turning it on should make it safe for the partitioner to potentially choose to recompute collectives (and agree on the binary "recompute-or-save" choice across all ranks) (3) even without an `spmd_mode`, users can override this choice by using `torch.utils.checkpoint()` in their user code. User checkpointing generally always overrides the partitioner, and this should be safe because we expect the user to apply checkpointing consistently across ranks Pull Request resolved: https://github.com/pytorch/pytorch/pull/147561 Approved by: https://github.com/zou3519
This commit is contained in:
parent
420a9be743
commit
3646d4dbc8
|
|
@ -2,6 +2,7 @@
|
||||||
import datetime
|
import datetime
|
||||||
import functools
|
import functools
|
||||||
import unittest
|
import unittest
|
||||||
|
from collections import defaultdict
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -32,6 +33,7 @@ from torch.testing._internal.common_utils import (
|
||||||
skipIfRocm,
|
skipIfRocm,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||||
|
from torch.utils._python_dispatch import TorchDispatchMode
|
||||||
|
|
||||||
|
|
||||||
def _tolist_with_constrain_as_size(tensor):
|
def _tolist_with_constrain_as_size(tensor):
|
||||||
|
|
@ -42,6 +44,7 @@ def _tolist_with_constrain_as_size(tensor):
|
||||||
|
|
||||||
|
|
||||||
@requires_nccl()
|
@requires_nccl()
|
||||||
|
@instantiate_parametrized_tests
|
||||||
class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||||
"""
|
"""
|
||||||
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
|
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
|
||||||
|
|
@ -550,6 +553,192 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||||
inductor_out = compiled_fn(*inputs, **trs)
|
inductor_out = compiled_fn(*inputs, **trs)
|
||||||
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
|
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
|
||||||
|
|
||||||
|
# The goal of this test is that when `unsafe_allow_recompute_of_collectives=False`,
|
||||||
|
# The partitioner will *never* recompute collectives in the backward, even
|
||||||
|
# if the activation_memory_budget partitioner is being used,
|
||||||
|
# unless there is a manual user checkpoint() region (which we know makes it safe
|
||||||
|
# to recompute the collective, since we assume that the user applied the AC
|
||||||
|
# region consistently across all ranks)
|
||||||
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
|
@skip_if_lt_x_gpu(2)
|
||||||
|
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
||||||
|
@patch.object(torch._functorch.config, "activation_memory_budget", 0.01)
|
||||||
|
@parametrize("override_with_ac", [False, True])
|
||||||
|
def test_all_to_all_recompute_is_always_banned(self, override_with_ac):
|
||||||
|
@torch.library.custom_op("custom_ns::foo", mutates_args=())
|
||||||
|
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
@foo.register_fake
|
||||||
|
def _(x):
|
||||||
|
return torch.empty_like(x)
|
||||||
|
|
||||||
|
def setup_context(ctx, inputs, output):
|
||||||
|
ctx.save_for_backward(inputs[0])
|
||||||
|
return
|
||||||
|
|
||||||
|
def backward(ctx, grad):
|
||||||
|
(x,) = ctx.saved_tensors
|
||||||
|
return grad * x
|
||||||
|
|
||||||
|
foo.register_autograd(backward, setup_context=setup_context)
|
||||||
|
|
||||||
|
class AllToAllSingle(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx,
|
||||||
|
input: torch.Tensor,
|
||||||
|
output_split_sizes,
|
||||||
|
input_split_sizes,
|
||||||
|
tag,
|
||||||
|
ranks,
|
||||||
|
group_size: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
ctx.output_split_sizes = input_split_sizes
|
||||||
|
ctx.input_split_sizes = output_split_sizes
|
||||||
|
ctx.group_size = group_size
|
||||||
|
a2a = torch.ops._c10d_functional.all_to_all_single.default(
|
||||||
|
input,
|
||||||
|
output_split_sizes,
|
||||||
|
input_split_sizes,
|
||||||
|
"0",
|
||||||
|
)
|
||||||
|
a2a = torch.ops.c10d_functional.wait_tensor(a2a)
|
||||||
|
return a2a
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad):
|
||||||
|
grad = torch.ops._c10d_functional.all_to_all_single.default(
|
||||||
|
grad,
|
||||||
|
ctx.output_split_sizes,
|
||||||
|
ctx.input_split_sizes,
|
||||||
|
"0",
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
torch.ops.c10d_functional.wait_tensor(grad),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def alltoall_autograd(
|
||||||
|
inp,
|
||||||
|
output_split_sizes,
|
||||||
|
input_split_sizes,
|
||||||
|
tag,
|
||||||
|
ranks,
|
||||||
|
group_size,
|
||||||
|
):
|
||||||
|
out = AllToAllSingle.apply(
|
||||||
|
inp, output_split_sizes, input_split_sizes, tag, ranks, group_size
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
# simple mode to track how many collective ops we saw in the backward
|
||||||
|
class TrackingMode(TorchDispatchMode):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.ops_counter = defaultdict(int)
|
||||||
|
|
||||||
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
|
rs = func(*args, **kwargs)
|
||||||
|
self.ops_counter[func] += 1
|
||||||
|
return rs
|
||||||
|
|
||||||
|
def example(
|
||||||
|
inp,
|
||||||
|
input_split_sizes_tensor,
|
||||||
|
output_split_sizes_tensor,
|
||||||
|
*,
|
||||||
|
tag,
|
||||||
|
ranks,
|
||||||
|
group_size,
|
||||||
|
):
|
||||||
|
input_split_sizes = _tolist_with_constrain_as_size(input_split_sizes_tensor)
|
||||||
|
output_split_sizes = _tolist_with_constrain_as_size(
|
||||||
|
output_split_sizes_tensor
|
||||||
|
)
|
||||||
|
a2a = torch.ops.custom_ns.alltoall_autograd.default(
|
||||||
|
inp,
|
||||||
|
output_split_sizes,
|
||||||
|
input_split_sizes,
|
||||||
|
tag,
|
||||||
|
ranks,
|
||||||
|
group_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return torch.ops.custom_ns.foo(a2a)
|
||||||
|
|
||||||
|
with _dynamo_dist_per_rank_init(
|
||||||
|
self.rank, self.world_size
|
||||||
|
), torch._dynamo.config.patch(
|
||||||
|
dynamic_shapes=True,
|
||||||
|
capture_dynamic_output_shape_ops=True,
|
||||||
|
capture_scalar_outputs=True,
|
||||||
|
), torch.library._scoped_library(
|
||||||
|
"custom_ns", "FRAGMENT"
|
||||||
|
) as lib:
|
||||||
|
lib.define(
|
||||||
|
"alltoall_autograd(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor" # noqa: B950
|
||||||
|
)
|
||||||
|
lib.impl("alltoall_autograd", alltoall_autograd, "Autograd")
|
||||||
|
lib.impl("alltoall_autograd", alltoall_autograd, "Meta")
|
||||||
|
|
||||||
|
row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2
|
||||||
|
input_split_sizes_tensor = torch.tensor(
|
||||||
|
[(i + 1) * (self.rank + 1) for i in range(self.world_size)],
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
output_split_sizes_tensor = torch.tensor(
|
||||||
|
[(i + 1) * (self.rank + 1) for i in range(self.world_size)],
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
inputs = (
|
||||||
|
torch.ones(int(row), 5, device="cuda", requires_grad=True)
|
||||||
|
* (self.rank + 1),
|
||||||
|
input_split_sizes_tensor,
|
||||||
|
output_split_sizes_tensor,
|
||||||
|
)
|
||||||
|
trs = self.get_world_trs()
|
||||||
|
|
||||||
|
compiled_fn = torch.compile(example, fullgraph=True, dynamic=True)
|
||||||
|
|
||||||
|
if override_with_ac:
|
||||||
|
|
||||||
|
def compiled_fn_wrapper(*args):
|
||||||
|
return example(*inputs, **trs)
|
||||||
|
|
||||||
|
out = torch.utils.checkpoint.checkpoint(
|
||||||
|
compiled_fn_wrapper, *inputs, use_reentrant=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = compiled_fn(*inputs, **trs)
|
||||||
|
|
||||||
|
# track how many all_to_alls we saw in the backward
|
||||||
|
with TrackingMode() as m:
|
||||||
|
out.sum().backward()
|
||||||
|
if override_with_ac:
|
||||||
|
# We wrapped our test in AC, which overrides the partitioner decision
|
||||||
|
# of never recomputing collectives.
|
||||||
|
# So we should properly see the all2all be recomputed in the backward
|
||||||
|
self.assertEqual(
|
||||||
|
m.ops_counter[torch.ops._c10d_functional.all_to_all_single.default],
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# there is 1 all2all in the fw, and 1 all2all in the backward.
|
||||||
|
# notably: even though activation_memory_budget == 0 ("recompute_everything"),
|
||||||
|
# we are still choosing *not* to recompute the all2all from the fw
|
||||||
|
self.assertEqual(
|
||||||
|
m.ops_counter[torch.ops._c10d_functional.all_to_all_single.default],
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
def test_all_to_all_single_inductor_split_sizes_none(self):
|
def test_all_to_all_single_inductor_split_sizes_none(self):
|
||||||
|
|
|
||||||
|
|
@ -223,6 +223,30 @@ graphsafe_rng_functionalization = True
|
||||||
# Used for tests
|
# Used for tests
|
||||||
strict_autograd_cache = False
|
strict_autograd_cache = False
|
||||||
|
|
||||||
|
# Note [Recomputing collectives in the partitioner]
|
||||||
|
# The purpose of this config is as follows:
|
||||||
|
# - We have many passes in the compiler (min-cut partitioning, DCE, etc)
|
||||||
|
# which can reorder or ,delete duplicate nodes in the graph
|
||||||
|
# - If any of these passes reorder/delete/duplicate a collective
|
||||||
|
# in a setting where the compiler is being run independently on multiple
|
||||||
|
# ranks, we run the risk that the compiler will make a different decison on
|
||||||
|
# different ranks, resulting in a NCCL hang when using torch.compile
|
||||||
|
# To handle this, we will (by default) ensure that collectives are not modified
|
||||||
|
# by the compiler.
|
||||||
|
#
|
||||||
|
# A few examples:
|
||||||
|
# - don't dead-code-eliminate collectives
|
||||||
|
# (in case they are dead on rank i but not rank j)
|
||||||
|
# - don't recompute collectives in partitioning
|
||||||
|
# (in case we recompute on rank i but not rank j)
|
||||||
|
#
|
||||||
|
# Today this flag **must** be set to false, but eventually
|
||||||
|
# we want the option to set it to true.
|
||||||
|
# In order to potentially optimize collectives, we'll need the compiler
|
||||||
|
# to broadcast information across ranks at compile time to ensure
|
||||||
|
# that any decisions on collectives are made consistently.
|
||||||
|
unsafe_allow_optimization_of_collectives = False
|
||||||
|
|
||||||
# See Note [AOTAutograd Tangent Subclassness for mutated inputs]
|
# See Note [AOTAutograd Tangent Subclassness for mutated inputs]
|
||||||
# TODO(ivankobzarev): Remove this config, being able to deduce it compile time.
|
# TODO(ivankobzarev): Remove this config, being able to deduce it compile time.
|
||||||
disable_guess_zero_tangent_for_mutated_input_subclass = False
|
disable_guess_zero_tangent_for_mutated_input_subclass = False
|
||||||
|
|
|
||||||
|
|
@ -920,6 +920,21 @@ def functionalize_rng_ops(
|
||||||
return fw_module, bw_module
|
return fw_module, bw_module
|
||||||
|
|
||||||
|
|
||||||
|
def force_save_collectives(joint_module: fx.GraphModule) -> None:
|
||||||
|
"""
|
||||||
|
By default, the partitioner is not allowed to recompute collectives
|
||||||
|
unless they come from a user-annotated AC region.
|
||||||
|
See Note [Recomputing collectives in the partitioner]
|
||||||
|
"""
|
||||||
|
for node in joint_module.graph.nodes:
|
||||||
|
if (
|
||||||
|
isinstance(node.target, torch._ops.OpOverload)
|
||||||
|
and node.target.namespace == "_c10d_functional"
|
||||||
|
and not must_recompute(node)
|
||||||
|
):
|
||||||
|
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
|
||||||
|
|
||||||
|
|
||||||
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
|
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
|
||||||
"""
|
"""
|
||||||
If there are two consecutive checkpointed blocks with no operator in
|
If there are two consecutive checkpointed blocks with no operator in
|
||||||
|
|
@ -1134,7 +1149,14 @@ def solve_min_cut(
|
||||||
if op_types.is_view(node):
|
if op_types.is_view(node):
|
||||||
return False
|
return False
|
||||||
if node in dont_ban:
|
if node in dont_ban:
|
||||||
return False
|
# collectives are *always* banned from recompute, overriding `dont_ban`
|
||||||
|
# (in particular, the activation memory budget logic is not allowed to recompute collectives)
|
||||||
|
is_collective = (
|
||||||
|
isinstance(node.target, torch._ops.OpOverload)
|
||||||
|
and node.target.namespace == "_c10d_functional"
|
||||||
|
)
|
||||||
|
if config.unsafe_allow_optimization_of_collectives or not is_collective:
|
||||||
|
return False
|
||||||
# This bans recomputation of the node unless we've been forced not to by
|
# This bans recomputation of the node unless we've been forced not to by
|
||||||
# user annotation
|
# user annotation
|
||||||
if must_recompute(node):
|
if must_recompute(node):
|
||||||
|
|
@ -1142,7 +1164,6 @@ def solve_min_cut(
|
||||||
|
|
||||||
if "val" in node.meta and isinstance(node.meta["val"], torch.SymFloat):
|
if "val" in node.meta and isinstance(node.meta["val"], torch.SymFloat):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
banned_nodes.add(node)
|
banned_nodes.add(node)
|
||||||
# A node will only ever be recomputed if there is a path from an
|
# A node will only ever be recomputed if there is a path from an
|
||||||
# ancestor of this node to the backwards path through this node that
|
# ancestor of this node to the backwards path through this node that
|
||||||
|
|
@ -1926,6 +1947,8 @@ def min_cut_rematerialization_partition(
|
||||||
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
|
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
|
||||||
if graph_has_recomputable_ops:
|
if graph_has_recomputable_ops:
|
||||||
joint_module = cleanup_recompute_tags(joint_module)
|
joint_module = cleanup_recompute_tags(joint_module)
|
||||||
|
if not config.unsafe_allow_optimization_of_collectives:
|
||||||
|
force_save_collectives(joint_module)
|
||||||
|
|
||||||
def classify_nodes(joint_module):
|
def classify_nodes(joint_module):
|
||||||
name_to_node = get_name_to_node(joint_module.graph)
|
name_to_node = get_name_to_node(joint_module.graph)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user