[4/N] Apply py39 ruff and pyupgrade fixes (#143257)

```torch/fx/passes/annotate_getitem_nodes.py``` was changed to support the new type hinting annotations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143257
Approved by: https://github.com/justinchuby, https://github.com/albanD
This commit is contained in:
cyy 2025-01-04 10:47:51 +00:00 committed by PyTorch MergeBot
parent a881954b0c
commit df458be4e5
55 changed files with 247 additions and 227 deletions

View File

@ -1,8 +1,3 @@
import dis
import inspect
from collections.abc import Sequence
from typing import Union
import functorch._C
import torch
from functorch._C import dim as _C

View File

@ -27,7 +27,7 @@ from __future__ import annotations
import keyword
import warnings
from typing import List, Optional, Set, Tuple, TYPE_CHECKING, Union
from typing import Optional, TYPE_CHECKING, Union
if TYPE_CHECKING:
@ -73,11 +73,11 @@ class ParsedExpression:
"""
self.has_ellipsis: bool = False
self.has_ellipsis_parenthesized: Optional[bool] = None
self.identifiers: Set[Union[str, AnonymousAxis]] = set()
self.identifiers: set[Union[str, AnonymousAxis]] = set()
# that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
self.has_non_unitary_anonymous_axes: bool = False
# composition keeps structure of composite axes, see how different corner cases are handled in tests
self.composition: List[Union[List[Union[str, AnonymousAxis]], str]] = []
self.composition: list[Union[list[Union[str, AnonymousAxis]], str]] = []
if "." in expression:
if "..." not in expression:
raise ValueError(
@ -90,7 +90,7 @@ class ParsedExpression:
expression = expression.replace("...", _ellipsis)
self.has_ellipsis = True
bracket_group: Optional[List[Union[str, AnonymousAxis]]] = None
bracket_group: Optional[list[Union[str, AnonymousAxis]]] = None
def add_axis_name(x: str) -> None:
if x in self.identifiers:
@ -164,7 +164,7 @@ class ParsedExpression:
@staticmethod
def check_axis_name_return_reason(
name: str, allow_underscore: bool = False
) -> Tuple[bool, str]:
) -> tuple[bool, str]:
"""Check if the given axis name is valid, and a message explaining why if not.
Valid axes names are python identifiers except keywords, and should not start or end with an underscore.
@ -174,7 +174,7 @@ class ParsedExpression:
allow_underscore (bool): whether axis names are allowed to start with an underscore
Returns:
Tuple[bool, str]: whether the axis name is valid, a message explaining why if not
tuple[bool, str]: whether the axis name is valid, a message explaining why if not
"""
if not str.isidentifier(name):
return False, "not a valid python identifier"
@ -211,7 +211,7 @@ class ParsedExpression:
def parse_pattern(
pattern: str, axes_lengths: Mapping[str, int]
) -> Tuple[ParsedExpression, ParsedExpression]:
) -> tuple[ParsedExpression, ParsedExpression]:
"""Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object.
Args:
@ -219,7 +219,7 @@ def parse_pattern(
axes_lengths (Mapping[str, int]): any additional length specifications for dimensions
Returns:
Tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions
tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions
"""
# adapted from einops.einops._prepare_transformation_recipe
# https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/einops/einops.py

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import functools
from typing import Callable, Dict, List, Tuple, TYPE_CHECKING, Union
from typing import Callable, TYPE_CHECKING, Union
import torch
from functorch._C import dim as _C
@ -18,7 +18,6 @@ from ._parsing import (
if TYPE_CHECKING:
from collections.abc import Sequence
__all__ = ["rearrange"]
dims = _C.dims
@ -69,9 +68,9 @@ def _create_rearrange_callable(
# an identity rearrangement on a 0-dimension tensor
return lambda tensor: tensor
first_class_dims: Tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))
identifier_dim_map: Dict[Union[str, AnonymousAxis], Tuple[str, ...]] = {}
anon_axes: List[AnonymousAxis] = []
first_class_dims: tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))
identifier_dim_map: dict[Union[str, AnonymousAxis], tuple[str, ...]] = {}
anon_axes: list[AnonymousAxis] = []
# map the left-hand side identifiers to strings representing first class dims
dims_i = 0
@ -99,11 +98,11 @@ def _create_rearrange_callable(
raise ValueError(f"Unexpected dimension: {dimension}")
def composition_to_dims(
composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]],
) -> List[Union[str, Tuple[str, ...]]]:
composition: Sequence[Union[list[Union[str, AnonymousAxis]], str]],
) -> list[Union[str, tuple[str, ...]]]:
"""Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first
class dims."""
dim_composition: List[Union[str, Tuple[str, ...]]] = []
dim_composition: list[Union[str, tuple[str, ...]]] = []
for dimension in composition:
if isinstance(dimension, list):
dim_composition.append(
@ -152,7 +151,7 @@ def _create_rearrange_callable(
def rearrange(
tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
tensor: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]],
pattern: str,
**axes_lengths: int,
) -> torch.Tensor:

View File

@ -2,7 +2,6 @@
import copy
import logging
from typing import List
import torch
import torch.nn as nn
@ -247,7 +246,7 @@ class TestActivationSparsifier(TestCase):
assert mask2 is None
else:
assert type(mask1) == type(mask2)
if isinstance(mask1, List):
if isinstance(mask1, list):
assert len(mask1) == len(mask2)
for idx in range(len(mask1)):
assert torch.all(mask1[idx] == mask2[idx])
@ -258,7 +257,7 @@ class TestActivationSparsifier(TestCase):
for state in state_dict["state"].values():
mask = state["mask"]
if mask is not None:
if isinstance(mask, List):
if isinstance(mask, list):
for idx in range(len(mask)):
assert mask[idx].is_sparse
else:

View File

@ -3,7 +3,6 @@
import copy
import logging
import warnings
from typing import Tuple
import torch
from torch import nn
@ -73,7 +72,7 @@ class TestBaseDataScheduler(TestCase):
def _get_name_data_config(self, some_data, defaults):
config = copy.deepcopy(defaults)
if isinstance(some_data, Tuple):
if isinstance(some_data, tuple):
# dealing with data_list
name, data = some_data
else:

View File

@ -4,7 +4,6 @@ import copy
import itertools
import logging
import math
from typing import Tuple
import torch
from torch import nn
@ -54,7 +53,7 @@ class _BaseDataSparsiferTestCase(TestCase):
@staticmethod
def _get_name_data_config(some_data, defaults=None):
if isinstance(some_data, Tuple):
if isinstance(some_data, tuple):
# dealing with data_list
name, data = some_data
config = defaults
@ -482,8 +481,9 @@ class TestBaseDataSparsifier(_BaseDataSparsiferTestCase):
nn.Parameter(torch.randn(4, 4)),
nn.Parameter(torch.randn(5, 5)),
)
param4, param5 = nn.Parameter(torch.randn(1, 1)), nn.Parameter(
torch.randn(4, 4)
param4, param5 = (
nn.Parameter(torch.randn(1, 1)),
nn.Parameter(torch.randn(4, 4)),
)
data_list = [("param1", param1), ("param2", param2), ("param3", param3)]
defaults = {"test": 3}
@ -585,8 +585,9 @@ class TestNormDataSparsifiers(_NormDataSparsifierTestCase):
nn.Parameter(torch.randn(4, 4)),
nn.Parameter(torch.randn(5, 5)),
)
param4, param5 = nn.Parameter(torch.randn(10, 10)), nn.Parameter(
torch.randn(4, 4)
param4, param5 = (
nn.Parameter(torch.randn(10, 10)),
nn.Parameter(torch.randn(4, 4)),
)
data_list = [("param1", param1), ("param2", param2), ("param3", param3)]
defaults = {

View File

@ -2,13 +2,17 @@
from __future__ import annotations
import typing
from typing import List, Optional, Sequence, Union # noqa: F401
from typing import List, Optional, Union
import torch
from torch import Tensor, types
from torch.testing._internal.common_utils import run_tests, TestCase
if typing.TYPE_CHECKING:
from collections.abc import Sequence
mutates_args = {}

View File

@ -6,7 +6,8 @@ import functools
import itertools
import unittest
from collections import defaultdict
from typing import Any, Iterable, List, Optional, Tuple, Union
from collections.abc import Iterable
from typing import Any, List, Optional, Tuple, Union
import torch
import torch.distributed as dist

View File

@ -1,7 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
from typing import Any, Callable, Dict, Optional, Sequence
from collections.abc import Sequence
from typing import Any, Callable, Dict, Optional
from unittest import skip
import torch

View File

@ -10,8 +10,9 @@ import inspect
import io
import operator
import unittest
from collections.abc import Sequence
from enum import Enum
from typing import Dict, List, Sequence
from typing import Dict, List
from unittest.mock import patch
import torch

View File

@ -21,10 +21,11 @@ import warnings
import weakref
from abc import ABC
from collections import namedtuple
from collections.abc import Iterator
from copy import deepcopy
from enum import Enum, IntEnum
from functools import wraps
from typing import Any, Dict, Iterator, List, Literal, Tuple, TypedDict
from typing import Any, Dict, List, Literal, Tuple, TypedDict
from unittest import mock
import numpy as np

View File

@ -1,6 +1,7 @@
# Owner(s): ["oncall: export"]
import unittest
from typing import Any, Dict, Optional, OrderedDict, Tuple
from collections import OrderedDict
from typing import Any, Dict, Optional, Tuple
import torch
from torch._export.passes.lift_constants_pass import (

View File

@ -14,8 +14,7 @@ import random
import types
import unittest
import warnings
from collections import namedtuple
from typing import OrderedDict
from collections import namedtuple, OrderedDict
from unittest.case import skipIf
from common_utils import (

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: fx"]
import unittest
from typing import Mapping
from collections.abc import Mapping
import torch
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner

View File

@ -3,8 +3,8 @@
import functools
import itertools
import os
from collections.abc import Sequence
from pathlib import Path
from typing import Sequence
from unittest import skip
import yaml

View File

@ -9,7 +9,7 @@ from collections import namedtuple, OrderedDict
from copy import deepcopy
from functools import partial
from tempfile import NamedTemporaryFile
from typing import Any, Dict, List, Tuple
from typing import Any
import torch
import torch.nn as nn
@ -55,11 +55,11 @@ class ToyModel(nn.Module):
def forward_hook(
self: TestCase,
fired_hooks: List[int],
fired_hooks: list[int],
expected_module: nn.Module,
hook_id: int,
module: nn.Module,
inp: Tuple[torch.Tensor],
inp: tuple[torch.Tensor],
out: torch.Tensor,
) -> None:
fired_hooks.append(hook_id)
@ -69,11 +69,11 @@ def forward_hook(
def forward_pre_hook(
self: TestCase,
fired_hooks: List[int],
fired_hooks: list[int],
expected_module: nn.Module,
hook_id: int,
module: nn.Module,
inp: Tuple[torch.Tensor],
inp: tuple[torch.Tensor],
) -> None:
fired_hooks.append(hook_id)
self.assertEqual(id(module), id(expected_module))
@ -82,12 +82,12 @@ def forward_pre_hook(
def full_backward_hook(
self: TestCase,
fired_hooks: List[int],
fired_hooks: list[int],
expected_module: nn.Module,
hook_id: int,
module: nn.Module,
grad_input: Tuple[torch.Tensor],
grad_output: Tuple[torch.Tensor],
grad_input: tuple[torch.Tensor],
grad_output: tuple[torch.Tensor],
) -> None:
fired_hooks.append(hook_id)
self.assertEqual(id(module), id(expected_module))
@ -97,11 +97,11 @@ def full_backward_hook(
def full_backward_pre_hook(
self: TestCase,
fired_hooks: List[int],
fired_hooks: list[int],
expected_module: nn.Module,
hook_id: int,
module: nn.Module,
grad_input: Tuple[torch.Tensor],
grad_input: tuple[torch.Tensor],
) -> None:
fired_hooks.append(hook_id)
self.assertEqual(id(module), id(expected_module))
@ -122,8 +122,8 @@ class KwargModel(nn.Module):
def internal_forward_hook(
self,
module: nn.Module,
args: Tuple[torch.Tensor],
kwargs: Dict[str, Any],
args: tuple[torch.Tensor],
kwargs: dict[str, Any],
out: torch.Tensor,
):
return out + kwargs["bias"]
@ -142,13 +142,13 @@ class FailsInForwardModel(nn.Module):
def kwarg_forward_pre_hook(
self: TestCase,
fired_hooks: List[int],
fired_hooks: list[int],
expected_module: nn.Module,
hook_id: int,
module: nn.Module,
args: Tuple[torch.Tensor],
kwargs: Dict[str, Any],
) -> Tuple[Any, Any]:
args: tuple[torch.Tensor],
kwargs: dict[str, Any],
) -> tuple[Any, Any]:
fired_hooks.append(hook_id)
self.assertEqual(id(module), id(expected_module))
self.assertEqual(len(args), 1)
@ -158,12 +158,12 @@ def kwarg_forward_pre_hook(
def kwarg_forward_hook(
self: TestCase,
fired_hooks: List[int],
fired_hooks: list[int],
expected_module: nn.Module,
hook_id: int,
module: nn.Module,
args: Tuple[torch.Tensor],
kwargs: Dict[str, Any],
args: tuple[torch.Tensor],
kwargs: dict[str, Any],
out: torch.Tensor,
) -> Any:
fired_hooks.append(hook_id)
@ -188,7 +188,7 @@ class DummyContextManager:
class TestModuleHooks(TestCase):
@parametrize_test("named_tuple", (True, False))
def test_forward_hooks(self, named_tuple):
fired_hooks: List[int] = []
fired_hooks: list[int] = []
model = ToyModel(named_tuple)
x = torch.randn(10, 10)
hook = partial(forward_hook, self, fired_hooks, model.net1.seq2)
@ -210,7 +210,7 @@ class TestModuleHooks(TestCase):
@parametrize_test("named_tuple", (True, False))
def test_forward_pre_hooks(self, named_tuple):
fired_hooks: List[int] = []
fired_hooks: list[int] = []
model = ToyModel(named_tuple)
x = torch.randn(10, 10)
hook = partial(forward_pre_hook, self, fired_hooks, model.net2.seq1)
@ -232,7 +232,7 @@ class TestModuleHooks(TestCase):
@parametrize_test("named_tuple", (True, False))
def test_full_backward_hooks(self, named_tuple):
fired_hooks: List[int] = []
fired_hooks: list[int] = []
model = ToyModel(named_tuple)
x = torch.randn(10, 10)
hook = partial(full_backward_hook, self, fired_hooks, model.net1)
@ -254,7 +254,7 @@ class TestModuleHooks(TestCase):
@parametrize_test("named_tuple", (True, False))
def test_full_backward_pre_hooks(self, named_tuple):
fired_hooks: List[int] = []
fired_hooks: list[int] = []
model = ToyModel(named_tuple)
x = torch.randn(10, 10)
hook = partial(full_backward_pre_hook, self, fired_hooks, model.net1)
@ -294,7 +294,7 @@ class TestModuleHooks(TestCase):
@parametrize_test("named_tuple", (True, False))
def test_mixed_hooks(self, named_tuple):
fired_hooks: List[int] = []
fired_hooks: list[int] = []
model = ToyModel(named_tuple)
x = torch.randn(10, 10)
model.register_forward_pre_hook(
@ -319,7 +319,7 @@ class TestModuleHooks(TestCase):
def test_kwarg_hooks(self):
# 1. test forward pre hook
fired_hooks: List[int] = []
fired_hooks: list[int] = []
x: torch.Tensor = torch.ones(10, 10)
bias: torch.Tensor = torch.ones(10, 10)
model = KwargModel()
@ -336,7 +336,7 @@ class TestModuleHooks(TestCase):
self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5)
# 2. test forward pre and forward hooks
fired_hooks: List[int] = []
fired_hooks: list[int] = []
x: torch.Tensor = torch.ones(10, 10)
bias: torch.Tensor = torch.ones(10, 10)
model = KwargModel()
@ -372,7 +372,7 @@ class TestModuleHooks(TestCase):
def test_remove_kwarg_hooks(self):
# test forward pre and forward hooks
fired_hooks: List[int] = []
fired_hooks: list[int] = []
x: torch.Tensor = torch.ones(10, 10)
bias: torch.Tensor = torch.ones(10, 10)
model = KwargModel()
@ -1217,8 +1217,8 @@ class TestModuleGlobalHooks(TestCase):
def test_module_global_hooks_with_kwargs(self):
def kwarg_global_forward_hook(
module: nn.Module,
args: Tuple[torch.Tensor],
kwargs: Dict[str, Any],
args: tuple[torch.Tensor],
kwargs: dict[str, Any],
out: torch.Tensor,
) -> Any:
out = out + kwargs["bias"]

View File

@ -2,7 +2,6 @@
import itertools
import random
from typing import List
import torch
import torch.nn.utils.rnn as rnn_utils
@ -219,7 +218,7 @@ class PackedSequenceTest(TestCase):
# more dimensions
maxlen = 9
for num_dim in (0, 1, 2, 3):
sequences: List[torch.Tensor] = []
sequences: list[torch.Tensor] = []
trailing_dims = [4] * num_dim
for i in range(1, maxlen + 1):
seq_len = i * i

View File

@ -1573,10 +1573,10 @@ torch.cuda.synchronize()
return torch.stack([col, col + 2], 1).view(2, 2, 2, 2)
if adaptive:
cls_name = "AdaptiveMaxPool{}d".format(num_dim) # noqa: UP032
cls_name = f"AdaptiveMaxPool{num_dim}d"
else:
# FIXME(#105716): Test fails when using f-string
cls_name = "MaxPool{}d".format(num_dim) # noqa: UP032
cls_name = f"MaxPool{num_dim}d"
module_cls = getattr(nn, cls_name)
module = module_cls(2, return_indices=True).to(device, dtype=dtype)
numel = 4 ** (num_dim + 1)

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: onnx"]
"""Unit tests for the internal registration wrapper module."""
from typing import Sequence
from collections.abc import Sequence
from torch.onnx import errors
from torch.onnx._internal import registration

View File

@ -10,19 +10,8 @@ import logging
import os
import unittest
import warnings
from typing import (
Any,
Callable,
Collection,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from collections.abc import Collection, Iterable, Mapping, Sequence
from typing import Any, Callable, List, Optional, Tuple, Type, Union
import numpy as np
import onnxruntime

View File

@ -3,7 +3,8 @@
import os
import unittest
from collections import OrderedDict
from typing import List, Mapping, Tuple
from collections.abc import Mapping
from typing import List, Tuple
import onnx_test_common
import parameterized

View File

@ -10,7 +10,7 @@ import itertools
import unittest
import unittest.mock
import warnings
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import numpy as np
@ -26,6 +26,10 @@ from torch.onnx._internal import registration
from torch.testing._internal import common_quantization, common_utils, jit_utils
if TYPE_CHECKING:
from collections.abc import Iterable
def export_to_onnx(
model: Union[torch.nn.Module, torch.jit.ScriptFunction],
input: Union[torch.Tensor, Tuple[torch.Tensor]],

View File

@ -1,6 +1,6 @@
# Owner(s): ["oncall: package/deploy"]
from typing import Iterable
from collections.abc import Iterable
from torch.package import GlobGroup
from torch.testing._internal.common_utils import run_tests

View File

@ -18,7 +18,7 @@ import os
import sys
import tempfile
import unittest
from typing import Any, Dict, List
from typing import Any
import torch
import torch.nn as nn
@ -51,7 +51,7 @@ from torch.testing._internal.common_utils import (
from torch.utils._triton import has_triton
Json = Dict[str, Any]
Json = dict[str, Any]
class TestExecutionTrace(TestCase):
@ -97,7 +97,7 @@ class TestExecutionTrace(TestCase):
nodes = et_graph["nodes"]
return nodes
def get_execution_trace_rf_ids(self, nodes: List[Json]) -> List[int]:
def get_execution_trace_rf_ids(self, nodes: list[Json]) -> list[int]:
"""Returns a sorted list of rf_id (record function ids) in execution trace"""
def get_rf_id(node):
@ -115,7 +115,7 @@ class TestExecutionTrace(TestCase):
)
return sorted(rf_id for rf_id in rf_ids_ if rf_id is not None)
def get_kineto_rf_ids(self, events: List[Json]) -> List[int]:
def get_kineto_rf_ids(self, events: list[Json]) -> list[int]:
"""Returns a sorted list of Record function IDs for CPU operators and user annotations"""
ops_and_annotations = (
e for e in events if e.get("cat", "") in ["cpu_op", "user_annotation"]

View File

@ -5,7 +5,8 @@ import itertools as it
import sys
import textwrap
import unittest
from typing import Callable, Dict, Iterator, List, Optional, Tuple
from collections.abc import Iterator
from typing import Callable, Optional
import torch
from torch._C._profiler import _EventType, _TensorMetadata
@ -309,9 +310,9 @@ class TestDataFlow(TestCase):
@staticmethod
def formatSchemas(
prof: torch.profiler.profile, indent: int = 12
) -> Tuple[Tuple[str, Tuple[bool, ...]], ...]:
) -> tuple[tuple[str, tuple[bool, ...]], ...]:
tree = prof.profiler.kineto_results.experimental_event_tree()
out: List[Tuple[str, Tuple[bool, ...]]] = []
out: list[tuple[str, tuple[bool, ...]]] = []
for node in _utils.traverse_dfs(tree):
if node.tag == _EventType.TorchOp:
e = node.extra_fields
@ -327,8 +328,8 @@ class TestDataFlow(TestCase):
@staticmethod
def _run_and_format_data_flow(
inputs: Dict[str, torch.Tensor],
f: Callable[..., Optional[Dict[str, torch.Tensor]]],
inputs: dict[str, torch.Tensor],
f: Callable[..., Optional[dict[str, torch.Tensor]]],
indent: int = 12,
) -> str:
with profile() as prof:
@ -339,7 +340,7 @@ class TestDataFlow(TestCase):
graph = memory_profile._data_flow_graph
storage_to_id = {key.storage.ptr: key.id for key in graph._active_version}
lines: List[str] = []
lines: list[str] = []
for name, t in it.chain(inputs.items(), outputs.items()):
lines.append(f"{name + ':':<8} T{storage_to_id[t.storage().data_ptr()]}")
if t.grad is not None:
@ -352,7 +353,7 @@ class TestDataFlow(TestCase):
for node in graph.flow_nodes:
destroyed = {k for k, v in node._edges.items() if v.is_deletion}
inputs: List[str] = []
inputs: list[str] = []
for key, (_, v) in node.inputs.items():
inputs.append(f"T{key.id}(v{v}{'*' if key in destroyed else ''})")
@ -833,7 +834,7 @@ class TestMemoryProfilerE2E(TestCase):
@staticmethod
def _lookup_tensor_categories(
t: torch.Tensor, memory_profile: _memory_profiler.MemoryProfile
) -> Dict[_memory_profiler.TensorAndID, Optional[_memory_profiler.Category]]:
) -> dict[_memory_profiler.TensorAndID, Optional[_memory_profiler.Category]]:
storage = t.storage()
if storage is None:
raise ValueError("Cannot look up uninitialized Tensor.")
@ -889,7 +890,7 @@ class TestMemoryProfilerE2E(TestCase):
fn(lambda name: record_ops.mark_region(f"-- {name} ".ljust(105, "-")))
memory_profile = prof._memory_profile()
ptr_pair_to_key: Dict[Tuple[int, int], _memory_profiler.TensorKey] = {}
ptr_pair_to_key: dict[tuple[int, int], _memory_profiler.TensorKey] = {}
snapshot = memory_profile._category_snapshot()
# Build map from observed live Tensors to the memory profiler's
@ -922,7 +923,7 @@ class TestMemoryProfilerE2E(TestCase):
return f"{target_key.storage.allocation_id} ({','.join(categories)})"
out: List[str] = []
out: list[str] = []
for name, inputs, outputs in record_ops.results:
if inputs or outputs:
# PyTorch ops

View File

@ -17,7 +17,7 @@ import threading
import time
import unittest
from dataclasses import dataclass, field
from typing import List, Optional
from typing import Optional
from unittest.mock import patch
import expecttest
@ -2277,7 +2277,7 @@ class MockProfilerEvent:
start_time_ns: int
duration_time_ns: int
correlation_id: int = 0
children: List["MockProfilerEvent"] = field(default_factory=list)
children: list["MockProfilerEvent"] = field(default_factory=list)
parent: Optional["MockProfilerEvent"] = None
@property
@ -2301,7 +2301,7 @@ class MockNode:
@unittest.skipIf(sys.version_info >= (3, 13), "segfaults")
class TestExperimentalUtils(TestCase):
def make_tree(self) -> List[MockNode]:
def make_tree(self) -> list[MockNode]:
tree = {
"root_0": {
"1": {"2": {}},

View File

@ -14,7 +14,7 @@ try:
except ImportError:
None
from typing import Any, Dict
from typing import Any
import torch
import torch.optim
@ -29,7 +29,7 @@ from torch.profiler import kineto_available, record_function
from torch.testing._internal.common_utils import run_tests, TestCase
Json = Dict[str, Any]
Json = dict[str, Any]
class TestRecordFunction(TestCase):

View File

@ -19,7 +19,7 @@ import sys
import textwrap
import unittest
import weakref
from typing import Any, Dict, List
from typing import Any
import torch
import torch.nn as nn
@ -30,7 +30,7 @@ from torch.profiler import _utils, profile
from torch.testing._internal.common_utils import run_tests, TestCase
Json = Dict[str, Any]
Json = dict[str, Any]
from torch._C._profiler import _ExtraFields_PyCall
@ -455,7 +455,7 @@ class TestTorchTidyProfiler(TestCase):
nodes = p.profiler.kineto_results.experimental_event_tree()
def find_chain(names: List[str]):
def find_chain(names: list[str]):
out = []
for name in names:
root = [out[-1]] if out else nodes

View File

@ -25,7 +25,7 @@ from copy import deepcopy
from functools import partial, reduce
from itertools import product
from operator import mul
from typing import List, Tuple, TYPE_CHECKING
from typing import TYPE_CHECKING
import torch
import torch.autograd._functions
@ -10091,14 +10091,14 @@ TORCH_LIBRARY(test_multigrad_all_hooks, m) {
def test_multi_grad_any_hooks(self):
hook_id = 0
any_hook_handles: List[RemovableHandle] = []
any_hook_handles: list[RemovableHandle] = []
class MultiOutputModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.lin = nn.Linear(3, 3)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
z = self.lin(x)
out = torch.sin(z), torch.cos(z)
nonlocal hook_id
@ -10123,7 +10123,7 @@ TORCH_LIBRARY(test_multigrad_all_hooks, m) {
z = y[0] + y[1]
return self.mod2(z)
hook_order: List[int] = []
hook_order: list[int] = []
hook_count = 0
def hook(hook_id: int, *unused):
@ -13975,7 +13975,7 @@ class TestSelectiveActivationCheckpoint(TestCase):
counter = [0]
@torch.library.custom_op("mylib::sin_with_extra", mutates_args=())
def sin_with_extra(x: torch.Tensor) -> Tuple[torch.Tensor, int]:
def sin_with_extra(x: torch.Tensor) -> tuple[torch.Tensor, int]:
counter[0] += 1
return x.sin(), 2

View File

@ -4,7 +4,7 @@
import io
import textwrap
from typing import Dict, List, Optional
from typing import Optional
import torch
import torch.utils.bundled_inputs
@ -32,7 +32,7 @@ class TestBundledInputs(TestCase):
sm = torch.jit.script(SingleTensorModel())
original_size = model_size(sm)
get_expr: List[str] = []
get_expr: list[str] = []
samples = [
# Tensor with small numel and small storage.
(torch.tensor([1]),),
@ -328,8 +328,8 @@ class TestBundledInputs(TestCase):
class MyModel(torch.nn.Module):
def forward(
self,
arg1: Optional[Dict[str, torch.Tensor]],
arg2: Optional[List[torch.Tensor]],
arg1: Optional[dict[str, torch.Tensor]],
arg2: Optional[list[torch.Tensor]],
arg3: torch.Tensor,
):
if arg1 is None:
@ -393,7 +393,7 @@ class TestBundledInputs(TestCase):
""",
)
out: List[str] = []
out: list[str] = []
sm = torch.jit.script(MyModel())
original_size = model_size(sm)
small_inputs = (

View File

@ -4054,11 +4054,7 @@ class TestCudaMallocAsync(TestCase):
that the pytorch call is returning a correct list of UUIDs.
"""
cmd = "rocminfo | grep -o 'Uuid:.*GPU-.*' | sed 's/Uuid:.*GPU-//'"
uuids = (
subprocess.check_output(cmd, shell=True, universal_newlines=True)
.strip()
.split("\n")
)
uuids = subprocess.check_output(cmd, shell=True, text=True).strip().split("\n")
uuids = [s.strip() for s in uuids]
raw_uuids = torch.cuda._raw_device_uuid_amdsmi()
for uuid in uuids:
@ -4082,11 +4078,7 @@ import os
print(f"{torch.cuda.device_count()}")
"""
cmd = "rocminfo | grep -o 'Uuid:.*GPU-.*' | sed 's/Uuid://'"
uuids = (
subprocess.check_output(cmd, shell=True, universal_newlines=True)
.strip()
.split("\n")
)
uuids = subprocess.check_output(cmd, shell=True, text=True).strip().split("\n")
uuids = [s.strip() for s in uuids]
custom_envs = []

View File

@ -3,7 +3,7 @@
import sys
import textwrap
import traceback
from typing import List, Optional
from typing import Optional
import torch
import torch.cuda._sanitizer as csan
@ -148,9 +148,9 @@ class TestEventHandler(TestCase):
def kernel_launch(
self,
stream: StreamId,
read_only: Optional[List[DataPtr]] = None,
read_write: Optional[List[DataPtr]] = None,
) -> List[csan.SynchronizationError]:
read_only: Optional[list[DataPtr]] = None,
read_write: Optional[list[DataPtr]] = None,
) -> list[csan.SynchronizationError]:
if read_only is None:
read_only = []
if read_write is None:
@ -167,8 +167,8 @@ class TestEventHandler(TestCase):
def assert_good_kernel_launch(
self,
stream: StreamId,
read_only: Optional[List[DataPtr]] = None,
read_write: Optional[List[DataPtr]] = None,
read_only: Optional[list[DataPtr]] = None,
read_write: Optional[list[DataPtr]] = None,
) -> None:
self.assertEqual(self.kernel_launch(stream, read_only, read_write), [])
@ -176,8 +176,8 @@ class TestEventHandler(TestCase):
self,
number_of_errors: int,
stream: StreamId,
read_only: Optional[List[DataPtr]] = None,
read_write: Optional[List[DataPtr]] = None,
read_only: Optional[list[DataPtr]] = None,
read_write: Optional[list[DataPtr]] = None,
) -> None:
errors = self.kernel_launch(stream, read_only, read_write)
self.assertEqual(len(errors), number_of_errors)

