mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[TorchScript] Expand TorchScript __init__ annotation warning (#127045)
Summary: Expand TorchScript `__init__` annotation warning to `list` and `dict` with reference to GSD task T187638414 and annotation warning reproduction D56834720. Currently, the TorchScript compiler ignores and throws `UserWarning`s for the following annotation types for empty values within the `__init__` function: `List`, `Dict`, `Optional`. However, the compiler should additionally cover warnings for `list` and `dict`. This diff adds support for `list` and `dict`. Test Plan: Added 4 new unit tests: `test_annotated_empty_list_lowercase` and `test_annotated_empty_dict_lowercase` verify that TorchScript throws UserWarnings for the list and dict type annotations on empty values. ``` (base) [jananisriram@devvm2248.cco0 /data/users/jananisriram/fbsource/fbcode (e4ce427eb)]$ buck2 test @mode/{opt,inplace} //caffe2/test:jit -- --regex test_annotated_empty_list_lowercase ... Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` ``` (base) [jananisriram@devvm2248.cco0 /data/users/jananisriram/fbsource/fbcode (e4ce427eb)]$ buck2 test @mode/{opt,inplace} //caffe2/test:jit -- --regex test_annotated_empty_dict_lowercase ... Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` `test_annotated_with_jit_empty_list_lowercase` and `test_annotated_with_jit_empty_dict_lowercase` verify that TorchScript throws UserWarnings for the list and dict type annotations on empty values with the jit annotation. ``` (base) [jananisriram@devvm2248.cco0 /data/users/jananisriram/fbsource/fbcode (e4ce427eb)]$ buck2 test @mode/{opt,inplace} //caffe2/test:jit -- --regex test_annotated_with_jit_empty_list_lowercase ... Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` ``` (base) [jananisriram@devvm2248.cco0 /data/users/jananisriram/fbsource/fbcode (e4ce427eb)]$ buck2 test @mode/{opt,inplace} //caffe2/test:jit -- --regex test_annotated_with_jit_empty_dict_lowercase ... Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` Differential Revision: D57752002 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127045 Approved by: https://github.com/davidberard98
This commit is contained in:
parent
1be7e4086a
commit
f4cbcff8ef
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
import warnings
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
|
@ -150,6 +151,30 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
|||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)"
|
||||
)
|
||||
def test_annotated_empty_list_lowercase(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.x: list[int] = []
|
||||
|
||||
def forward(self, x: list[int]):
|
||||
self.x = x
|
||||
return 1
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
def test_annotated_empty_dict(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -171,6 +196,30 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
|||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)"
|
||||
)
|
||||
def test_annotated_empty_dict_lowercase(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.x: dict[str, int] = {}
|
||||
|
||||
def forward(self, x: dict[str, int]):
|
||||
self.x = x
|
||||
return 1
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
def test_annotated_empty_optional(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -213,6 +262,30 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
|||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)"
|
||||
)
|
||||
def test_annotated_with_jit_empty_list_lowercase(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.x = torch.jit.annotate(list[int], [])
|
||||
|
||||
def forward(self, x: list[int]):
|
||||
self.x = x
|
||||
return 1
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
def test_annotated_with_jit_empty_dict(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -234,6 +307,30 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
|||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)"
|
||||
)
|
||||
def test_annotated_with_jit_empty_dict_lowercase(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.x = torch.jit.annotate(dict[str, int], {})
|
||||
|
||||
def forward(self, x: dict[str, int]):
|
||||
self.x = x
|
||||
return 1
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
def test_annotated_with_jit_empty_optional(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -156,7 +156,7 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
|
|||
# cannot be reassigned later to a non-empty tuple. Same
|
||||
# deal with `NamedTuple`
|
||||
|
||||
containers = {"List", "Dict", "Optional"}
|
||||
containers = {"List", "list", "Dict", "dict", "Optional"}
|
||||
|
||||
# If we're not evaluating one of the specified problem types
|
||||
try:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user