diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 66ca040dd41..74f5e3749c6 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -5912,6 +5912,22 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): actual[1].untyped_storage().data_ptr(), ) + def test_torch_compile_in_compile_frame(self): + # TODO(anijain2305/yanboliang) - Dont graph break on torch.compile. + def gn(x, c=None): + if c is None: + c = 2 + return c * x + + def outer_func(x): + return torch.compile(gn)(x) + + compile_outer = torch.compile(outer_func, backend="eager") + x = torch.randn(4) + ref = outer_func(x) + res = compile_outer(x) + self.assertEqual(ref, res) + instantiate_parametrized_tests(ReproTests) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index f7ac4d073b5..bf1cbed10b8 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -547,11 +547,6 @@ class VariableBuilder: if id_dispatch is not None: return id_dispatch(self, value) - # Note - There are some nested values where types mismatch! - # We want to get those out and wrap those. - if is_function_or_wrapper(value): - value = inspect.getattr_static(value, "_torchdynamo_inline", value) - # Everything else (NB: order matters!) if is_traceable_wrapper_subclass(value) or istype( value, config.traceable_tensor_subclasses @@ -988,6 +983,13 @@ class VariableBuilder: elif is_lru_cache_wrapped_function(value): self.install_guards(GuardBuilder.TYPE_MATCH) return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source) + elif is_function_or_wrapper(value) and inspect.getattr_static( + value, "_torchdynamo_inline", False + ): + self.install_guards(GuardBuilder.TYPE_MATCH) + return WrapperUserFunctionVariable( + value, "_torchdynamo_inline", source=self.source + ) elif is_function_or_wrapper(value): value, attr_name = unwrap_with_attr_name_if_wrapper(value) # For these wrappers, Dynamo points to the wrapped function, diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 90c2d0665ae..c7324960fb9 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -152,6 +152,8 @@ class UserFunctionVariable(BaseUserFunctionVariable): assert isinstance( fn, (types.FunctionType, torch.jit.ScriptFunction) ), f"expected FunctionType found {typestr(fn)} {fn}" + # TODO(anijain2305) - Replace directly calling UserFunctionVariable with + # VariableBuilder, which handles the wrapping of _torchdynamo_inline. # unpack @torch._dynamo.optimize()(fn) wrapped function fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn) self.fn: types.FunctionType = fn