[4/N] Apply py39 ruff and pyupgrade fixes (#143257)

```torch/fx/passes/annotate_getitem_nodes.py``` was changed to support the new type hinting annotations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143257
Approved by: https://github.com/justinchuby, https://github.com/albanD
This commit is contained in:
cyy 2025-01-04 10:47:51 +00:00 committed by PyTorch MergeBot
parent a881954b0c
commit df458be4e5
55 changed files with 247 additions and 227 deletions

View File

@ -1,8 +1,3 @@
import dis
import inspect
from collections.abc import Sequence
from typing import Union
import functorch._C import functorch._C
import torch import torch
from functorch._C import dim as _C from functorch._C import dim as _C

View File

@ -27,7 +27,7 @@ from __future__ import annotations
import keyword import keyword
import warnings import warnings
from typing import List, Optional, Set, Tuple, TYPE_CHECKING, Union from typing import Optional, TYPE_CHECKING, Union
if TYPE_CHECKING: if TYPE_CHECKING:
@ -73,11 +73,11 @@ class ParsedExpression:
""" """
self.has_ellipsis: bool = False self.has_ellipsis: bool = False
self.has_ellipsis_parenthesized: Optional[bool] = None self.has_ellipsis_parenthesized: Optional[bool] = None
self.identifiers: Set[Union[str, AnonymousAxis]] = set() self.identifiers: set[Union[str, AnonymousAxis]] = set()
# that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition # that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
self.has_non_unitary_anonymous_axes: bool = False self.has_non_unitary_anonymous_axes: bool = False
# composition keeps structure of composite axes, see how different corner cases are handled in tests # composition keeps structure of composite axes, see how different corner cases are handled in tests
self.composition: List[Union[List[Union[str, AnonymousAxis]], str]] = [] self.composition: list[Union[list[Union[str, AnonymousAxis]], str]] = []
if "." in expression: if "." in expression:
if "..." not in expression: if "..." not in expression:
raise ValueError( raise ValueError(
@ -90,7 +90,7 @@ class ParsedExpression:
expression = expression.replace("...", _ellipsis) expression = expression.replace("...", _ellipsis)
self.has_ellipsis = True self.has_ellipsis = True
bracket_group: Optional[List[Union[str, AnonymousAxis]]] = None bracket_group: Optional[list[Union[str, AnonymousAxis]]] = None
def add_axis_name(x: str) -> None: def add_axis_name(x: str) -> None:
if x in self.identifiers: if x in self.identifiers:
@ -164,7 +164,7 @@ class ParsedExpression:
@staticmethod @staticmethod
def check_axis_name_return_reason( def check_axis_name_return_reason(
name: str, allow_underscore: bool = False name: str, allow_underscore: bool = False
) -> Tuple[bool, str]: ) -> tuple[bool, str]:
"""Check if the given axis name is valid, and a message explaining why if not. """Check if the given axis name is valid, and a message explaining why if not.
Valid axes names are python identifiers except keywords, and should not start or end with an underscore. Valid axes names are python identifiers except keywords, and should not start or end with an underscore.
@ -174,7 +174,7 @@ class ParsedExpression:
allow_underscore (bool): whether axis names are allowed to start with an underscore allow_underscore (bool): whether axis names are allowed to start with an underscore
Returns: Returns:
Tuple[bool, str]: whether the axis name is valid, a message explaining why if not tuple[bool, str]: whether the axis name is valid, a message explaining why if not
""" """
if not str.isidentifier(name): if not str.isidentifier(name):
return False, "not a valid python identifier" return False, "not a valid python identifier"
@ -211,7 +211,7 @@ class ParsedExpression:
def parse_pattern( def parse_pattern(
pattern: str, axes_lengths: Mapping[str, int] pattern: str, axes_lengths: Mapping[str, int]
) -> Tuple[ParsedExpression, ParsedExpression]: ) -> tuple[ParsedExpression, ParsedExpression]:
"""Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object. """Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object.
Args: Args:
@ -219,7 +219,7 @@ def parse_pattern(
axes_lengths (Mapping[str, int]): any additional length specifications for dimensions axes_lengths (Mapping[str, int]): any additional length specifications for dimensions
Returns: Returns:
Tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions
""" """
# adapted from einops.einops._prepare_transformation_recipe # adapted from einops.einops._prepare_transformation_recipe
# https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/einops/einops.py # https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/einops/einops.py

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import functools import functools
from typing import Callable, Dict, List, Tuple, TYPE_CHECKING, Union from typing import Callable, TYPE_CHECKING, Union
import torch import torch
from functorch._C import dim as _C from functorch._C import dim as _C
@ -18,7 +18,6 @@ from ._parsing import (
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence from collections.abc import Sequence
__all__ = ["rearrange"] __all__ = ["rearrange"]
dims = _C.dims dims = _C.dims
@ -69,9 +68,9 @@ def _create_rearrange_callable(
# an identity rearrangement on a 0-dimension tensor # an identity rearrangement on a 0-dimension tensor
return lambda tensor: tensor return lambda tensor: tensor
first_class_dims: Tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims)) first_class_dims: tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))
identifier_dim_map: Dict[Union[str, AnonymousAxis], Tuple[str, ...]] = {} identifier_dim_map: dict[Union[str, AnonymousAxis], tuple[str, ...]] = {}
anon_axes: List[AnonymousAxis] = [] anon_axes: list[AnonymousAxis] = []
# map the left-hand side identifiers to strings representing first class dims # map the left-hand side identifiers to strings representing first class dims
dims_i = 0 dims_i = 0
@ -99,11 +98,11 @@ def _create_rearrange_callable(
raise ValueError(f"Unexpected dimension: {dimension}") raise ValueError(f"Unexpected dimension: {dimension}")
def composition_to_dims( def composition_to_dims(
composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]], composition: Sequence[Union[list[Union[str, AnonymousAxis]], str]],
) -> List[Union[str, Tuple[str, ...]]]: ) -> list[Union[str, tuple[str, ...]]]:
"""Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first """Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first
class dims.""" class dims."""
dim_composition: List[Union[str, Tuple[str, ...]]] = [] dim_composition: list[Union[str, tuple[str, ...]]] = []
for dimension in composition: for dimension in composition:
if isinstance(dimension, list): if isinstance(dimension, list):
dim_composition.append( dim_composition.append(
@ -152,7 +151,7 @@ def _create_rearrange_callable(
def rearrange( def rearrange(
tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]], tensor: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]],
pattern: str, pattern: str,
**axes_lengths: int, **axes_lengths: int,
) -> torch.Tensor: ) -> torch.Tensor:

View File

@ -2,7 +2,6 @@
import copy import copy
import logging import logging
from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -247,7 +246,7 @@ class TestActivationSparsifier(TestCase):
assert mask2 is None assert mask2 is None
else: else:
assert type(mask1) == type(mask2) assert type(mask1) == type(mask2)
if isinstance(mask1, List): if isinstance(mask1, list):
assert len(mask1) == len(mask2) assert len(mask1) == len(mask2)
for idx in range(len(mask1)): for idx in range(len(mask1)):
assert torch.all(mask1[idx] == mask2[idx]) assert torch.all(mask1[idx] == mask2[idx])
@ -258,7 +257,7 @@ class TestActivationSparsifier(TestCase):
for state in state_dict["state"].values(): for state in state_dict["state"].values():
mask = state["mask"] mask = state["mask"]
if mask is not None: if mask is not None:
if isinstance(mask, List): if isinstance(mask, list):
for idx in range(len(mask)): for idx in range(len(mask)):
assert mask[idx].is_sparse assert mask[idx].is_sparse
else: else:

View File

@ -3,7 +3,6 @@
import copy import copy
import logging import logging
import warnings import warnings
from typing import Tuple
import torch import torch
from torch import nn from torch import nn
@ -73,7 +72,7 @@ class TestBaseDataScheduler(TestCase):
def _get_name_data_config(self, some_data, defaults): def _get_name_data_config(self, some_data, defaults):
config = copy.deepcopy(defaults) config = copy.deepcopy(defaults)
if isinstance(some_data, Tuple): if isinstance(some_data, tuple):
# dealing with data_list # dealing with data_list
name, data = some_data name, data = some_data
else: else:

View File

