diff --git a/torch/__init__.py b/torch/__init__.py index 22f5206af65..9be88b832fa 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -22,9 +22,9 @@ import platform import sys import textwrap import threading +from collections.abc import Callable as _Callable from typing import ( Any as _Any, - Callable as _Callable, get_origin as _get_origin, Optional as _Optional, overload as _overload, diff --git a/torch/_compile.py b/torch/_compile.py index 697576a3de6..76ddd3ccb05 100644 --- a/torch/_compile.py +++ b/torch/_compile.py @@ -4,7 +4,8 @@ circular dependencies. """ import functools -from typing import Callable, Optional, overload, TypeVar, Union +from collections.abc import Callable +from typing import Optional, overload, TypeVar, Union from typing_extensions import ParamSpec diff --git a/torch/_guards.py b/torch/_guards.py index 76a35d1060e..e3d20c9fc51 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -14,16 +14,7 @@ from abc import abstractmethod from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import ( - Any, - Callable, - Generic, - NamedTuple, - Optional, - TYPE_CHECKING, - TypeVar, - Union, -) +from typing import Any, Generic, NamedTuple, Optional, TYPE_CHECKING, TypeVar, Union import torch from torch.utils import _pytree as pytree @@ -36,7 +27,7 @@ log = logging.getLogger(__name__) if TYPE_CHECKING: - from collections.abc import Generator, Iterator + from collections.abc import Callable, Generator, Iterator from types import CodeType import sympy diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 0eb457edcc0..c2ee4cd55b5 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1,9 +1,9 @@ # mypy: allow-untyped-defs import math -from collections.abc import Sequence +from collections.abc import Callable, Sequence from enum import Enum from functools import wraps -from typing import Callable, Optional, TypeVar, Union +from typing import Optional, TypeVar, Union from typing_extensions import ParamSpec import torch diff --git a/torch/_ops.py b/torch/_ops.py index c4584256b3a..8f91e072c23 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -6,19 +6,19 @@ import importlib import inspect import sys import types -from collections.abc import Iterator +from collections.abc import Callable, Iterator from functools import cached_property from typing import ( Any, - Callable, ClassVar, + Concatenate, final, Generic, Optional, TYPE_CHECKING, Union, ) -from typing_extensions import Concatenate, ParamSpec, TypeVar +from typing_extensions import ParamSpec, TypeVar import torch import torch.utils._pytree as pytree diff --git a/torch/_tensor.py b/torch/_tensor.py index ae23989bb3b..f91539b7533 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -4,10 +4,11 @@ import enum import functools import warnings from collections import OrderedDict +from collections.abc import Callable from copy import deepcopy from numbers import Number -from typing import Any, Callable, cast, Optional, TypeVar, Union -from typing_extensions import Concatenate, ParamSpec +from typing import Any, cast, Concatenate, Optional, TypeVar, Union +from typing_extensions import ParamSpec import torch import torch._C as _C diff --git a/torch/_utils.py b/torch/_utils.py index 9bd062cb5ce..68d395a90c9 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -7,8 +7,9 @@ import sys import traceback import warnings from collections import defaultdict +from collections.abc import Callable from types import ModuleType -from typing import Any, Callable, Generic, Optional, TYPE_CHECKING +from typing import Any, Generic, Optional, TYPE_CHECKING from typing_extensions import deprecated, ParamSpec import torch diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 8602ac955f1..0d56facc7ca 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -5,7 +5,8 @@ import os import sys import tempfile import typing_extensions -from typing import Any, Callable, Optional, TypeVar +from collections.abc import Callable +from typing import Any, Optional, TypeVar from typing_extensions import ParamSpec import torch diff --git a/torch/_vmap_internals.py b/torch/_vmap_internals.py index 6baee77ade5..3f303f78a47 100644 --- a/torch/_vmap_internals.py +++ b/torch/_vmap_internals.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import functools -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional, Union from typing_extensions import deprecated import torch diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index 9382a5500e0..5cc8e523406 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -27,6 +27,7 @@ import warnings from _codecs import encode from collections import Counter, OrderedDict +from collections.abc import Callable from pickle import ( APPEND, APPENDS, @@ -68,7 +69,7 @@ from pickle import ( ) from struct import unpack from sys import maxsize -from typing import Any, Callable, Union +from typing import Any, Union import torch from torch._utils import _sparse_tensors_to_validate, IMPORT_MAPPING, NAME_MAPPING diff --git a/torch/ao/ns/_numeric_suite.py b/torch/ao/ns/_numeric_suite.py index 96d24a2cf2e..1c8e751b1eb 100644 --- a/torch/ao/ns/_numeric_suite.py +++ b/torch/ao/ns/_numeric_suite.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional, Union import torch import torch.ao.nn.quantized as nnq diff --git a/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py b/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py index 680ecd9f139..a716d91bbb8 100644 --- a/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py +++ b/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs -from typing import Callable, Optional, Union +from collections.abc import Callable +from typing import Optional, Union import torch diff --git a/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py b/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py index ffbb99bb296..826ad95bf63 100644 --- a/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py +++ b/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs +from collections.abc import Callable from itertools import chain from operator import getitem -from typing import Callable, Optional, Union +from typing import Optional, Union import torch import torch.nn.functional as F diff --git a/torch/ao/pruning/_experimental/pruner/prune_functions.py b/torch/ao/pruning/_experimental/pruner/prune_functions.py index a1882af4ca1..d5e4b7823dc 100644 --- a/torch/ao/pruning/_experimental/pruner/prune_functions.py +++ b/torch/ao/pruning/_experimental/pruner/prune_functions.py @@ -4,7 +4,8 @@ Collection of conversion functions for linear / conv2d structured pruning Also contains utilities for bias propagation """ -from typing import Callable, cast, Optional +from collections.abc import Callable +from typing import cast, Optional import torch from torch import nn, Tensor diff --git a/torch/ao/pruning/scheduler/lambda_scheduler.py b/torch/ao/pruning/scheduler/lambda_scheduler.py index 5588c157161..7c0e8088890 100644 --- a/torch/ao/pruning/scheduler/lambda_scheduler.py +++ b/torch/ao/pruning/scheduler/lambda_scheduler.py @@ -1,5 +1,6 @@ import warnings -from typing import Callable, Union +from collections.abc import Callable +from typing import Union from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier diff --git a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py index c9577cbb79a..c3541ac83ca 100644 --- a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py +++ b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs import operator +from collections.abc import Callable from functools import reduce -from typing import Callable, Optional, Union +from typing import Optional, Union import torch import torch.nn.functional as F diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index f50b9d6cd13..3bf5d82f190 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs import sys -from typing import Callable, Optional, Union +from collections.abc import Callable +from typing import Optional, Union import torch from torch import Tensor diff --git a/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/torch/ao/quantization/backend_config/_common_operator_config_utils.py index 781bfdc8b39..ab44cfa0919 100644 --- a/torch/ao/quantization/backend_config/_common_operator_config_utils.py +++ b/torch/ao/quantization/backend_config/_common_operator_config_utils.py @@ -2,7 +2,8 @@ import copy import operator from collections import namedtuple -from typing import Callable, Union +from collections.abc import Callable +from typing import Union import torch import torch.ao.nn.intrinsic as nni diff --git a/torch/ao/quantization/backend_config/backend_config.py b/torch/ao/quantization/backend_config/backend_config.py index 3919b84da28..a4aa3f9a2b8 100644 --- a/torch/ao/quantization/backend_config/backend_config.py +++ b/torch/ao/quantization/backend_config/backend_config.py @@ -3,12 +3,14 @@ from __future__ import annotations from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING, Union import torch if TYPE_CHECKING: + from collections.abc import Callable + from torch.ao.quantization.utils import Pattern diff --git a/torch/ao/quantization/backend_config/utils.py b/torch/ao/quantization/backend_config/utils.py index 97dd6007c7f..65094392abf 100644 --- a/torch/ao/quantization/backend_config/utils.py +++ b/torch/ao/quantization/backend_config/utils.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs -from typing import Any, Callable, Union +from collections.abc import Callable +from typing import Any, Union import torch import torch.nn as nn diff --git a/torch/ao/quantization/experimental/adaround_optimization.py b/torch/ao/quantization/experimental/adaround_optimization.py index fd2d8124bb7..9082d6c0f99 100644 --- a/torch/ao/quantization/experimental/adaround_optimization.py +++ b/torch/ao/quantization/experimental/adaround_optimization.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import copy -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional, Union import torch from torch.ao.quantization.experimental.adaround_fake_quantize import ( diff --git a/torch/ao/quantization/experimental/fake_quantize.py b/torch/ao/quantization/experimental/fake_quantize.py index b18b5e133f1..7c302c3d6f4 100644 --- a/torch/ao/quantization/experimental/fake_quantize.py +++ b/torch/ao/quantization/experimental/fake_quantize.py @@ -1,4 +1,5 @@ -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import torch from torch import Tensor diff --git a/torch/ao/quantization/fuser_method_mappings.py b/torch/ao/quantization/fuser_method_mappings.py index 260bbee37bd..69dfe760613 100644 --- a/torch/ao/quantization/fuser_method_mappings.py +++ b/torch/ao/quantization/fuser_method_mappings.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import itertools -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional, Union import torch.ao.nn.intrinsic as nni import torch.nn as nn diff --git a/torch/ao/quantization/pt2e/_numeric_debugger.py b/torch/ao/quantization/pt2e/_numeric_debugger.py index 81c6a2060e7..75b20751425 100644 --- a/torch/ao/quantization/pt2e/_numeric_debugger.py +++ b/torch/ao/quantization/pt2e/_numeric_debugger.py @@ -1,8 +1,8 @@ import copy import logging -from collections.abc import Sequence +from collections.abc import Callable, Sequence from dataclasses import dataclass -from typing import Callable, Optional +from typing import Optional import torch from torch.ao.ns.fx.utils import compute_sqnr diff --git a/torch/ao/quantization/pt2e/graph_utils.py b/torch/ao/quantization/pt2e/graph_utils.py index 80520d1ef0d..e943981adfb 100644 --- a/torch/ao/quantization/pt2e/graph_utils.py +++ b/torch/ao/quantization/pt2e/graph_utils.py @@ -2,8 +2,8 @@ import itertools import operator from collections import OrderedDict -from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Sequence +from typing import Any, Optional, Union import torch from torch.export import ExportedProgram diff --git a/torch/ao/quantization/pt2e/qat_utils.py b/torch/ao/quantization/pt2e/qat_utils.py index b9ce762896f..982230bf860 100644 --- a/torch/ao/quantization/pt2e/qat_utils.py +++ b/torch/ao/quantization/pt2e/qat_utils.py @@ -3,7 +3,8 @@ import copy import dataclasses import itertools import operator -from typing import Any, Callable, Optional, TYPE_CHECKING +from collections.abc import Callable +from typing import Any, Optional, TYPE_CHECKING import torch import torch.nn.functional as F diff --git a/torch/ao/quantization/pt2e/representation/rewrite.py b/torch/ao/quantization/pt2e/representation/rewrite.py index 5a757a70049..0f055cc3019 100644 --- a/torch/ao/quantization/pt2e/representation/rewrite.py +++ b/torch/ao/quantization/pt2e/representation/rewrite.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs +from collections.abc import Callable from dataclasses import dataclass from functools import partial -from typing import Any, Callable, Optional +from typing import Any, Optional import torch from torch._export.utils import _disable_aten_to_metadata_assertions diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index fc1da49cde5..3a90bb953f1 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs import operator import types -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional, Union import torch import torch.ao.quantization.pt2e._affine_quantization # noqa: F401 diff --git a/torch/ao/quantization/qconfig_mapping.py b/torch/ao/quantization/qconfig_mapping.py index bd34a6b8a1f..10111d4ab8a 100644 --- a/torch/ao/quantization/qconfig_mapping.py +++ b/torch/ao/quantization/qconfig_mapping.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections import OrderedDict -from typing import Any, Callable, Union +from typing import Any, TYPE_CHECKING, Union import torch @@ -26,6 +26,10 @@ from .qconfig import ( ) +if TYPE_CHECKING: + from collections.abc import Callable + + __all__ = [ "get_default_qconfig_mapping", "get_default_qat_qconfig_mapping", diff --git a/torch/ao/quantization/quantization_mappings.py b/torch/ao/quantization/quantization_mappings.py index e22fba05bbc..b8f1e8b4e01 100644 --- a/torch/ao/quantization/quantization_mappings.py +++ b/torch/ao/quantization/quantization_mappings.py @@ -1,5 +1,6 @@ import copy -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional, Union import torch import torch.ao.nn as ao_nn diff --git a/torch/ao/quantization/quantizer/quantizer.py b/torch/ao/quantization/quantizer/quantizer.py index 450f683689f..91c7159a89a 100644 --- a/torch/ao/quantization/quantizer/quantizer.py +++ b/torch/ao/quantization/quantizer/quantizer.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass, field -from typing import Annotated, Callable, Optional, Union +from typing import Annotated, Optional, Union import torch from torch import Tensor diff --git a/torch/ao/quantization/quantizer/utils.py b/torch/ao/quantization/quantizer/utils.py index e21060ba204..7c65a8e6801 100644 --- a/torch/ao/quantization/quantizer/utils.py +++ b/torch/ao/quantization/quantizer/utils.py @@ -1,4 +1,5 @@ -from typing import Callable, Optional +from collections.abc import Callable +from typing import Optional from torch.ao.quantization.pt2e.utils import _is_sym_size_node from torch.ao.quantization.quantizer.quantizer import ( diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 87551ce5b89..a8637e1668c 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -3,10 +3,9 @@ import functools import itertools import operator import warnings -from collections.abc import Sequence +from collections.abc import Callable, Sequence from dataclasses import dataclass -from typing import Any, Callable, Optional, TYPE_CHECKING, Union -from typing_extensions import TypeAlias +from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union import torch import torch.nn.functional as F diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index 6005152a4d7..177203e8ff4 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -4,7 +4,7 @@ from __future__ import annotations import copy import functools import typing_extensions -from typing import Any, Callable, Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING import torch import torch._dynamo as torchdynamo @@ -35,6 +35,8 @@ from torch.fx._compatibility import compatibility if TYPE_CHECKING: + from collections.abc import Callable + from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor from torch.fx import Node diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py index f8ac0a7727d..dec59bb02df 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py @@ -1,8 +1,9 @@ # mypy: allow-untyped-defs import itertools import typing +from collections.abc import Callable from dataclasses import dataclass -from typing import Callable, NamedTuple, Optional +from typing import NamedTuple, Optional import torch import torch.nn.functional as F diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py index 0061342151a..1874dc6e20b 100644 --- a/torch/ao/quantization/utils.py +++ b/torch/ao/quantization/utils.py @@ -7,8 +7,9 @@ import functools import sys import warnings from collections import OrderedDict +from collections.abc import Callable from inspect import getfullargspec, signature -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union import torch from torch.ao.quantization.quant_type import QuantType diff --git a/torch/autograd/function.py b/torch/autograd/function.py index ac3aad9f93b..d25d93d7274 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -4,8 +4,9 @@ import inspect import itertools import warnings from collections import OrderedDict -from typing import Any, Callable, Optional, TypeVar -from typing_extensions import Concatenate, deprecated, ParamSpec +from collections.abc import Callable +from typing import Any, Concatenate, Optional, TypeVar +from typing_extensions import deprecated, ParamSpec import torch import torch._C as _C diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index bb19894d288..7e649f1f00a 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -2,9 +2,9 @@ import collections import functools import warnings -from collections.abc import Iterable +from collections.abc import Callable, Iterable from itertools import product -from typing import Callable, Optional, Union +from typing import Optional, Union from typing_extensions import deprecated import torch diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index 4b2707b65d0..5de55efda97 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -4,18 +4,24 @@ import functools import logging import threading from collections import defaultdict, deque -from collections.abc import Generator, Iterable, Iterator, MutableMapping, Sequence +from collections.abc import ( + Callable, + Generator, + Iterable, + Iterator, + MutableMapping, + Sequence, +) from typing import ( Any, - Callable, cast, Literal, NamedTuple, Optional, TYPE_CHECKING, + TypeAlias, Union, ) -from typing_extensions import TypeAlias from weakref import WeakKeyDictionary, WeakValueDictionary import torch diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 2d7b7d8cf8b..bf562b68f73 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -17,8 +17,9 @@ import sys import threading import traceback import warnings +from collections.abc import Callable from functools import lru_cache -from typing import Any, Callable, cast, NewType, Optional, TYPE_CHECKING, Union +from typing import Any, cast, NewType, Optional, TYPE_CHECKING, Union import torch import torch._C diff --git a/torch/cuda/_gpu_trace.py b/torch/cuda/_gpu_trace.py index 9a23a8a2abc..d3b8f7e4626 100644 --- a/torch/cuda/_gpu_trace.py +++ b/torch/cuda/_gpu_trace.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable from torch._utils import CallbackRegistry diff --git a/torch/cuda/gds.py b/torch/cuda/gds.py index d3922499682..5a7dfa388ca 100644 --- a/torch/cuda/gds.py +++ b/torch/cuda/gds.py @@ -1,6 +1,7 @@ import os import sys -from typing import Callable, Optional +from collections.abc import Callable +from typing import Optional import torch from torch.types import Storage diff --git a/torch/cuda/graphs.py b/torch/cuda/graphs.py index 3946b7b3360..0a6cbe7a0ae 100644 --- a/torch/cuda/graphs.py +++ b/torch/cuda/graphs.py @@ -2,8 +2,9 @@ from __future__ import annotations import gc import typing -from typing import Callable, Optional, overload, TYPE_CHECKING, Union -from typing_extensions import ParamSpec, Self, TypeAlias, TypeVar +from collections.abc import Callable +from typing import Optional, overload, TYPE_CHECKING, TypeAlias, Union +from typing_extensions import ParamSpec, Self, TypeVar import torch from torch import Tensor diff --git a/torch/cuda/jiterator.py b/torch/cuda/jiterator.py index 8bcb14d9fcf..6eaa54b5b79 100644 --- a/torch/cuda/jiterator.py +++ b/torch/cuda/jiterator.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import re -from typing import Callable +from collections.abc import Callable import torch from torch import Tensor diff --git a/torch/library.py b/torch/library.py index 57e6dd069f9..0ac29cfde3f 100644 --- a/torch/library.py +++ b/torch/library.py @@ -6,8 +6,8 @@ import re import sys import traceback import weakref -from collections.abc import Sequence -from typing import Any, Callable, Optional, overload, TYPE_CHECKING, TypeVar, Union +from collections.abc import Callable, Sequence +from typing import Any, Optional, overload, TYPE_CHECKING, TypeVar, Union from typing_extensions import deprecated, ParamSpec import torch diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 668f47c15bc..ee83d0f346b 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -14,7 +14,7 @@ __all__ = [ "ONNXProgram", ] -from typing import Any, Callable, TYPE_CHECKING +from typing import Any, TYPE_CHECKING import torch from torch._C import _onnx as _C_onnx @@ -45,7 +45,7 @@ from .errors import OnnxExporterError if TYPE_CHECKING: import os - from collections.abc import Collection, Mapping, Sequence + from collections.abc import Callable, Collection, Mapping, Sequence # Set namespace for exposed private names ONNXProgram.__module__ = "torch.onnx" diff --git a/torch/onnx/_internal/exporter/_capture_strategies.py b/torch/onnx/_internal/exporter/_capture_strategies.py index 89a2b7e9e5e..b7f9016ae6d 100644 --- a/torch/onnx/_internal/exporter/_capture_strategies.py +++ b/torch/onnx/_internal/exporter/_capture_strategies.py @@ -9,7 +9,7 @@ import dataclasses import datetime import logging import pathlib -from typing import Any, Callable, TYPE_CHECKING +from typing import Any, TYPE_CHECKING import torch from torch.onnx import _flags @@ -17,6 +17,7 @@ from torch.onnx import _flags if TYPE_CHECKING: import os + from collections.abc import Callable logger = logging.getLogger(__name__) diff --git a/torch/onnx/_internal/exporter/_compat.py b/torch/onnx/_internal/exporter/_compat.py index 0bc0c6182fc..78a54c270d5 100644 --- a/torch/onnx/_internal/exporter/_compat.py +++ b/torch/onnx/_internal/exporter/_compat.py @@ -7,8 +7,8 @@ from __future__ import annotations import io import logging import warnings -from collections.abc import Mapping, Sequence -from typing import Any, Callable, TYPE_CHECKING +from collections.abc import Callable, Mapping, Sequence +from typing import Any, TYPE_CHECKING import torch from torch.onnx import _constants as onnx_constants diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index cdd72447ddb..7e7f206c80f 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -12,8 +12,8 @@ import pathlib import textwrap import traceback import typing -from collections.abc import Mapping, Sequence -from typing import Any, Callable, Literal +from collections.abc import Callable, Mapping, Sequence +from typing import Any, Literal import onnxscript import onnxscript.evaluator diff --git a/torch/onnx/_internal/exporter/_decomp.py b/torch/onnx/_internal/exporter/_decomp.py index 9227a6ee52f..4988706404e 100644 --- a/torch/onnx/_internal/exporter/_decomp.py +++ b/torch/onnx/_internal/exporter/_decomp.py @@ -2,13 +2,15 @@ from __future__ import annotations import itertools -from typing import Callable, TYPE_CHECKING +from typing import TYPE_CHECKING import torch import torch._ops if TYPE_CHECKING: + from collections.abc import Callable + from torch.onnx._internal.exporter import _registration diff --git a/torch/onnx/_internal/exporter/_dispatching.py b/torch/onnx/_internal/exporter/_dispatching.py index 141cb76deac..1f935cfed19 100644 --- a/torch/onnx/_internal/exporter/_dispatching.py +++ b/torch/onnx/_internal/exporter/_dispatching.py @@ -2,8 +2,8 @@ from __future__ import annotations import logging -from collections.abc import Sequence -from typing import Any, Callable +from collections.abc import Callable, Sequence +from typing import Any from onnxscript import ir diff --git a/torch/onnx/_internal/exporter/_flags.py b/torch/onnx/_internal/exporter/_flags.py index 0f07508f831..8e9d8c9db6e 100644 --- a/torch/onnx/_internal/exporter/_flags.py +++ b/torch/onnx/_internal/exporter/_flags.py @@ -3,10 +3,14 @@ from __future__ import annotations import functools -from typing import Callable, TypeVar +from typing import TYPE_CHECKING, TypeVar from typing_extensions import ParamSpec +if TYPE_CHECKING: + from collections.abc import Callable + + _is_onnx_exporting = False # Use ParamSpec to preserve parameter types instead of erasing to Any diff --git a/torch/onnx/_internal/exporter/_isolated.py b/torch/onnx/_internal/exporter/_isolated.py index ea575f07a5e..461590ec9eb 100644 --- a/torch/onnx/_internal/exporter/_isolated.py +++ b/torch/onnx/_internal/exporter/_isolated.py @@ -5,10 +5,14 @@ from __future__ import annotations import multiprocessing import os import warnings -from typing import Any, Callable, TypeVar, TypeVarTuple, Union, Unpack +from typing import Any, TYPE_CHECKING, TypeVar, TypeVarTuple, Union, Unpack from typing_extensions import ParamSpec +if TYPE_CHECKING: + from collections.abc import Callable + + _P = ParamSpec("_P") _R = TypeVar("_R") _Ts = TypeVarTuple("_Ts") diff --git a/torch/onnx/_internal/exporter/_onnx_program.py b/torch/onnx/_internal/exporter/_onnx_program.py index 62333289fad..17f646c9337 100644 --- a/torch/onnx/_internal/exporter/_onnx_program.py +++ b/torch/onnx/_internal/exporter/_onnx_program.py @@ -13,8 +13,8 @@ import os import tempfile import textwrap import warnings -from collections.abc import Sequence -from typing import Any, Callable, TYPE_CHECKING +from collections.abc import Callable, Sequence +from typing import Any, TYPE_CHECKING import torch from torch.onnx._internal._lazy_import import onnx, onnxscript_apis, onnxscript_ir as ir diff --git a/torch/onnx/_internal/exporter/_registration.py b/torch/onnx/_internal/exporter/_registration.py index fefc8022d7e..0dd23819af1 100644 --- a/torch/onnx/_internal/exporter/_registration.py +++ b/torch/onnx/_internal/exporter/_registration.py @@ -18,8 +18,8 @@ import logging import math import operator import types -from typing import Callable, Literal, Union -from typing_extensions import TypeAlias +from collections.abc import Callable +from typing import Literal, TypeAlias, Union import torch import torch._ops diff --git a/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py b/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py index 8c045d11a2b..a2f86a6ccf2 100644 --- a/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py +++ b/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py @@ -8,8 +8,8 @@ from __future__ import annotations __all__ = ["onnx_impl", "get_torchlib_ops"] import logging -from collections.abc import Sequence -from typing import Any, Callable, TypeVar +from collections.abc import Callable, Sequence +from typing import Any, TypeVar from typing_extensions import ParamSpec import onnxscript diff --git a/torch/onnx/_internal/torchscript_exporter/registration.py b/torch/onnx/_internal/torchscript_exporter/registration.py index b8bba134f36..f073227f87b 100644 --- a/torch/onnx/_internal/torchscript_exporter/registration.py +++ b/torch/onnx/_internal/torchscript_exporter/registration.py @@ -2,8 +2,8 @@ """Module for handling symbolic function registration.""" import warnings -from collections.abc import Collection, Sequence -from typing import Callable, Generic, Optional, TypeVar, Union +from collections.abc import Callable, Collection, Sequence +from typing import Generic, Optional, TypeVar, Union from typing_extensions import ParamSpec from torch.onnx import _constants, errors diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py b/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py index 73e242ca323..cd7763bf41e 100644 --- a/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py @@ -103,8 +103,14 @@ import math import sys import typing import warnings -from typing import Any, Callable, Literal, NoReturn, TypeVar as _TypeVar -from typing_extensions import Concatenate as _Concatenate, ParamSpec as _ParamSpec +from typing import ( + Any, + Concatenate as _Concatenate, + Literal, + NoReturn, + TypeVar as _TypeVar, +) +from typing_extensions import ParamSpec as _ParamSpec import torch import torch._C._onnx as _C_onnx @@ -115,7 +121,7 @@ from torch.onnx._internal.torchscript_exporter._globals import GLOBALS if typing.TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Callable, Sequence from torch.types import Number diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py index a0b79bd619b..65657f6a91c 100644 --- a/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py @@ -14,7 +14,7 @@ import functools import math import sys import warnings -from typing import Callable, TYPE_CHECKING +from typing import TYPE_CHECKING from typing_extensions import deprecated import torch @@ -33,7 +33,7 @@ from torch.onnx._internal.torchscript_exporter._globals import GLOBALS if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Callable, Sequence from torch.types import Number diff --git a/torch/onnx/_internal/torchscript_exporter/utils.py b/torch/onnx/_internal/torchscript_exporter/utils.py index 2a7339c27e0..78c903d0277 100644 --- a/torch/onnx/_internal/torchscript_exporter/utils.py +++ b/torch/onnx/_internal/torchscript_exporter/utils.py @@ -62,7 +62,7 @@ import inspect import re import typing import warnings -from typing import Any, Callable, cast +from typing import Any, cast from typing_extensions import deprecated import torch @@ -80,7 +80,7 @@ from torch.onnx._internal.torchscript_exporter._globals import GLOBALS if typing.TYPE_CHECKING: - from collections.abc import Collection, Mapping, Sequence + from collections.abc import Callable, Collection, Mapping, Sequence # TODO(justinchuby): Remove dependency to this global variable from constant_fold.cpp diff --git a/torch/onnx/ops/__init__.py b/torch/onnx/ops/__init__.py index d10ba1ac7a3..8da3fc8e587 100644 --- a/torch/onnx/ops/__init__.py +++ b/torch/onnx/ops/__init__.py @@ -17,14 +17,14 @@ __all__ = [ ] -from typing import Callable, TYPE_CHECKING +from typing import TYPE_CHECKING import torch from torch.onnx.ops import _impl, _symbolic_impl if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Callable, Sequence # https://github.com/onnx/onnx/blob/f542e1f06699ea7e1db5f62af53355b64338c723/onnx/onnx.proto#L597 diff --git a/torch/onnx/ops/_impl.py b/torch/onnx/ops/_impl.py index a7eba334ecf..f5e3721111d 100644 --- a/torch/onnx/ops/_impl.py +++ b/torch/onnx/ops/_impl.py @@ -1,6 +1,7 @@ # flake8: noqa: B950 import math -from typing import Callable, Optional, TypeVar +from collections.abc import Callable +from typing import Optional, TypeVar from typing_extensions import ParamSpec import torch diff --git a/torch/overrides.py b/torch/overrides.py index 8dc238d114b..0e4c2252531 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -27,9 +27,9 @@ import contextlib import functools import types import warnings -from collections.abc import Iterable +from collections.abc import Callable, Iterable from functools import wraps -from typing import Any, Callable, Optional, TypeVar +from typing import Any, Optional, TypeVar from typing_extensions import ParamSpec import torch diff --git a/torch/serialization.py b/torch/serialization.py index a6eb314fc1a..45a44ea37c0 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -13,10 +13,11 @@ import tarfile import tempfile import threading import warnings +from collections.abc import Callable from contextlib import closing, contextmanager from enum import Enum -from typing import Any, Callable, cast, Generic, IO, Optional, TypeVar, Union -from typing_extensions import TypeAlias, TypeIs +from typing import Any, cast, Generic, IO, Optional, TypeAlias, TypeVar, Union +from typing_extensions import TypeIs import torch import torch._weights_only_unpickler as _weights_only_unpickler diff --git a/torch/types.py b/torch/types.py index ab6f4639f44..01a62ff4d01 100644 --- a/torch/types.py +++ b/torch/types.py @@ -12,8 +12,8 @@ from builtins import ( # noqa: F401 str as _str, ) from collections.abc import Sequence -from typing import Any, IO, TYPE_CHECKING, Union -from typing_extensions import Self, TypeAlias +from typing import Any, IO, TYPE_CHECKING, TypeAlias, Union +from typing_extensions import Self # `as` imports have better static analysis support than assignment `ExposedType: TypeAlias = HiddenType` from torch import ( # noqa: F401