mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR adds `linalg.vander`, the linalg version of `torch.vander`. We add autograd support and support for batched inputs. We also take this chance to improve the docs (TODO: Check that they render correctly!) and add an OpInfo. **Discussion**: The current default for the `increasing` kwargs is extremely odd as it is the opposite of the classical definition (see [wiki](https://en.wikipedia.org/wiki/Vandermonde_matrix)). This is reflected in the docs, where I explicit both the odd defaults that we use and the classical definition. See also [this stackoverflow post](https://stackoverflow.com/a/71758047/5280578), which shows how people are confused by this defaults. My take on this would be to correct the default to be `increasing=True` and document the divergence with NumPy (as we do for other `linalg` functions) as: - It is what people expect - It gives the correct determinant called "the Vandermonde determinant" rather than (-1)^{n-1} times the Vandermonde det (ugh). - [Minor] It is more efficient (no `flip` needed) - Since it's under `linalg.vander`, it's strictly not a drop-in replacement for `np.vander`. We will deprecate `torch.vander` in a PR after this one in this stack (once we settle on what's the correct default). Thoughts? mruberry cc kgryte rgommers as they might have some context for the defaults of NumPy. Fixes https://github.com/pytorch/pytorch/issues/60197 Pull Request resolved: https://github.com/pytorch/pytorch/pull/76303 Approved by: https://github.com/albanD
1918 lines
99 KiB
Python
1918 lines
99 KiB
Python
"""
|
|
Python implementation of ``__torch_function__``
|
|
|
|
While most of the torch API and handling for ``__torch_function__`` happens
|
|
at the C++ level, some of the torch API is written in Python so we need
|
|
python-level handling for ``__torch_function__`` overrides as well. The main
|
|
developer-facing functionality in this file are handle_torch_function and
|
|
has_torch_function. See torch/functional.py and test/test_overrides.py
|
|
for usage examples.
|
|
|
|
Note
|
|
----
|
|
heavily inspired by NumPy's ``__array_function__`` (see:
|
|
https://github.com/pytorch/pytorch/issues/24015 and
|
|
https://www.numpy.org/neps/nep-0018-array-function-protocol.html
|
|
)
|
|
|
|
If changing this file in a way that can affect ``__torch_function__`` overhead,
|
|
please report the benchmarks in ``benchmarks/overrides_benchmark``. See the
|
|
instructions in the ``README.md`` in that directory.
|
|
"""
|
|
|
|
import __future__
|
|
|
|
import collections
|
|
import functools
|
|
import types
|
|
import warnings
|
|
from typing import Dict, Set, List, Any, Callable, Iterable, Type, Iterator
|
|
import contextlib
|
|
|
|
import torch
|
|
from torch._C import (
|
|
_has_torch_function, _has_torch_function_unary,
|
|
_has_torch_function_variadic, _add_docstr, _set_torch_function_mode, _get_torch_function_mode)
|
|
|
|
__all__ = [
|
|
"get_ignored_functions",
|
|
"get_overridable_functions",
|
|
"get_testing_overrides",
|
|
"handle_torch_function",
|
|
"has_torch_function",
|
|
"is_tensor_like",
|
|
"is_tensor_method_or_property",
|
|
"wrap_torch_function",
|
|
"enable_reentrant_dispatch",
|
|
]
|
|
|
|
@functools.lru_cache(None)
|
|
def get_ignored_functions() -> Set[Callable]:
|
|
"""
|
|
Return public functions that cannot be overridden by ``__torch_function__``.
|
|
|
|
Returns
|
|
-------
|
|
Set[Callable]
|
|
A tuple of functions that are publicly available in the torch API but cannot
|
|
be overridden with ``__torch_function__``. Mostly this is because none of the
|
|
arguments of these functions are tensors or tensor-likes.
|
|
|
|
Examples
|
|
--------
|
|
>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions()
|
|
True
|
|
>>> torch.add in torch.overrides.get_ignored_functions()
|
|
False
|
|
"""
|
|
Tensor = torch.Tensor
|
|
return {
|
|
torch.typename,
|
|
torch.is_tensor,
|
|
torch.is_storage,
|
|
torch.set_default_tensor_type,
|
|
torch.set_rng_state,
|
|
torch.get_rng_state,
|
|
torch.manual_seed,
|
|
torch.initial_seed,
|
|
torch.seed,
|
|
torch.save,
|
|
torch.load,
|
|
torch.set_printoptions,
|
|
torch.fork,
|
|
torch.get_default_dtype,
|
|
torch.get_num_interop_threads,
|
|
torch.get_num_threads,
|
|
torch.init_num_threads,
|
|
torch.import_ir_module,
|
|
torch.import_ir_module_from_buffer,
|
|
torch.is_anomaly_enabled,
|
|
torch.is_grad_enabled,
|
|
torch.merge_type_from_type_comment,
|
|
torch.parse_ir,
|
|
torch.parse_schema,
|
|
torch.parse_type_comment,
|
|
torch.set_anomaly_enabled,
|
|
torch.set_flush_denormal,
|
|
torch.set_num_interop_threads,
|
|
torch.set_num_threads,
|
|
torch.wait,
|
|
torch.as_tensor,
|
|
torch.from_numpy,
|
|
torch.get_device,
|
|
torch.tensor,
|
|
torch.default_generator,
|
|
torch.has_cuda,
|
|
torch.has_cudnn,
|
|
torch.has_lapack,
|
|
torch.device,
|
|
torch.dtype,
|
|
torch.finfo,
|
|
torch.has_mkl,
|
|
torch.has_mps,
|
|
torch.has_mkldnn,
|
|
torch.has_openmp,
|
|
torch.iinfo,
|
|
torch.memory_format,
|
|
torch.qscheme,
|
|
torch.set_grad_enabled,
|
|
torch.no_grad,
|
|
torch.enable_grad,
|
|
torch.inference_mode,
|
|
torch.is_inference_mode_enabled,
|
|
torch.layout,
|
|
torch.align_tensors,
|
|
torch.arange,
|
|
torch.as_strided,
|
|
torch.bartlett_window,
|
|
torch.blackman_window,
|
|
torch.broadcast_shapes,
|
|
torch.can_cast,
|
|
torch.cudnn_affine_grid_generator,
|
|
torch.cudnn_batch_norm,
|
|
torch.cudnn_convolution,
|
|
torch.cudnn_convolution_transpose,
|
|
torch.cudnn_convolution_relu,
|
|
torch.cudnn_convolution_add_relu,
|
|
torch.cudnn_grid_sampler,
|
|
torch.cudnn_is_acceptable,
|
|
torch.empty,
|
|
torch.empty_strided,
|
|
torch.empty_quantized,
|
|
torch.eye,
|
|
torch.fft.fftfreq,
|
|
torch.fft.rfftfreq,
|
|
torch.from_file,
|
|
torch.full,
|
|
torch.fill,
|
|
torch.hamming_window,
|
|
torch.hann_window,
|
|
torch.kaiser_window,
|
|
torch.linspace,
|
|
torch.logspace,
|
|
torch.mkldnn_adaptive_avg_pool2d,
|
|
torch.mkldnn_convolution,
|
|
torch.mkldnn_max_pool2d,
|
|
torch.mkldnn_max_pool3d,
|
|
torch.mkldnn_linear_backward_weights,
|
|
torch.nested_tensor,
|
|
torch.normal,
|
|
torch.ones,
|
|
torch.promote_types,
|
|
torch.rand,
|
|
torch.randn,
|
|
torch.randint,
|
|
torch.randperm,
|
|
torch.range,
|
|
torch.result_type,
|
|
torch.scalar_tensor,
|
|
torch.sparse_coo_tensor,
|
|
torch.sparse_csr_tensor,
|
|
torch.tril_indices,
|
|
torch.triu_indices,
|
|
torch.vander,
|
|
torch.zeros,
|
|
torch._jit_internal.boolean_dispatch,
|
|
torch.nn.functional.assert_int_or_pair,
|
|
torch.nn.functional.upsample,
|
|
torch.nn.functional.upsample_bilinear,
|
|
torch.nn.functional.upsample_nearest,
|
|
torch.nn.functional.has_torch_function,
|
|
torch.nn.functional.has_torch_function_unary,
|
|
torch.nn.functional.has_torch_function_variadic,
|
|
torch.nn.functional.handle_torch_function,
|
|
torch.nn.functional.sigmoid,
|
|
torch.nn.functional.hardsigmoid,
|
|
torch.nn.functional.tanh,
|
|
# Doesn't actually take or return tensor arguments
|
|
torch.nn.init.calculate_gain,
|
|
# These are deprecated; don't test them
|
|
torch.nn.init.uniform,
|
|
torch.nn.init.normal,
|
|
torch.nn.init.constant,
|
|
torch.nn.init.eye,
|
|
torch.nn.init.dirac,
|
|
torch.nn.init.xavier_uniform,
|
|
torch.nn.init.xavier_normal,
|
|
torch.nn.init.kaiming_uniform,
|
|
torch.nn.init.kaiming_normal,
|
|
torch.nn.init.orthogonal,
|
|
torch.nn.init.sparse,
|
|
has_torch_function,
|
|
handle_torch_function,
|
|
torch.set_autocast_enabled,
|
|
torch.is_autocast_enabled,
|
|
torch.clear_autocast_cache,
|
|
torch.set_autocast_cpu_enabled,
|
|
torch.is_autocast_cpu_enabled,
|
|
torch.set_autocast_cpu_dtype,
|
|
torch.get_autocast_cpu_dtype,
|
|
torch.get_autocast_gpu_dtype,
|
|
torch.set_autocast_gpu_dtype,
|
|
torch.autocast_increment_nesting,
|
|
torch.autocast_decrement_nesting,
|
|
torch.is_autocast_cache_enabled,
|
|
torch.set_autocast_cache_enabled,
|
|
torch.nn.functional.hardswish,
|
|
torch.is_vulkan_available,
|
|
torch.are_deterministic_algorithms_enabled,
|
|
torch.use_deterministic_algorithms,
|
|
torch.is_deterministic_algorithms_warn_only_enabled,
|
|
torch.set_deterministic_debug_mode,
|
|
torch.get_deterministic_debug_mode,
|
|
torch.unify_type_list,
|
|
torch.is_warn_always_enabled,
|
|
torch.set_warn_always,
|
|
torch.vitals_enabled,
|
|
torch.set_vital,
|
|
torch.read_vitals,
|
|
torch.frombuffer,
|
|
torch.asarray,
|
|
Tensor.__delitem__,
|
|
Tensor.__dir__,
|
|
Tensor.__getattribute__,
|
|
Tensor.__init__,
|
|
Tensor.__iter__,
|
|
Tensor.__init_subclass__,
|
|
Tensor.__delattr__,
|
|
Tensor.__setattr__,
|
|
Tensor.__torch_function__,
|
|
Tensor.__torch_dispatch__,
|
|
Tensor.__new__,
|
|
Tensor.__class__,
|
|
Tensor.__subclasshook__,
|
|
Tensor.as_subclass,
|
|
Tensor.reinforce,
|
|
Tensor.new,
|
|
Tensor.new_tensor,
|
|
Tensor.new_empty,
|
|
Tensor.new_empty_strided,
|
|
Tensor.new_zeros,
|
|
Tensor.new_ones,
|
|
Tensor.new_full,
|
|
Tensor._make_subclass,
|
|
Tensor.stride,
|
|
Tensor.unflatten,
|
|
Tensor.to_sparse_coo,
|
|
Tensor.to_sparse_csr,
|
|
Tensor._reduce_ex_internal,
|
|
Tensor._fix_weakref,
|
|
Tensor._make_wrapper_subclass,
|
|
Tensor._python_dispatch.__get__,
|
|
Tensor._conj,
|
|
Tensor._conj_physical,
|
|
Tensor._neg_view,
|
|
Tensor._is_zerotensor,
|
|
Tensor._addmm_activation,
|
|
Tensor._nested_tensor_layer_norm,
|
|
Tensor.to_padded_tensor,
|
|
}
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def get_default_nowrap_functions() -> Set[Callable]:
|
|
"""
|
|
Return public functions that do not wrap in a subclass when invoked by
|
|
the default ``Tensor.__torch_function__`` that preserves subclasses. Typically,
|
|
these functions represent field accesses (i.e., retrieving a Tensor that
|
|
is stored somewhere on the Tensor) as opposed to computation. Users of
|
|
these functions expect object identity to be preserved over multiple accesses
|
|
(e.g., ``a.grad is a.grad``) which cannot be upheld if we're wrapping on
|
|
the fly every time (furthermore, the tensor stored here might already be
|
|
the subclass, in which case wrapping really ought not to happen).
|
|
|
|
Not ALL property accessors have this property; for example ``Tensor.T`` actually
|
|
just creates a new transposed tensor on the fly, and so we SHOULD interpose on
|
|
these calls (you need to check the implementation of the function to see if
|
|
this is the case or not). Additionally, if a property accessor doesn't return a Tensor,
|
|
it doesn't have to be on this list (though it is harmless if it is).
|
|
"""
|
|
Tensor = torch.Tensor
|
|
return {
|
|
Tensor._base.__get__,
|
|
Tensor.grad.__get__,
|
|
Tensor._grad.__get__,
|
|
}
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def get_testing_overrides() -> Dict[Callable, Callable]:
|
|
"""Return a dict containing dummy overrides for all overridable functions
|
|
|
|
Returns
|
|
-------
|
|
Dict[Callable, Callable]
|
|
A dictionary that maps overridable functions in the PyTorch API to
|
|
lambda functions that have the same signature as the real function
|
|
and unconditionally return -1. These lambda functions are useful
|
|
for testing API coverage for a type that defines ``__torch_function__``.
|
|
|
|
Examples
|
|
--------
|
|
>>> import inspect
|
|
>>> my_add = torch.overrides.get_testing_overrides()[torch.add]
|
|
>>> inspect.signature(my_add)
|
|
<Signature (input, other, out=None)>
|
|
"""
|
|
# Every function in the PyTorchAPI that can be overriden needs an entry
|
|
# in this dict.
|
|
#
|
|
# Optimally we would use inspect to get the function signature and define
|
|
# the lambda function procedurally but that is blocked by generating
|
|
# function signatures for native kernels that can be consumed by inspect.
|
|
# See Issue #28233.
|
|
Tensor = torch.Tensor
|
|
ret: Dict[Callable, Callable] = {
|
|
torch.abs: lambda input, out=None: -1,
|
|
torch.absolute: lambda input, out=None: -1,
|
|
torch.adaptive_avg_pool1d: lambda input, output_size: -1,
|
|
torch.adaptive_max_pool1d: lambda inputs, output_size: -1,
|
|
torch.acos: lambda input, out=None: -1,
|
|
torch.adjoint: lambda input: -1,
|
|
torch.arccos: lambda input, out=None: -1,
|
|
torch.acosh: lambda input, out=None: -1,
|
|
torch.arccosh: lambda input, out=None: -1,
|
|
torch.add: lambda input, other, out=None: -1,
|
|
torch.addbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1,
|
|
torch.addcdiv: lambda input, tensor1, tensor2, value=1, out=None: -1,
|
|
torch.addcmul: lambda input, tensor1, tensor2, value=1, out=None: -1,
|
|
torch.addmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
|
|
torch.addmv: lambda input, mat, vec, beta=1, alpha=1, out=None: -1,
|
|
torch.addr: lambda input, vec1, vec2, beta=1, alpha=1, out=None: -1,
|
|
torch.affine_grid_generator: lambda theta, size, align_corners: -1,
|
|
torch.all: lambda input, dim=None: -1,
|
|
torch.allclose: lambda input, other, trol=1e-05, atol=1e-08, equal_nan=False: -1,
|
|
torch.alpha_dropout: lambda input, p, train, inplace=False: -1,
|
|
torch.amax: lambda input, dim=None: -1,
|
|
torch.amin: lambda input, dim=None: -1,
|
|
torch.aminmax: lambda input, dim=None, keepdim=False, out=None: -1,
|
|
torch.angle: lambda input, out=None: -1,
|
|
torch.any: lambda input, dim=None, keepdim=False, out=None: -1,
|
|
torch.argmax: lambda input: -1,
|
|
torch.argmin: lambda input: -1,
|
|
torch.argsort: lambda input, dim=None: -1,
|
|
torch.asin: lambda input, out=None: -1,
|
|
torch._assert_async: lambda input: -1,
|
|
torch.arcsin: lambda input, out=None: -1,
|
|
torch.asinh: lambda input, out=None: -1,
|
|
torch.arcsinh: lambda input, out=None: -1,
|
|
torch.atan: lambda input, out=None: -1,
|
|
torch.arctan: lambda input, out=None: -1,
|
|
torch.atan2: lambda input, other, out=None: -1,
|
|
torch.arctan2: lambda input, other, out=None: -1,
|
|
torch.atanh: lambda input, out=None: -1,
|
|
torch.arctanh: lambda input, out=None: -1,
|
|
torch.atleast_1d: lambda *tensors: -1,
|
|
torch.atleast_2d: lambda *tensors: -1,
|
|
torch.atleast_3d: lambda *tensors: -1,
|
|
torch.avg_pool1d: lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True: -1,
|
|
torch.baddbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1,
|
|
torch.batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled: -1,
|
|
torch.batch_norm_backward_elemt: lambda grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count_tensor: -1,
|
|
torch.batch_norm_backward_reduce: lambda grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g: -1,
|
|
torch.batch_norm_elemt: lambda input, weight, bias, mean, invstd, eps: -1,
|
|
torch.batch_norm_gather_stats: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1,
|
|
torch.batch_norm_gather_stats_with_counts: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1,
|
|
torch.batch_norm_stats: lambda input, eps: -1,
|
|
torch.batch_norm_update_stats: lambda input, running_mean, running_var, momentum: -1,
|
|
torch.bernoulli: lambda input, generator=None, out=None: -1,
|
|
torch.bilinear: lambda input1, input2, weight, bias: -1,
|
|
torch.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None, reduce=None,
|
|
reduction='mean', pos_weight=None: -1),
|
|
torch.bincount: lambda input, weights=None, minlength=0: -1,
|
|
torch.binomial: lambda count, prob, generator=None: -1,
|
|
torch.bitwise_and: lambda input, other, out=None: -1,
|
|
torch.bitwise_not: lambda input, out=None: -1,
|
|
torch.bitwise_or: lambda input, other, out=None: -1,
|
|
torch.bitwise_xor: lambda input, other, out=None: -1,
|
|
torch.bitwise_left_shift: lambda input, other, out=None: -1,
|
|
torch.bitwise_right_shift: lambda input, other, out=None: -1,
|
|
torch.block_diag: lambda *tensors: -1,
|
|
torch.bmm: lambda input, mat2, out=None: -1,
|
|
torch.broadcast_tensors: lambda *tensors: -1,
|
|
torch.broadcast_to: lambda self, size: -1,
|
|
torch.bucketize: lambda input, boundaries, out_int32=False, right=False, out=None: -1,
|
|
torch.cartesian_prod: lambda *tensors: -1,
|
|
torch.cat: lambda tensors, dim=0, out=None: -1,
|
|
torch.concat: lambda tensors, dim=0, out=None: -1, # alias for torch.cat
|
|
torch.cdist: lambda x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary': -1,
|
|
torch.ceil: lambda input, out=None: -1,
|
|
torch.celu: lambda input, alhpa=1., inplace=False: -1,
|
|
torch.chain_matmul: lambda *matrices, out=None: -1,
|
|
torch.channel_shuffle: lambda input, groups : -1,
|
|
torch.cholesky: lambda input, upper=False, out=None: -1,
|
|
torch.linalg.cholesky: lambda input, out=None: -1,
|
|
torch.linalg.cholesky_ex: lambda input, check_errors=False, out=None: -1,
|
|
torch.cholesky_inverse: lambda input, upper=False, out=None: -1,
|
|
torch.cholesky_solve: lambda input1, input2, upper=False, out=None: -1,
|
|
torch.choose_qparams_optimized: lambda input, numel, n_bins, ratio, bit_width: -1,
|
|
torch.chunk: lambda input, chunks, dim=0: -1,
|
|
torch.clamp: lambda input, min=None, max=None, out=None: -1,
|
|
torch.clip: lambda input, min=None, max=None, out=None: -1,
|
|
torch.clamp_min: lambda input, min, out=None: -1,
|
|
torch.clamp_max: lambda input, max, out=None: -1,
|
|
torch.column_stack: lambda tensors, out=None: -1,
|
|
torch.cov: lambda input, correction=1, fweights=None, aweights=None: -1,
|
|
torch.clone: lambda input: -1,
|
|
torch.combinations: lambda input, r=2, with_replacement=False: -1,
|
|
torch.complex: lambda real, imag: -1,
|
|
torch.copysign: lambda input, other, out=None: -1,
|
|
torch.polar: lambda abs, ang: -1,
|
|
torch.linalg.cond: lambda input, ord=None: -1,
|
|
torch.conj: lambda input, out=None: -1,
|
|
torch.conj_physical: lambda input, out=None: -1,
|
|
torch.resolve_conj: lambda input, out=None: -1,
|
|
torch.resolve_neg: lambda input, out=None: -1,
|
|
torch.constant_pad_nd: lambda input, pad, value=0: -1,
|
|
torch.conv1d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
|
|
torch.conv2d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
|
|
torch.conv3d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
|
|
torch.convolution: lambda input, weight, bias, stride, padding, dilation, transposed, output_adding, groups: -1,
|
|
torch.conv_tbc: lambda input, weight, bias, pad=0: -1,
|
|
torch.conv_transpose1d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
|
|
torch.conv_transpose2d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
|
|
torch.conv_transpose3d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
|
|
torch.corrcoef: lambda input: -1,
|
|
torch.cos: lambda input, out=None: -1,
|
|
torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1,
|
|
torch.cosh: lambda input, out=None: -1,
|
|
torch.cosine_similarity: lambda x1, x2, dim=1, eps=1e-8: -1,
|
|
torch.count_nonzero: lambda input: -1,
|
|
torch.cross: lambda input, other, dim=None, out=None: -1,
|
|
torch.linalg.cross: lambda input, other, dim=-1, out=None: -1,
|
|
torch.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean',
|
|
zero_infinity=False: -1),
|
|
torch.cummax: lambda input, dim, out=None: -1,
|
|
torch.cummin: lambda input, dim, out=None: -1,
|
|
torch.cumprod: lambda input, dim, out=None, dtype=None: -1,
|
|
torch.cumsum: lambda input, dim, out=None, dtype=None: -1,
|
|
torch.cumulative_trapezoid: lambda y, x=None, dim=-1: -1,
|
|
torch.logcumsumexp: lambda input, dim, out=None: -1,
|
|
torch.deg2rad: lambda input, out=None: -1,
|
|
torch.dequantize: lambda input: -1,
|
|
torch.det: lambda input: -1,
|
|
torch.linalg.det: lambda input: -1, # alias for torch.det # type: ignore[attr-defined]
|
|
torch.detach: lambda input: -1,
|
|
torch.diag: lambda input, diagonal=0, out=None: -1,
|
|
torch.diag_embed: lambda input, diagonal=0, out=None: -1,
|
|
torch.diagflat: lambda input, offset=0: -1,
|
|
torch.diff: lambda input, n=1, dim=-1, prepend=None, append=None, out=None: -1,
|
|
torch.diagonal: lambda input, offset=0, dim1=0, dim2=1: -1,
|
|
torch.linalg.diagonal: lambda input, offset=0, dim1=-2, dim2=-1: -1,
|
|
torch.diagonal_scatter: lambda input, src, offset=0, dim1=0, dim2=1: -1,
|
|
torch.digamma: lambda input, out=None: -1,
|
|
torch.dist: lambda input, other, p=2: -1,
|
|
torch.div: lambda input, other, rounding_mode=None, out=None: -1,
|
|
torch.divide: lambda input, other, rounding_mode=None, out=None: -1,
|
|
torch.dot: lambda input, other, out=None: -1,
|
|
torch.dropout: lambda input, p, train, inplace=False: -1,
|
|
torch.dsmm: lambda input, mat2: -1,
|
|
torch.hsmm: lambda mat1, mat2: -1,
|
|
torch.dsplit: lambda input, indices_or_sections: -1,
|
|
torch.dstack: lambda tensors, out=None: -1,
|
|
torch.eig: lambda input, eigenvectors=False, out=None: -1,
|
|
torch.linalg.eig: lambda input, out=None: -1,
|
|
torch.linalg.eigvals: lambda input, out=None: -1,
|
|
torch.linalg.eigh: lambda input, UPLO="L", out=None: -1,
|
|
torch.linalg.eigvalsh: lambda input, UPLO="L", out=None: -1,
|
|
torch.einsum: lambda equation, *operands: -1,
|
|
torch.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False,
|
|
sparse=False: -1),
|
|
torch.embedding_bag: (lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False,
|
|
mode='mean', sparse=False, per_sample_weights=None, padding_idx=None: -1),
|
|
torch.empty_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
|
|
torch.eq: lambda input, other, out=None: -1,
|
|
torch.equal: lambda input, other: -1,
|
|
torch.erf: lambda input, out=None: -1,
|
|
torch.erfc: lambda input, out=None: -1,
|
|
torch.erfinv: lambda input, out=None: -1,
|
|
torch.exp: lambda input, out=None: -1,
|
|
torch.exp2: lambda input, out=None: -1,
|
|
torch.expm1: lambda input, out=None: -1,
|
|
torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
|
|
torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1,
|
|
torch.fused_moving_avg_obs_fake_quant: (lambda x, observer_on, fake_quant_on, averaging_const, running_min,
|
|
running_max, scale, zero_point, quant_min, quant_max, ch_axis,
|
|
per_row_fake_quant=False, symmetric_quant=False: -1),
|
|
torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1,
|
|
torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1,
|
|
torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1,
|
|
torch.fbgemm_linear_int8_weight_fp32_activation: (lambda input, weight, packed, col_offsets, weight_scale,
|
|
weight_zero_point, bias: -1),
|
|
torch.fbgemm_linear_quantize_weight: lambda input: -1,
|
|
torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1,
|
|
torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1,
|
|
torch.feature_alpha_dropout: lambda input, p, train: -1,
|
|
torch.feature_dropout: lambda input, p, train: -1,
|
|
torch.fft.fft: lambda input, n=None, dim=-1, norm=None: -1,
|
|
torch.fft.ifft: lambda input, n=None, dim=-1, norm=None: -1,
|
|
torch.fft.rfft: lambda input, n=None, dim=-1, norm=None: -1,
|
|
torch.fft.irfft: lambda input, n=None, dim=-1, norm=None: -1,
|
|
torch.fft.hfft: lambda input, n=None, dim=-1, norm=None: -1,
|
|
torch.fft.ihfft: lambda input, n=None, dim=-1, norm=None: -1,
|
|
torch.fft.hfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
|
|
torch.fft.ihfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
|
|
torch.fft.hfftn: lambda input, s=None, dim=-1, norm=None: -1,
|
|
torch.fft.ihfftn: lambda input, s=None, dim=-1, norm=None: -1,
|
|
torch.fft.fftn: lambda input, s=None, dim=None, norm=None: -1,
|
|
torch.fft.ifftn: lambda input, s=None, dim=None, norm=None: -1,
|
|
torch.fft.rfftn: lambda input, s=None, dim=None, norm=None: -1,
|
|
torch.fft.irfftn: lambda input, s=None, dim=None, norm=None: -1,
|
|
torch.fft.fft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
|
|
torch.fft.ifft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
|
|
torch.fft.rfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
|
|
torch.fft.irfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
|
|
torch.fft.fftshift: lambda input, dim=None: -1,
|
|
torch.fft.ifftshift: lambda input, dim=None: -1,
|
|
torch.fft.fft: lambda input, n=None, dim=-1, norm=None: -1,
|
|
torch.fix: lambda input, out=None: -1,
|
|
torch.flatten: lambda input, start_dim=0, end_dim=-1: -1,
|
|
torch.flip: lambda input, dims: -1,
|
|
torch.fliplr: lambda input: -1,
|
|
torch.flipud: lambda input: -1,
|
|
torch.frobenius_norm: lambda input, dim=None, keepdim=False, out=None: -1,
|
|
torch.floor: lambda input, out=None: -1,
|
|
torch.floor_divide: lambda input, other: -1,
|
|
torch.float_power: lambda input, exponent, out=None: -1,
|
|
torch.fmod: lambda input, other, out=None: -1,
|
|
torch.frac: lambda input, out=None: -1,
|
|
torch.frexp: lambda input, out=None: -1,
|
|
torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
|
|
torch.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1,
|
|
torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1,
|
|
torch.gcd: lambda input, other, out=None: -1,
|
|
torch.ge: lambda input, other, out=None: -1,
|
|
torch.greater_equal: lambda input, other, out=None: -1,
|
|
torch.geqrf: lambda input, out=None: -1,
|
|
torch.i0: lambda input, out=None: -1,
|
|
torch.inner: lambda input, other, out=None: -1,
|
|
torch.outer: lambda input, vec2, out=None: -1,
|
|
torch.ger: lambda input, vec2, out=None: -1, # alias for torch.outer
|
|
torch.gradient: lambda input, spacing=None, dim=None, edge_order=1: -1,
|
|
torch.grid_sampler: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
|
|
torch.grid_sampler_2d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
|
|
torch.grid_sampler_3d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
|
|
torch.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05, cudnn_enabled=True: -1,
|
|
torch.gru: lambda input, hx, params, has_biases, num_layers, gropout, train, bidirectional, batch_first: -1,
|
|
torch.gru_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
|
|
torch.gt: lambda input, other, out=None: -1,
|
|
torch.greater: lambda input, other, out=None: -1,
|
|
torch.hardshrink: lambda input, lambd=0.5: -1,
|
|
torch.heaviside: lambda input, values, out=None: -1,
|
|
torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction='mean': -1,
|
|
torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
|
|
torch.histogram: lambda input, bins=100, min=None, max=None, weight=None, density=False, out=None: -1,
|
|
torch.histogramdd: lambda input, bins, range=None, weight=None, density=False: -1,
|
|
torch.linalg.householder_product: lambda input, tau: -1,
|
|
torch.hspmm: lambda mat1, mat2, out=None: -1,
|
|
torch.hsplit: lambda input, indices_or_sections: -1,
|
|
torch.hstack: lambda tensors, out=None: -1,
|
|
torch.hypot: lambda input, other, out=None: -1,
|
|
torch.igamma: lambda input, other, out=None: -1,
|
|
torch.igammac: lambda input, other, out=None: -1,
|
|
torch.imag: lambda input, out=None: -1,
|
|
torch.index_add: lambda input, dim, index, source: -1,
|
|
torch.index_copy: lambda input, dim, index, source: -1,
|
|
torch.index_put: lambda input, indices, values, accumulate=False: -1,
|
|
torch.index_select: lambda input, dim, index, out=None: -1,
|
|
torch.index_fill: lambda input, dim, index, value: -1,
|
|
torch._index_reduce: lambda input, dim, index, source, reduce, include_input=True: -1,
|
|
torch.isfinite: lambda tensor: -1,
|
|
torch.isin: lambda e, te, assume_unique=False, invert=False: -1,
|
|
torch.isinf: lambda tensor: -1,
|
|
torch.isreal: lambda tensor: -1,
|
|
torch.isposinf: lambda input, out=None: -1,
|
|
torch.isneginf: lambda input, out=None: -1,
|
|
torch.instance_norm: (lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps,
|
|
cudnn_enabled: -1),
|
|
torch.int_repr: lambda input: -1,
|
|
torch.inverse: lambda input, out=None: -1,
|
|
torch.linalg.inv: lambda input, out=None: -1,
|
|
torch.linalg.inv_ex: lambda input, check_errors=False, out=None: -1,
|
|
torch.is_complex: lambda input: -1,
|
|
torch.is_conj: lambda input: -1,
|
|
torch.is_neg: lambda input: -1,
|
|
torch.is_distributed: lambda input: -1,
|
|
torch.is_inference: lambda input: -1,
|
|
torch.is_floating_point: lambda input: -1,
|
|
torch.is_nonzero: lambda input: -1,
|
|
torch.is_same_size: lambda input, other: -1,
|
|
torch.is_signed: lambda input: -1,
|
|
torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1,
|
|
torch.isnan: lambda input: -1,
|
|
torch.istft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
|
|
normalized=False, onesided=None, length=None, return_complex=False: -1),
|
|
torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
|
|
torch.kron: lambda input, other: -1,
|
|
torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1,
|
|
torch.linalg.ldl_factor_ex: lambda input, hermitian=False, check_errors=False, out=None: -1,
|
|
torch.linalg.ldl_factor: lambda input, hermitian=False, out=None: -1,
|
|
torch.linalg.ldl_solve: lambda LD, pivots, B, hermitian=False, out=None: -1,
|
|
torch.layer_norm: lambda input, normalized_shape, weight=None, bias=None, esp=1e-05, cudnn_enabled=True: -1,
|
|
torch.lcm: lambda input, other, out=None: -1,
|
|
torch.ldexp: lambda input, other, out=None: -1,
|
|
torch.le: lambda input, other, out=None: -1,
|
|
torch.less_equal: lambda input, other, out=None: -1,
|
|
torch.lerp: lambda input, end, weight, out=None: -1,
|
|
torch.lgamma: lambda input, out=None: -1,
|
|
torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None,
|
|
tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1,
|
|
torch.log: lambda input, out=None: -1,
|
|
torch.log_softmax: lambda input, dim, dtype=None: -1,
|
|
torch.log10: lambda input, out=None: -1,
|
|
torch.log1p: lambda input, out=None: -1,
|
|
torch.log2: lambda input, out=None: -1,
|
|
torch.logaddexp: lambda input, other, out=None: -1,
|
|
torch.logaddexp2: lambda input, other, out=None: -1,
|
|
torch.logdet: lambda input: -1,
|
|
torch.xlogy: lambda x, y, out=None: -1,
|
|
torch.logical_and: lambda input, other, out=None: -1,
|
|
torch.logical_not: lambda input, out=None: -1,
|
|
torch.logical_or: lambda input, other, out=None: -1,
|
|
torch.logical_xor: lambda input, other, out=None: -1,
|
|
torch.logsumexp: lambda input, names, keepdim=False, out=None: -1,
|
|
torch.logit: lambda input, eps=None: -1,
|
|
torch.logsumexp: lambda input, names, keepdim=False, out=None: -1,
|
|
torch.lstm: lambda data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional: -1,
|
|
torch.lstm_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
|
|
torch.lstsq: lambda input, A, out=None: -1,
|
|
torch.lt: lambda input, other, out=None: -1,
|
|
torch.less: lambda input, other, out=None: -1,
|
|
torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1,
|
|
torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1,
|
|
torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1, # type: ignore[attr-defined] # noqa: B950
|
|
torch.masked_fill: lambda input, mask, value: -1,
|
|
torch.masked_scatter: lambda input, mask, source: -1,
|
|
torch.masked_select: lambda input, mask, out=None: -1,
|
|
torch.matmul: lambda input, other, out=None: -1,
|
|
torch.linalg.lu_factor: lambda input, pivot=True, out=None: -1,
|
|
torch.linalg.lu_factor_ex: lambda input, pivot=True, check_errors=False, out=None: -1,
|
|
torch.linalg.matmul: lambda input, other, out=None: -1, # alias for torch.matmul
|
|
torch.matrix_power: lambda input, n: -1,
|
|
torch.linalg.matrix_power: lambda input, n, out=None: -1,
|
|
torch.matrix_rank: lambda input, tol=None, symmetric=False: -1,
|
|
torch.linalg.matrix_rank: lambda input, tol=None, hermitian=False: -1,
|
|
torch.linalg.multi_dot: lambda tensors, out=None: -1,
|
|
torch.matrix_exp: lambda input: -1,
|
|
torch.linalg.matrix_exp: lambda input: -1,
|
|
torch.max: lambda input, out=None: -1,
|
|
torch.maximum: lambda input, other, out=None: -1,
|
|
torch.fmax: lambda input, other, out=None: -1,
|
|
torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
|
|
torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
|
|
torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
|
|
torch.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
|
return_indices=False, ceil_mode=False: -1),
|
|
torch.mean: lambda input, dim=None: -1,
|
|
torch.nanmean: lambda input, dim=None, keepdim=False, dtype=None, out=None: -1,
|
|
torch.median: lambda input, dim=None: -1,
|
|
torch.nanmedian: lambda input, dim=None: -1,
|
|
torch.meshgrid: lambda *tensors, **kwargs: -1,
|
|
torch.min: lambda input, out=None: -1,
|
|
torch.minimum: lambda input, other, out=None: -1,
|
|
torch.fmin: lambda input, other, out=None: -1,
|
|
torch.miopen_batch_norm: (lambda input, weight, bias, running_mean, running_var, training,
|
|
exponential_average_factor, epsilon: -1),
|
|
torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1,
|
|
torch.miopen_convolution_transpose: (lambda input, weight, bias, padding, output_padding, stride, dilation,
|
|
groups, benchmark, deterministic: -1),
|
|
torch.miopen_depthwise_convolution: (lambda input, weight, bias, padding, stride, dilation, groups, benchmark,
|
|
deterministic: -1),
|
|
torch.miopen_rnn: (lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first,
|
|
dropout, train, bidirectional, batch_sizes, dropout_state: -1),
|
|
torch.mm: lambda input, mat2, out=None: -1,
|
|
torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1,
|
|
torch.movedim: lambda input, source, destination: -1,
|
|
torch.moveaxis: lambda input, source, destination: -1,
|
|
torch.msort: lambda input, descending=False, out=None: -1,
|
|
torch.mul: lambda input, other, out=None: -1,
|
|
torch.multiply: lambda input, other, out=None: -1,
|
|
torch.multinomial: lambda input, num_samples, replacement=False, out=None: -1,
|
|
torch.mv: lambda input, vec, out=None: -1,
|
|
torch.mvlgamma: lambda input, p: -1,
|
|
torch.narrow: lambda input, dim, start, length: -1,
|
|
torch.narrow_copy: lambda input, dim, start, length: -1,
|
|
torch.nan_to_num: lambda input, nan=0.0, posinf=None, neginf=None, out=None: -1,
|
|
torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1,
|
|
torch.native_dropout: lambda input, p, train: -1,
|
|
torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
|
|
torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
|
|
torch.native_norm: lambda input, p=2: -1,
|
|
torch.native_norm: lambda input, p=2: -1,
|
|
torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1,
|
|
torch.native_channel_shuffle: lambda input, groups : -1,
|
|
torch.ne: lambda input, other, out=None: -1,
|
|
torch.not_equal: lambda input, other, out=None: -1,
|
|
torch.neg: lambda input, out=None: -1,
|
|
torch.negative: lambda input, out=None: -1,
|
|
torch.nextafter: lambda input, other, out=None: -1,
|
|
torch.nn.functional.adaptive_avg_pool2d: lambda input, output_size: -1,
|
|
torch.nn.functional.adaptive_avg_pool3d: lambda input, output_size: -1,
|
|
torch.nn.functional.adaptive_max_pool1d: lambda input, output_size, return_indices=False: -1,
|
|
torch.nn.functional.adaptive_max_pool1d_with_indices: lambda input, output_size, return_indices=False: -1,
|
|
torch.nn.functional.adaptive_max_pool2d: lambda input, output_size, return_indices=False: -1,
|
|
torch.nn.functional.adaptive_max_pool2d_with_indices: lambda input, output_size, return_indices=False: -1,
|
|
torch.nn.functional.adaptive_max_pool3d: lambda input, output_size, return_indices=False: -1,
|
|
torch.nn.functional.adaptive_max_pool3d_with_indices: lambda input, output_size, return_indices=False: -1,
|
|
torch.nn.functional.affine_grid: lambda theta, size, align_corners=None: -1,
|
|
torch.nn.functional.alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
|
|
torch.nn.functional.avg_pool2d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False,
|
|
count_include_pad=True, divisor_override=None: -1),
|
|
torch.nn.functional.avg_pool3d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False,
|
|
count_include_pad=True, divisor_override=None: -1),
|
|
torch.nn.functional.batch_norm: (lambda input, running_mean, running_var, weight=None, bias=None, training=False,
|
|
momentum=0.1, eps=1e-05: -1),
|
|
torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1,
|
|
torch.nn.functional.binary_cross_entropy: (lambda input, target, weight=None, size_average=None, reduce=None,
|
|
reduction="mean": -1),
|
|
torch.nn.functional.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None,
|
|
reduce=None, reduction="mean", pos_weight=None: -1),
|
|
torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1,
|
|
torch.nn.functional.cosine_embedding_loss: (lambda input1, input2, target, margin=0, size_average=None,
|
|
reduce=None, reduction='mean': -1),
|
|
torch.nn.functional.cross_entropy: (lambda input, target, weight=None, size_average=None, ignore_index=-100,
|
|
reduce=None, reduction="mean", label_smoothing=0.0: -1),
|
|
torch.nn.functional.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0,
|
|
reduction='mean', zero_infinity=False: -1),
|
|
torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1,
|
|
torch.nn.functional.dropout2d: lambda input, p=0.5, training=True, inplace=False: -1,
|
|
torch.nn.functional.dropout3d: lambda input, p=0.5, training=True, inplace=False: -1,
|
|
torch.nn.functional.elu: lambda input, alpha=1.0, inplace=False: -1,
|
|
torch.nn.functional.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0,
|
|
scale_grad_by_freq=False, sparse=False: -1),
|
|
torch.nn.functional.embedding_bag: (lambda input, weight, offsets=None, max_norm=None, norm_type=2,
|
|
scale_grad_by_freq=False, mode='mean', sparse=False, per_sample_weights=None,
|
|
include_last_offset=False, padding_idx=None: -1),
|
|
torch.nn.functional.feature_alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
|
|
torch.nn.functional.fold: lambda input, output_size, kernel_size, dilation=1, padding=0, stride=1: -1,
|
|
torch.nn.functional.fractional_max_pool2d: (lambda input, kernel_size, output_size=None, output_ratio=None,
|
|
return_indices=False, _random_samples=None: -1),
|
|
torch.nn.functional.fractional_max_pool2d_with_indices: (
|
|
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
|
|
_random_samples=None: -1),
|
|
torch.nn.functional.fractional_max_pool3d: (lambda input, kernel_size, output_size=None, output_ratio=None,
|
|
return_indices=False, _random_samples=None: -1),
|
|
torch.nn.functional.fractional_max_pool3d_with_indices: (
|
|
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
|
|
_random_samples=None: -1),
|
|
torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction='mean': -1,
|
|
torch.nn.functional.gelu: lambda input, approximate='none': -1,
|
|
torch.nn.functional.glu: lambda input, dim=-1: -1,
|
|
torch.nn.functional.grid_sample: lambda input, grid, mode='bilinear', padding_mode='zeros', align_corners=None: -1,
|
|
torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1,
|
|
torch.nn.functional.gumbel_softmax: lambda logits, tau=1, hard=False, eps=1e-10, dim=-1: -1,
|
|
torch.nn.functional.hardshrink: lambda input, lambd=0.5: -1,
|
|
torch.nn.functional.hardtanh: lambda input, min_val=-1., max_val=1., inplace=False: -1,
|
|
torch.nn.functional.hinge_embedding_loss: (lambda input, target, margin=1.0, size_average=None, reduce=None,
|
|
reduction='mean': -1),
|
|
torch.nn.functional.instance_norm: (lambda input, running_mean=None, running_var=None, weight=None, bias=None,
|
|
use_input_stats=True, momentum=0.1, eps=1e-05: -1),
|
|
torch.nn.functional.interpolate: (lambda input, size=None, scale_factor=None, mode='nearest', align_corners=None,
|
|
recompute_scale_factor=None, antialias=False: -1),
|
|
torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
|
|
torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
|
|
torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
|
|
torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1,
|
|
torch.nn.functional.linear: lambda input, weight, bias=None: -1,
|
|
torch.nn.functional.local_response_norm: lambda input, size, alpha=0.0001, beta=0.75, k=1.0: -1,
|
|
torch.nn.functional.log_softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
|
|
torch.nn.functional.logsigmoid: lambda input: -1,
|
|
torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
|
|
torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
|
|
torch.nn.functional.margin_ranking_loss: (lambda input1, input2, target, margin=0, size_average=None,
|
|
reduce=None, reduction='mean': -1),
|
|
torch.nn.functional.max_pool1d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
|
ceil_mode=False, return_indices=False: -1),
|
|
torch.nn.functional.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
|
return_indices=False, ceil_mode=False: -1),
|
|
torch.nn.functional.max_pool2d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
|
ceil_mode=False, return_indices=False: -1),
|
|
torch.nn.functional.max_pool2d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
|
return_indices=False, ceil_mode=False: -1),
|
|
torch.nn.functional.max_pool3d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
|
return_indices=False, ceil_mode=False: -1),
|
|
torch.nn.functional.max_pool3d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
|
return_indices=False, ceil_mode=False: -1),
|
|
torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
|
|
torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
|
|
torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
|
|
torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
|
|
torch.nn.functional.multi_head_attention_forward: (
|
|
lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v,
|
|
add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None,
|
|
need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None,
|
|
v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None: -1),
|
|
torch.nn.functional.multi_margin_loss: (lambda input, target, p=1, margin=1.0, weight=None, size_average=None,
|
|
reduce=None, reduction='mean': -1),
|
|
torch.nn.functional.multilabel_margin_loss: (lambda input, target, size_average=None, reduce=None,
|
|
reduction='mean': -1),
|
|
torch.nn.functional.multilabel_soft_margin_loss: (lambda input, target, weight=None, size_average=None,
|
|
reduce=None, reduction='mean': -1),
|
|
torch.nn.functional.nll_loss: (lambda input, target, weight=None, size_average=None, ignore_index=-100,
|
|
reduce=None, reduction='mean': -1),
|
|
torch.nn.functional.normalize: lambda input, p=2, dim=1, eps=1e-12, out=None: -1,
|
|
torch.nn.functional.one_hot: lambda tensor, num_classes=-1: -1,
|
|
torch.nn.functional.pad: lambda input, pad, mode='constant', value=0: -1,
|
|
torch.nn.functional.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
|
|
torch.nn.functional.poisson_nll_loss: (lambda input, target, log_input=True, full=False, size_average=None,
|
|
eps=1e-08, reduce=None, reduction='mean': -1),
|
|
torch.nn.functional.prelu: lambda input, weight: -1,
|
|
torch.nn.functional.relu: lambda input, inplace=False: -1,
|
|
torch.nn.functional.relu6: lambda input, inplace=False: -1,
|
|
torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1,
|
|
torch.nn.functional.selu: lambda input, inplace=False: -1,
|
|
torch.nn.functional.silu: lambda input, inplace=False: -1,
|
|
torch.nn.functional.mish: lambda input, inplace=False: -1,
|
|
torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean', beta=1.: -1,
|
|
torch.nn.functional.huber_loss: lambda input, target, reduction='mean', delta=1.: -1,
|
|
torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
|
|
torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
|
|
torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
|
|
torch.nn.functional.softplus: lambda input, beta=1, threshold=20: -1,
|
|
torch.nn.functional.softshrink: lambda input, lambd=0.5: -1,
|
|
torch.nn.functional.softsign: lambda input: -1,
|
|
torch.nn.functional.tanhshrink: lambda input: -1,
|
|
torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1,
|
|
torch.nn.functional.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06,
|
|
swap=False, size_average=None, reduce=None, reduction='mean': -1),
|
|
torch.nn.functional.triplet_margin_with_distance_loss: (lambda anchor, positive, negative, *,
|
|
distance_function=None, margin=1.0,
|
|
swap=False, reduction='mean': -1),
|
|
torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
|
|
torch.nn.init.uniform_: lambda tensor, a=0., b=1.: -1,
|
|
torch.nn.init.constant_: lambda tensor, val: -1,
|
|
torch.nn.init.normal_: lambda tensor, mean=0., std=1.: -1,
|
|
torch.nn.init.constant_: lambda tensor, val: -1,
|
|
torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode='fan_in', nonlinearity='leaky_relu': -1,
|
|
torch.nonzero: lambda input, as_tuple=False: -1,
|
|
torch.argwhere: lambda input: -1,
|
|
torch.norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
|
|
torch.linalg.norm: lambda input, ord=None, dim=None, keepdim=False, out=None, dtype=None: -1,
|
|
torch.linalg.vector_norm: lambda input, ord=2, dim=None, keepdim=False, out=None, dtype=None: -1,
|
|
torch.linalg.matrix_norm: lambda input, ord='fro', dim=(-2, -1), keepdim=False, out=None, dtype=None: -1,
|
|
torch.norm_except_dim: lambda v, pow=2, dim=0: -1,
|
|
torch.nuclear_norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
|
|
torch.numel: lambda input: -1,
|
|
torch.orgqr: lambda input, tau: -1,
|
|
torch.ormqr: lambda input, input2, input3, left=True, transpose=False: -1,
|
|
torch.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
|
|
torch.permute: lambda self, dim: -1,
|
|
torch.pca_lowrank: lambda input, q=None, center=True, niter=2: -1,
|
|
torch.pdist: lambda input, p=2: -1,
|
|
torch.pinverse: lambda input, rcond=1e-15: -1,
|
|
torch.linalg.pinv: lambda input, rcond=1e-15, hermitian=False: -1,
|
|
torch.pixel_shuffle: lambda input, upscale_factor: -1,
|
|
torch.pixel_unshuffle: lambda input, downscale_factor: -1,
|
|
torch.poisson: lambda input, generator=None: -1,
|
|
torch.poisson_nll_loss: lambda input, target, log_input, full, eps, reduction: -1,
|
|
torch.polygamma: lambda input, n, out=None: -1,
|
|
torch.positive: lambda input, out=None: -1,
|
|
torch.prelu: lambda input, weight: -1,
|
|
torch.ones_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
|
|
torch.pow: lambda input, exponent, out=None: -1,
|
|
torch.prod: lambda input, dtype=None: -1,
|
|
torch.put: lambda input, index, source, accumulate=False: -1,
|
|
torch.q_per_channel_axis: lambda input: -1,
|
|
torch.q_per_channel_scales: lambda input: -1,
|
|
torch.q_per_channel_zero_points: lambda input: -1,
|
|
torch.q_scale: lambda input: -1,
|
|
torch.q_zero_point: lambda input: -1,
|
|
torch.qr: lambda input, some=True, out=None: -1,
|
|
torch.linalg.qr: lambda input, mode='reduced', out=None: -1,
|
|
torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1,
|
|
torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1,
|
|
torch.quantize_per_channel: lambda input, scales, zero_points, axis, dtype: -1,
|
|
torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1,
|
|
torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1,
|
|
torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1,
|
|
torch.quantized_gru_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
|
|
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
|
|
|
|
torch.quantized_lstm_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
|
|
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
|
|
torch.quantized_max_pool1d: (lambda input, kernel_size, stride=tuple(), padding=(0,),
|
|
dilation=(1,), ceil_mode=False: -1),
|
|
torch.quantized_max_pool2d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0),
|
|
dilation=(1, 1), ceil_mode=False: -1),
|
|
torch.quantized_rnn_relu_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
|
|
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
|
|
torch.quantized_rnn_tanh_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
|
|
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
|
|
torch.rad2deg: lambda input, out=None: -1,
|
|
torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
|
|
torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
|
|
torch.randn_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
|
|
torch.ravel: lambda input: -1,
|
|
torch.real: lambda input, out=None: -1,
|
|
torch.vdot: lambda input, other, out=None: -1,
|
|
torch.view_as_real: lambda input: -1,
|
|
torch.view_as_complex: lambda input: -1,
|
|
torch.reciprocal: lambda input, out=None: -1,
|
|
torch.relu: lambda input, inplace=False: -1,
|
|
torch.remainder: lambda input, other, out=None: -1,
|
|
torch.renorm: lambda input, p, dim, maxnorm, out=None: -1,
|
|
torch.repeat_interleave: lambda input, dim=None: -1,
|
|
torch.reshape: lambda input, shape: -1,
|
|
torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
|
|
torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
|
|
torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
|
|
torch.rnn_tanh_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
|
|
torch.roll: lambda input, shifts, dims=None: -1,
|
|
torch.rot90: lambda input, k=1, dims=(0, 1): -1,
|
|
torch.round: lambda input, out=None: -1,
|
|
torch.row_stack: lambda tensors, out=None: -1, # alias for torch.vstack
|
|
torch._rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1),
|
|
torch.rrelu: lambda input, lower=1. / 8, upper=1. / 3, training=False, inplace=False: -1,
|
|
torch.rsqrt: lambda input, out=None: -1,
|
|
torch.rsub: lambda input, other, alpha=1: -1,
|
|
torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
|
|
torch.scatter: lambda input, dim, index, src: -1,
|
|
torch.scatter_add: lambda input, dim, index, src: -1,
|
|
torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1,
|
|
torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
|
|
torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, axis=0, unsafe=False: -1,
|
|
torch.select: lambda input, dim, index: -1,
|
|
torch.select_scatter: lambda input, src, dim, index: -1,
|
|
torch.slice_scatter: lambda input, src, dim=0, start=None, end=None, step=1: -1,
|
|
torch.selu: lambda input, inplace=False: -1,
|
|
torch.sigmoid: lambda input, out=None: -1,
|
|
torch.sign: lambda input, out=None: -1,
|
|
torch.signbit: lambda input, out=None: -1,
|
|
torch.sgn: lambda input, out=None: -1,
|
|
torch.sin: lambda input, out=None: -1,
|
|
torch.sinc: lambda input, out=None: -1,
|
|
torch.sinh: lambda input, out=None: -1,
|
|
torch.slogdet: lambda input: -1,
|
|
torch.linalg.slogdet: lambda input: -1,
|
|
torch.smm: lambda input, mat2: -1,
|
|
torch.spmm: lambda input, mat2: -1,
|
|
torch.softmax: lambda input, dim, dtype=None: -1,
|
|
torch.solve: lambda input, A, out=None: -1,
|
|
torch.linalg.solve: lambda input, other, out=None: -1,
|
|
torch.sort: lambda input, dim=-1, descending=False, *, stable=False, out=None: -1,
|
|
torch.split: lambda tensor, split_size_or_sections, dim=0: -1,
|
|
torch.split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1,
|
|
torch.sqrt: lambda input, out=None: -1,
|
|
torch.square: lambda input, out=None: -1,
|
|
torch.squeeze: lambda input, dim=None, out=None: -1,
|
|
torch.sspaddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
|
|
torch.stack: lambda tensors, dim=0, out=None: -1,
|
|
torch.std: lambda input, dim=None: -1,
|
|
torch.std_mean: lambda input, dim=None: -1,
|
|
torch.stft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
|
|
pad_mode='reflect', normalized=False, onesided=True, return_complex=None: -1),
|
|
torch.sub: lambda input, other, out=None: -1,
|
|
torch.subtract: lambda input, other, out=None: -1,
|
|
torch.sum: lambda input, dim=None: -1,
|
|
torch.nansum: lambda input, dim=None: -1,
|
|
torch.svd: lambda input, some=True, compute_uv=True, out=None: -1,
|
|
torch.svd_lowrank: lambda input, q=6, niter=2, M=None: -1,
|
|
torch.linalg.svd: lambda input, full_matrices=True, out=None: -1,
|
|
torch.linalg.svdvals: lambda input, out=None: -1,
|
|
torch.symeig: lambda input, eigenvectors=False, upper=True, out=None: -1,
|
|
torch.swapaxes: lambda input, dim0, dim1: -1,
|
|
torch.swapdims: lambda input, axis0, axis1: -1,
|
|
torch.special.entr: lambda input: -1,
|
|
torch.special.erf: lambda input: -1,
|
|
torch.special.erfc: lambda input: -1,
|
|
torch.special.erfcx: lambda input: -1,
|
|
torch.special.erfinv: lambda input: -1,
|
|
torch.special.exp2: lambda input: -1,
|
|
torch.special.expm1: lambda input: -1,
|
|
torch.special.expit: lambda input: -1,
|
|
torch.special.polygamma: lambda input, n, out=None: -1,
|
|
torch.special.digamma: lambda input: -1,
|
|
torch.special.psi: lambda input: -1,
|
|
torch.special.gammainc: lambda input, other, out=None: -1,
|
|
torch.special.gammaincc: lambda input, other, out=None: -1,
|
|
torch.special.gammaln: lambda input: -1,
|
|
torch.special.i0: lambda input: -1,
|
|
torch.special.i0e: lambda input: -1,
|
|
torch.special.i1: lambda input: -1,
|
|
torch.special.i1e: lambda input: -1,
|
|
torch.special.logit: lambda input: -1,
|
|
torch.special.logsumexp: lambda input, dim, keepdim=False, out=None: -1,
|
|
torch.special.log1p: lambda input: -1,
|
|
torch.special.log_softmax: lambda input, dim, dtype=None: -1,
|
|
torch.special.round: lambda input: -1,
|
|
torch.special.sinc: lambda input: -1,
|
|
torch.special.softmax: lambda input, dim, dtype=None: -1,
|
|
torch.special.multigammaln: lambda input, p: -1,
|
|
torch.special.ndtri: lambda input: -1,
|
|
torch.special.ndtr: lambda input: -1,
|
|
torch.special.log_ndtr: lambda input: -1,
|
|
torch.special.xlogy: lambda input, other, out=None: -1,
|
|
torch.special.xlog1py: lambda input, other, out=None: -1,
|
|
torch.special.zeta: lambda self, other, out=None: -1,
|
|
torch.t: lambda input: -1,
|
|
torch.take: lambda input, index: -1,
|
|
torch.take_along_dim: lambda input, indices, dim=None, out=None: -1,
|
|
torch.tan: lambda input, out=None: -1,
|
|
torch.tanh: lambda input, out=None: -1,
|
|
torch.linalg.tensorinv: lambda a, ind=2: -1,
|
|
torch.linalg.tensorsolve: lambda a, b, dims=None: -1,
|
|
torch.tensordot: lambda a, b, dims=2, out=None: -1,
|
|
torch.tensor_split: lambda input, indices_or_sections, dim=0: -1,
|
|
torch.threshold: lambda input, threshold, value, inplace=False: -1,
|
|
torch.tile: lambda input, dims: -1,
|
|
torch.topk: lambda input, k, dim=-1, descending=False, out=None: -1,
|
|
torch.trace: lambda input: -1,
|
|
torch.transpose: lambda input, dim0, dim1: -1,
|
|
torch.trapz: lambda y, x=None, dim=-1: -1,
|
|
torch.trapezoid: lambda y, x=None, dim=-1: -1,
|
|
torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1,
|
|
torch.linalg.solve_triangular: lambda input, B, upper, left=True, unitriangular=False: -1,
|
|
torch.tril: lambda input, diagonal=0, out=None: -1,
|
|
torch.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False,
|
|
|
|
size_average=None, reduce=None, reduction='mean': -1),
|
|
torch.triu: lambda input, diagonal=0, out=None: -1,
|
|
torch.true_divide: lambda input, other: -1,
|
|
torch.trunc: lambda input, out=None: -1,
|
|
torch.unbind: lambda input, dim=0: -1,
|
|
torch.unique: lambda input, sorted=True, return_inverse=False, return_counts=False, dim=None: -1,
|
|
torch.unique_consecutive: lambda input, return_inverse=False, return_counts=False, dim=None: -1,
|
|
torch.unsafe_chunk: lambda input, chunks, dim=0: -1,
|
|
torch.unsafe_split: lambda tensor, split_size_or_sections, dim=0: -1,
|
|
torch.unsafe_split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1,
|
|
torch.unsqueeze: lambda input, dim, out=None: -1,
|
|
torch.linalg.vander: lambda x, N=None: -1,
|
|
torch.var: lambda input, dim=None: -1,
|
|
torch.var_mean: lambda input, dim=None: -1,
|
|
torch.vsplit: lambda input, indices_or_sections: -1,
|
|
torch.vstack: lambda tensors, out=None: -1,
|
|
torch.where: lambda condition, x=None, y=None: -1,
|
|
torch.zeros_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
|
|
torch._fw_primal_copy: lambda self, level: -1,
|
|
torch._make_dual_copy: lambda primal, tangent, level: -1,
|
|
torch.view_as_real_copy: lambda self: -1,
|
|
torch.view_as_complex_copy: lambda self: -1,
|
|
torch._conj_copy: lambda self: -1,
|
|
torch._neg_view_copy: lambda self: -1,
|
|
torch.as_strided_copy: lambda self, size, stride, storage_offset=None: -1,
|
|
torch._sparse_broadcast_to_copy: lambda self, size: -1,
|
|
torch.diagonal_copy: lambda self, offset=0, dim1=0, dim2=1: -1,
|
|
torch.expand_copy: lambda self, size, *, implicit=False: -1,
|
|
torch.narrow_copy: lambda self, dim, start, length: -1,
|
|
torch.permute_copy: lambda self, dims: -1,
|
|
torch._reshape_alias_copy: lambda self, size, stride: -1,
|
|
torch.select_copy: lambda self, dim, index: -1,
|
|
torch.detach_copy: lambda self: -1,
|
|
torch.slice_copy: lambda self, dim=0, start=None, end=None, step=1: -1,
|
|
torch.split_copy: lambda self, split_size, dim=0: -1,
|
|
torch.split_with_sizes_copy: lambda self, split_sizes, dim=0: -1,
|
|
torch.squeeze_copy: lambda self: -1,
|
|
torch.squeeze_copy: lambda self, dim: -1,
|
|
torch.t_copy: lambda self: -1,
|
|
torch.transpose_copy: lambda self, dim0, dim1: -1,
|
|
torch.unsqueeze_copy: lambda self, dim: -1,
|
|
torch._indices_copy: lambda self: -1,
|
|
torch._values_copy: lambda self: -1,
|
|
torch.indices_copy: lambda self: -1,
|
|
torch.values_copy: lambda self: -1,
|
|
torch.crow_indices_copy: lambda self: -1,
|
|
torch.col_indices_copy: lambda self: -1,
|
|
torch.unbind_copy: lambda self, dim=0: -1,
|
|
torch.view_copy: lambda self, size: -1,
|
|
torch.view_copy: lambda self, dtype: -1,
|
|
torch.unfold_copy: lambda self, dimension, size, step: -1,
|
|
torch.alias_copy: lambda self: -1,
|
|
Tensor.__floordiv__: lambda self, other: -1,
|
|
Tensor.__rfloordiv__: lambda self, other: -1,
|
|
Tensor.__ifloordiv__: lambda self, other: -1,
|
|
Tensor.__truediv__: lambda self, other: -1,
|
|
Tensor.__rtruediv__: lambda self, other: -1,
|
|
Tensor.__itruediv__: lambda self, other: -1,
|
|
Tensor.__lshift__: lambda self, other: -1,
|
|
Tensor.__rlshift__: lambda self, other: -1,
|
|
Tensor.__ilshift__: lambda self, other: -1,
|
|
Tensor.__rshift__: lambda self, other: -1,
|
|
Tensor.__rrshift__: lambda self, other: -1,
|
|
Tensor.__irshift__: lambda self, other: -1,
|
|
Tensor.__and__: lambda self, other: -1,
|
|
Tensor.__or__: lambda self, other: -1,
|
|
Tensor.__xor__: lambda self, other: -1,
|
|
Tensor.__float__: lambda self: -1,
|
|
Tensor.__complex__: lambda self: -1,
|
|
Tensor.__array__: lambda self, dtype: -1,
|
|
Tensor.__bool__: lambda self: -1,
|
|
Tensor.__contains__: lambda self, other: -1,
|
|
Tensor.__neg__: lambda self: -1,
|
|
Tensor.__invert__: lambda self: -1,
|
|
Tensor.__mod__: lambda self, other: -1,
|
|
Tensor.__rmod__: lambda self, other: -1,
|
|
Tensor.__imod__: lambda self, other: -1,
|
|
Tensor.__array_wrap__: lambda self, array: -1,
|
|
Tensor.__getitem__: lambda self, idx: -1,
|
|
Tensor.__deepcopy__: lambda self, memo: -1,
|
|
Tensor.__int__: lambda self: -1,
|
|
Tensor.__long__: lambda self: -1,
|
|
Tensor.__hash__: lambda self: -1,
|
|
Tensor.__index__: lambda self: -1,
|
|
Tensor.__len__: lambda self: -1,
|
|
Tensor.__format__: lambda self, format_spec: -1,
|
|
Tensor.__reduce_ex__: lambda self, proto: -1,
|
|
Tensor.__reversed__: lambda self: -1,
|
|
Tensor.__repr__: lambda self, *, tensor_contents=None: -1,
|
|
Tensor.__setitem__: lambda self, k, v: -1,
|
|
Tensor.__setstate__: lambda self, d: -1,
|
|
Tensor.T.__get__: lambda self: -1,
|
|
Tensor.H.__get__: lambda self: -1,
|
|
Tensor.mT.__get__: lambda self: -1,
|
|
Tensor.mH.__get__: lambda self: -1,
|
|
Tensor._backward_hooks.__get__: lambda self: -1,
|
|
Tensor._base.__get__: lambda self: -1,
|
|
Tensor._cdata.__get__: lambda self: -1,
|
|
Tensor.grad.__get__: lambda self: -1,
|
|
Tensor._grad.__get__: lambda self: -1,
|
|
Tensor._grad_fn.__get__: lambda self: -1,
|
|
Tensor.grad_fn.__get__: lambda self: -1,
|
|
Tensor._version.__get__: lambda self: -1,
|
|
Tensor._autocast_to_reduced_precision: lambda self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype: -1,
|
|
Tensor._autocast_to_full_precision: lambda self, cuda_enabled, cpu_enabled: -1,
|
|
Tensor.data.__get__: lambda self: -1,
|
|
Tensor.device.__get__: lambda self: -1,
|
|
Tensor.dtype.__get__: lambda self: -1,
|
|
Tensor.is_cuda.__get__: lambda self: -1,
|
|
Tensor.is_xpu.__get__: lambda self: -1,
|
|
Tensor.is_ipu.__get__: lambda self: -1,
|
|
Tensor.is_leaf.__get__: lambda self: -1,
|
|
Tensor.retains_grad.__get__: lambda self: -1,
|
|
Tensor.is_meta.__get__: lambda self: -1,
|
|
Tensor.is_mps.__get__: lambda self: -1,
|
|
Tensor.is_nested.__get__: lambda self: -1,
|
|
Tensor.is_ort.__get__: lambda self: -1,
|
|
Tensor.is_mkldnn.__get__: lambda self: -1,
|
|
Tensor.is_quantized.__get__: lambda self: -1,
|
|
Tensor.is_sparse.__get__: lambda self: -1,
|
|
Tensor.is_sparse_csr.__get__: lambda self: -1,
|
|
Tensor.is_vulkan.__get__: lambda self: -1,
|
|
Tensor.layout.__get__: lambda self: -1,
|
|
Tensor.name.__get__: lambda self: -1,
|
|
Tensor.names.__get__: lambda self: -1,
|
|
Tensor.ndim.__get__: lambda self: -1,
|
|
Tensor.output_nr.__get__: lambda self: -1,
|
|
Tensor.requires_grad.__get__: lambda self: -1,
|
|
Tensor.shape.__get__: lambda self: -1,
|
|
Tensor.volatile.__get__: lambda self: -1,
|
|
Tensor.real.__get__: lambda self: -1,
|
|
Tensor.imag.__get__: lambda self: -1,
|
|
Tensor.__cuda_array_interface__.__get__: lambda self: -1,
|
|
Tensor.type: lambda self, dtype=None, non_blocking=False, **kwargs: -1,
|
|
Tensor._coalesced_: lambda self: -1,
|
|
Tensor._dimI: lambda self: -1,
|
|
Tensor._dimV: lambda self: -1,
|
|
Tensor._indices: lambda self: -1,
|
|
Tensor._is_view: lambda self: -1,
|
|
Tensor._nnz: lambda self: -1,
|
|
Tensor.crow_indices: lambda self: -1,
|
|
Tensor.col_indices: lambda self: -1,
|
|
Tensor._update_names: lambda self, names, inplace: -1,
|
|
Tensor._values: lambda self: -1,
|
|
Tensor.adjoint: lambda self: -1,
|
|
Tensor.align_as: lambda self, other: -1,
|
|
Tensor.align_to: lambda self, order, ellipsis_idx: -1,
|
|
Tensor.apply_: lambda self, callable: -1,
|
|
Tensor.as_strided: lambda self, size, stride: -1,
|
|
Tensor.as_strided_: lambda self, size, stride: -1,
|
|
Tensor.backward: lambda self, gradient=None, retain_graph=None, create_graph=False, inputs=None: -1,
|
|
Tensor.bfloat16: lambda self, memory_format=torch.preserve_format: -1,
|
|
Tensor.bool: lambda self, memory_format=torch.preserve_format: -1,
|
|
Tensor.byte: lambda self, memory_format=torch.preserve_format: -1,
|
|
Tensor.char: lambda self, memory_format=torch.preserve_format: -1,
|
|
Tensor.cauchy_: lambda self, median=0, sigma=1, *, generator=None: -1,
|
|
Tensor.coalesce: lambda self: -1,
|
|
Tensor._coalesced_: lambda self, coalesced: -1,
|
|
Tensor.contiguous: lambda self, memory_format=torch.contiguous_format: -1,
|
|
Tensor.copy_: lambda self, src, non_blocking=False: -1,
|
|
Tensor.cpu: lambda self, memory_format=torch.preserve_format: -1,
|
|
Tensor.cuda: lambda self, memory_format=torch.preserve_format: -1,
|
|
Tensor.xpu: lambda self, memory_format=torch.preserve_format: -1,
|
|
Tensor.ipu: lambda self, memory_format=torch.preserve_format: -1,
|
|
Tensor.data_ptr: lambda self: -1,
|
|
Tensor.dense_dim: lambda self: -1,
|
|
Tensor.diagonal_scatter: lambda self, src, offset=0, dim1=0, dim2=1: -1,
|
|
Tensor.dim: lambda self: -1,
|
|
Tensor.double: lambda self, memory_format=torch.preserve_format: -1,
|
|
Tensor.cdouble: lambda self, memory_format=torch.preserve_format: -1,
|
|
Tensor.element_size: lambda self: -1,
|
|
Tensor.expand: lambda self, size: -1,
|
|
Tensor.expand_as: lambda self, other: -1,
|
|
Tensor.exponential_: lambda self, lambd=1, *, generator=None: -1,
|
|
Tensor.fill_: lambda self, value: -1,
|
|
Tensor.fill_diagonal_: lambda self, value: -1,
|
|
Tensor.float: lambda self, memory_format=torch.preserve_format: -1,
|
|
Tensor.cfloat: lambda self, memory_format=torch.preserve_format: -1,
|
|
Tensor.geometric_: lambda self, p, *, generator=None: -1,
|
|
Tensor.get_device: lambda self: -1,
|
|
Tensor.half: lambda self, memory_format=torch.preserve_format: -1,
|
|
Tensor.chalf: lambda self, memory_format=torch.preserve_format: -1,
|
|
Tensor.has_names: lambda self: -1,
|
|
Tensor.indices: lambda self: -1,
|
|
Tensor.int: lambda self, memory_format=torch.preserve_format: -1,
|
|
Tensor.is_coalesced: lambda self: -1,
|
|
Tensor.is_contiguous: lambda self: -1,
|
|
Tensor.is_inference: lambda self: -1,
|
|
Tensor.is_pinned: lambda self: -1,
|
|
Tensor.is_set_to: lambda self, tensor: -1,
|
|
Tensor.is_shared: lambda self: -1,
|
|
Tensor.item: lambda self: -1,
|
|
Tensor.log_normal_: lambda self, mean=1, std=2, *, generator=None: -1,
|
|
Tensor.log_softmax: lambda self, dim: -1,
|
|
Tensor.long: lambda self, memory_format=torch.preserve_format: -1,
|
|
Tensor.map_: lambda self, tensor, callable: -1,
|
|
Tensor.map2_: lambda self, x, y, callable: -1,
|
|
Tensor.mm: lambda self, mat2: -1,
|
|
Tensor.narrow_copy: lambda self, dimension, start, length: -1,
|
|
Tensor.ndimension: lambda self: -1,
|
|
Tensor.nelement: lambda self: -1,
|
|
Tensor.normal_: lambda self: -1,
|
|
Tensor.numpy: lambda self: -1,
|
|
Tensor.permute: lambda self, dim: -1,
|
|
Tensor.pin_memory: lambda self: -1,
|
|
Tensor.put_: lambda self, indices, tensor, accumulate=False: -1,
|
|
Tensor.qscheme: lambda self: -1,
|
|
Tensor.random_: lambda self, from_=0, to=None, *, generator=None: -1,
|
|
Tensor.record_stream: lambda self, stream: -1,
|
|
Tensor.refine_names: lambda self, names: -1,
|
|
Tensor.register_hook: lambda self, hook: -1,
|
|
Tensor.rename: lambda self, name: -1,
|
|
Tensor.repeat: lambda self, *size: -1,
|
|
Tensor.requires_grad_: lambda self, requires_grad=True: -1,
|
|
Tensor.reshape_as: lambda self, other: -1,
|
|
Tensor.resize: lambda self, *size: -1,
|
|
Tensor.resize_: lambda self, size: -1,
|
|
Tensor.resize_as: lambda self, other: -1,
|
|
Tensor.resize_as_sparse_: lambda self, other: -1,
|
|
Tensor.retain_grad: lambda self: -1,
|
|
Tensor.set_: lambda self, source=None, storage_offset=0, size=None, stride=None: -1,
|
|
Tensor.select_scatter: lambda self, src, dim, index: -1,
|
|
Tensor.share_memory_: lambda self: -1,
|
|
Tensor.short: lambda self, memory_format=torch.preserve_format: -1,
|
|
Tensor.size: lambda self: -1,
|
|
Tensor.slice_scatter: lambda self, src, dim=0, start=None, end=None, step=1: -1,
|
|
Tensor.sparse_dim: lambda self: -1,
|
|
Tensor.sparse_mask: lambda self, mask: -1,
|
|
Tensor.sparse_resize_: lambda self, size1, size2, dense_dim: -1,
|
|
Tensor.sparse_resize_and_clear_: lambda self, size1, size2, dense_dim: -1,
|
|
Tensor.sspaddmm: lambda self, mat1, mat2, beta=1, alpha=1, out=None: -1,
|
|
Tensor.storage: lambda self: -1,
|
|
Tensor._storage: lambda self: -1,
|
|
Tensor.storage_offset: lambda self: -1,
|
|
Tensor.storage_type: lambda self: -1,
|
|
Tensor.sum_to_size: lambda self, size: -1,
|
|
Tensor.tile: lambda self, *reps: -1,
|
|
Tensor.to: lambda self, dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format: -1,
|
|
Tensor.to_dense: lambda self, dtype=None: -1,
|
|
Tensor._to_dense: lambda self, dtype=None: -1,
|
|
Tensor.to_sparse: lambda self: -1,
|
|
Tensor.tolist: lambda self: -1,
|
|
Tensor.to_mkldnn: lambda self: -1,
|
|
Tensor.type_as: lambda self, other: -1,
|
|
Tensor.unfold: lambda self, dimension, size, step: -1,
|
|
Tensor.uniform_: lambda self, from_=0, to=1: -1,
|
|
Tensor.values: lambda self: -1,
|
|
Tensor.view: lambda self, shape: -1,
|
|
Tensor.view_as: lambda self, other: -1,
|
|
Tensor.zero_: lambda self: -1,
|
|
Tensor.__dlpack__: lambda self, stream=None: -1,
|
|
Tensor.__dlpack_device__: lambda self: -1,
|
|
torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1,
|
|
}
|
|
|
|
ret2 = {}
|
|
ignored = get_ignored_functions()
|
|
|
|
for k, v in ret.items():
|
|
# Generate methods like __add__ and add_ by default from add
|
|
names = [
|
|
k.__name__, # Default method
|
|
k.__name__ + "_", # Inplace variant
|
|
"__" + k.__name__ + "__", # Dunder method
|
|
"__i" + k.__name__ + "__", # Inplace dunder method
|
|
"__r" + k.__name__ + "__", # Reverse dunder method
|
|
]
|
|
|
|
if k.__name__.startswith("bitwise_"):
|
|
# bitwise_<op> have dunder methods of the form __<op>__
|
|
# And so on.
|
|
subname = k.__name__[len("bitwise_"):]
|
|
names.extend([
|
|
"__" + subname + "__",
|
|
"__i" + subname + "__",
|
|
"__r" + subname + "__"
|
|
])
|
|
|
|
for name in names:
|
|
func = getattr(Tensor, name, None)
|
|
if callable(func) and func not in ret and func not in ignored:
|
|
ret2[func] = v
|
|
|
|
ret.update(ret2)
|
|
return ret
|
|
|
|
def wrap_torch_function(dispatcher: Callable):
|
|
"""Wraps a given function with ``__torch_function__`` -related functionality.
|
|
|
|
Parameters
|
|
----------
|
|
dispatcher: Callable
|
|
A callable that returns an iterable of Tensor-likes passed into the function.
|
|
|
|
Note
|
|
----
|
|
This decorator may reduce the performance of your code. Generally, it's enough to express
|
|
your code as a series of functions that, themselves, support __torch_function__. If you
|
|
find yourself in the rare situation where this is not the case, e.g. if you're wrapping a
|
|
low-level library and you also need it to work for Tensor-likes, then this function is available.
|
|
|
|
Examples
|
|
--------
|
|
>>> def dispatcher(a): # Must have the same signature as func
|
|
... return (a,)
|
|
>>> @torch.overrides.wrap_torch_function(dispatcher)
|
|
>>> def func(a): # This will make func dispatchable by __torch_function__
|
|
... return a + 0
|
|
"""
|
|
def inner(func):
|
|
@functools.wraps(func)
|
|
def wrapped(*args, **kwargs):
|
|
relevant_args = dispatcher(*args, **kwargs)
|
|
if has_torch_function(relevant_args):
|
|
return handle_torch_function(wrapped, relevant_args, *args, **kwargs)
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapped
|
|
|
|
return inner
|
|
|
|
def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]:
|
|
"""Returns a list of arguments on which to call __torch_function__.
|
|
|
|
Checks arguments in relevant_args for __torch_function__ implementations,
|
|
storing references to the arguments and their types in overloaded_args and
|
|
overloaded_types in order of calling precedence. Only distinct types are
|
|
considered. If a type is a subclass of another type it will have higher
|
|
precedence, otherwise the precedence order is the same as the order of
|
|
arguments in relevant_args, that is, from left-to-right in the argument list.
|
|
|
|
The precedence-determining algorithm implemented in this function is
|
|
described in `NEP-0018`_.
|
|
|
|
See torch::append_overloaded_arg for the equivalent function in the C++
|
|
implementation.
|
|
|
|
Parameters
|
|
----------
|
|
relevant_args : iterable of array-like
|
|
Iterable of array-like arguments to check for __torch_function__
|
|
methods.
|
|
|
|
Returns
|
|
-------
|
|
overloaded_args : list
|
|
Arguments from relevant_args on which to call __torch_function__
|
|
methods, in the order in which they should be called.
|
|
|
|
.. _NEP-0018:
|
|
https://numpy.org/neps/nep-0018-array-function-protocol.html
|
|
"""
|
|
# If torch function is not enabled, there are no overloaded types
|
|
if not torch._C._is_torch_function_enabled():
|
|
return []
|
|
# Runtime is O(num_arguments * num_unique_types)
|
|
overloaded_types: Set[Type] = set()
|
|
overloaded_args: List[Any] = []
|
|
for arg in relevant_args:
|
|
arg_type = type(arg)
|
|
# We only collect arguments if they have a unique type, which ensures
|
|
# reasonable performance even with a long list of possibly overloaded
|
|
# arguments.
|
|
#
|
|
# NB: Important to exclude _disabled_torch_function_impl, otherwise
|
|
# https://github.com/pytorch/pytorch/issues/64687
|
|
if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__') and
|
|
arg_type.__torch_function__ != torch._C._disabled_torch_function_impl):
|
|
# Create lists explicitly for the first type (usually the only one
|
|
# done) to avoid setting up the iterator for overloaded_args.
|
|
if overloaded_types:
|
|
overloaded_types.add(arg_type)
|
|
# By default, insert argument at the end, but if it is
|
|
# subclass of another argument, insert it before that argument.
|
|
# This ensures "subclasses before superclasses".
|
|
index = len(overloaded_args)
|
|
for i, old_arg in enumerate(overloaded_args):
|
|
if issubclass(arg_type, type(old_arg)):
|
|
index = i
|
|
break
|
|
overloaded_args.insert(index, arg)
|
|
else:
|
|
overloaded_types = {arg_type}
|
|
overloaded_args = [arg]
|
|
return overloaded_args
|
|
|
|
|
|
def handle_torch_function(
|
|
public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs) -> Any:
|
|
"""Implement a function with checks for ``__torch_function__`` overrides.
|
|
|
|
See torch::autograd::handle_torch_function for the equivalent of this
|
|
function in the C++ implementation.
|
|
|
|
Arguments
|
|
---------
|
|
public_api : function
|
|
Function exposed by the public torch API originally called like
|
|
``public_api(*args, **kwargs)`` on which arguments are now being
|
|
checked.
|
|
relevant_args : iterable
|
|
Iterable of arguments to check for __torch_function__ methods.
|
|
args : tuple
|
|
Arbitrary positional arguments originally passed into ``public_api``.
|
|
kwargs : tuple
|
|
Arbitrary keyword arguments originally passed into ``public_api``.
|
|
|
|
Returns
|
|
-------
|
|
object
|
|
Result from calling ``implementation`` or an ``__torch_function__``
|
|
method, as appropriate.
|
|
|
|
Raises
|
|
------
|
|
TypeError : if no implementation is found.
|
|
|
|
Example
|
|
-------
|
|
>>> def func(a):
|
|
... if has_torch_function_unary(a):
|
|
... return handle_torch_function(func, (a,), a)
|
|
... return a + 0
|
|
"""
|
|
# Check for __torch_function__ methods.
|
|
overloaded_args = _get_overloaded_args(relevant_args)
|
|
# overloaded_args already have unique types.
|
|
types = tuple(map(type, overloaded_args))
|
|
|
|
# Check for __torch_function__ mode.
|
|
mode = _get_torch_function_mode()
|
|
if mode is not None:
|
|
# NB: unlike on tensors, modes are instances
|
|
with _no_torch_function_mode():
|
|
result = mode.__torch_function__(public_api, types, args, kwargs)
|
|
if result is not NotImplemented:
|
|
return result
|
|
|
|
# Call overrides
|
|
for overloaded_arg in overloaded_args:
|
|
# This call needs to become a classmethod call in the future.
|
|
# See https://github.com/pytorch/pytorch/issues/63767
|
|
torch_func_method = overloaded_arg.__torch_function__
|
|
if hasattr(torch_func_method, "__self__") and torch_func_method.__self__ is overloaded_arg and \
|
|
torch_func_method is not torch._C._disabled_torch_function_impl:
|
|
warnings.warn("Defining your `__torch_function__ as a plain method is deprecated and "
|
|
"will be an error in future, please define it as a classmethod.",
|
|
DeprecationWarning)
|
|
|
|
# Use `public_api` instead of `implementation` so __torch_function__
|
|
# implementations can do equality/identity comparisons.
|
|
result = torch_func_method(public_api, types, args, kwargs)
|
|
|
|
if result is not NotImplemented:
|
|
return result
|
|
|
|
func_name = '{}.{}'.format(public_api.__module__, public_api.__name__)
|
|
msg = (
|
|
"no implementation found for '{}' on types that implement "
|
|
'__torch_function__: {}'
|
|
).format(func_name, [type(arg) for arg in overloaded_args])
|
|
if mode is not None:
|
|
msg += f" nor in mode {mode}"
|
|
raise TypeError(msg)
|
|
|
|
has_torch_function = _add_docstr(
|
|
_has_torch_function,
|
|
r"""Check for __torch_function__ implementations in the elements of an iterable
|
|
or if a __torch_function__ mode is enabled. Considers exact ``Tensor`` s
|
|
and ``Parameter`` s non-dispatchable. Use this to guard a call to
|
|
:func:`handle_torch_function`; don't use it to test if something
|
|
is Tensor-like, use :func:`is_tensor_like` instead.
|
|
Arguments
|
|
---------
|
|
relevant_args : iterable
|
|
Iterable or aguments to check for __torch_function__ methods.
|
|
Returns
|
|
-------
|
|
bool
|
|
True if any of the elements of relevant_args have __torch_function__
|
|
implementations, False otherwise.
|
|
See Also
|
|
________
|
|
torch.is_tensor_like
|
|
Checks if something is a Tensor-like, including an exact ``Tensor``.
|
|
"""
|
|
)
|
|
|
|
has_torch_function_unary = _add_docstr(
|
|
_has_torch_function_unary,
|
|
r"""Special case of `has_torch_function` for single inputs.
|
|
Instead of:
|
|
`has_torch_function((t,))`
|
|
call:
|
|
`has_torch_function_unary(t)`
|
|
which skips unnecessary packing and unpacking work.
|
|
"""
|
|
)
|
|
|
|
has_torch_function_variadic = _add_docstr(
|
|
_has_torch_function_variadic,
|
|
r"""Special case of `has_torch_function` that skips tuple creation.
|
|
|
|
This uses the METH_FASTCALL protocol introduced in Python 3.7
|
|
|
|
Instead of:
|
|
`has_torch_function((a, b))`
|
|
call:
|
|
`has_torch_function_variadic(a, b)`
|
|
which skips unnecessary packing and unpacking work.
|
|
"""
|
|
)
|
|
|
|
@functools.lru_cache(None)
|
|
def get_overridable_functions() -> Dict[Any, List[Callable]]:
|
|
"""List functions that are overridable via __torch_function__
|
|
|
|
Returns
|
|
-------
|
|
Dict[Any, List[Callable]]
|
|
A dictionary that maps namespaces that contain overridable functions
|
|
to functions in that namespace that can be overridden.
|
|
"""
|
|
overridable_funcs = collections.defaultdict(list)
|
|
tested_namespaces = [
|
|
(torch, torch.__all__ + dir(torch._C._VariableFunctions)),
|
|
(torch.functional, torch.functional.__all__),
|
|
(torch.nn.functional, dir(torch.nn.functional)),
|
|
(torch.nn.init, dir(torch.nn.init)),
|
|
(torch.Tensor, dir(torch.Tensor)),
|
|
(torch.linalg, dir(torch.linalg)),
|
|
(torch.fft, dir(torch.fft)),
|
|
(torch.special, dir(torch.special)),
|
|
]
|
|
for namespace, ns_funcs in tested_namespaces:
|
|
for func_name in ns_funcs:
|
|
# ignore private functions or functions that are deleted in torch.__init__
|
|
if namespace is not torch.Tensor:
|
|
if func_name.startswith('_'):
|
|
continue
|
|
elif func_name.endswith('_'):
|
|
continue
|
|
elif not func_name[0].islower():
|
|
continue
|
|
elif func_name == 'unique_dim':
|
|
continue
|
|
else:
|
|
func = getattr(namespace, func_name)
|
|
if getattr(object, func_name, None) == func:
|
|
continue
|
|
if func_name == '__weakref__':
|
|
continue
|
|
func = getattr(namespace, func_name)
|
|
if namespace is torch.Tensor and getattr(object, func_name, None) == func:
|
|
continue
|
|
# ignore re-exported modules
|
|
if isinstance(func, types.ModuleType):
|
|
continue
|
|
# ignore __future__ imports
|
|
if isinstance(func, __future__._Feature):
|
|
continue
|
|
|
|
if not callable(func) and hasattr(func, "__get__"):
|
|
if func.__get__ in get_ignored_functions():
|
|
msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
|
|
"but still has an explicit override")
|
|
assert func.__get__ not in get_testing_overrides(), msg.format(namespace, func.__name__)
|
|
continue
|
|
else:
|
|
overridable_funcs[func].append(func.__get__)
|
|
continue
|
|
|
|
if not callable(func):
|
|
continue
|
|
|
|
# cannot be overriden by __torch_function__
|
|
if func in get_ignored_functions():
|
|
msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
|
|
"but still has an explicit override")
|
|
assert func not in get_testing_overrides(), msg.format(namespace, func.__name__)
|
|
continue
|
|
overridable_funcs[namespace].append(func)
|
|
return overridable_funcs
|
|
|
|
@functools.lru_cache(None)
|
|
def _get_tensor_methods() -> Set[Callable]:
|
|
""" Returns a set of the overridable methods on ``torch.Tensor`` """
|
|
overridable_funcs = get_overridable_functions()
|
|
methods = set(overridable_funcs[torch.Tensor])
|
|
return methods
|
|
|
|
def is_tensor_method_or_property(func: Callable) -> bool:
|
|
"""
|
|
Returns True if the function passed in is a handler for a
|
|
method or property belonging to ``torch.Tensor``, as passed
|
|
into ``__torch_function__``.
|
|
|
|
.. note::
|
|
For properties, their ``__get__`` method must be passed in.
|
|
|
|
This may be needed, in particular, for the following reasons:
|
|
|
|
1. Methods/properties sometimes don't contain a `__module__` slot.
|
|
2. They require that the first passed-in argument is an instance
|
|
of ``torch.Tensor``.
|
|
|
|
Examples
|
|
--------
|
|
>>> is_tensor_method_or_property(torch.Tensor.add)
|
|
True
|
|
>>> is_tensor_method_or_property(torch.add)
|
|
False
|
|
"""
|
|
return func in _get_tensor_methods() or func.__name__ == "__get__"
|
|
|
|
def is_tensor_like(inp):
|
|
"""
|
|
Returns ``True`` if the passed-in input is a Tensor-like.
|
|
|
|
Currently, this occurs whenever there's a ``__torch_function__``
|
|
attribute on the type of the input.
|
|
|
|
Examples
|
|
--------
|
|
A subclass of tensor is generally a Tensor-like.
|
|
|
|
>>> class SubTensor(torch.Tensor): ...
|
|
>>> is_tensor_like(SubTensor([0]))
|
|
True
|
|
|
|
Built-in or user types aren't usually Tensor-like.
|
|
|
|
>>> is_tensor_like(6)
|
|
False
|
|
>>> is_tensor_like(None)
|
|
False
|
|
>>> class NotATensor: ...
|
|
>>> is_tensor_like(NotATensor())
|
|
False
|
|
|
|
But, they can be made Tensor-like by implementing __torch_function__.
|
|
|
|
>>> class TensorLike:
|
|
... @classmethod
|
|
... def __torch_function__(cls, func, types, args, kwargs):
|
|
... return -1
|
|
>>> is_tensor_like(TensorLike())
|
|
True
|
|
"""
|
|
return type(inp) is torch.Tensor or hasattr(type(inp), "__torch_function__")
|
|
|
|
def _wrap_init(f):
|
|
undef = object()
|
|
|
|
@functools.wraps(f)
|
|
def wrapped(self, *args, inner=undef, **kwargs):
|
|
if inner is undef:
|
|
raise TypeError(
|
|
"missing inner keyword argument; instead of constructing a TorchModeFunction directly, "
|
|
"pass the constructor to push_torch_function_mode"
|
|
)
|
|
self.inner = inner
|
|
return f(self, *args, **kwargs)
|
|
return wrapped
|
|
|
|
|
|
def _wrap_torch_function(f):
|
|
@functools.wraps(f)
|
|
def wrapped(self, *args, **kwargs):
|
|
with enable_torch_function_mode(self.inner):
|
|
return f(self, *args, **kwargs)
|
|
return wrapped
|
|
|
|
|
|
# Implementation note: I had a choice about how much of mode stacks
|
|
# to implement in Python versus in C++. At time of writing, I did not care
|
|
# too much about implementation efficiency; however, I do care about making it
|
|
# hard for users to implement modes in the wrong way. In the end, it turned
|
|
# out to be possible to implement mode stacks entirely from userland, with the
|
|
# C++ API providing only _get_torch_function_mode() and
|
|
# _set_torch_function_mode(), so I opted to provide some unsafe C++ bindings and
|
|
# have the bulk of the logic for managing the stack in Python, which helped
|
|
# simplify the C++ API surface. It would also have been valid to build in the
|
|
# notion of mode stack directly into C++ but in this design it's substantially
|
|
# more difficult to interact with TorchFunctionModeMeta.
|
|
|
|
|
|
class TorchFunctionModeMeta(type):
|
|
"""
|
|
Metaclass for :class:`TorchFunctionMode`; it does two things:
|
|
|
|
* Adds an implicit ``inner`` kwarg to ``__init__``, to
|
|
allow the modes to be chained together to form a stack.
|
|
|
|
* Reenables the inner mode, so that by default PyTorch API calls
|
|
will compositionally proceed to the next mode on the stack.
|
|
|
|
The default behavior for the second bullet is important, as it is easy to
|
|
accidentally write ``__torch_function__`` implementations that are not
|
|
compositional, and the wrapping here makes the obvious code do the
|
|
right thing (aka, this is why there is a metaclass).
|
|
"""
|
|
def __new__(metacls, name, bases, dct):
|
|
if '__init__' in dct:
|
|
dct['__init__'] = _wrap_init(dct['__init__'])
|
|
if '__torch_function__' in dct:
|
|
dct['__torch_function__'] = _wrap_torch_function(dct['__torch_function__'])
|
|
return super().__new__(metacls, name, bases, dct)
|
|
|
|
|
|
class TorchFunctionMode(metaclass=TorchFunctionModeMeta):
|
|
"""
|
|
A ``TorchFunctionMode`` allows you to override the meaning of all
|
|
``__torch_function__`` overrideable functions within a dynamic scope,
|
|
without having to actually create a tensor subclass or manually
|
|
monkey-patch functions in the PyTorch API. Some common situations
|
|
where you should use a mode:
|
|
|
|
* You want to override the meaning of factory functions, or other
|
|
functions that do not otherwise take a tensor as an argument
|
|
(these cannot be overridden with tensor subclasses).
|
|
|
|
* You want to override the behavior of all functions without needing
|
|
to wrap your inputs in tensor subclasses; e.g., if you are just
|
|
interested in logging intermediate computations.
|
|
|
|
* You want to control the order of execution of various tensor
|
|
subclasses explicitly, rather than implicitly via the return of
|
|
``NotImplemented``.
|
|
|
|
Independent subclasses of :class:`TorchFunctionMode` are compositional:
|
|
modes can be pushed onto a stack with :func:`push_torch_function_mode`.
|
|
When you call functions in the PyTorch API inside your
|
|
``__torch_function__`` implementation, by default, they will forward on to
|
|
the next mode on the mode stack. If you want recursively call back into
|
|
your current ``__torch_function__`` implementation, either explicitly
|
|
invoke ``self.__torch_function__(...)``, or use the context manager
|
|
``enable_torch_function_mode(self, replace=self.inner)`` to make PyTorch
|
|
API self-referential (beware of infinite loops, in this case!)
|
|
"""
|
|
# Force metaclass to generate constructor at the base of the hierarchy
|
|
def __init__(self):
|
|
pass
|
|
|
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class BaseTorchFunctionMode(TorchFunctionMode):
|
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
# This is private API as I'm not sure it's possible for users to use this
|
|
# compositionally (easy to discard too many modes). It is useful for
|
|
# library code though, e.g., in handle_torch_function
|
|
@contextlib.contextmanager
|
|
def _no_torch_function_mode() -> Iterator[None]:
|
|
old = _get_torch_function_mode()
|
|
_set_torch_function_mode(None)
|
|
try:
|
|
yield
|
|
finally:
|
|
_set_torch_function_mode(old)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def enable_torch_function_mode(mode, *, replace=None, ignore_preexisting=False) -> Iterator[None]:
|
|
"""
|
|
Context manager that sets the current :class:`TorchFunctionMode`; see the
|
|
class for more information on what modes are. This function is
|
|
non-compositional; if there is already an existing mode, it will raise an
|
|
error; prefer using :func:`push_torch_function_mode` if your
|
|
``__torch_function__`` implementation can defer to an inner mode.
|
|
|
|
This function is safe to use inside a ``__torch_function__`` mode handler,
|
|
as the mode is guaranteed to be disabled in this context. You can use
|
|
this context manager to reinstate the mode so that calls to overridable
|
|
APIs recursively call back into your mode handler (this can easily cause
|
|
infinite loops, so use with care!)
|
|
|
|
Args:
|
|
mode (:class:`TorchFunctionMode`, Tensor-like class or None): the
|
|
mode to set as current mode. If you pass a Tensor-like class,
|
|
it will be treated as a non-compositional mode with no state,
|
|
which is convenient if you have an existing tensor subclass
|
|
that you'd like to apply globally in a quick and dirty way.
|
|
Passing None will disable the current mode.
|
|
replace (:class:`TorchFunctionMode` or Tensor-like class): the
|
|
mode to replace. You can use this argument to change the mode in
|
|
a situation where you know what the current mode is (and you are
|
|
intentionally overwriting it.) If you don't know what the current
|
|
mode is, use ``ignore_preexisting`` instead.
|
|
ignore_preexisting (bool): if True, ignore any preexisting mode
|
|
and overwrite it with the passed mode.
|
|
"""
|
|
if not (
|
|
mode is None or
|
|
isinstance(mode, TorchFunctionMode) or
|
|
(isinstance(mode, type) and not issubclass(mode, TorchFunctionMode))
|
|
):
|
|
raise TypeError(
|
|
"expected to get TorchFunctionMode, Tensor-like class or None as argument, got "
|
|
f"{type(mode)} instead"
|
|
)
|
|
old = _get_torch_function_mode()
|
|
# Short circuit. It is valid to reset a mode to yourself, we won't error
|
|
# here
|
|
if old is mode:
|
|
yield
|
|
return
|
|
if old is not None and not ignore_preexisting and old is not replace:
|
|
if isinstance(mode, TorchFunctionMode):
|
|
help_text = (
|
|
'Use push_torch_function_mode instead.'
|
|
)
|
|
else:
|
|
help_text = (
|
|
'If you intended to completely override the preexisting mode, '
|
|
'pass ignore_preexisting=True. This can result in unexpected '
|
|
'behavior; please consider rewriting your mode to be a subclass '
|
|
'of TorchFunctionMode to make it compositional!'
|
|
)
|
|
raise ValueError(
|
|
'Attempted to enable_torch_function_mode, but there is already an '
|
|
f'active mode {old}. {help_text}'
|
|
)
|
|
# NB: we don't require TorchFunctionMode since this is intended to also
|
|
# let you directly pass a Tensor subclass type to "mode-ify" it.
|
|
if not hasattr(mode, '__torch_function__'):
|
|
raise ValueError(
|
|
'The argument passed to enable_torch_function_mode must implement '
|
|
'__torch_function__'
|
|
)
|
|
_set_torch_function_mode(mode)
|
|
try:
|
|
yield
|
|
finally:
|
|
_set_torch_function_mode(old)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def push_torch_function_mode(ctor) -> Iterator[TorchFunctionMode]:
|
|
"""
|
|
Context manager that pushes a :class:`TorchFunctionMode` onto the current
|
|
mode stack; see the class for more information on what modes are. Stacked
|
|
modes can delegate to each other by invoking the ``__torch_function__``
|
|
method for the ``inner`` mode.
|
|
|
|
Args:
|
|
ctor: a function that when invoked as ``ctor(inner=...)`` produces
|
|
a :class:`TorchFunctionMode`. If your :class:`TorchFunctionMode`
|
|
has no ``__init__`` implementation, you can simply pass the class
|
|
itself (e.g., ``push_torch_function_mode(MyMode)``); otherwise,
|
|
use ``functools.partial`` to partially apply the constructor with all
|
|
non-inner arguments (e.g.,
|
|
``push_torch_function_mode(partial(MyMode, arg))``)
|
|
"""
|
|
if isinstance(ctor, TorchFunctionMode):
|
|
raise ValueError(
|
|
'Expected a TorchFunctionMode constructor function, but got an '
|
|
f'instance of TorchFunctionMode {ctor}. Consider using '
|
|
'enable_torch_function_mode instead.'
|
|
)
|
|
old = _get_torch_function_mode()
|
|
if old is None:
|
|
inner = BaseTorchFunctionMode(inner=None)
|
|
else:
|
|
inner = old
|
|
mode = ctor(inner=inner)
|
|
if not isinstance(mode, TorchFunctionMode):
|
|
raise ValueError(
|
|
'The callable passed to push_torch_function_mode must return '
|
|
'a TorchFunctionMode'
|
|
)
|
|
_set_torch_function_mode(mode)
|
|
try:
|
|
yield mode
|
|
finally:
|
|
_set_torch_function_mode(old)
|
|
|
|
class enable_reentrant_dispatch():
|
|
def __enter__(self):
|
|
self._raii_guard = torch._C._RestorePythonTLSSnapshot()
|
|
|
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
|
del self._raii_guard
|