diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 1b9bf3c83a8..a3c4226ed6a 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -5912,22 +5912,6 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): actual[1].untyped_storage().data_ptr(), ) - def test_torch_compile_in_compile_frame(self): - # TODO(anijain2305/yanboliang) - Dont graph break on torch.compile. - def gn(x, c=None): - if c is None: - c = 2 - return c * x - - def outer_func(x): - return torch.compile(gn)(x) - - compile_outer = torch.compile(outer_func, backend="eager") - x = torch.randn(4) - ref = outer_func(x) - res = compile_outer(x) - self.assertEqual(ref, res) - instantiate_parametrized_tests(ReproTests) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index bf1cbed10b8..f7ac4d073b5 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -547,6 +547,11 @@ class VariableBuilder: if id_dispatch is not None: return id_dispatch(self, value) + # Note - There are some nested values where types mismatch! + # We want to get those out and wrap those. + if is_function_or_wrapper(value): + value = inspect.getattr_static(value, "_torchdynamo_inline", value) + # Everything else (NB: order matters!) if is_traceable_wrapper_subclass(value) or istype( value, config.traceable_tensor_subclasses @@ -983,13 +988,6 @@ class VariableBuilder: elif is_lru_cache_wrapped_function(value): self.install_guards(GuardBuilder.TYPE_MATCH) return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source) - elif is_function_or_wrapper(value) and inspect.getattr_static( - value, "_torchdynamo_inline", False - ): - self.install_guards(GuardBuilder.TYPE_MATCH) - return WrapperUserFunctionVariable( - value, "_torchdynamo_inline", source=self.source - ) elif is_function_or_wrapper(value): value, attr_name = unwrap_with_attr_name_if_wrapper(value) # For these wrappers, Dynamo points to the wrapped function, diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 77662e1f4b5..90c2d0665ae 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -152,6 +152,8 @@ class UserFunctionVariable(BaseUserFunctionVariable): assert isinstance( fn, (types.FunctionType, torch.jit.ScriptFunction) ), f"expected FunctionType found {typestr(fn)} {fn}" + # unpack @torch._dynamo.optimize()(fn) wrapped function + fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn) self.fn: types.FunctionType = fn def as_python_constant(self):