Migrate from Tuple -> tuple in torch/_dynamo (#144261)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144261
Approved by: https://github.com/aorenste, https://github.com/zou3519
This commit is contained in:
bobrenjc93 2025-01-09 20:10:46 -08:00 committed by PyTorch MergeBot
parent f295eff512
commit 1fe3af2c68
26 changed files with 107 additions and 155 deletions

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional
import torch
import torch.utils._pytree as pytree
@ -100,7 +100,7 @@ class ModIndex(torch.autograd.Function):
return torch.ops.aten.index(x, indices)
@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
def setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> None:
x, indices = inputs
ctx.save_for_backward(*indices)
ctx.input_shape = x.shape
@ -131,8 +131,8 @@ class TransformGetItemToIndex(TorchFunctionMode):
def __torch_function__(
self,
func: OpOverload,
types: Tuple[torch._C._TensorMeta, ...],
args: Tuple[object, ...] = (),
types: tuple[torch._C._TensorMeta, ...],
args: tuple[object, ...] = (),
kwargs: Optional[Dict[str, object]] = None,
) -> object:
if func == torch.Tensor.__getitem__:
@ -161,8 +161,8 @@ _trace_wrapped_op = TraceWrapped()
def _assert_meta(
grad: torch.Tensor,
size: Tuple[int, ...],
stride: Tuple[int, ...],
size: tuple[int, ...],
stride: tuple[int, ...],
dtype: torch.dtype,
) -> torch.Tensor:
assert grad.size() == size, "size mismatch"

View File

@ -4,7 +4,7 @@ import functools
import logging
import sys
from importlib.metadata import EntryPoint
from typing import Callable, Dict, List, Optional, Protocol, Sequence, Tuple
from typing import Callable, Dict, List, Optional, Protocol, Sequence
import torch
from torch import fx
@ -14,7 +14,7 @@ log = logging.getLogger(__name__)
class CompiledFn(Protocol):
def __call__(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]:
def __call__(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]:
...

View File

@ -6,18 +6,7 @@ import functools
import itertools
import sys
import types
from typing import (
Any,
Callable,
cast,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
Union,
)
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Sequence, Union
from ..utils._backport_slots import dataclass_slots
from .bytecode_analysis import (
@ -447,7 +436,7 @@ def create_swap(n) -> List[Instruction]:
def lnotab_writer(
lineno: int, byteno: int = 0
) -> Tuple[List[int], Callable[[int, int], None]]:
) -> tuple[List[int], Callable[[int, int], None]]:
"""
Used to create typing.CodeType.co_lnotab
See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt
@ -537,7 +526,7 @@ def linetable_311_writer(first_lineno: int):
assert 0 < size <= 8
# first byte - use 13 (no column info) is positions is
# malformed, otherwise use 14 (long form)
other_varints: Tuple[int, ...] = ()
other_varints: tuple[int, ...] = ()
if (
positions
and positions.lineno is not None
@ -670,7 +659,7 @@ def assemble_exception_table(tab: List[ExceptionTableEntry]) -> bytes:
return bytes(b)
def assemble(instructions: List[Instruction], firstlineno: int) -> Tuple[bytes, bytes]:
def assemble(instructions: List[Instruction], firstlineno: int) -> tuple[bytes, bytes]:
"""Do the opposite of dis.get_instructions()"""
code: List[int] = []
if sys.version_info >= (3, 11):
@ -854,7 +843,7 @@ def compute_exception_table(
instructions: List[Instruction],
) -> List[ExceptionTableEntry]:
"""Compute exception table in list format from instructions with exn_tab_entries"""
exn_dict: Dict[Tuple[int, int], Tuple[int, int, bool]] = {}
exn_dict: Dict[tuple[int, int], tuple[int, int, bool]] = {}
indexof = get_indexof(instructions)
for inst in instructions:
@ -888,7 +877,7 @@ def compute_exception_table(
# smallest byte that the next exception table entry can start at
nexti = 0
# stack of current nested keys
key_stack: List[Tuple[int, int]] = []
key_stack: List[tuple[int, int]] = []
exn_tab: List[ExceptionTableEntry] = []
def pop():
@ -933,7 +922,7 @@ def check_inst_exn_tab_entries_nested(
"Properly sorted" means entries are sorted by increasing starts, then
decreasing ends.
"""
entry_stack: List[Tuple[int, int]] = []
entry_stack: List[tuple[int, int]] = []
for entry in tab:
key = (indexof[entry.start], indexof[entry.end])
while entry_stack and entry_stack[-1][1] < key[0]:
@ -949,7 +938,7 @@ def propagate_inst_exn_table_entries(instructions: List[Instruction]) -> None:
Supports nested exception table entries.
"""
indexof = get_indexof(instructions)
entries: Dict[Tuple[int, int], InstructionExnTabEntry] = {}
entries: Dict[tuple[int, int], InstructionExnTabEntry] = {}
for inst in instructions:
if inst.exn_tab_entry:
key = (
@ -1417,7 +1406,7 @@ def transform_code_object(code, transformations, safe=False) -> types.CodeType:
def clean_and_assemble_instructions(
instructions: List[Instruction], keys: List[str], code_options: Dict[str, Any]
) -> Tuple[List[Instruction], types.CodeType]:
) -> tuple[List[Instruction], types.CodeType]:
# also implicitly checks for no duplicate instructions
check_inst_exn_tab_entries_valid(instructions)

View File

@ -2,7 +2,6 @@
import logging
import weakref
from dataclasses import dataclass
from typing import Tuple
from torch._guards import CompileId
@ -167,7 +166,7 @@ def is_recompilation(cache_size: CacheSizeRelevantForFrame) -> bool:
def exceeds_recompile_limit(
cache_size: CacheSizeRelevantForFrame, compile_id: CompileId
) -> Tuple[bool, str]:
) -> tuple[bool, str]:
"""
Checks if we are exceeding the cache size limit.
"""

View File

@ -5,7 +5,7 @@ import itertools
import operator
import time
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
import torch
from torch._dynamo.external_utils import (
@ -118,7 +118,7 @@ class AutogradCompilerInstance:
inputs: List[torch.Tensor],
sizes: List[int],
scalars: List[Union[int, float]],
origins: List[List[Tuple[int, str]]],
origins: List[List[tuple[int, str]]],
):
counters["compiled_autograd"]["captures"] += 1
self.id = next(COMPILE_COUNTER)
@ -785,7 +785,7 @@ class AutogradCompilerInstance:
return proxy_tensor.proxy
def bind_tensors_to_proxies(
self, tensors, proxies, origins: Optional[List[Tuple[int, str]]] = None
self, tensors, proxies, origins: Optional[List[tuple[int, str]]] = None
):
if isinstance(proxies, torch.fx.Proxy):
if origins:

View File

@ -22,7 +22,7 @@ import typing
import weakref
from pathlib import Path
from types import CellType, CodeType, FunctionType, ModuleType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Set, TypeVar, Union
from typing_extensions import ParamSpec
from weakref import ReferenceType
@ -619,7 +619,7 @@ def _compile(
globals: Dict[str, object],
locals: Dict[str, object],
builtins: Dict[str, object],
closure: Tuple[CellType],
closure: tuple[CellType],
compiler_fn: CompilerFn,
one_graph: bool,
export: bool,

View File

@ -1,6 +1,6 @@
import threading
from contextlib import contextmanager
from typing import Any, Generator, Tuple
from typing import Any, Generator
import torch
@ -24,7 +24,7 @@ class TracableCreateParameter(torch.autograd.Function):
return placeholder.set_(tensor)
@staticmethod
def backward(ctx: Any, *grad_outputs: torch.Tensor) -> Tuple[None, torch.Tensor]:
def backward(ctx: Any, *grad_outputs: torch.Tensor) -> tuple[None, torch.Tensor]:
grad = grad_outputs[0]
return None, grad # grad flows to placeholder
@ -38,7 +38,7 @@ def tracable_create_parameter(
def new_parameter_placeholder(
size: Tuple[int, ...], dtype: torch.dtype, device: torch.device, requires_grad: bool
size: tuple[int, ...], dtype: torch.dtype, device: torch.device, requires_grad: bool
) -> torch.nn.Parameter:
"""Create a placeholder to be passed to the above functions"""
result = torch.nn.Parameter(

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import time
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, Iterable, Optional, Type, Union
import torch
@ -397,7 +397,7 @@ def get_interface_for_device(device: Union[str, torch.device]) -> Type[DeviceInt
raise NotImplementedError(f"No interface for device {device}")
def get_registered_device_interfaces() -> Iterable[Tuple[str, Type[DeviceInterface]]]:
def get_registered_device_interfaces() -> Iterable[tuple[str, Type[DeviceInterface]]]:
if not _device_initialized:
init_device_reg()
return device_interfaces.items()

View File

@ -35,7 +35,6 @@ from typing import (
NamedTuple,
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
@ -1014,7 +1013,7 @@ class FlattenInputOutputSignature(torch.fx.Transformer):
def __init__(
self,
m: torch.fx.GraphModule,
flat_args: Tuple[Any],
flat_args: tuple[Any],
matched_input_elements_positions: List[int],
flat_results: List[Any],
matched_output_elements_positions: List[int],
@ -1381,7 +1380,7 @@ def export(
Dict[torch._ops.OpOverload, Callable[..., Any]]
] = None,
tracing_mode: str = "symbolic",
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], tuple[Any], List[Any]]] = None,
specialize_float: bool = True,
assume_static_by_default: bool = False,
same_signature: bool = True,
@ -1472,7 +1471,7 @@ def export(
graph = None
out_guards = None
graph_captured_input = None
graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None
graph_captured_result: Optional[tuple[torch.Tensor, ...]] = None
fake_mode = None
result_traced = None

View File

@ -6,7 +6,7 @@ import textwrap
import typing
from enum import auto, Enum
from traceback import extract_stack, format_exc, format_list, StackSummary
from typing import Any, NoReturn, Optional, Tuple, Type, TYPE_CHECKING
from typing import Any, NoReturn, Optional, Type, TYPE_CHECKING
import torch._guards
@ -425,7 +425,7 @@ def augment_exc_message(exc: Exception, msg: str = "\n", export: bool = False) -
def get_exc_message(
e: Exception, compile_id: CompileId
) -> Tuple[Optional[str], Optional[int]]:
) -> tuple[Optional[str], Optional[int]]:
filename = None
lineno = None
if e.innermost_user_frame_summary is not None: # type: ignore[attr-defined]

View File

@ -1,6 +1,6 @@
import logging
import operator
from typing import Any, Dict, Iterable, List, Set, Tuple
from typing import Any, Dict, Iterable, List, Set
import torch.fx
from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation
@ -104,7 +104,7 @@ def _replace_region_with_subgraph(
graph: torch.fx.Graph,
region: Region,
get_subgraph_node: Node,
node_ind_arg_ind: Iterable[Tuple[int, int]],
node_ind_arg_ind: Iterable[tuple[int, int]],
inds_with_external_users: List[int],
sub_gm: torch.fx.GraphModule,
subgraph_name: str,
@ -147,7 +147,7 @@ def _replace_region_with_subgraph(
def _get_external_inputs(
region: Region,
) -> Dict[Node, Tuple[int, int]]:
) -> Dict[Node, tuple[int, int]]:
external_node_to_indices = dict()
region_unique = set(region)
for node_ind, node in enumerate(region):
@ -183,9 +183,9 @@ def _get_inds_with_external_users(region: Region, inds_unique: Set[int]) -> None
def _copy_nodes_and_remap_inputs(
subgraph: torch.fx.Graph, region: Region
) -> Dict[Tuple[int, int], Any]:
) -> Dict[tuple[int, int], Any]:
external_inputs_to_indices = _get_external_inputs(region)
indices_to_placeholder_ind: Dict[Tuple[int, int], Any] = {}
indices_to_placeholder_ind: Dict[tuple[int, int], Any] = {}
region_to_subgraph_node = {}
for node in external_inputs_to_indices.keys():
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
@ -219,7 +219,7 @@ def _create_subgraph_outputs(
def _create_subgraph(
region: Region,
inds_with_external_users: List[int],
) -> Tuple[torch.fx.Graph, Dict[Tuple[int, int], Any]]:
) -> tuple[torch.fx.Graph, Dict[tuple[int, int], Any]]:
subgraph: torch.fx.Graph = torch.fx.Graph()
node_ind_input_inds = _copy_nodes_and_remap_inputs(subgraph, region)
_create_subgraph_outputs(subgraph, inds_with_external_users)

View File

@ -13,7 +13,6 @@ from typing import (
List,
Optional,
Set,
Tuple,
TYPE_CHECKING,
TypeVar,
)
@ -34,7 +33,7 @@ if TYPE_CHECKING:
Node = torch.fx.Node
Region = List[Node]
IdenticalNodes = List[Node]
GlobalStateKey = Tuple[bool, bool, int, bool, bool, torch.dtype, bool, bool, bool, bool]
GlobalStateKey = tuple[bool, bool, int, bool, bool, torch.dtype, bool, bool, bool, bool]
log = logging.getLogger(__name__)
graph_expansion_log = torch._logging.getArtifactLogger(
@ -48,7 +47,7 @@ def debug_log(msg: str, *args) -> None: # type: ignore[no-untyped-def]
def _extract_tensor_metadata_for_node_hash(
x: torch.Tensor,
) -> Tuple[Callable[[T], T], Tuple[Any, ...]]:
) -> tuple[Callable[[T], T], tuple[Any, ...]]:
from torch._inductor.codecache import _ident, extract_tensor_metadata_for_cache_key
out = []
@ -104,7 +103,7 @@ def _extract_tensor_arg(arg: Any) -> Any:
def _normalize_args(
node: Node,
) -> Tuple[Tuple[str, ...], Tuple[Optional[Any], ...]]:
) -> tuple[tuple[str, ...], tuple[Optional[Any], ...]]:
flat_args, _ = tree_flatten(node.args)
sorted_kwargs = sorted(node.kwargs.items(), key=lambda x: x[0])
sorted_keys = tuple(sorted(node.kwargs.keys()))

View File

@ -21,18 +21,7 @@ import weakref
from contextlib import contextmanager
from copy import deepcopy
from inspect import currentframe
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
from typing import Any, Callable, Dict, List, Optional, Set, Type, TYPE_CHECKING, Union
from weakref import ReferenceType
import torch
@ -647,7 +636,7 @@ class GuardBuilder(GuardBuilderBase):
self._cached_guard_managers: Dict[
str, torch._C._dynamo.guards.GuardManager
] = {}
self._cached_duplicate_input_guards: Set[Tuple[str, str]] = set()
self._cached_duplicate_input_guards: Set[tuple[str, str]] = set()
def guard_on_dict_keys_and_ignore_order(self, example_value, guard):
dict_mgr = self.get_guard_manager(guard)
@ -1511,7 +1500,7 @@ class GuardBuilder(GuardBuilderBase):
ref = self.arg_ref(guard)
val = self.get(guard.name)
if np:
np_types: Tuple[Type[Any], ...] = (
np_types: tuple[Type[Any], ...] = (
np.int8,
np.int16,
np.int32,
@ -1817,10 +1806,10 @@ class GuardBuilder(GuardBuilderBase):
]
if output_graph.export_constraints:
names: Dict[str, Tuple[int, int]] = {}
source_pairs: List[Tuple[Source, Source]] = []
names: Dict[str, tuple[int, int]] = {}
source_pairs: List[tuple[Source, Source]] = []
derived_equalities: List[ # type: ignore[type-arg]
Tuple[Source, Union[Source, Symbol], Callable]
tuple[Source, Union[Source, Symbol], Callable]
] = []
phantom_symbols: Dict[str, Symbol] = {}
relaxed_sources: Set[Source] = set()
@ -2178,7 +2167,7 @@ class PyExprCSEPass:
log.exception("Failed to visit expr at line %s.\n%s", ex.lineno, e)
raise
def replace(self, expr: str) -> Tuple[List[str], str]:
def replace(self, expr: str) -> tuple[List[str], str]:
replacer = self.Replacer(self._config, self._new_var)
new_node = replacer.visit(ast.parse(expr))
return replacer.preface, _ast_unparse(new_node)
@ -2559,13 +2548,13 @@ class CheckFunctionManager:
return None
def build_guard_function(code_parts, closure_args) -> Tuple[str, str]:
def build_guard_function(code_parts, closure_args) -> tuple[str, str]:
from torch._inductor.utils import IndentedBuffer
csepass = PyExprCSEPass()
csepass.count(code_parts)
def replace(expr: str) -> Tuple[List[str], str]:
def replace(expr: str) -> tuple[List[str], str]:
return csepass.replace(expr)
# Generate the inner body of the guard function.

View File

@ -12,18 +12,7 @@ import sys
import traceback
import weakref
from dataclasses import dataclass
from typing import (
Any,
Callable,
cast,
Dict,
List,
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
from typing import Any, Callable, cast, Dict, List, Optional, Set, TYPE_CHECKING, Union
import sympy
@ -427,7 +416,7 @@ class OutputGraph:
# random_calls tracks calls to random() and random_values_var stores the name of
# the variable that stores __gen_rand_values results.
self.random_calls: List[
Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
tuple[Callable[..., object], tuple[object, ...], Dict[str, object]]
] = []
self.random_values_var = None
@ -638,7 +627,7 @@ class OutputGraph:
Saves to out if it is provided. Else saves to the tracing context's global_state.
"""
global_state = cast(
Dict[str, Tuple[Callable[..., Any], bool]],
Dict[str, tuple[Callable[..., Any], bool]],
out
if out is not None
else self.tracing_context.global_context.global_state,
@ -1270,7 +1259,7 @@ class OutputGraph:
Momentarily restores the global state to what it was prior to tracing the current output
"""
prior_global_state = self.tracing_context.global_context.copy_graphstate()
current_global_state: Dict[str, Tuple[Any, bool]] = {}
current_global_state: Dict[str, tuple[Any, bool]] = {}
self.save_global_state(out=current_global_state)
try:
# Set to state prior to tracing the graph

View File

@ -8,7 +8,7 @@ import logging
import os
import pickle
from collections import defaultdict
from typing import DefaultDict, Optional, Tuple, TYPE_CHECKING, TypeVar, Union
from typing import DefaultDict, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import Self
import torch._dynamo.config
@ -168,10 +168,10 @@ class FrameStateSizeEntry:
# NB: We don't have cases where we have a known dimensionality but
# we know NOTHING about the individual sizes
size: Union[
AutoDynamic, AutoUnset, Tuple[Union[int, AutoDynamic], ...]
AutoDynamic, AutoUnset, tuple[Union[int, AutoDynamic], ...]
] = dataclasses.field(default=auto_unset)
stride: Union[
AutoDynamic, AutoUnset, Tuple[Union[int, AutoDynamic, InferStride], ...]
AutoDynamic, AutoUnset, tuple[Union[int, AutoDynamic, InferStride], ...]
] = dataclasses.field(default=auto_unset)
def render(self) -> str:
@ -187,7 +187,7 @@ class FrameStateSizeEntry:
else:
return str(s)
def render_tuple(ss: Tuple[Union[int, AutoDynamic, InferStride], ...]) -> str:
def render_tuple(ss: tuple[Union[int, AutoDynamic, InferStride], ...]) -> str:
return "[" + ", ".join(render_single(s) for s in ss) + "]"
# Common cases
@ -246,7 +246,7 @@ class FrameStateSizeEntry:
return self.stride[dim] is auto_dynamic
@staticmethod
def _munge_symint(xs: Tuple[int, ...]) -> Tuple[Union[AutoDynamic, int], ...]:
def _munge_symint(xs: tuple[int, ...]) -> tuple[Union[AutoDynamic, int], ...]:
return tuple(auto_dynamic if isinstance(x, torch.SymInt) else x for x in xs)
@classmethod
@ -255,7 +255,7 @@ class FrameStateSizeEntry:
@classmethod
def make_tensor(
cls, size: Tuple[int, ...], stride: Tuple[int, ...]
cls, size: tuple[int, ...], stride: tuple[int, ...]
) -> FrameStateSizeEntry:
return FrameStateSizeEntry(
scalar=auto_dynamic,
@ -264,7 +264,7 @@ class FrameStateSizeEntry:
)
@classmethod
def make_size(cls, size: Tuple[int, ...]) -> FrameStateSizeEntry:
def make_size(cls, size: tuple[int, ...]) -> FrameStateSizeEntry:
return FrameStateSizeEntry(
scalar=auto_unset,
size=cls._munge_symint(size),
@ -284,9 +284,9 @@ class FrameStateSizeEntry:
@classmethod
def _merge_atom_tup(
cls,
xs: Union[AutoDynamic, AutoUnset, Tuple[_T, ...]],
ys: Union[AutoDynamic, AutoUnset, Tuple[_T, ...]],
) -> Union[AutoDynamic, AutoUnset, Tuple[Union[AutoDynamic, _T], ...]]:
xs: Union[AutoDynamic, AutoUnset, tuple[_T, ...]],
ys: Union[AutoDynamic, AutoUnset, tuple[_T, ...]],
) -> Union[AutoDynamic, AutoUnset, tuple[Union[AutoDynamic, _T], ...]]:
if xs is auto_unset:
return ys
if ys is auto_unset:
@ -651,7 +651,7 @@ def put_code_state() -> None:
put_remote_code_state(cache_key)
def write_local_impl(cache_key: str, pickled_code: bytes) -> Optional[Tuple[str, int]]:
def write_local_impl(cache_key: str, pickled_code: bytes) -> Optional[tuple[str, int]]:
path = code_state_path(cache_key)
if path is None:

View File

@ -2,7 +2,7 @@
# Please add a new import when adding a new polyfill module.
import importlib
from typing import Tuple, TYPE_CHECKING
from typing import TYPE_CHECKING
from .. import polyfills, trace_rules
@ -12,7 +12,7 @@ if TYPE_CHECKING:
# See also the TYPE_CHECKING block in torch/_dynamo/polyfills/__init__.py
POLYFILLED_MODULE_NAMES: Tuple[str, ...] = (
POLYFILLED_MODULE_NAMES: tuple[str, ...] = (
"builtins",
"functools",
"itertools",
@ -21,7 +21,7 @@ POLYFILLED_MODULE_NAMES: Tuple[str, ...] = (
"pytree",
"sys",
)
POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple(
POLYFILLED_MODULES: tuple["ModuleType", ...] = tuple(
importlib.import_module(f".{submodule}", package=polyfills.__name__)
for submodule in POLYFILLED_MODULE_NAMES
)

View File

@ -1,7 +1,7 @@
import dataclasses
from dataclasses import field
from types import CellType, CodeType, ModuleType
from typing import Any, BinaryIO, Dict, IO, Tuple
from typing import Any, BinaryIO, Dict, IO
from typing_extensions import Self
from torch.utils._import_utils import import_dill
@ -29,7 +29,7 @@ class DummyModule:
@dataclasses.dataclass
class ExecutionRecord:
code: CodeType
closure: Tuple[CellType]
closure: tuple[CellType]
globals: Dict[str, Any] = field(default_factory=dict)
locals: Dict[str, Any] = field(default_factory=dict)
builtins: Dict[str, Any] = field(default_factory=dict)
@ -50,7 +50,7 @@ class ExecutionRecorder:
LOCAL_MOD_PREFIX = "___local_mod_"
code: CodeType
closure: Tuple[CellType]
closure: tuple[CellType]
globals: Dict[str, Any] = field(default_factory=dict)
locals: Dict[str, Any] = field(default_factory=dict)
builtins: Dict[str, Any] = field(default_factory=dict)

View File

@ -9,7 +9,7 @@ import shutil
import sys
import textwrap
from importlib import import_module
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Union
import torch
from torch._dynamo.debug_utils import (
@ -89,7 +89,7 @@ def save_graph_repro_ep(
*,
exported_program: Optional[ExportedProgram] = None,
gm: Optional[torch.nn.Module] = None,
args: Optional[Tuple[Any]] = None,
args: Optional[tuple[Any]] = None,
config_patches: Optional[Dict[str, str]] = None,
stable_output=False,
save_dir=None,

View File

@ -3,7 +3,7 @@ import copy
import dataclasses
import sys
import types
from typing import Any, cast, Dict, List, Optional, Tuple
from typing import Any, cast, Dict, List, Optional
from .bytecode_transformation import (
bytecode_from_template,
@ -89,7 +89,7 @@ def _try_except_tf_mode_template(dummy, stack_var_name):
@dataclasses.dataclass(frozen=True)
class ReenterWith:
stack_index: int
target_values: Optional[Tuple[Any, ...]] = None
target_values: Optional[tuple[Any, ...]] = None
def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]):
"""
@ -271,14 +271,14 @@ class ContinueExecutionCache:
code,
lineno,
offset: int,
setup_fn_target_offsets: Tuple[int, ...], # only used in Python 3.11+
setup_fn_target_offsets: tuple[int, ...], # only used in Python 3.11+
nstack: int,
argnames: Tuple[str, ...],
argnames_null: Tuple[str, ...],
setup_fns: Tuple[ReenterWith, ...],
stack_ctx_vars: Tuple[Tuple[int, Tuple[Any]], ...],
argnames_ctx_vars: Tuple[Tuple[str, Tuple[Any]], ...],
null_idxes: Tuple[int, ...],
argnames: tuple[str, ...],
argnames_null: tuple[str, ...],
setup_fns: tuple[ReenterWith, ...],
stack_ctx_vars: tuple[tuple[int, tuple[Any]], ...],
argnames_ctx_vars: tuple[tuple[str, tuple[Any]], ...],
null_idxes: tuple[int, ...],
) -> types.CodeType:
assert offset is not None
assert not (
@ -461,7 +461,7 @@ class ContinueExecutionCache:
@classmethod
def generate_based_on_original_code_object(
cls, code, lineno, offset: int, setup_fn_target_offsets: Tuple[int, ...], *args
cls, code, lineno, offset: int, setup_fn_target_offsets: tuple[int, ...], *args
):
"""
This handles the case of generating a resume into code generated

View File

@ -19,7 +19,7 @@ import traceback
import types
import typing
import weakref
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, cast, Dict, List, Optional, Set, Type, Union
from unittest.mock import patch
import torch
@ -2634,7 +2634,7 @@ class InstructionTranslatorBase(
speculation_log: SpeculationLog,
distributed_state: Optional[DistributedState],
# This determines whether to use the execution recorder.
closure: Optional[Tuple[types.CellType]] = None,
closure: Optional[tuple[types.CellType]] = None,
) -> None:
super().__init__()
self.speculation_log = speculation_log
@ -2678,7 +2678,7 @@ class InstructionTranslatorBase(
# Stack of module being parsed, current nn.module is at the end of ordered dict.
# The first field of tuple is the fully qualified name of current module
# in original hierarchy. The second field is the type of current nn.module
self.nn_module_stack: Dict[str, Tuple[str, Type[Any]]] = {}
self.nn_module_stack: Dict[str, tuple[str, Type[Any]]] = {}
self.num_calls: Dict[str, int] = {}
# Flag to indicate whether tracing is used for export.
self.export = export
@ -2861,7 +2861,7 @@ class InstructionTranslator(InstructionTranslatorBase):
torch_function_mode_stack
)
self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = []
self.debug_locals: List[tuple[VariableTracker, List[VariableTracker]]] = []
if export:
# export gets confused if we never realize unused inputs
# in export mode just eagerly realize everything

View File

@ -1,7 +1,7 @@
import contextlib
import importlib
import logging
from typing import Tuple, Union
from typing import Union
import torch
import torch.testing
@ -19,7 +19,7 @@ from . import config, reset, utils
log = logging.getLogger(__name__)
def run_tests(needs: Union[str, Tuple[str, ...]] = ()) -> None:
def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None:
from torch.testing._internal.common_utils import run_tests
if TEST_WITH_TORCHDYNAMO or IS_WINDOWS or TEST_WITH_CROSSREF:

View File

@ -16,7 +16,6 @@ from typing import (
Optional,
overload,
Sequence,
Tuple,
TypeVar,
Union,
)
@ -141,7 +140,7 @@ def reduce_to_scalar_loss(out: torch.Tensor) -> torch.Tensor:
@overload
def reduce_to_scalar_loss(
out: Union[List[Any], Tuple[Any, ...], Dict[Any, Any]]
out: Union[List[Any], tuple[Any, ...], Dict[Any, Any]]
) -> float:
...

View File

@ -54,7 +54,6 @@ from typing import (
Optional,
overload,
Set,
Tuple,
Type,
TypeVar,
Union,
@ -106,7 +105,7 @@ try:
# NOTE: Make sure `NP_SUPPORTED_MODULES` and `NP_TO_TNP_MODULE` are in sync.
if np:
NP_SUPPORTED_MODULES: Tuple[types.ModuleType, ...] = (
NP_SUPPORTED_MODULES: tuple[types.ModuleType, ...] = (
np,
np.fft,
np.linalg,
@ -202,8 +201,8 @@ class ReinplaceCounters:
def tabulate(
rows: Union[List[Tuple[str, object]], List[List[object]]],
headers: Union[Tuple[str, ...], List[str]],
rows: Union[List[tuple[str, object]], List[List[object]]],
headers: Union[tuple[str, ...], List[str]],
) -> str:
try:
import tabulate
@ -590,7 +589,7 @@ def compile_times(repr: Literal["str"], aggregate: bool = False) -> str:
@overload
def compile_times(
repr: Literal["csv"], aggregate: bool = False
) -> Tuple[List[str], List[object]]:
) -> tuple[List[str], List[object]]:
...
@ -658,7 +657,7 @@ class DuplicateWarningChecker:
def reset(self):
self.set = OrderedDict()
def add(self, key: Union[str, Tuple[object, object]]) -> bool:
def add(self, key: Union[str, tuple[object, object]]) -> bool:
if key in self.set:
self.set.move_to_end(key, last=True)
if not config.verbose:
@ -797,7 +796,7 @@ def istype(obj: object, allowed_types: Type[T]) -> TypeIs[T]:
@overload
def istype(
obj: object, allowed_types: Tuple[Type[List[T]], Type[Tuple[T, ...]]]
obj: object, allowed_types: tuple[Type[List[T]], Type[tuple[T, ...]]]
) -> TypeIs[T]:
...
@ -940,7 +939,7 @@ def is_numpy_ndarray(value):
def istensor(obj):
"""Check of obj is a tensor"""
tensor_list: Tuple[type, ...] = (
tensor_list: tuple[type, ...] = (
torch.Tensor,
torch.nn.Parameter,
*config.traceable_tensor_subclasses,
@ -1900,7 +1899,7 @@ def is_namedtuple_cls(cls):
@functools.lru_cache(1)
def namedtuple_fields(cls) -> Tuple[str, ...]:
def namedtuple_fields(cls) -> tuple[str, ...]:
"""Get the fields of a namedtuple or a torch.return_types.* quasi-namedtuple"""
if cls is slice:
return ("start", "stop", "step")
@ -2188,7 +2187,7 @@ def tuple_iterator_getitem(it, index):
iter_next = next
def normalize_range_iter(range_iter) -> Tuple[int, int, int]:
def normalize_range_iter(range_iter) -> tuple[int, int, int]:
_, (range_obj,), maybe_idx = range_iter.__reduce__()
# In 3.12+, `maybe_idx` could be None, and `range_obj.start` would've been
# already incremented by the current index.
@ -3070,7 +3069,7 @@ def tensor_always_has_static_shape(
tensor: Union[torch.Tensor, Any],
is_tensor: bool,
tensor_source: Source,
) -> Tuple[bool, Optional[TensorStaticReason]]:
) -> tuple[bool, Optional[TensorStaticReason]]:
"""
Given a tensor, source, and is_tensor flag, determine if a shape should be static.

View File

@ -6,17 +6,7 @@ import functools
import inspect
import itertools
import types
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
TYPE_CHECKING,
TypeVar,
)
from typing import Any, Callable, Dict, List, Optional, Sequence, TYPE_CHECKING, TypeVar
from typing_extensions import Never
from unittest.mock import patch
@ -1095,7 +1085,7 @@ class DynamoTritonHOPifier(TritonHOPifier):
def get_value(self, val: Any) -> Any:
return val.value
def check_grid(self, grid) -> Tuple[torch.fx.proxy.Proxy, ...]:
def check_grid(self, grid) -> tuple[torch.fx.proxy.Proxy, ...]:
from .lists import BaseListVariable
if isinstance(grid, BaseListVariable):

View File

@ -8,7 +8,7 @@ import itertools
import logging
import types
import warnings
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from typing import Dict, List, Optional, TYPE_CHECKING
import torch._C
import torch.fx
@ -615,7 +615,7 @@ def speculate_subgraph(
# The following code re-order the placeholders to
# O1, O2, O3, O4, O5, X1, X2, X3
def move_lifted_freevars_phs_to_end(
graph: torch.fx.Graph, lifted_freevars: Tuple[torch.fx.Node]
graph: torch.fx.Graph, lifted_freevars: tuple[torch.fx.Node]
):
lifted_ph_set = {
child_p.node for child_p in lifted_freevars.values()

View File

@ -1,7 +1,7 @@
import collections
import functools
import inspect
from typing import Any, Callable, Dict, final, Optional, Tuple, Union
from typing import Any, Callable, Dict, final, Optional, Union
from typing_extensions import Self
from ..utils import is_function_or_wrapper
@ -108,7 +108,7 @@ class LazyVariableTracker(VariableTracker):
def realize_all(
cls,
value: Any,
cache: Optional[Dict[int, Tuple[Any, Any]]] = None,
cache: Optional[Dict[int, tuple[Any, Any]]] = None,
) -> Any:
"""
Walk an object and realize all LazyVariableTrackers inside it.