diff --git a/test/jit/test_scriptmod_ann.py b/test/jit/test_scriptmod_ann.py index 5d9856744d2..65e34f8584c 100644 --- a/test/jit/test_scriptmod_ann.py +++ b/test/jit/test_scriptmod_ann.py @@ -2,6 +2,7 @@ import os import sys +import unittest import warnings from typing import Dict, List, Optional @@ -150,6 +151,30 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase): ): torch.jit.script(M()) + @unittest.skipIf( + sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)" + ) + def test_annotated_empty_list_lowercase(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.x: list[int] = [] + + def forward(self, x: list[int]): + self.x = x + return 1 + + with self.assertRaisesRegexWithHighlight( + RuntimeError, "Tried to set nonexistent attribute", "self.x = x" + ): + with self.assertWarnsRegex( + UserWarning, + "doesn't support " + "instance-level annotations on " + "empty non-base types", + ): + torch.jit.script(M()) + def test_annotated_empty_dict(self): class M(torch.nn.Module): def __init__(self): @@ -171,6 +196,30 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase): ): torch.jit.script(M()) + @unittest.skipIf( + sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)" + ) + def test_annotated_empty_dict_lowercase(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.x: dict[str, int] = {} + + def forward(self, x: dict[str, int]): + self.x = x + return 1 + + with self.assertRaisesRegexWithHighlight( + RuntimeError, "Tried to set nonexistent attribute", "self.x = x" + ): + with self.assertWarnsRegex( + UserWarning, + "doesn't support " + "instance-level annotations on " + "empty non-base types", + ): + torch.jit.script(M()) + def test_annotated_empty_optional(self): class M(torch.nn.Module): def __init__(self): @@ -213,6 +262,30 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase): ): torch.jit.script(M()) + @unittest.skipIf( + sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)" + ) + def test_annotated_with_jit_empty_list_lowercase(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.x = torch.jit.annotate(list[int], []) + + def forward(self, x: list[int]): + self.x = x + return 1 + + with self.assertRaisesRegexWithHighlight( + RuntimeError, "Tried to set nonexistent attribute", "self.x = x" + ): + with self.assertWarnsRegex( + UserWarning, + "doesn't support " + "instance-level annotations on " + "empty non-base types", + ): + torch.jit.script(M()) + def test_annotated_with_jit_empty_dict(self): class M(torch.nn.Module): def __init__(self): @@ -234,6 +307,30 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase): ): torch.jit.script(M()) + @unittest.skipIf( + sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)" + ) + def test_annotated_with_jit_empty_dict_lowercase(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.x = torch.jit.annotate(dict[str, int], {}) + + def forward(self, x: dict[str, int]): + self.x = x + return 1 + + with self.assertRaisesRegexWithHighlight( + RuntimeError, "Tried to set nonexistent attribute", "self.x = x" + ): + with self.assertWarnsRegex( + UserWarning, + "doesn't support " + "instance-level annotations on " + "empty non-base types", + ): + torch.jit.script(M()) + def test_annotated_with_jit_empty_optional(self): class M(torch.nn.Module): def __init__(self): diff --git a/torch/jit/_check.py b/torch/jit/_check.py index 790da30e511..0dc2cb6d37b 100644 --- a/torch/jit/_check.py +++ b/torch/jit/_check.py @@ -156,7 +156,7 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor): # cannot be reassigned later to a non-empty tuple. Same # deal with `NamedTuple` - containers = {"List", "Dict", "Optional"} + containers = {"List", "list", "Dict", "dict", "Optional"} # If we're not evaluating one of the specified problem types try: