[Dynamo] Match closures by code ID (#109427)

Closes https://github.com/pytorch/pytorch/issues/107866

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109427
Approved by: https://github.com/ezyang, https://github.com/jansel
This commit is contained in:
Ken Jin 2023-09-25 19:10:31 +00:00 committed by PyTorch MergeBot
parent 09c598745c
commit 3de0857503
4 changed files with 94 additions and 4 deletions

View File

@ -7544,6 +7544,76 @@ ShapeEnv not equal: field values don't match:
torch.set_default_dtype(torch.double)
foo()
def test_no_recompile_inner_function(self):
def forward(inp):
def g(y):
return inp + y
print("graph break")
return g(torch.rand([1]))
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(forward)
input = torch.rand([2])
_ = opt_fn(input)
_ = opt_fn(input)
_ = opt_fn(input)
# Should not have recompiled
self.assertEqual(cnts.frame_count, 1)
def test_no_recompile_inner_lambda(self):
def forward(inp):
g = lambda y: inp + y
print("graph break")
return g(torch.rand([1]))
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(forward)
input = torch.rand([2])
_ = opt_fn(input)
_ = opt_fn(input)
_ = opt_fn(input)
# Should not have recompiled
self.assertEqual(cnts.frame_count, 1)
def test_complex_closure(self):
@torch.compile
def forward(y):
def a():
def x(z):
return y + z
return x
return a()
input1 = torch.rand([2])
input2 = torch.rand([2])
res = forward(input1)(input2)
self.assertTrue(same(res, input1 + input2))
def test_non_inlined_closure(self):
@torch.compile()
def program(x, y):
one = lambda x, y: x + y
def inner():
# Force no inlining
torch._dynamo.graph_break()
return one(x, y)
res = inner()
one = lambda x, y: x - y
res += inner()
return res
input1 = torch.randn(1)
input2 = torch.randn(1)
self.assertTrue(same(program(input1, input2), input1 + input1))
class TestTracer(JitTestCase):
def test_jit_save(self):

View File

@ -1529,9 +1529,15 @@ def forward(self, arg0_1, arg1_1):
inp = torch.ones(3, 4)
exp_out = inp.sin()
iter_n = torch._dynamo.config.cache_size_limit + 1
# Need this because Dynamo checks lambda code ID not object itself.
def make_dummy_fn(op):
exec(f"temp = lambda x: x.{op}()")
return locals()["temp"]
for _ in range(iter_n):
# each lambda has a different object id thus fails the guard
self.assertEqual(foo(inp, lambda x: x.cos(), lambda x: x.sin()), exp_out)
self.assertEqual(foo(inp, make_dummy_fn("cos"), make_dummy_fn("sin")), exp_out)
self.assertEqual(counters["stats"]["calls_captured"], iter_n)
self.assertEqual(counters["stats"]["unique_graphs"], iter_n)

View File

@ -182,7 +182,7 @@ class GuardCodeList:
class GuardBuilder(GuardBuilderBase):
def __init__(
self,
id_ref: Callable[[Type[object]], str],
id_ref: Callable[[Any], str],
source_ref: Callable[[Source], str],
lookup_weakrefs: Callable[[Type[object]], ReferenceType[object]],
user_scope: Optional[Dict[str, object]],
@ -487,6 +487,20 @@ class GuardBuilder(GuardBuilderBase):
if guard.is_local():
return self.ID_MATCH(guard)
def CLOSURE_MATCH(self, guard: Guard):
"""matches a closure by __code__ id."""
if guard.is_local():
val = self.get(guard.name)
# Strictly only want user-defined functions
if type(val) == types.FunctionType and hasattr(val, "__code__"):
ref = self.arg_ref(guard)
code = [
f"___check_obj_id(getattr({ref}, '__code__', None), {self.id_ref(val.__code__)})",
]
self._produce_guard_code(guard, code)
else:
self.FUNCTION_MATCH(guard)
def BUILTIN_MATCH(self, guard: Guard):
return self.FUNCTION_MATCH(guard)

View File

@ -338,7 +338,7 @@ class VariableBuilder:
lambda self, value: LambdaVariable(
InspectSignatureVariable.create,
source=self.source,
guards=self.make_guards(GuardBuilder.FUNCTION_MATCH),
guards=self.make_guards(GuardBuilder.CLOSURE_MATCH),
),
),
(comptime, lambda self, value: ComptimeVariable()),
@ -562,7 +562,7 @@ class VariableBuilder:
return UserFunctionVariable(
value,
source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
guards=make_guards(GuardBuilder.CLOSURE_MATCH),
)
elif istype(value, (types.ModuleType, replay_record.DummyModule)):
return PythonModuleVariable(