mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[4/N] Apply py39 ruff and pyupgrade fixes (#143257)
```torch/fx/passes/annotate_getitem_nodes.py``` was changed to support the new type hinting annotations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143257 Approved by: https://github.com/justinchuby, https://github.com/albanD
This commit is contained in:
parent
a881954b0c
commit
df458be4e5
|
|
@ -1,8 +1,3 @@
|
|||
import dis
|
||||
import inspect
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import functorch._C
|
||||
import torch
|
||||
from functorch._C import dim as _C
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from __future__ import annotations
|
|||
|
||||
import keyword
|
||||
import warnings
|
||||
from typing import List, Optional, Set, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -73,11 +73,11 @@ class ParsedExpression:
|
|||
"""
|
||||
self.has_ellipsis: bool = False
|
||||
self.has_ellipsis_parenthesized: Optional[bool] = None
|
||||
self.identifiers: Set[Union[str, AnonymousAxis]] = set()
|
||||
self.identifiers: set[Union[str, AnonymousAxis]] = set()
|
||||
# that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
|
||||
self.has_non_unitary_anonymous_axes: bool = False
|
||||
# composition keeps structure of composite axes, see how different corner cases are handled in tests
|
||||
self.composition: List[Union[List[Union[str, AnonymousAxis]], str]] = []
|
||||
self.composition: list[Union[list[Union[str, AnonymousAxis]], str]] = []
|
||||
if "." in expression:
|
||||
if "..." not in expression:
|
||||
raise ValueError(
|
||||
|
|
@ -90,7 +90,7 @@ class ParsedExpression:
|
|||
expression = expression.replace("...", _ellipsis)
|
||||
self.has_ellipsis = True
|
||||
|
||||
bracket_group: Optional[List[Union[str, AnonymousAxis]]] = None
|
||||
bracket_group: Optional[list[Union[str, AnonymousAxis]]] = None
|
||||
|
||||
def add_axis_name(x: str) -> None:
|
||||
if x in self.identifiers:
|
||||
|
|
@ -164,7 +164,7 @@ class ParsedExpression:
|
|||
@staticmethod
|
||||
def check_axis_name_return_reason(
|
||||
name: str, allow_underscore: bool = False
|
||||
) -> Tuple[bool, str]:
|
||||
) -> tuple[bool, str]:
|
||||
"""Check if the given axis name is valid, and a message explaining why if not.
|
||||
|
||||
Valid axes names are python identifiers except keywords, and should not start or end with an underscore.
|
||||
|
|
@ -174,7 +174,7 @@ class ParsedExpression:
|
|||
allow_underscore (bool): whether axis names are allowed to start with an underscore
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: whether the axis name is valid, a message explaining why if not
|
||||
tuple[bool, str]: whether the axis name is valid, a message explaining why if not
|
||||
"""
|
||||
if not str.isidentifier(name):
|
||||
return False, "not a valid python identifier"
|
||||
|
|
@ -211,7 +211,7 @@ class ParsedExpression:
|
|||
|
||||
def parse_pattern(
|
||||
pattern: str, axes_lengths: Mapping[str, int]
|
||||
) -> Tuple[ParsedExpression, ParsedExpression]:
|
||||
) -> tuple[ParsedExpression, ParsedExpression]:
|
||||
"""Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object.
|
||||
|
||||
Args:
|
||||
|
|
@ -219,7 +219,7 @@ def parse_pattern(
|
|||
axes_lengths (Mapping[str, int]): any additional length specifications for dimensions
|
||||
|
||||
Returns:
|
||||
Tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions
|
||||
tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions
|
||||
"""
|
||||
# adapted from einops.einops._prepare_transformation_recipe
|
||||
# https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/einops/einops.py
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import Callable, Dict, List, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Callable, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from functorch._C import dim as _C
|
||||
|
|
@ -18,7 +18,6 @@ from ._parsing import (
|
|||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
__all__ = ["rearrange"]
|
||||
|
||||
dims = _C.dims
|
||||
|
|
@ -69,9 +68,9 @@ def _create_rearrange_callable(
|
|||
# an identity rearrangement on a 0-dimension tensor
|
||||
return lambda tensor: tensor
|
||||
|
||||
first_class_dims: Tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))
|
||||
identifier_dim_map: Dict[Union[str, AnonymousAxis], Tuple[str, ...]] = {}
|
||||
anon_axes: List[AnonymousAxis] = []
|
||||
first_class_dims: tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))
|
||||
identifier_dim_map: dict[Union[str, AnonymousAxis], tuple[str, ...]] = {}
|
||||
anon_axes: list[AnonymousAxis] = []
|
||||
|
||||
# map the left-hand side identifiers to strings representing first class dims
|
||||
dims_i = 0
|
||||
|
|
@ -99,11 +98,11 @@ def _create_rearrange_callable(
|
|||
raise ValueError(f"Unexpected dimension: {dimension}")
|
||||
|
||||
def composition_to_dims(
|
||||
composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]],
|
||||
) -> List[Union[str, Tuple[str, ...]]]:
|
||||
composition: Sequence[Union[list[Union[str, AnonymousAxis]], str]],
|
||||
) -> list[Union[str, tuple[str, ...]]]:
|
||||
"""Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first
|
||||
class dims."""
|
||||
dim_composition: List[Union[str, Tuple[str, ...]]] = []
|
||||
dim_composition: list[Union[str, tuple[str, ...]]] = []
|
||||
for dimension in composition:
|
||||
if isinstance(dimension, list):
|
||||
dim_composition.append(
|
||||
|
|
@ -152,7 +151,7 @@ def _create_rearrange_callable(
|
|||
|
||||
|
||||
def rearrange(
|
||||
tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
|
||||
tensor: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]],
|
||||
pattern: str,
|
||||
**axes_lengths: int,
|
||||
) -> torch.Tensor:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import copy
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -247,7 +246,7 @@ class TestActivationSparsifier(TestCase):
|
|||
assert mask2 is None
|
||||
else:
|
||||
assert type(mask1) == type(mask2)
|
||||
if isinstance(mask1, List):
|
||||
if isinstance(mask1, list):
|
||||
assert len(mask1) == len(mask2)
|
||||
for idx in range(len(mask1)):
|
||||
assert torch.all(mask1[idx] == mask2[idx])
|
||||
|
|
@ -258,7 +257,7 @@ class TestActivationSparsifier(TestCase):
|
|||
for state in state_dict["state"].values():
|
||||
mask = state["mask"]
|
||||
if mask is not None:
|
||||
if isinstance(mask, List):
|
||||
if isinstance(mask, list):
|
||||
for idx in range(len(mask)):
|
||||
assert mask[idx].is_sparse
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
import copy
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
|
@ -73,7 +72,7 @@ class TestBaseDataScheduler(TestCase):
|
|||
|
||||
def _get_name_data_config(self, some_data, defaults):
|
||||
config = copy.deepcopy(defaults)
|
||||
if isinstance(some_data, Tuple):
|
||||
if isinstance(some_data, tuple):
|
||||
# dealing with data_list
|
||||
name, data = some_data
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import copy
|
|||
import itertools
|
||||
import logging
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
|
@ -54,7 +53,7 @@ class _BaseDataSparsiferTestCase(TestCase):
|
|||
|
||||
@staticmethod
|
||||
def _get_name_data_config(some_data, defaults=None):
|
||||
if isinstance(some_data, Tuple):
|
||||
if isinstance(some_data, tuple):
|
||||
# dealing with data_list
|
||||
name, data = some_data
|
||||
config = defaults
|
||||
|
|
@ -482,8 +481,9 @@ class TestBaseDataSparsifier(_BaseDataSparsiferTestCase):
|
|||
nn.Parameter(torch.randn(4, 4)),
|
||||
nn.Parameter(torch.randn(5, 5)),
|
||||
)
|
||||
param4, param5 = nn.Parameter(torch.randn(1, 1)), nn.Parameter(
|
||||
torch.randn(4, 4)
|
||||
param4, param5 = (
|
||||
nn.Parameter(torch.randn(1, 1)),
|
||||
nn.Parameter(torch.randn(4, 4)),
|
||||
)
|
||||
data_list = [("param1", param1), ("param2", param2), ("param3", param3)]
|
||||
defaults = {"test": 3}
|
||||
|
|
@ -585,8 +585,9 @@ class TestNormDataSparsifiers(_NormDataSparsifierTestCase):
|
|||
nn.Parameter(torch.randn(4, 4)),
|
||||
nn.Parameter(torch.randn(5, 5)),
|
||||
)
|
||||
param4, param5 = nn.Parameter(torch.randn(10, 10)), nn.Parameter(
|
||||
torch.randn(4, 4)
|
||||
param4, param5 = (
|
||||
nn.Parameter(torch.randn(10, 10)),
|
||||
nn.Parameter(torch.randn(4, 4)),
|
||||
)
|
||||
data_list = [("param1", param1), ("param2", param2), ("param3", param3)]
|
||||
defaults = {
|
||||
|
|
|
|||
|
|
@ -2,13 +2,17 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from typing import List, Optional, Sequence, Union # noqa: F401
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor, types
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
mutates_args = {}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ import functools
|
|||
import itertools
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Union
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
from typing import Any, Callable, Dict, Optional, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from unittest import skip
|
||||
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -10,8 +10,9 @@ import inspect
|
|||
import io
|
||||
import operator
|
||||
import unittest
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Sequence
|
||||
from typing import Dict, List
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -21,10 +21,11 @@ import warnings
|
|||
import weakref
|
||||
from abc import ABC
|
||||
from collections import namedtuple
|
||||
from collections.abc import Iterator
|
||||
from copy import deepcopy
|
||||
from enum import Enum, IntEnum
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Iterator, List, Literal, Tuple, TypedDict
|
||||
from typing import Any, Dict, List, Literal, Tuple, TypedDict
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# Owner(s): ["oncall: export"]
|
||||
import unittest
|
||||
from typing import Any, Dict, Optional, OrderedDict, Tuple
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch._export.passes.lift_constants_pass import (
|
||||
|
|
|
|||
|
|
@ -14,8 +14,7 @@ import random
|
|||
import types
|
||||
import unittest
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from typing import OrderedDict
|
||||
from collections import namedtuple, OrderedDict
|
||||
from unittest.case import skipIf
|
||||
|
||||
from common_utils import (
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# Owner(s): ["module: fx"]
|
||||
|
||||
import unittest
|
||||
from typing import Mapping
|
||||
from collections.abc import Mapping
|
||||
|
||||
import torch
|
||||
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@
|
|||
import functools
|
||||
import itertools
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
from unittest import skip
|
||||
|
||||
import yaml
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from collections import namedtuple, OrderedDict
|
|||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -55,11 +55,11 @@ class ToyModel(nn.Module):
|
|||
|
||||
def forward_hook(
|
||||
self: TestCase,
|
||||
fired_hooks: List[int],
|
||||
fired_hooks: list[int],
|
||||
expected_module: nn.Module,
|
||||
hook_id: int,
|
||||
module: nn.Module,
|
||||
inp: Tuple[torch.Tensor],
|
||||
inp: tuple[torch.Tensor],
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
fired_hooks.append(hook_id)
|
||||
|
|
@ -69,11 +69,11 @@ def forward_hook(
|
|||
|
||||
def forward_pre_hook(
|
||||
self: TestCase,
|
||||
fired_hooks: List[int],
|
||||
fired_hooks: list[int],
|
||||
expected_module: nn.Module,
|
||||
hook_id: int,
|
||||
module: nn.Module,
|
||||
inp: Tuple[torch.Tensor],
|
||||
inp: tuple[torch.Tensor],
|
||||
) -> None:
|
||||
fired_hooks.append(hook_id)
|
||||
self.assertEqual(id(module), id(expected_module))
|
||||
|
|
@ -82,12 +82,12 @@ def forward_pre_hook(
|
|||
|
||||
def full_backward_hook(
|
||||
self: TestCase,
|
||||
fired_hooks: List[int],
|
||||
fired_hooks: list[int],
|
||||
expected_module: nn.Module,
|
||||
hook_id: int,
|
||||
module: nn.Module,
|
||||
grad_input: Tuple[torch.Tensor],
|
||||
grad_output: Tuple[torch.Tensor],
|
||||
grad_input: tuple[torch.Tensor],
|
||||
grad_output: tuple[torch.Tensor],
|
||||
) -> None:
|
||||
fired_hooks.append(hook_id)
|
||||
self.assertEqual(id(module), id(expected_module))
|
||||
|
|
@ -97,11 +97,11 @@ def full_backward_hook(
|
|||
|
||||
def full_backward_pre_hook(
|
||||
self: TestCase,
|
||||
fired_hooks: List[int],
|
||||
fired_hooks: list[int],
|
||||
expected_module: nn.Module,
|
||||
hook_id: int,
|
||||
module: nn.Module,
|
||||
grad_input: Tuple[torch.Tensor],
|
||||
grad_input: tuple[torch.Tensor],
|
||||
) -> None:
|
||||
fired_hooks.append(hook_id)
|
||||
self.assertEqual(id(module), id(expected_module))
|
||||
|
|
@ -122,8 +122,8 @@ class KwargModel(nn.Module):
|
|||
def internal_forward_hook(
|
||||
self,
|
||||
module: nn.Module,
|
||||
args: Tuple[torch.Tensor],
|
||||
kwargs: Dict[str, Any],
|
||||
args: tuple[torch.Tensor],
|
||||
kwargs: dict[str, Any],
|
||||
out: torch.Tensor,
|
||||
):
|
||||
return out + kwargs["bias"]
|
||||
|
|
@ -142,13 +142,13 @@ class FailsInForwardModel(nn.Module):
|
|||
|
||||
def kwarg_forward_pre_hook(
|
||||
self: TestCase,
|
||||
fired_hooks: List[int],
|
||||
fired_hooks: list[int],
|
||||
expected_module: nn.Module,
|
||||
hook_id: int,
|
||||
module: nn.Module,
|
||||
args: Tuple[torch.Tensor],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Tuple[Any, Any]:
|
||||
args: tuple[torch.Tensor],
|
||||
kwargs: dict[str, Any],
|
||||
) -> tuple[Any, Any]:
|
||||
fired_hooks.append(hook_id)
|
||||
self.assertEqual(id(module), id(expected_module))
|
||||
self.assertEqual(len(args), 1)
|
||||
|
|
@ -158,12 +158,12 @@ def kwarg_forward_pre_hook(
|
|||
|
||||
def kwarg_forward_hook(
|
||||
self: TestCase,
|
||||
fired_hooks: List[int],
|
||||
fired_hooks: list[int],
|
||||
expected_module: nn.Module,
|
||||
hook_id: int,
|
||||
module: nn.Module,
|
||||
args: Tuple[torch.Tensor],
|
||||
kwargs: Dict[str, Any],
|
||||
args: tuple[torch.Tensor],
|
||||
kwargs: dict[str, Any],
|
||||
out: torch.Tensor,
|
||||
) -> Any:
|
||||
fired_hooks.append(hook_id)
|
||||
|
|
@ -188,7 +188,7 @@ class DummyContextManager:
|
|||
class TestModuleHooks(TestCase):
|
||||
@parametrize_test("named_tuple", (True, False))
|
||||
def test_forward_hooks(self, named_tuple):
|
||||
fired_hooks: List[int] = []
|
||||
fired_hooks: list[int] = []
|
||||
model = ToyModel(named_tuple)
|
||||
x = torch.randn(10, 10)
|
||||
hook = partial(forward_hook, self, fired_hooks, model.net1.seq2)
|
||||
|
|
@ -210,7 +210,7 @@ class TestModuleHooks(TestCase):
|
|||
|
||||
@parametrize_test("named_tuple", (True, False))
|
||||
def test_forward_pre_hooks(self, named_tuple):
|
||||
fired_hooks: List[int] = []
|
||||
fired_hooks: list[int] = []
|
||||
model = ToyModel(named_tuple)
|
||||
x = torch.randn(10, 10)
|
||||
hook = partial(forward_pre_hook, self, fired_hooks, model.net2.seq1)
|
||||
|
|
@ -232,7 +232,7 @@ class TestModuleHooks(TestCase):
|
|||
|
||||
@parametrize_test("named_tuple", (True, False))
|
||||
def test_full_backward_hooks(self, named_tuple):
|
||||
fired_hooks: List[int] = []
|
||||
fired_hooks: list[int] = []
|
||||
model = ToyModel(named_tuple)
|
||||
x = torch.randn(10, 10)
|
||||
hook = partial(full_backward_hook, self, fired_hooks, model.net1)
|
||||
|
|
@ -254,7 +254,7 @@ class TestModuleHooks(TestCase):
|
|||
|
||||
@parametrize_test("named_tuple", (True, False))
|
||||
def test_full_backward_pre_hooks(self, named_tuple):
|
||||
fired_hooks: List[int] = []
|
||||
fired_hooks: list[int] = []
|
||||
model = ToyModel(named_tuple)
|
||||
x = torch.randn(10, 10)
|
||||
hook = partial(full_backward_pre_hook, self, fired_hooks, model.net1)
|
||||
|
|
@ -294,7 +294,7 @@ class TestModuleHooks(TestCase):
|
|||
|
||||
@parametrize_test("named_tuple", (True, False))
|
||||
def test_mixed_hooks(self, named_tuple):
|
||||
fired_hooks: List[int] = []
|
||||
fired_hooks: list[int] = []
|
||||
model = ToyModel(named_tuple)
|
||||
x = torch.randn(10, 10)
|
||||
model.register_forward_pre_hook(
|
||||
|
|
@ -319,7 +319,7 @@ class TestModuleHooks(TestCase):
|
|||
|
||||
def test_kwarg_hooks(self):
|
||||
# 1. test forward pre hook
|
||||
fired_hooks: List[int] = []
|
||||
fired_hooks: list[int] = []
|
||||
x: torch.Tensor = torch.ones(10, 10)
|
||||
bias: torch.Tensor = torch.ones(10, 10)
|
||||
model = KwargModel()
|
||||
|
|
@ -336,7 +336,7 @@ class TestModuleHooks(TestCase):
|
|||
self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5)
|
||||
|
||||
# 2. test forward pre and forward hooks
|
||||
fired_hooks: List[int] = []
|
||||
fired_hooks: list[int] = []
|
||||
x: torch.Tensor = torch.ones(10, 10)
|
||||
bias: torch.Tensor = torch.ones(10, 10)
|
||||
model = KwargModel()
|
||||
|
|
@ -372,7 +372,7 @@ class TestModuleHooks(TestCase):
|
|||
|
||||
def test_remove_kwarg_hooks(self):
|
||||
# test forward pre and forward hooks
|
||||
fired_hooks: List[int] = []
|
||||
fired_hooks: list[int] = []
|
||||
x: torch.Tensor = torch.ones(10, 10)
|
||||
bias: torch.Tensor = torch.ones(10, 10)
|
||||
model = KwargModel()
|
||||
|
|
@ -1217,8 +1217,8 @@ class TestModuleGlobalHooks(TestCase):
|
|||
def test_module_global_hooks_with_kwargs(self):
|
||||
def kwarg_global_forward_hook(
|
||||
module: nn.Module,
|
||||
args: Tuple[torch.Tensor],
|
||||
kwargs: Dict[str, Any],
|
||||
args: tuple[torch.Tensor],
|
||||
kwargs: dict[str, Any],
|
||||
out: torch.Tensor,
|
||||
) -> Any:
|
||||
out = out + kwargs["bias"]
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import itertools
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn.utils.rnn as rnn_utils
|
||||
|
|
@ -219,7 +218,7 @@ class PackedSequenceTest(TestCase):
|
|||
# more dimensions
|
||||
maxlen = 9
|
||||
for num_dim in (0, 1, 2, 3):
|
||||
sequences: List[torch.Tensor] = []
|
||||
sequences: list[torch.Tensor] = []
|
||||
trailing_dims = [4] * num_dim
|
||||
for i in range(1, maxlen + 1):
|
||||
seq_len = i * i
|
||||
|
|
|
|||
|
|
@ -1573,10 +1573,10 @@ torch.cuda.synchronize()
|
|||
return torch.stack([col, col + 2], 1).view(2, 2, 2, 2)
|
||||
|
||||
if adaptive:
|
||||
cls_name = "AdaptiveMaxPool{}d".format(num_dim) # noqa: UP032
|
||||
cls_name = f"AdaptiveMaxPool{num_dim}d"
|
||||
else:
|
||||
# FIXME(#105716): Test fails when using f-string
|
||||
cls_name = "MaxPool{}d".format(num_dim) # noqa: UP032
|
||||
cls_name = f"MaxPool{num_dim}d"
|
||||
module_cls = getattr(nn, cls_name)
|
||||
module = module_cls(2, return_indices=True).to(device, dtype=dtype)
|
||||
numel = 4 ** (num_dim + 1)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# Owner(s): ["module: onnx"]
|
||||
"""Unit tests for the internal registration wrapper module."""
|
||||
|
||||
from typing import Sequence
|
||||
from collections.abc import Sequence
|
||||
|
||||
from torch.onnx import errors
|
||||
from torch.onnx._internal import registration
|
||||
|
|
|
|||
|
|
@ -10,19 +10,8 @@ import logging
|
|||
import os
|
||||
import unittest
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import Collection, Iterable, Mapping, Sequence
|
||||
from typing import Any, Callable, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@
|
|||
import os
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
from typing import List, Mapping, Tuple
|
||||
from collections.abc import Mapping
|
||||
from typing import List, Tuple
|
||||
|
||||
import onnx_test_common
|
||||
import parameterized
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import itertools
|
|||
import unittest
|
||||
import unittest.mock
|
||||
import warnings
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -26,6 +26,10 @@ from torch.onnx._internal import registration
|
|||
from torch.testing._internal import common_quantization, common_utils, jit_utils
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
|
||||
def export_to_onnx(
|
||||
model: Union[torch.nn.Module, torch.jit.ScriptFunction],
|
||||
input: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Owner(s): ["oncall: package/deploy"]
|
||||
|
||||
from typing import Iterable
|
||||
from collections.abc import Iterable
|
||||
|
||||
from torch.package import GlobGroup
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import os
|
|||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -51,7 +51,7 @@ from torch.testing._internal.common_utils import (
|
|||
from torch.utils._triton import has_triton
|
||||
|
||||
|
||||
Json = Dict[str, Any]
|
||||
Json = dict[str, Any]
|
||||
|
||||
|
||||
class TestExecutionTrace(TestCase):
|
||||
|
|
@ -97,7 +97,7 @@ class TestExecutionTrace(TestCase):
|
|||
nodes = et_graph["nodes"]
|
||||
return nodes
|
||||
|
||||
def get_execution_trace_rf_ids(self, nodes: List[Json]) -> List[int]:
|
||||
def get_execution_trace_rf_ids(self, nodes: list[Json]) -> list[int]:
|
||||
"""Returns a sorted list of rf_id (record function ids) in execution trace"""
|
||||
|
||||
def get_rf_id(node):
|
||||
|
|
@ -115,7 +115,7 @@ class TestExecutionTrace(TestCase):
|
|||
)
|
||||
return sorted(rf_id for rf_id in rf_ids_ if rf_id is not None)
|
||||
|
||||
def get_kineto_rf_ids(self, events: List[Json]) -> List[int]:
|
||||
def get_kineto_rf_ids(self, events: list[Json]) -> list[int]:
|
||||
"""Returns a sorted list of Record function IDs for CPU operators and user annotations"""
|
||||
ops_and_annotations = (
|
||||
e for e in events if e.get("cat", "") in ["cpu_op", "user_annotation"]
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@ import itertools as it
|
|||
import sys
|
||||
import textwrap
|
||||
import unittest
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple
|
||||
from collections.abc import Iterator
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch._C._profiler import _EventType, _TensorMetadata
|
||||
|
|
@ -309,9 +310,9 @@ class TestDataFlow(TestCase):
|
|||
@staticmethod
|
||||
def formatSchemas(
|
||||
prof: torch.profiler.profile, indent: int = 12
|
||||
) -> Tuple[Tuple[str, Tuple[bool, ...]], ...]:
|
||||
) -> tuple[tuple[str, tuple[bool, ...]], ...]:
|
||||
tree = prof.profiler.kineto_results.experimental_event_tree()
|
||||
out: List[Tuple[str, Tuple[bool, ...]]] = []
|
||||
out: list[tuple[str, tuple[bool, ...]]] = []
|
||||
for node in _utils.traverse_dfs(tree):
|
||||
if node.tag == _EventType.TorchOp:
|
||||
e = node.extra_fields
|
||||
|
|
@ -327,8 +328,8 @@ class TestDataFlow(TestCase):
|
|||
|
||||
@staticmethod
|
||||
def _run_and_format_data_flow(
|
||||
inputs: Dict[str, torch.Tensor],
|
||||
f: Callable[..., Optional[Dict[str, torch.Tensor]]],
|
||||
inputs: dict[str, torch.Tensor],
|
||||
f: Callable[..., Optional[dict[str, torch.Tensor]]],
|
||||
indent: int = 12,
|
||||
) -> str:
|
||||
with profile() as prof:
|
||||
|
|
@ -339,7 +340,7 @@ class TestDataFlow(TestCase):
|
|||
graph = memory_profile._data_flow_graph
|
||||
storage_to_id = {key.storage.ptr: key.id for key in graph._active_version}
|
||||
|
||||
lines: List[str] = []
|
||||
lines: list[str] = []
|
||||
for name, t in it.chain(inputs.items(), outputs.items()):
|
||||
lines.append(f"{name + ':':<8} T{storage_to_id[t.storage().data_ptr()]}")
|
||||
if t.grad is not None:
|
||||
|
|
@ -352,7 +353,7 @@ class TestDataFlow(TestCase):
|
|||
for node in graph.flow_nodes:
|
||||
destroyed = {k for k, v in node._edges.items() if v.is_deletion}
|
||||
|
||||
inputs: List[str] = []
|
||||
inputs: list[str] = []
|
||||
for key, (_, v) in node.inputs.items():
|
||||
inputs.append(f"T{key.id}(v{v}{'*' if key in destroyed else ''})")
|
||||
|
||||
|
|
@ -833,7 +834,7 @@ class TestMemoryProfilerE2E(TestCase):
|
|||
@staticmethod
|
||||
def _lookup_tensor_categories(
|
||||
t: torch.Tensor, memory_profile: _memory_profiler.MemoryProfile
|
||||
) -> Dict[_memory_profiler.TensorAndID, Optional[_memory_profiler.Category]]:
|
||||
) -> dict[_memory_profiler.TensorAndID, Optional[_memory_profiler.Category]]:
|
||||
storage = t.storage()
|
||||
if storage is None:
|
||||
raise ValueError("Cannot look up uninitialized Tensor.")
|
||||
|
|
@ -889,7 +890,7 @@ class TestMemoryProfilerE2E(TestCase):
|
|||
fn(lambda name: record_ops.mark_region(f"-- {name} ".ljust(105, "-")))
|
||||
|
||||
memory_profile = prof._memory_profile()
|
||||
ptr_pair_to_key: Dict[Tuple[int, int], _memory_profiler.TensorKey] = {}
|
||||
ptr_pair_to_key: dict[tuple[int, int], _memory_profiler.TensorKey] = {}
|
||||
snapshot = memory_profile._category_snapshot()
|
||||
|
||||
# Build map from observed live Tensors to the memory profiler's
|
||||
|
|
@ -922,7 +923,7 @@ class TestMemoryProfilerE2E(TestCase):
|
|||
|
||||
return f"{target_key.storage.allocation_id} ({','.join(categories)})"
|
||||
|
||||
out: List[str] = []
|
||||
out: list[str] = []
|
||||
for name, inputs, outputs in record_ops.results:
|
||||
if inputs or outputs:
|
||||
# PyTorch ops
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import threading
|
|||
import time
|
||||
import unittest
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import expecttest
|
||||
|
|
@ -2277,7 +2277,7 @@ class MockProfilerEvent:
|
|||
start_time_ns: int
|
||||
duration_time_ns: int
|
||||
correlation_id: int = 0
|
||||
children: List["MockProfilerEvent"] = field(default_factory=list)
|
||||
children: list["MockProfilerEvent"] = field(default_factory=list)
|
||||
parent: Optional["MockProfilerEvent"] = None
|
||||
|
||||
@property
|
||||
|
|
@ -2301,7 +2301,7 @@ class MockNode:
|
|||
|
||||
@unittest.skipIf(sys.version_info >= (3, 13), "segfaults")
|
||||
class TestExperimentalUtils(TestCase):
|
||||
def make_tree(self) -> List[MockNode]:
|
||||
def make_tree(self) -> list[MockNode]:
|
||||
tree = {
|
||||
"root_0": {
|
||||
"1": {"2": {}},
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ try:
|
|||
except ImportError:
|
||||
None
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.optim
|
||||
|
|
@ -29,7 +29,7 @@ from torch.profiler import kineto_available, record_function
|
|||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
Json = Dict[str, Any]
|
||||
Json = dict[str, Any]
|
||||
|
||||
|
||||
class TestRecordFunction(TestCase):
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ import sys
|
|||
import textwrap
|
||||
import unittest
|
||||
import weakref
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -30,7 +30,7 @@ from torch.profiler import _utils, profile
|
|||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
Json = Dict[str, Any]
|
||||
Json = dict[str, Any]
|
||||
|
||||
from torch._C._profiler import _ExtraFields_PyCall
|
||||
|
||||
|
|
@ -455,7 +455,7 @@ class TestTorchTidyProfiler(TestCase):
|
|||
|
||||
nodes = p.profiler.kineto_results.experimental_event_tree()
|
||||
|
||||
def find_chain(names: List[str]):
|
||||
def find_chain(names: list[str]):
|
||||
out = []
|
||||
for name in names:
|
||||
root = [out[-1]] if out else nodes
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from copy import deepcopy
|
|||
from functools import partial, reduce
|
||||
from itertools import product
|
||||
from operator import mul
|
||||
from typing import List, Tuple, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.autograd._functions
|
||||
|
|
@ -10091,14 +10091,14 @@ TORCH_LIBRARY(test_multigrad_all_hooks, m) {
|
|||
|
||||
def test_multi_grad_any_hooks(self):
|
||||
hook_id = 0
|
||||
any_hook_handles: List[RemovableHandle] = []
|
||||
any_hook_handles: list[RemovableHandle] = []
|
||||
|
||||
class MultiOutputModule(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.lin = nn.Linear(3, 3)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
z = self.lin(x)
|
||||
out = torch.sin(z), torch.cos(z)
|
||||
nonlocal hook_id
|
||||
|
|
@ -10123,7 +10123,7 @@ TORCH_LIBRARY(test_multigrad_all_hooks, m) {
|
|||
z = y[0] + y[1]
|
||||
return self.mod2(z)
|
||||
|
||||
hook_order: List[int] = []
|
||||
hook_order: list[int] = []
|
||||
hook_count = 0
|
||||
|
||||
def hook(hook_id: int, *unused):
|
||||
|
|
@ -13975,7 +13975,7 @@ class TestSelectiveActivationCheckpoint(TestCase):
|
|||
counter = [0]
|
||||
|
||||
@torch.library.custom_op("mylib::sin_with_extra", mutates_args=())
|
||||
def sin_with_extra(x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
||||
def sin_with_extra(x: torch.Tensor) -> tuple[torch.Tensor, int]:
|
||||
counter[0] += 1
|
||||
return x.sin(), 2
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
import io
|
||||
import textwrap
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.utils.bundled_inputs
|
||||
|
|
@ -32,7 +32,7 @@ class TestBundledInputs(TestCase):
|
|||
|
||||
sm = torch.jit.script(SingleTensorModel())
|
||||
original_size = model_size(sm)
|
||||
get_expr: List[str] = []
|
||||
get_expr: list[str] = []
|
||||
samples = [
|
||||
# Tensor with small numel and small storage.
|
||||
(torch.tensor([1]),),
|
||||
|
|
@ -328,8 +328,8 @@ class TestBundledInputs(TestCase):
|
|||
class MyModel(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
arg1: Optional[Dict[str, torch.Tensor]],
|
||||
arg2: Optional[List[torch.Tensor]],
|
||||
arg1: Optional[dict[str, torch.Tensor]],
|
||||
arg2: Optional[list[torch.Tensor]],
|
||||
arg3: torch.Tensor,
|
||||
):
|
||||
if arg1 is None:
|
||||
|
|
@ -393,7 +393,7 @@ class TestBundledInputs(TestCase):
|
|||
""",
|
||||
)
|
||||
|
||||
out: List[str] = []
|
||||
out: list[str] = []
|
||||
sm = torch.jit.script(MyModel())
|
||||
original_size = model_size(sm)
|
||||
small_inputs = (
|
||||
|
|
|
|||
|
|
@ -4054,11 +4054,7 @@ class TestCudaMallocAsync(TestCase):
|
|||
that the pytorch call is returning a correct list of UUIDs.
|
||||
"""
|
||||
cmd = "rocminfo | grep -o 'Uuid:.*GPU-.*' | sed 's/Uuid:.*GPU-//'"
|
||||
uuids = (
|
||||
subprocess.check_output(cmd, shell=True, universal_newlines=True)
|
||||
.strip()
|
||||
.split("\n")
|
||||
)
|
||||
uuids = subprocess.check_output(cmd, shell=True, text=True).strip().split("\n")
|
||||
uuids = [s.strip() for s in uuids]
|
||||
raw_uuids = torch.cuda._raw_device_uuid_amdsmi()
|
||||
for uuid in uuids:
|
||||
|
|
@ -4082,11 +4078,7 @@ import os
|
|||
print(f"{torch.cuda.device_count()}")
|
||||
"""
|
||||
cmd = "rocminfo | grep -o 'Uuid:.*GPU-.*' | sed 's/Uuid://'"
|
||||
uuids = (
|
||||
subprocess.check_output(cmd, shell=True, universal_newlines=True)
|
||||
.strip()
|
||||
.split("\n")
|
||||
)
|
||||
uuids = subprocess.check_output(cmd, shell=True, text=True).strip().split("\n")
|
||||
uuids = [s.strip() for s in uuids]
|
||||
|
||||
custom_envs = []
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import sys
|
||||
import textwrap
|
||||
import traceback
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.cuda._sanitizer as csan
|
||||
|
|
@ -148,9 +148,9 @@ class TestEventHandler(TestCase):
|
|||
def kernel_launch(
|
||||
self,
|
||||
stream: StreamId,
|
||||
read_only: Optional[List[DataPtr]] = None,
|
||||
read_write: Optional[List[DataPtr]] = None,
|
||||
) -> List[csan.SynchronizationError]:
|
||||
read_only: Optional[list[DataPtr]] = None,
|
||||
read_write: Optional[list[DataPtr]] = None,
|
||||
) -> list[csan.SynchronizationError]:
|
||||
if read_only is None:
|
||||
read_only = []
|
||||
if read_write is None:
|
||||
|
|
@ -167,8 +167,8 @@ class TestEventHandler(TestCase):
|
|||
def assert_good_kernel_launch(
|
||||
self,
|
||||
stream: StreamId,
|
||||
read_only: Optional[List[DataPtr]] = None,
|
||||
read_write: Optional[List[DataPtr]] = None,
|
||||
read_only: Optional[list[DataPtr]] = None,
|
||||
read_write: Optional[list[DataPtr]] = None,
|
||||
) -> None:
|
||||
self.assertEqual(self.kernel_launch(stream, read_only, read_write), [])
|
||||
|
||||
|
|
@ -176,8 +176,8 @@ class TestEventHandler(TestCase):
|
|||
self,
|
||||
number_of_errors: int,
|
||||
stream: StreamId,
|
||||
read_only: Optional[List[DataPtr]] = None,
|
||||
read_write: Optional[List[DataPtr]] = None,
|
||||
read_only: Optional[list[DataPtr]] = None,
|
||||
read_write: Optional[list[DataPtr]] = None,
|
||||
) -> None:
|
||||
errors = self.kernel_launch(stream, read_only, read_write)
|
||||
self.assertEqual(len(errors), number_of_errors)
|
||||
|
|
|
|||
|
|
@ -1262,14 +1262,12 @@ class TestDataLoader(TestCase):
|
|||
list(iter(loader))
|
||||
|
||||
def test_typing(self):
|
||||
from typing import List
|
||||
|
||||
# Make sure there is no TypeError
|
||||
|
||||
class SomeDatasetClass(Dataset[List[torch.Tensor]]):
|
||||
class SomeDatasetClass(Dataset[list[torch.Tensor]]):
|
||||
pass
|
||||
|
||||
def _create_dataloader(is_train: bool) -> DataLoader[List[torch.Tensor]]:
|
||||
def _create_dataloader(is_train: bool) -> DataLoader[list[torch.Tensor]]:
|
||||
pass
|
||||
|
||||
@unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import unittest
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
|
@ -27,9 +27,9 @@ class MyModule(torch.nn.Module):
|
|||
class MyDummyFnOptimizer:
|
||||
def __init__(
|
||||
self,
|
||||
params: List[Tensor],
|
||||
params: list[Tensor],
|
||||
lr: float = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.999),
|
||||
betas: tuple[float, float] = (0.9, 0.999),
|
||||
eps: float = 1e-6,
|
||||
weight_decay: float = 0.0,
|
||||
_allow_empty_param_list: bool = False,
|
||||
|
|
@ -63,7 +63,7 @@ class MyDummyFnOptimizer:
|
|||
"MyDummyFnOptimizer does not support step_param() as of now"
|
||||
)
|
||||
|
||||
def step(self, gradients: List[Optional[Tensor]]):
|
||||
def step(self, gradients: list[Optional[Tensor]]):
|
||||
# call the custom optimizer step implementation
|
||||
with torch.no_grad():
|
||||
raise RuntimeError("MyDummyFnOptimizer does not support step() as of now")
|
||||
|
|
|
|||
|
|
@ -1200,6 +1200,15 @@ class {test_classname}(torch.nn.Module):
|
|||
inp3_y = inp3.y
|
||||
return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y
|
||||
|
||||
class MyModule2(torch.nn.Module):
|
||||
def forward(self, inp: tuple[CustomType, torch.Tensor], inp2: list[CustomType], inp3: CustomNamedTuple):
|
||||
inp_0 = inp[0]
|
||||
inp_1 = inp[1]
|
||||
inp2_0 = inp2[0]
|
||||
inp3_x = inp3.x
|
||||
inp3_y = inp3.y
|
||||
return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y
|
||||
|
||||
my_module = MyModule()
|
||||
my_module_traced = torch.fx.symbolic_trace(my_module)
|
||||
|
||||
|
|
@ -1214,6 +1223,20 @@ class {test_classname}(torch.nn.Module):
|
|||
if node.target == operator.getitem:
|
||||
self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.")
|
||||
|
||||
my_module = MyModule2()
|
||||
my_module_traced = torch.fx.symbolic_trace(my_module)
|
||||
|
||||
# by default, fx transform loses type annotation of getitem nodes.
|
||||
for node in my_module_traced.graph.nodes:
|
||||
if node.target == operator.getitem:
|
||||
assert node.type is None
|
||||
|
||||
annotate_getitem_nodes(my_module_traced.graph)
|
||||
|
||||
for node in my_module_traced.graph.nodes:
|
||||
if node.target == operator.getitem:
|
||||
self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.")
|
||||
|
||||
def test_subgraph_uniquename(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
import itertools
|
||||
import torch
|
||||
from typing import List, Any
|
||||
from typing import Any
|
||||
from functools import wraps
|
||||
import unittest
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
|
|
@ -100,7 +100,7 @@ def apply_masked_reduction_along_dim(op, input, *args, **kwargs):
|
|||
# dimensions along which the reduction operation is applied:
|
||||
dim_ = torch.masked._canonical_dim(dim, input.ndim)
|
||||
# slices in product(*ranges) define all elementary slices:
|
||||
ranges: List[Any] = []
|
||||
ranges: list[Any] = []
|
||||
# shape of output for the keepdim=True case:
|
||||
shape = []
|
||||
for i in range(input.ndim):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# Owner(s): ["module: mkldnn"]
|
||||
import itertools
|
||||
import unittest
|
||||
from typing import NamedTuple, List
|
||||
from typing import NamedTuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
|
@ -16,7 +16,7 @@ FUSION_GROUP = 'prim::TensorExprGroup'
|
|||
class PointwisePostOp(NamedTuple):
|
||||
attr : str
|
||||
pointwise_module : nn.Module
|
||||
scalars : List = []
|
||||
scalars : list = []
|
||||
algorithm : str = ""
|
||||
|
||||
CONV_MODULES = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Owner(s): ["module: unknown"]
|
||||
|
||||
from typing import Optional, List
|
||||
from typing import Optional
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
|
||||
|
||||
|
|
@ -8,12 +8,12 @@ from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorc
|
|||
|
||||
|
||||
class FloatListWrapperModule(torch.nn.Module):
|
||||
def forward(self, values, incr: Optional[List[float]]):
|
||||
def forward(self, values, incr: Optional[list[float]]):
|
||||
return torch._C._nn._test_optional_floatlist(values, incr)
|
||||
|
||||
|
||||
class IntListWrapperModule(torch.nn.Module):
|
||||
def forward(self, values, incr: Optional[List[int]]):
|
||||
def forward(self, values, incr: Optional[list[int]]):
|
||||
return torch._C._nn._test_optional_intlist(values, incr)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import sys
|
|||
import tempfile
|
||||
import unittest
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -3545,7 +3545,7 @@ def get_tolerances(
|
|||
true_value: torch.Tensor,
|
||||
computed_value: torch.Tensor,
|
||||
fudge_factor: Optional[float] = None,
|
||||
) -> Tuple[float, float]:
|
||||
) -> tuple[float, float]:
|
||||
"""Returns the absolute and relative tolerances for comparing two tensors."""
|
||||
fudge_factor = fudge_factor if fudge_factor is not None else 1.0
|
||||
atol = get_atol(true_value, computed_value)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
import ctypes
|
||||
import os
|
||||
import unittest
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.backends._nnapi.prepare import convert_model_to_nnapi
|
||||
|
|
@ -700,7 +699,7 @@ class TestNNAPI(TestCase):
|
|||
|
||||
def test_multi_output(self):
|
||||
class MultiModel(torch.nn.Module):
|
||||
def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(self, lhs, rhs) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
the_sum = lhs + rhs
|
||||
the_diff = lhs - rhs
|
||||
return the_sum, the_diff
|
||||
|
|
|
|||
|
|
@ -278,7 +278,7 @@ class TestNumPyInterop(TestCase):
|
|||
def test_from_numpy_no_leak_on_invalid_dtype(self):
|
||||
# This used to leak memory as the `from_numpy` call raised an exception and didn't decref the temporary
|
||||
# object. See https://github.com/pytorch/pytorch/issues/121138
|
||||
x = np.array("value".encode("ascii"))
|
||||
x = np.array(b"value")
|
||||
for _ in range(1000):
|
||||
try:
|
||||
torch.from_numpy(x)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from collections import defaultdict
|
|||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from importlib import import_module
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch._prims as prims
|
||||
|
|
@ -1483,7 +1482,7 @@ class TestCommon(TestCase):
|
|||
unsupported_dtypes = set()
|
||||
supported_backward_dtypes = set()
|
||||
unsupported_backward_dtypes = set()
|
||||
dtype_error: Dict[torch.dtype, Exception] = {}
|
||||
dtype_error: dict[torch.dtype, Exception] = {}
|
||||
|
||||
def unsupported(dtype, e):
|
||||
dtype_error[dtype] = e
|
||||
|
|
@ -1987,7 +1986,7 @@ class TestCompositeCompliance(TestCase):
|
|||
for sample in op.sample_inputs(device, dtype, requires_grad=False):
|
||||
inp = sample.input
|
||||
outs = op(inp, *sample.args, **sample.kwargs)
|
||||
if not isinstance(outs, (tuple, List)):
|
||||
if not isinstance(outs, (tuple, list)):
|
||||
outs = [outs]
|
||||
|
||||
# for all outputs that are views of the input, we should be able to replay the
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import math
|
|||
import tempfile
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Tuple
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from optim.test_lrscheduler import TestLRScheduler # noqa: F401
|
||||
|
|
@ -1769,8 +1769,8 @@ class TestOptimRenewed(TestCase):
|
|||
|
||||
@staticmethod
|
||||
def _state_dict_post_hook(
|
||||
optimizer: Optimizer, state_dict: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
optimizer: Optimizer, state_dict: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
if "test" in state_dict["state"]:
|
||||
state_dict["state"].pop("test")
|
||||
state_dict["ran_state_dict_pre_hook"] = True
|
||||
|
|
@ -1821,14 +1821,14 @@ class TestOptimRenewed(TestCase):
|
|||
|
||||
@staticmethod
|
||||
def _load_state_dict_pre_hook1(
|
||||
optimizer: Optimizer, state_dict: Dict[str, Any]
|
||||
optimizer: Optimizer, state_dict: dict[str, Any]
|
||||
) -> None:
|
||||
state_dict["param_groups"][0]["lr"] = 0.002
|
||||
|
||||
@staticmethod
|
||||
def _load_state_dict_pre_hook2(
|
||||
optimizer: Optimizer, state_dict: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
optimizer: Optimizer, state_dict: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
# The typical use case for returning a state dict is to drastically modify the state dict.
|
||||
# I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used
|
||||
my_state_dict = deepcopy(state_dict)
|
||||
|
|
@ -1906,7 +1906,7 @@ class TestOptimRenewed(TestCase):
|
|||
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_step_post_hook(self, device, dtype, optim_info):
|
||||
def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
||||
def post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
|
||||
nonlocal data
|
||||
data += 2
|
||||
|
||||
|
|
@ -1938,7 +1938,7 @@ class TestOptimRenewed(TestCase):
|
|||
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_step_pre_hook(self, device, dtype, optim_info):
|
||||
def pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
||||
def pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
|
||||
nonlocal data
|
||||
data += 2
|
||||
|
||||
|
|
@ -1970,19 +1970,19 @@ class TestOptimRenewed(TestCase):
|
|||
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_step_all_hooks(self, device, dtype, optim_info):
|
||||
def global_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
||||
def global_pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
|
||||
nonlocal data
|
||||
data.append(0)
|
||||
|
||||
def global_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
||||
def global_post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
|
||||
nonlocal data
|
||||
data.append(5)
|
||||
|
||||
def local_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
||||
def local_pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
|
||||
nonlocal data
|
||||
data.append(1)
|
||||
|
||||
def local_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
||||
def local_post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
|
||||
nonlocal data
|
||||
data.append(2)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import numpy as np
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Sequence
|
||||
from collections.abc import Sequence
|
||||
import random
|
||||
from functools import partial
|
||||
from itertools import product, combinations, permutations
|
||||
|
|
@ -736,7 +736,7 @@ class TestReductions(TestCase):
|
|||
|
||||
# TODO: kill this ane replace with common creation ops
|
||||
def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True,
|
||||
use_complex=False) -> Dict[str, List[torch.Tensor]]:
|
||||
use_complex=False) -> dict[str, list[torch.Tensor]]:
|
||||
float_types = [torch.double,
|
||||
torch.float]
|
||||
int_types = [torch.int64,
|
||||
|
|
@ -778,7 +778,7 @@ class TestReductions(TestCase):
|
|||
types += int_types
|
||||
if use_complex:
|
||||
types += complex_types
|
||||
tensors: Dict[str, List[torch.Tensor]] = {"cont": [], "noncont": [], "slice": []}
|
||||
tensors: dict[str, list[torch.Tensor]] = {"cont": [], "noncont": [], "slice": []}
|
||||
for dtype in types:
|
||||
tensors["cont"].append(make_contiguous(shape, dtype))
|
||||
tensors["noncont"].append(make_non_contiguous(shape, dtype))
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm
|
|||
skipIfCrossRef
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from numbers import Number
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
from packaging import version
|
||||
from torch.testing._internal.common_cuda import \
|
||||
(SM53OrLater, SM80OrLater, TEST_MULTIGPU)
|
||||
|
|
@ -334,7 +334,7 @@ class TestSparse(TestSparseBase):
|
|||
self.assertEqual(t._values(), tc._values())
|
||||
return tc
|
||||
|
||||
value_map: Dict[Any, Any] = {}
|
||||
value_map: dict[Any, Any] = {}
|
||||
for idx, val in zip(t._indices().t(), t._values()):
|
||||
idx_tup = tuple(idx.tolist())
|
||||
if idx_tup in value_map:
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from torch.testing._internal.common_methods_invocations import (
|
|||
from torch.testing._internal.common_cuda import SM53OrLater
|
||||
from torch._prims_common import corresponding_complex_dtype
|
||||
|
||||
from typing import Optional, List
|
||||
from typing import Optional
|
||||
from packaging import version
|
||||
|
||||
|
||||
|
|
@ -597,7 +597,7 @@ class TestFFT(TestCase):
|
|||
else:
|
||||
numpy_fn = getattr(np.fft, fname)
|
||||
|
||||
def fn(t: torch.Tensor, s: Optional[List[int]], dim: List[int] = (-2, -1), norm: Optional[str] = None):
|
||||
def fn(t: torch.Tensor, s: Optional[list[int]], dim: list[int] = (-2, -1), norm: Optional[str] = None):
|
||||
return torch_fn(t, s, dim, norm)
|
||||
|
||||
torch_fns = (torch_fn, torch.jit.script(fn))
|
||||
|
|
|
|||
|
|
@ -2,14 +2,13 @@
|
|||
# ruff: noqa: F841
|
||||
|
||||
import unittest
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.testing._internal.static_module import StaticModule
|
||||
from typing import List
|
||||
|
||||
|
||||
def linear_shim(
|
||||
|
|
@ -108,7 +107,7 @@ def fork_wait_graph2(input1, input2):
|
|||
:param iters: number of future/wait pairs to be created
|
||||
"""
|
||||
def fork_wait_graph3(input, iters: int):
|
||||
futures : List[torch.jit.Future[torch.Tensor]] = []
|
||||
futures : list[torch.jit.Future[torch.Tensor]] = []
|
||||
for _ in range(iters):
|
||||
futures.append(torch.jit.fork(torch.neg, input))
|
||||
results = []
|
||||
|
|
@ -123,7 +122,7 @@ def fork_wait_graph3(input, iters: int):
|
|||
:param num_child_forks: number of child forks per parent fork
|
||||
"""
|
||||
def fork_wait_graph4(input, num_forks: int, num_child_forks: int):
|
||||
futures : List[torch.jit.Future[torch.Tensor]] = []
|
||||
futures : list[torch.jit.Future[torch.Tensor]] = []
|
||||
for _ in range(num_forks):
|
||||
futures.append(torch.jit.fork(fork_wait_graph3, input, num_child_forks))
|
||||
results = []
|
||||
|
|
@ -150,7 +149,7 @@ def loop_graph(a, b, iters: int):
|
|||
def output_graph(a, b, c, iters: int):
|
||||
s = torch.tensor([[3, 3], [3, 3]])
|
||||
k = a + b * c + s
|
||||
d: Dict[int, torch.Tensor] = {}
|
||||
d: dict[int, torch.Tensor] = {}
|
||||
for i in range(iters):
|
||||
d[i] = k + i
|
||||
return d
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import itertools
|
|||
import math
|
||||
import pickle
|
||||
import sys
|
||||
from typing import Callable, List, Tuple, Type
|
||||
from typing import Callable
|
||||
|
||||
import sympy
|
||||
|
||||
|
|
@ -594,7 +594,7 @@ class TestSympyInterp(TestCase):
|
|||
self.fail(f"Unexpected error for {fn}{args}: {str(e)}")
|
||||
|
||||
|
||||
def type_name_fn(type: Type) -> str:
|
||||
def type_name_fn(type: type) -> str:
|
||||
return type.__name__
|
||||
|
||||
|
||||
|
|
@ -606,7 +606,7 @@ def parametrize_relational_types(*types):
|
|||
|
||||
|
||||
class TestSympySolve(TestCase):
|
||||
def _create_integer_symbols(self) -> List[sympy.Symbol]:
|
||||
def _create_integer_symbols(self) -> list[sympy.Symbol]:
|
||||
return sympy.symbols("a b c", integer=True)
|
||||
|
||||
def test_give_up(self):
|
||||
|
|
@ -665,9 +665,9 @@ class TestSympySolve(TestCase):
|
|||
|
||||
def _test_cases(
|
||||
self,
|
||||
cases: List[Tuple[sympy.Basic, sympy.Basic]],
|
||||
cases: list[tuple[sympy.Basic, sympy.Basic]],
|
||||
thing: sympy.Basic,
|
||||
op: Type[sympy.Rel],
|
||||
op: type[sympy.Rel],
|
||||
**kwargs,
|
||||
):
|
||||
for source, expected in cases:
|
||||
|
|
@ -761,7 +761,7 @@ class TestSympySolve(TestCase):
|
|||
Le: (Le(FloorDiv(a, pos), integer), (integer + 1) * pos),
|
||||
}[op]
|
||||
|
||||
cases: List[Tuple[sympy.Basic, sympy.Basic]] = [
|
||||
cases: list[tuple[sympy.Basic, sympy.Basic]] = [
|
||||
# 'b' is not strictly positive
|
||||
(op(FloorDiv(a, b), integer), None),
|
||||
# 'c' is not strictly positive
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import unittest
|
|||
from itertools import product, combinations, combinations_with_replacement, permutations
|
||||
import random
|
||||
import tempfile
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any
|
||||
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import (
|
||||
|
|
@ -4125,7 +4125,7 @@ class TestAsArray(TestCase):
|
|||
def test_default_device(self, device):
|
||||
original = torch.arange(5)
|
||||
|
||||
examples: List[Tuple[Any, Dict]] = [
|
||||
examples: list[tuple[Any, dict]] = [
|
||||
(3, {}),
|
||||
(original, {}),
|
||||
(to_numpy(original), {}),
|
||||
|
|
|
|||
|
|
@ -12,7 +12,8 @@ import re
|
|||
import subprocess
|
||||
import sys
|
||||
import unittest.mock
|
||||
from typing import Any, Callable, Iterator, List, Tuple
|
||||
from typing import Any, Callable
|
||||
from collections.abc import Iterator
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -496,7 +497,7 @@ if __name__ == '__main__':
|
|||
self.assertNotIn('OK', stderr.decode('ascii'))
|
||||
|
||||
|
||||
def make_assert_close_inputs(actual: Any, expected: Any) -> List[Tuple[Any, Any]]:
|
||||
def make_assert_close_inputs(actual: Any, expected: Any) -> list[tuple[Any, Any]]:
|
||||
"""Makes inputs for :func:`torch.testing.assert_close` functions based on two examples.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -53,7 +53,6 @@ from torch.testing._internal.common_device_type import (
|
|||
dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast,
|
||||
skipMeta, PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyNativeDeviceTypes, skipCUDAIfNotRocm,
|
||||
get_all_device_types, skipXLA)
|
||||
from typing import Tuple
|
||||
import torch.backends.quantized
|
||||
import torch.testing._internal.data
|
||||
from torch.testing._internal.common_cuda import (
|
||||
|
|
@ -3511,7 +3510,7 @@ else:
|
|||
|
||||
def _prepare_data_for_index_copy_and_add_deterministic(
|
||||
self, dim: int, device: torch.device
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert (dim >= 0 and dim < 3)
|
||||
a = [5, 4, 3]
|
||||
a[dim] = 2000
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import math
|
|||
import itertools
|
||||
import torch.optim as optim
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU
|
||||
from typing import List, Tuple, Optional, Dict
|
||||
from typing import Optional
|
||||
import torch.utils.cpp_extension
|
||||
from torch.testing._internal.common_nn import NNTestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
|
|
@ -149,12 +149,12 @@ def _check_equal(
|
|||
|
||||
|
||||
def check_out_and_grad(
|
||||
out_tuple: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
grad_query_tuple: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
grad_key_tuple: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
grad_value_tuple: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
grad_attn_mask_tuple: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None,
|
||||
fudge_factors: Optional[Dict[str, float]] = None
|
||||
out_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
grad_query_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
grad_key_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
grad_value_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
grad_attn_mask_tuple: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None,
|
||||
fudge_factors: Optional[dict[str, float]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Check output and gradients of attention mechanism tensors.
|
||||
|
|
@ -2574,7 +2574,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
|
||||
def test_cudnn_attention_preserves_query_layout(self, device):
|
||||
|
||||
def test_attention(backend: SDPBackend, permute_order: List[List[int]]):
|
||||
def test_attention(backend: SDPBackend, permute_order: list[list[int]]):
|
||||
BHSqD = [4, 16, 256, 64]
|
||||
BHSkvD = [4, 16, 512, 64]
|
||||
|
||||
|
|
@ -2602,7 +2602,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
|
||||
@parametrize("mask_dim", [1, 2, 3, 4])
|
||||
def test_mem_efficient_attention_mask_variants(self, device, mask_dim: List[int]):
|
||||
def test_mem_efficient_attention_mask_variants(self, device, mask_dim: list[int]):
|
||||
dtype = torch.float16
|
||||
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
|
||||
batch, num_heads, head_dim = 8, 8, 64
|
||||
|
|
@ -3307,7 +3307,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
@tf32_enabled()
|
||||
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
|
||||
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
|
||||
scale: str, enable_gqa: bool, n_heads: List[int]):
|
||||
scale: str, enable_gqa: bool, n_heads: list[int]):
|
||||
if isSM8XDevice and head_dim in range(193, 256 + 1):
|
||||
self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled")
|
||||
if is_causal and seq_len_q != seq_len_k:
|
||||
|
|
@ -3905,7 +3905,7 @@ class TestAttnBias(NNTestCase):
|
|||
"shape",
|
||||
[(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)],
|
||||
)
|
||||
def test_causal_variants(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]):
|
||||
def test_causal_variants(self, device, causal_variant: CausalVariant, shape: list[tuple[int]]):
|
||||
make_tensor = partial(
|
||||
torch.rand, device=device, dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
|
|
@ -3942,7 +3942,7 @@ class TestAttnBias(NNTestCase):
|
|||
)
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows")
|
||||
@skipIfTorchDynamo("This function already calls torch.compile.")
|
||||
def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]):
|
||||
def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: list[tuple[int]]):
|
||||
if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT:
|
||||
self.skipTest("No support for LOWER_RIGHT variant for now")
|
||||
return
|
||||
|
|
@ -3975,7 +3975,7 @@ class TestAttnBias(NNTestCase):
|
|||
self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!")
|
||||
|
||||
@parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)])
|
||||
def test_is_causal_equals_upper_left(self, device, shape: List[Tuple[int]]):
|
||||
def test_is_causal_equals_upper_left(self, device, shape: list[tuple[int]]):
|
||||
make_tensor = partial(
|
||||
torch.rand, device=device, dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import shutil
|
|||
import unittest
|
||||
from collections import defaultdict
|
||||
from threading import Lock
|
||||
from typing import Dict, IO, List, Optional
|
||||
from typing import IO, Optional
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
|
|
@ -49,12 +49,12 @@ def _strip_filename(msg: str) -> str:
|
|||
return tail.split(":", 1)[-1]
|
||||
|
||||
|
||||
def _run_mypy() -> Dict[str, List[str]]:
|
||||
def _run_mypy() -> dict[str, list[str]]:
|
||||
"""Clears the cache and run mypy before running any of the typing tests."""
|
||||
if os.path.isdir(CACHE_DIR):
|
||||
shutil.rmtree(CACHE_DIR)
|
||||
|
||||
rc: Dict[str, List[str]] = {}
|
||||
rc: dict[str, list[str]] = {}
|
||||
for directory in (REVEAL_DIR, PASS_DIR, FAIL_DIR):
|
||||
# Run mypy
|
||||
stdout, stderr, _ = api.run(
|
||||
|
|
@ -119,10 +119,10 @@ def _construct_format_dict():
|
|||
|
||||
#: A dictionary with all supported format keys (as keys)
|
||||
#: and matching values
|
||||
FORMAT_DICT: Dict[str, str] = _construct_format_dict()
|
||||
FORMAT_DICT: dict[str, str] = _construct_format_dict()
|
||||
|
||||
|
||||
def _parse_reveals(file: IO[str]) -> List[str]:
|
||||
def _parse_reveals(file: IO[str]) -> list[str]:
|
||||
"""Extract and parse all ``" # E: "`` comments from the passed file-like object.
|
||||
|
||||
All format keys will be substituted for their respective value from `FORMAT_DICT`,
|
||||
|
|
@ -160,10 +160,10 @@ def _test_reveal(path: str, reveal: str, expected_reveal: str, lineno: int) -> N
|
|||
@unittest.skipIf(NO_MYPY, reason="Mypy is not installed")
|
||||
class TestTyping(TestCase):
|
||||
_lock = Lock()
|
||||
_cached_output: Optional[Dict[str, List[str]]] = None
|
||||
_cached_output: Optional[dict[str, list[str]]] = None
|
||||
|
||||
@classmethod
|
||||
def get_mypy_output(cls) -> Dict[str, List[str]]:
|
||||
def get_mypy_output(cls) -> dict[str, list[str]]:
|
||||
with cls._lock:
|
||||
if cls._cached_output is None:
|
||||
cls._cached_output = _run_mypy()
|
||||
|
|
@ -192,7 +192,7 @@ class TestTyping(TestCase):
|
|||
with open(path) as fin:
|
||||
lines = fin.readlines()
|
||||
|
||||
errors = defaultdict(lambda: "")
|
||||
errors = defaultdict(str)
|
||||
|
||||
output_mypy = self.get_mypy_output()
|
||||
self.assertIn(path, output_mypy)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import textwrap
|
|||
import traceback
|
||||
import unittest
|
||||
import warnings
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
|
|
@ -439,7 +439,7 @@ class TestCheckpoint(TestCase):
|
|||
# get de-allocated directly. So using cuda memory usage as a proxy
|
||||
|
||||
def _do_test(fn, should_free):
|
||||
stats: List[int] = []
|
||||
stats: list[int] = []
|
||||
|
||||
def track(x, idx):
|
||||
# Track that at each step of the backward, some Tensor were
|
||||
|
|
@ -1203,7 +1203,7 @@ def f(x):
|
|||
return g(x) + 1
|
||||
"""
|
||||
|
||||
out: Dict[str, Any] = {}
|
||||
out: dict[str, Any] = {}
|
||||
scope = {"__compile_source__": source}
|
||||
exec(source, scope, out)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
|
|||
"""
|
||||
Annotate the type of getitem nodes, inferred from the type of sequence node.
|
||||
If sequence node is not annotated with a type, do nothing.
|
||||
Currently support getitem nodes from Tuple, List, and NamedTuple sequence node.
|
||||
Currently support getitem nodes from tuple, list, and NamedTuple sequence node.
|
||||
|
||||
This is helpful since annotations on local names within function are lost during FX transforms.
|
||||
Adding back known type annotation for getitem nodes to improve jit scriptability.
|
||||
|
|
@ -35,6 +35,21 @@ def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
|
|||
elif sequence_node.type._name == "List":
|
||||
assert len(parameterized_types) == 1
|
||||
node.type = parameterized_types[0]
|
||||
# Generic Alias Type
|
||||
elif hasattr(sequence_node.type, "__origin__"):
|
||||
parameterized_types = sequence_node.type.__args__
|
||||
if sequence_node.type.__origin__ is tuple:
|
||||
if len(parameterized_types) == 2 and isinstance(
|
||||
parameterized_types[1], type(...)
|
||||
):
|
||||
node.type = parameterized_types[0]
|
||||
else:
|
||||
assert len(parameterized_types) > index_node
|
||||
node_type = parameterized_types[index_node]
|
||||
node.type = node_type
|
||||
elif sequence_node.type.__origin__ is list:
|
||||
assert len(parameterized_types) == 1
|
||||
node.type = parameterized_types[0]
|
||||
# NamedTuple type
|
||||
elif hasattr(sequence_node.type, "__annotations__"):
|
||||
if sequence_node.type == torch.Tensor:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user