mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
PEP585: More UP006 fixes (#146392)
This should be the final PR before we can enable RUFF UP006. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146392 Approved by: https://github.com/justinchuby, https://github.com/albanD, https://github.com/Skylion007
This commit is contained in:
parent
76ad19a549
commit
db4ce78d46
|
|
@ -1,7 +1,6 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -119,7 +118,7 @@ def _find_name_param_mappings(module: torch.nn.Module, prefix: str):
|
|||
|
||||
|
||||
def _discover_ddp_ignored_params(module: torch.nn.Module, prefix: str):
|
||||
ddp_ignore_parameters: List[str] = []
|
||||
ddp_ignore_parameters: list[str] = []
|
||||
if isinstance(module, FSDP2):
|
||||
ddp_ignore_parameters = [name for name, _ in module.named_parameters(prefix)]
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import gc
|
|||
import threading
|
||||
import unittest
|
||||
from datetime import timedelta
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -576,24 +576,24 @@ class ProcessGroupDummy(dist.ProcessGroup):
|
|||
self.waits = 0
|
||||
self.dels = 0
|
||||
|
||||
def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> dist.Work:
|
||||
def broadcast(self, tensor_list: list[torch.Tensor], opts: object) -> dist.Work:
|
||||
return _DummyWork(self)
|
||||
|
||||
def allgather_into_tensor_coalesced(
|
||||
self,
|
||||
output_lists: List[torch.Tensor],
|
||||
input_list: List[torch.Tensor],
|
||||
output_lists: list[torch.Tensor],
|
||||
input_list: list[torch.Tensor],
|
||||
opts: object,
|
||||
) -> dist.Work:
|
||||
return _DummyWork(self)
|
||||
|
||||
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> dist.Work:
|
||||
def allreduce(self, tensors: list[torch.Tensor], opts: object) -> dist.Work:
|
||||
return _DummyWork(self)
|
||||
|
||||
def reduce_scatter_tensor_coalesced(
|
||||
self,
|
||||
outputTensors: List[torch.Tensor],
|
||||
inputTensors: List[torch.Tensor],
|
||||
outputTensors: list[torch.Tensor],
|
||||
inputTensors: list[torch.Tensor],
|
||||
opts: object,
|
||||
) -> dist.Work:
|
||||
return _DummyWork(self)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# Owner(s): ["module: functorch"]
|
||||
from typing import Dict, List, Tuple
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from torch._functorch._activation_checkpointing.ac_logging_utils import (
|
||||
|
|
@ -33,17 +32,17 @@ class TestAcLogging(TestCase):
|
|||
|
||||
self.graph.nodes = [self.node1, self.node2]
|
||||
|
||||
self.all_recomputable_banned_nodes: List[Node] = [self.node1]
|
||||
self.saved_node_idxs: List[int] = [0]
|
||||
self.recomputable_node_idxs: List[int] = []
|
||||
self.all_recomputable_banned_nodes: list[Node] = [self.node1]
|
||||
self.saved_node_idxs: list[int] = [0]
|
||||
self.recomputable_node_idxs: list[int] = []
|
||||
self.expected_runtime: int = 100
|
||||
self.memories_banned_nodes: List[int] = [50]
|
||||
self.runtimes_banned_nodes: List[int] = [10]
|
||||
self.min_cut_saved_values: List[Node] = [self.node1]
|
||||
self.memories_banned_nodes: list[int] = [50]
|
||||
self.runtimes_banned_nodes: list[int] = [10]
|
||||
self.min_cut_saved_values: list[Node] = [self.node1]
|
||||
|
||||
def test_create_joint_graph_node_information(self) -> None:
|
||||
recomputable_node_info: Dict[str, int] = {"node1": 0}
|
||||
expected_output: Dict[str, Dict] = {
|
||||
recomputable_node_info: dict[str, int] = {"node1": 0}
|
||||
expected_output: dict[str, dict] = {
|
||||
"node1": {
|
||||
"index": 0,
|
||||
"name": "node1",
|
||||
|
|
@ -68,12 +67,12 @@ class TestAcLogging(TestCase):
|
|||
self.assertEqual(result, expected_output)
|
||||
|
||||
def test_create_joint_graph_edges(self) -> None:
|
||||
expected_edges: List[Tuple[str, str]] = [("node1", "node2")]
|
||||
expected_edges: list[tuple[str, str]] = [("node1", "node2")]
|
||||
result = create_joint_graph_edges(self.graph)
|
||||
self.assertEqual(result, expected_edges)
|
||||
|
||||
def test_create_activation_checkpointing_logging_structure_payload(self) -> None:
|
||||
input_joint_graph_node_information: Dict[str, Dict] = {
|
||||
input_joint_graph_node_information: dict[str, dict] = {
|
||||
"node1": {
|
||||
"index": 0,
|
||||
"name": "node1",
|
||||
|
|
@ -85,8 +84,8 @@ class TestAcLogging(TestCase):
|
|||
"recomputable_candidate_info": {"recomputable_node_idx": 0},
|
||||
}
|
||||
}
|
||||
joint_graph_edges: List[Tuple[str, str]] = [("node1", "node2")]
|
||||
expected_payload: Dict[str, any] = {
|
||||
joint_graph_edges: list[tuple[str, str]] = [("node1", "node2")]
|
||||
expected_payload: dict[str, any] = {
|
||||
"Joint Graph Size": 2,
|
||||
"Joint Graph Edges": {"Total": 1, "Edges": joint_graph_edges},
|
||||
"Joint Graph Node Information": input_joint_graph_node_information,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
"""Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime."""
|
||||
|
||||
from typing import Sequence
|
||||
from typing import Sequence # noqa: UP035
|
||||
|
||||
import onnx_test_common
|
||||
import onnxscript
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import platform
|
|||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any, Mapping
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -20,6 +20,10 @@ import onnxscript
|
|||
import torch
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping
|
||||
|
||||
|
||||
_REPRODUCTION_TEMPLATE = '''\
|
||||
import google.protobuf.text_format
|
||||
import numpy as np
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import datetime
|
||||
from enum import Enum
|
||||
from types import TracebackType
|
||||
from typing import Callable, Optional, Type
|
||||
from typing import Callable, Optional
|
||||
|
||||
class Aggregation(Enum):
|
||||
VALUE = ...
|
||||
|
|
@ -48,7 +48,7 @@ class _WaitCounterTracker:
|
|||
def __enter__(self) -> None: ...
|
||||
def __exit__(
|
||||
self,
|
||||
exec_type: Optional[Type[BaseException]] = None,
|
||||
exec_type: Optional[type[BaseException]] = None,
|
||||
exec_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
) -> None: ...
|
||||
|
|
|
|||
|
|
@ -4,17 +4,7 @@ from collections import defaultdict
|
|||
from collections.abc import Sequence
|
||||
from functools import lru_cache, partial, wraps
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ Python polyfills for common builtins.
|
|||
import types
|
||||
from collections.abc import MutableMapping, Sequence
|
||||
from itertools import repeat as _repeat
|
||||
from typing import Any, Callable, List, TYPE_CHECKING
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,6 @@ from typing import (
|
|||
Generic,
|
||||
Optional,
|
||||
overload,
|
||||
Set,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
|
@ -1393,7 +1392,7 @@ def _scrubbed_inductor_config_for_logging() -> Optional[str]:
|
|||
except Exception:
|
||||
return "Value is not JSON serializable"
|
||||
|
||||
keys_to_scrub: Set[Any] = set()
|
||||
keys_to_scrub: set[Any] = set()
|
||||
inductor_conf_str = None
|
||||
inductor_config_copy = (
|
||||
torch._inductor.config.get_config_copy() if torch._inductor.config else None
|
||||
|
|
|
|||
|
|
@ -16,8 +16,9 @@ computations.
|
|||
"""
|
||||
|
||||
import collections
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
|
||||
from .. import variables
|
||||
from ..current_scope_id import current_scope_id
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ import sys
|
|||
import types
|
||||
import typing
|
||||
from collections import defaultdict, OrderedDict
|
||||
from collections.abc import KeysView
|
||||
from typing import Callable, Sequence, TYPE_CHECKING, Union
|
||||
from collections.abc import KeysView, Sequence
|
||||
from typing import Callable, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch import sym_float, sym_int
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ import itertools
|
|||
import sys
|
||||
import types
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, TypeVar
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
|
||||
from typing_extensions import Never
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -517,7 +517,7 @@ class LocalGeneratorObjectVariable(VariableTracker):
|
|||
def has_force_unpack_var_sequence(self, tx) -> builtins.bool:
|
||||
return True
|
||||
|
||||
def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
|
||||
def force_unpack_var_sequence(self, tx) -> list[VariableTracker]:
|
||||
result = []
|
||||
while True:
|
||||
try:
|
||||
|
|
@ -547,8 +547,8 @@ class LocalGeneratorObjectVariable(VariableTracker):
|
|||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: "List[VariableTracker]",
|
||||
kwargs: "Dict[str, VariableTracker]",
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
if name == "__next__":
|
||||
return self.next_variable(tx)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from collections import OrderedDict
|
|||
from contextlib import contextmanager
|
||||
from functools import lru_cache
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any
|
||||
|
||||
from torch._logging import trace_structured
|
||||
from torch.fx import Graph, Node
|
||||
|
|
@ -11,9 +11,9 @@ log: logging.Logger = logging.getLogger(__name__)
|
|||
|
||||
def create_joint_graph_node_information(
|
||||
joint_graph: Graph,
|
||||
recomputable_node_info: Dict[str, int],
|
||||
) -> Dict[str, Any]:
|
||||
joint_graph_node_information: Dict[str, Any] = {}
|
||||
recomputable_node_info: dict[str, int],
|
||||
) -> dict[str, Any]:
|
||||
joint_graph_node_information: dict[str, Any] = {}
|
||||
|
||||
for i, joint_graph_node in enumerate(joint_graph.nodes):
|
||||
is_recomputable_candidate: bool = (
|
||||
|
|
@ -22,7 +22,7 @@ def create_joint_graph_node_information(
|
|||
tensor_meta = joint_graph_node.meta.get("tensor_meta")
|
||||
shape = getattr(tensor_meta, "shape", []) if tensor_meta else []
|
||||
|
||||
node_info: Dict[str, Any] = {
|
||||
node_info: dict[str, Any] = {
|
||||
"index": i,
|
||||
"name": joint_graph_node.name,
|
||||
"is_recomputable_candidate": is_recomputable_candidate,
|
||||
|
|
@ -43,8 +43,8 @@ def create_joint_graph_node_information(
|
|||
return joint_graph_node_information
|
||||
|
||||
|
||||
def create_joint_graph_edges(joint_graph: Graph) -> List[Tuple[str, str]]:
|
||||
joint_graph_edges: List[Tuple[str, str]] = [
|
||||
def create_joint_graph_edges(joint_graph: Graph) -> list[tuple[str, str]]:
|
||||
joint_graph_edges: list[tuple[str, str]] = [
|
||||
(inp.name, node.name)
|
||||
for node in joint_graph.nodes
|
||||
for inp in node.all_input_nodes
|
||||
|
|
@ -54,17 +54,17 @@ def create_joint_graph_edges(joint_graph: Graph) -> List[Tuple[str, str]]:
|
|||
|
||||
def create_activation_checkpointing_logging_structure_payload(
|
||||
joint_graph: Graph,
|
||||
joint_graph_node_information: Dict[str, Any],
|
||||
joint_graph_edges: List[Tuple[str, str]],
|
||||
all_recomputable_banned_nodes: List[Node],
|
||||
joint_graph_node_information: dict[str, Any],
|
||||
joint_graph_edges: list[tuple[str, str]],
|
||||
all_recomputable_banned_nodes: list[Node],
|
||||
expected_runtime: float,
|
||||
saved_node_idxs: List[int],
|
||||
recomputable_node_idxs: List[int],
|
||||
memories_banned_nodes: List[float],
|
||||
runtimes_banned_nodes: List[float],
|
||||
min_cut_saved_values: List[Node],
|
||||
) -> Dict[str, Any]:
|
||||
activation_checkpointing_logging_structure_payload: Dict[str, Any] = {
|
||||
saved_node_idxs: list[int],
|
||||
recomputable_node_idxs: list[int],
|
||||
memories_banned_nodes: list[float],
|
||||
runtimes_banned_nodes: list[float],
|
||||
min_cut_saved_values: list[Node],
|
||||
) -> dict[str, Any]:
|
||||
activation_checkpointing_logging_structure_payload: dict[str, Any] = {
|
||||
"Joint Graph Size": len(joint_graph.nodes),
|
||||
"Joint Graph Edges": {
|
||||
"Total": len(joint_graph_edges),
|
||||
|
|
@ -86,15 +86,15 @@ def create_activation_checkpointing_logging_structure_payload(
|
|||
|
||||
def create_structured_trace_for_min_cut_info(
|
||||
joint_graph: Graph,
|
||||
all_recomputable_banned_nodes: List[Node],
|
||||
saved_node_idxs: List[int],
|
||||
recomputable_node_idxs: List[int],
|
||||
all_recomputable_banned_nodes: list[Node],
|
||||
saved_node_idxs: list[int],
|
||||
recomputable_node_idxs: list[int],
|
||||
expected_runtime: float,
|
||||
memories_banned_nodes: List[float],
|
||||
runtimes_banned_nodes: List[float],
|
||||
min_cut_saved_values: List[Node],
|
||||
memories_banned_nodes: list[float],
|
||||
runtimes_banned_nodes: list[float],
|
||||
min_cut_saved_values: list[Node],
|
||||
) -> None:
|
||||
recomputable_node_info: Dict[str, int] = {
|
||||
recomputable_node_info: dict[str, int] = {
|
||||
node.name: idx for idx, node in enumerate(all_recomputable_banned_nodes)
|
||||
}
|
||||
joint_graph_node_information = create_joint_graph_node_information(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import dataclasses
|
||||
import itertools
|
||||
from typing import Any, Iterable, Union
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
|
|
|
|||
|
|
@ -8,8 +8,6 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._ops import HigherOrderOperator
|
||||
|
|
@ -50,9 +48,9 @@ class AOTICallDelegate(HigherOrderOperator):
|
|||
self,
|
||||
lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type]
|
||||
original_gm: torch.fx.GraphModule,
|
||||
weight_args: List[torch.Tensor],
|
||||
input_args: List[torch.Tensor],
|
||||
) -> List[torch.Tensor]:
|
||||
weight_args: list[torch.Tensor],
|
||||
input_args: list[torch.Tensor],
|
||||
) -> list[torch.Tensor]:
|
||||
return super().__call__(lowered_module, original_gm, weight_args, input_args)
|
||||
|
||||
|
||||
|
|
@ -68,9 +66,9 @@ aoti_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU)
|
|||
def call_delegate_cpu(
|
||||
lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type]
|
||||
original_gm: torch.fx.GraphModule,
|
||||
weight_args: List[torch.Tensor],
|
||||
input_args: List[torch.Tensor],
|
||||
) -> List[torch.Tensor]:
|
||||
weight_args: list[torch.Tensor],
|
||||
input_args: list[torch.Tensor],
|
||||
) -> list[torch.Tensor]:
|
||||
# FX creates this immutable_dict/list concept. Get rid of this.
|
||||
map_types = {
|
||||
torch.fx.immutable_collections.immutable_dict: dict,
|
||||
|
|
@ -104,8 +102,8 @@ def call_delegate_fake_tensor_mode(
|
|||
mode: FakeTensorMode,
|
||||
lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type]
|
||||
original_gm: torch.fx.GraphModule,
|
||||
weight_args: List[torch.Tensor],
|
||||
input_args: List[torch.Tensor],
|
||||
) -> List[torch.Tensor]:
|
||||
weight_args: list[torch.Tensor],
|
||||
input_args: list[torch.Tensor],
|
||||
) -> list[torch.Tensor]:
|
||||
with mode:
|
||||
return call_delegate_cpu(lowered_module, original_gm, weight_args, input_args)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import functools
|
||||
from contextlib import contextmanager, ExitStack
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.fx.traceback as fx_traceback
|
||||
|
|
@ -173,7 +173,7 @@ def _detect_input_alias(gm: torch.fx.GraphModule) -> bool:
|
|||
|
||||
|
||||
# The invariant here is that we always trace the branch with fake tensor
|
||||
def _maybe_fake_tracing(fn, inputs: List[Any], pre_dispatch):
|
||||
def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch):
|
||||
fake_mode = detect_fake_mode(inputs)
|
||||
tracing_mode = "real"
|
||||
if fake_mode is None:
|
||||
|
|
@ -565,8 +565,8 @@ def validate_subgraph_args_types(lifted_args: Union[tuple[Any, ...], list[Any]])
|
|||
|
||||
def check_input_alias_and_mutation(
|
||||
gm: torch.fx.GraphModule,
|
||||
fake_args: List[FakeTensor],
|
||||
) -> Tuple[List[int], dict[int, int], dict[int, int], dict[int, int]]:
|
||||
fake_args: list[FakeTensor],
|
||||
) -> tuple[list[int], dict[int, int], dict[int, int], dict[int, int]]:
|
||||
with disable_proxy_modes_tracing():
|
||||
"""This function returns mutated inputs, inp-inp alias, inp-out alias, out-out alias
|
||||
in the graph module gm. It checks whether input tensor versions have
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
import io
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, IO, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Any, IO, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch._inductor.config
|
||||
import torch.fx
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@ from typing import (
|
|||
cast,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
|
|
@ -523,7 +522,7 @@ class FxGraphCachePickler(pickle.Pickler):
|
|||
|
||||
def _reduce_tensor(
|
||||
self, t: Tensor
|
||||
) -> Tuple[Callable[[T], T], Tuple[Union[TensorMetadata, TensorMetadataAndValues]]]:
|
||||
) -> tuple[Callable[[T], T], tuple[Union[TensorMetadata, TensorMetadataAndValues]]]:
|
||||
"""
|
||||
Custom reducer to pickle Tensors. If we see tensors, we know they're constants
|
||||
stored as attributes on the GraphModule.
|
||||
|
|
|
|||
|
|
@ -18,8 +18,6 @@ from typing import (
|
|||
cast,
|
||||
ClassVar,
|
||||
Generic,
|
||||
Iterator,
|
||||
MutableMapping,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
|
|
@ -59,7 +57,7 @@ from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Iterator, MutableMapping, Sequence
|
||||
|
||||
from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode
|
||||
from ..loop_body import LoopBody
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import os
|
||||
from itertools import chain, count, zip_longest
|
||||
from typing import Any, Callable, Hashable, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import sympy
|
||||
|
||||
|
|
@ -24,6 +24,8 @@ from .wrapper import PythonWrapperCodegen, SymbolicCallArg
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Hashable
|
||||
|
||||
from ..graph import GraphLowering
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from ..scheduler import (
|
||||
BaseSchedulerNode,
|
||||
|
|
@ -114,7 +114,7 @@ class CUDACombinedScheduling(BaseScheduling):
|
|||
|
||||
def benchmark_fused_nodes(
|
||||
self, nodes: Sequence[BaseSchedulerNode]
|
||||
) -> Tuple[float, str]:
|
||||
) -> tuple[float, str]:
|
||||
return self._triton_scheduling.benchmark_fused_nodes(nodes)
|
||||
|
||||
def benchmark_codegened_module(self, module):
|
||||
|
|
@ -129,5 +129,5 @@ class CUDACombinedScheduling(BaseScheduling):
|
|||
|
||||
def benchmark_combo_kernel(
|
||||
self, node_list: Sequence[BaseSchedulerNode]
|
||||
) -> tuple[float, float, List[Optional[str]]]:
|
||||
) -> tuple[float, float, list[Optional[str]]]:
|
||||
return self._triton_scheduling.benchmark_combo_kernel(node_list)
|
||||
|
|
|
|||
|
|
@ -11,16 +11,7 @@ import math
|
|||
import operator
|
||||
import textwrap
|
||||
from collections import Counter
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Iterator,
|
||||
no_type_check,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Generic, no_type_check, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import sympy
|
||||
|
|
@ -72,7 +63,7 @@ from .simd_kernel_features import (
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Sequence
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import dataclasses
|
|||
import functools
|
||||
import itertools
|
||||
import typing
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import sympy
|
||||
|
||||
|
|
@ -21,6 +21,10 @@ from ..utils import cache_on_self
|
|||
from ..virtualized import V
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
|
||||
class NodeScheduleMarker:
|
||||
@staticmethod
|
||||
def only_nodes(it: Iterable[NodeScheduleEntry]) -> Iterable[SchedulerNode]:
|
||||
|
|
@ -81,7 +85,7 @@ class SIMDKernelFeatures:
|
|||
# numel excludes reduction_numel
|
||||
self.numel: sympy.Expr = V.graph.sizevars.simplify(numel)
|
||||
self.reduction_numel: sympy.Expr = V.graph.sizevars.simplify(reduction_numel)
|
||||
self._stats_cache: Dict[Tuple[sympy.Expr, ...], MemoryStats] = {}
|
||||
self._stats_cache: dict[tuple[sympy.Expr, ...], MemoryStats] = {}
|
||||
|
||||
@cache_on_self
|
||||
def is_reduction(self) -> bool:
|
||||
|
|
@ -205,7 +209,7 @@ class SIMDKernelFeatures:
|
|||
return node.node.data.reduction_hint
|
||||
|
||||
def memory_stats(
|
||||
self, groups_dict: Optional[Dict[str, sympy.Expr]] = None
|
||||
self, groups_dict: Optional[dict[str, sympy.Expr]] = None
|
||||
) -> MemoryStats:
|
||||
"""Analysis to generate features that can be used in heuristics"""
|
||||
if groups_dict is None:
|
||||
|
|
@ -228,11 +232,11 @@ class MemoryEstimator:
|
|||
We simulate the memory effects of CSE/buffer elimination in codegen.
|
||||
"""
|
||||
|
||||
kernel_sizes: Tuple[sympy.Expr, ...]
|
||||
kernel_sizes: tuple[sympy.Expr, ...]
|
||||
outside_loop: MemoryEstimate
|
||||
loops: List[MemoryEstimate]
|
||||
loops: list[MemoryEstimate]
|
||||
persistent: MemoryEstimate
|
||||
symbols: List[sympy.Symbol]
|
||||
symbols: list[sympy.Symbol]
|
||||
|
||||
def __init__(self, features: SIMDKernelFeatures, groups: Sequence[sympy.Expr]):
|
||||
self.features = features
|
||||
|
|
@ -341,7 +345,7 @@ class MemoryEstimator:
|
|||
return True
|
||||
return False
|
||||
|
||||
def set_ranges(self, *lengths: List[List[sympy.Expr]]) -> List[List[sympy.Expr]]:
|
||||
def set_ranges(self, *lengths: list[list[sympy.Expr]]) -> list[list[sympy.Expr]]:
|
||||
assert len(self.kernel_sizes) == len(lengths)
|
||||
return [
|
||||
self.make_flat_range(sym, numel, length)
|
||||
|
|
@ -350,8 +354,8 @@ class MemoryEstimator:
|
|||
|
||||
@staticmethod
|
||||
def make_flat_range(
|
||||
sym: sympy.Symbol, numel: sympy.Expr, lengths: List[sympy.Expr]
|
||||
) -> List[sympy.Expr]:
|
||||
sym: sympy.Symbol, numel: sympy.Expr, lengths: list[sympy.Expr]
|
||||
) -> list[sympy.Expr]:
|
||||
if len(lengths) == 1 and numel == lengths[0]:
|
||||
return [sym]
|
||||
divisor = sympy.S.One
|
||||
|
|
@ -370,10 +374,10 @@ class MemoryEstimator:
|
|||
class MemoryEstimate:
|
||||
"""Tracks the memory usage of a single loop in the generated kernel"""
|
||||
|
||||
reads: Dict[str, OrderedSet[MemoryDep]] = dataclasses.field(
|
||||
reads: dict[str, OrderedSet[MemoryDep]] = dataclasses.field(
|
||||
default_factory=functools.partial(collections.defaultdict, OrderedSet)
|
||||
)
|
||||
writes: Dict[str, OrderedSet[MemoryDep]] = dataclasses.field(
|
||||
writes: dict[str, OrderedSet[MemoryDep]] = dataclasses.field(
|
||||
default_factory=functools.partial(collections.defaultdict, OrderedSet)
|
||||
)
|
||||
|
||||
|
|
@ -474,8 +478,8 @@ class StatsForLoop:
|
|||
class StatsForReadsOrWrites:
|
||||
"""Memory usage stats that are collected for reads/writes/both"""
|
||||
|
||||
dim: List[StatsForDim]
|
||||
loop: List[StatsForLoop]
|
||||
dim: list[StatsForDim]
|
||||
loop: list[StatsForLoop]
|
||||
# total bytes contiguous in any dimension
|
||||
bytes_contiguous_or_broadcast: sympy.Expr = sympy.S.Zero
|
||||
bytes_non_contiguous: sympy.Expr = sympy.S.Zero
|
||||
|
|
@ -506,8 +510,8 @@ class StatsForReadsOrWrites:
|
|||
@classmethod
|
||||
def compute(
|
||||
cls,
|
||||
loop_deps: List[Dict[str, OrderedSet[MemoryDep]]],
|
||||
index_symbols: List[sympy.Symbol],
|
||||
loop_deps: list[dict[str, OrderedSet[MemoryDep]]],
|
||||
index_symbols: list[sympy.Symbol],
|
||||
) -> typing.Self:
|
||||
ndim = len(index_symbols)
|
||||
result = cls(dim := [StatsForDim() for _ in range(ndim)], [])
|
||||
|
|
@ -521,7 +525,7 @@ class StatsForReadsOrWrites:
|
|||
loop_stats.count_per_thread += len(deps)
|
||||
loop_stats.bytes_per_thread += itemsize * len(deps)
|
||||
for dep in deps:
|
||||
strides: List[sympy.Expr] = V.graph.sizevars.stride_vars(
|
||||
strides: list[sympy.Expr] = V.graph.sizevars.stride_vars(
|
||||
dep.index, index_symbols
|
||||
)
|
||||
for i in range(ndim):
|
||||
|
|
@ -568,7 +572,7 @@ class StatsForKernelType:
|
|||
|
||||
@classmethod
|
||||
def compute(
|
||||
cls, loops: List[MemoryEstimate], estimator: MemoryEstimator
|
||||
cls, loops: list[MemoryEstimate], estimator: MemoryEstimator
|
||||
) -> typing.Self:
|
||||
reads = StatsForReadsOrWrites.compute(
|
||||
[loop.reads for loop in loops], estimator.symbols
|
||||
|
|
|
|||
|
|
@ -12,19 +12,11 @@ import sys
|
|||
import time
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import AbstractContextManager
|
||||
from dataclasses import dataclass
|
||||
from inspect import currentframe
|
||||
from itertools import count
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ContextManager,
|
||||
Mapping,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import Never, override, ParamSpec, Protocol, TypedDict, Unpack
|
||||
from unittest import mock
|
||||
|
||||
|
|
@ -127,7 +119,7 @@ from .virtualized import V
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator, Sequence
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
|
||||
from torch._inductor.output_code import _StrideExprStr
|
||||
from torch._ops import OpOverload
|
||||
|
|
@ -480,7 +472,7 @@ def is_tf32_warning_applicable(gm: GraphModule) -> bool:
|
|||
|
||||
def maybe_disable_comprehensive_padding(
|
||||
example_inputs: Sequence[InputType],
|
||||
) -> contextlib.AbstractContextManager[None, None]:
|
||||
) -> AbstractContextManager[None, None]:
|
||||
"""
|
||||
For CPU backend, enable comprehensive padding causes some unit tests
|
||||
fail due to changing number of generated kernels. Skip for now.
|
||||
|
|
@ -1780,7 +1772,7 @@ def get_cpp_wrapper_config() -> dict[str, object]:
|
|||
}
|
||||
|
||||
|
||||
def get_cuda_device_context(gm: torch.fx.GraphModule) -> ContextManager[None]:
|
||||
def get_cuda_device_context(gm: torch.fx.GraphModule) -> AbstractContextManager[None]:
|
||||
"""
|
||||
Returns a cuda device context manager if there is a single device in the graph
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -48,17 +48,9 @@ import traceback
|
|||
import warnings
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from contextlib import AbstractContextManager
|
||||
from enum import auto, Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
ContextManager,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
import torch.fx
|
||||
from torch import Tensor
|
||||
|
|
@ -179,7 +171,7 @@ def enable_history_recording() -> Generator[None, None, None]:
|
|||
torch.cuda.memory._record_memory_history(None)
|
||||
|
||||
|
||||
def get_history_recording() -> ContextManager[None]:
|
||||
def get_history_recording() -> AbstractContextManager[None]:
|
||||
# TODO - remove, prevents cleanup
|
||||
if not config.triton.cudagraph_trees_history_recording:
|
||||
return contextlib.nullcontext()
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ import dataclasses
|
|||
import itertools
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
from typing_extensions import Self
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -80,7 +80,7 @@ class MemoryDep(Dep):
|
|||
def num_vars(self) -> int:
|
||||
return len(self.var_names)
|
||||
|
||||
def decide_loop_order_to_match(self, other: "MemoryDep") -> Optional[List[int]]:
|
||||
def decide_loop_order_to_match(self, other: "MemoryDep") -> Optional[list[int]]:
|
||||
"""
|
||||
Can return None if not able to decide loop orders.
|
||||
"""
|
||||
|
|
@ -451,8 +451,8 @@ class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
|
|||
@staticmethod
|
||||
def drop_unused_symbols(
|
||||
index: Union[int, sympy.Expr],
|
||||
var_names: List[sympy.Expr],
|
||||
sizes: List[sympy.Expr],
|
||||
var_names: list[sympy.Expr],
|
||||
sizes: list[sympy.Expr],
|
||||
) -> None:
|
||||
"""
|
||||
Reduction has last (reduced) dim in its sizes, but
|
||||
|
|
@ -571,7 +571,7 @@ def var_builder(prefix: str) -> tuple[VarRanges, Callable[[sympy.Expr], sympy.Sy
|
|||
|
||||
def index_vars_no_squeeze(
|
||||
*argsizes: Sequence[sympy.Expr], prefix: str
|
||||
) -> Tuple[List[List[sympy.Symbol]], VarRanges]:
|
||||
) -> tuple[list[list[sympy.Symbol]], VarRanges]:
|
||||
var_ranges, add_var = var_builder(prefix)
|
||||
args: list[list[sympy.Symbol]] = [list(map(add_var, size)) for size in argsizes]
|
||||
return args, var_ranges
|
||||
|
|
@ -579,7 +579,7 @@ def index_vars_no_squeeze(
|
|||
|
||||
def index_vars_squeeze(
|
||||
*argsizes: Sequence[sympy.Expr], prefix: str = "d"
|
||||
) -> Tuple[List[List[sympy.Expr]], VarRanges]:
|
||||
) -> tuple[list[list[sympy.Expr]], VarRanges]:
|
||||
from .ir import SqueezeView
|
||||
|
||||
var_ranges, add_var = var_builder(prefix)
|
||||
|
|
@ -597,7 +597,7 @@ def extract_read_writes(
|
|||
*argsizes: Sequence[sympy.Expr],
|
||||
normalize: bool = False,
|
||||
prefix: str = "d",
|
||||
hidden_args: Sequence[List[sympy.Expr]] = (),
|
||||
hidden_args: Sequence[list[sympy.Expr]] = (),
|
||||
) -> ReadWrites:
|
||||
args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix)
|
||||
|
||||
|
|
@ -630,7 +630,7 @@ def extract_read_writes(
|
|||
|
||||
def extract_loop_body_with_args(
|
||||
fn: Any,
|
||||
args: List[List[sympy.Expr]],
|
||||
args: list[list[sympy.Expr]],
|
||||
var_ranges: VarRanges,
|
||||
normalize: bool = False,
|
||||
) -> _RecordLoadStoreInner:
|
||||
|
|
@ -761,17 +761,17 @@ class FreeUnbackedSymbolsOpsHandler(DefaultHandler):
|
|||
self.symbols |= free_unbacked_symbols(size)
|
||||
return sympy_index_symbol(f"({str(index_var)})")
|
||||
|
||||
def frexp(self, x: Any) -> Tuple[None, ...]:
|
||||
def frexp(self, x: Any) -> tuple[None, ...]:
|
||||
return (None,) * 2
|
||||
|
||||
def scan(
|
||||
self, dtypes: Any, combine_fn: Any, values: Sequence[Any]
|
||||
) -> Tuple[None, ...]:
|
||||
) -> tuple[None, ...]:
|
||||
return (None,) * len(values)
|
||||
|
||||
def sort(
|
||||
self, dtypes: Any, values: Sequence[Any], stable: Any, descending: Any
|
||||
) -> Tuple[None, ...]:
|
||||
) -> tuple[None, ...]:
|
||||
return (None,) * len(values)
|
||||
|
||||
def reduction(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import contextlib
|
||||
import threading
|
||||
from typing import Any, Generator
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import signal
|
|||
import string
|
||||
import sys
|
||||
import traceback
|
||||
from collections.abc import KeysView
|
||||
from collections.abc import KeysView, Sequence
|
||||
from enum import Enum
|
||||
from functools import partial, wraps
|
||||
from types import FrameType
|
||||
|
|
@ -18,7 +18,6 @@ from typing import (
|
|||
get_origin,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ import functools
|
|||
import itertools
|
||||
import operator
|
||||
import typing
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._inductor.runtime.runtime_utils
|
||||
|
|
@ -275,7 +276,7 @@ def should_pad_bench_key(
|
|||
input: Optional[Tensor] = None,
|
||||
is_base_time_key: bool = False,
|
||||
) -> str:
|
||||
def tensor_key(t: Tensor) -> Tuple[torch.Size, Tuple[int, ...], torch.dtype]:
|
||||
def tensor_key(t: Tensor) -> tuple[torch.Size, tuple[int, ...], torch.dtype]:
|
||||
return (t.shape, t.stride(), t.dtype)
|
||||
|
||||
tf32_key = (
|
||||
|
|
@ -436,8 +437,8 @@ def _should_pad_bench(
|
|||
return False
|
||||
|
||||
def realize_symbols(
|
||||
ds: Union[torch.Size, Tuple[torch.SymInt, ...]]
|
||||
) -> List[int]:
|
||||
ds: Union[torch.Size, tuple[torch.SymInt, ...]]
|
||||
) -> list[int]:
|
||||
return [d if isinstance(d, int) else d.node.hint for d in ds]
|
||||
|
||||
if any(
|
||||
|
|
|
|||
|
|
@ -9,14 +9,13 @@ import textwrap
|
|||
import traceback
|
||||
import typing
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from contextlib import nullcontext
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
ContextManager,
|
||||
Literal,
|
||||
Optional,
|
||||
overload,
|
||||
|
|
@ -6710,7 +6709,7 @@ class FallbackKernel(ExternKernelAlloc):
|
|||
@classmethod
|
||||
def create(cls, kernel, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,)
|
||||
context: ContextManager[None] = (
|
||||
context: AbstractContextManager[None] = (
|
||||
V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext() # type: ignore[assignment]
|
||||
)
|
||||
with context:
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import math
|
|||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import sympy
|
||||
|
||||
|
|
@ -1195,7 +1195,7 @@ def next_power_of_two(n):
|
|||
|
||||
|
||||
def set_head_dim_values(
|
||||
kernel_options: Dict[str, Any], qk_head_dim, v_head_dim, graph_sizevars
|
||||
kernel_options: dict[str, Any], qk_head_dim, v_head_dim, graph_sizevars
|
||||
):
|
||||
"""
|
||||
Mutates kernel options, adding head dimension calculations.
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import os
|
|||
import re
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Callable, cast, Dict, List, Optional, TYPE_CHECKING, Union
|
||||
from typing import Callable, cast, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from torch._inductor import config
|
||||
from torch._inductor.utils import get_benchmark_name
|
||||
|
|
@ -92,7 +92,7 @@ class CachedMetricsDeltas:
|
|||
num_matches_for_scatter_upon_const_tensor: int
|
||||
|
||||
|
||||
def get_metric_fields() -> List[str]:
|
||||
def get_metric_fields() -> list[str]:
|
||||
return [field.name for field in dataclasses.fields(CachedMetricsDeltas)]
|
||||
|
||||
|
||||
|
|
@ -132,7 +132,7 @@ class MetricTable:
|
|||
num_rows_added: int = 0
|
||||
|
||||
def add_row(
|
||||
self, row_fn: Callable[[], Dict[str, Optional[Union[str, float]]]]
|
||||
self, row_fn: Callable[[], dict[str, Optional[Union[str, float]]]]
|
||||
) -> None:
|
||||
if self.table_name not in enabled_metric_tables():
|
||||
return
|
||||
|
|
@ -160,7 +160,7 @@ class MetricTable:
|
|||
writer = csv.writer(fd, lineterminator="\n")
|
||||
writer.writerow(["model_name"] + self.column_names)
|
||||
|
||||
def _write_row(self, row: List[str]) -> None:
|
||||
def _write_row(self, row: list[str]) -> None:
|
||||
filename = self.output_filename()
|
||||
if self.num_rows_added == 0 and not os.path.exists(filename):
|
||||
self.write_header()
|
||||
|
|
@ -181,7 +181,7 @@ class MetricTable:
|
|||
writer.writerow(row)
|
||||
|
||||
@staticmethod
|
||||
def register_table(name: str, column_names: List[str]) -> None:
|
||||
def register_table(name: str, column_names: list[str]) -> None:
|
||||
table = MetricTable(name, column_names)
|
||||
REGISTERED_METRIC_TABLES[name] = table
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Optional, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
import sympy
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ import logging
|
|||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import torch
|
||||
|
|
@ -277,7 +277,7 @@ class CompiledFxGraphConstantsWithGm(CompiledFxGraphConstants):
|
|||
def __init__(self, gm: torch.fx.GraphModule) -> None:
|
||||
self.gm = gm
|
||||
|
||||
def unwrap(self, g: CompiledFxGraph) -> Dict[str, torch.Tensor]:
|
||||
def unwrap(self, g: CompiledFxGraph) -> dict[str, torch.Tensor]:
|
||||
frozen_params = {
|
||||
name: getattr(self.gm, orig_name)
|
||||
for name, orig_name in g.frozen_param_names.items()
|
||||
|
|
@ -301,10 +301,10 @@ class CompiledFxGraph(OutputCode):
|
|||
device_idxs: OrderedSet[int]
|
||||
mutated_inputs: OrderedSet[str]
|
||||
mutated_input_idxs: OrderedSet[int]
|
||||
constants: Optional[Dict[str, torch.Tensor]]
|
||||
frozen_param_names: Dict[str, str]
|
||||
torchbind_constants: Dict[str, torch._C.ScriptObject]
|
||||
output_strides: Optional[List[Optional[tuple[_StrideExprStr, ...]]]]
|
||||
constants: Optional[dict[str, torch.Tensor]]
|
||||
frozen_param_names: dict[str, str]
|
||||
torchbind_constants: dict[str, torch._C.ScriptObject]
|
||||
output_strides: Optional[list[Optional[tuple[_StrideExprStr, ...]]]]
|
||||
disabled_cudagraphs_reason: Optional[str]
|
||||
metrics_deltas: metrics.CachedMetricsDeltas
|
||||
counter_deltas: Counter[str]
|
||||
|
|
|
|||
|
|
@ -14,21 +14,11 @@ import textwrap
|
|||
import traceback
|
||||
import typing
|
||||
from collections import Counter, defaultdict
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from types import ModuleType
|
||||
|
||||
import sympy
|
||||
|
|
@ -2794,7 +2784,7 @@ class Scheduler:
|
|||
)
|
||||
|
||||
# Start compiling choices in parallel
|
||||
future_choices: List[tuple[Any, Optional[LambdaFuture], ModuleType]] = []
|
||||
future_choices: list[tuple[Any, Optional[LambdaFuture], ModuleType]] = []
|
||||
triton_choices = 0
|
||||
for choice, unfused_time in sorted(
|
||||
choice_timings.items(), key=lambda x: x[1]
|
||||
|
|
@ -2964,7 +2954,7 @@ class Scheduler:
|
|||
|
||||
# These are potential fusions which we are async compiling,
|
||||
# and which we will benchmark profitability of.
|
||||
pending_fusions: Dict[
|
||||
pending_fusions: dict[
|
||||
BaseSchedulerNode,
|
||||
tuple[Callable[[], bool], BaseSchedulerNode, BaseSchedulerNode],
|
||||
] = {}
|
||||
|
|
@ -4068,7 +4058,7 @@ class Scheduler:
|
|||
|
||||
def benchmark_combo_kernel(
|
||||
self, node_list: Sequence[BaseSchedulerNode]
|
||||
) -> tuple[float, float, List[Optional[str]]]:
|
||||
) -> tuple[float, float, list[Optional[str]]]:
|
||||
"""
|
||||
Benchmark fused list of nodes and return the execution time
|
||||
in milliseconds on randomly generated inputs.
|
||||
|
|
@ -4308,7 +4298,7 @@ class BaseScheduling:
|
|||
|
||||
def benchmark_combo_kernel(
|
||||
self, node_list: Sequence[BaseSchedulerNode]
|
||||
) -> tuple[float, float, List[Optional[str]]]:
|
||||
) -> tuple[float, float, list[Optional[str]]]:
|
||||
"""
|
||||
Benchmark the list of nodes to combine and return the execution time
|
||||
and memory copy time in milliseconds on randomly generated inputs.
|
||||
|
|
|
|||
|
|
@ -21,23 +21,18 @@ import tempfile
|
|||
import textwrap
|
||||
import time
|
||||
import unittest
|
||||
from collections.abc import Collection, Iterator, Mapping, MutableMapping, MutableSet
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Collection,
|
||||
Generic,
|
||||
Iterator,
|
||||
Literal,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
MutableSet,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
|
|
@ -1361,7 +1356,7 @@ def _rocm_native_device_arch_name(device: str) -> str:
|
|||
|
||||
@functools.lru_cache(None)
|
||||
def try_import_ck_lib() -> (
|
||||
tuple[Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], Type[Any]]
|
||||
tuple[Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any]]
|
||||
):
|
||||
try:
|
||||
import ck4inductor # type: ignore[import]
|
||||
|
|
@ -2608,7 +2603,7 @@ class ScopedDict(MutableMapping[KeyType, ValType]):
|
|||
|
||||
|
||||
@dataclass_transform(frozen_default=True)
|
||||
def ir_dataclass(cls: Optional[Type[Any]] = None, /, *, frozen: bool = True) -> Any:
|
||||
def ir_dataclass(cls: Optional[type[Any]] = None, /, *, frozen: bool = True) -> Any:
|
||||
def wrap(cls: _T) -> _T:
|
||||
if sys.version_info >= (3, 10):
|
||||
return dataclasses.dataclass(cls, kw_only=True, frozen=frozen) # type: ignore[call-overload]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import datetime
|
|||
import tempfile
|
||||
from collections import defaultdict
|
||||
from types import ModuleType
|
||||
from typing import Any, Dict, Optional, Protocol
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
import torch
|
||||
from torch.autograd import DeviceType
|
||||
|
|
@ -73,7 +73,7 @@ def get_triton_kernel(mod: ModuleType): # type: ignore[no-untyped-def]
|
|||
|
||||
|
||||
def benchmark_all_kernels(
|
||||
benchmark_name: str, benchmark_all_configs: Optional[Dict[Any, Any]]
|
||||
benchmark_name: str, benchmark_all_configs: Optional[dict[Any, Any]]
|
||||
) -> None:
|
||||
"""
|
||||
An experimental API used only when config.benchmark_kernel is true.
|
||||
|
|
@ -184,7 +184,7 @@ def parse_profile_event_list(
|
|||
"""
|
||||
return ev.self_device_time_total / 1000 / nruns # type: ignore[attr-defined]
|
||||
|
||||
all_events: Dict[str, list[ProfileEvent]] = defaultdict(list)
|
||||
all_events: dict[str, list[ProfileEvent]] = defaultdict(list)
|
||||
|
||||
def add_event(
|
||||
ev: torch.autograd.profiler_util.EventList,
|
||||
|
|
|
|||
|
|
@ -1353,11 +1353,11 @@ def _is_exception(obj) -> bool:
|
|||
|
||||
|
||||
def raise_error_container_parameter_missing(target_type) -> None:
|
||||
if target_type == "Dict":
|
||||
if target_type.endswith("ict"):
|
||||
raise RuntimeError(
|
||||
"Attempted to use Dict without "
|
||||
f"Attempted to use {target_type} without "
|
||||
"contained types. Please add contained type, e.g. "
|
||||
"Dict[int, int]"
|
||||
f"{target_type}[int, int]"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Attempted to use {target_type} without a "
|
||||
|
|
@ -1366,15 +1366,20 @@ def raise_error_container_parameter_missing(target_type) -> None:
|
|||
)
|
||||
|
||||
|
||||
_RAW_TYPE_NAME_MAPPING = {
|
||||
dict: "dict",
|
||||
list: "list",
|
||||
tuple: "tuple",
|
||||
typing.Dict: "Dict", # noqa: UP006
|
||||
typing.List: "List", # noqa: UP006
|
||||
typing.Optional: "Optional",
|
||||
typing.Tuple: "Tuple", # noqa: UP006
|
||||
}
|
||||
|
||||
|
||||
def check_args_exist(target_type) -> None:
|
||||
if target_type is typing.List or target_type is list: # noqa: UP006
|
||||
raise_error_container_parameter_missing("List")
|
||||
elif target_type is typing.Tuple or target_type is tuple: # noqa: UP006
|
||||
raise_error_container_parameter_missing("Tuple")
|
||||
elif target_type is typing.Dict or target_type is dict: # noqa: UP006
|
||||
raise_error_container_parameter_missing("Dict")
|
||||
elif target_type is None or target_type is Optional:
|
||||
raise_error_container_parameter_missing("Optional")
|
||||
if name := _RAW_TYPE_NAME_MAPPING.get(target_type):
|
||||
raise_error_container_parameter_missing(name)
|
||||
|
||||
|
||||
def check_empty_containers(obj) -> None:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import operator
|
|||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from functools import partial, reduce
|
||||
from typing import Callable, List, Optional, Tuple, Type, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._prims_common as utils
|
||||
|
|
|
|||
|
|
@ -12,12 +12,9 @@ from typing import (
|
|||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
overload,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import warnings
|
|||
from collections.abc import Iterable, Sequence
|
||||
from enum import Enum
|
||||
from functools import partial, reduce, singledispatch, wraps
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, overload, Tuple, Union
|
||||
from typing import Any, Callable, cast, Optional, overload, Union
|
||||
|
||||
import torch
|
||||
import torch._prims as prims
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._prims as prims
|
||||
|
|
|
|||
|
|
@ -1,4 +1 @@
|
|||
from typing import List
|
||||
|
||||
|
||||
__all__: list[str] = []
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ import contextlib
|
|||
import warnings
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, ContextManager, Optional, Union
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
|
@ -665,7 +666,7 @@ class BaseFunctionalizeAPI(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def redispatch_to_next(self) -> ContextManager:
|
||||
def redispatch_to_next(self) -> AbstractContextManager:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -709,7 +710,7 @@ class PythonFunctionalizeAPI(BaseFunctionalizeAPI):
|
|||
def functionalize(self, inner_f: Callable) -> Callable:
|
||||
return dispatch_functionalize(inner_f, self.mode)
|
||||
|
||||
def redispatch_to_next(self) -> ContextManager:
|
||||
def redispatch_to_next(self) -> AbstractContextManager:
|
||||
# [NOTE] We don't do anything here because at the time
|
||||
# we exercise this path, we would have already popped the
|
||||
# FunctionalTensorMode from mode stack. Since FunctionalTensorMode
|
||||
|
|
@ -753,7 +754,7 @@ class CppFunctionalizeAPI(BaseFunctionalizeAPI):
|
|||
def functionalize(self, inner_f: Callable) -> Callable:
|
||||
return torch.func.functionalize(inner_f)
|
||||
|
||||
def redispatch_to_next(self) -> ContextManager:
|
||||
def redispatch_to_next(self) -> AbstractContextManager:
|
||||
return torch._C._ExcludeDispatchKeyGuard(
|
||||
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
|
||||
)
|
||||
|
|
@ -801,7 +802,7 @@ class FunctorchFunctionalizeAPI(BaseFunctionalizeAPI):
|
|||
),
|
||||
)
|
||||
|
||||
def redispatch_to_next(self) -> ContextManager:
|
||||
def redispatch_to_next(self) -> AbstractContextManager:
|
||||
return self.interpreter.lower()
|
||||
|
||||
def replace(self, input_tensor, output_tensor) -> None:
|
||||
|
|
|
|||
|
|
@ -7,12 +7,12 @@ import typing
|
|||
import warnings
|
||||
import weakref
|
||||
from abc import abstractmethod
|
||||
from contextlib import AbstractContextManager
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
ContextManager,
|
||||
Generic,
|
||||
NewType,
|
||||
Optional,
|
||||
|
|
@ -1741,7 +1741,9 @@ class MetaConverter(Generic[_TensorT]):
|
|||
# subclasses. Relevant test is
|
||||
# DynamicShapesFunctionTests::test_add_dynamic_shapes in
|
||||
# test/dynamo/test_dynamic_shapes.py
|
||||
maybe_fake_mgr: ContextManager[None] = contextlib.nullcontext()
|
||||
maybe_fake_mgr: AbstractContextManager[
|
||||
None
|
||||
] = contextlib.nullcontext()
|
||||
from torch._subclasses.fake_tensor import (
|
||||
in_kernel_invocation_manager,
|
||||
maybe_get_fake_mode,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ half, float, double and bfloat16) and complex :class:`Tensor` types (cfloat, cdo
|
|||
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from typing import cast, List, Optional, Tuple, Union
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import _vmap_internals
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import sys
|
||||
import types
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, TypeVar
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import threading
|
|||
import traceback
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, cast, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._C
|
||||
|
|
@ -85,7 +85,7 @@ try:
|
|||
paths = ["libamd_smi.so"]
|
||||
if rocm_home := os.getenv("ROCM_HOME", os.getenv("ROCM_PATH")):
|
||||
paths = [os.path.join(rocm_home, "lib/libamd_smi.so")] + paths
|
||||
self.paths: List[str] = paths
|
||||
self.paths: list[str] = paths
|
||||
|
||||
def hooked_CDLL(
|
||||
self, name: Union[str, Path, None], *args: Any, **kwargs: Any
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Any, ContextManager, Optional
|
||||
from contextlib import AbstractContextManager, contextmanager, nullcontext
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -14,7 +14,7 @@ from .contract import _State, contract
|
|||
|
||||
|
||||
@contextmanager
|
||||
def _no_hook(module: nn.Module, user_ctx: Optional[ContextManager] = None):
|
||||
def _no_hook(module: nn.Module, user_ctx: Optional[AbstractContextManager] = None):
|
||||
r"""
|
||||
Disable hooks installed by checkpoint to avoid unintentional recursion
|
||||
during backward recomputation.
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
import pickle
|
||||
from dataclasses import dataclass
|
||||
from io import BufferedIOBase
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch._weights_only_unpickler as _weights_only_unpickler
|
||||
from torch.serialization import _load, _save, DEFAULT_PROTOCOL, MAP_LOCATION
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -23,7 +23,7 @@ _weights_only_unpickler._add_safe_globals([_Entry])
|
|||
|
||||
class _PseudoZipFile:
|
||||
def __init__(self) -> None:
|
||||
self.records: Dict[str, Tuple[object, int]] = {}
|
||||
self.records: dict[str, tuple[object, int]] = {}
|
||||
|
||||
def write_record(self, key: str, data: object, length: int) -> None:
|
||||
self.records[key] = (data, length)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Iterator
|
||||
from typing import Tuple, Union
|
||||
from typing import Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
from typing import List, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard.op_registry_utils import _decorator_func
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from contextlib import contextmanager
|
|||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import operator
|
|||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from functools import reduce
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
|
|
@ -205,21 +205,21 @@ class SavePlanner(abc.ABC):
|
|||
|
||||
# Save plan for the current rank as computed by `create_local_plan` API
|
||||
# Cached on the local rank.
|
||||
_cached_save_plan: Dict[str, SavePlan] = {}
|
||||
_cached_save_plan: dict[str, SavePlan] = {}
|
||||
# Final save plan for the current rank.
|
||||
# This is created by merging the plan created by `create_local_plan` API
|
||||
# and the result of `create_global_plan` for the given rank.
|
||||
# This is the final plan computed by the `finish_plan` API that gets
|
||||
# sent to the `write_data`.
|
||||
# Cached on the local rank.
|
||||
_cached_final_save_plan: Dict[str, SavePlan] = {}
|
||||
_cached_final_save_plan: dict[str, SavePlan] = {}
|
||||
# Collection of all the local plans from all the ranks.
|
||||
# This is the input to the `create_global_plan` API.
|
||||
# Cached on the coordinator rank.
|
||||
_cached_all_plans: Dict[str, list[SavePlan]] = {}
|
||||
_cached_all_plans: dict[str, list[SavePlan]] = {}
|
||||
# Global checkpoint plan as computed by `create_global_plan` API.
|
||||
# Cached on the coordinator rank.
|
||||
_cached_global_plan: Dict[str, list[SavePlan]] = {}
|
||||
_cached_global_plan: dict[str, list[SavePlan]] = {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_up_planner(
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ import logging
|
|||
import os
|
||||
import socket
|
||||
import traceback
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
from torch.distributed.elastic.events.handlers import get_logging_handler
|
||||
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ was launched a :class:`api.SubprocessContext` is returned. Both are specific
|
|||
implementations of the parent :class:`api.PContext` class.
|
||||
"""
|
||||
|
||||
from typing import Callable, Dict, Optional, Tuple, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401
|
||||
_validate_full_rank,
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ from dataclasses import dataclass, field
|
|||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from string import Template
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
|
||||
from torch.distributed.elastic.utils.logging import get_logger
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Union
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
|
|
@ -90,7 +90,7 @@ def validate_tensors_metadata(
|
|||
|
||||
def generate_stage_to_rank_mapping(
|
||||
pp_size: int, num_stages: int, style: str = "loop"
|
||||
) -> Dict[int, int]:
|
||||
) -> dict[int, int]:
|
||||
"""
|
||||
Compute the stage id to rank mapping for either a looped or V-style schedule.
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import re
|
|||
from abc import ABC, abstractmethod
|
||||
from collections import Counter, defaultdict
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, NamedTuple, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -1025,7 +1025,7 @@ def _validate_schedule(
|
|||
pp_group_size: int,
|
||||
num_stages: int,
|
||||
num_microbatches: int,
|
||||
) -> Dict[int, int]:
|
||||
) -> dict[int, int]:
|
||||
assert (
|
||||
len(actions) == pp_group_size
|
||||
), f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}"
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import threading
|
|||
import warnings
|
||||
from collections.abc import Generator
|
||||
from datetime import timedelta
|
||||
from typing import Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -9,17 +9,7 @@ import warnings
|
|||
import zipfile
|
||||
from collections.abc import Iterator
|
||||
from enum import auto, Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, Union
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, cast, Generic, List, Optional, TypeVar, Union
|
||||
from typing import Callable, cast, Generic, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import importlib
|
|||
import io
|
||||
import pickle
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Callable, Dict, NewType, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, Callable, NewType, Optional, TypeVar, Union
|
||||
from typing_extensions import override, Self
|
||||
|
||||
import torch
|
||||
|
|
@ -45,7 +45,7 @@ class GraphPickler(pickle.Pickler):
|
|||
@override
|
||||
def reducer_override(
|
||||
self, obj: object
|
||||
) -> Tuple[Callable[..., Any], Tuple[Any, ...]]:
|
||||
) -> tuple[Callable[..., Any], tuple[Any, ...]]:
|
||||
# This function is supposed to return either NotImplemented (meaning to
|
||||
# do the default pickle behavior) or a pair of (unpickle callable, data
|
||||
# to pass to unpickle).
|
||||
|
|
@ -138,13 +138,13 @@ class _GraphUnpickler(pickle.Unpickler):
|
|||
|
||||
|
||||
class _ShapeEnvPickleData:
|
||||
data: Dict[str, object]
|
||||
data: dict[str, object]
|
||||
|
||||
@classmethod
|
||||
def reduce_helper(
|
||||
cls, pickler: GraphPickler, obj: ShapeEnv
|
||||
) -> Tuple[
|
||||
Callable[[Self, _UnpickleState], ShapeEnv], Tuple[Self, _UnpickleStateToken]
|
||||
) -> tuple[
|
||||
Callable[[Self, _UnpickleState], ShapeEnv], tuple[Self, _UnpickleStateToken]
|
||||
]:
|
||||
return cls.unpickle, (cls(obj), pickler._unpickle_state)
|
||||
|
||||
|
|
@ -174,8 +174,8 @@ class _SymNodePickleData:
|
|||
cls,
|
||||
pickler: GraphPickler,
|
||||
obj: _SymNodeT,
|
||||
) -> Tuple[
|
||||
Callable[[Self, _UnpickleState], _SymNodeT], Tuple[Self, _UnpickleStateToken]
|
||||
) -> tuple[
|
||||
Callable[[Self, _UnpickleState], _SymNodeT], tuple[Self, _UnpickleStateToken]
|
||||
]:
|
||||
args = (cls(obj.node), pickler._unpickle_state)
|
||||
if isinstance(obj, torch.SymInt):
|
||||
|
|
@ -205,8 +205,8 @@ class _TensorPickleData:
|
|||
@classmethod
|
||||
def reduce_helper(
|
||||
cls, pickler: GraphPickler, obj: FakeTensor
|
||||
) -> Tuple[
|
||||
Callable[[Self, _UnpickleState], FakeTensor], Tuple[Self, _UnpickleStateToken]
|
||||
) -> tuple[
|
||||
Callable[[Self, _UnpickleState], FakeTensor], tuple[Self, _UnpickleStateToken]
|
||||
]:
|
||||
return cls.unpickle, (
|
||||
cls(pickler._meta_tensor_describer, obj),
|
||||
|
|
@ -266,8 +266,8 @@ class _TorchNumpyPickleData:
|
|||
def reduce_helper(
|
||||
cls, pickler: GraphPickler, obj: object
|
||||
) -> Optional[
|
||||
Tuple[
|
||||
Callable[[Self, _UnpickleState], object], Tuple[Self, _UnpickleStateToken]
|
||||
tuple[
|
||||
Callable[[Self, _UnpickleState], object], tuple[Self, _UnpickleStateToken]
|
||||
]
|
||||
]:
|
||||
if data := cls.from_object(obj):
|
||||
|
|
@ -309,9 +309,9 @@ class _GraphModulePickleData:
|
|||
@classmethod
|
||||
def reduce_helper(
|
||||
cls, pickler: GraphPickler, obj: torch.fx.GraphModule
|
||||
) -> Tuple[
|
||||
) -> tuple[
|
||||
Callable[[Self, _UnpickleState], torch.fx.GraphModule],
|
||||
Tuple[Self, _UnpickleStateToken],
|
||||
tuple[Self, _UnpickleStateToken],
|
||||
]:
|
||||
return cls.unpickle, (
|
||||
cls(obj),
|
||||
|
|
@ -337,7 +337,7 @@ class _GraphModulePickleData:
|
|||
|
||||
class _NodePickleData:
|
||||
def __init__(
|
||||
self, node: torch.fx.Node, mapping: Dict[torch.fx.Node, "_NodePickleData"]
|
||||
self, node: torch.fx.Node, mapping: dict[torch.fx.Node, "_NodePickleData"]
|
||||
) -> None:
|
||||
self.args = pytree.tree_map_only(torch.fx.Node, lambda n: mapping[n], node.args)
|
||||
self.kwargs = pytree.tree_map_only(
|
||||
|
|
@ -358,7 +358,7 @@ class _NodePickleData:
|
|||
def unpickle(
|
||||
self,
|
||||
graph: torch.fx.Graph,
|
||||
mapping: Dict["_NodePickleData", torch.fx.Node],
|
||||
mapping: dict["_NodePickleData", torch.fx.Node],
|
||||
unpickle_state: _UnpickleState,
|
||||
) -> torch.fx.Node:
|
||||
args = pytree.tree_map_only(_NodePickleData, lambda n: mapping[n], self.args)
|
||||
|
|
@ -376,7 +376,7 @@ class _OpPickleData:
|
|||
@classmethod
|
||||
def reduce_helper(
|
||||
cls, pickler: GraphPickler, op: object
|
||||
) -> Tuple[Callable[[_UnpickleState], object], Tuple[_UnpickleStateToken]]:
|
||||
) -> tuple[Callable[[_UnpickleState], object], tuple[_UnpickleStateToken]]:
|
||||
result = cls.pickle(op)
|
||||
return (result.unpickle, (pickler._unpickle_state,))
|
||||
|
||||
|
|
@ -404,7 +404,7 @@ class _OpPickleData:
|
|||
def _pickle_op(
|
||||
name: str,
|
||||
datacls: Union[
|
||||
Type["_OpOverloadPickleData"], Type["_OpOverloadPacketPickleData"]
|
||||
type["_OpOverloadPickleData"], type["_OpOverloadPacketPickleData"]
|
||||
],
|
||||
) -> "_OpPickleData":
|
||||
if not name.startswith("torch.ops.aten"): # TODO: What's the full list?
|
||||
|
|
@ -501,7 +501,7 @@ class _GraphPickleData:
|
|||
self.tracer_cls = graph._tracer_cls
|
||||
self.tracer_extras = graph._tracer_extras
|
||||
|
||||
nodes: Dict[torch.fx.Node, _NodePickleData] = {}
|
||||
nodes: dict[torch.fx.Node, _NodePickleData] = {}
|
||||
for node in graph.nodes:
|
||||
nodes[node] = _NodePickleData(node, nodes)
|
||||
self.nodes = tuple(nodes.values())
|
||||
|
|
@ -521,7 +521,7 @@ class _GraphPickleData:
|
|||
) -> torch.fx.Graph:
|
||||
graph = torch.fx.Graph(gm, self.tracer_cls, self.tracer_extras)
|
||||
|
||||
nodes: Dict[_NodePickleData, torch.fx.Node] = {}
|
||||
nodes: dict[_NodePickleData, torch.fx.Node] = {}
|
||||
for nd in self.nodes:
|
||||
nodes[nd] = nd.unpickle(graph, nodes, unpickle_state)
|
||||
|
||||
|
|
@ -532,9 +532,9 @@ class _TracingContextPickleData:
|
|||
@classmethod
|
||||
def reduce_helper(
|
||||
cls, pickler: GraphPickler, obj: torch._guards.TracingContext
|
||||
) -> Tuple[
|
||||
) -> tuple[
|
||||
Callable[[Self, _UnpickleState], torch._guards.TracingContext],
|
||||
Tuple[Self, _UnpickleStateToken],
|
||||
tuple[Self, _UnpickleStateToken],
|
||||
]:
|
||||
return (
|
||||
cls.unpickle,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ import math
|
|||
import operator
|
||||
import sys
|
||||
from functools import lru_cache, update_wrapper
|
||||
from typing import Optional, Set, TYPE_CHECKING, Union
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch._logging.structured as structured
|
||||
|
|
@ -1233,7 +1233,7 @@ def _make_node_magic(method, func):
|
|||
else:
|
||||
method_attr = method
|
||||
|
||||
def uninteresting_files() -> Set[str]:
|
||||
def uninteresting_files() -> set[str]:
|
||||
import torch
|
||||
|
||||
mods = [
|
||||
|
|
|
|||
|
|
@ -36,11 +36,11 @@ if TYPE_CHECKING:
|
|||
# Mapping of builtins to their `typing` equivalent.
|
||||
# (PEP585: See D68459095 test plan)
|
||||
_origin_type_map = {
|
||||
list: typing.List,
|
||||
dict: typing.Dict,
|
||||
set: typing.Set,
|
||||
frozenset: typing.FrozenSet,
|
||||
tuple: typing.Tuple,
|
||||
list: typing.List, # noqa: UP006
|
||||
dict: typing.Dict, # noqa: UP006
|
||||
set: typing.Set, # noqa: UP006
|
||||
frozenset: typing.FrozenSet, # noqa: UP006
|
||||
tuple: typing.Tuple, # noqa: UP006
|
||||
}
|
||||
|
||||
_legal_ops = dict.fromkeys(
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import ast
|
||||
import copy
|
||||
import dataclasses
|
||||
import inspect
|
||||
import re
|
||||
import string
|
||||
from collections import namedtuple
|
||||
from textwrap import dedent
|
||||
from typing import List, Tuple # noqa: F401
|
||||
|
||||
import torch
|
||||
import torch.jit.annotations
|
||||
|
|
@ -551,7 +551,7 @@ def build_ignore_context_manager(ctx, stmt):
|
|||
return_type_ann = " -> " + outputs[0].ann
|
||||
return_statement_str += outputs[0].name
|
||||
if len(outputs) > 1:
|
||||
return_type_ann = " -> Tuple"
|
||||
return_type_ann = " -> tuple"
|
||||
return_type_ann += "[" + ", ".join([var.ann for var in outputs]) + "]"
|
||||
return_statement_str += ", ".join([var.name for var in outputs])
|
||||
return return_type_ann, return_statement_str
|
||||
|
|
@ -581,10 +581,18 @@ def build_ignore_context_manager(ctx, stmt):
|
|||
return_stmt = ast.parse(return_stmt).body[0]
|
||||
ignore_function.body.append(return_stmt) # type: ignore[attr-defined]
|
||||
|
||||
ignore_func_str = f"""\
|
||||
# Backward compat: These used to be imported into the outer global scope so some
|
||||
# code may still expect them.
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
@torch.jit.ignore
|
||||
{astunparse.unparse(ignore_function)}
|
||||
"""
|
||||
g = copy.copy(globals())
|
||||
exec(ignore_func_str, g) # noqa: P204
|
||||
# registers the custom function in the global context
|
||||
ignore_func_str = "@torch.jit.ignore\n" + astunparse.unparse(ignore_function)
|
||||
ignore_func_str += f'\nglobals()["{ignore_function_name}"] = {ignore_function_name}'
|
||||
exec(ignore_func_str) # noqa: P204
|
||||
globals()[ignore_function_name] = g[ignore_function_name]
|
||||
|
||||
# build the statements as:
|
||||
# <out_1>, <out_2>, ... = torch.jit.frontend.<func>(<in_1>, <in_2>)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ This package enables an interface for accessing MTIA backend in python
|
|||
|
||||
import threading
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import device as _device, Tensor
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
""" This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """
|
||||
import contextlib
|
||||
from collections.abc import Iterable
|
||||
from typing import List, Union
|
||||
from typing import Union
|
||||
from warnings import warn
|
||||
|
||||
import torch.backends.cuda
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Sequence
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.export.dynamic_shapes import _Dim, _DimHint
|
||||
|
|
@ -13,6 +13,10 @@ from torch.onnx._internal._lazy_import import onnxscript_ir as ir
|
|||
from torch.utils import _pytree
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
def from_dynamic_axes_to_dynamic_shapes(
|
||||
model,
|
||||
args: tuple[Any, ...],
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import threading
|
|||
import warnings
|
||||
from contextlib import closing, contextmanager
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, cast, Dict, Generic, IO, Optional, TypeVar, Union
|
||||
from typing import Any, Callable, cast, Generic, IO, Optional, TypeVar, Union
|
||||
from typing_extensions import TypeAlias, TypeIs
|
||||
|
||||
import torch
|
||||
|
|
@ -1898,7 +1898,7 @@ def _load(
|
|||
data_descripter_size64 = 24
|
||||
data_descripter_size32 = 16
|
||||
mz_uint32_max = 0xFFFFFFFF
|
||||
offsets: Dict[str, int] = dict()
|
||||
offsets: dict[str, int] = dict()
|
||||
|
||||
def _get_offset(key, name, numel):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
# The Tensor classes are added to this module by python_tensor.cpp
|
||||
# A workaround to support both TorchScript and MyPy:
|
||||
from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
import copy
|
||||
from itertools import chain
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -164,7 +164,7 @@ class FusionEmbeddingWithModifier(FusionEmbeddingWithHook):
|
|||
# _fqn_modifiers is a private function as a contract between DSD. When users change the state_dict
|
||||
# keys, they need to provide a mapping from the new key to the original key. This is used to ensure
|
||||
# consistency between the state_dict keys and fqn.
|
||||
def _fqn_modifiers(self) -> Dict[str, str]:
|
||||
def _fqn_modifiers(self) -> dict[str, str]:
|
||||
return {
|
||||
"weight": "embedding",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
from typing import List
|
||||
|
||||
from torch.testing._internal.opinfo.core import OpInfo
|
||||
from torch.testing._internal.opinfo.definitions import (
|
||||
_masked,
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ import bisect
|
|||
import itertools
|
||||
import math
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from typing import cast, Generic, Iterable, Optional, TypeVar, Union
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import cast, Generic, Optional, TypeVar, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
# No 'default_generator' in torch/__init__.pyi
|
||||
|
|
|
|||
|
|
@ -75,7 +75,6 @@ import sys
|
|||
import urllib.parse
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
import warnings
|
||||
|
||||
import torch.utils.show_pickle
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ This package is lazily initialized, so you can always import it, and use
|
|||
import threading
|
||||
import traceback
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._C
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user