mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
PEP585: Add noqa to necessary tests (#146391)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146391 Approved by: https://github.com/justinchuby, https://github.com/Skylion007
This commit is contained in:
parent
b61032fcf7
commit
1f8ff94d4f
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
"""Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime."""
|
"""Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime."""
|
||||||
|
|
||||||
from typing import List
|
from typing import Sequence
|
||||||
|
|
||||||
import onnx_test_common
|
import onnx_test_common
|
||||||
import onnxscript
|
import onnxscript
|
||||||
|
|
@ -90,7 +90,11 @@ class TestONNXScriptRuntime(onnx_test_common._TestONNXRuntime):
|
||||||
|
|
||||||
@onnxscript.script(custom_opset)
|
@onnxscript.script(custom_opset)
|
||||||
def layer_norm(
|
def layer_norm(
|
||||||
X, axes: List[int], weight: FLOAT[...], bias: FLOAT[...], eps: float
|
X,
|
||||||
|
axes: Sequence[int],
|
||||||
|
weight: FLOAT[...],
|
||||||
|
bias: FLOAT[...],
|
||||||
|
eps: float,
|
||||||
):
|
):
|
||||||
mean = op.ReduceMean(X, axes=axes)
|
mean = op.ReduceMean(X, axes=axes)
|
||||||
D = X - mean # op.Sub(X, mean)
|
D = X - mean # op.Sub(X, mean)
|
||||||
|
|
|
||||||
|
|
@ -231,7 +231,7 @@ class TestPybindTypeCasters(common.TestCase):
|
||||||
Our Pybind functions have a signature of the form `() -> return_type`.
|
Our Pybind functions have a signature of the form `() -> return_type`.
|
||||||
"""
|
"""
|
||||||
# Imports needed for the `eval` below.
|
# Imports needed for the `eval` below.
|
||||||
from typing import List, Tuple # noqa: F401
|
from typing import List, Tuple # noqa: F401, UP035
|
||||||
|
|
||||||
return eval(re.search("-> (.*)\n", func.__doc__).group(1))
|
return eval(re.search("-> (.*)\n", func.__doc__).group(1))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ import warnings
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
from typing import Any, Callable, List, NamedTuple, Optional, Tuple, Union
|
from typing import Any, Callable, NamedTuple, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx._pytree as fx_pytree
|
import torch.fx._pytree as fx_pytree
|
||||||
|
|
@ -2270,10 +2270,19 @@ class TestFX(JitTestCase):
|
||||||
graph: torch.fx.Graph = torch.fx.Graph()
|
graph: torch.fx.Graph = torch.fx.Graph()
|
||||||
x: torch.fx.Node = graph.create_node("placeholder", "x")
|
x: torch.fx.Node = graph.create_node("placeholder", "x")
|
||||||
b: torch.fx.Node = graph.create_node(
|
b: torch.fx.Node = graph.create_node(
|
||||||
"call_function", target=torch.relu, args=(x,), type_expr=List[float]
|
"call_function", target=torch.relu, args=(x,), type_expr=list[float]
|
||||||
)
|
)
|
||||||
output: torch.fx.Node = graph.output(b)
|
output: torch.fx.Node = graph.output(b)
|
||||||
|
|
||||||
|
self.assertTrue('list[float]' in str(graph))
|
||||||
|
|
||||||
|
def test_typename_print_pre_pep585(self):
|
||||||
|
graph : torch.fx.Graph = torch.fx.Graph()
|
||||||
|
x : torch.fx.Node = graph.create_node('placeholder', 'x')
|
||||||
|
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,),
|
||||||
|
type_expr=typing.List[float]) # noqa: UP006
|
||||||
|
output : torch.fx.Node = graph.output(b)
|
||||||
|
|
||||||
self.assertTrue("typing.List[float]" in str(graph))
|
self.assertTrue("typing.List[float]" in str(graph))
|
||||||
|
|
||||||
def test_layout(self):
|
def test_layout(self):
|
||||||
|
|
@ -2922,6 +2931,19 @@ class TestFX(JitTestCase):
|
||||||
def forward(self, x: list[str]) -> list[str]:
|
def forward(self, x: list[str]) -> list[str]:
|
||||||
return self.other(x)
|
return self.other(x)
|
||||||
|
|
||||||
|
traced = symbolic_trace(ReturnTypeModule())
|
||||||
|
self.assertIn("-> list[str]", traced._code)
|
||||||
|
scripted = torch.jit.script(traced)
|
||||||
|
self.assertIn("-> List[str]", scripted.code)
|
||||||
|
|
||||||
|
def test_return_type_exists_pre_pep585(self):
|
||||||
|
class ReturnTypeModule(torch.nn.Module):
|
||||||
|
def other(self, x: typing.List[str]) -> typing.List[str]: # noqa: UP006
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: typing.List[str]) -> typing.List[str]: # noqa: UP006
|
||||||
|
return self.other(x)
|
||||||
|
|
||||||
traced = symbolic_trace(ReturnTypeModule())
|
traced = symbolic_trace(ReturnTypeModule())
|
||||||
self.assertIn("-> typing_List[str]", traced._code)
|
self.assertIn("-> typing_List[str]", traced._code)
|
||||||
scripted = torch.jit.script(traced)
|
scripted = torch.jit.script(traced)
|
||||||
|
|
@ -3735,7 +3757,7 @@ class TestFX(JitTestCase):
|
||||||
@unittest.skipIf(sys.version_info > (3, 11), "Does not work in 3.11")
|
@unittest.skipIf(sys.version_info > (3, 11), "Does not work in 3.11")
|
||||||
def test_annotations_empty_tuple(self):
|
def test_annotations_empty_tuple(self):
|
||||||
class Foo(torch.nn.Module):
|
class Foo(torch.nn.Module):
|
||||||
def forward(self, x: Tuple[()], y: Tuple[str, Tuple[()]]):
|
def forward(self, x: typing.Tuple[()], y: typing.Tuple[str, typing.Tuple[()]]): # noqa: UP006
|
||||||
return "foo"
|
return "foo"
|
||||||
|
|
||||||
traced = torch.fx.symbolic_trace(Foo())
|
traced = torch.fx.symbolic_trace(Foo())
|
||||||
|
|
@ -4320,10 +4342,10 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
|
||||||
tuple,
|
tuple,
|
||||||
type,
|
type,
|
||||||
typing.Callable,
|
typing.Callable,
|
||||||
typing.Dict,
|
typing.Dict, # noqa: UP006
|
||||||
typing.List,
|
typing.List, # noqa: UP006
|
||||||
typing.Tuple,
|
typing.Tuple, # noqa: UP006
|
||||||
typing.Type,
|
typing.Type, # noqa: UP006
|
||||||
typing.Union,
|
typing.Union,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ import tempfile
|
||||||
import typing
|
import typing
|
||||||
import unittest
|
import unittest
|
||||||
from types import BuiltinFunctionType
|
from types import BuiltinFunctionType
|
||||||
from typing import Callable, List, NamedTuple, Optional, Union
|
from typing import Callable, NamedTuple, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx.experimental.meta_tracer
|
import torch.fx.experimental.meta_tracer
|
||||||
|
|
@ -1548,25 +1548,25 @@ class {test_classname}(torch.nn.Module):
|
||||||
(Optional[list[int]], list[int]),
|
(Optional[list[int]], list[int]),
|
||||||
] + [
|
] + [
|
||||||
# pre-PEP585 signatures
|
# pre-PEP585 signatures
|
||||||
(typing.List[int], int),
|
(typing.List[int], int), # noqa: UP006
|
||||||
(typing.List[int], create_type_hint([int, int])),
|
(typing.List[int], create_type_hint([int, int])), # noqa: UP006
|
||||||
(typing.List[int], create_type_hint((int, int))),
|
(typing.List[int], create_type_hint((int, int))), # noqa: UP006
|
||||||
(typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])),
|
(typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])), # noqa: UP006
|
||||||
(
|
(
|
||||||
typing.List[torch.Tensor],
|
typing.List[torch.Tensor], # noqa: UP006
|
||||||
create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
|
create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
|
||||||
),
|
),
|
||||||
(typing.List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
|
(typing.List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])), # noqa: UP006
|
||||||
(typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
|
(typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])), # noqa: UP006
|
||||||
(typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
|
(typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))), # noqa: UP006
|
||||||
(
|
(
|
||||||
typing.List[torch.Tensor],
|
typing.List[torch.Tensor], # noqa: UP006
|
||||||
create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
|
create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
|
||||||
),
|
),
|
||||||
(typing.List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
|
(typing.List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))), # noqa: UP006
|
||||||
(typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
|
(typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))), # noqa: UP006
|
||||||
(Optional[typing.List[torch.Tensor]], typing.List[torch.Tensor]),
|
(Optional[typing.List[torch.Tensor]], typing.List[torch.Tensor]), # noqa: UP006
|
||||||
(Optional[typing.List[int]], typing.List[int]),
|
(Optional[typing.List[int]], typing.List[int]), # noqa: UP006
|
||||||
]
|
]
|
||||||
|
|
||||||
for sig_type, arg_type in should_be_equal:
|
for sig_type, arg_type in should_be_equal:
|
||||||
|
|
@ -1575,7 +1575,7 @@ class {test_classname}(torch.nn.Module):
|
||||||
should_fail = [
|
should_fail = [
|
||||||
(int, float),
|
(int, float),
|
||||||
(Union[int, float], str),
|
(Union[int, float], str),
|
||||||
(list[torch.Tensor], List[int]),
|
(list[torch.Tensor], typing.List[int]), # noqa: UP006
|
||||||
] + [
|
] + [
|
||||||
# pre-PEP585 signatures
|
# pre-PEP585 signatures
|
||||||
(list[torch.Tensor], list[int]),
|
(list[torch.Tensor], list[int]),
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
from typing import Any, Tuple, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils._contextlib import (
|
from torch.utils._contextlib import (
|
||||||
|
|
@ -386,7 +386,7 @@ class _unsafe_preserve_version_counter(_DecoratorContextManager):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tensors: Union[torch.Tensor, Tuple[torch.Tensor, ...]]) -> None:
|
def __init__(self, tensors: Union[torch.Tensor, tuple[torch.Tensor, ...]]) -> None:
|
||||||
self.tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tensors
|
self.tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tensors
|
||||||
assert isinstance(self.tensors, tuple)
|
assert isinstance(self.tensors, tuple)
|
||||||
self.prev_versions = tuple(t._version for t in self.tensors)
|
self.prev_versions = tuple(t._version for t in self.tensors)
|
||||||
|
|
|
||||||
|
|
@ -455,9 +455,13 @@ class CodeGen:
|
||||||
|
|
||||||
typename = _type_repr(o)
|
typename = _type_repr(o)
|
||||||
|
|
||||||
if hasattr(o, "__origin__"):
|
if origin_type := getattr(o, "__origin__", None):
|
||||||
# This is a generic type, e.g. typing.List[torch.Tensor]
|
# list[...], typing.List[...], TensorType[...]
|
||||||
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
|
|
||||||
|
if isinstance(o, typing._GenericAlias): # type: ignore[attr-defined]
|
||||||
|
# This is a generic pre-PEP585 type, e.g. typing.List[torch.Tensor]
|
||||||
|
origin_type = _origin_type_map.get(origin_type, origin_type)
|
||||||
|
|
||||||
origin_typename = add_global(_type_repr(origin_type), origin_type)
|
origin_typename = add_global(_type_repr(origin_type), origin_type)
|
||||||
|
|
||||||
if hasattr(o, "__args__"):
|
if hasattr(o, "__args__"):
|
||||||
|
|
|
||||||
|
|
@ -126,7 +126,9 @@ def _type_repr(obj: object) -> str:
|
||||||
typically enough to uniquely identify a type. For everything
|
typically enough to uniquely identify a type. For everything
|
||||||
else, we fall back on repr(obj).
|
else, we fall back on repr(obj).
|
||||||
"""
|
"""
|
||||||
if isinstance(obj, type):
|
# Extension: If we don't ignore GenericAlias then `list[int]` will print
|
||||||
|
# simply "list".
|
||||||
|
if isinstance(obj, type) and not isinstance(obj, types.GenericAlias):
|
||||||
if obj.__module__ == "builtins":
|
if obj.__module__ == "builtins":
|
||||||
return obj.__qualname__
|
return obj.__qualname__
|
||||||
return f"{obj.__module__}.{obj.__qualname__}"
|
return f"{obj.__module__}.{obj.__qualname__}"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user