PEP585: More UP006 fixes (#146392)

This should be the final PR before we can enable RUFF UP006.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146392
Approved by: https://github.com/justinchuby, https://github.com/albanD, https://github.com/Skylion007
This commit is contained in:
Aaron Orenstein 2025-02-19 14:37:18 -08:00 committed by PyTorch MergeBot
parent 76ad19a549
commit db4ce78d46
81 changed files with 283 additions and 329 deletions

View File

@ -1,7 +1,6 @@
# Owner(s): ["oncall: distributed"]
import sys
from typing import List
import torch
import torch.distributed as dist
@ -119,7 +118,7 @@ def _find_name_param_mappings(module: torch.nn.Module, prefix: str):
def _discover_ddp_ignored_params(module: torch.nn.Module, prefix: str):
ddp_ignore_parameters: List[str] = []
ddp_ignore_parameters: list[str] = []
if isinstance(module, FSDP2):
ddp_ignore_parameters = [name for name, _ in module.named_parameters(prefix)]
else:

View File

@ -3,7 +3,7 @@ import gc
import threading
import unittest
from datetime import timedelta
from typing import List, Optional
from typing import Optional
import torch
import torch.distributed as dist
@ -576,24 +576,24 @@ class ProcessGroupDummy(dist.ProcessGroup):
self.waits = 0
self.dels = 0
def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> dist.Work:
def broadcast(self, tensor_list: list[torch.Tensor], opts: object) -> dist.Work:
return _DummyWork(self)
def allgather_into_tensor_coalesced(
self,
output_lists: List[torch.Tensor],
input_list: List[torch.Tensor],
output_lists: list[torch.Tensor],
input_list: list[torch.Tensor],
opts: object,
) -> dist.Work:
return _DummyWork(self)
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> dist.Work:
def allreduce(self, tensors: list[torch.Tensor], opts: object) -> dist.Work:
return _DummyWork(self)
def reduce_scatter_tensor_coalesced(
self,
outputTensors: List[torch.Tensor],
inputTensors: List[torch.Tensor],
outputTensors: list[torch.Tensor],
inputTensors: list[torch.Tensor],
opts: object,
) -> dist.Work:
return _DummyWork(self)

View File

@ -1,5 +1,4 @@
# Owner(s): ["module: functorch"]
from typing import Dict, List, Tuple
from unittest.mock import MagicMock, patch
from torch._functorch._activation_checkpointing.ac_logging_utils import (
@ -33,17 +32,17 @@ class TestAcLogging(TestCase):
self.graph.nodes = [self.node1, self.node2]
self.all_recomputable_banned_nodes: List[Node] = [self.node1]
self.saved_node_idxs: List[int] = [0]
self.recomputable_node_idxs: List[int] = []
self.all_recomputable_banned_nodes: list[Node] = [self.node1]
self.saved_node_idxs: list[int] = [0]
self.recomputable_node_idxs: list[int] = []
self.expected_runtime: int = 100
self.memories_banned_nodes: List[int] = [50]
self.runtimes_banned_nodes: List[int] = [10]
self.min_cut_saved_values: List[Node] = [self.node1]
self.memories_banned_nodes: list[int] = [50]
self.runtimes_banned_nodes: list[int] = [10]
self.min_cut_saved_values: list[Node] = [self.node1]
def test_create_joint_graph_node_information(self) -> None:
recomputable_node_info: Dict[str, int] = {"node1": 0}
expected_output: Dict[str, Dict] = {
recomputable_node_info: dict[str, int] = {"node1": 0}
expected_output: dict[str, dict] = {
"node1": {
"index": 0,
"name": "node1",
@ -68,12 +67,12 @@ class TestAcLogging(TestCase):
self.assertEqual(result, expected_output)
def test_create_joint_graph_edges(self) -> None:
expected_edges: List[Tuple[str, str]] = [("node1", "node2")]
expected_edges: list[tuple[str, str]] = [("node1", "node2")]
result = create_joint_graph_edges(self.graph)
self.assertEqual(result, expected_edges)
def test_create_activation_checkpointing_logging_structure_payload(self) -> None:
input_joint_graph_node_information: Dict[str, Dict] = {
input_joint_graph_node_information: dict[str, dict] = {
"node1": {
"index": 0,
"name": "node1",
@ -85,8 +84,8 @@ class TestAcLogging(TestCase):
"recomputable_candidate_info": {"recomputable_node_idx": 0},
}
}
joint_graph_edges: List[Tuple[str, str]] = [("node1", "node2")]
expected_payload: Dict[str, any] = {
joint_graph_edges: list[tuple[str, str]] = [("node1", "node2")]
expected_payload: dict[str, any] = {
"Joint Graph Size": 2,
"Joint Graph Edges": {"Total": 1, "Edges": joint_graph_edges},
"Joint Graph Node Information": input_joint_graph_node_information,

View File

@ -2,7 +2,7 @@
"""Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime."""
from typing import Sequence
from typing import Sequence # noqa: UP035
import onnx_test_common
import onnxscript

View File

@ -9,7 +9,7 @@ import platform
import sys
import time
import traceback
from typing import Any, Mapping
from typing import Any, TYPE_CHECKING
import numpy as np
@ -20,6 +20,10 @@ import onnxscript
import torch
if TYPE_CHECKING:
from collections.abc import Mapping
_REPRODUCTION_TEMPLATE = '''\
import google.protobuf.text_format
import numpy as np

View File

@ -3,7 +3,7 @@
import datetime
from enum import Enum
from types import TracebackType
from typing import Callable, Optional, Type
from typing import Callable, Optional
class Aggregation(Enum):
VALUE = ...
@ -48,7 +48,7 @@ class _WaitCounterTracker:
def __enter__(self) -> None: ...
def __exit__(
self,
exec_type: Optional[Type[BaseException]] = None,
exec_type: Optional[type[BaseException]] = None,
exec_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None: ...

View File

@ -4,17 +4,7 @@ from collections import defaultdict
from collections.abc import Sequence
from functools import lru_cache, partial, wraps
from itertools import chain
from typing import (
Callable,
Dict,
FrozenSet,
List,
Optional,
Set,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing import Callable, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import ParamSpec

View File

@ -11,7 +11,7 @@ Python polyfills for common builtins.
import types
from collections.abc import MutableMapping, Sequence
from itertools import repeat as _repeat
from typing import Any, Callable, List, TYPE_CHECKING
from typing import Any, Callable, TYPE_CHECKING
import torch

View File

@ -59,7 +59,6 @@ from typing import (
Generic,
Optional,
overload,
Set,
TypeVar,
Union,
)
@ -1393,7 +1392,7 @@ def _scrubbed_inductor_config_for_logging() -> Optional[str]:
except Exception:
return "Value is not JSON serializable"
keys_to_scrub: Set[Any] = set()
keys_to_scrub: set[Any] = set()
inductor_conf_str = None
inductor_config_copy = (
torch._inductor.config.get_config_copy() if torch._inductor.config else None

View File

@ -16,8 +16,9 @@ computations.
"""
import collections
from collections.abc import Sequence
from enum import Enum
from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING
from typing import Any, Callable, Optional, TYPE_CHECKING
from .. import variables
from ..current_scope_id import current_scope_id

View File

@ -11,8 +11,8 @@ import sys
import types
import typing
from collections import defaultdict, OrderedDict
from collections.abc import KeysView
from typing import Callable, Sequence, TYPE_CHECKING, Union
from collections.abc import KeysView, Sequence
from typing import Callable, TYPE_CHECKING, Union
import torch
from torch import sym_float, sym_int

View File

@ -30,7 +30,7 @@ import itertools
import sys
import types
from collections.abc import Sequence
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, TypeVar
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
from typing_extensions import Never
from unittest.mock import patch
@ -517,7 +517,7 @@ class LocalGeneratorObjectVariable(VariableTracker):
def has_force_unpack_var_sequence(self, tx) -> builtins.bool:
return True
def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
def force_unpack_var_sequence(self, tx) -> list[VariableTracker]:
result = []
while True:
try:
@ -547,8 +547,8 @@ class LocalGeneratorObjectVariable(VariableTracker):
self,
tx: "InstructionTranslator",
name: str,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__next__":
return self.next_variable(tx)

View File

@ -16,7 +16,7 @@ from collections import OrderedDict
from contextlib import contextmanager
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
from unittest.mock import patch
import torch

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any, Dict, List, Tuple
from typing import Any
from torch._logging import trace_structured
from torch.fx import Graph, Node
@ -11,9 +11,9 @@ log: logging.Logger = logging.getLogger(__name__)
def create_joint_graph_node_information(
joint_graph: Graph,
recomputable_node_info: Dict[str, int],
) -> Dict[str, Any]:
joint_graph_node_information: Dict[str, Any] = {}
recomputable_node_info: dict[str, int],
) -> dict[str, Any]:
joint_graph_node_information: dict[str, Any] = {}
for i, joint_graph_node in enumerate(joint_graph.nodes):
is_recomputable_candidate: bool = (
@ -22,7 +22,7 @@ def create_joint_graph_node_information(
tensor_meta = joint_graph_node.meta.get("tensor_meta")
shape = getattr(tensor_meta, "shape", []) if tensor_meta else []
node_info: Dict[str, Any] = {
node_info: dict[str, Any] = {
"index": i,
"name": joint_graph_node.name,
"is_recomputable_candidate": is_recomputable_candidate,
@ -43,8 +43,8 @@ def create_joint_graph_node_information(
return joint_graph_node_information
def create_joint_graph_edges(joint_graph: Graph) -> List[Tuple[str, str]]:
joint_graph_edges: List[Tuple[str, str]] = [
def create_joint_graph_edges(joint_graph: Graph) -> list[tuple[str, str]]:
joint_graph_edges: list[tuple[str, str]] = [
(inp.name, node.name)
for node in joint_graph.nodes
for inp in node.all_input_nodes
@ -54,17 +54,17 @@ def create_joint_graph_edges(joint_graph: Graph) -> List[Tuple[str, str]]:
def create_activation_checkpointing_logging_structure_payload(
joint_graph: Graph,
joint_graph_node_information: Dict[str, Any],
joint_graph_edges: List[Tuple[str, str]],
all_recomputable_banned_nodes: List[Node],
joint_graph_node_information: dict[str, Any],
joint_graph_edges: list[tuple[str, str]],
all_recomputable_banned_nodes: list[Node],
expected_runtime: float,
saved_node_idxs: List[int],
recomputable_node_idxs: List[int],
memories_banned_nodes: List[float],
runtimes_banned_nodes: List[float],
min_cut_saved_values: List[Node],
) -> Dict[str, Any]:
activation_checkpointing_logging_structure_payload: Dict[str, Any] = {
saved_node_idxs: list[int],
recomputable_node_idxs: list[int],
memories_banned_nodes: list[float],
runtimes_banned_nodes: list[float],
min_cut_saved_values: list[Node],
) -> dict[str, Any]:
activation_checkpointing_logging_structure_payload: dict[str, Any] = {
"Joint Graph Size": len(joint_graph.nodes),
"Joint Graph Edges": {
"Total": len(joint_graph_edges),
@ -86,15 +86,15 @@ def create_activation_checkpointing_logging_structure_payload(
def create_structured_trace_for_min_cut_info(
joint_graph: Graph,
all_recomputable_banned_nodes: List[Node],
saved_node_idxs: List[int],
recomputable_node_idxs: List[int],
all_recomputable_banned_nodes: list[Node],
saved_node_idxs: list[int],
recomputable_node_idxs: list[int],
expected_runtime: float,
memories_banned_nodes: List[float],
runtimes_banned_nodes: List[float],
min_cut_saved_values: List[Node],
memories_banned_nodes: list[float],
runtimes_banned_nodes: list[float],
min_cut_saved_values: list[Node],
) -> None:
recomputable_node_info: Dict[str, int] = {
recomputable_node_info: dict[str, int] = {
node.name: idx for idx, node in enumerate(all_recomputable_banned_nodes)
}
joint_graph_node_information = create_joint_graph_node_information(

View File

@ -1,6 +1,7 @@
import dataclasses
import itertools
from typing import Any, Iterable, Union
from collections.abc import Iterable
from typing import Any, Union
import torch
from torch.utils._python_dispatch import is_traceable_wrapper_subclass

View File

@ -8,8 +8,6 @@
from __future__ import annotations
from typing import List
import torch
import torch.utils._pytree as pytree
from torch._ops import HigherOrderOperator
@ -50,9 +48,9 @@ class AOTICallDelegate(HigherOrderOperator):
self,
lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type]
original_gm: torch.fx.GraphModule,
weight_args: List[torch.Tensor],
input_args: List[torch.Tensor],
) -> List[torch.Tensor]:
weight_args: list[torch.Tensor],
input_args: list[torch.Tensor],
) -> list[torch.Tensor]:
return super().__call__(lowered_module, original_gm, weight_args, input_args)
@ -68,9 +66,9 @@ aoti_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU)
def call_delegate_cpu(
lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type]
original_gm: torch.fx.GraphModule,
weight_args: List[torch.Tensor],
input_args: List[torch.Tensor],
) -> List[torch.Tensor]:
weight_args: list[torch.Tensor],
input_args: list[torch.Tensor],
) -> list[torch.Tensor]:
# FX creates this immutable_dict/list concept. Get rid of this.
map_types = {
torch.fx.immutable_collections.immutable_dict: dict,
@ -104,8 +102,8 @@ def call_delegate_fake_tensor_mode(
mode: FakeTensorMode,
lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type]
original_gm: torch.fx.GraphModule,
weight_args: List[torch.Tensor],
input_args: List[torch.Tensor],
) -> List[torch.Tensor]:
weight_args: list[torch.Tensor],
input_args: list[torch.Tensor],
) -> list[torch.Tensor]:
with mode:
return call_delegate_cpu(lowered_module, original_gm, weight_args, input_args)

View File

@ -2,7 +2,7 @@
import functools
from contextlib import contextmanager, ExitStack
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
import torch
import torch.fx.traceback as fx_traceback
@ -173,7 +173,7 @@ def _detect_input_alias(gm: torch.fx.GraphModule) -> bool:
# The invariant here is that we always trace the branch with fake tensor
def _maybe_fake_tracing(fn, inputs: List[Any], pre_dispatch):
def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch):
fake_mode = detect_fake_mode(inputs)
tracing_mode = "real"
if fake_mode is None:
@ -565,8 +565,8 @@ def validate_subgraph_args_types(lifted_args: Union[tuple[Any, ...], list[Any]])
def check_input_alias_and_mutation(
gm: torch.fx.GraphModule,
fake_args: List[FakeTensor],
) -> Tuple[List[int], dict[int, int], dict[int, int], dict[int, int]]:
fake_args: list[FakeTensor],
) -> tuple[list[int], dict[int, int], dict[int, int], dict[int, int]]:
with disable_proxy_modes_tracing():
"""This function returns mutated inputs, inp-inp alias, inp-out alias, out-out alias
in the graph module gm. It checks whether input tensor versions have

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import io
import logging
import os
from typing import Any, Dict, IO, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Any, IO, Optional, TYPE_CHECKING, Union
import torch._inductor.config
import torch.fx

View File

@ -37,7 +37,6 @@ from typing import (
cast,
NoReturn,
Optional,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
@ -523,7 +522,7 @@ class FxGraphCachePickler(pickle.Pickler):
def _reduce_tensor(
self, t: Tensor
) -> Tuple[Callable[[T], T], Tuple[Union[TensorMetadata, TensorMetadataAndValues]]]:
) -> tuple[Callable[[T], T], tuple[Union[TensorMetadata, TensorMetadataAndValues]]]:
"""
Custom reducer to pickle Tensors. If we see tensors, we know they're constants
stored as attributes on the GraphModule.

View File

@ -18,8 +18,6 @@ from typing import (
cast,
ClassVar,
Generic,
Iterator,
MutableMapping,
NamedTuple,
Optional,
TYPE_CHECKING,
@ -59,7 +57,7 @@ from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Iterator, MutableMapping, Sequence
from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode
from ..loop_body import LoopBody

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import os
from itertools import chain, count, zip_longest
from typing import Any, Callable, Hashable, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import sympy
@ -24,6 +24,8 @@ from .wrapper import PythonWrapperCodegen, SymbolicCallArg
if TYPE_CHECKING:
from collections.abc import Hashable
from ..graph import GraphLowering

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
from __future__ import annotations
from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Any, Optional, TYPE_CHECKING, Union
from ..scheduler import (
BaseSchedulerNode,
@ -114,7 +114,7 @@ class CUDACombinedScheduling(BaseScheduling):
def benchmark_fused_nodes(
self, nodes: Sequence[BaseSchedulerNode]
) -> Tuple[float, str]:
) -> tuple[float, str]:
return self._triton_scheduling.benchmark_fused_nodes(nodes)
def benchmark_codegened_module(self, module):
@ -129,5 +129,5 @@ class CUDACombinedScheduling(BaseScheduling):
def benchmark_combo_kernel(
self, node_list: Sequence[BaseSchedulerNode]
) -> tuple[float, float, List[Optional[str]]]:
) -> tuple[float, float, list[Optional[str]]]:
return self._triton_scheduling.benchmark_combo_kernel(node_list)

View File

@ -11,16 +11,7 @@ import math
import operator
import textwrap
from collections import Counter
from typing import (
Any,
Callable,
Generic,
Iterator,
no_type_check,
Optional,
TYPE_CHECKING,
Union,
)
from typing import Any, Callable, Generic, no_type_check, Optional, TYPE_CHECKING, Union
from typing_extensions import TypeVar
import sympy
@ -72,7 +63,7 @@ from .simd_kernel_features import (
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from collections.abc import Iterable, Iterator, Sequence
log = logging.getLogger(__name__)

View File

@ -5,7 +5,7 @@ import dataclasses
import functools
import itertools
import typing
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import Any, Optional, Union
import sympy
@ -21,6 +21,10 @@ from ..utils import cache_on_self
from ..virtualized import V
if typing.TYPE_CHECKING:
from collections.abc import Iterable, Sequence
class NodeScheduleMarker:
@staticmethod
def only_nodes(it: Iterable[NodeScheduleEntry]) -> Iterable[SchedulerNode]:
@ -81,7 +85,7 @@ class SIMDKernelFeatures:
# numel excludes reduction_numel
self.numel: sympy.Expr = V.graph.sizevars.simplify(numel)
self.reduction_numel: sympy.Expr = V.graph.sizevars.simplify(reduction_numel)
self._stats_cache: Dict[Tuple[sympy.Expr, ...], MemoryStats] = {}
self._stats_cache: dict[tuple[sympy.Expr, ...], MemoryStats] = {}
@cache_on_self
def is_reduction(self) -> bool:
@ -205,7 +209,7 @@ class SIMDKernelFeatures:
return node.node.data.reduction_hint
def memory_stats(
self, groups_dict: Optional[Dict[str, sympy.Expr]] = None
self, groups_dict: Optional[dict[str, sympy.Expr]] = None
) -> MemoryStats:
"""Analysis to generate features that can be used in heuristics"""
if groups_dict is None:
@ -228,11 +232,11 @@ class MemoryEstimator:
We simulate the memory effects of CSE/buffer elimination in codegen.
"""
kernel_sizes: Tuple[sympy.Expr, ...]
kernel_sizes: tuple[sympy.Expr, ...]
outside_loop: MemoryEstimate
loops: List[MemoryEstimate]
loops: list[MemoryEstimate]
persistent: MemoryEstimate
symbols: List[sympy.Symbol]
symbols: list[sympy.Symbol]
def __init__(self, features: SIMDKernelFeatures, groups: Sequence[sympy.Expr]):
self.features = features
@ -341,7 +345,7 @@ class MemoryEstimator:
return True
return False
def set_ranges(self, *lengths: List[List[sympy.Expr]]) -> List[List[sympy.Expr]]:
def set_ranges(self, *lengths: list[list[sympy.Expr]]) -> list[list[sympy.Expr]]:
assert len(self.kernel_sizes) == len(lengths)
return [
self.make_flat_range(sym, numel, length)
@ -350,8 +354,8 @@ class MemoryEstimator:
@staticmethod
def make_flat_range(
sym: sympy.Symbol, numel: sympy.Expr, lengths: List[sympy.Expr]
) -> List[sympy.Expr]:
sym: sympy.Symbol, numel: sympy.Expr, lengths: list[sympy.Expr]
) -> list[sympy.Expr]:
if len(lengths) == 1 and numel == lengths[0]:
return [sym]
divisor = sympy.S.One
@ -370,10 +374,10 @@ class MemoryEstimator:
class MemoryEstimate:
"""Tracks the memory usage of a single loop in the generated kernel"""
reads: Dict[str, OrderedSet[MemoryDep]] = dataclasses.field(
reads: dict[str, OrderedSet[MemoryDep]] = dataclasses.field(
default_factory=functools.partial(collections.defaultdict, OrderedSet)
)
writes: Dict[str, OrderedSet[MemoryDep]] = dataclasses.field(
writes: dict[str, OrderedSet[MemoryDep]] = dataclasses.field(
default_factory=functools.partial(collections.defaultdict, OrderedSet)
)
@ -474,8 +478,8 @@ class StatsForLoop:
class StatsForReadsOrWrites:
"""Memory usage stats that are collected for reads/writes/both"""
dim: List[StatsForDim]
loop: List[StatsForLoop]
dim: list[StatsForDim]
loop: list[StatsForLoop]
# total bytes contiguous in any dimension
bytes_contiguous_or_broadcast: sympy.Expr = sympy.S.Zero
bytes_non_contiguous: sympy.Expr = sympy.S.Zero
@ -506,8 +510,8 @@ class StatsForReadsOrWrites:
@classmethod
def compute(
cls,
loop_deps: List[Dict[str, OrderedSet[MemoryDep]]],
index_symbols: List[sympy.Symbol],
loop_deps: list[dict[str, OrderedSet[MemoryDep]]],
index_symbols: list[sympy.Symbol],
) -> typing.Self:
ndim = len(index_symbols)
result = cls(dim := [StatsForDim() for _ in range(ndim)], [])
@ -521,7 +525,7 @@ class StatsForReadsOrWrites:
loop_stats.count_per_thread += len(deps)
loop_stats.bytes_per_thread += itemsize * len(deps)
for dep in deps:
strides: List[sympy.Expr] = V.graph.sizevars.stride_vars(
strides: list[sympy.Expr] = V.graph.sizevars.stride_vars(
dep.index, index_symbols
)
for i in range(ndim):
@ -568,7 +572,7 @@ class StatsForKernelType:
@classmethod
def compute(
cls, loops: List[MemoryEstimate], estimator: MemoryEstimator
cls, loops: list[MemoryEstimate], estimator: MemoryEstimator
) -> typing.Self:
reads = StatsForReadsOrWrites.compute(
[loop.reads for loop in loops], estimator.symbols

View File

@ -12,19 +12,11 @@ import sys
import time
import warnings
from abc import ABC, abstractmethod
from contextlib import AbstractContextManager
from dataclasses import dataclass
from inspect import currentframe
from itertools import count
from typing import (
Any,
Callable,
ContextManager,
Mapping,
Optional,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import Never, override, ParamSpec, Protocol, TypedDict, Unpack
from unittest import mock
@ -127,7 +119,7 @@ from .virtualized import V
if TYPE_CHECKING:
from collections.abc import Generator, Sequence
from collections.abc import Generator, Mapping, Sequence
from torch._inductor.output_code import _StrideExprStr
from torch._ops import OpOverload
@ -480,7 +472,7 @@ def is_tf32_warning_applicable(gm: GraphModule) -> bool:
def maybe_disable_comprehensive_padding(
example_inputs: Sequence[InputType],
) -> contextlib.AbstractContextManager[None, None]:
) -> AbstractContextManager[None, None]:
"""
For CPU backend, enable comprehensive padding causes some unit tests
fail due to changing number of generated kernels. Skip for now.
@ -1780,7 +1772,7 @@ def get_cpp_wrapper_config() -> dict[str, object]:
}
def get_cuda_device_context(gm: torch.fx.GraphModule) -> ContextManager[None]:
def get_cuda_device_context(gm: torch.fx.GraphModule) -> AbstractContextManager[None]:
"""
Returns a cuda device context manager if there is a single device in the graph
"""

View File

@ -48,17 +48,9 @@ import traceback
import warnings
import weakref
from collections import defaultdict
from contextlib import AbstractContextManager
from enum import auto, Enum
from typing import (
Any,
Callable,
cast,
ContextManager,
Optional,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypeVar, Union
import torch.fx
from torch import Tensor
@ -179,7 +171,7 @@ def enable_history_recording() -> Generator[None, None, None]:
torch.cuda.memory._record_memory_history(None)
def get_history_recording() -> ContextManager[None]:
def get_history_recording() -> AbstractContextManager[None]:
# TODO - remove, prevents cleanup
if not config.triton.cudagraph_trees_history_recording:
return contextlib.nullcontext()

View File

@ -3,8 +3,8 @@ import dataclasses
import itertools
import logging
import re
from collections.abc import Sequence
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union
from collections.abc import Iterable, Sequence
from typing import Any, Callable, Optional, TypeVar, Union
from typing_extensions import Self
from unittest.mock import patch
@ -80,7 +80,7 @@ class MemoryDep(Dep):
def num_vars(self) -> int:
return len(self.var_names)
def decide_loop_order_to_match(self, other: "MemoryDep") -> Optional[List[int]]:
def decide_loop_order_to_match(self, other: "MemoryDep") -> Optional[list[int]]:
"""
Can return None if not able to decide loop orders.
"""
@ -451,8 +451,8 @@ class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
@staticmethod
def drop_unused_symbols(
index: Union[int, sympy.Expr],
var_names: List[sympy.Expr],
sizes: List[sympy.Expr],
var_names: list[sympy.Expr],
sizes: list[sympy.Expr],
) -> None:
"""
Reduction has last (reduced) dim in its sizes, but
@ -571,7 +571,7 @@ def var_builder(prefix: str) -> tuple[VarRanges, Callable[[sympy.Expr], sympy.Sy
def index_vars_no_squeeze(
*argsizes: Sequence[sympy.Expr], prefix: str
) -> Tuple[List[List[sympy.Symbol]], VarRanges]:
) -> tuple[list[list[sympy.Symbol]], VarRanges]:
var_ranges, add_var = var_builder(prefix)
args: list[list[sympy.Symbol]] = [list(map(add_var, size)) for size in argsizes]
return args, var_ranges
@ -579,7 +579,7 @@ def index_vars_no_squeeze(
def index_vars_squeeze(
*argsizes: Sequence[sympy.Expr], prefix: str = "d"
) -> Tuple[List[List[sympy.Expr]], VarRanges]:
) -> tuple[list[list[sympy.Expr]], VarRanges]:
from .ir import SqueezeView
var_ranges, add_var = var_builder(prefix)
@ -597,7 +597,7 @@ def extract_read_writes(
*argsizes: Sequence[sympy.Expr],
normalize: bool = False,
prefix: str = "d",
hidden_args: Sequence[List[sympy.Expr]] = (),
hidden_args: Sequence[list[sympy.Expr]] = (),
) -> ReadWrites:
args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix)
@ -630,7 +630,7 @@ def extract_read_writes(
def extract_loop_body_with_args(
fn: Any,
args: List[List[sympy.Expr]],
args: list[list[sympy.Expr]],
var_ranges: VarRanges,
normalize: bool = False,
) -> _RecordLoadStoreInner:
@ -761,17 +761,17 @@ class FreeUnbackedSymbolsOpsHandler(DefaultHandler):
self.symbols |= free_unbacked_symbols(size)
return sympy_index_symbol(f"({str(index_var)})")
def frexp(self, x: Any) -> Tuple[None, ...]:
def frexp(self, x: Any) -> tuple[None, ...]:
return (None,) * 2
def scan(
self, dtypes: Any, combine_fn: Any, values: Sequence[Any]
) -> Tuple[None, ...]:
) -> tuple[None, ...]:
return (None,) * len(values)
def sort(
self, dtypes: Any, values: Sequence[Any], stable: Any, descending: Any
) -> Tuple[None, ...]:
) -> tuple[None, ...]:
return (None,) * len(values)
def reduction(

View File

@ -1,6 +1,7 @@
import contextlib
import threading
from typing import Any, Generator
from collections.abc import Generator
from typing import Any
import torch

View File

@ -7,7 +7,7 @@ import signal
import string
import sys
import traceback
from collections.abc import KeysView
from collections.abc import KeysView, Sequence
from enum import Enum
from functools import partial, wraps
from types import FrameType
@ -18,7 +18,6 @@ from typing import (
get_origin,
Literal,
Optional,
Sequence,
TypeVar,
Union,
)

View File

@ -2,7 +2,8 @@ import functools
import itertools
import operator
import typing
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
import torch
import torch._inductor.runtime.runtime_utils
@ -275,7 +276,7 @@ def should_pad_bench_key(
input: Optional[Tensor] = None,
is_base_time_key: bool = False,
) -> str:
def tensor_key(t: Tensor) -> Tuple[torch.Size, Tuple[int, ...], torch.dtype]:
def tensor_key(t: Tensor) -> tuple[torch.Size, tuple[int, ...], torch.dtype]:
return (t.shape, t.stride(), t.dtype)
tf32_key = (
@ -436,8 +437,8 @@ def _should_pad_bench(
return False
def realize_symbols(
ds: Union[torch.Size, Tuple[torch.SymInt, ...]]
) -> List[int]:
ds: Union[torch.Size, tuple[torch.SymInt, ...]]
) -> list[int]:
return [d if isinstance(d, int) else d.node.hint for d in ds]
if any(

View File

@ -9,14 +9,13 @@ import textwrap
import traceback
import typing
from collections.abc import Generator, Iterable, Sequence
from contextlib import nullcontext
from contextlib import AbstractContextManager, nullcontext
from enum import Enum
from functools import partial
from typing import (
Any,
Callable,
ClassVar,
ContextManager,
Literal,
Optional,
overload,
@ -6710,7 +6709,7 @@ class FallbackKernel(ExternKernelAlloc):
@classmethod
def create(cls, kernel, *args, **kwargs): # type: ignore[no-untyped-def]
fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,)
context: ContextManager[None] = (
context: AbstractContextManager[None] = (
V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext() # type: ignore[assignment]
)
with context:

View File

@ -7,7 +7,7 @@ import math
from collections.abc import Sequence
from dataclasses import dataclass
from enum import auto, Enum
from typing import Any, Dict, Optional, Union
from typing import Any, Optional, Union
import sympy
@ -1195,7 +1195,7 @@ def next_power_of_two(n):
def set_head_dim_values(
kernel_options: Dict[str, Any], qk_head_dim, v_head_dim, graph_sizevars
kernel_options: dict[str, Any], qk_head_dim, v_head_dim, graph_sizevars
):
"""
Mutates kernel options, adding head dimension calculations.

View File

@ -7,7 +7,7 @@ import os
import re
from dataclasses import dataclass
from functools import lru_cache
from typing import Callable, cast, Dict, List, Optional, TYPE_CHECKING, Union
from typing import Callable, cast, Optional, TYPE_CHECKING, Union
from torch._inductor import config
from torch._inductor.utils import get_benchmark_name
@ -92,7 +92,7 @@ class CachedMetricsDeltas:
num_matches_for_scatter_upon_const_tensor: int
def get_metric_fields() -> List[str]:
def get_metric_fields() -> list[str]:
return [field.name for field in dataclasses.fields(CachedMetricsDeltas)]
@ -132,7 +132,7 @@ class MetricTable:
num_rows_added: int = 0
def add_row(
self, row_fn: Callable[[], Dict[str, Optional[Union[str, float]]]]
self, row_fn: Callable[[], dict[str, Optional[Union[str, float]]]]
) -> None:
if self.table_name not in enabled_metric_tables():
return
@ -160,7 +160,7 @@ class MetricTable:
writer = csv.writer(fd, lineterminator="\n")
writer.writerow(["model_name"] + self.column_names)
def _write_row(self, row: List[str]) -> None:
def _write_row(self, row: list[str]) -> None:
filename = self.output_filename()
if self.num_rows_added == 0 and not os.path.exists(filename):
self.write_header()
@ -181,7 +181,7 @@ class MetricTable:
writer.writerow(row)
@staticmethod
def register_table(name: str, column_names: List[str]) -> None:
def register_table(name: str, column_names: list[str]) -> None:
table = MetricTable(name, column_names)
REGISTERED_METRIC_TABLES[name] = table

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
from typing import Any, Optional, Sequence
from collections.abc import Sequence
from typing import Any, Optional
import sympy

View File

@ -27,7 +27,7 @@ import logging
import os
import re
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing_extensions import TypeAlias
import torch
@ -277,7 +277,7 @@ class CompiledFxGraphConstantsWithGm(CompiledFxGraphConstants):
def __init__(self, gm: torch.fx.GraphModule) -> None:
self.gm = gm
def unwrap(self, g: CompiledFxGraph) -> Dict[str, torch.Tensor]:
def unwrap(self, g: CompiledFxGraph) -> dict[str, torch.Tensor]:
frozen_params = {
name: getattr(self.gm, orig_name)
for name, orig_name in g.frozen_param_names.items()
@ -301,10 +301,10 @@ class CompiledFxGraph(OutputCode):
device_idxs: OrderedSet[int]
mutated_inputs: OrderedSet[str]
mutated_input_idxs: OrderedSet[int]
constants: Optional[Dict[str, torch.Tensor]]
frozen_param_names: Dict[str, str]
torchbind_constants: Dict[str, torch._C.ScriptObject]
output_strides: Optional[List[Optional[tuple[_StrideExprStr, ...]]]]
constants: Optional[dict[str, torch.Tensor]]
frozen_param_names: dict[str, str]
torchbind_constants: dict[str, torch._C.ScriptObject]
output_strides: Optional[list[Optional[tuple[_StrideExprStr, ...]]]]
disabled_cudagraphs_reason: Optional[str]
metrics_deltas: metrics.CachedMetricsDeltas
counter_deltas: Counter[str]

View File

@ -14,21 +14,11 @@ import textwrap
import traceback
import typing
from collections import Counter, defaultdict
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Optional,
Sequence,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar, Union
if TYPE_CHECKING:
from collections.abc import Sequence
from types import ModuleType
import sympy
@ -2794,7 +2784,7 @@ class Scheduler:
)
# Start compiling choices in parallel
future_choices: List[tuple[Any, Optional[LambdaFuture], ModuleType]] = []
future_choices: list[tuple[Any, Optional[LambdaFuture], ModuleType]] = []
triton_choices = 0
for choice, unfused_time in sorted(
choice_timings.items(), key=lambda x: x[1]
@ -2964,7 +2954,7 @@ class Scheduler:
# These are potential fusions which we are async compiling,
# and which we will benchmark profitability of.
pending_fusions: Dict[
pending_fusions: dict[
BaseSchedulerNode,
tuple[Callable[[], bool], BaseSchedulerNode, BaseSchedulerNode],
] = {}
@ -4068,7 +4058,7 @@ class Scheduler:
def benchmark_combo_kernel(
self, node_list: Sequence[BaseSchedulerNode]
) -> tuple[float, float, List[Optional[str]]]:
) -> tuple[float, float, list[Optional[str]]]:
"""
Benchmark fused list of nodes and return the execution time
in milliseconds on randomly generated inputs.
@ -4308,7 +4298,7 @@ class BaseScheduling:
def benchmark_combo_kernel(
self, node_list: Sequence[BaseSchedulerNode]
) -> tuple[float, float, List[Optional[str]]]:
) -> tuple[float, float, list[Optional[str]]]:
"""
Benchmark the list of nodes to combine and return the execution time
and memory copy time in milliseconds on randomly generated inputs.

View File

@ -21,23 +21,18 @@ import tempfile
import textwrap
import time
import unittest
from collections.abc import Collection, Iterator, Mapping, MutableMapping, MutableSet
from datetime import datetime
from io import StringIO
from typing import (
Any,
Callable,
cast,
Collection,
Generic,
Iterator,
Literal,
Mapping,
MutableMapping,
MutableSet,
NamedTuple,
Optional,
Protocol,
Type,
TYPE_CHECKING,
TypeVar,
Union,
@ -1361,7 +1356,7 @@ def _rocm_native_device_arch_name(device: str) -> str:
@functools.lru_cache(None)
def try_import_ck_lib() -> (
tuple[Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], Type[Any]]
tuple[Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any]]
):
try:
import ck4inductor # type: ignore[import]
@ -2608,7 +2603,7 @@ class ScopedDict(MutableMapping[KeyType, ValType]):
@dataclass_transform(frozen_default=True)
def ir_dataclass(cls: Optional[Type[Any]] = None, /, *, frozen: bool = True) -> Any:
def ir_dataclass(cls: Optional[type[Any]] = None, /, *, frozen: bool = True) -> Any:
def wrap(cls: _T) -> _T:
if sys.version_info >= (3, 10):
return dataclasses.dataclass(cls, kw_only=True, frozen=frozen) # type: ignore[call-overload]

View File

@ -3,7 +3,7 @@ import datetime
import tempfile
from collections import defaultdict
from types import ModuleType
from typing import Any, Dict, Optional, Protocol
from typing import Any, Optional, Protocol
import torch
from torch.autograd import DeviceType
@ -73,7 +73,7 @@ def get_triton_kernel(mod: ModuleType): # type: ignore[no-untyped-def]
def benchmark_all_kernels(
benchmark_name: str, benchmark_all_configs: Optional[Dict[Any, Any]]
benchmark_name: str, benchmark_all_configs: Optional[dict[Any, Any]]
) -> None:
"""
An experimental API used only when config.benchmark_kernel is true.
@ -184,7 +184,7 @@ def parse_profile_event_list(
"""
return ev.self_device_time_total / 1000 / nruns # type: ignore[attr-defined]
all_events: Dict[str, list[ProfileEvent]] = defaultdict(list)
all_events: dict[str, list[ProfileEvent]] = defaultdict(list)
def add_event(
ev: torch.autograd.profiler_util.EventList,

View File

@ -1353,11 +1353,11 @@ def _is_exception(obj) -> bool:
def raise_error_container_parameter_missing(target_type) -> None:
if target_type == "Dict":
if target_type.endswith("ict"):
raise RuntimeError(
"Attempted to use Dict without "
f"Attempted to use {target_type} without "
"contained types. Please add contained type, e.g. "
"Dict[int, int]"
f"{target_type}[int, int]"
)
raise RuntimeError(
f"Attempted to use {target_type} without a "
@ -1366,15 +1366,20 @@ def raise_error_container_parameter_missing(target_type) -> None:
)
_RAW_TYPE_NAME_MAPPING = {
dict: "dict",
list: "list",
tuple: "tuple",
typing.Dict: "Dict", # noqa: UP006
typing.List: "List", # noqa: UP006
typing.Optional: "Optional",
typing.Tuple: "Tuple", # noqa: UP006
}
def check_args_exist(target_type) -> None:
if target_type is typing.List or target_type is list: # noqa: UP006
raise_error_container_parameter_missing("List")
elif target_type is typing.Tuple or target_type is tuple: # noqa: UP006
raise_error_container_parameter_missing("Tuple")
elif target_type is typing.Dict or target_type is dict: # noqa: UP006
raise_error_container_parameter_missing("Dict")
elif target_type is None or target_type is Optional:
raise_error_container_parameter_missing("Optional")
if name := _RAW_TYPE_NAME_MAPPING.get(target_type):
raise_error_container_parameter_missing(name)
def check_empty_containers(obj) -> None:

View File

@ -3,7 +3,7 @@ import operator
from collections.abc import Sequence
from enum import Enum
from functools import partial, reduce
from typing import Callable, List, Optional, Tuple, Type, Union
from typing import Callable, Optional, Union
import torch
import torch._prims_common as utils

View File

@ -12,12 +12,9 @@ from typing import (
Any,
Callable,
cast,
List,
NamedTuple,
Optional,
overload,
Tuple,
Type,
TYPE_CHECKING,
TypeVar,
Union,

View File

@ -10,7 +10,7 @@ import warnings
from collections.abc import Iterable, Sequence
from enum import Enum
from functools import partial, reduce, singledispatch, wraps
from typing import Any, Callable, cast, Dict, List, Optional, overload, Tuple, Union
from typing import Any, Callable, cast, Optional, overload, Union
import torch
import torch._prims as prims

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
from functools import partial
from typing import Optional, Tuple, Union
from typing import Optional, Union
import torch
import torch._prims as prims

View File

@ -1,4 +1 @@
from typing import List
__all__: list[str] = []

View File

@ -3,7 +3,8 @@ import contextlib
import warnings
import weakref
from abc import ABC, abstractmethod
from typing import Any, Callable, ContextManager, Optional, Union
from contextlib import AbstractContextManager
from typing import Any, Callable, Optional, Union
import torch
import torch.utils._pytree as pytree
@ -665,7 +666,7 @@ class BaseFunctionalizeAPI(ABC):
pass
@abstractmethod
def redispatch_to_next(self) -> ContextManager:
def redispatch_to_next(self) -> AbstractContextManager:
pass
@abstractmethod
@ -709,7 +710,7 @@ class PythonFunctionalizeAPI(BaseFunctionalizeAPI):
def functionalize(self, inner_f: Callable) -> Callable:
return dispatch_functionalize(inner_f, self.mode)
def redispatch_to_next(self) -> ContextManager:
def redispatch_to_next(self) -> AbstractContextManager:
# [NOTE] We don't do anything here because at the time
# we exercise this path, we would have already popped the
# FunctionalTensorMode from mode stack. Since FunctionalTensorMode
@ -753,7 +754,7 @@ class CppFunctionalizeAPI(BaseFunctionalizeAPI):
def functionalize(self, inner_f: Callable) -> Callable:
return torch.func.functionalize(inner_f)
def redispatch_to_next(self) -> ContextManager:
def redispatch_to_next(self) -> AbstractContextManager:
return torch._C._ExcludeDispatchKeyGuard(
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
)
@ -801,7 +802,7 @@ class FunctorchFunctionalizeAPI(BaseFunctionalizeAPI):
),
)
def redispatch_to_next(self) -> ContextManager:
def redispatch_to_next(self) -> AbstractContextManager:
return self.interpreter.lower()
def replace(self, input_tensor, output_tensor) -> None:

View File

@ -7,12 +7,12 @@ import typing
import warnings
import weakref
from abc import abstractmethod
from contextlib import AbstractContextManager
from dataclasses import dataclass
from typing import (
Any,
Callable,
ClassVar,
ContextManager,
Generic,
NewType,
Optional,
@ -1741,7 +1741,9 @@ class MetaConverter(Generic[_TensorT]):
# subclasses. Relevant test is
# DynamicShapesFunctionTests::test_add_dynamic_shapes in
# test/dynamo/test_dynamic_shapes.py
maybe_fake_mgr: ContextManager[None] = contextlib.nullcontext()
maybe_fake_mgr: AbstractContextManager[
None
] = contextlib.nullcontext()
from torch._subclasses.fake_tensor import (
in_kernel_invocation_manager,
maybe_get_fake_mode,

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Optional, Union
import torch
from torch import Tensor

View File

@ -10,7 +10,7 @@ half, float, double and bfloat16) and complex :class:`Tensor` types (cfloat, cdo
import warnings
from collections.abc import Sequence
from typing import cast, List, Optional, Tuple, Union
from typing import cast, Optional, Union
import torch
from torch import _vmap_internals

View File

@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
import sys
import types
from typing import List
import torch

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, TypeVar
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
from typing_extensions import ParamSpec
import torch

View File

@ -17,7 +17,7 @@ import threading
import traceback
import warnings
from functools import lru_cache
from typing import Any, Callable, cast, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Optional, Union
import torch
import torch._C
@ -85,7 +85,7 @@ try:
paths = ["libamd_smi.so"]
if rocm_home := os.getenv("ROCM_HOME", os.getenv("ROCM_PATH")):
paths = [os.path.join(rocm_home, "lib/libamd_smi.so")] + paths
self.paths: List[str] = paths
self.paths: list[str] = paths
def hooked_CDLL(
self, name: Union[str, Path, None], *args: Any, **kwargs: Any

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
from collections.abc import Generator
from contextlib import contextmanager, nullcontext
from typing import Any, ContextManager, Optional
from contextlib import AbstractContextManager, contextmanager, nullcontext
from typing import Any, Optional
import torch
import torch.nn as nn
@ -14,7 +14,7 @@ from .contract import _State, contract
@contextmanager
def _no_hook(module: nn.Module, user_ctx: Optional[ContextManager] = None):
def _no_hook(module: nn.Module, user_ctx: Optional[AbstractContextManager] = None):
r"""
Disable hooks installed by checkpoint to avoid unintentional recursion
during backward recomputation.

View File

@ -1,14 +1,14 @@
import pickle
from dataclasses import dataclass
from io import BufferedIOBase
from typing import Any, Dict, List, Tuple
from typing import Any
import torch
import torch._weights_only_unpickler as _weights_only_unpickler
from torch.serialization import _load, _save, DEFAULT_PROTOCOL, MAP_LOCATION
__all__: List[str] = []
__all__: list[str] = []
@dataclass
@ -23,7 +23,7 @@ _weights_only_unpickler._add_safe_globals([_Entry])
class _PseudoZipFile:
def __init__(self) -> None:
self.records: Dict[str, Tuple[object, int]] = {}
self.records: dict[str, tuple[object, int]] = {}
def write_record(self, key: str, data: object, length: int) -> None:
self.records[key] = (data, length)

View File

@ -1,5 +1,5 @@
from collections.abc import Iterator
from typing import Tuple, Union
from typing import Union
import torch.nn as nn
from torch.distributed._shard.sharded_tensor import ShardedTensor

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import functools
from typing import List, TYPE_CHECKING
from typing import TYPE_CHECKING
import torch
from torch.distributed._shard.op_registry_utils import _decorator_func

View File

@ -7,7 +7,7 @@ from contextlib import contextmanager
from datetime import timedelta
from enum import Enum
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Optional
import torch
import torch.distributed._functional_collectives as funcol

View File

@ -4,7 +4,7 @@ import operator
from dataclasses import dataclass
from enum import auto, Enum
from functools import reduce
from typing import Any, Dict, Optional, Union
from typing import Any, Optional, Union
import torch
from torch.distributed.checkpoint.metadata import (
@ -205,21 +205,21 @@ class SavePlanner(abc.ABC):
# Save plan for the current rank as computed by `create_local_plan` API
# Cached on the local rank.
_cached_save_plan: Dict[str, SavePlan] = {}
_cached_save_plan: dict[str, SavePlan] = {}
# Final save plan for the current rank.
# This is created by merging the plan created by `create_local_plan` API
# and the result of `create_global_plan` for the given rank.
# This is the final plan computed by the `finish_plan` API that gets
# sent to the `write_data`.
# Cached on the local rank.
_cached_final_save_plan: Dict[str, SavePlan] = {}
_cached_final_save_plan: dict[str, SavePlan] = {}
# Collection of all the local plans from all the ranks.
# This is the input to the `create_global_plan` API.
# Cached on the coordinator rank.
_cached_all_plans: Dict[str, list[SavePlan]] = {}
_cached_all_plans: dict[str, list[SavePlan]] = {}
# Global checkpoint plan as computed by `create_global_plan` API.
# Cached on the coordinator rank.
_cached_global_plan: Dict[str, list[SavePlan]] = {}
_cached_global_plan: dict[str, list[SavePlan]] = {}
@abc.abstractmethod
def set_up_planner(

View File

@ -24,7 +24,7 @@ import logging
import os
import socket
import traceback
from typing import Dict, Optional
from typing import Optional
from torch.distributed.elastic.events.handlers import get_logging_handler

View File

@ -62,7 +62,7 @@ was launched a :class:`api.SubprocessContext` is returned. Both are specific
implementations of the parent :class:`api.PContext` class.
"""
from typing import Callable, Dict, Optional, Tuple, Union
from typing import Callable, Optional, Union
from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401
_validate_full_rank,

View File

@ -58,7 +58,7 @@ from dataclasses import dataclass, field
from datetime import datetime
from functools import wraps
from string import Template
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
from typing import Any, Callable, Optional, TypeVar
from torch.distributed.elastic.utils.logging import get_logger

View File

@ -2,7 +2,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
from dataclasses import dataclass
from typing import Dict, Union
from typing import Union
import torch
from torch import fx
@ -90,7 +90,7 @@ def validate_tensors_metadata(
def generate_stage_to_rank_mapping(
pp_size: int, num_stages: int, style: str = "loop"
) -> Dict[int, int]:
) -> dict[int, int]:
"""
Compute the stage id to rank mapping for either a looped or V-style schedule.

View File

@ -9,7 +9,7 @@ import re
from abc import ABC, abstractmethod
from collections import Counter, defaultdict
from enum import Enum
from typing import Any, Callable, Dict, NamedTuple, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
import torch
import torch.distributed as dist
@ -1025,7 +1025,7 @@ def _validate_schedule(
pp_group_size: int,
num_stages: int,
num_microbatches: int,
) -> Dict[int, int]:
) -> dict[int, int]:
assert (
len(actions) == pp_group_size
), f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}"

View File

@ -5,7 +5,6 @@ import threading
import warnings
from collections.abc import Generator
from datetime import timedelta
from typing import Tuple
from urllib.parse import urlparse
import torch

View File

@ -9,17 +9,7 @@ import warnings
import zipfile
from collections.abc import Iterator
from enum import auto, Enum
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import torch
import torch.utils._pytree as pytree

View File

@ -1,4 +1,4 @@
from typing import Dict, Union
from typing import Union
import torch
import torch.utils._pytree as pytree

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
from __future__ import annotations
from typing import Callable, cast, Generic, List, Optional, TypeVar, Union
from typing import Callable, cast, Generic, Optional, TypeVar, Union
import torch

View File

@ -3,7 +3,7 @@ import importlib
import io
import pickle
from abc import abstractmethod
from typing import Any, Callable, Dict, NewType, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Callable, NewType, Optional, TypeVar, Union
from typing_extensions import override, Self
import torch
@ -45,7 +45,7 @@ class GraphPickler(pickle.Pickler):
@override
def reducer_override(
self, obj: object
) -> Tuple[Callable[..., Any], Tuple[Any, ...]]:
) -> tuple[Callable[..., Any], tuple[Any, ...]]:
# This function is supposed to return either NotImplemented (meaning to
# do the default pickle behavior) or a pair of (unpickle callable, data
# to pass to unpickle).
@ -138,13 +138,13 @@ class _GraphUnpickler(pickle.Unpickler):
class _ShapeEnvPickleData:
data: Dict[str, object]
data: dict[str, object]
@classmethod
def reduce_helper(
cls, pickler: GraphPickler, obj: ShapeEnv
) -> Tuple[
Callable[[Self, _UnpickleState], ShapeEnv], Tuple[Self, _UnpickleStateToken]
) -> tuple[
Callable[[Self, _UnpickleState], ShapeEnv], tuple[Self, _UnpickleStateToken]
]:
return cls.unpickle, (cls(obj), pickler._unpickle_state)
@ -174,8 +174,8 @@ class _SymNodePickleData:
cls,
pickler: GraphPickler,
obj: _SymNodeT,
) -> Tuple[
Callable[[Self, _UnpickleState], _SymNodeT], Tuple[Self, _UnpickleStateToken]
) -> tuple[
Callable[[Self, _UnpickleState], _SymNodeT], tuple[Self, _UnpickleStateToken]
]:
args = (cls(obj.node), pickler._unpickle_state)
if isinstance(obj, torch.SymInt):
@ -205,8 +205,8 @@ class _TensorPickleData:
@classmethod
def reduce_helper(
cls, pickler: GraphPickler, obj: FakeTensor
) -> Tuple[
Callable[[Self, _UnpickleState], FakeTensor], Tuple[Self, _UnpickleStateToken]
) -> tuple[
Callable[[Self, _UnpickleState], FakeTensor], tuple[Self, _UnpickleStateToken]
]:
return cls.unpickle, (
cls(pickler._meta_tensor_describer, obj),
@ -266,8 +266,8 @@ class _TorchNumpyPickleData:
def reduce_helper(
cls, pickler: GraphPickler, obj: object
) -> Optional[
Tuple[
Callable[[Self, _UnpickleState], object], Tuple[Self, _UnpickleStateToken]
tuple[
Callable[[Self, _UnpickleState], object], tuple[Self, _UnpickleStateToken]
]
]:
if data := cls.from_object(obj):
@ -309,9 +309,9 @@ class _GraphModulePickleData:
@classmethod
def reduce_helper(
cls, pickler: GraphPickler, obj: torch.fx.GraphModule
) -> Tuple[
) -> tuple[
Callable[[Self, _UnpickleState], torch.fx.GraphModule],
Tuple[Self, _UnpickleStateToken],
tuple[Self, _UnpickleStateToken],
]:
return cls.unpickle, (
cls(obj),
@ -337,7 +337,7 @@ class _GraphModulePickleData:
class _NodePickleData:
def __init__(
self, node: torch.fx.Node, mapping: Dict[torch.fx.Node, "_NodePickleData"]
self, node: torch.fx.Node, mapping: dict[torch.fx.Node, "_NodePickleData"]
) -> None:
self.args = pytree.tree_map_only(torch.fx.Node, lambda n: mapping[n], node.args)
self.kwargs = pytree.tree_map_only(
@ -358,7 +358,7 @@ class _NodePickleData:
def unpickle(
self,
graph: torch.fx.Graph,
mapping: Dict["_NodePickleData", torch.fx.Node],
mapping: dict["_NodePickleData", torch.fx.Node],
unpickle_state: _UnpickleState,
) -> torch.fx.Node:
args = pytree.tree_map_only(_NodePickleData, lambda n: mapping[n], self.args)
@ -376,7 +376,7 @@ class _OpPickleData:
@classmethod
def reduce_helper(
cls, pickler: GraphPickler, op: object
) -> Tuple[Callable[[_UnpickleState], object], Tuple[_UnpickleStateToken]]:
) -> tuple[Callable[[_UnpickleState], object], tuple[_UnpickleStateToken]]:
result = cls.pickle(op)
return (result.unpickle, (pickler._unpickle_state,))
@ -404,7 +404,7 @@ class _OpPickleData:
def _pickle_op(
name: str,
datacls: Union[
Type["_OpOverloadPickleData"], Type["_OpOverloadPacketPickleData"]
type["_OpOverloadPickleData"], type["_OpOverloadPacketPickleData"]
],
) -> "_OpPickleData":
if not name.startswith("torch.ops.aten"): # TODO: What's the full list?
@ -501,7 +501,7 @@ class _GraphPickleData:
self.tracer_cls = graph._tracer_cls
self.tracer_extras = graph._tracer_extras
nodes: Dict[torch.fx.Node, _NodePickleData] = {}
nodes: dict[torch.fx.Node, _NodePickleData] = {}
for node in graph.nodes:
nodes[node] = _NodePickleData(node, nodes)
self.nodes = tuple(nodes.values())
@ -521,7 +521,7 @@ class _GraphPickleData:
) -> torch.fx.Graph:
graph = torch.fx.Graph(gm, self.tracer_cls, self.tracer_extras)
nodes: Dict[_NodePickleData, torch.fx.Node] = {}
nodes: dict[_NodePickleData, torch.fx.Node] = {}
for nd in self.nodes:
nodes[nd] = nd.unpickle(graph, nodes, unpickle_state)
@ -532,9 +532,9 @@ class _TracingContextPickleData:
@classmethod
def reduce_helper(
cls, pickler: GraphPickler, obj: torch._guards.TracingContext
) -> Tuple[
) -> tuple[
Callable[[Self, _UnpickleState], torch._guards.TracingContext],
Tuple[Self, _UnpickleStateToken],
tuple[Self, _UnpickleStateToken],
]:
return (
cls.unpickle,

View File

@ -23,7 +23,7 @@ import math
import operator
import sys
from functools import lru_cache, update_wrapper
from typing import Optional, Set, TYPE_CHECKING, Union
from typing import Optional, TYPE_CHECKING, Union
import torch
import torch._logging.structured as structured
@ -1233,7 +1233,7 @@ def _make_node_magic(method, func):
else:
method_attr = method
def uninteresting_files() -> Set[str]:
def uninteresting_files() -> set[str]:
import torch
mods = [

View File

@ -36,11 +36,11 @@ if TYPE_CHECKING:
# Mapping of builtins to their `typing` equivalent.
# (PEP585: See D68459095 test plan)
_origin_type_map = {
list: typing.List,
dict: typing.Dict,
set: typing.Set,
frozenset: typing.FrozenSet,
tuple: typing.Tuple,
list: typing.List, # noqa: UP006
dict: typing.Dict, # noqa: UP006
set: typing.Set, # noqa: UP006
frozenset: typing.FrozenSet, # noqa: UP006
tuple: typing.Tuple, # noqa: UP006
}
_legal_ops = dict.fromkeys(

View File

@ -1,12 +1,12 @@
# mypy: allow-untyped-defs
import ast
import copy
import dataclasses
import inspect
import re
import string
from collections import namedtuple
from textwrap import dedent
from typing import List, Tuple # noqa: F401
import torch
import torch.jit.annotations
@ -551,7 +551,7 @@ def build_ignore_context_manager(ctx, stmt):
return_type_ann = " -> " + outputs[0].ann
return_statement_str += outputs[0].name
if len(outputs) > 1:
return_type_ann = " -> Tuple"
return_type_ann = " -> tuple"
return_type_ann += "[" + ", ".join([var.ann for var in outputs]) + "]"
return_statement_str += ", ".join([var.name for var in outputs])
return return_type_ann, return_statement_str
@ -581,10 +581,18 @@ def build_ignore_context_manager(ctx, stmt):
return_stmt = ast.parse(return_stmt).body[0]
ignore_function.body.append(return_stmt) # type: ignore[attr-defined]
ignore_func_str = f"""\
# Backward compat: These used to be imported into the outer global scope so some
# code may still expect them.
from typing import List, Dict, Tuple
@torch.jit.ignore
{astunparse.unparse(ignore_function)}
"""
g = copy.copy(globals())
exec(ignore_func_str, g) # noqa: P204
# registers the custom function in the global context
ignore_func_str = "@torch.jit.ignore\n" + astunparse.unparse(ignore_function)
ignore_func_str += f'\nglobals()["{ignore_function_name}"] = {ignore_function_name}'
exec(ignore_func_str) # noqa: P204
globals()[ignore_function_name] = g[ignore_function_name]
# build the statements as:
# <out_1>, <out_2>, ... = torch.jit.frontend.<func>(<in_1>, <in_2>)

View File

@ -5,7 +5,7 @@ This package enables an interface for accessing MTIA backend in python
import threading
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
import torch
from torch import device as _device, Tensor

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import List, Optional, Tuple, Union
from typing import Optional, Union
import torch
import torch.nn.functional as F

View File

@ -2,7 +2,7 @@
""" This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """
import contextlib
from collections.abc import Iterable
from typing import List, Union
from typing import Union
from warnings import warn
import torch.backends.cuda

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import inspect
import warnings
from typing import Any, Sequence
from typing import Any, TYPE_CHECKING
import torch
from torch.export.dynamic_shapes import _Dim, _DimHint
@ -13,6 +13,10 @@ from torch.onnx._internal._lazy_import import onnxscript_ir as ir
from torch.utils import _pytree
if TYPE_CHECKING:
from collections.abc import Sequence
def from_dynamic_axes_to_dynamic_shapes(
model,
args: tuple[Any, ...],

View File

@ -15,7 +15,7 @@ import threading
import warnings
from contextlib import closing, contextmanager
from enum import Enum
from typing import Any, Callable, cast, Dict, Generic, IO, Optional, TypeVar, Union
from typing import Any, Callable, cast, Generic, IO, Optional, TypeVar, Union
from typing_extensions import TypeAlias, TypeIs
import torch
@ -1898,7 +1898,7 @@ def _load(
data_descripter_size64 = 24
data_descripter_size32 = 16
mz_uint32_max = 0xFFFFFFFF
offsets: Dict[str, int] = dict()
offsets: dict[str, int] = dict()
def _get_offset(key, name, numel):
"""

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
# The Tensor classes are added to this module by python_tensor.cpp
# A workaround to support both TorchScript and MyPy:
from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Any, Optional, TYPE_CHECKING, Union
import torch
from torch import Tensor

View File

@ -4,7 +4,7 @@
import copy
from itertools import chain
from typing import Any, Dict
from typing import Any
import torch
import torch.nn as nn
@ -164,7 +164,7 @@ class FusionEmbeddingWithModifier(FusionEmbeddingWithHook):
# _fqn_modifiers is a private function as a contract between DSD. When users change the state_dict
# keys, they need to provide a mapping from the new key to the original key. This is used to ensure
# consistency between the state_dict keys and fqn.
def _fqn_modifiers(self) -> Dict[str, str]:
def _fqn_modifiers(self) -> dict[str, str]:
return {
"weight": "embedding",
}

View File

@ -1,7 +1,5 @@
# mypy: ignore-errors
from typing import List
from torch.testing._internal.opinfo.core import OpInfo
from torch.testing._internal.opinfo.definitions import (
_masked,

View File

@ -3,8 +3,8 @@ import bisect
import itertools
import math
import warnings
from collections.abc import Sequence
from typing import cast, Generic, Iterable, Optional, TypeVar, Union
from collections.abc import Iterable, Sequence
from typing import cast, Generic, Optional, TypeVar, Union
from typing_extensions import deprecated
# No 'default_generator' in torch/__init__.pyi

View File

@ -75,7 +75,6 @@ import sys
import urllib.parse
import zipfile
from pathlib import Path
from typing import Dict
import warnings
import torch.utils.show_pickle

View File

@ -9,7 +9,7 @@ This package is lazily initialized, so you can always import it, and use
import threading
import traceback
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
import torch
import torch._C