mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] run-only recursively on recompile limit exceeded (#148021)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148021 Approved by: https://github.com/anijain2305
This commit is contained in:
parent
1bbe57336b
commit
8f361c808b
|
|
@ -293,6 +293,34 @@ len(L['x']) == 3""".split(
|
|||
):
|
||||
self.assertIn(line, filter_reasons())
|
||||
|
||||
@torch._dynamo.config.patch(recompile_limit=1)
|
||||
def test_recompile_child_run_only(self):
|
||||
def f(x, n):
|
||||
if torch.compiler.is_compiling():
|
||||
x = x + 1
|
||||
x = g(x)
|
||||
return h(x) + n
|
||||
|
||||
def g(x):
|
||||
if torch.compiler.is_compiling():
|
||||
return x + 2
|
||||
return x
|
||||
|
||||
def h(x):
|
||||
if torch.compiler.is_compiling():
|
||||
return x + 4
|
||||
return x
|
||||
|
||||
torch.compile(g, backend="eager")(torch.randn(3))
|
||||
inp = torch.randn(3)
|
||||
opt_f = torch.compile(f, backend="eager")
|
||||
opt_f(inp, 0)
|
||||
|
||||
# expect f to run eager, g compiled (from previous invocatino), h eager
|
||||
res = opt_f(inp, 1)
|
||||
|
||||
self.assertEqual(res, inp + 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -1278,7 +1278,7 @@ class ConvertFrame:
|
|||
elif isinstance(e, RecompileLimitExceeded):
|
||||
return ConvertFrameReturn(
|
||||
frame_exec_strategy=FrameExecStrategy(
|
||||
FrameAction.RUN_ONLY, FrameAction.SKIP
|
||||
FrameAction.RUN_ONLY, FrameAction.RUN_ONLY
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user