mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[aot] always lower the backward with a deepcopy (#149229)
FIXES https://github.com/pytorch/pytorch/issues/149105 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149229 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
5ebc283f2c
commit
e481615bc7
|
|
@ -3935,6 +3935,42 @@ class CompiledAutograd1(torch.nn.Module):
|
|||
|
||||
self.assertTrue("aten.randn" in str(gm))
|
||||
|
||||
def test_aot_bwd_gm_runnable(self):
|
||||
# This test ensures that the bw_module saved in
|
||||
# CompiledFunction._lazy_backward_info is executable,
|
||||
# by ensuring post grad passes have not ran on it.
|
||||
|
||||
post_grad_graphs = []
|
||||
|
||||
def post_grad_pass(graph):
|
||||
nonlocal post_grad_graphs
|
||||
post_grad_graphs.append(graph)
|
||||
return graph
|
||||
|
||||
x = torch.randn(10, 10, requires_grad=True)
|
||||
y = torch.randn(10, 10, requires_grad=True)
|
||||
# forces symints to be saved for backward
|
||||
# and forces aot compilation of the backward
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
torch._dynamo.mark_dynamic(y, 1)
|
||||
|
||||
@torch.compile
|
||||
def fn(x, y):
|
||||
return torch.matmul(x, y).sum()
|
||||
|
||||
with inductor_config.patch(post_grad_custom_post_pass=post_grad_pass):
|
||||
loss = fn(x, y)
|
||||
self.assertEqual(len(post_grad_graphs), 2) # 1 fwd and 1 bwd
|
||||
|
||||
self.assertTrue(loss.grad_fn.name(), "CompiledFunctionBackward")
|
||||
self.assertIsNot(
|
||||
post_grad_graphs[1],
|
||||
loss.grad_fn._forward_cls._lazy_backward_info.bw_module.graph,
|
||||
)
|
||||
|
||||
with compiled_autograd._enable(lambda gm: gm):
|
||||
loss.backward()
|
||||
|
||||
|
||||
def load_test_module(name):
|
||||
testdir = Path(__file__).absolute().parent.parent
|
||||
|
|
|
|||
|
|
@ -1194,8 +1194,9 @@ def aot_dispatch_autograd(
|
|||
compiled_bw_func = None
|
||||
if num_symints_saved_for_bw > 0:
|
||||
try:
|
||||
# See Note: [Backward graph lazy lowering]
|
||||
compiled_bw_func = aot_config.bw_compiler(
|
||||
bw_module, placeholder_list
|
||||
copy.deepcopy(bw_module), placeholder_list
|
||||
)
|
||||
except Exception as e:
|
||||
exc = e
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ This module defines runtime wrappers, which, based on previous analysis attempts
|
|||
"""
|
||||
import builtins
|
||||
import collections
|
||||
import copy
|
||||
import itertools
|
||||
import pprint
|
||||
from contextlib import nullcontext
|
||||
|
|
@ -1459,6 +1460,13 @@ def merge_view_inputs(
|
|||
return args_to_functionalization, post_processed_calling_convention_meta
|
||||
|
||||
|
||||
# Note: [Backward graph lazy lowering]
|
||||
# After AOTDispatch traces the backward for graphs requiring autograd, we will lower the graph lazily,
|
||||
# unless we suspect that inductor might specialize and insert additional guards. When we do lazy
|
||||
# lowering, we stash the AOT backward graph (bw_module) in this class.
|
||||
#
|
||||
# Lowering passes are performed on a deepcopy of this bw_module due to compatbility
|
||||
# with compiled autograd. See: https://github.com/pytorch/pytorch/pull/149229#discussion_r2002122645.
|
||||
@dataclass
|
||||
class AutogradLazyBackwardCompileInfo:
|
||||
bw_module: Callable
|
||||
|
|
@ -2217,14 +2225,12 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
|||
dynamo_compile_column_us="backward_cumulative_compile_time_us",
|
||||
):
|
||||
CompileEventLogger.compilation_metric(is_forward=False)
|
||||
# See Note: [Backward graph lazy lowering]
|
||||
CompiledFunction.compiled_bw = aot_config.bw_compiler(
|
||||
bw_module, placeholder_list
|
||||
copy.deepcopy(bw_module), placeholder_list
|
||||
)
|
||||
# Maybe save cache entry
|
||||
if try_save_cache_entry is not None:
|
||||
# CompiledFunction.metadata
|
||||
# CompiledFunction.maybe_subclass_metadata
|
||||
# bw_module
|
||||
try_save_cache_entry(
|
||||
CompiledFunction.compiled_bw,
|
||||
fw_metadata,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user