[dynamo] Bug fix for _torchdynamo_inline source handling (#135612)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135612
Approved by: https://github.com/drisspg
This commit is contained in:
Animesh Jain 2024-09-11 14:18:34 -07:00 committed by PyTorch MergeBot
parent f5f1d0a753
commit eaba287adb
3 changed files with 25 additions and 5 deletions

View File

@ -5912,6 +5912,22 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
actual[1].untyped_storage().data_ptr(),
)
def test_torch_compile_in_compile_frame(self):
# TODO(anijain2305/yanboliang) - Dont graph break on torch.compile.
def gn(x, c=None):
if c is None:
c = 2
return c * x
def outer_func(x):
return torch.compile(gn)(x)
compile_outer = torch.compile(outer_func, backend="eager")
x = torch.randn(4)
ref = outer_func(x)
res = compile_outer(x)
self.assertEqual(ref, res)
instantiate_parametrized_tests(ReproTests)

View File

@ -547,11 +547,6 @@ class VariableBuilder:
if id_dispatch is not None:
return id_dispatch(self, value)
# Note - There are some nested values where types mismatch!
# We want to get those out and wrap those.
if is_function_or_wrapper(value):
value = inspect.getattr_static(value, "_torchdynamo_inline", value)
# Everything else (NB: order matters!)
if is_traceable_wrapper_subclass(value) or istype(
value, config.traceable_tensor_subclasses
@ -988,6 +983,13 @@ class VariableBuilder:
elif is_lru_cache_wrapped_function(value):
self.install_guards(GuardBuilder.TYPE_MATCH)
return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source)
elif is_function_or_wrapper(value) and inspect.getattr_static(
value, "_torchdynamo_inline", False
):
self.install_guards(GuardBuilder.TYPE_MATCH)
return WrapperUserFunctionVariable(
value, "_torchdynamo_inline", source=self.source
)
elif is_function_or_wrapper(value):
value, attr_name = unwrap_with_attr_name_if_wrapper(value)
# For these wrappers, Dynamo points to the wrapped function,

View File

@ -152,6 +152,8 @@ class UserFunctionVariable(BaseUserFunctionVariable):
assert isinstance(
fn, (types.FunctionType, torch.jit.ScriptFunction)
), f"expected FunctionType found {typestr(fn)} {fn}"
# TODO(anijain2305) - Replace directly calling UserFunctionVariable with
# VariableBuilder, which handles the wrapping of _torchdynamo_inline.
# unpack @torch._dynamo.optimize()(fn) wrapped function
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
self.fn: types.FunctionType = fn