[Dynamo] Fix guards for script_if_tracing or lru_cache fn with default args (#120390)

Fixes #120387

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120390
Approved by: https://github.com/anijain2305
This commit is contained in:
Yanbo Liang 2024-02-26 19:40:14 +00:00 committed by PyTorch MergeBot
parent 55b5908427
commit 5a0a964444
3 changed files with 40 additions and 5 deletions

View File

@ -91,6 +91,16 @@ def inline_unused(x):
return x + 5.6
@functools.lru_cache
def inline_lru_cache_fn_with_default_args(x, y, _=None):
return torch.sin(x * y)
@torch.jit.script_if_tracing
def inline_script_if_tracing_fn_with_default_args(x, y, _=None):
return torch.cos(x * y)
class FunctionTests(torch._dynamo.test_case.TestCase):
@make_test
def test_inline_jit_annotations(x):
@ -99,6 +109,14 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
x = inline_unused(x)
return
@make_test
def test_inline_script_if_tracing_fn_with_default_args(a, b):
return inline_script_if_tracing_fn_with_default_args(a, 2, b)
@make_test
def test_inline_lru_cache_fn_with_default_args(a, b):
return inline_lru_cache_fn_with_default_args(a, 2, b)
@make_test
def test_add(a, b):
return a + b

View File

@ -555,14 +555,27 @@ def is_function(value):
def unwrap_if_wrapper(fn):
return unwrap_with_attr_name_if_wrapper(fn)[0]
def unwrap_with_attr_name_if_wrapper(fn):
# unpack @functools.lru_cache wrapped function
if isinstance(fn, functools._lru_cache_wrapper):
fn = inspect.getattr_static(fn, "__wrapped__")
attr_name = "__wrapped__"
# unpack @torch._dynamo.optimize()(fn) wrapped function
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
elif is_function(fn) and inspect.getattr_static(fn, "_torchdynamo_inline", False):
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
attr_name = "_torchdynamo_inline"
# unpack torch.jit.script_if_tracing
if inspect.getattr_static(fn, "__script_if_tracing_wrapper", False):
elif is_function(fn) and inspect.getattr_static(
fn, "__script_if_tracing_wrapper", False
):
fn = inspect.getattr_static(fn, "__original_fn", fn)
return fn
attr_name = "__original_fn"
else:
attr_name = None
return fn, attr_name
def is_numpy_ndarray(value):

View File

@ -77,7 +77,7 @@ from ..utils import (
tuple_iterator,
tuple_iterator_getitem,
tuple_iterator_len,
unwrap_if_wrapper,
unwrap_with_attr_name_if_wrapper,
wrap_fake_exception,
)
@ -709,7 +709,11 @@ class VariableBuilder:
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return TorchCtxManagerClassVariable(value, source=self.source)
elif is_function_or_wrapper(value):
value = unwrap_if_wrapper(value)
value, attr_name = unwrap_with_attr_name_if_wrapper(value)
# For these wrappers, Dynamo points to the wrapped function,
# so source needs to be updated as well.
if attr_name is not None:
self.source = AttrSource(self.source, attr_name)
return trace_rules.lookup(value).create_with_source(
value, source=self.source
)