mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[4/N] Apply py39 ruff and pyupgrade fixes (#143257)
```torch/fx/passes/annotate_getitem_nodes.py``` was changed to support the new type hinting annotations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143257 Approved by: https://github.com/justinchuby, https://github.com/albanD
This commit is contained in:
parent
a881954b0c
commit
df458be4e5
|
|
@ -1,8 +1,3 @@
|
||||||
import dis
|
|
||||||
import inspect
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import functorch._C
|
import functorch._C
|
||||||
import torch
|
import torch
|
||||||
from functorch._C import dim as _C
|
from functorch._C import dim as _C
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import keyword
|
import keyword
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Set, Tuple, TYPE_CHECKING, Union
|
from typing import Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -73,11 +73,11 @@ class ParsedExpression:
|
||||||
"""
|
"""
|
||||||
self.has_ellipsis: bool = False
|
self.has_ellipsis: bool = False
|
||||||
self.has_ellipsis_parenthesized: Optional[bool] = None
|
self.has_ellipsis_parenthesized: Optional[bool] = None
|
||||||
self.identifiers: Set[Union[str, AnonymousAxis]] = set()
|
self.identifiers: set[Union[str, AnonymousAxis]] = set()
|
||||||
# that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
|
# that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
|
||||||
self.has_non_unitary_anonymous_axes: bool = False
|
self.has_non_unitary_anonymous_axes: bool = False
|
||||||
# composition keeps structure of composite axes, see how different corner cases are handled in tests
|
# composition keeps structure of composite axes, see how different corner cases are handled in tests
|
||||||
self.composition: List[Union[List[Union[str, AnonymousAxis]], str]] = []
|
self.composition: list[Union[list[Union[str, AnonymousAxis]], str]] = []
|
||||||
if "." in expression:
|
if "." in expression:
|
||||||
if "..." not in expression:
|
if "..." not in expression:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
@ -90,7 +90,7 @@ class ParsedExpression:
|
||||||
expression = expression.replace("...", _ellipsis)
|
expression = expression.replace("...", _ellipsis)
|
||||||
self.has_ellipsis = True
|
self.has_ellipsis = True
|
||||||
|
|
||||||
bracket_group: Optional[List[Union[str, AnonymousAxis]]] = None
|
bracket_group: Optional[list[Union[str, AnonymousAxis]]] = None
|
||||||
|
|
||||||
def add_axis_name(x: str) -> None:
|
def add_axis_name(x: str) -> None:
|
||||||
if x in self.identifiers:
|
if x in self.identifiers:
|
||||||
|
|
@ -164,7 +164,7 @@ class ParsedExpression:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_axis_name_return_reason(
|
def check_axis_name_return_reason(
|
||||||
name: str, allow_underscore: bool = False
|
name: str, allow_underscore: bool = False
|
||||||
) -> Tuple[bool, str]:
|
) -> tuple[bool, str]:
|
||||||
"""Check if the given axis name is valid, and a message explaining why if not.
|
"""Check if the given axis name is valid, and a message explaining why if not.
|
||||||
|
|
||||||
Valid axes names are python identifiers except keywords, and should not start or end with an underscore.
|
Valid axes names are python identifiers except keywords, and should not start or end with an underscore.
|
||||||
|
|
@ -174,7 +174,7 @@ class ParsedExpression:
|
||||||
allow_underscore (bool): whether axis names are allowed to start with an underscore
|
allow_underscore (bool): whether axis names are allowed to start with an underscore
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, str]: whether the axis name is valid, a message explaining why if not
|
tuple[bool, str]: whether the axis name is valid, a message explaining why if not
|
||||||
"""
|
"""
|
||||||
if not str.isidentifier(name):
|
if not str.isidentifier(name):
|
||||||
return False, "not a valid python identifier"
|
return False, "not a valid python identifier"
|
||||||
|
|
@ -211,7 +211,7 @@ class ParsedExpression:
|
||||||
|
|
||||||
def parse_pattern(
|
def parse_pattern(
|
||||||
pattern: str, axes_lengths: Mapping[str, int]
|
pattern: str, axes_lengths: Mapping[str, int]
|
||||||
) -> Tuple[ParsedExpression, ParsedExpression]:
|
) -> tuple[ParsedExpression, ParsedExpression]:
|
||||||
"""Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object.
|
"""Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -219,7 +219,7 @@ def parse_pattern(
|
||||||
axes_lengths (Mapping[str, int]): any additional length specifications for dimensions
|
axes_lengths (Mapping[str, int]): any additional length specifications for dimensions
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions
|
tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions
|
||||||
"""
|
"""
|
||||||
# adapted from einops.einops._prepare_transformation_recipe
|
# adapted from einops.einops._prepare_transformation_recipe
|
||||||
# https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/einops/einops.py
|
# https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/einops/einops.py
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from typing import Callable, Dict, List, Tuple, TYPE_CHECKING, Union
|
from typing import Callable, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from functorch._C import dim as _C
|
from functorch._C import dim as _C
|
||||||
|
|
@ -18,7 +18,6 @@ from ._parsing import (
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["rearrange"]
|
__all__ = ["rearrange"]
|
||||||
|
|
||||||
dims = _C.dims
|
dims = _C.dims
|
||||||
|
|
@ -69,9 +68,9 @@ def _create_rearrange_callable(
|
||||||
# an identity rearrangement on a 0-dimension tensor
|
# an identity rearrangement on a 0-dimension tensor
|
||||||
return lambda tensor: tensor
|
return lambda tensor: tensor
|
||||||
|
|
||||||
first_class_dims: Tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))
|
first_class_dims: tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))
|
||||||
identifier_dim_map: Dict[Union[str, AnonymousAxis], Tuple[str, ...]] = {}
|
identifier_dim_map: dict[Union[str, AnonymousAxis], tuple[str, ...]] = {}
|
||||||
anon_axes: List[AnonymousAxis] = []
|
anon_axes: list[AnonymousAxis] = []
|
||||||
|
|
||||||
# map the left-hand side identifiers to strings representing first class dims
|
# map the left-hand side identifiers to strings representing first class dims
|
||||||
dims_i = 0
|
dims_i = 0
|
||||||
|
|
@ -99,11 +98,11 @@ def _create_rearrange_callable(
|
||||||
raise ValueError(f"Unexpected dimension: {dimension}")
|
raise ValueError(f"Unexpected dimension: {dimension}")
|
||||||
|
|
||||||
def composition_to_dims(
|
def composition_to_dims(
|
||||||
composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]],
|
composition: Sequence[Union[list[Union[str, AnonymousAxis]], str]],
|
||||||
) -> List[Union[str, Tuple[str, ...]]]:
|
) -> list[Union[str, tuple[str, ...]]]:
|
||||||
"""Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first
|
"""Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first
|
||||||
class dims."""
|
class dims."""
|
||||||
dim_composition: List[Union[str, Tuple[str, ...]]] = []
|
dim_composition: list[Union[str, tuple[str, ...]]] = []
|
||||||
for dimension in composition:
|
for dimension in composition:
|
||||||
if isinstance(dimension, list):
|
if isinstance(dimension, list):
|
||||||
dim_composition.append(
|
dim_composition.append(
|
||||||
|
|
@ -152,7 +151,7 @@ def _create_rearrange_callable(
|
||||||
|
|
||||||
|
|
||||||
def rearrange(
|
def rearrange(
|
||||||
tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
|
tensor: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]],
|
||||||
pattern: str,
|
pattern: str,
|
||||||
**axes_lengths: int,
|
**axes_lengths: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -247,7 +246,7 @@ class TestActivationSparsifier(TestCase):
|
||||||
assert mask2 is None
|
assert mask2 is None
|
||||||
else:
|
else:
|
||||||
assert type(mask1) == type(mask2)
|
assert type(mask1) == type(mask2)
|
||||||
if isinstance(mask1, List):
|
if isinstance(mask1, list):
|
||||||
assert len(mask1) == len(mask2)
|
assert len(mask1) == len(mask2)
|
||||||
for idx in range(len(mask1)):
|
for idx in range(len(mask1)):
|
||||||
assert torch.all(mask1[idx] == mask2[idx])
|
assert torch.all(mask1[idx] == mask2[idx])
|
||||||
|
|
@ -258,7 +257,7 @@ class TestActivationSparsifier(TestCase):
|
||||||
for state in state_dict["state"].values():
|
for state in state_dict["state"].values():
|
||||||
mask = state["mask"]
|
mask = state["mask"]
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
if isinstance(mask, List):
|
if isinstance(mask, list):
|
||||||
for idx in range(len(mask)):
|
for idx in range(len(mask)):
|
||||||
assert mask[idx].is_sparse
|
assert mask[idx].is_sparse
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
@ -73,7 +72,7 @@ class TestBaseDataScheduler(TestCase):
|
||||||
|
|
||||||
def _get_name_data_config(self, some_data, defaults):
|
def _get_name_data_config(self, some_data, defaults):
|
||||||
config = copy.deepcopy(defaults)
|
config = copy.deepcopy(defaults)
|
||||||
if isinstance(some_data, Tuple):
|
if isinstance(some_data, tuple):
|
||||||
# dealing with data_list
|
# dealing with data_list
|
||||||
name, data = some_data
|
name, data = some_data
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ import copy
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
@ -54,7 +53,7 @@ class _BaseDataSparsiferTestCase(TestCase):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_name_data_config(some_data, defaults=None):
|
def _get_name_data_config(some_data, defaults=None):
|
||||||
if isinstance(some_data, Tuple):
|
if isinstance(some_data, tuple):
|
||||||
# dealing with data_list
|
# dealing with data_list
|
||||||
name, data = some_data
|
name, data = some_data
|
||||||
config = defaults
|
config = defaults
|
||||||
|
|
@ -482,8 +481,9 @@ class TestBaseDataSparsifier(_BaseDataSparsiferTestCase):
|
||||||
nn.Parameter(torch.randn(4, 4)),
|
nn.Parameter(torch.randn(4, 4)),
|
||||||
nn.Parameter(torch.randn(5, 5)),
|
nn.Parameter(torch.randn(5, 5)),
|
||||||
)
|
)
|
||||||
param4, param5 = nn.Parameter(torch.randn(1, 1)), nn.Parameter(
|
param4, param5 = (
|
||||||
torch.randn(4, 4)
|
nn.Parameter(torch.randn(1, 1)),
|
||||||
|
nn.Parameter(torch.randn(4, 4)),
|
||||||
)
|
)
|
||||||
data_list = [("param1", param1), ("param2", param2), ("param3", param3)]
|
data_list = [("param1", param1), ("param2", param2), ("param3", param3)]
|
||||||
defaults = {"test": 3}
|
defaults = {"test": 3}
|
||||||
|
|
@ -585,8 +585,9 @@ class TestNormDataSparsifiers(_NormDataSparsifierTestCase):
|
||||||
nn.Parameter(torch.randn(4, 4)),
|
nn.Parameter(torch.randn(4, 4)),
|
||||||
nn.Parameter(torch.randn(5, 5)),
|
nn.Parameter(torch.randn(5, 5)),
|
||||||
)
|
)
|
||||||
param4, param5 = nn.Parameter(torch.randn(10, 10)), nn.Parameter(
|
param4, param5 = (
|
||||||
torch.randn(4, 4)
|
nn.Parameter(torch.randn(10, 10)),
|
||||||
|
nn.Parameter(torch.randn(4, 4)),
|
||||||
)
|
)
|
||||||
data_list = [("param1", param1), ("param2", param2), ("param3", param3)]
|
data_list = [("param1", param1), ("param2", param2), ("param3", param3)]
|
||||||
defaults = {
|
defaults = {
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,17 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import typing
|
import typing
|
||||||
from typing import List, Optional, Sequence, Union # noqa: F401
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, types
|
from torch import Tensor, types
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||||
|
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
mutates_args = {}
|
mutates_args = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,8 @@ import functools
|
||||||
import itertools
|
import itertools
|
||||||
import unittest
|
import unittest
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Iterable, List, Optional, Tuple, Union
|
from collections.abc import Iterable
|
||||||
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||||
# Owner(s): ["oncall: distributed"]
|
# Owner(s): ["oncall: distributed"]
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, Optional, Sequence
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Callable, Dict, Optional
|
||||||
from unittest import skip
|
from unittest import skip
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,9 @@ import inspect
|
||||||
import io
|
import io
|
||||||
import operator
|
import operator
|
||||||
import unittest
|
import unittest
|
||||||
|
from collections.abc import Sequence
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Sequence
|
from typing import Dict, List
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
||||||
|
|
@ -21,10 +21,11 @@ import warnings
|
||||||
import weakref
|
import weakref
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from collections.abc import Iterator
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import Enum, IntEnum
|
from enum import Enum, IntEnum
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict, Iterator, List, Literal, Tuple, TypedDict
|
from typing import Any, Dict, List, Literal, Tuple, TypedDict
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
# Owner(s): ["oncall: export"]
|
# Owner(s): ["oncall: export"]
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Any, Dict, Optional, OrderedDict, Tuple
|
from collections import OrderedDict
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._export.passes.lift_constants_pass import (
|
from torch._export.passes.lift_constants_pass import (
|
||||||
|
|
|
||||||
|
|
@ -14,8 +14,7 @@ import random
|
||||||
import types
|
import types
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
from collections import namedtuple
|
from collections import namedtuple, OrderedDict
|
||||||
from typing import OrderedDict
|
|
||||||
from unittest.case import skipIf
|
from unittest.case import skipIf
|
||||||
|
|
||||||
from common_utils import (
|
from common_utils import (
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
# Owner(s): ["module: fx"]
|
# Owner(s): ["module: fx"]
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Mapping
|
from collections.abc import Mapping
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
|
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,8 @@
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
|
from collections.abc import Sequence
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Sequence
|
|
||||||
from unittest import skip
|
from unittest import skip
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from collections import namedtuple, OrderedDict
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -55,11 +55,11 @@ class ToyModel(nn.Module):
|
||||||
|
|
||||||
def forward_hook(
|
def forward_hook(
|
||||||
self: TestCase,
|
self: TestCase,
|
||||||
fired_hooks: List[int],
|
fired_hooks: list[int],
|
||||||
expected_module: nn.Module,
|
expected_module: nn.Module,
|
||||||
hook_id: int,
|
hook_id: int,
|
||||||
module: nn.Module,
|
module: nn.Module,
|
||||||
inp: Tuple[torch.Tensor],
|
inp: tuple[torch.Tensor],
|
||||||
out: torch.Tensor,
|
out: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
fired_hooks.append(hook_id)
|
fired_hooks.append(hook_id)
|
||||||
|
|
@ -69,11 +69,11 @@ def forward_hook(
|
||||||
|
|
||||||
def forward_pre_hook(
|
def forward_pre_hook(
|
||||||
self: TestCase,
|
self: TestCase,
|
||||||
fired_hooks: List[int],
|
fired_hooks: list[int],
|
||||||
expected_module: nn.Module,
|
expected_module: nn.Module,
|
||||||
hook_id: int,
|
hook_id: int,
|
||||||
module: nn.Module,
|
module: nn.Module,
|
||||||
inp: Tuple[torch.Tensor],
|
inp: tuple[torch.Tensor],
|
||||||
) -> None:
|
) -> None:
|
||||||
fired_hooks.append(hook_id)
|
fired_hooks.append(hook_id)
|
||||||
self.assertEqual(id(module), id(expected_module))
|
self.assertEqual(id(module), id(expected_module))
|
||||||
|
|
@ -82,12 +82,12 @@ def forward_pre_hook(
|
||||||
|
|
||||||
def full_backward_hook(
|
def full_backward_hook(
|
||||||
self: TestCase,
|
self: TestCase,
|
||||||
fired_hooks: List[int],
|
fired_hooks: list[int],
|
||||||
expected_module: nn.Module,
|
expected_module: nn.Module,
|
||||||
hook_id: int,
|
hook_id: int,
|
||||||
module: nn.Module,
|
module: nn.Module,
|
||||||
grad_input: Tuple[torch.Tensor],
|
grad_input: tuple[torch.Tensor],
|
||||||
grad_output: Tuple[torch.Tensor],
|
grad_output: tuple[torch.Tensor],
|
||||||
) -> None:
|
) -> None:
|
||||||
fired_hooks.append(hook_id)
|
fired_hooks.append(hook_id)
|
||||||
self.assertEqual(id(module), id(expected_module))
|
self.assertEqual(id(module), id(expected_module))
|
||||||
|
|
@ -97,11 +97,11 @@ def full_backward_hook(
|
||||||
|
|
||||||
def full_backward_pre_hook(
|
def full_backward_pre_hook(
|
||||||
self: TestCase,
|
self: TestCase,
|
||||||
fired_hooks: List[int],
|
fired_hooks: list[int],
|
||||||
expected_module: nn.Module,
|
expected_module: nn.Module,
|
||||||
hook_id: int,
|
hook_id: int,
|
||||||
module: nn.Module,
|
module: nn.Module,
|
||||||
grad_input: Tuple[torch.Tensor],
|
grad_input: tuple[torch.Tensor],
|
||||||
) -> None:
|
) -> None:
|
||||||
fired_hooks.append(hook_id)
|
fired_hooks.append(hook_id)
|
||||||
self.assertEqual(id(module), id(expected_module))
|
self.assertEqual(id(module), id(expected_module))
|
||||||
|
|
@ -122,8 +122,8 @@ class KwargModel(nn.Module):
|
||||||
def internal_forward_hook(
|
def internal_forward_hook(
|
||||||
self,
|
self,
|
||||||
module: nn.Module,
|
module: nn.Module,
|
||||||
args: Tuple[torch.Tensor],
|
args: tuple[torch.Tensor],
|
||||||
kwargs: Dict[str, Any],
|
kwargs: dict[str, Any],
|
||||||
out: torch.Tensor,
|
out: torch.Tensor,
|
||||||
):
|
):
|
||||||
return out + kwargs["bias"]
|
return out + kwargs["bias"]
|
||||||
|
|
@ -142,13 +142,13 @@ class FailsInForwardModel(nn.Module):
|
||||||
|
|
||||||
def kwarg_forward_pre_hook(
|
def kwarg_forward_pre_hook(
|
||||||
self: TestCase,
|
self: TestCase,
|
||||||
fired_hooks: List[int],
|
fired_hooks: list[int],
|
||||||
expected_module: nn.Module,
|
expected_module: nn.Module,
|
||||||
hook_id: int,
|
hook_id: int,
|
||||||
module: nn.Module,
|
module: nn.Module,
|
||||||
args: Tuple[torch.Tensor],
|
args: tuple[torch.Tensor],
|
||||||
kwargs: Dict[str, Any],
|
kwargs: dict[str, Any],
|
||||||
) -> Tuple[Any, Any]:
|
) -> tuple[Any, Any]:
|
||||||
fired_hooks.append(hook_id)
|
fired_hooks.append(hook_id)
|
||||||
self.assertEqual(id(module), id(expected_module))
|
self.assertEqual(id(module), id(expected_module))
|
||||||
self.assertEqual(len(args), 1)
|
self.assertEqual(len(args), 1)
|
||||||
|
|
@ -158,12 +158,12 @@ def kwarg_forward_pre_hook(
|
||||||
|
|
||||||
def kwarg_forward_hook(
|
def kwarg_forward_hook(
|
||||||
self: TestCase,
|
self: TestCase,
|
||||||
fired_hooks: List[int],
|
fired_hooks: list[int],
|
||||||
expected_module: nn.Module,
|
expected_module: nn.Module,
|
||||||
hook_id: int,
|
hook_id: int,
|
||||||
module: nn.Module,
|
module: nn.Module,
|
||||||
args: Tuple[torch.Tensor],
|
args: tuple[torch.Tensor],
|
||||||
kwargs: Dict[str, Any],
|
kwargs: dict[str, Any],
|
||||||
out: torch.Tensor,
|
out: torch.Tensor,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
fired_hooks.append(hook_id)
|
fired_hooks.append(hook_id)
|
||||||
|
|
@ -188,7 +188,7 @@ class DummyContextManager:
|
||||||
class TestModuleHooks(TestCase):
|
class TestModuleHooks(TestCase):
|
||||||
@parametrize_test("named_tuple", (True, False))
|
@parametrize_test("named_tuple", (True, False))
|
||||||
def test_forward_hooks(self, named_tuple):
|
def test_forward_hooks(self, named_tuple):
|
||||||
fired_hooks: List[int] = []
|
fired_hooks: list[int] = []
|
||||||
model = ToyModel(named_tuple)
|
model = ToyModel(named_tuple)
|
||||||
x = torch.randn(10, 10)
|
x = torch.randn(10, 10)
|
||||||
hook = partial(forward_hook, self, fired_hooks, model.net1.seq2)
|
hook = partial(forward_hook, self, fired_hooks, model.net1.seq2)
|
||||||
|
|
@ -210,7 +210,7 @@ class TestModuleHooks(TestCase):
|
||||||
|
|
||||||
@parametrize_test("named_tuple", (True, False))
|
@parametrize_test("named_tuple", (True, False))
|
||||||
def test_forward_pre_hooks(self, named_tuple):
|
def test_forward_pre_hooks(self, named_tuple):
|
||||||
fired_hooks: List[int] = []
|
fired_hooks: list[int] = []
|
||||||
model = ToyModel(named_tuple)
|
model = ToyModel(named_tuple)
|
||||||
x = torch.randn(10, 10)
|
x = torch.randn(10, 10)
|
||||||
hook = partial(forward_pre_hook, self, fired_hooks, model.net2.seq1)
|
hook = partial(forward_pre_hook, self, fired_hooks, model.net2.seq1)
|
||||||
|
|
@ -232,7 +232,7 @@ class TestModuleHooks(TestCase):
|
||||||
|
|
||||||
@parametrize_test("named_tuple", (True, False))
|
@parametrize_test("named_tuple", (True, False))
|
||||||
def test_full_backward_hooks(self, named_tuple):
|
def test_full_backward_hooks(self, named_tuple):
|
||||||
fired_hooks: List[int] = []
|
fired_hooks: list[int] = []
|
||||||
model = ToyModel(named_tuple)
|
model = ToyModel(named_tuple)
|
||||||
x = torch.randn(10, 10)
|
x = torch.randn(10, 10)
|
||||||
hook = partial(full_backward_hook, self, fired_hooks, model.net1)
|
hook = partial(full_backward_hook, self, fired_hooks, model.net1)
|
||||||
|
|
@ -254,7 +254,7 @@ class TestModuleHooks(TestCase):
|
||||||
|
|
||||||
@parametrize_test("named_tuple", (True, False))
|
@parametrize_test("named_tuple", (True, False))
|
||||||
def test_full_backward_pre_hooks(self, named_tuple):
|
def test_full_backward_pre_hooks(self, named_tuple):
|
||||||
fired_hooks: List[int] = []
|
fired_hooks: list[int] = []
|
||||||
model = ToyModel(named_tuple)
|
model = ToyModel(named_tuple)
|
||||||
x = torch.randn(10, 10)
|
x = torch.randn(10, 10)
|
||||||
hook = partial(full_backward_pre_hook, self, fired_hooks, model.net1)
|
hook = partial(full_backward_pre_hook, self, fired_hooks, model.net1)
|
||||||
|
|
@ -294,7 +294,7 @@ class TestModuleHooks(TestCase):
|
||||||
|
|
||||||
@parametrize_test("named_tuple", (True, False))
|
@parametrize_test("named_tuple", (True, False))
|
||||||
def test_mixed_hooks(self, named_tuple):
|
def test_mixed_hooks(self, named_tuple):
|
||||||
fired_hooks: List[int] = []
|
fired_hooks: list[int] = []
|
||||||
model = ToyModel(named_tuple)
|
model = ToyModel(named_tuple)
|
||||||
x = torch.randn(10, 10)
|
x = torch.randn(10, 10)
|
||||||
model.register_forward_pre_hook(
|
model.register_forward_pre_hook(
|
||||||
|
|
@ -319,7 +319,7 @@ class TestModuleHooks(TestCase):
|
||||||
|
|
||||||
def test_kwarg_hooks(self):
|
def test_kwarg_hooks(self):
|
||||||
# 1. test forward pre hook
|
# 1. test forward pre hook
|
||||||
fired_hooks: List[int] = []
|
fired_hooks: list[int] = []
|
||||||
x: torch.Tensor = torch.ones(10, 10)
|
x: torch.Tensor = torch.ones(10, 10)
|
||||||
bias: torch.Tensor = torch.ones(10, 10)
|
bias: torch.Tensor = torch.ones(10, 10)
|
||||||
model = KwargModel()
|
model = KwargModel()
|
||||||
|
|
@ -336,7 +336,7 @@ class TestModuleHooks(TestCase):
|
||||||
self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5)
|
self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5)
|
||||||
|
|
||||||
# 2. test forward pre and forward hooks
|
# 2. test forward pre and forward hooks
|
||||||
fired_hooks: List[int] = []
|
fired_hooks: list[int] = []
|
||||||
x: torch.Tensor = torch.ones(10, 10)
|
x: torch.Tensor = torch.ones(10, 10)
|
||||||
bias: torch.Tensor = torch.ones(10, 10)
|
bias: torch.Tensor = torch.ones(10, 10)
|
||||||
model = KwargModel()
|
model = KwargModel()
|
||||||
|
|
@ -372,7 +372,7 @@ class TestModuleHooks(TestCase):
|
||||||
|
|
||||||
def test_remove_kwarg_hooks(self):
|
def test_remove_kwarg_hooks(self):
|
||||||
# test forward pre and forward hooks
|
# test forward pre and forward hooks
|
||||||
fired_hooks: List[int] = []
|
fired_hooks: list[int] = []
|
||||||
x: torch.Tensor = torch.ones(10, 10)
|
x: torch.Tensor = torch.ones(10, 10)
|
||||||
bias: torch.Tensor = torch.ones(10, 10)
|
bias: torch.Tensor = torch.ones(10, 10)
|
||||||
model = KwargModel()
|
model = KwargModel()
|
||||||
|
|
@ -1217,8 +1217,8 @@ class TestModuleGlobalHooks(TestCase):
|
||||||
def test_module_global_hooks_with_kwargs(self):
|
def test_module_global_hooks_with_kwargs(self):
|
||||||
def kwarg_global_forward_hook(
|
def kwarg_global_forward_hook(
|
||||||
module: nn.Module,
|
module: nn.Module,
|
||||||
args: Tuple[torch.Tensor],
|
args: tuple[torch.Tensor],
|
||||||
kwargs: Dict[str, Any],
|
kwargs: dict[str, Any],
|
||||||
out: torch.Tensor,
|
out: torch.Tensor,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
out = out + kwargs["bias"]
|
out = out + kwargs["bias"]
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import random
|
import random
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.utils.rnn as rnn_utils
|
import torch.nn.utils.rnn as rnn_utils
|
||||||
|
|
@ -219,7 +218,7 @@ class PackedSequenceTest(TestCase):
|
||||||
# more dimensions
|
# more dimensions
|
||||||
maxlen = 9
|
maxlen = 9
|
||||||
for num_dim in (0, 1, 2, 3):
|
for num_dim in (0, 1, 2, 3):
|
||||||
sequences: List[torch.Tensor] = []
|
sequences: list[torch.Tensor] = []
|
||||||
trailing_dims = [4] * num_dim
|
trailing_dims = [4] * num_dim
|
||||||
for i in range(1, maxlen + 1):
|
for i in range(1, maxlen + 1):
|
||||||
seq_len = i * i
|
seq_len = i * i
|
||||||
|
|
|
||||||
|
|
@ -1573,10 +1573,10 @@ torch.cuda.synchronize()
|
||||||
return torch.stack([col, col + 2], 1).view(2, 2, 2, 2)
|
return torch.stack([col, col + 2], 1).view(2, 2, 2, 2)
|
||||||
|
|
||||||
if adaptive:
|
if adaptive:
|
||||||
cls_name = "AdaptiveMaxPool{}d".format(num_dim) # noqa: UP032
|
cls_name = f"AdaptiveMaxPool{num_dim}d"
|
||||||
else:
|
else:
|
||||||
# FIXME(#105716): Test fails when using f-string
|
# FIXME(#105716): Test fails when using f-string
|
||||||
cls_name = "MaxPool{}d".format(num_dim) # noqa: UP032
|
cls_name = f"MaxPool{num_dim}d"
|
||||||
module_cls = getattr(nn, cls_name)
|
module_cls = getattr(nn, cls_name)
|
||||||
module = module_cls(2, return_indices=True).to(device, dtype=dtype)
|
module = module_cls(2, return_indices=True).to(device, dtype=dtype)
|
||||||
numel = 4 ** (num_dim + 1)
|
numel = 4 ** (num_dim + 1)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
# Owner(s): ["module: onnx"]
|
# Owner(s): ["module: onnx"]
|
||||||
"""Unit tests for the internal registration wrapper module."""
|
"""Unit tests for the internal registration wrapper module."""
|
||||||
|
|
||||||
from typing import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from torch.onnx import errors
|
from torch.onnx import errors
|
||||||
from torch.onnx._internal import registration
|
from torch.onnx._internal import registration
|
||||||
|
|
|
||||||
|
|
@ -10,19 +10,8 @@ import logging
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
from typing import (
|
from collections.abc import Collection, Iterable, Mapping, Sequence
|
||||||
Any,
|
from typing import Any, Callable, List, Optional, Tuple, Type, Union
|
||||||
Callable,
|
|
||||||
Collection,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,8 @@
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import List, Mapping, Tuple
|
from collections.abc import Mapping
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
import onnx_test_common
|
import onnx_test_common
|
||||||
import parameterized
|
import parameterized
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import itertools
|
||||||
import unittest
|
import unittest
|
||||||
import unittest.mock
|
import unittest.mock
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
@ -26,6 +26,10 @@ from torch.onnx._internal import registration
|
||||||
from torch.testing._internal import common_quantization, common_utils, jit_utils
|
from torch.testing._internal import common_quantization, common_utils, jit_utils
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
|
||||||
def export_to_onnx(
|
def export_to_onnx(
|
||||||
model: Union[torch.nn.Module, torch.jit.ScriptFunction],
|
model: Union[torch.nn.Module, torch.jit.ScriptFunction],
|
||||||
input: Union[torch.Tensor, Tuple[torch.Tensor]],
|
input: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# Owner(s): ["oncall: package/deploy"]
|
# Owner(s): ["oncall: package/deploy"]
|
||||||
|
|
||||||
from typing import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
from torch.package import GlobGroup
|
from torch.package import GlobGroup
|
||||||
from torch.testing._internal.common_utils import run_tests
|
from torch.testing._internal.common_utils import run_tests
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -51,7 +51,7 @@ from torch.testing._internal.common_utils import (
|
||||||
from torch.utils._triton import has_triton
|
from torch.utils._triton import has_triton
|
||||||
|
|
||||||
|
|
||||||
Json = Dict[str, Any]
|
Json = dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class TestExecutionTrace(TestCase):
|
class TestExecutionTrace(TestCase):
|
||||||
|
|
@ -97,7 +97,7 @@ class TestExecutionTrace(TestCase):
|
||||||
nodes = et_graph["nodes"]
|
nodes = et_graph["nodes"]
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
def get_execution_trace_rf_ids(self, nodes: List[Json]) -> List[int]:
|
def get_execution_trace_rf_ids(self, nodes: list[Json]) -> list[int]:
|
||||||
"""Returns a sorted list of rf_id (record function ids) in execution trace"""
|
"""Returns a sorted list of rf_id (record function ids) in execution trace"""
|
||||||
|
|
||||||
def get_rf_id(node):
|
def get_rf_id(node):
|
||||||
|
|
@ -115,7 +115,7 @@ class TestExecutionTrace(TestCase):
|
||||||
)
|
)
|
||||||
return sorted(rf_id for rf_id in rf_ids_ if rf_id is not None)
|
return sorted(rf_id for rf_id in rf_ids_ if rf_id is not None)
|
||||||
|
|
||||||
def get_kineto_rf_ids(self, events: List[Json]) -> List[int]:
|
def get_kineto_rf_ids(self, events: list[Json]) -> list[int]:
|
||||||
"""Returns a sorted list of Record function IDs for CPU operators and user annotations"""
|
"""Returns a sorted list of Record function IDs for CPU operators and user annotations"""
|
||||||
ops_and_annotations = (
|
ops_and_annotations = (
|
||||||
e for e in events if e.get("cat", "") in ["cpu_op", "user_annotation"]
|
e for e in events if e.get("cat", "") in ["cpu_op", "user_annotation"]
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,8 @@ import itertools as it
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple
|
from collections.abc import Iterator
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._C._profiler import _EventType, _TensorMetadata
|
from torch._C._profiler import _EventType, _TensorMetadata
|
||||||
|
|
@ -309,9 +310,9 @@ class TestDataFlow(TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def formatSchemas(
|
def formatSchemas(
|
||||||
prof: torch.profiler.profile, indent: int = 12
|
prof: torch.profiler.profile, indent: int = 12
|
||||||
) -> Tuple[Tuple[str, Tuple[bool, ...]], ...]:
|
) -> tuple[tuple[str, tuple[bool, ...]], ...]:
|
||||||
tree = prof.profiler.kineto_results.experimental_event_tree()
|
tree = prof.profiler.kineto_results.experimental_event_tree()
|
||||||
out: List[Tuple[str, Tuple[bool, ...]]] = []
|
out: list[tuple[str, tuple[bool, ...]]] = []
|
||||||
for node in _utils.traverse_dfs(tree):
|
for node in _utils.traverse_dfs(tree):
|
||||||
if node.tag == _EventType.TorchOp:
|
if node.tag == _EventType.TorchOp:
|
||||||
e = node.extra_fields
|
e = node.extra_fields
|
||||||
|
|
@ -327,8 +328,8 @@ class TestDataFlow(TestCase):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _run_and_format_data_flow(
|
def _run_and_format_data_flow(
|
||||||
inputs: Dict[str, torch.Tensor],
|
inputs: dict[str, torch.Tensor],
|
||||||
f: Callable[..., Optional[Dict[str, torch.Tensor]]],
|
f: Callable[..., Optional[dict[str, torch.Tensor]]],
|
||||||
indent: int = 12,
|
indent: int = 12,
|
||||||
) -> str:
|
) -> str:
|
||||||
with profile() as prof:
|
with profile() as prof:
|
||||||
|
|
@ -339,7 +340,7 @@ class TestDataFlow(TestCase):
|
||||||
graph = memory_profile._data_flow_graph
|
graph = memory_profile._data_flow_graph
|
||||||
storage_to_id = {key.storage.ptr: key.id for key in graph._active_version}
|
storage_to_id = {key.storage.ptr: key.id for key in graph._active_version}
|
||||||
|
|
||||||
lines: List[str] = []
|
lines: list[str] = []
|
||||||
for name, t in it.chain(inputs.items(), outputs.items()):
|
for name, t in it.chain(inputs.items(), outputs.items()):
|
||||||
lines.append(f"{name + ':':<8} T{storage_to_id[t.storage().data_ptr()]}")
|
lines.append(f"{name + ':':<8} T{storage_to_id[t.storage().data_ptr()]}")
|
||||||
if t.grad is not None:
|
if t.grad is not None:
|
||||||
|
|
@ -352,7 +353,7 @@ class TestDataFlow(TestCase):
|
||||||
for node in graph.flow_nodes:
|
for node in graph.flow_nodes:
|
||||||
destroyed = {k for k, v in node._edges.items() if v.is_deletion}
|
destroyed = {k for k, v in node._edges.items() if v.is_deletion}
|
||||||
|
|
||||||
inputs: List[str] = []
|
inputs: list[str] = []
|
||||||
for key, (_, v) in node.inputs.items():
|
for key, (_, v) in node.inputs.items():
|
||||||
inputs.append(f"T{key.id}(v{v}{'*' if key in destroyed else ''})")
|
inputs.append(f"T{key.id}(v{v}{'*' if key in destroyed else ''})")
|
||||||
|
|
||||||
|
|
@ -833,7 +834,7 @@ class TestMemoryProfilerE2E(TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _lookup_tensor_categories(
|
def _lookup_tensor_categories(
|
||||||
t: torch.Tensor, memory_profile: _memory_profiler.MemoryProfile
|
t: torch.Tensor, memory_profile: _memory_profiler.MemoryProfile
|
||||||
) -> Dict[_memory_profiler.TensorAndID, Optional[_memory_profiler.Category]]:
|
) -> dict[_memory_profiler.TensorAndID, Optional[_memory_profiler.Category]]:
|
||||||
storage = t.storage()
|
storage = t.storage()
|
||||||
if storage is None:
|
if storage is None:
|
||||||
raise ValueError("Cannot look up uninitialized Tensor.")
|
raise ValueError("Cannot look up uninitialized Tensor.")
|
||||||
|
|
@ -889,7 +890,7 @@ class TestMemoryProfilerE2E(TestCase):
|
||||||
fn(lambda name: record_ops.mark_region(f"-- {name} ".ljust(105, "-")))
|
fn(lambda name: record_ops.mark_region(f"-- {name} ".ljust(105, "-")))
|
||||||
|
|
||||||
memory_profile = prof._memory_profile()
|
memory_profile = prof._memory_profile()
|
||||||
ptr_pair_to_key: Dict[Tuple[int, int], _memory_profiler.TensorKey] = {}
|
ptr_pair_to_key: dict[tuple[int, int], _memory_profiler.TensorKey] = {}
|
||||||
snapshot = memory_profile._category_snapshot()
|
snapshot = memory_profile._category_snapshot()
|
||||||
|
|
||||||
# Build map from observed live Tensors to the memory profiler's
|
# Build map from observed live Tensors to the memory profiler's
|
||||||
|
|
@ -922,7 +923,7 @@ class TestMemoryProfilerE2E(TestCase):
|
||||||
|
|
||||||
return f"{target_key.storage.allocation_id} ({','.join(categories)})"
|
return f"{target_key.storage.allocation_id} ({','.join(categories)})"
|
||||||
|
|
||||||
out: List[str] = []
|
out: list[str] = []
|
||||||
for name, inputs, outputs in record_ops.results:
|
for name, inputs, outputs in record_ops.results:
|
||||||
if inputs or outputs:
|
if inputs or outputs:
|
||||||
# PyTorch ops
|
# PyTorch ops
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ import threading
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import expecttest
|
import expecttest
|
||||||
|
|
@ -2277,7 +2277,7 @@ class MockProfilerEvent:
|
||||||
start_time_ns: int
|
start_time_ns: int
|
||||||
duration_time_ns: int
|
duration_time_ns: int
|
||||||
correlation_id: int = 0
|
correlation_id: int = 0
|
||||||
children: List["MockProfilerEvent"] = field(default_factory=list)
|
children: list["MockProfilerEvent"] = field(default_factory=list)
|
||||||
parent: Optional["MockProfilerEvent"] = None
|
parent: Optional["MockProfilerEvent"] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -2301,7 +2301,7 @@ class MockNode:
|
||||||
|
|
||||||
@unittest.skipIf(sys.version_info >= (3, 13), "segfaults")
|
@unittest.skipIf(sys.version_info >= (3, 13), "segfaults")
|
||||||
class TestExperimentalUtils(TestCase):
|
class TestExperimentalUtils(TestCase):
|
||||||
def make_tree(self) -> List[MockNode]:
|
def make_tree(self) -> list[MockNode]:
|
||||||
tree = {
|
tree = {
|
||||||
"root_0": {
|
"root_0": {
|
||||||
"1": {"2": {}},
|
"1": {"2": {}},
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
None
|
None
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.optim
|
import torch.optim
|
||||||
|
|
@ -29,7 +29,7 @@ from torch.profiler import kineto_available, record_function
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||||
|
|
||||||
|
|
||||||
Json = Dict[str, Any]
|
Json = dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class TestRecordFunction(TestCase):
|
class TestRecordFunction(TestCase):
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
import unittest
|
import unittest
|
||||||
import weakref
|
import weakref
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -30,7 +30,7 @@ from torch.profiler import _utils, profile
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||||
|
|
||||||
|
|
||||||
Json = Dict[str, Any]
|
Json = dict[str, Any]
|
||||||
|
|
||||||
from torch._C._profiler import _ExtraFields_PyCall
|
from torch._C._profiler import _ExtraFields_PyCall
|
||||||
|
|
||||||
|
|
@ -455,7 +455,7 @@ class TestTorchTidyProfiler(TestCase):
|
||||||
|
|
||||||
nodes = p.profiler.kineto_results.experimental_event_tree()
|
nodes = p.profiler.kineto_results.experimental_event_tree()
|
||||||
|
|
||||||
def find_chain(names: List[str]):
|
def find_chain(names: list[str]):
|
||||||
out = []
|
out = []
|
||||||
for name in names:
|
for name in names:
|
||||||
root = [out[-1]] if out else nodes
|
root = [out[-1]] if out else nodes
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ from copy import deepcopy
|
||||||
from functools import partial, reduce
|
from functools import partial, reduce
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from operator import mul
|
from operator import mul
|
||||||
from typing import List, Tuple, TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.autograd._functions
|
import torch.autograd._functions
|
||||||
|
|
@ -10091,14 +10091,14 @@ TORCH_LIBRARY(test_multigrad_all_hooks, m) {
|
||||||
|
|
||||||
def test_multi_grad_any_hooks(self):
|
def test_multi_grad_any_hooks(self):
|
||||||
hook_id = 0
|
hook_id = 0
|
||||||
any_hook_handles: List[RemovableHandle] = []
|
any_hook_handles: list[RemovableHandle] = []
|
||||||
|
|
||||||
class MultiOutputModule(nn.Module):
|
class MultiOutputModule(nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lin = nn.Linear(3, 3)
|
self.lin = nn.Linear(3, 3)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
z = self.lin(x)
|
z = self.lin(x)
|
||||||
out = torch.sin(z), torch.cos(z)
|
out = torch.sin(z), torch.cos(z)
|
||||||
nonlocal hook_id
|
nonlocal hook_id
|
||||||
|
|
@ -10123,7 +10123,7 @@ TORCH_LIBRARY(test_multigrad_all_hooks, m) {
|
||||||
z = y[0] + y[1]
|
z = y[0] + y[1]
|
||||||
return self.mod2(z)
|
return self.mod2(z)
|
||||||
|
|
||||||
hook_order: List[int] = []
|
hook_order: list[int] = []
|
||||||
hook_count = 0
|
hook_count = 0
|
||||||
|
|
||||||
def hook(hook_id: int, *unused):
|
def hook(hook_id: int, *unused):
|
||||||
|
|
@ -13975,7 +13975,7 @@ class TestSelectiveActivationCheckpoint(TestCase):
|
||||||
counter = [0]
|
counter = [0]
|
||||||
|
|
||||||
@torch.library.custom_op("mylib::sin_with_extra", mutates_args=())
|
@torch.library.custom_op("mylib::sin_with_extra", mutates_args=())
|
||||||
def sin_with_extra(x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
def sin_with_extra(x: torch.Tensor) -> tuple[torch.Tensor, int]:
|
||||||
counter[0] += 1
|
counter[0] += 1
|
||||||
return x.sin(), 2
|
return x.sin(), 2
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Dict, List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.bundled_inputs
|
import torch.utils.bundled_inputs
|
||||||
|
|
@ -32,7 +32,7 @@ class TestBundledInputs(TestCase):
|
||||||
|
|
||||||
sm = torch.jit.script(SingleTensorModel())
|
sm = torch.jit.script(SingleTensorModel())
|
||||||
original_size = model_size(sm)
|
original_size = model_size(sm)
|
||||||
get_expr: List[str] = []
|
get_expr: list[str] = []
|
||||||
samples = [
|
samples = [
|
||||||
# Tensor with small numel and small storage.
|
# Tensor with small numel and small storage.
|
||||||
(torch.tensor([1]),),
|
(torch.tensor([1]),),
|
||||||
|
|
@ -328,8 +328,8 @@ class TestBundledInputs(TestCase):
|
||||||
class MyModel(torch.nn.Module):
|
class MyModel(torch.nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
arg1: Optional[Dict[str, torch.Tensor]],
|
arg1: Optional[dict[str, torch.Tensor]],
|
||||||
arg2: Optional[List[torch.Tensor]],
|
arg2: Optional[list[torch.Tensor]],
|
||||||
arg3: torch.Tensor,
|
arg3: torch.Tensor,
|
||||||
):
|
):
|
||||||
if arg1 is None:
|
if arg1 is None:
|
||||||
|
|
@ -393,7 +393,7 @@ class TestBundledInputs(TestCase):
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
out: List[str] = []
|
out: list[str] = []
|
||||||
sm = torch.jit.script(MyModel())
|
sm = torch.jit.script(MyModel())
|
||||||
original_size = model_size(sm)
|
original_size = model_size(sm)
|
||||||
small_inputs = (
|
small_inputs = (
|
||||||
|
|
|
||||||
|
|
@ -4054,11 +4054,7 @@ class TestCudaMallocAsync(TestCase):
|
||||||
that the pytorch call is returning a correct list of UUIDs.
|
that the pytorch call is returning a correct list of UUIDs.
|
||||||
"""
|
"""
|
||||||
cmd = "rocminfo | grep -o 'Uuid:.*GPU-.*' | sed 's/Uuid:.*GPU-//'"
|
cmd = "rocminfo | grep -o 'Uuid:.*GPU-.*' | sed 's/Uuid:.*GPU-//'"
|
||||||
uuids = (
|
uuids = subprocess.check_output(cmd, shell=True, text=True).strip().split("\n")
|
||||||
subprocess.check_output(cmd, shell=True, universal_newlines=True)
|
|
||||||
.strip()
|
|
||||||
.split("\n")
|
|
||||||
)
|
|
||||||
uuids = [s.strip() for s in uuids]
|
uuids = [s.strip() for s in uuids]
|
||||||
raw_uuids = torch.cuda._raw_device_uuid_amdsmi()
|
raw_uuids = torch.cuda._raw_device_uuid_amdsmi()
|
||||||
for uuid in uuids:
|
for uuid in uuids:
|
||||||
|
|
@ -4082,11 +4078,7 @@ import os
|
||||||
print(f"{torch.cuda.device_count()}")
|
print(f"{torch.cuda.device_count()}")
|
||||||
"""
|
"""
|
||||||
cmd = "rocminfo | grep -o 'Uuid:.*GPU-.*' | sed 's/Uuid://'"
|
cmd = "rocminfo | grep -o 'Uuid:.*GPU-.*' | sed 's/Uuid://'"
|
||||||
uuids = (
|
uuids = subprocess.check_output(cmd, shell=True, text=True).strip().split("\n")
|
||||||
subprocess.check_output(cmd, shell=True, universal_newlines=True)
|
|
||||||
.strip()
|
|
||||||
.split("\n")
|
|
||||||
)
|
|
||||||
uuids = [s.strip() for s in uuids]
|
uuids = [s.strip() for s in uuids]
|
||||||
|
|
||||||
custom_envs = []
|
custom_envs = []
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
import traceback
|
import traceback
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda._sanitizer as csan
|
import torch.cuda._sanitizer as csan
|
||||||
|
|
@ -148,9 +148,9 @@ class TestEventHandler(TestCase):
|
||||||
def kernel_launch(
|
def kernel_launch(
|
||||||
self,
|
self,
|
||||||
stream: StreamId,
|
stream: StreamId,
|
||||||
read_only: Optional[List[DataPtr]] = None,
|
read_only: Optional[list[DataPtr]] = None,
|
||||||
read_write: Optional[List[DataPtr]] = None,
|
read_write: Optional[list[DataPtr]] = None,
|
||||||
) -> List[csan.SynchronizationError]:
|
) -> list[csan.SynchronizationError]:
|
||||||
if read_only is None:
|
if read_only is None:
|
||||||
read_only = []
|
read_only = []
|
||||||
if read_write is None:
|
if read_write is None:
|
||||||
|
|
@ -167,8 +167,8 @@ class TestEventHandler(TestCase):
|
||||||
def assert_good_kernel_launch(
|
def assert_good_kernel_launch(
|
||||||
self,
|
self,
|
||||||
stream: StreamId,
|
stream: StreamId,
|
||||||
read_only: Optional[List[DataPtr]] = None,
|
read_only: Optional[list[DataPtr]] = None,
|
||||||
read_write: Optional[List[DataPtr]] = None,
|
read_write: Optional[list[DataPtr]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.assertEqual(self.kernel_launch(stream, read_only, read_write), [])
|
self.assertEqual(self.kernel_launch(stream, read_only, read_write), [])
|
||||||
|
|
||||||
|
|
@ -176,8 +176,8 @@ class TestEventHandler(TestCase):
|
||||||
self,
|
self,
|
||||||
number_of_errors: int,
|
number_of_errors: int,
|
||||||
stream: StreamId,
|
stream: StreamId,
|
||||||
read_only: Optional[List[DataPtr]] = None,
|
read_only: Optional[list[DataPtr]] = None,
|
||||||
read_write: Optional[List[DataPtr]] = None,
|
read_write: Optional[list[DataPtr]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
errors = self.kernel_launch(stream, read_only, read_write)
|
errors = self.kernel_launch(stream, read_only, read_write)
|
||||||
self.assertEqual(len(errors), number_of_errors)
|
self.assertEqual(len(errors), number_of_errors)
|
||||||
|
|
|
||||||
|
|
@ -1262,14 +1262,12 @@ class TestDataLoader(TestCase):
|
||||||
list(iter(loader))
|
list(iter(loader))
|
||||||
|
|
||||||
def test_typing(self):
|
def test_typing(self):
|
||||||
from typing import List
|
|
||||||
|
|
||||||
# Make sure there is no TypeError
|
# Make sure there is no TypeError
|
||||||
|
|
||||||
class SomeDatasetClass(Dataset[List[torch.Tensor]]):
|
class SomeDatasetClass(Dataset[list[torch.Tensor]]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _create_dataloader(is_train: bool) -> DataLoader[List[torch.Tensor]]:
|
def _create_dataloader(is_train: bool) -> DataLoader[list[torch.Tensor]]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
|
@unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
# Owner(s): ["oncall: distributed"]
|
# Owner(s): ["oncall: distributed"]
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
@ -27,9 +27,9 @@ class MyModule(torch.nn.Module):
|
||||||
class MyDummyFnOptimizer:
|
class MyDummyFnOptimizer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: List[Tensor],
|
params: list[Tensor],
|
||||||
lr: float = 1e-3,
|
lr: float = 1e-3,
|
||||||
betas: Tuple[float, float] = (0.9, 0.999),
|
betas: tuple[float, float] = (0.9, 0.999),
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
weight_decay: float = 0.0,
|
weight_decay: float = 0.0,
|
||||||
_allow_empty_param_list: bool = False,
|
_allow_empty_param_list: bool = False,
|
||||||
|
|
@ -63,7 +63,7 @@ class MyDummyFnOptimizer:
|
||||||
"MyDummyFnOptimizer does not support step_param() as of now"
|
"MyDummyFnOptimizer does not support step_param() as of now"
|
||||||
)
|
)
|
||||||
|
|
||||||
def step(self, gradients: List[Optional[Tensor]]):
|
def step(self, gradients: list[Optional[Tensor]]):
|
||||||
# call the custom optimizer step implementation
|
# call the custom optimizer step implementation
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
raise RuntimeError("MyDummyFnOptimizer does not support step() as of now")
|
raise RuntimeError("MyDummyFnOptimizer does not support step() as of now")
|
||||||
|
|
|
||||||
|
|
@ -1200,6 +1200,15 @@ class {test_classname}(torch.nn.Module):
|
||||||
inp3_y = inp3.y
|
inp3_y = inp3.y
|
||||||
return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y
|
return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y
|
||||||
|
|
||||||
|
class MyModule2(torch.nn.Module):
|
||||||
|
def forward(self, inp: tuple[CustomType, torch.Tensor], inp2: list[CustomType], inp3: CustomNamedTuple):
|
||||||
|
inp_0 = inp[0]
|
||||||
|
inp_1 = inp[1]
|
||||||
|
inp2_0 = inp2[0]
|
||||||
|
inp3_x = inp3.x
|
||||||
|
inp3_y = inp3.y
|
||||||
|
return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y
|
||||||
|
|
||||||
my_module = MyModule()
|
my_module = MyModule()
|
||||||
my_module_traced = torch.fx.symbolic_trace(my_module)
|
my_module_traced = torch.fx.symbolic_trace(my_module)
|
||||||
|
|
||||||
|
|
@ -1214,6 +1223,20 @@ class {test_classname}(torch.nn.Module):
|
||||||
if node.target == operator.getitem:
|
if node.target == operator.getitem:
|
||||||
self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.")
|
self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.")
|
||||||
|
|
||||||
|
my_module = MyModule2()
|
||||||
|
my_module_traced = torch.fx.symbolic_trace(my_module)
|
||||||
|
|
||||||
|
# by default, fx transform loses type annotation of getitem nodes.
|
||||||
|
for node in my_module_traced.graph.nodes:
|
||||||
|
if node.target == operator.getitem:
|
||||||
|
assert node.type is None
|
||||||
|
|
||||||
|
annotate_getitem_nodes(my_module_traced.graph)
|
||||||
|
|
||||||
|
for node in my_module_traced.graph.nodes:
|
||||||
|
if node.target == operator.getitem:
|
||||||
|
self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.")
|
||||||
|
|
||||||
def test_subgraph_uniquename(self):
|
def test_subgraph_uniquename(self):
|
||||||
class MyModule(torch.nn.Module):
|
class MyModule(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import torch
|
import torch
|
||||||
from typing import List, Any
|
from typing import Any
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
import unittest
|
import unittest
|
||||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||||
|
|
@ -100,7 +100,7 @@ def apply_masked_reduction_along_dim(op, input, *args, **kwargs):
|
||||||
# dimensions along which the reduction operation is applied:
|
# dimensions along which the reduction operation is applied:
|
||||||
dim_ = torch.masked._canonical_dim(dim, input.ndim)
|
dim_ = torch.masked._canonical_dim(dim, input.ndim)
|
||||||
# slices in product(*ranges) define all elementary slices:
|
# slices in product(*ranges) define all elementary slices:
|
||||||
ranges: List[Any] = []
|
ranges: list[Any] = []
|
||||||
# shape of output for the keepdim=True case:
|
# shape of output for the keepdim=True case:
|
||||||
shape = []
|
shape = []
|
||||||
for i in range(input.ndim):
|
for i in range(input.ndim):
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
# Owner(s): ["module: mkldnn"]
|
# Owner(s): ["module: mkldnn"]
|
||||||
import itertools
|
import itertools
|
||||||
import unittest
|
import unittest
|
||||||
from typing import NamedTuple, List
|
from typing import NamedTuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
@ -16,7 +16,7 @@ FUSION_GROUP = 'prim::TensorExprGroup'
|
||||||
class PointwisePostOp(NamedTuple):
|
class PointwisePostOp(NamedTuple):
|
||||||
attr : str
|
attr : str
|
||||||
pointwise_module : nn.Module
|
pointwise_module : nn.Module
|
||||||
scalars : List = []
|
scalars : list = []
|
||||||
algorithm : str = ""
|
algorithm : str = ""
|
||||||
|
|
||||||
CONV_MODULES = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
|
CONV_MODULES = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# Owner(s): ["module: unknown"]
|
# Owner(s): ["module: unknown"]
|
||||||
|
|
||||||
from typing import Optional, List
|
from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
|
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
|
||||||
|
|
||||||
|
|
@ -8,12 +8,12 @@ from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorc
|
||||||
|
|
||||||
|
|
||||||
class FloatListWrapperModule(torch.nn.Module):
|
class FloatListWrapperModule(torch.nn.Module):
|
||||||
def forward(self, values, incr: Optional[List[float]]):
|
def forward(self, values, incr: Optional[list[float]]):
|
||||||
return torch._C._nn._test_optional_floatlist(values, incr)
|
return torch._C._nn._test_optional_floatlist(values, incr)
|
||||||
|
|
||||||
|
|
||||||
class IntListWrapperModule(torch.nn.Module):
|
class IntListWrapperModule(torch.nn.Module):
|
||||||
def forward(self, values, incr: Optional[List[int]]):
|
def forward(self, values, incr: Optional[list[int]]):
|
||||||
return torch._C._nn._test_optional_intlist(values, incr)
|
return torch._C._nn._test_optional_intlist(values, incr)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
@ -3545,7 +3545,7 @@ def get_tolerances(
|
||||||
true_value: torch.Tensor,
|
true_value: torch.Tensor,
|
||||||
computed_value: torch.Tensor,
|
computed_value: torch.Tensor,
|
||||||
fudge_factor: Optional[float] = None,
|
fudge_factor: Optional[float] = None,
|
||||||
) -> Tuple[float, float]:
|
) -> tuple[float, float]:
|
||||||
"""Returns the absolute and relative tolerances for comparing two tensors."""
|
"""Returns the absolute and relative tolerances for comparing two tensors."""
|
||||||
fudge_factor = fudge_factor if fudge_factor is not None else 1.0
|
fudge_factor = fudge_factor if fudge_factor is not None else 1.0
|
||||||
atol = get_atol(true_value, computed_value)
|
atol = get_atol(true_value, computed_value)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@
|
||||||
import ctypes
|
import ctypes
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.backends._nnapi.prepare import convert_model_to_nnapi
|
from torch.backends._nnapi.prepare import convert_model_to_nnapi
|
||||||
|
|
@ -700,7 +699,7 @@ class TestNNAPI(TestCase):
|
||||||
|
|
||||||
def test_multi_output(self):
|
def test_multi_output(self):
|
||||||
class MultiModel(torch.nn.Module):
|
class MultiModel(torch.nn.Module):
|
||||||
def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, lhs, rhs) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
the_sum = lhs + rhs
|
the_sum = lhs + rhs
|
||||||
the_diff = lhs - rhs
|
the_diff = lhs - rhs
|
||||||
return the_sum, the_diff
|
return the_sum, the_diff
|
||||||
|
|
|
||||||
|
|
@ -278,7 +278,7 @@ class TestNumPyInterop(TestCase):
|
||||||
def test_from_numpy_no_leak_on_invalid_dtype(self):
|
def test_from_numpy_no_leak_on_invalid_dtype(self):
|
||||||
# This used to leak memory as the `from_numpy` call raised an exception and didn't decref the temporary
|
# This used to leak memory as the `from_numpy` call raised an exception and didn't decref the temporary
|
||||||
# object. See https://github.com/pytorch/pytorch/issues/121138
|
# object. See https://github.com/pytorch/pytorch/issues/121138
|
||||||
x = np.array("value".encode("ascii"))
|
x = np.array(b"value")
|
||||||
for _ in range(1000):
|
for _ in range(1000):
|
||||||
try:
|
try:
|
||||||
torch.from_numpy(x)
|
torch.from_numpy(x)
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ from collections import defaultdict
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._prims as prims
|
import torch._prims as prims
|
||||||
|
|
@ -1483,7 +1482,7 @@ class TestCommon(TestCase):
|
||||||
unsupported_dtypes = set()
|
unsupported_dtypes = set()
|
||||||
supported_backward_dtypes = set()
|
supported_backward_dtypes = set()
|
||||||
unsupported_backward_dtypes = set()
|
unsupported_backward_dtypes = set()
|
||||||
dtype_error: Dict[torch.dtype, Exception] = {}
|
dtype_error: dict[torch.dtype, Exception] = {}
|
||||||
|
|
||||||
def unsupported(dtype, e):
|
def unsupported(dtype, e):
|
||||||
dtype_error[dtype] = e
|
dtype_error[dtype] = e
|
||||||
|
|
@ -1987,7 +1986,7 @@ class TestCompositeCompliance(TestCase):
|
||||||
for sample in op.sample_inputs(device, dtype, requires_grad=False):
|
for sample in op.sample_inputs(device, dtype, requires_grad=False):
|
||||||
inp = sample.input
|
inp = sample.input
|
||||||
outs = op(inp, *sample.args, **sample.kwargs)
|
outs = op(inp, *sample.args, **sample.kwargs)
|
||||||
if not isinstance(outs, (tuple, List)):
|
if not isinstance(outs, (tuple, list)):
|
||||||
outs = [outs]
|
outs = [outs]
|
||||||
|
|
||||||
# for all outputs that are views of the input, we should be able to replay the
|
# for all outputs that are views of the input, we should be able to replay the
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import math
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Dict, Tuple
|
from typing import Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from optim.test_lrscheduler import TestLRScheduler # noqa: F401
|
from optim.test_lrscheduler import TestLRScheduler # noqa: F401
|
||||||
|
|
@ -1769,8 +1769,8 @@ class TestOptimRenewed(TestCase):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _state_dict_post_hook(
|
def _state_dict_post_hook(
|
||||||
optimizer: Optimizer, state_dict: Dict[str, Any]
|
optimizer: Optimizer, state_dict: dict[str, Any]
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
if "test" in state_dict["state"]:
|
if "test" in state_dict["state"]:
|
||||||
state_dict["state"].pop("test")
|
state_dict["state"].pop("test")
|
||||||
state_dict["ran_state_dict_pre_hook"] = True
|
state_dict["ran_state_dict_pre_hook"] = True
|
||||||
|
|
@ -1821,14 +1821,14 @@ class TestOptimRenewed(TestCase):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_state_dict_pre_hook1(
|
def _load_state_dict_pre_hook1(
|
||||||
optimizer: Optimizer, state_dict: Dict[str, Any]
|
optimizer: Optimizer, state_dict: dict[str, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
state_dict["param_groups"][0]["lr"] = 0.002
|
state_dict["param_groups"][0]["lr"] = 0.002
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_state_dict_pre_hook2(
|
def _load_state_dict_pre_hook2(
|
||||||
optimizer: Optimizer, state_dict: Dict[str, Any]
|
optimizer: Optimizer, state_dict: dict[str, Any]
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
# The typical use case for returning a state dict is to drastically modify the state dict.
|
# The typical use case for returning a state dict is to drastically modify the state dict.
|
||||||
# I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used
|
# I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used
|
||||||
my_state_dict = deepcopy(state_dict)
|
my_state_dict = deepcopy(state_dict)
|
||||||
|
|
@ -1906,7 +1906,7 @@ class TestOptimRenewed(TestCase):
|
||||||
|
|
||||||
@optims(optim_db, dtypes=[torch.float32])
|
@optims(optim_db, dtypes=[torch.float32])
|
||||||
def test_step_post_hook(self, device, dtype, optim_info):
|
def test_step_post_hook(self, device, dtype, optim_info):
|
||||||
def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
def post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
|
||||||
nonlocal data
|
nonlocal data
|
||||||
data += 2
|
data += 2
|
||||||
|
|
||||||
|
|
@ -1938,7 +1938,7 @@ class TestOptimRenewed(TestCase):
|
||||||
|
|
||||||
@optims(optim_db, dtypes=[torch.float32])
|
@optims(optim_db, dtypes=[torch.float32])
|
||||||
def test_step_pre_hook(self, device, dtype, optim_info):
|
def test_step_pre_hook(self, device, dtype, optim_info):
|
||||||
def pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
def pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
|
||||||
nonlocal data
|
nonlocal data
|
||||||
data += 2
|
data += 2
|
||||||
|
|
||||||
|
|
@ -1970,19 +1970,19 @@ class TestOptimRenewed(TestCase):
|
||||||
|
|
||||||
@optims(optim_db, dtypes=[torch.float32])
|
@optims(optim_db, dtypes=[torch.float32])
|
||||||
def test_step_all_hooks(self, device, dtype, optim_info):
|
def test_step_all_hooks(self, device, dtype, optim_info):
|
||||||
def global_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
def global_pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
|
||||||
nonlocal data
|
nonlocal data
|
||||||
data.append(0)
|
data.append(0)
|
||||||
|
|
||||||
def global_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
def global_post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
|
||||||
nonlocal data
|
nonlocal data
|
||||||
data.append(5)
|
data.append(5)
|
||||||
|
|
||||||
def local_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
def local_pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
|
||||||
nonlocal data
|
nonlocal data
|
||||||
data.append(1)
|
data.append(1)
|
||||||
|
|
||||||
def local_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
def local_post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
|
||||||
nonlocal data
|
nonlocal data
|
||||||
data.append(2)
|
data.append(2)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Dict, List, Sequence
|
from collections.abc import Sequence
|
||||||
import random
|
import random
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import product, combinations, permutations
|
from itertools import product, combinations, permutations
|
||||||
|
|
@ -736,7 +736,7 @@ class TestReductions(TestCase):
|
||||||
|
|
||||||
# TODO: kill this ane replace with common creation ops
|
# TODO: kill this ane replace with common creation ops
|
||||||
def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True,
|
def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True,
|
||||||
use_complex=False) -> Dict[str, List[torch.Tensor]]:
|
use_complex=False) -> dict[str, list[torch.Tensor]]:
|
||||||
float_types = [torch.double,
|
float_types = [torch.double,
|
||||||
torch.float]
|
torch.float]
|
||||||
int_types = [torch.int64,
|
int_types = [torch.int64,
|
||||||
|
|
@ -778,7 +778,7 @@ class TestReductions(TestCase):
|
||||||
types += int_types
|
types += int_types
|
||||||
if use_complex:
|
if use_complex:
|
||||||
types += complex_types
|
types += complex_types
|
||||||
tensors: Dict[str, List[torch.Tensor]] = {"cont": [], "noncont": [], "slice": []}
|
tensors: dict[str, list[torch.Tensor]] = {"cont": [], "noncont": [], "slice": []}
|
||||||
for dtype in types:
|
for dtype in types:
|
||||||
tensors["cont"].append(make_contiguous(shape, dtype))
|
tensors["cont"].append(make_contiguous(shape, dtype))
|
||||||
tensors["noncont"].append(make_non_contiguous(shape, dtype))
|
tensors["noncont"].append(make_non_contiguous(shape, dtype))
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm
|
||||||
skipIfCrossRef
|
skipIfCrossRef
|
||||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import Dict, Any
|
from typing import Any
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torch.testing._internal.common_cuda import \
|
from torch.testing._internal.common_cuda import \
|
||||||
(SM53OrLater, SM80OrLater, TEST_MULTIGPU)
|
(SM53OrLater, SM80OrLater, TEST_MULTIGPU)
|
||||||
|
|
@ -334,7 +334,7 @@ class TestSparse(TestSparseBase):
|
||||||
self.assertEqual(t._values(), tc._values())
|
self.assertEqual(t._values(), tc._values())
|
||||||
return tc
|
return tc
|
||||||
|
|
||||||
value_map: Dict[Any, Any] = {}
|
value_map: dict[Any, Any] = {}
|
||||||
for idx, val in zip(t._indices().t(), t._values()):
|
for idx, val in zip(t._indices().t(), t._values()):
|
||||||
idx_tup = tuple(idx.tolist())
|
idx_tup = tuple(idx.tolist())
|
||||||
if idx_tup in value_map:
|
if idx_tup in value_map:
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from torch.testing._internal.common_methods_invocations import (
|
||||||
from torch.testing._internal.common_cuda import SM53OrLater
|
from torch.testing._internal.common_cuda import SM53OrLater
|
||||||
from torch._prims_common import corresponding_complex_dtype
|
from torch._prims_common import corresponding_complex_dtype
|
||||||
|
|
||||||
from typing import Optional, List
|
from typing import Optional
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -597,7 +597,7 @@ class TestFFT(TestCase):
|
||||||
else:
|
else:
|
||||||
numpy_fn = getattr(np.fft, fname)
|
numpy_fn = getattr(np.fft, fname)
|
||||||
|
|
||||||
def fn(t: torch.Tensor, s: Optional[List[int]], dim: List[int] = (-2, -1), norm: Optional[str] = None):
|
def fn(t: torch.Tensor, s: Optional[list[int]], dim: list[int] = (-2, -1), norm: Optional[str] = None):
|
||||||
return torch_fn(t, s, dim, norm)
|
return torch_fn(t, s, dim, norm)
|
||||||
|
|
||||||
torch_fns = (torch_fn, torch.jit.script(fn))
|
torch_fns = (torch_fn, torch.jit.script(fn))
|
||||||
|
|
|
||||||
|
|
@ -2,14 +2,13 @@
|
||||||
# ruff: noqa: F841
|
# ruff: noqa: F841
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Dict, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||||
from torch.testing._internal.static_module import StaticModule
|
from torch.testing._internal.static_module import StaticModule
|
||||||
from typing import List
|
|
||||||
|
|
||||||
|
|
||||||
def linear_shim(
|
def linear_shim(
|
||||||
|
|
@ -108,7 +107,7 @@ def fork_wait_graph2(input1, input2):
|
||||||
:param iters: number of future/wait pairs to be created
|
:param iters: number of future/wait pairs to be created
|
||||||
"""
|
"""
|
||||||
def fork_wait_graph3(input, iters: int):
|
def fork_wait_graph3(input, iters: int):
|
||||||
futures : List[torch.jit.Future[torch.Tensor]] = []
|
futures : list[torch.jit.Future[torch.Tensor]] = []
|
||||||
for _ in range(iters):
|
for _ in range(iters):
|
||||||
futures.append(torch.jit.fork(torch.neg, input))
|
futures.append(torch.jit.fork(torch.neg, input))
|
||||||
results = []
|
results = []
|
||||||
|
|
@ -123,7 +122,7 @@ def fork_wait_graph3(input, iters: int):
|
||||||
:param num_child_forks: number of child forks per parent fork
|
:param num_child_forks: number of child forks per parent fork
|
||||||
"""
|
"""
|
||||||
def fork_wait_graph4(input, num_forks: int, num_child_forks: int):
|
def fork_wait_graph4(input, num_forks: int, num_child_forks: int):
|
||||||
futures : List[torch.jit.Future[torch.Tensor]] = []
|
futures : list[torch.jit.Future[torch.Tensor]] = []
|
||||||
for _ in range(num_forks):
|
for _ in range(num_forks):
|
||||||
futures.append(torch.jit.fork(fork_wait_graph3, input, num_child_forks))
|
futures.append(torch.jit.fork(fork_wait_graph3, input, num_child_forks))
|
||||||
results = []
|
results = []
|
||||||
|
|
@ -150,7 +149,7 @@ def loop_graph(a, b, iters: int):
|
||||||
def output_graph(a, b, c, iters: int):
|
def output_graph(a, b, c, iters: int):
|
||||||
s = torch.tensor([[3, 3], [3, 3]])
|
s = torch.tensor([[3, 3], [3, 3]])
|
||||||
k = a + b * c + s
|
k = a + b * c + s
|
||||||
d: Dict[int, torch.Tensor] = {}
|
d: dict[int, torch.Tensor] = {}
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
d[i] = k + i
|
d[i] = k + i
|
||||||
return d
|
return d
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import itertools
|
||||||
import math
|
import math
|
||||||
import pickle
|
import pickle
|
||||||
import sys
|
import sys
|
||||||
from typing import Callable, List, Tuple, Type
|
from typing import Callable
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
|
|
||||||
|
|
@ -594,7 +594,7 @@ class TestSympyInterp(TestCase):
|
||||||
self.fail(f"Unexpected error for {fn}{args}: {str(e)}")
|
self.fail(f"Unexpected error for {fn}{args}: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def type_name_fn(type: Type) -> str:
|
def type_name_fn(type: type) -> str:
|
||||||
return type.__name__
|
return type.__name__
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -606,7 +606,7 @@ def parametrize_relational_types(*types):
|
||||||
|
|
||||||
|
|
||||||
class TestSympySolve(TestCase):
|
class TestSympySolve(TestCase):
|
||||||
def _create_integer_symbols(self) -> List[sympy.Symbol]:
|
def _create_integer_symbols(self) -> list[sympy.Symbol]:
|
||||||
return sympy.symbols("a b c", integer=True)
|
return sympy.symbols("a b c", integer=True)
|
||||||
|
|
||||||
def test_give_up(self):
|
def test_give_up(self):
|
||||||
|
|
@ -665,9 +665,9 @@ class TestSympySolve(TestCase):
|
||||||
|
|
||||||
def _test_cases(
|
def _test_cases(
|
||||||
self,
|
self,
|
||||||
cases: List[Tuple[sympy.Basic, sympy.Basic]],
|
cases: list[tuple[sympy.Basic, sympy.Basic]],
|
||||||
thing: sympy.Basic,
|
thing: sympy.Basic,
|
||||||
op: Type[sympy.Rel],
|
op: type[sympy.Rel],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
for source, expected in cases:
|
for source, expected in cases:
|
||||||
|
|
@ -761,7 +761,7 @@ class TestSympySolve(TestCase):
|
||||||
Le: (Le(FloorDiv(a, pos), integer), (integer + 1) * pos),
|
Le: (Le(FloorDiv(a, pos), integer), (integer + 1) * pos),
|
||||||
}[op]
|
}[op]
|
||||||
|
|
||||||
cases: List[Tuple[sympy.Basic, sympy.Basic]] = [
|
cases: list[tuple[sympy.Basic, sympy.Basic]] = [
|
||||||
# 'b' is not strictly positive
|
# 'b' is not strictly positive
|
||||||
(op(FloorDiv(a, b), integer), None),
|
(op(FloorDiv(a, b), integer), None),
|
||||||
# 'c' is not strictly positive
|
# 'c' is not strictly positive
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import unittest
|
||||||
from itertools import product, combinations, combinations_with_replacement, permutations
|
from itertools import product, combinations, combinations_with_replacement, permutations
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any
|
||||||
|
|
||||||
from torch.testing import make_tensor
|
from torch.testing import make_tensor
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
|
|
@ -4125,7 +4125,7 @@ class TestAsArray(TestCase):
|
||||||
def test_default_device(self, device):
|
def test_default_device(self, device):
|
||||||
original = torch.arange(5)
|
original = torch.arange(5)
|
||||||
|
|
||||||
examples: List[Tuple[Any, Dict]] = [
|
examples: list[tuple[Any, dict]] = [
|
||||||
(3, {}),
|
(3, {}),
|
||||||
(original, {}),
|
(original, {}),
|
||||||
(to_numpy(original), {}),
|
(to_numpy(original), {}),
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,8 @@ import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import unittest.mock
|
import unittest.mock
|
||||||
from typing import Any, Callable, Iterator, List, Tuple
|
from typing import Any, Callable
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
@ -496,7 +497,7 @@ if __name__ == '__main__':
|
||||||
self.assertNotIn('OK', stderr.decode('ascii'))
|
self.assertNotIn('OK', stderr.decode('ascii'))
|
||||||
|
|
||||||
|
|
||||||
def make_assert_close_inputs(actual: Any, expected: Any) -> List[Tuple[Any, Any]]:
|
def make_assert_close_inputs(actual: Any, expected: Any) -> list[tuple[Any, Any]]:
|
||||||
"""Makes inputs for :func:`torch.testing.assert_close` functions based on two examples.
|
"""Makes inputs for :func:`torch.testing.assert_close` functions based on two examples.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,6 @@ from torch.testing._internal.common_device_type import (
|
||||||
dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast,
|
dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast,
|
||||||
skipMeta, PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyNativeDeviceTypes, skipCUDAIfNotRocm,
|
skipMeta, PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyNativeDeviceTypes, skipCUDAIfNotRocm,
|
||||||
get_all_device_types, skipXLA)
|
get_all_device_types, skipXLA)
|
||||||
from typing import Tuple
|
|
||||||
import torch.backends.quantized
|
import torch.backends.quantized
|
||||||
import torch.testing._internal.data
|
import torch.testing._internal.data
|
||||||
from torch.testing._internal.common_cuda import (
|
from torch.testing._internal.common_cuda import (
|
||||||
|
|
@ -3511,7 +3510,7 @@ else:
|
||||||
|
|
||||||
def _prepare_data_for_index_copy_and_add_deterministic(
|
def _prepare_data_for_index_copy_and_add_deterministic(
|
||||||
self, dim: int, device: torch.device
|
self, dim: int, device: torch.device
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
assert (dim >= 0 and dim < 3)
|
assert (dim >= 0 and dim < 3)
|
||||||
a = [5, 4, 3]
|
a = [5, 4, 3]
|
||||||
a[dim] = 2000
|
a[dim] = 2000
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ import math
|
||||||
import itertools
|
import itertools
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU
|
||||||
from typing import List, Tuple, Optional, Dict
|
from typing import Optional
|
||||||
import torch.utils.cpp_extension
|
import torch.utils.cpp_extension
|
||||||
from torch.testing._internal.common_nn import NNTestCase
|
from torch.testing._internal.common_nn import NNTestCase
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
|
|
@ -149,12 +149,12 @@ def _check_equal(
|
||||||
|
|
||||||
|
|
||||||
def check_out_and_grad(
|
def check_out_and_grad(
|
||||||
out_tuple: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
out_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||||
grad_query_tuple: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
grad_query_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||||
grad_key_tuple: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
grad_key_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||||
grad_value_tuple: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
grad_value_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||||
grad_attn_mask_tuple: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None,
|
grad_attn_mask_tuple: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None,
|
||||||
fudge_factors: Optional[Dict[str, float]] = None
|
fudge_factors: Optional[dict[str, float]] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Check output and gradients of attention mechanism tensors.
|
Check output and gradients of attention mechanism tensors.
|
||||||
|
|
@ -2574,7 +2574,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
|
||||||
def test_cudnn_attention_preserves_query_layout(self, device):
|
def test_cudnn_attention_preserves_query_layout(self, device):
|
||||||
|
|
||||||
def test_attention(backend: SDPBackend, permute_order: List[List[int]]):
|
def test_attention(backend: SDPBackend, permute_order: list[list[int]]):
|
||||||
BHSqD = [4, 16, 256, 64]
|
BHSqD = [4, 16, 256, 64]
|
||||||
BHSkvD = [4, 16, 512, 64]
|
BHSkvD = [4, 16, 512, 64]
|
||||||
|
|
||||||
|
|
@ -2602,7 +2602,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
|
||||||
@parametrize("mask_dim", [1, 2, 3, 4])
|
@parametrize("mask_dim", [1, 2, 3, 4])
|
||||||
def test_mem_efficient_attention_mask_variants(self, device, mask_dim: List[int]):
|
def test_mem_efficient_attention_mask_variants(self, device, mask_dim: list[int]):
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
|
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
|
||||||
batch, num_heads, head_dim = 8, 8, 64
|
batch, num_heads, head_dim = 8, 8, 64
|
||||||
|
|
@ -3307,7 +3307,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||||
@tf32_enabled()
|
@tf32_enabled()
|
||||||
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
|
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
|
||||||
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
|
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
|
||||||
scale: str, enable_gqa: bool, n_heads: List[int]):
|
scale: str, enable_gqa: bool, n_heads: list[int]):
|
||||||
if isSM8XDevice and head_dim in range(193, 256 + 1):
|
if isSM8XDevice and head_dim in range(193, 256 + 1):
|
||||||
self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled")
|
self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled")
|
||||||
if is_causal and seq_len_q != seq_len_k:
|
if is_causal and seq_len_q != seq_len_k:
|
||||||
|
|
@ -3905,7 +3905,7 @@ class TestAttnBias(NNTestCase):
|
||||||
"shape",
|
"shape",
|
||||||
[(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)],
|
[(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)],
|
||||||
)
|
)
|
||||||
def test_causal_variants(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]):
|
def test_causal_variants(self, device, causal_variant: CausalVariant, shape: list[tuple[int]]):
|
||||||
make_tensor = partial(
|
make_tensor = partial(
|
||||||
torch.rand, device=device, dtype=torch.float16, requires_grad=True
|
torch.rand, device=device, dtype=torch.float16, requires_grad=True
|
||||||
)
|
)
|
||||||
|
|
@ -3942,7 +3942,7 @@ class TestAttnBias(NNTestCase):
|
||||||
)
|
)
|
||||||
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows")
|
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows")
|
||||||
@skipIfTorchDynamo("This function already calls torch.compile.")
|
@skipIfTorchDynamo("This function already calls torch.compile.")
|
||||||
def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]):
|
def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: list[tuple[int]]):
|
||||||
if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT:
|
if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT:
|
||||||
self.skipTest("No support for LOWER_RIGHT variant for now")
|
self.skipTest("No support for LOWER_RIGHT variant for now")
|
||||||
return
|
return
|
||||||
|
|
@ -3975,7 +3975,7 @@ class TestAttnBias(NNTestCase):
|
||||||
self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!")
|
self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!")
|
||||||
|
|
||||||
@parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)])
|
@parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)])
|
||||||
def test_is_causal_equals_upper_left(self, device, shape: List[Tuple[int]]):
|
def test_is_causal_equals_upper_left(self, device, shape: list[tuple[int]]):
|
||||||
make_tensor = partial(
|
make_tensor = partial(
|
||||||
torch.rand, device=device, dtype=torch.float16, requires_grad=True
|
torch.rand, device=device, dtype=torch.float16, requires_grad=True
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ import shutil
|
||||||
import unittest
|
import unittest
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Dict, IO, List, Optional
|
from typing import IO, Optional
|
||||||
|
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
|
|
@ -49,12 +49,12 @@ def _strip_filename(msg: str) -> str:
|
||||||
return tail.split(":", 1)[-1]
|
return tail.split(":", 1)[-1]
|
||||||
|
|
||||||
|
|
||||||
def _run_mypy() -> Dict[str, List[str]]:
|
def _run_mypy() -> dict[str, list[str]]:
|
||||||
"""Clears the cache and run mypy before running any of the typing tests."""
|
"""Clears the cache and run mypy before running any of the typing tests."""
|
||||||
if os.path.isdir(CACHE_DIR):
|
if os.path.isdir(CACHE_DIR):
|
||||||
shutil.rmtree(CACHE_DIR)
|
shutil.rmtree(CACHE_DIR)
|
||||||
|
|
||||||
rc: Dict[str, List[str]] = {}
|
rc: dict[str, list[str]] = {}
|
||||||
for directory in (REVEAL_DIR, PASS_DIR, FAIL_DIR):
|
for directory in (REVEAL_DIR, PASS_DIR, FAIL_DIR):
|
||||||
# Run mypy
|
# Run mypy
|
||||||
stdout, stderr, _ = api.run(
|
stdout, stderr, _ = api.run(
|
||||||
|
|
@ -119,10 +119,10 @@ def _construct_format_dict():
|
||||||
|
|
||||||
#: A dictionary with all supported format keys (as keys)
|
#: A dictionary with all supported format keys (as keys)
|
||||||
#: and matching values
|
#: and matching values
|
||||||
FORMAT_DICT: Dict[str, str] = _construct_format_dict()
|
FORMAT_DICT: dict[str, str] = _construct_format_dict()
|
||||||
|
|
||||||
|
|
||||||
def _parse_reveals(file: IO[str]) -> List[str]:
|
def _parse_reveals(file: IO[str]) -> list[str]:
|
||||||
"""Extract and parse all ``" # E: "`` comments from the passed file-like object.
|
"""Extract and parse all ``" # E: "`` comments from the passed file-like object.
|
||||||
|
|
||||||
All format keys will be substituted for their respective value from `FORMAT_DICT`,
|
All format keys will be substituted for their respective value from `FORMAT_DICT`,
|
||||||
|
|
@ -160,10 +160,10 @@ def _test_reveal(path: str, reveal: str, expected_reveal: str, lineno: int) -> N
|
||||||
@unittest.skipIf(NO_MYPY, reason="Mypy is not installed")
|
@unittest.skipIf(NO_MYPY, reason="Mypy is not installed")
|
||||||
class TestTyping(TestCase):
|
class TestTyping(TestCase):
|
||||||
_lock = Lock()
|
_lock = Lock()
|
||||||
_cached_output: Optional[Dict[str, List[str]]] = None
|
_cached_output: Optional[dict[str, list[str]]] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_mypy_output(cls) -> Dict[str, List[str]]:
|
def get_mypy_output(cls) -> dict[str, list[str]]:
|
||||||
with cls._lock:
|
with cls._lock:
|
||||||
if cls._cached_output is None:
|
if cls._cached_output is None:
|
||||||
cls._cached_output = _run_mypy()
|
cls._cached_output = _run_mypy()
|
||||||
|
|
@ -192,7 +192,7 @@ class TestTyping(TestCase):
|
||||||
with open(path) as fin:
|
with open(path) as fin:
|
||||||
lines = fin.readlines()
|
lines = fin.readlines()
|
||||||
|
|
||||||
errors = defaultdict(lambda: "")
|
errors = defaultdict(str)
|
||||||
|
|
||||||
output_mypy = self.get_mypy_output()
|
output_mypy = self.get_mypy_output()
|
||||||
self.assertIn(path, output_mypy)
|
self.assertIn(path, output_mypy)
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ import textwrap
|
||||||
import traceback
|
import traceback
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
|
@ -439,7 +439,7 @@ class TestCheckpoint(TestCase):
|
||||||
# get de-allocated directly. So using cuda memory usage as a proxy
|
# get de-allocated directly. So using cuda memory usage as a proxy
|
||||||
|
|
||||||
def _do_test(fn, should_free):
|
def _do_test(fn, should_free):
|
||||||
stats: List[int] = []
|
stats: list[int] = []
|
||||||
|
|
||||||
def track(x, idx):
|
def track(x, idx):
|
||||||
# Track that at each step of the backward, some Tensor were
|
# Track that at each step of the backward, some Tensor were
|
||||||
|
|
@ -1203,7 +1203,7 @@ def f(x):
|
||||||
return g(x) + 1
|
return g(x) + 1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
out: Dict[str, Any] = {}
|
out: dict[str, Any] = {}
|
||||||
scope = {"__compile_source__": source}
|
scope = {"__compile_source__": source}
|
||||||
exec(source, scope, out)
|
exec(source, scope, out)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
|
||||||
"""
|
"""
|
||||||
Annotate the type of getitem nodes, inferred from the type of sequence node.
|
Annotate the type of getitem nodes, inferred from the type of sequence node.
|
||||||
If sequence node is not annotated with a type, do nothing.
|
If sequence node is not annotated with a type, do nothing.
|
||||||
Currently support getitem nodes from Tuple, List, and NamedTuple sequence node.
|
Currently support getitem nodes from tuple, list, and NamedTuple sequence node.
|
||||||
|
|
||||||
This is helpful since annotations on local names within function are lost during FX transforms.
|
This is helpful since annotations on local names within function are lost during FX transforms.
|
||||||
Adding back known type annotation for getitem nodes to improve jit scriptability.
|
Adding back known type annotation for getitem nodes to improve jit scriptability.
|
||||||
|
|
@ -35,6 +35,21 @@ def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
|
||||||
elif sequence_node.type._name == "List":
|
elif sequence_node.type._name == "List":
|
||||||
assert len(parameterized_types) == 1
|
assert len(parameterized_types) == 1
|
||||||
node.type = parameterized_types[0]
|
node.type = parameterized_types[0]
|
||||||
|
# Generic Alias Type
|
||||||
|
elif hasattr(sequence_node.type, "__origin__"):
|
||||||
|
parameterized_types = sequence_node.type.__args__
|
||||||
|
if sequence_node.type.__origin__ is tuple:
|
||||||
|
if len(parameterized_types) == 2 and isinstance(
|
||||||
|
parameterized_types[1], type(...)
|
||||||
|
):
|
||||||
|
node.type = parameterized_types[0]
|
||||||
|
else:
|
||||||
|
assert len(parameterized_types) > index_node
|
||||||
|
node_type = parameterized_types[index_node]
|
||||||
|
node.type = node_type
|
||||||
|
elif sequence_node.type.__origin__ is list:
|
||||||
|
assert len(parameterized_types) == 1
|
||||||
|
node.type = parameterized_types[0]
|
||||||
# NamedTuple type
|
# NamedTuple type
|
||||||
elif hasattr(sequence_node.type, "__annotations__"):
|
elif hasattr(sequence_node.type, "__annotations__"):
|
||||||
if sequence_node.type == torch.Tensor:
|
if sequence_node.type == torch.Tensor:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user