View File

@ -1262,14 +1262,12 @@ class TestDataLoader(TestCase):
list(iter(loader))
def test_typing(self):
from typing import List
# Make sure there is no TypeError
class SomeDatasetClass(Dataset[List[torch.Tensor]]):
class SomeDatasetClass(Dataset[list[torch.Tensor]]):
pass
def _create_dataloader(is_train: bool) -> DataLoader[List[torch.Tensor]]:
def _create_dataloader(is_train: bool) -> DataLoader[list[torch.Tensor]]:
pass
@unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")

View File

@ -1,7 +1,7 @@
# Owner(s): ["oncall: distributed"]
import unittest
from typing import List, Optional, Tuple
from typing import Optional
import torch
import torch.distributed
@ -27,9 +27,9 @@ class MyModule(torch.nn.Module):
class MyDummyFnOptimizer:
def __init__(
self,
params: List[Tensor],
params: list[Tensor],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.0,
_allow_empty_param_list: bool = False,
@ -63,7 +63,7 @@ class MyDummyFnOptimizer:
"MyDummyFnOptimizer does not support step_param() as of now"
)
def step(self, gradients: List[Optional[Tensor]]):
def step(self, gradients: list[Optional[Tensor]]):
# call the custom optimizer step implementation
with torch.no_grad():
raise RuntimeError("MyDummyFnOptimizer does not support step() as of now")

