[BE] enable UFMT for top-level files torch/*.py (#127707)

Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127707
Approved by: https://github.com/ezyang
This commit is contained in:
Xuehai Pan 2024-06-12 20:10:47 +08:00 committed by PyTorch MergeBot
parent cc231a8e2b
commit dd143d44cc
15 changed files with 1548 additions and 875 deletions

View File

@ -1556,7 +1556,6 @@ exclude_patterns = [
'torch/distributed/tensor/parallel/style.py',
'torch/fft/__init__.py',
'torch/func/__init__.py',
'torch/functional.py',
'torch/futures/__init__.py',
'torch/fx/__init__.py',
'torch/fx/_compatibility.py',
@ -1642,8 +1641,6 @@ exclude_patterns = [
'torch/fx/subgraph_rewriter.py',
'torch/fx/tensor_type.py',
'torch/fx/traceback.py',
'torch/hub.py',
'torch/library.py',
'torch/linalg/__init__.py',
'torch/monitor/__init__.py',
'torch/nested/__init__.py',
@ -1767,11 +1764,6 @@ exclude_patterns = [
'torch/nn/utils/rnn.py',
'torch/nn/utils/spectral_norm.py',
'torch/nn/utils/weight_norm.py',
'torch/overrides.py',
'torch/quasirandom.py',
'torch/random.py',
'torch/return_types.py',
'torch/serialization.py',
'torch/signal/__init__.py',
'torch/signal/windows/__init__.py',
'torch/signal/windows/windows.py',

View File

@ -2,7 +2,6 @@
from __future__ import annotations
import contextlib
import dataclasses
import enum
import functools
@ -31,6 +30,7 @@ from torch.utils import _pytree as pytree
from torch.utils._traceback import CapturedTraceback
from torch.utils.weak import WeakTensorKeyDictionary
log = logging.getLogger(__name__)
@ -40,7 +40,6 @@ if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features,
# such as auto-completion in tools like pylance, even when these modules are not explicitly
# imported in user code.
import torch
@ -176,7 +175,7 @@ class Guard:
def sort_key(self):
# Put the duplicate input guards at the end. The duplicate guards have
# two sources while guard.name only considers one source.
from ._dynamo.guards import GuardBuilder
from torch._dynamo.guards import GuardBuilder
is_duplicate_input = (
isinstance(self.create_fn, functools.partial)

View File

@ -7,9 +7,8 @@
from typing import Dict, Optional, Tuple
import torch
from torch import Tensor
from . import _linalg_utils as _utils
from .overrides import handle_torch_function, has_torch_function
from torch import _linalg_utils as _utils, Tensor
from torch.overrides import handle_torch_function, has_torch_function
__all__ = ["lobpcg"]

View File

@ -6,9 +6,8 @@ __all__ = ["svd_lowrank", "pca_lowrank"]
from typing import Optional, Tuple
import torch
from torch import Tensor
from . import _linalg_utils as _utils
from .overrides import handle_torch_function, has_torch_function
from torch import _linalg_utils as _utils, Tensor
from torch.overrides import handle_torch_function, has_torch_function
def get_approximate_basis(

View File

@ -761,22 +761,22 @@ class Tensor(torch._C.TensorBase):
return torch.norm(self, p, dim, keepdim, dtype=dtype)
def solve(self, other):
from ._linalg_utils import solve
from torch._linalg_utils import solve
return solve(self, other)
def lstsq(self, other):
from ._linalg_utils import lstsq
from torch._linalg_utils import lstsq
return lstsq(self, other)
def eig(self, eigenvectors=False):
from ._linalg_utils import eig
from torch._linalg_utils import eig
return eig(self, eigenvectors=eigenvectors)
def symeig(self, eigenvectors=False):
from ._linalg_utils import _symeig
from torch._linalg_utils import _symeig
return _symeig(self, eigenvectors=eigenvectors)

View File

@ -1,47 +1,46 @@
# mypy: allow-untyped-defs
from typing import (
List, Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING
)
import operator
import itertools
import operator
from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
import torch
from torch._C import _add_docstr
import torch.nn.functional as F
from ._lowrank import svd_lowrank, pca_lowrank
from .overrides import (
has_torch_function, has_torch_function_unary, has_torch_function_variadic,
handle_torch_function)
from ._jit_internal import boolean_dispatch
from ._jit_internal import _overload as overload
from torch import _VF, Tensor
from torch._C import _add_docstr
from torch._jit_internal import _overload as overload, boolean_dispatch
from torch._lowrank import pca_lowrank, svd_lowrank
from torch.overrides import (
handle_torch_function,
has_torch_function,
has_torch_function_unary,
has_torch_function_variadic,
)
Tensor = torch.Tensor
from torch import _VF
__all__ = [
'atleast_1d',
'atleast_2d',
'atleast_3d',
'align_tensors',
'broadcast_shapes',
'broadcast_tensors',
'cartesian_prod',
'block_diag',
'cdist',
'chain_matmul',
'einsum',
'istft',
'lu',
'norm',
'meshgrid',
'pca_lowrank',
'split',
'stft',
'svd_lowrank',
'tensordot',
'unique',
'unique_consecutive',
'unravel_index',
"atleast_1d",
"atleast_2d",
"atleast_3d",
"align_tensors",
"broadcast_shapes",
"broadcast_tensors",
"cartesian_prod",
"block_diag",
"cdist",
"chain_matmul",
"einsum",
"istft",
"lu",
"norm",
"meshgrid",
"pca_lowrank",
"split",
"stft",
"svd_lowrank",
"tensordot",
"unique",
"unique_consecutive",
"unravel_index",
]
@ -124,16 +123,25 @@ def broadcast_shapes(*shapes):
if isinstance(shape, (tuple, list)):
for i in range(-1, -1 - len(shape), -1):
if shape[i] < 0:
raise RuntimeError(f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})")
raise RuntimeError(
f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})"
)
# NB: result is initialized to 1 so this is effectively an
# equals one test
if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious(shape[i] == result[i]):
if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious(
shape[i] == result[i]
):
continue
if result[i] != 1:
raise RuntimeError("Shape mismatch: objects cannot be broadcast to a single shape")
raise RuntimeError(
"Shape mismatch: objects cannot be broadcast to a single shape"
)
result[i] = shape[i]
else:
raise RuntimeError("Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", shape)
raise RuntimeError(
"Input shapes should be of type ints, a tuple of ints, or a list of ints, got ",
shape,
)
return torch.Size(result)
else:
# with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail
@ -188,7 +196,8 @@ def split(
"""
if has_torch_function_unary(tensor):
return handle_torch_function(
split, (tensor,), tensor, split_size_or_sections, dim=dim)
split, (tensor,), tensor, split_size_or_sections, dim=dim
)
# Overwriting reason:
# This dispatches to two ATen functions depending on the type of
# split_size_or_sections. The branching code is in _tensor.py, which we
@ -335,10 +344,13 @@ def einsum(*args: Any) -> Tensor:
[ 0.3311, 5.5201, -3.0356]])
"""
import torch.backends.opt_einsum as opt_einsum
# This wrapper exists to support variadic args.
if len(args) < 2:
raise ValueError('einsum(): must specify the equation string and at least one operand, '
'or at least one operand and its subscripts list')
raise ValueError(
"einsum(): must specify the equation string and at least one operand, "
"or at least one operand and its subscripts list"
)
equation = None
operands = None
@ -350,19 +362,21 @@ def einsum(*args: Any) -> Tensor:
# input operands into a tensorlist (List[Tensor]).
def parse_subscript(n: int) -> str:
if n == Ellipsis:
return '...'
return "..."
if n >= 0 and n < 26:
return chr(ord('A') + n)
return chr(ord("A") + n)
if n >= 26 and n < 52:
return chr(ord('a') + n - 26)
raise ValueError('einsum(): subscript in subscript list is not within the valid range [0, 52)')
return chr(ord("a") + n - 26)
raise ValueError(
"einsum(): subscript in subscript list is not within the valid range [0, 52)"
)
# Parse subscripts for input operands
equation = ','.join(''.join(parse_subscript(s) for s in l) for l in args[1::2])
equation = ",".join("".join(parse_subscript(s) for s in l) for l in args[1::2])
# Parse optional output subscripts (provided when the number of arguments is odd)
if len(args) % 2 == 1:
equation += '->' + ''.join(parse_subscript(s) for s in args[-1])
equation += "->" + "".join(parse_subscript(s) for s in args[-1])
operands = args[:-1:2]
else:
operands = args[::2]
@ -388,7 +402,9 @@ def einsum(*args: Any) -> Tensor:
path = None
if opt_einsum.is_available():
_opt_einsum = opt_einsum.get_opt_einsum()
tupled_path = _opt_einsum.contract_path(equation, *operands, optimize=opt_einsum.strategy)[0]
tupled_path = _opt_einsum.contract_path(
equation, *operands, optimize=opt_einsum.strategy
)[0]
# flatten path for dispatching to C++
path = [item for pair in tupled_path for item in pair]
return _VF.einsum(equation, operands, path=path) # type: ignore[attr-defined]
@ -397,10 +413,13 @@ def einsum(*args: Any) -> Tensor:
# This wrapper exists to support variadic args.
if TYPE_CHECKING:
# The JIT doesn't understand Union, so only add type annotation for mypy
def meshgrid(*tensors: Union[Tensor, List[Tensor]],
indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
def meshgrid(
*tensors: Union[Tensor, List[Tensor]], indexing: Optional[str] = None
) -> Tuple[Tensor, ...]:
return _meshgrid(*tensors, indexing=indexing)
else:
def meshgrid(*tensors, indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors.
@ -509,15 +528,22 @@ def _meshgrid(*tensors, indexing: Optional[str]):
# kwarg for forward compatibility reasons.
#
# Remove this two weeks after landing.
kwargs = {} if indexing is None else {'indexing': indexing}
kwargs = {} if indexing is None else {"indexing": indexing}
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
win_length: Optional[int] = None, window: Optional[Tensor] = None,
center: bool = True, pad_mode: str = 'reflect', normalized: bool = False,
def stft(
input: Tensor,
n_fft: int,
hop_length: Optional[int] = None,
win_length: Optional[int] = None,
window: Optional[Tensor] = None,
center: bool = True,
pad_mode: str = "reflect",
normalized: bool = False,
onesided: Optional[bool] = None,
return_complex: Optional[bool] = None) -> Tensor:
return_complex: Optional[bool] = None,
) -> Tensor:
r"""Short-time Fourier transform (STFT).
.. warning::
@ -652,9 +678,19 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
"""
if has_torch_function_unary(input):
return handle_torch_function(
stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
window=window, center=center, pad_mode=pad_mode, normalized=normalized,
onesided=onesided, return_complex=return_complex)
stft,
(input,),
input,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
center=center,
pad_mode=pad_mode,
normalized=normalized,
onesided=onesided,
return_complex=return_complex,
)
# NOTE: Do not edit. This code will be removed once the forward-compatibility
# period is over for PR #73432
if center:
@ -663,8 +699,16 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
pad = int(n_fft // 2)
input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
input = input.view(input.shape[-signal_dim:])
return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined]
normalized, onesided, return_complex)
return _VF.stft( # type: ignore[attr-defined]
input,
n_fft,
hop_length,
win_length,
window,
normalized,
onesided,
return_complex,
)
istft = _add_docstr(
@ -746,7 +790,8 @@ Args:
Returns:
Tensor: Least squares estimation of the original signal of shape `(B?, length)` where
`B?` is an optional batch dimension from the input tensor.
""")
""",
)
if TYPE_CHECKING:
@ -758,9 +803,13 @@ else:
_unique_impl_out = Tuple[Tensor, Tensor, Tensor]
def _unique_impl(input: Tensor, sorted: bool = True,
return_inverse: bool = False, return_counts: bool = False,
dim: Optional[int] = None) -> _unique_impl_out:
def _unique_impl(
input: Tensor,
sorted: bool = True,
return_inverse: bool = False,
return_counts: bool = False,
dim: Optional[int] = None,
) -> _unique_impl_out:
r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> Tuple[Tensor, Tensor, Tensor]
Returns the unique elements of the input tensor.
@ -896,8 +945,14 @@ def _unique_impl(input: Tensor, sorted: bool = True,
"""
if has_torch_function_unary(input):
return handle_torch_function(
unique, (input,), input, sorted=sorted, return_inverse=return_inverse,
return_counts=return_counts, dim=dim)
unique,
(input,),
input,
sorted=sorted,
return_inverse=return_inverse,
return_counts=return_counts,
dim=dim,
)
if dim is not None:
output, inverse_indices, counts = _VF.unique_dim(
@ -917,9 +972,12 @@ def _unique_impl(input: Tensor, sorted: bool = True,
return output, inverse_indices, counts
def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False,
def _unique_consecutive_impl(
input: Tensor,
return_inverse: bool = False,
return_counts: bool = False,
dim: Optional[int] = None) -> _unique_impl_out:
dim: Optional[int] = None,
) -> _unique_impl_out:
r"""Eliminates all but the first element from every consecutive group of equivalent elements.
.. note:: This function is different from :func:`torch.unique` in the sense that this function
@ -971,14 +1029,22 @@ def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False,
"""
if has_torch_function_unary(input):
return handle_torch_function(
unique_consecutive, (input,), input, return_inverse=return_inverse,
return_counts=return_counts, dim=dim)
unique_consecutive,
(input,),
input,
return_inverse=return_inverse,
return_counts=return_counts,
dim=dim,
)
output, inverse_indices, counts = _VF.unique_consecutive( # type: ignore[attr-defined]
input, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
input, return_inverse=return_inverse, return_counts=return_counts, dim=dim
)
return output, inverse_indices, counts
def _return_counts(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
def _return_counts(
input, sorted=True, return_inverse=False, return_counts=False, dim=None
):
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
if has_torch_function_unary(input):
@ -988,7 +1054,9 @@ def _return_counts(input, sorted=True, return_inverse=False, return_counts=False
return output, counts
def _return_output(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
def _return_output(
input, sorted=True, return_inverse=False, return_counts=False, dim=None
):
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor
if has_torch_function_unary(input):
@ -998,59 +1066,72 @@ def _return_output(input, sorted=True, return_inverse=False, return_counts=False
return output
def _return_inverse(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
def _return_inverse(
input, sorted=True, return_inverse=False, return_counts=False, dim=None
):
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
if has_torch_function_unary(input):
return _unique_impl(input, sorted, return_inverse, return_counts, dim)
output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
output, inverse_indices, _ = _unique_impl(
input, sorted, return_inverse, return_counts, dim
)
return output, inverse_indices
_return_inverse_false = boolean_dispatch(
arg_name='return_counts',
arg_name="return_counts",
arg_index=3,
default=False,
if_true=_return_counts,
if_false=_return_output,
module_name=__name__,
func_name='unique')
func_name="unique",
)
_return_inverse_true = boolean_dispatch(
arg_name='return_counts',
arg_name="return_counts",
arg_index=3,
default=False,
if_true=_unique_impl,
if_false=_return_inverse,
module_name=__name__,
func_name='unique')
func_name="unique",
)
# The return type of unique depends on `return_inverse`, and `return_counts` so in order to
# resolve the output type in TorchScript we need to statically know the value of both parameters
unique = boolean_dispatch(
arg_name='return_inverse',
arg_name="return_inverse",
arg_index=2,
default=False,
if_true=_return_inverse_true,
if_false=_return_inverse_false,
module_name=__name__,
func_name='unique')
func_name="unique",
)
unique.__doc__ = _unique_impl.__doc__
def _consecutive_return_counts(input, return_inverse=False, return_counts=False, dim=None):
def _consecutive_return_counts(
input, return_inverse=False, return_counts=False, dim=None
):
# type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
if has_torch_function_unary(input):
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
output, _, counts = _unique_consecutive_impl(
input, return_inverse, return_counts, dim
)
return output, counts
def _consecutive_return_output(input, return_inverse=False, return_counts=False, dim=None):
def _consecutive_return_output(
input, return_inverse=False, return_counts=False, dim=None
):
# type: (Tensor, bool, bool, Optional[int]) -> Tensor
if has_torch_function_unary(input):
@ -1060,45 +1141,52 @@ def _consecutive_return_output(input, return_inverse=False, return_counts=False,
return output
def _consecutive_return_inverse(input, return_inverse=False, return_counts=False, dim=None):
def _consecutive_return_inverse(
input, return_inverse=False, return_counts=False, dim=None
):
# type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
if has_torch_function_unary(input):
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
output, inverse_indices, _ = _unique_consecutive_impl(
input, return_inverse, return_counts, dim
)
return output, inverse_indices
_consecutive_return_inverse_false = boolean_dispatch(
arg_name='return_counts',
arg_name="return_counts",
arg_index=1,
default=False,
if_true=_consecutive_return_counts,
if_false=_consecutive_return_output,
module_name=__name__,
func_name='unique_consecutive')
func_name="unique_consecutive",
)
_consecutive_return_inverse_true = boolean_dispatch(
arg_name='return_counts',
arg_name="return_counts",
arg_index=1,
default=False,
if_true=_unique_consecutive_impl,
if_false=_consecutive_return_inverse,
module_name=__name__,
func_name='unique_consecutive')
func_name="unique_consecutive",
)
# The return type of unique depends on `return_inverse`, and `return_counts` so in order to
# resolve the output type in TorchScript we need to statically know the value of both parameters
unique_consecutive = boolean_dispatch(
arg_name='return_inverse',
arg_name="return_inverse",
arg_index=2,
default=False,
if_true=_consecutive_return_inverse_true,
if_false=_consecutive_return_inverse_false,
module_name=__name__,
func_name='unique_consecutive')
func_name="unique_consecutive",
)
unique_consecutive.__doc__ = _unique_consecutive_impl.__doc__
if TYPE_CHECKING:
@ -1106,24 +1194,50 @@ if TYPE_CHECKING:
# There's no good way to use this type annotation without breaking JIT
# overloads. So leave untyped for mypy for now.
else:
@overload
def tensordot(a, b, dims: int = 2, out: Optional[torch.Tensor] = None):
def tensordot(
a,
b,
dims: int = 2,
out: Optional[torch.Tensor] = None,
):
pass
@overload # noqa: F811
def tensordot(a, b, dims: Tuple[List[int], List[int]], out: Optional[torch.Tensor] = None): # noqa: F811
@overload
def tensordot( # noqa: F811
a,
b,
dims: Tuple[List[int], List[int]],
out: Optional[torch.Tensor] = None,
):
pass
@overload # noqa: F811
def tensordot(a, b, dims: List[List[int]], out: Optional[torch.Tensor] = None): # noqa: F811
@overload
def tensordot( # noqa: F811
a,
b,
dims: List[List[int]],
out: Optional[torch.Tensor] = None,
):
pass
@overload # noqa: F811
def tensordot(a, b, dims: torch.Tensor, out: Optional[torch.Tensor] = None): # noqa: F811
@overload
def tensordot( # noqa: F811
a,
b,
dims: torch.Tensor,
out: Optional[torch.Tensor] = None,
):
pass
def tensordot(a, b, dims=2, out: Optional[torch.Tensor] = None): # noqa: F811
def tensordot( # noqa: F811
a,
b,
dims=2,
out: Optional[torch.Tensor] = None,
):
r"""Returns a contraction of a and b over multiple dimensions.
:attr:`tensordot` implements a generalized matrix product.
@ -1178,10 +1292,12 @@ def tensordot(a, b, dims=2, out: Optional[torch.Tensor] = None): # noqa: F811
return handle_torch_function(tensordot, (a, b), a, b, dims=dims, out=out)
if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)):
raise RuntimeError("tensordot expects dims to be int or "
raise RuntimeError(
"tensordot expects dims to be int or "
+ "Tuple[List[int], List[int]] or "
+ "List[List[int]] containing two lists, but got "
+ f"dims={dims}")
+ f"dims={dims}"
)
dims_a: List[int] = []
dims_b: List[int] = []
@ -1206,7 +1322,9 @@ def tensordot(a, b, dims=2, out: Optional[torch.Tensor] = None): # noqa: F811
if dims < 0:
raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
if dims > min(a.dim(), b.dim()):
raise RuntimeError(f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}")
raise RuntimeError(
f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}"
)
dims_a = list(range(-dims, 0))
dims_b = list(range(dims))
@ -1287,7 +1405,7 @@ def block_diag(*tensors):
return torch._C._VariableFunctions.block_diag(tensors) # type: ignore[attr-defined]
def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'):
def cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"):
# type: (Tensor, Tensor, float, str) -> (Tensor)
r"""Computes batched the p-norm distance between each pair of the two collections of row vectors.
@ -1331,12 +1449,13 @@ def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'):
"""
if has_torch_function_variadic(x1, x2):
return handle_torch_function(
cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode)
if compute_mode == 'use_mm_for_euclid_dist_if_necessary':
cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode
)
if compute_mode == "use_mm_for_euclid_dist_if_necessary":
return _VF.cdist(x1, x2, p, None) # type: ignore[attr-defined]
elif compute_mode == 'use_mm_for_euclid_dist':
elif compute_mode == "use_mm_for_euclid_dist":
return _VF.cdist(x1, x2, p, 1) # type: ignore[attr-defined]
elif compute_mode == 'donot_use_mm_for_euclid_dist':
elif compute_mode == "donot_use_mm_for_euclid_dist":
return _VF.cdist(x1, x2, p, 2) # type: ignore[attr-defined]
else:
raise ValueError(f"{compute_mode} is not a valid value for compute_mode")
@ -1478,27 +1597,62 @@ else:
# TODO: type dim as BroadcastingList when
# https://github.com/pytorch/pytorch/issues/33782 is fixed
@overload
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):
def norm(
input,
p="fro",
dim=None,
keepdim=False,
out=None,
dtype=None,
):
# type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
pass
@overload # noqa: F811
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811
@overload
def norm( # noqa: F811
input,
p="fro",
dim=None,
keepdim=False,
out=None,
dtype=None,
):
# type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
pass
@overload # noqa: F811
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811
@overload
def norm( # noqa: F811
input,
p="fro",
dim=None,
keepdim=False,
out=None,
dtype=None,
):
# type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
pass
@overload # noqa: F811
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811
@overload
def norm( # noqa: F811
input,
p="fro",
dim=None,
keepdim=False,
out=None,
dtype=None,
):
# type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
pass
def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811
def norm( # noqa: F811
input,
p: Optional[Union[float, str]] = "fro",
dim=None,
keepdim=False,
out=None,
dtype=None,
):
r"""Returns the matrix norm or vector norm of a given tensor.
.. warning::
@ -1594,14 +1748,19 @@ def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False,
if has_torch_function_unary(input):
return handle_torch_function(
norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)
norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype
)
# NB. All the repeated code and weird python is to please TorchScript.
# For a more compact implementation see the relevant function in `_refs/__init__.py`
# We don't do this for MPS or sparse tensors
if input.layout == torch.strided and input.device.type in \
("cpu", "cuda", "meta", torch.utils.backend_registration._privateuse1_backend_name):
if input.layout == torch.strided and input.device.type in (
"cpu",
"cuda",
"meta",
torch.utils.backend_registration._privateuse1_backend_name,
):
if dim is not None:
if isinstance(dim, (int, torch.SymInt)):
_dim = [dim]
@ -1611,11 +1770,17 @@ def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False,
_dim = None # type: ignore[assignment]
if isinstance(p, str):
if p == "fro" and (dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2):
if p == "fro" and (
dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2
):
if out is None:
return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype)
return torch.linalg.vector_norm(
input, 2, _dim, keepdim, dtype=dtype
)
else:
return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype, out=out)
return torch.linalg.vector_norm(
input, 2, _dim, keepdim, dtype=dtype, out=out
)
# Here we either call the nuclear norm, or we call matrix_norm with some arguments
# that will throw an error
@ -1624,14 +1789,18 @@ def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False,
if out is None:
return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype)
else:
return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype, out=out)
return torch.linalg.matrix_norm(
input, p, _dim, keepdim, dtype=dtype, out=out
)
else:
# NB. p should be Union[str, number], not Optional!
_p = 2.0 if p is None else p
if out is None:
return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype)
else:
return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype, out=out)
return torch.linalg.vector_norm(
input, _p, _dim, keepdim, dtype=dtype, out=out
)
ndim = input.dim()
@ -1641,7 +1810,7 @@ def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False,
if p == "fro":
return _VF.frobenius_norm(input, dim=(), keepdim=keepdim)
if not isinstance(p, str):
_dim = [i for i in range(ndim)] # noqa: C416 TODO: rewrite as list(range(m))
_dim = list(range(ndim))
return _VF.norm(input, p, dim=_dim, keepdim=keepdim) # type: ignore[attr-defined]
# TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed
@ -1695,7 +1864,10 @@ def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False,
else:
return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out) # type: ignore[attr-defined]
def unravel_index(indices: Tensor, shape: Union[int, Sequence[int], torch.Size]) -> Tuple[Tensor, ...]:
def unravel_index(
indices: Tensor, shape: Union[int, Sequence[int], torch.Size]
) -> Tuple[Tensor, ...]:
r"""Converts a tensor of flat indices into a tuple of coordinate tensors that
index into an arbitrary tensor of the specified shape.
@ -1745,19 +1917,23 @@ def unravel_index(indices: Tensor, shape: Union[int, Sequence[int], torch.Size])
tensor([[34], [78]]))
"""
if has_torch_function_unary(indices):
return handle_torch_function(
unravel_index, (indices,), indices, shape=shape)
return handle_torch_function(unravel_index, (indices,), indices, shape=shape)
res_tensor = _unravel_index(indices, shape)
return res_tensor.unbind(-1)
def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
torch._check_type(
not indices.is_complex() and not indices.is_floating_point() and not indices.dtype == torch.bool,
lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}")
not indices.is_complex()
and not indices.is_floating_point()
and not indices.dtype == torch.bool,
lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}",
)
torch._check_type(
isinstance(shape, (int, torch.SymInt, Sequence)),
lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}")
lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}",
)
if isinstance(shape, (int, torch.SymInt)):
shape = torch.Size([shape])
@ -1765,18 +1941,29 @@ def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
for dim in shape:
torch._check_type(
isinstance(dim, (int, torch.SymInt)),
lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}")
lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}",
)
shape = torch.Size(shape)
torch._check_value(
all(dim >= 0 for dim in shape),
lambda: f"'shape' cannot have negative values, but got {tuple(shape)}")
lambda: f"'shape' cannot have negative values, but got {tuple(shape)}",
)
coefs = list(reversed(list(itertools.accumulate(reversed(shape[1:] + torch.Size([1])), func=operator.mul))))
coefs = list(
reversed(
list(
itertools.accumulate(
reversed(shape[1:] + torch.Size([1])), func=operator.mul
)
)
)
)
return indices.unsqueeze(-1).floor_divide(
torch.tensor(coefs, device=indices.device, dtype=torch.int64)
) % torch.tensor(shape, device=indices.device, dtype=torch.int64)
def chain_matmul(*matrices, out=None):
r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
@ -1923,6 +2110,7 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None):
# If get_infos is True, then we don't need to check for errors and vice versa
return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
if TYPE_CHECKING:
_ListOrSeq = Sequence[Tensor]
else:
@ -1932,16 +2120,21 @@ else:
def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:
get_infos_int = 1 if get_infos else 0
if out_len - get_infos_int != 2:
raise TypeError(f"expected tuple of {2 + int(get_infos)} elements but got {out_len}")
raise TypeError(
f"expected tuple of {2 + int(get_infos)} elements but got {out_len}"
)
if not isinstance(out, (tuple, list)):
raise TypeError(f"argument 'out' must be tuple of Tensors, not {type(out).__name__}")
raise TypeError(
f"argument 'out' must be tuple of Tensors, not {type(out).__name__}"
)
def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
# type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]
if has_torch_function_unary(A):
return handle_torch_function(
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
)
result = _lu_impl(A, pivot, get_infos, out)
if out is not None:
_check_list_size(len(out), get_infos, out)
@ -1957,7 +2150,8 @@ def _lu_no_infos(A, pivot=True, get_infos=False, out=None):
# need to check for torch_function here so that we exit if
if has_torch_function_unary(A):
return handle_torch_function(
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
)
result = _lu_impl(A, pivot, get_infos, out)
if out is not None:
_check_list_size(len(out), get_infos, out)
@ -1967,18 +2161,20 @@ def _lu_no_infos(A, pivot=True, get_infos=False, out=None):
else:
return result[0], result[1] # A_LU, pivots
# The return type of lu depends on `get_infos`, so in order to resolve the output type
# of lu in TorchScript we need to statically know the value of `get_infos`
lu = boolean_dispatch(
arg_name='get_infos',
arg_name="get_infos",
arg_index=2,
default=False,
if_true=_lu_with_infos,
if_false=_lu_no_infos,
module_name=__name__,
func_name='lu')
func_name="lu",
)
lu.__doc__ = _lu_impl.__doc__
def align_tensors(*tensors):
raise RuntimeError('`align_tensors` not yet implemented.')
raise RuntimeError("`align_tensors` not yet implemented.")

View File

@ -8,22 +8,22 @@ import re
import shutil
import sys
import tempfile
import torch
import uuid
import warnings
import zipfile
from pathlib import Path
from typing import Dict, Optional, Any
from typing import Any, Dict, Optional
from typing_extensions import deprecated
from urllib.error import HTTPError, URLError
from urllib.request import urlopen, Request
from urllib.parse import urlparse # noqa: F401
from urllib.request import Request, urlopen
import torch
from torch.serialization import MAP_LOCATION
class _Faketqdm: # type: ignore[no-redef]
def __init__(self, total=None, disable=False,
unit=None, *args, **kwargs):
class _Faketqdm: # type: ignore[no-redef]
def __init__(self, total=None, disable=False, unit=None, *args, **kwargs):
self.total = total
self.disable = disable
self.n = 0
@ -57,7 +57,8 @@ class _Faketqdm: # type: ignore[no-redef]
if self.disable:
return
sys.stderr.write('\n')
sys.stderr.write("\n")
try:
from tqdm import tqdm # If tqdm is installed use it, otherwise use the fake wrapper
@ -65,25 +66,30 @@ except ImportError:
tqdm = _Faketqdm
__all__ = [
'download_url_to_file',
'get_dir',
'help',
'list',
'load',
'load_state_dict_from_url',
'set_dir',
"download_url_to_file",
"get_dir",
"help",
"list",
"load",
"load_state_dict_from_url",
"set_dir",
]
# matches bfd8deac from resnet18-bfd8deac.pth
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
HASH_REGEX = re.compile(r"-([a-f0-9]*)\.")
_TRUSTED_REPO_OWNERS = ("facebookresearch", "facebookincubator", "pytorch", "fairinternal")
ENV_GITHUB_TOKEN = 'GITHUB_TOKEN'
ENV_TORCH_HOME = 'TORCH_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
VAR_DEPENDENCY = 'dependencies'
MODULE_HUBCONF = 'hubconf.py'
_TRUSTED_REPO_OWNERS = (
"facebookresearch",
"facebookincubator",
"pytorch",
"fairinternal",
)
ENV_GITHUB_TOKEN = "GITHUB_TOKEN"
ENV_TORCH_HOME = "TORCH_HOME"
ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
DEFAULT_CACHE_DIR = "~/.cache"
VAR_DEPENDENCY = "dependencies"
MODULE_HUBCONF = "hubconf.py"
READ_DATA_CHUNK = 128 * 1024
_hub_dir: Optional[str] = None
@ -101,6 +107,7 @@ def _add_to_sys_path(path):
def _import_module(name, path):
import importlib.util
from importlib.abc import Loader
spec = importlib.util.spec_from_file_location(name, path)
assert spec is not None
module = importlib.util.module_from_spec(spec)
@ -131,18 +138,20 @@ def _load_attr_from_module(module, func_name):
def _get_torch_home():
torch_home = os.path.expanduser(
os.getenv(ENV_TORCH_HOME,
os.path.join(os.getenv(ENV_XDG_CACHE_HOME,
DEFAULT_CACHE_DIR), 'torch')))
os.getenv(
ENV_TORCH_HOME,
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"),
)
)
return torch_home
def _parse_repo_info(github):
if ':' in github:
repo_info, ref = github.split(':')
if ":" in github:
repo_info, ref = github.split(":")
else:
repo_info, ref = github, None
repo_owner, repo_name = repo_info.split('/')
repo_owner, repo_name = repo_info.split("/")
if ref is None:
# The ref wasn't specified by the user, so we need to figure out the
@ -150,16 +159,18 @@ def _parse_repo_info(github):
# then it's the default branch, otherwise it's master.
try:
with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"):
ref = 'main'
ref = "main"
except HTTPError as e:
if e.code == 404:
ref = 'master'
ref = "master"
else:
raise
except URLError as e:
# No internet connection, need to check for cache as last resort
for possible_ref in ("main", "master"):
if os.path.exists(f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}"):
if os.path.exists(
f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}"
):
ref = possible_ref
break
if ref is None:
@ -172,35 +183,40 @@ def _parse_repo_info(github):
def _read_url(url):
with urlopen(url) as r:
return r.read().decode(r.headers.get_content_charset('utf-8'))
return r.read().decode(r.headers.get_content_charset("utf-8"))
def _validate_not_a_forked_repo(repo_owner, repo_name, ref):
# Use urlopen to avoid depending on local git.
headers = {'Accept': 'application/vnd.github.v3+json'}
headers = {"Accept": "application/vnd.github.v3+json"}
token = os.environ.get(ENV_GITHUB_TOKEN)
if token is not None:
headers['Authorization'] = f'token {token}'
headers["Authorization"] = f"token {token}"
for url_prefix in (
f'https://api.github.com/repos/{repo_owner}/{repo_name}/branches',
f'https://api.github.com/repos/{repo_owner}/{repo_name}/tags'):
f"https://api.github.com/repos/{repo_owner}/{repo_name}/branches",
f"https://api.github.com/repos/{repo_owner}/{repo_name}/tags",
):
page = 0
while True:
page += 1
url = f'{url_prefix}?per_page=100&page={page}'
url = f"{url_prefix}?per_page=100&page={page}"
response = json.loads(_read_url(Request(url, headers=headers)))
# Empty response means no more data to process
if not response:
break
for br in response:
if br['name'] == ref or br['commit']['sha'].startswith(ref):
if br["name"] == ref or br["commit"]["sha"].startswith(ref):
return
raise ValueError(f'Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. '
'If it\'s a commit from a forked repo, please call hub.load() with forked repo directly.')
raise ValueError(
f"Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. "
"If it's a commit from a forked repo, please call hub.load() with forked repo directly."
)
def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=True, skip_validation=False):
def _get_cache_or_reload(
github, force_reload, trust_repo, calling_fn, verbose=True, skip_validation=False
):
# Setup hub_dir to save downloaded files
hub_dir = get_dir()
os.makedirs(hub_dir, exist_ok=True)
@ -210,27 +226,33 @@ def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=T
# this causes confusion with path on both Linux and Windows.
# Backslash is not allowed in Github branch name so no need to
# to worry about it.
normalized_br = ref.replace('/', '_')
normalized_br = ref.replace("/", "_")
# Github renames folder repo-v1.x.x to repo-1.x.x
# We don't know the repo name before downloading the zip file
# and inspect name from it.
# To check if cached repo exists, we need to normalize folder names.
owner_name_branch = '_'.join([repo_owner, repo_name, normalized_br])
owner_name_branch = "_".join([repo_owner, repo_name, normalized_br])
repo_dir = os.path.join(hub_dir, owner_name_branch)
# Check that the repo is in the trusted list
_check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo=trust_repo, calling_fn=calling_fn)
_check_repo_is_trusted(
repo_owner,
repo_name,
owner_name_branch,
trust_repo=trust_repo,
calling_fn=calling_fn,
)
use_cache = (not force_reload) and os.path.exists(repo_dir)
if use_cache:
if verbose:
sys.stderr.write(f'Using cache found in {repo_dir}\n')
sys.stderr.write(f"Using cache found in {repo_dir}\n")
else:
# Validate the tag/branch is from the original repo instead of a forked repo
if not skip_validation:
_validate_not_a_forked_repo(repo_owner, repo_name, ref)
cached_file = os.path.join(hub_dir, normalized_br + '.zip')
cached_file = os.path.join(hub_dir, normalized_br + ".zip")
_remove_if_exists(cached_file)
try:
@ -250,7 +272,9 @@ def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=T
"refs/tags/tag_name as the ref. That might require using skip_validation=True."
)
disambiguated_branch_ref = f"refs/heads/{ref}"
url = _git_archive_link(repo_owner, repo_name, ref=disambiguated_branch_ref)
url = _git_archive_link(
repo_owner, repo_name, ref=disambiguated_branch_ref
)
download_url_to_file(url, cached_file, progress=False)
else:
raise
@ -269,7 +293,9 @@ def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=T
return repo_dir
def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo, calling_fn="load"):
def _check_repo_is_trusted(
repo_owner, repo_name, owner_name_branch, trust_repo, calling_fn="load"
):
hub_dir = get_dir()
filepath = os.path.join(hub_dir, "trusted_list")
@ -282,7 +308,7 @@ def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo,
# if a repo was already downloaded by torchhub, then it is already trusted (even if it's not in the allowlist)
trusted_repos_legacy = next(os.walk(hub_dir))[1]
owner_name = '_'.join([repo_owner, repo_name])
owner_name = "_".join([repo_owner, repo_name])
is_trusted = (
owner_name in trusted_repos
or owner_name_branch in trusted_repos_legacy
@ -298,13 +324,15 @@ def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo,
"trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, "
f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with "
f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for "
f"confirmation if the repo is not already trusted. This will eventually be the default behaviour")
f"confirmation if the repo is not already trusted. This will eventually be the default behaviour"
)
return
if (trust_repo is False) or (trust_repo == "check" and not is_trusted):
response = input(
f"The repository {owner_name} does not belong to the list of trusted repositories and as such cannot be downloaded. "
"Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?")
"Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?"
)
if response.lower() in ("y", "yes"):
if is_trusted:
print("The repository is already trusted.")
@ -321,6 +349,7 @@ def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo,
def _check_module_exists(name):
import importlib.util
return importlib.util.find_spec(name) is not None
@ -335,7 +364,7 @@ def _check_dependencies(m):
def _load_entry_from_hubconf(m, model):
if not isinstance(model, str):
raise ValueError('Invalid input: model should be a string of function name')
raise ValueError("Invalid input: model should be a string of function name")
# Note that if a missing dependency is imported at top level of hubconf, it will
# throw before this function. It's a chicken and egg situation where we have to
@ -346,7 +375,7 @@ def _load_entry_from_hubconf(m, model):
func = _load_attr_from_module(m, model)
if func is None or not callable(func):
raise RuntimeError(f'Cannot find callable {model} in hubconf')
raise RuntimeError(f"Cannot find callable {model} in hubconf")
return func
@ -362,12 +391,12 @@ def get_dir():
variable is not set.
"""
# Issue warning to move data if old env is set
if os.getenv('TORCH_HUB'):
warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead')
if os.getenv("TORCH_HUB"):
warnings.warn("TORCH_HUB is deprecated, please use env TORCH_HOME instead")
if _hub_dir is not None:
return _hub_dir
return os.path.join(_get_torch_home(), 'hub')
return os.path.join(_get_torch_home(), "hub")
def set_dir(d):
@ -381,7 +410,9 @@ def set_dir(d):
_hub_dir = os.path.expanduser(d)
def list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True):
def list(
github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True
):
r"""
List all callable entrypoints available in the repo specified by ``github``.
@ -424,15 +455,25 @@ def list(github, force_reload=False, skip_validation=False, trust_repo=None, ver
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
"""
repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "list", verbose=verbose,
skip_validation=skip_validation)
repo_dir = _get_cache_or_reload(
github,
force_reload,
trust_repo,
"list",
verbose=verbose,
skip_validation=skip_validation,
)
with _add_to_sys_path(repo_dir):
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
# We take functions starts with '_' as internal helper functions
entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')]
entrypoints = [
f
for f in dir(hub_module)
if callable(getattr(hub_module, f)) and not f.startswith("_")
]
return entrypoints
@ -474,8 +515,14 @@ def help(github, model, force_reload=False, skip_validation=False, trust_repo=No
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
"""
repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "help", verbose=True,
skip_validation=skip_validation)
repo_dir = _get_cache_or_reload(
github,
force_reload,
trust_repo,
"help",
verbose=True,
skip_validation=skip_validation,
)
with _add_to_sys_path(repo_dir):
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
@ -486,9 +533,17 @@ def help(github, model, force_reload=False, skip_validation=False, trust_repo=No
return entry.__doc__
def load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True,
def load(
repo_or_dir,
model,
*args,
source="github",
trust_repo=None,
force_reload=False,
verbose=True,
skip_validation=False,
**kwargs):
**kwargs,
):
r"""
Load a model from a github repo or a local directory.
@ -559,13 +614,20 @@ def load(repo_or_dir, model, *args, source='github', trust_repo=None, force_relo
"""
source = source.lower()
if source not in ('github', 'local'):
if source not in ("github", "local"):
raise ValueError(
f'Unknown source: "{source}". Allowed values: "github" | "local".')
f'Unknown source: "{source}". Allowed values: "github" | "local".'
)
if source == 'github':
repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, trust_repo, "load",
verbose=verbose, skip_validation=skip_validation)
if source == "github":
repo_or_dir = _get_cache_or_reload(
repo_or_dir,
force_reload,
trust_repo,
"load",
verbose=verbose,
skip_validation=skip_validation,
)
model = _load_local(repo_or_dir, model, *args, **kwargs)
return model
@ -601,8 +663,9 @@ def _load_local(hubconf_dir, model, *args, **kwargs):
return model
def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
progress: bool = True) -> None:
def download_url_to_file(
url: str, dst: str, hash_prefix: Optional[str] = None, progress: bool = True
) -> None:
r"""Download object at the given URL to a local path.
Args:
@ -623,7 +686,7 @@ def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
req = Request(url, headers={"User-Agent": "torch.hub"})
u = urlopen(req)
meta = u.info()
if hasattr(meta, 'getheaders'):
if hasattr(meta, "getheaders"):
content_length = meta.getheaders("Content-Length")
else:
content_length = meta.get_all("Content-Length")
@ -637,20 +700,25 @@ def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
# file permissions being applied to the downloaded file.
dst = os.path.expanduser(dst)
for seq in range(tempfile.TMP_MAX):
tmp_dst = dst + '.' + uuid.uuid4().hex + '.partial'
tmp_dst = dst + "." + uuid.uuid4().hex + ".partial"
try:
f = open(tmp_dst, 'w+b')
f = open(tmp_dst, "w+b")
except FileExistsError:
continue
break
else:
raise FileExistsError(errno.EEXIST, 'No usable temporary file name found')
raise FileExistsError(errno.EEXIST, "No usable temporary file name found")
try:
if hash_prefix is not None:
sha256 = hashlib.sha256()
with tqdm(total=file_size, disable=not progress,
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
with tqdm(
total=file_size,
disable=not progress,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as pbar:
while True:
buffer = u.read(READ_DATA_CHUNK)
if len(buffer) == 0:
@ -664,7 +732,9 @@ def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
if hash_prefix is not None:
digest = sha256.hexdigest() # type: ignore[possibly-undefined]
if digest[: len(hash_prefix)] != hash_prefix:
raise RuntimeError(f'invalid hash value (expected "{hash_prefix}", got "{digest}")')
raise RuntimeError(
f'invalid hash value (expected "{hash_prefix}", got "{digest}")'
)
shutil.move(f.name, dst)
finally:
f.close()
@ -683,23 +753,30 @@ def _is_legacy_zip_format(filename: str) -> bool:
@deprecated(
'Falling back to the old format < 1.6. This support will be '
'deprecated in favor of default zipfile format introduced in 1.6. '
'Please redo torch.save() to save it in the new zipfile format.',
"Falling back to the old format < 1.6. This support will be "
"deprecated in favor of default zipfile format introduced in 1.6. "
"Please redo torch.save() to save it in the new zipfile format.",
category=FutureWarning,
)
def _legacy_zip_load(filename: str, model_dir: str, map_location: MAP_LOCATION, weights_only: bool) -> Dict[str, Any]:
def _legacy_zip_load(
filename: str,
model_dir: str,
map_location: MAP_LOCATION,
weights_only: bool,
) -> Dict[str, Any]:
# Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
# We deliberately don't handle tarfile here since our legacy serialization format was in tar.
# E.g. resnet18-5c106cde.pth which is widely used.
with zipfile.ZipFile(filename) as f:
members = f.infolist()
if len(members) != 1:
raise RuntimeError('Only one file(not dir) is allowed in the zipfile')
raise RuntimeError("Only one file(not dir) is allowed in the zipfile")
f.extractall(model_dir)
extraced_name = members[0].filename
extracted_file = os.path.join(model_dir, extraced_name)
return torch.load(extracted_file, map_location=map_location, weights_only=weights_only)
return torch.load(
extracted_file, map_location=map_location, weights_only=weights_only
)
def load_state_dict_from_url(
@ -742,12 +819,14 @@ def load_state_dict_from_url(
"""
# Issue warning to move data if old env is set
if os.getenv('TORCH_MODEL_ZOO'):
warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
if os.getenv("TORCH_MODEL_ZOO"):
warnings.warn(
"TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead"
)
if model_dir is None:
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
model_dir = os.path.join(hub_dir, "checkpoints")
os.makedirs(model_dir, exist_ok=True)

View File

@ -1,28 +1,34 @@
# mypy: allow-untyped-defs
from ._ops import OpOverload
from typing import Any, Optional, Set, List, Union, Callable, Tuple, Dict, Sequence
from typing_extensions import deprecated
import traceback
import torch
import weakref
import contextlib
import functools
import inspect
import re
import contextlib
import sys
from torch._library.custom_ops import custom_op, _maybe_get_opdef, device_types_t, CustomOpDef
import traceback
import weakref
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
from typing_extensions import deprecated
import torch
import torch._library as _library
from torch._library.custom_ops import (
_maybe_get_opdef,
custom_op,
CustomOpDef,
device_types_t,
)
from torch._ops import OpOverload
__all__ = [
'Library',
'impl',
'define',
'fallthrough_kernel',
'impl_abstract',
'register_fake',
'get_ctx',
'custom_op',
"Library",
"impl",
"define",
"fallthrough_kernel",
"impl_abstract",
"register_fake",
"get_ctx",
"custom_op",
]
# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
@ -33,7 +39,8 @@ _impls: Set[str] = set()
_defs: Set[str] = set()
# prim is reserved by TorchScript interpreter
_reserved_namespaces = ['prim']
_reserved_namespaces = ["prim"]
def fallthrough_kernel():
"""
@ -41,6 +48,7 @@ def fallthrough_kernel():
"""
raise NotImplementedError("fallthrough_kernel() should never be called.")
class Library:
"""
A class to create libraries that can be used to register new operators or
@ -59,16 +67,22 @@ class Library:
kind: "DEF", "IMPL" (default: "IMPL"), "FRAGMENT"
dispatch_key: PyTorch dispatch key (default: "")
"""
def __init__(self, ns, kind, dispatch_key=""):
if kind not in ('IMPL', 'DEF', 'FRAGMENT'):
if kind not in ("IMPL", "DEF", "FRAGMENT"):
raise ValueError("Unsupported kind: ", kind)
if ns in _reserved_namespaces and (kind == "DEF" or kind == 'FRAGMENT'):
raise ValueError(ns, " is a reserved namespace. Please try creating a library with another name.")
if ns in _reserved_namespaces and (kind == "DEF" or kind == "FRAGMENT"):
raise ValueError(
ns,
" is a reserved namespace. Please try creating a library with another name.",
)
frame = traceback.extract_stack(limit=3)[0]
filename, lineno = frame.filename, frame.lineno
self.m: Optional[Any] = torch._C._dispatch_library(kind, ns, dispatch_key, filename, lineno)
self.m: Optional[Any] = torch._C._dispatch_library(
kind, ns, dispatch_key, filename, lineno
)
self.ns = ns
self._op_defs: Set[str] = set()
self._op_impls: Set[str] = set()
@ -79,13 +93,21 @@ class Library:
# Python __del__ can lead to weird things (globals and locals may already
# be gone when __del__ actually gets called!). finalizers help the
# situation because it lets us capture references and keeps them alive
weakref.finalize(self, _del_library, _impls, self._op_impls, _defs, self._op_defs, self._registration_handles)
weakref.finalize(
self,
_del_library,
_impls,
self._op_impls,
_defs,
self._op_defs,
self._registration_handles,
)
def __repr__(self):
return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>"
def define(self, schema, alias_analysis="", *, tags=()):
r'''Defines a new operator and its semantics in the ns namespace.
r"""Defines a new operator and its semantics in the ns namespace.
Args:
schema: function schema to define a new operator.
@ -102,7 +124,7 @@ class Library:
Example::
>>> my_lib = Library("mylib", "DEF")
>>> my_lib.define("sum(Tensor self) -> Tensor")
'''
"""
# This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
# AliasAnalysis type in C++
if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]:
@ -113,7 +135,9 @@ class Library:
name = schema.split("(")[0]
packet_name = name.split(".")[0] if "." in name else name
has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr(getattr(torch.ops, self.ns), packet_name)
has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr(
getattr(torch.ops, self.ns), packet_name
)
result = self.m.define(schema, alias_analysis, tuple(tags))
name = schema.split("(")[0]
@ -131,7 +155,7 @@ class Library:
return result
def _register_fake(self, op_name, fn, _stacklevel=1):
r'''Registers the fake impl for an operator defined in the library.'''
r"""Registers the fake impl for an operator defined in the library."""
source = torch._library.utils.get_source(_stacklevel + 1)
frame = sys._getframe(_stacklevel)
caller_module = inspect.getmodule(frame)
@ -141,7 +165,9 @@ class Library:
# TODO(rzou): We're gonna need to stage this change with torchvision,
# since torchvision is github first.
if caller_module_name is not None and caller_module_name.startswith("torchvision."):
if caller_module_name is not None and caller_module_name.startswith(
"torchvision."
):
caller_module_name = None
qualname = f"{self.ns}::{op_name}"
@ -154,8 +180,8 @@ class Library:
handle = entry.abstract_impl.register(func_to_register, source)
self._registration_handles.append(handle)
def _impl_with_aoti_compile(self, op_name, dispatch_key=''):
r'''Register the operator to use the AOTI-compiled implementation.
def _impl_with_aoti_compile(self, op_name, dispatch_key=""):
r"""Register the operator to use the AOTI-compiled implementation.
Args:
op_name: operator name (along with the overload) or OpOverload object.
@ -165,8 +191,8 @@ class Library:
Example::
>>> my_lib = Library("aten", "IMPL")
>>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU")
'''
if dispatch_key == '':
"""
if dispatch_key == "":
dispatch_key = self.dispatch_key
assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense)
@ -175,19 +201,24 @@ class Library:
elif isinstance(op_name, OpOverload):
name = op_name._schema.name
overload_name = op_name._schema.overload_name
if overload_name != '':
name = name + '.' + overload_name
if overload_name != "":
name = name + "." + overload_name
else:
raise RuntimeError("_impl_with_aoti_compile should be passed either a name or an OpOverload object "
"as the first argument")
raise RuntimeError(
"_impl_with_aoti_compile should be passed either a name or an OpOverload object "
"as the first argument"
)
key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
if key in _impls:
# TODO: in future, add more info about where the existing function is registered (this info is
# today already returned by the C++ warning when _impl_with_aoti_compile is called but we error out before that)
raise RuntimeError("This is not allowed since there's already a kernel registered from python overriding {}"
"'s behavior for {} dispatch key and {} namespace.".
format(name.split("::")[-1], dispatch_key, self.ns))
raise RuntimeError(
"This is not allowed since there's already a kernel registered from python overriding {}"
"'s behavior for {} dispatch key and {} namespace.".format(
name.split("::")[-1], dispatch_key, self.ns
)
)
assert self.m is not None
impl_fn: Callable = self.m.impl_with_aoti_compile
@ -196,8 +227,8 @@ class Library:
_impls.add(key)
self._op_impls.add(key)
def impl(self, op_name, fn, dispatch_key='', *, with_keyset=False):
r'''Registers the function implementation for an operator defined in the library.
def impl(self, op_name, fn, dispatch_key="", *, with_keyset=False):
r"""Registers the function implementation for an operator defined in the library.
Args:
op_name: operator name (along with the overload) or OpOverload object.
@ -211,10 +242,12 @@ class Library:
>>> def div_cpu(self, other):
>>> return self * (1 / other)
>>> my_lib.impl("div.Tensor", div_cpu, "CPU")
'''
"""
if not callable(fn):
raise TypeError(f"Input function is required to be a callable but found type {type(fn)}")
if dispatch_key == '':
raise TypeError(
f"Input function is required to be a callable but found type {type(fn)}"
)
if dispatch_key == "":
dispatch_key = self.dispatch_key
if isinstance(op_name, str):
@ -222,37 +255,50 @@ class Library:
elif isinstance(op_name, OpOverload):
name = op_name._schema.name
overload_name = op_name._schema.overload_name
if overload_name != '':
name = name + '.' + overload_name
if overload_name != "":
name = name + "." + overload_name
else:
raise RuntimeError("impl should be passed either a name or an OpOverload object as the first argument")
raise RuntimeError(
"impl should be passed either a name or an OpOverload object as the first argument"
)
key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
if key in _impls:
# TODO: in future, add more info about where the existing function is registered (this info is
# today already returned by the C++ warning when impl is called but we error out before that)
raise RuntimeError("This is not allowed since there's already a kernel registered from python overriding {}"
"'s behavior for {} dispatch key and {} namespace.".
format(name.split("::")[-1], dispatch_key, self.ns))
raise RuntimeError(
"This is not allowed since there's already a kernel registered from python overriding {}"
"'s behavior for {} dispatch key and {} namespace.".format(
name.split("::")[-1], dispatch_key, self.ns
)
)
if dispatch_key == "Meta":
dispatcher_op_name = name
if '::' not in dispatcher_op_name:
dispatcher_op_name = f'{self.ns}::{dispatcher_op_name}'
if "::" not in dispatcher_op_name:
dispatcher_op_name = f"{self.ns}::{dispatcher_op_name}"
# Internally, we shouldn't be registering meta kernels for any operators that
# have CompositeImplicitAutograd kernels.
# Instead, we should be letting those decompositions run, and writing meta kernels
# only for the base operators.
if torch._C._dispatch_has_kernel_for_dispatch_key(dispatcher_op_name, "CompositeImplicitAutograd"):
if torch._C._dispatch_has_kernel_for_dispatch_key(
dispatcher_op_name, "CompositeImplicitAutograd"
):
raise RuntimeError(
f"We should not register a meta kernel directly to the operator '{name}',"
" because it has a CompositeImplicitAutograd kernel in core."
" Instead we should let the operator decompose, and ensure that we have meta kernels"
" for the base ops that it decomposes into.")
" for the base ops that it decomposes into."
)
assert self.m is not None
self.m.impl(name, dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd", fn, with_keyset)
self.m.impl(
name,
dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd",
fn,
with_keyset,
)
_impls.add(key)
self._op_impls.add(key)
@ -283,7 +329,9 @@ class Library:
delattr(namespace, name)
def _del_library(captured_impls, op_impls, captured_defs, op_defs, registration_handles):
def _del_library(
captured_impls, op_impls, captured_defs, op_defs, registration_handles
):
captured_impls -= op_impls
captured_defs -= op_defs
for handle in registration_handles:
@ -357,7 +405,8 @@ def define(qualname, schema, *, lib=None, tags=()):
if not isinstance(qualname, str):
raise ValueError(
f"define(qualname, schema): expected qualname "
f"to be instance of str, got {type(qualname)}")
f"to be instance of str, got {type(qualname)}"
)
namespace, name = torch._library.utils.parse_namespace(qualname)
if lib is None:
lib = Library(namespace, "FRAGMENT")
@ -366,7 +415,8 @@ def define(qualname, schema, *, lib=None, tags=()):
raise ValueError(
f"define(qualname, schema, ...): expected schema "
f'to look like e.g. "(Tensor x) -> Tensor" but '
f'got "{schema}"')
f'got "{schema}"'
)
lib.define(name + schema, alias_analysis="", tags=tags)
@ -375,10 +425,12 @@ def _(lib: Library, schema, alias_analysis=""):
"""The old torch.library.define.
We're keeping this around for BC reasons
"""
def wrap(f):
name = lib.define(schema, alias_analysis)
lib.impl(name, f)
return f
return wrap
@ -460,9 +512,11 @@ def _device_type_to_key(device_type: str) -> str:
@impl.register
def _(lib: Library, name, dispatch_key=""):
"""Legacy torch.library.impl API. Kept around for BC"""
def wrap(f):
lib.impl(name, f, dispatch_key)
return f
return wrap
@ -480,7 +534,9 @@ def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel)
_op_identifier = Union[str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef"]
_op_identifier = Union[
str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef"
]
def register_kernel(
@ -489,7 +545,8 @@ def register_kernel(
func: Optional[Callable] = None,
/,
*,
lib: Optional[Library] = None):
lib: Optional[Library] = None,
):
"""Register an implementation for a device type for this operator.
Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
@ -530,7 +587,9 @@ def register_kernel(
"""
if not isinstance(op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)):
if not isinstance(
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
):
raise ValueError("register_kernel(op): got unexpected type for op: {type(op)}")
if isinstance(op, torch._ops.OpOverload):
op = op._name
@ -549,7 +608,8 @@ def register_fake(
/,
*,
lib: Optional[Library] = None,
_stacklevel: int = 1):
_stacklevel: int = 1,
):
r"""Register a FakeTensor implementation ("fake impl") for this operator.
Also sometimes known as a "meta kernel", "abstract impl".
@ -630,7 +690,9 @@ def register_fake(
>>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
"""
if not isinstance(op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)):
if not isinstance(
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
):
raise ValueError("register_fake(op): got unexpected type for op: {type(op)}")
if isinstance(op, torch._ops.OpOverload):
op = op._name
@ -661,7 +723,14 @@ def register_fake(
return register(func)
def register_autograd(op: _op_identifier, backward: Callable, /, *, setup_context: Optional[Callable] = None, lib=None) -> None:
def register_autograd(
op: _op_identifier,
backward: Callable,
/,
*,
setup_context: Optional[Callable] = None,
lib=None,
) -> None:
r"""Register a backward formula for this custom op.
In order for an operator to work with autograd, you need to register
@ -737,8 +806,12 @@ def register_autograd(op: _op_identifier, backward: Callable, /, *, setup_contex
>>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
"""
if not isinstance(op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)):
raise ValueError(f"register_autograd(op): got unexpected type for op: {type(op)}")
if not isinstance(
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
):
raise ValueError(
f"register_autograd(op): got unexpected type for op: {type(op)}"
)
if isinstance(op, torch._ops.OpOverload):
op = op._name
opdef = _maybe_get_opdef(op)
@ -760,7 +833,8 @@ def register_autograd(op: _op_identifier, backward: Callable, /, *, setup_contex
raise NotImplementedError(
f"register_autograd with kwarg-only Tensor args. In the original "
f"definition of the op, please make your tensors not kwarg-only. "
f"Got: {schema}")
f"Got: {schema}"
)
info = _library.autograd.Info(backward, setup_context)
autograd_kernel = _library.autograd.make_autograd_impl(op, info)
@ -788,8 +862,8 @@ def _check_pystubs_once(func, qualname, actual_module_name):
return func(*args, **kwargs)
maybe_pystub = torch._C._dispatch_pystub(
op._schema.name,
op._schema.overload_name)
op._schema.name, op._schema.overload_name
)
if maybe_pystub is None:
if torch._library.utils.requires_set_python_module():
namespace = op.namespace
@ -800,7 +874,8 @@ def _check_pystubs_once(func, qualname, actual_module_name):
f'companion C++ `m.set_python_module("{actual_module_name}")` '
f"call, but we could not find one. Please add that to "
f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the "
f"operator was registered in ({cpp_filename})")
f"operator was registered in ({cpp_filename})"
)
else:
pystub_module = maybe_pystub[0]
if actual_module_name != pystub_module:
@ -809,9 +884,11 @@ def _check_pystubs_once(func, qualname, actual_module_name):
f"Operator '{qualname}' specified that its python fake impl "
f"is in the Python module '{pystub_module}' but it was actually found "
f"in '{actual_module_name}'. Please either move the fake impl "
f"or correct the m.set_python_module call ({cpp_filename})")
f"or correct the m.set_python_module call ({cpp_filename})"
)
checked = True
return func(*args, **kwargs)
return inner
@ -929,4 +1006,7 @@ def opcheck(
"""
import torch.testing._internal.optests as optests
return optests.opcheck(op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception)
return optests.opcheck(
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
)

View File

@ -23,19 +23,26 @@ instructions in the ``README.md`` in that directory.
import __future__ # noqa: F404
import collections
import contextlib
import functools
import types
import warnings
from typing import Dict, Set, List, Any, Callable, Iterable, Type, Tuple
from functools import wraps
import contextlib
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Type
import torch
from torch._C import (
_has_torch_function, _has_torch_function_unary,
_has_torch_function_variadic, _add_docstr,
_push_on_torch_function_stack, _pop_torch_function_stack, _get_function_stack_at, _len_torch_function_stack,
_is_torch_function_mode_enabled)
_add_docstr,
_get_function_stack_at,
_has_torch_function,
_has_torch_function_unary,
_has_torch_function_variadic,
_is_torch_function_mode_enabled,
_len_torch_function_stack,
_pop_torch_function_stack,
_push_on_torch_function_stack,
)
__all__ = [
"get_ignored_functions",
@ -52,7 +59,8 @@ __all__ = [
def _disable_user_warnings(
func: Callable, regex: str = '.*is deprecated, please use.*', module: str = 'torch') -> Callable:
func: Callable, regex: str = ".*is deprecated, please use.*", module: str = "torch"
) -> Callable:
"""
Decorator that temporarily disables ``UserWarning``s for the given ``module`` if the warning message matches the
given ``regex`` pattern.
@ -75,8 +83,11 @@ def _disable_user_warnings(
@wraps(func)
def wrapper(*args, **kwargs):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, message=regex, module=module)
warnings.filterwarnings(
"ignore", category=UserWarning, message=regex, module=module
)
return func(*args, **kwargs)
return wrapper
@ -470,8 +481,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.batch_norm_update_stats: lambda input, running_mean, running_var, momentum: -1,
torch.bernoulli: lambda input, generator=None, out=None: -1,
torch.bilinear: lambda input1, input2, weight, bias: -1,
torch.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None, reduce=None,
reduction='mean', pos_weight=None: -1),
torch.binary_cross_entropy_with_logits: (
lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1
),
torch.bincount: lambda input, weights=None, minlength=0: -1,
torch.binomial: lambda count, prob, generator=None: -1,
torch.bitwise_and: lambda input, other, out=None: -1,
@ -489,9 +501,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.cat: lambda tensors, dim=0, out=None: -1,
torch.concat: lambda tensors, dim=0, out=None: -1, # alias for torch.cat
torch.concatenate: lambda tensors, dim=0, out=None: -1, # alias for torch.concatenate
torch.cdist: lambda x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary': -1,
torch.cdist: lambda x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary": -1,
torch.ceil: lambda input, out=None: -1,
torch.celu: lambda input, alpha=1., inplace=False: -1,
torch.celu: lambda input, alpha=1.0, inplace=False: -1,
torch.chain_matmul: lambda *matrices, out=None: -1,
torch.channel_shuffle: lambda input, groups: -1,
torch.cholesky: lambda input, upper=False, out=None: -1,
@ -528,14 +540,15 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.conv_transpose3d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
torch.corrcoef: lambda input: -1,
torch.cos: lambda input, out=None: -1,
torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1,
torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1,
torch.cosh: lambda input, out=None: -1,
torch.cosine_similarity: lambda x1, x2, dim=1, eps=1e-8: -1,
torch.count_nonzero: lambda input: -1,
torch.cross: lambda input, other, dim=None, out=None: -1,
torch.linalg.cross: lambda input, other, dim=-1, out=None: -1,
torch.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean',
zero_infinity=False: -1),
torch.ctc_loss: (
lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1
),
torch.cummax: lambda input, dim, out=None: -1,
torch.cummin: lambda input, dim, out=None: -1,
torch.cumprod: lambda input, dim, out=None, dtype=None: -1,
@ -570,10 +583,12 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.linalg.eigh: lambda input, UPLO="L", out=None: -1,
torch.linalg.eigvalsh: lambda input, UPLO="L", out=None: -1,
torch.einsum: lambda equation, *operands: -1,
torch.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False,
sparse=False: -1),
torch.embedding_bag: (lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False,
mode='mean', sparse=False, per_sample_weights=None, padding_idx=None: -1),
torch.embedding: (
lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 # noqa: B950
),
torch.embedding_bag: (
lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, padding_idx=None: -1 # noqa: B950
),
torch.empty_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
torch.eq: lambda input, other, out=None: -1,
torch.equal: lambda input, other: -1,
@ -585,14 +600,15 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.expm1: lambda input, out=None: -1,
torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1,
torch.fused_moving_avg_obs_fake_quant: (lambda x, observer_on, fake_quant_on, averaging_const, running_min,
running_max, scale, zero_point, quant_min, quant_max, ch_axis,
per_row_fake_quant=False, symmetric_quant=False: -1),
torch.fused_moving_avg_obs_fake_quant: (
lambda x, observer_on, fake_quant_on, averaging_const, running_min, running_max, scale, zero_point, quant_min, quant_max, ch_axis, per_row_fake_quant=False, symmetric_quant=False: -1 # noqa: B950
),
torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1,
torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1,
torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1,
torch.fbgemm_linear_int8_weight_fp32_activation: (lambda input, weight, packed, col_offsets, weight_scale,
weight_zero_point, bias: -1),
torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1, # noqa: B950
torch.fbgemm_linear_int8_weight_fp32_activation: (
lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1
),
torch.fbgemm_linear_quantize_weight: lambda input: -1,
torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1,
torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1,
@ -630,7 +646,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.fmod: lambda input, other, out=None: -1,
torch.frac: lambda input, out=None: -1,
torch.frexp: lambda input, out=None: -1,
torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, # noqa: B950
torch._functional_assert_async: lambda input, msg, dep_token: -1,
torch.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1,
torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1,
@ -653,7 +669,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.greater: lambda input, other, out=None: -1,
torch.hardshrink: lambda input, lambd=0.5: -1,
torch.heaviside: lambda input, values, out=None: -1,
torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction='mean': -1,
torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950
torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
torch.histogram: lambda input, bins=100, min=None, max=None, weight=None, density=False, out=None: -1,
torch.histogramdd: lambda input, bins, range=None, weight=None, density=False: -1,
@ -677,8 +693,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.isreal: lambda tensor: -1,
torch.isposinf: lambda input, out=None: -1,
torch.isneginf: lambda input, out=None: -1,
torch.instance_norm: (lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps,
cudnn_enabled: -1),
torch.instance_norm: (
lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps, cudnn_enabled: -1
),
torch.int_repr: lambda input: -1,
torch.inverse: lambda input, out=None: -1,
torch.linalg.inv: lambda input, out=None: -1,
@ -694,9 +711,10 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.is_signed: lambda input: -1,
torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1,
torch.isnan: lambda input: -1,
torch.istft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
normalized=False, onesided=None, length=None, return_complex=False: -1),
torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
torch.istft: (
lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, normalized=False, onesided=None, length=None, return_complex=False: -1 # noqa: B950
),
torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1,
torch.kron: lambda input, other: -1,
torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1,
torch.linalg.ldl_factor_ex: lambda input, hermitian=False, check_errors=False, out=None: -1,
@ -709,8 +727,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.less_equal: lambda input, other, out=None: -1,
torch.lerp: lambda input, end, weight, out=None: -1,
torch.lgamma: lambda input, out=None: -1,
torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None,
tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1,
torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None, tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1, # noqa: B950
torch.log: lambda input, out=None: -1,
torch.log_softmax: lambda input, dim, dtype=None: -1,
torch.log10: lambda input, out=None: -1,
@ -732,7 +749,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.less: lambda input, other, out=None: -1,
torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1,
torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1,
torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1, # type: ignore[attr-defined] # noqa: B950
torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1, # type: ignore[attr-defined] # noqa: B950
torch.masked_fill: lambda input, mask, value: -1,
torch.masked_scatter: lambda input, mask, source: -1,
torch.masked_select: lambda input, mask, out=None: -1,
@ -754,8 +771,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
torch.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
return_indices=False, ceil_mode=False: -1),
torch.max_pool1d_with_indices: (
lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
),
torch.mean: lambda input, dim=None: -1,
torch.nanmean: lambda input, dim=None, keepdim=False, dtype=None, out=None: -1,
torch.median: lambda input, dim=None: -1,
@ -764,17 +782,21 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.min: lambda input, out=None: -1,
torch.minimum: lambda input, other, out=None: -1,
torch.fmin: lambda input, other, out=None: -1,
torch.miopen_batch_norm: (lambda input, weight, bias, running_mean, running_var, training,
exponential_average_factor, epsilon: -1),
torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1,
torch.miopen_batch_norm: (
lambda input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon: -1
),
torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1, # noqa: B950
torch.miopen_convolution_add_relu: lambda input, weight, z, alpha, bias, stride, padding, dilation, groups: -1,
torch.miopen_convolution_relu: lambda input, weight, bias, stride, padding, dilation, groups: -1,
torch.miopen_convolution_transpose: (lambda input, weight, bias, padding, output_padding, stride, dilation,
groups, benchmark, deterministic: -1),
torch.miopen_depthwise_convolution: (lambda input, weight, bias, padding, stride, dilation, groups, benchmark,
deterministic: -1),
torch.miopen_rnn: (lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first,
dropout, train, bidirectional, batch_sizes, dropout_state: -1),
torch.miopen_convolution_transpose: (
lambda input, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic: -1
),
torch.miopen_depthwise_convolution: (
lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1
),
torch.miopen_rnn: (
lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state: -1 # noqa: B950
),
torch.mm: lambda input, mat2, out=None: -1,
torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1,
torch.movedim: lambda input, source, destination: -1,
@ -809,62 +831,76 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.nn.functional.adaptive_max_pool3d_with_indices: lambda input, output_size, return_indices=False: -1,
torch.nn.functional.affine_grid: lambda theta, size, align_corners=None: -1,
torch.nn.functional.alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
torch.nn.functional.avg_pool2d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False,
count_include_pad=True, divisor_override=None: -1),
torch.nn.functional.avg_pool3d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False,
count_include_pad=True, divisor_override=None: -1),
torch.nn.functional.batch_norm: (lambda input, running_mean, running_var, weight=None, bias=None, training=False,
momentum=0.1, eps=1e-05: -1),
torch.nn.functional.avg_pool2d: (
lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 # noqa: B950
),
torch.nn.functional.avg_pool3d: (
lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 # noqa: B950
),
torch.nn.functional.batch_norm: (
lambda input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05: -1
),
torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1,
torch.nn.functional.binary_cross_entropy: (lambda input, target, weight=None, size_average=None, reduce=None,
reduction="mean": -1),
torch.nn.functional.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None,
reduce=None, reduction="mean", pos_weight=None: -1),
torch.nn.functional.binary_cross_entropy: (
lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1
),
torch.nn.functional.binary_cross_entropy_with_logits: (
lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1
),
torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1,
torch.nn.functional.cosine_embedding_loss: (lambda input1, input2, target, margin=0, size_average=None,
reduce=None, reduction='mean': -1),
torch.nn.functional.cross_entropy: (lambda input, target, weight=None, size_average=None, ignore_index=-100,
reduce=None, reduction="mean", label_smoothing=0.0: -1),
torch.nn.functional.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0,
reduction='mean', zero_infinity=False: -1),
torch.nn.functional.cosine_embedding_loss: (
lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1
),
torch.nn.functional.cross_entropy: (
lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", label_smoothing=0.0: -1 # noqa: B950
),
torch.nn.functional.ctc_loss: (
lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1
),
torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1,
torch.nn.functional.dropout1d: lambda input, p=0.5, training=True, inplace=False: -1,
torch.nn.functional.dropout2d: lambda input, p=0.5, training=True, inplace=False: -1,
torch.nn.functional.dropout3d: lambda input, p=0.5, training=True, inplace=False: -1,
torch.nn.functional.elu: lambda input, alpha=1.0, inplace=False: -1,
torch.nn.functional.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0,
scale_grad_by_freq=False, sparse=False: -1),
torch.nn.functional.embedding_bag: (lambda input, weight, offsets=None, max_norm=None, norm_type=2,
scale_grad_by_freq=False, mode='mean', sparse=False, per_sample_weights=None,
include_last_offset=False, padding_idx=None: -1),
torch.nn.functional.embedding: (
lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 # noqa: B950
),
torch.nn.functional.embedding_bag: (
lambda input, weight, offsets=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None: -1 # noqa: B950
),
torch.nn.functional.feature_alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
torch.nn.functional.fold: lambda input, output_size, kernel_size, dilation=1, padding=0, stride=1: -1,
torch.nn.functional.fractional_max_pool2d: (lambda input, kernel_size, output_size=None, output_ratio=None,
return_indices=False, _random_samples=None: -1),
torch.nn.functional.fractional_max_pool2d: (
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
),
torch.nn.functional.fractional_max_pool2d_with_indices: (
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
_random_samples=None: -1),
torch.nn.functional.fractional_max_pool3d: (lambda input, kernel_size, output_size=None, output_ratio=None,
return_indices=False, _random_samples=None: -1),
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
),
torch.nn.functional.fractional_max_pool3d: (
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
),
torch.nn.functional.fractional_max_pool3d_with_indices: (
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
_random_samples=None: -1),
torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction='mean': -1,
torch.nn.functional.gelu: lambda input, approximate='none': -1,
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
),
torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction="mean": -1,
torch.nn.functional.gelu: lambda input, approximate="none": -1,
torch.nn.functional.glu: lambda input, dim=-1: -1,
torch.nn.functional.grid_sample: lambda input, grid, mode='bilinear', padding_mode='zeros', align_corners=None: -1,
torch.nn.functional.grid_sample: lambda input, grid, mode="bilinear", padding_mode="zeros", align_corners=None: -1, # noqa: B950
torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1,
torch.nn.functional.gumbel_softmax: lambda logits, tau=1, hard=False, eps=1e-10, dim=-1: -1,
torch.nn.functional.hardshrink: lambda input, lambd=0.5: -1,
torch.nn.functional.hardtanh: lambda input, min_val=-1., max_val=1., inplace=False: -1,
torch.nn.functional.hinge_embedding_loss: (lambda input, target, margin=1.0, size_average=None, reduce=None,
reduction='mean': -1),
torch.nn.functional.instance_norm: (lambda input, running_mean=None, running_var=None, weight=None, bias=None,
use_input_stats=True, momentum=0.1, eps=1e-05: -1),
torch.nn.functional.interpolate: (lambda input, size=None, scale_factor=None, mode='nearest', align_corners=None,
recompute_scale_factor=None, antialias=False: -1),
torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
torch.nn.functional.hardtanh: lambda input, min_val=-1.0, max_val=1.0, inplace=False: -1,
torch.nn.functional.hinge_embedding_loss: (
lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1
),
torch.nn.functional.instance_norm: (
lambda input, running_mean=None, running_var=None, weight=None, bias=None, use_input_stats=True, momentum=0.1, eps=1e-05: -1 # noqa: B950
),
torch.nn.functional.interpolate: (
lambda input, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None, antialias=False: -1 # noqa: B950
),
torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1, # noqa: B950
torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1,
torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1,
torch.nn.functional.linear: lambda input, weight, bias=None: -1,
@ -874,55 +910,65 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
torch.nn.functional.lp_pool3d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
torch.nn.functional.margin_ranking_loss: (lambda input1, input2, target, margin=0, size_average=None,
reduce=None, reduction='mean': -1),
torch.nn.functional.max_pool1d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
ceil_mode=False, return_indices=False: -1),
torch.nn.functional.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
return_indices=False, ceil_mode=False: -1),
torch.nn.functional.max_pool2d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
ceil_mode=False, return_indices=False: -1),
torch.nn.functional.max_pool2d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
return_indices=False, ceil_mode=False: -1),
torch.nn.functional.max_pool3d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
return_indices=False, ceil_mode=False: -1),
torch.nn.functional.max_pool3d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
return_indices=False, ceil_mode=False: -1),
torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
torch.nn.functional.margin_ranking_loss: (
lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1
),
torch.nn.functional.max_pool1d: (
lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1
),
torch.nn.functional.max_pool1d_with_indices: (
lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
),
torch.nn.functional.max_pool2d: (
lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1
),
torch.nn.functional.max_pool2d_with_indices: (
lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
),
torch.nn.functional.max_pool3d: (
lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
),
torch.nn.functional.max_pool3d_with_indices: (
lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
),
torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1,
torch.nn.functional.multi_head_attention_forward: (
lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v,
add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None,
need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None,
v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1),
torch.nn.functional.multi_margin_loss: (lambda input, target, p=1, margin=1.0, weight=None, size_average=None,
reduce=None, reduction='mean': -1),
torch.nn.functional.multilabel_margin_loss: (lambda input, target, size_average=None, reduce=None,
reduction='mean': -1),
torch.nn.functional.multilabel_soft_margin_loss: (lambda input, target, weight=None, size_average=None,
reduce=None, reduction='mean': -1),
torch.nn.functional.nll_loss: (lambda input, target, weight=None, size_average=None, ignore_index=-100,
reduce=None, reduction='mean': -1),
lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None, need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None, v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1 # noqa: B950
),
torch.nn.functional.multi_margin_loss: (
lambda input, target, p=1, margin=1.0, weight=None, size_average=None, reduce=None, reduction="mean": -1
),
torch.nn.functional.multilabel_margin_loss: (
lambda input, target, size_average=None, reduce=None, reduction="mean": -1
),
torch.nn.functional.multilabel_soft_margin_loss: (
lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1
),
torch.nn.functional.nll_loss: (
lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean": -1
),
torch.nn.functional.normalize: lambda input, p=2, dim=1, eps=1e-12, out=None: -1,
torch.nn.functional.one_hot: lambda tensor, num_classes=-1: -1,
torch.nn.functional.pad: lambda input, pad, mode='constant', value=0: -1,
torch.nn.functional.pad: lambda input, pad, mode="constant", value=0: -1,
torch.nn.functional.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
torch.nn.functional.poisson_nll_loss: (lambda input, target, log_input=True, full=False, size_average=None,
eps=1e-08, reduce=None, reduction='mean': -1),
torch.nn.functional.poisson_nll_loss: (
lambda input, target, log_input=True, full=False, size_average=None, eps=1e-08, reduce=None, reduction="mean": -1 # noqa: B950
),
torch.nn.functional.prelu: lambda input, weight: -1,
torch.nn.functional.relu: lambda input, inplace=False: -1,
torch.nn.functional.relu6: lambda input, inplace=False: -1,
torch.nn.functional.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1,
torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1, # noqa: B950
torch.nn.functional.selu: lambda input, inplace=False: -1,
torch.nn.functional.silu: lambda input, inplace=False: -1,
torch.nn.functional.mish: lambda input, inplace=False: -1,
torch.nn.functional.scaled_dot_product_attention: lambda query, key, value, attn_mask=None, dropout_p=0.0: -1,
torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean', beta=1.: -1,
torch.nn.functional.huber_loss: lambda input, target, reduction='mean', delta=1.: -1,
torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", beta=1.0: -1, # noqa: B950
torch.nn.functional.huber_loss: lambda input, target, reduction="mean", delta=1.0: -1,
torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950
torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
torch.nn.functional.softplus: lambda input, beta=1, threshold=20: -1,
@ -930,25 +976,29 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.nn.functional.softsign: lambda input: -1,
torch.nn.functional.tanhshrink: lambda input: -1,
torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1,
torch.nn.functional.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06,
swap=False, size_average=None, reduce=None, reduction='mean': -1),
torch.nn.functional.triplet_margin_with_distance_loss: (lambda anchor, positive, negative, *,
distance_function=None, margin=1.0,
swap=False, reduction='mean': -1),
torch.nn.functional.triplet_margin_loss: (
lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 # noqa: B950
),
torch.nn.functional.triplet_margin_with_distance_loss: (
lambda anchor, positive, negative, *, distance_function=None, margin=1.0, swap=False, reduction="mean": -1
),
torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
torch.nn.init.uniform_: lambda tensor, a=0., b=1., generator=None: -1,
torch.nn.init.normal_: lambda tensor, mean=0., std=1., generator=None: -1,
torch.nn.init.uniform_: lambda tensor, a=0.0, b=1.0, generator=None: -1,
torch.nn.init.normal_: lambda tensor, mean=0.0, std=1.0, generator=None: -1,
torch.nn.init.constant_: lambda tensor, val: -1,
torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=None: -1,
torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", generator=None: -1, # noqa: B950
torch.nonzero: lambda input, as_tuple=False: -1,
torch.nonzero_static: lambda input, *, size, fill_value=-1: -1,
torch.argwhere: lambda input: -1,
torch.norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
torch.norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1,
torch.linalg.norm: lambda input, ord=None, dim=None, keepdim=False, out=None, dtype=None: -1,
torch.linalg.vector_norm: lambda input, ord=2, dim=None, keepdim=False, out=None, dtype=None: -1,
torch.linalg.matrix_norm: lambda input, ord='fro', dim=(-2, -1), keepdim=False, out=None, dtype=None: -1,
torch.linalg.matrix_norm: lambda input, ord="fro", dim=(
-2,
-1,
), keepdim=False, out=None, dtype=None: -1,
torch.norm_except_dim: lambda v, pow=2, dim=0: -1,
torch.nuclear_norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
torch.nuclear_norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1,
torch.numel: lambda input: -1,
torch.orgqr: lambda input, tau: -1,
torch.ormqr: lambda input, input2, input3, left=True, transpose=False: -1,
@ -975,28 +1025,43 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.q_scale: lambda input: -1,
torch.q_zero_point: lambda input: -1,
torch.qr: lambda input, some=True, out=None: -1,
torch.linalg.qr: lambda input, mode='reduced', out=None: -1,
torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1,
torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1,
torch.linalg.qr: lambda input, mode="reduced", out=None: -1,
torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1,
torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1,
torch.quantize_per_channel: lambda input, scales, zero_points, axis, dtype: -1,
torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1,
torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1,
torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1,
torch.quantized_gru_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
torch.quantized_lstm_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
torch.quantized_max_pool1d: (lambda input, kernel_size, stride=tuple(), padding=(0,),
dilation=(1,), ceil_mode=False: -1),
torch.quantized_max_pool2d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0),
dilation=(1, 1), ceil_mode=False: -1),
torch.quantized_max_pool3d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0, 0),
dilation=(1, 1, 1), ceil_mode=False: -1),
torch.quantized_rnn_relu_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
torch.quantized_rnn_tanh_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
torch.quantized_gru_cell: (
lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
),
torch.quantized_lstm_cell: (
lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
),
torch.quantized_max_pool1d: (
lambda input, kernel_size, stride=tuple(), padding=(0,), dilation=(
1,
), ceil_mode=False: -1
),
torch.quantized_max_pool2d: (
lambda input, kernel_size, stride=tuple(), padding=(0, 0), dilation=(
1,
1,
), ceil_mode=False: -1
),
torch.quantized_max_pool3d: (
lambda input, kernel_size, stride=tuple(), padding=(0, 0, 0), dilation=(
1,
1,
1,
), ceil_mode=False: -1
),
torch.quantized_rnn_relu_cell: (
lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
),
torch.quantized_rnn_tanh_cell: (
lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
),
torch.rad2deg: lambda input, out=None: -1,
torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
@ -1014,16 +1079,16 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.repeat_interleave: lambda input, dim=None: -1,
torch.reshape: lambda input, shape: -1,
torch.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, # noqa: B950
torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, # noqa: B950
torch.rnn_tanh_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
torch.roll: lambda input, shifts, dims=None: -1,
torch.rot90: lambda input, k=1, dims=(0, 1): -1,
torch.round: lambda input, out=None: -1,
torch.row_stack: lambda tensors, out=None: -1, # alias for torch.vstack
torch._rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1),
torch.rrelu: lambda input, lower=1. / 8, upper=1. / 3, training=False, inplace=False: -1,
torch.rrelu: lambda input, lower=1.0 / 8, upper=1.0 / 3, training=False, inplace=False: -1,
torch.rsqrt: lambda input, out=None: -1,
torch.rsub: lambda input, other, alpha=1: -1,
torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
@ -1031,7 +1096,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.scatter_add: lambda input, dim, index, src: -1,
torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1,
torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1,
torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1, # noqa: B950
torch.select: lambda input, dim, index: -1,
torch.select_scatter: lambda input, src, dim, index: -1,
torch.slice_inverse: lambda input, src, dim=0, start=None, end=None, step=1: -1,
@ -1061,8 +1126,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.stack: lambda tensors, dim=0, out=None: -1,
torch.std: lambda input, dim=None: -1,
torch.std_mean: lambda input, dim=None: -1,
torch.stft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
pad_mode='reflect', normalized=False, onesided=True, return_complex=None: -1),
torch.stft: (
lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=None: -1 # noqa: B950
),
torch.sub: lambda input, other, out=None: -1,
torch.subtract: lambda input, other, out=None: -1,
torch.sum: lambda input, dim=None: -1,
@ -1164,9 +1230,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1,
torch.linalg.solve_triangular: lambda input, B, upper, left=True, unitriangular=False: -1,
torch.tril: lambda input, diagonal=0, out=None: -1,
torch.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False,
size_average=None, reduce=None, reduction='mean': -1),
torch.triplet_margin_loss: (
lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 # noqa: B950
),
torch.triu: lambda input, diagonal=0, out=None: -1,
torch.true_divide: lambda input, other: -1,
torch.trunc: lambda input, out=None: -1,
@ -1436,10 +1502,16 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1,
}
privateuse1_backend_name = torch.utils.backend_registration._privateuse1_backend_name
privateuse1_backend_name = (
torch.utils.backend_registration._privateuse1_backend_name
)
if hasattr(Tensor, privateuse1_backend_name):
ret[getattr(Tensor, privateuse1_backend_name)] = lambda self, device=None, non_blocking=False, **kwargs: -1
ret[getattr(Tensor, f'is_{privateuse1_backend_name}').__get__] = lambda self: -1 # noqa: B009
ret[
getattr(Tensor, privateuse1_backend_name)
] = lambda self, device=None, non_blocking=False, **kwargs: -1
ret[
getattr(Tensor, f"is_{privateuse1_backend_name}").__get__
] = lambda self: -1 # noqa: B009
ret2 = {}
ignored = get_ignored_functions()
@ -1458,11 +1530,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
# bitwise_<op> have dunder methods of the form __<op>__
# And so on.
subname = k.__name__[len("bitwise_") :]
names.extend([
"__" + subname + "__",
"__i" + subname + "__",
"__r" + subname + "__"
])
names.extend(
["__" + subname + "__", "__i" + subname + "__", "__r" + subname + "__"]
)
for name in names:
func = getattr(Tensor, name, None)
@ -1472,6 +1542,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
ret.update(ret2)
return ret
def wrap_torch_function(dispatcher: Callable):
"""Wraps a given function with ``__torch_function__`` -related functionality.
@ -1495,6 +1566,7 @@ def wrap_torch_function(dispatcher: Callable):
>>> def func(a): # This will make func dispatchable by __torch_function__
... return a + 0
"""
def inner(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
@ -1508,7 +1580,10 @@ def wrap_torch_function(dispatcher: Callable):
return inner
def _get_overloaded_args(relevant_args: Iterable[Any], get_type_fn: Callable[[Any], Type] = None) -> List[Any]:
def _get_overloaded_args(
relevant_args: Iterable[Any], get_type_fn: Callable[[Any], Type] = None
) -> List[Any]:
"""Returns a list of arguments on which to call __torch_function__.
Checks arguments in relevant_args for __torch_function__ implementations,
@ -1559,8 +1634,11 @@ def _get_overloaded_args(relevant_args: Iterable[Any], get_type_fn: Callable[[An
#
# NB: Important to exclude _disabled_torch_function_impl, otherwise
# https://github.com/pytorch/pytorch/issues/64687
if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__') and
arg_type.__torch_function__ != torch._C._disabled_torch_function_impl):
if (
arg_type not in overloaded_types
and hasattr(arg_type, "__torch_function__")
and arg_type.__torch_function__ != torch._C._disabled_torch_function_impl
):
# Create lists explicitly for the first type (usually the only one
# done) to avoid setting up the iterator for overloaded_args.
if overloaded_types:
@ -1581,7 +1659,8 @@ def _get_overloaded_args(relevant_args: Iterable[Any], get_type_fn: Callable[[An
def handle_torch_function(
public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs) -> Any:
public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs
) -> Any:
"""Implement a function with checks for ``__torch_function__`` overrides.
See torch::autograd::handle_torch_function for the equivalent of this
@ -1636,11 +1715,16 @@ def handle_torch_function(
# This call needs to become a classmethod call in the future.
# See https://github.com/pytorch/pytorch/issues/63767
torch_func_method = overloaded_arg.__torch_function__
if hasattr(torch_func_method, "__self__") and torch_func_method.__self__ is overloaded_arg and \
torch_func_method is not torch._C._disabled_torch_function_impl:
warnings.warn("Defining your `__torch_function__ as a plain method is deprecated and "
if (
hasattr(torch_func_method, "__self__")
and torch_func_method.__self__ is overloaded_arg
and torch_func_method is not torch._C._disabled_torch_function_impl
):
warnings.warn(
"Defining your `__torch_function__ as a plain method is deprecated and "
"will be an error in future, please define it as a classmethod.",
DeprecationWarning)
DeprecationWarning,
)
# Use `public_api` instead of `implementation` so __torch_function__
# implementations can do equality/identity comparisons.
@ -1649,15 +1733,16 @@ def handle_torch_function(
if result is not NotImplemented:
return result
func_name = f'{public_api.__module__}.{public_api.__name__}'
func_name = f"{public_api.__module__}.{public_api.__name__}"
msg = (
f"no implementation found for '{func_name}' on types that implement "
f'__torch_function__: {[type(arg) for arg in overloaded_args]}'
f"__torch_function__: {[type(arg) for arg in overloaded_args]}"
)
if _is_torch_function_mode_enabled():
msg += f" nor in mode {_get_current_function_mode()}"
raise TypeError(msg)
has_torch_function = _add_docstr(
_has_torch_function,
r"""Check for __torch_function__ implementations in the elements of an iterable
@ -1678,7 +1763,7 @@ has_torch_function = _add_docstr(
________
torch.is_tensor_like
Checks if something is a Tensor-like, including an exact ``Tensor``.
"""
""",
)
has_torch_function_unary = _add_docstr(
@ -1689,7 +1774,7 @@ has_torch_function_unary = _add_docstr(
call:
`has_torch_function_unary(t)`
which skips unnecessary packing and unpacking work.
"""
""",
)
has_torch_function_variadic = _add_docstr(
@ -1703,11 +1788,14 @@ has_torch_function_variadic = _add_docstr(
call:
`has_torch_function_variadic(a, b)`
which skips unnecessary packing and unpacking work.
"""
""",
)
@functools.lru_cache(None)
def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]:
def _get_overridable_functions() -> (
Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]
):
overridable_funcs = collections.defaultdict(list)
index = {}
tested_namespaces = [
@ -1725,21 +1813,21 @@ def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callab
ignore = False
# ignore private functions or functions that are deleted in torch.__init__
if namespace is not torch.Tensor:
if func_name.startswith('__'):
if func_name.startswith("__"):
continue
elif func_name.startswith('_'):
elif func_name.startswith("_"):
ignore = True
elif func_name.endswith('_'):
elif func_name.endswith("_"):
ignore = True
elif not func_name[0].islower():
ignore = True
elif func_name == 'unique_dim':
elif func_name == "unique_dim":
continue
else:
func = getattr(namespace, func_name)
if getattr(object, func_name, None) == func:
continue
if func_name == '__weakref__':
if func_name == "__weakref__":
continue
func = getattr(namespace, func_name)
if namespace is torch.Tensor and getattr(object, func_name, None) == func:
@ -1757,9 +1845,13 @@ def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callab
if ignore:
continue
if func.__get__ in get_ignored_functions():
msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
"but still has an explicit override")
assert func.__get__ not in get_testing_overrides(), msg.format(namespace, func.__name__)
msg = (
"{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
"but still has an explicit override"
)
assert func.__get__ not in get_testing_overrides(), msg.format(
namespace, func.__name__
)
continue
else:
overridable_funcs[func].append(func.__get__)
@ -1775,13 +1867,18 @@ def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callab
# cannot be overriden by __torch_function__
if func in get_ignored_functions():
msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
"but still has an explicit override")
assert func not in get_testing_overrides(), msg.format(namespace, func.__name__)
msg = (
"{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
"but still has an explicit override"
)
assert func not in get_testing_overrides(), msg.format(
namespace, func.__name__
)
continue
overridable_funcs[namespace].append(func)
return overridable_funcs, index
@_disable_user_warnings
def get_overridable_functions() -> Dict[Any, List[Callable]]:
"""List functions that are overridable via __torch_function__
@ -1794,6 +1891,7 @@ def get_overridable_functions() -> Dict[Any, List[Callable]]:
"""
return _get_overridable_functions()[0]
@_disable_user_warnings
def resolve_name(f):
"""Get a human readable string name for a function passed to
@ -1814,6 +1912,7 @@ def resolve_name(f):
return str(f)
return _get_overridable_functions()[1].get(f)
@functools.lru_cache(None)
def _get_tensor_methods() -> Set[Callable]:
"""Returns a set of the overridable methods on ``torch.Tensor``"""
@ -1821,6 +1920,7 @@ def _get_tensor_methods() -> Set[Callable]:
methods = set(overridable_funcs[torch.Tensor])
return methods
@_disable_user_warnings
def is_tensor_method_or_property(func: Callable) -> bool:
"""
@ -1846,6 +1946,7 @@ def is_tensor_method_or_property(func: Callable) -> bool:
"""
return func in _get_tensor_methods() or func.__name__ == "__get__"
def is_tensor_like(inp):
"""
Returns ``True`` if the passed-in input is a Tensor-like.
@ -1882,6 +1983,7 @@ def is_tensor_like(inp):
"""
return type(inp) is torch.Tensor or hasattr(inp, "__torch_function__")
class TorchFunctionMode:
"""
A ``TorchFunctionMode`` allows you to override the meaning of all
@ -1912,6 +2014,7 @@ class TorchFunctionMode:
``enable_torch_function_mode(self, replace=self.inner)`` to make PyTorch
API self-referential (beware of infinite loops, in this case!)
"""
inner: "TorchFunctionMode"
# Force metaclass to generate constructor at the base of the hierarchy
@ -1930,7 +2033,9 @@ class TorchFunctionMode:
@classmethod
def push(cls, *args, **kwargs):
warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`")
warnings.warn(
"`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`"
)
instance = cls(*args, **kwargs)
return instance
@ -1944,6 +2049,7 @@ def _get_current_function_mode_stack():
stack_len = _len_torch_function_stack()
return [_get_function_stack_at(i) for i in range(stack_len)]
def _push_mode(mode):
_push_on_torch_function_stack(mode)
@ -1961,6 +2067,7 @@ def _pop_mode_temporarily():
finally:
_push_mode(old)
class BaseTorchFunctionMode(TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
import torch
from typing import Optional
import torch
class SobolEngine:
r"""
@ -48,8 +49,10 @@ class SobolEngine:
def __init__(self, dimension, scramble=False, seed=None):
if dimension > self.MAXDIM or dimension < 1:
raise ValueError("Supported range of dimensionality "
f"for SobolEngine is [1, {self.MAXDIM}]")
raise ValueError(
"Supported range of dimensionality "
f"for SobolEngine is [1, {self.MAXDIM}]"
)
self.seed = seed
self.scramble = scramble
@ -57,7 +60,9 @@ class SobolEngine:
cpu = torch.device("cpu")
self.sobolstate = torch.zeros(dimension, self.MAXBIT, device=cpu, dtype=torch.long)
self.sobolstate = torch.zeros(
dimension, self.MAXBIT, device=cpu, dtype=torch.long
)
torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension)
if not self.scramble:
@ -69,8 +74,12 @@ class SobolEngine:
self._first_point = (self.quasi / 2**self.MAXBIT).reshape(1, -1)
self.num_generated = 0
def draw(self, n: int = 1, out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None) -> torch.Tensor:
def draw(
self,
n: int = 1,
out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
r"""
Function to draw a sequence of :attr:`n` points from a Sobol sequence.
Note that the samples are dependent on the previous samples. The size
@ -92,12 +101,22 @@ class SobolEngine:
result = self._first_point.to(dtype)
else:
result, self.quasi = torch._sobol_engine_draw(
self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated, dtype=dtype,
self.quasi,
n - 1,
self.sobolstate,
self.dimension,
self.num_generated,
dtype=dtype,
)
result = torch.cat((self._first_point.to(dtype), result), dim=-2)
else:
result, self.quasi = torch._sobol_engine_draw(
self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1, dtype=dtype,
self.quasi,
n,
self.sobolstate,
self.dimension,
self.num_generated - 1,
dtype=dtype,
)
self.num_generated += n
@ -108,8 +127,12 @@ class SobolEngine:
return result
def draw_base2(self, m: int, out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None) -> torch.Tensor:
def draw_base2(
self,
m: int,
out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
r"""
Function to draw a sequence of :attr:`2**m` points from a Sobol sequence.
Note that the samples are dependent on the previous samples. The size
@ -125,7 +148,8 @@ class SobolEngine:
n = 2**m
total_n = self.num_generated + n
if not (total_n & (total_n - 1) == 0):
raise ValueError("The balance properties of Sobol' points require "
raise ValueError(
"The balance properties of Sobol' points require "
f"n to be a power of 2. {self.num_generated} points have been "
f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. "
"If you still want to do this, please use "
@ -151,9 +175,13 @@ class SobolEngine:
n (Int): The number of steps to fast-forward by.
"""
if self.num_generated == 0:
torch._sobol_engine_ff_(self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated)
torch._sobol_engine_ff_(
self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated
)
else:
torch._sobol_engine_ff_(self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1)
torch._sobol_engine_ff_(
self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1
)
self.num_generated += n
return self
@ -166,8 +194,12 @@ class SobolEngine:
cpu = torch.device("cpu")
# Generate shift vector
shift_ints = torch.randint(2, (self.dimension, self.MAXBIT), device=cpu, generator=g)
self.shift = torch.mv(shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu)))
shift_ints = torch.randint(
2, (self.dimension, self.MAXBIT), device=cpu, generator=g
)
self.shift = torch.mv(
shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu))
)
# Generate lower triangular matrices (stacked across dimensions)
ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT)
@ -176,9 +208,9 @@ class SobolEngine:
torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension)
def __repr__(self):
fmt_string = [f'dimension={self.dimension}']
fmt_string = [f"dimension={self.dimension}"]
if self.scramble:
fmt_string += ['scramble=True']
fmt_string += ["scramble=True"]
if self.seed is not None:
fmt_string += [f'seed={self.seed}']
return self.__class__.__name__ + '(' + ', '.join(fmt_string) + ')'
fmt_string += [f"seed={self.seed}"]
return self.__class__.__name__ + "(" + ", ".join(fmt_string) + ")"

View File

@ -1,10 +1,10 @@
# mypy: allow-untyped-defs
import contextlib
from typing import Generator
import warnings
from typing import Generator
from torch._C import default_generator
import torch
from torch._C import default_generator
def set_rng_state(new_state: torch.Tensor) -> None:
@ -46,10 +46,12 @@ def manual_seed(seed) -> torch._C.Generator:
torch.cuda.manual_seed_all(seed)
import torch.mps
if not torch.mps._is_in_bad_fork():
torch.mps.manual_seed(seed)
import torch.xpu
if not torch.xpu._is_in_bad_fork():
torch.xpu.manual_seed_all(seed)
@ -69,10 +71,12 @@ def seed() -> int:
torch.cuda.manual_seed_all(seed)
import torch.mps
if not torch.mps._is_in_bad_fork():
torch.mps.manual_seed(seed)
import torch.xpu
if not torch.xpu._is_in_bad_fork():
torch.xpu.manual_seed_all(seed)
@ -95,7 +99,9 @@ def _seed_custom_device(seed) -> None:
custom_device_mod = getattr(torch, custom_backend_name)
_bad_fork_name = "_is_in_bad_fork"
_seed_all_name = "manual_seed_all"
if hasattr(custom_device_mod, _bad_fork_name) and hasattr(custom_device_mod, _seed_all_name):
if hasattr(custom_device_mod, _bad_fork_name) and hasattr(
custom_device_mod, _seed_all_name
):
if not getattr(custom_device_mod, _bad_fork_name)():
getattr(custom_device_mod, _seed_all_name)(seed)
else:
@ -117,7 +123,13 @@ _fork_rng_warned_already = False
@contextlib.contextmanager
def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices", device_type="cuda") -> Generator:
def fork_rng(
devices=None,
enabled=True,
_caller="fork_rng",
_devices_kw="devices",
device_type="cuda",
) -> Generator:
"""
Forks the RNG, so that when you return, the RNG is reset
to the state that it was previously in.
@ -138,8 +150,10 @@ def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="device
device_type = torch.device(device_type).type
device_mod = getattr(torch, device_type, None)
if device_mod is None:
raise RuntimeError(f"torch has no module of `{device_type}`, you should register " +
"a module by `torch._register_device_module`.")
raise RuntimeError(
f"torch has no module of `{device_type}`, you should register "
+ "a module by `torch._register_device_module`."
)
global _fork_rng_warned_already
# Internal arguments:
@ -153,7 +167,8 @@ def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="device
if devices is None:
num_devices = device_mod.device_count()
if num_devices > 1 and not _fork_rng_warned_already:
message = (f"{device_type.upper()} reports that you have {num_devices} available devices, and "
message = (
f"{device_type.upper()} reports that you have {num_devices} available devices, and "
f"you have used {_caller} without explicitly specifying which devices are being used. "
f"For safety, we initialize *every* {device_type.upper()} device by default, which can "
f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only"
@ -163,7 +178,8 @@ def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="device
"set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, "
f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices "
f"and suppress this warning, set the '{_devices_kw}' keyword argument to "
f"`range(torch.{device_type}.device_count())`.")
f"`range(torch.{device_type}.device_count())`."
)
warnings.warn(message)
_fork_rng_warned_already = True
devices = list(range(num_devices))

View File

@ -1,8 +1,9 @@
import torch
import inspect
import torch
from torch.utils._pytree import register_pytree_node, SequenceKey
__all__ = ["pytree_register_structseq", "all_return_types"]
all_return_types = []
@ -10,6 +11,7 @@ all_return_types = []
# error: Module has no attribute "_return_types"
return_types = torch._C._return_types # type: ignore[attr-defined]
def pytree_register_structseq(cls):
def structseq_flatten(structseq):
return list(structseq), None
@ -28,14 +30,15 @@ def pytree_register_structseq(cls):
flatten_with_keys_fn=structseq_flatten_with_keys,
)
for name in dir(return_types):
if name.startswith('__'):
if name.startswith("__"):
continue
_attr = getattr(return_types, name)
globals()[name] = _attr
if not name.startswith('_'):
if not name.startswith("_"):
__all__.append(name)
all_return_types.append(_attr)

File diff suppressed because it is too large Load Diff

View File

@ -2,8 +2,9 @@
from typing import Any, Iterable
from ._vendor.packaging.version import InvalidVersion, Version
from .version import __version__ as internal_version
from torch._vendor.packaging.version import InvalidVersion, Version
from torch.version import __version__ as internal_version
__all__ = ["TorchVersion"]

View File

@ -1,9 +1,14 @@
# mypy: allow-untyped-defs
import builtins
from typing import Any, List, Optional, Sequence, Tuple, Union
from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
import torch
if TYPE_CHECKING:
from torch.autograd.graph import GradientEdge
# Convenience aliases for common composite types that we need
# to talk about in PyTorch
@ -11,8 +16,8 @@ _TensorOrTensors = Union[torch.Tensor, Sequence[torch.Tensor]]
_TensorOrTensorsOrGradEdge = Union[
torch.Tensor,
Sequence[torch.Tensor],
"torch.autograd.graph.GradientEdge",
Sequence["torch.autograd.graph.GradientEdge"],
"GradientEdge",
Sequence["GradientEdge"],
]
# In some cases, these basic types are shadowed by corresponding