jit trace will fail for parameter check if it contains param whose ki… (#94032)

…nd is _ParameterKind.VAR_KEYWORD

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94032
Approved by: https://github.com/qihqi, https://github.com/davidberard98
This commit is contained in:
Wang, Yi A 2023-02-13 20:33:26 +00:00 committed by PyTorch MergeBot
parent 4d6a4401f8
commit d82c2b14c7
3 changed files with 19 additions and 5 deletions

View File

@ -37,7 +37,7 @@ class TestJitUtils(JitTestCase):
fn_positional_only_arg = jit_utils._get_py3_code(code, 'fn_positional_only_arg')
self.assertEqual(
[],
["y"],
torch._jit_internal.get_callable_argument_names(fn_positional_only_arg))
# Tests that VAR_POSITIONAL arguments are ignored.
@ -46,7 +46,7 @@ class TestJitUtils(JitTestCase):
def fn_var_positional_arg(x, *arg):
return x + arg[0]
self.assertEqual(
[],
["x"],
torch._jit_internal.get_callable_argument_names(fn_var_positional_arg))
# Tests that KEYWORD_ONLY arguments are ignored.
@ -54,7 +54,7 @@ class TestJitUtils(JitTestCase):
def fn_keyword_only_arg(x, *, y):
return x + y
self.assertEqual(
[],
["x"],
torch._jit_internal.get_callable_argument_names(fn_keyword_only_arg))
# Tests that VAR_KEYWORD arguments are ignored.
@ -74,7 +74,7 @@ class TestJitUtils(JitTestCase):
''')
fn_hybrid_args = jit_utils._get_py3_code(code, 'fn_hybrid_args')
self.assertEqual(
[],
["y"],
torch._jit_internal.get_callable_argument_names(fn_hybrid_args))
def test_checkscriptassertraisesregex(self):

View File

@ -1959,6 +1959,20 @@ class TestTracer(JitTestCase):
FileCheck().check("first_arg").check_next("second_arg") \
.run(str(traced_module.graph))
def test_trace_checking_with_deprecated_name(self):
class MyClass(torch.nn.Module):
def __init__(self):
super(MyClass, self).__init__()
def forward(self, x, y, **deprecated_arguments):
if len(deprecated_arguments) > 0:
raise RuntimeError(f"Got unexpected arguments: {deprecated_arguments}")
return x + y
model = MyClass()
m2 = torch.jit.trace(model, (torch.ones(1), torch.ones(1)))
m3 = torch.jit.trace(model, example_kwarg_inputs={'x': torch.ones(1), "y": torch.ones(1)}, strict=False)
class TestMixTracingScripting(JitTestCase):
def test_trace_script(self):

View File

@ -320,7 +320,7 @@ def get_callable_argument_names(fn) -> List[str]:
# All four other types of arguments do not map to individual values
# with a keyword as name.
if not param.kind == param.POSITIONAL_OR_KEYWORD:
return []
continue
argument_names.append(name)