View File

@ -1200,6 +1200,15 @@ class {test_classname}(torch.nn.Module):
inp3_y = inp3.y
return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y
class MyModule2(torch.nn.Module):
def forward(self, inp: tuple[CustomType, torch.Tensor], inp2: list[CustomType], inp3: CustomNamedTuple):
inp_0 = inp[0]
inp_1 = inp[1]
inp2_0 = inp2[0]
inp3_x = inp3.x
inp3_y = inp3.y
return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y
my_module = MyModule()
my_module_traced = torch.fx.symbolic_trace(my_module)
@ -1214,6 +1223,20 @@ class {test_classname}(torch.nn.Module):
if node.target == operator.getitem:
self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.")
my_module = MyModule2()
my_module_traced = torch.fx.symbolic_trace(my_module)
# by default, fx transform loses type annotation of getitem nodes.
for node in my_module_traced.graph.nodes:
if node.target == operator.getitem:
assert node.type is None
annotate_getitem_nodes(my_module_traced.graph)
for node in my_module_traced.graph.nodes:
if node.target == operator.getitem:
self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.")
def test_subgraph_uniquename(self):
class MyModule(torch.nn.Module):
def __init__(self) -> None:

View File

@ -5,7 +5,7 @@
import itertools
import torch
from typing import List, Any
from typing import Any
from functools import wraps
import unittest
from torch.testing._internal.common_utils import skipIfTorchDynamo
@ -100,7 +100,7 @@ def apply_masked_reduction_along_dim(op, input, *args, **kwargs):
# dimensions along which the reduction operation is applied:
dim_ = torch.masked._canonical_dim(dim, input.ndim)
# slices in product(*ranges) define all elementary slices:
ranges: List[Any] = []
ranges: list[Any] = []
# shape of output for the keepdim=True case:
shape = []
for i in range(input.ndim):

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: mkldnn"]
import itertools
import unittest
from typing import NamedTuple, List
from typing import NamedTuple
import torch
from torch import nn
@ -16,7 +16,7 @@ FUSION_GROUP = 'prim::TensorExprGroup'
class PointwisePostOp(NamedTuple):
attr : str
pointwise_module : nn.Module
scalars : List = []
scalars : list = []
algorithm : str = ""
CONV_MODULES = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}

