[dynamo] run-only recursively on recompile limit exceeded (#148021)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148021
Approved by: https://github.com/anijain2305
This commit is contained in:
William Wen 2025-02-28 17:15:22 -08:00 committed by PyTorch MergeBot
parent 1bbe57336b
commit 8f361c808b
2 changed files with 29 additions and 1 deletions

View File

@ -293,6 +293,34 @@ len(L['x']) == 3""".split(
):
self.assertIn(line, filter_reasons())
@torch._dynamo.config.patch(recompile_limit=1)
def test_recompile_child_run_only(self):
def f(x, n):
if torch.compiler.is_compiling():
x = x + 1
x = g(x)
return h(x) + n
def g(x):
if torch.compiler.is_compiling():
return x + 2
return x
def h(x):
if torch.compiler.is_compiling():
return x + 4
return x
torch.compile(g, backend="eager")(torch.randn(3))
inp = torch.randn(3)
opt_f = torch.compile(f, backend="eager")
opt_f(inp, 0)
# expect f to run eager, g compiled (from previous invocatino), h eager
res = opt_f(inp, 1)
self.assertEqual(res, inp + 3)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1278,7 +1278,7 @@ class ConvertFrame:
elif isinstance(e, RecompileLimitExceeded):
return ConvertFrameReturn(
frame_exec_strategy=FrameExecStrategy(
FrameAction.RUN_ONLY, FrameAction.SKIP
FrameAction.RUN_ONLY, FrameAction.RUN_ONLY
)
)