[aot] always lower the backward with a deepcopy (#149229)

FIXES https://github.com/pytorch/pytorch/issues/149105

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149229
Approved by: https://github.com/bdhirsh
This commit is contained in:
Simon Fan 2025-03-20 11:53:41 -07:00 committed by PyTorch MergeBot
parent 5ebc283f2c
commit e481615bc7
3 changed files with 48 additions and 5 deletions

View File

@ -3935,6 +3935,42 @@ class CompiledAutograd1(torch.nn.Module):
self.assertTrue("aten.randn" in str(gm))
def test_aot_bwd_gm_runnable(self):
# This test ensures that the bw_module saved in
# CompiledFunction._lazy_backward_info is executable,
# by ensuring post grad passes have not ran on it.
post_grad_graphs = []
def post_grad_pass(graph):
nonlocal post_grad_graphs
post_grad_graphs.append(graph)
return graph
x = torch.randn(10, 10, requires_grad=True)
y = torch.randn(10, 10, requires_grad=True)
# forces symints to be saved for backward
# and forces aot compilation of the backward
torch._dynamo.mark_dynamic(x, 0)
torch._dynamo.mark_dynamic(y, 1)
@torch.compile
def fn(x, y):
return torch.matmul(x, y).sum()
with inductor_config.patch(post_grad_custom_post_pass=post_grad_pass):
loss = fn(x, y)
self.assertEqual(len(post_grad_graphs), 2) # 1 fwd and 1 bwd
self.assertTrue(loss.grad_fn.name(), "CompiledFunctionBackward")
self.assertIsNot(
post_grad_graphs[1],
loss.grad_fn._forward_cls._lazy_backward_info.bw_module.graph,
)
with compiled_autograd._enable(lambda gm: gm):
loss.backward()
def load_test_module(name):
testdir = Path(__file__).absolute().parent.parent

View File

@ -1194,8 +1194,9 @@ def aot_dispatch_autograd(
compiled_bw_func = None
if num_symints_saved_for_bw > 0:
try:
# See Note: [Backward graph lazy lowering]
compiled_bw_func = aot_config.bw_compiler(
bw_module, placeholder_list
copy.deepcopy(bw_module), placeholder_list
)
except Exception as e:
exc = e

View File

@ -8,6 +8,7 @@ This module defines runtime wrappers, which, based on previous analysis attempts
"""
import builtins
import collections
import copy
import itertools
import pprint
from contextlib import nullcontext
@ -1459,6 +1460,13 @@ def merge_view_inputs(
return args_to_functionalization, post_processed_calling_convention_meta
# Note: [Backward graph lazy lowering]
# After AOTDispatch traces the backward for graphs requiring autograd, we will lower the graph lazily,
# unless we suspect that inductor might specialize and insert additional guards. When we do lazy
# lowering, we stash the AOT backward graph (bw_module) in this class.
#
# Lowering passes are performed on a deepcopy of this bw_module due to compatbility
# with compiled autograd. See: https://github.com/pytorch/pytorch/pull/149229#discussion_r2002122645.
@dataclass
class AutogradLazyBackwardCompileInfo:
bw_module: Callable
@ -2217,14 +2225,12 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
dynamo_compile_column_us="backward_cumulative_compile_time_us",
):
CompileEventLogger.compilation_metric(is_forward=False)
# See Note: [Backward graph lazy lowering]
CompiledFunction.compiled_bw = aot_config.bw_compiler(
bw_module, placeholder_list
copy.deepcopy(bw_module), placeholder_list
)
# Maybe save cache entry
if try_save_cache_entry is not None:
# CompiledFunction.metadata
# CompiledFunction.maybe_subclass_metadata
# bw_module
try_save_cache_entry(
CompiledFunction.compiled_bw,
fw_metadata,