View File

@ -1,6 +1,6 @@
# Owner(s): ["module: unknown"]
from typing import Optional, List
from typing import Optional
import torch
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
@ -8,12 +8,12 @@ from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorc
class FloatListWrapperModule(torch.nn.Module):
def forward(self, values, incr: Optional[List[float]]):
def forward(self, values, incr: Optional[list[float]]):
return torch._C._nn._test_optional_floatlist(values, incr)
class IntListWrapperModule(torch.nn.Module):
def forward(self, values, incr: Optional[List[int]]):
def forward(self, values, incr: Optional[list[int]]):
return torch._C._nn._test_optional_intlist(values, incr)

View File

@ -9,7 +9,7 @@ import sys
import tempfile
import unittest
from functools import partial
from typing import Optional, Tuple
from typing import Optional
import numpy as np
@ -3545,7 +3545,7 @@ def get_tolerances(
true_value: torch.Tensor,
computed_value: torch.Tensor,
fudge_factor: Optional[float] = None,
) -> Tuple[float, float]:
) -> tuple[float, float]:
"""Returns the absolute and relative tolerances for comparing two tensors."""
fudge_factor = fudge_factor if fudge_factor is not None else 1.0
atol = get_atol(true_value, computed_value)

View File

@ -4,7 +4,6 @@
import ctypes
import os
import unittest
from typing import Tuple
import torch
from torch.backends._nnapi.prepare import convert_model_to_nnapi
@ -700,7 +699,7 @@ class TestNNAPI(TestCase):
def test_multi_output(self):
class MultiModel(torch.nn.Module):
def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, lhs, rhs) -> tuple[torch.Tensor, torch.Tensor]:
the_sum = lhs + rhs
the_diff = lhs - rhs
return the_sum, the_diff

