[BE][Easy][19/19] enforce style for empty lines in import segments in torch/[o-z]*/ (#129771)

See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129771
Approved by: https://github.com/justinchuby, https://github.com/janeyx99
This commit is contained in:
Xuehai Pan 2024-07-31 19:56:45 +08:00 committed by PyTorch MergeBot
parent c59f3fff52
commit 30293319a8
120 changed files with 163 additions and 101 deletions

View File

@ -56,7 +56,6 @@ ISORT_SKIPLIST = re.compile(
# torch/[e-n]*/** # torch/[e-n]*/**
"torch/[e-n]*/**", "torch/[e-n]*/**",
# torch/[o-z]*/** # torch/[o-z]*/**
"torch/[o-z]*/**",
], ],
), ),
) )

View File

@ -3,7 +3,29 @@ from torch import _C
from torch._C import _onnx as _C_onnx from torch._C import _onnx as _C_onnx
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
from . import ( # usort:skip. Keep the order instead of sorting lexicographically from ._exporter_states import ExportTypes
from ._internal.onnxruntime import (
is_onnxrt_backend_supported,
OrtBackend as _OrtBackend,
OrtBackendOptions as _OrtBackendOptions,
OrtExecutionProvider as _OrtExecutionProvider,
)
from ._type_utils import JitScalarType
from .errors import CheckerError # Backwards compatibility
from .utils import (
_optimize_graph,
_run_symbolic_function,
_run_symbolic_method,
export,
export_to_pretty_string,
is_in_onnx_export,
register_custom_op_symbolic,
select_model_mode_for_export,
unregister_custom_op_symbolic,
)
from . import ( # usort: skip. Keep the order instead of sorting lexicographically
_deprecation, _deprecation,
errors, errors,
symbolic_caffe2, symbolic_caffe2,
@ -25,22 +47,8 @@ from . import ( # usort:skip. Keep the order instead of sorting lexicographical
utils, utils,
) )
from ._exporter_states import ExportTypes
from ._type_utils import JitScalarType
from .errors import CheckerError # Backwards compatibility
from .utils import (
_optimize_graph,
_run_symbolic_function,
_run_symbolic_method,
export,
export_to_pretty_string,
is_in_onnx_export,
register_custom_op_symbolic,
select_model_mode_for_export,
unregister_custom_op_symbolic,
)
from ._internal.exporter import ( # usort:skip. needs to be last to avoid circular import from ._internal.exporter import ( # usort: skip. needs to be last to avoid circular import
DiagnosticOptions, DiagnosticOptions,
ExportOptions, ExportOptions,
ONNXProgram, ONNXProgram,
@ -53,12 +61,6 @@ from ._internal.exporter import ( # usort:skip. needs to be last to avoid circu
enable_fake_mode, enable_fake_mode,
) )
from ._internal.onnxruntime import (
is_onnxrt_backend_supported,
OrtBackend as _OrtBackend,
OrtBackendOptions as _OrtBackendOptions,
OrtExecutionProvider as _OrtExecutionProvider,
)
__all__ = [ __all__ = [
# Modules # Modules

View File

@ -6,6 +6,7 @@ import warnings
from typing import Callable, TypeVar from typing import Callable, TypeVar
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
_T = TypeVar("_T") _T = TypeVar("_T")
_P = ParamSpec("_P") _P = ParamSpec("_P")

View File

@ -9,6 +9,7 @@ from ._diagnostic import (
from ._rules import rules from ._rules import rules
from .infra import levels from .infra import levels
__all__ = [ __all__ = [
"TorchScriptOnnxExportDiagnostic", "TorchScriptOnnxExportDiagnostic",
"ExportDiagnosticEngine", "ExportDiagnosticEngine",

View File

@ -7,12 +7,12 @@ import gzip
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
from torch.onnx._internal.diagnostics import infra from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import formatter, sarif from torch.onnx._internal.diagnostics.infra import formatter, sarif
from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
from torch.utils import cpp_backtrace from torch.utils import cpp_backtrace
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Generator from collections.abc import Generator

View File

@ -13,6 +13,7 @@ from typing import Tuple
# flake8: noqa # flake8: noqa
from torch.onnx._internal.diagnostics import infra from torch.onnx._internal.diagnostics import infra
""" """
GENERATED CODE - DO NOT EDIT DIRECTLY GENERATED CODE - DO NOT EDIT DIRECTLY
The purpose of generating a class for each rule is to override the `format_message` The purpose of generating a class for each rule is to override the `format_message`

View File

@ -14,6 +14,7 @@ from ._infra import (
) )
from .context import Diagnostic, DiagnosticContext, RuntimeErrorWithDiagnostic from .context import Diagnostic, DiagnosticContext, RuntimeErrorWithDiagnostic
__all__ = [ __all__ = [
"Diagnostic", "Diagnostic",
"DiagnosticContext", "DiagnosticContext",

View File

@ -4,14 +4,10 @@
from __future__ import annotations from __future__ import annotations
import contextlib import contextlib
import dataclasses import dataclasses
import gzip import gzip
import logging import logging
from typing import Callable, Generator, Generic, Literal, Mapping, TypeVar from typing import Callable, Generator, Generic, Literal, Mapping, TypeVar
from typing_extensions import Self from typing_extensions import Self
from torch.onnx._internal.diagnostics import infra from torch.onnx._internal.diagnostics import infra

View File

@ -7,7 +7,6 @@ import traceback
from typing import Any, Callable, Union from typing import Any, Callable, Union
from torch._logging import LazyString from torch._logging import LazyString
from torch.onnx._internal.diagnostics.infra import sarif from torch.onnx._internal.diagnostics.infra import sarif

View File

@ -97,4 +97,5 @@ from torch.onnx._internal.diagnostics.infra.sarif._version_control_details impor
from torch.onnx._internal.diagnostics.infra.sarif._web_request import WebRequest from torch.onnx._internal.diagnostics.infra.sarif._web_request import WebRequest
from torch.onnx._internal.diagnostics.infra.sarif._web_response import WebResponse from torch.onnx._internal.diagnostics.infra.sarif._web_response import WebResponse
# flake8: noqa # flake8: noqa

View File

@ -1,5 +1,6 @@
from typing import Final from typing import Final
SARIF_VERSION: Final = "2.1.0" SARIF_VERSION: Final = "2.1.0"
SARIF_SCHEMA_LINK: Final = "https://docs.oasis-open.org/sarif/sarif/v2.1.0/cs01/schemas/sarif-schema-2.1.0.json" SARIF_SCHEMA_LINK: Final = "https://docs.oasis-open.org/sarif/sarif/v2.1.0/cs01/schemas/sarif-schema-2.1.0.json"
# flake8: noqa # flake8: noqa

View File

@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import functools import functools
import inspect import inspect
import traceback import traceback
from typing import Any, Callable, Mapping, Sequence from typing import Any, Callable, Mapping, Sequence

View File

@ -4,12 +4,10 @@ from __future__ import ( # for onnx.ModelProto (ONNXProgram) and onnxruntime (O
) )
import abc import abc
import contextlib import contextlib
import dataclasses import dataclasses
import logging import logging
import os import os
import tempfile import tempfile
import warnings import warnings
from collections import defaultdict from collections import defaultdict
@ -27,11 +25,9 @@ from typing import (
from typing_extensions import Self from typing_extensions import Self
import torch import torch
import torch._ops import torch._ops
import torch.export as torch_export import torch.export as torch_export
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
from torch.onnx._internal import io_adapter from torch.onnx._internal import io_adapter
from torch.onnx._internal.diagnostics import infra from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.fx import ( from torch.onnx._internal.fx import (
@ -41,6 +37,7 @@ from torch.onnx._internal.fx import (
serialization as fx_serialization, serialization as fx_serialization,
) )
# We can only import onnx from this module in a type-checking context to ensure that # We can only import onnx from this module in a type-checking context to ensure that
# 'import torch.onnx' continues to work without having 'onnx' installed. We fully # 'import torch.onnx' continues to work without having 'onnx' installed. We fully
# 'import onnx' inside of dynamo_export (by way of _assert_dependencies). # 'import onnx' inside of dynamo_export (by way of _assert_dependencies).
@ -48,6 +45,7 @@ if TYPE_CHECKING:
import io import io
import onnx import onnx
import onnxruntime # type: ignore[import] import onnxruntime # type: ignore[import]
import onnxscript # type: ignore[import] import onnxscript # type: ignore[import]
from onnxscript.function_libs.torch_lib import ( # type: ignore[import] from onnxscript.function_libs.torch_lib import ( # type: ignore[import]
@ -55,7 +53,6 @@ if TYPE_CHECKING:
) )
from torch._subclasses import fake_tensor from torch._subclasses import fake_tensor
from torch.onnx._internal.fx import diagnostics from torch.onnx._internal.fx import diagnostics
_DEFAULT_OPSET_VERSION: Final[int] = 18 _DEFAULT_OPSET_VERSION: Final[int] = 18

View File

@ -2,23 +2,20 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import contextlib import contextlib
import dataclasses import dataclasses
import difflib import difflib
import io import io
import logging import logging
import sys import sys
from typing import Any, Callable, TYPE_CHECKING from typing import Any, Callable, TYPE_CHECKING
import torch import torch
import torch.fx import torch.fx
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher
if TYPE_CHECKING: if TYPE_CHECKING:
from torch._subclasses import fake_tensor from torch._subclasses import fake_tensor

View File

@ -1,5 +1,6 @@
from .unsupported_nodes import UnsupportedFxNodesAnalysis from .unsupported_nodes import UnsupportedFxNodesAnalysis
__all__ = [ __all__ = [
"UnsupportedFxNodesAnalysis", "UnsupportedFxNodesAnalysis",
] ]

View File

@ -15,7 +15,6 @@ from __future__ import annotations
import abc import abc
import contextlib import contextlib
from typing import Callable, Sequence from typing import Callable, Sequence
from onnxscript.function_libs.torch_lib.ops import ( # type: ignore[import-not-found] from onnxscript.function_libs.torch_lib.ops import ( # type: ignore[import-not-found]
@ -26,6 +25,7 @@ from onnxscript.function_libs.torch_lib.ops import ( # type: ignore[import-not-
import torch import torch
from torch._decomp import decompositions from torch._decomp import decompositions
_NEW_OP_NAMESPACE: str = "onnx_export" _NEW_OP_NAMESPACE: str = "onnx_export"
"""The namespace for the custom operator.""" """The namespace for the custom operator."""

View File

@ -8,7 +8,6 @@ from typing import Callable
import torch import torch
import torch._ops import torch._ops
import torch.fx import torch.fx
from torch.onnx._internal.fx import registration from torch.onnx._internal.fx import registration

View File

@ -2,9 +2,7 @@
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
import functools import functools
from typing import Any, TYPE_CHECKING from typing import Any, TYPE_CHECKING
import onnxscript # type: ignore[import] import onnxscript # type: ignore[import]
@ -17,6 +15,7 @@ from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import decorator, formatter from torch.onnx._internal.diagnostics.infra import decorator, formatter
from torch.onnx._internal.fx import registration, type_utils as fx_type_utils from torch.onnx._internal.fx import registration, type_utils as fx_type_utils
if TYPE_CHECKING: if TYPE_CHECKING:
import logging import logging

View File

@ -16,7 +16,6 @@ from onnxscript.function_libs.torch_lib import ( # type: ignore[import]
import torch import torch
import torch.fx import torch.fx
from torch.onnx import _type_utils as jit_type_utils from torch.onnx import _type_utils as jit_type_utils
from torch.onnx._internal.fx import ( from torch.onnx._internal.fx import (
_pass, _pass,
diagnostics, diagnostics,

View File

@ -2,16 +2,15 @@
from __future__ import annotations from __future__ import annotations
import functools import functools
from typing import Any, Callable, Mapping, Sequence from typing import Any, Callable, Mapping, Sequence
import torch import torch
import torch.fx import torch.fx
import torch.onnx import torch.onnx
import torch.onnx._internal.fx.passes as passes import torch.onnx._internal.fx.passes as passes
from torch.onnx._internal import exporter, io_adapter from torch.onnx._internal import exporter, io_adapter
# Functions directly wrapped to produce torch.fx.Proxy so that symbolic # Functions directly wrapped to produce torch.fx.Proxy so that symbolic
# data can flow through those functions. Python functions (e.g., `torch.arange`) # data can flow through those functions. Python functions (e.g., `torch.arange`)
# not defined by pybind11 in C++ do not go though Python dispatcher, so # not defined by pybind11 in C++ do not go though Python dispatcher, so

View File

@ -11,13 +11,13 @@ from typing import Any, Callable, Sequence, TYPE_CHECKING
import torch import torch
import torch._ops import torch._ops
import torch.fx import torch.fx
from torch.onnx._internal.fx import ( from torch.onnx._internal.fx import (
diagnostics, diagnostics,
registration, registration,
type_utils as fx_type_utils, type_utils as fx_type_utils,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
import onnxscript # type: ignore[import] import onnxscript # type: ignore[import]
from onnxscript.function_libs.torch_lib import ( # type: ignore[import] from onnxscript.function_libs.torch_lib import ( # type: ignore[import]

View File

@ -13,7 +13,6 @@ import torch
import torch.fx import torch.fx
from torch.fx.experimental import symbolic_shapes from torch.fx.experimental import symbolic_shapes
from torch.onnx import _constants, _type_utils as jit_type_utils from torch.onnx import _constants, _type_utils as jit_type_utils
from torch.onnx._internal.fx import ( from torch.onnx._internal.fx import (
diagnostics, diagnostics,
fx_onnx_interpreter, fx_onnx_interpreter,

View File

@ -5,6 +5,7 @@ from .readability import RestoreParameterAndBufferNames
from .type_promotion import InsertTypePromotion from .type_promotion import InsertTypePromotion
from .virtualization import MovePlaceholderToFront, ReplaceGetAttrWithPlaceholder from .virtualization import MovePlaceholderToFront, ReplaceGetAttrWithPlaceholder
__all__ = [ __all__ = [
"Decompose", "Decompose",
"InsertTypePromotion", "InsertTypePromotion",

View File

@ -6,9 +6,7 @@ These functions should NOT be directly invoked outside of `passes` package.
from __future__ import annotations from __future__ import annotations
import collections import collections
import re import re
from typing import Callable from typing import Callable
import torch.fx import torch.fx

View File

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import contextlib import contextlib
from typing import Callable, Mapping, TYPE_CHECKING from typing import Callable, Mapping, TYPE_CHECKING
import torch import torch
@ -12,6 +11,7 @@ from torch.fx.experimental import proxy_tensor
from torch.onnx._internal.fx import _pass, diagnostics from torch.onnx._internal.fx import _pass, diagnostics
from torch.onnx._internal.fx.passes import _utils from torch.onnx._internal.fx.passes import _utils
if TYPE_CHECKING: if TYPE_CHECKING:
import torch.fx import torch.fx
from torch._subclasses import fake_tensor from torch._subclasses import fake_tensor

View File

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import contextlib import contextlib
from typing import Callable, TYPE_CHECKING from typing import Callable, TYPE_CHECKING
import torch import torch
@ -14,6 +13,7 @@ from torch.onnx._internal.fx import _pass, diagnostics
from torch.onnx._internal.fx.passes import _utils from torch.onnx._internal.fx.passes import _utils
from torch.utils import _pytree as pytree from torch.utils import _pytree as pytree
if TYPE_CHECKING: if TYPE_CHECKING:
from torch._subclasses import fake_tensor from torch._subclasses import fake_tensor

View File

@ -2,19 +2,17 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import collections import collections
import copy import copy
import operator import operator
from typing import Any, Dict, Final, Generator, Iterator, Sequence, Tuple from typing import Any, Dict, Final, Generator, Iterator, Sequence, Tuple
import torch import torch
import torch.fx import torch.fx
from torch.onnx._internal.fx import _pass, diagnostics from torch.onnx._internal.fx import _pass, diagnostics
from torch.utils import _pytree as pytree from torch.utils import _pytree as pytree
_FX_TRACER_NN_MODULE_META_TYPE = Tuple[str, type] _FX_TRACER_NN_MODULE_META_TYPE = Tuple[str, type]
"""Legacy type of item from `node.meta["nn_module_stack"].items()` produced by FX symbolic tracer.""" """Legacy type of item from `node.meta["nn_module_stack"].items()` produced by FX symbolic tracer."""
_FX_TRACER_NN_MODULE_STACK_META_TYPE = collections.OrderedDict _FX_TRACER_NN_MODULE_STACK_META_TYPE = collections.OrderedDict

View File

@ -4,7 +4,6 @@ from __future__ import annotations
from typing import Sequence from typing import Sequence
import torch import torch
from torch.onnx._internal.fx import _pass, diagnostics from torch.onnx._internal.fx import _pass, diagnostics

View File

@ -3,20 +3,16 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import dataclasses import dataclasses
import inspect import inspect
import logging import logging
from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING
import torch import torch
import torch._ops import torch._ops
import torch.fx import torch.fx
import torch.fx.traceback as fx_traceback import torch.fx.traceback as fx_traceback
from torch import _prims_common, _refs from torch import _prims_common, _refs
from torch._prims_common import ( from torch._prims_common import (
ELEMENTWISE_TYPE_PROMOTION_KIND, ELEMENTWISE_TYPE_PROMOTION_KIND,
wrappers as _prims_common_wrappers, wrappers as _prims_common_wrappers,
@ -24,15 +20,16 @@ from torch._prims_common import (
from torch._refs import linalg as _linalg_refs, nn as _nn_refs, special as _special_refs from torch._refs import linalg as _linalg_refs, nn as _nn_refs, special as _special_refs
from torch._refs.nn import functional as _functional_refs from torch._refs.nn import functional as _functional_refs
from torch.fx.experimental import proxy_tensor from torch.fx.experimental import proxy_tensor
from torch.onnx._internal.fx import _pass, diagnostics, type_utils as fx_type_utils from torch.onnx._internal.fx import _pass, diagnostics, type_utils as fx_type_utils
from torch.utils import _python_dispatch, _pytree from torch.utils import _python_dispatch, _pytree
if TYPE_CHECKING: if TYPE_CHECKING:
from types import ModuleType from types import ModuleType
from torch._subclasses import fake_tensor from torch._subclasses import fake_tensor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# TODO(bowbao): move to type utils. # TODO(bowbao): move to type utils.

View File

@ -4,9 +4,9 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
from torch.onnx._internal.fx import _pass from torch.onnx._internal.fx import _pass
if TYPE_CHECKING: if TYPE_CHECKING:
import torch.fx import torch.fx

View File

@ -5,6 +5,7 @@ from typing import List, TYPE_CHECKING, Union
import torch import torch
if TYPE_CHECKING: if TYPE_CHECKING:
import io import io
@ -16,7 +17,6 @@ def has_safetensors_and_transformers():
# safetensors is not an exporter requirement, but needed for some huggingface models # safetensors is not an exporter requirement, but needed for some huggingface models
import safetensors # type: ignore[import] # noqa: F401 import safetensors # type: ignore[import] # noqa: F401
import transformers # type: ignore[import] # noqa: F401 import transformers # type: ignore[import] # noqa: F401
from safetensors import torch as safetensors_torch # noqa: F401 from safetensors import torch as safetensors_torch # noqa: F401
return True return True

View File

@ -13,6 +13,7 @@ import torch.fx
from torch.onnx._internal import exporter, io_adapter from torch.onnx._internal import exporter, io_adapter
from torch.onnx._internal.diagnostics import infra from torch.onnx._internal.diagnostics import infra
if TYPE_CHECKING: if TYPE_CHECKING:
import torch.onnx import torch.onnx
from torch.export.exported_program import ExportedProgram from torch.export.exported_program import ExportedProgram

View File

@ -15,11 +15,13 @@ from typing import (
) )
import numpy import numpy
import onnx import onnx
import torch import torch
from torch._subclasses import fake_tensor from torch._subclasses import fake_tensor
if TYPE_CHECKING: if TYPE_CHECKING:
import onnx.defs.OpSchema.AttrType # type: ignore[import] # noqa: TCH004 import onnx.defs.OpSchema.AttrType # type: ignore[import] # noqa: TCH004

View File

@ -13,9 +13,9 @@ from typing import (
import torch import torch
import torch.export as torch_export import torch.export as torch_export
from torch.utils import _pytree as pytree from torch.utils import _pytree as pytree
if TYPE_CHECKING: if TYPE_CHECKING:
import inspect import inspect

View File

@ -1,10 +1,11 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Utilities for manipulating the torch.Graph object and the torchscript.""" """Utilities for manipulating the torch.Graph object and the torchscript."""
from __future__ import annotations
# TODO(justinchuby): Move more of the symbolic helper functions here and expose # TODO(justinchuby): Move more of the symbolic helper functions here and expose
# them to the user. # them to the user.
from __future__ import annotations
import dataclasses import dataclasses
import re import re
import typing import typing

View File

@ -4,7 +4,6 @@ import dataclasses
import importlib import importlib
import logging import logging
import os import os
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@ -33,6 +32,7 @@ from torch.fx.passes.operator_support import OperatorSupport
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
from torch.utils import _pytree from torch.utils import _pytree
if TYPE_CHECKING: if TYPE_CHECKING:
import onnx import onnx
@ -929,7 +929,7 @@ class OrtBackend:
try: try:
from onnxscript import optimizer # type: ignore[import] from onnxscript import optimizer # type: ignore[import]
from onnxscript.rewriter import ( # type: ignore[import] from onnxscript.rewriter import ( # type: ignore[import]
onnxruntime as ort_rewriter, # type: ignore[import] onnxruntime as ort_rewriter,
) )
onnx_model = optimizer.optimize(onnx_model) onnx_model = optimizer.optimize(onnx_model)
@ -1112,7 +1112,6 @@ class OrtBackend:
the ``compile`` method is invoked directly.""" the ``compile`` method is invoked directly."""
if self._options.use_aot_autograd: if self._options.use_aot_autograd:
from functorch.compile import min_cut_rematerialization_partition from functorch.compile import min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd from torch._dynamo.backends.common import aot_autograd
return aot_autograd( return aot_autograd(

View File

@ -10,6 +10,7 @@ import torch
from torch._C import _onnx as _C_onnx from torch._C import _onnx as _C_onnx
from torch.onnx import errors from torch.onnx import errors
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
# Hack to help mypy to recognize torch._C.Value # Hack to help mypy to recognize torch._C.Value
from torch import _C # noqa: F401 from torch import _C # noqa: F401

View File

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING
from torch.onnx import _constants from torch.onnx import _constants
from torch.onnx._internal import diagnostics from torch.onnx._internal import diagnostics
if TYPE_CHECKING: if TYPE_CHECKING:
from torch import _C from torch import _C

View File

@ -19,6 +19,7 @@ from torch.onnx import _constants, _type_utils, errors, utils
from torch.onnx._globals import GLOBALS from torch.onnx._globals import GLOBALS
from torch.onnx._internal import jit_utils from torch.onnx._internal import jit_utils
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from torch.types import Number from torch.types import Number

View File

@ -23,6 +23,7 @@ from torch.onnx import (
from torch.onnx._globals import GLOBALS from torch.onnx._globals import GLOBALS
from torch.onnx._internal import jit_utils, registration from torch.onnx._internal import jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST! # EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md # see Note [Edit Symbolic Files] in README.md

View File

@ -21,6 +21,7 @@ from torch.onnx import (
) )
from torch.onnx._internal import jit_utils, registration from torch.onnx._internal import jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST! # EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md # see Note [Edit Symbolic Files] in README.md

View File

@ -25,6 +25,7 @@ from torch.onnx import _constants, _type_utils, symbolic_helper
from torch.onnx._globals import GLOBALS from torch.onnx._globals import GLOBALS
from torch.onnx._internal import jit_utils, registration from torch.onnx._internal import jit_utils, registration
__all__ = [ __all__ = [
"hardswish", "hardswish",
"tril", "tril",

View File

@ -33,6 +33,7 @@ from torch import _C
from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
from torch.onnx._internal import jit_utils, registration from torch.onnx._internal import jit_utils, registration
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15) _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15)

View File

@ -36,6 +36,7 @@ from torch.nn.functional import (
from torch.onnx import _type_utils, errors, symbolic_helper, utils from torch.onnx import _type_utils, errors, symbolic_helper, utils
from torch.onnx._internal import jit_utils, registration from torch.onnx._internal import jit_utils, registration
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16) _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16)

View File

@ -25,6 +25,7 @@ from torch import _C
from torch.onnx import _type_utils, errors, symbolic_helper from torch.onnx import _type_utils, errors, symbolic_helper
from torch.onnx._internal import jit_utils, registration from torch.onnx._internal import jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST! # EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md # see Note [Edit Symbolic Files] in README.md

View File

@ -27,6 +27,7 @@ from torch import _C
from torch.onnx import _type_utils, symbolic_helper, symbolic_opset9 as opset9 from torch.onnx import _type_utils, symbolic_helper, symbolic_opset9 as opset9
from torch.onnx._internal import jit_utils, registration from torch.onnx._internal import jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST! # EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py # see Note [Edit Symbolic Files] in symbolic_helper.py

View File

@ -26,6 +26,7 @@ Size
from typing import List from typing import List
# EDITING THIS FILE? READ THIS FIRST! # EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py # see Note [Edit Symbolic Files] in symbolic_helper.py

View File

@ -24,11 +24,11 @@ New operators:
import functools import functools
import torch.nn.functional as F import torch.nn.functional as F
from torch import _C from torch import _C
from torch.onnx import symbolic_helper from torch.onnx import symbolic_helper
from torch.onnx._internal import jit_utils, registration from torch.onnx._internal import jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST! # EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py # see Note [Edit Symbolic Files] in symbolic_helper.py

View File

@ -39,6 +39,7 @@ from torch._C import _onnx as _C_onnx
from torch.onnx import _type_utils, errors, symbolic_helper, symbolic_opset9 as opset9 from torch.onnx import _type_utils, errors, symbolic_helper, symbolic_opset9 as opset9
from torch.onnx._internal import jit_utils, registration from torch.onnx._internal import jit_utils, registration
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=8) _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=8)
block_listed_operators = ( block_listed_operators = (

View File

@ -27,6 +27,7 @@ from torch.onnx import _constants, _deprecation, _type_utils, errors, symbolic_h
from torch.onnx._globals import GLOBALS from torch.onnx._globals import GLOBALS
from torch.onnx._internal import jit_utils, registration from torch.onnx._internal import jit_utils, registration
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.types import Number from torch.types import Number

View File

@ -30,6 +30,7 @@ from torch.onnx._globals import GLOBALS
from torch.onnx._internal import onnx_proto_utils from torch.onnx._internal import onnx_proto_utils
from torch.types import Number from torch.types import Number
_ORT_PROVIDERS = ("CPUExecutionProvider",) _ORT_PROVIDERS = ("CPUExecutionProvider",)
_NumericType = Union[Number, torch.Tensor, np.ndarray] _NumericType = Union[Number, torch.Tensor, np.ndarray]

View File

@ -23,6 +23,7 @@ from torch.optim.rprop import Rprop
from torch.optim.sgd import SGD from torch.optim.sgd import SGD
from torch.optim.sparse_adam import SparseAdam from torch.optim.sparse_adam import SparseAdam
Adafactor.__module__ = "torch.optim" Adafactor.__module__ = "torch.optim"

View File

@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor
from .optimizer import ( from .optimizer import (
_disable_dynamo_if_unsupported, _disable_dynamo_if_unsupported,
_get_scalar_dtype, _get_scalar_dtype,
@ -12,6 +13,7 @@ from .optimizer import (
ParamsT, ParamsT,
) )
__all__ = ["Adafactor", "adafactor"] __all__ = ["Adafactor", "adafactor"]

View File

@ -20,6 +20,7 @@ from .optimizer import (
ParamsT, ParamsT,
) )
__all__ = ["Adadelta", "adadelta"] __all__ = ["Adadelta", "adadelta"]

View File

@ -4,6 +4,7 @@ from typing import List, Optional, Union
import torch import torch
from torch import Tensor from torch import Tensor
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
from .optimizer import ( from .optimizer import (
_default_to_fused_or_foreach, _default_to_fused_or_foreach,
_differentiable_doc, _differentiable_doc,
@ -17,6 +18,7 @@ from .optimizer import (
ParamsT, ParamsT,
) )
__all__ = ["Adagrad", "adagrad"] __all__ = ["Adagrad", "adagrad"]

View File

@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
from .optimizer import ( from .optimizer import (
_capturable_doc, _capturable_doc,
_default_to_fused_or_foreach, _default_to_fused_or_foreach,
@ -24,6 +25,7 @@ from .optimizer import (
ParamsT, ParamsT,
) )
__all__ = ["Adam", "adam"] __all__ = ["Adam", "adam"]

View File

@ -21,6 +21,7 @@ from .optimizer import (
ParamsT, ParamsT,
) )
__all__ = ["Adamax", "adamax"] __all__ = ["Adamax", "adamax"]

View File

@ -5,6 +5,7 @@ from typing import cast, List, Optional, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
from .optimizer import ( from .optimizer import (
_capturable_doc, _capturable_doc,
_default_to_fused_or_foreach, _default_to_fused_or_foreach,
@ -24,6 +25,7 @@ from .optimizer import (
ParamsT, ParamsT,
) )
__all__ = ["AdamW", "adamw"] __all__ = ["AdamW", "adamw"]

View File

@ -21,6 +21,7 @@ from .optimizer import (
ParamsT, ParamsT,
) )
__all__ = ["ASGD", "asgd"] __all__ = ["ASGD", "asgd"]

View File

@ -3,8 +3,10 @@ from typing import Optional, Union
import torch import torch
from torch import Tensor from torch import Tensor
from .optimizer import Optimizer, ParamsT from .optimizer import Optimizer, ParamsT
__all__ = ["LBFGS"] __all__ = ["LBFGS"]

View File

@ -26,6 +26,7 @@ from torch import inf, Tensor
from .optimizer import Optimizer from .optimizer import Optimizer
__all__ = [ __all__ = [
"LambdaLR", "LambdaLR",
"MultiplicativeLR", "MultiplicativeLR",

View File

@ -5,6 +5,7 @@ from typing import cast, List, Optional, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor
from .optimizer import ( from .optimizer import (
_capturable_doc, _capturable_doc,
_default_to_fused_or_foreach, _default_to_fused_or_foreach,
@ -22,6 +23,7 @@ from .optimizer import (
ParamsT, ParamsT,
) )
__all__ = ["NAdam", "nadam"] __all__ = ["NAdam", "nadam"]

View File

@ -35,6 +35,7 @@ from torch.utils._foreach_utils import (
) )
from torch.utils.hooks import RemovableHandle from torch.utils.hooks import RemovableHandle
Args: TypeAlias = Tuple[Any, ...] Args: TypeAlias = Tuple[Any, ...]
Kwargs: TypeAlias = Dict[str, Any] Kwargs: TypeAlias = Dict[str, Any]
StateDict: TypeAlias = Dict[str, Any] StateDict: TypeAlias = Dict[str, Any]

View File

@ -22,6 +22,7 @@ from .optimizer import (
ParamsT, ParamsT,
) )
__all__ = ["RAdam", "radam"] __all__ = ["RAdam", "radam"]

View File

@ -5,6 +5,7 @@ from typing import List, Optional, Union
import torch import torch
from torch import Tensor from torch import Tensor
from .optimizer import ( from .optimizer import (
_capturable_doc, _capturable_doc,
_default_to_fused_or_foreach, _default_to_fused_or_foreach,
@ -20,6 +21,7 @@ from .optimizer import (
ParamsT, ParamsT,
) )
__all__ = ["RMSprop", "rmsprop"] __all__ = ["RMSprop", "rmsprop"]

View File

@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor
from .optimizer import ( from .optimizer import (
_capturable_doc, _capturable_doc,
_default_to_fused_or_foreach, _default_to_fused_or_foreach,
@ -20,6 +21,7 @@ from .optimizer import (
ParamsT, ParamsT,
) )
__all__ = ["Rprop", "rprop"] __all__ = ["Rprop", "rprop"]

View File

@ -5,6 +5,7 @@ from typing import List, Optional, Union
import torch import torch
from torch import Tensor from torch import Tensor
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
from .optimizer import ( from .optimizer import (
_default_to_fused_or_foreach, _default_to_fused_or_foreach,
_differentiable_doc, _differentiable_doc,
@ -16,6 +17,7 @@ from .optimizer import (
Optimizer, Optimizer,
) )
__all__ = ["SGD", "sgd"] __all__ = ["SGD", "sgd"]

View File

@ -3,9 +3,11 @@ from typing import List, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor
from . import _functional as F from . import _functional as F
from .optimizer import _maximize_doc, Optimizer, ParamsT from .optimizer import _maximize_doc, Optimizer, ParamsT
__all__ = ["SparseAdam"] __all__ = ["SparseAdam"]

View File

@ -11,8 +11,10 @@ from torch import Tensor
from torch.nn import Module from torch.nn import Module
from torch.optim.lr_scheduler import _format_param, LRScheduler from torch.optim.lr_scheduler import _format_param, LRScheduler
from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
from .optimizer import Optimizer from .optimizer import Optimizer
__all__ = [ __all__ = [
"AveragedModel", "AveragedModel",
"update_bn", "update_bn",
@ -25,6 +27,7 @@ __all__ = [
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
PARAM_LIST = Union[Tuple[Tensor, ...], List[Tensor]] PARAM_LIST = Union[Tuple[Tensor, ...], List[Tensor]]

View File

@ -6,6 +6,7 @@ from typing import cast
import torch import torch
from torch.types import Storage from torch.types import Storage
__serialization_id_record_name__ = ".data/serialization_id" __serialization_id_record_name__ = ".data/serialization_id"

View File

@ -2,6 +2,7 @@
import _warnings import _warnings
import os.path import os.path
# note: implementations # note: implementations
# copied from cpython's import code # copied from cpython's import code

View File

@ -4,6 +4,7 @@ See mangling.md for details.
""" """
import re import re
_mangle_index = 0 _mangle_index = 0

View File

@ -1,6 +1,7 @@
from typing import Dict, List from typing import Dict, List
from ..package_exporter import PackagingError from torch.package.package_exporter import PackagingError
__all__ = ["find_first_use_of_broken_modules"] __all__ = ["find_first_use_of_broken_modules"]

View File

@ -2,6 +2,7 @@
import sys import sys
from typing import Any, Callable, Iterable, List, Tuple from typing import Any, Callable, Iterable, List, Tuple
__all__ = ["trace_dependencies"] __all__ = ["trace_dependencies"]

View File

@ -3,6 +3,7 @@ from typing import Dict, List
from .glob_group import GlobGroup, GlobPattern from .glob_group import GlobGroup, GlobPattern
__all__ = ["Directory"] __all__ = ["Directory"]

View File

@ -2,6 +2,7 @@
import re import re
from typing import Iterable, Union from typing import Iterable, Union
GlobPattern = Union[str, Iterable[str]] GlobPattern = Union[str, Iterable[str]]

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import importlib import importlib
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pickle import ( # type: ignore[attr-defined] # type: ignore[attr-defined] from pickle import ( # type: ignore[attr-defined]
_getattribute, _getattribute,
_Pickler, _Pickler,
whichmodule as _pickle_whichmodule, whichmodule as _pickle_whichmodule,
@ -11,6 +11,7 @@ from typing import Any, Dict, List, Optional, Tuple
from ._mangling import demangle, get_mangle_prefix, is_mangled from ._mangling import demangle, get_mangle_prefix, is_mangled
__all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"] __all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"]

View File

@ -39,6 +39,7 @@ from .find_file_dependencies import find_files_source_depends_on
from .glob_group import GlobGroup, GlobPattern from .glob_group import GlobGroup, GlobPattern
from .importer import Importer, OrderedImporter, sys_importer from .importer import Importer, OrderedImporter, sys_importer
__all__ = [ __all__ = [
"PackagingErrorReason", "PackagingErrorReason",
"EmptyMatchError", "EmptyMatchError",

View File

@ -38,6 +38,7 @@ from ._package_unpickler import PackageUnpickler
from .file_structure_representation import _create_directory_from_file_list, Directory from .file_structure_representation import _create_directory_from_file_list, Directory
from .importer import Importer from .importer import Importer
if TYPE_CHECKING: if TYPE_CHECKING:
from .glob_group import GlobPattern from .glob_group import GlobPattern

View File

@ -25,6 +25,7 @@ from .profiler import (
tensorboard_trace_handler, tensorboard_trace_handler,
) )
__all__ = [ __all__ = [
"profile", "profile",
"schedule", "schedule",

View File

@ -32,6 +32,7 @@ from torch._C._profiler import (
from torch._utils import _element_size from torch._utils import _element_size
from torch.profiler import _utils from torch.profiler import _utils
KeyAndID = Tuple["Key", int] KeyAndID = Tuple["Key", int]
TensorAndID = Tuple["TensorKey", int] TensorAndID = Tuple["TensorKey", int]

View File

@ -7,9 +7,9 @@ from dataclasses import dataclass
from typing import Dict, List, TYPE_CHECKING from typing import Dict, List, TYPE_CHECKING
from torch.autograd.profiler import profile from torch.autograd.profiler import profile
from torch.profiler import DeviceType from torch.profiler import DeviceType
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.autograd import _KinetoEvent from torch.autograd import _KinetoEvent

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from contextlib import contextmanager from contextlib import contextmanager
try: try:
from torch._C import _itt from torch._C import _itt
except ImportError: except ImportError:

View File

@ -1,16 +1,14 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from .quantize import * # noqa: F403
from .observer import * # noqa: F403
from .qconfig import * # noqa: F403
from .fake_quantize import * # noqa: F403 from .fake_quantize import * # noqa: F403
from .fuse_modules import fuse_modules from .fuse_modules import fuse_modules
from .stubs import * # noqa: F403
from .quant_type import * # noqa: F403
from .quantize_jit import * # noqa: F403
# from .quantize_fx import *
from .quantization_mappings import * # noqa: F403
from .fuser_method_mappings import * # noqa: F403 from .fuser_method_mappings import * # noqa: F403
from .observer import * # noqa: F403
from .qconfig import * # noqa: F403
from .quant_type import * # noqa: F403
from .quantization_mappings import * # noqa: F403
from .quantize import * # noqa: F403
from .quantize_jit import * # noqa: F403
from .stubs import * # noqa: F403
def default_eval_fn(model, calib_data): def default_eval_fn(model, calib_data):

View File

@ -15,6 +15,7 @@ from torch.ao.quantization.fx.pattern_utils import (
QuantizeHandler, QuantizeHandler,
) )
# QuantizeHandler.__module__ = _NAMESPACE # QuantizeHandler.__module__ = _NAMESPACE
_register_fusion_pattern.__module__ = "torch.ao.quantization.fx.pattern_utils" _register_fusion_pattern.__module__ = "torch.ao.quantization.fx.pattern_utils"
get_default_fusion_patterns.__module__ = "torch.ao.quantization.fx.pattern_utils" get_default_fusion_patterns.__module__ = "torch.ao.quantization.fx.pattern_utils"

View File

@ -23,6 +23,7 @@ from torch.ao.quantization.fx.quantize_handler import (
StandaloneModuleQuantizeHandler, StandaloneModuleQuantizeHandler,
) )
QuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" QuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns"
BinaryOpQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" BinaryOpQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns"
CatQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" CatQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns"

View File

@ -15,6 +15,7 @@ from .semi_structured import (
to_sparse_semi_structured, to_sparse_semi_structured,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.types import _dtype as DType from torch.types import _dtype as DType

View File

@ -3,6 +3,7 @@ import contextlib
import torch import torch
__all__ = [ __all__ = [
"fallback_dispatcher", "fallback_dispatcher",
"semi_sparse_values", "semi_sparse_values",

View File

@ -8,8 +8,10 @@ from typing import Optional, Tuple
import torch import torch
from torch.utils._triton import has_triton from torch.utils._triton import has_triton
from ._triton_ops_meta import get_meta from ._triton_ops_meta import get_meta
TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE = int( TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE = int(
os.getenv("TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE", 2) os.getenv("TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE", 2)
) )

View File

@ -20,6 +20,7 @@ from torch.sparse._semi_structured_ops import (
semi_sparse_view, semi_sparse_view,
) )
__all__ = [ __all__ = [
"SparseSemiStructuredTensor", "SparseSemiStructuredTensor",
"SparseSemiStructuredTensorCUTLASS", "SparseSemiStructuredTensorCUTLASS",

View File

@ -1,4 +1,5 @@
from torch._C import FileCheck as FileCheck from torch._C import FileCheck as FileCheck
from . import _utils from . import _utils
from ._comparison import assert_allclose, assert_close as assert_close from ._comparison import assert_allclose, assert_close as assert_close
from ._creation import make_tensor as make_tensor from ._creation import make_tensor as make_tensor

View File

@ -9,6 +9,7 @@ from typing import cast, List, Optional, Tuple, Union
import torch import torch
_INTEGRAL_TYPES = [ _INTEGRAL_TYPES = [
torch.uint8, torch.uint8,
torch.int8, torch.int8,

View File

@ -8,7 +8,6 @@ from typing import Any, Dict
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed._sharded_tensor import ShardedTensor from torch.distributed._sharded_tensor import ShardedTensor
from torch.distributed._state_dict_utils import _gather_state_dict from torch.distributed._state_dict_utils import _gather_state_dict
from torch.distributed._tensor import DTensor from torch.distributed._tensor import DTensor

View File

@ -3,6 +3,7 @@ import logging
import os import os
import sys import sys
# NOTE: [dynamo_test_failures.py] # NOTE: [dynamo_test_failures.py]
# #
# We generate xFailIfTorchDynamo* for all tests in `dynamo_expected_failures` # We generate xFailIfTorchDynamo* for all tests in `dynamo_expected_failures`

View File

@ -34,9 +34,9 @@ from torch.testing._internal.common_utils import (
TrackedInputIter, TrackedInputIter,
) )
from torch.testing._internal.opinfo import utils from torch.testing._internal.opinfo import utils
from torchgen.utils import dataclass_repr from torchgen.utils import dataclass_repr
# Reasonable testing sizes for dimensions # Reasonable testing sizes for dimensions
L = 20 L = 20
M = 10 M = 10

View File

@ -11,6 +11,7 @@ from torch.testing._internal.opinfo.definitions import (
special, special,
) )
# Operator database # Operator database
op_db: List[OpInfo] = [ op_db: List[OpInfo] = [
*fft.op_db, *fft.op_db,

View File

@ -7,7 +7,6 @@ from typing import List
import numpy as np import numpy as np
import torch import torch
from torch.testing import make_tensor from torch.testing import make_tensor
from torch.testing._internal.common_cuda import SM53OrLater from torch.testing._internal.common_cuda import SM53OrLater
from torch.testing._internal.common_device_type import precisionOverride from torch.testing._internal.common_device_type import precisionOverride
@ -31,6 +30,7 @@ from torch.testing._internal.opinfo.refs import (
PythonRefInfo, PythonRefInfo,
) )
has_scipy_fft = False has_scipy_fft = False
if TEST_SCIPY: if TEST_SCIPY:
try: try:

View File

@ -11,7 +11,6 @@ import numpy as np
from numpy import inf from numpy import inf
import torch import torch
from torch.testing import make_tensor from torch.testing import make_tensor
from torch.testing._internal.common_cuda import ( from torch.testing._internal.common_cuda import (
_get_magma_version, _get_magma_version,

View File

@ -2,7 +2,6 @@
import unittest import unittest
from functools import partial from functools import partial
from itertools import product from itertools import product
from typing import Callable, List, Tuple from typing import Callable, List, Tuple
@ -18,6 +17,7 @@ from torch.testing._internal.opinfo.core import (
SampleInput, SampleInput,
) )
if TEST_SCIPY: if TEST_SCIPY:
import scipy.signal import scipy.signal

View File

@ -7,6 +7,7 @@ from torch.testing._internal.opinfo.core import (
UnaryUfuncInfo, UnaryUfuncInfo,
) )
# NOTE [Python References] # NOTE [Python References]
# Python References emulate existing PyTorch operations, but can ultimately # Python References emulate existing PyTorch operations, but can ultimately
# be expressed in terms of "primitive" operations from torch._prims. # be expressed in terms of "primitive" operations from torch._prims.

Some files were not shown because too many files have changed in this diff Show More