@ -4,7 +4,6 @@ import copy
import itertools import itertools
import logging import logging
import math import math
from typing import Tuple
import torch import torch
from torch import nn from torch import nn
@ -54,7 +53,7 @@ class _BaseDataSparsiferTestCase(TestCase):
@staticmethod @staticmethod
def _get_name_data_config(some_data, defaults=None): def _get_name_data_config(some_data, defaults=None):
if isinstance(some_data, Tuple): if isinstance(some_data, tuple):
# dealing with data_list # dealing with data_list
name, data = some_data name, data = some_data
config = defaults config = defaults
@ -482,8 +481,9 @@ class TestBaseDataSparsifier(_BaseDataSparsiferTestCase):
nn.Parameter(torch.randn(4, 4)), nn.Parameter(torch.randn(4, 4)),
nn.Parameter(torch.randn(5, 5)), nn.Parameter(torch.randn(5, 5)),
) )
param4, param5 = nn.Parameter(torch.randn(1, 1)), nn.Parameter( param4, param5 = (
torch.randn(4, 4) nn.Parameter(torch.randn(1, 1)),
nn.Parameter(torch.randn(4, 4)),
) )
data_list = [("param1", param1), ("param2", param2), ("param3", param3)] data_list = [("param1", param1), ("param2", param2), ("param3", param3)]
defaults = {"test": 3} defaults = {"test": 3}
@ -585,8 +585,9 @@ class TestNormDataSparsifiers(_NormDataSparsifierTestCase):
nn.Parameter(torch.randn(4, 4)), nn.Parameter(torch.randn(4, 4)),
nn.Parameter(torch.randn(5, 5)), nn.Parameter(torch.randn(5, 5)),
) )
param4, param5 = nn.Parameter(torch.randn(10, 10)), nn.Parameter( param4, param5 = (
torch.randn(4, 4) nn.Parameter(torch.randn(10, 10)),
nn.Parameter(torch.randn(4, 4)),
) )
data_list = [("param1", param1), ("param2", param2), ("param3", param3)] data_list = [("param1", param1), ("param2", param2), ("param3", param3)]
defaults = { defaults = {

View File

@ -2,13 +2,17 @@
from __future__ import annotations from __future__ import annotations
import typing import typing
from typing import List, Optional, Sequence, Union # noqa: F401 from typing import List, Optional, Union
import torch import torch
from torch import Tensor, types from torch import Tensor, types
from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.common_utils import run_tests, TestCase
if typing.TYPE_CHECKING:
from collections.abc import Sequence
mutates_args = {} mutates_args = {}

View File

@ -6,7 +6,8 @@ import functools
import itertools import itertools
import unittest import unittest
from collections import defaultdict from collections import defaultdict
from typing import Any, Iterable, List, Optional, Tuple, Union from collections.abc import Iterable
from typing import Any, List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist

View File

@ -1,7 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates # Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
from typing import Any, Callable, Dict, Optional, Sequence from collections.abc import Sequence
from typing import Any, Callable, Dict, Optional
from unittest import skip from unittest import skip
import torch import torch

View File

@ -10,8 +10,9 @@ import inspect
import io import io
import operator import operator
import unittest import unittest
from collections.abc import Sequence
from enum import Enum from enum import Enum
from typing import Dict, List, Sequence from typing import Dict, List
from unittest.mock import patch from unittest.mock import patch
import torch import torch

View File

@ -21,10 +21,11 @@ import warnings
import weakref import weakref
from abc import ABC from abc import ABC
from collections import namedtuple from collections import namedtuple
from collections.abc import Iterator
from copy import deepcopy from copy import deepcopy
from enum import Enum, IntEnum from enum import Enum, IntEnum
from functools import wraps from functools import wraps
from typing import Any, Dict, Iterator, List, Literal, Tuple, TypedDict from typing import Any, Dict, List, Literal, Tuple, TypedDict
from unittest import mock from unittest import mock
import numpy as np import numpy as np

View File

@ -1,6 +1,7 @@
# Owner(s): ["oncall: export"] # Owner(s): ["oncall: export"]
import unittest import unittest
from typing import Any, Dict, Optional, OrderedDict, Tuple from collections import OrderedDict
from typing import Any, Dict, Optional, Tuple
import torch import torch
from torch._export.passes.lift_constants_pass import ( from torch._export.passes.lift_constants_pass import (

View File

@ -14,8 +14,7 @@ import random
import types import types
import unittest import unittest
import warnings import warnings
from collections import namedtuple from collections import namedtuple, OrderedDict
from typing import OrderedDict
from unittest.case import skipIf from unittest.case import skipIf
from common_utils import ( from common_utils import (

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: fx"] # Owner(s): ["module: fx"]
import unittest import unittest
from typing import Mapping from collections.abc import Mapping
import torch import torch
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner

View File

@ -3,8 +3,8 @@
import functools import functools
import itertools import itertools
import os import os
from collections.abc import Sequence
from pathlib import Path from pathlib import Path
from typing import Sequence
from unittest import skip from unittest import skip
import yaml import yaml

View File

@ -9,7 +9,7 @@ from collections import namedtuple, OrderedDict
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Any, Dict, List, Tuple from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -55,11 +55,11 @@ class ToyModel(nn.Module):
def forward_hook( def forward_hook(
self: TestCase, self: TestCase,
fired_hooks: List[int], fired_hooks: list[int],
expected_module: nn.Module, expected_module: nn.Module,
hook_id: int, hook_id: int,
module: nn.Module, module: nn.Module,
inp: Tuple[torch.Tensor], inp: tuple[torch.Tensor],
out: torch.Tensor, out: torch.Tensor,
) -> None: ) -> None:
fired_hooks.append(hook_id) fired_hooks.append(hook_id)
@ -69,11 +69,11 @@ def forward_hook(
def forward_pre_hook( def forward_pre_hook(
self: TestCase, self: TestCase,
fired_hooks: List[int], fired_hooks: list[int],
expected_module: nn.Module, expected_module: nn.Module,
hook_id: int, hook_id: int,
module: nn.Module, module: nn.Module,
inp: Tuple[torch.Tensor], inp: tuple[torch.Tensor],
) -> None: ) -> None:
fired_hooks.append(hook_id) fired_hooks.append(hook_id)
self.assertEqual(id(module), id(expected_module)) self.assertEqual(id(module), id(expected_module))
@ -82,12 +82,12 @@ def forward_pre_hook(
def full_backward_hook( def full_backward_hook(
self: TestCase, self: TestCase,
fired_hooks: List[int], fired_hooks: list[int],
expected_module: nn.Module, expected_module: nn.Module,
hook_id: int, hook_id: int,
module: nn.Module, module: nn.Module,
grad_input: Tuple[torch.Tensor], grad_input: tuple[torch.Tensor],
grad_output: Tuple[torch.Tensor], grad_output: tuple[torch.Tensor],
) -> None: ) -> None:
fired_hooks.append(hook_id) fired_hooks.append(hook_id)
self.assertEqual(id(module), id(expected_module)) self.assertEqual(id(module), id(expected_module))
@ -97,11 +97,11 @@ def full_backward_hook(
def full_backward_pre_hook( def full_backward_pre_hook(
self: TestCase, self: TestCase,
fired_hooks: List[int], fired_hooks: list[int],
expected_module: nn.Module, expected_module: nn.Module,
hook_id: int, hook_id: int,
module: nn.Module, module: nn.Module,
grad_input: Tuple[torch.Tensor], grad_input: tuple[torch.Tensor],
) -> None: ) -> None:
fired_hooks.append(hook_id) fired_hooks.append(hook_id)
self.assertEqual(id(module), id(expected_module)) self.assertEqual(id(module), id(expected_module))
@ -122,8 +122,8 @@ class KwargModel(nn.Module):
def internal_forward_hook( def internal_forward_hook(
self, self,
module: nn.Module, module: nn.Module,
args: Tuple[torch.Tensor], args: tuple[torch.Tensor],
kwargs: Dict[str, Any], kwargs: dict[str, Any],
out: torch.Tensor, out: torch.Tensor,
): ):
return out + kwargs["bias"] return out + kwargs["bias"]
@ -142,13 +142,13 @@ class FailsInForwardModel(nn.Module):
def kwarg_forward_pre_hook( def kwarg_forward_pre_hook(
self: TestCase, self: TestCase,
fired_hooks: List[int], fired_hooks: list[int],
expected_module: nn.Module, expected_module: nn.Module,
hook_id: int, hook_id: int,
module: nn.Module, module: nn.Module,
args: Tuple[torch.Tensor], args: tuple[torch.Tensor],
kwargs: Dict[str, Any], kwargs: dict[str, Any],
) -> Tuple[Any, Any]: ) -> tuple[Any, Any]:
fired_hooks.append(hook_id) fired_hooks.append(hook_id)
self.assertEqual(id(module), id(expected_module)) self.assertEqual(id(module), id(expected_module))
self.assertEqual(len(args), 1) self.assertEqual(len(args), 1)
@ -158,12 +158,12 @@ def kwarg_forward_pre_hook(
def kwarg_forward_hook( def kwarg_forward_hook(
self: TestCase, self: TestCase,
fired_hooks: List[int], fired_hooks: list[int],
expected_module: nn.Module, expected_module: nn.Module,
hook_id: int, hook_id: int,
module: nn.Module, module: nn.Module,
args: Tuple[torch.Tensor], args: tuple[torch.Tensor],
kwargs: Dict[str, Any], kwargs: dict[str, Any],
out: torch.Tensor, out: torch.Tensor,
) -> Any: ) -> Any:
fired_hooks.append(hook_id) fired_hooks.append(hook_id)
@ -188,7 +188,7 @@ class DummyContextManager:
class TestModuleHooks(TestCase): class TestModuleHooks(TestCase):
@parametrize_test("named_tuple", (True, False)) @parametrize_test("named_tuple", (True, False))
def test_forward_hooks(self, named_tuple): def test_forward_hooks(self, named_tuple):
fired_hooks: List[int] = [] fired_hooks: list[int] = []
model = ToyModel(named_tuple) model = ToyModel(named_tuple)
x = torch.randn(10, 10) x = torch.randn(10, 10)
hook = partial(forward_hook, self, fired_hooks, model.net1.seq2) hook = partial(forward_hook, self, fired_hooks, model.net1.seq2)
@ -210,7 +210,7 @@ class TestModuleHooks(TestCase):
@parametrize_test("named_tuple", (True, False)) @parametrize_test("named_tuple", (True, False))
def test_forward_pre_hooks(self, named_tuple): def test_forward_pre_hooks(self, named_tuple):
fired_hooks: List[int] = [] fired_hooks: list[int] = []
model = ToyModel(named_tuple) model = ToyModel(named_tuple)
x = torch.randn(10, 10) x = torch.randn(10, 10)
hook = partial(forward_pre_hook, self, fired_hooks, model.net2.seq1) hook = partial(forward_pre_hook, self, fired_hooks, model.net2.seq1)
@ -232,7 +232,7 @@ class TestModuleHooks(TestCase):
@parametrize_test("named_tuple", (True, False)) @parametrize_test("named_tuple", (True, False))
def test_full_backward_hooks(self, named_tuple): def test_full_backward_hooks(self, named_tuple):
fired_hooks: List[int] = [] fired_hooks: list[int] = []
model = ToyModel(named_tuple) model = ToyModel(named_tuple)
x = torch.randn(10, 10) x = torch.randn(10, 10)
hook = partial(full_backward_hook, self, fired_hooks, model.net1) hook = partial(full_backward_hook, self, fired_hooks, model.net1)
@ -254,7 +254,7 @@ class TestModuleHooks(TestCase):
@parametrize_test("named_tuple", (True, False)) @parametrize_test("named_tuple", (True, False))
def test_full_backward_pre_hooks(self, named_tuple): def test_full_backward_pre_hooks(self, named_tuple):
fired_hooks: List[int] = [] fired_hooks: list[int] = []
model = ToyModel(named_tuple) model = ToyModel(named_tuple)
x = torch.randn(10, 10) x = torch.randn(10, 10)
hook = partial(full_backward_pre_hook, self, fired_hooks, model.net1) hook = partial(full_backward_pre_hook, self, fired_hooks, model.net1)
@ -294,7 +294,7 @@ class TestModuleHooks(TestCase):
@parametrize_test("named_tuple", (True, False)) @parametrize_test("named_tuple", (True, False))
def test_mixed_hooks(self, named_tuple): def test_mixed_hooks(self, named_tuple):
fired_hooks: List[int] = [] fired_hooks: list[int] = []
model = ToyModel(named_tuple) model = ToyModel(named_tuple)
x = torch.randn(10, 10) x = torch.randn(10, 10)
model.register_forward_pre_hook( model.register_forward_pre_hook(
@ -319,7 +319,7 @@ class TestModuleHooks(TestCase):
def test_kwarg_hooks(self): def test_kwarg_hooks(self):
# 1. test forward pre hook # 1. test forward pre hook
fired_hooks: List[int] = [] fired_hooks: list[int] = []
x: torch.Tensor = torch.ones(10, 10) x: torch.Tensor = torch.ones(10, 10)
bias: torch.Tensor = torch.ones(10, 10) bias: torch.Tensor = torch.ones(10, 10)
model = KwargModel() model = KwargModel()
@ -336,7 +336,7 @@ class TestModuleHooks(TestCase):
self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5) self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5)
# 2. test forward pre and forward hooks # 2. test forward pre and forward hooks
fired_hooks: List[int] = [] fired_hooks: list[int] = []
x: torch.Tensor = torch.ones(10, 10) x: torch.Tensor = torch.ones(10, 10)
bias: torch.Tensor = torch.ones(10, 10) bias: torch.Tensor = torch.ones(10, 10)
model = KwargModel() model = KwargModel()
@ -372,7 +372,7 @@ class TestModuleHooks(TestCase):
def test_remove_kwarg_hooks(self): def test_remove_kwarg_hooks(self):
# test forward pre and forward hooks # test forward pre and forward hooks
fired_hooks: List[int] = [] fired_hooks: list[int] = []
x: torch.Tensor = torch.ones(10, 10) x: torch.Tensor = torch.ones(10, 10)
bias: torch.Tensor = torch.ones(10, 10) bias: torch.Tensor = torch.ones(10, 10)
model = KwargModel() model = KwargModel()
@ -1217,8 +1217,8 @@ class TestModuleGlobalHooks(TestCase):
def test_module_global_hooks_with_kwargs(self): def test_module_global_hooks_with_kwargs(self):
def kwarg_global_forward_hook( def kwarg_global_forward_hook(
module: nn.Module, module: nn.Module,
args: Tuple[torch.Tensor], args: tuple[torch.Tensor],
kwargs: Dict[str, Any], kwargs: dict[str, Any],
out: torch.Tensor, out: torch.Tensor,
) -> Any: ) -> Any:
out = out + kwargs["bias"] out = out + kwargs["bias"]

View File

@ -2,7 +2,6 @@
import itertools import itertools
import random import random
from typing import List
import torch import torch
import torch.nn.utils.rnn as rnn_utils import torch.nn.utils.rnn as rnn_utils
@ -219,7 +218,7 @@ class PackedSequenceTest(TestCase):
# more dimensions # more dimensions
maxlen = 9 maxlen = 9
for num_dim in (0, 1, 2, 3): for num_dim in (0, 1, 2, 3):
sequences: List[torch.Tensor] = [] sequences: list[torch.Tensor] = []
trailing_dims = [4] * num_dim trailing_dims = [4] * num_dim
for i in range(1, maxlen + 1): for i in range(1, maxlen + 1):
seq_len = i * i seq_len = i * i

View File

@ -1573,10 +1573,10 @@ torch.cuda.synchronize()
return torch.stack([col, col + 2], 1).view(2, 2, 2, 2) return torch.stack([col, col + 2], 1).view(2, 2, 2, 2)
if adaptive: if adaptive:
cls_name = "AdaptiveMaxPool{}d".format(num_dim) # noqa: UP032 cls_name = f"AdaptiveMaxPool{num_dim}d"
else: else:
# FIXME(#105716): Test fails when using f-string # FIXME(#105716): Test fails when using f-string
cls_name = "MaxPool{}d".format(num_dim) # noqa: UP032 cls_name = f"MaxPool{num_dim}d"
module_cls = getattr(nn, cls_name) module_cls = getattr(nn, cls_name)
module = module_cls(2, return_indices=True).to(device, dtype=dtype) module = module_cls(2, return_indices=True).to(device, dtype=dtype)
numel = 4 ** (num_dim + 1) numel = 4 ** (num_dim + 1)

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: onnx"] # Owner(s): ["module: onnx"]
"""Unit tests for the internal registration wrapper module.""" """Unit tests for the internal registration wrapper module."""
from typing import Sequence from collections.abc import Sequence
from torch.onnx import errors from torch.onnx import errors
from torch.onnx._internal import registration from torch.onnx._internal import registration

View File

@ -10,19 +10,8 @@ import logging
import os import os
import unittest import unittest
import warnings import warnings
from typing import ( from collections.abc import Collection, Iterable, Mapping, Sequence
Any, from typing import Any, Callable, List, Optional, Tuple, Type, Union
Callable,
Collection,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
)
import numpy as np import numpy as np
import onnxruntime import onnxruntime

View File

@ -3,7 +3,8 @@
import os import os
import unittest import unittest
from collections import OrderedDict from collections import OrderedDict
from typing import List, Mapping, Tuple from collections.abc import Mapping
from typing import List, Tuple
import onnx_test_common import onnx_test_common
import parameterized import parameterized

View File

@ -10,7 +10,7 @@ import itertools
import unittest import unittest
import unittest.mock import unittest.mock
import warnings import warnings
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import numpy as np import numpy as np
@ -26,6 +26,10 @@ from torch.onnx._internal import registration
from torch.testing._internal import common_quantization, common_utils, jit_utils from torch.testing._internal import common_quantization, common_utils, jit_utils
if TYPE_CHECKING:
from collections.abc import Iterable
def export_to_onnx( def export_to_onnx(
model: Union[torch.nn.Module, torch.jit.ScriptFunction], model: Union[torch.nn.Module, torch.jit.ScriptFunction],
input: Union[torch.Tensor, Tuple[torch.Tensor]], input: Union[torch.Tensor, Tuple[torch.Tensor]],

View File

@ -1,6 +1,6 @@
# Owner(s): ["oncall: package/deploy"] # Owner(s): ["oncall: package/deploy"]
from typing import Iterable from collections.abc import Iterable
from torch.package import GlobGroup from torch.package import GlobGroup
from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import run_tests

View File

@ -18,7 +18,7 @@ import os
import sys import sys
import tempfile import tempfile
import unittest import unittest
from typing import Any, Dict, List from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -51,7 +51,7 @@ from torch.testing._internal.common_utils import (
from torch.utils._triton import has_triton from torch.utils._triton import has_triton
Json = Dict[str, Any] Json = dict[str, Any]
class TestExecutionTrace(TestCase): class TestExecutionTrace(TestCase):
@ -97,7 +97,7 @@ class TestExecutionTrace(TestCase):
nodes = et_graph["nodes"] nodes = et_graph["nodes"]
return nodes return nodes
def get_execution_trace_rf_ids(self, nodes: List[Json]) -> List[int]: def get_execution_trace_rf_ids(self, nodes: list[Json]) -> list[int]:
"""Returns a sorted list of rf_id (record function ids) in execution trace""" """Returns a sorted list of rf_id (record function ids) in execution trace"""
def get_rf_id(node): def get_rf_id(node):
@ -115,7 +115,7 @@ class TestExecutionTrace(TestCase):
) )
return sorted(rf_id for rf_id in rf_ids_ if rf_id is not None) return sorted(rf_id for rf_id in rf_ids_ if rf_id is not None)
def get_kineto_rf_ids(self, events: List[Json]) -> List[int]: def get_kineto_rf_ids(self, events: list[Json]) -> list[int]:
"""Returns a sorted list of Record function IDs for CPU operators and user annotations""" """Returns a sorted list of Record function IDs for CPU operators and user annotations"""
ops_and_annotations = ( ops_and_annotations = (
e for e in events if e.get("cat", "") in ["cpu_op", "user_annotation"] e for e in events if e.get("cat", "") in ["cpu_op", "user_annotation"]

View File

@ -5,7 +5,8 @@ import itertools as it
import sys import sys
import textwrap import textwrap
import unittest import unittest
from typing import Callable, Dict, Iterator, List, Optional, Tuple from collections.abc import Iterator
from typing import Callable, Optional
import torch import torch
from torch._C._profiler import _EventType, _TensorMetadata from torch._C._profiler import _EventType, _TensorMetadata
@ -309,9 +310,9 @@ class TestDataFlow(TestCase):
@staticmethod @staticmethod
def formatSchemas( def formatSchemas(
prof: torch.profiler.profile, indent: int = 12 prof: torch.profiler.profile, indent: int = 12
) -> Tuple[Tuple[str, Tuple[bool, ...]], ...]: ) -> tuple[tuple[str, tuple[bool, ...]], ...]:
tree = prof.profiler.kineto_results.experimental_event_tree() tree = prof.profiler.kineto_results.experimental_event_tree()
out: List[Tuple[str, Tuple[bool, ...]]] = [] out: list[tuple[str, tuple[bool, ...]]] = []
for node in _utils.traverse_dfs(tree): for node in _utils.traverse_dfs(tree):
if node.tag == _EventType.TorchOp: if node.tag == _EventType.TorchOp:
e = node.extra_fields e = node.extra_fields
@ -327,8 +328,8 @@ class TestDataFlow(TestCase):
@staticmethod @staticmethod
def _run_and_format_data_flow( def _run_and_format_data_flow(
inputs: Dict[str, torch.Tensor], inputs: dict[str, torch.Tensor],
f: Callable[..., Optional[Dict[str, torch.Tensor]]], f: Callable[..., Optional[dict[str, torch.Tensor]]],
indent: int = 12, indent: int = 12,
) -> str: ) -> str:
with profile() as prof: with profile() as prof:
@ -339,7 +340,7 @@ class TestDataFlow(TestCase):
graph = memory_profile._data_flow_graph graph = memory_profile._data_flow_graph
storage_to_id = {key.storage.ptr: key.id for key in graph._active_version} storage_to_id = {key.storage.ptr: key.id for key in graph._active_version}
lines: List[str] = [] lines: list[str] = []
for name, t in it.chain(inputs.items(), outputs.items()): for name, t in it.chain(inputs.items(), outputs.items()):
lines.append(f"{name + ':':<8} T{storage_to_id[t.storage().data_ptr()]}") lines.append(f"{name + ':':<8} T{storage_to_id[t.storage().data_ptr()]}")
if t.grad is not None: if t.grad is not None:
@ -352,7 +353,7 @@ class TestDataFlow(TestCase):
for node in graph.flow_nodes: for node in graph.flow_nodes:
destroyed = {k for k, v in node._edges.items() if v.is_deletion} destroyed = {k for k, v in node._edges.items() if v.is_deletion}
inputs: List[str] = [] inputs: list[str] = []
for key, (_, v) in node.inputs.items(): for key, (_, v) in node.inputs.items():
inputs.append(f"T{key.id}(v{v}{'*' if key in destroyed else ''})") inputs.append(f"T{key.id}(v{v}{'*' if key in destroyed else ''})")
@ -833,7 +834,7 @@ class TestMemoryProfilerE2E(TestCase):
@staticmethod @staticmethod
def _lookup_tensor_categories( def _lookup_tensor_categories(
t: torch.Tensor, memory_profile: _memory_profiler.MemoryProfile t: torch.Tensor, memory_profile: _memory_profiler.MemoryProfile
) -> Dict[_memory_profiler.TensorAndID, Optional[_memory_profiler.Category]]: ) -> dict[_memory_profiler.TensorAndID, Optional[_memory_profiler.Category]]:
storage = t.storage() storage = t.storage()
if storage is None: if storage is None:
raise ValueError("Cannot look up uninitialized Tensor.") raise ValueError("Cannot look up uninitialized Tensor.")
@ -889,7 +890,7 @@ class TestMemoryProfilerE2E(TestCase):
fn(lambda name: record_ops.mark_region(f"-- {name} ".ljust(105, "-"))) fn(lambda name: record_ops.mark_region(f"-- {name} ".ljust(105, "-")))
memory_profile = prof._memory_profile() memory_profile = prof._memory_profile()
ptr_pair_to_key: Dict[Tuple[int, int], _memory_profiler.TensorKey] = {} ptr_pair_to_key: dict[tuple[int, int], _memory_profiler.TensorKey] = {}
snapshot = memory_profile._category_snapshot() snapshot = memory_profile._category_snapshot()
# Build map from observed live Tensors to the memory profiler's # Build map from observed live Tensors to the memory profiler's
@ -922,7 +923,7 @@ class TestMemoryProfilerE2E(TestCase):
return f"{target_key.storage.allocation_id} ({','.join(categories)})" return f"{target_key.storage.allocation_id} ({','.join(categories)})"
out: List[str] = [] out: list[str] = []
for name, inputs, outputs in record_ops.results: for name, inputs, outputs in record_ops.results:
if inputs or outputs: if inputs or outputs:
# PyTorch ops # PyTorch ops

View File

@ -17,7 +17,7 @@ import threading
import time import time
import unittest import unittest
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional from typing import Optional
from unittest.mock import patch from unittest.mock import patch
import expecttest import expecttest
@ -2277,7 +2277,7 @@ class MockProfilerEvent:
start_time_ns: int start_time_ns: int
duration_time_ns: int duration_time_ns: int
correlation_id: int = 0 correlation_id: int = 0
children: List["MockProfilerEvent"] = field(default_factory=list) children: list["MockProfilerEvent"] = field(default_factory=list)
parent: Optional["MockProfilerEvent"] = None parent: Optional["MockProfilerEvent"] = None
@property @property
@ -2301,7 +2301,7 @@ class MockNode:
@unittest.skipIf(sys.version_info >= (3, 13), "segfaults") @unittest.skipIf(sys.version_info >= (3, 13), "segfaults")
class TestExperimentalUtils(TestCase): class TestExperimentalUtils(TestCase):
def make_tree(self) -> List[MockNode]: def make_tree(self) -> list[MockNode]:
tree = { tree = {
"root_0": { "root_0": {
"1": {"2": {}}, "1": {"2": {}},

View File

@ -14,7 +14,7 @@ try:
except ImportError: except ImportError:
None None
from typing import Any, Dict from typing import Any
import torch import torch
import torch.optim import torch.optim
@ -29,7 +29,7 @@ from torch.profiler import kineto_available, record_function
from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.common_utils import run_tests, TestCase
Json = Dict[str, Any] Json = dict[str, Any]
class TestRecordFunction(TestCase): class TestRecordFunction(TestCase):

View File

@ -19,7 +19,7 @@ import sys
import textwrap import textwrap
import unittest import unittest
import weakref import weakref
from typing import Any, Dict, List from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -30,7 +30,7 @@ from torch.profiler import _utils, profile
from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.common_utils import run_tests, TestCase
Json = Dict[str, Any] Json = dict[str, Any]
from torch._C._profiler import _ExtraFields_PyCall from torch._C._profiler import _ExtraFields_PyCall
@ -455,7 +455,7 @@ class TestTorchTidyProfiler(TestCase):
nodes = p.profiler.kineto_results.experimental_event_tree() nodes = p.profiler.kineto_results.experimental_event_tree()
def find_chain(names: List[str]): def find_chain(names: list[str]):
out = [] out = []
for name in names: for name in names:
root = [out[-1]] if out else nodes root = [out[-1]] if out else nodes

View File

@ -25,7 +25,7 @@ from copy import deepcopy
from functools import partial, reduce from functools import partial, reduce
from itertools import product from itertools import product
from operator import mul from operator import mul
from typing import List, Tuple, TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
import torch.autograd._functions import torch.autograd._functions
@ -10091,14 +10091,14 @@ TORCH_LIBRARY(test_multigrad_all_hooks, m) {
def test_multi_grad_any_hooks(self): def test_multi_grad_any_hooks(self):
hook_id = 0 hook_id = 0
any_hook_handles: List[RemovableHandle] = [] any_hook_handles: list[RemovableHandle] = []
class MultiOutputModule(nn.Module): class MultiOutputModule(nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.lin = nn.Linear(3, 3) self.lin = nn.Linear(3, 3)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
z = self.lin(x) z = self.lin(x)
out = torch.sin(z), torch.cos(z) out = torch.sin(z), torch.cos(z)
nonlocal hook_id nonlocal hook_id
@ -10123,7 +10123,7 @@ TORCH_LIBRARY(test_multigrad_all_hooks, m) {
z = y[0] + y[1] z = y[0] + y[1]
return self.mod2(z) return self.mod2(z)
hook_order: List[int] = [] hook_order: list[int] = []
hook_count = 0 hook_count = 0
def hook(hook_id: int, *unused): def hook(hook_id: int, *unused):
@ -13975,7 +13975,7 @@ class TestSelectiveActivationCheckpoint(TestCase):
counter = [0] counter = [0]
@torch.library.custom_op("mylib::sin_with_extra", mutates_args=()) @torch.library.custom_op("mylib::sin_with_extra", mutates_args=())
def sin_with_extra(x: torch.Tensor) -> Tuple[torch.Tensor, int]: def sin_with_extra(x: torch.Tensor) -> tuple[torch.Tensor, int]:
counter[0] += 1 counter[0] += 1
return x.sin(), 2 return x.sin(), 2

View File

@ -4,7 +4,7 @@
import io import io
import textwrap import textwrap
from typing import Dict, List, Optional from typing import Optional
import torch import torch
import torch.utils.bundled_inputs import torch.utils.bundled_inputs
@ -32,7 +32,7 @@ class TestBundledInputs(TestCase):
sm = torch.jit.script(SingleTensorModel()) sm = torch.jit.script(SingleTensorModel())
original_size = model_size(sm) original_size = model_size(sm)
get_expr: List[str] = [] get_expr: list[str] = []
samples = [ samples = [
# Tensor with small numel and small storage. # Tensor with small numel and small storage.
(torch.tensor([1]),), (torch.tensor([1]),),
@ -328,8 +328,8 @@ class TestBundledInputs(TestCase):
class MyModel(torch.nn.Module): class MyModel(torch.nn.Module):
def forward( def forward(
self, self,
arg1: Optional[Dict[str, torch.Tensor]], arg1: Optional[dict[str, torch.Tensor]],
arg2: Optional[List[torch.Tensor]], arg2: Optional[list[torch.Tensor]],
arg3: torch.Tensor, arg3: torch.Tensor,
): ):
if arg1 is None: if arg1 is None:
@ -393,7 +393,7 @@ class TestBundledInputs(TestCase):
""", """,
) )
out: List[str] = [] out: list[str] = []
sm = torch.jit.script(MyModel()) sm = torch.jit.script(MyModel())
original_size = model_size(sm) original_size = model_size(sm)
small_inputs = ( small_inputs = (

View File

@ -4054,11 +4054,7 @@ class TestCudaMallocAsync(TestCase):
that the pytorch call is returning a correct list of UUIDs. that the pytorch call is returning a correct list of UUIDs.
""" """
cmd = "rocminfo | grep -o 'Uuid:.*GPU-.*' | sed 's/Uuid:.*GPU-//'" cmd = "rocminfo | grep -o 'Uuid:.*GPU-.*' | sed 's/Uuid:.*GPU-//'"
uuids = ( uuids = subprocess.check_output(cmd, shell=True, text=True).strip().split("\n")
subprocess.check_output(cmd, shell=True, universal_newlines=True)
.strip()
.split("\n")
)
uuids = [s.strip() for s in uuids] uuids = [s.strip() for s in uuids]
raw_uuids = torch.cuda._raw_device_uuid_amdsmi() raw_uuids = torch.cuda._raw_device_uuid_amdsmi()
for uuid in uuids: for uuid in uuids:
@ -4082,11 +4078,7 @@ import os
print(f"{torch.cuda.device_count()}") print(f"{torch.cuda.device_count()}")
""" """
cmd = "rocminfo | grep -o 'Uuid:.*GPU-.*' | sed 's/Uuid://'" cmd = "rocminfo | grep -o 'Uuid:.*GPU-.*' | sed 's/Uuid://'"
uuids = ( uuids = subprocess.check_output(cmd, shell=True, text=True).strip().split("\n")
subprocess.check_output(cmd, shell=True, universal_newlines=True)
.strip()
.split("\n")
)
uuids = [s.strip() for s in uuids] uuids = [s.strip() for s in uuids]
custom_envs = [] custom_envs = []

View File

@ -3,7 +3,7 @@
import sys import sys
import textwrap import textwrap
import traceback import traceback
from typing import List, Optional from typing import Optional
import torch import torch
import torch.cuda._sanitizer as csan import torch.cuda._sanitizer as csan
@ -148,9 +148,9 @@ class TestEventHandler(TestCase):
def kernel_launch( def kernel_launch(
self, self,
stream: StreamId, stream: StreamId,
read_only: Optional[List[DataPtr]] = None, read_only: Optional[list[DataPtr]] = None,
read_write: Optional[List[DataPtr]] = None, read_write: Optional[list[DataPtr]] = None,
) -> List[csan.SynchronizationError]: ) -> list[csan.SynchronizationError]:
if read_only is None: if read_only is None:
read_only = [] read_only = []
if read_write is None: if read_write is None:
@ -167,8 +167,8 @@ class TestEventHandler(TestCase):
def assert_good_kernel_launch( def assert_good_kernel_launch(
self, self,
stream: StreamId, stream: StreamId,
read_only: Optional[List[DataPtr]] = None, read_only: Optional[list[DataPtr]] = None,
read_write: Optional[List[DataPtr]] = None, read_write: Optional[list[DataPtr]] = None,
) -> None: ) -> None:
self.assertEqual(self.kernel_launch(stream, read_only, read_write), []) self.assertEqual(self.kernel_launch(stream, read_only, read_write), [])
@ -176,8 +176,8 @@ class TestEventHandler(TestCase):
self, self,
number_of_errors: int, number_of_errors: int,
stream: StreamId, stream: StreamId,
read_only: Optional[List[DataPtr]] = None, read_only: Optional[list[DataPtr]] = None,
read_write: Optional[List[DataPtr]] = None, read_write: Optional[list[DataPtr]] = None,
) -> None: ) -> None:
errors = self.kernel_launch(stream, read_only, read_write) errors = self.kernel_launch(stream, read_only, read_write)
self.assertEqual(len(errors), number_of_errors) self.assertEqual(len(errors), number_of_errors)

View File

@ -1262,14 +1262,12 @@ class TestDataLoader(TestCase):
list(iter(loader)) list(iter(loader))
def test_typing(self): def test_typing(self):
from typing import List
# Make sure there is no TypeError # Make sure there is no TypeError
class SomeDatasetClass(Dataset[List[torch.Tensor]]): class SomeDatasetClass(Dataset[list[torch.Tensor]]):
pass pass
def _create_dataloader(is_train: bool) -> DataLoader[List[torch.Tensor]]: def _create_dataloader(is_train: bool) -> DataLoader[list[torch.Tensor]]:
pass pass
@unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI") @unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")

View File

@ -1,7 +1,7 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
import unittest import unittest
from typing import List, Optional, Tuple from typing import Optional
import torch import torch
import torch.distributed import torch.distributed
@ -27,9 +27,9 @@ class MyModule(torch.nn.Module):
class MyDummyFnOptimizer: class MyDummyFnOptimizer:
def __init__( def __init__(
self, self,
params: List[Tensor], params: list[Tensor],
lr: float = 1e-3, lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999), betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6, eps: float = 1e-6,
weight_decay: float = 0.0, weight_decay: float = 0.0,
_allow_empty_param_list: bool = False, _allow_empty_param_list: bool = False,
@ -63,7 +63,7 @@ class MyDummyFnOptimizer:
"MyDummyFnOptimizer does not support step_param() as of now" "MyDummyFnOptimizer does not support step_param() as of now"
) )
def step(self, gradients: List[Optional[Tensor]]): def step(self, gradients: list[Optional[Tensor]]):
# call the custom optimizer step implementation # call the custom optimizer step implementation
with torch.no_grad(): with torch.no_grad():
raise RuntimeError("MyDummyFnOptimizer does not support step() as of now") raise RuntimeError("MyDummyFnOptimizer does not support step() as of now")

View File

@ -1200,6 +1200,15 @@ class {test_classname}(torch.nn.Module):
inp3_y = inp3.y inp3_y = inp3.y
return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y
class MyModule2(torch.nn.Module):
def forward(self, inp: tuple[CustomType, torch.Tensor], inp2: list[CustomType], inp3: CustomNamedTuple):
inp_0 = inp[0]
inp_1 = inp[1]
inp2_0 = inp2[0]
inp3_x = inp3.x
inp3_y = inp3.y
return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y
my_module = MyModule() my_module = MyModule()
my_module_traced = torch.fx.symbolic_trace(my_module) my_module_traced = torch.fx.symbolic_trace(my_module)
@ -1214,6 +1223,20 @@ class {test_classname}(torch.nn.Module):
if node.target == operator.getitem: if node.target == operator.getitem:
self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.") self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.")
my_module = MyModule2()
my_module_traced = torch.fx.symbolic_trace(my_module)
# by default, fx transform loses type annotation of getitem nodes.
for node in my_module_traced.graph.nodes:
if node.target == operator.getitem:
assert node.type is None
annotate_getitem_nodes(my_module_traced.graph)
for node in my_module_traced.graph.nodes:
if node.target == operator.getitem:
self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.")
def test_subgraph_uniquename(self): def test_subgraph_uniquename(self):
class MyModule(torch.nn.Module): class MyModule(torch.nn.Module):
def __init__(self) -> None: def __init__(self) -> None:

View File

@ -5,7 +5,7 @@
import itertools import itertools
import torch import torch
from typing import List, Any from typing import Any
from functools import wraps from functools import wraps
import unittest import unittest
from torch.testing._internal.common_utils import skipIfTorchDynamo from torch.testing._internal.common_utils import skipIfTorchDynamo
@ -100,7 +100,7 @@ def apply_masked_reduction_along_dim(op, input, *args, **kwargs):
# dimensions along which the reduction operation is applied: # dimensions along which the reduction operation is applied:
dim_ = torch.masked._canonical_dim(dim, input.ndim) dim_ = torch.masked._canonical_dim(dim, input.ndim)
# slices in product(*ranges) define all elementary slices: # slices in product(*ranges) define all elementary slices:
ranges: List[Any] = [] ranges: list[Any] = []
# shape of output for the keepdim=True case: # shape of output for the keepdim=True case:
shape = [] shape = []
for i in range(input.ndim): for i in range(input.ndim):

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: mkldnn"] # Owner(s): ["module: mkldnn"]
import itertools import itertools
import unittest import unittest
from typing import NamedTuple, List from typing import NamedTuple
import torch import torch
from torch import nn from torch import nn
@ -16,7 +16,7 @@ FUSION_GROUP = 'prim::TensorExprGroup'
class PointwisePostOp(NamedTuple): class PointwisePostOp(NamedTuple):
attr : str attr : str
pointwise_module : nn.Module pointwise_module : nn.Module
scalars : List = [] scalars : list = []
algorithm : str = "" algorithm : str = ""
CONV_MODULES = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d} CONV_MODULES = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}

View File

@ -1,6 +1,6 @@
# Owner(s): ["module: unknown"] # Owner(s): ["module: unknown"]
from typing import Optional, List from typing import Optional
import torch import torch
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
@ -8,12 +8,12 @@ from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorc
class FloatListWrapperModule(torch.nn.Module): class FloatListWrapperModule(torch.nn.Module):
def forward(self, values, incr: Optional[List[float]]): def forward(self, values, incr: Optional[list[float]]):
return torch._C._nn._test_optional_floatlist(values, incr) return torch._C._nn._test_optional_floatlist(values, incr)
class IntListWrapperModule(torch.nn.Module): class IntListWrapperModule(torch.nn.Module):
def forward(self, values, incr: Optional[List[int]]): def forward(self, values, incr: Optional[list[int]]):
return torch._C._nn._test_optional_intlist(values, incr) return torch._C._nn._test_optional_intlist(values, incr)

View File

@ -9,7 +9,7 @@ import sys
import tempfile import tempfile
import unittest import unittest
from functools import partial from functools import partial
from typing import Optional, Tuple from typing import Optional
import numpy as np import numpy as np
@ -3545,7 +3545,7 @@ def get_tolerances(
true_value: torch.Tensor, true_value: torch.Tensor,
computed_value: torch.Tensor, computed_value: torch.Tensor,
fudge_factor: Optional[float] = None, fudge_factor: Optional[float] = None,
) -> Tuple[float, float]: ) -> tuple[float, float]:
"""Returns the absolute and relative tolerances for comparing two tensors.""" """Returns the absolute and relative tolerances for comparing two tensors."""
fudge_factor = fudge_factor if fudge_factor is not None else 1.0 fudge_factor = fudge_factor if fudge_factor is not None else 1.0
atol = get_atol(true_value, computed_value) atol = get_atol(true_value, computed_value)

View File

@ -4,7 +4,6 @@
import ctypes import ctypes
import os import os
import unittest import unittest
from typing import Tuple
import torch import torch
from torch.backends._nnapi.prepare import convert_model_to_nnapi from torch.backends._nnapi.prepare import convert_model_to_nnapi
@ -700,7 +699,7 @@ class TestNNAPI(TestCase):
def test_multi_output(self): def test_multi_output(self):
class MultiModel(torch.nn.Module): class MultiModel(torch.nn.Module):
def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, lhs, rhs) -> tuple[torch.Tensor, torch.Tensor]:
the_sum = lhs + rhs the_sum = lhs + rhs
the_diff = lhs - rhs the_diff = lhs - rhs
return the_sum, the_diff return the_sum, the_diff

View File

@ -278,7 +278,7 @@ class TestNumPyInterop(TestCase):
def test_from_numpy_no_leak_on_invalid_dtype(self): def test_from_numpy_no_leak_on_invalid_dtype(self):
# This used to leak memory as the `from_numpy` call raised an exception and didn't decref the temporary # This used to leak memory as the `from_numpy` call raised an exception and didn't decref the temporary
# object. See https://github.com/pytorch/pytorch/issues/121138 # object. See https://github.com/pytorch/pytorch/issues/121138
x = np.array("value".encode("ascii")) x = np.array(b"value")
for _ in range(1000): for _ in range(1000):
try: try:
torch.from_numpy(x) torch.from_numpy(x)

View File

@ -11,7 +11,6 @@ from collections import defaultdict
from collections.abc import Sequence from collections.abc import Sequence
from functools import partial from functools import partial
from importlib import import_module from importlib import import_module
from typing import Dict, List
import torch import torch
import torch._prims as prims import torch._prims as prims
@ -1483,7 +1482,7 @@ class TestCommon(TestCase):
unsupported_dtypes = set() unsupported_dtypes = set()
supported_backward_dtypes = set() supported_backward_dtypes = set()
unsupported_backward_dtypes = set() unsupported_backward_dtypes = set()
dtype_error: Dict[torch.dtype, Exception] = {} dtype_error: dict[torch.dtype, Exception] = {}
def unsupported(dtype, e): def unsupported(dtype, e):
dtype_error[dtype] = e dtype_error[dtype] = e
@ -1987,7 +1986,7 @@ class TestCompositeCompliance(TestCase):
for sample in op.sample_inputs(device, dtype, requires_grad=False): for sample in op.sample_inputs(device, dtype, requires_grad=False):
inp = sample.input inp = sample.input
outs = op(inp, *sample.args, **sample.kwargs) outs = op(inp, *sample.args, **sample.kwargs)
if not isinstance(outs, (tuple, List)): if not isinstance(outs, (tuple, list)):
outs = [outs] outs = [outs]
# for all outputs that are views of the input, we should be able to replay the # for all outputs that are views of the input, we should be able to replay the

View File

@ -4,7 +4,7 @@ import math
import tempfile import tempfile
import unittest import unittest
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, Tuple from typing import Any
from unittest.mock import patch from unittest.mock import patch
from optim.test_lrscheduler import TestLRScheduler # noqa: F401 from optim.test_lrscheduler import TestLRScheduler # noqa: F401
@ -1769,8 +1769,8 @@ class TestOptimRenewed(TestCase):
@staticmethod @staticmethod
def _state_dict_post_hook( def _state_dict_post_hook(
optimizer: Optimizer, state_dict: Dict[str, Any] optimizer: Optimizer, state_dict: dict[str, Any]
) -> Dict[str, Any]: ) -> dict[str, Any]:
if "test" in state_dict["state"]: if "test" in state_dict["state"]:
state_dict["state"].pop("test") state_dict["state"].pop("test")
state_dict["ran_state_dict_pre_hook"] = True state_dict["ran_state_dict_pre_hook"] = True
@ -1821,14 +1821,14 @@ class TestOptimRenewed(TestCase):
@staticmethod @staticmethod
def _load_state_dict_pre_hook1( def _load_state_dict_pre_hook1(
optimizer: Optimizer, state_dict: Dict[str, Any] optimizer: Optimizer, state_dict: dict[str, Any]
) -> None: ) -> None:
state_dict["param_groups"][0]["lr"] = 0.002 state_dict["param_groups"][0]["lr"] = 0.002
@staticmethod @staticmethod
def _load_state_dict_pre_hook2( def _load_state_dict_pre_hook2(
optimizer: Optimizer, state_dict: Dict[str, Any] optimizer: Optimizer, state_dict: dict[str, Any]
) -> Dict[str, Any]: ) -> dict[str, Any]:
# The typical use case for returning a state dict is to drastically modify the state dict. # The typical use case for returning a state dict is to drastically modify the state dict.
# I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used # I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used
my_state_dict = deepcopy(state_dict) my_state_dict = deepcopy(state_dict)
@ -1906,7 +1906,7 @@ class TestOptimRenewed(TestCase):
@optims(optim_db, dtypes=[torch.float32]) @optims(optim_db, dtypes=[torch.float32])
def test_step_post_hook(self, device, dtype, optim_info): def test_step_post_hook(self, device, dtype, optim_info):
def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): def post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
nonlocal data nonlocal data
data += 2 data += 2
@ -1938,7 +1938,7 @@ class TestOptimRenewed(TestCase):
@optims(optim_db, dtypes=[torch.float32]) @optims(optim_db, dtypes=[torch.float32])
def test_step_pre_hook(self, device, dtype, optim_info): def test_step_pre_hook(self, device, dtype, optim_info):
def pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): def pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
nonlocal data nonlocal data
data += 2 data += 2
@ -1970,19 +1970,19 @@ class TestOptimRenewed(TestCase):
@optims(optim_db, dtypes=[torch.float32]) @optims(optim_db, dtypes=[torch.float32])
def test_step_all_hooks(self, device, dtype, optim_info): def test_step_all_hooks(self, device, dtype, optim_info):
def global_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): def global_pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
nonlocal data nonlocal data
data.append(0) data.append(0)
def global_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): def global_post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
nonlocal data nonlocal data
data.append(5) data.append(5)
def local_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): def local_pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
nonlocal data nonlocal data
data.append(1) data.append(1)
def local_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): def local_post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
nonlocal data nonlocal data
data.append(2) data.append(2)

View File

@ -5,7 +5,7 @@ import torch
import numpy as np import numpy as np
import math import math
from typing import Dict, List, Sequence from collections.abc import Sequence
import random import random
from functools import partial from functools import partial
from itertools import product, combinations, permutations from itertools import product, combinations, permutations
@ -736,7 +736,7 @@ class TestReductions(TestCase):
# TODO: kill this ane replace with common creation ops # TODO: kill this ane replace with common creation ops
def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True, def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True,
use_complex=False) -> Dict[str, List[torch.Tensor]]: use_complex=False) -> dict[str, list[torch.Tensor]]:
float_types = [torch.double, float_types = [torch.double,
torch.float] torch.float]
int_types = [torch.int64, int_types = [torch.int64,
@ -778,7 +778,7 @@ class TestReductions(TestCase):
types += int_types types += int_types
if use_complex: if use_complex:
types += complex_types types += complex_types
tensors: Dict[str, List[torch.Tensor]] = {"cont": [], "noncont": [], "slice": []} tensors: dict[str, list[torch.Tensor]] = {"cont": [], "noncont": [], "slice": []}
for dtype in types: for dtype in types:
tensors["cont"].append(make_contiguous(shape, dtype)) tensors["cont"].append(make_contiguous(shape, dtype))
tensors["noncont"].append(make_non_contiguous(shape, dtype)) tensors["noncont"].append(make_non_contiguous(shape, dtype))

View File

@ -15,7 +15,7 @@ from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm
skipIfCrossRef skipIfCrossRef
from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_cuda import TEST_CUDA
from numbers import Number from numbers import Number
from typing import Dict, Any from typing import Any
from packaging import version from packaging import version
from torch.testing._internal.common_cuda import \ from torch.testing._internal.common_cuda import \
(SM53OrLater, SM80OrLater, TEST_MULTIGPU) (SM53OrLater, SM80OrLater, TEST_MULTIGPU)
@ -334,7 +334,7 @@ class TestSparse(TestSparseBase):
self.assertEqual(t._values(), tc._values()) self.assertEqual(t._values(), tc._values())
return tc return tc
value_map: Dict[Any, Any] = {} value_map: dict[Any, Any] = {}
for idx, val in zip(t._indices().t(), t._values()): for idx, val in zip(t._indices().t(), t._values()):
idx_tup = tuple(idx.tolist()) idx_tup = tuple(idx.tolist())
if idx_tup in value_map: if idx_tup in value_map:

View File

@ -21,7 +21,7 @@ from torch.testing._internal.common_methods_invocations import (
from torch.testing._internal.common_cuda import SM53OrLater from torch.testing._internal.common_cuda import SM53OrLater
from torch._prims_common import corresponding_complex_dtype from torch._prims_common import corresponding_complex_dtype
from typing import Optional, List from typing import Optional
from packaging import version from packaging import version
@ -597,7 +597,7 @@ class TestFFT(TestCase):
else: else:
numpy_fn = getattr(np.fft, fname) numpy_fn = getattr(np.fft, fname)
def fn(t: torch.Tensor, s: Optional[List[int]], dim: List[int] = (-2, -1), norm: Optional[str] = None): def fn(t: torch.Tensor, s: Optional[list[int]], dim: list[int] = (-2, -1), norm: Optional[str] = None):
return torch_fn(t, s, dim, norm) return torch_fn(t, s, dim, norm)
torch_fns = (torch_fn, torch.jit.script(fn)) torch_fns = (torch_fn, torch.jit.script(fn))

View File

@ -2,14 +2,13 @@
# ruff: noqa: F841 # ruff: noqa: F841
import unittest import unittest
from typing import Dict, Optional from typing import Optional
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.static_module import StaticModule from torch.testing._internal.static_module import StaticModule
from typing import List
def linear_shim( def linear_shim(
@ -108,7 +107,7 @@ def fork_wait_graph2(input1, input2):
:param iters: number of future/wait pairs to be created :param iters: number of future/wait pairs to be created
""" """
def fork_wait_graph3(input, iters: int): def fork_wait_graph3(input, iters: int):
futures : List[torch.jit.Future[torch.Tensor]] = [] futures : list[torch.jit.Future[torch.Tensor]] = []
for _ in range(iters): for _ in range(iters):
futures.append(torch.jit.fork(torch.neg, input)) futures.append(torch.jit.fork(torch.neg, input))
results = [] results = []
@ -123,7 +122,7 @@ def fork_wait_graph3(input, iters: int):
:param num_child_forks: number of child forks per parent fork :param num_child_forks: number of child forks per parent fork
""" """
def fork_wait_graph4(input, num_forks: int, num_child_forks: int): def fork_wait_graph4(input, num_forks: int, num_child_forks: int):
futures : List[torch.jit.Future[torch.Tensor]] = [] futures : list[torch.jit.Future[torch.Tensor]] = []
for _ in range(num_forks): for _ in range(num_forks):
futures.append(torch.jit.fork(fork_wait_graph3, input, num_child_forks)) futures.append(torch.jit.fork(fork_wait_graph3, input, num_child_forks))
results = [] results = []
@ -150,7 +149,7 @@ def loop_graph(a, b, iters: int):
def output_graph(a, b, c, iters: int): def output_graph(a, b, c, iters: int):
s = torch.tensor([[3, 3], [3, 3]]) s = torch.tensor([[3, 3], [3, 3]])
k = a + b * c + s k = a + b * c + s
d: Dict[int, torch.Tensor] = {} d: dict[int, torch.Tensor] = {}
for i in range(iters): for i in range(iters):
d[i] = k + i d[i] = k + i
return d return d

View File

@ -5,7 +5,7 @@ import itertools
import math import math
import pickle import pickle
import sys import sys
from typing import Callable, List, Tuple, Type from typing import Callable
import sympy import sympy
@ -594,7 +594,7 @@ class TestSympyInterp(TestCase):
self.fail(f"Unexpected error for {fn}{args}: {str(e)}") self.fail(f"Unexpected error for {fn}{args}: {str(e)}")
def type_name_fn(type: Type) -> str: def type_name_fn(type: type) -> str:
return type.__name__ return type.__name__
@ -606,7 +606,7 @@ def parametrize_relational_types(*types):
class TestSympySolve(TestCase): class TestSympySolve(TestCase):
def _create_integer_symbols(self) -> List[sympy.Symbol]: def _create_integer_symbols(self) -> list[sympy.Symbol]:
return sympy.symbols("a b c", integer=True) return sympy.symbols("a b c", integer=True)
def test_give_up(self): def test_give_up(self):
@ -665,9 +665,9 @@ class TestSympySolve(TestCase):
def _test_cases( def _test_cases(
self, self,
cases: List[Tuple[sympy.Basic, sympy.Basic]], cases: list[tuple[sympy.Basic, sympy.Basic]],
thing: sympy.Basic, thing: sympy.Basic,
op: Type[sympy.Rel], op: type[sympy.Rel],
**kwargs, **kwargs,
): ):
for source, expected in cases: for source, expected in cases:
@ -761,7 +761,7 @@ class TestSympySolve(TestCase):
Le: (Le(FloorDiv(a, pos), integer), (integer + 1) * pos), Le: (Le(FloorDiv(a, pos), integer), (integer + 1) * pos),
}[op] }[op]
cases: List[Tuple[sympy.Basic, sympy.Basic]] = [ cases: list[tuple[sympy.Basic, sympy.Basic]] = [
# 'b' is not strictly positive # 'b' is not strictly positive
(op(FloorDiv(a, b), integer), None), (op(FloorDiv(a, b), integer), None),
# 'c' is not strictly positive # 'c' is not strictly positive

View File

@ -11,7 +11,7 @@ import unittest
from itertools import product, combinations, combinations_with_replacement, permutations from itertools import product, combinations, combinations_with_replacement, permutations
import random import random
import tempfile import tempfile
from typing import Any, Dict, List, Tuple from typing import Any
from torch.testing import make_tensor from torch.testing import make_tensor
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
@ -4125,7 +4125,7 @@ class TestAsArray(TestCase):
def test_default_device(self, device): def test_default_device(self, device):
original = torch.arange(5) original = torch.arange(5)
examples: List[Tuple[Any, Dict]] = [ examples: list[tuple[Any, dict]] = [
(3, {}), (3, {}),
(original, {}), (original, {}),
(to_numpy(original), {}), (to_numpy(original), {}),

View File

@ -12,7 +12,8 @@ import re
import subprocess import subprocess
import sys import sys
import unittest.mock import unittest.mock
from typing import Any, Callable, Iterator, List, Tuple from typing import Any, Callable
from collections.abc import Iterator
import torch import torch
@ -496,7 +497,7 @@ if __name__ == '__main__':
self.assertNotIn('OK', stderr.decode('ascii')) self.assertNotIn('OK', stderr.decode('ascii'))
def make_assert_close_inputs(actual: Any, expected: Any) -> List[Tuple[Any, Any]]: def make_assert_close_inputs(actual: Any, expected: Any) -> list[tuple[Any, Any]]:
"""Makes inputs for :func:`torch.testing.assert_close` functions based on two examples. """Makes inputs for :func:`torch.testing.assert_close` functions based on two examples.
Args: Args:

View File

@ -53,7 +53,6 @@ from torch.testing._internal.common_device_type import (
dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast,
skipMeta, PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyNativeDeviceTypes, skipCUDAIfNotRocm, skipMeta, PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyNativeDeviceTypes, skipCUDAIfNotRocm,
get_all_device_types, skipXLA) get_all_device_types, skipXLA)
from typing import Tuple
import torch.backends.quantized import torch.backends.quantized
import torch.testing._internal.data import torch.testing._internal.data
from torch.testing._internal.common_cuda import ( from torch.testing._internal.common_cuda import (
@ -3511,7 +3510,7 @@ else:
def _prepare_data_for_index_copy_and_add_deterministic( def _prepare_data_for_index_copy_and_add_deterministic(
self, dim: int, device: torch.device self, dim: int, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert (dim >= 0 and dim < 3) assert (dim >= 0 and dim < 3)
a = [5, 4, 3] a = [5, 4, 3]
a[dim] = 2000 a[dim] = 2000

View File

@ -18,7 +18,7 @@ import math
import itertools import itertools
import torch.optim as optim import torch.optim as optim
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU
from typing import List, Tuple, Optional, Dict from typing import Optional
import torch.utils.cpp_extension import torch.utils.cpp_extension
from torch.testing._internal.common_nn import NNTestCase from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
@ -149,12 +149,12 @@ def _check_equal(
def check_out_and_grad( def check_out_and_grad(
out_tuple: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], out_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
grad_query_tuple: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], grad_query_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
grad_key_tuple: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], grad_key_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
grad_value_tuple: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], grad_value_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
grad_attn_mask_tuple: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, grad_attn_mask_tuple: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None,
fudge_factors: Optional[Dict[str, float]] = None fudge_factors: Optional[dict[str, float]] = None
) -> None: ) -> None:
""" """
Check output and gradients of attention mechanism tensors. Check output and gradients of attention mechanism tensors.
@ -2574,7 +2574,7 @@ class TestSDPACudaOnly(NNTestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
def test_cudnn_attention_preserves_query_layout(self, device): def test_cudnn_attention_preserves_query_layout(self, device):
def test_attention(backend: SDPBackend, permute_order: List[List[int]]): def test_attention(backend: SDPBackend, permute_order: list[list[int]]):
BHSqD = [4, 16, 256, 64] BHSqD = [4, 16, 256, 64]
BHSkvD = [4, 16, 512, 64] BHSkvD = [4, 16, 512, 64]
@ -2602,7 +2602,7 @@ class TestSDPACudaOnly(NNTestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
@parametrize("mask_dim", [1, 2, 3, 4]) @parametrize("mask_dim", [1, 2, 3, 4])
def test_mem_efficient_attention_mask_variants(self, device, mask_dim: List[int]): def test_mem_efficient_attention_mask_variants(self, device, mask_dim: list[int]):
dtype = torch.float16 dtype = torch.float16
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
batch, num_heads, head_dim = 8, 8, 64 batch, num_heads, head_dim = 8, 8, 64
@ -3307,7 +3307,7 @@ class TestSDPACudaOnly(NNTestCase):
@tf32_enabled() @tf32_enabled()
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
scale: str, enable_gqa: bool, n_heads: List[int]): scale: str, enable_gqa: bool, n_heads: list[int]):
if isSM8XDevice and head_dim in range(193, 256 + 1): if isSM8XDevice and head_dim in range(193, 256 + 1):
self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled") self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled")
if is_causal and seq_len_q != seq_len_k: if is_causal and seq_len_q != seq_len_k:
@ -3905,7 +3905,7 @@ class TestAttnBias(NNTestCase):
"shape", "shape",
[(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)], [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)],
) )
def test_causal_variants(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]): def test_causal_variants(self, device, causal_variant: CausalVariant, shape: list[tuple[int]]):
make_tensor = partial( make_tensor = partial(
torch.rand, device=device, dtype=torch.float16, requires_grad=True torch.rand, device=device, dtype=torch.float16, requires_grad=True
) )
@ -3942,7 +3942,7 @@ class TestAttnBias(NNTestCase):
) )
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows") @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows")
@skipIfTorchDynamo("This function already calls torch.compile.") @skipIfTorchDynamo("This function already calls torch.compile.")
def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]): def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: list[tuple[int]]):
if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT: if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT:
self.skipTest("No support for LOWER_RIGHT variant for now") self.skipTest("No support for LOWER_RIGHT variant for now")
return return
@ -3975,7 +3975,7 @@ class TestAttnBias(NNTestCase):
self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!") self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!")
@parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)]) @parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)])
def test_is_causal_equals_upper_left(self, device, shape: List[Tuple[int]]): def test_is_causal_equals_upper_left(self, device, shape: list[tuple[int]]):
make_tensor = partial( make_tensor = partial(
torch.rand, device=device, dtype=torch.float16, requires_grad=True torch.rand, device=device, dtype=torch.float16, requires_grad=True
) )

