[BE] enable UFMT for top-level files torch/*.py (#127707)

Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127707
Approved by: https://github.com/ezyang
This commit is contained in:
Xuehai Pan 2024-06-12 20:10:47 +08:00 committed by PyTorch MergeBot
parent cc231a8e2b
commit dd143d44cc
15 changed files with 1548 additions and 875 deletions

View File

@ -1556,7 +1556,6 @@ exclude_patterns = [
'torch/distributed/tensor/parallel/style.py', 'torch/distributed/tensor/parallel/style.py',
'torch/fft/__init__.py', 'torch/fft/__init__.py',
'torch/func/__init__.py', 'torch/func/__init__.py',
'torch/functional.py',
'torch/futures/__init__.py', 'torch/futures/__init__.py',
'torch/fx/__init__.py', 'torch/fx/__init__.py',
'torch/fx/_compatibility.py', 'torch/fx/_compatibility.py',
@ -1642,8 +1641,6 @@ exclude_patterns = [
'torch/fx/subgraph_rewriter.py', 'torch/fx/subgraph_rewriter.py',
'torch/fx/tensor_type.py', 'torch/fx/tensor_type.py',
'torch/fx/traceback.py', 'torch/fx/traceback.py',
'torch/hub.py',
'torch/library.py',
'torch/linalg/__init__.py', 'torch/linalg/__init__.py',
'torch/monitor/__init__.py', 'torch/monitor/__init__.py',
'torch/nested/__init__.py', 'torch/nested/__init__.py',
@ -1767,11 +1764,6 @@ exclude_patterns = [
'torch/nn/utils/rnn.py', 'torch/nn/utils/rnn.py',
'torch/nn/utils/spectral_norm.py', 'torch/nn/utils/spectral_norm.py',
'torch/nn/utils/weight_norm.py', 'torch/nn/utils/weight_norm.py',
'torch/overrides.py',
'torch/quasirandom.py',
'torch/random.py',
'torch/return_types.py',
'torch/serialization.py',
'torch/signal/__init__.py', 'torch/signal/__init__.py',
'torch/signal/windows/__init__.py', 'torch/signal/windows/__init__.py',
'torch/signal/windows/windows.py', 'torch/signal/windows/windows.py',

View File

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import contextlib import contextlib
import dataclasses import dataclasses
import enum import enum
import functools import functools
@ -31,6 +30,7 @@ from torch.utils import _pytree as pytree
from torch.utils._traceback import CapturedTraceback from torch.utils._traceback import CapturedTraceback
from torch.utils.weak import WeakTensorKeyDictionary from torch.utils.weak import WeakTensorKeyDictionary
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -40,7 +40,6 @@ if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features, # Import the following modules during type checking to enable code intelligence features,
# such as auto-completion in tools like pylance, even when these modules are not explicitly # such as auto-completion in tools like pylance, even when these modules are not explicitly
# imported in user code. # imported in user code.
import torch import torch
@ -176,7 +175,7 @@ class Guard:
def sort_key(self): def sort_key(self):
# Put the duplicate input guards at the end. The duplicate guards have # Put the duplicate input guards at the end. The duplicate guards have
# two sources while guard.name only considers one source. # two sources while guard.name only considers one source.
from ._dynamo.guards import GuardBuilder from torch._dynamo.guards import GuardBuilder
is_duplicate_input = ( is_duplicate_input = (
isinstance(self.create_fn, functools.partial) isinstance(self.create_fn, functools.partial)

View File

@ -7,9 +7,8 @@
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import torch import torch
from torch import Tensor from torch import _linalg_utils as _utils, Tensor
from . import _linalg_utils as _utils from torch.overrides import handle_torch_function, has_torch_function
from .overrides import handle_torch_function, has_torch_function
__all__ = ["lobpcg"] __all__ = ["lobpcg"]

View File

@ -6,9 +6,8 @@ __all__ = ["svd_lowrank", "pca_lowrank"]
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
from torch import Tensor from torch import _linalg_utils as _utils, Tensor
from . import _linalg_utils as _utils from torch.overrides import handle_torch_function, has_torch_function
from .overrides import handle_torch_function, has_torch_function
def get_approximate_basis( def get_approximate_basis(

View File

@ -761,22 +761,22 @@ class Tensor(torch._C.TensorBase):
return torch.norm(self, p, dim, keepdim, dtype=dtype) return torch.norm(self, p, dim, keepdim, dtype=dtype)
def solve(self, other): def solve(self, other):
from ._linalg_utils import solve from torch._linalg_utils import solve
return solve(self, other) return solve(self, other)
def lstsq(self, other): def lstsq(self, other):
from ._linalg_utils import lstsq from torch._linalg_utils import lstsq
return lstsq(self, other) return lstsq(self, other)
def eig(self, eigenvectors=False): def eig(self, eigenvectors=False):
from ._linalg_utils import eig from torch._linalg_utils import eig
return eig(self, eigenvectors=eigenvectors) return eig(self, eigenvectors=eigenvectors)
def symeig(self, eigenvectors=False): def symeig(self, eigenvectors=False):
from ._linalg_utils import _symeig from torch._linalg_utils import _symeig
return _symeig(self, eigenvectors=eigenvectors) return _symeig(self, eigenvectors=eigenvectors)

View File

@ -1,47 +1,46 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from typing import (
List, Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING
)
import operator
import itertools import itertools
import operator
from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
import torch import torch
from torch._C import _add_docstr
import torch.nn.functional as F import torch.nn.functional as F
from ._lowrank import svd_lowrank, pca_lowrank from torch import _VF, Tensor
from .overrides import ( from torch._C import _add_docstr
has_torch_function, has_torch_function_unary, has_torch_function_variadic, from torch._jit_internal import _overload as overload, boolean_dispatch
handle_torch_function) from torch._lowrank import pca_lowrank, svd_lowrank
from ._jit_internal import boolean_dispatch from torch.overrides import (
from ._jit_internal import _overload as overload handle_torch_function,
has_torch_function,
has_torch_function_unary,
has_torch_function_variadic,
)
Tensor = torch.Tensor
from torch import _VF
__all__ = [ __all__ = [
'atleast_1d', "atleast_1d",
'atleast_2d', "atleast_2d",
'atleast_3d', "atleast_3d",
'align_tensors', "align_tensors",
'broadcast_shapes', "broadcast_shapes",
'broadcast_tensors', "broadcast_tensors",
'cartesian_prod', "cartesian_prod",
'block_diag', "block_diag",
'cdist', "cdist",
'chain_matmul', "chain_matmul",
'einsum', "einsum",
'istft', "istft",
'lu', "lu",
'norm', "norm",
'meshgrid', "meshgrid",
'pca_lowrank', "pca_lowrank",
'split', "split",
'stft', "stft",
'svd_lowrank', "svd_lowrank",
'tensordot', "tensordot",
'unique', "unique",
'unique_consecutive', "unique_consecutive",
'unravel_index', "unravel_index",
] ]
@ -124,16 +123,25 @@ def broadcast_shapes(*shapes):
if isinstance(shape, (tuple, list)): if isinstance(shape, (tuple, list)):
for i in range(-1, -1 - len(shape), -1): for i in range(-1, -1 - len(shape), -1):
if shape[i] < 0: if shape[i] < 0:
raise RuntimeError(f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})") raise RuntimeError(
f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})"
)
# NB: result is initialized to 1 so this is effectively an # NB: result is initialized to 1 so this is effectively an
# equals one test # equals one test
if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious(shape[i] == result[i]): if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious(
shape[i] == result[i]
):
continue continue
if result[i] != 1: if result[i] != 1:
raise RuntimeError("Shape mismatch: objects cannot be broadcast to a single shape") raise RuntimeError(
"Shape mismatch: objects cannot be broadcast to a single shape"
)
result[i] = shape[i] result[i] = shape[i]
else: else:
raise RuntimeError("Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", shape) raise RuntimeError(
"Input shapes should be of type ints, a tuple of ints, or a list of ints, got ",
shape,
)
return torch.Size(result) return torch.Size(result)
else: else:
# with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail # with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail
@ -188,7 +196,8 @@ def split(
""" """
if has_torch_function_unary(tensor): if has_torch_function_unary(tensor):
return handle_torch_function( return handle_torch_function(
split, (tensor,), tensor, split_size_or_sections, dim=dim) split, (tensor,), tensor, split_size_or_sections, dim=dim
)
# Overwriting reason: # Overwriting reason:
# This dispatches to two ATen functions depending on the type of # This dispatches to two ATen functions depending on the type of
# split_size_or_sections. The branching code is in _tensor.py, which we # split_size_or_sections. The branching code is in _tensor.py, which we
@ -335,10 +344,13 @@ def einsum(*args: Any) -> Tensor:
[ 0.3311, 5.5201, -3.0356]]) [ 0.3311, 5.5201, -3.0356]])
""" """
import torch.backends.opt_einsum as opt_einsum import torch.backends.opt_einsum as opt_einsum
# This wrapper exists to support variadic args. # This wrapper exists to support variadic args.
if len(args) < 2: if len(args) < 2:
raise ValueError('einsum(): must specify the equation string and at least one operand, ' raise ValueError(
'or at least one operand and its subscripts list') "einsum(): must specify the equation string and at least one operand, "
"or at least one operand and its subscripts list"
)
equation = None equation = None
operands = None operands = None
@ -350,19 +362,21 @@ def einsum(*args: Any) -> Tensor:
# input operands into a tensorlist (List[Tensor]). # input operands into a tensorlist (List[Tensor]).
def parse_subscript(n: int) -> str: def parse_subscript(n: int) -> str:
if n == Ellipsis: if n == Ellipsis:
return '...' return "..."
if n >= 0 and n < 26: if n >= 0 and n < 26:
return chr(ord('A') + n) return chr(ord("A") + n)
if n >= 26 and n < 52: if n >= 26 and n < 52:
return chr(ord('a') + n - 26) return chr(ord("a") + n - 26)
raise ValueError('einsum(): subscript in subscript list is not within the valid range [0, 52)') raise ValueError(
"einsum(): subscript in subscript list is not within the valid range [0, 52)"
)
# Parse subscripts for input operands # Parse subscripts for input operands
equation = ','.join(''.join(parse_subscript(s) for s in l) for l in args[1::2]) equation = ",".join("".join(parse_subscript(s) for s in l) for l in args[1::2])
# Parse optional output subscripts (provided when the number of arguments is odd) # Parse optional output subscripts (provided when the number of arguments is odd)
if len(args) % 2 == 1: if len(args) % 2 == 1:
equation += '->' + ''.join(parse_subscript(s) for s in args[-1]) equation += "->" + "".join(parse_subscript(s) for s in args[-1])
operands = args[:-1:2] operands = args[:-1:2]
else: else:
operands = args[::2] operands = args[::2]
@ -388,7 +402,9 @@ def einsum(*args: Any) -> Tensor:
path = None path = None
if opt_einsum.is_available(): if opt_einsum.is_available():
_opt_einsum = opt_einsum.get_opt_einsum() _opt_einsum = opt_einsum.get_opt_einsum()
tupled_path = _opt_einsum.contract_path(equation, *operands, optimize=opt_einsum.strategy)[0] tupled_path = _opt_einsum.contract_path(
equation, *operands, optimize=opt_einsum.strategy
)[0]
# flatten path for dispatching to C++ # flatten path for dispatching to C++
path = [item for pair in tupled_path for item in pair] path = [item for pair in tupled_path for item in pair]
return _VF.einsum(equation, operands, path=path) # type: ignore[attr-defined] return _VF.einsum(equation, operands, path=path) # type: ignore[attr-defined]
@ -397,10 +413,13 @@ def einsum(*args: Any) -> Tensor:
# This wrapper exists to support variadic args. # This wrapper exists to support variadic args.
if TYPE_CHECKING: if TYPE_CHECKING:
# The JIT doesn't understand Union, so only add type annotation for mypy # The JIT doesn't understand Union, so only add type annotation for mypy
def meshgrid(*tensors: Union[Tensor, List[Tensor]], def meshgrid(
indexing: Optional[str] = None) -> Tuple[Tensor, ...]: *tensors: Union[Tensor, List[Tensor]], indexing: Optional[str] = None
) -> Tuple[Tensor, ...]:
return _meshgrid(*tensors, indexing=indexing) return _meshgrid(*tensors, indexing=indexing)
else: else:
def meshgrid(*tensors, indexing: Optional[str] = None) -> Tuple[Tensor, ...]: def meshgrid(*tensors, indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors. r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors.
@ -509,15 +528,22 @@ def _meshgrid(*tensors, indexing: Optional[str]):
# kwarg for forward compatibility reasons. # kwarg for forward compatibility reasons.
# #
# Remove this two weeks after landing. # Remove this two weeks after landing.
kwargs = {} if indexing is None else {'indexing': indexing} kwargs = {} if indexing is None else {"indexing": indexing}
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined] return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, def stft(
win_length: Optional[int] = None, window: Optional[Tensor] = None, input: Tensor,
center: bool = True, pad_mode: str = 'reflect', normalized: bool = False, n_fft: int,
hop_length: Optional[int] = None,
win_length: Optional[int] = None,
window: Optional[Tensor] = None,
center: bool = True,
pad_mode: str = "reflect",
normalized: bool = False,
onesided: Optional[bool] = None, onesided: Optional[bool] = None,
return_complex: Optional[bool] = None) -> Tensor: return_complex: Optional[bool] = None,
) -> Tensor:
r"""Short-time Fourier transform (STFT). r"""Short-time Fourier transform (STFT).
.. warning:: .. warning::
@ -652,9 +678,19 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
""" """
if has_torch_function_unary(input): if has_torch_function_unary(input):
return handle_torch_function( return handle_torch_function(
stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length, stft,
window=window, center=center, pad_mode=pad_mode, normalized=normalized, (input,),
onesided=onesided, return_complex=return_complex) input,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
center=center,
pad_mode=pad_mode,
normalized=normalized,
onesided=onesided,
return_complex=return_complex,
)
# NOTE: Do not edit. This code will be removed once the forward-compatibility # NOTE: Do not edit. This code will be removed once the forward-compatibility
# period is over for PR #73432 # period is over for PR #73432
if center: if center:
@ -663,8 +699,16 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
pad = int(n_fft // 2) pad = int(n_fft // 2)
input = F.pad(input.view(extended_shape), [pad, pad], pad_mode) input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
input = input.view(input.shape[-signal_dim:]) input = input.view(input.shape[-signal_dim:])
return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined] return _VF.stft( # type: ignore[attr-defined]
normalized, onesided, return_complex) input,
n_fft,
hop_length,
win_length,
window,
normalized,
onesided,
return_complex,
)
istft = _add_docstr( istft = _add_docstr(
@ -746,7 +790,8 @@ Args:
Returns: Returns:
Tensor: Least squares estimation of the original signal of shape `(B?, length)` where Tensor: Least squares estimation of the original signal of shape `(B?, length)` where
`B?` is an optional batch dimension from the input tensor. `B?` is an optional batch dimension from the input tensor.
""") """,
)
if TYPE_CHECKING: if TYPE_CHECKING:
@ -758,9 +803,13 @@ else:
_unique_impl_out = Tuple[Tensor, Tensor, Tensor] _unique_impl_out = Tuple[Tensor, Tensor, Tensor]
def _unique_impl(input: Tensor, sorted: bool = True, def _unique_impl(
return_inverse: bool = False, return_counts: bool = False, input: Tensor,
dim: Optional[int] = None) -> _unique_impl_out: sorted: bool = True,
return_inverse: bool = False,
return_counts: bool = False,
dim: Optional[int] = None,
) -> _unique_impl_out:
r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> Tuple[Tensor, Tensor, Tensor] r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> Tuple[Tensor, Tensor, Tensor]
Returns the unique elements of the input tensor. Returns the unique elements of the input tensor.
@ -896,8 +945,14 @@ def _unique_impl(input: Tensor, sorted: bool = True,
""" """
if has_torch_function_unary(input): if has_torch_function_unary(input):
return handle_torch_function( return handle_torch_function(
unique, (input,), input, sorted=sorted, return_inverse=return_inverse, unique,
return_counts=return_counts, dim=dim) (input,),
input,
sorted=sorted,
return_inverse=return_inverse,
return_counts=return_counts,
dim=dim,
)
if dim is not None: if dim is not None:
output, inverse_indices, counts = _VF.unique_dim( output, inverse_indices, counts = _VF.unique_dim(
@ -917,9 +972,12 @@ def _unique_impl(input: Tensor, sorted: bool = True,
return output, inverse_indices, counts return output, inverse_indices, counts
def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False, def _unique_consecutive_impl(
input: Tensor,
return_inverse: bool = False,
return_counts: bool = False, return_counts: bool = False,
dim: Optional[int] = None) -> _unique_impl_out: dim: Optional[int] = None,
) -> _unique_impl_out:
r"""Eliminates all but the first element from every consecutive group of equivalent elements. r"""Eliminates all but the first element from every consecutive group of equivalent elements.
.. note:: This function is different from :func:`torch.unique` in the sense that this function .. note:: This function is different from :func:`torch.unique` in the sense that this function
@ -971,14 +1029,22 @@ def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False,
""" """
if has_torch_function_unary(input): if has_torch_function_unary(input):
return handle_torch_function( return handle_torch_function(
unique_consecutive, (input,), input, return_inverse=return_inverse, unique_consecutive,
return_counts=return_counts, dim=dim) (input,),
input,
return_inverse=return_inverse,
return_counts=return_counts,
dim=dim,
)
output, inverse_indices, counts = _VF.unique_consecutive( # type: ignore[attr-defined] output, inverse_indices, counts = _VF.unique_consecutive( # type: ignore[attr-defined]
input, return_inverse=return_inverse, return_counts=return_counts, dim=dim) input, return_inverse=return_inverse, return_counts=return_counts, dim=dim
)
return output, inverse_indices, counts return output, inverse_indices, counts
def _return_counts(input, sorted=True, return_inverse=False, return_counts=False, dim=None): def _return_counts(
input, sorted=True, return_inverse=False, return_counts=False, dim=None
):
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
if has_torch_function_unary(input): if has_torch_function_unary(input):
@ -988,7 +1054,9 @@ def _return_counts(input, sorted=True, return_inverse=False, return_counts=False
return output, counts return output, counts
def _return_output(input, sorted=True, return_inverse=False, return_counts=False, dim=None): def _return_output(
input, sorted=True, return_inverse=False, return_counts=False, dim=None
):
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor
if has_torch_function_unary(input): if has_torch_function_unary(input):
@ -998,59 +1066,72 @@ def _return_output(input, sorted=True, return_inverse=False, return_counts=False
return output return output
def _return_inverse(input, sorted=True, return_inverse=False, return_counts=False, dim=None): def _return_inverse(
input, sorted=True, return_inverse=False, return_counts=False, dim=None
):
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
if has_torch_function_unary(input): if has_torch_function_unary(input):
return _unique_impl(input, sorted, return_inverse, return_counts, dim) return _unique_impl(input, sorted, return_inverse, return_counts, dim)
output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim) output, inverse_indices, _ = _unique_impl(
input, sorted, return_inverse, return_counts, dim
)
return output, inverse_indices return output, inverse_indices
_return_inverse_false = boolean_dispatch( _return_inverse_false = boolean_dispatch(
arg_name='return_counts', arg_name="return_counts",
arg_index=3, arg_index=3,
default=False, default=False,
if_true=_return_counts, if_true=_return_counts,
if_false=_return_output, if_false=_return_output,
module_name=__name__, module_name=__name__,
func_name='unique') func_name="unique",
)
_return_inverse_true = boolean_dispatch( _return_inverse_true = boolean_dispatch(
arg_name='return_counts', arg_name="return_counts",
arg_index=3, arg_index=3,
default=False, default=False,
if_true=_unique_impl, if_true=_unique_impl,
if_false=_return_inverse, if_false=_return_inverse,
module_name=__name__, module_name=__name__,
func_name='unique') func_name="unique",
)
# The return type of unique depends on `return_inverse`, and `return_counts` so in order to # The return type of unique depends on `return_inverse`, and `return_counts` so in order to
# resolve the output type in TorchScript we need to statically know the value of both parameters # resolve the output type in TorchScript we need to statically know the value of both parameters
unique = boolean_dispatch( unique = boolean_dispatch(
arg_name='return_inverse', arg_name="return_inverse",
arg_index=2, arg_index=2,
default=False, default=False,
if_true=_return_inverse_true, if_true=_return_inverse_true,
if_false=_return_inverse_false, if_false=_return_inverse_false,
module_name=__name__, module_name=__name__,
func_name='unique') func_name="unique",
)
unique.__doc__ = _unique_impl.__doc__ unique.__doc__ = _unique_impl.__doc__
def _consecutive_return_counts(input, return_inverse=False, return_counts=False, dim=None): def _consecutive_return_counts(
input, return_inverse=False, return_counts=False, dim=None
):
# type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
if has_torch_function_unary(input): if has_torch_function_unary(input):
return _unique_consecutive_impl(input, return_inverse, return_counts, dim) return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim) output, _, counts = _unique_consecutive_impl(
input, return_inverse, return_counts, dim
)
return output, counts return output, counts
def _consecutive_return_output(input, return_inverse=False, return_counts=False, dim=None): def _consecutive_return_output(
input, return_inverse=False, return_counts=False, dim=None
):
# type: (Tensor, bool, bool, Optional[int]) -> Tensor # type: (Tensor, bool, bool, Optional[int]) -> Tensor
if has_torch_function_unary(input): if has_torch_function_unary(input):
@ -1060,45 +1141,52 @@ def _consecutive_return_output(input, return_inverse=False, return_counts=False,
return output return output
def _consecutive_return_inverse(input, return_inverse=False, return_counts=False, dim=None): def _consecutive_return_inverse(
input, return_inverse=False, return_counts=False, dim=None
):
# type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
if has_torch_function_unary(input): if has_torch_function_unary(input):
return _unique_consecutive_impl(input, return_inverse, return_counts, dim) return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim) output, inverse_indices, _ = _unique_consecutive_impl(
input, return_inverse, return_counts, dim
)
return output, inverse_indices return output, inverse_indices
_consecutive_return_inverse_false = boolean_dispatch( _consecutive_return_inverse_false = boolean_dispatch(
arg_name='return_counts', arg_name="return_counts",
arg_index=1, arg_index=1,
default=False, default=False,
if_true=_consecutive_return_counts, if_true=_consecutive_return_counts,
if_false=_consecutive_return_output, if_false=_consecutive_return_output,
module_name=__name__, module_name=__name__,
func_name='unique_consecutive') func_name="unique_consecutive",
)
_consecutive_return_inverse_true = boolean_dispatch( _consecutive_return_inverse_true = boolean_dispatch(
arg_name='return_counts', arg_name="return_counts",
arg_index=1, arg_index=1,
default=False, default=False,
if_true=_unique_consecutive_impl, if_true=_unique_consecutive_impl,
if_false=_consecutive_return_inverse, if_false=_consecutive_return_inverse,
module_name=__name__, module_name=__name__,
func_name='unique_consecutive') func_name="unique_consecutive",
)
# The return type of unique depends on `return_inverse`, and `return_counts` so in order to # The return type of unique depends on `return_inverse`, and `return_counts` so in order to
# resolve the output type in TorchScript we need to statically know the value of both parameters # resolve the output type in TorchScript we need to statically know the value of both parameters
unique_consecutive = boolean_dispatch( unique_consecutive = boolean_dispatch(
arg_name='return_inverse', arg_name="return_inverse",
arg_index=2, arg_index=2,
default=False, default=False,
if_true=_consecutive_return_inverse_true, if_true=_consecutive_return_inverse_true,
if_false=_consecutive_return_inverse_false, if_false=_consecutive_return_inverse_false,
module_name=__name__, module_name=__name__,
func_name='unique_consecutive') func_name="unique_consecutive",
)
unique_consecutive.__doc__ = _unique_consecutive_impl.__doc__ unique_consecutive.__doc__ = _unique_consecutive_impl.__doc__
if TYPE_CHECKING: if TYPE_CHECKING:
@ -1106,24 +1194,50 @@ if TYPE_CHECKING:
# There's no good way to use this type annotation without breaking JIT # There's no good way to use this type annotation without breaking JIT
# overloads. So leave untyped for mypy for now. # overloads. So leave untyped for mypy for now.
else: else:
@overload @overload
def tensordot(a, b, dims: int = 2, out: Optional[torch.Tensor] = None): def tensordot(
a,
b,
dims: int = 2,
out: Optional[torch.Tensor] = None,
):
pass pass
@overload # noqa: F811 @overload
def tensordot(a, b, dims: Tuple[List[int], List[int]], out: Optional[torch.Tensor] = None): # noqa: F811 def tensordot( # noqa: F811
a,
b,
dims: Tuple[List[int], List[int]],
out: Optional[torch.Tensor] = None,
):
pass pass
@overload # noqa: F811 @overload
def tensordot(a, b, dims: List[List[int]], out: Optional[torch.Tensor] = None): # noqa: F811 def tensordot( # noqa: F811
a,
b,
dims: List[List[int]],
out: Optional[torch.Tensor] = None,
):
pass pass
@overload # noqa: F811 @overload
def tensordot(a, b, dims: torch.Tensor, out: Optional[torch.Tensor] = None): # noqa: F811 def tensordot( # noqa: F811
a,
b,
dims: torch.Tensor,
out: Optional[torch.Tensor] = None,
):
pass pass
def tensordot(a, b, dims=2, out: Optional[torch.Tensor] = None): # noqa: F811 def tensordot( # noqa: F811
a,
b,
dims=2,
out: Optional[torch.Tensor] = None,
):
r"""Returns a contraction of a and b over multiple dimensions. r"""Returns a contraction of a and b over multiple dimensions.
:attr:`tensordot` implements a generalized matrix product. :attr:`tensordot` implements a generalized matrix product.
@ -1178,10 +1292,12 @@ def tensordot(a, b, dims=2, out: Optional[torch.Tensor] = None): # noqa: F811
return handle_torch_function(tensordot, (a, b), a, b, dims=dims, out=out) return handle_torch_function(tensordot, (a, b), a, b, dims=dims, out=out)
if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)): if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)):
raise RuntimeError("tensordot expects dims to be int or " raise RuntimeError(
"tensordot expects dims to be int or "
+ "Tuple[List[int], List[int]] or " + "Tuple[List[int], List[int]] or "
+ "List[List[int]] containing two lists, but got " + "List[List[int]] containing two lists, but got "
+ f"dims={dims}") + f"dims={dims}"
)
dims_a: List[int] = [] dims_a: List[int] = []
dims_b: List[int] = [] dims_b: List[int] = []
@ -1206,7 +1322,9 @@ def tensordot(a, b, dims=2, out: Optional[torch.Tensor] = None): # noqa: F811
if dims < 0: if dims < 0:
raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}") raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
if dims > min(a.dim(), b.dim()): if dims > min(a.dim(), b.dim()):
raise RuntimeError(f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}") raise RuntimeError(
f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}"
)
dims_a = list(range(-dims, 0)) dims_a = list(range(-dims, 0))
dims_b = list(range(dims)) dims_b = list(range(dims))
@ -1287,7 +1405,7 @@ def block_diag(*tensors):
return torch._C._VariableFunctions.block_diag(tensors) # type: ignore[attr-defined] return torch._C._VariableFunctions.block_diag(tensors) # type: ignore[attr-defined]
def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'): def cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"):
# type: (Tensor, Tensor, float, str) -> (Tensor) # type: (Tensor, Tensor, float, str) -> (Tensor)
r"""Computes batched the p-norm distance between each pair of the two collections of row vectors. r"""Computes batched the p-norm distance between each pair of the two collections of row vectors.
@ -1331,12 +1449,13 @@ def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'):
""" """
if has_torch_function_variadic(x1, x2): if has_torch_function_variadic(x1, x2):
return handle_torch_function( return handle_torch_function(
cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode) cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode
if compute_mode == 'use_mm_for_euclid_dist_if_necessary': )
if compute_mode == "use_mm_for_euclid_dist_if_necessary":
return _VF.cdist(x1, x2, p, None) # type: ignore[attr-defined] return _VF.cdist(x1, x2, p, None) # type: ignore[attr-defined]
elif compute_mode == 'use_mm_for_euclid_dist': elif compute_mode == "use_mm_for_euclid_dist":
return _VF.cdist(x1, x2, p, 1) # type: ignore[attr-defined] return _VF.cdist(x1, x2, p, 1) # type: ignore[attr-defined]
elif compute_mode == 'donot_use_mm_for_euclid_dist': elif compute_mode == "donot_use_mm_for_euclid_dist":
return _VF.cdist(x1, x2, p, 2) # type: ignore[attr-defined] return _VF.cdist(x1, x2, p, 2) # type: ignore[attr-defined]
else: else:
raise ValueError(f"{compute_mode} is not a valid value for compute_mode") raise ValueError(f"{compute_mode} is not a valid value for compute_mode")
@ -1478,27 +1597,62 @@ else:
# TODO: type dim as BroadcastingList when # TODO: type dim as BroadcastingList when
# https://github.com/pytorch/pytorch/issues/33782 is fixed # https://github.com/pytorch/pytorch/issues/33782 is fixed
@overload @overload
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): def norm(
input,
p="fro",
dim=None,
keepdim=False,
out=None,
dtype=None,
):
# type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor # type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
pass pass
@overload # noqa: F811 @overload
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811 def norm( # noqa: F811
input,
p="fro",
dim=None,
keepdim=False,
out=None,
dtype=None,
):
# type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor # type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
pass pass
@overload # noqa: F811 @overload
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811 def norm( # noqa: F811
input,
p="fro",
dim=None,
keepdim=False,
out=None,
dtype=None,
):
# type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor # type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
pass pass
@overload # noqa: F811 @overload
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811 def norm( # noqa: F811
input,
p="fro",
dim=None,
keepdim=False,
out=None,
dtype=None,
):
# type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor # type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
pass pass
def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811 def norm( # noqa: F811
input,
p: Optional[Union[float, str]] = "fro",
dim=None,
keepdim=False,
out=None,
dtype=None,
):
r"""Returns the matrix norm or vector norm of a given tensor. r"""Returns the matrix norm or vector norm of a given tensor.
.. warning:: .. warning::
@ -1594,14 +1748,19 @@ def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False,
if has_torch_function_unary(input): if has_torch_function_unary(input):
return handle_torch_function( return handle_torch_function(
norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype) norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype
)
# NB. All the repeated code and weird python is to please TorchScript. # NB. All the repeated code and weird python is to please TorchScript.
# For a more compact implementation see the relevant function in `_refs/__init__.py` # For a more compact implementation see the relevant function in `_refs/__init__.py`
# We don't do this for MPS or sparse tensors # We don't do this for MPS or sparse tensors
if input.layout == torch.strided and input.device.type in \ if input.layout == torch.strided and input.device.type in (
("cpu", "cuda", "meta", torch.utils.backend_registration._privateuse1_backend_name): "cpu",
"cuda",
"meta",
torch.utils.backend_registration._privateuse1_backend_name,
):
if dim is not None: if dim is not None:
if isinstance(dim, (int, torch.SymInt)): if isinstance(dim, (int, torch.SymInt)):
_dim = [dim] _dim = [dim]
@ -1611,11 +1770,17 @@ def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False,
_dim = None # type: ignore[assignment] _dim = None # type: ignore[assignment]
if isinstance(p, str): if isinstance(p, str):
if p == "fro" and (dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2): if p == "fro" and (
dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2
):
if out is None: if out is None:
return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype) return torch.linalg.vector_norm(
input, 2, _dim, keepdim, dtype=dtype
)
else: else:
return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype, out=out) return torch.linalg.vector_norm(
input, 2, _dim, keepdim, dtype=dtype, out=out
)
# Here we either call the nuclear norm, or we call matrix_norm with some arguments # Here we either call the nuclear norm, or we call matrix_norm with some arguments
# that will throw an error # that will throw an error
@ -1624,14 +1789,18 @@ def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False,
if out is None: if out is None:
return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype) return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype)
else: else:
return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype, out=out) return torch.linalg.matrix_norm(
input, p, _dim, keepdim, dtype=dtype, out=out
)
else: else:
# NB. p should be Union[str, number], not Optional! # NB. p should be Union[str, number], not Optional!
_p = 2.0 if p is None else p _p = 2.0 if p is None else p
if out is None: if out is None:
return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype) return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype)
else: else:
return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype, out=out) return torch.linalg.vector_norm(
input, _p, _dim, keepdim, dtype=dtype, out=out
)
ndim = input.dim() ndim = input.dim()
@ -1641,7 +1810,7 @@ def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False,
if p == "fro": if p == "fro":
return _VF.frobenius_norm(input, dim=(), keepdim=keepdim) return _VF.frobenius_norm(input, dim=(), keepdim=keepdim)
if not isinstance(p, str): if not isinstance(p, str):
_dim = [i for i in range(ndim)] # noqa: C416 TODO: rewrite as list(range(m)) _dim = list(range(ndim))
return _VF.norm(input, p, dim=_dim, keepdim=keepdim) # type: ignore[attr-defined] return _VF.norm(input, p, dim=_dim, keepdim=keepdim) # type: ignore[attr-defined]
# TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed # TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed
@ -1695,7 +1864,10 @@ def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False,
else: else:
return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out) # type: ignore[attr-defined] return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out) # type: ignore[attr-defined]
def unravel_index(indices: Tensor, shape: Union[int, Sequence[int], torch.Size]) -> Tuple[Tensor, ...]:
def unravel_index(
indices: Tensor, shape: Union[int, Sequence[int], torch.Size]
) -> Tuple[Tensor, ...]:
r"""Converts a tensor of flat indices into a tuple of coordinate tensors that r"""Converts a tensor of flat indices into a tuple of coordinate tensors that
index into an arbitrary tensor of the specified shape. index into an arbitrary tensor of the specified shape.
@ -1745,19 +1917,23 @@ def unravel_index(indices: Tensor, shape: Union[int, Sequence[int], torch.Size])
tensor([[34], [78]])) tensor([[34], [78]]))
""" """
if has_torch_function_unary(indices): if has_torch_function_unary(indices):
return handle_torch_function( return handle_torch_function(unravel_index, (indices,), indices, shape=shape)
unravel_index, (indices,), indices, shape=shape)
res_tensor = _unravel_index(indices, shape) res_tensor = _unravel_index(indices, shape)
return res_tensor.unbind(-1) return res_tensor.unbind(-1)
def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor: def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
torch._check_type( torch._check_type(
not indices.is_complex() and not indices.is_floating_point() and not indices.dtype == torch.bool, not indices.is_complex()
lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}") and not indices.is_floating_point()
and not indices.dtype == torch.bool,
lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}",
)
torch._check_type( torch._check_type(
isinstance(shape, (int, torch.SymInt, Sequence)), isinstance(shape, (int, torch.SymInt, Sequence)),
lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}") lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}",
)
if isinstance(shape, (int, torch.SymInt)): if isinstance(shape, (int, torch.SymInt)):
shape = torch.Size([shape]) shape = torch.Size([shape])
@ -1765,18 +1941,29 @@ def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
for dim in shape: for dim in shape:
torch._check_type( torch._check_type(
isinstance(dim, (int, torch.SymInt)), isinstance(dim, (int, torch.SymInt)),
lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}") lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}",
)
shape = torch.Size(shape) shape = torch.Size(shape)
torch._check_value( torch._check_value(
all(dim >= 0 for dim in shape), all(dim >= 0 for dim in shape),
lambda: f"'shape' cannot have negative values, but got {tuple(shape)}") lambda: f"'shape' cannot have negative values, but got {tuple(shape)}",
)
coefs = list(reversed(list(itertools.accumulate(reversed(shape[1:] + torch.Size([1])), func=operator.mul)))) coefs = list(
reversed(
list(
itertools.accumulate(
reversed(shape[1:] + torch.Size([1])), func=operator.mul
)
)
)
)
return indices.unsqueeze(-1).floor_divide( return indices.unsqueeze(-1).floor_divide(
torch.tensor(coefs, device=indices.device, dtype=torch.int64) torch.tensor(coefs, device=indices.device, dtype=torch.int64)
) % torch.tensor(shape, device=indices.device, dtype=torch.int64) ) % torch.tensor(shape, device=indices.device, dtype=torch.int64)
def chain_matmul(*matrices, out=None): def chain_matmul(*matrices, out=None):
r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
@ -1923,6 +2110,7 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None):
# If get_infos is True, then we don't need to check for errors and vice versa # If get_infos is True, then we don't need to check for errors and vice versa
return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos)) return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
if TYPE_CHECKING: if TYPE_CHECKING:
_ListOrSeq = Sequence[Tensor] _ListOrSeq = Sequence[Tensor]
else: else:
@ -1932,16 +2120,21 @@ else:
def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None: def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:
get_infos_int = 1 if get_infos else 0 get_infos_int = 1 if get_infos else 0
if out_len - get_infos_int != 2: if out_len - get_infos_int != 2:
raise TypeError(f"expected tuple of {2 + int(get_infos)} elements but got {out_len}") raise TypeError(
f"expected tuple of {2 + int(get_infos)} elements but got {out_len}"
)
if not isinstance(out, (tuple, list)): if not isinstance(out, (tuple, list)):
raise TypeError(f"argument 'out' must be tuple of Tensors, not {type(out).__name__}") raise TypeError(
f"argument 'out' must be tuple of Tensors, not {type(out).__name__}"
)
def _lu_with_infos(A, pivot=True, get_infos=False, out=None): def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
# type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor] # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]
if has_torch_function_unary(A): if has_torch_function_unary(A):
return handle_torch_function( return handle_torch_function(
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out) lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
)
result = _lu_impl(A, pivot, get_infos, out) result = _lu_impl(A, pivot, get_infos, out)
if out is not None: if out is not None:
_check_list_size(len(out), get_infos, out) _check_list_size(len(out), get_infos, out)
@ -1957,7 +2150,8 @@ def _lu_no_infos(A, pivot=True, get_infos=False, out=None):
# need to check for torch_function here so that we exit if # need to check for torch_function here so that we exit if
if has_torch_function_unary(A): if has_torch_function_unary(A):
return handle_torch_function( return handle_torch_function(
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out) lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
)
result = _lu_impl(A, pivot, get_infos, out) result = _lu_impl(A, pivot, get_infos, out)
if out is not None: if out is not None:
_check_list_size(len(out), get_infos, out) _check_list_size(len(out), get_infos, out)
@ -1967,18 +2161,20 @@ def _lu_no_infos(A, pivot=True, get_infos=False, out=None):
else: else:
return result[0], result[1] # A_LU, pivots return result[0], result[1] # A_LU, pivots
# The return type of lu depends on `get_infos`, so in order to resolve the output type # The return type of lu depends on `get_infos`, so in order to resolve the output type
# of lu in TorchScript we need to statically know the value of `get_infos` # of lu in TorchScript we need to statically know the value of `get_infos`
lu = boolean_dispatch( lu = boolean_dispatch(
arg_name='get_infos', arg_name="get_infos",
arg_index=2, arg_index=2,
default=False, default=False,
if_true=_lu_with_infos, if_true=_lu_with_infos,
if_false=_lu_no_infos, if_false=_lu_no_infos,
module_name=__name__, module_name=__name__,
func_name='lu') func_name="lu",
)
lu.__doc__ = _lu_impl.__doc__ lu.__doc__ = _lu_impl.__doc__
def align_tensors(*tensors): def align_tensors(*tensors):
raise RuntimeError('`align_tensors` not yet implemented.') raise RuntimeError("`align_tensors` not yet implemented.")

View File

@ -8,22 +8,22 @@ import re
import shutil import shutil
import sys import sys
import tempfile import tempfile
import torch
import uuid import uuid
import warnings import warnings
import zipfile import zipfile
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Any from typing import Any, Dict, Optional
from typing_extensions import deprecated from typing_extensions import deprecated
from urllib.error import HTTPError, URLError from urllib.error import HTTPError, URLError
from urllib.request import urlopen, Request
from urllib.parse import urlparse # noqa: F401 from urllib.parse import urlparse # noqa: F401
from urllib.request import Request, urlopen
import torch
from torch.serialization import MAP_LOCATION from torch.serialization import MAP_LOCATION
class _Faketqdm: # type: ignore[no-redef]
def __init__(self, total=None, disable=False, class _Faketqdm: # type: ignore[no-redef]
unit=None, *args, **kwargs): def __init__(self, total=None, disable=False, unit=None, *args, **kwargs):
self.total = total self.total = total
self.disable = disable self.disable = disable
self.n = 0 self.n = 0
@ -57,7 +57,8 @@ class _Faketqdm: # type: ignore[no-redef]
if self.disable: if self.disable:
return return
sys.stderr.write('\n') sys.stderr.write("\n")
try: try:
from tqdm import tqdm # If tqdm is installed use it, otherwise use the fake wrapper from tqdm import tqdm # If tqdm is installed use it, otherwise use the fake wrapper
@ -65,25 +66,30 @@ except ImportError:
tqdm = _Faketqdm tqdm = _Faketqdm
__all__ = [ __all__ = [
'download_url_to_file', "download_url_to_file",
'get_dir', "get_dir",
'help', "help",
'list', "list",
'load', "load",
'load_state_dict_from_url', "load_state_dict_from_url",
'set_dir', "set_dir",
] ]
# matches bfd8deac from resnet18-bfd8deac.pth # matches bfd8deac from resnet18-bfd8deac.pth
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') HASH_REGEX = re.compile(r"-([a-f0-9]*)\.")
_TRUSTED_REPO_OWNERS = ("facebookresearch", "facebookincubator", "pytorch", "fairinternal") _TRUSTED_REPO_OWNERS = (
ENV_GITHUB_TOKEN = 'GITHUB_TOKEN' "facebookresearch",
ENV_TORCH_HOME = 'TORCH_HOME' "facebookincubator",
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' "pytorch",
DEFAULT_CACHE_DIR = '~/.cache' "fairinternal",
VAR_DEPENDENCY = 'dependencies' )
MODULE_HUBCONF = 'hubconf.py' ENV_GITHUB_TOKEN = "GITHUB_TOKEN"
ENV_TORCH_HOME = "TORCH_HOME"
ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
DEFAULT_CACHE_DIR = "~/.cache"
VAR_DEPENDENCY = "dependencies"
MODULE_HUBCONF = "hubconf.py"
READ_DATA_CHUNK = 128 * 1024 READ_DATA_CHUNK = 128 * 1024
_hub_dir: Optional[str] = None _hub_dir: Optional[str] = None
@ -101,6 +107,7 @@ def _add_to_sys_path(path):
def _import_module(name, path): def _import_module(name, path):
import importlib.util import importlib.util
from importlib.abc import Loader from importlib.abc import Loader
spec = importlib.util.spec_from_file_location(name, path) spec = importlib.util.spec_from_file_location(name, path)
assert spec is not None assert spec is not None
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
@ -131,18 +138,20 @@ def _load_attr_from_module(module, func_name):
def _get_torch_home(): def _get_torch_home():
torch_home = os.path.expanduser( torch_home = os.path.expanduser(
os.getenv(ENV_TORCH_HOME, os.getenv(
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, ENV_TORCH_HOME,
DEFAULT_CACHE_DIR), 'torch'))) os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"),
)
)
return torch_home return torch_home
def _parse_repo_info(github): def _parse_repo_info(github):
if ':' in github: if ":" in github:
repo_info, ref = github.split(':') repo_info, ref = github.split(":")
else: else:
repo_info, ref = github, None repo_info, ref = github, None
repo_owner, repo_name = repo_info.split('/') repo_owner, repo_name = repo_info.split("/")
if ref is None: if ref is None:
# The ref wasn't specified by the user, so we need to figure out the # The ref wasn't specified by the user, so we need to figure out the
@ -150,16 +159,18 @@ def _parse_repo_info(github):
# then it's the default branch, otherwise it's master. # then it's the default branch, otherwise it's master.
try: try:
with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"): with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"):
ref = 'main' ref = "main"
except HTTPError as e: except HTTPError as e:
if e.code == 404: if e.code == 404:
ref = 'master' ref = "master"
else: else:
raise raise
except URLError as e: except URLError as e:
# No internet connection, need to check for cache as last resort # No internet connection, need to check for cache as last resort
for possible_ref in ("main", "master"): for possible_ref in ("main", "master"):
if os.path.exists(f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}"): if os.path.exists(
f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}"
):
ref = possible_ref ref = possible_ref
break break
if ref is None: if ref is None:
@ -172,35 +183,40 @@ def _parse_repo_info(github):
def _read_url(url): def _read_url(url):
with urlopen(url) as r: with urlopen(url) as r:
return r.read().decode(r.headers.get_content_charset('utf-8')) return r.read().decode(r.headers.get_content_charset("utf-8"))
def _validate_not_a_forked_repo(repo_owner, repo_name, ref): def _validate_not_a_forked_repo(repo_owner, repo_name, ref):
# Use urlopen to avoid depending on local git. # Use urlopen to avoid depending on local git.
headers = {'Accept': 'application/vnd.github.v3+json'} headers = {"Accept": "application/vnd.github.v3+json"}
token = os.environ.get(ENV_GITHUB_TOKEN) token = os.environ.get(ENV_GITHUB_TOKEN)
if token is not None: if token is not None:
headers['Authorization'] = f'token {token}' headers["Authorization"] = f"token {token}"
for url_prefix in ( for url_prefix in (
f'https://api.github.com/repos/{repo_owner}/{repo_name}/branches', f"https://api.github.com/repos/{repo_owner}/{repo_name}/branches",
f'https://api.github.com/repos/{repo_owner}/{repo_name}/tags'): f"https://api.github.com/repos/{repo_owner}/{repo_name}/tags",
):
page = 0 page = 0
while True: while True:
page += 1 page += 1
url = f'{url_prefix}?per_page=100&page={page}' url = f"{url_prefix}?per_page=100&page={page}"
response = json.loads(_read_url(Request(url, headers=headers))) response = json.loads(_read_url(Request(url, headers=headers)))
# Empty response means no more data to process # Empty response means no more data to process
if not response: if not response:
break break
for br in response: for br in response:
if br['name'] == ref or br['commit']['sha'].startswith(ref): if br["name"] == ref or br["commit"]["sha"].startswith(ref):
return return
raise ValueError(f'Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. ' raise ValueError(
'If it\'s a commit from a forked repo, please call hub.load() with forked repo directly.') f"Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. "
"If it's a commit from a forked repo, please call hub.load() with forked repo directly."
)
def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=True, skip_validation=False): def _get_cache_or_reload(
github, force_reload, trust_repo, calling_fn, verbose=True, skip_validation=False
):
# Setup hub_dir to save downloaded files # Setup hub_dir to save downloaded files
hub_dir = get_dir() hub_dir = get_dir()
os.makedirs(hub_dir, exist_ok=True) os.makedirs(hub_dir, exist_ok=True)
@ -210,27 +226,33 @@ def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=T
# this causes confusion with path on both Linux and Windows. # this causes confusion with path on both Linux and Windows.
# Backslash is not allowed in Github branch name so no need to # Backslash is not allowed in Github branch name so no need to
# to worry about it. # to worry about it.
normalized_br = ref.replace('/', '_') normalized_br = ref.replace("/", "_")
# Github renames folder repo-v1.x.x to repo-1.x.x # Github renames folder repo-v1.x.x to repo-1.x.x
# We don't know the repo name before downloading the zip file # We don't know the repo name before downloading the zip file
# and inspect name from it. # and inspect name from it.
# To check if cached repo exists, we need to normalize folder names. # To check if cached repo exists, we need to normalize folder names.
owner_name_branch = '_'.join([repo_owner, repo_name, normalized_br]) owner_name_branch = "_".join([repo_owner, repo_name, normalized_br])
repo_dir = os.path.join(hub_dir, owner_name_branch) repo_dir = os.path.join(hub_dir, owner_name_branch)
# Check that the repo is in the trusted list # Check that the repo is in the trusted list
_check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo=trust_repo, calling_fn=calling_fn) _check_repo_is_trusted(
repo_owner,
repo_name,
owner_name_branch,
trust_repo=trust_repo,
calling_fn=calling_fn,
)
use_cache = (not force_reload) and os.path.exists(repo_dir) use_cache = (not force_reload) and os.path.exists(repo_dir)
if use_cache: if use_cache:
if verbose: if verbose:
sys.stderr.write(f'Using cache found in {repo_dir}\n') sys.stderr.write(f"Using cache found in {repo_dir}\n")
else: else:
# Validate the tag/branch is from the original repo instead of a forked repo # Validate the tag/branch is from the original repo instead of a forked repo
if not skip_validation: if not skip_validation:
_validate_not_a_forked_repo(repo_owner, repo_name, ref) _validate_not_a_forked_repo(repo_owner, repo_name, ref)
cached_file = os.path.join(hub_dir, normalized_br + '.zip') cached_file = os.path.join(hub_dir, normalized_br + ".zip")
_remove_if_exists(cached_file) _remove_if_exists(cached_file)
try: try:
@ -250,7 +272,9 @@ def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=T
"refs/tags/tag_name as the ref. That might require using skip_validation=True." "refs/tags/tag_name as the ref. That might require using skip_validation=True."
) )
disambiguated_branch_ref = f"refs/heads/{ref}" disambiguated_branch_ref = f"refs/heads/{ref}"
url = _git_archive_link(repo_owner, repo_name, ref=disambiguated_branch_ref) url = _git_archive_link(
repo_owner, repo_name, ref=disambiguated_branch_ref
)
download_url_to_file(url, cached_file, progress=False) download_url_to_file(url, cached_file, progress=False)
else: else:
raise raise
@ -269,7 +293,9 @@ def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=T
return repo_dir return repo_dir
def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo, calling_fn="load"): def _check_repo_is_trusted(
repo_owner, repo_name, owner_name_branch, trust_repo, calling_fn="load"
):
hub_dir = get_dir() hub_dir = get_dir()
filepath = os.path.join(hub_dir, "trusted_list") filepath = os.path.join(hub_dir, "trusted_list")
@ -282,7 +308,7 @@ def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo,
# if a repo was already downloaded by torchhub, then it is already trusted (even if it's not in the allowlist) # if a repo was already downloaded by torchhub, then it is already trusted (even if it's not in the allowlist)
trusted_repos_legacy = next(os.walk(hub_dir))[1] trusted_repos_legacy = next(os.walk(hub_dir))[1]
owner_name = '_'.join([repo_owner, repo_name]) owner_name = "_".join([repo_owner, repo_name])
is_trusted = ( is_trusted = (
owner_name in trusted_repos owner_name in trusted_repos
or owner_name_branch in trusted_repos_legacy or owner_name_branch in trusted_repos_legacy
@ -298,13 +324,15 @@ def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo,
"trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, " "trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, "
f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with " f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with "
f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for " f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for "
f"confirmation if the repo is not already trusted. This will eventually be the default behaviour") f"confirmation if the repo is not already trusted. This will eventually be the default behaviour"
)
return return
if (trust_repo is False) or (trust_repo == "check" and not is_trusted): if (trust_repo is False) or (trust_repo == "check" and not is_trusted):
response = input( response = input(
f"The repository {owner_name} does not belong to the list of trusted repositories and as such cannot be downloaded. " f"The repository {owner_name} does not belong to the list of trusted repositories and as such cannot be downloaded. "
"Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?") "Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?"
)
if response.lower() in ("y", "yes"): if response.lower() in ("y", "yes"):
if is_trusted: if is_trusted:
print("The repository is already trusted.") print("The repository is already trusted.")
@ -321,6 +349,7 @@ def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo,
def _check_module_exists(name): def _check_module_exists(name):
import importlib.util import importlib.util
return importlib.util.find_spec(name) is not None return importlib.util.find_spec(name) is not None
@ -335,7 +364,7 @@ def _check_dependencies(m):
def _load_entry_from_hubconf(m, model): def _load_entry_from_hubconf(m, model):
if not isinstance(model, str): if not isinstance(model, str):
raise ValueError('Invalid input: model should be a string of function name') raise ValueError("Invalid input: model should be a string of function name")
# Note that if a missing dependency is imported at top level of hubconf, it will # Note that if a missing dependency is imported at top level of hubconf, it will
# throw before this function. It's a chicken and egg situation where we have to # throw before this function. It's a chicken and egg situation where we have to
@ -346,7 +375,7 @@ def _load_entry_from_hubconf(m, model):
func = _load_attr_from_module(m, model) func = _load_attr_from_module(m, model)
if func is None or not callable(func): if func is None or not callable(func):
raise RuntimeError(f'Cannot find callable {model} in hubconf') raise RuntimeError(f"Cannot find callable {model} in hubconf")
return func return func
@ -362,12 +391,12 @@ def get_dir():
variable is not set. variable is not set.
""" """
# Issue warning to move data if old env is set # Issue warning to move data if old env is set
if os.getenv('TORCH_HUB'): if os.getenv("TORCH_HUB"):
warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead') warnings.warn("TORCH_HUB is deprecated, please use env TORCH_HOME instead")
if _hub_dir is not None: if _hub_dir is not None:
return _hub_dir return _hub_dir
return os.path.join(_get_torch_home(), 'hub') return os.path.join(_get_torch_home(), "hub")
def set_dir(d): def set_dir(d):
@ -381,7 +410,9 @@ def set_dir(d):
_hub_dir = os.path.expanduser(d) _hub_dir = os.path.expanduser(d)
def list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True): def list(
github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True
):
r""" r"""
List all callable entrypoints available in the repo specified by ``github``. List all callable entrypoints available in the repo specified by ``github``.
@ -424,15 +455,25 @@ def list(github, force_reload=False, skip_validation=False, trust_repo=None, ver
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True) >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
""" """
repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "list", verbose=verbose, repo_dir = _get_cache_or_reload(
skip_validation=skip_validation) github,
force_reload,
trust_repo,
"list",
verbose=verbose,
skip_validation=skip_validation,
)
with _add_to_sys_path(repo_dir): with _add_to_sys_path(repo_dir):
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF) hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
hub_module = _import_module(MODULE_HUBCONF, hubconf_path) hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
# We take functions starts with '_' as internal helper functions # We take functions starts with '_' as internal helper functions
entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')] entrypoints = [
f
for f in dir(hub_module)
if callable(getattr(hub_module, f)) and not f.startswith("_")
]
return entrypoints return entrypoints
@ -474,8 +515,14 @@ def help(github, model, force_reload=False, skip_validation=False, trust_repo=No
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True)) >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
""" """
repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "help", verbose=True, repo_dir = _get_cache_or_reload(
skip_validation=skip_validation) github,
force_reload,
trust_repo,
"help",
verbose=True,
skip_validation=skip_validation,
)
with _add_to_sys_path(repo_dir): with _add_to_sys_path(repo_dir):
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF) hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
@ -486,9 +533,17 @@ def help(github, model, force_reload=False, skip_validation=False, trust_repo=No
return entry.__doc__ return entry.__doc__
def load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True, def load(
repo_or_dir,
model,
*args,
source="github",
trust_repo=None,
force_reload=False,
verbose=True,
skip_validation=False, skip_validation=False,
**kwargs): **kwargs,
):
r""" r"""
Load a model from a github repo or a local directory. Load a model from a github repo or a local directory.
@ -559,13 +614,20 @@ def load(repo_or_dir, model, *args, source='github', trust_repo=None, force_relo
""" """
source = source.lower() source = source.lower()
if source not in ('github', 'local'): if source not in ("github", "local"):
raise ValueError( raise ValueError(
f'Unknown source: "{source}". Allowed values: "github" | "local".') f'Unknown source: "{source}". Allowed values: "github" | "local".'
)
if source == 'github': if source == "github":
repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, trust_repo, "load", repo_or_dir = _get_cache_or_reload(
verbose=verbose, skip_validation=skip_validation) repo_or_dir,
force_reload,
trust_repo,
"load",
verbose=verbose,
skip_validation=skip_validation,
)
model = _load_local(repo_or_dir, model, *args, **kwargs) model = _load_local(repo_or_dir, model, *args, **kwargs)
return model return model
@ -601,8 +663,9 @@ def _load_local(hubconf_dir, model, *args, **kwargs):
return model return model
def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None, def download_url_to_file(
progress: bool = True) -> None: url: str, dst: str, hash_prefix: Optional[str] = None, progress: bool = True
) -> None:
r"""Download object at the given URL to a local path. r"""Download object at the given URL to a local path.
Args: Args:
@ -623,7 +686,7 @@ def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
req = Request(url, headers={"User-Agent": "torch.hub"}) req = Request(url, headers={"User-Agent": "torch.hub"})
u = urlopen(req) u = urlopen(req)
meta = u.info() meta = u.info()
if hasattr(meta, 'getheaders'): if hasattr(meta, "getheaders"):
content_length = meta.getheaders("Content-Length") content_length = meta.getheaders("Content-Length")
else: else:
content_length = meta.get_all("Content-Length") content_length = meta.get_all("Content-Length")
@ -637,20 +700,25 @@ def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
# file permissions being applied to the downloaded file. # file permissions being applied to the downloaded file.
dst = os.path.expanduser(dst) dst = os.path.expanduser(dst)
for seq in range(tempfile.TMP_MAX): for seq in range(tempfile.TMP_MAX):
tmp_dst = dst + '.' + uuid.uuid4().hex + '.partial' tmp_dst = dst + "." + uuid.uuid4().hex + ".partial"
try: try:
f = open(tmp_dst, 'w+b') f = open(tmp_dst, "w+b")
except FileExistsError: except FileExistsError:
continue continue
break break
else: else:
raise FileExistsError(errno.EEXIST, 'No usable temporary file name found') raise FileExistsError(errno.EEXIST, "No usable temporary file name found")
try: try:
if hash_prefix is not None: if hash_prefix is not None:
sha256 = hashlib.sha256() sha256 = hashlib.sha256()
with tqdm(total=file_size, disable=not progress, with tqdm(
unit='B', unit_scale=True, unit_divisor=1024) as pbar: total=file_size,
disable=not progress,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as pbar:
while True: while True:
buffer = u.read(READ_DATA_CHUNK) buffer = u.read(READ_DATA_CHUNK)
if len(buffer) == 0: if len(buffer) == 0:
@ -664,7 +732,9 @@ def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
if hash_prefix is not None: if hash_prefix is not None:
digest = sha256.hexdigest() # type: ignore[possibly-undefined] digest = sha256.hexdigest() # type: ignore[possibly-undefined]
if digest[: len(hash_prefix)] != hash_prefix: if digest[: len(hash_prefix)] != hash_prefix:
raise RuntimeError(f'invalid hash value (expected "{hash_prefix}", got "{digest}")') raise RuntimeError(
f'invalid hash value (expected "{hash_prefix}", got "{digest}")'
)
shutil.move(f.name, dst) shutil.move(f.name, dst)
finally: finally:
f.close() f.close()
@ -683,23 +753,30 @@ def _is_legacy_zip_format(filename: str) -> bool:
@deprecated( @deprecated(
'Falling back to the old format < 1.6. This support will be ' "Falling back to the old format < 1.6. This support will be "
'deprecated in favor of default zipfile format introduced in 1.6. ' "deprecated in favor of default zipfile format introduced in 1.6. "
'Please redo torch.save() to save it in the new zipfile format.', "Please redo torch.save() to save it in the new zipfile format.",
category=FutureWarning, category=FutureWarning,
) )
def _legacy_zip_load(filename: str, model_dir: str, map_location: MAP_LOCATION, weights_only: bool) -> Dict[str, Any]: def _legacy_zip_load(
filename: str,
model_dir: str,
map_location: MAP_LOCATION,
weights_only: bool,
) -> Dict[str, Any]:
# Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand. # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
# We deliberately don't handle tarfile here since our legacy serialization format was in tar. # We deliberately don't handle tarfile here since our legacy serialization format was in tar.
# E.g. resnet18-5c106cde.pth which is widely used. # E.g. resnet18-5c106cde.pth which is widely used.
with zipfile.ZipFile(filename) as f: with zipfile.ZipFile(filename) as f:
members = f.infolist() members = f.infolist()
if len(members) != 1: if len(members) != 1:
raise RuntimeError('Only one file(not dir) is allowed in the zipfile') raise RuntimeError("Only one file(not dir) is allowed in the zipfile")
f.extractall(model_dir) f.extractall(model_dir)
extraced_name = members[0].filename extraced_name = members[0].filename
extracted_file = os.path.join(model_dir, extraced_name) extracted_file = os.path.join(model_dir, extraced_name)
return torch.load(extracted_file, map_location=map_location, weights_only=weights_only) return torch.load(
extracted_file, map_location=map_location, weights_only=weights_only
)
def load_state_dict_from_url( def load_state_dict_from_url(
@ -742,12 +819,14 @@ def load_state_dict_from_url(
""" """
# Issue warning to move data if old env is set # Issue warning to move data if old env is set
if os.getenv('TORCH_MODEL_ZOO'): if os.getenv("TORCH_MODEL_ZOO"):
warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') warnings.warn(
"TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead"
)
if model_dir is None: if model_dir is None:
hub_dir = get_dir() hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints') model_dir = os.path.join(hub_dir, "checkpoints")
os.makedirs(model_dir, exist_ok=True) os.makedirs(model_dir, exist_ok=True)

View File

@ -1,28 +1,34 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from ._ops import OpOverload import contextlib
from typing import Any, Optional, Set, List, Union, Callable, Tuple, Dict, Sequence
from typing_extensions import deprecated
import traceback
import torch
import weakref
import functools import functools
import inspect import inspect
import re import re
import contextlib
import sys import sys
from torch._library.custom_ops import custom_op, _maybe_get_opdef, device_types_t, CustomOpDef import traceback
import weakref
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
from typing_extensions import deprecated
import torch
import torch._library as _library import torch._library as _library
from torch._library.custom_ops import (
_maybe_get_opdef,
custom_op,
CustomOpDef,
device_types_t,
)
from torch._ops import OpOverload
__all__ = [ __all__ = [
'Library', "Library",
'impl', "impl",
'define', "define",
'fallthrough_kernel', "fallthrough_kernel",
'impl_abstract', "impl_abstract",
'register_fake', "register_fake",
'get_ctx', "get_ctx",
'custom_op', "custom_op",
] ]
# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered # Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
@ -33,7 +39,8 @@ _impls: Set[str] = set()
_defs: Set[str] = set() _defs: Set[str] = set()
# prim is reserved by TorchScript interpreter # prim is reserved by TorchScript interpreter
_reserved_namespaces = ['prim'] _reserved_namespaces = ["prim"]
def fallthrough_kernel(): def fallthrough_kernel():
""" """
@ -41,6 +48,7 @@ def fallthrough_kernel():
""" """
raise NotImplementedError("fallthrough_kernel() should never be called.") raise NotImplementedError("fallthrough_kernel() should never be called.")
class Library: class Library:
""" """
A class to create libraries that can be used to register new operators or A class to create libraries that can be used to register new operators or
@ -59,16 +67,22 @@ class Library:
kind: "DEF", "IMPL" (default: "IMPL"), "FRAGMENT" kind: "DEF", "IMPL" (default: "IMPL"), "FRAGMENT"
dispatch_key: PyTorch dispatch key (default: "") dispatch_key: PyTorch dispatch key (default: "")
""" """
def __init__(self, ns, kind, dispatch_key=""): def __init__(self, ns, kind, dispatch_key=""):
if kind not in ('IMPL', 'DEF', 'FRAGMENT'): if kind not in ("IMPL", "DEF", "FRAGMENT"):
raise ValueError("Unsupported kind: ", kind) raise ValueError("Unsupported kind: ", kind)
if ns in _reserved_namespaces and (kind == "DEF" or kind == 'FRAGMENT'): if ns in _reserved_namespaces and (kind == "DEF" or kind == "FRAGMENT"):
raise ValueError(ns, " is a reserved namespace. Please try creating a library with another name.") raise ValueError(
ns,
" is a reserved namespace. Please try creating a library with another name.",
)
frame = traceback.extract_stack(limit=3)[0] frame = traceback.extract_stack(limit=3)[0]
filename, lineno = frame.filename, frame.lineno filename, lineno = frame.filename, frame.lineno
self.m: Optional[Any] = torch._C._dispatch_library(kind, ns, dispatch_key, filename, lineno) self.m: Optional[Any] = torch._C._dispatch_library(
kind, ns, dispatch_key, filename, lineno
)
self.ns = ns self.ns = ns
self._op_defs: Set[str] = set() self._op_defs: Set[str] = set()
self._op_impls: Set[str] = set() self._op_impls: Set[str] = set()
@ -79,13 +93,21 @@ class Library:
# Python __del__ can lead to weird things (globals and locals may already # Python __del__ can lead to weird things (globals and locals may already
# be gone when __del__ actually gets called!). finalizers help the # be gone when __del__ actually gets called!). finalizers help the
# situation because it lets us capture references and keeps them alive # situation because it lets us capture references and keeps them alive
weakref.finalize(self, _del_library, _impls, self._op_impls, _defs, self._op_defs, self._registration_handles) weakref.finalize(
self,
_del_library,
_impls,
self._op_impls,
_defs,
self._op_defs,
self._registration_handles,
)
def __repr__(self): def __repr__(self):
return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>" return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>"
def define(self, schema, alias_analysis="", *, tags=()): def define(self, schema, alias_analysis="", *, tags=()):
r'''Defines a new operator and its semantics in the ns namespace. r"""Defines a new operator and its semantics in the ns namespace.
Args: Args:
schema: function schema to define a new operator. schema: function schema to define a new operator.
@ -102,7 +124,7 @@ class Library:
Example:: Example::
>>> my_lib = Library("mylib", "DEF") >>> my_lib = Library("mylib", "DEF")
>>> my_lib.define("sum(Tensor self) -> Tensor") >>> my_lib.define("sum(Tensor self) -> Tensor")
''' """
# This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
# AliasAnalysis type in C++ # AliasAnalysis type in C++
if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]: if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]:
@ -113,7 +135,9 @@ class Library:
name = schema.split("(")[0] name = schema.split("(")[0]
packet_name = name.split(".")[0] if "." in name else name packet_name = name.split(".")[0] if "." in name else name
has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr(getattr(torch.ops, self.ns), packet_name) has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr(
getattr(torch.ops, self.ns), packet_name
)
result = self.m.define(schema, alias_analysis, tuple(tags)) result = self.m.define(schema, alias_analysis, tuple(tags))
name = schema.split("(")[0] name = schema.split("(")[0]
@ -131,7 +155,7 @@ class Library:
return result return result
def _register_fake(self, op_name, fn, _stacklevel=1): def _register_fake(self, op_name, fn, _stacklevel=1):
r'''Registers the fake impl for an operator defined in the library.''' r"""Registers the fake impl for an operator defined in the library."""
source = torch._library.utils.get_source(_stacklevel + 1) source = torch._library.utils.get_source(_stacklevel + 1)
frame = sys._getframe(_stacklevel) frame = sys._getframe(_stacklevel)
caller_module = inspect.getmodule(frame) caller_module = inspect.getmodule(frame)
@ -141,7 +165,9 @@ class Library:
# TODO(rzou): We're gonna need to stage this change with torchvision, # TODO(rzou): We're gonna need to stage this change with torchvision,
# since torchvision is github first. # since torchvision is github first.
if caller_module_name is not None and caller_module_name.startswith("torchvision."): if caller_module_name is not None and caller_module_name.startswith(
"torchvision."
):
caller_module_name = None caller_module_name = None
qualname = f"{self.ns}::{op_name}" qualname = f"{self.ns}::{op_name}"
@ -154,8 +180,8 @@ class Library:
handle = entry.abstract_impl.register(func_to_register, source) handle = entry.abstract_impl.register(func_to_register, source)
self._registration_handles.append(handle) self._registration_handles.append(handle)
def _impl_with_aoti_compile(self, op_name, dispatch_key=''): def _impl_with_aoti_compile(self, op_name, dispatch_key=""):
r'''Register the operator to use the AOTI-compiled implementation. r"""Register the operator to use the AOTI-compiled implementation.
Args: Args:
op_name: operator name (along with the overload) or OpOverload object. op_name: operator name (along with the overload) or OpOverload object.
@ -165,8 +191,8 @@ class Library:
Example:: Example::
>>> my_lib = Library("aten", "IMPL") >>> my_lib = Library("aten", "IMPL")
>>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU") >>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU")
''' """
if dispatch_key == '': if dispatch_key == "":
dispatch_key = self.dispatch_key dispatch_key = self.dispatch_key
assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense) assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense)
@ -175,19 +201,24 @@ class Library:
elif isinstance(op_name, OpOverload): elif isinstance(op_name, OpOverload):
name = op_name._schema.name name = op_name._schema.name
overload_name = op_name._schema.overload_name overload_name = op_name._schema.overload_name
if overload_name != '': if overload_name != "":
name = name + '.' + overload_name name = name + "." + overload_name
else: else:
raise RuntimeError("_impl_with_aoti_compile should be passed either a name or an OpOverload object " raise RuntimeError(
"as the first argument") "_impl_with_aoti_compile should be passed either a name or an OpOverload object "
"as the first argument"
)
key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
if key in _impls: if key in _impls:
# TODO: in future, add more info about where the existing function is registered (this info is # TODO: in future, add more info about where the existing function is registered (this info is
# today already returned by the C++ warning when _impl_with_aoti_compile is called but we error out before that) # today already returned by the C++ warning when _impl_with_aoti_compile is called but we error out before that)
raise RuntimeError("This is not allowed since there's already a kernel registered from python overriding {}" raise RuntimeError(
"'s behavior for {} dispatch key and {} namespace.". "This is not allowed since there's already a kernel registered from python overriding {}"
format(name.split("::")[-1], dispatch_key, self.ns)) "'s behavior for {} dispatch key and {} namespace.".format(
name.split("::")[-1], dispatch_key, self.ns
)
)
assert self.m is not None assert self.m is not None
impl_fn: Callable = self.m.impl_with_aoti_compile impl_fn: Callable = self.m.impl_with_aoti_compile
@ -196,8 +227,8 @@ class Library:
_impls.add(key) _impls.add(key)
self._op_impls.add(key) self._op_impls.add(key)
def impl(self, op_name, fn, dispatch_key='', *, with_keyset=False): def impl(self, op_name, fn, dispatch_key="", *, with_keyset=False):
r'''Registers the function implementation for an operator defined in the library. r"""Registers the function implementation for an operator defined in the library.
Args: Args:
op_name: operator name (along with the overload) or OpOverload object. op_name: operator name (along with the overload) or OpOverload object.
@ -211,10 +242,12 @@ class Library:
>>> def div_cpu(self, other): >>> def div_cpu(self, other):
>>> return self * (1 / other) >>> return self * (1 / other)
>>> my_lib.impl("div.Tensor", div_cpu, "CPU") >>> my_lib.impl("div.Tensor", div_cpu, "CPU")
''' """
if not callable(fn): if not callable(fn):
raise TypeError(f"Input function is required to be a callable but found type {type(fn)}") raise TypeError(
if dispatch_key == '': f"Input function is required to be a callable but found type {type(fn)}"
)
if dispatch_key == "":
dispatch_key = self.dispatch_key dispatch_key = self.dispatch_key
if isinstance(op_name, str): if isinstance(op_name, str):
@ -222,37 +255,50 @@ class Library:
elif isinstance(op_name, OpOverload): elif isinstance(op_name, OpOverload):
name = op_name._schema.name name = op_name._schema.name
overload_name = op_name._schema.overload_name overload_name = op_name._schema.overload_name
if overload_name != '': if overload_name != "":
name = name + '.' + overload_name name = name + "." + overload_name
else: else:
raise RuntimeError("impl should be passed either a name or an OpOverload object as the first argument") raise RuntimeError(
"impl should be passed either a name or an OpOverload object as the first argument"
)
key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
if key in _impls: if key in _impls:
# TODO: in future, add more info about where the existing function is registered (this info is # TODO: in future, add more info about where the existing function is registered (this info is
# today already returned by the C++ warning when impl is called but we error out before that) # today already returned by the C++ warning when impl is called but we error out before that)
raise RuntimeError("This is not allowed since there's already a kernel registered from python overriding {}" raise RuntimeError(
"'s behavior for {} dispatch key and {} namespace.". "This is not allowed since there's already a kernel registered from python overriding {}"
format(name.split("::")[-1], dispatch_key, self.ns)) "'s behavior for {} dispatch key and {} namespace.".format(
name.split("::")[-1], dispatch_key, self.ns
)
)
if dispatch_key == "Meta": if dispatch_key == "Meta":
dispatcher_op_name = name dispatcher_op_name = name
if '::' not in dispatcher_op_name: if "::" not in dispatcher_op_name:
dispatcher_op_name = f'{self.ns}::{dispatcher_op_name}' dispatcher_op_name = f"{self.ns}::{dispatcher_op_name}"
# Internally, we shouldn't be registering meta kernels for any operators that # Internally, we shouldn't be registering meta kernels for any operators that
# have CompositeImplicitAutograd kernels. # have CompositeImplicitAutograd kernels.
# Instead, we should be letting those decompositions run, and writing meta kernels # Instead, we should be letting those decompositions run, and writing meta kernels
# only for the base operators. # only for the base operators.
if torch._C._dispatch_has_kernel_for_dispatch_key(dispatcher_op_name, "CompositeImplicitAutograd"): if torch._C._dispatch_has_kernel_for_dispatch_key(
dispatcher_op_name, "CompositeImplicitAutograd"
):
raise RuntimeError( raise RuntimeError(
f"We should not register a meta kernel directly to the operator '{name}'," f"We should not register a meta kernel directly to the operator '{name}',"
" because it has a CompositeImplicitAutograd kernel in core." " because it has a CompositeImplicitAutograd kernel in core."
" Instead we should let the operator decompose, and ensure that we have meta kernels" " Instead we should let the operator decompose, and ensure that we have meta kernels"
" for the base ops that it decomposes into.") " for the base ops that it decomposes into."
)
assert self.m is not None assert self.m is not None
self.m.impl(name, dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd", fn, with_keyset) self.m.impl(
name,
dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd",
fn,
with_keyset,
)
_impls.add(key) _impls.add(key)
self._op_impls.add(key) self._op_impls.add(key)
@ -283,7 +329,9 @@ class Library:
delattr(namespace, name) delattr(namespace, name)
def _del_library(captured_impls, op_impls, captured_defs, op_defs, registration_handles): def _del_library(
captured_impls, op_impls, captured_defs, op_defs, registration_handles
):
captured_impls -= op_impls captured_impls -= op_impls
captured_defs -= op_defs captured_defs -= op_defs
for handle in registration_handles: for handle in registration_handles:
@ -357,7 +405,8 @@ def define(qualname, schema, *, lib=None, tags=()):
if not isinstance(qualname, str): if not isinstance(qualname, str):
raise ValueError( raise ValueError(
f"define(qualname, schema): expected qualname " f"define(qualname, schema): expected qualname "
f"to be instance of str, got {type(qualname)}") f"to be instance of str, got {type(qualname)}"
)
namespace, name = torch._library.utils.parse_namespace(qualname) namespace, name = torch._library.utils.parse_namespace(qualname)
if lib is None: if lib is None:
lib = Library(namespace, "FRAGMENT") lib = Library(namespace, "FRAGMENT")
@ -366,7 +415,8 @@ def define(qualname, schema, *, lib=None, tags=()):
raise ValueError( raise ValueError(
f"define(qualname, schema, ...): expected schema " f"define(qualname, schema, ...): expected schema "
f'to look like e.g. "(Tensor x) -> Tensor" but ' f'to look like e.g. "(Tensor x) -> Tensor" but '
f'got "{schema}"') f'got "{schema}"'
)
lib.define(name + schema, alias_analysis="", tags=tags) lib.define(name + schema, alias_analysis="", tags=tags)
@ -375,10 +425,12 @@ def _(lib: Library, schema, alias_analysis=""):
"""The old torch.library.define. """The old torch.library.define.
We're keeping this around for BC reasons We're keeping this around for BC reasons
""" """
def wrap(f): def wrap(f):
name = lib.define(schema, alias_analysis) name = lib.define(schema, alias_analysis)
lib.impl(name, f) lib.impl(name, f)
return f return f
return wrap return wrap
@ -460,9 +512,11 @@ def _device_type_to_key(device_type: str) -> str:
@impl.register @impl.register
def _(lib: Library, name, dispatch_key=""): def _(lib: Library, name, dispatch_key=""):
"""Legacy torch.library.impl API. Kept around for BC""" """Legacy torch.library.impl API. Kept around for BC"""
def wrap(f): def wrap(f):
lib.impl(name, f, dispatch_key) lib.impl(name, f, dispatch_key)
return f return f
return wrap return wrap
@ -480,7 +534,9 @@ def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel) return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel)
_op_identifier = Union[str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef"] _op_identifier = Union[
str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef"
]
def register_kernel( def register_kernel(
@ -489,7 +545,8 @@ def register_kernel(
func: Optional[Callable] = None, func: Optional[Callable] = None,
/, /,
*, *,
lib: Optional[Library] = None): lib: Optional[Library] = None,
):
"""Register an implementation for a device type for this operator. """Register an implementation for a device type for this operator.
Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu". Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
@ -530,7 +587,9 @@ def register_kernel(
""" """
if not isinstance(op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)): if not isinstance(
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
):
raise ValueError("register_kernel(op): got unexpected type for op: {type(op)}") raise ValueError("register_kernel(op): got unexpected type for op: {type(op)}")
if isinstance(op, torch._ops.OpOverload): if isinstance(op, torch._ops.OpOverload):
op = op._name op = op._name
@ -549,7 +608,8 @@ def register_fake(
/, /,
*, *,
lib: Optional[Library] = None, lib: Optional[Library] = None,
_stacklevel: int = 1): _stacklevel: int = 1,
):
r"""Register a FakeTensor implementation ("fake impl") for this operator. r"""Register a FakeTensor implementation ("fake impl") for this operator.
Also sometimes known as a "meta kernel", "abstract impl". Also sometimes known as a "meta kernel", "abstract impl".
@ -630,7 +690,9 @@ def register_fake(
>>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x)) >>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
""" """
if not isinstance(op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)): if not isinstance(
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
):
raise ValueError("register_fake(op): got unexpected type for op: {type(op)}") raise ValueError("register_fake(op): got unexpected type for op: {type(op)}")
if isinstance(op, torch._ops.OpOverload): if isinstance(op, torch._ops.OpOverload):
op = op._name op = op._name
@ -661,7 +723,14 @@ def register_fake(
return register(func) return register(func)
def register_autograd(op: _op_identifier, backward: Callable, /, *, setup_context: Optional[Callable] = None, lib=None) -> None: def register_autograd(
op: _op_identifier,
backward: Callable,
/,
*,
setup_context: Optional[Callable] = None,
lib=None,
) -> None:
r"""Register a backward formula for this custom op. r"""Register a backward formula for this custom op.
In order for an operator to work with autograd, you need to register In order for an operator to work with autograd, you need to register
@ -737,8 +806,12 @@ def register_autograd(op: _op_identifier, backward: Callable, /, *, setup_contex
>>> assert torch.allclose(grad_x, torch.full_like(x, 3.14)) >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
""" """
if not isinstance(op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)): if not isinstance(
raise ValueError(f"register_autograd(op): got unexpected type for op: {type(op)}") op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
):
raise ValueError(
f"register_autograd(op): got unexpected type for op: {type(op)}"
)
if isinstance(op, torch._ops.OpOverload): if isinstance(op, torch._ops.OpOverload):
op = op._name op = op._name
opdef = _maybe_get_opdef(op) opdef = _maybe_get_opdef(op)
@ -760,7 +833,8 @@ def register_autograd(op: _op_identifier, backward: Callable, /, *, setup_contex
raise NotImplementedError( raise NotImplementedError(
f"register_autograd with kwarg-only Tensor args. In the original " f"register_autograd with kwarg-only Tensor args. In the original "
f"definition of the op, please make your tensors not kwarg-only. " f"definition of the op, please make your tensors not kwarg-only. "
f"Got: {schema}") f"Got: {schema}"
)
info = _library.autograd.Info(backward, setup_context) info = _library.autograd.Info(backward, setup_context)
autograd_kernel = _library.autograd.make_autograd_impl(op, info) autograd_kernel = _library.autograd.make_autograd_impl(op, info)
@ -788,8 +862,8 @@ def _check_pystubs_once(func, qualname, actual_module_name):
return func(*args, **kwargs) return func(*args, **kwargs)
maybe_pystub = torch._C._dispatch_pystub( maybe_pystub = torch._C._dispatch_pystub(
op._schema.name, op._schema.name, op._schema.overload_name
op._schema.overload_name) )
if maybe_pystub is None: if maybe_pystub is None:
if torch._library.utils.requires_set_python_module(): if torch._library.utils.requires_set_python_module():
namespace = op.namespace namespace = op.namespace
@ -800,7 +874,8 @@ def _check_pystubs_once(func, qualname, actual_module_name):
f'companion C++ `m.set_python_module("{actual_module_name}")` ' f'companion C++ `m.set_python_module("{actual_module_name}")` '
f"call, but we could not find one. Please add that to " f"call, but we could not find one. Please add that to "
f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the " f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the "
f"operator was registered in ({cpp_filename})") f"operator was registered in ({cpp_filename})"
)
else: else:
pystub_module = maybe_pystub[0] pystub_module = maybe_pystub[0]
if actual_module_name != pystub_module: if actual_module_name != pystub_module:
@ -809,9 +884,11 @@ def _check_pystubs_once(func, qualname, actual_module_name):
f"Operator '{qualname}' specified that its python fake impl " f"Operator '{qualname}' specified that its python fake impl "
f"is in the Python module '{pystub_module}' but it was actually found " f"is in the Python module '{pystub_module}' but it was actually found "
f"in '{actual_module_name}'. Please either move the fake impl " f"in '{actual_module_name}'. Please either move the fake impl "
f"or correct the m.set_python_module call ({cpp_filename})") f"or correct the m.set_python_module call ({cpp_filename})"
)
checked = True checked = True
return func(*args, **kwargs) return func(*args, **kwargs)
return inner return inner
@ -929,4 +1006,7 @@ def opcheck(
""" """
import torch.testing._internal.optests as optests import torch.testing._internal.optests as optests
return optests.opcheck(op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception)
return optests.opcheck(
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
)

View File

@ -23,19 +23,26 @@ instructions in the ``README.md`` in that directory.
import __future__ # noqa: F404 import __future__ # noqa: F404
import collections import collections
import contextlib
import functools import functools
import types import types
import warnings import warnings
from typing import Dict, Set, List, Any, Callable, Iterable, Type, Tuple
from functools import wraps from functools import wraps
import contextlib from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Type
import torch import torch
from torch._C import ( from torch._C import (
_has_torch_function, _has_torch_function_unary, _add_docstr,
_has_torch_function_variadic, _add_docstr, _get_function_stack_at,
_push_on_torch_function_stack, _pop_torch_function_stack, _get_function_stack_at, _len_torch_function_stack, _has_torch_function,
_is_torch_function_mode_enabled) _has_torch_function_unary,
_has_torch_function_variadic,
_is_torch_function_mode_enabled,
_len_torch_function_stack,
_pop_torch_function_stack,
_push_on_torch_function_stack,
)
__all__ = [ __all__ = [
"get_ignored_functions", "get_ignored_functions",
@ -52,7 +59,8 @@ __all__ = [
def _disable_user_warnings( def _disable_user_warnings(
func: Callable, regex: str = '.*is deprecated, please use.*', module: str = 'torch') -> Callable: func: Callable, regex: str = ".*is deprecated, please use.*", module: str = "torch"
) -> Callable:
""" """
Decorator that temporarily disables ``UserWarning``s for the given ``module`` if the warning message matches the Decorator that temporarily disables ``UserWarning``s for the given ``module`` if the warning message matches the
given ``regex`` pattern. given ``regex`` pattern.
@ -75,8 +83,11 @@ def _disable_user_warnings(
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, message=regex, module=module) warnings.filterwarnings(
"ignore", category=UserWarning, message=regex, module=module
)
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper
@ -470,8 +481,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.batch_norm_update_stats: lambda input, running_mean, running_var, momentum: -1, torch.batch_norm_update_stats: lambda input, running_mean, running_var, momentum: -1,
torch.bernoulli: lambda input, generator=None, out=None: -1, torch.bernoulli: lambda input, generator=None, out=None: -1,
torch.bilinear: lambda input1, input2, weight, bias: -1, torch.bilinear: lambda input1, input2, weight, bias: -1,
torch.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None, reduce=None, torch.binary_cross_entropy_with_logits: (
reduction='mean', pos_weight=None: -1), lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1
),
torch.bincount: lambda input, weights=None, minlength=0: -1, torch.bincount: lambda input, weights=None, minlength=0: -1,
torch.binomial: lambda count, prob, generator=None: -1, torch.binomial: lambda count, prob, generator=None: -1,
torch.bitwise_and: lambda input, other, out=None: -1, torch.bitwise_and: lambda input, other, out=None: -1,
@ -489,9 +501,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.cat: lambda tensors, dim=0, out=None: -1, torch.cat: lambda tensors, dim=0, out=None: -1,
torch.concat: lambda tensors, dim=0, out=None: -1, # alias for torch.cat torch.concat: lambda tensors, dim=0, out=None: -1, # alias for torch.cat
torch.concatenate: lambda tensors, dim=0, out=None: -1, # alias for torch.concatenate torch.concatenate: lambda tensors, dim=0, out=None: -1, # alias for torch.concatenate
torch.cdist: lambda x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary': -1, torch.cdist: lambda x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary": -1,
torch.ceil: lambda input, out=None: -1, torch.ceil: lambda input, out=None: -1,
torch.celu: lambda input, alpha=1., inplace=False: -1, torch.celu: lambda input, alpha=1.0, inplace=False: -1,
torch.chain_matmul: lambda *matrices, out=None: -1, torch.chain_matmul: lambda *matrices, out=None: -1,
torch.channel_shuffle: lambda input, groups: -1, torch.channel_shuffle: lambda input, groups: -1,
torch.cholesky: lambda input, upper=False, out=None: -1, torch.cholesky: lambda input, upper=False, out=None: -1,
@ -528,14 +540,15 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.conv_transpose3d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1, torch.conv_transpose3d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
torch.corrcoef: lambda input: -1, torch.corrcoef: lambda input: -1,
torch.cos: lambda input, out=None: -1, torch.cos: lambda input, out=None: -1,
torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1, torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1,
torch.cosh: lambda input, out=None: -1, torch.cosh: lambda input, out=None: -1,
torch.cosine_similarity: lambda x1, x2, dim=1, eps=1e-8: -1, torch.cosine_similarity: lambda x1, x2, dim=1, eps=1e-8: -1,
torch.count_nonzero: lambda input: -1, torch.count_nonzero: lambda input: -1,
torch.cross: lambda input, other, dim=None, out=None: -1, torch.cross: lambda input, other, dim=None, out=None: -1,
torch.linalg.cross: lambda input, other, dim=-1, out=None: -1, torch.linalg.cross: lambda input, other, dim=-1, out=None: -1,
torch.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', torch.ctc_loss: (
zero_infinity=False: -1), lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1
),
torch.cummax: lambda input, dim, out=None: -1, torch.cummax: lambda input, dim, out=None: -1,
torch.cummin: lambda input, dim, out=None: -1, torch.cummin: lambda input, dim, out=None: -1,
torch.cumprod: lambda input, dim, out=None, dtype=None: -1, torch.cumprod: lambda input, dim, out=None, dtype=None: -1,
@ -570,10 +583,12 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.linalg.eigh: lambda input, UPLO="L", out=None: -1, torch.linalg.eigh: lambda input, UPLO="L", out=None: -1,
torch.linalg.eigvalsh: lambda input, UPLO="L", out=None: -1, torch.linalg.eigvalsh: lambda input, UPLO="L", out=None: -1,
torch.einsum: lambda equation, *operands: -1, torch.einsum: lambda equation, *operands: -1,
torch.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, torch.embedding: (
sparse=False: -1), lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 # noqa: B950
torch.embedding_bag: (lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False, ),
mode='mean', sparse=False, per_sample_weights=None, padding_idx=None: -1), torch.embedding_bag: (
lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, padding_idx=None: -1 # noqa: B950
),
torch.empty_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, torch.empty_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
torch.eq: lambda input, other, out=None: -1, torch.eq: lambda input, other, out=None: -1,
torch.equal: lambda input, other: -1, torch.equal: lambda input, other: -1,
@ -585,14 +600,15 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.expm1: lambda input, out=None: -1, torch.expm1: lambda input, out=None: -1,
torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1, torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1, torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1,
torch.fused_moving_avg_obs_fake_quant: (lambda x, observer_on, fake_quant_on, averaging_const, running_min, torch.fused_moving_avg_obs_fake_quant: (
running_max, scale, zero_point, quant_min, quant_max, ch_axis, lambda x, observer_on, fake_quant_on, averaging_const, running_min, running_max, scale, zero_point, quant_min, quant_max, ch_axis, per_row_fake_quant=False, symmetric_quant=False: -1 # noqa: B950
per_row_fake_quant=False, symmetric_quant=False: -1), ),
torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1, torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1,
torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1, torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1,
torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1, torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1, # noqa: B950
torch.fbgemm_linear_int8_weight_fp32_activation: (lambda input, weight, packed, col_offsets, weight_scale, torch.fbgemm_linear_int8_weight_fp32_activation: (
weight_zero_point, bias: -1), lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1
),
torch.fbgemm_linear_quantize_weight: lambda input: -1, torch.fbgemm_linear_quantize_weight: lambda input: -1,
torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1, torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1,
torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1, torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1,
@ -630,7 +646,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.fmod: lambda input, other, out=None: -1, torch.fmod: lambda input, other, out=None: -1,
torch.frac: lambda input, out=None: -1, torch.frac: lambda input, out=None: -1,
torch.frexp: lambda input, out=None: -1, torch.frexp: lambda input, out=None: -1,
torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, # noqa: B950
torch._functional_assert_async: lambda input, msg, dep_token: -1, torch._functional_assert_async: lambda input, msg, dep_token: -1,
torch.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1, torch.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1,
torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1, torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1,
@ -653,7 +669,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.greater: lambda input, other, out=None: -1, torch.greater: lambda input, other, out=None: -1,
torch.hardshrink: lambda input, lambd=0.5: -1, torch.hardshrink: lambda input, lambd=0.5: -1,
torch.heaviside: lambda input, values, out=None: -1, torch.heaviside: lambda input, values, out=None: -1,
torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction='mean': -1, torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950
torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1, torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
torch.histogram: lambda input, bins=100, min=None, max=None, weight=None, density=False, out=None: -1, torch.histogram: lambda input, bins=100, min=None, max=None, weight=None, density=False, out=None: -1,
torch.histogramdd: lambda input, bins, range=None, weight=None, density=False: -1, torch.histogramdd: lambda input, bins, range=None, weight=None, density=False: -1,
@ -677,8 +693,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.isreal: lambda tensor: -1, torch.isreal: lambda tensor: -1,
torch.isposinf: lambda input, out=None: -1, torch.isposinf: lambda input, out=None: -1,
torch.isneginf: lambda input, out=None: -1, torch.isneginf: lambda input, out=None: -1,
torch.instance_norm: (lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps, torch.instance_norm: (
cudnn_enabled: -1), lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps, cudnn_enabled: -1
),
torch.int_repr: lambda input: -1, torch.int_repr: lambda input: -1,
torch.inverse: lambda input, out=None: -1, torch.inverse: lambda input, out=None: -1,
torch.linalg.inv: lambda input, out=None: -1, torch.linalg.inv: lambda input, out=None: -1,
@ -694,9 +711,10 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.is_signed: lambda input: -1, torch.is_signed: lambda input: -1,
torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1, torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1,
torch.isnan: lambda input: -1, torch.isnan: lambda input: -1,
torch.istft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, torch.istft: (
normalized=False, onesided=None, length=None, return_complex=False: -1), lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, normalized=False, onesided=None, length=None, return_complex=False: -1 # noqa: B950
torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1, ),
torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1,
torch.kron: lambda input, other: -1, torch.kron: lambda input, other: -1,
torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1, torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1,
torch.linalg.ldl_factor_ex: lambda input, hermitian=False, check_errors=False, out=None: -1, torch.linalg.ldl_factor_ex: lambda input, hermitian=False, check_errors=False, out=None: -1,
@ -709,8 +727,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.less_equal: lambda input, other, out=None: -1, torch.less_equal: lambda input, other, out=None: -1,
torch.lerp: lambda input, end, weight, out=None: -1, torch.lerp: lambda input, end, weight, out=None: -1,
torch.lgamma: lambda input, out=None: -1, torch.lgamma: lambda input, out=None: -1,
torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None, torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None, tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1, # noqa: B950
tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1,
torch.log: lambda input, out=None: -1, torch.log: lambda input, out=None: -1,
torch.log_softmax: lambda input, dim, dtype=None: -1, torch.log_softmax: lambda input, dim, dtype=None: -1,
torch.log10: lambda input, out=None: -1, torch.log10: lambda input, out=None: -1,
@ -732,7 +749,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.less: lambda input, other, out=None: -1, torch.less: lambda input, other, out=None: -1,
torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1, torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1,
torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1, torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1,
torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1, # type: ignore[attr-defined] # noqa: B950 torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1, # type: ignore[attr-defined] # noqa: B950
torch.masked_fill: lambda input, mask, value: -1, torch.masked_fill: lambda input, mask, value: -1,
torch.masked_scatter: lambda input, mask, source: -1, torch.masked_scatter: lambda input, mask, source: -1,
torch.masked_select: lambda input, mask, out=None: -1, torch.masked_select: lambda input, mask, out=None: -1,
@ -754,8 +771,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1, torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1, torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1, torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
torch.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1, torch.max_pool1d_with_indices: (
return_indices=False, ceil_mode=False: -1), lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
),
torch.mean: lambda input, dim=None: -1, torch.mean: lambda input, dim=None: -1,
torch.nanmean: lambda input, dim=None, keepdim=False, dtype=None, out=None: -1, torch.nanmean: lambda input, dim=None, keepdim=False, dtype=None, out=None: -1,
torch.median: lambda input, dim=None: -1, torch.median: lambda input, dim=None: -1,
@ -764,17 +782,21 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.min: lambda input, out=None: -1, torch.min: lambda input, out=None: -1,
torch.minimum: lambda input, other, out=None: -1, torch.minimum: lambda input, other, out=None: -1,
torch.fmin: lambda input, other, out=None: -1, torch.fmin: lambda input, other, out=None: -1,
torch.miopen_batch_norm: (lambda input, weight, bias, running_mean, running_var, training, torch.miopen_batch_norm: (
exponential_average_factor, epsilon: -1), lambda input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon: -1
torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1, ),
torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1, # noqa: B950
torch.miopen_convolution_add_relu: lambda input, weight, z, alpha, bias, stride, padding, dilation, groups: -1, torch.miopen_convolution_add_relu: lambda input, weight, z, alpha, bias, stride, padding, dilation, groups: -1,
torch.miopen_convolution_relu: lambda input, weight, bias, stride, padding, dilation, groups: -1, torch.miopen_convolution_relu: lambda input, weight, bias, stride, padding, dilation, groups: -1,
torch.miopen_convolution_transpose: (lambda input, weight, bias, padding, output_padding, stride, dilation, torch.miopen_convolution_transpose: (
groups, benchmark, deterministic: -1), lambda input, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic: -1
torch.miopen_depthwise_convolution: (lambda input, weight, bias, padding, stride, dilation, groups, benchmark, ),
deterministic: -1), torch.miopen_depthwise_convolution: (
torch.miopen_rnn: (lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1
dropout, train, bidirectional, batch_sizes, dropout_state: -1), ),
torch.miopen_rnn: (
lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state: -1 # noqa: B950
),
torch.mm: lambda input, mat2, out=None: -1, torch.mm: lambda input, mat2, out=None: -1,
torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1, torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1,
torch.movedim: lambda input, source, destination: -1, torch.movedim: lambda input, source, destination: -1,
@ -809,62 +831,76 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.nn.functional.adaptive_max_pool3d_with_indices: lambda input, output_size, return_indices=False: -1, torch.nn.functional.adaptive_max_pool3d_with_indices: lambda input, output_size, return_indices=False: -1,
torch.nn.functional.affine_grid: lambda theta, size, align_corners=None: -1, torch.nn.functional.affine_grid: lambda theta, size, align_corners=None: -1,
torch.nn.functional.alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1, torch.nn.functional.alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
torch.nn.functional.avg_pool2d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, torch.nn.functional.avg_pool2d: (
count_include_pad=True, divisor_override=None: -1), lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 # noqa: B950
torch.nn.functional.avg_pool3d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, ),
count_include_pad=True, divisor_override=None: -1), torch.nn.functional.avg_pool3d: (
torch.nn.functional.batch_norm: (lambda input, running_mean, running_var, weight=None, bias=None, training=False, lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 # noqa: B950
momentum=0.1, eps=1e-05: -1), ),
torch.nn.functional.batch_norm: (
lambda input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05: -1
),
torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1, torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1,
torch.nn.functional.binary_cross_entropy: (lambda input, target, weight=None, size_average=None, reduce=None, torch.nn.functional.binary_cross_entropy: (
reduction="mean": -1), lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1
torch.nn.functional.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None, ),
reduce=None, reduction="mean", pos_weight=None: -1), torch.nn.functional.binary_cross_entropy_with_logits: (
lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1
),
torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1, torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1,
torch.nn.functional.cosine_embedding_loss: (lambda input1, input2, target, margin=0, size_average=None, torch.nn.functional.cosine_embedding_loss: (
reduce=None, reduction='mean': -1), lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1
torch.nn.functional.cross_entropy: (lambda input, target, weight=None, size_average=None, ignore_index=-100, ),
reduce=None, reduction="mean", label_smoothing=0.0: -1), torch.nn.functional.cross_entropy: (
torch.nn.functional.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0, lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", label_smoothing=0.0: -1 # noqa: B950
reduction='mean', zero_infinity=False: -1), ),
torch.nn.functional.ctc_loss: (
lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1
),
torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1, torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1,
torch.nn.functional.dropout1d: lambda input, p=0.5, training=True, inplace=False: -1, torch.nn.functional.dropout1d: lambda input, p=0.5, training=True, inplace=False: -1,
torch.nn.functional.dropout2d: lambda input, p=0.5, training=True, inplace=False: -1, torch.nn.functional.dropout2d: lambda input, p=0.5, training=True, inplace=False: -1,
torch.nn.functional.dropout3d: lambda input, p=0.5, training=True, inplace=False: -1, torch.nn.functional.dropout3d: lambda input, p=0.5, training=True, inplace=False: -1,
torch.nn.functional.elu: lambda input, alpha=1.0, inplace=False: -1, torch.nn.functional.elu: lambda input, alpha=1.0, inplace=False: -1,
torch.nn.functional.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, torch.nn.functional.embedding: (
scale_grad_by_freq=False, sparse=False: -1), lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 # noqa: B950
torch.nn.functional.embedding_bag: (lambda input, weight, offsets=None, max_norm=None, norm_type=2, ),
scale_grad_by_freq=False, mode='mean', sparse=False, per_sample_weights=None, torch.nn.functional.embedding_bag: (
include_last_offset=False, padding_idx=None: -1), lambda input, weight, offsets=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None: -1 # noqa: B950
),
torch.nn.functional.feature_alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1, torch.nn.functional.feature_alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
torch.nn.functional.fold: lambda input, output_size, kernel_size, dilation=1, padding=0, stride=1: -1, torch.nn.functional.fold: lambda input, output_size, kernel_size, dilation=1, padding=0, stride=1: -1,
torch.nn.functional.fractional_max_pool2d: (lambda input, kernel_size, output_size=None, output_ratio=None, torch.nn.functional.fractional_max_pool2d: (
return_indices=False, _random_samples=None: -1), lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
),
torch.nn.functional.fractional_max_pool2d_with_indices: ( torch.nn.functional.fractional_max_pool2d_with_indices: (
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
_random_samples=None: -1), ),
torch.nn.functional.fractional_max_pool3d: (lambda input, kernel_size, output_size=None, output_ratio=None, torch.nn.functional.fractional_max_pool3d: (
return_indices=False, _random_samples=None: -1), lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
),
torch.nn.functional.fractional_max_pool3d_with_indices: ( torch.nn.functional.fractional_max_pool3d_with_indices: (
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
_random_samples=None: -1), ),
torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction='mean': -1, torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction="mean": -1,
torch.nn.functional.gelu: lambda input, approximate='none': -1, torch.nn.functional.gelu: lambda input, approximate="none": -1,
torch.nn.functional.glu: lambda input, dim=-1: -1, torch.nn.functional.glu: lambda input, dim=-1: -1,
torch.nn.functional.grid_sample: lambda input, grid, mode='bilinear', padding_mode='zeros', align_corners=None: -1, torch.nn.functional.grid_sample: lambda input, grid, mode="bilinear", padding_mode="zeros", align_corners=None: -1, # noqa: B950
torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1, torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1,
torch.nn.functional.gumbel_softmax: lambda logits, tau=1, hard=False, eps=1e-10, dim=-1: -1, torch.nn.functional.gumbel_softmax: lambda logits, tau=1, hard=False, eps=1e-10, dim=-1: -1,
torch.nn.functional.hardshrink: lambda input, lambd=0.5: -1, torch.nn.functional.hardshrink: lambda input, lambd=0.5: -1,
torch.nn.functional.hardtanh: lambda input, min_val=-1., max_val=1., inplace=False: -1, torch.nn.functional.hardtanh: lambda input, min_val=-1.0, max_val=1.0, inplace=False: -1,
torch.nn.functional.hinge_embedding_loss: (lambda input, target, margin=1.0, size_average=None, reduce=None, torch.nn.functional.hinge_embedding_loss: (
reduction='mean': -1), lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1
torch.nn.functional.instance_norm: (lambda input, running_mean=None, running_var=None, weight=None, bias=None, ),
use_input_stats=True, momentum=0.1, eps=1e-05: -1), torch.nn.functional.instance_norm: (
torch.nn.functional.interpolate: (lambda input, size=None, scale_factor=None, mode='nearest', align_corners=None, lambda input, running_mean=None, running_var=None, weight=None, bias=None, use_input_stats=True, momentum=0.1, eps=1e-05: -1 # noqa: B950
recompute_scale_factor=None, antialias=False: -1), ),
torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1, torch.nn.functional.interpolate: (
torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1, lambda input, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None, antialias=False: -1 # noqa: B950
),
torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1, # noqa: B950
torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1,
torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1, torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1, torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1,
torch.nn.functional.linear: lambda input, weight, bias=None: -1, torch.nn.functional.linear: lambda input, weight, bias=None: -1,
@ -874,55 +910,65 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1, torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1, torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
torch.nn.functional.lp_pool3d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1, torch.nn.functional.lp_pool3d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
torch.nn.functional.margin_ranking_loss: (lambda input1, input2, target, margin=0, size_average=None, torch.nn.functional.margin_ranking_loss: (
reduce=None, reduction='mean': -1), lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1
torch.nn.functional.max_pool1d: (lambda input, kernel_size, stride=None, padding=0, dilation=1, ),
ceil_mode=False, return_indices=False: -1), torch.nn.functional.max_pool1d: (
torch.nn.functional.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1, lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1
return_indices=False, ceil_mode=False: -1), ),
torch.nn.functional.max_pool2d: (lambda input, kernel_size, stride=None, padding=0, dilation=1, torch.nn.functional.max_pool1d_with_indices: (
ceil_mode=False, return_indices=False: -1), lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
torch.nn.functional.max_pool2d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1, ),
return_indices=False, ceil_mode=False: -1), torch.nn.functional.max_pool2d: (
torch.nn.functional.max_pool3d: (lambda input, kernel_size, stride=None, padding=0, dilation=1, lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1
return_indices=False, ceil_mode=False: -1), ),
torch.nn.functional.max_pool3d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1, torch.nn.functional.max_pool2d_with_indices: (
return_indices=False, ceil_mode=False: -1), lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, ),
torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, torch.nn.functional.max_pool3d: (
torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1, ),
torch.nn.functional.max_pool3d_with_indices: (
lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
),
torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1,
torch.nn.functional.multi_head_attention_forward: ( torch.nn.functional.multi_head_attention_forward: (
lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None, need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None, v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1 # noqa: B950
add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None, ),
need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None, torch.nn.functional.multi_margin_loss: (
v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1), lambda input, target, p=1, margin=1.0, weight=None, size_average=None, reduce=None, reduction="mean": -1
torch.nn.functional.multi_margin_loss: (lambda input, target, p=1, margin=1.0, weight=None, size_average=None, ),
reduce=None, reduction='mean': -1), torch.nn.functional.multilabel_margin_loss: (
torch.nn.functional.multilabel_margin_loss: (lambda input, target, size_average=None, reduce=None, lambda input, target, size_average=None, reduce=None, reduction="mean": -1
reduction='mean': -1), ),
torch.nn.functional.multilabel_soft_margin_loss: (lambda input, target, weight=None, size_average=None, torch.nn.functional.multilabel_soft_margin_loss: (
reduce=None, reduction='mean': -1), lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1
torch.nn.functional.nll_loss: (lambda input, target, weight=None, size_average=None, ignore_index=-100, ),
reduce=None, reduction='mean': -1), torch.nn.functional.nll_loss: (
lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean": -1
),
torch.nn.functional.normalize: lambda input, p=2, dim=1, eps=1e-12, out=None: -1, torch.nn.functional.normalize: lambda input, p=2, dim=1, eps=1e-12, out=None: -1,
torch.nn.functional.one_hot: lambda tensor, num_classes=-1: -1, torch.nn.functional.one_hot: lambda tensor, num_classes=-1: -1,
torch.nn.functional.pad: lambda input, pad, mode='constant', value=0: -1, torch.nn.functional.pad: lambda input, pad, mode="constant", value=0: -1,
torch.nn.functional.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1, torch.nn.functional.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
torch.nn.functional.poisson_nll_loss: (lambda input, target, log_input=True, full=False, size_average=None, torch.nn.functional.poisson_nll_loss: (
eps=1e-08, reduce=None, reduction='mean': -1), lambda input, target, log_input=True, full=False, size_average=None, eps=1e-08, reduce=None, reduction="mean": -1 # noqa: B950
),
torch.nn.functional.prelu: lambda input, weight: -1, torch.nn.functional.prelu: lambda input, weight: -1,
torch.nn.functional.relu: lambda input, inplace=False: -1, torch.nn.functional.relu: lambda input, inplace=False: -1,
torch.nn.functional.relu6: lambda input, inplace=False: -1, torch.nn.functional.relu6: lambda input, inplace=False: -1,
torch.nn.functional.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1, torch.nn.functional.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1, torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1, # noqa: B950
torch.nn.functional.selu: lambda input, inplace=False: -1, torch.nn.functional.selu: lambda input, inplace=False: -1,
torch.nn.functional.silu: lambda input, inplace=False: -1, torch.nn.functional.silu: lambda input, inplace=False: -1,
torch.nn.functional.mish: lambda input, inplace=False: -1, torch.nn.functional.mish: lambda input, inplace=False: -1,
torch.nn.functional.scaled_dot_product_attention: lambda query, key, value, attn_mask=None, dropout_p=0.0: -1, torch.nn.functional.scaled_dot_product_attention: lambda query, key, value, attn_mask=None, dropout_p=0.0: -1,
torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean', beta=1.: -1, torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", beta=1.0: -1, # noqa: B950
torch.nn.functional.huber_loss: lambda input, target, reduction='mean', delta=1.: -1, torch.nn.functional.huber_loss: lambda input, target, reduction="mean", delta=1.0: -1,
torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1, torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950
torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1, torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1, torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
torch.nn.functional.softplus: lambda input, beta=1, threshold=20: -1, torch.nn.functional.softplus: lambda input, beta=1, threshold=20: -1,
@ -930,25 +976,29 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.nn.functional.softsign: lambda input: -1, torch.nn.functional.softsign: lambda input: -1,
torch.nn.functional.tanhshrink: lambda input: -1, torch.nn.functional.tanhshrink: lambda input: -1,
torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1, torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1,
torch.nn.functional.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, torch.nn.functional.triplet_margin_loss: (
swap=False, size_average=None, reduce=None, reduction='mean': -1), lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 # noqa: B950
torch.nn.functional.triplet_margin_with_distance_loss: (lambda anchor, positive, negative, *, ),
distance_function=None, margin=1.0, torch.nn.functional.triplet_margin_with_distance_loss: (
swap=False, reduction='mean': -1), lambda anchor, positive, negative, *, distance_function=None, margin=1.0, swap=False, reduction="mean": -1
),
torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1, torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
torch.nn.init.uniform_: lambda tensor, a=0., b=1., generator=None: -1, torch.nn.init.uniform_: lambda tensor, a=0.0, b=1.0, generator=None: -1,
torch.nn.init.normal_: lambda tensor, mean=0., std=1., generator=None: -1, torch.nn.init.normal_: lambda tensor, mean=0.0, std=1.0, generator=None: -1,
torch.nn.init.constant_: lambda tensor, val: -1, torch.nn.init.constant_: lambda tensor, val: -1,
torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=None: -1, torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", generator=None: -1, # noqa: B950
torch.nonzero: lambda input, as_tuple=False: -1, torch.nonzero: lambda input, as_tuple=False: -1,
torch.nonzero_static: lambda input, *, size, fill_value=-1: -1, torch.nonzero_static: lambda input, *, size, fill_value=-1: -1,
torch.argwhere: lambda input: -1, torch.argwhere: lambda input: -1,
torch.norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1, torch.norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1,
torch.linalg.norm: lambda input, ord=None, dim=None, keepdim=False, out=None, dtype=None: -1, torch.linalg.norm: lambda input, ord=None, dim=None, keepdim=False, out=None, dtype=None: -1,
torch.linalg.vector_norm: lambda input, ord=2, dim=None, keepdim=False, out=None, dtype=None: -1, torch.linalg.vector_norm: lambda input, ord=2, dim=None, keepdim=False, out=None, dtype=None: -1,
torch.linalg.matrix_norm: lambda input, ord='fro', dim=(-2, -1), keepdim=False, out=None, dtype=None: -1, torch.linalg.matrix_norm: lambda input, ord="fro", dim=(
-2,
-1,
), keepdim=False, out=None, dtype=None: -1,
torch.norm_except_dim: lambda v, pow=2, dim=0: -1, torch.norm_except_dim: lambda v, pow=2, dim=0: -1,
torch.nuclear_norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1, torch.nuclear_norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1,
torch.numel: lambda input: -1, torch.numel: lambda input: -1,
torch.orgqr: lambda input, tau: -1, torch.orgqr: lambda input, tau: -1,
torch.ormqr: lambda input, input2, input3, left=True, transpose=False: -1, torch.ormqr: lambda input, input2, input3, left=True, transpose=False: -1,
@ -975,28 +1025,43 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.q_scale: lambda input: -1, torch.q_scale: lambda input: -1,
torch.q_zero_point: lambda input: -1, torch.q_zero_point: lambda input: -1,
torch.qr: lambda input, some=True, out=None: -1, torch.qr: lambda input, some=True, out=None: -1,
torch.linalg.qr: lambda input, mode='reduced', out=None: -1, torch.linalg.qr: lambda input, mode="reduced", out=None: -1,
torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1, torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1,
torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1, torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1,
torch.quantize_per_channel: lambda input, scales, zero_points, axis, dtype: -1, torch.quantize_per_channel: lambda input, scales, zero_points, axis, dtype: -1,
torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1, torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1,
torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1, torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1,
torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1, torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1,
torch.quantized_gru_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, torch.quantized_gru_cell: (
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1), lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
),
torch.quantized_lstm_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, torch.quantized_lstm_cell: (
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1), lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
torch.quantized_max_pool1d: (lambda input, kernel_size, stride=tuple(), padding=(0,), ),
dilation=(1,), ceil_mode=False: -1), torch.quantized_max_pool1d: (
torch.quantized_max_pool2d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0), lambda input, kernel_size, stride=tuple(), padding=(0,), dilation=(
dilation=(1, 1), ceil_mode=False: -1), 1,
torch.quantized_max_pool3d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0, 0), ), ceil_mode=False: -1
dilation=(1, 1, 1), ceil_mode=False: -1), ),
torch.quantized_rnn_relu_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, torch.quantized_max_pool2d: (
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1), lambda input, kernel_size, stride=tuple(), padding=(0, 0), dilation=(
torch.quantized_rnn_tanh_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, 1,
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1), 1,
), ceil_mode=False: -1
),
torch.quantized_max_pool3d: (
lambda input, kernel_size, stride=tuple(), padding=(0, 0, 0), dilation=(
1,
1,
1,
), ceil_mode=False: -1
),
torch.quantized_rnn_relu_cell: (
lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
),
torch.quantized_rnn_tanh_cell: (
lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
),
torch.rad2deg: lambda input, out=None: -1, torch.rad2deg: lambda input, out=None: -1,
torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
@ -1014,16 +1079,16 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.repeat_interleave: lambda input, dim=None: -1, torch.repeat_interleave: lambda input, dim=None: -1,
torch.reshape: lambda input, shape: -1, torch.reshape: lambda input, shape: -1,
torch.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1, torch.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, # noqa: B950
torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1, torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, # noqa: B950
torch.rnn_tanh_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1, torch.rnn_tanh_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
torch.roll: lambda input, shifts, dims=None: -1, torch.roll: lambda input, shifts, dims=None: -1,
torch.rot90: lambda input, k=1, dims=(0, 1): -1, torch.rot90: lambda input, k=1, dims=(0, 1): -1,
torch.round: lambda input, out=None: -1, torch.round: lambda input, out=None: -1,
torch.row_stack: lambda tensors, out=None: -1, # alias for torch.vstack torch.row_stack: lambda tensors, out=None: -1, # alias for torch.vstack
torch._rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1), torch._rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1),
torch.rrelu: lambda input, lower=1. / 8, upper=1. / 3, training=False, inplace=False: -1, torch.rrelu: lambda input, lower=1.0 / 8, upper=1.0 / 3, training=False, inplace=False: -1,
torch.rsqrt: lambda input, out=None: -1, torch.rsqrt: lambda input, out=None: -1,
torch.rsub: lambda input, other, alpha=1: -1, torch.rsub: lambda input, other, alpha=1: -1,
torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1, torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
@ -1031,7 +1096,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.scatter_add: lambda input, dim, index, src: -1, torch.scatter_add: lambda input, dim, index, src: -1,
torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1, torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1,
torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1, torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1, torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1, # noqa: B950
torch.select: lambda input, dim, index: -1, torch.select: lambda input, dim, index: -1,
torch.select_scatter: lambda input, src, dim, index: -1, torch.select_scatter: lambda input, src, dim, index: -1,
torch.slice_inverse: lambda input, src, dim=0, start=None, end=None, step=1: -1, torch.slice_inverse: lambda input, src, dim=0, start=None, end=None, step=1: -1,
@ -1061,8 +1126,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.stack: lambda tensors, dim=0, out=None: -1, torch.stack: lambda tensors, dim=0, out=None: -1,
torch.std: lambda input, dim=None: -1, torch.std: lambda input, dim=None: -1,
torch.std_mean: lambda input, dim=None: -1, torch.std_mean: lambda input, dim=None: -1,
torch.stft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, torch.stft: (
pad_mode='reflect', normalized=False, onesided=True, return_complex=None: -1), lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=None: -1 # noqa: B950
),
torch.sub: lambda input, other, out=None: -1, torch.sub: lambda input, other, out=None: -1,
torch.subtract: lambda input, other, out=None: -1, torch.subtract: lambda input, other, out=None: -1,
torch.sum: lambda input, dim=None: -1, torch.sum: lambda input, dim=None: -1,
@ -1164,9 +1230,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1, torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1,
torch.linalg.solve_triangular: lambda input, B, upper, left=True, unitriangular=False: -1, torch.linalg.solve_triangular: lambda input, B, upper, left=True, unitriangular=False: -1,
torch.tril: lambda input, diagonal=0, out=None: -1, torch.tril: lambda input, diagonal=0, out=None: -1,
torch.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, torch.triplet_margin_loss: (
lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 # noqa: B950
size_average=None, reduce=None, reduction='mean': -1), ),
torch.triu: lambda input, diagonal=0, out=None: -1, torch.triu: lambda input, diagonal=0, out=None: -1,
torch.true_divide: lambda input, other: -1, torch.true_divide: lambda input, other: -1,
torch.trunc: lambda input, out=None: -1, torch.trunc: lambda input, out=None: -1,
@ -1436,10 +1502,16 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1, torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1,
} }
privateuse1_backend_name = torch.utils.backend_registration._privateuse1_backend_name privateuse1_backend_name = (
torch.utils.backend_registration._privateuse1_backend_name
)
if hasattr(Tensor, privateuse1_backend_name): if hasattr(Tensor, privateuse1_backend_name):
ret[getattr(Tensor, privateuse1_backend_name)] = lambda self, device=None, non_blocking=False, **kwargs: -1 ret[
ret[getattr(Tensor, f'is_{privateuse1_backend_name}').__get__] = lambda self: -1 # noqa: B009 getattr(Tensor, privateuse1_backend_name)
] = lambda self, device=None, non_blocking=False, **kwargs: -1
ret[
getattr(Tensor, f"is_{privateuse1_backend_name}").__get__
] = lambda self: -1 # noqa: B009
ret2 = {} ret2 = {}
ignored = get_ignored_functions() ignored = get_ignored_functions()
@ -1458,11 +1530,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
# bitwise_<op> have dunder methods of the form __<op>__ # bitwise_<op> have dunder methods of the form __<op>__
# And so on. # And so on.
subname = k.__name__[len("bitwise_") :] subname = k.__name__[len("bitwise_") :]
names.extend([ names.extend(
"__" + subname + "__", ["__" + subname + "__", "__i" + subname + "__", "__r" + subname + "__"]
"__i" + subname + "__", )
"__r" + subname + "__"
])
for name in names: for name in names:
func = getattr(Tensor, name, None) func = getattr(Tensor, name, None)
@ -1472,6 +1542,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
ret.update(ret2) ret.update(ret2)
return ret return ret
def wrap_torch_function(dispatcher: Callable): def wrap_torch_function(dispatcher: Callable):
"""Wraps a given function with ``__torch_function__`` -related functionality. """Wraps a given function with ``__torch_function__`` -related functionality.
@ -1495,6 +1566,7 @@ def wrap_torch_function(dispatcher: Callable):
>>> def func(a): # This will make func dispatchable by __torch_function__ >>> def func(a): # This will make func dispatchable by __torch_function__
... return a + 0 ... return a + 0
""" """
def inner(func): def inner(func):
@functools.wraps(func) @functools.wraps(func)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
@ -1508,7 +1580,10 @@ def wrap_torch_function(dispatcher: Callable):
return inner return inner
def _get_overloaded_args(relevant_args: Iterable[Any], get_type_fn: Callable[[Any], Type] = None) -> List[Any]:
def _get_overloaded_args(
relevant_args: Iterable[Any], get_type_fn: Callable[[Any], Type] = None
) -> List[Any]:
"""Returns a list of arguments on which to call __torch_function__. """Returns a list of arguments on which to call __torch_function__.
Checks arguments in relevant_args for __torch_function__ implementations, Checks arguments in relevant_args for __torch_function__ implementations,
@ -1559,8 +1634,11 @@ def _get_overloaded_args(relevant_args: Iterable[Any], get_type_fn: Callable[[An
# #
# NB: Important to exclude _disabled_torch_function_impl, otherwise # NB: Important to exclude _disabled_torch_function_impl, otherwise
# https://github.com/pytorch/pytorch/issues/64687 # https://github.com/pytorch/pytorch/issues/64687
if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__') and if (
arg_type.__torch_function__ != torch._C._disabled_torch_function_impl): arg_type not in overloaded_types
and hasattr(arg_type, "__torch_function__")
and arg_type.__torch_function__ != torch._C._disabled_torch_function_impl
):
# Create lists explicitly for the first type (usually the only one # Create lists explicitly for the first type (usually the only one
# done) to avoid setting up the iterator for overloaded_args. # done) to avoid setting up the iterator for overloaded_args.
if overloaded_types: if overloaded_types:
@ -1581,7 +1659,8 @@ def _get_overloaded_args(relevant_args: Iterable[Any], get_type_fn: Callable[[An
def handle_torch_function( def handle_torch_function(
public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs) -> Any: public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs
) -> Any:
"""Implement a function with checks for ``__torch_function__`` overrides. """Implement a function with checks for ``__torch_function__`` overrides.
See torch::autograd::handle_torch_function for the equivalent of this See torch::autograd::handle_torch_function for the equivalent of this
@ -1636,11 +1715,16 @@ def handle_torch_function(
# This call needs to become a classmethod call in the future. # This call needs to become a classmethod call in the future.
# See https://github.com/pytorch/pytorch/issues/63767 # See https://github.com/pytorch/pytorch/issues/63767
torch_func_method = overloaded_arg.__torch_function__ torch_func_method = overloaded_arg.__torch_function__
if hasattr(torch_func_method, "__self__") and torch_func_method.__self__ is overloaded_arg and \ if (
torch_func_method is not torch._C._disabled_torch_function_impl: hasattr(torch_func_method, "__self__")
warnings.warn("Defining your `__torch_function__ as a plain method is deprecated and " and torch_func_method.__self__ is overloaded_arg
and torch_func_method is not torch._C._disabled_torch_function_impl
):
warnings.warn(
"Defining your `__torch_function__ as a plain method is deprecated and "
"will be an error in future, please define it as a classmethod.", "will be an error in future, please define it as a classmethod.",
DeprecationWarning) DeprecationWarning,
)
# Use `public_api` instead of `implementation` so __torch_function__ # Use `public_api` instead of `implementation` so __torch_function__
# implementations can do equality/identity comparisons. # implementations can do equality/identity comparisons.
@ -1649,15 +1733,16 @@ def handle_torch_function(
if result is not NotImplemented: if result is not NotImplemented:
return result return result
func_name = f'{public_api.__module__}.{public_api.__name__}' func_name = f"{public_api.__module__}.{public_api.__name__}"
msg = ( msg = (
f"no implementation found for '{func_name}' on types that implement " f"no implementation found for '{func_name}' on types that implement "
f'__torch_function__: {[type(arg) for arg in overloaded_args]}' f"__torch_function__: {[type(arg) for arg in overloaded_args]}"
) )
if _is_torch_function_mode_enabled(): if _is_torch_function_mode_enabled():
msg += f" nor in mode {_get_current_function_mode()}" msg += f" nor in mode {_get_current_function_mode()}"
raise TypeError(msg) raise TypeError(msg)
has_torch_function = _add_docstr( has_torch_function = _add_docstr(
_has_torch_function, _has_torch_function,
r"""Check for __torch_function__ implementations in the elements of an iterable r"""Check for __torch_function__ implementations in the elements of an iterable
@ -1678,7 +1763,7 @@ has_torch_function = _add_docstr(
________ ________
torch.is_tensor_like torch.is_tensor_like
Checks if something is a Tensor-like, including an exact ``Tensor``. Checks if something is a Tensor-like, including an exact ``Tensor``.
""" """,
) )
has_torch_function_unary = _add_docstr( has_torch_function_unary = _add_docstr(
@ -1689,7 +1774,7 @@ has_torch_function_unary = _add_docstr(
call: call:
`has_torch_function_unary(t)` `has_torch_function_unary(t)`
which skips unnecessary packing and unpacking work. which skips unnecessary packing and unpacking work.
""" """,
) )
has_torch_function_variadic = _add_docstr( has_torch_function_variadic = _add_docstr(
@ -1703,11 +1788,14 @@ has_torch_function_variadic = _add_docstr(
call: call:
`has_torch_function_variadic(a, b)` `has_torch_function_variadic(a, b)`
which skips unnecessary packing and unpacking work. which skips unnecessary packing and unpacking work.
""" """,
) )
@functools.lru_cache(None) @functools.lru_cache(None)
def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]: def _get_overridable_functions() -> (
Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]
):
overridable_funcs = collections.defaultdict(list) overridable_funcs = collections.defaultdict(list)
index = {} index = {}
tested_namespaces = [ tested_namespaces = [
@ -1725,21 +1813,21 @@ def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callab
ignore = False ignore = False
# ignore private functions or functions that are deleted in torch.__init__ # ignore private functions or functions that are deleted in torch.__init__
if namespace is not torch.Tensor: if namespace is not torch.Tensor:
if func_name.startswith('__'): if func_name.startswith("__"):
continue continue
elif func_name.startswith('_'): elif func_name.startswith("_"):
ignore = True ignore = True
elif func_name.endswith('_'): elif func_name.endswith("_"):
ignore = True ignore = True
elif not func_name[0].islower(): elif not func_name[0].islower():
ignore = True ignore = True
elif func_name == 'unique_dim': elif func_name == "unique_dim":
continue continue
else: else:
func = getattr(namespace, func_name) func = getattr(namespace, func_name)
if getattr(object, func_name, None) == func: if getattr(object, func_name, None) == func:
continue continue
if func_name == '__weakref__': if func_name == "__weakref__":
continue continue
func = getattr(namespace, func_name) func = getattr(namespace, func_name)
if namespace is torch.Tensor and getattr(object, func_name, None) == func: if namespace is torch.Tensor and getattr(object, func_name, None) == func:
@ -1757,9 +1845,13 @@ def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callab
if ignore: if ignore:
continue continue
if func.__get__ in get_ignored_functions(): if func.__get__ in get_ignored_functions():
msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions " msg = (
"but still has an explicit override") "{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
assert func.__get__ not in get_testing_overrides(), msg.format(namespace, func.__name__) "but still has an explicit override"
)
assert func.__get__ not in get_testing_overrides(), msg.format(
namespace, func.__name__
)
continue continue
else: else:
overridable_funcs[func].append(func.__get__) overridable_funcs[func].append(func.__get__)
@ -1775,13 +1867,18 @@ def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callab
# cannot be overriden by __torch_function__ # cannot be overriden by __torch_function__
if func in get_ignored_functions(): if func in get_ignored_functions():
msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions " msg = (
"but still has an explicit override") "{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
assert func not in get_testing_overrides(), msg.format(namespace, func.__name__) "but still has an explicit override"
)
assert func not in get_testing_overrides(), msg.format(
namespace, func.__name__
)
continue continue
overridable_funcs[namespace].append(func) overridable_funcs[namespace].append(func)
return overridable_funcs, index return overridable_funcs, index
@_disable_user_warnings @_disable_user_warnings
def get_overridable_functions() -> Dict[Any, List[Callable]]: def get_overridable_functions() -> Dict[Any, List[Callable]]:
"""List functions that are overridable via __torch_function__ """List functions that are overridable via __torch_function__
@ -1794,6 +1891,7 @@ def get_overridable_functions() -> Dict[Any, List[Callable]]:
""" """
return _get_overridable_functions()[0] return _get_overridable_functions()[0]
@_disable_user_warnings @_disable_user_warnings
def resolve_name(f): def resolve_name(f):
"""Get a human readable string name for a function passed to """Get a human readable string name for a function passed to
@ -1814,6 +1912,7 @@ def resolve_name(f):
return str(f) return str(f)
return _get_overridable_functions()[1].get(f) return _get_overridable_functions()[1].get(f)
@functools.lru_cache(None) @functools.lru_cache(None)
def _get_tensor_methods() -> Set[Callable]: def _get_tensor_methods() -> Set[Callable]:
"""Returns a set of the overridable methods on ``torch.Tensor``""" """Returns a set of the overridable methods on ``torch.Tensor``"""
@ -1821,6 +1920,7 @@ def _get_tensor_methods() -> Set[Callable]:
methods = set(overridable_funcs[torch.Tensor]) methods = set(overridable_funcs[torch.Tensor])
return methods return methods
@_disable_user_warnings @_disable_user_warnings
def is_tensor_method_or_property(func: Callable) -> bool: def is_tensor_method_or_property(func: Callable) -> bool:
""" """
@ -1846,6 +1946,7 @@ def is_tensor_method_or_property(func: Callable) -> bool:
""" """
return func in _get_tensor_methods() or func.__name__ == "__get__" return func in _get_tensor_methods() or func.__name__ == "__get__"
def is_tensor_like(inp): def is_tensor_like(inp):
""" """
Returns ``True`` if the passed-in input is a Tensor-like. Returns ``True`` if the passed-in input is a Tensor-like.
@ -1882,6 +1983,7 @@ def is_tensor_like(inp):
""" """
return type(inp) is torch.Tensor or hasattr(inp, "__torch_function__") return type(inp) is torch.Tensor or hasattr(inp, "__torch_function__")
class TorchFunctionMode: class TorchFunctionMode:
""" """
A ``TorchFunctionMode`` allows you to override the meaning of all A ``TorchFunctionMode`` allows you to override the meaning of all
@ -1912,6 +2014,7 @@ class TorchFunctionMode:
``enable_torch_function_mode(self, replace=self.inner)`` to make PyTorch ``enable_torch_function_mode(self, replace=self.inner)`` to make PyTorch
API self-referential (beware of infinite loops, in this case!) API self-referential (beware of infinite loops, in this case!)
""" """
inner: "TorchFunctionMode" inner: "TorchFunctionMode"
# Force metaclass to generate constructor at the base of the hierarchy # Force metaclass to generate constructor at the base of the hierarchy
@ -1930,7 +2033,9 @@ class TorchFunctionMode:
@classmethod @classmethod
def push(cls, *args, **kwargs): def push(cls, *args, **kwargs):
warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`") warnings.warn(
"`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`"
)
instance = cls(*args, **kwargs) instance = cls(*args, **kwargs)
return instance return instance
@ -1944,6 +2049,7 @@ def _get_current_function_mode_stack():
stack_len = _len_torch_function_stack() stack_len = _len_torch_function_stack()
return [_get_function_stack_at(i) for i in range(stack_len)] return [_get_function_stack_at(i) for i in range(stack_len)]
def _push_mode(mode): def _push_mode(mode):
_push_on_torch_function_stack(mode) _push_on_torch_function_stack(mode)
@ -1961,6 +2067,7 @@ def _pop_mode_temporarily():
finally: finally:
_push_mode(old) _push_mode(old)
class BaseTorchFunctionMode(TorchFunctionMode): class BaseTorchFunctionMode(TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None): def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None: if kwargs is None:

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import torch
from typing import Optional from typing import Optional
import torch
class SobolEngine: class SobolEngine:
r""" r"""
@ -48,8 +49,10 @@ class SobolEngine:
def __init__(self, dimension, scramble=False, seed=None): def __init__(self, dimension, scramble=False, seed=None):
if dimension > self.MAXDIM or dimension < 1: if dimension > self.MAXDIM or dimension < 1:
raise ValueError("Supported range of dimensionality " raise ValueError(
f"for SobolEngine is [1, {self.MAXDIM}]") "Supported range of dimensionality "
f"for SobolEngine is [1, {self.MAXDIM}]"
)
self.seed = seed self.seed = seed
self.scramble = scramble self.scramble = scramble
@ -57,7 +60,9 @@ class SobolEngine:
cpu = torch.device("cpu") cpu = torch.device("cpu")
self.sobolstate = torch.zeros(dimension, self.MAXBIT, device=cpu, dtype=torch.long) self.sobolstate = torch.zeros(
dimension, self.MAXBIT, device=cpu, dtype=torch.long
)
torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension) torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension)
if not self.scramble: if not self.scramble:
@ -69,8 +74,12 @@ class SobolEngine:
self._first_point = (self.quasi / 2**self.MAXBIT).reshape(1, -1) self._first_point = (self.quasi / 2**self.MAXBIT).reshape(1, -1)
self.num_generated = 0 self.num_generated = 0
def draw(self, n: int = 1, out: Optional[torch.Tensor] = None, def draw(
dtype: Optional[torch.dtype] = None) -> torch.Tensor: self,
n: int = 1,
out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
r""" r"""
Function to draw a sequence of :attr:`n` points from a Sobol sequence. Function to draw a sequence of :attr:`n` points from a Sobol sequence.
Note that the samples are dependent on the previous samples. The size Note that the samples are dependent on the previous samples. The size
@ -92,12 +101,22 @@ class SobolEngine:
result = self._first_point.to(dtype) result = self._first_point.to(dtype)
else: else:
result, self.quasi = torch._sobol_engine_draw( result, self.quasi = torch._sobol_engine_draw(
self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated, dtype=dtype, self.quasi,
n - 1,
self.sobolstate,
self.dimension,
self.num_generated,
dtype=dtype,
) )
result = torch.cat((self._first_point.to(dtype), result), dim=-2) result = torch.cat((self._first_point.to(dtype), result), dim=-2)
else: else:
result, self.quasi = torch._sobol_engine_draw( result, self.quasi = torch._sobol_engine_draw(
self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1, dtype=dtype, self.quasi,
n,
self.sobolstate,
self.dimension,
self.num_generated - 1,
dtype=dtype,
) )
self.num_generated += n self.num_generated += n
@ -108,8 +127,12 @@ class SobolEngine:
return result return result
def draw_base2(self, m: int, out: Optional[torch.Tensor] = None, def draw_base2(
dtype: Optional[torch.dtype] = None) -> torch.Tensor: self,
m: int,
out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
r""" r"""
Function to draw a sequence of :attr:`2**m` points from a Sobol sequence. Function to draw a sequence of :attr:`2**m` points from a Sobol sequence.
Note that the samples are dependent on the previous samples. The size Note that the samples are dependent on the previous samples. The size
@ -125,7 +148,8 @@ class SobolEngine:
n = 2**m n = 2**m
total_n = self.num_generated + n total_n = self.num_generated + n
if not (total_n & (total_n - 1) == 0): if not (total_n & (total_n - 1) == 0):
raise ValueError("The balance properties of Sobol' points require " raise ValueError(
"The balance properties of Sobol' points require "
f"n to be a power of 2. {self.num_generated} points have been " f"n to be a power of 2. {self.num_generated} points have been "
f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. " f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. "
"If you still want to do this, please use " "If you still want to do this, please use "
@ -151,9 +175,13 @@ class SobolEngine:
n (Int): The number of steps to fast-forward by. n (Int): The number of steps to fast-forward by.
""" """
if self.num_generated == 0: if self.num_generated == 0:
torch._sobol_engine_ff_(self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated) torch._sobol_engine_ff_(
self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated
)
else: else:
torch._sobol_engine_ff_(self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1) torch._sobol_engine_ff_(
self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1
)
self.num_generated += n self.num_generated += n
return self return self
@ -166,8 +194,12 @@ class SobolEngine:
cpu = torch.device("cpu") cpu = torch.device("cpu")
# Generate shift vector # Generate shift vector
shift_ints = torch.randint(2, (self.dimension, self.MAXBIT), device=cpu, generator=g) shift_ints = torch.randint(
self.shift = torch.mv(shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu))) 2, (self.dimension, self.MAXBIT), device=cpu, generator=g
)
self.shift = torch.mv(
shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu))
)
# Generate lower triangular matrices (stacked across dimensions) # Generate lower triangular matrices (stacked across dimensions)
ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT) ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT)
@ -176,9 +208,9 @@ class SobolEngine:
torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension) torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension)
def __repr__(self): def __repr__(self):
fmt_string = [f'dimension={self.dimension}'] fmt_string = [f"dimension={self.dimension}"]
if self.scramble: if self.scramble:
fmt_string += ['scramble=True'] fmt_string += ["scramble=True"]
if self.seed is not None: if self.seed is not None:
fmt_string += [f'seed={self.seed}'] fmt_string += [f"seed={self.seed}"]
return self.__class__.__name__ + '(' + ', '.join(fmt_string) + ')' return self.__class__.__name__ + "(" + ", ".join(fmt_string) + ")"

View File

@ -1,10 +1,10 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import contextlib import contextlib
from typing import Generator
import warnings import warnings
from typing import Generator
from torch._C import default_generator
import torch import torch
from torch._C import default_generator
def set_rng_state(new_state: torch.Tensor) -> None: def set_rng_state(new_state: torch.Tensor) -> None:
@ -46,10 +46,12 @@ def manual_seed(seed) -> torch._C.Generator:
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
import torch.mps import torch.mps
if not torch.mps._is_in_bad_fork(): if not torch.mps._is_in_bad_fork():
torch.mps.manual_seed(seed) torch.mps.manual_seed(seed)
import torch.xpu import torch.xpu
if not torch.xpu._is_in_bad_fork(): if not torch.xpu._is_in_bad_fork():
torch.xpu.manual_seed_all(seed) torch.xpu.manual_seed_all(seed)
@ -69,10 +71,12 @@ def seed() -> int:
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
import torch.mps import torch.mps
if not torch.mps._is_in_bad_fork(): if not torch.mps._is_in_bad_fork():
torch.mps.manual_seed(seed) torch.mps.manual_seed(seed)
import torch.xpu import torch.xpu
if not torch.xpu._is_in_bad_fork(): if not torch.xpu._is_in_bad_fork():
torch.xpu.manual_seed_all(seed) torch.xpu.manual_seed_all(seed)
@ -95,7 +99,9 @@ def _seed_custom_device(seed) -> None:
custom_device_mod = getattr(torch, custom_backend_name) custom_device_mod = getattr(torch, custom_backend_name)
_bad_fork_name = "_is_in_bad_fork" _bad_fork_name = "_is_in_bad_fork"
_seed_all_name = "manual_seed_all" _seed_all_name = "manual_seed_all"
if hasattr(custom_device_mod, _bad_fork_name) and hasattr(custom_device_mod, _seed_all_name): if hasattr(custom_device_mod, _bad_fork_name) and hasattr(
custom_device_mod, _seed_all_name
):
if not getattr(custom_device_mod, _bad_fork_name)(): if not getattr(custom_device_mod, _bad_fork_name)():
getattr(custom_device_mod, _seed_all_name)(seed) getattr(custom_device_mod, _seed_all_name)(seed)
else: else:
@ -117,7 +123,13 @@ _fork_rng_warned_already = False
@contextlib.contextmanager @contextlib.contextmanager
def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices", device_type="cuda") -> Generator: def fork_rng(
devices=None,
enabled=True,
_caller="fork_rng",
_devices_kw="devices",
device_type="cuda",
) -> Generator:
""" """
Forks the RNG, so that when you return, the RNG is reset Forks the RNG, so that when you return, the RNG is reset
to the state that it was previously in. to the state that it was previously in.
@ -138,8 +150,10 @@ def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="device
device_type = torch.device(device_type).type device_type = torch.device(device_type).type
device_mod = getattr(torch, device_type, None) device_mod = getattr(torch, device_type, None)
if device_mod is None: if device_mod is None:
raise RuntimeError(f"torch has no module of `{device_type}`, you should register " + raise RuntimeError(
"a module by `torch._register_device_module`.") f"torch has no module of `{device_type}`, you should register "
+ "a module by `torch._register_device_module`."
)
global _fork_rng_warned_already global _fork_rng_warned_already
# Internal arguments: # Internal arguments:
@ -153,7 +167,8 @@ def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="device
if devices is None: if devices is None:
num_devices = device_mod.device_count() num_devices = device_mod.device_count()
if num_devices > 1 and not _fork_rng_warned_already: if num_devices > 1 and not _fork_rng_warned_already:
message = (f"{device_type.upper()} reports that you have {num_devices} available devices, and " message = (
f"{device_type.upper()} reports that you have {num_devices} available devices, and "
f"you have used {_caller} without explicitly specifying which devices are being used. " f"you have used {_caller} without explicitly specifying which devices are being used. "
f"For safety, we initialize *every* {device_type.upper()} device by default, which can " f"For safety, we initialize *every* {device_type.upper()} device by default, which can "
f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only" f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only"
@ -163,7 +178,8 @@ def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="device
"set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, " "set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, "
f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices " f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices "
f"and suppress this warning, set the '{_devices_kw}' keyword argument to " f"and suppress this warning, set the '{_devices_kw}' keyword argument to "
f"`range(torch.{device_type}.device_count())`.") f"`range(torch.{device_type}.device_count())`."
)
warnings.warn(message) warnings.warn(message)
_fork_rng_warned_already = True _fork_rng_warned_already = True
devices = list(range(num_devices)) devices = list(range(num_devices))

View File

@ -1,8 +1,9 @@
import torch
import inspect import inspect
import torch
from torch.utils._pytree import register_pytree_node, SequenceKey from torch.utils._pytree import register_pytree_node, SequenceKey
__all__ = ["pytree_register_structseq", "all_return_types"] __all__ = ["pytree_register_structseq", "all_return_types"]
all_return_types = [] all_return_types = []
@ -10,6 +11,7 @@ all_return_types = []
# error: Module has no attribute "_return_types" # error: Module has no attribute "_return_types"
return_types = torch._C._return_types # type: ignore[attr-defined] return_types = torch._C._return_types # type: ignore[attr-defined]
def pytree_register_structseq(cls): def pytree_register_structseq(cls):
def structseq_flatten(structseq): def structseq_flatten(structseq):
return list(structseq), None return list(structseq), None
@ -28,14 +30,15 @@ def pytree_register_structseq(cls):
flatten_with_keys_fn=structseq_flatten_with_keys, flatten_with_keys_fn=structseq_flatten_with_keys,
) )
for name in dir(return_types): for name in dir(return_types):
if name.startswith('__'): if name.startswith("__"):
continue continue
_attr = getattr(return_types, name) _attr = getattr(return_types, name)
globals()[name] = _attr globals()[name] = _attr
if not name.startswith('_'): if not name.startswith("_"):
__all__.append(name) __all__.append(name)
all_return_types.append(_attr) all_return_types.append(_attr)

File diff suppressed because it is too large Load Diff

View File

@ -2,8 +2,9 @@
from typing import Any, Iterable from typing import Any, Iterable
from ._vendor.packaging.version import InvalidVersion, Version from torch._vendor.packaging.version import InvalidVersion, Version
from .version import __version__ as internal_version from torch.version import __version__ as internal_version
__all__ = ["TorchVersion"] __all__ = ["TorchVersion"]

View File

@ -1,9 +1,14 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import builtins import builtins
from typing import Any, List, Optional, Sequence, Tuple, Union from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
import torch import torch
if TYPE_CHECKING:
from torch.autograd.graph import GradientEdge
# Convenience aliases for common composite types that we need # Convenience aliases for common composite types that we need
# to talk about in PyTorch # to talk about in PyTorch
@ -11,8 +16,8 @@ _TensorOrTensors = Union[torch.Tensor, Sequence[torch.Tensor]]
_TensorOrTensorsOrGradEdge = Union[ _TensorOrTensorsOrGradEdge = Union[
torch.Tensor, torch.Tensor,
Sequence[torch.Tensor], Sequence[torch.Tensor],
"torch.autograd.graph.GradientEdge", "GradientEdge",
Sequence["torch.autograd.graph.GradientEdge"], Sequence["GradientEdge"],
] ]
# In some cases, these basic types are shadowed by corresponding # In some cases, these basic types are shadowed by corresponding