pytorch/test/dynamo/test_nested_graph_breaks.py

882 lines
24 KiB
Python

# Owner(s): ["module: dynamo"]
import sys
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo import config
from torch._dynamo.testing import make_test_cls_with_patches
try:
# from . import test_ctx_manager
pass
except ImportError:
# import test_aot_autograd
# import test_ctx_manager
# import test_export
# import test_functions
# import test_higher_order_ops
# import test_misc
# import test_modules
# import test_repros
# import test_sdpa
# import test_subgraphs
pass
test_classes = {}
def make_nested_cls(cls):
suffix = "_nested_graph_breaks"
cls_prefix = "NestedGraphBreaks"
test_class = make_test_cls_with_patches(
cls,
cls_prefix,
suffix,
(config, "debug_force_nested_calls", True),
(config, "debug_force_graph_break_on_leaf_return", True),
(config, "debug_disable_compile_counter", True),
xfail_prop="_expected_failure_nested_graph_breaks",
)
test_classes[test_class.__name__] = test_class
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
# globals()[test_class.__name__] = test_class
test_class.__module__ = __name__
return test_class
tests = [
# test_ctx_manager.CtxManagerTests,
# test_functions.FunctionTests,
# test_misc.MiscTests,
# test_repros.ReproTests,
# test_modules.NNModuleTests,
# test_subgraphs.SubGraphTests,
# test_higher_order_ops.HigherOrderOpTests,
# test_higher_order_ops.FuncTorchHigherOrderOpTests,
# test_aot_autograd.AotAutogradFallbackTests,
# test_sdpa.TestSDPA,
]
test = None
for test in tests:
make_nested_cls(test)
del test
# for use in test_side_effects_globals
global1, global2, global3, global4 = (torch.zeros(3),) * 4
class NestedGraphBreakTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
def test_single_graph_break(self):
# NOTE marking f1, f2, f3 as global
# prevents them from being freevars
global f1, f2, f3
def f1(x1):
x1 = x1 + 1
torch._dynamo.graph_break()
return x1 + 2
def f2(x2):
return f1(x2 + 4) + 8
def f3(x3):
return f2(x3 + 16) + 32
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
x = torch.zeros(3)
res = f3(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 6)
def test_single_graph_break_repeat(self):
global f1, f2, f3
def f1(x1):
x1 = x1 + 1
torch._dynamo.graph_break()
return x1 + 2
def f2(x2):
tmp1 = f1(x2 + 4)
tmp2 = f1(x2 + 8) << 4
return tmp1 + tmp2
def f3(x3):
return f2(x3 + 256) + 512
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
x = torch.zeros(3, dtype=torch.long)
res = f3(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 3)
self.assertEqual(cnts.op_count, 10)
def test_doubly_nested_graph_break(self):
global f1, f2, f3
def f1(x1):
x1 = x1 + 1
torch._dynamo.graph_break()
return x1 + 2
def f2(x2):
x2 = x2 + 4
torch._dynamo.graph_break()
return f1(x2 + 8) + 16
def f3(x3):
return f2(x3 + 32) + 64
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
x = torch.zeros(3)
res = f3(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 3)
self.assertEqual(cnts.op_count, 7)
def test_differing_arg_nums(self):
global f1, f2, f3, f4
def f1(x1, x2):
x = x1 + x2
torch._dynamo.graph_break()
return x + 1
def f2(x3, x4, x5, x6):
return f1(x3 + x4, x5 + x6) + 2
def f3(x7, x8):
return f2(x7, x7 + 4, x8, x8 + 8) + 16
def f4(x9):
return f3(x9, x9 + 32) + 64
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f4)
x = torch.zeros(3)
res = f4(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 10)
def test_differing_locals_nums(self):
global f1, f2, f3
def f1(x1):
loc1 = x1 + 1
torch._dynamo.graph_break()
return loc1 + 2
def f2(x2):
loc1 = x2 + 4
loc2 = x2 + 8
return f1(x2) + loc1 + loc2
def f3(x3):
loc1 = x3 + 16
loc2 = x3 + 32
loc3 = x3 + 64
loc4 = x3 + 128
return f2(x3) + loc1 + loc2 + loc3 + loc4
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
x = torch.zeros(3)
res = f3(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 14)
def test_counters(self):
global f1, f2, f3, f4
def f1(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
def f2(x):
return f1(x + 4) + 8
def f3(x):
x = x + 16
for _ in range(1):
x = f2(x)
return x + 32
@torch.compile(backend="eager")
def f4(x):
return f3(x + 64) + 128
self.assertEqual(f4(torch.zeros(3)), torch.zeros(3) + 255)
self.assertEqual(len(torch._dynamo.utils.counters["graph_break"]), 2)
def test_supported_ctx_manager(self):
global check, check_disabled, f1, f2, f3
@torch._dynamo.disable
def check_disabled(value):
assert torch.is_grad_enabled() == value
def check(value):
assert torch.is_grad_enabled() == value
def f1(x):
with torch.no_grad():
x = x + 1
check(False)
check_disabled(False)
check(False)
return x + 2
def f2(x):
with torch.enable_grad():
x = x + 4
check(True)
check_disabled(True)
check(True)
return f1(x) + 8
def f3(x):
with torch.no_grad():
x = x + 16
check(False)
check_disabled(False)
check(False)
return f2(x) + 32
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
x = torch.zeros(3)
res = f3(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 4)
# includes set_grad_enabled ops
self.assertEqual(cnts.op_count, 14)
def test_inactive_ctx_manager(self):
global check, f1, f2, f3
def check(value):
assert torch.is_grad_enabled() == value
def f1(x, ctx1):
x = x + 1
ctx2 = torch.no_grad()
# torch.no_grad() is a stack value at the time of graph break
ctx3 = (torch.no_grad(), torch._dynamo.graph_break())[0]
x = x + 64
torch._dynamo.graph_break()
with ctx1:
check(False)
with ctx2:
check(False)
with ctx3:
check(False)
return x + 2
def f2(x, ctx1):
x = x + 4
ctx2 = torch.no_grad()
x = f1(x, torch.no_grad())
with ctx1:
check(False)
with ctx2:
check(False)
return x + 8
def f3(x):
x = x + 16
ctx = torch.no_grad()
x = f2(x, torch.no_grad())
with ctx:
check(False)
return x + 32
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
x = torch.zeros(3)
res = f3(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 3)
self.assertEqual(cnts.op_count, 7)
@torch._dynamo.config.patch(recompile_limit=1, fail_on_recompile_limit_hit=True)
def test_no_recompiles(self):
global f1, f2, f3
def f1(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
def f2(x):
x = x + 4
x = f1(x)
torch._dynamo.graph_break()
return x + 8
def f3(x):
x = x + 16
return f2(x) + 32
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
x = torch.zeros(3)
res = f3(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 3)
def test_cells(self):
def f1(x1):
cell1 = x1 + 1
cell2 = x1 + 2
def f2(x2, x3):
nonlocal cell1
cell3 = x2 + x3 + 4
cell1 += 8
def f3(x4):
nonlocal cell2, cell3
cell2 += 16
cell3 += 32
torch._dynamo.graph_break()
return x4 + cell1 + cell2 + cell3
return f3(x2 + x3), cell3
return f2(x1 + 64, x1 + 128) + (cell1, cell2)
def outer(x):
return f1(x)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(outer)
x = torch.zeros(3)
res = outer(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 13)
def test_dead_nested_cells(self):
global f1, f2, f3
def f3(x, cell1):
cell1 += 2
x = x + cell1
torch._dynamo.graph_break()
return x + cell1
def f1(cell1=0):
def inner(x):
x += 4
x = f3(x, cell1)
return x + 8
return inner
def f2(x):
return f1()(x + 16) + 32
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f2)
x = torch.zeros(3)
res = f2(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
# If we don't handle dead cells in nested functions correctly,
# frame_count will increase since we also
# graph break when we attempt to codegen inner.
# The exact issue was that side_effects was failing to codegen inner's cell's creation.
# So when we try to codegen cells for resume functions, we end up trying to codegen
# a CellVariable without a source, which leads to a graph break we can't resume from.
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 6)
def test_cells_double_graph_break(self):
def f1(x1):
cell1 = x1 + 1
def f2(x2):
nonlocal cell1
cell1 += 2
torch._dynamo.graph_break()
torch._dynamo.graph_break()
return x2 + cell1
return f2(x1 + 4), cell1
def outer(x):
return f1(x)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(outer)
x = torch.zeros(3)
res = outer(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 4)
def test_side_effects_cells(self):
cell1, cell2, cell3, cell4 = (torch.zeros(3),) * 4
def f1():
nonlocal cell1
cell1 += 1
torch._dynamo.graph_break()
return cell1 + cell2
def f2():
nonlocal cell3
cell3 += 2
return f1() + cell3 + cell4
def f3():
return f2()
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
cell1 = torch.zeros(3)
cell2 = torch.zeros(3) + 4
cell3 = torch.zeros(3)
cell4 = torch.zeros(3) + 8
res = f3()
res = (res,) + tuple(x.clone() for x in (cell1, cell2, cell3, cell4))
cell1 = torch.zeros(3)
cell2 = torch.zeros(3) + 4
cell3 = torch.zeros(3)
cell4 = torch.zeros(3) + 8
ref = opt_fn()
ref = (ref,) + tuple(x.clone() for x in (cell1, cell2, cell3, cell4))
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 5)
def test_side_effects_globals(self):
global f1, f2, f3
global global1, global2, global3, global4
def f1():
global global1
global1 += 1
torch._dynamo.graph_break()
return global1 + global2
def f2():
global global3
global3 += 2
return f1() + global3 + global4
def f3(x):
return x + f2()
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
x = torch.ones(3)
global1 = torch.zeros(3)
global2 = torch.zeros(3) + 4
global3 = torch.zeros(3)
global4 = torch.zeros(3) + 8
res = (f3(x), global1.clone(), global2, global3.clone(), global4)
global1 = torch.zeros(3)
global2 = torch.zeros(3) + 4
global3 = torch.zeros(3)
global4 = torch.zeros(3) + 8
ref = (opt_fn(x), global1.clone(), global2, global3.clone(), global4)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 6)
def test_side_effects_globals_different_module(self):
global f1, f2, _test_nested_graph_breaks_helper
try:
from . import _test_nested_graph_breaks_helper
except ImportError:
import _test_nested_graph_breaks_helper
def f1(x):
x = x + 1
torch._dynamo.graph_break()
return x + 1
def f2(x):
x = x + 1
x = _test_nested_graph_breaks_helper.fn(x, f1)
return x + 1
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f2)
_test_nested_graph_breaks_helper.reset_state()
x = torch.zeros(3)
res = (f2(x), _test_nested_graph_breaks_helper.global1.clone())
_test_nested_graph_breaks_helper.reset_state()
ref = (opt_fn(x), _test_nested_graph_breaks_helper.global1.clone())
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 7)
def test_nested_graph_break_in_loop(self):
global f1, f2, f3, f4, f5
def f1(x, i):
x = x + 1
if i == 5:
torch._dynamo.graph_break()
return x + 1
def f2(x, i):
x = x + 1
x = f1(x, i)
return x + 1
def f3(x):
for i in range(8):
x = f2(x, i)
return x
def f4(x):
x = x + 1
x = f3(x)
return x + 1
def f5(x):
x = x + 1
x = f4(x)
return x + 1
cnts = torch._dynamo.testing.CompileCounter()
# dynamic=True to prevent unnecessary recompiles
opt_fn = torch._dynamo.optimize(backend=cnts, dynamic=True)(f5)
x = torch.zeros(3)
res = f5(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
# skip frame due to nested graph break in for loop
# 2 frames from f5+f4, 2 frames from f2+f1 (i == 5), 1 frame from f2+f1 (i != 5)
self.assertEqual(cnts.frame_count, 5)
# 4 additions from f5+f4, 2 x 4 additions from f2+f1 (i == 5, i != 5)
self.assertEqual(cnts.op_count, 12)
self.assertEqual(torch._dynamo.utils.counters["frames"]["total"], 6)
def test_nested_graph_break_in_try_block(self):
# NOTE: this also tests nested step_graph_break
global f1, f2, f3, f4, f5
def f1(x):
x = x + 1
torch._dynamo.graph_break()
return x + 1
def f2(x):
x = x + 1
x = f1(x)
return x + 1
def f3(x):
x = x + 1
try:
x = x + 1
x = f2(x)
x = x + 1
finally:
pass
return x + 1
def f4(x):
x = x + 1
x = f3(x)
return x + 1
def f5(x):
x = x + 1
x = f4(x)
return x + 1
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f5)
x = torch.zeros(3)
res = f5(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
# skip frame due to graph break in try block
# 2 frames from f5+f4+(first part of f3), 2 frames from f2+f1
self.assertEqual(cnts.frame_count, 4)
# 5 additions from f5+f4+(first part of f3), 4 additions from f2+f1
self.assertEqual(cnts.op_count, 9)
self.assertEqual(torch._dynamo.utils.counters["frames"]["total"], 4)
def test_nested_step_unsupported(self):
global f1, f2, f3
def f1(x):
return x + 1
def f2(x):
x = x + 2
torch._dynamo.step_unsupported()
return f1(x) + 4
def f3(x):
x = x + 8
return f2(x) + 16
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
x = torch.zeros(3)
res = f3(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
# 1 frame from start of f3 + start of f2, 1 frame from f1, 1 frame from the end of f3
self.assertEqual(cnts.frame_count, 3)
# all ops except + 4
self.assertEqual(cnts.op_count, 4)
self.assertEqual(torch._dynamo.utils.counters["frames"]["total"], 3)
def test_generator_nested_graph_break(self):
def gen(x):
yield x + 1
torch._dynamo.graph_break()
yield x + 2
def fn(x):
x = x + 4
return list(gen(x))
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(fn)
x = torch.zeros(3)
res = fn(x)
# NOTE: if we enable nested graph breaks on inlined generators, we expect
# some sort of internal dynamo failure
ref = opt_fn(x)
self.assertEqual(ref, res)
# fn should be skipped
self.assertEqual(cnts.frame_count, 0)
def outer(x):
x = x + 8
return fn(x)[0] + 16
cnts.clear()
torch.compiler.reset()
opt_fn = torch._dynamo.optimize(backend=cnts)(outer)
x = torch.zeros(3)
res = outer(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
# only outer should be traced
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 2)
def test_return_after_graph_break_nested(self):
# With improper implementation, returning immediately after a nested graph
# break may skip the rest of the top-level frame.
def f2(inner, x):
x += 2
return inner(x)
@torch.compile(backend="eager")
def f3(inner, x):
result = f2(inner, x)
x += 4
if result is not None:
x += result
return x
# test normal graph break
x = torch.zeros(3)
def inner1(x):
x += 1
return torch._dynamo.graph_break()
ref = f3(inner1, x)
self.assertEqual(ref, torch.zeros(3) + 7)
# test step graph break
x = torch.zeros(3)
def inner2(x):
x += 1
return torch._dynamo.step_unsupported()
ref = f3(inner2, x)
self.assertEqual(ref, torch.zeros(3) + 7)
# test store attr graph break
# NOTE: we do this manual bytecode generation hack since the only RETURN_*
# instruction that can follow STORE_ATTR is RETURN_CONST, which was removed in 3.14+.
# make sure inner3's code options are compatible with the instructions below
global y
def y():
pass
def inner3(x):
x.attr = 1000
y.attr = 2000
new_inst = torch._dynamo.bytecode_transformation.create_instruction
insts = [
new_inst("LOAD_CONST", argval=1000),
new_inst("LOAD_CONST", argval=2000),
new_inst("LOAD_GLOBAL", argval="y"),
# NOTE: this should cause a graph break - change y if it doesn't work!
new_inst("STORE_ATTR", argval="attr"),
new_inst("RETURN_VALUE"),
]
if sys.version_info >= (3, 11):
insts = [new_inst("RESUME", arg=0)] + insts
code_keys = torch._dynamo.bytecode_transformation.get_code_keys()
code_options = {k: getattr(inner3.__code__, k) for k in code_keys}
_, inner3_code = (
torch._dynamo.bytecode_transformation.clean_and_assemble_instructions(
insts, code_keys, code_options
)
)
inner3.__code__ = inner3_code
torch._dynamo.utils.counters.clear()
x = torch.zeros(3)
ref = f3(inner3, x)
self.assertEqual(ref, torch.zeros(3) + 1006)
# make sure we're actually STORE_ATTR graph breaking
self.assertEqual(len(torch._dynamo.utils.counters["graph_break"]), 1)
# dynamic branching is harder to test - the other tests should be enough cover
# test every function returning
@torch.compiler.disable
def inner5(x):
x += 8
return x
def inner4(x):
x += 1
return inner5(x)
@torch.compile(backend="eager")
def f4(x):
x += 4
return f2(inner4, x)
x = torch.zeros(3)
ref = f4(x)
self.assertEqual(ref, torch.zeros(3) + 15)
def test_return_after_graph_break_deep_nested(self):
@torch.compiler.disable
def f1(x):
return x + 1
def f2(x):
return f1(x + 2)
def f3(x):
return f2(x + 4)
def f4(x):
x = f3(x + 8)
return x + 16
def f5(x):
return f4(x + 32)
def f6(x):
return f5(x + 64)
def f7(x):
x = f6(x + 128)
return x + 256
@torch.compile(backend="eager")
def f8(x):
return f7(x + 512)
x = torch.zeros(3)
ref = f8(x)
self.assertEqual(ref, torch.zeros(3) + 1023)
# check that only 2 resume functions are created
self.assertEqual(len(torch._dynamo.utils.counters["resumes"]), 2)
for name in ("resume_in_f4", "resume_in_f7"):
self.assertTrue(
any(
name in key
for key in torch._dynamo.utils.counters["resumes"].keys()
)
)
def test_disable_nested_graph_breaks(self):
global f1, f2, f3, f4, f5
def f1(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
def f2(x):
return f1(x + 4) + 8
# NOTE since the disable_nested_graph_breaks decorator is implemented as a
# context manager, we don't need to separately test context manager usage.
@torch._dynamo.disable_nested_graph_breaks
def f3(x):
return f2(x + 16) + 32
def f4(x):
return f3(x + 64) + 128
def f5(x):
return f4(x + 256) + 512
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f5)
x = torch.zeros(3)
res = f5(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
# 2 frames from each of f5+f4, f3, f2, f1
self.assertEqual(cnts.frame_count, 8)
self.assertEqual(cnts.op_count, 10)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()