mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[BE] enable UFMT for top-level files torch/*.py (#127707)
Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127707 Approved by: https://github.com/ezyang
This commit is contained in:
parent
cc231a8e2b
commit
dd143d44cc
|
|
@ -1556,7 +1556,6 @@ exclude_patterns = [
|
||||||
'torch/distributed/tensor/parallel/style.py',
|
'torch/distributed/tensor/parallel/style.py',
|
||||||
'torch/fft/__init__.py',
|
'torch/fft/__init__.py',
|
||||||
'torch/func/__init__.py',
|
'torch/func/__init__.py',
|
||||||
'torch/functional.py',
|
|
||||||
'torch/futures/__init__.py',
|
'torch/futures/__init__.py',
|
||||||
'torch/fx/__init__.py',
|
'torch/fx/__init__.py',
|
||||||
'torch/fx/_compatibility.py',
|
'torch/fx/_compatibility.py',
|
||||||
|
|
@ -1642,8 +1641,6 @@ exclude_patterns = [
|
||||||
'torch/fx/subgraph_rewriter.py',
|
'torch/fx/subgraph_rewriter.py',
|
||||||
'torch/fx/tensor_type.py',
|
'torch/fx/tensor_type.py',
|
||||||
'torch/fx/traceback.py',
|
'torch/fx/traceback.py',
|
||||||
'torch/hub.py',
|
|
||||||
'torch/library.py',
|
|
||||||
'torch/linalg/__init__.py',
|
'torch/linalg/__init__.py',
|
||||||
'torch/monitor/__init__.py',
|
'torch/monitor/__init__.py',
|
||||||
'torch/nested/__init__.py',
|
'torch/nested/__init__.py',
|
||||||
|
|
@ -1767,11 +1764,6 @@ exclude_patterns = [
|
||||||
'torch/nn/utils/rnn.py',
|
'torch/nn/utils/rnn.py',
|
||||||
'torch/nn/utils/spectral_norm.py',
|
'torch/nn/utils/spectral_norm.py',
|
||||||
'torch/nn/utils/weight_norm.py',
|
'torch/nn/utils/weight_norm.py',
|
||||||
'torch/overrides.py',
|
|
||||||
'torch/quasirandom.py',
|
|
||||||
'torch/random.py',
|
|
||||||
'torch/return_types.py',
|
|
||||||
'torch/serialization.py',
|
|
||||||
'torch/signal/__init__.py',
|
'torch/signal/__init__.py',
|
||||||
'torch/signal/windows/__init__.py',
|
'torch/signal/windows/__init__.py',
|
||||||
'torch/signal/windows/windows.py',
|
'torch/signal/windows/windows.py',
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import enum
|
import enum
|
||||||
import functools
|
import functools
|
||||||
|
|
@ -31,6 +30,7 @@ from torch.utils import _pytree as pytree
|
||||||
from torch.utils._traceback import CapturedTraceback
|
from torch.utils._traceback import CapturedTraceback
|
||||||
from torch.utils.weak import WeakTensorKeyDictionary
|
from torch.utils.weak import WeakTensorKeyDictionary
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -40,7 +40,6 @@ if TYPE_CHECKING:
|
||||||
# Import the following modules during type checking to enable code intelligence features,
|
# Import the following modules during type checking to enable code intelligence features,
|
||||||
# such as auto-completion in tools like pylance, even when these modules are not explicitly
|
# such as auto-completion in tools like pylance, even when these modules are not explicitly
|
||||||
# imported in user code.
|
# imported in user code.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -176,7 +175,7 @@ class Guard:
|
||||||
def sort_key(self):
|
def sort_key(self):
|
||||||
# Put the duplicate input guards at the end. The duplicate guards have
|
# Put the duplicate input guards at the end. The duplicate guards have
|
||||||
# two sources while guard.name only considers one source.
|
# two sources while guard.name only considers one source.
|
||||||
from ._dynamo.guards import GuardBuilder
|
from torch._dynamo.guards import GuardBuilder
|
||||||
|
|
||||||
is_duplicate_input = (
|
is_duplicate_input = (
|
||||||
isinstance(self.create_fn, functools.partial)
|
isinstance(self.create_fn, functools.partial)
|
||||||
|
|
|
||||||
|
|
@ -7,9 +7,8 @@
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import _linalg_utils as _utils, Tensor
|
||||||
from . import _linalg_utils as _utils
|
from torch.overrides import handle_torch_function, has_torch_function
|
||||||
from .overrides import handle_torch_function, has_torch_function
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["lobpcg"]
|
__all__ = ["lobpcg"]
|
||||||
|
|
|
||||||
|
|
@ -6,9 +6,8 @@ __all__ = ["svd_lowrank", "pca_lowrank"]
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import _linalg_utils as _utils, Tensor
|
||||||
from . import _linalg_utils as _utils
|
from torch.overrides import handle_torch_function, has_torch_function
|
||||||
from .overrides import handle_torch_function, has_torch_function
|
|
||||||
|
|
||||||
|
|
||||||
def get_approximate_basis(
|
def get_approximate_basis(
|
||||||
|
|
|
||||||
|
|
@ -761,22 +761,22 @@ class Tensor(torch._C.TensorBase):
|
||||||
return torch.norm(self, p, dim, keepdim, dtype=dtype)
|
return torch.norm(self, p, dim, keepdim, dtype=dtype)
|
||||||
|
|
||||||
def solve(self, other):
|
def solve(self, other):
|
||||||
from ._linalg_utils import solve
|
from torch._linalg_utils import solve
|
||||||
|
|
||||||
return solve(self, other)
|
return solve(self, other)
|
||||||
|
|
||||||
def lstsq(self, other):
|
def lstsq(self, other):
|
||||||
from ._linalg_utils import lstsq
|
from torch._linalg_utils import lstsq
|
||||||
|
|
||||||
return lstsq(self, other)
|
return lstsq(self, other)
|
||||||
|
|
||||||
def eig(self, eigenvectors=False):
|
def eig(self, eigenvectors=False):
|
||||||
from ._linalg_utils import eig
|
from torch._linalg_utils import eig
|
||||||
|
|
||||||
return eig(self, eigenvectors=eigenvectors)
|
return eig(self, eigenvectors=eigenvectors)
|
||||||
|
|
||||||
def symeig(self, eigenvectors=False):
|
def symeig(self, eigenvectors=False):
|
||||||
from ._linalg_utils import _symeig
|
from torch._linalg_utils import _symeig
|
||||||
|
|
||||||
return _symeig(self, eigenvectors=eigenvectors)
|
return _symeig(self, eigenvectors=eigenvectors)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,47 +1,46 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
from typing import (
|
|
||||||
List, Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING
|
|
||||||
)
|
|
||||||
import operator
|
|
||||||
import itertools
|
import itertools
|
||||||
|
import operator
|
||||||
|
from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._C import _add_docstr
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from ._lowrank import svd_lowrank, pca_lowrank
|
from torch import _VF, Tensor
|
||||||
from .overrides import (
|
from torch._C import _add_docstr
|
||||||
has_torch_function, has_torch_function_unary, has_torch_function_variadic,
|
from torch._jit_internal import _overload as overload, boolean_dispatch
|
||||||
handle_torch_function)
|
from torch._lowrank import pca_lowrank, svd_lowrank
|
||||||
from ._jit_internal import boolean_dispatch
|
from torch.overrides import (
|
||||||
from ._jit_internal import _overload as overload
|
handle_torch_function,
|
||||||
|
has_torch_function,
|
||||||
|
has_torch_function_unary,
|
||||||
|
has_torch_function_variadic,
|
||||||
|
)
|
||||||
|
|
||||||
Tensor = torch.Tensor
|
|
||||||
from torch import _VF
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'atleast_1d',
|
"atleast_1d",
|
||||||
'atleast_2d',
|
"atleast_2d",
|
||||||
'atleast_3d',
|
"atleast_3d",
|
||||||
'align_tensors',
|
"align_tensors",
|
||||||
'broadcast_shapes',
|
"broadcast_shapes",
|
||||||
'broadcast_tensors',
|
"broadcast_tensors",
|
||||||
'cartesian_prod',
|
"cartesian_prod",
|
||||||
'block_diag',
|
"block_diag",
|
||||||
'cdist',
|
"cdist",
|
||||||
'chain_matmul',
|
"chain_matmul",
|
||||||
'einsum',
|
"einsum",
|
||||||
'istft',
|
"istft",
|
||||||
'lu',
|
"lu",
|
||||||
'norm',
|
"norm",
|
||||||
'meshgrid',
|
"meshgrid",
|
||||||
'pca_lowrank',
|
"pca_lowrank",
|
||||||
'split',
|
"split",
|
||||||
'stft',
|
"stft",
|
||||||
'svd_lowrank',
|
"svd_lowrank",
|
||||||
'tensordot',
|
"tensordot",
|
||||||
'unique',
|
"unique",
|
||||||
'unique_consecutive',
|
"unique_consecutive",
|
||||||
'unravel_index',
|
"unravel_index",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -124,16 +123,25 @@ def broadcast_shapes(*shapes):
|
||||||
if isinstance(shape, (tuple, list)):
|
if isinstance(shape, (tuple, list)):
|
||||||
for i in range(-1, -1 - len(shape), -1):
|
for i in range(-1, -1 - len(shape), -1):
|
||||||
if shape[i] < 0:
|
if shape[i] < 0:
|
||||||
raise RuntimeError(f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})")
|
raise RuntimeError(
|
||||||
|
f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})"
|
||||||
|
)
|
||||||
# NB: result is initialized to 1 so this is effectively an
|
# NB: result is initialized to 1 so this is effectively an
|
||||||
# equals one test
|
# equals one test
|
||||||
if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious(shape[i] == result[i]):
|
if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious(
|
||||||
|
shape[i] == result[i]
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
if result[i] != 1:
|
if result[i] != 1:
|
||||||
raise RuntimeError("Shape mismatch: objects cannot be broadcast to a single shape")
|
raise RuntimeError(
|
||||||
|
"Shape mismatch: objects cannot be broadcast to a single shape"
|
||||||
|
)
|
||||||
result[i] = shape[i]
|
result[i] = shape[i]
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", shape)
|
raise RuntimeError(
|
||||||
|
"Input shapes should be of type ints, a tuple of ints, or a list of ints, got ",
|
||||||
|
shape,
|
||||||
|
)
|
||||||
return torch.Size(result)
|
return torch.Size(result)
|
||||||
else:
|
else:
|
||||||
# with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail
|
# with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail
|
||||||
|
|
@ -188,7 +196,8 @@ def split(
|
||||||
"""
|
"""
|
||||||
if has_torch_function_unary(tensor):
|
if has_torch_function_unary(tensor):
|
||||||
return handle_torch_function(
|
return handle_torch_function(
|
||||||
split, (tensor,), tensor, split_size_or_sections, dim=dim)
|
split, (tensor,), tensor, split_size_or_sections, dim=dim
|
||||||
|
)
|
||||||
# Overwriting reason:
|
# Overwriting reason:
|
||||||
# This dispatches to two ATen functions depending on the type of
|
# This dispatches to two ATen functions depending on the type of
|
||||||
# split_size_or_sections. The branching code is in _tensor.py, which we
|
# split_size_or_sections. The branching code is in _tensor.py, which we
|
||||||
|
|
@ -335,10 +344,13 @@ def einsum(*args: Any) -> Tensor:
|
||||||
[ 0.3311, 5.5201, -3.0356]])
|
[ 0.3311, 5.5201, -3.0356]])
|
||||||
"""
|
"""
|
||||||
import torch.backends.opt_einsum as opt_einsum
|
import torch.backends.opt_einsum as opt_einsum
|
||||||
|
|
||||||
# This wrapper exists to support variadic args.
|
# This wrapper exists to support variadic args.
|
||||||
if len(args) < 2:
|
if len(args) < 2:
|
||||||
raise ValueError('einsum(): must specify the equation string and at least one operand, '
|
raise ValueError(
|
||||||
'or at least one operand and its subscripts list')
|
"einsum(): must specify the equation string and at least one operand, "
|
||||||
|
"or at least one operand and its subscripts list"
|
||||||
|
)
|
||||||
|
|
||||||
equation = None
|
equation = None
|
||||||
operands = None
|
operands = None
|
||||||
|
|
@ -350,19 +362,21 @@ def einsum(*args: Any) -> Tensor:
|
||||||
# input operands into a tensorlist (List[Tensor]).
|
# input operands into a tensorlist (List[Tensor]).
|
||||||
def parse_subscript(n: int) -> str:
|
def parse_subscript(n: int) -> str:
|
||||||
if n == Ellipsis:
|
if n == Ellipsis:
|
||||||
return '...'
|
return "..."
|
||||||
if n >= 0 and n < 26:
|
if n >= 0 and n < 26:
|
||||||
return chr(ord('A') + n)
|
return chr(ord("A") + n)
|
||||||
if n >= 26 and n < 52:
|
if n >= 26 and n < 52:
|
||||||
return chr(ord('a') + n - 26)
|
return chr(ord("a") + n - 26)
|
||||||
raise ValueError('einsum(): subscript in subscript list is not within the valid range [0, 52)')
|
raise ValueError(
|
||||||
|
"einsum(): subscript in subscript list is not within the valid range [0, 52)"
|
||||||
|
)
|
||||||
|
|
||||||
# Parse subscripts for input operands
|
# Parse subscripts for input operands
|
||||||
equation = ','.join(''.join(parse_subscript(s) for s in l) for l in args[1::2])
|
equation = ",".join("".join(parse_subscript(s) for s in l) for l in args[1::2])
|
||||||
|
|
||||||
# Parse optional output subscripts (provided when the number of arguments is odd)
|
# Parse optional output subscripts (provided when the number of arguments is odd)
|
||||||
if len(args) % 2 == 1:
|
if len(args) % 2 == 1:
|
||||||
equation += '->' + ''.join(parse_subscript(s) for s in args[-1])
|
equation += "->" + "".join(parse_subscript(s) for s in args[-1])
|
||||||
operands = args[:-1:2]
|
operands = args[:-1:2]
|
||||||
else:
|
else:
|
||||||
operands = args[::2]
|
operands = args[::2]
|
||||||
|
|
@ -388,7 +402,9 @@ def einsum(*args: Any) -> Tensor:
|
||||||
path = None
|
path = None
|
||||||
if opt_einsum.is_available():
|
if opt_einsum.is_available():
|
||||||
_opt_einsum = opt_einsum.get_opt_einsum()
|
_opt_einsum = opt_einsum.get_opt_einsum()
|
||||||
tupled_path = _opt_einsum.contract_path(equation, *operands, optimize=opt_einsum.strategy)[0]
|
tupled_path = _opt_einsum.contract_path(
|
||||||
|
equation, *operands, optimize=opt_einsum.strategy
|
||||||
|
)[0]
|
||||||
# flatten path for dispatching to C++
|
# flatten path for dispatching to C++
|
||||||
path = [item for pair in tupled_path for item in pair]
|
path = [item for pair in tupled_path for item in pair]
|
||||||
return _VF.einsum(equation, operands, path=path) # type: ignore[attr-defined]
|
return _VF.einsum(equation, operands, path=path) # type: ignore[attr-defined]
|
||||||
|
|
@ -397,10 +413,13 @@ def einsum(*args: Any) -> Tensor:
|
||||||
# This wrapper exists to support variadic args.
|
# This wrapper exists to support variadic args.
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
# The JIT doesn't understand Union, so only add type annotation for mypy
|
# The JIT doesn't understand Union, so only add type annotation for mypy
|
||||||
def meshgrid(*tensors: Union[Tensor, List[Tensor]],
|
def meshgrid(
|
||||||
indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
|
*tensors: Union[Tensor, List[Tensor]], indexing: Optional[str] = None
|
||||||
|
) -> Tuple[Tensor, ...]:
|
||||||
return _meshgrid(*tensors, indexing=indexing)
|
return _meshgrid(*tensors, indexing=indexing)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def meshgrid(*tensors, indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
|
def meshgrid(*tensors, indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
|
||||||
r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors.
|
r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors.
|
||||||
|
|
||||||
|
|
@ -509,15 +528,22 @@ def _meshgrid(*tensors, indexing: Optional[str]):
|
||||||
# kwarg for forward compatibility reasons.
|
# kwarg for forward compatibility reasons.
|
||||||
#
|
#
|
||||||
# Remove this two weeks after landing.
|
# Remove this two weeks after landing.
|
||||||
kwargs = {} if indexing is None else {'indexing': indexing}
|
kwargs = {} if indexing is None else {"indexing": indexing}
|
||||||
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
|
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
|
def stft(
|
||||||
win_length: Optional[int] = None, window: Optional[Tensor] = None,
|
input: Tensor,
|
||||||
center: bool = True, pad_mode: str = 'reflect', normalized: bool = False,
|
n_fft: int,
|
||||||
|
hop_length: Optional[int] = None,
|
||||||
|
win_length: Optional[int] = None,
|
||||||
|
window: Optional[Tensor] = None,
|
||||||
|
center: bool = True,
|
||||||
|
pad_mode: str = "reflect",
|
||||||
|
normalized: bool = False,
|
||||||
onesided: Optional[bool] = None,
|
onesided: Optional[bool] = None,
|
||||||
return_complex: Optional[bool] = None) -> Tensor:
|
return_complex: Optional[bool] = None,
|
||||||
|
) -> Tensor:
|
||||||
r"""Short-time Fourier transform (STFT).
|
r"""Short-time Fourier transform (STFT).
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
|
|
@ -652,9 +678,19 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
|
||||||
"""
|
"""
|
||||||
if has_torch_function_unary(input):
|
if has_torch_function_unary(input):
|
||||||
return handle_torch_function(
|
return handle_torch_function(
|
||||||
stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
|
stft,
|
||||||
window=window, center=center, pad_mode=pad_mode, normalized=normalized,
|
(input,),
|
||||||
onesided=onesided, return_complex=return_complex)
|
input,
|
||||||
|
n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
win_length=win_length,
|
||||||
|
window=window,
|
||||||
|
center=center,
|
||||||
|
pad_mode=pad_mode,
|
||||||
|
normalized=normalized,
|
||||||
|
onesided=onesided,
|
||||||
|
return_complex=return_complex,
|
||||||
|
)
|
||||||
# NOTE: Do not edit. This code will be removed once the forward-compatibility
|
# NOTE: Do not edit. This code will be removed once the forward-compatibility
|
||||||
# period is over for PR #73432
|
# period is over for PR #73432
|
||||||
if center:
|
if center:
|
||||||
|
|
@ -663,8 +699,16 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
|
||||||
pad = int(n_fft // 2)
|
pad = int(n_fft // 2)
|
||||||
input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
|
input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
|
||||||
input = input.view(input.shape[-signal_dim:])
|
input = input.view(input.shape[-signal_dim:])
|
||||||
return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined]
|
return _VF.stft( # type: ignore[attr-defined]
|
||||||
normalized, onesided, return_complex)
|
input,
|
||||||
|
n_fft,
|
||||||
|
hop_length,
|
||||||
|
win_length,
|
||||||
|
window,
|
||||||
|
normalized,
|
||||||
|
onesided,
|
||||||
|
return_complex,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
istft = _add_docstr(
|
istft = _add_docstr(
|
||||||
|
|
@ -746,7 +790,8 @@ Args:
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: Least squares estimation of the original signal of shape `(B?, length)` where
|
Tensor: Least squares estimation of the original signal of shape `(B?, length)` where
|
||||||
`B?` is an optional batch dimension from the input tensor.
|
`B?` is an optional batch dimension from the input tensor.
|
||||||
""")
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -758,9 +803,13 @@ else:
|
||||||
_unique_impl_out = Tuple[Tensor, Tensor, Tensor]
|
_unique_impl_out = Tuple[Tensor, Tensor, Tensor]
|
||||||
|
|
||||||
|
|
||||||
def _unique_impl(input: Tensor, sorted: bool = True,
|
def _unique_impl(
|
||||||
return_inverse: bool = False, return_counts: bool = False,
|
input: Tensor,
|
||||||
dim: Optional[int] = None) -> _unique_impl_out:
|
sorted: bool = True,
|
||||||
|
return_inverse: bool = False,
|
||||||
|
return_counts: bool = False,
|
||||||
|
dim: Optional[int] = None,
|
||||||
|
) -> _unique_impl_out:
|
||||||
r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> Tuple[Tensor, Tensor, Tensor]
|
r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> Tuple[Tensor, Tensor, Tensor]
|
||||||
|
|
||||||
Returns the unique elements of the input tensor.
|
Returns the unique elements of the input tensor.
|
||||||
|
|
@ -896,8 +945,14 @@ def _unique_impl(input: Tensor, sorted: bool = True,
|
||||||
"""
|
"""
|
||||||
if has_torch_function_unary(input):
|
if has_torch_function_unary(input):
|
||||||
return handle_torch_function(
|
return handle_torch_function(
|
||||||
unique, (input,), input, sorted=sorted, return_inverse=return_inverse,
|
unique,
|
||||||
return_counts=return_counts, dim=dim)
|
(input,),
|
||||||
|
input,
|
||||||
|
sorted=sorted,
|
||||||
|
return_inverse=return_inverse,
|
||||||
|
return_counts=return_counts,
|
||||||
|
dim=dim,
|
||||||
|
)
|
||||||
|
|
||||||
if dim is not None:
|
if dim is not None:
|
||||||
output, inverse_indices, counts = _VF.unique_dim(
|
output, inverse_indices, counts = _VF.unique_dim(
|
||||||
|
|
@ -917,9 +972,12 @@ def _unique_impl(input: Tensor, sorted: bool = True,
|
||||||
return output, inverse_indices, counts
|
return output, inverse_indices, counts
|
||||||
|
|
||||||
|
|
||||||
def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False,
|
def _unique_consecutive_impl(
|
||||||
|
input: Tensor,
|
||||||
|
return_inverse: bool = False,
|
||||||
return_counts: bool = False,
|
return_counts: bool = False,
|
||||||
dim: Optional[int] = None) -> _unique_impl_out:
|
dim: Optional[int] = None,
|
||||||
|
) -> _unique_impl_out:
|
||||||
r"""Eliminates all but the first element from every consecutive group of equivalent elements.
|
r"""Eliminates all but the first element from every consecutive group of equivalent elements.
|
||||||
|
|
||||||
.. note:: This function is different from :func:`torch.unique` in the sense that this function
|
.. note:: This function is different from :func:`torch.unique` in the sense that this function
|
||||||
|
|
@ -971,14 +1029,22 @@ def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False,
|
||||||
"""
|
"""
|
||||||
if has_torch_function_unary(input):
|
if has_torch_function_unary(input):
|
||||||
return handle_torch_function(
|
return handle_torch_function(
|
||||||
unique_consecutive, (input,), input, return_inverse=return_inverse,
|
unique_consecutive,
|
||||||
return_counts=return_counts, dim=dim)
|
(input,),
|
||||||
|
input,
|
||||||
|
return_inverse=return_inverse,
|
||||||
|
return_counts=return_counts,
|
||||||
|
dim=dim,
|
||||||
|
)
|
||||||
output, inverse_indices, counts = _VF.unique_consecutive( # type: ignore[attr-defined]
|
output, inverse_indices, counts = _VF.unique_consecutive( # type: ignore[attr-defined]
|
||||||
input, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
|
input, return_inverse=return_inverse, return_counts=return_counts, dim=dim
|
||||||
|
)
|
||||||
return output, inverse_indices, counts
|
return output, inverse_indices, counts
|
||||||
|
|
||||||
|
|
||||||
def _return_counts(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
|
def _return_counts(
|
||||||
|
input, sorted=True, return_inverse=False, return_counts=False, dim=None
|
||||||
|
):
|
||||||
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
|
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
|
||||||
|
|
||||||
if has_torch_function_unary(input):
|
if has_torch_function_unary(input):
|
||||||
|
|
@ -988,7 +1054,9 @@ def _return_counts(input, sorted=True, return_inverse=False, return_counts=False
|
||||||
return output, counts
|
return output, counts
|
||||||
|
|
||||||
|
|
||||||
def _return_output(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
|
def _return_output(
|
||||||
|
input, sorted=True, return_inverse=False, return_counts=False, dim=None
|
||||||
|
):
|
||||||
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor
|
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor
|
||||||
|
|
||||||
if has_torch_function_unary(input):
|
if has_torch_function_unary(input):
|
||||||
|
|
@ -998,59 +1066,72 @@ def _return_output(input, sorted=True, return_inverse=False, return_counts=False
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _return_inverse(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
|
def _return_inverse(
|
||||||
|
input, sorted=True, return_inverse=False, return_counts=False, dim=None
|
||||||
|
):
|
||||||
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
|
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
|
||||||
|
|
||||||
if has_torch_function_unary(input):
|
if has_torch_function_unary(input):
|
||||||
return _unique_impl(input, sorted, return_inverse, return_counts, dim)
|
return _unique_impl(input, sorted, return_inverse, return_counts, dim)
|
||||||
|
|
||||||
output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
|
output, inverse_indices, _ = _unique_impl(
|
||||||
|
input, sorted, return_inverse, return_counts, dim
|
||||||
|
)
|
||||||
return output, inverse_indices
|
return output, inverse_indices
|
||||||
|
|
||||||
|
|
||||||
_return_inverse_false = boolean_dispatch(
|
_return_inverse_false = boolean_dispatch(
|
||||||
arg_name='return_counts',
|
arg_name="return_counts",
|
||||||
arg_index=3,
|
arg_index=3,
|
||||||
default=False,
|
default=False,
|
||||||
if_true=_return_counts,
|
if_true=_return_counts,
|
||||||
if_false=_return_output,
|
if_false=_return_output,
|
||||||
module_name=__name__,
|
module_name=__name__,
|
||||||
func_name='unique')
|
func_name="unique",
|
||||||
|
)
|
||||||
|
|
||||||
_return_inverse_true = boolean_dispatch(
|
_return_inverse_true = boolean_dispatch(
|
||||||
arg_name='return_counts',
|
arg_name="return_counts",
|
||||||
arg_index=3,
|
arg_index=3,
|
||||||
default=False,
|
default=False,
|
||||||
if_true=_unique_impl,
|
if_true=_unique_impl,
|
||||||
if_false=_return_inverse,
|
if_false=_return_inverse,
|
||||||
module_name=__name__,
|
module_name=__name__,
|
||||||
func_name='unique')
|
func_name="unique",
|
||||||
|
)
|
||||||
|
|
||||||
# The return type of unique depends on `return_inverse`, and `return_counts` so in order to
|
# The return type of unique depends on `return_inverse`, and `return_counts` so in order to
|
||||||
# resolve the output type in TorchScript we need to statically know the value of both parameters
|
# resolve the output type in TorchScript we need to statically know the value of both parameters
|
||||||
|
|
||||||
unique = boolean_dispatch(
|
unique = boolean_dispatch(
|
||||||
arg_name='return_inverse',
|
arg_name="return_inverse",
|
||||||
arg_index=2,
|
arg_index=2,
|
||||||
default=False,
|
default=False,
|
||||||
if_true=_return_inverse_true,
|
if_true=_return_inverse_true,
|
||||||
if_false=_return_inverse_false,
|
if_false=_return_inverse_false,
|
||||||
module_name=__name__,
|
module_name=__name__,
|
||||||
func_name='unique')
|
func_name="unique",
|
||||||
|
)
|
||||||
unique.__doc__ = _unique_impl.__doc__
|
unique.__doc__ = _unique_impl.__doc__
|
||||||
|
|
||||||
|
|
||||||
def _consecutive_return_counts(input, return_inverse=False, return_counts=False, dim=None):
|
def _consecutive_return_counts(
|
||||||
|
input, return_inverse=False, return_counts=False, dim=None
|
||||||
|
):
|
||||||
# type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
|
# type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
|
||||||
|
|
||||||
if has_torch_function_unary(input):
|
if has_torch_function_unary(input):
|
||||||
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
|
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
|
||||||
|
|
||||||
output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
|
output, _, counts = _unique_consecutive_impl(
|
||||||
|
input, return_inverse, return_counts, dim
|
||||||
|
)
|
||||||
return output, counts
|
return output, counts
|
||||||
|
|
||||||
|
|
||||||
def _consecutive_return_output(input, return_inverse=False, return_counts=False, dim=None):
|
def _consecutive_return_output(
|
||||||
|
input, return_inverse=False, return_counts=False, dim=None
|
||||||
|
):
|
||||||
# type: (Tensor, bool, bool, Optional[int]) -> Tensor
|
# type: (Tensor, bool, bool, Optional[int]) -> Tensor
|
||||||
|
|
||||||
if has_torch_function_unary(input):
|
if has_torch_function_unary(input):
|
||||||
|
|
@ -1060,45 +1141,52 @@ def _consecutive_return_output(input, return_inverse=False, return_counts=False,
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _consecutive_return_inverse(input, return_inverse=False, return_counts=False, dim=None):
|
def _consecutive_return_inverse(
|
||||||
|
input, return_inverse=False, return_counts=False, dim=None
|
||||||
|
):
|
||||||
# type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
|
# type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
|
||||||
|
|
||||||
if has_torch_function_unary(input):
|
if has_torch_function_unary(input):
|
||||||
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
|
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
|
||||||
|
|
||||||
output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
|
output, inverse_indices, _ = _unique_consecutive_impl(
|
||||||
|
input, return_inverse, return_counts, dim
|
||||||
|
)
|
||||||
return output, inverse_indices
|
return output, inverse_indices
|
||||||
|
|
||||||
|
|
||||||
_consecutive_return_inverse_false = boolean_dispatch(
|
_consecutive_return_inverse_false = boolean_dispatch(
|
||||||
arg_name='return_counts',
|
arg_name="return_counts",
|
||||||
arg_index=1,
|
arg_index=1,
|
||||||
default=False,
|
default=False,
|
||||||
if_true=_consecutive_return_counts,
|
if_true=_consecutive_return_counts,
|
||||||
if_false=_consecutive_return_output,
|
if_false=_consecutive_return_output,
|
||||||
module_name=__name__,
|
module_name=__name__,
|
||||||
func_name='unique_consecutive')
|
func_name="unique_consecutive",
|
||||||
|
)
|
||||||
|
|
||||||
_consecutive_return_inverse_true = boolean_dispatch(
|
_consecutive_return_inverse_true = boolean_dispatch(
|
||||||
arg_name='return_counts',
|
arg_name="return_counts",
|
||||||
arg_index=1,
|
arg_index=1,
|
||||||
default=False,
|
default=False,
|
||||||
if_true=_unique_consecutive_impl,
|
if_true=_unique_consecutive_impl,
|
||||||
if_false=_consecutive_return_inverse,
|
if_false=_consecutive_return_inverse,
|
||||||
module_name=__name__,
|
module_name=__name__,
|
||||||
func_name='unique_consecutive')
|
func_name="unique_consecutive",
|
||||||
|
)
|
||||||
|
|
||||||
# The return type of unique depends on `return_inverse`, and `return_counts` so in order to
|
# The return type of unique depends on `return_inverse`, and `return_counts` so in order to
|
||||||
# resolve the output type in TorchScript we need to statically know the value of both parameters
|
# resolve the output type in TorchScript we need to statically know the value of both parameters
|
||||||
|
|
||||||
unique_consecutive = boolean_dispatch(
|
unique_consecutive = boolean_dispatch(
|
||||||
arg_name='return_inverse',
|
arg_name="return_inverse",
|
||||||
arg_index=2,
|
arg_index=2,
|
||||||
default=False,
|
default=False,
|
||||||
if_true=_consecutive_return_inverse_true,
|
if_true=_consecutive_return_inverse_true,
|
||||||
if_false=_consecutive_return_inverse_false,
|
if_false=_consecutive_return_inverse_false,
|
||||||
module_name=__name__,
|
module_name=__name__,
|
||||||
func_name='unique_consecutive')
|
func_name="unique_consecutive",
|
||||||
|
)
|
||||||
unique_consecutive.__doc__ = _unique_consecutive_impl.__doc__
|
unique_consecutive.__doc__ = _unique_consecutive_impl.__doc__
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -1106,24 +1194,50 @@ if TYPE_CHECKING:
|
||||||
# There's no good way to use this type annotation without breaking JIT
|
# There's no good way to use this type annotation without breaking JIT
|
||||||
# overloads. So leave untyped for mypy for now.
|
# overloads. So leave untyped for mypy for now.
|
||||||
else:
|
else:
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def tensordot(a, b, dims: int = 2, out: Optional[torch.Tensor] = None):
|
def tensordot(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
dims: int = 2,
|
||||||
|
out: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@overload # noqa: F811
|
@overload
|
||||||
def tensordot(a, b, dims: Tuple[List[int], List[int]], out: Optional[torch.Tensor] = None): # noqa: F811
|
def tensordot( # noqa: F811
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
dims: Tuple[List[int], List[int]],
|
||||||
|
out: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@overload # noqa: F811
|
@overload
|
||||||
def tensordot(a, b, dims: List[List[int]], out: Optional[torch.Tensor] = None): # noqa: F811
|
def tensordot( # noqa: F811
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
dims: List[List[int]],
|
||||||
|
out: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@overload # noqa: F811
|
@overload
|
||||||
def tensordot(a, b, dims: torch.Tensor, out: Optional[torch.Tensor] = None): # noqa: F811
|
def tensordot( # noqa: F811
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
dims: torch.Tensor,
|
||||||
|
out: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def tensordot(a, b, dims=2, out: Optional[torch.Tensor] = None): # noqa: F811
|
def tensordot( # noqa: F811
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
dims=2,
|
||||||
|
out: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
r"""Returns a contraction of a and b over multiple dimensions.
|
r"""Returns a contraction of a and b over multiple dimensions.
|
||||||
|
|
||||||
:attr:`tensordot` implements a generalized matrix product.
|
:attr:`tensordot` implements a generalized matrix product.
|
||||||
|
|
@ -1178,10 +1292,12 @@ def tensordot(a, b, dims=2, out: Optional[torch.Tensor] = None): # noqa: F811
|
||||||
return handle_torch_function(tensordot, (a, b), a, b, dims=dims, out=out)
|
return handle_torch_function(tensordot, (a, b), a, b, dims=dims, out=out)
|
||||||
|
|
||||||
if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)):
|
if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)):
|
||||||
raise RuntimeError("tensordot expects dims to be int or "
|
raise RuntimeError(
|
||||||
|
"tensordot expects dims to be int or "
|
||||||
+ "Tuple[List[int], List[int]] or "
|
+ "Tuple[List[int], List[int]] or "
|
||||||
+ "List[List[int]] containing two lists, but got "
|
+ "List[List[int]] containing two lists, but got "
|
||||||
+ f"dims={dims}")
|
+ f"dims={dims}"
|
||||||
|
)
|
||||||
|
|
||||||
dims_a: List[int] = []
|
dims_a: List[int] = []
|
||||||
dims_b: List[int] = []
|
dims_b: List[int] = []
|
||||||
|
|
@ -1206,7 +1322,9 @@ def tensordot(a, b, dims=2, out: Optional[torch.Tensor] = None): # noqa: F811
|
||||||
if dims < 0:
|
if dims < 0:
|
||||||
raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
|
raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
|
||||||
if dims > min(a.dim(), b.dim()):
|
if dims > min(a.dim(), b.dim()):
|
||||||
raise RuntimeError(f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}")
|
raise RuntimeError(
|
||||||
|
f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}"
|
||||||
|
)
|
||||||
dims_a = list(range(-dims, 0))
|
dims_a = list(range(-dims, 0))
|
||||||
dims_b = list(range(dims))
|
dims_b = list(range(dims))
|
||||||
|
|
||||||
|
|
@ -1287,7 +1405,7 @@ def block_diag(*tensors):
|
||||||
return torch._C._VariableFunctions.block_diag(tensors) # type: ignore[attr-defined]
|
return torch._C._VariableFunctions.block_diag(tensors) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'):
|
def cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"):
|
||||||
# type: (Tensor, Tensor, float, str) -> (Tensor)
|
# type: (Tensor, Tensor, float, str) -> (Tensor)
|
||||||
r"""Computes batched the p-norm distance between each pair of the two collections of row vectors.
|
r"""Computes batched the p-norm distance between each pair of the two collections of row vectors.
|
||||||
|
|
||||||
|
|
@ -1331,12 +1449,13 @@ def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'):
|
||||||
"""
|
"""
|
||||||
if has_torch_function_variadic(x1, x2):
|
if has_torch_function_variadic(x1, x2):
|
||||||
return handle_torch_function(
|
return handle_torch_function(
|
||||||
cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode)
|
cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode
|
||||||
if compute_mode == 'use_mm_for_euclid_dist_if_necessary':
|
)
|
||||||
|
if compute_mode == "use_mm_for_euclid_dist_if_necessary":
|
||||||
return _VF.cdist(x1, x2, p, None) # type: ignore[attr-defined]
|
return _VF.cdist(x1, x2, p, None) # type: ignore[attr-defined]
|
||||||
elif compute_mode == 'use_mm_for_euclid_dist':
|
elif compute_mode == "use_mm_for_euclid_dist":
|
||||||
return _VF.cdist(x1, x2, p, 1) # type: ignore[attr-defined]
|
return _VF.cdist(x1, x2, p, 1) # type: ignore[attr-defined]
|
||||||
elif compute_mode == 'donot_use_mm_for_euclid_dist':
|
elif compute_mode == "donot_use_mm_for_euclid_dist":
|
||||||
return _VF.cdist(x1, x2, p, 2) # type: ignore[attr-defined]
|
return _VF.cdist(x1, x2, p, 2) # type: ignore[attr-defined]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{compute_mode} is not a valid value for compute_mode")
|
raise ValueError(f"{compute_mode} is not a valid value for compute_mode")
|
||||||
|
|
@ -1478,27 +1597,62 @@ else:
|
||||||
# TODO: type dim as BroadcastingList when
|
# TODO: type dim as BroadcastingList when
|
||||||
# https://github.com/pytorch/pytorch/issues/33782 is fixed
|
# https://github.com/pytorch/pytorch/issues/33782 is fixed
|
||||||
@overload
|
@overload
|
||||||
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):
|
def norm(
|
||||||
|
input,
|
||||||
|
p="fro",
|
||||||
|
dim=None,
|
||||||
|
keepdim=False,
|
||||||
|
out=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
# type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
|
# type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@overload # noqa: F811
|
@overload
|
||||||
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811
|
def norm( # noqa: F811
|
||||||
|
input,
|
||||||
|
p="fro",
|
||||||
|
dim=None,
|
||||||
|
keepdim=False,
|
||||||
|
out=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
# type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
|
# type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@overload # noqa: F811
|
@overload
|
||||||
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811
|
def norm( # noqa: F811
|
||||||
|
input,
|
||||||
|
p="fro",
|
||||||
|
dim=None,
|
||||||
|
keepdim=False,
|
||||||
|
out=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
# type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
|
# type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@overload # noqa: F811
|
@overload
|
||||||
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811
|
def norm( # noqa: F811
|
||||||
|
input,
|
||||||
|
p="fro",
|
||||||
|
dim=None,
|
||||||
|
keepdim=False,
|
||||||
|
out=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
# type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
|
# type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811
|
def norm( # noqa: F811
|
||||||
|
input,
|
||||||
|
p: Optional[Union[float, str]] = "fro",
|
||||||
|
dim=None,
|
||||||
|
keepdim=False,
|
||||||
|
out=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
r"""Returns the matrix norm or vector norm of a given tensor.
|
r"""Returns the matrix norm or vector norm of a given tensor.
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
|
|
@ -1594,14 +1748,19 @@ def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False,
|
||||||
|
|
||||||
if has_torch_function_unary(input):
|
if has_torch_function_unary(input):
|
||||||
return handle_torch_function(
|
return handle_torch_function(
|
||||||
norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)
|
norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
# NB. All the repeated code and weird python is to please TorchScript.
|
# NB. All the repeated code and weird python is to please TorchScript.
|
||||||
# For a more compact implementation see the relevant function in `_refs/__init__.py`
|
# For a more compact implementation see the relevant function in `_refs/__init__.py`
|
||||||
|
|
||||||
# We don't do this for MPS or sparse tensors
|
# We don't do this for MPS or sparse tensors
|
||||||
if input.layout == torch.strided and input.device.type in \
|
if input.layout == torch.strided and input.device.type in (
|
||||||
("cpu", "cuda", "meta", torch.utils.backend_registration._privateuse1_backend_name):
|
"cpu",
|
||||||
|
"cuda",
|
||||||
|
"meta",
|
||||||
|
torch.utils.backend_registration._privateuse1_backend_name,
|
||||||
|
):
|
||||||
if dim is not None:
|
if dim is not None:
|
||||||
if isinstance(dim, (int, torch.SymInt)):
|
if isinstance(dim, (int, torch.SymInt)):
|
||||||
_dim = [dim]
|
_dim = [dim]
|
||||||
|
|
@ -1611,11 +1770,17 @@ def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False,
|
||||||
_dim = None # type: ignore[assignment]
|
_dim = None # type: ignore[assignment]
|
||||||
|
|
||||||
if isinstance(p, str):
|
if isinstance(p, str):
|
||||||
if p == "fro" and (dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2):
|
if p == "fro" and (
|
||||||
|
dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2
|
||||||
|
):
|
||||||
if out is None:
|
if out is None:
|
||||||
return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype)
|
return torch.linalg.vector_norm(
|
||||||
|
input, 2, _dim, keepdim, dtype=dtype
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype, out=out)
|
return torch.linalg.vector_norm(
|
||||||
|
input, 2, _dim, keepdim, dtype=dtype, out=out
|
||||||
|
)
|
||||||
|
|
||||||
# Here we either call the nuclear norm, or we call matrix_norm with some arguments
|
# Here we either call the nuclear norm, or we call matrix_norm with some arguments
|
||||||
# that will throw an error
|
# that will throw an error
|
||||||
|
|
@ -1624,14 +1789,18 @@ def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False,
|
||||||
if out is None:
|
if out is None:
|
||||||
return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype)
|
return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype, out=out)
|
return torch.linalg.matrix_norm(
|
||||||
|
input, p, _dim, keepdim, dtype=dtype, out=out
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# NB. p should be Union[str, number], not Optional!
|
# NB. p should be Union[str, number], not Optional!
|
||||||
_p = 2.0 if p is None else p
|
_p = 2.0 if p is None else p
|
||||||
if out is None:
|
if out is None:
|
||||||
return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype)
|
return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype, out=out)
|
return torch.linalg.vector_norm(
|
||||||
|
input, _p, _dim, keepdim, dtype=dtype, out=out
|
||||||
|
)
|
||||||
|
|
||||||
ndim = input.dim()
|
ndim = input.dim()
|
||||||
|
|
||||||
|
|
@ -1641,7 +1810,7 @@ def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False,
|
||||||
if p == "fro":
|
if p == "fro":
|
||||||
return _VF.frobenius_norm(input, dim=(), keepdim=keepdim)
|
return _VF.frobenius_norm(input, dim=(), keepdim=keepdim)
|
||||||
if not isinstance(p, str):
|
if not isinstance(p, str):
|
||||||
_dim = [i for i in range(ndim)] # noqa: C416 TODO: rewrite as list(range(m))
|
_dim = list(range(ndim))
|
||||||
return _VF.norm(input, p, dim=_dim, keepdim=keepdim) # type: ignore[attr-defined]
|
return _VF.norm(input, p, dim=_dim, keepdim=keepdim) # type: ignore[attr-defined]
|
||||||
|
|
||||||
# TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed
|
# TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed
|
||||||
|
|
@ -1695,7 +1864,10 @@ def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False,
|
||||||
else:
|
else:
|
||||||
return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out) # type: ignore[attr-defined]
|
return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out) # type: ignore[attr-defined]
|
||||||
|
|
||||||
def unravel_index(indices: Tensor, shape: Union[int, Sequence[int], torch.Size]) -> Tuple[Tensor, ...]:
|
|
||||||
|
def unravel_index(
|
||||||
|
indices: Tensor, shape: Union[int, Sequence[int], torch.Size]
|
||||||
|
) -> Tuple[Tensor, ...]:
|
||||||
r"""Converts a tensor of flat indices into a tuple of coordinate tensors that
|
r"""Converts a tensor of flat indices into a tuple of coordinate tensors that
|
||||||
index into an arbitrary tensor of the specified shape.
|
index into an arbitrary tensor of the specified shape.
|
||||||
|
|
||||||
|
|
@ -1745,19 +1917,23 @@ def unravel_index(indices: Tensor, shape: Union[int, Sequence[int], torch.Size])
|
||||||
tensor([[34], [78]]))
|
tensor([[34], [78]]))
|
||||||
"""
|
"""
|
||||||
if has_torch_function_unary(indices):
|
if has_torch_function_unary(indices):
|
||||||
return handle_torch_function(
|
return handle_torch_function(unravel_index, (indices,), indices, shape=shape)
|
||||||
unravel_index, (indices,), indices, shape=shape)
|
|
||||||
res_tensor = _unravel_index(indices, shape)
|
res_tensor = _unravel_index(indices, shape)
|
||||||
return res_tensor.unbind(-1)
|
return res_tensor.unbind(-1)
|
||||||
|
|
||||||
|
|
||||||
def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
|
def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
|
||||||
torch._check_type(
|
torch._check_type(
|
||||||
not indices.is_complex() and not indices.is_floating_point() and not indices.dtype == torch.bool,
|
not indices.is_complex()
|
||||||
lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}")
|
and not indices.is_floating_point()
|
||||||
|
and not indices.dtype == torch.bool,
|
||||||
|
lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}",
|
||||||
|
)
|
||||||
|
|
||||||
torch._check_type(
|
torch._check_type(
|
||||||
isinstance(shape, (int, torch.SymInt, Sequence)),
|
isinstance(shape, (int, torch.SymInt, Sequence)),
|
||||||
lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}")
|
lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}",
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(shape, (int, torch.SymInt)):
|
if isinstance(shape, (int, torch.SymInt)):
|
||||||
shape = torch.Size([shape])
|
shape = torch.Size([shape])
|
||||||
|
|
@ -1765,18 +1941,29 @@ def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
|
||||||
for dim in shape:
|
for dim in shape:
|
||||||
torch._check_type(
|
torch._check_type(
|
||||||
isinstance(dim, (int, torch.SymInt)),
|
isinstance(dim, (int, torch.SymInt)),
|
||||||
lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}")
|
lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}",
|
||||||
|
)
|
||||||
shape = torch.Size(shape)
|
shape = torch.Size(shape)
|
||||||
|
|
||||||
torch._check_value(
|
torch._check_value(
|
||||||
all(dim >= 0 for dim in shape),
|
all(dim >= 0 for dim in shape),
|
||||||
lambda: f"'shape' cannot have negative values, but got {tuple(shape)}")
|
lambda: f"'shape' cannot have negative values, but got {tuple(shape)}",
|
||||||
|
)
|
||||||
|
|
||||||
coefs = list(reversed(list(itertools.accumulate(reversed(shape[1:] + torch.Size([1])), func=operator.mul))))
|
coefs = list(
|
||||||
|
reversed(
|
||||||
|
list(
|
||||||
|
itertools.accumulate(
|
||||||
|
reversed(shape[1:] + torch.Size([1])), func=operator.mul
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
return indices.unsqueeze(-1).floor_divide(
|
return indices.unsqueeze(-1).floor_divide(
|
||||||
torch.tensor(coefs, device=indices.device, dtype=torch.int64)
|
torch.tensor(coefs, device=indices.device, dtype=torch.int64)
|
||||||
) % torch.tensor(shape, device=indices.device, dtype=torch.int64)
|
) % torch.tensor(shape, device=indices.device, dtype=torch.int64)
|
||||||
|
|
||||||
|
|
||||||
def chain_matmul(*matrices, out=None):
|
def chain_matmul(*matrices, out=None):
|
||||||
r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
|
r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
|
||||||
using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
|
using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
|
||||||
|
|
@ -1923,6 +2110,7 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None):
|
||||||
# If get_infos is True, then we don't need to check for errors and vice versa
|
# If get_infos is True, then we don't need to check for errors and vice versa
|
||||||
return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
|
return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
_ListOrSeq = Sequence[Tensor]
|
_ListOrSeq = Sequence[Tensor]
|
||||||
else:
|
else:
|
||||||
|
|
@ -1932,16 +2120,21 @@ else:
|
||||||
def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:
|
def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:
|
||||||
get_infos_int = 1 if get_infos else 0
|
get_infos_int = 1 if get_infos else 0
|
||||||
if out_len - get_infos_int != 2:
|
if out_len - get_infos_int != 2:
|
||||||
raise TypeError(f"expected tuple of {2 + int(get_infos)} elements but got {out_len}")
|
raise TypeError(
|
||||||
|
f"expected tuple of {2 + int(get_infos)} elements but got {out_len}"
|
||||||
|
)
|
||||||
if not isinstance(out, (tuple, list)):
|
if not isinstance(out, (tuple, list)):
|
||||||
raise TypeError(f"argument 'out' must be tuple of Tensors, not {type(out).__name__}")
|
raise TypeError(
|
||||||
|
f"argument 'out' must be tuple of Tensors, not {type(out).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
|
def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
|
||||||
# type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]
|
# type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]
|
||||||
if has_torch_function_unary(A):
|
if has_torch_function_unary(A):
|
||||||
return handle_torch_function(
|
return handle_torch_function(
|
||||||
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
|
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
|
||||||
|
)
|
||||||
result = _lu_impl(A, pivot, get_infos, out)
|
result = _lu_impl(A, pivot, get_infos, out)
|
||||||
if out is not None:
|
if out is not None:
|
||||||
_check_list_size(len(out), get_infos, out)
|
_check_list_size(len(out), get_infos, out)
|
||||||
|
|
@ -1957,7 +2150,8 @@ def _lu_no_infos(A, pivot=True, get_infos=False, out=None):
|
||||||
# need to check for torch_function here so that we exit if
|
# need to check for torch_function here so that we exit if
|
||||||
if has_torch_function_unary(A):
|
if has_torch_function_unary(A):
|
||||||
return handle_torch_function(
|
return handle_torch_function(
|
||||||
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
|
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
|
||||||
|
)
|
||||||
result = _lu_impl(A, pivot, get_infos, out)
|
result = _lu_impl(A, pivot, get_infos, out)
|
||||||
if out is not None:
|
if out is not None:
|
||||||
_check_list_size(len(out), get_infos, out)
|
_check_list_size(len(out), get_infos, out)
|
||||||
|
|
@ -1967,18 +2161,20 @@ def _lu_no_infos(A, pivot=True, get_infos=False, out=None):
|
||||||
else:
|
else:
|
||||||
return result[0], result[1] # A_LU, pivots
|
return result[0], result[1] # A_LU, pivots
|
||||||
|
|
||||||
|
|
||||||
# The return type of lu depends on `get_infos`, so in order to resolve the output type
|
# The return type of lu depends on `get_infos`, so in order to resolve the output type
|
||||||
# of lu in TorchScript we need to statically know the value of `get_infos`
|
# of lu in TorchScript we need to statically know the value of `get_infos`
|
||||||
lu = boolean_dispatch(
|
lu = boolean_dispatch(
|
||||||
arg_name='get_infos',
|
arg_name="get_infos",
|
||||||
arg_index=2,
|
arg_index=2,
|
||||||
default=False,
|
default=False,
|
||||||
if_true=_lu_with_infos,
|
if_true=_lu_with_infos,
|
||||||
if_false=_lu_no_infos,
|
if_false=_lu_no_infos,
|
||||||
module_name=__name__,
|
module_name=__name__,
|
||||||
func_name='lu')
|
func_name="lu",
|
||||||
|
)
|
||||||
lu.__doc__ = _lu_impl.__doc__
|
lu.__doc__ = _lu_impl.__doc__
|
||||||
|
|
||||||
|
|
||||||
def align_tensors(*tensors):
|
def align_tensors(*tensors):
|
||||||
raise RuntimeError('`align_tensors` not yet implemented.')
|
raise RuntimeError("`align_tensors` not yet implemented.")
|
||||||
|
|
|
||||||
253
torch/hub.py
253
torch/hub.py
|
|
@ -8,22 +8,22 @@ import re
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import torch
|
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
import zipfile
|
import zipfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Any
|
from typing import Any, Dict, Optional
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
from urllib.error import HTTPError, URLError
|
from urllib.error import HTTPError, URLError
|
||||||
from urllib.request import urlopen, Request
|
|
||||||
from urllib.parse import urlparse # noqa: F401
|
from urllib.parse import urlparse # noqa: F401
|
||||||
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch.serialization import MAP_LOCATION
|
from torch.serialization import MAP_LOCATION
|
||||||
|
|
||||||
class _Faketqdm: # type: ignore[no-redef]
|
|
||||||
|
|
||||||
def __init__(self, total=None, disable=False,
|
class _Faketqdm: # type: ignore[no-redef]
|
||||||
unit=None, *args, **kwargs):
|
def __init__(self, total=None, disable=False, unit=None, *args, **kwargs):
|
||||||
self.total = total
|
self.total = total
|
||||||
self.disable = disable
|
self.disable = disable
|
||||||
self.n = 0
|
self.n = 0
|
||||||
|
|
@ -57,7 +57,8 @@ class _Faketqdm: # type: ignore[no-redef]
|
||||||
if self.disable:
|
if self.disable:
|
||||||
return
|
return
|
||||||
|
|
||||||
sys.stderr.write('\n')
|
sys.stderr.write("\n")
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tqdm import tqdm # If tqdm is installed use it, otherwise use the fake wrapper
|
from tqdm import tqdm # If tqdm is installed use it, otherwise use the fake wrapper
|
||||||
|
|
@ -65,25 +66,30 @@ except ImportError:
|
||||||
tqdm = _Faketqdm
|
tqdm = _Faketqdm
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'download_url_to_file',
|
"download_url_to_file",
|
||||||
'get_dir',
|
"get_dir",
|
||||||
'help',
|
"help",
|
||||||
'list',
|
"list",
|
||||||
'load',
|
"load",
|
||||||
'load_state_dict_from_url',
|
"load_state_dict_from_url",
|
||||||
'set_dir',
|
"set_dir",
|
||||||
]
|
]
|
||||||
|
|
||||||
# matches bfd8deac from resnet18-bfd8deac.pth
|
# matches bfd8deac from resnet18-bfd8deac.pth
|
||||||
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
|
HASH_REGEX = re.compile(r"-([a-f0-9]*)\.")
|
||||||
|
|
||||||
_TRUSTED_REPO_OWNERS = ("facebookresearch", "facebookincubator", "pytorch", "fairinternal")
|
_TRUSTED_REPO_OWNERS = (
|
||||||
ENV_GITHUB_TOKEN = 'GITHUB_TOKEN'
|
"facebookresearch",
|
||||||
ENV_TORCH_HOME = 'TORCH_HOME'
|
"facebookincubator",
|
||||||
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
"pytorch",
|
||||||
DEFAULT_CACHE_DIR = '~/.cache'
|
"fairinternal",
|
||||||
VAR_DEPENDENCY = 'dependencies'
|
)
|
||||||
MODULE_HUBCONF = 'hubconf.py'
|
ENV_GITHUB_TOKEN = "GITHUB_TOKEN"
|
||||||
|
ENV_TORCH_HOME = "TORCH_HOME"
|
||||||
|
ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
|
||||||
|
DEFAULT_CACHE_DIR = "~/.cache"
|
||||||
|
VAR_DEPENDENCY = "dependencies"
|
||||||
|
MODULE_HUBCONF = "hubconf.py"
|
||||||
READ_DATA_CHUNK = 128 * 1024
|
READ_DATA_CHUNK = 128 * 1024
|
||||||
_hub_dir: Optional[str] = None
|
_hub_dir: Optional[str] = None
|
||||||
|
|
||||||
|
|
@ -101,6 +107,7 @@ def _add_to_sys_path(path):
|
||||||
def _import_module(name, path):
|
def _import_module(name, path):
|
||||||
import importlib.util
|
import importlib.util
|
||||||
from importlib.abc import Loader
|
from importlib.abc import Loader
|
||||||
|
|
||||||
spec = importlib.util.spec_from_file_location(name, path)
|
spec = importlib.util.spec_from_file_location(name, path)
|
||||||
assert spec is not None
|
assert spec is not None
|
||||||
module = importlib.util.module_from_spec(spec)
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
|
@ -131,18 +138,20 @@ def _load_attr_from_module(module, func_name):
|
||||||
|
|
||||||
def _get_torch_home():
|
def _get_torch_home():
|
||||||
torch_home = os.path.expanduser(
|
torch_home = os.path.expanduser(
|
||||||
os.getenv(ENV_TORCH_HOME,
|
os.getenv(
|
||||||
os.path.join(os.getenv(ENV_XDG_CACHE_HOME,
|
ENV_TORCH_HOME,
|
||||||
DEFAULT_CACHE_DIR), 'torch')))
|
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"),
|
||||||
|
)
|
||||||
|
)
|
||||||
return torch_home
|
return torch_home
|
||||||
|
|
||||||
|
|
||||||
def _parse_repo_info(github):
|
def _parse_repo_info(github):
|
||||||
if ':' in github:
|
if ":" in github:
|
||||||
repo_info, ref = github.split(':')
|
repo_info, ref = github.split(":")
|
||||||
else:
|
else:
|
||||||
repo_info, ref = github, None
|
repo_info, ref = github, None
|
||||||
repo_owner, repo_name = repo_info.split('/')
|
repo_owner, repo_name = repo_info.split("/")
|
||||||
|
|
||||||
if ref is None:
|
if ref is None:
|
||||||
# The ref wasn't specified by the user, so we need to figure out the
|
# The ref wasn't specified by the user, so we need to figure out the
|
||||||
|
|
@ -150,16 +159,18 @@ def _parse_repo_info(github):
|
||||||
# then it's the default branch, otherwise it's master.
|
# then it's the default branch, otherwise it's master.
|
||||||
try:
|
try:
|
||||||
with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"):
|
with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"):
|
||||||
ref = 'main'
|
ref = "main"
|
||||||
except HTTPError as e:
|
except HTTPError as e:
|
||||||
if e.code == 404:
|
if e.code == 404:
|
||||||
ref = 'master'
|
ref = "master"
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
except URLError as e:
|
except URLError as e:
|
||||||
# No internet connection, need to check for cache as last resort
|
# No internet connection, need to check for cache as last resort
|
||||||
for possible_ref in ("main", "master"):
|
for possible_ref in ("main", "master"):
|
||||||
if os.path.exists(f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}"):
|
if os.path.exists(
|
||||||
|
f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}"
|
||||||
|
):
|
||||||
ref = possible_ref
|
ref = possible_ref
|
||||||
break
|
break
|
||||||
if ref is None:
|
if ref is None:
|
||||||
|
|
@ -172,35 +183,40 @@ def _parse_repo_info(github):
|
||||||
|
|
||||||
def _read_url(url):
|
def _read_url(url):
|
||||||
with urlopen(url) as r:
|
with urlopen(url) as r:
|
||||||
return r.read().decode(r.headers.get_content_charset('utf-8'))
|
return r.read().decode(r.headers.get_content_charset("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
def _validate_not_a_forked_repo(repo_owner, repo_name, ref):
|
def _validate_not_a_forked_repo(repo_owner, repo_name, ref):
|
||||||
# Use urlopen to avoid depending on local git.
|
# Use urlopen to avoid depending on local git.
|
||||||
headers = {'Accept': 'application/vnd.github.v3+json'}
|
headers = {"Accept": "application/vnd.github.v3+json"}
|
||||||
token = os.environ.get(ENV_GITHUB_TOKEN)
|
token = os.environ.get(ENV_GITHUB_TOKEN)
|
||||||
if token is not None:
|
if token is not None:
|
||||||
headers['Authorization'] = f'token {token}'
|
headers["Authorization"] = f"token {token}"
|
||||||
for url_prefix in (
|
for url_prefix in (
|
||||||
f'https://api.github.com/repos/{repo_owner}/{repo_name}/branches',
|
f"https://api.github.com/repos/{repo_owner}/{repo_name}/branches",
|
||||||
f'https://api.github.com/repos/{repo_owner}/{repo_name}/tags'):
|
f"https://api.github.com/repos/{repo_owner}/{repo_name}/tags",
|
||||||
|
):
|
||||||
page = 0
|
page = 0
|
||||||
while True:
|
while True:
|
||||||
page += 1
|
page += 1
|
||||||
url = f'{url_prefix}?per_page=100&page={page}'
|
url = f"{url_prefix}?per_page=100&page={page}"
|
||||||
response = json.loads(_read_url(Request(url, headers=headers)))
|
response = json.loads(_read_url(Request(url, headers=headers)))
|
||||||
# Empty response means no more data to process
|
# Empty response means no more data to process
|
||||||
if not response:
|
if not response:
|
||||||
break
|
break
|
||||||
for br in response:
|
for br in response:
|
||||||
if br['name'] == ref or br['commit']['sha'].startswith(ref):
|
if br["name"] == ref or br["commit"]["sha"].startswith(ref):
|
||||||
return
|
return
|
||||||
|
|
||||||
raise ValueError(f'Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. '
|
raise ValueError(
|
||||||
'If it\'s a commit from a forked repo, please call hub.load() with forked repo directly.')
|
f"Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. "
|
||||||
|
"If it's a commit from a forked repo, please call hub.load() with forked repo directly."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=True, skip_validation=False):
|
def _get_cache_or_reload(
|
||||||
|
github, force_reload, trust_repo, calling_fn, verbose=True, skip_validation=False
|
||||||
|
):
|
||||||
# Setup hub_dir to save downloaded files
|
# Setup hub_dir to save downloaded files
|
||||||
hub_dir = get_dir()
|
hub_dir = get_dir()
|
||||||
os.makedirs(hub_dir, exist_ok=True)
|
os.makedirs(hub_dir, exist_ok=True)
|
||||||
|
|
@ -210,27 +226,33 @@ def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=T
|
||||||
# this causes confusion with path on both Linux and Windows.
|
# this causes confusion with path on both Linux and Windows.
|
||||||
# Backslash is not allowed in Github branch name so no need to
|
# Backslash is not allowed in Github branch name so no need to
|
||||||
# to worry about it.
|
# to worry about it.
|
||||||
normalized_br = ref.replace('/', '_')
|
normalized_br = ref.replace("/", "_")
|
||||||
# Github renames folder repo-v1.x.x to repo-1.x.x
|
# Github renames folder repo-v1.x.x to repo-1.x.x
|
||||||
# We don't know the repo name before downloading the zip file
|
# We don't know the repo name before downloading the zip file
|
||||||
# and inspect name from it.
|
# and inspect name from it.
|
||||||
# To check if cached repo exists, we need to normalize folder names.
|
# To check if cached repo exists, we need to normalize folder names.
|
||||||
owner_name_branch = '_'.join([repo_owner, repo_name, normalized_br])
|
owner_name_branch = "_".join([repo_owner, repo_name, normalized_br])
|
||||||
repo_dir = os.path.join(hub_dir, owner_name_branch)
|
repo_dir = os.path.join(hub_dir, owner_name_branch)
|
||||||
# Check that the repo is in the trusted list
|
# Check that the repo is in the trusted list
|
||||||
_check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo=trust_repo, calling_fn=calling_fn)
|
_check_repo_is_trusted(
|
||||||
|
repo_owner,
|
||||||
|
repo_name,
|
||||||
|
owner_name_branch,
|
||||||
|
trust_repo=trust_repo,
|
||||||
|
calling_fn=calling_fn,
|
||||||
|
)
|
||||||
|
|
||||||
use_cache = (not force_reload) and os.path.exists(repo_dir)
|
use_cache = (not force_reload) and os.path.exists(repo_dir)
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if verbose:
|
if verbose:
|
||||||
sys.stderr.write(f'Using cache found in {repo_dir}\n')
|
sys.stderr.write(f"Using cache found in {repo_dir}\n")
|
||||||
else:
|
else:
|
||||||
# Validate the tag/branch is from the original repo instead of a forked repo
|
# Validate the tag/branch is from the original repo instead of a forked repo
|
||||||
if not skip_validation:
|
if not skip_validation:
|
||||||
_validate_not_a_forked_repo(repo_owner, repo_name, ref)
|
_validate_not_a_forked_repo(repo_owner, repo_name, ref)
|
||||||
|
|
||||||
cached_file = os.path.join(hub_dir, normalized_br + '.zip')
|
cached_file = os.path.join(hub_dir, normalized_br + ".zip")
|
||||||
_remove_if_exists(cached_file)
|
_remove_if_exists(cached_file)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -250,7 +272,9 @@ def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=T
|
||||||
"refs/tags/tag_name as the ref. That might require using skip_validation=True."
|
"refs/tags/tag_name as the ref. That might require using skip_validation=True."
|
||||||
)
|
)
|
||||||
disambiguated_branch_ref = f"refs/heads/{ref}"
|
disambiguated_branch_ref = f"refs/heads/{ref}"
|
||||||
url = _git_archive_link(repo_owner, repo_name, ref=disambiguated_branch_ref)
|
url = _git_archive_link(
|
||||||
|
repo_owner, repo_name, ref=disambiguated_branch_ref
|
||||||
|
)
|
||||||
download_url_to_file(url, cached_file, progress=False)
|
download_url_to_file(url, cached_file, progress=False)
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
@ -269,7 +293,9 @@ def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=T
|
||||||
return repo_dir
|
return repo_dir
|
||||||
|
|
||||||
|
|
||||||
def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo, calling_fn="load"):
|
def _check_repo_is_trusted(
|
||||||
|
repo_owner, repo_name, owner_name_branch, trust_repo, calling_fn="load"
|
||||||
|
):
|
||||||
hub_dir = get_dir()
|
hub_dir = get_dir()
|
||||||
filepath = os.path.join(hub_dir, "trusted_list")
|
filepath = os.path.join(hub_dir, "trusted_list")
|
||||||
|
|
||||||
|
|
@ -282,7 +308,7 @@ def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo,
|
||||||
# if a repo was already downloaded by torchhub, then it is already trusted (even if it's not in the allowlist)
|
# if a repo was already downloaded by torchhub, then it is already trusted (even if it's not in the allowlist)
|
||||||
trusted_repos_legacy = next(os.walk(hub_dir))[1]
|
trusted_repos_legacy = next(os.walk(hub_dir))[1]
|
||||||
|
|
||||||
owner_name = '_'.join([repo_owner, repo_name])
|
owner_name = "_".join([repo_owner, repo_name])
|
||||||
is_trusted = (
|
is_trusted = (
|
||||||
owner_name in trusted_repos
|
owner_name in trusted_repos
|
||||||
or owner_name_branch in trusted_repos_legacy
|
or owner_name_branch in trusted_repos_legacy
|
||||||
|
|
@ -298,13 +324,15 @@ def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo,
|
||||||
"trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, "
|
"trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, "
|
||||||
f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with "
|
f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with "
|
||||||
f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for "
|
f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for "
|
||||||
f"confirmation if the repo is not already trusted. This will eventually be the default behaviour")
|
f"confirmation if the repo is not already trusted. This will eventually be the default behaviour"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if (trust_repo is False) or (trust_repo == "check" and not is_trusted):
|
if (trust_repo is False) or (trust_repo == "check" and not is_trusted):
|
||||||
response = input(
|
response = input(
|
||||||
f"The repository {owner_name} does not belong to the list of trusted repositories and as such cannot be downloaded. "
|
f"The repository {owner_name} does not belong to the list of trusted repositories and as such cannot be downloaded. "
|
||||||
"Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?")
|
"Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?"
|
||||||
|
)
|
||||||
if response.lower() in ("y", "yes"):
|
if response.lower() in ("y", "yes"):
|
||||||
if is_trusted:
|
if is_trusted:
|
||||||
print("The repository is already trusted.")
|
print("The repository is already trusted.")
|
||||||
|
|
@ -321,6 +349,7 @@ def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo,
|
||||||
|
|
||||||
def _check_module_exists(name):
|
def _check_module_exists(name):
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
|
||||||
return importlib.util.find_spec(name) is not None
|
return importlib.util.find_spec(name) is not None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -335,7 +364,7 @@ def _check_dependencies(m):
|
||||||
|
|
||||||
def _load_entry_from_hubconf(m, model):
|
def _load_entry_from_hubconf(m, model):
|
||||||
if not isinstance(model, str):
|
if not isinstance(model, str):
|
||||||
raise ValueError('Invalid input: model should be a string of function name')
|
raise ValueError("Invalid input: model should be a string of function name")
|
||||||
|
|
||||||
# Note that if a missing dependency is imported at top level of hubconf, it will
|
# Note that if a missing dependency is imported at top level of hubconf, it will
|
||||||
# throw before this function. It's a chicken and egg situation where we have to
|
# throw before this function. It's a chicken and egg situation where we have to
|
||||||
|
|
@ -346,7 +375,7 @@ def _load_entry_from_hubconf(m, model):
|
||||||
func = _load_attr_from_module(m, model)
|
func = _load_attr_from_module(m, model)
|
||||||
|
|
||||||
if func is None or not callable(func):
|
if func is None or not callable(func):
|
||||||
raise RuntimeError(f'Cannot find callable {model} in hubconf')
|
raise RuntimeError(f"Cannot find callable {model} in hubconf")
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
@ -362,12 +391,12 @@ def get_dir():
|
||||||
variable is not set.
|
variable is not set.
|
||||||
"""
|
"""
|
||||||
# Issue warning to move data if old env is set
|
# Issue warning to move data if old env is set
|
||||||
if os.getenv('TORCH_HUB'):
|
if os.getenv("TORCH_HUB"):
|
||||||
warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead')
|
warnings.warn("TORCH_HUB is deprecated, please use env TORCH_HOME instead")
|
||||||
|
|
||||||
if _hub_dir is not None:
|
if _hub_dir is not None:
|
||||||
return _hub_dir
|
return _hub_dir
|
||||||
return os.path.join(_get_torch_home(), 'hub')
|
return os.path.join(_get_torch_home(), "hub")
|
||||||
|
|
||||||
|
|
||||||
def set_dir(d):
|
def set_dir(d):
|
||||||
|
|
@ -381,7 +410,9 @@ def set_dir(d):
|
||||||
_hub_dir = os.path.expanduser(d)
|
_hub_dir = os.path.expanduser(d)
|
||||||
|
|
||||||
|
|
||||||
def list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True):
|
def list(
|
||||||
|
github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
List all callable entrypoints available in the repo specified by ``github``.
|
List all callable entrypoints available in the repo specified by ``github``.
|
||||||
|
|
||||||
|
|
@ -424,15 +455,25 @@ def list(github, force_reload=False, skip_validation=False, trust_repo=None, ver
|
||||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
|
||||||
>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
|
>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
|
||||||
"""
|
"""
|
||||||
repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "list", verbose=verbose,
|
repo_dir = _get_cache_or_reload(
|
||||||
skip_validation=skip_validation)
|
github,
|
||||||
|
force_reload,
|
||||||
|
trust_repo,
|
||||||
|
"list",
|
||||||
|
verbose=verbose,
|
||||||
|
skip_validation=skip_validation,
|
||||||
|
)
|
||||||
|
|
||||||
with _add_to_sys_path(repo_dir):
|
with _add_to_sys_path(repo_dir):
|
||||||
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
|
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
|
||||||
hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
|
hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
|
||||||
|
|
||||||
# We take functions starts with '_' as internal helper functions
|
# We take functions starts with '_' as internal helper functions
|
||||||
entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')]
|
entrypoints = [
|
||||||
|
f
|
||||||
|
for f in dir(hub_module)
|
||||||
|
if callable(getattr(hub_module, f)) and not f.startswith("_")
|
||||||
|
]
|
||||||
|
|
||||||
return entrypoints
|
return entrypoints
|
||||||
|
|
||||||
|
|
@ -474,8 +515,14 @@ def help(github, model, force_reload=False, skip_validation=False, trust_repo=No
|
||||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
|
||||||
>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
|
>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
|
||||||
"""
|
"""
|
||||||
repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "help", verbose=True,
|
repo_dir = _get_cache_or_reload(
|
||||||
skip_validation=skip_validation)
|
github,
|
||||||
|
force_reload,
|
||||||
|
trust_repo,
|
||||||
|
"help",
|
||||||
|
verbose=True,
|
||||||
|
skip_validation=skip_validation,
|
||||||
|
)
|
||||||
|
|
||||||
with _add_to_sys_path(repo_dir):
|
with _add_to_sys_path(repo_dir):
|
||||||
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
|
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
|
||||||
|
|
@ -486,9 +533,17 @@ def help(github, model, force_reload=False, skip_validation=False, trust_repo=No
|
||||||
return entry.__doc__
|
return entry.__doc__
|
||||||
|
|
||||||
|
|
||||||
def load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True,
|
def load(
|
||||||
|
repo_or_dir,
|
||||||
|
model,
|
||||||
|
*args,
|
||||||
|
source="github",
|
||||||
|
trust_repo=None,
|
||||||
|
force_reload=False,
|
||||||
|
verbose=True,
|
||||||
skip_validation=False,
|
skip_validation=False,
|
||||||
**kwargs):
|
**kwargs,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
Load a model from a github repo or a local directory.
|
Load a model from a github repo or a local directory.
|
||||||
|
|
||||||
|
|
@ -559,13 +614,20 @@ def load(repo_or_dir, model, *args, source='github', trust_repo=None, force_relo
|
||||||
"""
|
"""
|
||||||
source = source.lower()
|
source = source.lower()
|
||||||
|
|
||||||
if source not in ('github', 'local'):
|
if source not in ("github", "local"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'Unknown source: "{source}". Allowed values: "github" | "local".')
|
f'Unknown source: "{source}". Allowed values: "github" | "local".'
|
||||||
|
)
|
||||||
|
|
||||||
if source == 'github':
|
if source == "github":
|
||||||
repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, trust_repo, "load",
|
repo_or_dir = _get_cache_or_reload(
|
||||||
verbose=verbose, skip_validation=skip_validation)
|
repo_or_dir,
|
||||||
|
force_reload,
|
||||||
|
trust_repo,
|
||||||
|
"load",
|
||||||
|
verbose=verbose,
|
||||||
|
skip_validation=skip_validation,
|
||||||
|
)
|
||||||
|
|
||||||
model = _load_local(repo_or_dir, model, *args, **kwargs)
|
model = _load_local(repo_or_dir, model, *args, **kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
@ -601,8 +663,9 @@ def _load_local(hubconf_dir, model, *args, **kwargs):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
|
def download_url_to_file(
|
||||||
progress: bool = True) -> None:
|
url: str, dst: str, hash_prefix: Optional[str] = None, progress: bool = True
|
||||||
|
) -> None:
|
||||||
r"""Download object at the given URL to a local path.
|
r"""Download object at the given URL to a local path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -623,7 +686,7 @@ def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
|
||||||
req = Request(url, headers={"User-Agent": "torch.hub"})
|
req = Request(url, headers={"User-Agent": "torch.hub"})
|
||||||
u = urlopen(req)
|
u = urlopen(req)
|
||||||
meta = u.info()
|
meta = u.info()
|
||||||
if hasattr(meta, 'getheaders'):
|
if hasattr(meta, "getheaders"):
|
||||||
content_length = meta.getheaders("Content-Length")
|
content_length = meta.getheaders("Content-Length")
|
||||||
else:
|
else:
|
||||||
content_length = meta.get_all("Content-Length")
|
content_length = meta.get_all("Content-Length")
|
||||||
|
|
@ -637,20 +700,25 @@ def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
|
||||||
# file permissions being applied to the downloaded file.
|
# file permissions being applied to the downloaded file.
|
||||||
dst = os.path.expanduser(dst)
|
dst = os.path.expanduser(dst)
|
||||||
for seq in range(tempfile.TMP_MAX):
|
for seq in range(tempfile.TMP_MAX):
|
||||||
tmp_dst = dst + '.' + uuid.uuid4().hex + '.partial'
|
tmp_dst = dst + "." + uuid.uuid4().hex + ".partial"
|
||||||
try:
|
try:
|
||||||
f = open(tmp_dst, 'w+b')
|
f = open(tmp_dst, "w+b")
|
||||||
except FileExistsError:
|
except FileExistsError:
|
||||||
continue
|
continue
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raise FileExistsError(errno.EEXIST, 'No usable temporary file name found')
|
raise FileExistsError(errno.EEXIST, "No usable temporary file name found")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hash_prefix is not None:
|
if hash_prefix is not None:
|
||||||
sha256 = hashlib.sha256()
|
sha256 = hashlib.sha256()
|
||||||
with tqdm(total=file_size, disable=not progress,
|
with tqdm(
|
||||||
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
|
total=file_size,
|
||||||
|
disable=not progress,
|
||||||
|
unit="B",
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1024,
|
||||||
|
) as pbar:
|
||||||
while True:
|
while True:
|
||||||
buffer = u.read(READ_DATA_CHUNK)
|
buffer = u.read(READ_DATA_CHUNK)
|
||||||
if len(buffer) == 0:
|
if len(buffer) == 0:
|
||||||
|
|
@ -664,7 +732,9 @@ def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
|
||||||
if hash_prefix is not None:
|
if hash_prefix is not None:
|
||||||
digest = sha256.hexdigest() # type: ignore[possibly-undefined]
|
digest = sha256.hexdigest() # type: ignore[possibly-undefined]
|
||||||
if digest[: len(hash_prefix)] != hash_prefix:
|
if digest[: len(hash_prefix)] != hash_prefix:
|
||||||
raise RuntimeError(f'invalid hash value (expected "{hash_prefix}", got "{digest}")')
|
raise RuntimeError(
|
||||||
|
f'invalid hash value (expected "{hash_prefix}", got "{digest}")'
|
||||||
|
)
|
||||||
shutil.move(f.name, dst)
|
shutil.move(f.name, dst)
|
||||||
finally:
|
finally:
|
||||||
f.close()
|
f.close()
|
||||||
|
|
@ -683,23 +753,30 @@ def _is_legacy_zip_format(filename: str) -> bool:
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
'Falling back to the old format < 1.6. This support will be '
|
"Falling back to the old format < 1.6. This support will be "
|
||||||
'deprecated in favor of default zipfile format introduced in 1.6. '
|
"deprecated in favor of default zipfile format introduced in 1.6. "
|
||||||
'Please redo torch.save() to save it in the new zipfile format.',
|
"Please redo torch.save() to save it in the new zipfile format.",
|
||||||
category=FutureWarning,
|
category=FutureWarning,
|
||||||
)
|
)
|
||||||
def _legacy_zip_load(filename: str, model_dir: str, map_location: MAP_LOCATION, weights_only: bool) -> Dict[str, Any]:
|
def _legacy_zip_load(
|
||||||
|
filename: str,
|
||||||
|
model_dir: str,
|
||||||
|
map_location: MAP_LOCATION,
|
||||||
|
weights_only: bool,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
# Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
|
# Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
|
||||||
# We deliberately don't handle tarfile here since our legacy serialization format was in tar.
|
# We deliberately don't handle tarfile here since our legacy serialization format was in tar.
|
||||||
# E.g. resnet18-5c106cde.pth which is widely used.
|
# E.g. resnet18-5c106cde.pth which is widely used.
|
||||||
with zipfile.ZipFile(filename) as f:
|
with zipfile.ZipFile(filename) as f:
|
||||||
members = f.infolist()
|
members = f.infolist()
|
||||||
if len(members) != 1:
|
if len(members) != 1:
|
||||||
raise RuntimeError('Only one file(not dir) is allowed in the zipfile')
|
raise RuntimeError("Only one file(not dir) is allowed in the zipfile")
|
||||||
f.extractall(model_dir)
|
f.extractall(model_dir)
|
||||||
extraced_name = members[0].filename
|
extraced_name = members[0].filename
|
||||||
extracted_file = os.path.join(model_dir, extraced_name)
|
extracted_file = os.path.join(model_dir, extraced_name)
|
||||||
return torch.load(extracted_file, map_location=map_location, weights_only=weights_only)
|
return torch.load(
|
||||||
|
extracted_file, map_location=map_location, weights_only=weights_only
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict_from_url(
|
def load_state_dict_from_url(
|
||||||
|
|
@ -742,12 +819,14 @@ def load_state_dict_from_url(
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Issue warning to move data if old env is set
|
# Issue warning to move data if old env is set
|
||||||
if os.getenv('TORCH_MODEL_ZOO'):
|
if os.getenv("TORCH_MODEL_ZOO"):
|
||||||
warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
warnings.warn(
|
||||||
|
"TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead"
|
||||||
|
)
|
||||||
|
|
||||||
if model_dir is None:
|
if model_dir is None:
|
||||||
hub_dir = get_dir()
|
hub_dir = get_dir()
|
||||||
model_dir = os.path.join(hub_dir, 'checkpoints')
|
model_dir = os.path.join(hub_dir, "checkpoints")
|
||||||
|
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
|
||||||
|
|
|
||||||
222
torch/library.py
222
torch/library.py
|
|
@ -1,28 +1,34 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
from ._ops import OpOverload
|
import contextlib
|
||||||
from typing import Any, Optional, Set, List, Union, Callable, Tuple, Dict, Sequence
|
|
||||||
from typing_extensions import deprecated
|
|
||||||
import traceback
|
|
||||||
import torch
|
|
||||||
import weakref
|
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import re
|
import re
|
||||||
import contextlib
|
|
||||||
import sys
|
import sys
|
||||||
from torch._library.custom_ops import custom_op, _maybe_get_opdef, device_types_t, CustomOpDef
|
import traceback
|
||||||
|
import weakref
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||||
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch._library as _library
|
import torch._library as _library
|
||||||
|
from torch._library.custom_ops import (
|
||||||
|
_maybe_get_opdef,
|
||||||
|
custom_op,
|
||||||
|
CustomOpDef,
|
||||||
|
device_types_t,
|
||||||
|
)
|
||||||
|
from torch._ops import OpOverload
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Library',
|
"Library",
|
||||||
'impl',
|
"impl",
|
||||||
'define',
|
"define",
|
||||||
'fallthrough_kernel',
|
"fallthrough_kernel",
|
||||||
'impl_abstract',
|
"impl_abstract",
|
||||||
'register_fake',
|
"register_fake",
|
||||||
'get_ctx',
|
"get_ctx",
|
||||||
'custom_op',
|
"custom_op",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
|
# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
|
||||||
|
|
@ -33,7 +39,8 @@ _impls: Set[str] = set()
|
||||||
_defs: Set[str] = set()
|
_defs: Set[str] = set()
|
||||||
|
|
||||||
# prim is reserved by TorchScript interpreter
|
# prim is reserved by TorchScript interpreter
|
||||||
_reserved_namespaces = ['prim']
|
_reserved_namespaces = ["prim"]
|
||||||
|
|
||||||
|
|
||||||
def fallthrough_kernel():
|
def fallthrough_kernel():
|
||||||
"""
|
"""
|
||||||
|
|
@ -41,6 +48,7 @@ def fallthrough_kernel():
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("fallthrough_kernel() should never be called.")
|
raise NotImplementedError("fallthrough_kernel() should never be called.")
|
||||||
|
|
||||||
|
|
||||||
class Library:
|
class Library:
|
||||||
"""
|
"""
|
||||||
A class to create libraries that can be used to register new operators or
|
A class to create libraries that can be used to register new operators or
|
||||||
|
|
@ -59,16 +67,22 @@ class Library:
|
||||||
kind: "DEF", "IMPL" (default: "IMPL"), "FRAGMENT"
|
kind: "DEF", "IMPL" (default: "IMPL"), "FRAGMENT"
|
||||||
dispatch_key: PyTorch dispatch key (default: "")
|
dispatch_key: PyTorch dispatch key (default: "")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ns, kind, dispatch_key=""):
|
def __init__(self, ns, kind, dispatch_key=""):
|
||||||
if kind not in ('IMPL', 'DEF', 'FRAGMENT'):
|
if kind not in ("IMPL", "DEF", "FRAGMENT"):
|
||||||
raise ValueError("Unsupported kind: ", kind)
|
raise ValueError("Unsupported kind: ", kind)
|
||||||
|
|
||||||
if ns in _reserved_namespaces and (kind == "DEF" or kind == 'FRAGMENT'):
|
if ns in _reserved_namespaces and (kind == "DEF" or kind == "FRAGMENT"):
|
||||||
raise ValueError(ns, " is a reserved namespace. Please try creating a library with another name.")
|
raise ValueError(
|
||||||
|
ns,
|
||||||
|
" is a reserved namespace. Please try creating a library with another name.",
|
||||||
|
)
|
||||||
|
|
||||||
frame = traceback.extract_stack(limit=3)[0]
|
frame = traceback.extract_stack(limit=3)[0]
|
||||||
filename, lineno = frame.filename, frame.lineno
|
filename, lineno = frame.filename, frame.lineno
|
||||||
self.m: Optional[Any] = torch._C._dispatch_library(kind, ns, dispatch_key, filename, lineno)
|
self.m: Optional[Any] = torch._C._dispatch_library(
|
||||||
|
kind, ns, dispatch_key, filename, lineno
|
||||||
|
)
|
||||||
self.ns = ns
|
self.ns = ns
|
||||||
self._op_defs: Set[str] = set()
|
self._op_defs: Set[str] = set()
|
||||||
self._op_impls: Set[str] = set()
|
self._op_impls: Set[str] = set()
|
||||||
|
|
@ -79,13 +93,21 @@ class Library:
|
||||||
# Python __del__ can lead to weird things (globals and locals may already
|
# Python __del__ can lead to weird things (globals and locals may already
|
||||||
# be gone when __del__ actually gets called!). finalizers help the
|
# be gone when __del__ actually gets called!). finalizers help the
|
||||||
# situation because it lets us capture references and keeps them alive
|
# situation because it lets us capture references and keeps them alive
|
||||||
weakref.finalize(self, _del_library, _impls, self._op_impls, _defs, self._op_defs, self._registration_handles)
|
weakref.finalize(
|
||||||
|
self,
|
||||||
|
_del_library,
|
||||||
|
_impls,
|
||||||
|
self._op_impls,
|
||||||
|
_defs,
|
||||||
|
self._op_defs,
|
||||||
|
self._registration_handles,
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>"
|
return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>"
|
||||||
|
|
||||||
def define(self, schema, alias_analysis="", *, tags=()):
|
def define(self, schema, alias_analysis="", *, tags=()):
|
||||||
r'''Defines a new operator and its semantics in the ns namespace.
|
r"""Defines a new operator and its semantics in the ns namespace.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schema: function schema to define a new operator.
|
schema: function schema to define a new operator.
|
||||||
|
|
@ -102,7 +124,7 @@ class Library:
|
||||||
Example::
|
Example::
|
||||||
>>> my_lib = Library("mylib", "DEF")
|
>>> my_lib = Library("mylib", "DEF")
|
||||||
>>> my_lib.define("sum(Tensor self) -> Tensor")
|
>>> my_lib.define("sum(Tensor self) -> Tensor")
|
||||||
'''
|
"""
|
||||||
# This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
|
# This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
|
||||||
# AliasAnalysis type in C++
|
# AliasAnalysis type in C++
|
||||||
if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]:
|
if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]:
|
||||||
|
|
@ -113,7 +135,9 @@ class Library:
|
||||||
|
|
||||||
name = schema.split("(")[0]
|
name = schema.split("(")[0]
|
||||||
packet_name = name.split(".")[0] if "." in name else name
|
packet_name = name.split(".")[0] if "." in name else name
|
||||||
has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr(getattr(torch.ops, self.ns), packet_name)
|
has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr(
|
||||||
|
getattr(torch.ops, self.ns), packet_name
|
||||||
|
)
|
||||||
|
|
||||||
result = self.m.define(schema, alias_analysis, tuple(tags))
|
result = self.m.define(schema, alias_analysis, tuple(tags))
|
||||||
name = schema.split("(")[0]
|
name = schema.split("(")[0]
|
||||||
|
|
@ -131,7 +155,7 @@ class Library:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _register_fake(self, op_name, fn, _stacklevel=1):
|
def _register_fake(self, op_name, fn, _stacklevel=1):
|
||||||
r'''Registers the fake impl for an operator defined in the library.'''
|
r"""Registers the fake impl for an operator defined in the library."""
|
||||||
source = torch._library.utils.get_source(_stacklevel + 1)
|
source = torch._library.utils.get_source(_stacklevel + 1)
|
||||||
frame = sys._getframe(_stacklevel)
|
frame = sys._getframe(_stacklevel)
|
||||||
caller_module = inspect.getmodule(frame)
|
caller_module = inspect.getmodule(frame)
|
||||||
|
|
@ -141,7 +165,9 @@ class Library:
|
||||||
|
|
||||||
# TODO(rzou): We're gonna need to stage this change with torchvision,
|
# TODO(rzou): We're gonna need to stage this change with torchvision,
|
||||||
# since torchvision is github first.
|
# since torchvision is github first.
|
||||||
if caller_module_name is not None and caller_module_name.startswith("torchvision."):
|
if caller_module_name is not None and caller_module_name.startswith(
|
||||||
|
"torchvision."
|
||||||
|
):
|
||||||
caller_module_name = None
|
caller_module_name = None
|
||||||
|
|
||||||
qualname = f"{self.ns}::{op_name}"
|
qualname = f"{self.ns}::{op_name}"
|
||||||
|
|
@ -154,8 +180,8 @@ class Library:
|
||||||
handle = entry.abstract_impl.register(func_to_register, source)
|
handle = entry.abstract_impl.register(func_to_register, source)
|
||||||
self._registration_handles.append(handle)
|
self._registration_handles.append(handle)
|
||||||
|
|
||||||
def _impl_with_aoti_compile(self, op_name, dispatch_key=''):
|
def _impl_with_aoti_compile(self, op_name, dispatch_key=""):
|
||||||
r'''Register the operator to use the AOTI-compiled implementation.
|
r"""Register the operator to use the AOTI-compiled implementation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
op_name: operator name (along with the overload) or OpOverload object.
|
op_name: operator name (along with the overload) or OpOverload object.
|
||||||
|
|
@ -165,8 +191,8 @@ class Library:
|
||||||
Example::
|
Example::
|
||||||
>>> my_lib = Library("aten", "IMPL")
|
>>> my_lib = Library("aten", "IMPL")
|
||||||
>>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU")
|
>>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU")
|
||||||
'''
|
"""
|
||||||
if dispatch_key == '':
|
if dispatch_key == "":
|
||||||
dispatch_key = self.dispatch_key
|
dispatch_key = self.dispatch_key
|
||||||
assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense)
|
assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense)
|
||||||
|
|
||||||
|
|
@ -175,19 +201,24 @@ class Library:
|
||||||
elif isinstance(op_name, OpOverload):
|
elif isinstance(op_name, OpOverload):
|
||||||
name = op_name._schema.name
|
name = op_name._schema.name
|
||||||
overload_name = op_name._schema.overload_name
|
overload_name = op_name._schema.overload_name
|
||||||
if overload_name != '':
|
if overload_name != "":
|
||||||
name = name + '.' + overload_name
|
name = name + "." + overload_name
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("_impl_with_aoti_compile should be passed either a name or an OpOverload object "
|
raise RuntimeError(
|
||||||
"as the first argument")
|
"_impl_with_aoti_compile should be passed either a name or an OpOverload object "
|
||||||
|
"as the first argument"
|
||||||
|
)
|
||||||
|
|
||||||
key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
|
key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
|
||||||
if key in _impls:
|
if key in _impls:
|
||||||
# TODO: in future, add more info about where the existing function is registered (this info is
|
# TODO: in future, add more info about where the existing function is registered (this info is
|
||||||
# today already returned by the C++ warning when _impl_with_aoti_compile is called but we error out before that)
|
# today already returned by the C++ warning when _impl_with_aoti_compile is called but we error out before that)
|
||||||
raise RuntimeError("This is not allowed since there's already a kernel registered from python overriding {}"
|
raise RuntimeError(
|
||||||
"'s behavior for {} dispatch key and {} namespace.".
|
"This is not allowed since there's already a kernel registered from python overriding {}"
|
||||||
format(name.split("::")[-1], dispatch_key, self.ns))
|
"'s behavior for {} dispatch key and {} namespace.".format(
|
||||||
|
name.split("::")[-1], dispatch_key, self.ns
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
assert self.m is not None
|
assert self.m is not None
|
||||||
impl_fn: Callable = self.m.impl_with_aoti_compile
|
impl_fn: Callable = self.m.impl_with_aoti_compile
|
||||||
|
|
@ -196,8 +227,8 @@ class Library:
|
||||||
_impls.add(key)
|
_impls.add(key)
|
||||||
self._op_impls.add(key)
|
self._op_impls.add(key)
|
||||||
|
|
||||||
def impl(self, op_name, fn, dispatch_key='', *, with_keyset=False):
|
def impl(self, op_name, fn, dispatch_key="", *, with_keyset=False):
|
||||||
r'''Registers the function implementation for an operator defined in the library.
|
r"""Registers the function implementation for an operator defined in the library.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
op_name: operator name (along with the overload) or OpOverload object.
|
op_name: operator name (along with the overload) or OpOverload object.
|
||||||
|
|
@ -211,10 +242,12 @@ class Library:
|
||||||
>>> def div_cpu(self, other):
|
>>> def div_cpu(self, other):
|
||||||
>>> return self * (1 / other)
|
>>> return self * (1 / other)
|
||||||
>>> my_lib.impl("div.Tensor", div_cpu, "CPU")
|
>>> my_lib.impl("div.Tensor", div_cpu, "CPU")
|
||||||
'''
|
"""
|
||||||
if not callable(fn):
|
if not callable(fn):
|
||||||
raise TypeError(f"Input function is required to be a callable but found type {type(fn)}")
|
raise TypeError(
|
||||||
if dispatch_key == '':
|
f"Input function is required to be a callable but found type {type(fn)}"
|
||||||
|
)
|
||||||
|
if dispatch_key == "":
|
||||||
dispatch_key = self.dispatch_key
|
dispatch_key = self.dispatch_key
|
||||||
|
|
||||||
if isinstance(op_name, str):
|
if isinstance(op_name, str):
|
||||||
|
|
@ -222,37 +255,50 @@ class Library:
|
||||||
elif isinstance(op_name, OpOverload):
|
elif isinstance(op_name, OpOverload):
|
||||||
name = op_name._schema.name
|
name = op_name._schema.name
|
||||||
overload_name = op_name._schema.overload_name
|
overload_name = op_name._schema.overload_name
|
||||||
if overload_name != '':
|
if overload_name != "":
|
||||||
name = name + '.' + overload_name
|
name = name + "." + overload_name
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("impl should be passed either a name or an OpOverload object as the first argument")
|
raise RuntimeError(
|
||||||
|
"impl should be passed either a name or an OpOverload object as the first argument"
|
||||||
|
)
|
||||||
|
|
||||||
key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
|
key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
|
||||||
if key in _impls:
|
if key in _impls:
|
||||||
# TODO: in future, add more info about where the existing function is registered (this info is
|
# TODO: in future, add more info about where the existing function is registered (this info is
|
||||||
# today already returned by the C++ warning when impl is called but we error out before that)
|
# today already returned by the C++ warning when impl is called but we error out before that)
|
||||||
raise RuntimeError("This is not allowed since there's already a kernel registered from python overriding {}"
|
raise RuntimeError(
|
||||||
"'s behavior for {} dispatch key and {} namespace.".
|
"This is not allowed since there's already a kernel registered from python overriding {}"
|
||||||
format(name.split("::")[-1], dispatch_key, self.ns))
|
"'s behavior for {} dispatch key and {} namespace.".format(
|
||||||
|
name.split("::")[-1], dispatch_key, self.ns
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if dispatch_key == "Meta":
|
if dispatch_key == "Meta":
|
||||||
dispatcher_op_name = name
|
dispatcher_op_name = name
|
||||||
if '::' not in dispatcher_op_name:
|
if "::" not in dispatcher_op_name:
|
||||||
dispatcher_op_name = f'{self.ns}::{dispatcher_op_name}'
|
dispatcher_op_name = f"{self.ns}::{dispatcher_op_name}"
|
||||||
|
|
||||||
# Internally, we shouldn't be registering meta kernels for any operators that
|
# Internally, we shouldn't be registering meta kernels for any operators that
|
||||||
# have CompositeImplicitAutograd kernels.
|
# have CompositeImplicitAutograd kernels.
|
||||||
# Instead, we should be letting those decompositions run, and writing meta kernels
|
# Instead, we should be letting those decompositions run, and writing meta kernels
|
||||||
# only for the base operators.
|
# only for the base operators.
|
||||||
if torch._C._dispatch_has_kernel_for_dispatch_key(dispatcher_op_name, "CompositeImplicitAutograd"):
|
if torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||||
|
dispatcher_op_name, "CompositeImplicitAutograd"
|
||||||
|
):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"We should not register a meta kernel directly to the operator '{name}',"
|
f"We should not register a meta kernel directly to the operator '{name}',"
|
||||||
" because it has a CompositeImplicitAutograd kernel in core."
|
" because it has a CompositeImplicitAutograd kernel in core."
|
||||||
" Instead we should let the operator decompose, and ensure that we have meta kernels"
|
" Instead we should let the operator decompose, and ensure that we have meta kernels"
|
||||||
" for the base ops that it decomposes into.")
|
" for the base ops that it decomposes into."
|
||||||
|
)
|
||||||
|
|
||||||
assert self.m is not None
|
assert self.m is not None
|
||||||
self.m.impl(name, dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd", fn, with_keyset)
|
self.m.impl(
|
||||||
|
name,
|
||||||
|
dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd",
|
||||||
|
fn,
|
||||||
|
with_keyset,
|
||||||
|
)
|
||||||
|
|
||||||
_impls.add(key)
|
_impls.add(key)
|
||||||
self._op_impls.add(key)
|
self._op_impls.add(key)
|
||||||
|
|
@ -283,7 +329,9 @@ class Library:
|
||||||
delattr(namespace, name)
|
delattr(namespace, name)
|
||||||
|
|
||||||
|
|
||||||
def _del_library(captured_impls, op_impls, captured_defs, op_defs, registration_handles):
|
def _del_library(
|
||||||
|
captured_impls, op_impls, captured_defs, op_defs, registration_handles
|
||||||
|
):
|
||||||
captured_impls -= op_impls
|
captured_impls -= op_impls
|
||||||
captured_defs -= op_defs
|
captured_defs -= op_defs
|
||||||
for handle in registration_handles:
|
for handle in registration_handles:
|
||||||
|
|
@ -357,7 +405,8 @@ def define(qualname, schema, *, lib=None, tags=()):
|
||||||
if not isinstance(qualname, str):
|
if not isinstance(qualname, str):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"define(qualname, schema): expected qualname "
|
f"define(qualname, schema): expected qualname "
|
||||||
f"to be instance of str, got {type(qualname)}")
|
f"to be instance of str, got {type(qualname)}"
|
||||||
|
)
|
||||||
namespace, name = torch._library.utils.parse_namespace(qualname)
|
namespace, name = torch._library.utils.parse_namespace(qualname)
|
||||||
if lib is None:
|
if lib is None:
|
||||||
lib = Library(namespace, "FRAGMENT")
|
lib = Library(namespace, "FRAGMENT")
|
||||||
|
|
@ -366,7 +415,8 @@ def define(qualname, schema, *, lib=None, tags=()):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"define(qualname, schema, ...): expected schema "
|
f"define(qualname, schema, ...): expected schema "
|
||||||
f'to look like e.g. "(Tensor x) -> Tensor" but '
|
f'to look like e.g. "(Tensor x) -> Tensor" but '
|
||||||
f'got "{schema}"')
|
f'got "{schema}"'
|
||||||
|
)
|
||||||
lib.define(name + schema, alias_analysis="", tags=tags)
|
lib.define(name + schema, alias_analysis="", tags=tags)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -375,10 +425,12 @@ def _(lib: Library, schema, alias_analysis=""):
|
||||||
"""The old torch.library.define.
|
"""The old torch.library.define.
|
||||||
We're keeping this around for BC reasons
|
We're keeping this around for BC reasons
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrap(f):
|
def wrap(f):
|
||||||
name = lib.define(schema, alias_analysis)
|
name = lib.define(schema, alias_analysis)
|
||||||
lib.impl(name, f)
|
lib.impl(name, f)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
return wrap
|
return wrap
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -460,9 +512,11 @@ def _device_type_to_key(device_type: str) -> str:
|
||||||
@impl.register
|
@impl.register
|
||||||
def _(lib: Library, name, dispatch_key=""):
|
def _(lib: Library, name, dispatch_key=""):
|
||||||
"""Legacy torch.library.impl API. Kept around for BC"""
|
"""Legacy torch.library.impl API. Kept around for BC"""
|
||||||
|
|
||||||
def wrap(f):
|
def wrap(f):
|
||||||
lib.impl(name, f, dispatch_key)
|
lib.impl(name, f, dispatch_key)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
return wrap
|
return wrap
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -480,7 +534,9 @@ def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
|
||||||
return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel)
|
return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel)
|
||||||
|
|
||||||
|
|
||||||
_op_identifier = Union[str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef"]
|
_op_identifier = Union[
|
||||||
|
str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def register_kernel(
|
def register_kernel(
|
||||||
|
|
@ -489,7 +545,8 @@ def register_kernel(
|
||||||
func: Optional[Callable] = None,
|
func: Optional[Callable] = None,
|
||||||
/,
|
/,
|
||||||
*,
|
*,
|
||||||
lib: Optional[Library] = None):
|
lib: Optional[Library] = None,
|
||||||
|
):
|
||||||
"""Register an implementation for a device type for this operator.
|
"""Register an implementation for a device type for this operator.
|
||||||
|
|
||||||
Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
|
Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
|
||||||
|
|
@ -530,7 +587,9 @@ def register_kernel(
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)):
|
if not isinstance(
|
||||||
|
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
|
||||||
|
):
|
||||||
raise ValueError("register_kernel(op): got unexpected type for op: {type(op)}")
|
raise ValueError("register_kernel(op): got unexpected type for op: {type(op)}")
|
||||||
if isinstance(op, torch._ops.OpOverload):
|
if isinstance(op, torch._ops.OpOverload):
|
||||||
op = op._name
|
op = op._name
|
||||||
|
|
@ -549,7 +608,8 @@ def register_fake(
|
||||||
/,
|
/,
|
||||||
*,
|
*,
|
||||||
lib: Optional[Library] = None,
|
lib: Optional[Library] = None,
|
||||||
_stacklevel: int = 1):
|
_stacklevel: int = 1,
|
||||||
|
):
|
||||||
r"""Register a FakeTensor implementation ("fake impl") for this operator.
|
r"""Register a FakeTensor implementation ("fake impl") for this operator.
|
||||||
|
|
||||||
Also sometimes known as a "meta kernel", "abstract impl".
|
Also sometimes known as a "meta kernel", "abstract impl".
|
||||||
|
|
@ -630,7 +690,9 @@ def register_fake(
|
||||||
>>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
|
>>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not isinstance(op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)):
|
if not isinstance(
|
||||||
|
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
|
||||||
|
):
|
||||||
raise ValueError("register_fake(op): got unexpected type for op: {type(op)}")
|
raise ValueError("register_fake(op): got unexpected type for op: {type(op)}")
|
||||||
if isinstance(op, torch._ops.OpOverload):
|
if isinstance(op, torch._ops.OpOverload):
|
||||||
op = op._name
|
op = op._name
|
||||||
|
|
@ -661,7 +723,14 @@ def register_fake(
|
||||||
return register(func)
|
return register(func)
|
||||||
|
|
||||||
|
|
||||||
def register_autograd(op: _op_identifier, backward: Callable, /, *, setup_context: Optional[Callable] = None, lib=None) -> None:
|
def register_autograd(
|
||||||
|
op: _op_identifier,
|
||||||
|
backward: Callable,
|
||||||
|
/,
|
||||||
|
*,
|
||||||
|
setup_context: Optional[Callable] = None,
|
||||||
|
lib=None,
|
||||||
|
) -> None:
|
||||||
r"""Register a backward formula for this custom op.
|
r"""Register a backward formula for this custom op.
|
||||||
|
|
||||||
In order for an operator to work with autograd, you need to register
|
In order for an operator to work with autograd, you need to register
|
||||||
|
|
@ -737,8 +806,12 @@ def register_autograd(op: _op_identifier, backward: Callable, /, *, setup_contex
|
||||||
>>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
|
>>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not isinstance(op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)):
|
if not isinstance(
|
||||||
raise ValueError(f"register_autograd(op): got unexpected type for op: {type(op)}")
|
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"register_autograd(op): got unexpected type for op: {type(op)}"
|
||||||
|
)
|
||||||
if isinstance(op, torch._ops.OpOverload):
|
if isinstance(op, torch._ops.OpOverload):
|
||||||
op = op._name
|
op = op._name
|
||||||
opdef = _maybe_get_opdef(op)
|
opdef = _maybe_get_opdef(op)
|
||||||
|
|
@ -760,7 +833,8 @@ def register_autograd(op: _op_identifier, backward: Callable, /, *, setup_contex
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"register_autograd with kwarg-only Tensor args. In the original "
|
f"register_autograd with kwarg-only Tensor args. In the original "
|
||||||
f"definition of the op, please make your tensors not kwarg-only. "
|
f"definition of the op, please make your tensors not kwarg-only. "
|
||||||
f"Got: {schema}")
|
f"Got: {schema}"
|
||||||
|
)
|
||||||
|
|
||||||
info = _library.autograd.Info(backward, setup_context)
|
info = _library.autograd.Info(backward, setup_context)
|
||||||
autograd_kernel = _library.autograd.make_autograd_impl(op, info)
|
autograd_kernel = _library.autograd.make_autograd_impl(op, info)
|
||||||
|
|
@ -788,8 +862,8 @@ def _check_pystubs_once(func, qualname, actual_module_name):
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
maybe_pystub = torch._C._dispatch_pystub(
|
maybe_pystub = torch._C._dispatch_pystub(
|
||||||
op._schema.name,
|
op._schema.name, op._schema.overload_name
|
||||||
op._schema.overload_name)
|
)
|
||||||
if maybe_pystub is None:
|
if maybe_pystub is None:
|
||||||
if torch._library.utils.requires_set_python_module():
|
if torch._library.utils.requires_set_python_module():
|
||||||
namespace = op.namespace
|
namespace = op.namespace
|
||||||
|
|
@ -800,7 +874,8 @@ def _check_pystubs_once(func, qualname, actual_module_name):
|
||||||
f'companion C++ `m.set_python_module("{actual_module_name}")` '
|
f'companion C++ `m.set_python_module("{actual_module_name}")` '
|
||||||
f"call, but we could not find one. Please add that to "
|
f"call, but we could not find one. Please add that to "
|
||||||
f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the "
|
f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the "
|
||||||
f"operator was registered in ({cpp_filename})")
|
f"operator was registered in ({cpp_filename})"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
pystub_module = maybe_pystub[0]
|
pystub_module = maybe_pystub[0]
|
||||||
if actual_module_name != pystub_module:
|
if actual_module_name != pystub_module:
|
||||||
|
|
@ -809,9 +884,11 @@ def _check_pystubs_once(func, qualname, actual_module_name):
|
||||||
f"Operator '{qualname}' specified that its python fake impl "
|
f"Operator '{qualname}' specified that its python fake impl "
|
||||||
f"is in the Python module '{pystub_module}' but it was actually found "
|
f"is in the Python module '{pystub_module}' but it was actually found "
|
||||||
f"in '{actual_module_name}'. Please either move the fake impl "
|
f"in '{actual_module_name}'. Please either move the fake impl "
|
||||||
f"or correct the m.set_python_module call ({cpp_filename})")
|
f"or correct the m.set_python_module call ({cpp_filename})"
|
||||||
|
)
|
||||||
checked = True
|
checked = True
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -929,4 +1006,7 @@ def opcheck(
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import torch.testing._internal.optests as optests
|
import torch.testing._internal.optests as optests
|
||||||
return optests.opcheck(op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception)
|
|
||||||
|
return optests.opcheck(
|
||||||
|
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -23,19 +23,26 @@ instructions in the ``README.md`` in that directory.
|
||||||
import __future__ # noqa: F404
|
import __future__ # noqa: F404
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
import types
|
import types
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, Set, List, Any, Callable, Iterable, Type, Tuple
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
import contextlib
|
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._C import (
|
from torch._C import (
|
||||||
_has_torch_function, _has_torch_function_unary,
|
_add_docstr,
|
||||||
_has_torch_function_variadic, _add_docstr,
|
_get_function_stack_at,
|
||||||
_push_on_torch_function_stack, _pop_torch_function_stack, _get_function_stack_at, _len_torch_function_stack,
|
_has_torch_function,
|
||||||
_is_torch_function_mode_enabled)
|
_has_torch_function_unary,
|
||||||
|
_has_torch_function_variadic,
|
||||||
|
_is_torch_function_mode_enabled,
|
||||||
|
_len_torch_function_stack,
|
||||||
|
_pop_torch_function_stack,
|
||||||
|
_push_on_torch_function_stack,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_ignored_functions",
|
"get_ignored_functions",
|
||||||
|
|
@ -52,7 +59,8 @@ __all__ = [
|
||||||
|
|
||||||
|
|
||||||
def _disable_user_warnings(
|
def _disable_user_warnings(
|
||||||
func: Callable, regex: str = '.*is deprecated, please use.*', module: str = 'torch') -> Callable:
|
func: Callable, regex: str = ".*is deprecated, please use.*", module: str = "torch"
|
||||||
|
) -> Callable:
|
||||||
"""
|
"""
|
||||||
Decorator that temporarily disables ``UserWarning``s for the given ``module`` if the warning message matches the
|
Decorator that temporarily disables ``UserWarning``s for the given ``module`` if the warning message matches the
|
||||||
given ``regex`` pattern.
|
given ``regex`` pattern.
|
||||||
|
|
@ -75,8 +83,11 @@ def _disable_user_warnings(
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=UserWarning, message=regex, module=module)
|
warnings.filterwarnings(
|
||||||
|
"ignore", category=UserWarning, message=regex, module=module
|
||||||
|
)
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -470,8 +481,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.batch_norm_update_stats: lambda input, running_mean, running_var, momentum: -1,
|
torch.batch_norm_update_stats: lambda input, running_mean, running_var, momentum: -1,
|
||||||
torch.bernoulli: lambda input, generator=None, out=None: -1,
|
torch.bernoulli: lambda input, generator=None, out=None: -1,
|
||||||
torch.bilinear: lambda input1, input2, weight, bias: -1,
|
torch.bilinear: lambda input1, input2, weight, bias: -1,
|
||||||
torch.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None, reduce=None,
|
torch.binary_cross_entropy_with_logits: (
|
||||||
reduction='mean', pos_weight=None: -1),
|
lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1
|
||||||
|
),
|
||||||
torch.bincount: lambda input, weights=None, minlength=0: -1,
|
torch.bincount: lambda input, weights=None, minlength=0: -1,
|
||||||
torch.binomial: lambda count, prob, generator=None: -1,
|
torch.binomial: lambda count, prob, generator=None: -1,
|
||||||
torch.bitwise_and: lambda input, other, out=None: -1,
|
torch.bitwise_and: lambda input, other, out=None: -1,
|
||||||
|
|
@ -489,9 +501,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.cat: lambda tensors, dim=0, out=None: -1,
|
torch.cat: lambda tensors, dim=0, out=None: -1,
|
||||||
torch.concat: lambda tensors, dim=0, out=None: -1, # alias for torch.cat
|
torch.concat: lambda tensors, dim=0, out=None: -1, # alias for torch.cat
|
||||||
torch.concatenate: lambda tensors, dim=0, out=None: -1, # alias for torch.concatenate
|
torch.concatenate: lambda tensors, dim=0, out=None: -1, # alias for torch.concatenate
|
||||||
torch.cdist: lambda x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary': -1,
|
torch.cdist: lambda x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary": -1,
|
||||||
torch.ceil: lambda input, out=None: -1,
|
torch.ceil: lambda input, out=None: -1,
|
||||||
torch.celu: lambda input, alpha=1., inplace=False: -1,
|
torch.celu: lambda input, alpha=1.0, inplace=False: -1,
|
||||||
torch.chain_matmul: lambda *matrices, out=None: -1,
|
torch.chain_matmul: lambda *matrices, out=None: -1,
|
||||||
torch.channel_shuffle: lambda input, groups: -1,
|
torch.channel_shuffle: lambda input, groups: -1,
|
||||||
torch.cholesky: lambda input, upper=False, out=None: -1,
|
torch.cholesky: lambda input, upper=False, out=None: -1,
|
||||||
|
|
@ -528,14 +540,15 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.conv_transpose3d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
|
torch.conv_transpose3d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
|
||||||
torch.corrcoef: lambda input: -1,
|
torch.corrcoef: lambda input: -1,
|
||||||
torch.cos: lambda input, out=None: -1,
|
torch.cos: lambda input, out=None: -1,
|
||||||
torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1,
|
torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1,
|
||||||
torch.cosh: lambda input, out=None: -1,
|
torch.cosh: lambda input, out=None: -1,
|
||||||
torch.cosine_similarity: lambda x1, x2, dim=1, eps=1e-8: -1,
|
torch.cosine_similarity: lambda x1, x2, dim=1, eps=1e-8: -1,
|
||||||
torch.count_nonzero: lambda input: -1,
|
torch.count_nonzero: lambda input: -1,
|
||||||
torch.cross: lambda input, other, dim=None, out=None: -1,
|
torch.cross: lambda input, other, dim=None, out=None: -1,
|
||||||
torch.linalg.cross: lambda input, other, dim=-1, out=None: -1,
|
torch.linalg.cross: lambda input, other, dim=-1, out=None: -1,
|
||||||
torch.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean',
|
torch.ctc_loss: (
|
||||||
zero_infinity=False: -1),
|
lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1
|
||||||
|
),
|
||||||
torch.cummax: lambda input, dim, out=None: -1,
|
torch.cummax: lambda input, dim, out=None: -1,
|
||||||
torch.cummin: lambda input, dim, out=None: -1,
|
torch.cummin: lambda input, dim, out=None: -1,
|
||||||
torch.cumprod: lambda input, dim, out=None, dtype=None: -1,
|
torch.cumprod: lambda input, dim, out=None, dtype=None: -1,
|
||||||
|
|
@ -570,10 +583,12 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.linalg.eigh: lambda input, UPLO="L", out=None: -1,
|
torch.linalg.eigh: lambda input, UPLO="L", out=None: -1,
|
||||||
torch.linalg.eigvalsh: lambda input, UPLO="L", out=None: -1,
|
torch.linalg.eigvalsh: lambda input, UPLO="L", out=None: -1,
|
||||||
torch.einsum: lambda equation, *operands: -1,
|
torch.einsum: lambda equation, *operands: -1,
|
||||||
torch.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False,
|
torch.embedding: (
|
||||||
sparse=False: -1),
|
lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 # noqa: B950
|
||||||
torch.embedding_bag: (lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False,
|
),
|
||||||
mode='mean', sparse=False, per_sample_weights=None, padding_idx=None: -1),
|
torch.embedding_bag: (
|
||||||
|
lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, padding_idx=None: -1 # noqa: B950
|
||||||
|
),
|
||||||
torch.empty_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
|
torch.empty_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
|
||||||
torch.eq: lambda input, other, out=None: -1,
|
torch.eq: lambda input, other, out=None: -1,
|
||||||
torch.equal: lambda input, other: -1,
|
torch.equal: lambda input, other: -1,
|
||||||
|
|
@ -585,14 +600,15 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.expm1: lambda input, out=None: -1,
|
torch.expm1: lambda input, out=None: -1,
|
||||||
torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
|
torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
|
||||||
torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1,
|
torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1,
|
||||||
torch.fused_moving_avg_obs_fake_quant: (lambda x, observer_on, fake_quant_on, averaging_const, running_min,
|
torch.fused_moving_avg_obs_fake_quant: (
|
||||||
running_max, scale, zero_point, quant_min, quant_max, ch_axis,
|
lambda x, observer_on, fake_quant_on, averaging_const, running_min, running_max, scale, zero_point, quant_min, quant_max, ch_axis, per_row_fake_quant=False, symmetric_quant=False: -1 # noqa: B950
|
||||||
per_row_fake_quant=False, symmetric_quant=False: -1),
|
),
|
||||||
torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1,
|
torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1,
|
||||||
torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1,
|
torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1,
|
||||||
torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1,
|
torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1, # noqa: B950
|
||||||
torch.fbgemm_linear_int8_weight_fp32_activation: (lambda input, weight, packed, col_offsets, weight_scale,
|
torch.fbgemm_linear_int8_weight_fp32_activation: (
|
||||||
weight_zero_point, bias: -1),
|
lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1
|
||||||
|
),
|
||||||
torch.fbgemm_linear_quantize_weight: lambda input: -1,
|
torch.fbgemm_linear_quantize_weight: lambda input: -1,
|
||||||
torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1,
|
torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1,
|
||||||
torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1,
|
torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1,
|
||||||
|
|
@ -630,7 +646,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.fmod: lambda input, other, out=None: -1,
|
torch.fmod: lambda input, other, out=None: -1,
|
||||||
torch.frac: lambda input, out=None: -1,
|
torch.frac: lambda input, out=None: -1,
|
||||||
torch.frexp: lambda input, out=None: -1,
|
torch.frexp: lambda input, out=None: -1,
|
||||||
torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
|
torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, # noqa: B950
|
||||||
torch._functional_assert_async: lambda input, msg, dep_token: -1,
|
torch._functional_assert_async: lambda input, msg, dep_token: -1,
|
||||||
torch.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1,
|
torch.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1,
|
||||||
torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1,
|
torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1,
|
||||||
|
|
@ -653,7 +669,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.greater: lambda input, other, out=None: -1,
|
torch.greater: lambda input, other, out=None: -1,
|
||||||
torch.hardshrink: lambda input, lambd=0.5: -1,
|
torch.hardshrink: lambda input, lambd=0.5: -1,
|
||||||
torch.heaviside: lambda input, values, out=None: -1,
|
torch.heaviside: lambda input, values, out=None: -1,
|
||||||
torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction='mean': -1,
|
torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950
|
||||||
torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
|
torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
|
||||||
torch.histogram: lambda input, bins=100, min=None, max=None, weight=None, density=False, out=None: -1,
|
torch.histogram: lambda input, bins=100, min=None, max=None, weight=None, density=False, out=None: -1,
|
||||||
torch.histogramdd: lambda input, bins, range=None, weight=None, density=False: -1,
|
torch.histogramdd: lambda input, bins, range=None, weight=None, density=False: -1,
|
||||||
|
|
@ -677,8 +693,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.isreal: lambda tensor: -1,
|
torch.isreal: lambda tensor: -1,
|
||||||
torch.isposinf: lambda input, out=None: -1,
|
torch.isposinf: lambda input, out=None: -1,
|
||||||
torch.isneginf: lambda input, out=None: -1,
|
torch.isneginf: lambda input, out=None: -1,
|
||||||
torch.instance_norm: (lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps,
|
torch.instance_norm: (
|
||||||
cudnn_enabled: -1),
|
lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps, cudnn_enabled: -1
|
||||||
|
),
|
||||||
torch.int_repr: lambda input: -1,
|
torch.int_repr: lambda input: -1,
|
||||||
torch.inverse: lambda input, out=None: -1,
|
torch.inverse: lambda input, out=None: -1,
|
||||||
torch.linalg.inv: lambda input, out=None: -1,
|
torch.linalg.inv: lambda input, out=None: -1,
|
||||||
|
|
@ -694,9 +711,10 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.is_signed: lambda input: -1,
|
torch.is_signed: lambda input: -1,
|
||||||
torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1,
|
torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1,
|
||||||
torch.isnan: lambda input: -1,
|
torch.isnan: lambda input: -1,
|
||||||
torch.istft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
|
torch.istft: (
|
||||||
normalized=False, onesided=None, length=None, return_complex=False: -1),
|
lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, normalized=False, onesided=None, length=None, return_complex=False: -1 # noqa: B950
|
||||||
torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
|
),
|
||||||
|
torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1,
|
||||||
torch.kron: lambda input, other: -1,
|
torch.kron: lambda input, other: -1,
|
||||||
torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1,
|
torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1,
|
||||||
torch.linalg.ldl_factor_ex: lambda input, hermitian=False, check_errors=False, out=None: -1,
|
torch.linalg.ldl_factor_ex: lambda input, hermitian=False, check_errors=False, out=None: -1,
|
||||||
|
|
@ -709,8 +727,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.less_equal: lambda input, other, out=None: -1,
|
torch.less_equal: lambda input, other, out=None: -1,
|
||||||
torch.lerp: lambda input, end, weight, out=None: -1,
|
torch.lerp: lambda input, end, weight, out=None: -1,
|
||||||
torch.lgamma: lambda input, out=None: -1,
|
torch.lgamma: lambda input, out=None: -1,
|
||||||
torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None,
|
torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None, tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1, # noqa: B950
|
||||||
tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1,
|
|
||||||
torch.log: lambda input, out=None: -1,
|
torch.log: lambda input, out=None: -1,
|
||||||
torch.log_softmax: lambda input, dim, dtype=None: -1,
|
torch.log_softmax: lambda input, dim, dtype=None: -1,
|
||||||
torch.log10: lambda input, out=None: -1,
|
torch.log10: lambda input, out=None: -1,
|
||||||
|
|
@ -732,7 +749,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.less: lambda input, other, out=None: -1,
|
torch.less: lambda input, other, out=None: -1,
|
||||||
torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1,
|
torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1,
|
||||||
torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1,
|
torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1,
|
||||||
torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1, # type: ignore[attr-defined] # noqa: B950
|
torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1, # type: ignore[attr-defined] # noqa: B950
|
||||||
torch.masked_fill: lambda input, mask, value: -1,
|
torch.masked_fill: lambda input, mask, value: -1,
|
||||||
torch.masked_scatter: lambda input, mask, source: -1,
|
torch.masked_scatter: lambda input, mask, source: -1,
|
||||||
torch.masked_select: lambda input, mask, out=None: -1,
|
torch.masked_select: lambda input, mask, out=None: -1,
|
||||||
|
|
@ -754,8 +771,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
|
torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
|
||||||
torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
|
torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
|
||||||
torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
|
torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
|
||||||
torch.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
torch.max_pool1d_with_indices: (
|
||||||
return_indices=False, ceil_mode=False: -1),
|
lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
|
||||||
|
),
|
||||||
torch.mean: lambda input, dim=None: -1,
|
torch.mean: lambda input, dim=None: -1,
|
||||||
torch.nanmean: lambda input, dim=None, keepdim=False, dtype=None, out=None: -1,
|
torch.nanmean: lambda input, dim=None, keepdim=False, dtype=None, out=None: -1,
|
||||||
torch.median: lambda input, dim=None: -1,
|
torch.median: lambda input, dim=None: -1,
|
||||||
|
|
@ -764,17 +782,21 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.min: lambda input, out=None: -1,
|
torch.min: lambda input, out=None: -1,
|
||||||
torch.minimum: lambda input, other, out=None: -1,
|
torch.minimum: lambda input, other, out=None: -1,
|
||||||
torch.fmin: lambda input, other, out=None: -1,
|
torch.fmin: lambda input, other, out=None: -1,
|
||||||
torch.miopen_batch_norm: (lambda input, weight, bias, running_mean, running_var, training,
|
torch.miopen_batch_norm: (
|
||||||
exponential_average_factor, epsilon: -1),
|
lambda input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon: -1
|
||||||
torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1,
|
),
|
||||||
|
torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1, # noqa: B950
|
||||||
torch.miopen_convolution_add_relu: lambda input, weight, z, alpha, bias, stride, padding, dilation, groups: -1,
|
torch.miopen_convolution_add_relu: lambda input, weight, z, alpha, bias, stride, padding, dilation, groups: -1,
|
||||||
torch.miopen_convolution_relu: lambda input, weight, bias, stride, padding, dilation, groups: -1,
|
torch.miopen_convolution_relu: lambda input, weight, bias, stride, padding, dilation, groups: -1,
|
||||||
torch.miopen_convolution_transpose: (lambda input, weight, bias, padding, output_padding, stride, dilation,
|
torch.miopen_convolution_transpose: (
|
||||||
groups, benchmark, deterministic: -1),
|
lambda input, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic: -1
|
||||||
torch.miopen_depthwise_convolution: (lambda input, weight, bias, padding, stride, dilation, groups, benchmark,
|
),
|
||||||
deterministic: -1),
|
torch.miopen_depthwise_convolution: (
|
||||||
torch.miopen_rnn: (lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first,
|
lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1
|
||||||
dropout, train, bidirectional, batch_sizes, dropout_state: -1),
|
),
|
||||||
|
torch.miopen_rnn: (
|
||||||
|
lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state: -1 # noqa: B950
|
||||||
|
),
|
||||||
torch.mm: lambda input, mat2, out=None: -1,
|
torch.mm: lambda input, mat2, out=None: -1,
|
||||||
torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1,
|
torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1,
|
||||||
torch.movedim: lambda input, source, destination: -1,
|
torch.movedim: lambda input, source, destination: -1,
|
||||||
|
|
@ -809,62 +831,76 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.nn.functional.adaptive_max_pool3d_with_indices: lambda input, output_size, return_indices=False: -1,
|
torch.nn.functional.adaptive_max_pool3d_with_indices: lambda input, output_size, return_indices=False: -1,
|
||||||
torch.nn.functional.affine_grid: lambda theta, size, align_corners=None: -1,
|
torch.nn.functional.affine_grid: lambda theta, size, align_corners=None: -1,
|
||||||
torch.nn.functional.alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
|
torch.nn.functional.alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
|
||||||
torch.nn.functional.avg_pool2d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False,
|
torch.nn.functional.avg_pool2d: (
|
||||||
count_include_pad=True, divisor_override=None: -1),
|
lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 # noqa: B950
|
||||||
torch.nn.functional.avg_pool3d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False,
|
),
|
||||||
count_include_pad=True, divisor_override=None: -1),
|
torch.nn.functional.avg_pool3d: (
|
||||||
torch.nn.functional.batch_norm: (lambda input, running_mean, running_var, weight=None, bias=None, training=False,
|
lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 # noqa: B950
|
||||||
momentum=0.1, eps=1e-05: -1),
|
),
|
||||||
|
torch.nn.functional.batch_norm: (
|
||||||
|
lambda input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05: -1
|
||||||
|
),
|
||||||
torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1,
|
torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1,
|
||||||
torch.nn.functional.binary_cross_entropy: (lambda input, target, weight=None, size_average=None, reduce=None,
|
torch.nn.functional.binary_cross_entropy: (
|
||||||
reduction="mean": -1),
|
lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1
|
||||||
torch.nn.functional.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None,
|
),
|
||||||
reduce=None, reduction="mean", pos_weight=None: -1),
|
torch.nn.functional.binary_cross_entropy_with_logits: (
|
||||||
|
lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1
|
||||||
|
),
|
||||||
torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1,
|
torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1,
|
||||||
torch.nn.functional.cosine_embedding_loss: (lambda input1, input2, target, margin=0, size_average=None,
|
torch.nn.functional.cosine_embedding_loss: (
|
||||||
reduce=None, reduction='mean': -1),
|
lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1
|
||||||
torch.nn.functional.cross_entropy: (lambda input, target, weight=None, size_average=None, ignore_index=-100,
|
),
|
||||||
reduce=None, reduction="mean", label_smoothing=0.0: -1),
|
torch.nn.functional.cross_entropy: (
|
||||||
torch.nn.functional.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0,
|
lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", label_smoothing=0.0: -1 # noqa: B950
|
||||||
reduction='mean', zero_infinity=False: -1),
|
),
|
||||||
|
torch.nn.functional.ctc_loss: (
|
||||||
|
lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1
|
||||||
|
),
|
||||||
torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1,
|
torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1,
|
||||||
torch.nn.functional.dropout1d: lambda input, p=0.5, training=True, inplace=False: -1,
|
torch.nn.functional.dropout1d: lambda input, p=0.5, training=True, inplace=False: -1,
|
||||||
torch.nn.functional.dropout2d: lambda input, p=0.5, training=True, inplace=False: -1,
|
torch.nn.functional.dropout2d: lambda input, p=0.5, training=True, inplace=False: -1,
|
||||||
torch.nn.functional.dropout3d: lambda input, p=0.5, training=True, inplace=False: -1,
|
torch.nn.functional.dropout3d: lambda input, p=0.5, training=True, inplace=False: -1,
|
||||||
torch.nn.functional.elu: lambda input, alpha=1.0, inplace=False: -1,
|
torch.nn.functional.elu: lambda input, alpha=1.0, inplace=False: -1,
|
||||||
torch.nn.functional.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0,
|
torch.nn.functional.embedding: (
|
||||||
scale_grad_by_freq=False, sparse=False: -1),
|
lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 # noqa: B950
|
||||||
torch.nn.functional.embedding_bag: (lambda input, weight, offsets=None, max_norm=None, norm_type=2,
|
),
|
||||||
scale_grad_by_freq=False, mode='mean', sparse=False, per_sample_weights=None,
|
torch.nn.functional.embedding_bag: (
|
||||||
include_last_offset=False, padding_idx=None: -1),
|
lambda input, weight, offsets=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None: -1 # noqa: B950
|
||||||
|
),
|
||||||
torch.nn.functional.feature_alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
|
torch.nn.functional.feature_alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
|
||||||
torch.nn.functional.fold: lambda input, output_size, kernel_size, dilation=1, padding=0, stride=1: -1,
|
torch.nn.functional.fold: lambda input, output_size, kernel_size, dilation=1, padding=0, stride=1: -1,
|
||||||
torch.nn.functional.fractional_max_pool2d: (lambda input, kernel_size, output_size=None, output_ratio=None,
|
torch.nn.functional.fractional_max_pool2d: (
|
||||||
return_indices=False, _random_samples=None: -1),
|
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
|
||||||
|
),
|
||||||
torch.nn.functional.fractional_max_pool2d_with_indices: (
|
torch.nn.functional.fractional_max_pool2d_with_indices: (
|
||||||
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
|
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
|
||||||
_random_samples=None: -1),
|
),
|
||||||
torch.nn.functional.fractional_max_pool3d: (lambda input, kernel_size, output_size=None, output_ratio=None,
|
torch.nn.functional.fractional_max_pool3d: (
|
||||||
return_indices=False, _random_samples=None: -1),
|
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
|
||||||
|
),
|
||||||
torch.nn.functional.fractional_max_pool3d_with_indices: (
|
torch.nn.functional.fractional_max_pool3d_with_indices: (
|
||||||
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
|
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
|
||||||
_random_samples=None: -1),
|
),
|
||||||
torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction='mean': -1,
|
torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction="mean": -1,
|
||||||
torch.nn.functional.gelu: lambda input, approximate='none': -1,
|
torch.nn.functional.gelu: lambda input, approximate="none": -1,
|
||||||
torch.nn.functional.glu: lambda input, dim=-1: -1,
|
torch.nn.functional.glu: lambda input, dim=-1: -1,
|
||||||
torch.nn.functional.grid_sample: lambda input, grid, mode='bilinear', padding_mode='zeros', align_corners=None: -1,
|
torch.nn.functional.grid_sample: lambda input, grid, mode="bilinear", padding_mode="zeros", align_corners=None: -1, # noqa: B950
|
||||||
torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1,
|
torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1,
|
||||||
torch.nn.functional.gumbel_softmax: lambda logits, tau=1, hard=False, eps=1e-10, dim=-1: -1,
|
torch.nn.functional.gumbel_softmax: lambda logits, tau=1, hard=False, eps=1e-10, dim=-1: -1,
|
||||||
torch.nn.functional.hardshrink: lambda input, lambd=0.5: -1,
|
torch.nn.functional.hardshrink: lambda input, lambd=0.5: -1,
|
||||||
torch.nn.functional.hardtanh: lambda input, min_val=-1., max_val=1., inplace=False: -1,
|
torch.nn.functional.hardtanh: lambda input, min_val=-1.0, max_val=1.0, inplace=False: -1,
|
||||||
torch.nn.functional.hinge_embedding_loss: (lambda input, target, margin=1.0, size_average=None, reduce=None,
|
torch.nn.functional.hinge_embedding_loss: (
|
||||||
reduction='mean': -1),
|
lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1
|
||||||
torch.nn.functional.instance_norm: (lambda input, running_mean=None, running_var=None, weight=None, bias=None,
|
),
|
||||||
use_input_stats=True, momentum=0.1, eps=1e-05: -1),
|
torch.nn.functional.instance_norm: (
|
||||||
torch.nn.functional.interpolate: (lambda input, size=None, scale_factor=None, mode='nearest', align_corners=None,
|
lambda input, running_mean=None, running_var=None, weight=None, bias=None, use_input_stats=True, momentum=0.1, eps=1e-05: -1 # noqa: B950
|
||||||
recompute_scale_factor=None, antialias=False: -1),
|
),
|
||||||
torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
|
torch.nn.functional.interpolate: (
|
||||||
torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
|
lambda input, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None, antialias=False: -1 # noqa: B950
|
||||||
|
),
|
||||||
|
torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1, # noqa: B950
|
||||||
|
torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1,
|
||||||
torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
|
torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
|
||||||
torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1,
|
torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1,
|
||||||
torch.nn.functional.linear: lambda input, weight, bias=None: -1,
|
torch.nn.functional.linear: lambda input, weight, bias=None: -1,
|
||||||
|
|
@ -874,55 +910,65 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
|
torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
|
||||||
torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
|
torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
|
||||||
torch.nn.functional.lp_pool3d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
|
torch.nn.functional.lp_pool3d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
|
||||||
torch.nn.functional.margin_ranking_loss: (lambda input1, input2, target, margin=0, size_average=None,
|
torch.nn.functional.margin_ranking_loss: (
|
||||||
reduce=None, reduction='mean': -1),
|
lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1
|
||||||
torch.nn.functional.max_pool1d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
),
|
||||||
ceil_mode=False, return_indices=False: -1),
|
torch.nn.functional.max_pool1d: (
|
||||||
torch.nn.functional.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1
|
||||||
return_indices=False, ceil_mode=False: -1),
|
),
|
||||||
torch.nn.functional.max_pool2d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
torch.nn.functional.max_pool1d_with_indices: (
|
||||||
ceil_mode=False, return_indices=False: -1),
|
lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
|
||||||
torch.nn.functional.max_pool2d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
),
|
||||||
return_indices=False, ceil_mode=False: -1),
|
torch.nn.functional.max_pool2d: (
|
||||||
torch.nn.functional.max_pool3d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1
|
||||||
return_indices=False, ceil_mode=False: -1),
|
),
|
||||||
torch.nn.functional.max_pool3d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
torch.nn.functional.max_pool2d_with_indices: (
|
||||||
return_indices=False, ceil_mode=False: -1),
|
lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
|
||||||
torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
|
),
|
||||||
torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
|
torch.nn.functional.max_pool3d: (
|
||||||
torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
|
lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
|
||||||
torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
|
),
|
||||||
|
torch.nn.functional.max_pool3d_with_indices: (
|
||||||
|
lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
|
||||||
|
),
|
||||||
|
torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
|
||||||
|
torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
|
||||||
|
torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
|
||||||
|
torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1,
|
||||||
torch.nn.functional.multi_head_attention_forward: (
|
torch.nn.functional.multi_head_attention_forward: (
|
||||||
lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v,
|
lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None, need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None, v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1 # noqa: B950
|
||||||
add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None,
|
),
|
||||||
need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None,
|
torch.nn.functional.multi_margin_loss: (
|
||||||
v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1),
|
lambda input, target, p=1, margin=1.0, weight=None, size_average=None, reduce=None, reduction="mean": -1
|
||||||
torch.nn.functional.multi_margin_loss: (lambda input, target, p=1, margin=1.0, weight=None, size_average=None,
|
),
|
||||||
reduce=None, reduction='mean': -1),
|
torch.nn.functional.multilabel_margin_loss: (
|
||||||
torch.nn.functional.multilabel_margin_loss: (lambda input, target, size_average=None, reduce=None,
|
lambda input, target, size_average=None, reduce=None, reduction="mean": -1
|
||||||
reduction='mean': -1),
|
),
|
||||||
torch.nn.functional.multilabel_soft_margin_loss: (lambda input, target, weight=None, size_average=None,
|
torch.nn.functional.multilabel_soft_margin_loss: (
|
||||||
reduce=None, reduction='mean': -1),
|
lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1
|
||||||
torch.nn.functional.nll_loss: (lambda input, target, weight=None, size_average=None, ignore_index=-100,
|
),
|
||||||
reduce=None, reduction='mean': -1),
|
torch.nn.functional.nll_loss: (
|
||||||
|
lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean": -1
|
||||||
|
),
|
||||||
torch.nn.functional.normalize: lambda input, p=2, dim=1, eps=1e-12, out=None: -1,
|
torch.nn.functional.normalize: lambda input, p=2, dim=1, eps=1e-12, out=None: -1,
|
||||||
torch.nn.functional.one_hot: lambda tensor, num_classes=-1: -1,
|
torch.nn.functional.one_hot: lambda tensor, num_classes=-1: -1,
|
||||||
torch.nn.functional.pad: lambda input, pad, mode='constant', value=0: -1,
|
torch.nn.functional.pad: lambda input, pad, mode="constant", value=0: -1,
|
||||||
torch.nn.functional.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
|
torch.nn.functional.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
|
||||||
torch.nn.functional.poisson_nll_loss: (lambda input, target, log_input=True, full=False, size_average=None,
|
torch.nn.functional.poisson_nll_loss: (
|
||||||
eps=1e-08, reduce=None, reduction='mean': -1),
|
lambda input, target, log_input=True, full=False, size_average=None, eps=1e-08, reduce=None, reduction="mean": -1 # noqa: B950
|
||||||
|
),
|
||||||
torch.nn.functional.prelu: lambda input, weight: -1,
|
torch.nn.functional.prelu: lambda input, weight: -1,
|
||||||
torch.nn.functional.relu: lambda input, inplace=False: -1,
|
torch.nn.functional.relu: lambda input, inplace=False: -1,
|
||||||
torch.nn.functional.relu6: lambda input, inplace=False: -1,
|
torch.nn.functional.relu6: lambda input, inplace=False: -1,
|
||||||
torch.nn.functional.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
|
torch.nn.functional.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
|
||||||
torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1,
|
torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1, # noqa: B950
|
||||||
torch.nn.functional.selu: lambda input, inplace=False: -1,
|
torch.nn.functional.selu: lambda input, inplace=False: -1,
|
||||||
torch.nn.functional.silu: lambda input, inplace=False: -1,
|
torch.nn.functional.silu: lambda input, inplace=False: -1,
|
||||||
torch.nn.functional.mish: lambda input, inplace=False: -1,
|
torch.nn.functional.mish: lambda input, inplace=False: -1,
|
||||||
torch.nn.functional.scaled_dot_product_attention: lambda query, key, value, attn_mask=None, dropout_p=0.0: -1,
|
torch.nn.functional.scaled_dot_product_attention: lambda query, key, value, attn_mask=None, dropout_p=0.0: -1,
|
||||||
torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean', beta=1.: -1,
|
torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", beta=1.0: -1, # noqa: B950
|
||||||
torch.nn.functional.huber_loss: lambda input, target, reduction='mean', delta=1.: -1,
|
torch.nn.functional.huber_loss: lambda input, target, reduction="mean", delta=1.0: -1,
|
||||||
torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
|
torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950
|
||||||
torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
|
torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
|
||||||
torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
|
torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
|
||||||
torch.nn.functional.softplus: lambda input, beta=1, threshold=20: -1,
|
torch.nn.functional.softplus: lambda input, beta=1, threshold=20: -1,
|
||||||
|
|
@ -930,25 +976,29 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.nn.functional.softsign: lambda input: -1,
|
torch.nn.functional.softsign: lambda input: -1,
|
||||||
torch.nn.functional.tanhshrink: lambda input: -1,
|
torch.nn.functional.tanhshrink: lambda input: -1,
|
||||||
torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1,
|
torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1,
|
||||||
torch.nn.functional.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06,
|
torch.nn.functional.triplet_margin_loss: (
|
||||||
swap=False, size_average=None, reduce=None, reduction='mean': -1),
|
lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 # noqa: B950
|
||||||
torch.nn.functional.triplet_margin_with_distance_loss: (lambda anchor, positive, negative, *,
|
),
|
||||||
distance_function=None, margin=1.0,
|
torch.nn.functional.triplet_margin_with_distance_loss: (
|
||||||
swap=False, reduction='mean': -1),
|
lambda anchor, positive, negative, *, distance_function=None, margin=1.0, swap=False, reduction="mean": -1
|
||||||
|
),
|
||||||
torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
|
torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
|
||||||
torch.nn.init.uniform_: lambda tensor, a=0., b=1., generator=None: -1,
|
torch.nn.init.uniform_: lambda tensor, a=0.0, b=1.0, generator=None: -1,
|
||||||
torch.nn.init.normal_: lambda tensor, mean=0., std=1., generator=None: -1,
|
torch.nn.init.normal_: lambda tensor, mean=0.0, std=1.0, generator=None: -1,
|
||||||
torch.nn.init.constant_: lambda tensor, val: -1,
|
torch.nn.init.constant_: lambda tensor, val: -1,
|
||||||
torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=None: -1,
|
torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", generator=None: -1, # noqa: B950
|
||||||
torch.nonzero: lambda input, as_tuple=False: -1,
|
torch.nonzero: lambda input, as_tuple=False: -1,
|
||||||
torch.nonzero_static: lambda input, *, size, fill_value=-1: -1,
|
torch.nonzero_static: lambda input, *, size, fill_value=-1: -1,
|
||||||
torch.argwhere: lambda input: -1,
|
torch.argwhere: lambda input: -1,
|
||||||
torch.norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
|
torch.norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1,
|
||||||
torch.linalg.norm: lambda input, ord=None, dim=None, keepdim=False, out=None, dtype=None: -1,
|
torch.linalg.norm: lambda input, ord=None, dim=None, keepdim=False, out=None, dtype=None: -1,
|
||||||
torch.linalg.vector_norm: lambda input, ord=2, dim=None, keepdim=False, out=None, dtype=None: -1,
|
torch.linalg.vector_norm: lambda input, ord=2, dim=None, keepdim=False, out=None, dtype=None: -1,
|
||||||
torch.linalg.matrix_norm: lambda input, ord='fro', dim=(-2, -1), keepdim=False, out=None, dtype=None: -1,
|
torch.linalg.matrix_norm: lambda input, ord="fro", dim=(
|
||||||
|
-2,
|
||||||
|
-1,
|
||||||
|
), keepdim=False, out=None, dtype=None: -1,
|
||||||
torch.norm_except_dim: lambda v, pow=2, dim=0: -1,
|
torch.norm_except_dim: lambda v, pow=2, dim=0: -1,
|
||||||
torch.nuclear_norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
|
torch.nuclear_norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1,
|
||||||
torch.numel: lambda input: -1,
|
torch.numel: lambda input: -1,
|
||||||
torch.orgqr: lambda input, tau: -1,
|
torch.orgqr: lambda input, tau: -1,
|
||||||
torch.ormqr: lambda input, input2, input3, left=True, transpose=False: -1,
|
torch.ormqr: lambda input, input2, input3, left=True, transpose=False: -1,
|
||||||
|
|
@ -975,28 +1025,43 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.q_scale: lambda input: -1,
|
torch.q_scale: lambda input: -1,
|
||||||
torch.q_zero_point: lambda input: -1,
|
torch.q_zero_point: lambda input: -1,
|
||||||
torch.qr: lambda input, some=True, out=None: -1,
|
torch.qr: lambda input, some=True, out=None: -1,
|
||||||
torch.linalg.qr: lambda input, mode='reduced', out=None: -1,
|
torch.linalg.qr: lambda input, mode="reduced", out=None: -1,
|
||||||
torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1,
|
torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1,
|
||||||
torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1,
|
torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1,
|
||||||
torch.quantize_per_channel: lambda input, scales, zero_points, axis, dtype: -1,
|
torch.quantize_per_channel: lambda input, scales, zero_points, axis, dtype: -1,
|
||||||
torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1,
|
torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1,
|
||||||
torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1,
|
torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1,
|
||||||
torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1,
|
torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1,
|
||||||
torch.quantized_gru_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
|
torch.quantized_gru_cell: (
|
||||||
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
|
lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
|
||||||
|
),
|
||||||
torch.quantized_lstm_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
|
torch.quantized_lstm_cell: (
|
||||||
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
|
lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
|
||||||
torch.quantized_max_pool1d: (lambda input, kernel_size, stride=tuple(), padding=(0,),
|
),
|
||||||
dilation=(1,), ceil_mode=False: -1),
|
torch.quantized_max_pool1d: (
|
||||||
torch.quantized_max_pool2d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0),
|
lambda input, kernel_size, stride=tuple(), padding=(0,), dilation=(
|
||||||
dilation=(1, 1), ceil_mode=False: -1),
|
1,
|
||||||
torch.quantized_max_pool3d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0, 0),
|
), ceil_mode=False: -1
|
||||||
dilation=(1, 1, 1), ceil_mode=False: -1),
|
),
|
||||||
torch.quantized_rnn_relu_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
|
torch.quantized_max_pool2d: (
|
||||||
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
|
lambda input, kernel_size, stride=tuple(), padding=(0, 0), dilation=(
|
||||||
torch.quantized_rnn_tanh_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
|
1,
|
||||||
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
|
1,
|
||||||
|
), ceil_mode=False: -1
|
||||||
|
),
|
||||||
|
torch.quantized_max_pool3d: (
|
||||||
|
lambda input, kernel_size, stride=tuple(), padding=(0, 0, 0), dilation=(
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
), ceil_mode=False: -1
|
||||||
|
),
|
||||||
|
torch.quantized_rnn_relu_cell: (
|
||||||
|
lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
|
||||||
|
),
|
||||||
|
torch.quantized_rnn_tanh_cell: (
|
||||||
|
lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
|
||||||
|
),
|
||||||
torch.rad2deg: lambda input, out=None: -1,
|
torch.rad2deg: lambda input, out=None: -1,
|
||||||
torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
|
torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
|
||||||
torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
|
torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
|
||||||
|
|
@ -1014,16 +1079,16 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.repeat_interleave: lambda input, dim=None: -1,
|
torch.repeat_interleave: lambda input, dim=None: -1,
|
||||||
torch.reshape: lambda input, shape: -1,
|
torch.reshape: lambda input, shape: -1,
|
||||||
torch.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
|
torch.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
|
||||||
torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
|
torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, # noqa: B950
|
||||||
torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
|
torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
|
||||||
torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
|
torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, # noqa: B950
|
||||||
torch.rnn_tanh_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
|
torch.rnn_tanh_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
|
||||||
torch.roll: lambda input, shifts, dims=None: -1,
|
torch.roll: lambda input, shifts, dims=None: -1,
|
||||||
torch.rot90: lambda input, k=1, dims=(0, 1): -1,
|
torch.rot90: lambda input, k=1, dims=(0, 1): -1,
|
||||||
torch.round: lambda input, out=None: -1,
|
torch.round: lambda input, out=None: -1,
|
||||||
torch.row_stack: lambda tensors, out=None: -1, # alias for torch.vstack
|
torch.row_stack: lambda tensors, out=None: -1, # alias for torch.vstack
|
||||||
torch._rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1),
|
torch._rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1),
|
||||||
torch.rrelu: lambda input, lower=1. / 8, upper=1. / 3, training=False, inplace=False: -1,
|
torch.rrelu: lambda input, lower=1.0 / 8, upper=1.0 / 3, training=False, inplace=False: -1,
|
||||||
torch.rsqrt: lambda input, out=None: -1,
|
torch.rsqrt: lambda input, out=None: -1,
|
||||||
torch.rsub: lambda input, other, alpha=1: -1,
|
torch.rsub: lambda input, other, alpha=1: -1,
|
||||||
torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
|
torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
|
||||||
|
|
@ -1031,7 +1096,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.scatter_add: lambda input, dim, index, src: -1,
|
torch.scatter_add: lambda input, dim, index, src: -1,
|
||||||
torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1,
|
torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1,
|
||||||
torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
|
torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
|
||||||
torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1,
|
torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1, # noqa: B950
|
||||||
torch.select: lambda input, dim, index: -1,
|
torch.select: lambda input, dim, index: -1,
|
||||||
torch.select_scatter: lambda input, src, dim, index: -1,
|
torch.select_scatter: lambda input, src, dim, index: -1,
|
||||||
torch.slice_inverse: lambda input, src, dim=0, start=None, end=None, step=1: -1,
|
torch.slice_inverse: lambda input, src, dim=0, start=None, end=None, step=1: -1,
|
||||||
|
|
@ -1061,8 +1126,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.stack: lambda tensors, dim=0, out=None: -1,
|
torch.stack: lambda tensors, dim=0, out=None: -1,
|
||||||
torch.std: lambda input, dim=None: -1,
|
torch.std: lambda input, dim=None: -1,
|
||||||
torch.std_mean: lambda input, dim=None: -1,
|
torch.std_mean: lambda input, dim=None: -1,
|
||||||
torch.stft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
|
torch.stft: (
|
||||||
pad_mode='reflect', normalized=False, onesided=True, return_complex=None: -1),
|
lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=None: -1 # noqa: B950
|
||||||
|
),
|
||||||
torch.sub: lambda input, other, out=None: -1,
|
torch.sub: lambda input, other, out=None: -1,
|
||||||
torch.subtract: lambda input, other, out=None: -1,
|
torch.subtract: lambda input, other, out=None: -1,
|
||||||
torch.sum: lambda input, dim=None: -1,
|
torch.sum: lambda input, dim=None: -1,
|
||||||
|
|
@ -1164,9 +1230,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1,
|
torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1,
|
||||||
torch.linalg.solve_triangular: lambda input, B, upper, left=True, unitriangular=False: -1,
|
torch.linalg.solve_triangular: lambda input, B, upper, left=True, unitriangular=False: -1,
|
||||||
torch.tril: lambda input, diagonal=0, out=None: -1,
|
torch.tril: lambda input, diagonal=0, out=None: -1,
|
||||||
torch.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False,
|
torch.triplet_margin_loss: (
|
||||||
|
lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 # noqa: B950
|
||||||
size_average=None, reduce=None, reduction='mean': -1),
|
),
|
||||||
torch.triu: lambda input, diagonal=0, out=None: -1,
|
torch.triu: lambda input, diagonal=0, out=None: -1,
|
||||||
torch.true_divide: lambda input, other: -1,
|
torch.true_divide: lambda input, other: -1,
|
||||||
torch.trunc: lambda input, out=None: -1,
|
torch.trunc: lambda input, out=None: -1,
|
||||||
|
|
@ -1436,10 +1502,16 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1,
|
torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1,
|
||||||
}
|
}
|
||||||
|
|
||||||
privateuse1_backend_name = torch.utils.backend_registration._privateuse1_backend_name
|
privateuse1_backend_name = (
|
||||||
|
torch.utils.backend_registration._privateuse1_backend_name
|
||||||
|
)
|
||||||
if hasattr(Tensor, privateuse1_backend_name):
|
if hasattr(Tensor, privateuse1_backend_name):
|
||||||
ret[getattr(Tensor, privateuse1_backend_name)] = lambda self, device=None, non_blocking=False, **kwargs: -1
|
ret[
|
||||||
ret[getattr(Tensor, f'is_{privateuse1_backend_name}').__get__] = lambda self: -1 # noqa: B009
|
getattr(Tensor, privateuse1_backend_name)
|
||||||
|
] = lambda self, device=None, non_blocking=False, **kwargs: -1
|
||||||
|
ret[
|
||||||
|
getattr(Tensor, f"is_{privateuse1_backend_name}").__get__
|
||||||
|
] = lambda self: -1 # noqa: B009
|
||||||
|
|
||||||
ret2 = {}
|
ret2 = {}
|
||||||
ignored = get_ignored_functions()
|
ignored = get_ignored_functions()
|
||||||
|
|
@ -1458,11 +1530,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
# bitwise_<op> have dunder methods of the form __<op>__
|
# bitwise_<op> have dunder methods of the form __<op>__
|
||||||
# And so on.
|
# And so on.
|
||||||
subname = k.__name__[len("bitwise_") :]
|
subname = k.__name__[len("bitwise_") :]
|
||||||
names.extend([
|
names.extend(
|
||||||
"__" + subname + "__",
|
["__" + subname + "__", "__i" + subname + "__", "__r" + subname + "__"]
|
||||||
"__i" + subname + "__",
|
)
|
||||||
"__r" + subname + "__"
|
|
||||||
])
|
|
||||||
|
|
||||||
for name in names:
|
for name in names:
|
||||||
func = getattr(Tensor, name, None)
|
func = getattr(Tensor, name, None)
|
||||||
|
|
@ -1472,6 +1542,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
ret.update(ret2)
|
ret.update(ret2)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def wrap_torch_function(dispatcher: Callable):
|
def wrap_torch_function(dispatcher: Callable):
|
||||||
"""Wraps a given function with ``__torch_function__`` -related functionality.
|
"""Wraps a given function with ``__torch_function__`` -related functionality.
|
||||||
|
|
||||||
|
|
@ -1495,6 +1566,7 @@ def wrap_torch_function(dispatcher: Callable):
|
||||||
>>> def func(a): # This will make func dispatchable by __torch_function__
|
>>> def func(a): # This will make func dispatchable by __torch_function__
|
||||||
... return a + 0
|
... return a + 0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def inner(func):
|
def inner(func):
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
|
|
@ -1508,7 +1580,10 @@ def wrap_torch_function(dispatcher: Callable):
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
def _get_overloaded_args(relevant_args: Iterable[Any], get_type_fn: Callable[[Any], Type] = None) -> List[Any]:
|
|
||||||
|
def _get_overloaded_args(
|
||||||
|
relevant_args: Iterable[Any], get_type_fn: Callable[[Any], Type] = None
|
||||||
|
) -> List[Any]:
|
||||||
"""Returns a list of arguments on which to call __torch_function__.
|
"""Returns a list of arguments on which to call __torch_function__.
|
||||||
|
|
||||||
Checks arguments in relevant_args for __torch_function__ implementations,
|
Checks arguments in relevant_args for __torch_function__ implementations,
|
||||||
|
|
@ -1559,8 +1634,11 @@ def _get_overloaded_args(relevant_args: Iterable[Any], get_type_fn: Callable[[An
|
||||||
#
|
#
|
||||||
# NB: Important to exclude _disabled_torch_function_impl, otherwise
|
# NB: Important to exclude _disabled_torch_function_impl, otherwise
|
||||||
# https://github.com/pytorch/pytorch/issues/64687
|
# https://github.com/pytorch/pytorch/issues/64687
|
||||||
if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__') and
|
if (
|
||||||
arg_type.__torch_function__ != torch._C._disabled_torch_function_impl):
|
arg_type not in overloaded_types
|
||||||
|
and hasattr(arg_type, "__torch_function__")
|
||||||
|
and arg_type.__torch_function__ != torch._C._disabled_torch_function_impl
|
||||||
|
):
|
||||||
# Create lists explicitly for the first type (usually the only one
|
# Create lists explicitly for the first type (usually the only one
|
||||||
# done) to avoid setting up the iterator for overloaded_args.
|
# done) to avoid setting up the iterator for overloaded_args.
|
||||||
if overloaded_types:
|
if overloaded_types:
|
||||||
|
|
@ -1581,7 +1659,8 @@ def _get_overloaded_args(relevant_args: Iterable[Any], get_type_fn: Callable[[An
|
||||||
|
|
||||||
|
|
||||||
def handle_torch_function(
|
def handle_torch_function(
|
||||||
public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs) -> Any:
|
public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs
|
||||||
|
) -> Any:
|
||||||
"""Implement a function with checks for ``__torch_function__`` overrides.
|
"""Implement a function with checks for ``__torch_function__`` overrides.
|
||||||
|
|
||||||
See torch::autograd::handle_torch_function for the equivalent of this
|
See torch::autograd::handle_torch_function for the equivalent of this
|
||||||
|
|
@ -1636,11 +1715,16 @@ def handle_torch_function(
|
||||||
# This call needs to become a classmethod call in the future.
|
# This call needs to become a classmethod call in the future.
|
||||||
# See https://github.com/pytorch/pytorch/issues/63767
|
# See https://github.com/pytorch/pytorch/issues/63767
|
||||||
torch_func_method = overloaded_arg.__torch_function__
|
torch_func_method = overloaded_arg.__torch_function__
|
||||||
if hasattr(torch_func_method, "__self__") and torch_func_method.__self__ is overloaded_arg and \
|
if (
|
||||||
torch_func_method is not torch._C._disabled_torch_function_impl:
|
hasattr(torch_func_method, "__self__")
|
||||||
warnings.warn("Defining your `__torch_function__ as a plain method is deprecated and "
|
and torch_func_method.__self__ is overloaded_arg
|
||||||
|
and torch_func_method is not torch._C._disabled_torch_function_impl
|
||||||
|
):
|
||||||
|
warnings.warn(
|
||||||
|
"Defining your `__torch_function__ as a plain method is deprecated and "
|
||||||
"will be an error in future, please define it as a classmethod.",
|
"will be an error in future, please define it as a classmethod.",
|
||||||
DeprecationWarning)
|
DeprecationWarning,
|
||||||
|
)
|
||||||
|
|
||||||
# Use `public_api` instead of `implementation` so __torch_function__
|
# Use `public_api` instead of `implementation` so __torch_function__
|
||||||
# implementations can do equality/identity comparisons.
|
# implementations can do equality/identity comparisons.
|
||||||
|
|
@ -1649,15 +1733,16 @@ def handle_torch_function(
|
||||||
if result is not NotImplemented:
|
if result is not NotImplemented:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
func_name = f'{public_api.__module__}.{public_api.__name__}'
|
func_name = f"{public_api.__module__}.{public_api.__name__}"
|
||||||
msg = (
|
msg = (
|
||||||
f"no implementation found for '{func_name}' on types that implement "
|
f"no implementation found for '{func_name}' on types that implement "
|
||||||
f'__torch_function__: {[type(arg) for arg in overloaded_args]}'
|
f"__torch_function__: {[type(arg) for arg in overloaded_args]}"
|
||||||
)
|
)
|
||||||
if _is_torch_function_mode_enabled():
|
if _is_torch_function_mode_enabled():
|
||||||
msg += f" nor in mode {_get_current_function_mode()}"
|
msg += f" nor in mode {_get_current_function_mode()}"
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
|
|
||||||
|
|
||||||
has_torch_function = _add_docstr(
|
has_torch_function = _add_docstr(
|
||||||
_has_torch_function,
|
_has_torch_function,
|
||||||
r"""Check for __torch_function__ implementations in the elements of an iterable
|
r"""Check for __torch_function__ implementations in the elements of an iterable
|
||||||
|
|
@ -1678,7 +1763,7 @@ has_torch_function = _add_docstr(
|
||||||
________
|
________
|
||||||
torch.is_tensor_like
|
torch.is_tensor_like
|
||||||
Checks if something is a Tensor-like, including an exact ``Tensor``.
|
Checks if something is a Tensor-like, including an exact ``Tensor``.
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
has_torch_function_unary = _add_docstr(
|
has_torch_function_unary = _add_docstr(
|
||||||
|
|
@ -1689,7 +1774,7 @@ has_torch_function_unary = _add_docstr(
|
||||||
call:
|
call:
|
||||||
`has_torch_function_unary(t)`
|
`has_torch_function_unary(t)`
|
||||||
which skips unnecessary packing and unpacking work.
|
which skips unnecessary packing and unpacking work.
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
has_torch_function_variadic = _add_docstr(
|
has_torch_function_variadic = _add_docstr(
|
||||||
|
|
@ -1703,11 +1788,14 @@ has_torch_function_variadic = _add_docstr(
|
||||||
call:
|
call:
|
||||||
`has_torch_function_variadic(a, b)`
|
`has_torch_function_variadic(a, b)`
|
||||||
which skips unnecessary packing and unpacking work.
|
which skips unnecessary packing and unpacking work.
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]:
|
def _get_overridable_functions() -> (
|
||||||
|
Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]
|
||||||
|
):
|
||||||
overridable_funcs = collections.defaultdict(list)
|
overridable_funcs = collections.defaultdict(list)
|
||||||
index = {}
|
index = {}
|
||||||
tested_namespaces = [
|
tested_namespaces = [
|
||||||
|
|
@ -1725,21 +1813,21 @@ def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callab
|
||||||
ignore = False
|
ignore = False
|
||||||
# ignore private functions or functions that are deleted in torch.__init__
|
# ignore private functions or functions that are deleted in torch.__init__
|
||||||
if namespace is not torch.Tensor:
|
if namespace is not torch.Tensor:
|
||||||
if func_name.startswith('__'):
|
if func_name.startswith("__"):
|
||||||
continue
|
continue
|
||||||
elif func_name.startswith('_'):
|
elif func_name.startswith("_"):
|
||||||
ignore = True
|
ignore = True
|
||||||
elif func_name.endswith('_'):
|
elif func_name.endswith("_"):
|
||||||
ignore = True
|
ignore = True
|
||||||
elif not func_name[0].islower():
|
elif not func_name[0].islower():
|
||||||
ignore = True
|
ignore = True
|
||||||
elif func_name == 'unique_dim':
|
elif func_name == "unique_dim":
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
func = getattr(namespace, func_name)
|
func = getattr(namespace, func_name)
|
||||||
if getattr(object, func_name, None) == func:
|
if getattr(object, func_name, None) == func:
|
||||||
continue
|
continue
|
||||||
if func_name == '__weakref__':
|
if func_name == "__weakref__":
|
||||||
continue
|
continue
|
||||||
func = getattr(namespace, func_name)
|
func = getattr(namespace, func_name)
|
||||||
if namespace is torch.Tensor and getattr(object, func_name, None) == func:
|
if namespace is torch.Tensor and getattr(object, func_name, None) == func:
|
||||||
|
|
@ -1757,9 +1845,13 @@ def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callab
|
||||||
if ignore:
|
if ignore:
|
||||||
continue
|
continue
|
||||||
if func.__get__ in get_ignored_functions():
|
if func.__get__ in get_ignored_functions():
|
||||||
msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
|
msg = (
|
||||||
"but still has an explicit override")
|
"{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
|
||||||
assert func.__get__ not in get_testing_overrides(), msg.format(namespace, func.__name__)
|
"but still has an explicit override"
|
||||||
|
)
|
||||||
|
assert func.__get__ not in get_testing_overrides(), msg.format(
|
||||||
|
namespace, func.__name__
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
overridable_funcs[func].append(func.__get__)
|
overridable_funcs[func].append(func.__get__)
|
||||||
|
|
@ -1775,13 +1867,18 @@ def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callab
|
||||||
|
|
||||||
# cannot be overriden by __torch_function__
|
# cannot be overriden by __torch_function__
|
||||||
if func in get_ignored_functions():
|
if func in get_ignored_functions():
|
||||||
msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
|
msg = (
|
||||||
"but still has an explicit override")
|
"{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
|
||||||
assert func not in get_testing_overrides(), msg.format(namespace, func.__name__)
|
"but still has an explicit override"
|
||||||
|
)
|
||||||
|
assert func not in get_testing_overrides(), msg.format(
|
||||||
|
namespace, func.__name__
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
overridable_funcs[namespace].append(func)
|
overridable_funcs[namespace].append(func)
|
||||||
return overridable_funcs, index
|
return overridable_funcs, index
|
||||||
|
|
||||||
|
|
||||||
@_disable_user_warnings
|
@_disable_user_warnings
|
||||||
def get_overridable_functions() -> Dict[Any, List[Callable]]:
|
def get_overridable_functions() -> Dict[Any, List[Callable]]:
|
||||||
"""List functions that are overridable via __torch_function__
|
"""List functions that are overridable via __torch_function__
|
||||||
|
|
@ -1794,6 +1891,7 @@ def get_overridable_functions() -> Dict[Any, List[Callable]]:
|
||||||
"""
|
"""
|
||||||
return _get_overridable_functions()[0]
|
return _get_overridable_functions()[0]
|
||||||
|
|
||||||
|
|
||||||
@_disable_user_warnings
|
@_disable_user_warnings
|
||||||
def resolve_name(f):
|
def resolve_name(f):
|
||||||
"""Get a human readable string name for a function passed to
|
"""Get a human readable string name for a function passed to
|
||||||
|
|
@ -1814,6 +1912,7 @@ def resolve_name(f):
|
||||||
return str(f)
|
return str(f)
|
||||||
return _get_overridable_functions()[1].get(f)
|
return _get_overridable_functions()[1].get(f)
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
def _get_tensor_methods() -> Set[Callable]:
|
def _get_tensor_methods() -> Set[Callable]:
|
||||||
"""Returns a set of the overridable methods on ``torch.Tensor``"""
|
"""Returns a set of the overridable methods on ``torch.Tensor``"""
|
||||||
|
|
@ -1821,6 +1920,7 @@ def _get_tensor_methods() -> Set[Callable]:
|
||||||
methods = set(overridable_funcs[torch.Tensor])
|
methods = set(overridable_funcs[torch.Tensor])
|
||||||
return methods
|
return methods
|
||||||
|
|
||||||
|
|
||||||
@_disable_user_warnings
|
@_disable_user_warnings
|
||||||
def is_tensor_method_or_property(func: Callable) -> bool:
|
def is_tensor_method_or_property(func: Callable) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
@ -1846,6 +1946,7 @@ def is_tensor_method_or_property(func: Callable) -> bool:
|
||||||
"""
|
"""
|
||||||
return func in _get_tensor_methods() or func.__name__ == "__get__"
|
return func in _get_tensor_methods() or func.__name__ == "__get__"
|
||||||
|
|
||||||
|
|
||||||
def is_tensor_like(inp):
|
def is_tensor_like(inp):
|
||||||
"""
|
"""
|
||||||
Returns ``True`` if the passed-in input is a Tensor-like.
|
Returns ``True`` if the passed-in input is a Tensor-like.
|
||||||
|
|
@ -1882,6 +1983,7 @@ def is_tensor_like(inp):
|
||||||
"""
|
"""
|
||||||
return type(inp) is torch.Tensor or hasattr(inp, "__torch_function__")
|
return type(inp) is torch.Tensor or hasattr(inp, "__torch_function__")
|
||||||
|
|
||||||
|
|
||||||
class TorchFunctionMode:
|
class TorchFunctionMode:
|
||||||
"""
|
"""
|
||||||
A ``TorchFunctionMode`` allows you to override the meaning of all
|
A ``TorchFunctionMode`` allows you to override the meaning of all
|
||||||
|
|
@ -1912,6 +2014,7 @@ class TorchFunctionMode:
|
||||||
``enable_torch_function_mode(self, replace=self.inner)`` to make PyTorch
|
``enable_torch_function_mode(self, replace=self.inner)`` to make PyTorch
|
||||||
API self-referential (beware of infinite loops, in this case!)
|
API self-referential (beware of infinite loops, in this case!)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
inner: "TorchFunctionMode"
|
inner: "TorchFunctionMode"
|
||||||
|
|
||||||
# Force metaclass to generate constructor at the base of the hierarchy
|
# Force metaclass to generate constructor at the base of the hierarchy
|
||||||
|
|
@ -1930,7 +2033,9 @@ class TorchFunctionMode:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def push(cls, *args, **kwargs):
|
def push(cls, *args, **kwargs):
|
||||||
warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`")
|
warnings.warn(
|
||||||
|
"`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`"
|
||||||
|
)
|
||||||
instance = cls(*args, **kwargs)
|
instance = cls(*args, **kwargs)
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
|
@ -1944,6 +2049,7 @@ def _get_current_function_mode_stack():
|
||||||
stack_len = _len_torch_function_stack()
|
stack_len = _len_torch_function_stack()
|
||||||
return [_get_function_stack_at(i) for i in range(stack_len)]
|
return [_get_function_stack_at(i) for i in range(stack_len)]
|
||||||
|
|
||||||
|
|
||||||
def _push_mode(mode):
|
def _push_mode(mode):
|
||||||
_push_on_torch_function_stack(mode)
|
_push_on_torch_function_stack(mode)
|
||||||
|
|
||||||
|
|
@ -1961,6 +2067,7 @@ def _pop_mode_temporarily():
|
||||||
finally:
|
finally:
|
||||||
_push_mode(old)
|
_push_mode(old)
|
||||||
|
|
||||||
|
|
||||||
class BaseTorchFunctionMode(TorchFunctionMode):
|
class BaseTorchFunctionMode(TorchFunctionMode):
|
||||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||||
if kwargs is None:
|
if kwargs is None:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import torch
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class SobolEngine:
|
class SobolEngine:
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -48,8 +49,10 @@ class SobolEngine:
|
||||||
|
|
||||||
def __init__(self, dimension, scramble=False, seed=None):
|
def __init__(self, dimension, scramble=False, seed=None):
|
||||||
if dimension > self.MAXDIM or dimension < 1:
|
if dimension > self.MAXDIM or dimension < 1:
|
||||||
raise ValueError("Supported range of dimensionality "
|
raise ValueError(
|
||||||
f"for SobolEngine is [1, {self.MAXDIM}]")
|
"Supported range of dimensionality "
|
||||||
|
f"for SobolEngine is [1, {self.MAXDIM}]"
|
||||||
|
)
|
||||||
|
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.scramble = scramble
|
self.scramble = scramble
|
||||||
|
|
@ -57,7 +60,9 @@ class SobolEngine:
|
||||||
|
|
||||||
cpu = torch.device("cpu")
|
cpu = torch.device("cpu")
|
||||||
|
|
||||||
self.sobolstate = torch.zeros(dimension, self.MAXBIT, device=cpu, dtype=torch.long)
|
self.sobolstate = torch.zeros(
|
||||||
|
dimension, self.MAXBIT, device=cpu, dtype=torch.long
|
||||||
|
)
|
||||||
torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension)
|
torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension)
|
||||||
|
|
||||||
if not self.scramble:
|
if not self.scramble:
|
||||||
|
|
@ -69,8 +74,12 @@ class SobolEngine:
|
||||||
self._first_point = (self.quasi / 2**self.MAXBIT).reshape(1, -1)
|
self._first_point = (self.quasi / 2**self.MAXBIT).reshape(1, -1)
|
||||||
self.num_generated = 0
|
self.num_generated = 0
|
||||||
|
|
||||||
def draw(self, n: int = 1, out: Optional[torch.Tensor] = None,
|
def draw(
|
||||||
dtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
self,
|
||||||
|
n: int = 1,
|
||||||
|
out: Optional[torch.Tensor] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
Function to draw a sequence of :attr:`n` points from a Sobol sequence.
|
Function to draw a sequence of :attr:`n` points from a Sobol sequence.
|
||||||
Note that the samples are dependent on the previous samples. The size
|
Note that the samples are dependent on the previous samples. The size
|
||||||
|
|
@ -92,12 +101,22 @@ class SobolEngine:
|
||||||
result = self._first_point.to(dtype)
|
result = self._first_point.to(dtype)
|
||||||
else:
|
else:
|
||||||
result, self.quasi = torch._sobol_engine_draw(
|
result, self.quasi = torch._sobol_engine_draw(
|
||||||
self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated, dtype=dtype,
|
self.quasi,
|
||||||
|
n - 1,
|
||||||
|
self.sobolstate,
|
||||||
|
self.dimension,
|
||||||
|
self.num_generated,
|
||||||
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
result = torch.cat((self._first_point.to(dtype), result), dim=-2)
|
result = torch.cat((self._first_point.to(dtype), result), dim=-2)
|
||||||
else:
|
else:
|
||||||
result, self.quasi = torch._sobol_engine_draw(
|
result, self.quasi = torch._sobol_engine_draw(
|
||||||
self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1, dtype=dtype,
|
self.quasi,
|
||||||
|
n,
|
||||||
|
self.sobolstate,
|
||||||
|
self.dimension,
|
||||||
|
self.num_generated - 1,
|
||||||
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.num_generated += n
|
self.num_generated += n
|
||||||
|
|
@ -108,8 +127,12 @@ class SobolEngine:
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def draw_base2(self, m: int, out: Optional[torch.Tensor] = None,
|
def draw_base2(
|
||||||
dtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
self,
|
||||||
|
m: int,
|
||||||
|
out: Optional[torch.Tensor] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
Function to draw a sequence of :attr:`2**m` points from a Sobol sequence.
|
Function to draw a sequence of :attr:`2**m` points from a Sobol sequence.
|
||||||
Note that the samples are dependent on the previous samples. The size
|
Note that the samples are dependent on the previous samples. The size
|
||||||
|
|
@ -125,7 +148,8 @@ class SobolEngine:
|
||||||
n = 2**m
|
n = 2**m
|
||||||
total_n = self.num_generated + n
|
total_n = self.num_generated + n
|
||||||
if not (total_n & (total_n - 1) == 0):
|
if not (total_n & (total_n - 1) == 0):
|
||||||
raise ValueError("The balance properties of Sobol' points require "
|
raise ValueError(
|
||||||
|
"The balance properties of Sobol' points require "
|
||||||
f"n to be a power of 2. {self.num_generated} points have been "
|
f"n to be a power of 2. {self.num_generated} points have been "
|
||||||
f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. "
|
f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. "
|
||||||
"If you still want to do this, please use "
|
"If you still want to do this, please use "
|
||||||
|
|
@ -151,9 +175,13 @@ class SobolEngine:
|
||||||
n (Int): The number of steps to fast-forward by.
|
n (Int): The number of steps to fast-forward by.
|
||||||
"""
|
"""
|
||||||
if self.num_generated == 0:
|
if self.num_generated == 0:
|
||||||
torch._sobol_engine_ff_(self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated)
|
torch._sobol_engine_ff_(
|
||||||
|
self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
torch._sobol_engine_ff_(self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1)
|
torch._sobol_engine_ff_(
|
||||||
|
self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1
|
||||||
|
)
|
||||||
self.num_generated += n
|
self.num_generated += n
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -166,8 +194,12 @@ class SobolEngine:
|
||||||
cpu = torch.device("cpu")
|
cpu = torch.device("cpu")
|
||||||
|
|
||||||
# Generate shift vector
|
# Generate shift vector
|
||||||
shift_ints = torch.randint(2, (self.dimension, self.MAXBIT), device=cpu, generator=g)
|
shift_ints = torch.randint(
|
||||||
self.shift = torch.mv(shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu)))
|
2, (self.dimension, self.MAXBIT), device=cpu, generator=g
|
||||||
|
)
|
||||||
|
self.shift = torch.mv(
|
||||||
|
shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu))
|
||||||
|
)
|
||||||
|
|
||||||
# Generate lower triangular matrices (stacked across dimensions)
|
# Generate lower triangular matrices (stacked across dimensions)
|
||||||
ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT)
|
ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT)
|
||||||
|
|
@ -176,9 +208,9 @@ class SobolEngine:
|
||||||
torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension)
|
torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
fmt_string = [f'dimension={self.dimension}']
|
fmt_string = [f"dimension={self.dimension}"]
|
||||||
if self.scramble:
|
if self.scramble:
|
||||||
fmt_string += ['scramble=True']
|
fmt_string += ["scramble=True"]
|
||||||
if self.seed is not None:
|
if self.seed is not None:
|
||||||
fmt_string += [f'seed={self.seed}']
|
fmt_string += [f"seed={self.seed}"]
|
||||||
return self.__class__.__name__ + '(' + ', '.join(fmt_string) + ')'
|
return self.__class__.__name__ + "(" + ", ".join(fmt_string) + ")"
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import contextlib
|
import contextlib
|
||||||
from typing import Generator
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
from torch._C import default_generator
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch._C import default_generator
|
||||||
|
|
||||||
|
|
||||||
def set_rng_state(new_state: torch.Tensor) -> None:
|
def set_rng_state(new_state: torch.Tensor) -> None:
|
||||||
|
|
@ -46,10 +46,12 @@ def manual_seed(seed) -> torch._C.Generator:
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
import torch.mps
|
import torch.mps
|
||||||
|
|
||||||
if not torch.mps._is_in_bad_fork():
|
if not torch.mps._is_in_bad_fork():
|
||||||
torch.mps.manual_seed(seed)
|
torch.mps.manual_seed(seed)
|
||||||
|
|
||||||
import torch.xpu
|
import torch.xpu
|
||||||
|
|
||||||
if not torch.xpu._is_in_bad_fork():
|
if not torch.xpu._is_in_bad_fork():
|
||||||
torch.xpu.manual_seed_all(seed)
|
torch.xpu.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
@ -69,10 +71,12 @@ def seed() -> int:
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
import torch.mps
|
import torch.mps
|
||||||
|
|
||||||
if not torch.mps._is_in_bad_fork():
|
if not torch.mps._is_in_bad_fork():
|
||||||
torch.mps.manual_seed(seed)
|
torch.mps.manual_seed(seed)
|
||||||
|
|
||||||
import torch.xpu
|
import torch.xpu
|
||||||
|
|
||||||
if not torch.xpu._is_in_bad_fork():
|
if not torch.xpu._is_in_bad_fork():
|
||||||
torch.xpu.manual_seed_all(seed)
|
torch.xpu.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
@ -95,7 +99,9 @@ def _seed_custom_device(seed) -> None:
|
||||||
custom_device_mod = getattr(torch, custom_backend_name)
|
custom_device_mod = getattr(torch, custom_backend_name)
|
||||||
_bad_fork_name = "_is_in_bad_fork"
|
_bad_fork_name = "_is_in_bad_fork"
|
||||||
_seed_all_name = "manual_seed_all"
|
_seed_all_name = "manual_seed_all"
|
||||||
if hasattr(custom_device_mod, _bad_fork_name) and hasattr(custom_device_mod, _seed_all_name):
|
if hasattr(custom_device_mod, _bad_fork_name) and hasattr(
|
||||||
|
custom_device_mod, _seed_all_name
|
||||||
|
):
|
||||||
if not getattr(custom_device_mod, _bad_fork_name)():
|
if not getattr(custom_device_mod, _bad_fork_name)():
|
||||||
getattr(custom_device_mod, _seed_all_name)(seed)
|
getattr(custom_device_mod, _seed_all_name)(seed)
|
||||||
else:
|
else:
|
||||||
|
|
@ -117,7 +123,13 @@ _fork_rng_warned_already = False
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices", device_type="cuda") -> Generator:
|
def fork_rng(
|
||||||
|
devices=None,
|
||||||
|
enabled=True,
|
||||||
|
_caller="fork_rng",
|
||||||
|
_devices_kw="devices",
|
||||||
|
device_type="cuda",
|
||||||
|
) -> Generator:
|
||||||
"""
|
"""
|
||||||
Forks the RNG, so that when you return, the RNG is reset
|
Forks the RNG, so that when you return, the RNG is reset
|
||||||
to the state that it was previously in.
|
to the state that it was previously in.
|
||||||
|
|
@ -138,8 +150,10 @@ def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="device
|
||||||
device_type = torch.device(device_type).type
|
device_type = torch.device(device_type).type
|
||||||
device_mod = getattr(torch, device_type, None)
|
device_mod = getattr(torch, device_type, None)
|
||||||
if device_mod is None:
|
if device_mod is None:
|
||||||
raise RuntimeError(f"torch has no module of `{device_type}`, you should register " +
|
raise RuntimeError(
|
||||||
"a module by `torch._register_device_module`.")
|
f"torch has no module of `{device_type}`, you should register "
|
||||||
|
+ "a module by `torch._register_device_module`."
|
||||||
|
)
|
||||||
global _fork_rng_warned_already
|
global _fork_rng_warned_already
|
||||||
|
|
||||||
# Internal arguments:
|
# Internal arguments:
|
||||||
|
|
@ -153,7 +167,8 @@ def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="device
|
||||||
if devices is None:
|
if devices is None:
|
||||||
num_devices = device_mod.device_count()
|
num_devices = device_mod.device_count()
|
||||||
if num_devices > 1 and not _fork_rng_warned_already:
|
if num_devices > 1 and not _fork_rng_warned_already:
|
||||||
message = (f"{device_type.upper()} reports that you have {num_devices} available devices, and "
|
message = (
|
||||||
|
f"{device_type.upper()} reports that you have {num_devices} available devices, and "
|
||||||
f"you have used {_caller} without explicitly specifying which devices are being used. "
|
f"you have used {_caller} without explicitly specifying which devices are being used. "
|
||||||
f"For safety, we initialize *every* {device_type.upper()} device by default, which can "
|
f"For safety, we initialize *every* {device_type.upper()} device by default, which can "
|
||||||
f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only"
|
f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only"
|
||||||
|
|
@ -163,7 +178,8 @@ def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="device
|
||||||
"set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, "
|
"set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, "
|
||||||
f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices "
|
f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices "
|
||||||
f"and suppress this warning, set the '{_devices_kw}' keyword argument to "
|
f"and suppress this warning, set the '{_devices_kw}' keyword argument to "
|
||||||
f"`range(torch.{device_type}.device_count())`.")
|
f"`range(torch.{device_type}.device_count())`."
|
||||||
|
)
|
||||||
warnings.warn(message)
|
warnings.warn(message)
|
||||||
_fork_rng_warned_already = True
|
_fork_rng_warned_already = True
|
||||||
devices = list(range(num_devices))
|
devices = list(range(num_devices))
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
import torch
|
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch.utils._pytree import register_pytree_node, SequenceKey
|
from torch.utils._pytree import register_pytree_node, SequenceKey
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["pytree_register_structseq", "all_return_types"]
|
__all__ = ["pytree_register_structseq", "all_return_types"]
|
||||||
|
|
||||||
all_return_types = []
|
all_return_types = []
|
||||||
|
|
@ -10,6 +11,7 @@ all_return_types = []
|
||||||
# error: Module has no attribute "_return_types"
|
# error: Module has no attribute "_return_types"
|
||||||
return_types = torch._C._return_types # type: ignore[attr-defined]
|
return_types = torch._C._return_types # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
def pytree_register_structseq(cls):
|
def pytree_register_structseq(cls):
|
||||||
def structseq_flatten(structseq):
|
def structseq_flatten(structseq):
|
||||||
return list(structseq), None
|
return list(structseq), None
|
||||||
|
|
@ -28,14 +30,15 @@ def pytree_register_structseq(cls):
|
||||||
flatten_with_keys_fn=structseq_flatten_with_keys,
|
flatten_with_keys_fn=structseq_flatten_with_keys,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
for name in dir(return_types):
|
for name in dir(return_types):
|
||||||
if name.startswith('__'):
|
if name.startswith("__"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
_attr = getattr(return_types, name)
|
_attr = getattr(return_types, name)
|
||||||
globals()[name] = _attr
|
globals()[name] = _attr
|
||||||
|
|
||||||
if not name.startswith('_'):
|
if not name.startswith("_"):
|
||||||
__all__.append(name)
|
__all__.append(name)
|
||||||
all_return_types.append(_attr)
|
all_return_types.append(_attr)
|
||||||
|
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -2,8 +2,9 @@
|
||||||
|
|
||||||
from typing import Any, Iterable
|
from typing import Any, Iterable
|
||||||
|
|
||||||
from ._vendor.packaging.version import InvalidVersion, Version
|
from torch._vendor.packaging.version import InvalidVersion, Version
|
||||||
from .version import __version__ as internal_version
|
from torch.version import __version__ as internal_version
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["TorchVersion"]
|
__all__ = ["TorchVersion"]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,14 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import builtins
|
import builtins
|
||||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from torch.autograd.graph import GradientEdge
|
||||||
|
|
||||||
|
|
||||||
# Convenience aliases for common composite types that we need
|
# Convenience aliases for common composite types that we need
|
||||||
# to talk about in PyTorch
|
# to talk about in PyTorch
|
||||||
|
|
||||||
|
|
@ -11,8 +16,8 @@ _TensorOrTensors = Union[torch.Tensor, Sequence[torch.Tensor]]
|
||||||
_TensorOrTensorsOrGradEdge = Union[
|
_TensorOrTensorsOrGradEdge = Union[
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
Sequence[torch.Tensor],
|
Sequence[torch.Tensor],
|
||||||
"torch.autograd.graph.GradientEdge",
|
"GradientEdge",
|
||||||
Sequence["torch.autograd.graph.GradientEdge"],
|
Sequence["GradientEdge"],
|
||||||
]
|
]
|
||||||
|
|
||||||
# In some cases, these basic types are shadowed by corresponding
|
# In some cases, these basic types are shadowed by corresponding
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user