View File

@ -8,7 +8,7 @@ import shutil
import unittest import unittest
from collections import defaultdict from collections import defaultdict
from threading import Lock from threading import Lock
from typing import Dict, IO, List, Optional from typing import IO, Optional
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
@ -49,12 +49,12 @@ def _strip_filename(msg: str) -> str:
return tail.split(":", 1)[-1] return tail.split(":", 1)[-1]
def _run_mypy() -> Dict[str, List[str]]: def _run_mypy() -> dict[str, list[str]]:
"""Clears the cache and run mypy before running any of the typing tests.""" """Clears the cache and run mypy before running any of the typing tests."""
if os.path.isdir(CACHE_DIR): if os.path.isdir(CACHE_DIR):
shutil.rmtree(CACHE_DIR) shutil.rmtree(CACHE_DIR)
rc: Dict[str, List[str]] = {} rc: dict[str, list[str]] = {}
for directory in (REVEAL_DIR, PASS_DIR, FAIL_DIR): for directory in (REVEAL_DIR, PASS_DIR, FAIL_DIR):
# Run mypy # Run mypy
stdout, stderr, _ = api.run( stdout, stderr, _ = api.run(
@ -119,10 +119,10 @@ def _construct_format_dict():
#: A dictionary with all supported format keys (as keys) #: A dictionary with all supported format keys (as keys)
#: and matching values #: and matching values
FORMAT_DICT: Dict[str, str] = _construct_format_dict() FORMAT_DICT: dict[str, str] = _construct_format_dict()
def _parse_reveals(file: IO[str]) -> List[str]: def _parse_reveals(file: IO[str]) -> list[str]:
"""Extract and parse all ``" # E: "`` comments from the passed file-like object. """Extract and parse all ``" # E: "`` comments from the passed file-like object.
All format keys will be substituted for their respective value from `FORMAT_DICT`, All format keys will be substituted for their respective value from `FORMAT_DICT`,
@ -160,10 +160,10 @@ def _test_reveal(path: str, reveal: str, expected_reveal: str, lineno: int) -> N
@unittest.skipIf(NO_MYPY, reason="Mypy is not installed") @unittest.skipIf(NO_MYPY, reason="Mypy is not installed")
class TestTyping(TestCase): class TestTyping(TestCase):
_lock = Lock() _lock = Lock()
_cached_output: Optional[Dict[str, List[str]]] = None _cached_output: Optional[dict[str, list[str]]] = None
@classmethod @classmethod
def get_mypy_output(cls) -> Dict[str, List[str]]: def get_mypy_output(cls) -> dict[str, list[str]]:
with cls._lock: with cls._lock:
if cls._cached_output is None: if cls._cached_output is None:
cls._cached_output = _run_mypy() cls._cached_output = _run_mypy()
@ -192,7 +192,7 @@ class TestTyping(TestCase):
with open(path) as fin: with open(path) as fin:
lines = fin.readlines() lines = fin.readlines()
errors = defaultdict(lambda: "") errors = defaultdict(str)
output_mypy = self.get_mypy_output() output_mypy = self.get_mypy_output()
self.assertIn(path, output_mypy) self.assertIn(path, output_mypy)

View File

@ -12,7 +12,7 @@ import textwrap
import traceback import traceback
import unittest import unittest
import warnings import warnings
from typing import Any, Dict, List from typing import Any
import torch import torch
import torch.cuda import torch.cuda
@ -439,7 +439,7 @@ class TestCheckpoint(TestCase):
# get de-allocated directly. So using cuda memory usage as a proxy # get de-allocated directly. So using cuda memory usage as a proxy
def _do_test(fn, should_free): def _do_test(fn, should_free):
stats: List[int] = [] stats: list[int] = []
def track(x, idx): def track(x, idx):
# Track that at each step of the backward, some Tensor were # Track that at each step of the backward, some Tensor were
@ -1203,7 +1203,7 @@ def f(x):
return g(x) + 1 return g(x) + 1
""" """
out: Dict[str, Any] = {} out: dict[str, Any] = {}
scope = {"__compile_source__": source} scope = {"__compile_source__": source}
exec(source, scope, out) exec(source, scope, out)

View File

@ -7,7 +7,7 @@ def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
""" """
Annotate the type of getitem nodes, inferred from the type of sequence node. Annotate the type of getitem nodes, inferred from the type of sequence node.
If sequence node is not annotated with a type, do nothing. If sequence node is not annotated with a type, do nothing.
Currently support getitem nodes from Tuple, List, and NamedTuple sequence node. Currently support getitem nodes from tuple, list, and NamedTuple sequence node.
This is helpful since annotations on local names within function are lost during FX transforms. This is helpful since annotations on local names within function are lost during FX transforms.
Adding back known type annotation for getitem nodes to improve jit scriptability. Adding back known type annotation for getitem nodes to improve jit scriptability.
@ -35,6 +35,21 @@ def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
elif sequence_node.type._name == "List": elif sequence_node.type._name == "List":
assert len(parameterized_types) == 1 assert len(parameterized_types) == 1
node.type = parameterized_types[0] node.type = parameterized_types[0]
# Generic Alias Type
elif hasattr(sequence_node.type, "__origin__"):
parameterized_types = sequence_node.type.__args__
if sequence_node.type.__origin__ is tuple:
if len(parameterized_types) == 2 and isinstance(
parameterized_types[1], type(...)
):
node.type = parameterized_types[0]
else:
assert len(parameterized_types) > index_node
node_type = parameterized_types[index_node]
node.type = node_type
elif sequence_node.type.__origin__ is list:
assert len(parameterized_types) == 1
node.type = parameterized_types[0]
# NamedTuple type # NamedTuple type
elif hasattr(sequence_node.type, "__annotations__"): elif hasattr(sequence_node.type, "__annotations__"):
if sequence_node.type == torch.Tensor: if sequence_node.type == torch.Tensor: