mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
jit trace will fail for parameter check if it contains param whose ki… (#94032)
…nd is _ParameterKind.VAR_KEYWORD Pull Request resolved: https://github.com/pytorch/pytorch/pull/94032 Approved by: https://github.com/qihqi, https://github.com/davidberard98
This commit is contained in:
parent
4d6a4401f8
commit
d82c2b14c7
|
|
@ -37,7 +37,7 @@ class TestJitUtils(JitTestCase):
|
|||
|
||||
fn_positional_only_arg = jit_utils._get_py3_code(code, 'fn_positional_only_arg')
|
||||
self.assertEqual(
|
||||
[],
|
||||
["y"],
|
||||
torch._jit_internal.get_callable_argument_names(fn_positional_only_arg))
|
||||
|
||||
# Tests that VAR_POSITIONAL arguments are ignored.
|
||||
|
|
@ -46,7 +46,7 @@ class TestJitUtils(JitTestCase):
|
|||
def fn_var_positional_arg(x, *arg):
|
||||
return x + arg[0]
|
||||
self.assertEqual(
|
||||
[],
|
||||
["x"],
|
||||
torch._jit_internal.get_callable_argument_names(fn_var_positional_arg))
|
||||
|
||||
# Tests that KEYWORD_ONLY arguments are ignored.
|
||||
|
|
@ -54,7 +54,7 @@ class TestJitUtils(JitTestCase):
|
|||
def fn_keyword_only_arg(x, *, y):
|
||||
return x + y
|
||||
self.assertEqual(
|
||||
[],
|
||||
["x"],
|
||||
torch._jit_internal.get_callable_argument_names(fn_keyword_only_arg))
|
||||
|
||||
# Tests that VAR_KEYWORD arguments are ignored.
|
||||
|
|
@ -74,7 +74,7 @@ class TestJitUtils(JitTestCase):
|
|||
''')
|
||||
fn_hybrid_args = jit_utils._get_py3_code(code, 'fn_hybrid_args')
|
||||
self.assertEqual(
|
||||
[],
|
||||
["y"],
|
||||
torch._jit_internal.get_callable_argument_names(fn_hybrid_args))
|
||||
|
||||
def test_checkscriptassertraisesregex(self):
|
||||
|
|
|
|||
|
|
@ -1959,6 +1959,20 @@ class TestTracer(JitTestCase):
|
|||
FileCheck().check("first_arg").check_next("second_arg") \
|
||||
.run(str(traced_module.graph))
|
||||
|
||||
def test_trace_checking_with_deprecated_name(self):
|
||||
class MyClass(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyClass, self).__init__()
|
||||
|
||||
def forward(self, x, y, **deprecated_arguments):
|
||||
if len(deprecated_arguments) > 0:
|
||||
raise RuntimeError(f"Got unexpected arguments: {deprecated_arguments}")
|
||||
return x + y
|
||||
|
||||
model = MyClass()
|
||||
m2 = torch.jit.trace(model, (torch.ones(1), torch.ones(1)))
|
||||
m3 = torch.jit.trace(model, example_kwarg_inputs={'x': torch.ones(1), "y": torch.ones(1)}, strict=False)
|
||||
|
||||
|
||||
class TestMixTracingScripting(JitTestCase):
|
||||
def test_trace_script(self):
|
||||
|
|
|
|||
|
|
@ -320,7 +320,7 @@ def get_callable_argument_names(fn) -> List[str]:
|
|||
# All four other types of arguments do not map to individual values
|
||||
# with a keyword as name.
|
||||
if not param.kind == param.POSITIONAL_OR_KEYWORD:
|
||||
return []
|
||||
continue
|
||||
|
||||
argument_names.append(name)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user