View File

@ -278,7 +278,7 @@ class TestNumPyInterop(TestCase):
def test_from_numpy_no_leak_on_invalid_dtype(self):
# This used to leak memory as the `from_numpy` call raised an exception and didn't decref the temporary
# object. See https://github.com/pytorch/pytorch/issues/121138
x = np.array("value".encode("ascii"))
x = np.array(b"value")
for _ in range(1000):
try:
torch.from_numpy(x)

View File

@ -11,7 +11,6 @@ from collections import defaultdict
from collections.abc import Sequence
from functools import partial
from importlib import import_module
from typing import Dict, List
import torch
import torch._prims as prims
@ -1483,7 +1482,7 @@ class TestCommon(TestCase):
unsupported_dtypes = set()
supported_backward_dtypes = set()
unsupported_backward_dtypes = set()
dtype_error: Dict[torch.dtype, Exception] = {}
dtype_error: dict[torch.dtype, Exception] = {}
def unsupported(dtype, e):
dtype_error[dtype] = e
@ -1987,7 +1986,7 @@ class TestCompositeCompliance(TestCase):
for sample in op.sample_inputs(device, dtype, requires_grad=False):
inp = sample.input
outs = op(inp, *sample.args, **sample.kwargs)
if not isinstance(outs, (tuple, List)):
if not isinstance(outs, (tuple, list)):
outs = [outs]
# for all outputs that are views of the input, we should be able to replay the

