mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Dynamo] Match closures by code ID (#109427)
Closes https://github.com/pytorch/pytorch/issues/107866 Pull Request resolved: https://github.com/pytorch/pytorch/pull/109427 Approved by: https://github.com/ezyang, https://github.com/jansel
This commit is contained in:
parent
09c598745c
commit
3de0857503
|
|
@ -7544,6 +7544,76 @@ ShapeEnv not equal: field values don't match:
|
|||
torch.set_default_dtype(torch.double)
|
||||
foo()
|
||||
|
||||
def test_no_recompile_inner_function(self):
|
||||
def forward(inp):
|
||||
def g(y):
|
||||
return inp + y
|
||||
|
||||
print("graph break")
|
||||
return g(torch.rand([1]))
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts)(forward)
|
||||
|
||||
input = torch.rand([2])
|
||||
_ = opt_fn(input)
|
||||
_ = opt_fn(input)
|
||||
_ = opt_fn(input)
|
||||
# Should not have recompiled
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
def test_no_recompile_inner_lambda(self):
|
||||
def forward(inp):
|
||||
g = lambda y: inp + y
|
||||
print("graph break")
|
||||
return g(torch.rand([1]))
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts)(forward)
|
||||
|
||||
input = torch.rand([2])
|
||||
_ = opt_fn(input)
|
||||
_ = opt_fn(input)
|
||||
_ = opt_fn(input)
|
||||
# Should not have recompiled
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
def test_complex_closure(self):
|
||||
@torch.compile
|
||||
def forward(y):
|
||||
def a():
|
||||
def x(z):
|
||||
return y + z
|
||||
|
||||
return x
|
||||
|
||||
return a()
|
||||
|
||||
input1 = torch.rand([2])
|
||||
input2 = torch.rand([2])
|
||||
res = forward(input1)(input2)
|
||||
self.assertTrue(same(res, input1 + input2))
|
||||
|
||||
def test_non_inlined_closure(self):
|
||||
@torch.compile()
|
||||
def program(x, y):
|
||||
one = lambda x, y: x + y
|
||||
|
||||
def inner():
|
||||
# Force no inlining
|
||||
torch._dynamo.graph_break()
|
||||
return one(x, y)
|
||||
|
||||
res = inner()
|
||||
one = lambda x, y: x - y
|
||||
res += inner()
|
||||
return res
|
||||
|
||||
input1 = torch.randn(1)
|
||||
input2 = torch.randn(1)
|
||||
|
||||
self.assertTrue(same(program(input1, input2), input1 + input1))
|
||||
|
||||
|
||||
class TestTracer(JitTestCase):
|
||||
def test_jit_save(self):
|
||||
|
|
|
|||
|
|
@ -1529,9 +1529,15 @@ def forward(self, arg0_1, arg1_1):
|
|||
inp = torch.ones(3, 4)
|
||||
exp_out = inp.sin()
|
||||
iter_n = torch._dynamo.config.cache_size_limit + 1
|
||||
|
||||
# Need this because Dynamo checks lambda code ID not object itself.
|
||||
def make_dummy_fn(op):
|
||||
exec(f"temp = lambda x: x.{op}()")
|
||||
return locals()["temp"]
|
||||
|
||||
for _ in range(iter_n):
|
||||
# each lambda has a different object id thus fails the guard
|
||||
self.assertEqual(foo(inp, lambda x: x.cos(), lambda x: x.sin()), exp_out)
|
||||
self.assertEqual(foo(inp, make_dummy_fn("cos"), make_dummy_fn("sin")), exp_out)
|
||||
self.assertEqual(counters["stats"]["calls_captured"], iter_n)
|
||||
self.assertEqual(counters["stats"]["unique_graphs"], iter_n)
|
||||
|
||||
|
|
|
|||
|
|
@ -182,7 +182,7 @@ class GuardCodeList:
|
|||
class GuardBuilder(GuardBuilderBase):
|
||||
def __init__(
|
||||
self,
|
||||
id_ref: Callable[[Type[object]], str],
|
||||
id_ref: Callable[[Any], str],
|
||||
source_ref: Callable[[Source], str],
|
||||
lookup_weakrefs: Callable[[Type[object]], ReferenceType[object]],
|
||||
user_scope: Optional[Dict[str, object]],
|
||||
|
|
@ -487,6 +487,20 @@ class GuardBuilder(GuardBuilderBase):
|
|||
if guard.is_local():
|
||||
return self.ID_MATCH(guard)
|
||||
|
||||
def CLOSURE_MATCH(self, guard: Guard):
|
||||
"""matches a closure by __code__ id."""
|
||||
if guard.is_local():
|
||||
val = self.get(guard.name)
|
||||
# Strictly only want user-defined functions
|
||||
if type(val) == types.FunctionType and hasattr(val, "__code__"):
|
||||
ref = self.arg_ref(guard)
|
||||
code = [
|
||||
f"___check_obj_id(getattr({ref}, '__code__', None), {self.id_ref(val.__code__)})",
|
||||
]
|
||||
self._produce_guard_code(guard, code)
|
||||
else:
|
||||
self.FUNCTION_MATCH(guard)
|
||||
|
||||
def BUILTIN_MATCH(self, guard: Guard):
|
||||
return self.FUNCTION_MATCH(guard)
|
||||
|
||||
|
|
|
|||
|
|
@ -338,7 +338,7 @@ class VariableBuilder:
|
|||
lambda self, value: LambdaVariable(
|
||||
InspectSignatureVariable.create,
|
||||
source=self.source,
|
||||
guards=self.make_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
guards=self.make_guards(GuardBuilder.CLOSURE_MATCH),
|
||||
),
|
||||
),
|
||||
(comptime, lambda self, value: ComptimeVariable()),
|
||||
|
|
@ -562,7 +562,7 @@ class VariableBuilder:
|
|||
return UserFunctionVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
guards=make_guards(GuardBuilder.CLOSURE_MATCH),
|
||||
)
|
||||
elif istype(value, (types.ModuleType, replay_record.DummyModule)):
|
||||
return PythonModuleVariable(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user