mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Migrate from Tuple -> tuple in torch/_dynamo (#144261)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144261 Approved by: https://github.com/aorenste, https://github.com/zou3519
This commit is contained in:
parent
f295eff512
commit
1fe3af2c68
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
|
@ -100,7 +100,7 @@ class ModIndex(torch.autograd.Function):
|
|||
return torch.ops.aten.index(x, indices)
|
||||
|
||||
@staticmethod
|
||||
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
|
||||
def setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> None:
|
||||
x, indices = inputs
|
||||
ctx.save_for_backward(*indices)
|
||||
ctx.input_shape = x.shape
|
||||
|
|
@ -131,8 +131,8 @@ class TransformGetItemToIndex(TorchFunctionMode):
|
|||
def __torch_function__(
|
||||
self,
|
||||
func: OpOverload,
|
||||
types: Tuple[torch._C._TensorMeta, ...],
|
||||
args: Tuple[object, ...] = (),
|
||||
types: tuple[torch._C._TensorMeta, ...],
|
||||
args: tuple[object, ...] = (),
|
||||
kwargs: Optional[Dict[str, object]] = None,
|
||||
) -> object:
|
||||
if func == torch.Tensor.__getitem__:
|
||||
|
|
@ -161,8 +161,8 @@ _trace_wrapped_op = TraceWrapped()
|
|||
|
||||
def _assert_meta(
|
||||
grad: torch.Tensor,
|
||||
size: Tuple[int, ...],
|
||||
stride: Tuple[int, ...],
|
||||
size: tuple[int, ...],
|
||||
stride: tuple[int, ...],
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
assert grad.size() == size, "size mismatch"
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import functools
|
|||
import logging
|
||||
import sys
|
||||
from importlib.metadata import EntryPoint
|
||||
from typing import Callable, Dict, List, Optional, Protocol, Sequence, Tuple
|
||||
from typing import Callable, Dict, List, Optional, Protocol, Sequence
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
|
|
@ -14,7 +14,7 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class CompiledFn(Protocol):
|
||||
def __call__(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
||||
def __call__(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
...
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,18 +6,7 @@ import functools
|
|||
import itertools
|
||||
import sys
|
||||
import types
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Sequence, Union
|
||||
|
||||
from ..utils._backport_slots import dataclass_slots
|
||||
from .bytecode_analysis import (
|
||||
|
|
@ -447,7 +436,7 @@ def create_swap(n) -> List[Instruction]:
|
|||
|
||||
def lnotab_writer(
|
||||
lineno: int, byteno: int = 0
|
||||
) -> Tuple[List[int], Callable[[int, int], None]]:
|
||||
) -> tuple[List[int], Callable[[int, int], None]]:
|
||||
"""
|
||||
Used to create typing.CodeType.co_lnotab
|
||||
See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt
|
||||
|
|
@ -537,7 +526,7 @@ def linetable_311_writer(first_lineno: int):
|
|||
assert 0 < size <= 8
|
||||
# first byte - use 13 (no column info) is positions is
|
||||
# malformed, otherwise use 14 (long form)
|
||||
other_varints: Tuple[int, ...] = ()
|
||||
other_varints: tuple[int, ...] = ()
|
||||
if (
|
||||
positions
|
||||
and positions.lineno is not None
|
||||
|
|
@ -670,7 +659,7 @@ def assemble_exception_table(tab: List[ExceptionTableEntry]) -> bytes:
|
|||
return bytes(b)
|
||||
|
||||
|
||||
def assemble(instructions: List[Instruction], firstlineno: int) -> Tuple[bytes, bytes]:
|
||||
def assemble(instructions: List[Instruction], firstlineno: int) -> tuple[bytes, bytes]:
|
||||
"""Do the opposite of dis.get_instructions()"""
|
||||
code: List[int] = []
|
||||
if sys.version_info >= (3, 11):
|
||||
|
|
@ -854,7 +843,7 @@ def compute_exception_table(
|
|||
instructions: List[Instruction],
|
||||
) -> List[ExceptionTableEntry]:
|
||||
"""Compute exception table in list format from instructions with exn_tab_entries"""
|
||||
exn_dict: Dict[Tuple[int, int], Tuple[int, int, bool]] = {}
|
||||
exn_dict: Dict[tuple[int, int], tuple[int, int, bool]] = {}
|
||||
indexof = get_indexof(instructions)
|
||||
|
||||
for inst in instructions:
|
||||
|
|
@ -888,7 +877,7 @@ def compute_exception_table(
|
|||
# smallest byte that the next exception table entry can start at
|
||||
nexti = 0
|
||||
# stack of current nested keys
|
||||
key_stack: List[Tuple[int, int]] = []
|
||||
key_stack: List[tuple[int, int]] = []
|
||||
exn_tab: List[ExceptionTableEntry] = []
|
||||
|
||||
def pop():
|
||||
|
|
@ -933,7 +922,7 @@ def check_inst_exn_tab_entries_nested(
|
|||
"Properly sorted" means entries are sorted by increasing starts, then
|
||||
decreasing ends.
|
||||
"""
|
||||
entry_stack: List[Tuple[int, int]] = []
|
||||
entry_stack: List[tuple[int, int]] = []
|
||||
for entry in tab:
|
||||
key = (indexof[entry.start], indexof[entry.end])
|
||||
while entry_stack and entry_stack[-1][1] < key[0]:
|
||||
|
|
@ -949,7 +938,7 @@ def propagate_inst_exn_table_entries(instructions: List[Instruction]) -> None:
|
|||
Supports nested exception table entries.
|
||||
"""
|
||||
indexof = get_indexof(instructions)
|
||||
entries: Dict[Tuple[int, int], InstructionExnTabEntry] = {}
|
||||
entries: Dict[tuple[int, int], InstructionExnTabEntry] = {}
|
||||
for inst in instructions:
|
||||
if inst.exn_tab_entry:
|
||||
key = (
|
||||
|
|
@ -1417,7 +1406,7 @@ def transform_code_object(code, transformations, safe=False) -> types.CodeType:
|
|||
|
||||
def clean_and_assemble_instructions(
|
||||
instructions: List[Instruction], keys: List[str], code_options: Dict[str, Any]
|
||||
) -> Tuple[List[Instruction], types.CodeType]:
|
||||
) -> tuple[List[Instruction], types.CodeType]:
|
||||
# also implicitly checks for no duplicate instructions
|
||||
check_inst_exn_tab_entries_valid(instructions)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
import logging
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
from torch._guards import CompileId
|
||||
|
||||
|
|
@ -167,7 +166,7 @@ def is_recompilation(cache_size: CacheSizeRelevantForFrame) -> bool:
|
|||
|
||||
def exceeds_recompile_limit(
|
||||
cache_size: CacheSizeRelevantForFrame, compile_id: CompileId
|
||||
) -> Tuple[bool, str]:
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
Checks if we are exceeding the cache size limit.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import itertools
|
|||
import operator
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch._dynamo.external_utils import (
|
||||
|
|
@ -118,7 +118,7 @@ class AutogradCompilerInstance:
|
|||
inputs: List[torch.Tensor],
|
||||
sizes: List[int],
|
||||
scalars: List[Union[int, float]],
|
||||
origins: List[List[Tuple[int, str]]],
|
||||
origins: List[List[tuple[int, str]]],
|
||||
):
|
||||
counters["compiled_autograd"]["captures"] += 1
|
||||
self.id = next(COMPILE_COUNTER)
|
||||
|
|
@ -785,7 +785,7 @@ class AutogradCompilerInstance:
|
|||
return proxy_tensor.proxy
|
||||
|
||||
def bind_tensors_to_proxies(
|
||||
self, tensors, proxies, origins: Optional[List[Tuple[int, str]]] = None
|
||||
self, tensors, proxies, origins: Optional[List[tuple[int, str]]] = None
|
||||
):
|
||||
if isinstance(proxies, torch.fx.Proxy):
|
||||
if origins:
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ import typing
|
|||
import weakref
|
||||
from pathlib import Path
|
||||
from types import CellType, CodeType, FunctionType, ModuleType
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
from weakref import ReferenceType
|
||||
|
||||
|
|
@ -619,7 +619,7 @@ def _compile(
|
|||
globals: Dict[str, object],
|
||||
locals: Dict[str, object],
|
||||
builtins: Dict[str, object],
|
||||
closure: Tuple[CellType],
|
||||
closure: tuple[CellType],
|
||||
compiler_fn: CompilerFn,
|
||||
one_graph: bool,
|
||||
export: bool,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator, Tuple
|
||||
from typing import Any, Generator
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -24,7 +24,7 @@ class TracableCreateParameter(torch.autograd.Function):
|
|||
return placeholder.set_(tensor)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: torch.Tensor) -> Tuple[None, torch.Tensor]:
|
||||
def backward(ctx: Any, *grad_outputs: torch.Tensor) -> tuple[None, torch.Tensor]:
|
||||
grad = grad_outputs[0]
|
||||
return None, grad # grad flows to placeholder
|
||||
|
||||
|
|
@ -38,7 +38,7 @@ def tracable_create_parameter(
|
|||
|
||||
|
||||
def new_parameter_placeholder(
|
||||
size: Tuple[int, ...], dtype: torch.dtype, device: torch.device, requires_grad: bool
|
||||
size: tuple[int, ...], dtype: torch.dtype, device: torch.device, requires_grad: bool
|
||||
) -> torch.nn.Parameter:
|
||||
"""Create a placeholder to be passed to the above functions"""
|
||||
result = torch.nn.Parameter(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, Iterable, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -397,7 +397,7 @@ def get_interface_for_device(device: Union[str, torch.device]) -> Type[DeviceInt
|
|||
raise NotImplementedError(f"No interface for device {device}")
|
||||
|
||||
|
||||
def get_registered_device_interfaces() -> Iterable[Tuple[str, Type[DeviceInterface]]]:
|
||||
def get_registered_device_interfaces() -> Iterable[tuple[str, Type[DeviceInterface]]]:
|
||||
if not _device_initialized:
|
||||
init_device_reg()
|
||||
return device_interfaces.items()
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@ from typing import (
|
|||
NamedTuple,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
|
@ -1014,7 +1013,7 @@ class FlattenInputOutputSignature(torch.fx.Transformer):
|
|||
def __init__(
|
||||
self,
|
||||
m: torch.fx.GraphModule,
|
||||
flat_args: Tuple[Any],
|
||||
flat_args: tuple[Any],
|
||||
matched_input_elements_positions: List[int],
|
||||
flat_results: List[Any],
|
||||
matched_output_elements_positions: List[int],
|
||||
|
|
@ -1381,7 +1380,7 @@ def export(
|
|||
Dict[torch._ops.OpOverload, Callable[..., Any]]
|
||||
] = None,
|
||||
tracing_mode: str = "symbolic",
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], tuple[Any], List[Any]]] = None,
|
||||
specialize_float: bool = True,
|
||||
assume_static_by_default: bool = False,
|
||||
same_signature: bool = True,
|
||||
|
|
@ -1472,7 +1471,7 @@ def export(
|
|||
graph = None
|
||||
out_guards = None
|
||||
graph_captured_input = None
|
||||
graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None
|
||||
graph_captured_result: Optional[tuple[torch.Tensor, ...]] = None
|
||||
fake_mode = None
|
||||
result_traced = None
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import textwrap
|
|||
import typing
|
||||
from enum import auto, Enum
|
||||
from traceback import extract_stack, format_exc, format_list, StackSummary
|
||||
from typing import Any, NoReturn, Optional, Tuple, Type, TYPE_CHECKING
|
||||
from typing import Any, NoReturn, Optional, Type, TYPE_CHECKING
|
||||
|
||||
import torch._guards
|
||||
|
||||
|
|
@ -425,7 +425,7 @@ def augment_exc_message(exc: Exception, msg: str = "\n", export: bool = False) -
|
|||
|
||||
def get_exc_message(
|
||||
e: Exception, compile_id: CompileId
|
||||
) -> Tuple[Optional[str], Optional[int]]:
|
||||
) -> tuple[Optional[str], Optional[int]]:
|
||||
filename = None
|
||||
lineno = None
|
||||
if e.innermost_user_frame_summary is not None: # type: ignore[attr-defined]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
import operator
|
||||
from typing import Any, Dict, Iterable, List, Set, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Set
|
||||
|
||||
import torch.fx
|
||||
from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation
|
||||
|
|
@ -104,7 +104,7 @@ def _replace_region_with_subgraph(
|
|||
graph: torch.fx.Graph,
|
||||
region: Region,
|
||||
get_subgraph_node: Node,
|
||||
node_ind_arg_ind: Iterable[Tuple[int, int]],
|
||||
node_ind_arg_ind: Iterable[tuple[int, int]],
|
||||
inds_with_external_users: List[int],
|
||||
sub_gm: torch.fx.GraphModule,
|
||||
subgraph_name: str,
|
||||
|
|
@ -147,7 +147,7 @@ def _replace_region_with_subgraph(
|
|||
|
||||
def _get_external_inputs(
|
||||
region: Region,
|
||||
) -> Dict[Node, Tuple[int, int]]:
|
||||
) -> Dict[Node, tuple[int, int]]:
|
||||
external_node_to_indices = dict()
|
||||
region_unique = set(region)
|
||||
for node_ind, node in enumerate(region):
|
||||
|
|
@ -183,9 +183,9 @@ def _get_inds_with_external_users(region: Region, inds_unique: Set[int]) -> None
|
|||
|
||||
def _copy_nodes_and_remap_inputs(
|
||||
subgraph: torch.fx.Graph, region: Region
|
||||
) -> Dict[Tuple[int, int], Any]:
|
||||
) -> Dict[tuple[int, int], Any]:
|
||||
external_inputs_to_indices = _get_external_inputs(region)
|
||||
indices_to_placeholder_ind: Dict[Tuple[int, int], Any] = {}
|
||||
indices_to_placeholder_ind: Dict[tuple[int, int], Any] = {}
|
||||
region_to_subgraph_node = {}
|
||||
for node in external_inputs_to_indices.keys():
|
||||
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
|
||||
|
|
@ -219,7 +219,7 @@ def _create_subgraph_outputs(
|
|||
def _create_subgraph(
|
||||
region: Region,
|
||||
inds_with_external_users: List[int],
|
||||
) -> Tuple[torch.fx.Graph, Dict[Tuple[int, int], Any]]:
|
||||
) -> tuple[torch.fx.Graph, Dict[tuple[int, int], Any]]:
|
||||
subgraph: torch.fx.Graph = torch.fx.Graph()
|
||||
node_ind_input_inds = _copy_nodes_and_remap_inputs(subgraph, region)
|
||||
_create_subgraph_outputs(subgraph, inds_with_external_users)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from typing import (
|
|||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
)
|
||||
|
|
@ -34,7 +33,7 @@ if TYPE_CHECKING:
|
|||
Node = torch.fx.Node
|
||||
Region = List[Node]
|
||||
IdenticalNodes = List[Node]
|
||||
GlobalStateKey = Tuple[bool, bool, int, bool, bool, torch.dtype, bool, bool, bool, bool]
|
||||
GlobalStateKey = tuple[bool, bool, int, bool, bool, torch.dtype, bool, bool, bool, bool]
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
graph_expansion_log = torch._logging.getArtifactLogger(
|
||||
|
|
@ -48,7 +47,7 @@ def debug_log(msg: str, *args) -> None: # type: ignore[no-untyped-def]
|
|||
|
||||
def _extract_tensor_metadata_for_node_hash(
|
||||
x: torch.Tensor,
|
||||
) -> Tuple[Callable[[T], T], Tuple[Any, ...]]:
|
||||
) -> tuple[Callable[[T], T], tuple[Any, ...]]:
|
||||
from torch._inductor.codecache import _ident, extract_tensor_metadata_for_cache_key
|
||||
|
||||
out = []
|
||||
|
|
@ -104,7 +103,7 @@ def _extract_tensor_arg(arg: Any) -> Any:
|
|||
|
||||
def _normalize_args(
|
||||
node: Node,
|
||||
) -> Tuple[Tuple[str, ...], Tuple[Optional[Any], ...]]:
|
||||
) -> tuple[tuple[str, ...], tuple[Optional[Any], ...]]:
|
||||
flat_args, _ = tree_flatten(node.args)
|
||||
sorted_kwargs = sorted(node.kwargs.items(), key=lambda x: x[0])
|
||||
sorted_keys = tuple(sorted(node.kwargs.keys()))
|
||||
|
|
|
|||
|
|
@ -21,18 +21,7 @@ import weakref
|
|||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from inspect import currentframe
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Type, TYPE_CHECKING, Union
|
||||
from weakref import ReferenceType
|
||||
|
||||
import torch
|
||||
|
|
@ -647,7 +636,7 @@ class GuardBuilder(GuardBuilderBase):
|
|||
self._cached_guard_managers: Dict[
|
||||
str, torch._C._dynamo.guards.GuardManager
|
||||
] = {}
|
||||
self._cached_duplicate_input_guards: Set[Tuple[str, str]] = set()
|
||||
self._cached_duplicate_input_guards: Set[tuple[str, str]] = set()
|
||||
|
||||
def guard_on_dict_keys_and_ignore_order(self, example_value, guard):
|
||||
dict_mgr = self.get_guard_manager(guard)
|
||||
|
|
@ -1511,7 +1500,7 @@ class GuardBuilder(GuardBuilderBase):
|
|||
ref = self.arg_ref(guard)
|
||||
val = self.get(guard.name)
|
||||
if np:
|
||||
np_types: Tuple[Type[Any], ...] = (
|
||||
np_types: tuple[Type[Any], ...] = (
|
||||
np.int8,
|
||||
np.int16,
|
||||
np.int32,
|
||||
|
|
@ -1817,10 +1806,10 @@ class GuardBuilder(GuardBuilderBase):
|
|||
]
|
||||
|
||||
if output_graph.export_constraints:
|
||||
names: Dict[str, Tuple[int, int]] = {}
|
||||
source_pairs: List[Tuple[Source, Source]] = []
|
||||
names: Dict[str, tuple[int, int]] = {}
|
||||
source_pairs: List[tuple[Source, Source]] = []
|
||||
derived_equalities: List[ # type: ignore[type-arg]
|
||||
Tuple[Source, Union[Source, Symbol], Callable]
|
||||
tuple[Source, Union[Source, Symbol], Callable]
|
||||
] = []
|
||||
phantom_symbols: Dict[str, Symbol] = {}
|
||||
relaxed_sources: Set[Source] = set()
|
||||
|
|
@ -2178,7 +2167,7 @@ class PyExprCSEPass:
|
|||
log.exception("Failed to visit expr at line %s.\n%s", ex.lineno, e)
|
||||
raise
|
||||
|
||||
def replace(self, expr: str) -> Tuple[List[str], str]:
|
||||
def replace(self, expr: str) -> tuple[List[str], str]:
|
||||
replacer = self.Replacer(self._config, self._new_var)
|
||||
new_node = replacer.visit(ast.parse(expr))
|
||||
return replacer.preface, _ast_unparse(new_node)
|
||||
|
|
@ -2559,13 +2548,13 @@ class CheckFunctionManager:
|
|||
return None
|
||||
|
||||
|
||||
def build_guard_function(code_parts, closure_args) -> Tuple[str, str]:
|
||||
def build_guard_function(code_parts, closure_args) -> tuple[str, str]:
|
||||
from torch._inductor.utils import IndentedBuffer
|
||||
|
||||
csepass = PyExprCSEPass()
|
||||
csepass.count(code_parts)
|
||||
|
||||
def replace(expr: str) -> Tuple[List[str], str]:
|
||||
def replace(expr: str) -> tuple[List[str], str]:
|
||||
return csepass.replace(expr)
|
||||
|
||||
# Generate the inner body of the guard function.
|
||||
|
|
|
|||
|
|
@ -12,18 +12,7 @@ import sys
|
|||
import traceback
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Set, TYPE_CHECKING, Union
|
||||
|
||||
import sympy
|
||||
|
||||
|
|
@ -427,7 +416,7 @@ class OutputGraph:
|
|||
# random_calls tracks calls to random() and random_values_var stores the name of
|
||||
# the variable that stores __gen_rand_values results.
|
||||
self.random_calls: List[
|
||||
Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
|
||||
tuple[Callable[..., object], tuple[object, ...], Dict[str, object]]
|
||||
] = []
|
||||
self.random_values_var = None
|
||||
|
||||
|
|
@ -638,7 +627,7 @@ class OutputGraph:
|
|||
Saves to out if it is provided. Else saves to the tracing context's global_state.
|
||||
"""
|
||||
global_state = cast(
|
||||
Dict[str, Tuple[Callable[..., Any], bool]],
|
||||
Dict[str, tuple[Callable[..., Any], bool]],
|
||||
out
|
||||
if out is not None
|
||||
else self.tracing_context.global_context.global_state,
|
||||
|
|
@ -1270,7 +1259,7 @@ class OutputGraph:
|
|||
Momentarily restores the global state to what it was prior to tracing the current output
|
||||
"""
|
||||
prior_global_state = self.tracing_context.global_context.copy_graphstate()
|
||||
current_global_state: Dict[str, Tuple[Any, bool]] = {}
|
||||
current_global_state: Dict[str, tuple[Any, bool]] = {}
|
||||
self.save_global_state(out=current_global_state)
|
||||
try:
|
||||
# Set to state prior to tracing the graph
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import logging
|
|||
import os
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from typing import DefaultDict, Optional, Tuple, TYPE_CHECKING, TypeVar, Union
|
||||
from typing import DefaultDict, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch._dynamo.config
|
||||
|
|
@ -168,10 +168,10 @@ class FrameStateSizeEntry:
|
|||
# NB: We don't have cases where we have a known dimensionality but
|
||||
# we know NOTHING about the individual sizes
|
||||
size: Union[
|
||||
AutoDynamic, AutoUnset, Tuple[Union[int, AutoDynamic], ...]
|
||||
AutoDynamic, AutoUnset, tuple[Union[int, AutoDynamic], ...]
|
||||
] = dataclasses.field(default=auto_unset)
|
||||
stride: Union[
|
||||
AutoDynamic, AutoUnset, Tuple[Union[int, AutoDynamic, InferStride], ...]
|
||||
AutoDynamic, AutoUnset, tuple[Union[int, AutoDynamic, InferStride], ...]
|
||||
] = dataclasses.field(default=auto_unset)
|
||||
|
||||
def render(self) -> str:
|
||||
|
|
@ -187,7 +187,7 @@ class FrameStateSizeEntry:
|
|||
else:
|
||||
return str(s)
|
||||
|
||||
def render_tuple(ss: Tuple[Union[int, AutoDynamic, InferStride], ...]) -> str:
|
||||
def render_tuple(ss: tuple[Union[int, AutoDynamic, InferStride], ...]) -> str:
|
||||
return "[" + ", ".join(render_single(s) for s in ss) + "]"
|
||||
|
||||
# Common cases
|
||||
|
|
@ -246,7 +246,7 @@ class FrameStateSizeEntry:
|
|||
return self.stride[dim] is auto_dynamic
|
||||
|
||||
@staticmethod
|
||||
def _munge_symint(xs: Tuple[int, ...]) -> Tuple[Union[AutoDynamic, int], ...]:
|
||||
def _munge_symint(xs: tuple[int, ...]) -> tuple[Union[AutoDynamic, int], ...]:
|
||||
return tuple(auto_dynamic if isinstance(x, torch.SymInt) else x for x in xs)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -255,7 +255,7 @@ class FrameStateSizeEntry:
|
|||
|
||||
@classmethod
|
||||
def make_tensor(
|
||||
cls, size: Tuple[int, ...], stride: Tuple[int, ...]
|
||||
cls, size: tuple[int, ...], stride: tuple[int, ...]
|
||||
) -> FrameStateSizeEntry:
|
||||
return FrameStateSizeEntry(
|
||||
scalar=auto_dynamic,
|
||||
|
|
@ -264,7 +264,7 @@ class FrameStateSizeEntry:
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def make_size(cls, size: Tuple[int, ...]) -> FrameStateSizeEntry:
|
||||
def make_size(cls, size: tuple[int, ...]) -> FrameStateSizeEntry:
|
||||
return FrameStateSizeEntry(
|
||||
scalar=auto_unset,
|
||||
size=cls._munge_symint(size),
|
||||
|
|
@ -284,9 +284,9 @@ class FrameStateSizeEntry:
|
|||
@classmethod
|
||||
def _merge_atom_tup(
|
||||
cls,
|
||||
xs: Union[AutoDynamic, AutoUnset, Tuple[_T, ...]],
|
||||
ys: Union[AutoDynamic, AutoUnset, Tuple[_T, ...]],
|
||||
) -> Union[AutoDynamic, AutoUnset, Tuple[Union[AutoDynamic, _T], ...]]:
|
||||
xs: Union[AutoDynamic, AutoUnset, tuple[_T, ...]],
|
||||
ys: Union[AutoDynamic, AutoUnset, tuple[_T, ...]],
|
||||
) -> Union[AutoDynamic, AutoUnset, tuple[Union[AutoDynamic, _T], ...]]:
|
||||
if xs is auto_unset:
|
||||
return ys
|
||||
if ys is auto_unset:
|
||||
|
|
@ -651,7 +651,7 @@ def put_code_state() -> None:
|
|||
put_remote_code_state(cache_key)
|
||||
|
||||
|
||||
def write_local_impl(cache_key: str, pickled_code: bytes) -> Optional[Tuple[str, int]]:
|
||||
def write_local_impl(cache_key: str, pickled_code: bytes) -> Optional[tuple[str, int]]:
|
||||
path = code_state_path(cache_key)
|
||||
|
||||
if path is None:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# Please add a new import when adding a new polyfill module.
|
||||
|
||||
import importlib
|
||||
from typing import Tuple, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .. import polyfills, trace_rules
|
||||
|
||||
|
|
@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
# See also the TYPE_CHECKING block in torch/_dynamo/polyfills/__init__.py
|
||||
POLYFILLED_MODULE_NAMES: Tuple[str, ...] = (
|
||||
POLYFILLED_MODULE_NAMES: tuple[str, ...] = (
|
||||
"builtins",
|
||||
"functools",
|
||||
"itertools",
|
||||
|
|
@ -21,7 +21,7 @@ POLYFILLED_MODULE_NAMES: Tuple[str, ...] = (
|
|||
"pytree",
|
||||
"sys",
|
||||
)
|
||||
POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple(
|
||||
POLYFILLED_MODULES: tuple["ModuleType", ...] = tuple(
|
||||
importlib.import_module(f".{submodule}", package=polyfills.__name__)
|
||||
for submodule in POLYFILLED_MODULE_NAMES
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import dataclasses
|
||||
from dataclasses import field
|
||||
from types import CellType, CodeType, ModuleType
|
||||
from typing import Any, BinaryIO, Dict, IO, Tuple
|
||||
from typing import Any, BinaryIO, Dict, IO
|
||||
from typing_extensions import Self
|
||||
|
||||
from torch.utils._import_utils import import_dill
|
||||
|
|
@ -29,7 +29,7 @@ class DummyModule:
|
|||
@dataclasses.dataclass
|
||||
class ExecutionRecord:
|
||||
code: CodeType
|
||||
closure: Tuple[CellType]
|
||||
closure: tuple[CellType]
|
||||
globals: Dict[str, Any] = field(default_factory=dict)
|
||||
locals: Dict[str, Any] = field(default_factory=dict)
|
||||
builtins: Dict[str, Any] = field(default_factory=dict)
|
||||
|
|
@ -50,7 +50,7 @@ class ExecutionRecorder:
|
|||
LOCAL_MOD_PREFIX = "___local_mod_"
|
||||
|
||||
code: CodeType
|
||||
closure: Tuple[CellType]
|
||||
closure: tuple[CellType]
|
||||
globals: Dict[str, Any] = field(default_factory=dict)
|
||||
locals: Dict[str, Any] = field(default_factory=dict)
|
||||
builtins: Dict[str, Any] = field(default_factory=dict)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import shutil
|
|||
import sys
|
||||
import textwrap
|
||||
from importlib import import_module
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._dynamo.debug_utils import (
|
||||
|
|
@ -89,7 +89,7 @@ def save_graph_repro_ep(
|
|||
*,
|
||||
exported_program: Optional[ExportedProgram] = None,
|
||||
gm: Optional[torch.nn.Module] = None,
|
||||
args: Optional[Tuple[Any]] = None,
|
||||
args: Optional[tuple[Any]] = None,
|
||||
config_patches: Optional[Dict[str, str]] = None,
|
||||
stable_output=False,
|
||||
save_dir=None,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import copy
|
|||
import dataclasses
|
||||
import sys
|
||||
import types
|
||||
from typing import Any, cast, Dict, List, Optional, Tuple
|
||||
from typing import Any, cast, Dict, List, Optional
|
||||
|
||||
from .bytecode_transformation import (
|
||||
bytecode_from_template,
|
||||
|
|
@ -89,7 +89,7 @@ def _try_except_tf_mode_template(dummy, stack_var_name):
|
|||
@dataclasses.dataclass(frozen=True)
|
||||
class ReenterWith:
|
||||
stack_index: int
|
||||
target_values: Optional[Tuple[Any, ...]] = None
|
||||
target_values: Optional[tuple[Any, ...]] = None
|
||||
|
||||
def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]):
|
||||
"""
|
||||
|
|
@ -271,14 +271,14 @@ class ContinueExecutionCache:
|
|||
code,
|
||||
lineno,
|
||||
offset: int,
|
||||
setup_fn_target_offsets: Tuple[int, ...], # only used in Python 3.11+
|
||||
setup_fn_target_offsets: tuple[int, ...], # only used in Python 3.11+
|
||||
nstack: int,
|
||||
argnames: Tuple[str, ...],
|
||||
argnames_null: Tuple[str, ...],
|
||||
setup_fns: Tuple[ReenterWith, ...],
|
||||
stack_ctx_vars: Tuple[Tuple[int, Tuple[Any]], ...],
|
||||
argnames_ctx_vars: Tuple[Tuple[str, Tuple[Any]], ...],
|
||||
null_idxes: Tuple[int, ...],
|
||||
argnames: tuple[str, ...],
|
||||
argnames_null: tuple[str, ...],
|
||||
setup_fns: tuple[ReenterWith, ...],
|
||||
stack_ctx_vars: tuple[tuple[int, tuple[Any]], ...],
|
||||
argnames_ctx_vars: tuple[tuple[str, tuple[Any]], ...],
|
||||
null_idxes: tuple[int, ...],
|
||||
) -> types.CodeType:
|
||||
assert offset is not None
|
||||
assert not (
|
||||
|
|
@ -461,7 +461,7 @@ class ContinueExecutionCache:
|
|||
|
||||
@classmethod
|
||||
def generate_based_on_original_code_object(
|
||||
cls, code, lineno, offset: int, setup_fn_target_offsets: Tuple[int, ...], *args
|
||||
cls, code, lineno, offset: int, setup_fn_target_offsets: tuple[int, ...], *args
|
||||
):
|
||||
"""
|
||||
This handles the case of generating a resume into code generated
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ import traceback
|
|||
import types
|
||||
import typing
|
||||
import weakref
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Set, Type, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
|
@ -2634,7 +2634,7 @@ class InstructionTranslatorBase(
|
|||
speculation_log: SpeculationLog,
|
||||
distributed_state: Optional[DistributedState],
|
||||
# This determines whether to use the execution recorder.
|
||||
closure: Optional[Tuple[types.CellType]] = None,
|
||||
closure: Optional[tuple[types.CellType]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.speculation_log = speculation_log
|
||||
|
|
@ -2678,7 +2678,7 @@ class InstructionTranslatorBase(
|
|||
# Stack of module being parsed, current nn.module is at the end of ordered dict.
|
||||
# The first field of tuple is the fully qualified name of current module
|
||||
# in original hierarchy. The second field is the type of current nn.module
|
||||
self.nn_module_stack: Dict[str, Tuple[str, Type[Any]]] = {}
|
||||
self.nn_module_stack: Dict[str, tuple[str, Type[Any]]] = {}
|
||||
self.num_calls: Dict[str, int] = {}
|
||||
# Flag to indicate whether tracing is used for export.
|
||||
self.export = export
|
||||
|
|
@ -2861,7 +2861,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||
torch_function_mode_stack
|
||||
)
|
||||
|
||||
self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = []
|
||||
self.debug_locals: List[tuple[VariableTracker, List[VariableTracker]]] = []
|
||||
if export:
|
||||
# export gets confused if we never realize unused inputs
|
||||
# in export mode just eagerly realize everything
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import contextlib
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Tuple, Union
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.testing
|
||||
|
|
@ -19,7 +19,7 @@ from . import config, reset, utils
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_tests(needs: Union[str, Tuple[str, ...]] = ()) -> None:
|
||||
def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None:
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
if TEST_WITH_TORCHDYNAMO or IS_WINDOWS or TEST_WITH_CROSSREF:
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from typing import (
|
|||
Optional,
|
||||
overload,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
|
@ -141,7 +140,7 @@ def reduce_to_scalar_loss(out: torch.Tensor) -> torch.Tensor:
|
|||
|
||||
@overload
|
||||
def reduce_to_scalar_loss(
|
||||
out: Union[List[Any], Tuple[Any, ...], Dict[Any, Any]]
|
||||
out: Union[List[Any], tuple[Any, ...], Dict[Any, Any]]
|
||||
) -> float:
|
||||
...
|
||||
|
||||
|
|
|
|||
|
|
@ -54,7 +54,6 @@ from typing import (
|
|||
Optional,
|
||||
overload,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
|
|
@ -106,7 +105,7 @@ try:
|
|||
|
||||
# NOTE: Make sure `NP_SUPPORTED_MODULES` and `NP_TO_TNP_MODULE` are in sync.
|
||||
if np:
|
||||
NP_SUPPORTED_MODULES: Tuple[types.ModuleType, ...] = (
|
||||
NP_SUPPORTED_MODULES: tuple[types.ModuleType, ...] = (
|
||||
np,
|
||||
np.fft,
|
||||
np.linalg,
|
||||
|
|
@ -202,8 +201,8 @@ class ReinplaceCounters:
|
|||
|
||||
|
||||
def tabulate(
|
||||
rows: Union[List[Tuple[str, object]], List[List[object]]],
|
||||
headers: Union[Tuple[str, ...], List[str]],
|
||||
rows: Union[List[tuple[str, object]], List[List[object]]],
|
||||
headers: Union[tuple[str, ...], List[str]],
|
||||
) -> str:
|
||||
try:
|
||||
import tabulate
|
||||
|
|
@ -590,7 +589,7 @@ def compile_times(repr: Literal["str"], aggregate: bool = False) -> str:
|
|||
@overload
|
||||
def compile_times(
|
||||
repr: Literal["csv"], aggregate: bool = False
|
||||
) -> Tuple[List[str], List[object]]:
|
||||
) -> tuple[List[str], List[object]]:
|
||||
...
|
||||
|
||||
|
||||
|
|
@ -658,7 +657,7 @@ class DuplicateWarningChecker:
|
|||
def reset(self):
|
||||
self.set = OrderedDict()
|
||||
|
||||
def add(self, key: Union[str, Tuple[object, object]]) -> bool:
|
||||
def add(self, key: Union[str, tuple[object, object]]) -> bool:
|
||||
if key in self.set:
|
||||
self.set.move_to_end(key, last=True)
|
||||
if not config.verbose:
|
||||
|
|
@ -797,7 +796,7 @@ def istype(obj: object, allowed_types: Type[T]) -> TypeIs[T]:
|
|||
|
||||
@overload
|
||||
def istype(
|
||||
obj: object, allowed_types: Tuple[Type[List[T]], Type[Tuple[T, ...]]]
|
||||
obj: object, allowed_types: tuple[Type[List[T]], Type[tuple[T, ...]]]
|
||||
) -> TypeIs[T]:
|
||||
...
|
||||
|
||||
|
|
@ -940,7 +939,7 @@ def is_numpy_ndarray(value):
|
|||
|
||||
def istensor(obj):
|
||||
"""Check of obj is a tensor"""
|
||||
tensor_list: Tuple[type, ...] = (
|
||||
tensor_list: tuple[type, ...] = (
|
||||
torch.Tensor,
|
||||
torch.nn.Parameter,
|
||||
*config.traceable_tensor_subclasses,
|
||||
|
|
@ -1900,7 +1899,7 @@ def is_namedtuple_cls(cls):
|
|||
|
||||
|
||||
@functools.lru_cache(1)
|
||||
def namedtuple_fields(cls) -> Tuple[str, ...]:
|
||||
def namedtuple_fields(cls) -> tuple[str, ...]:
|
||||
"""Get the fields of a namedtuple or a torch.return_types.* quasi-namedtuple"""
|
||||
if cls is slice:
|
||||
return ("start", "stop", "step")
|
||||
|
|
@ -2188,7 +2187,7 @@ def tuple_iterator_getitem(it, index):
|
|||
iter_next = next
|
||||
|
||||
|
||||
def normalize_range_iter(range_iter) -> Tuple[int, int, int]:
|
||||
def normalize_range_iter(range_iter) -> tuple[int, int, int]:
|
||||
_, (range_obj,), maybe_idx = range_iter.__reduce__()
|
||||
# In 3.12+, `maybe_idx` could be None, and `range_obj.start` would've been
|
||||
# already incremented by the current index.
|
||||
|
|
@ -3070,7 +3069,7 @@ def tensor_always_has_static_shape(
|
|||
tensor: Union[torch.Tensor, Any],
|
||||
is_tensor: bool,
|
||||
tensor_source: Source,
|
||||
) -> Tuple[bool, Optional[TensorStaticReason]]:
|
||||
) -> tuple[bool, Optional[TensorStaticReason]]:
|
||||
"""
|
||||
Given a tensor, source, and is_tensor flag, determine if a shape should be static.
|
||||
|
||||
|
|
|
|||
|
|
@ -6,17 +6,7 @@ import functools
|
|||
import inspect
|
||||
import itertools
|
||||
import types
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
)
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, TYPE_CHECKING, TypeVar
|
||||
from typing_extensions import Never
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -1095,7 +1085,7 @@ class DynamoTritonHOPifier(TritonHOPifier):
|
|||
def get_value(self, val: Any) -> Any:
|
||||
return val.value
|
||||
|
||||
def check_grid(self, grid) -> Tuple[torch.fx.proxy.Proxy, ...]:
|
||||
def check_grid(self, grid) -> tuple[torch.fx.proxy.Proxy, ...]:
|
||||
from .lists import BaseListVariable
|
||||
|
||||
if isinstance(grid, BaseListVariable):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import itertools
|
|||
import logging
|
||||
import types
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
import torch._C
|
||||
import torch.fx
|
||||
|
|
@ -615,7 +615,7 @@ def speculate_subgraph(
|
|||
# The following code re-order the placeholders to
|
||||
# O1, O2, O3, O4, O5, X1, X2, X3
|
||||
def move_lifted_freevars_phs_to_end(
|
||||
graph: torch.fx.Graph, lifted_freevars: Tuple[torch.fx.Node]
|
||||
graph: torch.fx.Graph, lifted_freevars: tuple[torch.fx.Node]
|
||||
):
|
||||
lifted_ph_set = {
|
||||
child_p.node for child_p in lifted_freevars.values()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import collections
|
||||
import functools
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, final, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, final, Optional, Union
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..utils import is_function_or_wrapper
|
||||
|
|
@ -108,7 +108,7 @@ class LazyVariableTracker(VariableTracker):
|
|||
def realize_all(
|
||||
cls,
|
||||
value: Any,
|
||||
cache: Optional[Dict[int, Tuple[Any, Any]]] = None,
|
||||
cache: Optional[Dict[int, tuple[Any, Any]]] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Walk an object and realize all LazyVariableTrackers inside it.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user