View File

@ -4,7 +4,7 @@ import math
import tempfile
import unittest
from copy import deepcopy
from typing import Any, Dict, Tuple
from typing import Any
from unittest.mock import patch
from optim.test_lrscheduler import TestLRScheduler # noqa: F401
@ -1769,8 +1769,8 @@ class TestOptimRenewed(TestCase):
@staticmethod
def _state_dict_post_hook(
optimizer: Optimizer, state_dict: Dict[str, Any]
) -> Dict[str, Any]:
optimizer: Optimizer, state_dict: dict[str, Any]
) -> dict[str, Any]:
if "test" in state_dict["state"]:
state_dict["state"].pop("test")
state_dict["ran_state_dict_pre_hook"] = True
@ -1821,14 +1821,14 @@ class TestOptimRenewed(TestCase):
@staticmethod
def _load_state_dict_pre_hook1(
optimizer: Optimizer, state_dict: Dict[str, Any]
optimizer: Optimizer, state_dict: dict[str, Any]
) -> None:
state_dict["param_groups"][0]["lr"] = 0.002
@staticmethod
def _load_state_dict_pre_hook2(
optimizer: Optimizer, state_dict: Dict[str, Any]
) -> Dict[str, Any]:
optimizer: Optimizer, state_dict: dict[str, Any]
) -> dict[str, Any]:
# The typical use case for returning a state dict is to drastically modify the state dict.
# I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used
my_state_dict = deepcopy(state_dict)
@ -1906,7 +1906,7 @@ class TestOptimRenewed(TestCase):
@optims(optim_db, dtypes=[torch.float32])
def test_step_post_hook(self, device, dtype, optim_info):
def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
def post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
nonlocal data
data += 2
@ -1938,7 +1938,7 @@ class TestOptimRenewed(TestCase):
@optims(optim_db, dtypes=[torch.float32])
def test_step_pre_hook(self, device, dtype, optim_info):
def pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
def pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
nonlocal data
data += 2
@ -1970,19 +1970,19 @@ class TestOptimRenewed(TestCase):
@optims(optim_db, dtypes=[torch.float32])
def test_step_all_hooks(self, device, dtype, optim_info):
def global_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
def global_pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
nonlocal data
data.append(0)
def global_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
def global_post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
nonlocal data
data.append(5)
def local_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
def local_pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
nonlocal data
data.append(1)
def local_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
def local_post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
nonlocal data
data.append(2)

View File

