# Owner(s): ["oncall: jit"] # ruff: noqa: F841 import os import sys from collections import namedtuple from typing import Dict, List, NamedTuple, Tuple import torch from torch.testing._internal.common_utils import IS_WINDOWS from torch.testing._internal.jit_utils import JitTestCase, make_global # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) if __name__ == "__main__": raise RuntimeError( "This test file is not meant to be run directly, use:\n\n" "\tpython test/test_jit.py TESTNAME\n\n" "instead." ) class TestTyping(JitTestCase): def test_dict_in_not_in(self): def test_in_dict(x): # type: (Dict[str, int]) -> bool return "hi" in x self.checkScript(test_in_dict, ({"hi": 2, "bye": 3},)) self.checkScript(test_in_dict, ({"bye": 3},)) # Check evaluation order @torch.jit.script def a(): print("a") return 3 @torch.jit.script def b(): print("b") return {3: 2, 4: 1} @torch.jit.script def fn(): return a() in b() with self.capture_stdout() as captured: self.assertTrue(fn()) if not IS_WINDOWS: # no stdout capturing on windows self.assertEqual(captured[0], "a\nb\n") def test_not_in_dict(a): # type: (Dict[str, int]) -> bool if "hello" not in a: return False else: return True self.checkScript(test_not_in_dict, ({"hello": 1, "world": 2},)) self.checkScript(test_not_in_dict, ({"world": 2},)) def test_dict_tensor_key(a, t): # type: (Dict[Tensor, int], Tensor) -> bool if t in a: return True else: return False inp1 = torch.tensor(3) inp2 = torch.tensor(5) dict_a = {inp1: 1, inp2: 3} self.checkScript(test_dict_tensor_key, (dict_a, torch.tensor(4))) self.checkScript(test_dict_tensor_key, (dict_a, torch.tensor(3))) self.checkScript(test_dict_tensor_key, (dict_a, inp1)) self.checkScript(test_dict_tensor_key, (dict_a, inp2)) def test_list_type_refinement_annotation_element_mismatch(self): def fn(): l: List[int] = [1, 2, "foo", 3] return l with self.assertRaisesRegex( RuntimeError, "List type annotation" r" `List\[int\]` did not match the " "types of the given list elements", ): torch.jit.script(fn) def test_dict_type_refinement_annotation_key_mismatch(self): def fn(): l1 = [1, 2, "foo", 3] l2 = ["foo", "bar", "baz", "qux"] d: Dict[int, str] = dict(zip(l1, l2)) return d with self.assertRaisesRegex( RuntimeError, "Dicts may only " "contain homogeneous keys, but the " "type of the first generated key " r"was Union\[int, str\]", ): torch.jit.script(fn) def test_dict_type_refinement_annotation_value_mismatch(self): def fn(): l1 = ["foo", "bar", "baz", "qux"] l2 = [1, 2, "foo", 3] d: Dict[str, int] = dict(zip(l1, l2)) return d with self.assertRaisesRegex( RuntimeError, "Dict type annotation" r" `Dict\[str, int\]` did not match" " the type of an actual value type" r" `Union\[int, str\]`", ): torch.jit.script(fn) def test_dict_invalid_annotations(self): # Check for invalid value type annotation def wrong_value_type(dictionary: Dict[str, torch.jit.ScriptModule]): return with self.assertRaisesRegex(ValueError, "Unknown type annotation"): torch.jit.script(wrong_value_type) # Check for invalid key type annotation def wrong_key_type(dictionary: Dict[torch.jit.ScriptModule, str]): return with self.assertRaisesRegex(ValueError, "Unknown type annotation"): torch.jit.script(wrong_key_type) # Check for invalid key and value type annotation def wrong_key_value_type( dictionary: Dict[torch.jit.ScriptModule, torch.jit.ScriptModule] ): return with self.assertRaisesRegex(ValueError, "Unknown type annotation"): torch.jit.script(wrong_key_value_type) def test_tuple_specialization(self): @torch.jit.script def f(t, s): # type: (Tuple[Tensor, Tuple[int, Tensor]], str) -> Tensor x, t2 = t _, y = t2 return x + y t = ( torch.randn(2, 2), (1, torch.randn(2, 2)), ) f(t, "hi") graph = f.graph_for(t, "hi") input_types = list(next(graph.inputs()).type().elements()) w = input_types[0] self.assertEqual(input_types[0].kind(), "TensorType") self.assertEqual(input_types[1].elements()[1].kind(), "TensorType") def test_tuple_io(self): def stuff(x): # type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor] a, b = x return b, a a = (torch.rand(3), torch.rand(3)) self.checkScript(stuff, (a,)) def test_tuple_keyword(self): def bar(): f = tuple((1, 2)) # noqa: C409 return f self.checkScript(bar, ()) def foo(): return tuple(1, 2) self.checkScriptRaisesRegex(foo, (), Exception, "1 argument") def cant_infer_size(): return tuple([1, 2, 3]) # noqa: C409 with self.assertRaisesRegex(Exception, "cannot statically infer the expected"): torch.jit.script(cant_infer_size) def test_tuple_create_return(self): def stuff2(x): # type: (int) -> Tuple[Tensor, Tensor] a = (torch.ones(x), torch.zeros(x)) return a self.checkScript(stuff2, (3,)) def test_list_io(self): def stuff3(x): # type: (List[int]) -> Tuple[Tensor, List[int]] return torch.ones(x), x self.checkScript(stuff3, ([3, 2],)) def test_bool_list_io(self): @torch.jit.script def stuff4(x): # type: (List[bool]) -> Tuple[List[bool], List[bool], List[List[bool]]] return x, [True, False], [[True]] li_1, li_2, li_3 = stuff4([True]) li_3 = li_3[0] for li in [li_1, li_2, li_3]: self.assertTrue(type(li[0]) == bool) def test_nested_list(self): def foo(z): # type: (Tuple[int, List[List[int]]]) -> int x, y = z return y[0][1] self.checkScript(foo, ((1, [[1, 2], [3, 4]]),)) def test_list_sum(self): def fn(x: List[int]) -> int: return sum(x) def fn1(x: List[float]): return sum(x) def fn2(x: List[bool]): return sum(x) self.checkScript(fn, ([1, 2, 3],)) self.checkScript(fn1, ([1.0, 2.0, 3.0],)) self.checkScript(fn1, ([1, 2.8, 3],)) self.checkScript(fn2, ([True, False, False],)) self.checkScript(fn2, ([False, False, False],)) self.checkScript(fn2, ([0, 1, 1, 0],)) def test_list_unification(self): def fn(): return [1, None, 2] def fn2(x): return [torch.ones(2, 2), None, x] self.checkScript(fn, []) self.checkScript(fn2, (torch.ones(2, 2),)) # to avoid defining sum_list in multiple tests def get_sum_list_fn(self): def sum_list(a): # type: (List[int]) -> int sum = 0 for i in a: sum += i return sum return sum_list def test_sum_list_diff_elms(self): self.checkScript(self.get_sum_list_fn(), ([1, 2, 3, 4, 5],)) def test_sum_list_empty(self): self.checkScript(self.get_sum_list_fn(), ([],)) def test_sum_list_one(self): self.checkScript(self.get_sum_list_fn(), ([1],)) def test_sum_list_literal(self): def sum_list(): # type: () -> int sum = 0 for i in [1, 2, 3, 4, 5]: sum += i return sum self.checkScript(sum_list, ()) def test_sum_list_wrong_type(self): with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"): @torch.jit.script def sum_list(a): # type: (int) -> int sum = 0 for i in a: # noqa: T484 sum += i return sum sum_list(1) def test_list_iterables(self): with self.assertRaisesRegex( RuntimeError, "List of iterables is not supported currently" ): cu = torch.jit.CompilationUnit( """ def list_iterables(x): for i, j in [2, 3, 4], [5, 6, 7]: x += i x += j return x """ ) def test_for_in_string(self): def test_strings(x): # type: (str) -> str reverse = "" for c in x: reverse = c + reverse return reverse self.checkScript(test_strings, ("hello",)) self.checkScript(test_strings, ("",)) def test_list_strings(x): # type: (List[str]) -> str result = "" for sub_str in x: result += sub_str return result self.checkScript(test_list_strings, (["hello", "world"],)) self.checkScript(test_list_strings, (["hello", " ", "world", ""],)) def test_for_in_dict(self): def test_dicts(x): # type: (Dict[str, int]) -> int sum = 0 for key in x: sum += x[key] return sum self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},)) def test_dict_keys_values(x): # type: (Dict[str, int]) -> Tuple[str, int] key_str = "" sum = 0 for key in x.keys(): key_str += key for val in x.values(): sum += val return key_str, sum self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},)) def test_for_tuple_unpack(self): def for_tuple_unpack(x, y): for i, j in [[3, 4], [5, 6], [7, 8]]: x += i y += j return x, y self.checkScript(for_tuple_unpack, (torch.tensor(3), torch.tensor(5))) def nested_tuple_unpack(x, y): # type: (List[int], List[int]) -> int sum = 0 for i, (j, k), v in zip(x, enumerate(x), y): sum += i + j + k + v return sum self.checkScript(nested_tuple_unpack, ([1, 3, 5], [2, 4, 6])) def test_dict_comprehension(self): def fn(): return {i: chr(i + 65) for i in range(4)} self.checkScript(fn, ()) def test_dict_comprehension_with_type_annotation(self): def fn(): d: Dict[int, str] = {i: chr(i + 65) for i in range(4)} return d self.checkScript(fn, ()) with self.assertRaisesRegex(RuntimeError, ""): with self.assertRaisesRegex( AssertionError, "Expected Dict " "type annotation for dict " "comprehension, found " "Tuple[int, str]", ): @torch.jit.script def fn(): d: Tuple[int, str] = {i: chr(i + 65) for i in range(4)} return d def test_dict_comprehension_scope(self): def comprehension_can_access_outer_scope_variables(): lst = ["foo", "bar", "baz"] return {l: len(l) for l in lst} self.checkScript(comprehension_can_access_outer_scope_variables, ()) with self.assertRaisesRegex(RuntimeError, "undefined value i"): @torch.jit.script def outer_scope_cannot_access_comprehension_variables(): d = {i: chr(i + 65) for i in range(4)} i = i + 1 # noqa: F821 def test_for_tuple_assign(self): def test_simple_assign(x): # type: (Tuple[int, float]) -> float sum = 0.0 for a in x: sum += float(a) return sum self.checkScript(test_simple_assign, ((1, 2.5),)) def test_tuple_assign(x): # type: (Tuple[Tuple[int, int], Tuple[int, int]]) -> int sum = 0 for a in x: sum += a[0] sum += a[1] return sum self.checkScript(test_tuple_assign, (((1, 2), (4, 7)),)) def test_single_starred_lhs(self): with self.assertRaisesRegex( RuntimeError, "A Starred expression may only appear on the lhs within the presence" " of another non-starred expression", ): cu = torch.jit.CompilationUnit( """ def single_starred_lhs(x): a = (x, x, x) *b, = a return b """ ) def test_singleton_tuple_unpack(self): def foo(a): (b,) = (a,) return b + 1 self.checkScript(foo, (torch.rand(3),)) def test_tuple_assignments(self): def var_tuple_assign(x, y): # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor (a, b), c = x, y return a + b + c tuple_inputs = (torch.randn(1, 4), torch.randn(3, 4)) self.checkScript(var_tuple_assign, (tuple_inputs, torch.randn(3, 4))) def nested_tuple_assign(x, y, z): # type: (int, Tuple[int, Tuple[int, int]], Tuple[int, int]) -> int a, (b, (c, d)), (e, f) = x, y, z return a + b + c + d + e + f self.checkScript(nested_tuple_assign, ((1, (2, (3, 4)), (5, 6)))) def subscript_tuple_assign(a, x, i): # type: (List[int], Tensor, int) -> Tuple[int, Tensor, int] a[i], (x[i], b) = 1, (2, 3) return a[i] + 1, x + 5, b self.checkScript( subscript_tuple_assign, ([12, 7, 9, 11], torch.tensor((3, 13, 17)), 0) ) def star_tuple_assign(): # type: () -> Tuple[int, int, Tuple[int, int], Tuple[int, int]] a, (b, *c), *d = 1, (2, 3, 4), 5, 6 return a, b, c, d self.checkScript(star_tuple_assign, ()) def subscript_tuple_augmented_assign(a): # type: (Tuple[int, int]) -> Tuple[int, int] a[0] += 1 return a with self.assertRaisesRegex(RuntimeError, "does not support augmented assign"): scripted_aug_assign = torch.jit.script(subscript_tuple_augmented_assign) def test_multiple_assign(self): def test(): a = b, c = d, f = (1, 1) # side effect ten = torch.tensor(1) ten1 = ten2 = ten.add_(1) # ordering x = 1 y = 3 x, y = y, x + y return a, b, c, d, f, ten, ten1, ten2, x, y self.checkScript(test, ()) def test_opt_opt_refinement(self): @torch.jit.script def test_unify(weight, bias): # type: (Optional[int], Optional[int]) -> Optional[int] if weight is not None: opt = None else: if bias is not None: opt = 1 else: opt = None return opt def test_optional_refinement(self): @torch.jit.script def test_if_none_assignment(x): # type: (Optional[int]) -> int if x is None: x = 1 return x + 1 self.assertEqual(test_if_none_assignment(1), 2) def test_optional_conversion(self): @torch.jit.script def other_fn(x=None): # type: (Optional[int]) -> int return torch.jit._unwrap_optional(x) @torch.jit.script def fn(x): # type: (int) -> int return other_fn(x) self.assertEqual(fn(2), 2) @torch.jit.script def unify_to_optional(x): # type: (bool) -> Optional[int] if x: a = None else: a = 2 return a self.assertEqual(unify_to_optional(True), None) self.assertEqual(unify_to_optional(False), 2) @torch.jit.script def opt_list(x): # type: (Optional[List[float]]) -> int return 2 @torch.jit.script def broadcast_opt_list(x): # type: (Optional[BroadcastingList2[float]]) -> int return 2 @torch.jit.script def opt_list_tuple_caller(x): # type: (Tuple[float, float]) -> int return opt_list(x) + broadcast_opt_list(x) self.assertEqual(opt_list_tuple_caller((2.0, 3.0)), 4) def test_optional_tuple(self): def fn(x=None): # type: (Optional[Tuple[int, int]]) -> Tuple[int, int] if x is None: new_x = (1, 2) else: new_x = x return new_x self.checkScript(fn, ((3, 4),)) self.checkScript(fn, ()) def test_namedtuple_redefine(self): global _1, _2 _1 = namedtuple("GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"]) _2 = namedtuple("GoogLeNetOutputs", ["different"]) with self.assertRaisesRegex(RuntimeError, r"redefine"): @torch.jit.script def foo(x, y): # type: (_1, _2) -> _1 return x def test_namedtuple_py2(self): global _GoogLeNetOutputs # see [local resolution in python] _GoogLeNetOutputs = namedtuple( "GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"] ) @torch.jit.script def foo(x): # type: (_GoogLeNetOutputs) -> _GoogLeNetOutputs return x vals = torch.rand(3), torch.rand(4), torch.rand(5) out = foo( _GoogLeNetOutputs(logits=vals[0], aux_logits2=vals[1], aux_logits1=vals[2]) ) self.assertEqual(out.logits, vals[0]) self.assertEqual(out.aux_logits2, vals[1]) self.assertEqual(out.aux_logits1, vals[2]) def test_namedtuple_good_error(self): global _GoogLeNetOutputs # see [local resolution in python] _GoogLeNetOutputs = namedtuple( "GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"] ) @torch.jit.script def foo(x): # type: (_GoogLeNetOutputs) -> _GoogLeNetOutputs return x with self.assertRaisesRegex( RuntimeError, r"aka NamedTuple\(logits, aux_logits2, aux_logits1\)" ): out = foo(_GoogLeNetOutputs(logits="3", aux_logits2="4", aux_logits1="5")) def test_namedtuple_error_source_attribution(self): class _NamedTupleBadMemberType(NamedTuple): f1: torch.Tensor f2: "ABadForwardRefType" # noqa: F821 make_global(_NamedTupleBadMemberType) # see [local resolution in python] def fn(x: _NamedTupleBadMemberType) -> torch.Tensor: return x.f1.relu() # assert that this has a location associated with the error. # note the " +" is regex (i.e. "at least one space") with self.assertRaisesRegex(ValueError, "at +File"): torch.jit.script(fn) def test_inherited_annotations_python_310(self): # See #104484 # In python >=3.10, inspect.get_annotations doesn't always return the same values. # Sometimes it will show all annotations; other times it will show only annotations # that show in that class, not classes it inherits fro. class BaseModule(torch.nn.Module): state: List[int] def forward(self, x): pass def do_something_with_list(x: List[int]): if x: return x[-1] return 5 class Submodule(BaseModule): def __init__(self, self_x_value): super().__init__() self.x = self_x_value self.state = [] def forward(self, x): return self.x + x + do_something_with_list(self.state) class LowestModule(Submodule): def __init__(self) -> None: super().__init__(123) mod = LowestModule() mod2 = LowestModule() mod_s = torch.jit.script(mod) mod2_s = torch.jit.script(mod2)