mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[dynamo] Bug fix for _torchdynamo_inline source handling (#135612)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135612 Approved by: https://github.com/drisspg
This commit is contained in:
parent
f5f1d0a753
commit
eaba287adb
|
|
@ -5912,6 +5912,22 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
||||||
actual[1].untyped_storage().data_ptr(),
|
actual[1].untyped_storage().data_ptr(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_torch_compile_in_compile_frame(self):
|
||||||
|
# TODO(anijain2305/yanboliang) - Dont graph break on torch.compile.
|
||||||
|
def gn(x, c=None):
|
||||||
|
if c is None:
|
||||||
|
c = 2
|
||||||
|
return c * x
|
||||||
|
|
||||||
|
def outer_func(x):
|
||||||
|
return torch.compile(gn)(x)
|
||||||
|
|
||||||
|
compile_outer = torch.compile(outer_func, backend="eager")
|
||||||
|
x = torch.randn(4)
|
||||||
|
ref = outer_func(x)
|
||||||
|
res = compile_outer(x)
|
||||||
|
self.assertEqual(ref, res)
|
||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(ReproTests)
|
instantiate_parametrized_tests(ReproTests)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -547,11 +547,6 @@ class VariableBuilder:
|
||||||
if id_dispatch is not None:
|
if id_dispatch is not None:
|
||||||
return id_dispatch(self, value)
|
return id_dispatch(self, value)
|
||||||
|
|
||||||
# Note - There are some nested values where types mismatch!
|
|
||||||
# We want to get those out and wrap those.
|
|
||||||
if is_function_or_wrapper(value):
|
|
||||||
value = inspect.getattr_static(value, "_torchdynamo_inline", value)
|
|
||||||
|
|
||||||
# Everything else (NB: order matters!)
|
# Everything else (NB: order matters!)
|
||||||
if is_traceable_wrapper_subclass(value) or istype(
|
if is_traceable_wrapper_subclass(value) or istype(
|
||||||
value, config.traceable_tensor_subclasses
|
value, config.traceable_tensor_subclasses
|
||||||
|
|
@ -988,6 +983,13 @@ class VariableBuilder:
|
||||||
elif is_lru_cache_wrapped_function(value):
|
elif is_lru_cache_wrapped_function(value):
|
||||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||||
return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source)
|
return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source)
|
||||||
|
elif is_function_or_wrapper(value) and inspect.getattr_static(
|
||||||
|
value, "_torchdynamo_inline", False
|
||||||
|
):
|
||||||
|
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||||
|
return WrapperUserFunctionVariable(
|
||||||
|
value, "_torchdynamo_inline", source=self.source
|
||||||
|
)
|
||||||
elif is_function_or_wrapper(value):
|
elif is_function_or_wrapper(value):
|
||||||
value, attr_name = unwrap_with_attr_name_if_wrapper(value)
|
value, attr_name = unwrap_with_attr_name_if_wrapper(value)
|
||||||
# For these wrappers, Dynamo points to the wrapped function,
|
# For these wrappers, Dynamo points to the wrapped function,
|
||||||
|
|
|
||||||
|
|
@ -152,6 +152,8 @@ class UserFunctionVariable(BaseUserFunctionVariable):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
fn, (types.FunctionType, torch.jit.ScriptFunction)
|
fn, (types.FunctionType, torch.jit.ScriptFunction)
|
||||||
), f"expected FunctionType found {typestr(fn)} {fn}"
|
), f"expected FunctionType found {typestr(fn)} {fn}"
|
||||||
|
# TODO(anijain2305) - Replace directly calling UserFunctionVariable with
|
||||||
|
# VariableBuilder, which handles the wrapping of _torchdynamo_inline.
|
||||||
# unpack @torch._dynamo.optimize()(fn) wrapped function
|
# unpack @torch._dynamo.optimize()(fn) wrapped function
|
||||||
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
|
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
|
||||||
self.fn: types.FunctionType = fn
|
self.fn: types.FunctionType = fn
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user