@ -5,7 +5,7 @@ import torch
import numpy as np
import math
from typing import Dict, List, Sequence
from collections.abc import Sequence
import random
from functools import partial
from itertools import product, combinations, permutations
@ -736,7 +736,7 @@ class TestReductions(TestCase):
# TODO: kill this ane replace with common creation ops
def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True,
use_complex=False) -> Dict[str, List[torch.Tensor]]:
use_complex=False) -> dict[str, list[torch.Tensor]]:
float_types = [torch.double,
torch.float]
int_types = [torch.int64,
@ -778,7 +778,7 @@ class TestReductions(TestCase):
types += int_types
if use_complex:
types += complex_types
tensors: Dict[str, List[torch.Tensor]] = {"cont": [], "noncont": [], "slice": []}
tensors: dict[str, list[torch.Tensor]] = {"cont": [], "noncont": [], "slice": []}
for dtype in types:
tensors["cont"].append(make_contiguous(shape, dtype))
tensors["noncont"].append(make_non_contiguous(shape, dtype))

View File

@ -15,7 +15,7 @@ from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm
skipIfCrossRef
from torch.testing._internal.common_cuda import TEST_CUDA
from numbers import Number
from typing import Dict, Any
from typing import Any
from packaging import version
from torch.testing._internal.common_cuda import \
(SM53OrLater, SM80OrLater, TEST_MULTIGPU)
@ -334,7 +334,7 @@ class TestSparse(TestSparseBase):
self.assertEqual(t._values(), tc._values())
return tc
value_map: Dict[Any, Any] = {}
value_map: dict[Any, Any] = {}
for idx, val in zip(t._indices().t(), t._values()):
idx_tup = tuple(idx.tolist())
if idx_tup in value_map:

View File

@ -21,7 +21,7 @@ from torch.testing._internal.common_methods_invocations import (
from torch.testing._internal.common_cuda import SM53OrLater
from torch._prims_common import corresponding_complex_dtype
from typing import Optional, List
from typing import Optional
from packaging import version
@ -597,7 +597,7 @@ class TestFFT(TestCase):
else:
numpy_fn = getattr(np.fft, fname)
def fn(t: torch.Tensor, s: Optional[List[int]], dim: List[int] = (-2, -1), norm: Optional[str] = None):
def fn(t: torch.Tensor, s: Optional[list[int]], dim: list[int] = (-2, -1), norm: Optional[str] = None):
return torch_fn(t, s, dim, norm)
torch_fns = (torch_fn, torch.jit.script(fn))

View File

@ -2,14 +2,13 @@
# ruff: noqa: F841
import unittest
from typing import Dict, Optional
from typing import Optional
import numpy as np
import torch
from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.static_module import StaticModule
from typing import List
def linear_shim(
@ -108,7 +107,7 @@ def fork_wait_graph2(input1, input2):
:param iters: number of future/wait pairs to be created
"""
def fork_wait_graph3(input, iters: int):
futures : List[torch.jit.Future[torch.Tensor]] = []
futures : list[torch.jit.Future[torch.Tensor]] = []
for _ in range(iters):
futures.append(torch.jit.fork(torch.neg, input))
results = []
@ -123,7 +122,7 @@ def fork_wait_graph3(input, iters: int):
:param num_child_forks: number of child forks per parent fork
"""
def fork_wait_graph4(input, num_forks: int, num_child_forks: int):
futures : List[torch.jit.Future[torch.Tensor]] = []
futures : list[torch.jit.Future[torch.Tensor]] = []
for _ in range(num_forks):
futures.append(torch.jit.fork(fork_wait_graph3, input, num_child_forks))
results = []
@ -150,7 +149,7 @@ def loop_graph(a, b, iters: int):
def output_graph(a, b, c, iters: int):
s = torch.tensor([[3, 3], [3, 3]])
k = a + b * c + s
d: Dict[int, torch.Tensor] = {}
d: dict[int, torch.Tensor] = {}
for i in range(iters):
d[i] = k + i
return d

View File

@ -5,7 +5,7 @@ import itertools
import math
import pickle
import sys
from typing import Callable, List, Tuple, Type
from typing import Callable
import sympy
@ -594,7 +594,7 @@ class TestSympyInterp(TestCase):
self.fail(f"Unexpected error for {fn}{args}: {str(e)}")
def type_name_fn(type: Type) -> str:
def type_name_fn(type: type) -> str:
return type.__name__
@ -606,7 +606,7 @@ def parametrize_relational_types(*types):
class TestSympySolve(TestCase):
def _create_integer_symbols(self) -> List[sympy.Symbol]:
def _create_integer_symbols(self) -> list[sympy.Symbol]:
return sympy.symbols("a b c", integer=True)
def test_give_up(self):
@ -665,9 +665,9 @@ class TestSympySolve(TestCase):
def _test_cases(
self,
cases: List[Tuple[sympy.Basic, sympy.Basic]],
cases: list[tuple[sympy.Basic, sympy.Basic]],
thing: sympy.Basic,
op: Type[sympy.Rel],
op: type[sympy.Rel],
**kwargs,
):
for source, expected in cases:
@ -761,7 +761,7 @@ class TestSympySolve(TestCase):
Le: (Le(FloorDiv(a, pos), integer), (integer + 1) * pos),
}[op]
cases: List[Tuple[sympy.Basic, sympy.Basic]] = [
cases: list[tuple[sympy.Basic, sympy.Basic]] = [
# 'b' is not strictly positive
(op(FloorDiv(a, b), integer), None),
# 'c' is not strictly positive

View File

@ -11,7 +11,7 @@ import unittest
from itertools import product, combinations, combinations_with_replacement, permutations
import random
import tempfile
from typing import Any, Dict, List, Tuple
from typing import Any
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
@ -4125,7 +4125,7 @@ class TestAsArray(TestCase):
def test_default_device(self, device):
original = torch.arange(5)
examples: List[Tuple[Any, Dict]] = [
examples: list[tuple[Any, dict]] = [
(3, {}),
(original, {}),
(to_numpy(original), {}),

View File

@ -12,7 +12,8 @@ import re
import subprocess
import sys
import unittest.mock
from typing import Any, Callable, Iterator, List, Tuple
from typing import Any, Callable
from collections.abc import Iterator
import torch
@ -496,7 +497,7 @@ if __name__ == '__main__':
self.assertNotIn('OK', stderr.decode('ascii'))
def make_assert_close_inputs(actual: Any, expected: Any) -> List[Tuple[Any, Any]]:
def make_assert_close_inputs(actual: Any, expected: Any) -> list[tuple[Any, Any]]:
"""Makes inputs for :func:`torch.testing.assert_close` functions based on two examples.
Args:

View File

@ -53,7 +53,6 @@ from torch.testing._internal.common_device_type import (
dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast,
skipMeta, PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyNativeDeviceTypes, skipCUDAIfNotRocm,
get_all_device_types, skipXLA)
from typing import Tuple
import torch.backends.quantized
import torch.testing._internal.data
from torch.testing._internal.common_cuda import (
@ -3511,7 +3510,7 @@ else:
def _prepare_data_for_index_copy_and_add_deterministic(
self, dim: int, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert (dim >= 0 and dim < 3)
a = [5, 4, 3]
a[dim] = 2000

View File

@ -18,7 +18,7 @@ import math
import itertools
import torch.optim as optim
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU
from typing import List, Tuple, Optional, Dict
from typing import Optional
import torch.utils.cpp_extension
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import (
@ -149,12 +149,12 @@ def _check_equal(
def check_out_and_grad(
out_tuple: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
grad_query_tuple: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
grad_key_tuple: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
grad_value_tuple: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
grad_attn_mask_tuple: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None,
fudge_factors: Optional[Dict[str, float]] = None
out_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
grad_query_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
grad_key_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
grad_value_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
grad_attn_mask_tuple: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None,
fudge_factors: Optional[dict[str, float]] = None
) -> None:
"""
Check output and gradients of attention mechanism tensors.
@ -2574,7 +2574,7 @@ class TestSDPACudaOnly(NNTestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
def test_cudnn_attention_preserves_query_layout(self, device):
def test_attention(backend: SDPBackend, permute_order: List[List[int]]):
def test_attention(backend: SDPBackend, permute_order: list[list[int]]):
BHSqD = [4, 16, 256, 64]
BHSkvD = [4, 16, 512, 64]
@ -2602,7 +2602,7 @@ class TestSDPACudaOnly(NNTestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
@parametrize("mask_dim", [1, 2, 3, 4])
def test_mem_efficient_attention_mask_variants(self, device, mask_dim: List[int]):
def test_mem_efficient_attention_mask_variants(self, device, mask_dim: list[int]):
dtype = torch.float16
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
batch, num_heads, head_dim = 8, 8, 64
@ -3307,7 +3307,7 @@ class TestSDPACudaOnly(NNTestCase):
@tf32_enabled()
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
scale: str, enable_gqa: bool, n_heads: List[int]):
scale: str, enable_gqa: bool, n_heads: list[int]):
if isSM8XDevice and head_dim in range(193, 256 + 1):
self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled")
if is_causal and seq_len_q != seq_len_k:
@ -3905,7 +3905,7 @@ class TestAttnBias(NNTestCase):
"shape",
[(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)],
)
def test_causal_variants(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]):
def test_causal_variants(self, device, causal_variant: CausalVariant, shape: list[tuple[int]]):
make_tensor = partial(
torch.rand, device=device, dtype=torch.float16, requires_grad=True
)
@ -3942,7 +3942,7 @@ class TestAttnBias(NNTestCase):
)
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows")
@skipIfTorchDynamo("This function already calls torch.compile.")
def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]):
def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: list[tuple[int]]):
if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT:
self.skipTest("No support for LOWER_RIGHT variant for now")
return
@ -3975,7 +3975,7 @@ class TestAttnBias(NNTestCase):
self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!")
@parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)])
def test_is_causal_equals_upper_left(self, device, shape: List[Tuple[int]]):
def test_is_causal_equals_upper_left(self, device, shape: list[tuple[int]]):
make_tensor = partial(
torch.rand, device=device, dtype=torch.float16, requires_grad=True
)

