mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166477 Approved by: https://github.com/Lucaskabela, https://github.com/Skylion007 ghstack dependencies: #166476
882 lines
24 KiB
Python
882 lines
24 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import sys
|
|
|
|
import torch
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
from torch._dynamo import config
|
|
from torch._dynamo.testing import make_test_cls_with_patches
|
|
|
|
|
|
try:
|
|
# from . import test_ctx_manager
|
|
pass
|
|
except ImportError:
|
|
# import test_aot_autograd
|
|
# import test_ctx_manager
|
|
|
|
# import test_export
|
|
# import test_functions
|
|
# import test_higher_order_ops
|
|
# import test_misc
|
|
# import test_modules
|
|
# import test_repros
|
|
# import test_sdpa
|
|
# import test_subgraphs
|
|
pass
|
|
|
|
|
|
test_classes = {}
|
|
|
|
|
|
def make_nested_cls(cls):
|
|
suffix = "_nested_graph_breaks"
|
|
|
|
cls_prefix = "NestedGraphBreaks"
|
|
|
|
test_class = make_test_cls_with_patches(
|
|
cls,
|
|
cls_prefix,
|
|
suffix,
|
|
(config, "debug_force_nested_calls", True),
|
|
(config, "debug_force_graph_break_on_leaf_return", True),
|
|
(config, "debug_disable_compile_counter", True),
|
|
xfail_prop="_expected_failure_nested_graph_breaks",
|
|
)
|
|
|
|
test_classes[test_class.__name__] = test_class
|
|
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
|
|
# globals()[test_class.__name__] = test_class
|
|
test_class.__module__ = __name__
|
|
return test_class
|
|
|
|
|
|
tests = [
|
|
# test_ctx_manager.CtxManagerTests,
|
|
# test_functions.FunctionTests,
|
|
# test_misc.MiscTests,
|
|
# test_repros.ReproTests,
|
|
# test_modules.NNModuleTests,
|
|
# test_subgraphs.SubGraphTests,
|
|
# test_higher_order_ops.HigherOrderOpTests,
|
|
# test_higher_order_ops.FuncTorchHigherOrderOpTests,
|
|
# test_aot_autograd.AotAutogradFallbackTests,
|
|
# test_sdpa.TestSDPA,
|
|
]
|
|
test = None
|
|
for test in tests:
|
|
make_nested_cls(test)
|
|
del test
|
|
|
|
|
|
# for use in test_side_effects_globals
|
|
global1, global2, global3, global4 = (torch.zeros(3),) * 4
|
|
|
|
|
|
class NestedGraphBreakTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
|
def test_single_graph_break(self):
|
|
# NOTE marking f1, f2, f3 as global
|
|
# prevents them from being freevars
|
|
global f1, f2, f3
|
|
|
|
def f1(x1):
|
|
x1 = x1 + 1
|
|
torch._dynamo.graph_break()
|
|
return x1 + 2
|
|
|
|
def f2(x2):
|
|
return f1(x2 + 4) + 8
|
|
|
|
def f3(x3):
|
|
return f2(x3 + 16) + 32
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
|
x = torch.zeros(3)
|
|
res = f3(x)
|
|
ref = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
self.assertEqual(cnts.op_count, 6)
|
|
|
|
def test_single_graph_break_repeat(self):
|
|
global f1, f2, f3
|
|
|
|
def f1(x1):
|
|
x1 = x1 + 1
|
|
torch._dynamo.graph_break()
|
|
return x1 + 2
|
|
|
|
def f2(x2):
|
|
tmp1 = f1(x2 + 4)
|
|
tmp2 = f1(x2 + 8) << 4
|
|
return tmp1 + tmp2
|
|
|
|
def f3(x3):
|
|
return f2(x3 + 256) + 512
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
|
x = torch.zeros(3, dtype=torch.long)
|
|
res = f3(x)
|
|
ref = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(cnts.frame_count, 3)
|
|
self.assertEqual(cnts.op_count, 10)
|
|
|
|
def test_doubly_nested_graph_break(self):
|
|
global f1, f2, f3
|
|
|
|
def f1(x1):
|
|
x1 = x1 + 1
|
|
torch._dynamo.graph_break()
|
|
return x1 + 2
|
|
|
|
def f2(x2):
|
|
x2 = x2 + 4
|
|
torch._dynamo.graph_break()
|
|
return f1(x2 + 8) + 16
|
|
|
|
def f3(x3):
|
|
return f2(x3 + 32) + 64
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
|
x = torch.zeros(3)
|
|
res = f3(x)
|
|
ref = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(cnts.frame_count, 3)
|
|
self.assertEqual(cnts.op_count, 7)
|
|
|
|
def test_differing_arg_nums(self):
|
|
global f1, f2, f3, f4
|
|
|
|
def f1(x1, x2):
|
|
x = x1 + x2
|
|
torch._dynamo.graph_break()
|
|
return x + 1
|
|
|
|
def f2(x3, x4, x5, x6):
|
|
return f1(x3 + x4, x5 + x6) + 2
|
|
|
|
def f3(x7, x8):
|
|
return f2(x7, x7 + 4, x8, x8 + 8) + 16
|
|
|
|
def f4(x9):
|
|
return f3(x9, x9 + 32) + 64
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f4)
|
|
x = torch.zeros(3)
|
|
res = f4(x)
|
|
ref = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
self.assertEqual(cnts.op_count, 10)
|
|
|
|
def test_differing_locals_nums(self):
|
|
global f1, f2, f3
|
|
|
|
def f1(x1):
|
|
loc1 = x1 + 1
|
|
torch._dynamo.graph_break()
|
|
return loc1 + 2
|
|
|
|
def f2(x2):
|
|
loc1 = x2 + 4
|
|
loc2 = x2 + 8
|
|
return f1(x2) + loc1 + loc2
|
|
|
|
def f3(x3):
|
|
loc1 = x3 + 16
|
|
loc2 = x3 + 32
|
|
loc3 = x3 + 64
|
|
loc4 = x3 + 128
|
|
return f2(x3) + loc1 + loc2 + loc3 + loc4
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
|
x = torch.zeros(3)
|
|
res = f3(x)
|
|
ref = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
self.assertEqual(cnts.op_count, 14)
|
|
|
|
def test_counters(self):
|
|
global f1, f2, f3, f4
|
|
|
|
def f1(x):
|
|
x = x + 1
|
|
torch._dynamo.graph_break()
|
|
return x + 2
|
|
|
|
def f2(x):
|
|
return f1(x + 4) + 8
|
|
|
|
def f3(x):
|
|
x = x + 16
|
|
for _ in range(1):
|
|
x = f2(x)
|
|
return x + 32
|
|
|
|
@torch.compile(backend="eager")
|
|
def f4(x):
|
|
return f3(x + 64) + 128
|
|
|
|
self.assertEqual(f4(torch.zeros(3)), torch.zeros(3) + 255)
|
|
self.assertEqual(len(torch._dynamo.utils.counters["graph_break"]), 2)
|
|
|
|
def test_supported_ctx_manager(self):
|
|
global check, check_disabled, f1, f2, f3
|
|
|
|
@torch._dynamo.disable
|
|
def check_disabled(value):
|
|
assert torch.is_grad_enabled() == value
|
|
|
|
def check(value):
|
|
assert torch.is_grad_enabled() == value
|
|
|
|
def f1(x):
|
|
with torch.no_grad():
|
|
x = x + 1
|
|
check(False)
|
|
check_disabled(False)
|
|
check(False)
|
|
return x + 2
|
|
|
|
def f2(x):
|
|
with torch.enable_grad():
|
|
x = x + 4
|
|
check(True)
|
|
check_disabled(True)
|
|
check(True)
|
|
return f1(x) + 8
|
|
|
|
def f3(x):
|
|
with torch.no_grad():
|
|
x = x + 16
|
|
check(False)
|
|
check_disabled(False)
|
|
check(False)
|
|
return f2(x) + 32
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
|
x = torch.zeros(3)
|
|
res = f3(x)
|
|
ref = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(cnts.frame_count, 4)
|
|
# includes set_grad_enabled ops
|
|
self.assertEqual(cnts.op_count, 14)
|
|
|
|
def test_inactive_ctx_manager(self):
|
|
global check, f1, f2, f3
|
|
|
|
def check(value):
|
|
assert torch.is_grad_enabled() == value
|
|
|
|
def f1(x, ctx1):
|
|
x = x + 1
|
|
ctx2 = torch.no_grad()
|
|
# torch.no_grad() is a stack value at the time of graph break
|
|
ctx3 = (torch.no_grad(), torch._dynamo.graph_break())[0]
|
|
x = x + 64
|
|
torch._dynamo.graph_break()
|
|
with ctx1:
|
|
check(False)
|
|
with ctx2:
|
|
check(False)
|
|
with ctx3:
|
|
check(False)
|
|
return x + 2
|
|
|
|
def f2(x, ctx1):
|
|
x = x + 4
|
|
ctx2 = torch.no_grad()
|
|
x = f1(x, torch.no_grad())
|
|
with ctx1:
|
|
check(False)
|
|
with ctx2:
|
|
check(False)
|
|
return x + 8
|
|
|
|
def f3(x):
|
|
x = x + 16
|
|
ctx = torch.no_grad()
|
|
x = f2(x, torch.no_grad())
|
|
with ctx:
|
|
check(False)
|
|
return x + 32
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
|
x = torch.zeros(3)
|
|
res = f3(x)
|
|
ref = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(cnts.frame_count, 3)
|
|
self.assertEqual(cnts.op_count, 7)
|
|
|
|
@torch._dynamo.config.patch(recompile_limit=1, fail_on_recompile_limit_hit=True)
|
|
def test_no_recompiles(self):
|
|
global f1, f2, f3
|
|
|
|
def f1(x):
|
|
x = x + 1
|
|
torch._dynamo.graph_break()
|
|
return x + 2
|
|
|
|
def f2(x):
|
|
x = x + 4
|
|
x = f1(x)
|
|
torch._dynamo.graph_break()
|
|
return x + 8
|
|
|
|
def f3(x):
|
|
x = x + 16
|
|
return f2(x) + 32
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
|
x = torch.zeros(3)
|
|
res = f3(x)
|
|
ref = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(cnts.frame_count, 3)
|
|
|
|
def test_cells(self):
|
|
def f1(x1):
|
|
cell1 = x1 + 1
|
|
cell2 = x1 + 2
|
|
|
|
def f2(x2, x3):
|
|
nonlocal cell1
|
|
cell3 = x2 + x3 + 4
|
|
cell1 += 8
|
|
|
|
def f3(x4):
|
|
nonlocal cell2, cell3
|
|
cell2 += 16
|
|
cell3 += 32
|
|
torch._dynamo.graph_break()
|
|
return x4 + cell1 + cell2 + cell3
|
|
|
|
return f3(x2 + x3), cell3
|
|
|
|
return f2(x1 + 64, x1 + 128) + (cell1, cell2)
|
|
|
|
def outer(x):
|
|
return f1(x)
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(backend=cnts)(outer)
|
|
x = torch.zeros(3)
|
|
res = outer(x)
|
|
ref = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
self.assertEqual(cnts.op_count, 13)
|
|
|
|
def test_dead_nested_cells(self):
|
|
global f1, f2, f3
|
|
|
|
def f3(x, cell1):
|
|
cell1 += 2
|
|
x = x + cell1
|
|
torch._dynamo.graph_break()
|
|
return x + cell1
|
|
|
|
def f1(cell1=0):
|
|
def inner(x):
|
|
x += 4
|
|
x = f3(x, cell1)
|
|
return x + 8
|
|
|
|
return inner
|
|
|
|
def f2(x):
|
|
return f1()(x + 16) + 32
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f2)
|
|
x = torch.zeros(3)
|
|
res = f2(x)
|
|
ref = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
# If we don't handle dead cells in nested functions correctly,
|
|
# frame_count will increase since we also
|
|
# graph break when we attempt to codegen inner.
|
|
# The exact issue was that side_effects was failing to codegen inner's cell's creation.
|
|
# So when we try to codegen cells for resume functions, we end up trying to codegen
|
|
# a CellVariable without a source, which leads to a graph break we can't resume from.
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
self.assertEqual(cnts.op_count, 6)
|
|
|
|
def test_cells_double_graph_break(self):
|
|
def f1(x1):
|
|
cell1 = x1 + 1
|
|
|
|
def f2(x2):
|
|
nonlocal cell1
|
|
cell1 += 2
|
|
torch._dynamo.graph_break()
|
|
torch._dynamo.graph_break()
|
|
return x2 + cell1
|
|
|
|
return f2(x1 + 4), cell1
|
|
|
|
def outer(x):
|
|
return f1(x)
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(backend=cnts)(outer)
|
|
x = torch.zeros(3)
|
|
res = outer(x)
|
|
ref = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
self.assertEqual(cnts.op_count, 4)
|
|
|
|
def test_side_effects_cells(self):
|
|
cell1, cell2, cell3, cell4 = (torch.zeros(3),) * 4
|
|
|
|
def f1():
|
|
nonlocal cell1
|
|
cell1 += 1
|
|
torch._dynamo.graph_break()
|
|
return cell1 + cell2
|
|
|
|
def f2():
|
|
nonlocal cell3
|
|
cell3 += 2
|
|
return f1() + cell3 + cell4
|
|
|
|
def f3():
|
|
return f2()
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
|
|
|
cell1 = torch.zeros(3)
|
|
cell2 = torch.zeros(3) + 4
|
|
cell3 = torch.zeros(3)
|
|
cell4 = torch.zeros(3) + 8
|
|
res = f3()
|
|
res = (res,) + tuple(x.clone() for x in (cell1, cell2, cell3, cell4))
|
|
|
|
cell1 = torch.zeros(3)
|
|
cell2 = torch.zeros(3) + 4
|
|
cell3 = torch.zeros(3)
|
|
cell4 = torch.zeros(3) + 8
|
|
ref = opt_fn()
|
|
ref = (ref,) + tuple(x.clone() for x in (cell1, cell2, cell3, cell4))
|
|
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
self.assertEqual(cnts.op_count, 5)
|
|
|
|
def test_side_effects_globals(self):
|
|
global f1, f2, f3
|
|
global global1, global2, global3, global4
|
|
|
|
def f1():
|
|
global global1
|
|
global1 += 1
|
|
torch._dynamo.graph_break()
|
|
return global1 + global2
|
|
|
|
def f2():
|
|
global global3
|
|
global3 += 2
|
|
return f1() + global3 + global4
|
|
|
|
def f3(x):
|
|
return x + f2()
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
|
x = torch.ones(3)
|
|
|
|
global1 = torch.zeros(3)
|
|
global2 = torch.zeros(3) + 4
|
|
global3 = torch.zeros(3)
|
|
global4 = torch.zeros(3) + 8
|
|
res = (f3(x), global1.clone(), global2, global3.clone(), global4)
|
|
|
|
global1 = torch.zeros(3)
|
|
global2 = torch.zeros(3) + 4
|
|
global3 = torch.zeros(3)
|
|
global4 = torch.zeros(3) + 8
|
|
ref = (opt_fn(x), global1.clone(), global2, global3.clone(), global4)
|
|
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
self.assertEqual(cnts.op_count, 6)
|
|
|
|
def test_side_effects_globals_different_module(self):
|
|
global f1, f2, _test_nested_graph_breaks_helper
|
|
try:
|
|
from . import _test_nested_graph_breaks_helper
|
|
except ImportError:
|
|
import _test_nested_graph_breaks_helper
|
|
|
|
def f1(x):
|
|
x = x + 1
|
|
torch._dynamo.graph_break()
|
|
return x + 1
|
|
|
|
def f2(x):
|
|
x = x + 1
|
|
x = _test_nested_graph_breaks_helper.fn(x, f1)
|
|
return x + 1
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f2)
|
|
|
|
_test_nested_graph_breaks_helper.reset_state()
|
|
x = torch.zeros(3)
|
|
res = (f2(x), _test_nested_graph_breaks_helper.global1.clone())
|
|
|
|
_test_nested_graph_breaks_helper.reset_state()
|
|
ref = (opt_fn(x), _test_nested_graph_breaks_helper.global1.clone())
|
|
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
self.assertEqual(cnts.op_count, 7)
|
|
|
|
def test_nested_graph_break_in_loop(self):
|
|
global f1, f2, f3, f4, f5
|
|
|
|
def f1(x, i):
|
|
x = x + 1
|
|
if i == 5:
|
|
torch._dynamo.graph_break()
|
|
return x + 1
|
|
|
|
def f2(x, i):
|
|
x = x + 1
|
|
x = f1(x, i)
|
|
return x + 1
|
|
|
|
def f3(x):
|
|
for i in range(8):
|
|
x = f2(x, i)
|
|
return x
|
|
|
|
def f4(x):
|
|
x = x + 1
|
|
x = f3(x)
|
|
return x + 1
|
|
|
|
def f5(x):
|
|
x = x + 1
|
|
x = f4(x)
|
|
return x + 1
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
# dynamic=True to prevent unnecessary recompiles
|
|
opt_fn = torch._dynamo.optimize(backend=cnts, dynamic=True)(f5)
|
|
x = torch.zeros(3)
|
|
res = f5(x)
|
|
ref = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
# skip frame due to nested graph break in for loop
|
|
# 2 frames from f5+f4, 2 frames from f2+f1 (i == 5), 1 frame from f2+f1 (i != 5)
|
|
self.assertEqual(cnts.frame_count, 5)
|
|
# 4 additions from f5+f4, 2 x 4 additions from f2+f1 (i == 5, i != 5)
|
|
self.assertEqual(cnts.op_count, 12)
|
|
self.assertEqual(torch._dynamo.utils.counters["frames"]["total"], 6)
|
|
|
|
def test_nested_graph_break_in_try_block(self):
|
|
# NOTE: this also tests nested step_graph_break
|
|
global f1, f2, f3, f4, f5
|
|
|
|
def f1(x):
|
|
x = x + 1
|
|
torch._dynamo.graph_break()
|
|
return x + 1
|
|
|
|
def f2(x):
|
|
x = x + 1
|
|
x = f1(x)
|
|
return x + 1
|
|
|
|
def f3(x):
|
|
x = x + 1
|
|
try:
|
|
x = x + 1
|
|
x = f2(x)
|
|
x = x + 1
|
|
finally:
|
|
pass
|
|
return x + 1
|
|
|
|
def f4(x):
|
|
x = x + 1
|
|
x = f3(x)
|
|
return x + 1
|
|
|
|
def f5(x):
|
|
x = x + 1
|
|
x = f4(x)
|
|
return x + 1
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f5)
|
|
x = torch.zeros(3)
|
|
res = f5(x)
|
|
ref = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
# skip frame due to graph break in try block
|
|
# 2 frames from f5+f4+(first part of f3), 2 frames from f2+f1
|
|
self.assertEqual(cnts.frame_count, 4)
|
|
# 5 additions from f5+f4+(first part of f3), 4 additions from f2+f1
|
|
self.assertEqual(cnts.op_count, 9)
|
|
self.assertEqual(torch._dynamo.utils.counters["frames"]["total"], 4)
|
|
|
|
def test_nested_step_unsupported(self):
|
|
global f1, f2, f3
|
|
|
|
def f1(x):
|
|
return x + 1
|
|
|
|
def f2(x):
|
|
x = x + 2
|
|
torch._dynamo.step_unsupported()
|
|
return f1(x) + 4
|
|
|
|
def f3(x):
|
|
x = x + 8
|
|
return f2(x) + 16
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
|
x = torch.zeros(3)
|
|
res = f3(x)
|
|
ref = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
# 1 frame from start of f3 + start of f2, 1 frame from f1, 1 frame from the end of f3
|
|
self.assertEqual(cnts.frame_count, 3)
|
|
# all ops except + 4
|
|
self.assertEqual(cnts.op_count, 4)
|
|
self.assertEqual(torch._dynamo.utils.counters["frames"]["total"], 3)
|
|
|
|
def test_generator_nested_graph_break(self):
|
|
def gen(x):
|
|
yield x + 1
|
|
torch._dynamo.graph_break()
|
|
yield x + 2
|
|
|
|
def fn(x):
|
|
x = x + 4
|
|
return list(gen(x))
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(backend=cnts)(fn)
|
|
x = torch.zeros(3)
|
|
res = fn(x)
|
|
# NOTE: if we enable nested graph breaks on inlined generators, we expect
|
|
# some sort of internal dynamo failure
|
|
ref = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
# fn should be skipped
|
|
self.assertEqual(cnts.frame_count, 0)
|
|
|
|
def outer(x):
|
|
x = x + 8
|
|
return fn(x)[0] + 16
|
|
|
|
cnts.clear()
|
|
torch.compiler.reset()
|
|
|
|
opt_fn = torch._dynamo.optimize(backend=cnts)(outer)
|
|
x = torch.zeros(3)
|
|
res = outer(x)
|
|
ref = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
# only outer should be traced
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
self.assertEqual(cnts.op_count, 2)
|
|
|
|
def test_return_after_graph_break_nested(self):
|
|
# With improper implementation, returning immediately after a nested graph
|
|
# break may skip the rest of the top-level frame.
|
|
def f2(inner, x):
|
|
x += 2
|
|
return inner(x)
|
|
|
|
@torch.compile(backend="eager")
|
|
def f3(inner, x):
|
|
result = f2(inner, x)
|
|
x += 4
|
|
if result is not None:
|
|
x += result
|
|
return x
|
|
|
|
# test normal graph break
|
|
x = torch.zeros(3)
|
|
|
|
def inner1(x):
|
|
x += 1
|
|
return torch._dynamo.graph_break()
|
|
|
|
ref = f3(inner1, x)
|
|
self.assertEqual(ref, torch.zeros(3) + 7)
|
|
|
|
# test step graph break
|
|
x = torch.zeros(3)
|
|
|
|
def inner2(x):
|
|
x += 1
|
|
return torch._dynamo.step_unsupported()
|
|
|
|
ref = f3(inner2, x)
|
|
self.assertEqual(ref, torch.zeros(3) + 7)
|
|
|
|
# test store attr graph break
|
|
# NOTE: we do this manual bytecode generation hack since the only RETURN_*
|
|
# instruction that can follow STORE_ATTR is RETURN_CONST, which was removed in 3.14+.
|
|
|
|
# make sure inner3's code options are compatible with the instructions below
|
|
global y
|
|
|
|
def y():
|
|
pass
|
|
|
|
def inner3(x):
|
|
x.attr = 1000
|
|
y.attr = 2000
|
|
|
|
new_inst = torch._dynamo.bytecode_transformation.create_instruction
|
|
insts = [
|
|
new_inst("LOAD_CONST", argval=1000),
|
|
new_inst("LOAD_CONST", argval=2000),
|
|
new_inst("LOAD_GLOBAL", argval="y"),
|
|
# NOTE: this should cause a graph break - change y if it doesn't work!
|
|
new_inst("STORE_ATTR", argval="attr"),
|
|
new_inst("RETURN_VALUE"),
|
|
]
|
|
if sys.version_info >= (3, 11):
|
|
insts = [new_inst("RESUME", arg=0)] + insts
|
|
code_keys = torch._dynamo.bytecode_transformation.get_code_keys()
|
|
code_options = {k: getattr(inner3.__code__, k) for k in code_keys}
|
|
_, inner3_code = (
|
|
torch._dynamo.bytecode_transformation.clean_and_assemble_instructions(
|
|
insts, code_keys, code_options
|
|
)
|
|
)
|
|
inner3.__code__ = inner3_code
|
|
|
|
torch._dynamo.utils.counters.clear()
|
|
x = torch.zeros(3)
|
|
ref = f3(inner3, x)
|
|
self.assertEqual(ref, torch.zeros(3) + 1006)
|
|
# make sure we're actually STORE_ATTR graph breaking
|
|
self.assertEqual(len(torch._dynamo.utils.counters["graph_break"]), 1)
|
|
|
|
# dynamic branching is harder to test - the other tests should be enough cover
|
|
|
|
# test every function returning
|
|
@torch.compiler.disable
|
|
def inner5(x):
|
|
x += 8
|
|
return x
|
|
|
|
def inner4(x):
|
|
x += 1
|
|
return inner5(x)
|
|
|
|
@torch.compile(backend="eager")
|
|
def f4(x):
|
|
x += 4
|
|
return f2(inner4, x)
|
|
|
|
x = torch.zeros(3)
|
|
ref = f4(x)
|
|
self.assertEqual(ref, torch.zeros(3) + 15)
|
|
|
|
def test_return_after_graph_break_deep_nested(self):
|
|
@torch.compiler.disable
|
|
def f1(x):
|
|
return x + 1
|
|
|
|
def f2(x):
|
|
return f1(x + 2)
|
|
|
|
def f3(x):
|
|
return f2(x + 4)
|
|
|
|
def f4(x):
|
|
x = f3(x + 8)
|
|
return x + 16
|
|
|
|
def f5(x):
|
|
return f4(x + 32)
|
|
|
|
def f6(x):
|
|
return f5(x + 64)
|
|
|
|
def f7(x):
|
|
x = f6(x + 128)
|
|
return x + 256
|
|
|
|
@torch.compile(backend="eager")
|
|
def f8(x):
|
|
return f7(x + 512)
|
|
|
|
x = torch.zeros(3)
|
|
ref = f8(x)
|
|
self.assertEqual(ref, torch.zeros(3) + 1023)
|
|
|
|
# check that only 2 resume functions are created
|
|
self.assertEqual(len(torch._dynamo.utils.counters["resumes"]), 2)
|
|
for name in ("resume_in_f4", "resume_in_f7"):
|
|
self.assertTrue(
|
|
any(
|
|
name in key
|
|
for key in torch._dynamo.utils.counters["resumes"].keys()
|
|
)
|
|
)
|
|
|
|
def test_disable_nested_graph_breaks(self):
|
|
global f1, f2, f3, f4, f5
|
|
|
|
def f1(x):
|
|
x = x + 1
|
|
torch._dynamo.graph_break()
|
|
return x + 2
|
|
|
|
def f2(x):
|
|
return f1(x + 4) + 8
|
|
|
|
# NOTE since the disable_nested_graph_breaks decorator is implemented as a
|
|
# context manager, we don't need to separately test context manager usage.
|
|
@torch._dynamo.disable_nested_graph_breaks
|
|
def f3(x):
|
|
return f2(x + 16) + 32
|
|
|
|
def f4(x):
|
|
return f3(x + 64) + 128
|
|
|
|
def f5(x):
|
|
return f4(x + 256) + 512
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(backend=cnts)(f5)
|
|
x = torch.zeros(3)
|
|
res = f5(x)
|
|
ref = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
# 2 frames from each of f5+f4, f3, f2, f1
|
|
self.assertEqual(cnts.frame_count, 8)
|
|
self.assertEqual(cnts.op_count, 10)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|