mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE] enable UFMT for top-level files torch/*.py (#127707)
Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127707 Approved by: https://github.com/ezyang
This commit is contained in:
parent
cc231a8e2b
commit
dd143d44cc
|
|
@ -1556,7 +1556,6 @@ exclude_patterns = [
|
|||
'torch/distributed/tensor/parallel/style.py',
|
||||
'torch/fft/__init__.py',
|
||||
'torch/func/__init__.py',
|
||||
'torch/functional.py',
|
||||
'torch/futures/__init__.py',
|
||||
'torch/fx/__init__.py',
|
||||
'torch/fx/_compatibility.py',
|
||||
|
|
@ -1642,8 +1641,6 @@ exclude_patterns = [
|
|||
'torch/fx/subgraph_rewriter.py',
|
||||
'torch/fx/tensor_type.py',
|
||||
'torch/fx/traceback.py',
|
||||
'torch/hub.py',
|
||||
'torch/library.py',
|
||||
'torch/linalg/__init__.py',
|
||||
'torch/monitor/__init__.py',
|
||||
'torch/nested/__init__.py',
|
||||
|
|
@ -1767,11 +1764,6 @@ exclude_patterns = [
|
|||
'torch/nn/utils/rnn.py',
|
||||
'torch/nn/utils/spectral_norm.py',
|
||||
'torch/nn/utils/weight_norm.py',
|
||||
'torch/overrides.py',
|
||||
'torch/quasirandom.py',
|
||||
'torch/random.py',
|
||||
'torch/return_types.py',
|
||||
'torch/serialization.py',
|
||||
'torch/signal/__init__.py',
|
||||
'torch/signal/windows/__init__.py',
|
||||
'torch/signal/windows/windows.py',
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
|
|
@ -31,6 +30,7 @@ from torch.utils import _pytree as pytree
|
|||
from torch.utils._traceback import CapturedTraceback
|
||||
from torch.utils.weak import WeakTensorKeyDictionary
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -40,7 +40,6 @@ if TYPE_CHECKING:
|
|||
# Import the following modules during type checking to enable code intelligence features,
|
||||
# such as auto-completion in tools like pylance, even when these modules are not explicitly
|
||||
# imported in user code.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
|
|
@ -176,7 +175,7 @@ class Guard:
|
|||
def sort_key(self):
|
||||
# Put the duplicate input guards at the end. The duplicate guards have
|
||||
# two sources while guard.name only considers one source.
|
||||
from ._dynamo.guards import GuardBuilder
|
||||
from torch._dynamo.guards import GuardBuilder
|
||||
|
||||
is_duplicate_input = (
|
||||
isinstance(self.create_fn, functools.partial)
|
||||
|
|
|
|||
|
|
@ -7,9 +7,8 @@
|
|||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from . import _linalg_utils as _utils
|
||||
from .overrides import handle_torch_function, has_torch_function
|
||||
from torch import _linalg_utils as _utils, Tensor
|
||||
from torch.overrides import handle_torch_function, has_torch_function
|
||||
|
||||
|
||||
__all__ = ["lobpcg"]
|
||||
|
|
|
|||
|
|
@ -6,9 +6,8 @@ __all__ = ["svd_lowrank", "pca_lowrank"]
|
|||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from . import _linalg_utils as _utils
|
||||
from .overrides import handle_torch_function, has_torch_function
|
||||
from torch import _linalg_utils as _utils, Tensor
|
||||
from torch.overrides import handle_torch_function, has_torch_function
|
||||
|
||||
|
||||
def get_approximate_basis(
|
||||
|
|
|
|||
|
|
@ -761,22 +761,22 @@ class Tensor(torch._C.TensorBase):
|
|||
return torch.norm(self, p, dim, keepdim, dtype=dtype)
|
||||
|
||||
def solve(self, other):
|
||||
from ._linalg_utils import solve
|
||||
from torch._linalg_utils import solve
|
||||
|
||||
return solve(self, other)
|
||||
|
||||
def lstsq(self, other):
|
||||
from ._linalg_utils import lstsq
|
||||
from torch._linalg_utils import lstsq
|
||||
|
||||
return lstsq(self, other)
|
||||
|
||||
def eig(self, eigenvectors=False):
|
||||
from ._linalg_utils import eig
|
||||
from torch._linalg_utils import eig
|
||||
|
||||
return eig(self, eigenvectors=eigenvectors)
|
||||
|
||||
def symeig(self, eigenvectors=False):
|
||||
from ._linalg_utils import _symeig
|
||||
from torch._linalg_utils import _symeig
|
||||
|
||||
return _symeig(self, eigenvectors=eigenvectors)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,47 +1,46 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import (
|
||||
List, Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING
|
||||
)
|
||||
import operator
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch._C import _add_docstr
|
||||
import torch.nn.functional as F
|
||||
from ._lowrank import svd_lowrank, pca_lowrank
|
||||
from .overrides import (
|
||||
has_torch_function, has_torch_function_unary, has_torch_function_variadic,
|
||||
handle_torch_function)
|
||||
from ._jit_internal import boolean_dispatch
|
||||
from ._jit_internal import _overload as overload
|
||||
from torch import _VF, Tensor
|
||||
from torch._C import _add_docstr
|
||||
from torch._jit_internal import _overload as overload, boolean_dispatch
|
||||
from torch._lowrank import pca_lowrank, svd_lowrank
|
||||
from torch.overrides import (
|
||||
handle_torch_function,
|
||||
has_torch_function,
|
||||
has_torch_function_unary,
|
||||
has_torch_function_variadic,
|
||||
)
|
||||
|
||||
Tensor = torch.Tensor
|
||||
from torch import _VF
|
||||
|
||||
__all__ = [
|
||||
'atleast_1d',
|
||||
'atleast_2d',
|
||||
'atleast_3d',
|
||||
'align_tensors',
|
||||
'broadcast_shapes',
|
||||
'broadcast_tensors',
|
||||
'cartesian_prod',
|
||||
'block_diag',
|
||||
'cdist',
|
||||
'chain_matmul',
|
||||
'einsum',
|
||||
'istft',
|
||||
'lu',
|
||||
'norm',
|
||||
'meshgrid',
|
||||
'pca_lowrank',
|
||||
'split',
|
||||
'stft',
|
||||
'svd_lowrank',
|
||||
'tensordot',
|
||||
'unique',
|
||||
'unique_consecutive',
|
||||
'unravel_index',
|
||||
"atleast_1d",
|
||||
"atleast_2d",
|
||||
"atleast_3d",
|
||||
"align_tensors",
|
||||
"broadcast_shapes",
|
||||
"broadcast_tensors",
|
||||
"cartesian_prod",
|
||||
"block_diag",
|
||||
"cdist",
|
||||
"chain_matmul",
|
||||
"einsum",
|
||||
"istft",
|
||||
"lu",
|
||||
"norm",
|
||||
"meshgrid",
|
||||
"pca_lowrank",
|
||||
"split",
|
||||
"stft",
|
||||
"svd_lowrank",
|
||||
"tensordot",
|
||||
"unique",
|
||||
"unique_consecutive",
|
||||
"unravel_index",
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -124,16 +123,25 @@ def broadcast_shapes(*shapes):
|
|||
if isinstance(shape, (tuple, list)):
|
||||
for i in range(-1, -1 - len(shape), -1):
|
||||
if shape[i] < 0:
|
||||
raise RuntimeError(f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})")
|
||||
raise RuntimeError(
|
||||
f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})"
|
||||
)
|
||||
# NB: result is initialized to 1 so this is effectively an
|
||||
# equals one test
|
||||
if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious(shape[i] == result[i]):
|
||||
if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious(
|
||||
shape[i] == result[i]
|
||||
):
|
||||
continue
|
||||
if result[i] != 1:
|
||||
raise RuntimeError("Shape mismatch: objects cannot be broadcast to a single shape")
|
||||
raise RuntimeError(
|
||||
"Shape mismatch: objects cannot be broadcast to a single shape"
|
||||
)
|
||||
result[i] = shape[i]
|
||||
else:
|
||||
raise RuntimeError("Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", shape)
|
||||
raise RuntimeError(
|
||||
"Input shapes should be of type ints, a tuple of ints, or a list of ints, got ",
|
||||
shape,
|
||||
)
|
||||
return torch.Size(result)
|
||||
else:
|
||||
# with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail
|
||||
|
|
@ -188,7 +196,8 @@ def split(
|
|||
"""
|
||||
if has_torch_function_unary(tensor):
|
||||
return handle_torch_function(
|
||||
split, (tensor,), tensor, split_size_or_sections, dim=dim)
|
||||
split, (tensor,), tensor, split_size_or_sections, dim=dim
|
||||
)
|
||||
# Overwriting reason:
|
||||
# This dispatches to two ATen functions depending on the type of
|
||||
# split_size_or_sections. The branching code is in _tensor.py, which we
|
||||
|
|
@ -335,10 +344,13 @@ def einsum(*args: Any) -> Tensor:
|
|||
[ 0.3311, 5.5201, -3.0356]])
|
||||
"""
|
||||
import torch.backends.opt_einsum as opt_einsum
|
||||
|
||||
# This wrapper exists to support variadic args.
|
||||
if len(args) < 2:
|
||||
raise ValueError('einsum(): must specify the equation string and at least one operand, '
|
||||
'or at least one operand and its subscripts list')
|
||||
raise ValueError(
|
||||
"einsum(): must specify the equation string and at least one operand, "
|
||||
"or at least one operand and its subscripts list"
|
||||
)
|
||||
|
||||
equation = None
|
||||
operands = None
|
||||
|
|
@ -350,19 +362,21 @@ def einsum(*args: Any) -> Tensor:
|
|||
# input operands into a tensorlist (List[Tensor]).
|
||||
def parse_subscript(n: int) -> str:
|
||||
if n == Ellipsis:
|
||||
return '...'
|
||||
return "..."
|
||||
if n >= 0 and n < 26:
|
||||
return chr(ord('A') + n)
|
||||
return chr(ord("A") + n)
|
||||
if n >= 26 and n < 52:
|
||||
return chr(ord('a') + n - 26)
|
||||
raise ValueError('einsum(): subscript in subscript list is not within the valid range [0, 52)')
|
||||
return chr(ord("a") + n - 26)
|
||||
raise ValueError(
|
||||
"einsum(): subscript in subscript list is not within the valid range [0, 52)"
|
||||
)
|
||||
|
||||
# Parse subscripts for input operands
|
||||
equation = ','.join(''.join(parse_subscript(s) for s in l) for l in args[1::2])
|
||||
equation = ",".join("".join(parse_subscript(s) for s in l) for l in args[1::2])
|
||||
|
||||
# Parse optional output subscripts (provided when the number of arguments is odd)
|
||||
if len(args) % 2 == 1:
|
||||
equation += '->' + ''.join(parse_subscript(s) for s in args[-1])
|
||||
equation += "->" + "".join(parse_subscript(s) for s in args[-1])
|
||||
operands = args[:-1:2]
|
||||
else:
|
||||
operands = args[::2]
|
||||
|
|
@ -388,7 +402,9 @@ def einsum(*args: Any) -> Tensor:
|
|||
path = None
|
||||
if opt_einsum.is_available():
|
||||
_opt_einsum = opt_einsum.get_opt_einsum()
|
||||
tupled_path = _opt_einsum.contract_path(equation, *operands, optimize=opt_einsum.strategy)[0]
|
||||
tupled_path = _opt_einsum.contract_path(
|
||||
equation, *operands, optimize=opt_einsum.strategy
|
||||
)[0]
|
||||
# flatten path for dispatching to C++
|
||||
path = [item for pair in tupled_path for item in pair]
|
||||
return _VF.einsum(equation, operands, path=path) # type: ignore[attr-defined]
|
||||
|
|
@ -397,10 +413,13 @@ def einsum(*args: Any) -> Tensor:
|
|||
# This wrapper exists to support variadic args.
|
||||
if TYPE_CHECKING:
|
||||
# The JIT doesn't understand Union, so only add type annotation for mypy
|
||||
def meshgrid(*tensors: Union[Tensor, List[Tensor]],
|
||||
indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
|
||||
def meshgrid(
|
||||
*tensors: Union[Tensor, List[Tensor]], indexing: Optional[str] = None
|
||||
) -> Tuple[Tensor, ...]:
|
||||
return _meshgrid(*tensors, indexing=indexing)
|
||||
|
||||
else:
|
||||
|
||||
def meshgrid(*tensors, indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
|
||||
r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors.
|
||||
|
||||
|
|
@ -509,15 +528,22 @@ def _meshgrid(*tensors, indexing: Optional[str]):
|
|||
# kwarg for forward compatibility reasons.
|
||||
#
|
||||
# Remove this two weeks after landing.
|
||||
kwargs = {} if indexing is None else {'indexing': indexing}
|
||||
kwargs = {} if indexing is None else {"indexing": indexing}
|
||||
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
|
||||
win_length: Optional[int] = None, window: Optional[Tensor] = None,
|
||||
center: bool = True, pad_mode: str = 'reflect', normalized: bool = False,
|
||||
onesided: Optional[bool] = None,
|
||||
return_complex: Optional[bool] = None) -> Tensor:
|
||||
def stft(
|
||||
input: Tensor,
|
||||
n_fft: int,
|
||||
hop_length: Optional[int] = None,
|
||||
win_length: Optional[int] = None,
|
||||
window: Optional[Tensor] = None,
|
||||
center: bool = True,
|
||||
pad_mode: str = "reflect",
|
||||
normalized: bool = False,
|
||||
onesided: Optional[bool] = None,
|
||||
return_complex: Optional[bool] = None,
|
||||
) -> Tensor:
|
||||
r"""Short-time Fourier transform (STFT).
|
||||
|
||||
.. warning::
|
||||
|
|
@ -652,9 +678,19 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
|
|||
"""
|
||||
if has_torch_function_unary(input):
|
||||
return handle_torch_function(
|
||||
stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
|
||||
window=window, center=center, pad_mode=pad_mode, normalized=normalized,
|
||||
onesided=onesided, return_complex=return_complex)
|
||||
stft,
|
||||
(input,),
|
||||
input,
|
||||
n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
center=center,
|
||||
pad_mode=pad_mode,
|
||||
normalized=normalized,
|
||||
onesided=onesided,
|
||||
return_complex=return_complex,
|
||||
)
|
||||
# NOTE: Do not edit. This code will be removed once the forward-compatibility
|
||||
# period is over for PR #73432
|
||||
if center:
|
||||
|
|
@ -663,8 +699,16 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
|
|||
pad = int(n_fft // 2)
|
||||
input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
|
||||
input = input.view(input.shape[-signal_dim:])
|
||||
return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined]
|
||||
normalized, onesided, return_complex)
|
||||
return _VF.stft( # type: ignore[attr-defined]
|
||||
input,
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
window,
|
||||
normalized,
|
||||
onesided,
|
||||
return_complex,
|
||||
)
|
||||
|
||||
|
||||
istft = _add_docstr(
|
||||
|
|
@ -746,7 +790,8 @@ Args:
|
|||
Returns:
|
||||
Tensor: Least squares estimation of the original signal of shape `(B?, length)` where
|
||||
`B?` is an optional batch dimension from the input tensor.
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -758,9 +803,13 @@ else:
|
|||
_unique_impl_out = Tuple[Tensor, Tensor, Tensor]
|
||||
|
||||
|
||||
def _unique_impl(input: Tensor, sorted: bool = True,
|
||||
return_inverse: bool = False, return_counts: bool = False,
|
||||
dim: Optional[int] = None) -> _unique_impl_out:
|
||||
def _unique_impl(
|
||||
input: Tensor,
|
||||
sorted: bool = True,
|
||||
return_inverse: bool = False,
|
||||
return_counts: bool = False,
|
||||
dim: Optional[int] = None,
|
||||
) -> _unique_impl_out:
|
||||
r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> Tuple[Tensor, Tensor, Tensor]
|
||||
|
||||
Returns the unique elements of the input tensor.
|
||||
|
|
@ -896,8 +945,14 @@ def _unique_impl(input: Tensor, sorted: bool = True,
|
|||
"""
|
||||
if has_torch_function_unary(input):
|
||||
return handle_torch_function(
|
||||
unique, (input,), input, sorted=sorted, return_inverse=return_inverse,
|
||||
return_counts=return_counts, dim=dim)
|
||||
unique,
|
||||
(input,),
|
||||
input,
|
||||
sorted=sorted,
|
||||
return_inverse=return_inverse,
|
||||
return_counts=return_counts,
|
||||
dim=dim,
|
||||
)
|
||||
|
||||
if dim is not None:
|
||||
output, inverse_indices, counts = _VF.unique_dim(
|
||||
|
|
@ -917,9 +972,12 @@ def _unique_impl(input: Tensor, sorted: bool = True,
|
|||
return output, inverse_indices, counts
|
||||
|
||||
|
||||
def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False,
|
||||
return_counts: bool = False,
|
||||
dim: Optional[int] = None) -> _unique_impl_out:
|
||||
def _unique_consecutive_impl(
|
||||
input: Tensor,
|
||||
return_inverse: bool = False,
|
||||
return_counts: bool = False,
|
||||
dim: Optional[int] = None,
|
||||
) -> _unique_impl_out:
|
||||
r"""Eliminates all but the first element from every consecutive group of equivalent elements.
|
||||
|
||||
.. note:: This function is different from :func:`torch.unique` in the sense that this function
|
||||
|
|
@ -971,14 +1029,22 @@ def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False,
|
|||
"""
|
||||
if has_torch_function_unary(input):
|
||||
return handle_torch_function(
|
||||
unique_consecutive, (input,), input, return_inverse=return_inverse,
|
||||
return_counts=return_counts, dim=dim)
|
||||
unique_consecutive,
|
||||
(input,),
|
||||
input,
|
||||
return_inverse=return_inverse,
|
||||
return_counts=return_counts,
|
||||
dim=dim,
|
||||
)
|
||||
output, inverse_indices, counts = _VF.unique_consecutive( # type: ignore[attr-defined]
|
||||
input, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
|
||||
input, return_inverse=return_inverse, return_counts=return_counts, dim=dim
|
||||
)
|
||||
return output, inverse_indices, counts
|
||||
|
||||
|
||||
def _return_counts(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
|
||||
def _return_counts(
|
||||
input, sorted=True, return_inverse=False, return_counts=False, dim=None
|
||||
):
|
||||
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
|
||||
|
||||
if has_torch_function_unary(input):
|
||||
|
|
@ -988,7 +1054,9 @@ def _return_counts(input, sorted=True, return_inverse=False, return_counts=False
|
|||
return output, counts
|
||||
|
||||
|
||||
def _return_output(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
|
||||
def _return_output(
|
||||
input, sorted=True, return_inverse=False, return_counts=False, dim=None
|
||||
):
|
||||
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor
|
||||
|
||||
if has_torch_function_unary(input):
|
||||
|
|
@ -998,59 +1066,72 @@ def _return_output(input, sorted=True, return_inverse=False, return_counts=False
|
|||
return output
|
||||
|
||||
|
||||
def _return_inverse(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
|
||||
def _return_inverse(
|
||||
input, sorted=True, return_inverse=False, return_counts=False, dim=None
|
||||
):
|
||||
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
|
||||
|
||||
if has_torch_function_unary(input):
|
||||
return _unique_impl(input, sorted, return_inverse, return_counts, dim)
|
||||
|
||||
output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
|
||||
output, inverse_indices, _ = _unique_impl(
|
||||
input, sorted, return_inverse, return_counts, dim
|
||||
)
|
||||
return output, inverse_indices
|
||||
|
||||
|
||||
_return_inverse_false = boolean_dispatch(
|
||||
arg_name='return_counts',
|
||||
arg_name="return_counts",
|
||||
arg_index=3,
|
||||
default=False,
|
||||
if_true=_return_counts,
|
||||
if_false=_return_output,
|
||||
module_name=__name__,
|
||||
func_name='unique')
|
||||
func_name="unique",
|
||||
)
|
||||
|
||||
_return_inverse_true = boolean_dispatch(
|
||||
arg_name='return_counts',
|
||||
arg_name="return_counts",
|
||||
arg_index=3,
|
||||
default=False,
|
||||
if_true=_unique_impl,
|
||||
if_false=_return_inverse,
|
||||
module_name=__name__,
|
||||
func_name='unique')
|
||||
func_name="unique",
|
||||
)
|
||||
|
||||
# The return type of unique depends on `return_inverse`, and `return_counts` so in order to
|
||||
# resolve the output type in TorchScript we need to statically know the value of both parameters
|
||||
|
||||
unique = boolean_dispatch(
|
||||
arg_name='return_inverse',
|
||||
arg_name="return_inverse",
|
||||
arg_index=2,
|
||||
default=False,
|
||||
if_true=_return_inverse_true,
|
||||
if_false=_return_inverse_false,
|
||||
module_name=__name__,
|
||||
func_name='unique')
|
||||
func_name="unique",
|
||||
)
|
||||
unique.__doc__ = _unique_impl.__doc__
|
||||
|
||||
|
||||
def _consecutive_return_counts(input, return_inverse=False, return_counts=False, dim=None):
|
||||
def _consecutive_return_counts(
|
||||
input, return_inverse=False, return_counts=False, dim=None
|
||||
):
|
||||
# type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
|
||||
|
||||
if has_torch_function_unary(input):
|
||||
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
|
||||
|
||||
output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
|
||||
output, _, counts = _unique_consecutive_impl(
|
||||
input, return_inverse, return_counts, dim
|
||||
)
|
||||
return output, counts
|
||||
|
||||
|
||||
def _consecutive_return_output(input, return_inverse=False, return_counts=False, dim=None):
|
||||
def _consecutive_return_output(
|
||||
input, return_inverse=False, return_counts=False, dim=None
|
||||
):
|
||||
# type: (Tensor, bool, bool, Optional[int]) -> Tensor
|
||||
|
||||
if has_torch_function_unary(input):
|
||||
|
|
@ -1060,45 +1141,52 @@ def _consecutive_return_output(input, return_inverse=False, return_counts=False,
|
|||
return output
|
||||
|
||||
|
||||
def _consecutive_return_inverse(input, return_inverse=False, return_counts=False, dim=None):
|
||||
def _consecutive_return_inverse(
|
||||
input, return_inverse=False, return_counts=False, dim=None
|
||||
):
|
||||
# type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
|
||||
|
||||
if has_torch_function_unary(input):
|
||||
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
|
||||
|
||||
output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
|
||||
output, inverse_indices, _ = _unique_consecutive_impl(
|
||||
input, return_inverse, return_counts, dim
|
||||
)
|
||||
return output, inverse_indices
|
||||
|
||||
|
||||
_consecutive_return_inverse_false = boolean_dispatch(
|
||||
arg_name='return_counts',
|
||||
arg_name="return_counts",
|
||||
arg_index=1,
|
||||
default=False,
|
||||
if_true=_consecutive_return_counts,
|
||||
if_false=_consecutive_return_output,
|
||||
module_name=__name__,
|
||||
func_name='unique_consecutive')
|
||||
func_name="unique_consecutive",
|
||||
)
|
||||
|
||||
_consecutive_return_inverse_true = boolean_dispatch(
|
||||
arg_name='return_counts',
|
||||
arg_name="return_counts",
|
||||
arg_index=1,
|
||||
default=False,
|
||||
if_true=_unique_consecutive_impl,
|
||||
if_false=_consecutive_return_inverse,
|
||||
module_name=__name__,
|
||||
func_name='unique_consecutive')
|
||||
func_name="unique_consecutive",
|
||||
)
|
||||
|
||||
# The return type of unique depends on `return_inverse`, and `return_counts` so in order to
|
||||
# resolve the output type in TorchScript we need to statically know the value of both parameters
|
||||
|
||||
unique_consecutive = boolean_dispatch(
|
||||
arg_name='return_inverse',
|
||||
arg_name="return_inverse",
|
||||
arg_index=2,
|
||||
default=False,
|
||||
if_true=_consecutive_return_inverse_true,
|
||||
if_false=_consecutive_return_inverse_false,
|
||||
module_name=__name__,
|
||||
func_name='unique_consecutive')
|
||||
func_name="unique_consecutive",
|
||||
)
|
||||
unique_consecutive.__doc__ = _unique_consecutive_impl.__doc__
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -1106,24 +1194,50 @@ if TYPE_CHECKING:
|
|||
# There's no good way to use this type annotation without breaking JIT
|
||||
# overloads. So leave untyped for mypy for now.
|
||||
else:
|
||||
|
||||
@overload
|
||||
def tensordot(a, b, dims: int = 2, out: Optional[torch.Tensor] = None):
|
||||
def tensordot(
|
||||
a,
|
||||
b,
|
||||
dims: int = 2,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
@overload # noqa: F811
|
||||
def tensordot(a, b, dims: Tuple[List[int], List[int]], out: Optional[torch.Tensor] = None): # noqa: F811
|
||||
@overload
|
||||
def tensordot( # noqa: F811
|
||||
a,
|
||||
b,
|
||||
dims: Tuple[List[int], List[int]],
|
||||
out: Optional[torch.Tensor] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
@overload # noqa: F811
|
||||
def tensordot(a, b, dims: List[List[int]], out: Optional[torch.Tensor] = None): # noqa: F811
|
||||
@overload
|
||||
def tensordot( # noqa: F811
|
||||
a,
|
||||
b,
|
||||
dims: List[List[int]],
|
||||
out: Optional[torch.Tensor] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
@overload # noqa: F811
|
||||
def tensordot(a, b, dims: torch.Tensor, out: Optional[torch.Tensor] = None): # noqa: F811
|
||||
@overload
|
||||
def tensordot( # noqa: F811
|
||||
a,
|
||||
b,
|
||||
dims: torch.Tensor,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def tensordot(a, b, dims=2, out: Optional[torch.Tensor] = None): # noqa: F811
|
||||
def tensordot( # noqa: F811
|
||||
a,
|
||||
b,
|
||||
dims=2,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""Returns a contraction of a and b over multiple dimensions.
|
||||
|
||||
:attr:`tensordot` implements a generalized matrix product.
|
||||
|
|
@ -1178,10 +1292,12 @@ def tensordot(a, b, dims=2, out: Optional[torch.Tensor] = None): # noqa: F811
|
|||
return handle_torch_function(tensordot, (a, b), a, b, dims=dims, out=out)
|
||||
|
||||
if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)):
|
||||
raise RuntimeError("tensordot expects dims to be int or "
|
||||
+ "Tuple[List[int], List[int]] or "
|
||||
+ "List[List[int]] containing two lists, but got "
|
||||
+ f"dims={dims}")
|
||||
raise RuntimeError(
|
||||
"tensordot expects dims to be int or "
|
||||
+ "Tuple[List[int], List[int]] or "
|
||||
+ "List[List[int]] containing two lists, but got "
|
||||
+ f"dims={dims}"
|
||||
)
|
||||
|
||||
dims_a: List[int] = []
|
||||
dims_b: List[int] = []
|
||||
|
|
@ -1206,7 +1322,9 @@ def tensordot(a, b, dims=2, out: Optional[torch.Tensor] = None): # noqa: F811
|
|||
if dims < 0:
|
||||
raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
|
||||
if dims > min(a.dim(), b.dim()):
|
||||
raise RuntimeError(f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}")
|
||||
raise RuntimeError(
|
||||
f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}"
|
||||
)
|
||||
dims_a = list(range(-dims, 0))
|
||||
dims_b = list(range(dims))
|
||||
|
||||
|
|
@ -1287,7 +1405,7 @@ def block_diag(*tensors):
|
|||
return torch._C._VariableFunctions.block_diag(tensors) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'):
|
||||
def cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"):
|
||||
# type: (Tensor, Tensor, float, str) -> (Tensor)
|
||||
r"""Computes batched the p-norm distance between each pair of the two collections of row vectors.
|
||||
|
||||
|
|
@ -1331,12 +1449,13 @@ def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'):
|
|||
"""
|
||||
if has_torch_function_variadic(x1, x2):
|
||||
return handle_torch_function(
|
||||
cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode)
|
||||
if compute_mode == 'use_mm_for_euclid_dist_if_necessary':
|
||||
cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode
|
||||
)
|
||||
if compute_mode == "use_mm_for_euclid_dist_if_necessary":
|
||||
return _VF.cdist(x1, x2, p, None) # type: ignore[attr-defined]
|
||||
elif compute_mode == 'use_mm_for_euclid_dist':
|
||||
elif compute_mode == "use_mm_for_euclid_dist":
|
||||
return _VF.cdist(x1, x2, p, 1) # type: ignore[attr-defined]
|
||||
elif compute_mode == 'donot_use_mm_for_euclid_dist':
|
||||
elif compute_mode == "donot_use_mm_for_euclid_dist":
|
||||
return _VF.cdist(x1, x2, p, 2) # type: ignore[attr-defined]
|
||||
else:
|
||||
raise ValueError(f"{compute_mode} is not a valid value for compute_mode")
|
||||
|
|
@ -1478,27 +1597,62 @@ else:
|
|||
# TODO: type dim as BroadcastingList when
|
||||
# https://github.com/pytorch/pytorch/issues/33782 is fixed
|
||||
@overload
|
||||
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):
|
||||
def norm(
|
||||
input,
|
||||
p="fro",
|
||||
dim=None,
|
||||
keepdim=False,
|
||||
out=None,
|
||||
dtype=None,
|
||||
):
|
||||
# type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
|
||||
pass
|
||||
|
||||
@overload # noqa: F811
|
||||
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811
|
||||
@overload
|
||||
def norm( # noqa: F811
|
||||
input,
|
||||
p="fro",
|
||||
dim=None,
|
||||
keepdim=False,
|
||||
out=None,
|
||||
dtype=None,
|
||||
):
|
||||
# type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
|
||||
pass
|
||||
|
||||
@overload # noqa: F811
|
||||
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811
|
||||
@overload
|
||||
def norm( # noqa: F811
|
||||
input,
|
||||
p="fro",
|
||||
dim=None,
|
||||
keepdim=False,
|
||||
out=None,
|
||||
dtype=None,
|
||||
):
|
||||
# type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
|
||||
pass
|
||||
|
||||
@overload # noqa: F811
|
||||
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811
|
||||
@overload
|
||||
def norm( # noqa: F811
|
||||
input,
|
||||
p="fro",
|
||||
dim=None,
|
||||
keepdim=False,
|
||||
out=None,
|
||||
dtype=None,
|
||||
):
|
||||
# type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
|
||||
pass
|
||||
|
||||
|
||||
def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811
|
||||
def norm( # noqa: F811
|
||||
input,
|
||||
p: Optional[Union[float, str]] = "fro",
|
||||
dim=None,
|
||||
keepdim=False,
|
||||
out=None,
|
||||
dtype=None,
|
||||
):
|
||||
r"""Returns the matrix norm or vector norm of a given tensor.
|
||||
|
||||
.. warning::
|
||||
|
|
@ -1594,14 +1748,19 @@ def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False,
|
|||
|
||||
if has_torch_function_unary(input):
|
||||
return handle_torch_function(
|
||||
norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)
|
||||
norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype
|
||||
)
|
||||
|
||||
# NB. All the repeated code and weird python is to please TorchScript.
|
||||
# For a more compact implementation see the relevant function in `_refs/__init__.py`
|
||||
|
||||
# We don't do this for MPS or sparse tensors
|
||||
if input.layout == torch.strided and input.device.type in \
|
||||
("cpu", "cuda", "meta", torch.utils.backend_registration._privateuse1_backend_name):
|
||||
if input.layout == torch.strided and input.device.type in (
|
||||
"cpu",
|
||||
"cuda",
|
||||
"meta",
|
||||
torch.utils.backend_registration._privateuse1_backend_name,
|
||||
):
|
||||
if dim is not None:
|
||||
if isinstance(dim, (int, torch.SymInt)):
|
||||
_dim = [dim]
|
||||
|
|
@ -1611,11 +1770,17 @@ def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False,
|
|||
_dim = None # type: ignore[assignment]
|
||||
|
||||
if isinstance(p, str):
|
||||
if p == "fro" and (dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2):
|
||||
if p == "fro" and (
|
||||
dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2
|
||||
):
|
||||
if out is None:
|
||||
return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype)
|
||||
return torch.linalg.vector_norm(
|
||||
input, 2, _dim, keepdim, dtype=dtype
|
||||
)
|
||||
else:
|
||||
return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype, out=out)
|
||||
return torch.linalg.vector_norm(
|
||||
input, 2, _dim, keepdim, dtype=dtype, out=out
|
||||
)
|
||||
|
||||
# Here we either call the nuclear norm, or we call matrix_norm with some arguments
|
||||
# that will throw an error
|
||||
|
|
@ -1624,14 +1789,18 @@ def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False,
|
|||
if out is None:
|
||||
return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype)
|
||||
else:
|
||||
return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype, out=out)
|
||||
return torch.linalg.matrix_norm(
|
||||
input, p, _dim, keepdim, dtype=dtype, out=out
|
||||
)
|
||||
else:
|
||||
# NB. p should be Union[str, number], not Optional!
|
||||
_p = 2.0 if p is None else p
|
||||
if out is None:
|
||||
return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype)
|
||||
else:
|
||||
return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype, out=out)
|
||||
return torch.linalg.vector_norm(
|
||||
input, _p, _dim, keepdim, dtype=dtype, out=out
|
||||
)
|
||||
|
||||
ndim = input.dim()
|
||||
|
||||
|
|
@ -1641,7 +1810,7 @@ def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False,
|
|||
if p == "fro":
|
||||
return _VF.frobenius_norm(input, dim=(), keepdim=keepdim)
|
||||
if not isinstance(p, str):
|
||||
_dim = [i for i in range(ndim)] # noqa: C416 TODO: rewrite as list(range(m))
|
||||
_dim = list(range(ndim))
|
||||
return _VF.norm(input, p, dim=_dim, keepdim=keepdim) # type: ignore[attr-defined]
|
||||
|
||||
# TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed
|
||||
|
|
@ -1695,7 +1864,10 @@ def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False,
|
|||
else:
|
||||
return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out) # type: ignore[attr-defined]
|
||||
|
||||
def unravel_index(indices: Tensor, shape: Union[int, Sequence[int], torch.Size]) -> Tuple[Tensor, ...]:
|
||||
|
||||
def unravel_index(
|
||||
indices: Tensor, shape: Union[int, Sequence[int], torch.Size]
|
||||
) -> Tuple[Tensor, ...]:
|
||||
r"""Converts a tensor of flat indices into a tuple of coordinate tensors that
|
||||
index into an arbitrary tensor of the specified shape.
|
||||
|
||||
|
|
@ -1745,19 +1917,23 @@ def unravel_index(indices: Tensor, shape: Union[int, Sequence[int], torch.Size])
|
|||
tensor([[34], [78]]))
|
||||
"""
|
||||
if has_torch_function_unary(indices):
|
||||
return handle_torch_function(
|
||||
unravel_index, (indices,), indices, shape=shape)
|
||||
return handle_torch_function(unravel_index, (indices,), indices, shape=shape)
|
||||
res_tensor = _unravel_index(indices, shape)
|
||||
return res_tensor.unbind(-1)
|
||||
|
||||
|
||||
def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
|
||||
torch._check_type(
|
||||
not indices.is_complex() and not indices.is_floating_point() and not indices.dtype == torch.bool,
|
||||
lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}")
|
||||
not indices.is_complex()
|
||||
and not indices.is_floating_point()
|
||||
and not indices.dtype == torch.bool,
|
||||
lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}",
|
||||
)
|
||||
|
||||
torch._check_type(
|
||||
isinstance(shape, (int, torch.SymInt, Sequence)),
|
||||
lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}")
|
||||
lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}",
|
||||
)
|
||||
|
||||
if isinstance(shape, (int, torch.SymInt)):
|
||||
shape = torch.Size([shape])
|
||||
|
|
@ -1765,18 +1941,29 @@ def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
|
|||
for dim in shape:
|
||||
torch._check_type(
|
||||
isinstance(dim, (int, torch.SymInt)),
|
||||
lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}")
|
||||
lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}",
|
||||
)
|
||||
shape = torch.Size(shape)
|
||||
|
||||
torch._check_value(
|
||||
all(dim >= 0 for dim in shape),
|
||||
lambda: f"'shape' cannot have negative values, but got {tuple(shape)}")
|
||||
lambda: f"'shape' cannot have negative values, but got {tuple(shape)}",
|
||||
)
|
||||
|
||||
coefs = list(reversed(list(itertools.accumulate(reversed(shape[1:] + torch.Size([1])), func=operator.mul))))
|
||||
coefs = list(
|
||||
reversed(
|
||||
list(
|
||||
itertools.accumulate(
|
||||
reversed(shape[1:] + torch.Size([1])), func=operator.mul
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
return indices.unsqueeze(-1).floor_divide(
|
||||
torch.tensor(coefs, device=indices.device, dtype=torch.int64)
|
||||
) % torch.tensor(shape, device=indices.device, dtype=torch.int64)
|
||||
|
||||
|
||||
def chain_matmul(*matrices, out=None):
|
||||
r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
|
||||
using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
|
||||
|
|
@ -1923,6 +2110,7 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None):
|
|||
# If get_infos is True, then we don't need to check for errors and vice versa
|
||||
return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
_ListOrSeq = Sequence[Tensor]
|
||||
else:
|
||||
|
|
@ -1932,16 +2120,21 @@ else:
|
|||
def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:
|
||||
get_infos_int = 1 if get_infos else 0
|
||||
if out_len - get_infos_int != 2:
|
||||
raise TypeError(f"expected tuple of {2 + int(get_infos)} elements but got {out_len}")
|
||||
raise TypeError(
|
||||
f"expected tuple of {2 + int(get_infos)} elements but got {out_len}"
|
||||
)
|
||||
if not isinstance(out, (tuple, list)):
|
||||
raise TypeError(f"argument 'out' must be tuple of Tensors, not {type(out).__name__}")
|
||||
raise TypeError(
|
||||
f"argument 'out' must be tuple of Tensors, not {type(out).__name__}"
|
||||
)
|
||||
|
||||
|
||||
def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
|
||||
# type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]
|
||||
if has_torch_function_unary(A):
|
||||
return handle_torch_function(
|
||||
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
|
||||
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
|
||||
)
|
||||
result = _lu_impl(A, pivot, get_infos, out)
|
||||
if out is not None:
|
||||
_check_list_size(len(out), get_infos, out)
|
||||
|
|
@ -1957,7 +2150,8 @@ def _lu_no_infos(A, pivot=True, get_infos=False, out=None):
|
|||
# need to check for torch_function here so that we exit if
|
||||
if has_torch_function_unary(A):
|
||||
return handle_torch_function(
|
||||
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
|
||||
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
|
||||
)
|
||||
result = _lu_impl(A, pivot, get_infos, out)
|
||||
if out is not None:
|
||||
_check_list_size(len(out), get_infos, out)
|
||||
|
|
@ -1967,18 +2161,20 @@ def _lu_no_infos(A, pivot=True, get_infos=False, out=None):
|
|||
else:
|
||||
return result[0], result[1] # A_LU, pivots
|
||||
|
||||
|
||||
# The return type of lu depends on `get_infos`, so in order to resolve the output type
|
||||
# of lu in TorchScript we need to statically know the value of `get_infos`
|
||||
lu = boolean_dispatch(
|
||||
arg_name='get_infos',
|
||||
arg_name="get_infos",
|
||||
arg_index=2,
|
||||
default=False,
|
||||
if_true=_lu_with_infos,
|
||||
if_false=_lu_no_infos,
|
||||
module_name=__name__,
|
||||
func_name='lu')
|
||||
func_name="lu",
|
||||
)
|
||||
lu.__doc__ = _lu_impl.__doc__
|
||||
|
||||
|
||||
def align_tensors(*tensors):
|
||||
raise RuntimeError('`align_tensors` not yet implemented.')
|
||||
raise RuntimeError("`align_tensors` not yet implemented.")
|
||||
|
|
|
|||
257
torch/hub.py
257
torch/hub.py
|
|
@ -8,22 +8,22 @@ import re
|
|||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import torch
|
||||
import uuid
|
||||
import warnings
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Any
|
||||
from typing import Any, Dict, Optional
|
||||
from typing_extensions import deprecated
|
||||
from urllib.error import HTTPError, URLError
|
||||
from urllib.request import urlopen, Request
|
||||
from urllib.parse import urlparse # noqa: F401
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
import torch
|
||||
from torch.serialization import MAP_LOCATION
|
||||
|
||||
class _Faketqdm: # type: ignore[no-redef]
|
||||
|
||||
def __init__(self, total=None, disable=False,
|
||||
unit=None, *args, **kwargs):
|
||||
class _Faketqdm: # type: ignore[no-redef]
|
||||
def __init__(self, total=None, disable=False, unit=None, *args, **kwargs):
|
||||
self.total = total
|
||||
self.disable = disable
|
||||
self.n = 0
|
||||
|
|
@ -57,7 +57,8 @@ class _Faketqdm: # type: ignore[no-redef]
|
|||
if self.disable:
|
||||
return
|
||||
|
||||
sys.stderr.write('\n')
|
||||
sys.stderr.write("\n")
|
||||
|
||||
|
||||
try:
|
||||
from tqdm import tqdm # If tqdm is installed use it, otherwise use the fake wrapper
|
||||
|
|
@ -65,25 +66,30 @@ except ImportError:
|
|||
tqdm = _Faketqdm
|
||||
|
||||
__all__ = [
|
||||
'download_url_to_file',
|
||||
'get_dir',
|
||||
'help',
|
||||
'list',
|
||||
'load',
|
||||
'load_state_dict_from_url',
|
||||
'set_dir',
|
||||
"download_url_to_file",
|
||||
"get_dir",
|
||||
"help",
|
||||
"list",
|
||||
"load",
|
||||
"load_state_dict_from_url",
|
||||
"set_dir",
|
||||
]
|
||||
|
||||
# matches bfd8deac from resnet18-bfd8deac.pth
|
||||
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
|
||||
HASH_REGEX = re.compile(r"-([a-f0-9]*)\.")
|
||||
|
||||
_TRUSTED_REPO_OWNERS = ("facebookresearch", "facebookincubator", "pytorch", "fairinternal")
|
||||
ENV_GITHUB_TOKEN = 'GITHUB_TOKEN'
|
||||
ENV_TORCH_HOME = 'TORCH_HOME'
|
||||
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
||||
DEFAULT_CACHE_DIR = '~/.cache'
|
||||
VAR_DEPENDENCY = 'dependencies'
|
||||
MODULE_HUBCONF = 'hubconf.py'
|
||||
_TRUSTED_REPO_OWNERS = (
|
||||
"facebookresearch",
|
||||
"facebookincubator",
|
||||
"pytorch",
|
||||
"fairinternal",
|
||||
)
|
||||
ENV_GITHUB_TOKEN = "GITHUB_TOKEN"
|
||||
ENV_TORCH_HOME = "TORCH_HOME"
|
||||
ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
|
||||
DEFAULT_CACHE_DIR = "~/.cache"
|
||||
VAR_DEPENDENCY = "dependencies"
|
||||
MODULE_HUBCONF = "hubconf.py"
|
||||
READ_DATA_CHUNK = 128 * 1024
|
||||
_hub_dir: Optional[str] = None
|
||||
|
||||
|
|
@ -101,6 +107,7 @@ def _add_to_sys_path(path):
|
|||
def _import_module(name, path):
|
||||
import importlib.util
|
||||
from importlib.abc import Loader
|
||||
|
||||
spec = importlib.util.spec_from_file_location(name, path)
|
||||
assert spec is not None
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
|
|
@ -131,18 +138,20 @@ def _load_attr_from_module(module, func_name):
|
|||
|
||||
def _get_torch_home():
|
||||
torch_home = os.path.expanduser(
|
||||
os.getenv(ENV_TORCH_HOME,
|
||||
os.path.join(os.getenv(ENV_XDG_CACHE_HOME,
|
||||
DEFAULT_CACHE_DIR), 'torch')))
|
||||
os.getenv(
|
||||
ENV_TORCH_HOME,
|
||||
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"),
|
||||
)
|
||||
)
|
||||
return torch_home
|
||||
|
||||
|
||||
def _parse_repo_info(github):
|
||||
if ':' in github:
|
||||
repo_info, ref = github.split(':')
|
||||
if ":" in github:
|
||||
repo_info, ref = github.split(":")
|
||||
else:
|
||||
repo_info, ref = github, None
|
||||
repo_owner, repo_name = repo_info.split('/')
|
||||
repo_owner, repo_name = repo_info.split("/")
|
||||
|
||||
if ref is None:
|
||||
# The ref wasn't specified by the user, so we need to figure out the
|
||||
|
|
@ -150,16 +159,18 @@ def _parse_repo_info(github):
|
|||
# then it's the default branch, otherwise it's master.
|
||||
try:
|
||||
with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"):
|
||||
ref = 'main'
|
||||
ref = "main"
|
||||
except HTTPError as e:
|
||||
if e.code == 404:
|
||||
ref = 'master'
|
||||
ref = "master"
|
||||
else:
|
||||
raise
|
||||
except URLError as e:
|
||||
# No internet connection, need to check for cache as last resort
|
||||
for possible_ref in ("main", "master"):
|
||||
if os.path.exists(f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}"):
|
||||
if os.path.exists(
|
||||
f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}"
|
||||
):
|
||||
ref = possible_ref
|
||||
break
|
||||
if ref is None:
|
||||
|
|
@ -172,35 +183,40 @@ def _parse_repo_info(github):
|
|||
|
||||
def _read_url(url):
|
||||
with urlopen(url) as r:
|
||||
return r.read().decode(r.headers.get_content_charset('utf-8'))
|
||||
return r.read().decode(r.headers.get_content_charset("utf-8"))
|
||||
|
||||
|
||||
def _validate_not_a_forked_repo(repo_owner, repo_name, ref):
|
||||
# Use urlopen to avoid depending on local git.
|
||||
headers = {'Accept': 'application/vnd.github.v3+json'}
|
||||
headers = {"Accept": "application/vnd.github.v3+json"}
|
||||
token = os.environ.get(ENV_GITHUB_TOKEN)
|
||||
if token is not None:
|
||||
headers['Authorization'] = f'token {token}'
|
||||
headers["Authorization"] = f"token {token}"
|
||||
for url_prefix in (
|
||||
f'https://api.github.com/repos/{repo_owner}/{repo_name}/branches',
|
||||
f'https://api.github.com/repos/{repo_owner}/{repo_name}/tags'):
|
||||
f"https://api.github.com/repos/{repo_owner}/{repo_name}/branches",
|
||||
f"https://api.github.com/repos/{repo_owner}/{repo_name}/tags",
|
||||
):
|
||||
page = 0
|
||||
while True:
|
||||
page += 1
|
||||
url = f'{url_prefix}?per_page=100&page={page}'
|
||||
url = f"{url_prefix}?per_page=100&page={page}"
|
||||
response = json.loads(_read_url(Request(url, headers=headers)))
|
||||
# Empty response means no more data to process
|
||||
if not response:
|
||||
break
|
||||
for br in response:
|
||||
if br['name'] == ref or br['commit']['sha'].startswith(ref):
|
||||
if br["name"] == ref or br["commit"]["sha"].startswith(ref):
|
||||
return
|
||||
|
||||
raise ValueError(f'Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. '
|
||||
'If it\'s a commit from a forked repo, please call hub.load() with forked repo directly.')
|
||||
raise ValueError(
|
||||
f"Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. "
|
||||
"If it's a commit from a forked repo, please call hub.load() with forked repo directly."
|
||||
)
|
||||
|
||||
|
||||
def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=True, skip_validation=False):
|
||||
def _get_cache_or_reload(
|
||||
github, force_reload, trust_repo, calling_fn, verbose=True, skip_validation=False
|
||||
):
|
||||
# Setup hub_dir to save downloaded files
|
||||
hub_dir = get_dir()
|
||||
os.makedirs(hub_dir, exist_ok=True)
|
||||
|
|
@ -210,27 +226,33 @@ def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=T
|
|||
# this causes confusion with path on both Linux and Windows.
|
||||
# Backslash is not allowed in Github branch name so no need to
|
||||
# to worry about it.
|
||||
normalized_br = ref.replace('/', '_')
|
||||
normalized_br = ref.replace("/", "_")
|
||||
# Github renames folder repo-v1.x.x to repo-1.x.x
|
||||
# We don't know the repo name before downloading the zip file
|
||||
# and inspect name from it.
|
||||
# To check if cached repo exists, we need to normalize folder names.
|
||||
owner_name_branch = '_'.join([repo_owner, repo_name, normalized_br])
|
||||
owner_name_branch = "_".join([repo_owner, repo_name, normalized_br])
|
||||
repo_dir = os.path.join(hub_dir, owner_name_branch)
|
||||
# Check that the repo is in the trusted list
|
||||
_check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo=trust_repo, calling_fn=calling_fn)
|
||||
_check_repo_is_trusted(
|
||||
repo_owner,
|
||||
repo_name,
|
||||
owner_name_branch,
|
||||
trust_repo=trust_repo,
|
||||
calling_fn=calling_fn,
|
||||
)
|
||||
|
||||
use_cache = (not force_reload) and os.path.exists(repo_dir)
|
||||
|
||||
if use_cache:
|
||||
if verbose:
|
||||
sys.stderr.write(f'Using cache found in {repo_dir}\n')
|
||||
sys.stderr.write(f"Using cache found in {repo_dir}\n")
|
||||
else:
|
||||
# Validate the tag/branch is from the original repo instead of a forked repo
|
||||
if not skip_validation:
|
||||
_validate_not_a_forked_repo(repo_owner, repo_name, ref)
|
||||
|
||||
cached_file = os.path.join(hub_dir, normalized_br + '.zip')
|
||||
cached_file = os.path.join(hub_dir, normalized_br + ".zip")
|
||||
_remove_if_exists(cached_file)
|
||||
|
||||
try:
|
||||
|
|
@ -250,7 +272,9 @@ def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=T
|
|||
"refs/tags/tag_name as the ref. That might require using skip_validation=True."
|
||||
)
|
||||
disambiguated_branch_ref = f"refs/heads/{ref}"
|
||||
url = _git_archive_link(repo_owner, repo_name, ref=disambiguated_branch_ref)
|
||||
url = _git_archive_link(
|
||||
repo_owner, repo_name, ref=disambiguated_branch_ref
|
||||
)
|
||||
download_url_to_file(url, cached_file, progress=False)
|
||||
else:
|
||||
raise
|
||||
|
|
@ -269,7 +293,9 @@ def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=T
|
|||
return repo_dir
|
||||
|
||||
|
||||
def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo, calling_fn="load"):
|
||||
def _check_repo_is_trusted(
|
||||
repo_owner, repo_name, owner_name_branch, trust_repo, calling_fn="load"
|
||||
):
|
||||
hub_dir = get_dir()
|
||||
filepath = os.path.join(hub_dir, "trusted_list")
|
||||
|
||||
|
|
@ -282,7 +308,7 @@ def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo,
|
|||
# if a repo was already downloaded by torchhub, then it is already trusted (even if it's not in the allowlist)
|
||||
trusted_repos_legacy = next(os.walk(hub_dir))[1]
|
||||
|
||||
owner_name = '_'.join([repo_owner, repo_name])
|
||||
owner_name = "_".join([repo_owner, repo_name])
|
||||
is_trusted = (
|
||||
owner_name in trusted_repos
|
||||
or owner_name_branch in trusted_repos_legacy
|
||||
|
|
@ -298,13 +324,15 @@ def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo,
|
|||
"trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, "
|
||||
f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with "
|
||||
f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for "
|
||||
f"confirmation if the repo is not already trusted. This will eventually be the default behaviour")
|
||||
f"confirmation if the repo is not already trusted. This will eventually be the default behaviour"
|
||||
)
|
||||
return
|
||||
|
||||
if (trust_repo is False) or (trust_repo == "check" and not is_trusted):
|
||||
response = input(
|
||||
f"The repository {owner_name} does not belong to the list of trusted repositories and as such cannot be downloaded. "
|
||||
"Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?")
|
||||
"Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?"
|
||||
)
|
||||
if response.lower() in ("y", "yes"):
|
||||
if is_trusted:
|
||||
print("The repository is already trusted.")
|
||||
|
|
@ -321,6 +349,7 @@ def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo,
|
|||
|
||||
def _check_module_exists(name):
|
||||
import importlib.util
|
||||
|
||||
return importlib.util.find_spec(name) is not None
|
||||
|
||||
|
||||
|
|
@ -335,7 +364,7 @@ def _check_dependencies(m):
|
|||
|
||||
def _load_entry_from_hubconf(m, model):
|
||||
if not isinstance(model, str):
|
||||
raise ValueError('Invalid input: model should be a string of function name')
|
||||
raise ValueError("Invalid input: model should be a string of function name")
|
||||
|
||||
# Note that if a missing dependency is imported at top level of hubconf, it will
|
||||
# throw before this function. It's a chicken and egg situation where we have to
|
||||
|
|
@ -346,7 +375,7 @@ def _load_entry_from_hubconf(m, model):
|
|||
func = _load_attr_from_module(m, model)
|
||||
|
||||
if func is None or not callable(func):
|
||||
raise RuntimeError(f'Cannot find callable {model} in hubconf')
|
||||
raise RuntimeError(f"Cannot find callable {model} in hubconf")
|
||||
|
||||
return func
|
||||
|
||||
|
|
@ -362,12 +391,12 @@ def get_dir():
|
|||
variable is not set.
|
||||
"""
|
||||
# Issue warning to move data if old env is set
|
||||
if os.getenv('TORCH_HUB'):
|
||||
warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead')
|
||||
if os.getenv("TORCH_HUB"):
|
||||
warnings.warn("TORCH_HUB is deprecated, please use env TORCH_HOME instead")
|
||||
|
||||
if _hub_dir is not None:
|
||||
return _hub_dir
|
||||
return os.path.join(_get_torch_home(), 'hub')
|
||||
return os.path.join(_get_torch_home(), "hub")
|
||||
|
||||
|
||||
def set_dir(d):
|
||||
|
|
@ -381,7 +410,9 @@ def set_dir(d):
|
|||
_hub_dir = os.path.expanduser(d)
|
||||
|
||||
|
||||
def list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True):
|
||||
def list(
|
||||
github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True
|
||||
):
|
||||
r"""
|
||||
List all callable entrypoints available in the repo specified by ``github``.
|
||||
|
||||
|
|
@ -424,15 +455,25 @@ def list(github, force_reload=False, skip_validation=False, trust_repo=None, ver
|
|||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
|
||||
>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
|
||||
"""
|
||||
repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "list", verbose=verbose,
|
||||
skip_validation=skip_validation)
|
||||
repo_dir = _get_cache_or_reload(
|
||||
github,
|
||||
force_reload,
|
||||
trust_repo,
|
||||
"list",
|
||||
verbose=verbose,
|
||||
skip_validation=skip_validation,
|
||||
)
|
||||
|
||||
with _add_to_sys_path(repo_dir):
|
||||
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
|
||||
hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
|
||||
|
||||
# We take functions starts with '_' as internal helper functions
|
||||
entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')]
|
||||
entrypoints = [
|
||||
f
|
||||
for f in dir(hub_module)
|
||||
if callable(getattr(hub_module, f)) and not f.startswith("_")
|
||||
]
|
||||
|
||||
return entrypoints
|
||||
|
||||
|
|
@ -474,8 +515,14 @@ def help(github, model, force_reload=False, skip_validation=False, trust_repo=No
|
|||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
|
||||
>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
|
||||
"""
|
||||
repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "help", verbose=True,
|
||||
skip_validation=skip_validation)
|
||||
repo_dir = _get_cache_or_reload(
|
||||
github,
|
||||
force_reload,
|
||||
trust_repo,
|
||||
"help",
|
||||
verbose=True,
|
||||
skip_validation=skip_validation,
|
||||
)
|
||||
|
||||
with _add_to_sys_path(repo_dir):
|
||||
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
|
||||
|
|
@ -486,9 +533,17 @@ def help(github, model, force_reload=False, skip_validation=False, trust_repo=No
|
|||
return entry.__doc__
|
||||
|
||||
|
||||
def load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True,
|
||||
skip_validation=False,
|
||||
**kwargs):
|
||||
def load(
|
||||
repo_or_dir,
|
||||
model,
|
||||
*args,
|
||||
source="github",
|
||||
trust_repo=None,
|
||||
force_reload=False,
|
||||
verbose=True,
|
||||
skip_validation=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Load a model from a github repo or a local directory.
|
||||
|
||||
|
|
@ -559,13 +614,20 @@ def load(repo_or_dir, model, *args, source='github', trust_repo=None, force_relo
|
|||
"""
|
||||
source = source.lower()
|
||||
|
||||
if source not in ('github', 'local'):
|
||||
if source not in ("github", "local"):
|
||||
raise ValueError(
|
||||
f'Unknown source: "{source}". Allowed values: "github" | "local".')
|
||||
f'Unknown source: "{source}". Allowed values: "github" | "local".'
|
||||
)
|
||||
|
||||
if source == 'github':
|
||||
repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, trust_repo, "load",
|
||||
verbose=verbose, skip_validation=skip_validation)
|
||||
if source == "github":
|
||||
repo_or_dir = _get_cache_or_reload(
|
||||
repo_or_dir,
|
||||
force_reload,
|
||||
trust_repo,
|
||||
"load",
|
||||
verbose=verbose,
|
||||
skip_validation=skip_validation,
|
||||
)
|
||||
|
||||
model = _load_local(repo_or_dir, model, *args, **kwargs)
|
||||
return model
|
||||
|
|
@ -601,8 +663,9 @@ def _load_local(hubconf_dir, model, *args, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
|
||||
progress: bool = True) -> None:
|
||||
def download_url_to_file(
|
||||
url: str, dst: str, hash_prefix: Optional[str] = None, progress: bool = True
|
||||
) -> None:
|
||||
r"""Download object at the given URL to a local path.
|
||||
|
||||
Args:
|
||||
|
|
@ -623,7 +686,7 @@ def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
|
|||
req = Request(url, headers={"User-Agent": "torch.hub"})
|
||||
u = urlopen(req)
|
||||
meta = u.info()
|
||||
if hasattr(meta, 'getheaders'):
|
||||
if hasattr(meta, "getheaders"):
|
||||
content_length = meta.getheaders("Content-Length")
|
||||
else:
|
||||
content_length = meta.get_all("Content-Length")
|
||||
|
|
@ -637,20 +700,25 @@ def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
|
|||
# file permissions being applied to the downloaded file.
|
||||
dst = os.path.expanduser(dst)
|
||||
for seq in range(tempfile.TMP_MAX):
|
||||
tmp_dst = dst + '.' + uuid.uuid4().hex + '.partial'
|
||||
tmp_dst = dst + "." + uuid.uuid4().hex + ".partial"
|
||||
try:
|
||||
f = open(tmp_dst, 'w+b')
|
||||
f = open(tmp_dst, "w+b")
|
||||
except FileExistsError:
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise FileExistsError(errno.EEXIST, 'No usable temporary file name found')
|
||||
raise FileExistsError(errno.EEXIST, "No usable temporary file name found")
|
||||
|
||||
try:
|
||||
if hash_prefix is not None:
|
||||
sha256 = hashlib.sha256()
|
||||
with tqdm(total=file_size, disable=not progress,
|
||||
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
|
||||
with tqdm(
|
||||
total=file_size,
|
||||
disable=not progress,
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as pbar:
|
||||
while True:
|
||||
buffer = u.read(READ_DATA_CHUNK)
|
||||
if len(buffer) == 0:
|
||||
|
|
@ -663,8 +731,10 @@ def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
|
|||
f.close()
|
||||
if hash_prefix is not None:
|
||||
digest = sha256.hexdigest() # type: ignore[possibly-undefined]
|
||||
if digest[:len(hash_prefix)] != hash_prefix:
|
||||
raise RuntimeError(f'invalid hash value (expected "{hash_prefix}", got "{digest}")')
|
||||
if digest[: len(hash_prefix)] != hash_prefix:
|
||||
raise RuntimeError(
|
||||
f'invalid hash value (expected "{hash_prefix}", got "{digest}")'
|
||||
)
|
||||
shutil.move(f.name, dst)
|
||||
finally:
|
||||
f.close()
|
||||
|
|
@ -683,23 +753,30 @@ def _is_legacy_zip_format(filename: str) -> bool:
|
|||
|
||||
|
||||
@deprecated(
|
||||
'Falling back to the old format < 1.6. This support will be '
|
||||
'deprecated in favor of default zipfile format introduced in 1.6. '
|
||||
'Please redo torch.save() to save it in the new zipfile format.',
|
||||
"Falling back to the old format < 1.6. This support will be "
|
||||
"deprecated in favor of default zipfile format introduced in 1.6. "
|
||||
"Please redo torch.save() to save it in the new zipfile format.",
|
||||
category=FutureWarning,
|
||||
)
|
||||
def _legacy_zip_load(filename: str, model_dir: str, map_location: MAP_LOCATION, weights_only: bool) -> Dict[str, Any]:
|
||||
def _legacy_zip_load(
|
||||
filename: str,
|
||||
model_dir: str,
|
||||
map_location: MAP_LOCATION,
|
||||
weights_only: bool,
|
||||
) -> Dict[str, Any]:
|
||||
# Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
|
||||
# We deliberately don't handle tarfile here since our legacy serialization format was in tar.
|
||||
# E.g. resnet18-5c106cde.pth which is widely used.
|
||||
with zipfile.ZipFile(filename) as f:
|
||||
members = f.infolist()
|
||||
if len(members) != 1:
|
||||
raise RuntimeError('Only one file(not dir) is allowed in the zipfile')
|
||||
raise RuntimeError("Only one file(not dir) is allowed in the zipfile")
|
||||
f.extractall(model_dir)
|
||||
extraced_name = members[0].filename
|
||||
extracted_file = os.path.join(model_dir, extraced_name)
|
||||
return torch.load(extracted_file, map_location=map_location, weights_only=weights_only)
|
||||
return torch.load(
|
||||
extracted_file, map_location=map_location, weights_only=weights_only
|
||||
)
|
||||
|
||||
|
||||
def load_state_dict_from_url(
|
||||
|
|
@ -742,12 +819,14 @@ def load_state_dict_from_url(
|
|||
|
||||
"""
|
||||
# Issue warning to move data if old env is set
|
||||
if os.getenv('TORCH_MODEL_ZOO'):
|
||||
warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
||||
if os.getenv("TORCH_MODEL_ZOO"):
|
||||
warnings.warn(
|
||||
"TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead"
|
||||
)
|
||||
|
||||
if model_dir is None:
|
||||
hub_dir = get_dir()
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints')
|
||||
model_dir = os.path.join(hub_dir, "checkpoints")
|
||||
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
|
|
|
|||
242
torch/library.py
242
torch/library.py
|
|
@ -1,28 +1,34 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from ._ops import OpOverload
|
||||
from typing import Any, Optional, Set, List, Union, Callable, Tuple, Dict, Sequence
|
||||
from typing_extensions import deprecated
|
||||
import traceback
|
||||
import torch
|
||||
import weakref
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import re
|
||||
import contextlib
|
||||
import sys
|
||||
from torch._library.custom_ops import custom_op, _maybe_get_opdef, device_types_t, CustomOpDef
|
||||
import traceback
|
||||
import weakref
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
import torch._library as _library
|
||||
from torch._library.custom_ops import (
|
||||
_maybe_get_opdef,
|
||||
custom_op,
|
||||
CustomOpDef,
|
||||
device_types_t,
|
||||
)
|
||||
from torch._ops import OpOverload
|
||||
|
||||
|
||||
__all__ = [
|
||||
'Library',
|
||||
'impl',
|
||||
'define',
|
||||
'fallthrough_kernel',
|
||||
'impl_abstract',
|
||||
'register_fake',
|
||||
'get_ctx',
|
||||
'custom_op',
|
||||
"Library",
|
||||
"impl",
|
||||
"define",
|
||||
"fallthrough_kernel",
|
||||
"impl_abstract",
|
||||
"register_fake",
|
||||
"get_ctx",
|
||||
"custom_op",
|
||||
]
|
||||
|
||||
# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
|
||||
|
|
@ -33,7 +39,8 @@ _impls: Set[str] = set()
|
|||
_defs: Set[str] = set()
|
||||
|
||||
# prim is reserved by TorchScript interpreter
|
||||
_reserved_namespaces = ['prim']
|
||||
_reserved_namespaces = ["prim"]
|
||||
|
||||
|
||||
def fallthrough_kernel():
|
||||
"""
|
||||
|
|
@ -41,6 +48,7 @@ def fallthrough_kernel():
|
|||
"""
|
||||
raise NotImplementedError("fallthrough_kernel() should never be called.")
|
||||
|
||||
|
||||
class Library:
|
||||
"""
|
||||
A class to create libraries that can be used to register new operators or
|
||||
|
|
@ -59,16 +67,22 @@ class Library:
|
|||
kind: "DEF", "IMPL" (default: "IMPL"), "FRAGMENT"
|
||||
dispatch_key: PyTorch dispatch key (default: "")
|
||||
"""
|
||||
|
||||
def __init__(self, ns, kind, dispatch_key=""):
|
||||
if kind not in ('IMPL', 'DEF', 'FRAGMENT'):
|
||||
if kind not in ("IMPL", "DEF", "FRAGMENT"):
|
||||
raise ValueError("Unsupported kind: ", kind)
|
||||
|
||||
if ns in _reserved_namespaces and (kind == "DEF" or kind == 'FRAGMENT'):
|
||||
raise ValueError(ns, " is a reserved namespace. Please try creating a library with another name.")
|
||||
if ns in _reserved_namespaces and (kind == "DEF" or kind == "FRAGMENT"):
|
||||
raise ValueError(
|
||||
ns,
|
||||
" is a reserved namespace. Please try creating a library with another name.",
|
||||
)
|
||||
|
||||
frame = traceback.extract_stack(limit=3)[0]
|
||||
filename, lineno = frame.filename, frame.lineno
|
||||
self.m: Optional[Any] = torch._C._dispatch_library(kind, ns, dispatch_key, filename, lineno)
|
||||
self.m: Optional[Any] = torch._C._dispatch_library(
|
||||
kind, ns, dispatch_key, filename, lineno
|
||||
)
|
||||
self.ns = ns
|
||||
self._op_defs: Set[str] = set()
|
||||
self._op_impls: Set[str] = set()
|
||||
|
|
@ -79,13 +93,21 @@ class Library:
|
|||
# Python __del__ can lead to weird things (globals and locals may already
|
||||
# be gone when __del__ actually gets called!). finalizers help the
|
||||
# situation because it lets us capture references and keeps them alive
|
||||
weakref.finalize(self, _del_library, _impls, self._op_impls, _defs, self._op_defs, self._registration_handles)
|
||||
weakref.finalize(
|
||||
self,
|
||||
_del_library,
|
||||
_impls,
|
||||
self._op_impls,
|
||||
_defs,
|
||||
self._op_defs,
|
||||
self._registration_handles,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>"
|
||||
|
||||
def define(self, schema, alias_analysis="", *, tags=()):
|
||||
r'''Defines a new operator and its semantics in the ns namespace.
|
||||
r"""Defines a new operator and its semantics in the ns namespace.
|
||||
|
||||
Args:
|
||||
schema: function schema to define a new operator.
|
||||
|
|
@ -102,7 +124,7 @@ class Library:
|
|||
Example::
|
||||
>>> my_lib = Library("mylib", "DEF")
|
||||
>>> my_lib.define("sum(Tensor self) -> Tensor")
|
||||
'''
|
||||
"""
|
||||
# This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
|
||||
# AliasAnalysis type in C++
|
||||
if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]:
|
||||
|
|
@ -113,7 +135,9 @@ class Library:
|
|||
|
||||
name = schema.split("(")[0]
|
||||
packet_name = name.split(".")[0] if "." in name else name
|
||||
has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr(getattr(torch.ops, self.ns), packet_name)
|
||||
has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr(
|
||||
getattr(torch.ops, self.ns), packet_name
|
||||
)
|
||||
|
||||
result = self.m.define(schema, alias_analysis, tuple(tags))
|
||||
name = schema.split("(")[0]
|
||||
|
|
@ -131,7 +155,7 @@ class Library:
|
|||
return result
|
||||
|
||||
def _register_fake(self, op_name, fn, _stacklevel=1):
|
||||
r'''Registers the fake impl for an operator defined in the library.'''
|
||||
r"""Registers the fake impl for an operator defined in the library."""
|
||||
source = torch._library.utils.get_source(_stacklevel + 1)
|
||||
frame = sys._getframe(_stacklevel)
|
||||
caller_module = inspect.getmodule(frame)
|
||||
|
|
@ -141,7 +165,9 @@ class Library:
|
|||
|
||||
# TODO(rzou): We're gonna need to stage this change with torchvision,
|
||||
# since torchvision is github first.
|
||||
if caller_module_name is not None and caller_module_name.startswith("torchvision."):
|
||||
if caller_module_name is not None and caller_module_name.startswith(
|
||||
"torchvision."
|
||||
):
|
||||
caller_module_name = None
|
||||
|
||||
qualname = f"{self.ns}::{op_name}"
|
||||
|
|
@ -154,8 +180,8 @@ class Library:
|
|||
handle = entry.abstract_impl.register(func_to_register, source)
|
||||
self._registration_handles.append(handle)
|
||||
|
||||
def _impl_with_aoti_compile(self, op_name, dispatch_key=''):
|
||||
r'''Register the operator to use the AOTI-compiled implementation.
|
||||
def _impl_with_aoti_compile(self, op_name, dispatch_key=""):
|
||||
r"""Register the operator to use the AOTI-compiled implementation.
|
||||
|
||||
Args:
|
||||
op_name: operator name (along with the overload) or OpOverload object.
|
||||
|
|
@ -165,8 +191,8 @@ class Library:
|
|||
Example::
|
||||
>>> my_lib = Library("aten", "IMPL")
|
||||
>>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU")
|
||||
'''
|
||||
if dispatch_key == '':
|
||||
"""
|
||||
if dispatch_key == "":
|
||||
dispatch_key = self.dispatch_key
|
||||
assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense)
|
||||
|
||||
|
|
@ -175,19 +201,24 @@ class Library:
|
|||
elif isinstance(op_name, OpOverload):
|
||||
name = op_name._schema.name
|
||||
overload_name = op_name._schema.overload_name
|
||||
if overload_name != '':
|
||||
name = name + '.' + overload_name
|
||||
if overload_name != "":
|
||||
name = name + "." + overload_name
|
||||
else:
|
||||
raise RuntimeError("_impl_with_aoti_compile should be passed either a name or an OpOverload object "
|
||||
"as the first argument")
|
||||
raise RuntimeError(
|
||||
"_impl_with_aoti_compile should be passed either a name or an OpOverload object "
|
||||
"as the first argument"
|
||||
)
|
||||
|
||||
key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
|
||||
if key in _impls:
|
||||
# TODO: in future, add more info about where the existing function is registered (this info is
|
||||
# today already returned by the C++ warning when _impl_with_aoti_compile is called but we error out before that)
|
||||
raise RuntimeError("This is not allowed since there's already a kernel registered from python overriding {}"
|
||||
"'s behavior for {} dispatch key and {} namespace.".
|
||||
format(name.split("::")[-1], dispatch_key, self.ns))
|
||||
raise RuntimeError(
|
||||
"This is not allowed since there's already a kernel registered from python overriding {}"
|
||||
"'s behavior for {} dispatch key and {} namespace.".format(
|
||||
name.split("::")[-1], dispatch_key, self.ns
|
||||
)
|
||||
)
|
||||
|
||||
assert self.m is not None
|
||||
impl_fn: Callable = self.m.impl_with_aoti_compile
|
||||
|
|
@ -196,8 +227,8 @@ class Library:
|
|||
_impls.add(key)
|
||||
self._op_impls.add(key)
|
||||
|
||||
def impl(self, op_name, fn, dispatch_key='', *, with_keyset=False):
|
||||
r'''Registers the function implementation for an operator defined in the library.
|
||||
def impl(self, op_name, fn, dispatch_key="", *, with_keyset=False):
|
||||
r"""Registers the function implementation for an operator defined in the library.
|
||||
|
||||
Args:
|
||||
op_name: operator name (along with the overload) or OpOverload object.
|
||||
|
|
@ -211,10 +242,12 @@ class Library:
|
|||
>>> def div_cpu(self, other):
|
||||
>>> return self * (1 / other)
|
||||
>>> my_lib.impl("div.Tensor", div_cpu, "CPU")
|
||||
'''
|
||||
"""
|
||||
if not callable(fn):
|
||||
raise TypeError(f"Input function is required to be a callable but found type {type(fn)}")
|
||||
if dispatch_key == '':
|
||||
raise TypeError(
|
||||
f"Input function is required to be a callable but found type {type(fn)}"
|
||||
)
|
||||
if dispatch_key == "":
|
||||
dispatch_key = self.dispatch_key
|
||||
|
||||
if isinstance(op_name, str):
|
||||
|
|
@ -222,37 +255,50 @@ class Library:
|
|||
elif isinstance(op_name, OpOverload):
|
||||
name = op_name._schema.name
|
||||
overload_name = op_name._schema.overload_name
|
||||
if overload_name != '':
|
||||
name = name + '.' + overload_name
|
||||
if overload_name != "":
|
||||
name = name + "." + overload_name
|
||||
else:
|
||||
raise RuntimeError("impl should be passed either a name or an OpOverload object as the first argument")
|
||||
raise RuntimeError(
|
||||
"impl should be passed either a name or an OpOverload object as the first argument"
|
||||
)
|
||||
|
||||
key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
|
||||
if key in _impls:
|
||||
# TODO: in future, add more info about where the existing function is registered (this info is
|
||||
# today already returned by the C++ warning when impl is called but we error out before that)
|
||||
raise RuntimeError("This is not allowed since there's already a kernel registered from python overriding {}"
|
||||
"'s behavior for {} dispatch key and {} namespace.".
|
||||
format(name.split("::")[-1], dispatch_key, self.ns))
|
||||
raise RuntimeError(
|
||||
"This is not allowed since there's already a kernel registered from python overriding {}"
|
||||
"'s behavior for {} dispatch key and {} namespace.".format(
|
||||
name.split("::")[-1], dispatch_key, self.ns
|
||||
)
|
||||
)
|
||||
|
||||
if dispatch_key == "Meta":
|
||||
dispatcher_op_name = name
|
||||
if '::' not in dispatcher_op_name:
|
||||
dispatcher_op_name = f'{self.ns}::{dispatcher_op_name}'
|
||||
if "::" not in dispatcher_op_name:
|
||||
dispatcher_op_name = f"{self.ns}::{dispatcher_op_name}"
|
||||
|
||||
# Internally, we shouldn't be registering meta kernels for any operators that
|
||||
# have CompositeImplicitAutograd kernels.
|
||||
# Instead, we should be letting those decompositions run, and writing meta kernels
|
||||
# only for the base operators.
|
||||
if torch._C._dispatch_has_kernel_for_dispatch_key(dispatcher_op_name, "CompositeImplicitAutograd"):
|
||||
if torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||
dispatcher_op_name, "CompositeImplicitAutograd"
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"We should not register a meta kernel directly to the operator '{name}',"
|
||||
" because it has a CompositeImplicitAutograd kernel in core."
|
||||
" Instead we should let the operator decompose, and ensure that we have meta kernels"
|
||||
" for the base ops that it decomposes into.")
|
||||
" for the base ops that it decomposes into."
|
||||
)
|
||||
|
||||
assert self.m is not None
|
||||
self.m.impl(name, dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd", fn, with_keyset)
|
||||
self.m.impl(
|
||||
name,
|
||||
dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd",
|
||||
fn,
|
||||
with_keyset,
|
||||
)
|
||||
|
||||
_impls.add(key)
|
||||
self._op_impls.add(key)
|
||||
|
|
@ -283,7 +329,9 @@ class Library:
|
|||
delattr(namespace, name)
|
||||
|
||||
|
||||
def _del_library(captured_impls, op_impls, captured_defs, op_defs, registration_handles):
|
||||
def _del_library(
|
||||
captured_impls, op_impls, captured_defs, op_defs, registration_handles
|
||||
):
|
||||
captured_impls -= op_impls
|
||||
captured_defs -= op_defs
|
||||
for handle in registration_handles:
|
||||
|
|
@ -357,7 +405,8 @@ def define(qualname, schema, *, lib=None, tags=()):
|
|||
if not isinstance(qualname, str):
|
||||
raise ValueError(
|
||||
f"define(qualname, schema): expected qualname "
|
||||
f"to be instance of str, got {type(qualname)}")
|
||||
f"to be instance of str, got {type(qualname)}"
|
||||
)
|
||||
namespace, name = torch._library.utils.parse_namespace(qualname)
|
||||
if lib is None:
|
||||
lib = Library(namespace, "FRAGMENT")
|
||||
|
|
@ -366,7 +415,8 @@ def define(qualname, schema, *, lib=None, tags=()):
|
|||
raise ValueError(
|
||||
f"define(qualname, schema, ...): expected schema "
|
||||
f'to look like e.g. "(Tensor x) -> Tensor" but '
|
||||
f'got "{schema}"')
|
||||
f'got "{schema}"'
|
||||
)
|
||||
lib.define(name + schema, alias_analysis="", tags=tags)
|
||||
|
||||
|
||||
|
|
@ -375,10 +425,12 @@ def _(lib: Library, schema, alias_analysis=""):
|
|||
"""The old torch.library.define.
|
||||
We're keeping this around for BC reasons
|
||||
"""
|
||||
|
||||
def wrap(f):
|
||||
name = lib.define(schema, alias_analysis)
|
||||
lib.impl(name, f)
|
||||
return f
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
|
|
@ -460,9 +512,11 @@ def _device_type_to_key(device_type: str) -> str:
|
|||
@impl.register
|
||||
def _(lib: Library, name, dispatch_key=""):
|
||||
"""Legacy torch.library.impl API. Kept around for BC"""
|
||||
|
||||
def wrap(f):
|
||||
lib.impl(name, f, dispatch_key)
|
||||
return f
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
|
|
@ -480,16 +534,19 @@ def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
|
|||
return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel)
|
||||
|
||||
|
||||
_op_identifier = Union[str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef"]
|
||||
_op_identifier = Union[
|
||||
str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef"
|
||||
]
|
||||
|
||||
|
||||
def register_kernel(
|
||||
op: _op_identifier,
|
||||
device_types: device_types_t,
|
||||
func: Optional[Callable] = None,
|
||||
/,
|
||||
*,
|
||||
lib: Optional[Library] = None):
|
||||
op: _op_identifier,
|
||||
device_types: device_types_t,
|
||||
func: Optional[Callable] = None,
|
||||
/,
|
||||
*,
|
||||
lib: Optional[Library] = None,
|
||||
):
|
||||
"""Register an implementation for a device type for this operator.
|
||||
|
||||
Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
|
||||
|
|
@ -530,7 +587,9 @@ def register_kernel(
|
|||
|
||||
"""
|
||||
|
||||
if not isinstance(op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)):
|
||||
if not isinstance(
|
||||
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
|
||||
):
|
||||
raise ValueError("register_kernel(op): got unexpected type for op: {type(op)}")
|
||||
if isinstance(op, torch._ops.OpOverload):
|
||||
op = op._name
|
||||
|
|
@ -544,12 +603,13 @@ def register_kernel(
|
|||
|
||||
|
||||
def register_fake(
|
||||
op: _op_identifier,
|
||||
func: Optional[Callable] = None,
|
||||
/,
|
||||
*,
|
||||
lib: Optional[Library] = None,
|
||||
_stacklevel: int = 1):
|
||||
op: _op_identifier,
|
||||
func: Optional[Callable] = None,
|
||||
/,
|
||||
*,
|
||||
lib: Optional[Library] = None,
|
||||
_stacklevel: int = 1,
|
||||
):
|
||||
r"""Register a FakeTensor implementation ("fake impl") for this operator.
|
||||
|
||||
Also sometimes known as a "meta kernel", "abstract impl".
|
||||
|
|
@ -630,7 +690,9 @@ def register_fake(
|
|||
>>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
|
||||
|
||||
"""
|
||||
if not isinstance(op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)):
|
||||
if not isinstance(
|
||||
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
|
||||
):
|
||||
raise ValueError("register_fake(op): got unexpected type for op: {type(op)}")
|
||||
if isinstance(op, torch._ops.OpOverload):
|
||||
op = op._name
|
||||
|
|
@ -661,7 +723,14 @@ def register_fake(
|
|||
return register(func)
|
||||
|
||||
|
||||
def register_autograd(op: _op_identifier, backward: Callable, /, *, setup_context: Optional[Callable] = None, lib=None) -> None:
|
||||
def register_autograd(
|
||||
op: _op_identifier,
|
||||
backward: Callable,
|
||||
/,
|
||||
*,
|
||||
setup_context: Optional[Callable] = None,
|
||||
lib=None,
|
||||
) -> None:
|
||||
r"""Register a backward formula for this custom op.
|
||||
|
||||
In order for an operator to work with autograd, you need to register
|
||||
|
|
@ -737,8 +806,12 @@ def register_autograd(op: _op_identifier, backward: Callable, /, *, setup_contex
|
|||
>>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
|
||||
|
||||
"""
|
||||
if not isinstance(op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)):
|
||||
raise ValueError(f"register_autograd(op): got unexpected type for op: {type(op)}")
|
||||
if not isinstance(
|
||||
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
|
||||
):
|
||||
raise ValueError(
|
||||
f"register_autograd(op): got unexpected type for op: {type(op)}"
|
||||
)
|
||||
if isinstance(op, torch._ops.OpOverload):
|
||||
op = op._name
|
||||
opdef = _maybe_get_opdef(op)
|
||||
|
|
@ -760,7 +833,8 @@ def register_autograd(op: _op_identifier, backward: Callable, /, *, setup_contex
|
|||
raise NotImplementedError(
|
||||
f"register_autograd with kwarg-only Tensor args. In the original "
|
||||
f"definition of the op, please make your tensors not kwarg-only. "
|
||||
f"Got: {schema}")
|
||||
f"Got: {schema}"
|
||||
)
|
||||
|
||||
info = _library.autograd.Info(backward, setup_context)
|
||||
autograd_kernel = _library.autograd.make_autograd_impl(op, info)
|
||||
|
|
@ -788,8 +862,8 @@ def _check_pystubs_once(func, qualname, actual_module_name):
|
|||
return func(*args, **kwargs)
|
||||
|
||||
maybe_pystub = torch._C._dispatch_pystub(
|
||||
op._schema.name,
|
||||
op._schema.overload_name)
|
||||
op._schema.name, op._schema.overload_name
|
||||
)
|
||||
if maybe_pystub is None:
|
||||
if torch._library.utils.requires_set_python_module():
|
||||
namespace = op.namespace
|
||||
|
|
@ -800,7 +874,8 @@ def _check_pystubs_once(func, qualname, actual_module_name):
|
|||
f'companion C++ `m.set_python_module("{actual_module_name}")` '
|
||||
f"call, but we could not find one. Please add that to "
|
||||
f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the "
|
||||
f"operator was registered in ({cpp_filename})")
|
||||
f"operator was registered in ({cpp_filename})"
|
||||
)
|
||||
else:
|
||||
pystub_module = maybe_pystub[0]
|
||||
if actual_module_name != pystub_module:
|
||||
|
|
@ -809,9 +884,11 @@ def _check_pystubs_once(func, qualname, actual_module_name):
|
|||
f"Operator '{qualname}' specified that its python fake impl "
|
||||
f"is in the Python module '{pystub_module}' but it was actually found "
|
||||
f"in '{actual_module_name}'. Please either move the fake impl "
|
||||
f"or correct the m.set_python_module call ({cpp_filename})")
|
||||
f"or correct the m.set_python_module call ({cpp_filename})"
|
||||
)
|
||||
checked = True
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
|
|
@ -929,4 +1006,7 @@ def opcheck(
|
|||
|
||||
"""
|
||||
import torch.testing._internal.optests as optests
|
||||
return optests.opcheck(op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception)
|
||||
|
||||
return optests.opcheck(
|
||||
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,19 +23,26 @@ instructions in the ``README.md`` in that directory.
|
|||
import __future__ # noqa: F404
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import functools
|
||||
import types
|
||||
import warnings
|
||||
from typing import Dict, Set, List, Any, Callable, Iterable, Type, Tuple
|
||||
from functools import wraps
|
||||
import contextlib
|
||||
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch._C import (
|
||||
_has_torch_function, _has_torch_function_unary,
|
||||
_has_torch_function_variadic, _add_docstr,
|
||||
_push_on_torch_function_stack, _pop_torch_function_stack, _get_function_stack_at, _len_torch_function_stack,
|
||||
_is_torch_function_mode_enabled)
|
||||
_add_docstr,
|
||||
_get_function_stack_at,
|
||||
_has_torch_function,
|
||||
_has_torch_function_unary,
|
||||
_has_torch_function_variadic,
|
||||
_is_torch_function_mode_enabled,
|
||||
_len_torch_function_stack,
|
||||
_pop_torch_function_stack,
|
||||
_push_on_torch_function_stack,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_ignored_functions",
|
||||
|
|
@ -52,7 +59,8 @@ __all__ = [
|
|||
|
||||
|
||||
def _disable_user_warnings(
|
||||
func: Callable, regex: str = '.*is deprecated, please use.*', module: str = 'torch') -> Callable:
|
||||
func: Callable, regex: str = ".*is deprecated, please use.*", module: str = "torch"
|
||||
) -> Callable:
|
||||
"""
|
||||
Decorator that temporarily disables ``UserWarning``s for the given ``module`` if the warning message matches the
|
||||
given ``regex`` pattern.
|
||||
|
|
@ -75,8 +83,11 @@ def _disable_user_warnings(
|
|||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=UserWarning, message=regex, module=module)
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=UserWarning, message=regex, module=module
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
|
|
@ -470,8 +481,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.batch_norm_update_stats: lambda input, running_mean, running_var, momentum: -1,
|
||||
torch.bernoulli: lambda input, generator=None, out=None: -1,
|
||||
torch.bilinear: lambda input1, input2, weight, bias: -1,
|
||||
torch.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None, reduce=None,
|
||||
reduction='mean', pos_weight=None: -1),
|
||||
torch.binary_cross_entropy_with_logits: (
|
||||
lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1
|
||||
),
|
||||
torch.bincount: lambda input, weights=None, minlength=0: -1,
|
||||
torch.binomial: lambda count, prob, generator=None: -1,
|
||||
torch.bitwise_and: lambda input, other, out=None: -1,
|
||||
|
|
@ -489,11 +501,11 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.cat: lambda tensors, dim=0, out=None: -1,
|
||||
torch.concat: lambda tensors, dim=0, out=None: -1, # alias for torch.cat
|
||||
torch.concatenate: lambda tensors, dim=0, out=None: -1, # alias for torch.concatenate
|
||||
torch.cdist: lambda x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary': -1,
|
||||
torch.cdist: lambda x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary": -1,
|
||||
torch.ceil: lambda input, out=None: -1,
|
||||
torch.celu: lambda input, alpha=1., inplace=False: -1,
|
||||
torch.celu: lambda input, alpha=1.0, inplace=False: -1,
|
||||
torch.chain_matmul: lambda *matrices, out=None: -1,
|
||||
torch.channel_shuffle: lambda input, groups : -1,
|
||||
torch.channel_shuffle: lambda input, groups: -1,
|
||||
torch.cholesky: lambda input, upper=False, out=None: -1,
|
||||
torch.linalg.cholesky: lambda input, out=None: -1,
|
||||
torch.linalg.cholesky_ex: lambda input, check_errors=False, out=None: -1,
|
||||
|
|
@ -528,14 +540,15 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.conv_transpose3d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
|
||||
torch.corrcoef: lambda input: -1,
|
||||
torch.cos: lambda input, out=None: -1,
|
||||
torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1,
|
||||
torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1,
|
||||
torch.cosh: lambda input, out=None: -1,
|
||||
torch.cosine_similarity: lambda x1, x2, dim=1, eps=1e-8: -1,
|
||||
torch.count_nonzero: lambda input: -1,
|
||||
torch.cross: lambda input, other, dim=None, out=None: -1,
|
||||
torch.linalg.cross: lambda input, other, dim=-1, out=None: -1,
|
||||
torch.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean',
|
||||
zero_infinity=False: -1),
|
||||
torch.ctc_loss: (
|
||||
lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1
|
||||
),
|
||||
torch.cummax: lambda input, dim, out=None: -1,
|
||||
torch.cummin: lambda input, dim, out=None: -1,
|
||||
torch.cumprod: lambda input, dim, out=None, dtype=None: -1,
|
||||
|
|
@ -570,10 +583,12 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.linalg.eigh: lambda input, UPLO="L", out=None: -1,
|
||||
torch.linalg.eigvalsh: lambda input, UPLO="L", out=None: -1,
|
||||
torch.einsum: lambda equation, *operands: -1,
|
||||
torch.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False,
|
||||
sparse=False: -1),
|
||||
torch.embedding_bag: (lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False,
|
||||
mode='mean', sparse=False, per_sample_weights=None, padding_idx=None: -1),
|
||||
torch.embedding: (
|
||||
lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 # noqa: B950
|
||||
),
|
||||
torch.embedding_bag: (
|
||||
lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, padding_idx=None: -1 # noqa: B950
|
||||
),
|
||||
torch.empty_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
|
||||
torch.eq: lambda input, other, out=None: -1,
|
||||
torch.equal: lambda input, other: -1,
|
||||
|
|
@ -585,14 +600,15 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.expm1: lambda input, out=None: -1,
|
||||
torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
|
||||
torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1,
|
||||
torch.fused_moving_avg_obs_fake_quant: (lambda x, observer_on, fake_quant_on, averaging_const, running_min,
|
||||
running_max, scale, zero_point, quant_min, quant_max, ch_axis,
|
||||
per_row_fake_quant=False, symmetric_quant=False: -1),
|
||||
torch.fused_moving_avg_obs_fake_quant: (
|
||||
lambda x, observer_on, fake_quant_on, averaging_const, running_min, running_max, scale, zero_point, quant_min, quant_max, ch_axis, per_row_fake_quant=False, symmetric_quant=False: -1 # noqa: B950
|
||||
),
|
||||
torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1,
|
||||
torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1,
|
||||
torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1,
|
||||
torch.fbgemm_linear_int8_weight_fp32_activation: (lambda input, weight, packed, col_offsets, weight_scale,
|
||||
weight_zero_point, bias: -1),
|
||||
torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1, # noqa: B950
|
||||
torch.fbgemm_linear_int8_weight_fp32_activation: (
|
||||
lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1
|
||||
),
|
||||
torch.fbgemm_linear_quantize_weight: lambda input: -1,
|
||||
torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1,
|
||||
torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1,
|
||||
|
|
@ -630,7 +646,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.fmod: lambda input, other, out=None: -1,
|
||||
torch.frac: lambda input, out=None: -1,
|
||||
torch.frexp: lambda input, out=None: -1,
|
||||
torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
|
||||
torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, # noqa: B950
|
||||
torch._functional_assert_async: lambda input, msg, dep_token: -1,
|
||||
torch.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1,
|
||||
torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1,
|
||||
|
|
@ -653,7 +669,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.greater: lambda input, other, out=None: -1,
|
||||
torch.hardshrink: lambda input, lambd=0.5: -1,
|
||||
torch.heaviside: lambda input, values, out=None: -1,
|
||||
torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction='mean': -1,
|
||||
torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950
|
||||
torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
|
||||
torch.histogram: lambda input, bins=100, min=None, max=None, weight=None, density=False, out=None: -1,
|
||||
torch.histogramdd: lambda input, bins, range=None, weight=None, density=False: -1,
|
||||
|
|
@ -677,8 +693,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.isreal: lambda tensor: -1,
|
||||
torch.isposinf: lambda input, out=None: -1,
|
||||
torch.isneginf: lambda input, out=None: -1,
|
||||
torch.instance_norm: (lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps,
|
||||
cudnn_enabled: -1),
|
||||
torch.instance_norm: (
|
||||
lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps, cudnn_enabled: -1
|
||||
),
|
||||
torch.int_repr: lambda input: -1,
|
||||
torch.inverse: lambda input, out=None: -1,
|
||||
torch.linalg.inv: lambda input, out=None: -1,
|
||||
|
|
@ -694,9 +711,10 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.is_signed: lambda input: -1,
|
||||
torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1,
|
||||
torch.isnan: lambda input: -1,
|
||||
torch.istft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
|
||||
normalized=False, onesided=None, length=None, return_complex=False: -1),
|
||||
torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
|
||||
torch.istft: (
|
||||
lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, normalized=False, onesided=None, length=None, return_complex=False: -1 # noqa: B950
|
||||
),
|
||||
torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1,
|
||||
torch.kron: lambda input, other: -1,
|
||||
torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1,
|
||||
torch.linalg.ldl_factor_ex: lambda input, hermitian=False, check_errors=False, out=None: -1,
|
||||
|
|
@ -709,8 +727,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.less_equal: lambda input, other, out=None: -1,
|
||||
torch.lerp: lambda input, end, weight, out=None: -1,
|
||||
torch.lgamma: lambda input, out=None: -1,
|
||||
torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None,
|
||||
tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1,
|
||||
torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None, tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1, # noqa: B950
|
||||
torch.log: lambda input, out=None: -1,
|
||||
torch.log_softmax: lambda input, dim, dtype=None: -1,
|
||||
torch.log10: lambda input, out=None: -1,
|
||||
|
|
@ -732,7 +749,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.less: lambda input, other, out=None: -1,
|
||||
torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1,
|
||||
torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1,
|
||||
torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1, # type: ignore[attr-defined] # noqa: B950
|
||||
torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1, # type: ignore[attr-defined] # noqa: B950
|
||||
torch.masked_fill: lambda input, mask, value: -1,
|
||||
torch.masked_scatter: lambda input, mask, source: -1,
|
||||
torch.masked_select: lambda input, mask, out=None: -1,
|
||||
|
|
@ -754,8 +771,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
|
||||
torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
|
||||
torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
|
||||
torch.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
||||
return_indices=False, ceil_mode=False: -1),
|
||||
torch.max_pool1d_with_indices: (
|
||||
lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
|
||||
),
|
||||
torch.mean: lambda input, dim=None: -1,
|
||||
torch.nanmean: lambda input, dim=None, keepdim=False, dtype=None, out=None: -1,
|
||||
torch.median: lambda input, dim=None: -1,
|
||||
|
|
@ -764,17 +782,21 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.min: lambda input, out=None: -1,
|
||||
torch.minimum: lambda input, other, out=None: -1,
|
||||
torch.fmin: lambda input, other, out=None: -1,
|
||||
torch.miopen_batch_norm: (lambda input, weight, bias, running_mean, running_var, training,
|
||||
exponential_average_factor, epsilon: -1),
|
||||
torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1,
|
||||
torch.miopen_batch_norm: (
|
||||
lambda input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon: -1
|
||||
),
|
||||
torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1, # noqa: B950
|
||||
torch.miopen_convolution_add_relu: lambda input, weight, z, alpha, bias, stride, padding, dilation, groups: -1,
|
||||
torch.miopen_convolution_relu: lambda input, weight, bias, stride, padding, dilation, groups: -1,
|
||||
torch.miopen_convolution_transpose: (lambda input, weight, bias, padding, output_padding, stride, dilation,
|
||||
groups, benchmark, deterministic: -1),
|
||||
torch.miopen_depthwise_convolution: (lambda input, weight, bias, padding, stride, dilation, groups, benchmark,
|
||||
deterministic: -1),
|
||||
torch.miopen_rnn: (lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first,
|
||||
dropout, train, bidirectional, batch_sizes, dropout_state: -1),
|
||||
torch.miopen_convolution_transpose: (
|
||||
lambda input, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic: -1
|
||||
),
|
||||
torch.miopen_depthwise_convolution: (
|
||||
lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1
|
||||
),
|
||||
torch.miopen_rnn: (
|
||||
lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state: -1 # noqa: B950
|
||||
),
|
||||
torch.mm: lambda input, mat2, out=None: -1,
|
||||
torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1,
|
||||
torch.movedim: lambda input, source, destination: -1,
|
||||
|
|
@ -793,7 +815,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
|
||||
torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
|
||||
torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1,
|
||||
torch.native_channel_shuffle: lambda input, groups : -1,
|
||||
torch.native_channel_shuffle: lambda input, groups: -1,
|
||||
torch.ne: lambda input, other, out=None: -1,
|
||||
torch.not_equal: lambda input, other, out=None: -1,
|
||||
torch.neg: lambda input, out=None: -1,
|
||||
|
|
@ -809,62 +831,76 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.nn.functional.adaptive_max_pool3d_with_indices: lambda input, output_size, return_indices=False: -1,
|
||||
torch.nn.functional.affine_grid: lambda theta, size, align_corners=None: -1,
|
||||
torch.nn.functional.alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
|
||||
torch.nn.functional.avg_pool2d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False,
|
||||
count_include_pad=True, divisor_override=None: -1),
|
||||
torch.nn.functional.avg_pool3d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False,
|
||||
count_include_pad=True, divisor_override=None: -1),
|
||||
torch.nn.functional.batch_norm: (lambda input, running_mean, running_var, weight=None, bias=None, training=False,
|
||||
momentum=0.1, eps=1e-05: -1),
|
||||
torch.nn.functional.avg_pool2d: (
|
||||
lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 # noqa: B950
|
||||
),
|
||||
torch.nn.functional.avg_pool3d: (
|
||||
lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 # noqa: B950
|
||||
),
|
||||
torch.nn.functional.batch_norm: (
|
||||
lambda input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05: -1
|
||||
),
|
||||
torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1,
|
||||
torch.nn.functional.binary_cross_entropy: (lambda input, target, weight=None, size_average=None, reduce=None,
|
||||
reduction="mean": -1),
|
||||
torch.nn.functional.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None,
|
||||
reduce=None, reduction="mean", pos_weight=None: -1),
|
||||
torch.nn.functional.binary_cross_entropy: (
|
||||
lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1
|
||||
),
|
||||
torch.nn.functional.binary_cross_entropy_with_logits: (
|
||||
lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1
|
||||
),
|
||||
torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1,
|
||||
torch.nn.functional.cosine_embedding_loss: (lambda input1, input2, target, margin=0, size_average=None,
|
||||
reduce=None, reduction='mean': -1),
|
||||
torch.nn.functional.cross_entropy: (lambda input, target, weight=None, size_average=None, ignore_index=-100,
|
||||
reduce=None, reduction="mean", label_smoothing=0.0: -1),
|
||||
torch.nn.functional.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0,
|
||||
reduction='mean', zero_infinity=False: -1),
|
||||
torch.nn.functional.cosine_embedding_loss: (
|
||||
lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1
|
||||
),
|
||||
torch.nn.functional.cross_entropy: (
|
||||
lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", label_smoothing=0.0: -1 # noqa: B950
|
||||
),
|
||||
torch.nn.functional.ctc_loss: (
|
||||
lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1
|
||||
),
|
||||
torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1,
|
||||
torch.nn.functional.dropout1d: lambda input, p=0.5, training=True, inplace=False: -1,
|
||||
torch.nn.functional.dropout2d: lambda input, p=0.5, training=True, inplace=False: -1,
|
||||
torch.nn.functional.dropout3d: lambda input, p=0.5, training=True, inplace=False: -1,
|
||||
torch.nn.functional.elu: lambda input, alpha=1.0, inplace=False: -1,
|
||||
torch.nn.functional.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0,
|
||||
scale_grad_by_freq=False, sparse=False: -1),
|
||||
torch.nn.functional.embedding_bag: (lambda input, weight, offsets=None, max_norm=None, norm_type=2,
|
||||
scale_grad_by_freq=False, mode='mean', sparse=False, per_sample_weights=None,
|
||||
include_last_offset=False, padding_idx=None: -1),
|
||||
torch.nn.functional.embedding: (
|
||||
lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 # noqa: B950
|
||||
),
|
||||
torch.nn.functional.embedding_bag: (
|
||||
lambda input, weight, offsets=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None: -1 # noqa: B950
|
||||
),
|
||||
torch.nn.functional.feature_alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
|
||||
torch.nn.functional.fold: lambda input, output_size, kernel_size, dilation=1, padding=0, stride=1: -1,
|
||||
torch.nn.functional.fractional_max_pool2d: (lambda input, kernel_size, output_size=None, output_ratio=None,
|
||||
return_indices=False, _random_samples=None: -1),
|
||||
torch.nn.functional.fractional_max_pool2d: (
|
||||
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
|
||||
),
|
||||
torch.nn.functional.fractional_max_pool2d_with_indices: (
|
||||
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
|
||||
_random_samples=None: -1),
|
||||
torch.nn.functional.fractional_max_pool3d: (lambda input, kernel_size, output_size=None, output_ratio=None,
|
||||
return_indices=False, _random_samples=None: -1),
|
||||
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
|
||||
),
|
||||
torch.nn.functional.fractional_max_pool3d: (
|
||||
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
|
||||
),
|
||||
torch.nn.functional.fractional_max_pool3d_with_indices: (
|
||||
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
|
||||
_random_samples=None: -1),
|
||||
torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction='mean': -1,
|
||||
torch.nn.functional.gelu: lambda input, approximate='none': -1,
|
||||
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
|
||||
),
|
||||
torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction="mean": -1,
|
||||
torch.nn.functional.gelu: lambda input, approximate="none": -1,
|
||||
torch.nn.functional.glu: lambda input, dim=-1: -1,
|
||||
torch.nn.functional.grid_sample: lambda input, grid, mode='bilinear', padding_mode='zeros', align_corners=None: -1,
|
||||
torch.nn.functional.grid_sample: lambda input, grid, mode="bilinear", padding_mode="zeros", align_corners=None: -1, # noqa: B950
|
||||
torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1,
|
||||
torch.nn.functional.gumbel_softmax: lambda logits, tau=1, hard=False, eps=1e-10, dim=-1: -1,
|
||||
torch.nn.functional.hardshrink: lambda input, lambd=0.5: -1,
|
||||
torch.nn.functional.hardtanh: lambda input, min_val=-1., max_val=1., inplace=False: -1,
|
||||
torch.nn.functional.hinge_embedding_loss: (lambda input, target, margin=1.0, size_average=None, reduce=None,
|
||||
reduction='mean': -1),
|
||||
torch.nn.functional.instance_norm: (lambda input, running_mean=None, running_var=None, weight=None, bias=None,
|
||||
use_input_stats=True, momentum=0.1, eps=1e-05: -1),
|
||||
torch.nn.functional.interpolate: (lambda input, size=None, scale_factor=None, mode='nearest', align_corners=None,
|
||||
recompute_scale_factor=None, antialias=False: -1),
|
||||
torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
|
||||
torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
|
||||
torch.nn.functional.hardtanh: lambda input, min_val=-1.0, max_val=1.0, inplace=False: -1,
|
||||
torch.nn.functional.hinge_embedding_loss: (
|
||||
lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1
|
||||
),
|
||||
torch.nn.functional.instance_norm: (
|
||||
lambda input, running_mean=None, running_var=None, weight=None, bias=None, use_input_stats=True, momentum=0.1, eps=1e-05: -1 # noqa: B950
|
||||
),
|
||||
torch.nn.functional.interpolate: (
|
||||
lambda input, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None, antialias=False: -1 # noqa: B950
|
||||
),
|
||||
torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1, # noqa: B950
|
||||
torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1,
|
||||
torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
|
||||
torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1,
|
||||
torch.nn.functional.linear: lambda input, weight, bias=None: -1,
|
||||
|
|
@ -874,55 +910,65 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
|
||||
torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
|
||||
torch.nn.functional.lp_pool3d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
|
||||
torch.nn.functional.margin_ranking_loss: (lambda input1, input2, target, margin=0, size_average=None,
|
||||
reduce=None, reduction='mean': -1),
|
||||
torch.nn.functional.max_pool1d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
||||
ceil_mode=False, return_indices=False: -1),
|
||||
torch.nn.functional.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
||||
return_indices=False, ceil_mode=False: -1),
|
||||
torch.nn.functional.max_pool2d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
||||
ceil_mode=False, return_indices=False: -1),
|
||||
torch.nn.functional.max_pool2d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
||||
return_indices=False, ceil_mode=False: -1),
|
||||
torch.nn.functional.max_pool3d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
||||
return_indices=False, ceil_mode=False: -1),
|
||||
torch.nn.functional.max_pool3d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
||||
return_indices=False, ceil_mode=False: -1),
|
||||
torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
|
||||
torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
|
||||
torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
|
||||
torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
|
||||
torch.nn.functional.margin_ranking_loss: (
|
||||
lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1
|
||||
),
|
||||
torch.nn.functional.max_pool1d: (
|
||||
lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1
|
||||
),
|
||||
torch.nn.functional.max_pool1d_with_indices: (
|
||||
lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
|
||||
),
|
||||
torch.nn.functional.max_pool2d: (
|
||||
lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1
|
||||
),
|
||||
torch.nn.functional.max_pool2d_with_indices: (
|
||||
lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
|
||||
),
|
||||
torch.nn.functional.max_pool3d: (
|
||||
lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
|
||||
),
|
||||
torch.nn.functional.max_pool3d_with_indices: (
|
||||
lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
|
||||
),
|
||||
torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
|
||||
torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
|
||||
torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
|
||||
torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1,
|
||||
torch.nn.functional.multi_head_attention_forward: (
|
||||
lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v,
|
||||
add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None,
|
||||
need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None,
|
||||
v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1),
|
||||
torch.nn.functional.multi_margin_loss: (lambda input, target, p=1, margin=1.0, weight=None, size_average=None,
|
||||
reduce=None, reduction='mean': -1),
|
||||
torch.nn.functional.multilabel_margin_loss: (lambda input, target, size_average=None, reduce=None,
|
||||
reduction='mean': -1),
|
||||
torch.nn.functional.multilabel_soft_margin_loss: (lambda input, target, weight=None, size_average=None,
|
||||
reduce=None, reduction='mean': -1),
|
||||
torch.nn.functional.nll_loss: (lambda input, target, weight=None, size_average=None, ignore_index=-100,
|
||||
reduce=None, reduction='mean': -1),
|
||||
lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None, need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None, v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1 # noqa: B950
|
||||
),
|
||||
torch.nn.functional.multi_margin_loss: (
|
||||
lambda input, target, p=1, margin=1.0, weight=None, size_average=None, reduce=None, reduction="mean": -1
|
||||
),
|
||||
torch.nn.functional.multilabel_margin_loss: (
|
||||
lambda input, target, size_average=None, reduce=None, reduction="mean": -1
|
||||
),
|
||||
torch.nn.functional.multilabel_soft_margin_loss: (
|
||||
lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1
|
||||
),
|
||||
torch.nn.functional.nll_loss: (
|
||||
lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean": -1
|
||||
),
|
||||
torch.nn.functional.normalize: lambda input, p=2, dim=1, eps=1e-12, out=None: -1,
|
||||
torch.nn.functional.one_hot: lambda tensor, num_classes=-1: -1,
|
||||
torch.nn.functional.pad: lambda input, pad, mode='constant', value=0: -1,
|
||||
torch.nn.functional.pad: lambda input, pad, mode="constant", value=0: -1,
|
||||
torch.nn.functional.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
|
||||
torch.nn.functional.poisson_nll_loss: (lambda input, target, log_input=True, full=False, size_average=None,
|
||||
eps=1e-08, reduce=None, reduction='mean': -1),
|
||||
torch.nn.functional.poisson_nll_loss: (
|
||||
lambda input, target, log_input=True, full=False, size_average=None, eps=1e-08, reduce=None, reduction="mean": -1 # noqa: B950
|
||||
),
|
||||
torch.nn.functional.prelu: lambda input, weight: -1,
|
||||
torch.nn.functional.relu: lambda input, inplace=False: -1,
|
||||
torch.nn.functional.relu6: lambda input, inplace=False: -1,
|
||||
torch.nn.functional.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
|
||||
torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1,
|
||||
torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1, # noqa: B950
|
||||
torch.nn.functional.selu: lambda input, inplace=False: -1,
|
||||
torch.nn.functional.silu: lambda input, inplace=False: -1,
|
||||
torch.nn.functional.mish: lambda input, inplace=False: -1,
|
||||
torch.nn.functional.scaled_dot_product_attention: lambda query, key, value, attn_mask=None, dropout_p=0.0: -1,
|
||||
torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean', beta=1.: -1,
|
||||
torch.nn.functional.huber_loss: lambda input, target, reduction='mean', delta=1.: -1,
|
||||
torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
|
||||
torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", beta=1.0: -1, # noqa: B950
|
||||
torch.nn.functional.huber_loss: lambda input, target, reduction="mean", delta=1.0: -1,
|
||||
torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950
|
||||
torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
|
||||
torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
|
||||
torch.nn.functional.softplus: lambda input, beta=1, threshold=20: -1,
|
||||
|
|
@ -930,25 +976,29 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.nn.functional.softsign: lambda input: -1,
|
||||
torch.nn.functional.tanhshrink: lambda input: -1,
|
||||
torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1,
|
||||
torch.nn.functional.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06,
|
||||
swap=False, size_average=None, reduce=None, reduction='mean': -1),
|
||||
torch.nn.functional.triplet_margin_with_distance_loss: (lambda anchor, positive, negative, *,
|
||||
distance_function=None, margin=1.0,
|
||||
swap=False, reduction='mean': -1),
|
||||
torch.nn.functional.triplet_margin_loss: (
|
||||
lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 # noqa: B950
|
||||
),
|
||||
torch.nn.functional.triplet_margin_with_distance_loss: (
|
||||
lambda anchor, positive, negative, *, distance_function=None, margin=1.0, swap=False, reduction="mean": -1
|
||||
),
|
||||
torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
|
||||
torch.nn.init.uniform_: lambda tensor, a=0., b=1., generator=None: -1,
|
||||
torch.nn.init.normal_: lambda tensor, mean=0., std=1., generator=None: -1,
|
||||
torch.nn.init.uniform_: lambda tensor, a=0.0, b=1.0, generator=None: -1,
|
||||
torch.nn.init.normal_: lambda tensor, mean=0.0, std=1.0, generator=None: -1,
|
||||
torch.nn.init.constant_: lambda tensor, val: -1,
|
||||
torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=None: -1,
|
||||
torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", generator=None: -1, # noqa: B950
|
||||
torch.nonzero: lambda input, as_tuple=False: -1,
|
||||
torch.nonzero_static: lambda input, *, size, fill_value=-1: -1,
|
||||
torch.argwhere: lambda input: -1,
|
||||
torch.norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
|
||||
torch.norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1,
|
||||
torch.linalg.norm: lambda input, ord=None, dim=None, keepdim=False, out=None, dtype=None: -1,
|
||||
torch.linalg.vector_norm: lambda input, ord=2, dim=None, keepdim=False, out=None, dtype=None: -1,
|
||||
torch.linalg.matrix_norm: lambda input, ord='fro', dim=(-2, -1), keepdim=False, out=None, dtype=None: -1,
|
||||
torch.linalg.matrix_norm: lambda input, ord="fro", dim=(
|
||||
-2,
|
||||
-1,
|
||||
), keepdim=False, out=None, dtype=None: -1,
|
||||
torch.norm_except_dim: lambda v, pow=2, dim=0: -1,
|
||||
torch.nuclear_norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
|
||||
torch.nuclear_norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1,
|
||||
torch.numel: lambda input: -1,
|
||||
torch.orgqr: lambda input, tau: -1,
|
||||
torch.ormqr: lambda input, input2, input3, left=True, transpose=False: -1,
|
||||
|
|
@ -975,28 +1025,43 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.q_scale: lambda input: -1,
|
||||
torch.q_zero_point: lambda input: -1,
|
||||
torch.qr: lambda input, some=True, out=None: -1,
|
||||
torch.linalg.qr: lambda input, mode='reduced', out=None: -1,
|
||||
torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1,
|
||||
torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1,
|
||||
torch.linalg.qr: lambda input, mode="reduced", out=None: -1,
|
||||
torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1,
|
||||
torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1,
|
||||
torch.quantize_per_channel: lambda input, scales, zero_points, axis, dtype: -1,
|
||||
torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1,
|
||||
torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1,
|
||||
torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1,
|
||||
torch.quantized_gru_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
|
||||
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
|
||||
|
||||
torch.quantized_lstm_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
|
||||
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
|
||||
torch.quantized_max_pool1d: (lambda input, kernel_size, stride=tuple(), padding=(0,),
|
||||
dilation=(1,), ceil_mode=False: -1),
|
||||
torch.quantized_max_pool2d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0),
|
||||
dilation=(1, 1), ceil_mode=False: -1),
|
||||
torch.quantized_max_pool3d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0, 0),
|
||||
dilation=(1, 1, 1), ceil_mode=False: -1),
|
||||
torch.quantized_rnn_relu_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
|
||||
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
|
||||
torch.quantized_rnn_tanh_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
|
||||
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
|
||||
torch.quantized_gru_cell: (
|
||||
lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
|
||||
),
|
||||
torch.quantized_lstm_cell: (
|
||||
lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
|
||||
),
|
||||
torch.quantized_max_pool1d: (
|
||||
lambda input, kernel_size, stride=tuple(), padding=(0,), dilation=(
|
||||
1,
|
||||
), ceil_mode=False: -1
|
||||
),
|
||||
torch.quantized_max_pool2d: (
|
||||
lambda input, kernel_size, stride=tuple(), padding=(0, 0), dilation=(
|
||||
1,
|
||||
1,
|
||||
), ceil_mode=False: -1
|
||||
),
|
||||
torch.quantized_max_pool3d: (
|
||||
lambda input, kernel_size, stride=tuple(), padding=(0, 0, 0), dilation=(
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
), ceil_mode=False: -1
|
||||
),
|
||||
torch.quantized_rnn_relu_cell: (
|
||||
lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
|
||||
),
|
||||
torch.quantized_rnn_tanh_cell: (
|
||||
lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
|
||||
),
|
||||
torch.rad2deg: lambda input, out=None: -1,
|
||||
torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
|
||||
torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
|
||||
|
|
@ -1014,16 +1079,16 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.repeat_interleave: lambda input, dim=None: -1,
|
||||
torch.reshape: lambda input, shape: -1,
|
||||
torch.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
|
||||
torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
|
||||
torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, # noqa: B950
|
||||
torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
|
||||
torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
|
||||
torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, # noqa: B950
|
||||
torch.rnn_tanh_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
|
||||
torch.roll: lambda input, shifts, dims=None: -1,
|
||||
torch.rot90: lambda input, k=1, dims=(0, 1): -1,
|
||||
torch.round: lambda input, out=None: -1,
|
||||
torch.row_stack: lambda tensors, out=None: -1, # alias for torch.vstack
|
||||
torch._rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1),
|
||||
torch.rrelu: lambda input, lower=1. / 8, upper=1. / 3, training=False, inplace=False: -1,
|
||||
torch.rrelu: lambda input, lower=1.0 / 8, upper=1.0 / 3, training=False, inplace=False: -1,
|
||||
torch.rsqrt: lambda input, out=None: -1,
|
||||
torch.rsub: lambda input, other, alpha=1: -1,
|
||||
torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
|
||||
|
|
@ -1031,7 +1096,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.scatter_add: lambda input, dim, index, src: -1,
|
||||
torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1,
|
||||
torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
|
||||
torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1,
|
||||
torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1, # noqa: B950
|
||||
torch.select: lambda input, dim, index: -1,
|
||||
torch.select_scatter: lambda input, src, dim, index: -1,
|
||||
torch.slice_inverse: lambda input, src, dim=0, start=None, end=None, step=1: -1,
|
||||
|
|
@ -1061,8 +1126,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.stack: lambda tensors, dim=0, out=None: -1,
|
||||
torch.std: lambda input, dim=None: -1,
|
||||
torch.std_mean: lambda input, dim=None: -1,
|
||||
torch.stft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
|
||||
pad_mode='reflect', normalized=False, onesided=True, return_complex=None: -1),
|
||||
torch.stft: (
|
||||
lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=None: -1 # noqa: B950
|
||||
),
|
||||
torch.sub: lambda input, other, out=None: -1,
|
||||
torch.subtract: lambda input, other, out=None: -1,
|
||||
torch.sum: lambda input, dim=None: -1,
|
||||
|
|
@ -1164,9 +1230,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1,
|
||||
torch.linalg.solve_triangular: lambda input, B, upper, left=True, unitriangular=False: -1,
|
||||
torch.tril: lambda input, diagonal=0, out=None: -1,
|
||||
torch.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False,
|
||||
|
||||
size_average=None, reduce=None, reduction='mean': -1),
|
||||
torch.triplet_margin_loss: (
|
||||
lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 # noqa: B950
|
||||
),
|
||||
torch.triu: lambda input, diagonal=0, out=None: -1,
|
||||
torch.true_divide: lambda input, other: -1,
|
||||
torch.trunc: lambda input, out=None: -1,
|
||||
|
|
@ -1436,10 +1502,16 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1,
|
||||
}
|
||||
|
||||
privateuse1_backend_name = torch.utils.backend_registration._privateuse1_backend_name
|
||||
privateuse1_backend_name = (
|
||||
torch.utils.backend_registration._privateuse1_backend_name
|
||||
)
|
||||
if hasattr(Tensor, privateuse1_backend_name):
|
||||
ret[getattr(Tensor, privateuse1_backend_name)] = lambda self, device=None, non_blocking=False, **kwargs: -1
|
||||
ret[getattr(Tensor, f'is_{privateuse1_backend_name}').__get__] = lambda self: -1 # noqa: B009
|
||||
ret[
|
||||
getattr(Tensor, privateuse1_backend_name)
|
||||
] = lambda self, device=None, non_blocking=False, **kwargs: -1
|
||||
ret[
|
||||
getattr(Tensor, f"is_{privateuse1_backend_name}").__get__
|
||||
] = lambda self: -1 # noqa: B009
|
||||
|
||||
ret2 = {}
|
||||
ignored = get_ignored_functions()
|
||||
|
|
@ -1457,12 +1529,10 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
if k.__name__.startswith("bitwise_"):
|
||||
# bitwise_<op> have dunder methods of the form __<op>__
|
||||
# And so on.
|
||||
subname = k.__name__[len("bitwise_"):]
|
||||
names.extend([
|
||||
"__" + subname + "__",
|
||||
"__i" + subname + "__",
|
||||
"__r" + subname + "__"
|
||||
])
|
||||
subname = k.__name__[len("bitwise_") :]
|
||||
names.extend(
|
||||
["__" + subname + "__", "__i" + subname + "__", "__r" + subname + "__"]
|
||||
)
|
||||
|
||||
for name in names:
|
||||
func = getattr(Tensor, name, None)
|
||||
|
|
@ -1472,6 +1542,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
ret.update(ret2)
|
||||
return ret
|
||||
|
||||
|
||||
def wrap_torch_function(dispatcher: Callable):
|
||||
"""Wraps a given function with ``__torch_function__`` -related functionality.
|
||||
|
||||
|
|
@ -1495,6 +1566,7 @@ def wrap_torch_function(dispatcher: Callable):
|
|||
>>> def func(a): # This will make func dispatchable by __torch_function__
|
||||
... return a + 0
|
||||
"""
|
||||
|
||||
def inner(func):
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
|
|
@ -1508,7 +1580,10 @@ def wrap_torch_function(dispatcher: Callable):
|
|||
|
||||
return inner
|
||||
|
||||
def _get_overloaded_args(relevant_args: Iterable[Any], get_type_fn: Callable[[Any], Type] = None) -> List[Any]:
|
||||
|
||||
def _get_overloaded_args(
|
||||
relevant_args: Iterable[Any], get_type_fn: Callable[[Any], Type] = None
|
||||
) -> List[Any]:
|
||||
"""Returns a list of arguments on which to call __torch_function__.
|
||||
|
||||
Checks arguments in relevant_args for __torch_function__ implementations,
|
||||
|
|
@ -1559,8 +1634,11 @@ def _get_overloaded_args(relevant_args: Iterable[Any], get_type_fn: Callable[[An
|
|||
#
|
||||
# NB: Important to exclude _disabled_torch_function_impl, otherwise
|
||||
# https://github.com/pytorch/pytorch/issues/64687
|
||||
if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__') and
|
||||
arg_type.__torch_function__ != torch._C._disabled_torch_function_impl):
|
||||
if (
|
||||
arg_type not in overloaded_types
|
||||
and hasattr(arg_type, "__torch_function__")
|
||||
and arg_type.__torch_function__ != torch._C._disabled_torch_function_impl
|
||||
):
|
||||
# Create lists explicitly for the first type (usually the only one
|
||||
# done) to avoid setting up the iterator for overloaded_args.
|
||||
if overloaded_types:
|
||||
|
|
@ -1581,7 +1659,8 @@ def _get_overloaded_args(relevant_args: Iterable[Any], get_type_fn: Callable[[An
|
|||
|
||||
|
||||
def handle_torch_function(
|
||||
public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs) -> Any:
|
||||
public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs
|
||||
) -> Any:
|
||||
"""Implement a function with checks for ``__torch_function__`` overrides.
|
||||
|
||||
See torch::autograd::handle_torch_function for the equivalent of this
|
||||
|
|
@ -1636,11 +1715,16 @@ def handle_torch_function(
|
|||
# This call needs to become a classmethod call in the future.
|
||||
# See https://github.com/pytorch/pytorch/issues/63767
|
||||
torch_func_method = overloaded_arg.__torch_function__
|
||||
if hasattr(torch_func_method, "__self__") and torch_func_method.__self__ is overloaded_arg and \
|
||||
torch_func_method is not torch._C._disabled_torch_function_impl:
|
||||
warnings.warn("Defining your `__torch_function__ as a plain method is deprecated and "
|
||||
"will be an error in future, please define it as a classmethod.",
|
||||
DeprecationWarning)
|
||||
if (
|
||||
hasattr(torch_func_method, "__self__")
|
||||
and torch_func_method.__self__ is overloaded_arg
|
||||
and torch_func_method is not torch._C._disabled_torch_function_impl
|
||||
):
|
||||
warnings.warn(
|
||||
"Defining your `__torch_function__ as a plain method is deprecated and "
|
||||
"will be an error in future, please define it as a classmethod.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
# Use `public_api` instead of `implementation` so __torch_function__
|
||||
# implementations can do equality/identity comparisons.
|
||||
|
|
@ -1649,15 +1733,16 @@ def handle_torch_function(
|
|||
if result is not NotImplemented:
|
||||
return result
|
||||
|
||||
func_name = f'{public_api.__module__}.{public_api.__name__}'
|
||||
func_name = f"{public_api.__module__}.{public_api.__name__}"
|
||||
msg = (
|
||||
f"no implementation found for '{func_name}' on types that implement "
|
||||
f'__torch_function__: {[type(arg) for arg in overloaded_args]}'
|
||||
f"__torch_function__: {[type(arg) for arg in overloaded_args]}"
|
||||
)
|
||||
if _is_torch_function_mode_enabled():
|
||||
msg += f" nor in mode {_get_current_function_mode()}"
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
has_torch_function = _add_docstr(
|
||||
_has_torch_function,
|
||||
r"""Check for __torch_function__ implementations in the elements of an iterable
|
||||
|
|
@ -1678,7 +1763,7 @@ has_torch_function = _add_docstr(
|
|||
________
|
||||
torch.is_tensor_like
|
||||
Checks if something is a Tensor-like, including an exact ``Tensor``.
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
has_torch_function_unary = _add_docstr(
|
||||
|
|
@ -1689,7 +1774,7 @@ has_torch_function_unary = _add_docstr(
|
|||
call:
|
||||
`has_torch_function_unary(t)`
|
||||
which skips unnecessary packing and unpacking work.
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
has_torch_function_variadic = _add_docstr(
|
||||
|
|
@ -1703,11 +1788,14 @@ has_torch_function_variadic = _add_docstr(
|
|||
call:
|
||||
`has_torch_function_variadic(a, b)`
|
||||
which skips unnecessary packing and unpacking work.
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]:
|
||||
def _get_overridable_functions() -> (
|
||||
Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]
|
||||
):
|
||||
overridable_funcs = collections.defaultdict(list)
|
||||
index = {}
|
||||
tested_namespaces = [
|
||||
|
|
@ -1725,21 +1813,21 @@ def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callab
|
|||
ignore = False
|
||||
# ignore private functions or functions that are deleted in torch.__init__
|
||||
if namespace is not torch.Tensor:
|
||||
if func_name.startswith('__'):
|
||||
if func_name.startswith("__"):
|
||||
continue
|
||||
elif func_name.startswith('_'):
|
||||
elif func_name.startswith("_"):
|
||||
ignore = True
|
||||
elif func_name.endswith('_'):
|
||||
elif func_name.endswith("_"):
|
||||
ignore = True
|
||||
elif not func_name[0].islower():
|
||||
ignore = True
|
||||
elif func_name == 'unique_dim':
|
||||
elif func_name == "unique_dim":
|
||||
continue
|
||||
else:
|
||||
func = getattr(namespace, func_name)
|
||||
if getattr(object, func_name, None) == func:
|
||||
continue
|
||||
if func_name == '__weakref__':
|
||||
if func_name == "__weakref__":
|
||||
continue
|
||||
func = getattr(namespace, func_name)
|
||||
if namespace is torch.Tensor and getattr(object, func_name, None) == func:
|
||||
|
|
@ -1757,9 +1845,13 @@ def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callab
|
|||
if ignore:
|
||||
continue
|
||||
if func.__get__ in get_ignored_functions():
|
||||
msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
|
||||
"but still has an explicit override")
|
||||
assert func.__get__ not in get_testing_overrides(), msg.format(namespace, func.__name__)
|
||||
msg = (
|
||||
"{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
|
||||
"but still has an explicit override"
|
||||
)
|
||||
assert func.__get__ not in get_testing_overrides(), msg.format(
|
||||
namespace, func.__name__
|
||||
)
|
||||
continue
|
||||
else:
|
||||
overridable_funcs[func].append(func.__get__)
|
||||
|
|
@ -1775,13 +1867,18 @@ def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callab
|
|||
|
||||
# cannot be overriden by __torch_function__
|
||||
if func in get_ignored_functions():
|
||||
msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
|
||||
"but still has an explicit override")
|
||||
assert func not in get_testing_overrides(), msg.format(namespace, func.__name__)
|
||||
msg = (
|
||||
"{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
|
||||
"but still has an explicit override"
|
||||
)
|
||||
assert func not in get_testing_overrides(), msg.format(
|
||||
namespace, func.__name__
|
||||
)
|
||||
continue
|
||||
overridable_funcs[namespace].append(func)
|
||||
return overridable_funcs, index
|
||||
|
||||
|
||||
@_disable_user_warnings
|
||||
def get_overridable_functions() -> Dict[Any, List[Callable]]:
|
||||
"""List functions that are overridable via __torch_function__
|
||||
|
|
@ -1794,6 +1891,7 @@ def get_overridable_functions() -> Dict[Any, List[Callable]]:
|
|||
"""
|
||||
return _get_overridable_functions()[0]
|
||||
|
||||
|
||||
@_disable_user_warnings
|
||||
def resolve_name(f):
|
||||
"""Get a human readable string name for a function passed to
|
||||
|
|
@ -1814,13 +1912,15 @@ def resolve_name(f):
|
|||
return str(f)
|
||||
return _get_overridable_functions()[1].get(f)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _get_tensor_methods() -> Set[Callable]:
|
||||
""" Returns a set of the overridable methods on ``torch.Tensor`` """
|
||||
"""Returns a set of the overridable methods on ``torch.Tensor``"""
|
||||
overridable_funcs = get_overridable_functions()
|
||||
methods = set(overridable_funcs[torch.Tensor])
|
||||
return methods
|
||||
|
||||
|
||||
@_disable_user_warnings
|
||||
def is_tensor_method_or_property(func: Callable) -> bool:
|
||||
"""
|
||||
|
|
@ -1846,6 +1946,7 @@ def is_tensor_method_or_property(func: Callable) -> bool:
|
|||
"""
|
||||
return func in _get_tensor_methods() or func.__name__ == "__get__"
|
||||
|
||||
|
||||
def is_tensor_like(inp):
|
||||
"""
|
||||
Returns ``True`` if the passed-in input is a Tensor-like.
|
||||
|
|
@ -1882,6 +1983,7 @@ def is_tensor_like(inp):
|
|||
"""
|
||||
return type(inp) is torch.Tensor or hasattr(inp, "__torch_function__")
|
||||
|
||||
|
||||
class TorchFunctionMode:
|
||||
"""
|
||||
A ``TorchFunctionMode`` allows you to override the meaning of all
|
||||
|
|
@ -1912,6 +2014,7 @@ class TorchFunctionMode:
|
|||
``enable_torch_function_mode(self, replace=self.inner)`` to make PyTorch
|
||||
API self-referential (beware of infinite loops, in this case!)
|
||||
"""
|
||||
|
||||
inner: "TorchFunctionMode"
|
||||
|
||||
# Force metaclass to generate constructor at the base of the hierarchy
|
||||
|
|
@ -1930,7 +2033,9 @@ class TorchFunctionMode:
|
|||
|
||||
@classmethod
|
||||
def push(cls, *args, **kwargs):
|
||||
warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`")
|
||||
warnings.warn(
|
||||
"`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`"
|
||||
)
|
||||
instance = cls(*args, **kwargs)
|
||||
return instance
|
||||
|
||||
|
|
@ -1944,6 +2049,7 @@ def _get_current_function_mode_stack():
|
|||
stack_len = _len_torch_function_stack()
|
||||
return [_get_function_stack_at(i) for i in range(stack_len)]
|
||||
|
||||
|
||||
def _push_mode(mode):
|
||||
_push_on_torch_function_stack(mode)
|
||||
|
||||
|
|
@ -1961,6 +2067,7 @@ def _pop_mode_temporarily():
|
|||
finally:
|
||||
_push_mode(old)
|
||||
|
||||
|
||||
class BaseTorchFunctionMode(TorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class SobolEngine:
|
||||
r"""
|
||||
|
|
@ -48,8 +49,10 @@ class SobolEngine:
|
|||
|
||||
def __init__(self, dimension, scramble=False, seed=None):
|
||||
if dimension > self.MAXDIM or dimension < 1:
|
||||
raise ValueError("Supported range of dimensionality "
|
||||
f"for SobolEngine is [1, {self.MAXDIM}]")
|
||||
raise ValueError(
|
||||
"Supported range of dimensionality "
|
||||
f"for SobolEngine is [1, {self.MAXDIM}]"
|
||||
)
|
||||
|
||||
self.seed = seed
|
||||
self.scramble = scramble
|
||||
|
|
@ -57,7 +60,9 @@ class SobolEngine:
|
|||
|
||||
cpu = torch.device("cpu")
|
||||
|
||||
self.sobolstate = torch.zeros(dimension, self.MAXBIT, device=cpu, dtype=torch.long)
|
||||
self.sobolstate = torch.zeros(
|
||||
dimension, self.MAXBIT, device=cpu, dtype=torch.long
|
||||
)
|
||||
torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension)
|
||||
|
||||
if not self.scramble:
|
||||
|
|
@ -66,11 +71,15 @@ class SobolEngine:
|
|||
self._scramble()
|
||||
|
||||
self.quasi = self.shift.clone(memory_format=torch.contiguous_format)
|
||||
self._first_point = (self.quasi / 2 ** self.MAXBIT).reshape(1, -1)
|
||||
self._first_point = (self.quasi / 2**self.MAXBIT).reshape(1, -1)
|
||||
self.num_generated = 0
|
||||
|
||||
def draw(self, n: int = 1, out: Optional[torch.Tensor] = None,
|
||||
dtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
||||
def draw(
|
||||
self,
|
||||
n: int = 1,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Function to draw a sequence of :attr:`n` points from a Sobol sequence.
|
||||
Note that the samples are dependent on the previous samples. The size
|
||||
|
|
@ -92,12 +101,22 @@ class SobolEngine:
|
|||
result = self._first_point.to(dtype)
|
||||
else:
|
||||
result, self.quasi = torch._sobol_engine_draw(
|
||||
self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated, dtype=dtype,
|
||||
self.quasi,
|
||||
n - 1,
|
||||
self.sobolstate,
|
||||
self.dimension,
|
||||
self.num_generated,
|
||||
dtype=dtype,
|
||||
)
|
||||
result = torch.cat((self._first_point.to(dtype), result), dim=-2)
|
||||
else:
|
||||
result, self.quasi = torch._sobol_engine_draw(
|
||||
self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1, dtype=dtype,
|
||||
self.quasi,
|
||||
n,
|
||||
self.sobolstate,
|
||||
self.dimension,
|
||||
self.num_generated - 1,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
self.num_generated += n
|
||||
|
|
@ -108,8 +127,12 @@ class SobolEngine:
|
|||
|
||||
return result
|
||||
|
||||
def draw_base2(self, m: int, out: Optional[torch.Tensor] = None,
|
||||
dtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
||||
def draw_base2(
|
||||
self,
|
||||
m: int,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Function to draw a sequence of :attr:`2**m` points from a Sobol sequence.
|
||||
Note that the samples are dependent on the previous samples. The size
|
||||
|
|
@ -122,15 +145,16 @@ class SobolEngine:
|
|||
returned tensor.
|
||||
Default: ``None``
|
||||
"""
|
||||
n = 2 ** m
|
||||
n = 2**m
|
||||
total_n = self.num_generated + n
|
||||
if not (total_n & (total_n - 1) == 0):
|
||||
raise ValueError("The balance properties of Sobol' points require "
|
||||
f"n to be a power of 2. {self.num_generated} points have been "
|
||||
f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. "
|
||||
"If you still want to do this, please use "
|
||||
"'SobolEngine.draw()' instead."
|
||||
)
|
||||
raise ValueError(
|
||||
"The balance properties of Sobol' points require "
|
||||
f"n to be a power of 2. {self.num_generated} points have been "
|
||||
f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. "
|
||||
"If you still want to do this, please use "
|
||||
"'SobolEngine.draw()' instead."
|
||||
)
|
||||
return self.draw(n=n, out=out, dtype=dtype)
|
||||
|
||||
def reset(self):
|
||||
|
|
@ -151,9 +175,13 @@ class SobolEngine:
|
|||
n (Int): The number of steps to fast-forward by.
|
||||
"""
|
||||
if self.num_generated == 0:
|
||||
torch._sobol_engine_ff_(self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated)
|
||||
torch._sobol_engine_ff_(
|
||||
self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated
|
||||
)
|
||||
else:
|
||||
torch._sobol_engine_ff_(self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1)
|
||||
torch._sobol_engine_ff_(
|
||||
self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1
|
||||
)
|
||||
self.num_generated += n
|
||||
return self
|
||||
|
||||
|
|
@ -166,8 +194,12 @@ class SobolEngine:
|
|||
cpu = torch.device("cpu")
|
||||
|
||||
# Generate shift vector
|
||||
shift_ints = torch.randint(2, (self.dimension, self.MAXBIT), device=cpu, generator=g)
|
||||
self.shift = torch.mv(shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu)))
|
||||
shift_ints = torch.randint(
|
||||
2, (self.dimension, self.MAXBIT), device=cpu, generator=g
|
||||
)
|
||||
self.shift = torch.mv(
|
||||
shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu))
|
||||
)
|
||||
|
||||
# Generate lower triangular matrices (stacked across dimensions)
|
||||
ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT)
|
||||
|
|
@ -176,9 +208,9 @@ class SobolEngine:
|
|||
torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension)
|
||||
|
||||
def __repr__(self):
|
||||
fmt_string = [f'dimension={self.dimension}']
|
||||
fmt_string = [f"dimension={self.dimension}"]
|
||||
if self.scramble:
|
||||
fmt_string += ['scramble=True']
|
||||
fmt_string += ["scramble=True"]
|
||||
if self.seed is not None:
|
||||
fmt_string += [f'seed={self.seed}']
|
||||
return self.__class__.__name__ + '(' + ', '.join(fmt_string) + ')'
|
||||
fmt_string += [f"seed={self.seed}"]
|
||||
return self.__class__.__name__ + "(" + ", ".join(fmt_string) + ")"
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
from typing import Generator
|
||||
import warnings
|
||||
from typing import Generator
|
||||
|
||||
from torch._C import default_generator
|
||||
import torch
|
||||
from torch._C import default_generator
|
||||
|
||||
|
||||
def set_rng_state(new_state: torch.Tensor) -> None:
|
||||
|
|
@ -46,10 +46,12 @@ def manual_seed(seed) -> torch._C.Generator:
|
|||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
import torch.mps
|
||||
|
||||
if not torch.mps._is_in_bad_fork():
|
||||
torch.mps.manual_seed(seed)
|
||||
|
||||
import torch.xpu
|
||||
|
||||
if not torch.xpu._is_in_bad_fork():
|
||||
torch.xpu.manual_seed_all(seed)
|
||||
|
||||
|
|
@ -69,10 +71,12 @@ def seed() -> int:
|
|||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
import torch.mps
|
||||
|
||||
if not torch.mps._is_in_bad_fork():
|
||||
torch.mps.manual_seed(seed)
|
||||
|
||||
import torch.xpu
|
||||
|
||||
if not torch.xpu._is_in_bad_fork():
|
||||
torch.xpu.manual_seed_all(seed)
|
||||
|
||||
|
|
@ -95,7 +99,9 @@ def _seed_custom_device(seed) -> None:
|
|||
custom_device_mod = getattr(torch, custom_backend_name)
|
||||
_bad_fork_name = "_is_in_bad_fork"
|
||||
_seed_all_name = "manual_seed_all"
|
||||
if hasattr(custom_device_mod, _bad_fork_name) and hasattr(custom_device_mod, _seed_all_name):
|
||||
if hasattr(custom_device_mod, _bad_fork_name) and hasattr(
|
||||
custom_device_mod, _seed_all_name
|
||||
):
|
||||
if not getattr(custom_device_mod, _bad_fork_name)():
|
||||
getattr(custom_device_mod, _seed_all_name)(seed)
|
||||
else:
|
||||
|
|
@ -117,7 +123,13 @@ _fork_rng_warned_already = False
|
|||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices", device_type="cuda") -> Generator:
|
||||
def fork_rng(
|
||||
devices=None,
|
||||
enabled=True,
|
||||
_caller="fork_rng",
|
||||
_devices_kw="devices",
|
||||
device_type="cuda",
|
||||
) -> Generator:
|
||||
"""
|
||||
Forks the RNG, so that when you return, the RNG is reset
|
||||
to the state that it was previously in.
|
||||
|
|
@ -138,8 +150,10 @@ def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="device
|
|||
device_type = torch.device(device_type).type
|
||||
device_mod = getattr(torch, device_type, None)
|
||||
if device_mod is None:
|
||||
raise RuntimeError(f"torch has no module of `{device_type}`, you should register " +
|
||||
"a module by `torch._register_device_module`.")
|
||||
raise RuntimeError(
|
||||
f"torch has no module of `{device_type}`, you should register "
|
||||
+ "a module by `torch._register_device_module`."
|
||||
)
|
||||
global _fork_rng_warned_already
|
||||
|
||||
# Internal arguments:
|
||||
|
|
@ -153,17 +167,19 @@ def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="device
|
|||
if devices is None:
|
||||
num_devices = device_mod.device_count()
|
||||
if num_devices > 1 and not _fork_rng_warned_already:
|
||||
message = (f"{device_type.upper()} reports that you have {num_devices} available devices, and "
|
||||
f"you have used {_caller} without explicitly specifying which devices are being used. "
|
||||
f"For safety, we initialize *every* {device_type.upper()} device by default, which can "
|
||||
f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only"
|
||||
f" making use of a few {device_type.upper()} devices, set the environment variable "
|
||||
f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} "
|
||||
"with the set of devices you are actually using. For example, if you are using CPU only, "
|
||||
"set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, "
|
||||
f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices "
|
||||
f"and suppress this warning, set the '{_devices_kw}' keyword argument to "
|
||||
f"`range(torch.{device_type}.device_count())`.")
|
||||
message = (
|
||||
f"{device_type.upper()} reports that you have {num_devices} available devices, and "
|
||||
f"you have used {_caller} without explicitly specifying which devices are being used. "
|
||||
f"For safety, we initialize *every* {device_type.upper()} device by default, which can "
|
||||
f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only"
|
||||
f" making use of a few {device_type.upper()} devices, set the environment variable "
|
||||
f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} "
|
||||
"with the set of devices you are actually using. For example, if you are using CPU only, "
|
||||
"set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, "
|
||||
f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices "
|
||||
f"and suppress this warning, set the '{_devices_kw}' keyword argument to "
|
||||
f"`range(torch.{device_type}.device_count())`."
|
||||
)
|
||||
warnings.warn(message)
|
||||
_fork_rng_warned_already = True
|
||||
devices = list(range(num_devices))
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import torch
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import register_pytree_node, SequenceKey
|
||||
|
||||
|
||||
__all__ = ["pytree_register_structseq", "all_return_types"]
|
||||
|
||||
all_return_types = []
|
||||
|
|
@ -10,6 +11,7 @@ all_return_types = []
|
|||
# error: Module has no attribute "_return_types"
|
||||
return_types = torch._C._return_types # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def pytree_register_structseq(cls):
|
||||
def structseq_flatten(structseq):
|
||||
return list(structseq), None
|
||||
|
|
@ -28,14 +30,15 @@ def pytree_register_structseq(cls):
|
|||
flatten_with_keys_fn=structseq_flatten_with_keys,
|
||||
)
|
||||
|
||||
|
||||
for name in dir(return_types):
|
||||
if name.startswith('__'):
|
||||
if name.startswith("__"):
|
||||
continue
|
||||
|
||||
_attr = getattr(return_types, name)
|
||||
globals()[name] = _attr
|
||||
|
||||
if not name.startswith('_'):
|
||||
if not name.startswith("_"):
|
||||
__all__.append(name)
|
||||
all_return_types.append(_attr)
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -2,8 +2,9 @@
|
|||
|
||||
from typing import Any, Iterable
|
||||
|
||||
from ._vendor.packaging.version import InvalidVersion, Version
|
||||
from .version import __version__ as internal_version
|
||||
from torch._vendor.packaging.version import InvalidVersion, Version
|
||||
from torch.version import __version__ as internal_version
|
||||
|
||||
|
||||
__all__ = ["TorchVersion"]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,14 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import builtins
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.autograd.graph import GradientEdge
|
||||
|
||||
|
||||
# Convenience aliases for common composite types that we need
|
||||
# to talk about in PyTorch
|
||||
|
||||
|
|
@ -11,8 +16,8 @@ _TensorOrTensors = Union[torch.Tensor, Sequence[torch.Tensor]]
|
|||
_TensorOrTensorsOrGradEdge = Union[
|
||||
torch.Tensor,
|
||||
Sequence[torch.Tensor],
|
||||
"torch.autograd.graph.GradientEdge",
|
||||
Sequence["torch.autograd.graph.GradientEdge"],
|
||||
"GradientEdge",
|
||||
Sequence["GradientEdge"],
|
||||
]
|
||||
|
||||
# In some cases, these basic types are shadowed by corresponding
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user