View File

@ -8,7 +8,7 @@ import shutil
import unittest
from collections import defaultdict
from threading import Lock
from typing import Dict, IO, List, Optional
from typing import IO, Optional
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
@ -49,12 +49,12 @@ def _strip_filename(msg: str) -> str:
return tail.split(":", 1)[-1]
def _run_mypy() -> Dict[str, List[str]]:
def _run_mypy() -> dict[str, list[str]]:
"""Clears the cache and run mypy before running any of the typing tests."""
if os.path.isdir(CACHE_DIR):
shutil.rmtree(CACHE_DIR)
rc: Dict[str, List[str]] = {}
rc: dict[str, list[str]] = {}
for directory in (REVEAL_DIR, PASS_DIR, FAIL_DIR):
# Run mypy
stdout, stderr, _ = api.run(
@ -119,10 +119,10 @@ def _construct_format_dict():
#: A dictionary with all supported format keys (as keys)
#: and matching values
FORMAT_DICT: Dict[str, str] = _construct_format_dict()
FORMAT_DICT: dict[str, str] = _construct_format_dict()
def _parse_reveals(file: IO[str]) -> List[str]:
def _parse_reveals(file: IO[str]) -> list[str]:
"""Extract and parse all ``" # E: "`` comments from the passed file-like object.
All format keys will be substituted for their respective value from `FORMAT_DICT`,
@ -160,10 +160,10 @@ def _test_reveal(path: str, reveal: str, expected_reveal: str, lineno: int) -> N
@unittest.skipIf(NO_MYPY, reason="Mypy is not installed")
class TestTyping(TestCase):
_lock = Lock()
_cached_output: Optional[Dict[str, List[str]]] = None
_cached_output: Optional[dict[str, list[str]]] = None
@classmethod
def get_mypy_output(cls) -> Dict[str, List[str]]:
def get_mypy_output(cls) -> dict[str, list[str]]:
with cls._lock:
if cls._cached_output is None:
cls._cached_output = _run_mypy()
@ -192,7 +192,7 @@ class TestTyping(TestCase):
with open(path) as fin:
lines = fin.readlines()
errors = defaultdict(lambda: "")
errors = defaultdict(str)
output_mypy = self.get_mypy_output()
self.assertIn(path, output_mypy)

View File

@ -12,7 +12,7 @@ import textwrap
import traceback
import unittest
import warnings
from typing import Any, Dict, List
from typing import Any
import torch
import torch.cuda
@ -439,7 +439,7 @@ class TestCheckpoint(TestCase):
# get de-allocated directly. So using cuda memory usage as a proxy
def _do_test(fn, should_free):
stats: List[int] = []
stats: list[int] = []
def track(x, idx):
# Track that at each step of the backward, some Tensor were
@ -1203,7 +1203,7 @@ def f(x):
return g(x) + 1
"""
out: Dict[str, Any] = {}
out: dict[str, Any] = {}
scope = {"__compile_source__": source}
exec(source, scope, out)

View File

@ -7,7 +7,7 @@ def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
"""
Annotate the type of getitem nodes, inferred from the type of sequence node.
If sequence node is not annotated with a type, do nothing.
Currently support getitem nodes from Tuple, List, and NamedTuple sequence node.
Currently support getitem nodes from tuple, list, and NamedTuple sequence node.
This is helpful since annotations on local names within function are lost during FX transforms.
Adding back known type annotation for getitem nodes to improve jit scriptability.
@ -35,6 +35,21 @@ def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
elif sequence_node.type._name == "List":
assert len(parameterized_types) == 1
node.type = parameterized_types[0]
# Generic Alias Type
elif hasattr(sequence_node.type, "__origin__"):
parameterized_types = sequence_node.type.__args__
if sequence_node.type.__origin__ is tuple:
if len(parameterized_types) == 2 and isinstance(
parameterized_types[1], type(...)
):
node.type = parameterized_types[0]
else:
assert len(parameterized_types) > index_node
node_type = parameterized_types[index_node]
node.type = node_type
elif sequence_node.type.__origin__ is list:
assert len(parameterized_types) == 1
node.type = parameterized_types[0]
# NamedTuple type
elif hasattr(sequence_node.type, "__annotations__"):
if sequence_node.type == torch.Tensor: