From db4ce78d46c0aa136b3eff10ef3e5684eb78eb3f Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Wed, 19 Feb 2025 14:37:18 -0800 Subject: [PATCH] PEP585: More UP006 fixes (#146392) This should be the final PR before we can enable RUFF UP006. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146392 Approved by: https://github.com/justinchuby, https://github.com/albanD, https://github.com/Skylion007 --- .../fsdp/test_fully_shard_ignore_params.py | 3 +- .../test_c10d_functional_native.py | 14 +++--- test/functorch/test_ac_logging.py | 25 +++++----- test/onnx/test_onnxscript_runtime.py | 2 +- test/onnx/torchlib/error_reproduction.py | 6 ++- torch/_C/_monitor.pyi | 4 +- torch/_decomp/__init__.py | 12 +---- torch/_dynamo/polyfills/__init__.py | 2 +- torch/_dynamo/utils.py | 3 +- torch/_dynamo/variables/base.py | 3 +- torch/_dynamo/variables/builtin.py | 4 +- torch/_dynamo/variables/functions.py | 8 ++-- torch/_export/__init__.py | 2 +- .../ac_logging_utils.py | 48 +++++++++---------- .../_aot_autograd/subclass_parametrization.py | 3 +- torch/_higher_order_ops/aoti_call_delegate.py | 20 ++++---- torch/_higher_order_ops/utils.py | 8 ++-- torch/_inductor/__init__.py | 2 +- torch/_inductor/codecache.py | 3 +- torch/_inductor/codegen/common.py | 4 +- torch/_inductor/codegen/cpp_wrapper_gpu.py | 4 +- .../codegen/cuda_combined_scheduling.py | 6 +-- torch/_inductor/codegen/simd.py | 13 +---- .../_inductor/codegen/simd_kernel_features.py | 38 ++++++++------- torch/_inductor/compile_fx.py | 18 ++----- torch/_inductor/cudagraph_trees.py | 14 ++---- torch/_inductor/dependencies.py | 24 +++++----- torch/_inductor/freezing_utils.py | 3 +- torch/_inductor/fuzzer.py | 3 +- torch/_inductor/fx_passes/pad_mm.py | 9 ++-- torch/_inductor/ir.py | 5 +- torch/_inductor/kernel/flex_attention.py | 4 +- torch/_inductor/metrics.py | 10 ++-- torch/_inductor/mkldnn_ir.py | 3 +- torch/_inductor/output_code.py | 12 ++--- torch/_inductor/scheduler.py | 22 +++------ torch/_inductor/utils.py | 11 ++--- torch/_inductor/wrapper_benchmark.py | 6 +-- torch/_jit_internal.py | 27 ++++++----- torch/_prims/__init__.py | 2 +- torch/_prims_common/__init__.py | 3 -- torch/_refs/__init__.py | 2 +- torch/_refs/linalg/__init__.py | 2 +- torch/_refs/nn/__init__.py | 3 -- torch/_subclasses/functional_tensor.py | 11 +++-- torch/_subclasses/meta_utils.py | 6 ++- torch/ao/quantization/__init__.py | 2 +- torch/autograd/__init__.py | 2 +- torch/backends/quantized/__init__.py | 1 - torch/compiler/__init__.py | 2 +- torch/cuda/__init__.py | 4 +- .../_composable/checkpoint_activation.py | 6 +-- torch/distributed/_serialization.py | 6 +-- .../_shard/sharded_optim/__init__.py | 2 +- .../_shard/sharded_tensor/__init__.py | 2 +- .../distributed/_symmetric_memory/__init__.py | 2 +- torch/distributed/checkpoint/planner.py | 10 ++-- torch/distributed/elastic/events/__init__.py | 2 +- .../elastic/multiprocessing/__init__.py | 2 +- .../multiprocessing/errors/__init__.py | 2 +- torch/distributed/pipelining/_utils.py | 4 +- torch/distributed/pipelining/schedules.py | 4 +- torch/distributed/rpc/__init__.py | 1 - torch/export/__init__.py | 12 +---- torch/export/passes/__init__.py | 2 +- torch/futures/__init__.py | 2 +- torch/fx/_graph_pickler.py | 42 ++++++++-------- torch/fx/experimental/sym_node.py | 4 +- torch/fx/graph.py | 10 ++-- torch/jit/frontend.py | 18 +++++-- torch/mtia/__init__.py | 2 +- torch/nested/__init__.py | 2 +- torch/nn/attention/__init__.py | 2 +- .../_internal/exporter/_dynamic_shapes.py | 6 ++- torch/serialization.py | 4 +- torch/sparse/__init__.py | 2 +- .../distributed/common_state_dict.py | 4 +- .../_internal/opinfo/definitions/__init__.py | 2 - torch/utils/data/dataset.py | 4 +- torch/utils/model_dump/__init__.py | 1 - torch/xpu/__init__.py | 2 +- 81 files changed, 283 insertions(+), 329 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_ignore_params.py b/test/distributed/_composable/fsdp/test_fully_shard_ignore_params.py index 83ffd0d8a9e..45dc850fe8d 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_ignore_params.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_ignore_params.py @@ -1,7 +1,6 @@ # Owner(s): ["oncall: distributed"] import sys -from typing import List import torch import torch.distributed as dist @@ -119,7 +118,7 @@ def _find_name_param_mappings(module: torch.nn.Module, prefix: str): def _discover_ddp_ignored_params(module: torch.nn.Module, prefix: str): - ddp_ignore_parameters: List[str] = [] + ddp_ignore_parameters: list[str] = [] if isinstance(module, FSDP2): ddp_ignore_parameters = [name for name, _ in module.named_parameters(prefix)] else: diff --git a/test/distributed/test_c10d_functional_native.py b/test/distributed/test_c10d_functional_native.py index 442c6bdd530..4c4940dbccc 100644 --- a/test/distributed/test_c10d_functional_native.py +++ b/test/distributed/test_c10d_functional_native.py @@ -3,7 +3,7 @@ import gc import threading import unittest from datetime import timedelta -from typing import List, Optional +from typing import Optional import torch import torch.distributed as dist @@ -576,24 +576,24 @@ class ProcessGroupDummy(dist.ProcessGroup): self.waits = 0 self.dels = 0 - def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> dist.Work: + def broadcast(self, tensor_list: list[torch.Tensor], opts: object) -> dist.Work: return _DummyWork(self) def allgather_into_tensor_coalesced( self, - output_lists: List[torch.Tensor], - input_list: List[torch.Tensor], + output_lists: list[torch.Tensor], + input_list: list[torch.Tensor], opts: object, ) -> dist.Work: return _DummyWork(self) - def allreduce(self, tensors: List[torch.Tensor], opts: object) -> dist.Work: + def allreduce(self, tensors: list[torch.Tensor], opts: object) -> dist.Work: return _DummyWork(self) def reduce_scatter_tensor_coalesced( self, - outputTensors: List[torch.Tensor], - inputTensors: List[torch.Tensor], + outputTensors: list[torch.Tensor], + inputTensors: list[torch.Tensor], opts: object, ) -> dist.Work: return _DummyWork(self) diff --git a/test/functorch/test_ac_logging.py b/test/functorch/test_ac_logging.py index f4574a499ca..03ddb7d4584 100644 --- a/test/functorch/test_ac_logging.py +++ b/test/functorch/test_ac_logging.py @@ -1,5 +1,4 @@ # Owner(s): ["module: functorch"] -from typing import Dict, List, Tuple from unittest.mock import MagicMock, patch from torch._functorch._activation_checkpointing.ac_logging_utils import ( @@ -33,17 +32,17 @@ class TestAcLogging(TestCase): self.graph.nodes = [self.node1, self.node2] - self.all_recomputable_banned_nodes: List[Node] = [self.node1] - self.saved_node_idxs: List[int] = [0] - self.recomputable_node_idxs: List[int] = [] + self.all_recomputable_banned_nodes: list[Node] = [self.node1] + self.saved_node_idxs: list[int] = [0] + self.recomputable_node_idxs: list[int] = [] self.expected_runtime: int = 100 - self.memories_banned_nodes: List[int] = [50] - self.runtimes_banned_nodes: List[int] = [10] - self.min_cut_saved_values: List[Node] = [self.node1] + self.memories_banned_nodes: list[int] = [50] + self.runtimes_banned_nodes: list[int] = [10] + self.min_cut_saved_values: list[Node] = [self.node1] def test_create_joint_graph_node_information(self) -> None: - recomputable_node_info: Dict[str, int] = {"node1": 0} - expected_output: Dict[str, Dict] = { + recomputable_node_info: dict[str, int] = {"node1": 0} + expected_output: dict[str, dict] = { "node1": { "index": 0, "name": "node1", @@ -68,12 +67,12 @@ class TestAcLogging(TestCase): self.assertEqual(result, expected_output) def test_create_joint_graph_edges(self) -> None: - expected_edges: List[Tuple[str, str]] = [("node1", "node2")] + expected_edges: list[tuple[str, str]] = [("node1", "node2")] result = create_joint_graph_edges(self.graph) self.assertEqual(result, expected_edges) def test_create_activation_checkpointing_logging_structure_payload(self) -> None: - input_joint_graph_node_information: Dict[str, Dict] = { + input_joint_graph_node_information: dict[str, dict] = { "node1": { "index": 0, "name": "node1", @@ -85,8 +84,8 @@ class TestAcLogging(TestCase): "recomputable_candidate_info": {"recomputable_node_idx": 0}, } } - joint_graph_edges: List[Tuple[str, str]] = [("node1", "node2")] - expected_payload: Dict[str, any] = { + joint_graph_edges: list[tuple[str, str]] = [("node1", "node2")] + expected_payload: dict[str, any] = { "Joint Graph Size": 2, "Joint Graph Edges": {"Total": 1, "Edges": joint_graph_edges}, "Joint Graph Node Information": input_joint_graph_node_information, diff --git a/test/onnx/test_onnxscript_runtime.py b/test/onnx/test_onnxscript_runtime.py index 2eb2405535c..23205045e83 100644 --- a/test/onnx/test_onnxscript_runtime.py +++ b/test/onnx/test_onnxscript_runtime.py @@ -2,7 +2,7 @@ """Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime.""" -from typing import Sequence +from typing import Sequence # noqa: UP035 import onnx_test_common import onnxscript diff --git a/test/onnx/torchlib/error_reproduction.py b/test/onnx/torchlib/error_reproduction.py index dc30d8e5c86..260a37b65f1 100644 --- a/test/onnx/torchlib/error_reproduction.py +++ b/test/onnx/torchlib/error_reproduction.py @@ -9,7 +9,7 @@ import platform import sys import time import traceback -from typing import Any, Mapping +from typing import Any, TYPE_CHECKING import numpy as np @@ -20,6 +20,10 @@ import onnxscript import torch +if TYPE_CHECKING: + from collections.abc import Mapping + + _REPRODUCTION_TEMPLATE = '''\ import google.protobuf.text_format import numpy as np diff --git a/torch/_C/_monitor.pyi b/torch/_C/_monitor.pyi index d0201218170..d28c373e528 100644 --- a/torch/_C/_monitor.pyi +++ b/torch/_C/_monitor.pyi @@ -3,7 +3,7 @@ import datetime from enum import Enum from types import TracebackType -from typing import Callable, Optional, Type +from typing import Callable, Optional class Aggregation(Enum): VALUE = ... @@ -48,7 +48,7 @@ class _WaitCounterTracker: def __enter__(self) -> None: ... def __exit__( self, - exec_type: Optional[Type[BaseException]] = None, + exec_type: Optional[type[BaseException]] = None, exec_value: Optional[BaseException] = None, traceback: Optional[TracebackType] = None, ) -> None: ... diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 37b50a2efdd..a45988ca469 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -4,17 +4,7 @@ from collections import defaultdict from collections.abc import Sequence from functools import lru_cache, partial, wraps from itertools import chain -from typing import ( - Callable, - Dict, - FrozenSet, - List, - Optional, - Set, - TYPE_CHECKING, - TypeVar, - Union, -) +from typing import Callable, Optional, TYPE_CHECKING, TypeVar, Union from typing_extensions import ParamSpec diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index c6afee2274d..6b684977ad4 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -11,7 +11,7 @@ Python polyfills for common builtins. import types from collections.abc import MutableMapping, Sequence from itertools import repeat as _repeat -from typing import Any, Callable, List, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING import torch diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 347b05f5383..2b90ba2ec8d 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -59,7 +59,6 @@ from typing import ( Generic, Optional, overload, - Set, TypeVar, Union, ) @@ -1393,7 +1392,7 @@ def _scrubbed_inductor_config_for_logging() -> Optional[str]: except Exception: return "Value is not JSON serializable" - keys_to_scrub: Set[Any] = set() + keys_to_scrub: set[Any] = set() inductor_conf_str = None inductor_config_copy = ( torch._inductor.config.get_config_copy() if torch._inductor.config else None diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index eed78c93593..a74efedf9ac 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -16,8 +16,9 @@ computations. """ import collections +from collections.abc import Sequence from enum import Enum -from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING +from typing import Any, Callable, Optional, TYPE_CHECKING from .. import variables from ..current_scope_id import current_scope_id diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index e00bc431c4f..460fe310744 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -11,8 +11,8 @@ import sys import types import typing from collections import defaultdict, OrderedDict -from collections.abc import KeysView -from typing import Callable, Sequence, TYPE_CHECKING, Union +from collections.abc import KeysView, Sequence +from typing import Callable, TYPE_CHECKING, Union import torch from torch import sym_float, sym_int diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 8a2a31ae4ac..d2b211bbafa 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -30,7 +30,7 @@ import itertools import sys import types from collections.abc import Sequence -from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, TypeVar +from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar from typing_extensions import Never from unittest.mock import patch @@ -517,7 +517,7 @@ class LocalGeneratorObjectVariable(VariableTracker): def has_force_unpack_var_sequence(self, tx) -> builtins.bool: return True - def force_unpack_var_sequence(self, tx) -> List[VariableTracker]: + def force_unpack_var_sequence(self, tx) -> list[VariableTracker]: result = [] while True: try: @@ -547,8 +547,8 @@ class LocalGeneratorObjectVariable(VariableTracker): self, tx: "InstructionTranslator", name: str, - args: "List[VariableTracker]", - kwargs: "Dict[str, VariableTracker]", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": if name == "__next__": return self.next_variable(tx) diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index ddba404ccdd..4affd3698d4 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -16,7 +16,7 @@ from collections import OrderedDict from contextlib import contextmanager from functools import lru_cache -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union from unittest.mock import patch import torch diff --git a/torch/_functorch/_activation_checkpointing/ac_logging_utils.py b/torch/_functorch/_activation_checkpointing/ac_logging_utils.py index 8e3a39dc9b6..fe22a383795 100644 --- a/torch/_functorch/_activation_checkpointing/ac_logging_utils.py +++ b/torch/_functorch/_activation_checkpointing/ac_logging_utils.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Dict, List, Tuple +from typing import Any from torch._logging import trace_structured from torch.fx import Graph, Node @@ -11,9 +11,9 @@ log: logging.Logger = logging.getLogger(__name__) def create_joint_graph_node_information( joint_graph: Graph, - recomputable_node_info: Dict[str, int], -) -> Dict[str, Any]: - joint_graph_node_information: Dict[str, Any] = {} + recomputable_node_info: dict[str, int], +) -> dict[str, Any]: + joint_graph_node_information: dict[str, Any] = {} for i, joint_graph_node in enumerate(joint_graph.nodes): is_recomputable_candidate: bool = ( @@ -22,7 +22,7 @@ def create_joint_graph_node_information( tensor_meta = joint_graph_node.meta.get("tensor_meta") shape = getattr(tensor_meta, "shape", []) if tensor_meta else [] - node_info: Dict[str, Any] = { + node_info: dict[str, Any] = { "index": i, "name": joint_graph_node.name, "is_recomputable_candidate": is_recomputable_candidate, @@ -43,8 +43,8 @@ def create_joint_graph_node_information( return joint_graph_node_information -def create_joint_graph_edges(joint_graph: Graph) -> List[Tuple[str, str]]: - joint_graph_edges: List[Tuple[str, str]] = [ +def create_joint_graph_edges(joint_graph: Graph) -> list[tuple[str, str]]: + joint_graph_edges: list[tuple[str, str]] = [ (inp.name, node.name) for node in joint_graph.nodes for inp in node.all_input_nodes @@ -54,17 +54,17 @@ def create_joint_graph_edges(joint_graph: Graph) -> List[Tuple[str, str]]: def create_activation_checkpointing_logging_structure_payload( joint_graph: Graph, - joint_graph_node_information: Dict[str, Any], - joint_graph_edges: List[Tuple[str, str]], - all_recomputable_banned_nodes: List[Node], + joint_graph_node_information: dict[str, Any], + joint_graph_edges: list[tuple[str, str]], + all_recomputable_banned_nodes: list[Node], expected_runtime: float, - saved_node_idxs: List[int], - recomputable_node_idxs: List[int], - memories_banned_nodes: List[float], - runtimes_banned_nodes: List[float], - min_cut_saved_values: List[Node], -) -> Dict[str, Any]: - activation_checkpointing_logging_structure_payload: Dict[str, Any] = { + saved_node_idxs: list[int], + recomputable_node_idxs: list[int], + memories_banned_nodes: list[float], + runtimes_banned_nodes: list[float], + min_cut_saved_values: list[Node], +) -> dict[str, Any]: + activation_checkpointing_logging_structure_payload: dict[str, Any] = { "Joint Graph Size": len(joint_graph.nodes), "Joint Graph Edges": { "Total": len(joint_graph_edges), @@ -86,15 +86,15 @@ def create_activation_checkpointing_logging_structure_payload( def create_structured_trace_for_min_cut_info( joint_graph: Graph, - all_recomputable_banned_nodes: List[Node], - saved_node_idxs: List[int], - recomputable_node_idxs: List[int], + all_recomputable_banned_nodes: list[Node], + saved_node_idxs: list[int], + recomputable_node_idxs: list[int], expected_runtime: float, - memories_banned_nodes: List[float], - runtimes_banned_nodes: List[float], - min_cut_saved_values: List[Node], + memories_banned_nodes: list[float], + runtimes_banned_nodes: list[float], + min_cut_saved_values: list[Node], ) -> None: - recomputable_node_info: Dict[str, int] = { + recomputable_node_info: dict[str, int] = { node.name: idx for idx, node in enumerate(all_recomputable_banned_nodes) } joint_graph_node_information = create_joint_graph_node_information( diff --git a/torch/_functorch/_aot_autograd/subclass_parametrization.py b/torch/_functorch/_aot_autograd/subclass_parametrization.py index e24978b490e..5d6d17ca099 100644 --- a/torch/_functorch/_aot_autograd/subclass_parametrization.py +++ b/torch/_functorch/_aot_autograd/subclass_parametrization.py @@ -1,6 +1,7 @@ import dataclasses import itertools -from typing import Any, Iterable, Union +from collections.abc import Iterable +from typing import Any, Union import torch from torch.utils._python_dispatch import is_traceable_wrapper_subclass diff --git a/torch/_higher_order_ops/aoti_call_delegate.py b/torch/_higher_order_ops/aoti_call_delegate.py index cec5c019400..0fb0e0ea4a5 100644 --- a/torch/_higher_order_ops/aoti_call_delegate.py +++ b/torch/_higher_order_ops/aoti_call_delegate.py @@ -8,8 +8,6 @@ from __future__ import annotations -from typing import List - import torch import torch.utils._pytree as pytree from torch._ops import HigherOrderOperator @@ -50,9 +48,9 @@ class AOTICallDelegate(HigherOrderOperator): self, lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type] original_gm: torch.fx.GraphModule, - weight_args: List[torch.Tensor], - input_args: List[torch.Tensor], - ) -> List[torch.Tensor]: + weight_args: list[torch.Tensor], + input_args: list[torch.Tensor], + ) -> list[torch.Tensor]: return super().__call__(lowered_module, original_gm, weight_args, input_args) @@ -68,9 +66,9 @@ aoti_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU) def call_delegate_cpu( lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type] original_gm: torch.fx.GraphModule, - weight_args: List[torch.Tensor], - input_args: List[torch.Tensor], -) -> List[torch.Tensor]: + weight_args: list[torch.Tensor], + input_args: list[torch.Tensor], +) -> list[torch.Tensor]: # FX creates this immutable_dict/list concept. Get rid of this. map_types = { torch.fx.immutable_collections.immutable_dict: dict, @@ -104,8 +102,8 @@ def call_delegate_fake_tensor_mode( mode: FakeTensorMode, lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type] original_gm: torch.fx.GraphModule, - weight_args: List[torch.Tensor], - input_args: List[torch.Tensor], -) -> List[torch.Tensor]: + weight_args: list[torch.Tensor], + input_args: list[torch.Tensor], +) -> list[torch.Tensor]: with mode: return call_delegate_cpu(lowered_module, original_gm, weight_args, input_args) diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 787c186fcc5..15d5ea8b8b2 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -2,7 +2,7 @@ import functools from contextlib import contextmanager, ExitStack from dataclasses import dataclass -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import torch import torch.fx.traceback as fx_traceback @@ -173,7 +173,7 @@ def _detect_input_alias(gm: torch.fx.GraphModule) -> bool: # The invariant here is that we always trace the branch with fake tensor -def _maybe_fake_tracing(fn, inputs: List[Any], pre_dispatch): +def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch): fake_mode = detect_fake_mode(inputs) tracing_mode = "real" if fake_mode is None: @@ -565,8 +565,8 @@ def validate_subgraph_args_types(lifted_args: Union[tuple[Any, ...], list[Any]]) def check_input_alias_and_mutation( gm: torch.fx.GraphModule, - fake_args: List[FakeTensor], -) -> Tuple[List[int], dict[int, int], dict[int, int], dict[int, int]]: + fake_args: list[FakeTensor], +) -> tuple[list[int], dict[int, int], dict[int, int], dict[int, int]]: with disable_proxy_modes_tracing(): """This function returns mutated inputs, inp-inp alias, inp-out alias, out-out alias in the graph module gm. It checks whether input tensor versions have diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index 8f409ceb14b..3cef26f31e8 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -4,7 +4,7 @@ from __future__ import annotations import io import logging import os -from typing import Any, Dict, IO, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, IO, Optional, TYPE_CHECKING, Union import torch._inductor.config import torch.fx diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 7cef9194c01..8b558901e1a 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -37,7 +37,6 @@ from typing import ( cast, NoReturn, Optional, - Tuple, TYPE_CHECKING, TypeVar, Union, @@ -523,7 +522,7 @@ class FxGraphCachePickler(pickle.Pickler): def _reduce_tensor( self, t: Tensor - ) -> Tuple[Callable[[T], T], Tuple[Union[TensorMetadata, TensorMetadataAndValues]]]: + ) -> tuple[Callable[[T], T], tuple[Union[TensorMetadata, TensorMetadataAndValues]]]: """ Custom reducer to pickle Tensors. If we see tensors, we know they're constants stored as attributes on the GraphModule. diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 0e0dd09ac5a..1c89b4ec60f 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -18,8 +18,6 @@ from typing import ( cast, ClassVar, Generic, - Iterator, - MutableMapping, NamedTuple, Optional, TYPE_CHECKING, @@ -59,7 +57,7 @@ from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Iterator, MutableMapping, Sequence from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode from ..loop_body import LoopBody diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 9c419765da9..2f52b10534a 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import os from itertools import chain, count, zip_longest -from typing import Any, Callable, Hashable, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import sympy @@ -24,6 +24,8 @@ from .wrapper import PythonWrapperCodegen, SymbolicCallArg if TYPE_CHECKING: + from collections.abc import Hashable + from ..graph import GraphLowering diff --git a/torch/_inductor/codegen/cuda_combined_scheduling.py b/torch/_inductor/codegen/cuda_combined_scheduling.py index f5ba35b2f3a..3af7d72f710 100644 --- a/torch/_inductor/codegen/cuda_combined_scheduling.py +++ b/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs from __future__ import annotations -from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING, Union from ..scheduler import ( BaseSchedulerNode, @@ -114,7 +114,7 @@ class CUDACombinedScheduling(BaseScheduling): def benchmark_fused_nodes( self, nodes: Sequence[BaseSchedulerNode] - ) -> Tuple[float, str]: + ) -> tuple[float, str]: return self._triton_scheduling.benchmark_fused_nodes(nodes) def benchmark_codegened_module(self, module): @@ -129,5 +129,5 @@ class CUDACombinedScheduling(BaseScheduling): def benchmark_combo_kernel( self, node_list: Sequence[BaseSchedulerNode] - ) -> tuple[float, float, List[Optional[str]]]: + ) -> tuple[float, float, list[Optional[str]]]: return self._triton_scheduling.benchmark_combo_kernel(node_list) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 5dfc5b3bf58..3d5589cd45b 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -11,16 +11,7 @@ import math import operator import textwrap from collections import Counter -from typing import ( - Any, - Callable, - Generic, - Iterator, - no_type_check, - Optional, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, Generic, no_type_check, Optional, TYPE_CHECKING, Union from typing_extensions import TypeVar import sympy @@ -72,7 +63,7 @@ from .simd_kernel_features import ( if TYPE_CHECKING: - from collections.abc import Iterable, Sequence + from collections.abc import Iterable, Iterator, Sequence log = logging.getLogger(__name__) diff --git a/torch/_inductor/codegen/simd_kernel_features.py b/torch/_inductor/codegen/simd_kernel_features.py index 83e5a768f91..1a8cac07170 100644 --- a/torch/_inductor/codegen/simd_kernel_features.py +++ b/torch/_inductor/codegen/simd_kernel_features.py @@ -5,7 +5,7 @@ import dataclasses import functools import itertools import typing -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Union import sympy @@ -21,6 +21,10 @@ from ..utils import cache_on_self from ..virtualized import V +if typing.TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + class NodeScheduleMarker: @staticmethod def only_nodes(it: Iterable[NodeScheduleEntry]) -> Iterable[SchedulerNode]: @@ -81,7 +85,7 @@ class SIMDKernelFeatures: # numel excludes reduction_numel self.numel: sympy.Expr = V.graph.sizevars.simplify(numel) self.reduction_numel: sympy.Expr = V.graph.sizevars.simplify(reduction_numel) - self._stats_cache: Dict[Tuple[sympy.Expr, ...], MemoryStats] = {} + self._stats_cache: dict[tuple[sympy.Expr, ...], MemoryStats] = {} @cache_on_self def is_reduction(self) -> bool: @@ -205,7 +209,7 @@ class SIMDKernelFeatures: return node.node.data.reduction_hint def memory_stats( - self, groups_dict: Optional[Dict[str, sympy.Expr]] = None + self, groups_dict: Optional[dict[str, sympy.Expr]] = None ) -> MemoryStats: """Analysis to generate features that can be used in heuristics""" if groups_dict is None: @@ -228,11 +232,11 @@ class MemoryEstimator: We simulate the memory effects of CSE/buffer elimination in codegen. """ - kernel_sizes: Tuple[sympy.Expr, ...] + kernel_sizes: tuple[sympy.Expr, ...] outside_loop: MemoryEstimate - loops: List[MemoryEstimate] + loops: list[MemoryEstimate] persistent: MemoryEstimate - symbols: List[sympy.Symbol] + symbols: list[sympy.Symbol] def __init__(self, features: SIMDKernelFeatures, groups: Sequence[sympy.Expr]): self.features = features @@ -341,7 +345,7 @@ class MemoryEstimator: return True return False - def set_ranges(self, *lengths: List[List[sympy.Expr]]) -> List[List[sympy.Expr]]: + def set_ranges(self, *lengths: list[list[sympy.Expr]]) -> list[list[sympy.Expr]]: assert len(self.kernel_sizes) == len(lengths) return [ self.make_flat_range(sym, numel, length) @@ -350,8 +354,8 @@ class MemoryEstimator: @staticmethod def make_flat_range( - sym: sympy.Symbol, numel: sympy.Expr, lengths: List[sympy.Expr] - ) -> List[sympy.Expr]: + sym: sympy.Symbol, numel: sympy.Expr, lengths: list[sympy.Expr] + ) -> list[sympy.Expr]: if len(lengths) == 1 and numel == lengths[0]: return [sym] divisor = sympy.S.One @@ -370,10 +374,10 @@ class MemoryEstimator: class MemoryEstimate: """Tracks the memory usage of a single loop in the generated kernel""" - reads: Dict[str, OrderedSet[MemoryDep]] = dataclasses.field( + reads: dict[str, OrderedSet[MemoryDep]] = dataclasses.field( default_factory=functools.partial(collections.defaultdict, OrderedSet) ) - writes: Dict[str, OrderedSet[MemoryDep]] = dataclasses.field( + writes: dict[str, OrderedSet[MemoryDep]] = dataclasses.field( default_factory=functools.partial(collections.defaultdict, OrderedSet) ) @@ -474,8 +478,8 @@ class StatsForLoop: class StatsForReadsOrWrites: """Memory usage stats that are collected for reads/writes/both""" - dim: List[StatsForDim] - loop: List[StatsForLoop] + dim: list[StatsForDim] + loop: list[StatsForLoop] # total bytes contiguous in any dimension bytes_contiguous_or_broadcast: sympy.Expr = sympy.S.Zero bytes_non_contiguous: sympy.Expr = sympy.S.Zero @@ -506,8 +510,8 @@ class StatsForReadsOrWrites: @classmethod def compute( cls, - loop_deps: List[Dict[str, OrderedSet[MemoryDep]]], - index_symbols: List[sympy.Symbol], + loop_deps: list[dict[str, OrderedSet[MemoryDep]]], + index_symbols: list[sympy.Symbol], ) -> typing.Self: ndim = len(index_symbols) result = cls(dim := [StatsForDim() for _ in range(ndim)], []) @@ -521,7 +525,7 @@ class StatsForReadsOrWrites: loop_stats.count_per_thread += len(deps) loop_stats.bytes_per_thread += itemsize * len(deps) for dep in deps: - strides: List[sympy.Expr] = V.graph.sizevars.stride_vars( + strides: list[sympy.Expr] = V.graph.sizevars.stride_vars( dep.index, index_symbols ) for i in range(ndim): @@ -568,7 +572,7 @@ class StatsForKernelType: @classmethod def compute( - cls, loops: List[MemoryEstimate], estimator: MemoryEstimator + cls, loops: list[MemoryEstimate], estimator: MemoryEstimator ) -> typing.Self: reads = StatsForReadsOrWrites.compute( [loop.reads for loop in loops], estimator.symbols diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 4a212d89732..9c18c8236d7 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -12,19 +12,11 @@ import sys import time import warnings from abc import ABC, abstractmethod +from contextlib import AbstractContextManager from dataclasses import dataclass from inspect import currentframe from itertools import count -from typing import ( - Any, - Callable, - ContextManager, - Mapping, - Optional, - TYPE_CHECKING, - TypeVar, - Union, -) +from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union from typing_extensions import Never, override, ParamSpec, Protocol, TypedDict, Unpack from unittest import mock @@ -127,7 +119,7 @@ from .virtualized import V if TYPE_CHECKING: - from collections.abc import Generator, Sequence + from collections.abc import Generator, Mapping, Sequence from torch._inductor.output_code import _StrideExprStr from torch._ops import OpOverload @@ -480,7 +472,7 @@ def is_tf32_warning_applicable(gm: GraphModule) -> bool: def maybe_disable_comprehensive_padding( example_inputs: Sequence[InputType], -) -> contextlib.AbstractContextManager[None, None]: +) -> AbstractContextManager[None, None]: """ For CPU backend, enable comprehensive padding causes some unit tests fail due to changing number of generated kernels. Skip for now. @@ -1780,7 +1772,7 @@ def get_cpp_wrapper_config() -> dict[str, object]: } -def get_cuda_device_context(gm: torch.fx.GraphModule) -> ContextManager[None]: +def get_cuda_device_context(gm: torch.fx.GraphModule) -> AbstractContextManager[None]: """ Returns a cuda device context manager if there is a single device in the graph """ diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index c53f3b797c4..5421eecc6a7 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -48,17 +48,9 @@ import traceback import warnings import weakref from collections import defaultdict +from contextlib import AbstractContextManager from enum import auto, Enum -from typing import ( - Any, - Callable, - cast, - ContextManager, - Optional, - TYPE_CHECKING, - TypeVar, - Union, -) +from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypeVar, Union import torch.fx from torch import Tensor @@ -179,7 +171,7 @@ def enable_history_recording() -> Generator[None, None, None]: torch.cuda.memory._record_memory_history(None) -def get_history_recording() -> ContextManager[None]: +def get_history_recording() -> AbstractContextManager[None]: # TODO - remove, prevents cleanup if not config.triton.cudagraph_trees_history_recording: return contextlib.nullcontext() diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 149cde77619..39ae0efb0a1 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -3,8 +3,8 @@ import dataclasses import itertools import logging import re -from collections.abc import Sequence -from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union +from collections.abc import Iterable, Sequence +from typing import Any, Callable, Optional, TypeVar, Union from typing_extensions import Self from unittest.mock import patch @@ -80,7 +80,7 @@ class MemoryDep(Dep): def num_vars(self) -> int: return len(self.var_names) - def decide_loop_order_to_match(self, other: "MemoryDep") -> Optional[List[int]]: + def decide_loop_order_to_match(self, other: "MemoryDep") -> Optional[list[int]]: """ Can return None if not able to decide loop orders. """ @@ -451,8 +451,8 @@ class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined] @staticmethod def drop_unused_symbols( index: Union[int, sympy.Expr], - var_names: List[sympy.Expr], - sizes: List[sympy.Expr], + var_names: list[sympy.Expr], + sizes: list[sympy.Expr], ) -> None: """ Reduction has last (reduced) dim in its sizes, but @@ -571,7 +571,7 @@ def var_builder(prefix: str) -> tuple[VarRanges, Callable[[sympy.Expr], sympy.Sy def index_vars_no_squeeze( *argsizes: Sequence[sympy.Expr], prefix: str -) -> Tuple[List[List[sympy.Symbol]], VarRanges]: +) -> tuple[list[list[sympy.Symbol]], VarRanges]: var_ranges, add_var = var_builder(prefix) args: list[list[sympy.Symbol]] = [list(map(add_var, size)) for size in argsizes] return args, var_ranges @@ -579,7 +579,7 @@ def index_vars_no_squeeze( def index_vars_squeeze( *argsizes: Sequence[sympy.Expr], prefix: str = "d" -) -> Tuple[List[List[sympy.Expr]], VarRanges]: +) -> tuple[list[list[sympy.Expr]], VarRanges]: from .ir import SqueezeView var_ranges, add_var = var_builder(prefix) @@ -597,7 +597,7 @@ def extract_read_writes( *argsizes: Sequence[sympy.Expr], normalize: bool = False, prefix: str = "d", - hidden_args: Sequence[List[sympy.Expr]] = (), + hidden_args: Sequence[list[sympy.Expr]] = (), ) -> ReadWrites: args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix) @@ -630,7 +630,7 @@ def extract_read_writes( def extract_loop_body_with_args( fn: Any, - args: List[List[sympy.Expr]], + args: list[list[sympy.Expr]], var_ranges: VarRanges, normalize: bool = False, ) -> _RecordLoadStoreInner: @@ -761,17 +761,17 @@ class FreeUnbackedSymbolsOpsHandler(DefaultHandler): self.symbols |= free_unbacked_symbols(size) return sympy_index_symbol(f"({str(index_var)})") - def frexp(self, x: Any) -> Tuple[None, ...]: + def frexp(self, x: Any) -> tuple[None, ...]: return (None,) * 2 def scan( self, dtypes: Any, combine_fn: Any, values: Sequence[Any] - ) -> Tuple[None, ...]: + ) -> tuple[None, ...]: return (None,) * len(values) def sort( self, dtypes: Any, values: Sequence[Any], stable: Any, descending: Any - ) -> Tuple[None, ...]: + ) -> tuple[None, ...]: return (None,) * len(values) def reduction( diff --git a/torch/_inductor/freezing_utils.py b/torch/_inductor/freezing_utils.py index 560b1afc266..8a14890aacb 100644 --- a/torch/_inductor/freezing_utils.py +++ b/torch/_inductor/freezing_utils.py @@ -1,6 +1,7 @@ import contextlib import threading -from typing import Any, Generator +from collections.abc import Generator +from typing import Any import torch diff --git a/torch/_inductor/fuzzer.py b/torch/_inductor/fuzzer.py index 0d408ad9cf2..07432704b18 100644 --- a/torch/_inductor/fuzzer.py +++ b/torch/_inductor/fuzzer.py @@ -7,7 +7,7 @@ import signal import string import sys import traceback -from collections.abc import KeysView +from collections.abc import KeysView, Sequence from enum import Enum from functools import partial, wraps from types import FrameType @@ -18,7 +18,6 @@ from typing import ( get_origin, Literal, Optional, - Sequence, TypeVar, Union, ) diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index 220c041168b..ef6353a0c88 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -2,7 +2,8 @@ import functools import itertools import operator import typing -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union import torch import torch._inductor.runtime.runtime_utils @@ -275,7 +276,7 @@ def should_pad_bench_key( input: Optional[Tensor] = None, is_base_time_key: bool = False, ) -> str: - def tensor_key(t: Tensor) -> Tuple[torch.Size, Tuple[int, ...], torch.dtype]: + def tensor_key(t: Tensor) -> tuple[torch.Size, tuple[int, ...], torch.dtype]: return (t.shape, t.stride(), t.dtype) tf32_key = ( @@ -436,8 +437,8 @@ def _should_pad_bench( return False def realize_symbols( - ds: Union[torch.Size, Tuple[torch.SymInt, ...]] - ) -> List[int]: + ds: Union[torch.Size, tuple[torch.SymInt, ...]] + ) -> list[int]: return [d if isinstance(d, int) else d.node.hint for d in ds] if any( diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 71b03ef89f4..b6865d06930 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -9,14 +9,13 @@ import textwrap import traceback import typing from collections.abc import Generator, Iterable, Sequence -from contextlib import nullcontext +from contextlib import AbstractContextManager, nullcontext from enum import Enum from functools import partial from typing import ( Any, Callable, ClassVar, - ContextManager, Literal, Optional, overload, @@ -6710,7 +6709,7 @@ class FallbackKernel(ExternKernelAlloc): @classmethod def create(cls, kernel, *args, **kwargs): # type: ignore[no-untyped-def] fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,) - context: ContextManager[None] = ( + context: AbstractContextManager[None] = ( V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext() # type: ignore[assignment] ) with context: diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 6ac06cb77cd..023c0aaddac 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -7,7 +7,7 @@ import math from collections.abc import Sequence from dataclasses import dataclass from enum import auto, Enum -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union import sympy @@ -1195,7 +1195,7 @@ def next_power_of_two(n): def set_head_dim_values( - kernel_options: Dict[str, Any], qk_head_dim, v_head_dim, graph_sizevars + kernel_options: dict[str, Any], qk_head_dim, v_head_dim, graph_sizevars ): """ Mutates kernel options, adding head dimension calculations. diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py index a67ef6be172..1407ae09b88 100644 --- a/torch/_inductor/metrics.py +++ b/torch/_inductor/metrics.py @@ -7,7 +7,7 @@ import os import re from dataclasses import dataclass from functools import lru_cache -from typing import Callable, cast, Dict, List, Optional, TYPE_CHECKING, Union +from typing import Callable, cast, Optional, TYPE_CHECKING, Union from torch._inductor import config from torch._inductor.utils import get_benchmark_name @@ -92,7 +92,7 @@ class CachedMetricsDeltas: num_matches_for_scatter_upon_const_tensor: int -def get_metric_fields() -> List[str]: +def get_metric_fields() -> list[str]: return [field.name for field in dataclasses.fields(CachedMetricsDeltas)] @@ -132,7 +132,7 @@ class MetricTable: num_rows_added: int = 0 def add_row( - self, row_fn: Callable[[], Dict[str, Optional[Union[str, float]]]] + self, row_fn: Callable[[], dict[str, Optional[Union[str, float]]]] ) -> None: if self.table_name not in enabled_metric_tables(): return @@ -160,7 +160,7 @@ class MetricTable: writer = csv.writer(fd, lineterminator="\n") writer.writerow(["model_name"] + self.column_names) - def _write_row(self, row: List[str]) -> None: + def _write_row(self, row: list[str]) -> None: filename = self.output_filename() if self.num_rows_added == 0 and not os.path.exists(filename): self.write_header() @@ -181,7 +181,7 @@ class MetricTable: writer.writerow(row) @staticmethod - def register_table(name: str, column_names: List[str]) -> None: + def register_table(name: str, column_names: list[str]) -> None: table = MetricTable(name, column_names) REGISTERED_METRIC_TABLES[name] = table diff --git a/torch/_inductor/mkldnn_ir.py b/torch/_inductor/mkldnn_ir.py index cb2e92ea448..688353ddf91 100644 --- a/torch/_inductor/mkldnn_ir.py +++ b/torch/_inductor/mkldnn_ir.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs -from typing import Any, Optional, Sequence +from collections.abc import Sequence +from typing import Any, Optional import sympy diff --git a/torch/_inductor/output_code.py b/torch/_inductor/output_code.py index bde517463db..9637ce8c4c2 100644 --- a/torch/_inductor/output_code.py +++ b/torch/_inductor/output_code.py @@ -27,7 +27,7 @@ import logging import os import re from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Optional, TYPE_CHECKING, Union from typing_extensions import TypeAlias import torch @@ -277,7 +277,7 @@ class CompiledFxGraphConstantsWithGm(CompiledFxGraphConstants): def __init__(self, gm: torch.fx.GraphModule) -> None: self.gm = gm - def unwrap(self, g: CompiledFxGraph) -> Dict[str, torch.Tensor]: + def unwrap(self, g: CompiledFxGraph) -> dict[str, torch.Tensor]: frozen_params = { name: getattr(self.gm, orig_name) for name, orig_name in g.frozen_param_names.items() @@ -301,10 +301,10 @@ class CompiledFxGraph(OutputCode): device_idxs: OrderedSet[int] mutated_inputs: OrderedSet[str] mutated_input_idxs: OrderedSet[int] - constants: Optional[Dict[str, torch.Tensor]] - frozen_param_names: Dict[str, str] - torchbind_constants: Dict[str, torch._C.ScriptObject] - output_strides: Optional[List[Optional[tuple[_StrideExprStr, ...]]]] + constants: Optional[dict[str, torch.Tensor]] + frozen_param_names: dict[str, str] + torchbind_constants: dict[str, torch._C.ScriptObject] + output_strides: Optional[list[Optional[tuple[_StrideExprStr, ...]]]] disabled_cudagraphs_reason: Optional[str] metrics_deltas: metrics.CachedMetricsDeltas counter_deltas: Counter[str] diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index a413e5c9cec..ad99a1ecc45 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -14,21 +14,11 @@ import textwrap import traceback import typing from collections import Counter, defaultdict -from typing import ( - Any, - Callable, - Dict, - Generic, - List, - Optional, - Sequence, - TYPE_CHECKING, - TypeVar, - Union, -) +from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar, Union if TYPE_CHECKING: + from collections.abc import Sequence from types import ModuleType import sympy @@ -2794,7 +2784,7 @@ class Scheduler: ) # Start compiling choices in parallel - future_choices: List[tuple[Any, Optional[LambdaFuture], ModuleType]] = [] + future_choices: list[tuple[Any, Optional[LambdaFuture], ModuleType]] = [] triton_choices = 0 for choice, unfused_time in sorted( choice_timings.items(), key=lambda x: x[1] @@ -2964,7 +2954,7 @@ class Scheduler: # These are potential fusions which we are async compiling, # and which we will benchmark profitability of. - pending_fusions: Dict[ + pending_fusions: dict[ BaseSchedulerNode, tuple[Callable[[], bool], BaseSchedulerNode, BaseSchedulerNode], ] = {} @@ -4068,7 +4058,7 @@ class Scheduler: def benchmark_combo_kernel( self, node_list: Sequence[BaseSchedulerNode] - ) -> tuple[float, float, List[Optional[str]]]: + ) -> tuple[float, float, list[Optional[str]]]: """ Benchmark fused list of nodes and return the execution time in milliseconds on randomly generated inputs. @@ -4308,7 +4298,7 @@ class BaseScheduling: def benchmark_combo_kernel( self, node_list: Sequence[BaseSchedulerNode] - ) -> tuple[float, float, List[Optional[str]]]: + ) -> tuple[float, float, list[Optional[str]]]: """ Benchmark the list of nodes to combine and return the execution time and memory copy time in milliseconds on randomly generated inputs. diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 44cd0e0dc56..9ce6145a20c 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -21,23 +21,18 @@ import tempfile import textwrap import time import unittest +from collections.abc import Collection, Iterator, Mapping, MutableMapping, MutableSet from datetime import datetime from io import StringIO from typing import ( Any, Callable, cast, - Collection, Generic, - Iterator, Literal, - Mapping, - MutableMapping, - MutableSet, NamedTuple, Optional, Protocol, - Type, TYPE_CHECKING, TypeVar, Union, @@ -1361,7 +1356,7 @@ def _rocm_native_device_arch_name(device: str) -> str: @functools.lru_cache(None) def try_import_ck_lib() -> ( - tuple[Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], Type[Any]] + tuple[Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any]] ): try: import ck4inductor # type: ignore[import] @@ -2608,7 +2603,7 @@ class ScopedDict(MutableMapping[KeyType, ValType]): @dataclass_transform(frozen_default=True) -def ir_dataclass(cls: Optional[Type[Any]] = None, /, *, frozen: bool = True) -> Any: +def ir_dataclass(cls: Optional[type[Any]] = None, /, *, frozen: bool = True) -> Any: def wrap(cls: _T) -> _T: if sys.version_info >= (3, 10): return dataclasses.dataclass(cls, kw_only=True, frozen=frozen) # type: ignore[call-overload] diff --git a/torch/_inductor/wrapper_benchmark.py b/torch/_inductor/wrapper_benchmark.py index be5c1051c83..736326f4940 100644 --- a/torch/_inductor/wrapper_benchmark.py +++ b/torch/_inductor/wrapper_benchmark.py @@ -3,7 +3,7 @@ import datetime import tempfile from collections import defaultdict from types import ModuleType -from typing import Any, Dict, Optional, Protocol +from typing import Any, Optional, Protocol import torch from torch.autograd import DeviceType @@ -73,7 +73,7 @@ def get_triton_kernel(mod: ModuleType): # type: ignore[no-untyped-def] def benchmark_all_kernels( - benchmark_name: str, benchmark_all_configs: Optional[Dict[Any, Any]] + benchmark_name: str, benchmark_all_configs: Optional[dict[Any, Any]] ) -> None: """ An experimental API used only when config.benchmark_kernel is true. @@ -184,7 +184,7 @@ def parse_profile_event_list( """ return ev.self_device_time_total / 1000 / nruns # type: ignore[attr-defined] - all_events: Dict[str, list[ProfileEvent]] = defaultdict(list) + all_events: dict[str, list[ProfileEvent]] = defaultdict(list) def add_event( ev: torch.autograd.profiler_util.EventList, diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index afed93ad6a6..d97647afa87 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -1353,11 +1353,11 @@ def _is_exception(obj) -> bool: def raise_error_container_parameter_missing(target_type) -> None: - if target_type == "Dict": + if target_type.endswith("ict"): raise RuntimeError( - "Attempted to use Dict without " + f"Attempted to use {target_type} without " "contained types. Please add contained type, e.g. " - "Dict[int, int]" + f"{target_type}[int, int]" ) raise RuntimeError( f"Attempted to use {target_type} without a " @@ -1366,15 +1366,20 @@ def raise_error_container_parameter_missing(target_type) -> None: ) +_RAW_TYPE_NAME_MAPPING = { + dict: "dict", + list: "list", + tuple: "tuple", + typing.Dict: "Dict", # noqa: UP006 + typing.List: "List", # noqa: UP006 + typing.Optional: "Optional", + typing.Tuple: "Tuple", # noqa: UP006 +} + + def check_args_exist(target_type) -> None: - if target_type is typing.List or target_type is list: # noqa: UP006 - raise_error_container_parameter_missing("List") - elif target_type is typing.Tuple or target_type is tuple: # noqa: UP006 - raise_error_container_parameter_missing("Tuple") - elif target_type is typing.Dict or target_type is dict: # noqa: UP006 - raise_error_container_parameter_missing("Dict") - elif target_type is None or target_type is Optional: - raise_error_container_parameter_missing("Optional") + if name := _RAW_TYPE_NAME_MAPPING.get(target_type): + raise_error_container_parameter_missing(name) def check_empty_containers(obj) -> None: diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index 172e5728d03..a5b364f437d 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -3,7 +3,7 @@ import operator from collections.abc import Sequence from enum import Enum from functools import partial, reduce -from typing import Callable, List, Optional, Tuple, Type, Union +from typing import Callable, Optional, Union import torch import torch._prims_common as utils diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 57bc423cf3d..e8339b789f5 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -12,12 +12,9 @@ from typing import ( Any, Callable, cast, - List, NamedTuple, Optional, overload, - Tuple, - Type, TYPE_CHECKING, TypeVar, Union, diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 3db1dad6c59..13434e86b34 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -10,7 +10,7 @@ import warnings from collections.abc import Iterable, Sequence from enum import Enum from functools import partial, reduce, singledispatch, wraps -from typing import Any, Callable, cast, Dict, List, Optional, overload, Tuple, Union +from typing import Any, Callable, cast, Optional, overload, Union import torch import torch._prims as prims diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py index 04187913aac..00d95445c6f 100644 --- a/torch/_refs/linalg/__init__.py +++ b/torch/_refs/linalg/__init__.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs from functools import partial -from typing import Optional, Tuple, Union +from typing import Optional, Union import torch import torch._prims as prims diff --git a/torch/_refs/nn/__init__.py b/torch/_refs/nn/__init__.py index 840ecd9ca20..c9c2ef67bd9 100644 --- a/torch/_refs/nn/__init__.py +++ b/torch/_refs/nn/__init__.py @@ -1,4 +1 @@ -from typing import List - - __all__: list[str] = [] diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index 1629cd52cae..d5de918a2ca 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -3,7 +3,8 @@ import contextlib import warnings import weakref from abc import ABC, abstractmethod -from typing import Any, Callable, ContextManager, Optional, Union +from contextlib import AbstractContextManager +from typing import Any, Callable, Optional, Union import torch import torch.utils._pytree as pytree @@ -665,7 +666,7 @@ class BaseFunctionalizeAPI(ABC): pass @abstractmethod - def redispatch_to_next(self) -> ContextManager: + def redispatch_to_next(self) -> AbstractContextManager: pass @abstractmethod @@ -709,7 +710,7 @@ class PythonFunctionalizeAPI(BaseFunctionalizeAPI): def functionalize(self, inner_f: Callable) -> Callable: return dispatch_functionalize(inner_f, self.mode) - def redispatch_to_next(self) -> ContextManager: + def redispatch_to_next(self) -> AbstractContextManager: # [NOTE] We don't do anything here because at the time # we exercise this path, we would have already popped the # FunctionalTensorMode from mode stack. Since FunctionalTensorMode @@ -753,7 +754,7 @@ class CppFunctionalizeAPI(BaseFunctionalizeAPI): def functionalize(self, inner_f: Callable) -> Callable: return torch.func.functionalize(inner_f) - def redispatch_to_next(self) -> ContextManager: + def redispatch_to_next(self) -> AbstractContextManager: return torch._C._ExcludeDispatchKeyGuard( torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) ) @@ -801,7 +802,7 @@ class FunctorchFunctionalizeAPI(BaseFunctionalizeAPI): ), ) - def redispatch_to_next(self) -> ContextManager: + def redispatch_to_next(self) -> AbstractContextManager: return self.interpreter.lower() def replace(self, input_tensor, output_tensor) -> None: diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 3ffbaa96ec5..15fa4ad0f44 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -7,12 +7,12 @@ import typing import warnings import weakref from abc import abstractmethod +from contextlib import AbstractContextManager from dataclasses import dataclass from typing import ( Any, Callable, ClassVar, - ContextManager, Generic, NewType, Optional, @@ -1741,7 +1741,9 @@ class MetaConverter(Generic[_TensorT]): # subclasses. Relevant test is # DynamicShapesFunctionTests::test_add_dynamic_shapes in # test/dynamo/test_dynamic_shapes.py - maybe_fake_mgr: ContextManager[None] = contextlib.nullcontext() + maybe_fake_mgr: AbstractContextManager[ + None + ] = contextlib.nullcontext() from torch._subclasses.fake_tensor import ( in_kernel_invocation_manager, maybe_get_fake_mode, diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index 7393bc215d7..57ed1f60f94 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional, Union import torch from torch import Tensor diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index 62aa5cd3acc..c370a0368d7 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -10,7 +10,7 @@ half, float, double and bfloat16) and complex :class:`Tensor` types (cfloat, cdo import warnings from collections.abc import Sequence -from typing import cast, List, Optional, Tuple, Union +from typing import cast, Optional, Union import torch from torch import _vmap_internals diff --git a/torch/backends/quantized/__init__.py b/torch/backends/quantized/__init__.py index 50248505f64..caabfdf2437 100644 --- a/torch/backends/quantized/__init__.py +++ b/torch/backends/quantized/__init__.py @@ -1,7 +1,6 @@ # mypy: allow-untyped-defs import sys import types -from typing import List import torch diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index 9aa06211b00..46cd6e3ff3d 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, TypeVar +from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar from typing_extensions import ParamSpec import torch diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 4192e283119..a3bb560ad85 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -17,7 +17,7 @@ import threading import traceback import warnings from functools import lru_cache -from typing import Any, Callable, cast, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Optional, Union import torch import torch._C @@ -85,7 +85,7 @@ try: paths = ["libamd_smi.so"] if rocm_home := os.getenv("ROCM_HOME", os.getenv("ROCM_PATH")): paths = [os.path.join(rocm_home, "lib/libamd_smi.so")] + paths - self.paths: List[str] = paths + self.paths: list[str] = paths def hooked_CDLL( self, name: Union[str, Path, None], *args: Any, **kwargs: Any diff --git a/torch/distributed/_composable/checkpoint_activation.py b/torch/distributed/_composable/checkpoint_activation.py index df801543af5..0fe23cab72c 100644 --- a/torch/distributed/_composable/checkpoint_activation.py +++ b/torch/distributed/_composable/checkpoint_activation.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs from collections.abc import Generator -from contextlib import contextmanager, nullcontext -from typing import Any, ContextManager, Optional +from contextlib import AbstractContextManager, contextmanager, nullcontext +from typing import Any, Optional import torch import torch.nn as nn @@ -14,7 +14,7 @@ from .contract import _State, contract @contextmanager -def _no_hook(module: nn.Module, user_ctx: Optional[ContextManager] = None): +def _no_hook(module: nn.Module, user_ctx: Optional[AbstractContextManager] = None): r""" Disable hooks installed by checkpoint to avoid unintentional recursion during backward recomputation. diff --git a/torch/distributed/_serialization.py b/torch/distributed/_serialization.py index 44c13a9edd5..4c49f2585bc 100644 --- a/torch/distributed/_serialization.py +++ b/torch/distributed/_serialization.py @@ -1,14 +1,14 @@ import pickle from dataclasses import dataclass from io import BufferedIOBase -from typing import Any, Dict, List, Tuple +from typing import Any import torch import torch._weights_only_unpickler as _weights_only_unpickler from torch.serialization import _load, _save, DEFAULT_PROTOCOL, MAP_LOCATION -__all__: List[str] = [] +__all__: list[str] = [] @dataclass @@ -23,7 +23,7 @@ _weights_only_unpickler._add_safe_globals([_Entry]) class _PseudoZipFile: def __init__(self) -> None: - self.records: Dict[str, Tuple[object, int]] = {} + self.records: dict[str, tuple[object, int]] = {} def write_record(self, key: str, data: object, length: int) -> None: self.records[key] = (data, length) diff --git a/torch/distributed/_shard/sharded_optim/__init__.py b/torch/distributed/_shard/sharded_optim/__init__.py index 7deab8d253d..8555dcd2d09 100644 --- a/torch/distributed/_shard/sharded_optim/__init__.py +++ b/torch/distributed/_shard/sharded_optim/__init__.py @@ -1,5 +1,5 @@ from collections.abc import Iterator -from typing import Tuple, Union +from typing import Union import torch.nn as nn from torch.distributed._shard.sharded_tensor import ShardedTensor diff --git a/torch/distributed/_shard/sharded_tensor/__init__.py b/torch/distributed/_shard/sharded_tensor/__init__.py index 881193cf0ce..e1e9983d526 100644 --- a/torch/distributed/_shard/sharded_tensor/__init__.py +++ b/torch/distributed/_shard/sharded_tensor/__init__.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import functools -from typing import List, TYPE_CHECKING +from typing import TYPE_CHECKING import torch from torch.distributed._shard.op_registry_utils import _decorator_func diff --git a/torch/distributed/_symmetric_memory/__init__.py b/torch/distributed/_symmetric_memory/__init__.py index f8fde2b20f8..62299e2fa5d 100644 --- a/torch/distributed/_symmetric_memory/__init__.py +++ b/torch/distributed/_symmetric_memory/__init__.py @@ -7,7 +7,7 @@ from contextlib import contextmanager from datetime import timedelta from enum import Enum from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Optional import torch import torch.distributed._functional_collectives as funcol diff --git a/torch/distributed/checkpoint/planner.py b/torch/distributed/checkpoint/planner.py index 9306801f26d..45eb7ab41a6 100644 --- a/torch/distributed/checkpoint/planner.py +++ b/torch/distributed/checkpoint/planner.py @@ -4,7 +4,7 @@ import operator from dataclasses import dataclass from enum import auto, Enum from functools import reduce -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union import torch from torch.distributed.checkpoint.metadata import ( @@ -205,21 +205,21 @@ class SavePlanner(abc.ABC): # Save plan for the current rank as computed by `create_local_plan` API # Cached on the local rank. - _cached_save_plan: Dict[str, SavePlan] = {} + _cached_save_plan: dict[str, SavePlan] = {} # Final save plan for the current rank. # This is created by merging the plan created by `create_local_plan` API # and the result of `create_global_plan` for the given rank. # This is the final plan computed by the `finish_plan` API that gets # sent to the `write_data`. # Cached on the local rank. - _cached_final_save_plan: Dict[str, SavePlan] = {} + _cached_final_save_plan: dict[str, SavePlan] = {} # Collection of all the local plans from all the ranks. # This is the input to the `create_global_plan` API. # Cached on the coordinator rank. - _cached_all_plans: Dict[str, list[SavePlan]] = {} + _cached_all_plans: dict[str, list[SavePlan]] = {} # Global checkpoint plan as computed by `create_global_plan` API. # Cached on the coordinator rank. - _cached_global_plan: Dict[str, list[SavePlan]] = {} + _cached_global_plan: dict[str, list[SavePlan]] = {} @abc.abstractmethod def set_up_planner( diff --git a/torch/distributed/elastic/events/__init__.py b/torch/distributed/elastic/events/__init__.py index 23c74458f51..e6c2a271644 100644 --- a/torch/distributed/elastic/events/__init__.py +++ b/torch/distributed/elastic/events/__init__.py @@ -24,7 +24,7 @@ import logging import os import socket import traceback -from typing import Dict, Optional +from typing import Optional from torch.distributed.elastic.events.handlers import get_logging_handler diff --git a/torch/distributed/elastic/multiprocessing/__init__.py b/torch/distributed/elastic/multiprocessing/__init__.py index 74d612ce635..d0d311d2fb4 100644 --- a/torch/distributed/elastic/multiprocessing/__init__.py +++ b/torch/distributed/elastic/multiprocessing/__init__.py @@ -62,7 +62,7 @@ was launched a :class:`api.SubprocessContext` is returned. Both are specific implementations of the parent :class:`api.PContext` class. """ -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Optional, Union from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401 _validate_full_rank, diff --git a/torch/distributed/elastic/multiprocessing/errors/__init__.py b/torch/distributed/elastic/multiprocessing/errors/__init__.py index 6703e355b7b..34b22bbd8a2 100644 --- a/torch/distributed/elastic/multiprocessing/errors/__init__.py +++ b/torch/distributed/elastic/multiprocessing/errors/__init__.py @@ -58,7 +58,7 @@ from dataclasses import dataclass, field from datetime import datetime from functools import wraps from string import Template -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar +from typing import Any, Callable, Optional, TypeVar from torch.distributed.elastic.utils.logging import get_logger diff --git a/torch/distributed/pipelining/_utils.py b/torch/distributed/pipelining/_utils.py index d13ed78d6b0..0a4da5c098b 100644 --- a/torch/distributed/pipelining/_utils.py +++ b/torch/distributed/pipelining/_utils.py @@ -2,7 +2,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import logging from dataclasses import dataclass -from typing import Dict, Union +from typing import Union import torch from torch import fx @@ -90,7 +90,7 @@ def validate_tensors_metadata( def generate_stage_to_rank_mapping( pp_size: int, num_stages: int, style: str = "loop" -) -> Dict[int, int]: +) -> dict[int, int]: """ Compute the stage id to rank mapping for either a looped or V-style schedule. diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index 7749b46e957..1462d6ad842 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -9,7 +9,7 @@ import re from abc import ABC, abstractmethod from collections import Counter, defaultdict from enum import Enum -from typing import Any, Callable, Dict, NamedTuple, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union import torch import torch.distributed as dist @@ -1025,7 +1025,7 @@ def _validate_schedule( pp_group_size: int, num_stages: int, num_microbatches: int, -) -> Dict[int, int]: +) -> dict[int, int]: assert ( len(actions) == pp_group_size ), f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}" diff --git a/torch/distributed/rpc/__init__.py b/torch/distributed/rpc/__init__.py index a17185f520a..3d71b2fc22b 100644 --- a/torch/distributed/rpc/__init__.py +++ b/torch/distributed/rpc/__init__.py @@ -5,7 +5,6 @@ import threading import warnings from collections.abc import Generator from datetime import timedelta -from typing import Tuple from urllib.parse import urlparse import torch diff --git a/torch/export/__init__.py b/torch/export/__init__.py index e38a730858f..0ab7d1a8048 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -9,17 +9,7 @@ import warnings import zipfile from collections.abc import Iterator from enum import auto, Enum -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Tuple, - Type, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch import torch.utils._pytree as pytree diff --git a/torch/export/passes/__init__.py b/torch/export/passes/__init__.py index cf9775ba8a5..4e1d21de660 100644 --- a/torch/export/passes/__init__.py +++ b/torch/export/passes/__init__.py @@ -1,4 +1,4 @@ -from typing import Dict, Union +from typing import Union import torch import torch.utils._pytree as pytree diff --git a/torch/futures/__init__.py b/torch/futures/__init__.py index 6e40f4c84e8..236165f61ef 100644 --- a/torch/futures/__init__.py +++ b/torch/futures/__init__.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs from __future__ import annotations -from typing import Callable, cast, Generic, List, Optional, TypeVar, Union +from typing import Callable, cast, Generic, Optional, TypeVar, Union import torch diff --git a/torch/fx/_graph_pickler.py b/torch/fx/_graph_pickler.py index 5a5c77cd588..809f67db263 100644 --- a/torch/fx/_graph_pickler.py +++ b/torch/fx/_graph_pickler.py @@ -3,7 +3,7 @@ import importlib import io import pickle from abc import abstractmethod -from typing import Any, Callable, Dict, NewType, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Callable, NewType, Optional, TypeVar, Union from typing_extensions import override, Self import torch @@ -45,7 +45,7 @@ class GraphPickler(pickle.Pickler): @override def reducer_override( self, obj: object - ) -> Tuple[Callable[..., Any], Tuple[Any, ...]]: + ) -> tuple[Callable[..., Any], tuple[Any, ...]]: # This function is supposed to return either NotImplemented (meaning to # do the default pickle behavior) or a pair of (unpickle callable, data # to pass to unpickle). @@ -138,13 +138,13 @@ class _GraphUnpickler(pickle.Unpickler): class _ShapeEnvPickleData: - data: Dict[str, object] + data: dict[str, object] @classmethod def reduce_helper( cls, pickler: GraphPickler, obj: ShapeEnv - ) -> Tuple[ - Callable[[Self, _UnpickleState], ShapeEnv], Tuple[Self, _UnpickleStateToken] + ) -> tuple[ + Callable[[Self, _UnpickleState], ShapeEnv], tuple[Self, _UnpickleStateToken] ]: return cls.unpickle, (cls(obj), pickler._unpickle_state) @@ -174,8 +174,8 @@ class _SymNodePickleData: cls, pickler: GraphPickler, obj: _SymNodeT, - ) -> Tuple[ - Callable[[Self, _UnpickleState], _SymNodeT], Tuple[Self, _UnpickleStateToken] + ) -> tuple[ + Callable[[Self, _UnpickleState], _SymNodeT], tuple[Self, _UnpickleStateToken] ]: args = (cls(obj.node), pickler._unpickle_state) if isinstance(obj, torch.SymInt): @@ -205,8 +205,8 @@ class _TensorPickleData: @classmethod def reduce_helper( cls, pickler: GraphPickler, obj: FakeTensor - ) -> Tuple[ - Callable[[Self, _UnpickleState], FakeTensor], Tuple[Self, _UnpickleStateToken] + ) -> tuple[ + Callable[[Self, _UnpickleState], FakeTensor], tuple[Self, _UnpickleStateToken] ]: return cls.unpickle, ( cls(pickler._meta_tensor_describer, obj), @@ -266,8 +266,8 @@ class _TorchNumpyPickleData: def reduce_helper( cls, pickler: GraphPickler, obj: object ) -> Optional[ - Tuple[ - Callable[[Self, _UnpickleState], object], Tuple[Self, _UnpickleStateToken] + tuple[ + Callable[[Self, _UnpickleState], object], tuple[Self, _UnpickleStateToken] ] ]: if data := cls.from_object(obj): @@ -309,9 +309,9 @@ class _GraphModulePickleData: @classmethod def reduce_helper( cls, pickler: GraphPickler, obj: torch.fx.GraphModule - ) -> Tuple[ + ) -> tuple[ Callable[[Self, _UnpickleState], torch.fx.GraphModule], - Tuple[Self, _UnpickleStateToken], + tuple[Self, _UnpickleStateToken], ]: return cls.unpickle, ( cls(obj), @@ -337,7 +337,7 @@ class _GraphModulePickleData: class _NodePickleData: def __init__( - self, node: torch.fx.Node, mapping: Dict[torch.fx.Node, "_NodePickleData"] + self, node: torch.fx.Node, mapping: dict[torch.fx.Node, "_NodePickleData"] ) -> None: self.args = pytree.tree_map_only(torch.fx.Node, lambda n: mapping[n], node.args) self.kwargs = pytree.tree_map_only( @@ -358,7 +358,7 @@ class _NodePickleData: def unpickle( self, graph: torch.fx.Graph, - mapping: Dict["_NodePickleData", torch.fx.Node], + mapping: dict["_NodePickleData", torch.fx.Node], unpickle_state: _UnpickleState, ) -> torch.fx.Node: args = pytree.tree_map_only(_NodePickleData, lambda n: mapping[n], self.args) @@ -376,7 +376,7 @@ class _OpPickleData: @classmethod def reduce_helper( cls, pickler: GraphPickler, op: object - ) -> Tuple[Callable[[_UnpickleState], object], Tuple[_UnpickleStateToken]]: + ) -> tuple[Callable[[_UnpickleState], object], tuple[_UnpickleStateToken]]: result = cls.pickle(op) return (result.unpickle, (pickler._unpickle_state,)) @@ -404,7 +404,7 @@ class _OpPickleData: def _pickle_op( name: str, datacls: Union[ - Type["_OpOverloadPickleData"], Type["_OpOverloadPacketPickleData"] + type["_OpOverloadPickleData"], type["_OpOverloadPacketPickleData"] ], ) -> "_OpPickleData": if not name.startswith("torch.ops.aten"): # TODO: What's the full list? @@ -501,7 +501,7 @@ class _GraphPickleData: self.tracer_cls = graph._tracer_cls self.tracer_extras = graph._tracer_extras - nodes: Dict[torch.fx.Node, _NodePickleData] = {} + nodes: dict[torch.fx.Node, _NodePickleData] = {} for node in graph.nodes: nodes[node] = _NodePickleData(node, nodes) self.nodes = tuple(nodes.values()) @@ -521,7 +521,7 @@ class _GraphPickleData: ) -> torch.fx.Graph: graph = torch.fx.Graph(gm, self.tracer_cls, self.tracer_extras) - nodes: Dict[_NodePickleData, torch.fx.Node] = {} + nodes: dict[_NodePickleData, torch.fx.Node] = {} for nd in self.nodes: nodes[nd] = nd.unpickle(graph, nodes, unpickle_state) @@ -532,9 +532,9 @@ class _TracingContextPickleData: @classmethod def reduce_helper( cls, pickler: GraphPickler, obj: torch._guards.TracingContext - ) -> Tuple[ + ) -> tuple[ Callable[[Self, _UnpickleState], torch._guards.TracingContext], - Tuple[Self, _UnpickleStateToken], + tuple[Self, _UnpickleStateToken], ]: return ( cls.unpickle, diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index 53103898ea2..e6d0a9e417b 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -23,7 +23,7 @@ import math import operator import sys from functools import lru_cache, update_wrapper -from typing import Optional, Set, TYPE_CHECKING, Union +from typing import Optional, TYPE_CHECKING, Union import torch import torch._logging.structured as structured @@ -1233,7 +1233,7 @@ def _make_node_magic(method, func): else: method_attr = method - def uninteresting_files() -> Set[str]: + def uninteresting_files() -> set[str]: import torch mods = [ diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 548b4fda920..cbc377a6113 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -36,11 +36,11 @@ if TYPE_CHECKING: # Mapping of builtins to their `typing` equivalent. # (PEP585: See D68459095 test plan) _origin_type_map = { - list: typing.List, - dict: typing.Dict, - set: typing.Set, - frozenset: typing.FrozenSet, - tuple: typing.Tuple, + list: typing.List, # noqa: UP006 + dict: typing.Dict, # noqa: UP006 + set: typing.Set, # noqa: UP006 + frozenset: typing.FrozenSet, # noqa: UP006 + tuple: typing.Tuple, # noqa: UP006 } _legal_ops = dict.fromkeys( diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 3486c80b492..76682e75229 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -1,12 +1,12 @@ # mypy: allow-untyped-defs import ast +import copy import dataclasses import inspect import re import string from collections import namedtuple from textwrap import dedent -from typing import List, Tuple # noqa: F401 import torch import torch.jit.annotations @@ -551,7 +551,7 @@ def build_ignore_context_manager(ctx, stmt): return_type_ann = " -> " + outputs[0].ann return_statement_str += outputs[0].name if len(outputs) > 1: - return_type_ann = " -> Tuple" + return_type_ann = " -> tuple" return_type_ann += "[" + ", ".join([var.ann for var in outputs]) + "]" return_statement_str += ", ".join([var.name for var in outputs]) return return_type_ann, return_statement_str @@ -581,10 +581,18 @@ def build_ignore_context_manager(ctx, stmt): return_stmt = ast.parse(return_stmt).body[0] ignore_function.body.append(return_stmt) # type: ignore[attr-defined] + ignore_func_str = f"""\ +# Backward compat: These used to be imported into the outer global scope so some +# code may still expect them. +from typing import List, Dict, Tuple + +@torch.jit.ignore +{astunparse.unparse(ignore_function)} +""" + g = copy.copy(globals()) + exec(ignore_func_str, g) # noqa: P204 # registers the custom function in the global context - ignore_func_str = "@torch.jit.ignore\n" + astunparse.unparse(ignore_function) - ignore_func_str += f'\nglobals()["{ignore_function_name}"] = {ignore_function_name}' - exec(ignore_func_str) # noqa: P204 + globals()[ignore_function_name] = g[ignore_function_name] # build the statements as: # , , ... = torch.jit.frontend.(, ) diff --git a/torch/mtia/__init__.py b/torch/mtia/__init__.py index 29df5c7c90e..b413dd4b572 100644 --- a/torch/mtia/__init__.py +++ b/torch/mtia/__init__.py @@ -5,7 +5,7 @@ This package enables an interface for accessing MTIA backend in python import threading import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import torch from torch import device as _device, Tensor diff --git a/torch/nested/__init__.py b/torch/nested/__init__.py index d66b1d94d69..433c22489f0 100644 --- a/torch/nested/__init__.py +++ b/torch/nested/__init__.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import torch import torch.nn.functional as F diff --git a/torch/nn/attention/__init__.py b/torch/nn/attention/__init__.py index 07b10629ce2..3c83bbb0d0a 100644 --- a/torch/nn/attention/__init__.py +++ b/torch/nn/attention/__init__.py @@ -2,7 +2,7 @@ """ This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """ import contextlib from collections.abc import Iterable -from typing import List, Union +from typing import Union from warnings import warn import torch.backends.cuda diff --git a/torch/onnx/_internal/exporter/_dynamic_shapes.py b/torch/onnx/_internal/exporter/_dynamic_shapes.py index f80b211cd16..735700fde31 100644 --- a/torch/onnx/_internal/exporter/_dynamic_shapes.py +++ b/torch/onnx/_internal/exporter/_dynamic_shapes.py @@ -5,7 +5,7 @@ from __future__ import annotations import inspect import warnings -from typing import Any, Sequence +from typing import Any, TYPE_CHECKING import torch from torch.export.dynamic_shapes import _Dim, _DimHint @@ -13,6 +13,10 @@ from torch.onnx._internal._lazy_import import onnxscript_ir as ir from torch.utils import _pytree +if TYPE_CHECKING: + from collections.abc import Sequence + + def from_dynamic_axes_to_dynamic_shapes( model, args: tuple[Any, ...], diff --git a/torch/serialization.py b/torch/serialization.py index 8b494aa288a..e578eb35434 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -15,7 +15,7 @@ import threading import warnings from contextlib import closing, contextmanager from enum import Enum -from typing import Any, Callable, cast, Dict, Generic, IO, Optional, TypeVar, Union +from typing import Any, Callable, cast, Generic, IO, Optional, TypeVar, Union from typing_extensions import TypeAlias, TypeIs import torch @@ -1898,7 +1898,7 @@ def _load( data_descripter_size64 = 24 data_descripter_size32 = 16 mz_uint32_max = 0xFFFFFFFF - offsets: Dict[str, int] = dict() + offsets: dict[str, int] = dict() def _get_offset(key, name, numel): """ diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py index 801334cdd8f..858cb7fbd86 100644 --- a/torch/sparse/__init__.py +++ b/torch/sparse/__init__.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs # The Tensor classes are added to this module by python_tensor.cpp # A workaround to support both TorchScript and MyPy: -from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING, Union import torch from torch import Tensor diff --git a/torch/testing/_internal/distributed/common_state_dict.py b/torch/testing/_internal/distributed/common_state_dict.py index 8d80737359d..148a1749c43 100644 --- a/torch/testing/_internal/distributed/common_state_dict.py +++ b/torch/testing/_internal/distributed/common_state_dict.py @@ -4,7 +4,7 @@ import copy from itertools import chain -from typing import Any, Dict +from typing import Any import torch import torch.nn as nn @@ -164,7 +164,7 @@ class FusionEmbeddingWithModifier(FusionEmbeddingWithHook): # _fqn_modifiers is a private function as a contract between DSD. When users change the state_dict # keys, they need to provide a mapping from the new key to the original key. This is used to ensure # consistency between the state_dict keys and fqn. - def _fqn_modifiers(self) -> Dict[str, str]: + def _fqn_modifiers(self) -> dict[str, str]: return { "weight": "embedding", } diff --git a/torch/testing/_internal/opinfo/definitions/__init__.py b/torch/testing/_internal/opinfo/definitions/__init__.py index 84bc75a74d7..f26d3f402e7 100644 --- a/torch/testing/_internal/opinfo/definitions/__init__.py +++ b/torch/testing/_internal/opinfo/definitions/__init__.py @@ -1,7 +1,5 @@ # mypy: ignore-errors -from typing import List - from torch.testing._internal.opinfo.core import OpInfo from torch.testing._internal.opinfo.definitions import ( _masked, diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py index 5758fc565af..cb18d8064b0 100644 --- a/torch/utils/data/dataset.py +++ b/torch/utils/data/dataset.py @@ -3,8 +3,8 @@ import bisect import itertools import math import warnings -from collections.abc import Sequence -from typing import cast, Generic, Iterable, Optional, TypeVar, Union +from collections.abc import Iterable, Sequence +from typing import cast, Generic, Optional, TypeVar, Union from typing_extensions import deprecated # No 'default_generator' in torch/__init__.pyi diff --git a/torch/utils/model_dump/__init__.py b/torch/utils/model_dump/__init__.py index 6e0c3a1246b..5ab8fd9a35e 100644 --- a/torch/utils/model_dump/__init__.py +++ b/torch/utils/model_dump/__init__.py @@ -75,7 +75,6 @@ import sys import urllib.parse import zipfile from pathlib import Path -from typing import Dict import warnings import torch.utils.show_pickle diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index 2163cdd8018..67aa865ce6c 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -9,7 +9,7 @@ This package is lazily initialized, so you can always import it, and use import threading import traceback from functools import lru_cache -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import torch import torch._C