mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Dynamo] Fix guards for script_if_tracing or lru_cache fn with default args (#120390)
Fixes #120387 Pull Request resolved: https://github.com/pytorch/pytorch/pull/120390 Approved by: https://github.com/anijain2305
This commit is contained in:
parent
55b5908427
commit
5a0a964444
|
|
@ -91,6 +91,16 @@ def inline_unused(x):
|
|||
return x + 5.6
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def inline_lru_cache_fn_with_default_args(x, y, _=None):
|
||||
return torch.sin(x * y)
|
||||
|
||||
|
||||
@torch.jit.script_if_tracing
|
||||
def inline_script_if_tracing_fn_with_default_args(x, y, _=None):
|
||||
return torch.cos(x * y)
|
||||
|
||||
|
||||
class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
@make_test
|
||||
def test_inline_jit_annotations(x):
|
||||
|
|
@ -99,6 +109,14 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
|||
x = inline_unused(x)
|
||||
return
|
||||
|
||||
@make_test
|
||||
def test_inline_script_if_tracing_fn_with_default_args(a, b):
|
||||
return inline_script_if_tracing_fn_with_default_args(a, 2, b)
|
||||
|
||||
@make_test
|
||||
def test_inline_lru_cache_fn_with_default_args(a, b):
|
||||
return inline_lru_cache_fn_with_default_args(a, 2, b)
|
||||
|
||||
@make_test
|
||||
def test_add(a, b):
|
||||
return a + b
|
||||
|
|
|
|||
|
|
@ -555,14 +555,27 @@ def is_function(value):
|
|||
|
||||
|
||||
def unwrap_if_wrapper(fn):
|
||||
return unwrap_with_attr_name_if_wrapper(fn)[0]
|
||||
|
||||
|
||||
def unwrap_with_attr_name_if_wrapper(fn):
|
||||
# unpack @functools.lru_cache wrapped function
|
||||
if isinstance(fn, functools._lru_cache_wrapper):
|
||||
fn = inspect.getattr_static(fn, "__wrapped__")
|
||||
attr_name = "__wrapped__"
|
||||
# unpack @torch._dynamo.optimize()(fn) wrapped function
|
||||
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
|
||||
elif is_function(fn) and inspect.getattr_static(fn, "_torchdynamo_inline", False):
|
||||
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
|
||||
attr_name = "_torchdynamo_inline"
|
||||
# unpack torch.jit.script_if_tracing
|
||||
if inspect.getattr_static(fn, "__script_if_tracing_wrapper", False):
|
||||
elif is_function(fn) and inspect.getattr_static(
|
||||
fn, "__script_if_tracing_wrapper", False
|
||||
):
|
||||
fn = inspect.getattr_static(fn, "__original_fn", fn)
|
||||
return fn
|
||||
attr_name = "__original_fn"
|
||||
else:
|
||||
attr_name = None
|
||||
return fn, attr_name
|
||||
|
||||
|
||||
def is_numpy_ndarray(value):
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ from ..utils import (
|
|||
tuple_iterator,
|
||||
tuple_iterator_getitem,
|
||||
tuple_iterator_len,
|
||||
unwrap_if_wrapper,
|
||||
unwrap_with_attr_name_if_wrapper,
|
||||
wrap_fake_exception,
|
||||
)
|
||||
|
||||
|
|
@ -709,7 +709,11 @@ class VariableBuilder:
|
|||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
return TorchCtxManagerClassVariable(value, source=self.source)
|
||||
elif is_function_or_wrapper(value):
|
||||
value = unwrap_if_wrapper(value)
|
||||
value, attr_name = unwrap_with_attr_name_if_wrapper(value)
|
||||
# For these wrappers, Dynamo points to the wrapped function,
|
||||
# so source needs to be updated as well.
|
||||
if attr_name is not None:
|
||||
self.source = AttrSource(self.source, attr_name)
|
||||
return trace_rules.lookup(value).create_with_source(
|
||||
value, source=self.source
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user