diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index ddbc076a72d..1ff5d847b61 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs +from collections.abc import Callable from enum import Enum -from typing import Any, Callable +from typing import Any import torch from torch._C._profiler import ( diff --git a/torch/_C/_dynamo/compiled_autograd.pyi b/torch/_C/_dynamo/compiled_autograd.pyi index 321a99fc709..ef24582b502 100644 --- a/torch/_C/_dynamo/compiled_autograd.pyi +++ b/torch/_C/_dynamo/compiled_autograd.pyi @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable from torch import Tensor from torch._dynamo.compiled_autograd import AutogradCompilerInstance diff --git a/torch/_C/_dynamo/guards.pyi b/torch/_C/_dynamo/guards.pyi index 5824b705644..b8c0a93e35f 100644 --- a/torch/_C/_dynamo/guards.pyi +++ b/torch/_C/_dynamo/guards.pyi @@ -1,6 +1,6 @@ import enum -from typing import Any, Callable, Optional -from typing_extensions import TypeAlias +from collections.abc import Callable +from typing import Any, Optional, TypeAlias import torch diff --git a/torch/_C/_monitor.pyi b/torch/_C/_monitor.pyi index be6f0f64f97..82f2a3e4427 100644 --- a/torch/_C/_monitor.pyi +++ b/torch/_C/_monitor.pyi @@ -1,9 +1,9 @@ # Defined in torch/csrc/monitor/python_init.cpp import datetime +from collections.abc import Callable from enum import Enum from types import TracebackType -from typing import Callable class Aggregation(Enum): VALUE = ... diff --git a/torch/_C/_profiler.pyi b/torch/_C/_profiler.pyi index 5e2870f72b4..d60d89a6a47 100644 --- a/torch/_C/_profiler.pyi +++ b/torch/_C/_profiler.pyi @@ -1,6 +1,5 @@ from enum import Enum -from typing import Literal -from typing_extensions import TypeAlias +from typing import Literal, TypeAlias from torch._C import device, dtype, layout diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 8e9796d2f7c..20a15512899 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -1,10 +1,10 @@ # mypy: allow-untyped-defs import inspect from collections import defaultdict -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import lru_cache, partial, wraps from itertools import chain -from typing import Callable, Optional, TYPE_CHECKING, TypeVar, Union +from typing import Optional, TYPE_CHECKING, TypeVar, Union from typing_extensions import ParamSpec diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 870bd5a658a..7da9d20efc6 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -5,12 +5,12 @@ import itertools import numbers import operator import sys -from collections.abc import Iterable +from collections.abc import Callable, Iterable from contextlib import nullcontext from enum import Enum from functools import partial, reduce from itertools import chain, product -from typing import Any, Callable, cast, Optional, Union +from typing import Any, cast, Optional, Union import torch import torch._meta_registrations diff --git a/torch/_decomp/decompositions_for_jvp.py b/torch/_decomp/decompositions_for_jvp.py index cd1e0426f16..e11540e0c2b 100644 --- a/torch/_decomp/decompositions_for_jvp.py +++ b/torch/_decomp/decompositions_for_jvp.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs import inspect -from typing import Callable, Optional +from collections.abc import Callable +from typing import Optional import torch import torch._decomp diff --git a/torch/_decomp/decompositions_for_rng.py b/torch/_decomp/decompositions_for_rng.py index 256045498cb..455ef0cc994 100644 --- a/torch/_decomp/decompositions_for_rng.py +++ b/torch/_decomp/decompositions_for_rng.py @@ -2,7 +2,7 @@ # mypy: allow-untyped-defs import functools from collections import defaultdict -from typing import Callable +from collections.abc import Callable import torch import torch._decomp as decomp diff --git a/torch/_dispatch/python.py b/torch/_dispatch/python.py index a4103eb8387..7b790fe18ea 100644 --- a/torch/_dispatch/python.py +++ b/torch/_dispatch/python.py @@ -1,9 +1,9 @@ # mypy: allow-untyped-defs import itertools import unittest.mock -from collections.abc import Iterator +from collections.abc import Callable, Iterator from contextlib import contextmanager -from typing import Callable, TypeVar, Union +from typing import TypeVar, Union from typing_extensions import ParamSpec import torch diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index ccd027a6da1..d653db0c23a 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -16,7 +16,8 @@ from collections import OrderedDict from contextlib import contextmanager from functools import lru_cache -from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING, Union +from collections.abc import Callable from unittest.mock import patch import torch diff --git a/torch/_export/converter.py b/torch/_export/converter.py index bba7c2d16aa..125386cc8ab 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -4,9 +4,9 @@ import logging import operator import typing import warnings -from collections.abc import Sequence +from collections.abc import Callable, Sequence from contextlib import contextmanager -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union import torch import torch.export._trace diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 758c53c0c02..f14ed0af04e 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -6,9 +6,9 @@ import inspect import logging import math from collections import defaultdict -from collections.abc import Sequence +from collections.abc import Callable, Sequence from contextlib import contextmanager -from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING, Union import torch import torch.utils._pytree as pytree diff --git a/torch/_export/pass_base.py b/torch/_export/pass_base.py index 952e904ca26..61c57824e7e 100644 --- a/torch/_export/pass_base.py +++ b/torch/_export/pass_base.py @@ -2,8 +2,9 @@ import operator import traceback import typing +from collections.abc import Callable from contextlib import nullcontext -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union import torch from torch import fx diff --git a/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py b/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py index bd81f0a9267..d646b7edaaf 100644 --- a/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py +++ b/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py @@ -3,7 +3,7 @@ import math import operator import traceback from functools import partial -from typing import Callable, NamedTuple +from typing import NamedTuple, TYPE_CHECKING import sympy @@ -15,6 +15,10 @@ from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.value_ranges import ValueRanges +if TYPE_CHECKING: + from collections.abc import Callable + + __all__ = ["InputDim"] diff --git a/torch/_export/passes/constant_folding.py b/torch/_export/passes/constant_folding.py index fc4149dd55b..28f2542b63d 100644 --- a/torch/_export/passes/constant_folding.py +++ b/torch/_export/passes/constant_folding.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs import collections from collections import defaultdict -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional import torch import torch.utils._pytree as pytree diff --git a/torch/_export/passes/replace_with_hop_pass_util.py b/torch/_export/passes/replace_with_hop_pass_util.py index 97405809244..622024b46fe 100644 --- a/torch/_export/passes/replace_with_hop_pass_util.py +++ b/torch/_export/passes/replace_with_hop_pass_util.py @@ -4,7 +4,7 @@ from __future__ import annotations import contextlib import copy import operator -from typing import Callable, Optional, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING import torch @@ -12,6 +12,8 @@ from ..utils import node_replace_, nodes_map if TYPE_CHECKING: + from collections.abc import Callable + from torch._ops import HigherOrderOperator from torch.export.graph_signature import ExportGraphSignature diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 95154495c2c..71fd07b7a6d 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -14,11 +14,11 @@ import operator import traceback import typing from collections import namedtuple, OrderedDict -from collections.abc import Iterable, Iterator, Sequence +from collections.abc import Callable, Iterable, Iterator, Sequence from contextlib import contextmanager from dataclasses import dataclass, field from enum import Enum -from typing import Annotated, Any, Callable, cast, final, Optional, Union +from typing import Annotated, Any, cast, final, Optional, Union import sympy diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 06c608d20c7..2004bf6250d 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -9,10 +9,10 @@ import math import operator import re from collections import defaultdict -from collections.abc import Iterable +from collections.abc import Callable, Iterable from contextlib import contextmanager from inspect import ismethod, Parameter -from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING, Union import torch from torch._guards import detect_fake_mode diff --git a/torch/_functorch/_activation_checkpointing/knapsack_evaluator.py b/torch/_functorch/_activation_checkpointing/knapsack_evaluator.py index 7cc60f6ed54..2a1a3db275d 100644 --- a/torch/_functorch/_activation_checkpointing/knapsack_evaluator.py +++ b/torch/_functorch/_activation_checkpointing/knapsack_evaluator.py @@ -1,6 +1,6 @@ import operator from collections import deque -from typing import Callable +from collections.abc import Callable import networkx as nx diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 54ad74da8f7..87d3c85bc82 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -16,9 +16,10 @@ import shutil import time import traceback from abc import ABC, abstractmethod +from collections.abc import Callable from copy import copy from dataclasses import dataclass -from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar, Union +from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar, Union from typing_extensions import override import torch diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index acfd40fe78c..0f38c19bc98 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -11,7 +11,8 @@ a functionalized version of the graph under compilation. import collections import contextlib import logging -from typing import Callable, Optional +from collections.abc import Callable +from typing import Optional import torch import torch.utils._pytree as pytree diff --git a/torch/_functorch/_aot_autograd/graph_capture_wrappers.py b/torch/_functorch/_aot_autograd/graph_capture_wrappers.py index b2d96620b4b..ac39cbddc8b 100644 --- a/torch/_functorch/_aot_autograd/graph_capture_wrappers.py +++ b/torch/_functorch/_aot_autograd/graph_capture_wrappers.py @@ -12,9 +12,10 @@ It does so by: """ import warnings +from collections.abc import Callable from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext from dataclasses import dataclass -from typing import Any, Callable, cast, Optional, TypeVar, Union +from typing import Any, cast, Optional, TypeVar, Union from unittest.mock import patch import torch diff --git a/torch/_functorch/_aot_autograd/graph_compile.py b/torch/_functorch/_aot_autograd/graph_compile.py index 2ae1263c3ae..b57aa3b1a9a 100644 --- a/torch/_functorch/_aot_autograd/graph_compile.py +++ b/torch/_functorch/_aot_autograd/graph_compile.py @@ -17,8 +17,9 @@ import operator import time import traceback from collections import defaultdict +from collections.abc import Callable from contextlib import nullcontext -from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING, Union if TYPE_CHECKING: diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 80564a90e61..b82af9ee712 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -14,10 +14,11 @@ import copy import functools import itertools import pprint +from collections.abc import Callable from contextlib import AbstractContextManager, nullcontext from dataclasses import dataclass, field from functools import wraps -from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING, Union if TYPE_CHECKING: diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index a65351c3193..a6fefe1f7f5 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -11,16 +11,7 @@ import functools import itertools from dataclasses import dataclass, field from enum import Enum -from typing import ( - Any, - Callable, - NewType, - Optional, - Protocol, - TYPE_CHECKING, - TypeVar, - Union, -) +from typing import Any, NewType, Optional, Protocol, TYPE_CHECKING, TypeVar, Union import torch import torch.utils._pytree as pytree @@ -37,7 +28,7 @@ from .utils import strict_zip if TYPE_CHECKING: import contextlib - from collections.abc import Iterable, Sequence + from collections.abc import Callable, Iterable, Sequence from torch._guards import Source from torch._inductor.output_code import OutputCode diff --git a/torch/_functorch/_aot_autograd/subclass_utils.py b/torch/_functorch/_aot_autograd/subclass_utils.py index d06f727e25a..dfa3141a3a0 100644 --- a/torch/_functorch/_aot_autograd/subclass_utils.py +++ b/torch/_functorch/_aot_autograd/subclass_utils.py @@ -7,9 +7,8 @@ and this includes tensor subclasses that implement __torch_dispatch__. import collections import typing -from collections.abc import Iterable -from typing import Any, Callable, Optional, TypeVar, Union -from typing_extensions import TypeGuard +from collections.abc import Callable, Iterable +from typing import Any, Optional, TypeGuard, TypeVar, Union import torch import torch.utils._pytree as pytree diff --git a/torch/_functorch/_aot_autograd/utils.py b/torch/_functorch/_aot_autograd/utils.py index 8f6c7d1478e..50aae93c631 100644 --- a/torch/_functorch/_aot_autograd/utils.py +++ b/torch/_functorch/_aot_autograd/utils.py @@ -6,9 +6,10 @@ Contains various utils for AOTAutograd, including those for handling collections import dataclasses import operator import warnings +from collections.abc import Callable from contextlib import nullcontext from functools import wraps -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Optional, TypeVar, Union from typing_extensions import ParamSpec import torch diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 2b0df0be370..017a37076ff 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -2,9 +2,10 @@ import contextlib import itertools +from collections.abc import Callable from contextlib import nullcontext from functools import wraps -from typing import Any, Callable, Optional +from typing import Any, Optional from unittest.mock import patch import torch diff --git a/torch/_functorch/compile_utils.py b/torch/_functorch/compile_utils.py index 929b58540f4..cdf2e1855a0 100644 --- a/torch/_functorch/compile_utils.py +++ b/torch/_functorch/compile_utils.py @@ -2,7 +2,7 @@ import operator -from typing import Callable +from collections.abc import Callable import sympy diff --git a/torch/_functorch/compilers.py b/torch/_functorch/compilers.py index 5295a526e25..303281f8560 100644 --- a/torch/_functorch/compilers.py +++ b/torch/_functorch/compilers.py @@ -5,9 +5,10 @@ import logging import os import pickle import random +from collections.abc import Callable from contextlib import contextmanager from functools import partial -from typing import Callable, Union +from typing import Union import sympy diff --git a/torch/_functorch/deprecated.py b/torch/_functorch/deprecated.py index d6e295c65c7..773eb2aa8be 100644 --- a/torch/_functorch/deprecated.py +++ b/torch/_functorch/deprecated.py @@ -10,7 +10,8 @@ documentation. import textwrap import warnings -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional, Union import torch._functorch.apis as apis import torch._functorch.eager_transforms as _impl diff --git a/torch/_functorch/eager_transforms.py b/torch/_functorch/eager_transforms.py index 828f5e8decc..7a6cf009b27 100644 --- a/torch/_functorch/eager_transforms.py +++ b/torch/_functorch/eager_transforms.py @@ -7,8 +7,9 @@ # LICENSE file in the root directory of this source tree. import contextlib +from collections.abc import Callable from functools import partial, wraps -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union import torch import torch.autograd.forward_ad as fwAD diff --git a/torch/_functorch/fx_minifier.py b/torch/_functorch/fx_minifier.py index 3cf5fc24f1c..60609ad95e6 100644 --- a/torch/_functorch/fx_minifier.py +++ b/torch/_functorch/fx_minifier.py @@ -4,9 +4,9 @@ import copy import math import os import sys +from collections.abc import Callable from dataclasses import dataclass from functools import partial, wraps -from typing import Callable import torch import torch.fx as fx diff --git a/torch/_functorch/make_functional.py b/torch/_functorch/make_functional.py index 16988a022a9..d56d6b591a5 100644 --- a/torch/_functorch/make_functional.py +++ b/torch/_functorch/make_functional.py @@ -6,8 +6,8 @@ # LICENSE file in the root directory of this source tree. import copy -from collections.abc import Iterable, Sequence -from typing import Any, Callable, NoReturn, Union +from collections.abc import Callable, Iterable, Sequence +from typing import Any, NoReturn, Union import torch import torch.nn as nn diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index f7add3407a9..c94864c5def 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -11,8 +11,9 @@ import os import os.path import re from collections import defaultdict +from collections.abc import Callable from dataclasses import dataclass, replace -from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING, Union import torch import torch._inductor.inductor_prims diff --git a/torch/_functorch/vmap.py b/torch/_functorch/vmap.py index 5e3893fef5c..25ffe9c525f 100644 --- a/torch/_functorch/vmap.py +++ b/torch/_functorch/vmap.py @@ -9,8 +9,9 @@ import contextlib import functools import itertools +from collections.abc import Callable from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union import torch from torch import Tensor diff --git a/torch/_higher_order_ops/associative_scan.py b/torch/_higher_order_ops/associative_scan.py index 2b4cc988078..374c048e477 100644 --- a/torch/_higher_order_ops/associative_scan.py +++ b/torch/_higher_order_ops/associative_scan.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs import functools import itertools -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import torch import torch._prims_common as utils diff --git a/torch/_higher_order_ops/auto_functionalize.py b/torch/_higher_order_ops/auto_functionalize.py index d8374c356ab..e00e3d3d208 100644 --- a/torch/_higher_order_ops/auto_functionalize.py +++ b/torch/_higher_order_ops/auto_functionalize.py @@ -1,9 +1,9 @@ # mypy: allow-untyped-defs import warnings from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Callable, Sequence from dataclasses import dataclass -from typing import Any, Callable, get_args, Optional, Union +from typing import Any, get_args, Optional, Union import torch import torch._library.utils as library_utils diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index e7310df23c5..828a13e2fd5 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -4,7 +4,8 @@ import contextlib import functools import logging import warnings -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional, Union import torch import torch.utils._pytree as pytree diff --git a/torch/_higher_order_ops/flat_apply.py b/torch/_higher_order_ops/flat_apply.py index 654e2ea3838..8b45cb3db63 100644 --- a/torch/_higher_order_ops/flat_apply.py +++ b/torch/_higher_order_ops/flat_apply.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs +from collections.abc import Callable from dataclasses import dataclass -from typing import Callable import torch import torch.fx.node diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index 1a0a8673658..46dd77d0cdd 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -1,6 +1,6 @@ import math -from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Sequence +from typing import Any, Optional, Union import torch import torch.utils._pytree as pytree diff --git a/torch/_higher_order_ops/foreach_map.py b/torch/_higher_order_ops/foreach_map.py index 52841724c20..0d02515d555 100644 --- a/torch/_higher_order_ops/foreach_map.py +++ b/torch/_higher_order_ops/foreach_map.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from torch._higher_order_ops.base_hop import BaseHOP, FunctionWithNoFreeVars diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py index bf4a8632f23..2cf4a45ee2b 100644 --- a/torch/_higher_order_ops/invoke_subgraph.py +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -3,7 +3,7 @@ import contextlib from contextlib import nullcontext from dataclasses import dataclass, field -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, TYPE_CHECKING, Union import torch import torch.utils._pytree as pytree @@ -36,6 +36,10 @@ from torch.fx.graph_module import GraphModule from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts +if TYPE_CHECKING: + from collections.abc import Callable + + invoke_subgraph_counter = 0 diff --git a/torch/_higher_order_ops/local_map.py b/torch/_higher_order_ops/local_map.py index 7b3a4db249a..7b897dc9add 100644 --- a/torch/_higher_order_ops/local_map.py +++ b/torch/_higher_order_ops/local_map.py @@ -6,9 +6,9 @@ # NOTE: this file may be removed once we move to a dynamo frontend import functools -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import contextmanager -from typing import Any, Callable, Optional +from typing import Any, Optional import torch import torch.utils._pytree as pytree diff --git a/torch/_higher_order_ops/map.py b/torch/_higher_order_ops/map.py index 57d2cd3cb90..73f66986c75 100644 --- a/torch/_higher_order_ops/map.py +++ b/torch/_higher_order_ops/map.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import functools -from typing import Callable, Union +from collections.abc import Callable +from typing import Union from typing_extensions import TypeVarTuple import torch diff --git a/torch/_higher_order_ops/partitioner.py b/torch/_higher_order_ops/partitioner.py index 86ead85927f..81ad53b3733 100644 --- a/torch/_higher_order_ops/partitioner.py +++ b/torch/_higher_order_ops/partitioner.py @@ -1,5 +1,6 @@ import logging -from typing import Any, Callable, Union +from collections.abc import Callable +from typing import Any, Union import torch from torch._higher_order_ops.utils import create_bw_fn, materialize_as_graph diff --git a/torch/_higher_order_ops/scan.py b/torch/_higher_order_ops/scan.py index 197ce37e126..07852d85fbe 100644 --- a/torch/_higher_order_ops/scan.py +++ b/torch/_higher_order_ops/scan.py @@ -3,7 +3,8 @@ import enum import functools import itertools import logging -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import torch import torch._prims_common as utils diff --git a/torch/_higher_order_ops/strict_mode.py b/torch/_higher_order_ops/strict_mode.py index 1ed920c4a15..f5875ded5a9 100644 --- a/torch/_higher_order_ops/strict_mode.py +++ b/torch/_higher_order_ops/strict_mode.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Any, Callable, Union +from typing import Any, TYPE_CHECKING, Union import torch import torch._subclasses.functional_tensor @@ -20,6 +20,10 @@ from torch.fx.experimental.proxy_tensor import ( from torch.utils._python_dispatch import _get_current_dispatch_mode +if TYPE_CHECKING: + from collections.abc import Callable + + @exposed_in("torch") def strict_mode(callable, operands): from torch._dynamo.backends.debugging import ( diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 818bf78840b..08cbdf990af 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -8,8 +8,8 @@ import logging import operator import threading from collections import defaultdict -from collections.abc import Sequence -from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from collections.abc import Callable, Sequence +from typing import Any, Optional, TYPE_CHECKING, Union from typing_extensions import Never import sympy diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 2124f56f9b8..01f380a4c6f 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -1,10 +1,10 @@ # mypy: allow-untyped-defs import contextlib import functools -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext from dataclasses import dataclass -from typing import Any, Callable, Optional, overload, TypeVar, Union +from typing import Any, Optional, overload, TypeVar, Union import torch import torch.fx.traceback as fx_traceback diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index 64bb209310b..7aeefe7ba1d 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs import contextlib import functools -from typing import Any, Callable, Union +from collections.abc import Callable +from typing import Any, Union import torch import torch.utils._pytree as pytree diff --git a/torch/_lazy/extract_compiled_graph.py b/torch/_lazy/extract_compiled_graph.py index 38219a54b30..78455ebc964 100644 --- a/torch/_lazy/extract_compiled_graph.py +++ b/torch/_lazy/extract_compiled_graph.py @@ -3,7 +3,8 @@ import copy import dataclasses import itertools import os -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import torch import torch._lazy as lazy diff --git a/torch/_library/autograd.py b/torch/_library/autograd.py index 3f3e9295549..2707d07059e 100644 --- a/torch/_library/autograd.py +++ b/torch/_library/autograd.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs import dataclasses +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Callable, Optional, Protocol +from typing import Any, Optional, Protocol from torch import _C, _ops, autograd, Tensor from torch.utils import _pytree diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index 596eec5174c..cb9f97d651a 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -3,9 +3,9 @@ import collections import inspect import logging import weakref -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence from contextlib import contextmanager -from typing import Any, Callable, Optional, overload, Union +from typing import Any, Optional, overload, Union import torch from torch import _C, _ops, Tensor diff --git a/torch/_library/fake_impl.py b/torch/_library/fake_impl.py index 632020a04ba..877ebb0c591 100644 --- a/torch/_library/fake_impl.py +++ b/torch/_library/fake_impl.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import contextlib import functools -from typing import Callable +from collections.abc import Callable from typing_extensions import deprecated import torch diff --git a/torch/_library/fake_profile.py b/torch/_library/fake_profile.py index d480f666268..3bd1a444fa2 100644 --- a/torch/_library/fake_profile.py +++ b/torch/_library/fake_profile.py @@ -2,9 +2,9 @@ import contextlib import io import logging import os -from collections.abc import Generator +from collections.abc import Callable, Generator from dataclasses import dataclass -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union import torch from torch._library.custom_ops import _maybe_get_opdef diff --git a/torch/_library/simple_registry.py b/torch/_library/simple_registry.py index 1f11914e8e9..8709c9e95c2 100644 --- a/torch/_library/simple_registry.py +++ b/torch/_library/simple_registry.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional from .fake_impl import FakeImplHolder from .utils import RegistrationHandle diff --git a/torch/_library/triton.py b/torch/_library/triton.py index 741b341f7e2..761279743f3 100644 --- a/torch/_library/triton.py +++ b/torch/_library/triton.py @@ -2,8 +2,8 @@ import ast import contextlib import inspect import threading -from collections.abc import Generator, Iterable -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Generator, Iterable +from typing import Any, Optional, Union from torch.utils._exposed_in import exposed_in diff --git a/torch/_library/utils.py b/torch/_library/utils.py index 59a316acc69..3a88e40caab 100644 --- a/torch/_library/utils.py +++ b/torch/_library/utils.py @@ -2,8 +2,8 @@ import dataclasses import inspect import sys -from collections.abc import Iterable, Iterator -from typing import Any, Callable, Literal, Optional, overload, Union +from collections.abc import Callable, Iterable, Iterator +from typing import Any, Literal, Optional, overload, Union import torch import torch.utils._pytree as pytree diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index a418fe3b609..5a97f7f2133 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -15,8 +15,9 @@ import tempfile import time import warnings from collections import defaultdict +from collections.abc import Callable from dataclasses import dataclass, field -from typing import Any, Callable, Generic, Optional, Union +from typing import Any, Generic, Optional, Union from typing_extensions import ParamSpec from weakref import WeakSet diff --git a/torch/_logging/scribe.py b/torch/_logging/scribe.py index 4456a94ccc7..2feb814d4a2 100644 --- a/torch/_logging/scribe.py +++ b/torch/_logging/scribe.py @@ -1,5 +1,5 @@ -from typing import Callable, Union -from typing_extensions import TypeAlias +from collections.abc import Callable +from typing import TypeAlias, Union try: diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index 10feadf3ec0..e3a244476c7 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -1,9 +1,9 @@ # mypy: allow-untyped-defs import operator -from collections.abc import Sequence +from collections.abc import Callable, Sequence from enum import Enum from functools import partial, reduce -from typing import Callable, Optional, Union +from typing import Optional, Union import torch import torch._prims_common as utils diff --git a/torch/_prims/context.py b/torch/_prims/context.py index 30bc1f85c0e..b125a2bcaf0 100644 --- a/torch/_prims/context.py +++ b/torch/_prims/context.py @@ -2,12 +2,12 @@ from __future__ import annotations import functools from contextlib import nullcontext -from typing import Any, Callable, TYPE_CHECKING, TypeVar +from typing import Any, TYPE_CHECKING, TypeVar from typing_extensions import ParamSpec if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Callable, Sequence import torch import torch._decomp diff --git a/torch/_prims/executor.py b/torch/_prims/executor.py index fdd2e19ab43..55eb0d35c38 100644 --- a/torch/_prims/executor.py +++ b/torch/_prims/executor.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Optional, TypeVar +from collections.abc import Callable +from typing import Any, Optional, TypeVar from typing_extensions import ParamSpec, TypeVarTuple, Unpack from torch._prims.context import TorchRefsMode diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index a317e8403b5..41c400ffe64 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -4,22 +4,23 @@ from __future__ import annotations import operator import typing import warnings -from collections.abc import Sequence +from collections.abc import Callable, Sequence from contextlib import AbstractContextManager, nullcontext from enum import Enum from functools import reduce from typing import ( Any, - Callable, cast, NamedTuple, Optional, overload, TYPE_CHECKING, + TypeAlias, + TypeGuard, TypeVar, Union, ) -from typing_extensions import deprecated, TypeAlias, TypeGuard +from typing_extensions import deprecated import torch from torch import sym_float, sym_int, sym_max diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py index e5e5b13f62c..692388e6846 100644 --- a/torch/_prims_common/wrappers.py +++ b/torch/_prims_common/wrappers.py @@ -2,10 +2,10 @@ import inspect import types import warnings -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import wraps from types import GenericAlias -from typing import Callable, NamedTuple, Optional, overload, TypeVar, Union +from typing import NamedTuple, Optional, overload, TypeVar, Union from typing_extensions import ParamSpec import torch diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 558170f39b9..1045162bb64 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -7,10 +7,10 @@ import itertools import math import operator import warnings -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence from enum import Enum from functools import partial, reduce, singledispatch, wraps -from typing import Any, Callable, cast, Optional, overload, Union +from typing import Any, cast, Optional, overload, Union import torch import torch._prims as prims diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py index 89ead281d94..36e7cb7bb2d 100644 --- a/torch/_refs/nn/functional/__init__.py +++ b/torch/_refs/nn/functional/__init__.py @@ -1,9 +1,10 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs import math +from collections.abc import Callable from functools import wraps -from typing import Callable, Optional, TypeVar, Union -from typing_extensions import Concatenate, ParamSpec +from typing import Concatenate, Optional, TypeVar, Union +from typing_extensions import ParamSpec import torch import torch._prims as prims diff --git a/torch/_strobelight/cli_function_profiler.py b/torch/_strobelight/cli_function_profiler.py index 80108dc9918..0cc7db12fe2 100644 --- a/torch/_strobelight/cli_function_profiler.py +++ b/torch/_strobelight/cli_function_profiler.py @@ -6,10 +6,10 @@ import os import re import subprocess import time -from collections.abc import Sequence +from collections.abc import Callable, Sequence from threading import Lock from timeit import default_timer as timer -from typing import Any, Callable, Optional, TypeVar +from typing import Any, Optional, TypeVar from typing_extensions import ParamSpec diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index e509018481b..1b03048dd60 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -5,8 +5,9 @@ import itertools import math import operator import sys +from collections.abc import Callable from functools import reduce -from typing import Callable, Optional, Union +from typing import Optional, Union import torch import torch._custom_op diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index b0cf9e6cc8a..ebfc63f7063 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -15,8 +15,17 @@ import typing import weakref from collections import defaultdict from dataclasses import dataclass -from typing import Any, Callable, cast, Literal, Optional, TYPE_CHECKING, TypeVar, Union -from typing_extensions import Self, TypeGuard +from typing import ( + Any, + cast, + Literal, + Optional, + TYPE_CHECKING, + TypeGuard, + TypeVar, + Union, +) +from typing_extensions import Self from weakref import ReferenceType import torch @@ -53,7 +62,7 @@ from ._fake_tensor_utils import _CacheKeyState, _PySymInputStub, _SymIntOutputSt if TYPE_CHECKING: - from collections.abc import Generator, Iterable, Mapping, Sequence + from collections.abc import Callable, Generator, Iterable, Mapping, Sequence from types import TracebackType from torch._guards import Source diff --git a/torch/_subclasses/fake_utils.py b/torch/_subclasses/fake_utils.py index bd481c87cf6..36bbba87199 100644 --- a/torch/_subclasses/fake_utils.py +++ b/torch/_subclasses/fake_utils.py @@ -2,7 +2,8 @@ import functools import warnings -from typing import Any, Callable, Union +from collections.abc import Callable +from typing import Any, Union import torch import torch.utils._pytree as pytree diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index 28cc3070aff..15ed56ddca3 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -3,8 +3,9 @@ import contextlib import warnings import weakref from abc import ABC, abstractmethod +from collections.abc import Callable from contextlib import AbstractContextManager -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union import torch import torch.utils._pytree as pytree diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index c447ffb5d73..123f7d44f84 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -11,17 +11,17 @@ from contextlib import AbstractContextManager, contextmanager from dataclasses import dataclass from typing import ( Any, - Callable, ClassVar, Generic, NewType, Optional, Protocol, TYPE_CHECKING, + TypeGuard, TypeVar, Union, ) -from typing_extensions import override, TypedDict, TypeGuard, TypeIs, Unpack +from typing_extensions import override, TypedDict, TypeIs, Unpack import torch from torch._C._autograd import CreationMeta @@ -46,7 +46,7 @@ from torch.utils.weak import WeakIdKeyDictionary if TYPE_CHECKING: - from collections.abc import Generator + from collections.abc import Callable, Generator from torch._C._functorch import CInterpreter from torch._guards import Source diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index 08ec23b748e..30881e06ff1 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import io -from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union +from collections.abc import Callable +from typing import Any, Optional, TYPE_CHECKING, TypeVar, Union from typing_extensions import ParamSpec import torch diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index bd64f18483f..9181a87abe4 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional r""" diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index 5dbbd7611b6..19d7befe6af 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -1,8 +1,8 @@ # mypy: allow-untyped-defs import math import warnings +from collections.abc import Callable from functools import total_ordering -from typing import Callable import torch from torch import inf, Tensor diff --git a/torch/distributions/utils.py b/torch/distributions/utils.py index 8ebed81f493..a5afc5395ee 100644 --- a/torch/distributions/utils.py +++ b/torch/distributions/utils.py @@ -1,6 +1,6 @@ -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import update_wrapper -from typing import Any, Callable, Final, Generic, Optional, overload, TypeVar, Union +from typing import Any, Final, Generic, Optional, overload, TypeVar, Union import torch import torch.nn.functional as F diff --git a/torch/export/__init__.py b/torch/export/__init__.py index 621cabf15a3..bff5f3b4071 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -2,8 +2,8 @@ import logging import os import warnings import zipfile -from collections.abc import Mapping -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Mapping +from typing import Any, Optional, Union from typing_extensions import deprecated import torch diff --git a/torch/export/_draft_export.py b/torch/export/_draft_export.py index 2b14327b245..b689965c63d 100644 --- a/torch/export/_draft_export.py +++ b/torch/export/_draft_export.py @@ -5,10 +5,10 @@ import os import re import tempfile import time -from collections.abc import Mapping +from collections.abc import Callable, Mapping from dataclasses import dataclass from enum import IntEnum -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union import torch import torch._logging._internal diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 60e81a2c63e..c9e73f0896e 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -8,9 +8,9 @@ import re import sys import time import warnings +from collections.abc import Callable from contextlib import contextmanager, nullcontext -from typing import Any, Callable, Optional, TYPE_CHECKING, Union -from typing_extensions import TypeAlias +from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union if TYPE_CHECKING: diff --git a/torch/export/_tree_utils.py b/torch/export/_tree_utils.py index 1c6a05319ad..5c2d4426066 100644 --- a/torch/export/_tree_utils.py +++ b/torch/export/_tree_utils.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional from torch.utils._pytree import Context, TreeSpec diff --git a/torch/export/decomp_utils.py b/torch/export/decomp_utils.py index 2f4c86617cb..8e7e14e1b2e 100644 --- a/torch/export/decomp_utils.py +++ b/torch/export/decomp_utils.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Callable +from collections.abc import Callable import torch from torch._export.utils import ( diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 7fcea48e126..375d059d64c 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -4,8 +4,9 @@ import inspect import logging import sys from collections import defaultdict +from collections.abc import Callable from enum import auto, Enum -from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING, Union import torch from torch.utils._pytree import ( diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index 807321f0a1e..2d7fc4114a7 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -8,9 +8,9 @@ import operator import types import warnings from collections import defaultdict -from collections.abc import Iterator +from collections.abc import Callable, Iterator from contextlib import contextmanager -from typing import Any, Callable, final, NamedTuple, Optional, TYPE_CHECKING, Union +from typing import Any, final, NamedTuple, Optional, TYPE_CHECKING, Union from torch._guards import tracing, TracingContext from torch._higher_order_ops.utils import autograd_not_implemented diff --git a/torch/export/pt2_archive/_package.py b/torch/export/pt2_archive/_package.py index 1e63d51dc9d..2cd77e3818a 100644 --- a/torch/export/pt2_archive/_package.py +++ b/torch/export/pt2_archive/_package.py @@ -6,8 +6,7 @@ import os import tempfile import zipfile from dataclasses import dataclass -from typing import Any, IO, Optional, TYPE_CHECKING, Union -from typing_extensions import TypeAlias +from typing import Any, IO, Optional, TYPE_CHECKING, TypeAlias, Union import torch import torch.utils._pytree as pytree diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index d09307f66d6..73b24757300 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -5,11 +5,12 @@ import logging import operator import re from collections import defaultdict +from collections.abc import Callable from contextlib import contextmanager from copy import deepcopy from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, cast, Optional, Union +from typing import Any, cast, Optional, Union import torch import torch.fx._pytree as fx_pytree diff --git a/torch/futures/__init__.py b/torch/futures/__init__.py index 76a479d965a..cee84f54527 100644 --- a/torch/futures/__init__.py +++ b/torch/futures/__init__.py @@ -1,11 +1,15 @@ # mypy: allow-untyped-defs from __future__ import annotations -from typing import Callable, cast, Generic, Optional, TypeVar, Union +from typing import cast, Generic, Optional, TYPE_CHECKING, TypeVar, Union import torch +if TYPE_CHECKING: + from collections.abc import Callable + + __all__ = ["Future", "collect_all", "wait_all"] diff --git a/torch/jit/_dataclass_impls.py b/torch/jit/_dataclass_impls.py index 58abc91da04..67da5e20206 100644 --- a/torch/jit/_dataclass_impls.py +++ b/torch/jit/_dataclass_impls.py @@ -4,8 +4,8 @@ import ast import dataclasses import inspect import os +from collections.abc import Callable from functools import partial -from typing import Callable from torch._jit_internal import FAKE_FILENAME_PREFIX, is_optional from torch._sources import ParsedDef, SourceContext diff --git a/torch/jit/_decompositions.py b/torch/jit/_decompositions.py index 000ec7d0ec7..2ad0427a810 100644 --- a/torch/jit/_decompositions.py +++ b/torch/jit/_decompositions.py @@ -6,7 +6,8 @@ from torch import Tensor aten = torch.ops.aten import inspect import warnings -from typing import Callable, Optional, TypeVar +from collections.abc import Callable +from typing import Optional, TypeVar from typing_extensions import ParamSpec from torch.types import Number diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 4c06ed24078..7fe8b704fb5 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -14,7 +14,8 @@ import functools import inspect import pickle import warnings -from typing import Any, Callable, Union +from collections.abc import Callable +from typing import Any, Union from typing_extensions import deprecated import torch diff --git a/torch/jit/_script.pyi b/torch/jit/_script.pyi index f6727f9198c..7d3a5de62a9 100644 --- a/torch/jit/_script.pyi +++ b/torch/jit/_script.pyi @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs # mypy: disable-error-code="type-arg" -from typing import Any, Callable, NamedTuple, overload, TypeVar -from typing_extensions import Never, TypeAlias +from collections.abc import Callable +from typing import Any, NamedTuple, overload, TypeAlias, TypeVar +from typing_extensions import Never from _typeshed import Incomplete diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index 5084d7c9228..793472adfb4 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -16,8 +16,9 @@ import inspect import os import re import warnings +from collections.abc import Callable from enum import Enum -from typing import Any, Callable, Optional, TypeVar +from typing import Any, Optional, TypeVar from typing_extensions import ParamSpec import torch diff --git a/torch/masked/_ops.py b/torch/masked/_ops.py index 7cd75ce75c4..8e4ea77ff74 100644 --- a/torch/masked/_ops.py +++ b/torch/masked/_ops.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs import warnings -from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union -from typing_extensions import ParamSpec, TypeAlias +from collections.abc import Callable +from typing import Any, Optional, TYPE_CHECKING, TypeAlias, TypeVar, Union +from typing_extensions import ParamSpec import torch from torch import sym_float, Tensor diff --git a/torch/masked/maskedtensor/_ops_refs.py b/torch/masked/maskedtensor/_ops_refs.py index aff5ce5fb9c..9a0f638d799 100644 --- a/torch/masked/maskedtensor/_ops_refs.py +++ b/torch/masked/maskedtensor/_ops_refs.py @@ -1,8 +1,9 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates +from collections.abc import Callable from functools import partial -from typing import Any, Callable, TYPE_CHECKING +from typing import Any, TYPE_CHECKING import torch diff --git a/torch/mtia/__init__.py b/torch/mtia/__init__.py index 14871d42596..55bcadf0c2a 100644 --- a/torch/mtia/__init__.py +++ b/torch/mtia/__init__.py @@ -5,7 +5,8 @@ This package enables an interface for accessing MTIA backend in python import threading import warnings -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional, Union import torch from torch import device as _device, Tensor diff --git a/torch/numa/binding.py b/torch/numa/binding.py index b92a046676f..34a61e2b9c5 100644 --- a/torch/numa/binding.py +++ b/torch/numa/binding.py @@ -1,12 +1,12 @@ import os import traceback from collections import defaultdict -from collections.abc import Iterable, Iterator +from collections.abc import Callable, Iterable, Iterator from contextlib import contextmanager from dataclasses import asdict, dataclass from enum import Enum from logging import getLogger -from typing import Callable, Optional, TypeVar +from typing import Optional, TypeVar import torch from torch._utils_internal import signpost_event diff --git a/torch/onnx/_internal/fx/_pass.py b/torch/onnx/_internal/fx/_pass.py index 88cd0b2eaab..b1fad573f29 100644 --- a/torch/onnx/_internal/fx/_pass.py +++ b/torch/onnx/_internal/fx/_pass.py @@ -7,7 +7,7 @@ import dataclasses import difflib import io import sys -from typing import Any, Callable, TYPE_CHECKING +from typing import Any, TYPE_CHECKING import torch import torch.fx @@ -15,6 +15,8 @@ from torch._subclasses.fake_tensor import unset_fake_temporarily if TYPE_CHECKING: + from collections.abc import Callable + from torch._subclasses import fake_tensor diff --git a/torch/onnx/_internal/fx/passes/type_promotion.py b/torch/onnx/_internal/fx/passes/type_promotion.py index a7f6be32f47..c1b4ce9f4d7 100644 --- a/torch/onnx/_internal/fx/passes/type_promotion.py +++ b/torch/onnx/_internal/fx/passes/type_promotion.py @@ -6,7 +6,7 @@ import abc import dataclasses import inspect import logging -from typing import Any, Callable, TYPE_CHECKING +from typing import Any, TYPE_CHECKING import torch import torch._dispatch.python @@ -26,7 +26,7 @@ from torch.utils import _python_dispatch, _pytree if TYPE_CHECKING: - from collections.abc import Mapping, Sequence + from collections.abc import Callable, Mapping, Sequence from types import ModuleType from torch._subclasses import fake_tensor diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index fdbab432ff4..b526efed735 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -11,7 +11,6 @@ from collections import Counter from functools import partial, wraps from typing import ( Any, - Callable, cast, Literal, Optional, @@ -29,7 +28,7 @@ from .optimizer import _to_scalar, Optimizer if TYPE_CHECKING: - from collections.abc import Iterable, Sequence + from collections.abc import Callable, Iterable, Sequence __all__ = [ diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 2ef6c48f4ef..b3ece106063 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -4,11 +4,11 @@ import functools import warnings from collections import defaultdict, OrderedDict -from collections.abc import Hashable, Iterable, Sequence +from collections.abc import Callable, Hashable, Iterable, Sequence from copy import deepcopy from itertools import chain -from typing import Any, Callable, cast, Optional, overload, TypeVar, Union -from typing_extensions import ParamSpec, Self, TypeAlias +from typing import Any, cast, Optional, overload, TypeAlias, TypeVar, Union +from typing_extensions import ParamSpec, Self import torch import torch.utils.hooks as hooks diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index d19760cfeab..d647eea6043 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -4,9 +4,9 @@ r"""Implementation for Stochastic Weight Averaging implementation.""" import itertools import math import warnings -from collections.abc import Iterable +from collections.abc import Callable, Iterable from copy import deepcopy -from typing import Any, Callable, cast, Literal, Optional, Union +from typing import Any, cast, Literal, Optional, Union from typing_extensions import override import torch diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index d9f3a917c15..35e22b25a94 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -5,8 +5,7 @@ import enum import itertools as it import logging from collections.abc import Iterator -from typing import Any, cast, Optional, Union -from typing_extensions import Literal +from typing import Any, cast, Literal, Optional, Union import torch from torch._C import FunctionSchema diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index 573541799bb..3a19550f658 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -5,10 +5,10 @@ import os import shutil import tempfile from abc import ABC, abstractmethod -from collections.abc import Iterable +from collections.abc import Callable, Iterable from enum import Enum from functools import partial -from typing import Any, Callable, Optional +from typing import Any, Optional from typing_extensions import Self from warnings import warn diff --git a/torch/signal/windows/windows.py b/torch/signal/windows/windows.py index e68c202f03e..83d62c503fe 100644 --- a/torch/signal/windows/windows.py +++ b/torch/signal/windows/windows.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs -from collections.abc import Iterable +from collections.abc import Callable, Iterable from math import sqrt -from typing import Callable, Optional, TypeVar +from typing import Optional, TypeVar import torch from torch import Tensor diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index e081e15f96d..be648fd84e7 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs import warnings from collections import namedtuple -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional import torch from torch.sparse._semi_structured_conversions import ( diff --git a/torchgen/context.py b/torchgen/context.py index a482a59eeb7..e3725d66b96 100644 --- a/torchgen/context.py +++ b/torchgen/context.py @@ -2,7 +2,7 @@ from __future__ import annotations import contextlib import functools -from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union +from typing import Any, Optional, TYPE_CHECKING, TypeVar, Union import torchgen.local as local from torchgen.model import ( @@ -16,7 +16,7 @@ from torchgen.utils import context, S, T if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Callable, Iterator # Helper functions for defining generators on things in the model diff --git a/torchgen/gen.py b/torchgen/gen.py index 7bbdd4a7a74..ae0e4b52a0f 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -8,7 +8,7 @@ import os from collections import defaultdict, namedtuple, OrderedDict from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Callable, Literal, TYPE_CHECKING, TypeVar +from typing import Any, Literal, TYPE_CHECKING, TypeVar from typing_extensions import assert_never import yaml @@ -96,7 +96,7 @@ from torchgen.yaml_utils import YamlDumper, YamlLoader if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Callable, Sequence from typing import Optional diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index f47985837ea..666e3fc1a58 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Callable, Optional, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING from torchgen.api import cpp, dispatcher, functionalization from torchgen.api.translate import translate @@ -51,6 +51,8 @@ from torchgen.utils import concatMap, dataclass_repr, FileManager if TYPE_CHECKING: + from collections.abc import Callable + from torchgen.selective_build.selector import SelectiveBuilder diff --git a/torchgen/gen_lazy_tensor.py b/torchgen/gen_lazy_tensor.py index e397561d378..ffd0aab2a28 100644 --- a/torchgen/gen_lazy_tensor.py +++ b/torchgen/gen_lazy_tensor.py @@ -4,7 +4,7 @@ import argparse import os from collections import namedtuple from pathlib import Path -from typing import Any, Callable, TYPE_CHECKING +from typing import Any, TYPE_CHECKING import yaml @@ -26,7 +26,7 @@ from torchgen.yaml_utils import YamlLoader if TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Sequence + from collections.abc import Callable, Iterable, Iterator, Sequence # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # diff --git a/torchgen/model.py b/torchgen/model.py index eb3a80dffe6..1712332128d 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -5,14 +5,14 @@ import itertools import re from dataclasses import dataclass from enum import auto, Enum -from typing import Callable, Optional, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING from typing_extensions import assert_never from torchgen.utils import NamespaceHelper, OrderedSet if TYPE_CHECKING: - from collections.abc import Iterator, Sequence + from collections.abc import Callable, Iterator, Sequence # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # diff --git a/torchgen/utils.py b/torchgen/utils.py index f6777912a8f..035e8958b40 100644 --- a/torchgen/utils.py +++ b/torchgen/utils.py @@ -11,7 +11,7 @@ from dataclasses import is_dataclass from enum import auto, Enum from pathlib import Path from pprint import pformat -from typing import Any, Callable, Generic, NoReturn, TYPE_CHECKING, TypeVar +from typing import Any, Generic, NoReturn, TYPE_CHECKING, TypeVar from typing_extensions import assert_never, deprecated, Self from torchgen.code_template import CodeTemplate @@ -19,7 +19,7 @@ from torchgen.code_template import CodeTemplate if TYPE_CHECKING: from argparse import Namespace - from collections.abc import Iterable, Iterator, Sequence + from collections.abc import Callable, Iterable, Iterator, Sequence TORCHGEN_ROOT = Path(__file__).absolute().parent