mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
**Motivation:** We try to make torch.cond use torch.compile automatically so that we could error out when there is side-effects in the branches and correctly handle the closures. Before this PR, we have a warning if we don't turn on a config raise_on_backend_change (turning it on gives us an error) for the following code: ```python def foo() # Inside torch.cond, we'd like to do something like torch.compile(foo, backend="eager", fullgraph=True)(...) ... # Users may then call torch.compile somewhere else. # Dynamo will use the cached code of foo for "eager" backend # but we expect dynamo to recompile with "inductor" backend. torch.compile(foo, backend="inductor")(...) ``` This PR adds a BACKEND_MATCH guard. Effectively, it implements a per-backend cache. In the above example, the cached code for "eager" won't work for "inductor" due to guard check failures and the second torch.compile will do a re-compilation. In the future, it might be useful to have something like a configuration guard that guards against dynamo configuration changes across different compiles (e.g. compile a function with fullgraph=False then compile it again with fullgraph=True). **Implementation:** 1. We add a guarded_backend_cache and check the most_recent_backend against the backend associated with cached code. We also remove the raise_on_backend_change flag. Note: More lines are printed for debug log due to newly added context manager and guard adds . **Test Plan:** Removed original tests that raise on different backend and add a new test to test whether the BACKEND_MATCH guard can guard against backend change. Pull Request resolved: https://github.com/pytorch/pytorch/pull/107337 Approved by: https://github.com/jansel
320 lines
7.8 KiB
Python
320 lines
7.8 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import re
|
|
import sys
|
|
from io import StringIO
|
|
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
from torch._dynamo.comptime import comptime
|
|
|
|
# Because we don't support free variables in comptime at the moment,
|
|
# we have to communicate via globals. This also means these tests cannot
|
|
# be run in parallel in a single process (not that you'd... ever want
|
|
# to do that?)
|
|
FILE = None
|
|
SELF = None
|
|
|
|
|
|
class ComptimeTests(torch._dynamo.test_case.TestCase):
|
|
def test_print_graph(self):
|
|
global FILE
|
|
FILE = StringIO()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch._dynamo.optimize(cnt)
|
|
def f(x):
|
|
y = x * 2
|
|
|
|
@comptime
|
|
def _(ctx):
|
|
ctx.print_graph(verbose=False, file=FILE)
|
|
|
|
# Test the compact notation doesn't error or graph break;
|
|
# you'll have to visually inspect to see that it printed
|
|
comptime.print_graph()
|
|
|
|
return y + 3
|
|
|
|
f(torch.randn(2))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertExpectedInline(
|
|
FILE.getvalue().strip(),
|
|
"""\
|
|
def forward(self, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
y = l_x_ * 2; l_x_ = None""",
|
|
)
|
|
|
|
def test_print_disas(self):
|
|
global FILE
|
|
FILE = StringIO()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch._dynamo.optimize(cnt)
|
|
def f(x):
|
|
y = x * 2
|
|
|
|
@comptime
|
|
def _(ctx):
|
|
ctx.print_disas(file=FILE)
|
|
|
|
comptime.print_disas()
|
|
|
|
return y + 3
|
|
|
|
def munge_disas(s):
|
|
re.sub(
|
|
r"^(?: +\d+)?(?: +(-->)) \+\d+ ([A-Za-z0-9_]+)",
|
|
"\1 \3",
|
|
s,
|
|
flags=re.MULTILINE,
|
|
)
|
|
|
|
f(torch.randn(2))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
out = FILE.getvalue()
|
|
# Check that the instruction offset is working
|
|
self.assertIn("-->", out)
|
|
# Check that the bytecode resembles what we expect
|
|
self.assertIn("STORE_FAST", out)
|
|
if sys.version_info < (3, 11):
|
|
self.assertIn("BINARY_MULTIPLY", out)
|
|
else:
|
|
self.assertIn("BINARY_OP", out)
|
|
|
|
def test_print_value_stack(self):
|
|
global FILE
|
|
FILE = StringIO()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
def g(x):
|
|
@comptime
|
|
def _(ctx):
|
|
ctx.print_value_stack(file=FILE, stacklevel=1)
|
|
|
|
return x
|
|
|
|
@torch._dynamo.optimize(cnt)
|
|
def f(x):
|
|
y = x + g(x)
|
|
|
|
return y + comptime.print_value_stack_and_return(y * 2)
|
|
|
|
f(torch.randn(2))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertExpectedInline(
|
|
FILE.getvalue(),
|
|
"""\
|
|
- TensorVariable()
|
|
""",
|
|
)
|
|
|
|
def test_print_locals(self):
|
|
global FILE
|
|
FILE = StringIO()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch._dynamo.optimize(cnt)
|
|
def f(x):
|
|
y = x * 2
|
|
|
|
@comptime
|
|
def _(ctx):
|
|
ctx.print_locals(file=FILE)
|
|
|
|
comptime.print_locals()
|
|
|
|
return y + 3
|
|
|
|
f(torch.randn(2))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertExpectedInline(
|
|
FILE.getvalue(),
|
|
"""\
|
|
x = TensorVariable()
|
|
y = TensorVariable()
|
|
""",
|
|
)
|
|
|
|
def test_print_bt(self):
|
|
global FILE
|
|
FILE = StringIO()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
def g(x):
|
|
@comptime
|
|
def _(ctx):
|
|
ctx.print_bt(file=FILE)
|
|
|
|
comptime.print_bt()
|
|
|
|
return x + 3
|
|
|
|
@torch._dynamo.optimize(cnt)
|
|
def f(x):
|
|
y = x * 2
|
|
y = g(y)
|
|
return y + 3
|
|
|
|
def munge_filenames(s):
|
|
return re.sub(r'File "[^"]+", line \d+', 'File "X", line X', s)
|
|
|
|
f(torch.randn(2))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
bt = FILE.getvalue()
|
|
self.assertIn("y = g(y)", bt)
|
|
|
|
def test_print_guards(self):
|
|
global FILE
|
|
FILE = StringIO()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch._dynamo.optimize(cnt)
|
|
def f(x):
|
|
y = x * 2
|
|
|
|
@comptime
|
|
def _(ctx):
|
|
ctx.print_guards(file=FILE)
|
|
|
|
comptime.print_guards()
|
|
|
|
return y + 3
|
|
|
|
f(torch.randn(2))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertExpectedInline(
|
|
re.sub(r"\s+$", "", FILE.getvalue().rstrip(), flags=re.MULTILINE),
|
|
"""\
|
|
|
|
local "L['x']" TENSOR_MATCH
|
|
{
|
|
'guard_types': None,
|
|
'code': None,
|
|
'obj_weakref': None
|
|
'guarded_class': None
|
|
}
|
|
global '' GRAD_MODE
|
|
{
|
|
'guard_types': None,
|
|
'code': None,
|
|
'obj_weakref': None
|
|
'guarded_class': None
|
|
}
|
|
global '' DETERMINISTIC_ALGORITHMS
|
|
{
|
|
'guard_types': None,
|
|
'code': None,
|
|
'obj_weakref': None
|
|
'guarded_class': None
|
|
}
|
|
global '' TORCH_FUNCTION_STATE
|
|
{
|
|
'guard_types': None,
|
|
'code': None,
|
|
'obj_weakref': None
|
|
'guarded_class': None
|
|
}
|
|
global '' DEFAULT_DEVICE
|
|
{
|
|
'guard_types': None,
|
|
'code': None,
|
|
'obj_weakref': None
|
|
'guarded_class': None
|
|
}
|
|
global '' BACKEND_MATCH
|
|
{
|
|
'guard_types': None,
|
|
'code': None,
|
|
'obj_weakref': None
|
|
'guarded_class': None
|
|
}
|
|
shape_env '' SHAPE_ENV
|
|
{
|
|
'guard_types': None,
|
|
'code': None,
|
|
'obj_weakref': None
|
|
'guarded_class': None
|
|
}""",
|
|
)
|
|
|
|
def test_graph_break(self):
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch._dynamo.optimize(cnt)
|
|
def f(x):
|
|
y = x * 2
|
|
|
|
@comptime
|
|
def _(ctx):
|
|
pass
|
|
|
|
return y + 3
|
|
|
|
f(torch.randn(2))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
cnt.frame_count = 0
|
|
|
|
@torch._dynamo.optimize(cnt)
|
|
def g(x):
|
|
y = x * 2
|
|
|
|
@comptime
|
|
def _(ctx):
|
|
ctx.graph_break()
|
|
|
|
y = y + 2
|
|
|
|
comptime.graph_break()
|
|
|
|
return y * 3
|
|
|
|
g(torch.randn(2))
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
|
|
def test_get_local(self):
|
|
global SELF, FILE
|
|
SELF = self
|
|
FILE = StringIO()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch._dynamo.optimize(cnt)
|
|
def f(x):
|
|
y = x * 2
|
|
lit = 2
|
|
|
|
@comptime
|
|
def _(ctx):
|
|
y = ctx.get_local("y")
|
|
SELF.assertEqual(y.as_fake().size(0), 2)
|
|
SELF.assertEqual(y.size(0), 2)
|
|
# Trigger a graph write (TODO: this is not so
|
|
# useful right now as there's no way to make use
|
|
# of the output proxy; maybe it's useful for inserting
|
|
# side-effectful operations into the graph)
|
|
y.as_proxy() + 4
|
|
ctx.print_graph(verbose=False, file=FILE)
|
|
SELF.assertIs(y.python_type(), torch.Tensor)
|
|
lit = ctx.get_local("lit")
|
|
SELF.assertEqual(lit.as_python_constant(), 2)
|
|
|
|
return y + 3
|
|
|
|
f(torch.randn(2))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertExpectedInline(
|
|
FILE.getvalue().strip(),
|
|
"""\
|
|
def forward(self, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
y = l_x_ * 2; l_x_ = None
|
|
add = y + 4; y = None""",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|